add xNES preset, add options, allow preset specified by argument
This commit is contained in:
parent
403127bd66
commit
7800510d1f
2 changed files with 51 additions and 10 deletions
48
config.lua
48
config.lua
|
@ -1,4 +1,4 @@
|
|||
local preset = 'snes'
|
||||
preset = rawget(_G, 'preset') or 'ars'
|
||||
|
||||
local common_cfg = {
|
||||
-- read-only modes:
|
||||
|
@ -26,6 +26,8 @@ local common_cfg = {
|
|||
|
||||
-- network layers:
|
||||
layernorm = false,
|
||||
reduce_tiles = false,
|
||||
bias_out = true,
|
||||
|
||||
-- network evaluation (sampling joypad):
|
||||
frameskip = 4,
|
||||
|
@ -39,6 +41,7 @@ local common_cfg = {
|
|||
epoch_top_trials = 9999, -- for ARS.
|
||||
|
||||
-- sampling:
|
||||
deviation = 1.0,
|
||||
unperturbed_trial = true, -- perform an extra trial without any noise.
|
||||
graycode = false, -- for ARS.
|
||||
negate_trials = true, -- try pairs of normal and negated noise directions.
|
||||
|
@ -47,7 +50,7 @@ local common_cfg = {
|
|||
|
||||
-- epoch-related rates:
|
||||
learning_rate = 1.0,
|
||||
mean_adapt = 1.0, -- for xNES.
|
||||
mean_adapt = 1.0, -- for SNES, xNES.
|
||||
weight_decay = 0.0,
|
||||
sigma_decay = 0.0, -- for SNES.
|
||||
}
|
||||
|
@ -68,15 +71,49 @@ if preset == 'snes' then
|
|||
epoch_trials = 100,
|
||||
negate_trials = false,
|
||||
|
||||
deviation = 1.0,
|
||||
min_refresh = 0.2,
|
||||
|
||||
learning_rate = 0.1, -- TODO: rename to learn_primary or something.
|
||||
mean_adapt = 0.5, -- TODO: rename to learn_secondary or something.
|
||||
deviation = 1.0,
|
||||
weight_decay = 0.01, -- note: multiplied by its std, and mean_adapt.
|
||||
weight_decay = 0.02, -- note: multiplied by its std, and mean_adapt.
|
||||
sigma_decay = 0.01, -- note: multiplied by learning_rate.
|
||||
}
|
||||
|
||||
elseif preset == 'xnes' then
|
||||
|
||||
cfg = {
|
||||
es = 'xnes',
|
||||
|
||||
log_fn = 'logs-xnes.csv',
|
||||
params_fn = 'params-xnes.txt',
|
||||
|
||||
start_big = true,
|
||||
min_time = 300,
|
||||
timer_loser = 1.0,
|
||||
|
||||
decrement_reward = false,
|
||||
score_multiplier = 0,
|
||||
|
||||
init_zeros = true,
|
||||
|
||||
reduce_tiles = true,
|
||||
bias_out = false,
|
||||
|
||||
deterministic = false,
|
||||
|
||||
deviation = 0.5,
|
||||
negate_trials = false,
|
||||
min_refresh = 0.1,
|
||||
|
||||
epoch_trials = 50,
|
||||
|
||||
learning_rate = 0.01,
|
||||
mean_adapt = 1.0,
|
||||
weight_decay = 0.0,
|
||||
sigma_decay = 0.0,
|
||||
}
|
||||
|
||||
elseif preset == 'ars' then
|
||||
|
||||
cfg = {
|
||||
|
@ -90,10 +127,11 @@ elseif preset == 'ars' then
|
|||
min_time = 300,
|
||||
timer_loser = 1.0,
|
||||
|
||||
deviation = 0.1,
|
||||
|
||||
epoch_trials = 25,
|
||||
|
||||
learning_rate = 1.0,
|
||||
deviation = 0.1,
|
||||
weight_decay = 0.0025,
|
||||
}
|
||||
|
||||
|
|
11
main.lua
11
main.lua
|
@ -2,6 +2,7 @@ local globalize = require("strict")
|
|||
|
||||
-- configuration.
|
||||
|
||||
globalize{preset = arg}
|
||||
local cfg = require("config")
|
||||
local gcfg = require("gameconfig")
|
||||
|
||||
|
@ -151,10 +152,12 @@ local function make_network(input_size)
|
|||
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
|
||||
|
||||
nn_tz = nn_ty
|
||||
--nn_tz = nn_tz:feed(nn.Reshape{13, 17 * 2})
|
||||
--nn_tz = nn_tz:feed(nn.DenseBroadcast(5))
|
||||
--nn_tz = nn_tz:feed(nn.Relu())
|
||||
if cfg.reduce_tiles then
|
||||
nn_tz = nn_tz:feed(nn.Reshape{13, 17 * 2})
|
||||
nn_tz = nn_tz:feed(nn.DenseBroadcast(5))
|
||||
nn_tz = nn_tz:feed(nn.Relu())
|
||||
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
|
||||
end
|
||||
|
||||
nn_y = nn.Merge()
|
||||
nn_x:feed(nn_y)
|
||||
|
@ -171,7 +174,7 @@ local function make_network(input_size)
|
|||
--]]
|
||||
|
||||
nn_z = nn_y
|
||||
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut), true)
|
||||
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut), true, cfg.bias_out)
|
||||
nn_z = nn_z:feed(nn.Softmax())
|
||||
return nn.Model({nn_x, nn_tx}, {nn_z})
|
||||
end
|
||||
|
|
Loading…
Add table
Reference in a new issue