diff --git a/config.lua b/config.lua index 5821ef7..756af60 100644 --- a/config.lua +++ b/config.lua @@ -12,48 +12,62 @@ local function intmap(x) return math.pow(10, x / 2) end -local cfg = { - log_fn = 'log.csv', -- can be nil to disable logging. - +local common_cfg = { defer_prints = true, playable_mode = false, - start_big = false, --true - starting_lives = 0, --1 + playback_mode = false, + start_big = false, + starting_lives = 0, - init_zeros = true, -- instead of he_normal noise or whatever. frameskip = 4, -- true greedy epsilon has both deterministic and det_epsilon set. deterministic = false, -- use argmax on outputs instead of random sampling. det_epsilon = false, -- take random actions with probability eps. + layernorm = false, + init_zeros = true, -- instead of he_normal noise or whatever. graycode = false, - epoch_trials = 64 * (7/8), - epoch_top_trials = 40 * (7/8), -- new with ARS. unperturbed_trial = true, -- do a trial without any noise. negate_trials = true, -- try pairs of normal and negated noise directions. - time_inputs = true, -- binary inputs of global frame count -- ^ note that this now doubles the effective trials. - deviation = intmap(-3), - learning_rate = intmap(-4), - weight_decay = intmap(-6), + time_inputs = true, -- binary inputs of global frame count - adamant = true, -- run steps through AMSgrad. - adam_b1 = math.pow(10, -1 / 15), - adam_b2 = math.pow(10, -1 / 100), - adam_eps = intmap(-8), - adam_debias = false, + adamant = false, -- run steps through AMSgrad. - cap_time = 222, --400 + cap_time = 300, timer_loser = 1/2, - decrement_reward = false, -- bad idea, encourages mario to kill himself + decrement_reward = false, -- bad idea, encourages mario to run into goombas. +} - playback_mode = false, +local cfg = { + log_fn = 'log.csv', -- can be nil to disable logging. + params_fn = nil, -- can be nil to generate based on param count. + + deterministic = true, + + epoch_trials = 20, + epoch_top_trials = 10, + learning_rate = 1.0, + + deviation = 0.1, + weight_decay = 0.004, + + cap_time = 300, + starting_lives = 1, } -- TODO: so, uhh.. -- what happens when playback_mode is true but unperturbed_trial is false? +setmetatable(cfg, { + __index = function(t, n) + if common_cfg[n] ~= nil then return common_cfg[n] end + if n == 'params_fn' then return nil end + error("cannot use undeclared config '" .. tostring(n) .. "'", 2) + end +}) + cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials) cfg.eps_start = 1.0 * cfg.frameskip / 64 @@ -62,8 +76,7 @@ cfg.eps_frames = 1000000 cfg.enable_overlay = cfg.playable_mode cfg.enable_network = not cfg.playable_mode -return setmetatable(cfg, { - __index = function(t, n) - error("cannot use undeclared config '" .. tostring(n) .. "'", 2) - end -}) +assert(not cfg.ars_lips or cfg.unperturbed_trial, + "cfg.unperturbed_trial must be true to use cfg.ars_lips") + +return cfg diff --git a/main.lua b/main.lua index a6f3e6c..a93d48c 100644 --- a/main.lua +++ b/main.lua @@ -227,6 +227,7 @@ local function make_network(input_size) nn_ty:feed(nn_y) nn_y = nn_y:feed(nn.Dense(128)) + if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end if cfg.deterministic then nn_y = nn_y:feed(nn.Relu()) else diff --git a/nn.lua b/nn.lua index 49e7a63..16e5c88 100644 --- a/nn.lua +++ b/nn.lua @@ -311,6 +311,7 @@ local Gelu = Layer:extend() local Dense = Layer:extend() local Softmax = Layer:extend() local Embed = Layer:extend() +local LayerNorm = Layer:extend() function Weights:init(weight_init) self.weight_init = weight_init @@ -618,6 +619,42 @@ function Embed:forward(X) return Y end +function LayerNorm:init(eps) + Layer.init(self, "LayerNorm") + if eps == nil then eps = 1e-5 end + assert(type(eps) == "number") + self.eps = eps +end + +function LayerNorm:reset_cache(bs) + self.bs = bs + self.cache = cache(bs, self.shape_out) +end + +function LayerNorm:forward(X) + local bs = checkshape(X, self.shape_in) + if self.bs ~= bs then self:reset_cache(bs) end + + local mean = 0 + for i, v in ipairs(X) do + mean = mean + v / #X + end + + local var = 0 + for i, v in ipairs(X) do + local delta = v - mean + self.cache[i] = delta + var = var + delta * delta / #X + end + + local std = sqrt(var + self.eps) + for i, v in ipairs(self.cache) do + self.cache[i] = v / std + end + + return self.cache +end + function Model:init(nodes_in, nodes_out) assert(#nodes_in > 0, #nodes_in) assert(#nodes_out > 0, #nodes_out) @@ -764,14 +801,15 @@ return { traverse = traverse, traverse_all = traverse_all, - Weights = Weights, - Layer = Layer, - Model = Model, - Input = Input, - Merge = Merge, - Relu = Relu, - Gelu = Gelu, - Dense = Dense, - Softmax = Softmax, - Embed = Embed, + Weights = Weights, + Layer = Layer, + Model = Model, + Input = Input, + Merge = Merge, + Relu = Relu, + Gelu = Gelu, + Dense = Dense, + Softmax = Softmax, + Embed = Embed, + LayerNorm = LayerNorm, }