diff --git a/main.lua b/main.lua index 1df0544..96b1894 100644 --- a/main.lua +++ b/main.lua @@ -52,7 +52,8 @@ local timer_loser = 1/3 local enable_overlay = playable_mode local enable_network = not playable_mode -local input_size = 281 -- TODO: let the script figure this out for us. +local input_size = 60 -- TODO: let the script figure this out for us. +local tile_count = 17 * 13 local ok_routines = { [0x4] = true, -- sliding down flagpole @@ -231,17 +232,17 @@ package.loaded['nn'] = nil -- DEBUG local nn = require("nn") local network -local nn_x -local nn_y -local nn_z +local nn_x, nn_tx, nn_ty, nn_y, nn_z local function make_network(input_size, buttons) nn_x = nn.Input({input_size}) - nn_y = nn_x + nn_tx = nn.Input({tile_count}) + nn_ty = nn_tx:feed(nn.Embed(256, 2)) + nn_y = nn.Merge() + nn_x:feed(nn_y) + nn_ty:feed(nn_y) nn_z = {} - if false then - nn_y = nn_y:feed(nn.Dense(input_size)) - nn_y = nn_y:feed(nn.Gelu()) - else + + if true then nn_y = nn_y:feed(nn.Dense(128)) nn_y = nn_y:feed(nn.Gelu()) nn_y = nn_y:feed(nn.Dense(64)) @@ -249,13 +250,14 @@ local function make_network(input_size, buttons) nn_y = nn_y:feed(nn.Dense(48)) nn_y = nn_y:feed(nn.Gelu()) end + for i = 1, buttons do nn_z[i] = nn_y nn_z[i] = nn_z[i]:feed(nn.Dense(2)) nn_z[i] = nn_z[i]:feed(nn.Softmax()) end - return nn.Model({nn_x}, nn_z) + return nn.Model({nn_x, nn_tx}, nn_z) end -- and here we go with the game stuff. @@ -495,7 +497,7 @@ local function handle_tiles() --local tile_col = R(0x6A0) local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16 local tile_scroll_remainder = R(0x73F) % 16 - tile_input[#tile_input+1] = tile_scroll_remainder + extra_input[#extra_input+1] = tile_scroll_remainder for y = 0, 12 do for x = 0, 16 do local col = (x + tile_scroll) % 32 @@ -787,19 +789,14 @@ local function doit(dummy) local X = {} for i, v in ipairs(sprite_input) do insert(X, v / 256) end - for i, v in ipairs(tile_input) do insert(X, v / 256) end for i, v in ipairs(extra_input) do insert(X, v / 256) end nn.reshape(X, 1, input_size) + nn.reshape(tile_input, 1, tile_count) if enable_network and get_state() == 'playing' or ingame_paused then local choose = deterministic and argmax2 or rchoice2 - local outputs = network:forward({[nn_x]=X}) - - -- TODO: predict the *rewards* of all possible actions? - -- that's how DQN seems to work anyway. - -- ah, but A3C just returns probabilities, - -- besides the critic? + local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input}) local softmaxed = { outputs[nn_z[1]], diff --git a/nn.lua b/nn.lua index ca75d20..4da53aa 100644 --- a/nn.lua +++ b/nn.lua @@ -95,6 +95,11 @@ local function init_zeros(t, fan_in, fan_out) return t end +local function init_uniform(t, fan_in, fan_out) + for i = 1, #t do t[i] = uniform() * 2 - 1 end + return t +end + local function init_he_uniform(t, fan_in, fan_out) local s = sqrt(6 / fan_in) for i = 1, #t do t[i] = (uniform() * 2 - 1) * s end @@ -298,10 +303,12 @@ local Weights = Base:extend() local Layer = Base:extend() local Model = Base:extend() local Input = Layer:extend() +local Merge = Layer:extend() local Relu = Layer:extend() local Gelu = Layer:extend() local Dense = Layer:extend() local Softmax = Layer:extend() +local Embed = Layer:extend() function Weights:init(weight_init) self.weight_init = weight_init @@ -333,7 +340,7 @@ function Layer:make_shape(parent) end function Layer:feed(child) - assert(self.shape_out ~= nil) + assert(self.shape_out ~= nil, "missing output shape: "..self.name) child:make_shape(self) insert(self.children, child) insert(child.parents, self) @@ -360,7 +367,7 @@ end function Layer:get_size() local size = 0 - for i, w in ipairs(self.weights) do size = size + prod(w.size) end + for i, w in ipairs(self.weights) do size = size + prod(w.shape) end return size end @@ -417,6 +424,43 @@ function Input:backward(dY) return zeros(#dY) end +function Merge:init() + Layer.init(self, "Merge") + self.size = 0 + self.shape_in = 0 +end + +function Merge:make_shape(parent) + self.size = self.size + prod(parent.shape_out) + + self.shape_in = self.shape_in + 1 -- TODO: more robust. + self.shape_out = {self.size} +end + +function Merge:reset_cache(bs) + self.bs = bs + + self.cache = cache(bs, self.shape_out) +end + +function Merge:_propagate(edges, deterministic) + assert(#edges == self.shape_in) + local bs = edges[1].shape[1] + if bs ~= self.bs then self:reset_cache(bs) end + local Y = self.cache + + local yi = 1 + for i, X in ipairs(edges) do + for _, x in ipairs(X) do + Y[yi] = x + yi = yi + 1 + end + end + + checkshape(Y, self.shape_out) + return Y +end + function Relu:init() Layer.init(self, "Relu") end @@ -595,6 +639,51 @@ end --return (dY - np.sum(dY * self.sm, axis=-1, keepdims=True)) * self.cache --end +function Embed:init(vocab, dim) + Layer.init(self, "Embed") + assert(type(vocab) == "number") + assert(type(dim) == "number") + self.vocab = vocab + self.dim = dim + self.lut = self:_new_weights(init_uniform) + self.lut.shape = {self.vocab, self.dim} +end + +function Embed:make_shape(parent) + self.shape_in = parent.shape_out + self.shape_out = {parent.shape_out[1] * self.dim} +end + +function Embed:reset_cache(bs) + self.bs = bs + + self.cache = cache(bs, self.shape_out) + self.cache_x = cache(bs, self.shape_in) +end + +function Embed:forward(X) + local bs = checkshape(X, self.shape_in) + if self.bs ~= bs then self:reset_cache(bs) end + local Y = self.cache + + for i = 1, #X do + -- only needed for backwards pass. + self.cache_x[i] = X[i] + end + + local yi = 0 + for i, x in ipairs(X) do + local xi = x * self.dim + for j = 1, self.dim do + Y[yi+j] = self.lut[xi + j] + yi = yi + 1 + end + end + + checkshape(Y, self.shape_out) + return Y +end + function Model:init(nodes_in, nodes_out) assert(#nodes_in > 0, #nodes_in) assert(#nodes_out > 0, #nodes_out) @@ -735,8 +824,10 @@ return { Layer = Layer, Model = Model, Input = Input, + Merge = Merge, Relu = Relu, Gelu = Gelu, Dense = Dense, Softmax = Softmax, + Embed = Embed, }