diff --git a/config.lua b/config.lua index 13d04d5..4dacf7e 100644 --- a/config.lua +++ b/config.lua @@ -1,4 +1,4 @@ -local preset = 'snes' +preset = rawget(_G, 'preset') or 'ars' local common_cfg = { -- read-only modes: @@ -26,6 +26,8 @@ local common_cfg = { -- network layers: layernorm = false, + reduce_tiles = false, + bias_out = true, -- network evaluation (sampling joypad): frameskip = 4, @@ -39,6 +41,7 @@ local common_cfg = { epoch_top_trials = 9999, -- for ARS. -- sampling: + deviation = 1.0, unperturbed_trial = true, -- perform an extra trial without any noise. graycode = false, -- for ARS. negate_trials = true, -- try pairs of normal and negated noise directions. @@ -47,7 +50,7 @@ local common_cfg = { -- epoch-related rates: learning_rate = 1.0, - mean_adapt = 1.0, -- for xNES. + mean_adapt = 1.0, -- for SNES, xNES. weight_decay = 0.0, sigma_decay = 0.0, -- for SNES. } @@ -68,15 +71,49 @@ if preset == 'snes' then epoch_trials = 100, negate_trials = false, + deviation = 1.0, min_refresh = 0.2, learning_rate = 0.1, -- TODO: rename to learn_primary or something. mean_adapt = 0.5, -- TODO: rename to learn_secondary or something. - deviation = 1.0, - weight_decay = 0.01, -- note: multiplied by its std, and mean_adapt. + weight_decay = 0.02, -- note: multiplied by its std, and mean_adapt. sigma_decay = 0.01, -- note: multiplied by learning_rate. } +elseif preset == 'xnes' then + + cfg = { + es = 'xnes', + + log_fn = 'logs-xnes.csv', + params_fn = 'params-xnes.txt', + + start_big = true, + min_time = 300, + timer_loser = 1.0, + + decrement_reward = false, + score_multiplier = 0, + + init_zeros = true, + + reduce_tiles = true, + bias_out = false, + + deterministic = false, + + deviation = 0.5, + negate_trials = false, + min_refresh = 0.1, + + epoch_trials = 50, + + learning_rate = 0.01, + mean_adapt = 1.0, + weight_decay = 0.0, + sigma_decay = 0.0, + } + elseif preset == 'ars' then cfg = { @@ -90,10 +127,11 @@ elseif preset == 'ars' then min_time = 300, timer_loser = 1.0, + deviation = 0.1, + epoch_trials = 25, learning_rate = 1.0, - deviation = 0.1, weight_decay = 0.0025, } diff --git a/main.lua b/main.lua index 7223897..0e94d08 100644 --- a/main.lua +++ b/main.lua @@ -2,6 +2,7 @@ local globalize = require("strict") -- configuration. +globalize{preset = arg} local cfg = require("config") local gcfg = require("gameconfig") @@ -151,10 +152,12 @@ local function make_network(input_size) nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2)) nn_tz = nn_ty - --nn_tz = nn_tz:feed(nn.Reshape{13, 17 * 2}) - --nn_tz = nn_tz:feed(nn.DenseBroadcast(5)) - --nn_tz = nn_tz:feed(nn.Relu()) - -- note: due to a quirk in Merge, we don't need to flatten nn_tz. + if cfg.reduce_tiles then + nn_tz = nn_tz:feed(nn.Reshape{13, 17 * 2}) + nn_tz = nn_tz:feed(nn.DenseBroadcast(5)) + nn_tz = nn_tz:feed(nn.Relu()) + -- note: due to a quirk in Merge, we don't need to flatten nn_tz. + end nn_y = nn.Merge() nn_x:feed(nn_y) @@ -171,7 +174,7 @@ local function make_network(input_size) --]] nn_z = nn_y - nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut), true) + nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut), true, cfg.bias_out) nn_z = nn_z:feed(nn.Softmax()) return nn.Model({nn_x, nn_tx}, {nn_z}) end