add and utilize Merge and Embed layers

This commit is contained in:
Connor Olding 2017-09-07 23:06:30 +00:00
parent 6f2ffcdef7
commit acc8378980
2 changed files with 108 additions and 20 deletions

View file

@ -52,7 +52,8 @@ local timer_loser = 1/3
local enable_overlay = playable_mode local enable_overlay = playable_mode
local enable_network = not playable_mode local enable_network = not playable_mode
local input_size = 281 -- TODO: let the script figure this out for us. local input_size = 60 -- TODO: let the script figure this out for us.
local tile_count = 17 * 13
local ok_routines = { local ok_routines = {
[0x4] = true, -- sliding down flagpole [0x4] = true, -- sliding down flagpole
@ -231,17 +232,17 @@ package.loaded['nn'] = nil -- DEBUG
local nn = require("nn") local nn = require("nn")
local network local network
local nn_x local nn_x, nn_tx, nn_ty, nn_y, nn_z
local nn_y
local nn_z
local function make_network(input_size, buttons) local function make_network(input_size, buttons)
nn_x = nn.Input({input_size}) nn_x = nn.Input({input_size})
nn_y = nn_x nn_tx = nn.Input({tile_count})
nn_ty = nn_tx:feed(nn.Embed(256, 2))
nn_y = nn.Merge()
nn_x:feed(nn_y)
nn_ty:feed(nn_y)
nn_z = {} nn_z = {}
if false then
nn_y = nn_y:feed(nn.Dense(input_size)) if true then
nn_y = nn_y:feed(nn.Gelu())
else
nn_y = nn_y:feed(nn.Dense(128)) nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Gelu()) nn_y = nn_y:feed(nn.Gelu())
nn_y = nn_y:feed(nn.Dense(64)) nn_y = nn_y:feed(nn.Dense(64))
@ -249,13 +250,14 @@ local function make_network(input_size, buttons)
nn_y = nn_y:feed(nn.Dense(48)) nn_y = nn_y:feed(nn.Dense(48))
nn_y = nn_y:feed(nn.Gelu()) nn_y = nn_y:feed(nn.Gelu())
end end
for i = 1, buttons do for i = 1, buttons do
nn_z[i] = nn_y nn_z[i] = nn_y
nn_z[i] = nn_z[i]:feed(nn.Dense(2)) nn_z[i] = nn_z[i]:feed(nn.Dense(2))
nn_z[i] = nn_z[i]:feed(nn.Softmax()) nn_z[i] = nn_z[i]:feed(nn.Softmax())
end end
return nn.Model({nn_x}, nn_z) return nn.Model({nn_x, nn_tx}, nn_z)
end end
-- and here we go with the game stuff. -- and here we go with the game stuff.
@ -495,7 +497,7 @@ local function handle_tiles()
--local tile_col = R(0x6A0) --local tile_col = R(0x6A0)
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16 local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
local tile_scroll_remainder = R(0x73F) % 16 local tile_scroll_remainder = R(0x73F) % 16
tile_input[#tile_input+1] = tile_scroll_remainder extra_input[#extra_input+1] = tile_scroll_remainder
for y = 0, 12 do for y = 0, 12 do
for x = 0, 16 do for x = 0, 16 do
local col = (x + tile_scroll) % 32 local col = (x + tile_scroll) % 32
@ -787,19 +789,14 @@ local function doit(dummy)
local X = {} local X = {}
for i, v in ipairs(sprite_input) do insert(X, v / 256) end for i, v in ipairs(sprite_input) do insert(X, v / 256) end
for i, v in ipairs(tile_input) do insert(X, v / 256) end
for i, v in ipairs(extra_input) do insert(X, v / 256) end for i, v in ipairs(extra_input) do insert(X, v / 256) end
nn.reshape(X, 1, input_size) nn.reshape(X, 1, input_size)
nn.reshape(tile_input, 1, tile_count)
if enable_network and get_state() == 'playing' or ingame_paused then if enable_network and get_state() == 'playing' or ingame_paused then
local choose = deterministic and argmax2 or rchoice2 local choose = deterministic and argmax2 or rchoice2
local outputs = network:forward({[nn_x]=X}) local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
-- TODO: predict the *rewards* of all possible actions?
-- that's how DQN seems to work anyway.
-- ah, but A3C just returns probabilities,
-- besides the critic?
local softmaxed = { local softmaxed = {
outputs[nn_z[1]], outputs[nn_z[1]],

95
nn.lua
View file

@ -95,6 +95,11 @@ local function init_zeros(t, fan_in, fan_out)
return t return t
end end
local function init_uniform(t, fan_in, fan_out)
for i = 1, #t do t[i] = uniform() * 2 - 1 end
return t
end
local function init_he_uniform(t, fan_in, fan_out) local function init_he_uniform(t, fan_in, fan_out)
local s = sqrt(6 / fan_in) local s = sqrt(6 / fan_in)
for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end
@ -298,10 +303,12 @@ local Weights = Base:extend()
local Layer = Base:extend() local Layer = Base:extend()
local Model = Base:extend() local Model = Base:extend()
local Input = Layer:extend() local Input = Layer:extend()
local Merge = Layer:extend()
local Relu = Layer:extend() local Relu = Layer:extend()
local Gelu = Layer:extend() local Gelu = Layer:extend()
local Dense = Layer:extend() local Dense = Layer:extend()
local Softmax = Layer:extend() local Softmax = Layer:extend()
local Embed = Layer:extend()
function Weights:init(weight_init) function Weights:init(weight_init)
self.weight_init = weight_init self.weight_init = weight_init
@ -333,7 +340,7 @@ function Layer:make_shape(parent)
end end
function Layer:feed(child) function Layer:feed(child)
assert(self.shape_out ~= nil) assert(self.shape_out ~= nil, "missing output shape: "..self.name)
child:make_shape(self) child:make_shape(self)
insert(self.children, child) insert(self.children, child)
insert(child.parents, self) insert(child.parents, self)
@ -360,7 +367,7 @@ end
function Layer:get_size() function Layer:get_size()
local size = 0 local size = 0
for i, w in ipairs(self.weights) do size = size + prod(w.size) end for i, w in ipairs(self.weights) do size = size + prod(w.shape) end
return size return size
end end
@ -417,6 +424,43 @@ function Input:backward(dY)
return zeros(#dY) return zeros(#dY)
end end
function Merge:init()
Layer.init(self, "Merge")
self.size = 0
self.shape_in = 0
end
function Merge:make_shape(parent)
self.size = self.size + prod(parent.shape_out)
self.shape_in = self.shape_in + 1 -- TODO: more robust.
self.shape_out = {self.size}
end
function Merge:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
end
function Merge:_propagate(edges, deterministic)
assert(#edges == self.shape_in)
local bs = edges[1].shape[1]
if bs ~= self.bs then self:reset_cache(bs) end
local Y = self.cache
local yi = 1
for i, X in ipairs(edges) do
for _, x in ipairs(X) do
Y[yi] = x
yi = yi + 1
end
end
checkshape(Y, self.shape_out)
return Y
end
function Relu:init() function Relu:init()
Layer.init(self, "Relu") Layer.init(self, "Relu")
end end
@ -595,6 +639,51 @@ end
--return (dY - np.sum(dY * self.sm, axis=-1, keepdims=True)) * self.cache --return (dY - np.sum(dY * self.sm, axis=-1, keepdims=True)) * self.cache
--end --end
function Embed:init(vocab, dim)
Layer.init(self, "Embed")
assert(type(vocab) == "number")
assert(type(dim) == "number")
self.vocab = vocab
self.dim = dim
self.lut = self:_new_weights(init_uniform)
self.lut.shape = {self.vocab, self.dim}
end
function Embed:make_shape(parent)
self.shape_in = parent.shape_out
self.shape_out = {parent.shape_out[1] * self.dim}
end
function Embed:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
self.cache_x = cache(bs, self.shape_in)
end
function Embed:forward(X)
local bs = checkshape(X, self.shape_in)
if self.bs ~= bs then self:reset_cache(bs) end
local Y = self.cache
for i = 1, #X do
-- only needed for backwards pass.
self.cache_x[i] = X[i]
end
local yi = 0
for i, x in ipairs(X) do
local xi = x * self.dim
for j = 1, self.dim do
Y[yi+j] = self.lut[xi + j]
yi = yi + 1
end
end
checkshape(Y, self.shape_out)
return Y
end
function Model:init(nodes_in, nodes_out) function Model:init(nodes_in, nodes_out)
assert(#nodes_in > 0, #nodes_in) assert(#nodes_in > 0, #nodes_in)
assert(#nodes_out > 0, #nodes_out) assert(#nodes_out > 0, #nodes_out)
@ -735,8 +824,10 @@ return {
Layer = Layer, Layer = Layer,
Model = Model, Model = Model,
Input = Input, Input = Input,
Merge = Merge,
Relu = Relu, Relu = Relu,
Gelu = Gelu, Gelu = Gelu,
Dense = Dense, Dense = Dense,
Softmax = Softmax, Softmax = Softmax,
Embed = Embed,
} }