diff --git a/main.lua b/main.lua index fa98ce0..14acbbd 100644 --- a/main.lua +++ b/main.lua @@ -37,12 +37,12 @@ local det_epsilon = true -- take random actions with probability eps. local eps_start = 1.0 * 1/60 -- i think this should be * 1/16 for atari ref. local eps_stop = 0.1 * 1/60 -- " local eps_frames = 1000000 -local consider_past_rewards = false local learn_start_select = false -- local epoch_trials = 40 +local negate_trials = true -- try pairs of normal and negated noise directions. local unperturbed_trial = true -- do a trial without any noise. -local learning_rate = 0.3 -- bigger now that i'm shaping trials etc. +local learning_rate = 0.3 -- bigger now that i'm stealing code etc. local deviation = 0.05 -- local cap_time = 400 @@ -231,7 +231,7 @@ local nn_x local nn_y local nn_z local function make_network(input_size, buttons) - nn_x = nn.Input(input_size) + nn_x = nn.Input({input_size}) nn_y = nn_x nn_z = {} if false then @@ -526,7 +526,11 @@ local function prepare_epoch() -- (the os.time() as of here) and calling nn.normal() each trial. for i = 1, epoch_trials do local noise = nn.zeros(#base_params) - for j = 1, #base_params do noise[j] = nn.normal() end + if negate_trials and i % 2 == 0 then -- every other iteration... + for j = 1, #base_params do noise[j] = -trial_noise[i-1][j] end + else + for j = 1, #base_params do noise[j] = nn.normal() end + end trial_noise[i] = noise end trial_i = -1 @@ -535,8 +539,6 @@ end local function load_next_trial() trial_i = trial_i + 1 local W = nn.copy(base_params) - local noise = trial_noise[trial_i] - local devsqrt = sqrt(deviation) if trial_i == 0 and not unperturbed_trial then trial_i = 1 end @@ -824,7 +826,7 @@ while true do for i, v in ipairs(sprite_input) do insert(X, v / 256) end for i, v in ipairs(tile_input) do insert(X, v / 256) end for i, v in ipairs(extra_input) do insert(X, v / 256) end - if #X ~= input_size then error("input size should be: "..tostring(#X)) end + nn.reshape(X, 1, input_size) if enable_network and get_state() == 'playing' or ingame_paused then local choose = deterministic and argmax2 or rchoice2 @@ -858,7 +860,7 @@ while true do select = choose(softmaxed[8]), } - if det_epsilon then + if det_epsilon then --and not trial_i == 0 then local eps = lerp(eps_start, eps_stop, total_frames / eps_frames) for k, v in pairs(jp) do local ss_ok = k ~= 'start' and k ~= 'select' or learn_start_select diff --git a/nn.lua b/nn.lua index 652f4c1..2668592 100644 --- a/nn.lua +++ b/nn.lua @@ -1,5 +1,7 @@ +local ceil = math.ceil local cos = math.cos local exp = math.exp +local floor = math.floor local insert = table.insert local ipairs = ipairs local log = math.log @@ -18,6 +20,9 @@ local unpack = table.unpack or unpack local Base = require("Base") +-- hacks +local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) end + -- general utilities local function copy(t) -- shallow copy @@ -54,13 +59,23 @@ local function normal() -- box muller end local function zeros(n, out) - local out = out or {} + out = out or {} + if type(n) == 'table' then + local shape = n + n = prod(shape) + out.shape = shape + end for i = 1, n do out[i] = 0 end return out end local function arange(n, out) out = out or {} + if type(n) == 'table' then + local shape = n + n = prod(shape) + out.shape = shape + end for i = 1, n do out[i] = i - 1 end return out end @@ -92,6 +107,143 @@ local function init_he_normal(t, fan_in, fan_out) return t end +-- ndarray-ish stuff and more involved math + +local function pp(t, fmt, sep, ti, di, depth, isfirst, islast) + -- pretty-prints an nd-array. + fmt = fmt or '%10.7f,' + sep = sep or ',' + ti = ti or 0 + di = di or 1 + depth = depth or 0 + + if t.shape == nil then + local s = '[' + for i = 1, #t do s = s..fmt:format(t[i]) end + return s..']'..sep..'\n' + end + + local dim = t.shape[di] + + local ti_step = 1 + for dj = di + 1, #t.shape do ti_step = ti_step * t.shape[dj] end + + local indent = '' + for i = 1, depth do indent = indent..' ' end + + local s = '' + if di ~= #t.shape then + if isfirst then s = s..indent..'[\n' else s = s..'[\n' end + for i = 1, dim do + s = s..pp(t, fmt, sep, ti, di + 1, depth + 1, i == 1, i == dim) + ti = ti + ti_step + end + if islast then s = s..indent..']'..sep..'\n' else s = s..indent..']'..sep end + else + s = s..indent..'[' + for i = ti + 1, ti + dim do s = s..fmt:format(t[i])..sep end + s = s..']'..sep..'\n' + end + return s +end + +local function ppi(t, n, ...) + -- TODO: determine maximum number of digits if n is omitted. + n = n or 1 + return pp(t, '%'..tostring(n)..'i', ' ', ...) +end + +local function checkshape_helper(shape, isbatch) + local s = '{ ' + if not isbatch then + s = s..'n, ' + end + for i, v in ipairs(shape) do + if not isbatch or i > 1 then + s = s..tostring(v)..(i ~= #shape and ', ' or ' ') + end + end + return s..'}' +end + +local function checkshape(batch, shape) + assert(type(batch) == 'table', "batch is not an array") + assert(batch.shape ~= nil, "batch is missing a shape") + if #batch.shape == 1 then + error("batch shape is incomplete", 2) + end + for n=1, #shape do + if batch.shape[n+1] ~= shape[n] then + local s1 = checkshape_helper(batch.shape, true) + local s2 = checkshape_helper(shape, false) + error("shapes do not match: "..s1.." ~= "..s2, 2) + end + end + return batch.shape[1] +end + +local function reshape(a, ...) + local new_shape = {...} + assert(#a == prod(new_shape), "new shape does not fit size") + a.shape = new_shape + return a +end + +local function cache(bs, shape) + if bs == nil then return nil end + local fullshape = copy(shape) + insert(fullshape, bs, 1) + return zeros(fullshape) +end + +local function dot(a, b, ax_a, ax_b, out) + ax_a = ax_a or #a.shape - 0 + ax_b = ax_b or #b.shape - 1 + + assert(a.shape[ax_a] == b.shape[ax_b], "dotted axes do not match") + local dim = a.shape[ax_a] + + local out_shape = {} + for di = 1, #a.shape do if di ~= ax_a then insert(out_shape, a.shape[di]) end end + for di = 1, #b.shape do if di ~= ax_b then insert(out_shape, b.shape[di]) end end + + if out == nil then + out = zeros(prod(out_shape)) + else + assert(prod(out_shape) == #out, "given output is the wrong size") + end + out.shape = out_shape + + local a0 = 1 + local a1 = 1 + local b0 = 1 + local b1 = 1 + for di = 1, ax_a - 1 do a0 = a0 * a.shape[di] end + for di = 1, ax_b - 1 do b0 = b0 * b.shape[di] end + for di = ax_a + 1, #a.shape do a1 = a1 * a.shape[di] end + for di = ax_b + 1, #b.shape do b1 = b1 * b.shape[di] end + + local o = 1 + local i_end = a0 * dim - 1 + local k_end = b0 * dim - 1 + for i = 0, i_end, dim do for j = 1, a1 do + for k = 0, k_end, dim do for m = 1, b1 do + local res = 0 + local x = i + j + local y = k + m + for d = 1, dim do + res = res + a[x] * b[y] + x = x + a1 + y = y + b1 + end + out[o] = res + o = o + 1 + end end + end end + + return out +end + -- nodal local function traverse(node_in, node_out, nodes, dummy_mode) @@ -171,17 +323,17 @@ function Layer:init(name) self.parents = {} self.children = {} self.weights = {} - --self.size_in = nil - --self.size_out = nil + --self.shape_in = nil + --self.shape_out = nil end function Layer:make_shape(parent) - if self.size_in == nil then self.size_in = parent.size_out end - if self.size_out == nil then self.size_out = self.size_in end + if self.shape_in == nil then self.shape_in = parent.shape_out end + if self.shape_out == nil then self.shape_out = self.shape_in end end function Layer:feed(child) - assert(self.size_out ~= nil) + assert(self.shape_out ~= nil) child:make_shape(self) insert(self.children, child) insert(child.parents, self) @@ -216,12 +368,18 @@ function Layer:init_weights() for i, w in ipairs(self.weights) do --print("allocating weights", i, "of", self.name) for j, v in ipairs(w) do w[j] = nil end -- FIXME: HACK - w:allocate(self.size_in, self.size_out) + w:allocate(prod(self.shape_in), prod(self.shape_out)) end + self:reset_cache() +end + +function Layer:reset_cache(bs) + self.bs = bs end function Layer:_propagate(edges, deterministic) - assert(#edges == 1, #edges) -- override this if you need multiple parents. + -- override this if you need multiple parents. + assert(#edges == 1, ("%s edges for node %s (expected 1)"):format(#edges, self.name)) if deterministic then return self:forward_deterministic(edges[1]) else @@ -237,25 +395,25 @@ function Layer:propagate(values, deterministic) insert(edges, X) end end - assert(#edges > 0, #edges) + assert(#edges > 0, ("%s edges for node %s (expected >0)"):format(#edges, self.name)) local Y = self:_propagate(edges, deterministic) return Y end -function Input:init(size) +function Input:init(shape) Layer.init(self, "Input") - assert(type(size) == 'number') - self.size_in = size - self.size_out = size + assert(type(shape) == 'table') + self.shape_in = shape + self.shape_out = shape end function Input:forward(X) - assert(#X == self.size_in) + checkshape(X, self.shape_in) return X end function Input:backward(dY) - assert(#dY == self.size_out) + checkshape(dY, self.shape_out) return zeros(#dY) end @@ -263,99 +421,177 @@ function Relu:init() Layer.init(self, "Relu") end +function Relu:reset_cache(bs) + print("clearing cache:", self.name, bs) + self.bs = bs + + self.cache = cache(bs, self.shape_out) + self.dcache = cache(bs, self.shape_in) +end + function Relu:forward(X) - assert(#X == self.size_in) - self.cache = self.cache or zeros(self.size_out) + local bs = checkshape(X, self.shape_in) local Y = self.cache for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end - assert(#Y == self.size_out) + checkshape(Y, self.shape_out) return Y end function Relu:backward(dY) - assert(#dY == self.size_out) - self.dcache = self.dcache or zeros(self.size_in) + local bs = checkshape(dY, self.shape_out) local Y = self.cache local dX = self.dcache for i = 1, #dY do dX[i] = Y[i] >= 0 and dY[i] or 0 end - assert(#Y == self.size_in) - return Y + checkshape(dX, self.shape_in) + return dX end function Gelu:init() Layer.init(self, "Gelu") end +function Gelu:reset_cache(bs) + print("clearing cache:", self.name, bs) + self.bs = bs + + self.cache = cache(bs, self.shape_out) + self.cache_a = cache(bs, self.shape_out) + self.cache_sig = cache(bs, self.shape_out) + self.dcache = cache(bs, self.shape_in) +end + function Gelu:forward(X) - assert(#X == self.size_in) - self.cache = self.cache or zeros(self.size_out) + local bs = checkshape(X, self.shape_in) + if bs ~= self.bs then self:reset_cache(bs) end local Y = self.cache + local a = self.cache_a + local sig = self.cache_sig -- NOTE: approximate form of GELU exploiting similarities to sigmoid curve. for i = 1, #X do - Y[i] = X[i] / (1 + exp(-1.704 * X[i])) + a[i] = 1.704 * X[i] + sig[i] = 1 / (1 + exp(-a[i])) + Y[i] = X[i] * sig[i] end - assert(#Y == self.size_out) + checkshape(Y, self.shape_out) return Y end +function Gelu:backward(dY) + checkshape(dY, self.shape_out) + local Y = self.cache + local a = self.cache_a + local sig = self.cache_sig + local dX = self.dcache + + for i = 1, #dY do + dX[i] = dY[i] * sig[i] * (1 + a[i] * (1 - sig[i])) + end + + checkshape(dX, self.shape_in) + return dX +end + function Dense:init(dim) Layer.init(self, "Dense") assert(type(dim) == "number") self.dim = dim - self.size_out = dim + self.shape_out = {dim} self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but... self.biases = self:_new_weights(init_zeros) end function Dense:make_shape(parent) - self.size_in = parent.size_out - self.coeffs.shape = {self.size_in, self.dim} - self.biases.shape = self.dim + self.shape_in = parent.shape_out + self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim} + self.biases.shape = {1, self.dim} +end + +function Dense:reset_cache(bs) + print("clearing cache:", self.name, bs) + self.bs = bs + + self.cache = cache(bs, self.shape_out) + self.cache_x = cache(bs, self.shape_in) + self.dcache = cache(bs, self.shape_in) end function Dense:forward(X) - assert(#X == self.size_in) - self.cache = self.cache or zeros(self.size_out) + local bs = checkshape(X, self.shape_in) + if self.bs ~= bs then self:reset_cache(bs) end local Y = self.cache - for i = 1, self.dim do - local res = 0 - local c = (i - 1) * #X - for j = 1, #X do - res = res + X[j] * self.coeffs[c + j] - end - Y[i] = res + self.biases[i] + for i = 1, #X do + -- only needed for backwards pass. + self.cache_x[i] = X[i] end - assert(#Y == self.size_out) + --dot_1aab(X, self.coeffs, Y) + dot(X, self.coeffs, 2, 1, Y) + + for i = 1, self.dim do + Y[i] = Y[i] + self.biases[i] + end + + checkshape(Y, self.shape_out) return Y end +function Dense:backward(dY) + local X = self.cache_x + local dX = self.dcache + + --dot_ta(X, dY, self.coeffs.g) + dot(X, dY, 1, 1, self.coeffs.g) + + for b = 1, X.shape[1] do + local l = X.shape[2] + local j = (b - 1) * l + for i = 1, l do self.biases.g[i] = self.biases.g[i] + dY[j-1+i] end + end + + --dot_tb(dY, self.coeffs, dX) + dot(dY, self.coeffs, 2, 2, dX) + + checkshape(dX, self.shape_in) + return dX +end + function Softmax:init() Layer.init(self, "Softmax") end +function Softmax:reset_cache(bs) + print("clearing cache:", self.name, bs) + self.bs = bs + + self.cache = cache(bs, self.shape_out) +end + function Softmax:forward(X) - assert(#X == self.size_in) - self.cache = self.cache or zeros(self.size_out) + local bs = checkshape(X, self.shape_in) + if self.bs ~= bs then self:reset_cache(bs) end local Y = self.cache local alpha = 0 local num = {} -- TODO: cache local den = 0 - for i = 1, #X do alpha = max(alpha, X[i]) end - for i = 1, #X do num[i] = exp(X[i] - alpha) end - for i = 1, #X do den = den + num[i] end - for i = 1, #X do Y[i] = num[i] / den end + for b = 1, X.shape[1] do + local l = X.shape[2] + local j = (b - 1) * l + for i = j+1, j+l do alpha = max(alpha, X[i]) end + for i = j+1, j+l do num[i] = exp(X[i] - alpha) end + for i = j+1, j+l do den = den + num[i] end + for i = j+1, j+l do Y[i] = num[i] / den end + end - assert(#Y == self.size_out) + checkshape(Y, self.shape_out) return Y end @@ -392,6 +628,7 @@ function Model:forward(inputs) if contains(self.nodes_in, node) then local X = inputs[node] assert(X ~= nil, ("missing input for node %s"):format(node.name)) + assert(X.shape, ("missing shape for node %s"):format(node.name)) values[node] = node:_propagate({X}) else values[node] = node:propagate(values) @@ -491,6 +728,10 @@ return { init_zeros = init_zeros, init_he_uniform = init_he_uniform, init_he_normal = init_he_normal, + reshape = reshape, + pp = pp, + ppi = ppi, + dot = dot, traverse = traverse, traverse_all = traverse_all,