add xNES preset, add options, allow preset specified by argument

This commit is contained in:
Connor Olding 2018-06-13 03:00:42 +02:00
parent 403127bd66
commit 7800510d1f
2 changed files with 51 additions and 10 deletions

View File

@ -1,4 +1,4 @@
local preset = 'snes'
preset = rawget(_G, 'preset') or 'ars'
local common_cfg = {
-- read-only modes:
@ -26,6 +26,8 @@ local common_cfg = {
-- network layers:
layernorm = false,
reduce_tiles = false,
bias_out = true,
-- network evaluation (sampling joypad):
frameskip = 4,
@ -39,6 +41,7 @@ local common_cfg = {
epoch_top_trials = 9999, -- for ARS.
-- sampling:
deviation = 1.0,
unperturbed_trial = true, -- perform an extra trial without any noise.
graycode = false, -- for ARS.
negate_trials = true, -- try pairs of normal and negated noise directions.
@ -47,7 +50,7 @@ local common_cfg = {
-- epoch-related rates:
learning_rate = 1.0,
mean_adapt = 1.0, -- for xNES.
mean_adapt = 1.0, -- for SNES, xNES.
weight_decay = 0.0,
sigma_decay = 0.0, -- for SNES.
}
@ -68,15 +71,49 @@ if preset == 'snes' then
epoch_trials = 100,
negate_trials = false,
deviation = 1.0,
min_refresh = 0.2,
learning_rate = 0.1, -- TODO: rename to learn_primary or something.
mean_adapt = 0.5, -- TODO: rename to learn_secondary or something.
deviation = 1.0,
weight_decay = 0.01, -- note: multiplied by its std, and mean_adapt.
weight_decay = 0.02, -- note: multiplied by its std, and mean_adapt.
sigma_decay = 0.01, -- note: multiplied by learning_rate.
}
elseif preset == 'xnes' then
cfg = {
es = 'xnes',
log_fn = 'logs-xnes.csv',
params_fn = 'params-xnes.txt',
start_big = true,
min_time = 300,
timer_loser = 1.0,
decrement_reward = false,
score_multiplier = 0,
init_zeros = true,
reduce_tiles = true,
bias_out = false,
deterministic = false,
deviation = 0.5,
negate_trials = false,
min_refresh = 0.1,
epoch_trials = 50,
learning_rate = 0.01,
mean_adapt = 1.0,
weight_decay = 0.0,
sigma_decay = 0.0,
}
elseif preset == 'ars' then
cfg = {
@ -90,10 +127,11 @@ elseif preset == 'ars' then
min_time = 300,
timer_loser = 1.0,
deviation = 0.1,
epoch_trials = 25,
learning_rate = 1.0,
deviation = 0.1,
weight_decay = 0.0025,
}

View File

@ -2,6 +2,7 @@ local globalize = require("strict")
-- configuration.
globalize{preset = arg}
local cfg = require("config")
local gcfg = require("gameconfig")
@ -151,10 +152,12 @@ local function make_network(input_size)
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
nn_tz = nn_ty
--nn_tz = nn_tz:feed(nn.Reshape{13, 17 * 2})
--nn_tz = nn_tz:feed(nn.DenseBroadcast(5))
--nn_tz = nn_tz:feed(nn.Relu())
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
if cfg.reduce_tiles then
nn_tz = nn_tz:feed(nn.Reshape{13, 17 * 2})
nn_tz = nn_tz:feed(nn.DenseBroadcast(5))
nn_tz = nn_tz:feed(nn.Relu())
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
end
nn_y = nn.Merge()
nn_x:feed(nn_y)
@ -171,7 +174,7 @@ local function make_network(input_size)
--]]
nn_z = nn_y
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut), true)
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut), true, cfg.bias_out)
nn_z = nn_z:feed(nn.Softmax())
return nn.Model({nn_x, nn_tx}, {nn_z})
end