add and utilize Merge and Embed layers
This commit is contained in:
parent
6f2ffcdef7
commit
acc8378980
2 changed files with 108 additions and 20 deletions
33
main.lua
33
main.lua
|
@ -52,7 +52,8 @@ local timer_loser = 1/3
|
|||
local enable_overlay = playable_mode
|
||||
local enable_network = not playable_mode
|
||||
|
||||
local input_size = 281 -- TODO: let the script figure this out for us.
|
||||
local input_size = 60 -- TODO: let the script figure this out for us.
|
||||
local tile_count = 17 * 13
|
||||
|
||||
local ok_routines = {
|
||||
[0x4] = true, -- sliding down flagpole
|
||||
|
@ -231,17 +232,17 @@ package.loaded['nn'] = nil -- DEBUG
|
|||
local nn = require("nn")
|
||||
|
||||
local network
|
||||
local nn_x
|
||||
local nn_y
|
||||
local nn_z
|
||||
local nn_x, nn_tx, nn_ty, nn_y, nn_z
|
||||
local function make_network(input_size, buttons)
|
||||
nn_x = nn.Input({input_size})
|
||||
nn_y = nn_x
|
||||
nn_tx = nn.Input({tile_count})
|
||||
nn_ty = nn_tx:feed(nn.Embed(256, 2))
|
||||
nn_y = nn.Merge()
|
||||
nn_x:feed(nn_y)
|
||||
nn_ty:feed(nn_y)
|
||||
nn_z = {}
|
||||
if false then
|
||||
nn_y = nn_y:feed(nn.Dense(input_size))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
else
|
||||
|
||||
if true then
|
||||
nn_y = nn_y:feed(nn.Dense(128))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
nn_y = nn_y:feed(nn.Dense(64))
|
||||
|
@ -249,13 +250,14 @@ local function make_network(input_size, buttons)
|
|||
nn_y = nn_y:feed(nn.Dense(48))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
end
|
||||
|
||||
for i = 1, buttons do
|
||||
nn_z[i] = nn_y
|
||||
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
|
||||
nn_z[i] = nn_z[i]:feed(nn.Softmax())
|
||||
end
|
||||
|
||||
return nn.Model({nn_x}, nn_z)
|
||||
return nn.Model({nn_x, nn_tx}, nn_z)
|
||||
end
|
||||
|
||||
-- and here we go with the game stuff.
|
||||
|
@ -495,7 +497,7 @@ local function handle_tiles()
|
|||
--local tile_col = R(0x6A0)
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
tile_input[#tile_input+1] = tile_scroll_remainder
|
||||
extra_input[#extra_input+1] = tile_scroll_remainder
|
||||
for y = 0, 12 do
|
||||
for x = 0, 16 do
|
||||
local col = (x + tile_scroll) % 32
|
||||
|
@ -787,19 +789,14 @@ local function doit(dummy)
|
|||
|
||||
local X = {}
|
||||
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
||||
for i, v in ipairs(tile_input) do insert(X, v / 256) end
|
||||
for i, v in ipairs(extra_input) do insert(X, v / 256) end
|
||||
nn.reshape(X, 1, input_size)
|
||||
nn.reshape(tile_input, 1, tile_count)
|
||||
|
||||
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||
local choose = deterministic and argmax2 or rchoice2
|
||||
|
||||
local outputs = network:forward({[nn_x]=X})
|
||||
|
||||
-- TODO: predict the *rewards* of all possible actions?
|
||||
-- that's how DQN seems to work anyway.
|
||||
-- ah, but A3C just returns probabilities,
|
||||
-- besides the critic?
|
||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
||||
|
||||
local softmaxed = {
|
||||
outputs[nn_z[1]],
|
||||
|
|
95
nn.lua
95
nn.lua
|
@ -95,6 +95,11 @@ local function init_zeros(t, fan_in, fan_out)
|
|||
return t
|
||||
end
|
||||
|
||||
local function init_uniform(t, fan_in, fan_out)
|
||||
for i = 1, #t do t[i] = uniform() * 2 - 1 end
|
||||
return t
|
||||
end
|
||||
|
||||
local function init_he_uniform(t, fan_in, fan_out)
|
||||
local s = sqrt(6 / fan_in)
|
||||
for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end
|
||||
|
@ -298,10 +303,12 @@ local Weights = Base:extend()
|
|||
local Layer = Base:extend()
|
||||
local Model = Base:extend()
|
||||
local Input = Layer:extend()
|
||||
local Merge = Layer:extend()
|
||||
local Relu = Layer:extend()
|
||||
local Gelu = Layer:extend()
|
||||
local Dense = Layer:extend()
|
||||
local Softmax = Layer:extend()
|
||||
local Embed = Layer:extend()
|
||||
|
||||
function Weights:init(weight_init)
|
||||
self.weight_init = weight_init
|
||||
|
@ -333,7 +340,7 @@ function Layer:make_shape(parent)
|
|||
end
|
||||
|
||||
function Layer:feed(child)
|
||||
assert(self.shape_out ~= nil)
|
||||
assert(self.shape_out ~= nil, "missing output shape: "..self.name)
|
||||
child:make_shape(self)
|
||||
insert(self.children, child)
|
||||
insert(child.parents, self)
|
||||
|
@ -360,7 +367,7 @@ end
|
|||
|
||||
function Layer:get_size()
|
||||
local size = 0
|
||||
for i, w in ipairs(self.weights) do size = size + prod(w.size) end
|
||||
for i, w in ipairs(self.weights) do size = size + prod(w.shape) end
|
||||
return size
|
||||
end
|
||||
|
||||
|
@ -417,6 +424,43 @@ function Input:backward(dY)
|
|||
return zeros(#dY)
|
||||
end
|
||||
|
||||
function Merge:init()
|
||||
Layer.init(self, "Merge")
|
||||
self.size = 0
|
||||
self.shape_in = 0
|
||||
end
|
||||
|
||||
function Merge:make_shape(parent)
|
||||
self.size = self.size + prod(parent.shape_out)
|
||||
|
||||
self.shape_in = self.shape_in + 1 -- TODO: more robust.
|
||||
self.shape_out = {self.size}
|
||||
end
|
||||
|
||||
function Merge:reset_cache(bs)
|
||||
self.bs = bs
|
||||
|
||||
self.cache = cache(bs, self.shape_out)
|
||||
end
|
||||
|
||||
function Merge:_propagate(edges, deterministic)
|
||||
assert(#edges == self.shape_in)
|
||||
local bs = edges[1].shape[1]
|
||||
if bs ~= self.bs then self:reset_cache(bs) end
|
||||
local Y = self.cache
|
||||
|
||||
local yi = 1
|
||||
for i, X in ipairs(edges) do
|
||||
for _, x in ipairs(X) do
|
||||
Y[yi] = x
|
||||
yi = yi + 1
|
||||
end
|
||||
end
|
||||
|
||||
checkshape(Y, self.shape_out)
|
||||
return Y
|
||||
end
|
||||
|
||||
function Relu:init()
|
||||
Layer.init(self, "Relu")
|
||||
end
|
||||
|
@ -595,6 +639,51 @@ end
|
|||
--return (dY - np.sum(dY * self.sm, axis=-1, keepdims=True)) * self.cache
|
||||
--end
|
||||
|
||||
function Embed:init(vocab, dim)
|
||||
Layer.init(self, "Embed")
|
||||
assert(type(vocab) == "number")
|
||||
assert(type(dim) == "number")
|
||||
self.vocab = vocab
|
||||
self.dim = dim
|
||||
self.lut = self:_new_weights(init_uniform)
|
||||
self.lut.shape = {self.vocab, self.dim}
|
||||
end
|
||||
|
||||
function Embed:make_shape(parent)
|
||||
self.shape_in = parent.shape_out
|
||||
self.shape_out = {parent.shape_out[1] * self.dim}
|
||||
end
|
||||
|
||||
function Embed:reset_cache(bs)
|
||||
self.bs = bs
|
||||
|
||||
self.cache = cache(bs, self.shape_out)
|
||||
self.cache_x = cache(bs, self.shape_in)
|
||||
end
|
||||
|
||||
function Embed:forward(X)
|
||||
local bs = checkshape(X, self.shape_in)
|
||||
if self.bs ~= bs then self:reset_cache(bs) end
|
||||
local Y = self.cache
|
||||
|
||||
for i = 1, #X do
|
||||
-- only needed for backwards pass.
|
||||
self.cache_x[i] = X[i]
|
||||
end
|
||||
|
||||
local yi = 0
|
||||
for i, x in ipairs(X) do
|
||||
local xi = x * self.dim
|
||||
for j = 1, self.dim do
|
||||
Y[yi+j] = self.lut[xi + j]
|
||||
yi = yi + 1
|
||||
end
|
||||
end
|
||||
|
||||
checkshape(Y, self.shape_out)
|
||||
return Y
|
||||
end
|
||||
|
||||
function Model:init(nodes_in, nodes_out)
|
||||
assert(#nodes_in > 0, #nodes_in)
|
||||
assert(#nodes_out > 0, #nodes_out)
|
||||
|
@ -735,8 +824,10 @@ return {
|
|||
Layer = Layer,
|
||||
Model = Model,
|
||||
Input = Input,
|
||||
Merge = Merge,
|
||||
Relu = Relu,
|
||||
Gelu = Gelu,
|
||||
Dense = Dense,
|
||||
Softmax = Softmax,
|
||||
Embed = Embed,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue