add and utilize Merge and Embed layers
This commit is contained in:
parent
6f2ffcdef7
commit
acc8378980
2 changed files with 108 additions and 20 deletions
33
main.lua
33
main.lua
|
@ -52,7 +52,8 @@ local timer_loser = 1/3
|
||||||
local enable_overlay = playable_mode
|
local enable_overlay = playable_mode
|
||||||
local enable_network = not playable_mode
|
local enable_network = not playable_mode
|
||||||
|
|
||||||
local input_size = 281 -- TODO: let the script figure this out for us.
|
local input_size = 60 -- TODO: let the script figure this out for us.
|
||||||
|
local tile_count = 17 * 13
|
||||||
|
|
||||||
local ok_routines = {
|
local ok_routines = {
|
||||||
[0x4] = true, -- sliding down flagpole
|
[0x4] = true, -- sliding down flagpole
|
||||||
|
@ -231,17 +232,17 @@ package.loaded['nn'] = nil -- DEBUG
|
||||||
local nn = require("nn")
|
local nn = require("nn")
|
||||||
|
|
||||||
local network
|
local network
|
||||||
local nn_x
|
local nn_x, nn_tx, nn_ty, nn_y, nn_z
|
||||||
local nn_y
|
|
||||||
local nn_z
|
|
||||||
local function make_network(input_size, buttons)
|
local function make_network(input_size, buttons)
|
||||||
nn_x = nn.Input({input_size})
|
nn_x = nn.Input({input_size})
|
||||||
nn_y = nn_x
|
nn_tx = nn.Input({tile_count})
|
||||||
|
nn_ty = nn_tx:feed(nn.Embed(256, 2))
|
||||||
|
nn_y = nn.Merge()
|
||||||
|
nn_x:feed(nn_y)
|
||||||
|
nn_ty:feed(nn_y)
|
||||||
nn_z = {}
|
nn_z = {}
|
||||||
if false then
|
|
||||||
nn_y = nn_y:feed(nn.Dense(input_size))
|
if true then
|
||||||
nn_y = nn_y:feed(nn.Gelu())
|
|
||||||
else
|
|
||||||
nn_y = nn_y:feed(nn.Dense(128))
|
nn_y = nn_y:feed(nn.Dense(128))
|
||||||
nn_y = nn_y:feed(nn.Gelu())
|
nn_y = nn_y:feed(nn.Gelu())
|
||||||
nn_y = nn_y:feed(nn.Dense(64))
|
nn_y = nn_y:feed(nn.Dense(64))
|
||||||
|
@ -249,13 +250,14 @@ local function make_network(input_size, buttons)
|
||||||
nn_y = nn_y:feed(nn.Dense(48))
|
nn_y = nn_y:feed(nn.Dense(48))
|
||||||
nn_y = nn_y:feed(nn.Gelu())
|
nn_y = nn_y:feed(nn.Gelu())
|
||||||
end
|
end
|
||||||
|
|
||||||
for i = 1, buttons do
|
for i = 1, buttons do
|
||||||
nn_z[i] = nn_y
|
nn_z[i] = nn_y
|
||||||
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
|
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
|
||||||
nn_z[i] = nn_z[i]:feed(nn.Softmax())
|
nn_z[i] = nn_z[i]:feed(nn.Softmax())
|
||||||
end
|
end
|
||||||
|
|
||||||
return nn.Model({nn_x}, nn_z)
|
return nn.Model({nn_x, nn_tx}, nn_z)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- and here we go with the game stuff.
|
-- and here we go with the game stuff.
|
||||||
|
@ -495,7 +497,7 @@ local function handle_tiles()
|
||||||
--local tile_col = R(0x6A0)
|
--local tile_col = R(0x6A0)
|
||||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||||
local tile_scroll_remainder = R(0x73F) % 16
|
local tile_scroll_remainder = R(0x73F) % 16
|
||||||
tile_input[#tile_input+1] = tile_scroll_remainder
|
extra_input[#extra_input+1] = tile_scroll_remainder
|
||||||
for y = 0, 12 do
|
for y = 0, 12 do
|
||||||
for x = 0, 16 do
|
for x = 0, 16 do
|
||||||
local col = (x + tile_scroll) % 32
|
local col = (x + tile_scroll) % 32
|
||||||
|
@ -787,19 +789,14 @@ local function doit(dummy)
|
||||||
|
|
||||||
local X = {}
|
local X = {}
|
||||||
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
||||||
for i, v in ipairs(tile_input) do insert(X, v / 256) end
|
|
||||||
for i, v in ipairs(extra_input) do insert(X, v / 256) end
|
for i, v in ipairs(extra_input) do insert(X, v / 256) end
|
||||||
nn.reshape(X, 1, input_size)
|
nn.reshape(X, 1, input_size)
|
||||||
|
nn.reshape(tile_input, 1, tile_count)
|
||||||
|
|
||||||
if enable_network and get_state() == 'playing' or ingame_paused then
|
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||||
local choose = deterministic and argmax2 or rchoice2
|
local choose = deterministic and argmax2 or rchoice2
|
||||||
|
|
||||||
local outputs = network:forward({[nn_x]=X})
|
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
||||||
|
|
||||||
-- TODO: predict the *rewards* of all possible actions?
|
|
||||||
-- that's how DQN seems to work anyway.
|
|
||||||
-- ah, but A3C just returns probabilities,
|
|
||||||
-- besides the critic?
|
|
||||||
|
|
||||||
local softmaxed = {
|
local softmaxed = {
|
||||||
outputs[nn_z[1]],
|
outputs[nn_z[1]],
|
||||||
|
|
95
nn.lua
95
nn.lua
|
@ -95,6 +95,11 @@ local function init_zeros(t, fan_in, fan_out)
|
||||||
return t
|
return t
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function init_uniform(t, fan_in, fan_out)
|
||||||
|
for i = 1, #t do t[i] = uniform() * 2 - 1 end
|
||||||
|
return t
|
||||||
|
end
|
||||||
|
|
||||||
local function init_he_uniform(t, fan_in, fan_out)
|
local function init_he_uniform(t, fan_in, fan_out)
|
||||||
local s = sqrt(6 / fan_in)
|
local s = sqrt(6 / fan_in)
|
||||||
for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end
|
for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end
|
||||||
|
@ -298,10 +303,12 @@ local Weights = Base:extend()
|
||||||
local Layer = Base:extend()
|
local Layer = Base:extend()
|
||||||
local Model = Base:extend()
|
local Model = Base:extend()
|
||||||
local Input = Layer:extend()
|
local Input = Layer:extend()
|
||||||
|
local Merge = Layer:extend()
|
||||||
local Relu = Layer:extend()
|
local Relu = Layer:extend()
|
||||||
local Gelu = Layer:extend()
|
local Gelu = Layer:extend()
|
||||||
local Dense = Layer:extend()
|
local Dense = Layer:extend()
|
||||||
local Softmax = Layer:extend()
|
local Softmax = Layer:extend()
|
||||||
|
local Embed = Layer:extend()
|
||||||
|
|
||||||
function Weights:init(weight_init)
|
function Weights:init(weight_init)
|
||||||
self.weight_init = weight_init
|
self.weight_init = weight_init
|
||||||
|
@ -333,7 +340,7 @@ function Layer:make_shape(parent)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Layer:feed(child)
|
function Layer:feed(child)
|
||||||
assert(self.shape_out ~= nil)
|
assert(self.shape_out ~= nil, "missing output shape: "..self.name)
|
||||||
child:make_shape(self)
|
child:make_shape(self)
|
||||||
insert(self.children, child)
|
insert(self.children, child)
|
||||||
insert(child.parents, self)
|
insert(child.parents, self)
|
||||||
|
@ -360,7 +367,7 @@ end
|
||||||
|
|
||||||
function Layer:get_size()
|
function Layer:get_size()
|
||||||
local size = 0
|
local size = 0
|
||||||
for i, w in ipairs(self.weights) do size = size + prod(w.size) end
|
for i, w in ipairs(self.weights) do size = size + prod(w.shape) end
|
||||||
return size
|
return size
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -417,6 +424,43 @@ function Input:backward(dY)
|
||||||
return zeros(#dY)
|
return zeros(#dY)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Merge:init()
|
||||||
|
Layer.init(self, "Merge")
|
||||||
|
self.size = 0
|
||||||
|
self.shape_in = 0
|
||||||
|
end
|
||||||
|
|
||||||
|
function Merge:make_shape(parent)
|
||||||
|
self.size = self.size + prod(parent.shape_out)
|
||||||
|
|
||||||
|
self.shape_in = self.shape_in + 1 -- TODO: more robust.
|
||||||
|
self.shape_out = {self.size}
|
||||||
|
end
|
||||||
|
|
||||||
|
function Merge:reset_cache(bs)
|
||||||
|
self.bs = bs
|
||||||
|
|
||||||
|
self.cache = cache(bs, self.shape_out)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Merge:_propagate(edges, deterministic)
|
||||||
|
assert(#edges == self.shape_in)
|
||||||
|
local bs = edges[1].shape[1]
|
||||||
|
if bs ~= self.bs then self:reset_cache(bs) end
|
||||||
|
local Y = self.cache
|
||||||
|
|
||||||
|
local yi = 1
|
||||||
|
for i, X in ipairs(edges) do
|
||||||
|
for _, x in ipairs(X) do
|
||||||
|
Y[yi] = x
|
||||||
|
yi = yi + 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
checkshape(Y, self.shape_out)
|
||||||
|
return Y
|
||||||
|
end
|
||||||
|
|
||||||
function Relu:init()
|
function Relu:init()
|
||||||
Layer.init(self, "Relu")
|
Layer.init(self, "Relu")
|
||||||
end
|
end
|
||||||
|
@ -595,6 +639,51 @@ end
|
||||||
--return (dY - np.sum(dY * self.sm, axis=-1, keepdims=True)) * self.cache
|
--return (dY - np.sum(dY * self.sm, axis=-1, keepdims=True)) * self.cache
|
||||||
--end
|
--end
|
||||||
|
|
||||||
|
function Embed:init(vocab, dim)
|
||||||
|
Layer.init(self, "Embed")
|
||||||
|
assert(type(vocab) == "number")
|
||||||
|
assert(type(dim) == "number")
|
||||||
|
self.vocab = vocab
|
||||||
|
self.dim = dim
|
||||||
|
self.lut = self:_new_weights(init_uniform)
|
||||||
|
self.lut.shape = {self.vocab, self.dim}
|
||||||
|
end
|
||||||
|
|
||||||
|
function Embed:make_shape(parent)
|
||||||
|
self.shape_in = parent.shape_out
|
||||||
|
self.shape_out = {parent.shape_out[1] * self.dim}
|
||||||
|
end
|
||||||
|
|
||||||
|
function Embed:reset_cache(bs)
|
||||||
|
self.bs = bs
|
||||||
|
|
||||||
|
self.cache = cache(bs, self.shape_out)
|
||||||
|
self.cache_x = cache(bs, self.shape_in)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Embed:forward(X)
|
||||||
|
local bs = checkshape(X, self.shape_in)
|
||||||
|
if self.bs ~= bs then self:reset_cache(bs) end
|
||||||
|
local Y = self.cache
|
||||||
|
|
||||||
|
for i = 1, #X do
|
||||||
|
-- only needed for backwards pass.
|
||||||
|
self.cache_x[i] = X[i]
|
||||||
|
end
|
||||||
|
|
||||||
|
local yi = 0
|
||||||
|
for i, x in ipairs(X) do
|
||||||
|
local xi = x * self.dim
|
||||||
|
for j = 1, self.dim do
|
||||||
|
Y[yi+j] = self.lut[xi + j]
|
||||||
|
yi = yi + 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
checkshape(Y, self.shape_out)
|
||||||
|
return Y
|
||||||
|
end
|
||||||
|
|
||||||
function Model:init(nodes_in, nodes_out)
|
function Model:init(nodes_in, nodes_out)
|
||||||
assert(#nodes_in > 0, #nodes_in)
|
assert(#nodes_in > 0, #nodes_in)
|
||||||
assert(#nodes_out > 0, #nodes_out)
|
assert(#nodes_out > 0, #nodes_out)
|
||||||
|
@ -735,8 +824,10 @@ return {
|
||||||
Layer = Layer,
|
Layer = Layer,
|
||||||
Model = Model,
|
Model = Model,
|
||||||
Input = Input,
|
Input = Input,
|
||||||
|
Merge = Merge,
|
||||||
Relu = Relu,
|
Relu = Relu,
|
||||||
Gelu = Gelu,
|
Gelu = Gelu,
|
||||||
Dense = Dense,
|
Dense = Dense,
|
||||||
Softmax = Softmax,
|
Softmax = Softmax,
|
||||||
|
Embed = Embed,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue