add xNES preset, add options, allow preset specified by argument
This commit is contained in:
parent
403127bd66
commit
7800510d1f
2 changed files with 51 additions and 10 deletions
48
config.lua
48
config.lua
|
@ -1,4 +1,4 @@
|
||||||
local preset = 'snes'
|
preset = rawget(_G, 'preset') or 'ars'
|
||||||
|
|
||||||
local common_cfg = {
|
local common_cfg = {
|
||||||
-- read-only modes:
|
-- read-only modes:
|
||||||
|
@ -26,6 +26,8 @@ local common_cfg = {
|
||||||
|
|
||||||
-- network layers:
|
-- network layers:
|
||||||
layernorm = false,
|
layernorm = false,
|
||||||
|
reduce_tiles = false,
|
||||||
|
bias_out = true,
|
||||||
|
|
||||||
-- network evaluation (sampling joypad):
|
-- network evaluation (sampling joypad):
|
||||||
frameskip = 4,
|
frameskip = 4,
|
||||||
|
@ -39,6 +41,7 @@ local common_cfg = {
|
||||||
epoch_top_trials = 9999, -- for ARS.
|
epoch_top_trials = 9999, -- for ARS.
|
||||||
|
|
||||||
-- sampling:
|
-- sampling:
|
||||||
|
deviation = 1.0,
|
||||||
unperturbed_trial = true, -- perform an extra trial without any noise.
|
unperturbed_trial = true, -- perform an extra trial without any noise.
|
||||||
graycode = false, -- for ARS.
|
graycode = false, -- for ARS.
|
||||||
negate_trials = true, -- try pairs of normal and negated noise directions.
|
negate_trials = true, -- try pairs of normal and negated noise directions.
|
||||||
|
@ -47,7 +50,7 @@ local common_cfg = {
|
||||||
|
|
||||||
-- epoch-related rates:
|
-- epoch-related rates:
|
||||||
learning_rate = 1.0,
|
learning_rate = 1.0,
|
||||||
mean_adapt = 1.0, -- for xNES.
|
mean_adapt = 1.0, -- for SNES, xNES.
|
||||||
weight_decay = 0.0,
|
weight_decay = 0.0,
|
||||||
sigma_decay = 0.0, -- for SNES.
|
sigma_decay = 0.0, -- for SNES.
|
||||||
}
|
}
|
||||||
|
@ -68,15 +71,49 @@ if preset == 'snes' then
|
||||||
epoch_trials = 100,
|
epoch_trials = 100,
|
||||||
negate_trials = false,
|
negate_trials = false,
|
||||||
|
|
||||||
|
deviation = 1.0,
|
||||||
min_refresh = 0.2,
|
min_refresh = 0.2,
|
||||||
|
|
||||||
learning_rate = 0.1, -- TODO: rename to learn_primary or something.
|
learning_rate = 0.1, -- TODO: rename to learn_primary or something.
|
||||||
mean_adapt = 0.5, -- TODO: rename to learn_secondary or something.
|
mean_adapt = 0.5, -- TODO: rename to learn_secondary or something.
|
||||||
deviation = 1.0,
|
weight_decay = 0.02, -- note: multiplied by its std, and mean_adapt.
|
||||||
weight_decay = 0.01, -- note: multiplied by its std, and mean_adapt.
|
|
||||||
sigma_decay = 0.01, -- note: multiplied by learning_rate.
|
sigma_decay = 0.01, -- note: multiplied by learning_rate.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
elseif preset == 'xnes' then
|
||||||
|
|
||||||
|
cfg = {
|
||||||
|
es = 'xnes',
|
||||||
|
|
||||||
|
log_fn = 'logs-xnes.csv',
|
||||||
|
params_fn = 'params-xnes.txt',
|
||||||
|
|
||||||
|
start_big = true,
|
||||||
|
min_time = 300,
|
||||||
|
timer_loser = 1.0,
|
||||||
|
|
||||||
|
decrement_reward = false,
|
||||||
|
score_multiplier = 0,
|
||||||
|
|
||||||
|
init_zeros = true,
|
||||||
|
|
||||||
|
reduce_tiles = true,
|
||||||
|
bias_out = false,
|
||||||
|
|
||||||
|
deterministic = false,
|
||||||
|
|
||||||
|
deviation = 0.5,
|
||||||
|
negate_trials = false,
|
||||||
|
min_refresh = 0.1,
|
||||||
|
|
||||||
|
epoch_trials = 50,
|
||||||
|
|
||||||
|
learning_rate = 0.01,
|
||||||
|
mean_adapt = 1.0,
|
||||||
|
weight_decay = 0.0,
|
||||||
|
sigma_decay = 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
elseif preset == 'ars' then
|
elseif preset == 'ars' then
|
||||||
|
|
||||||
cfg = {
|
cfg = {
|
||||||
|
@ -90,10 +127,11 @@ elseif preset == 'ars' then
|
||||||
min_time = 300,
|
min_time = 300,
|
||||||
timer_loser = 1.0,
|
timer_loser = 1.0,
|
||||||
|
|
||||||
|
deviation = 0.1,
|
||||||
|
|
||||||
epoch_trials = 25,
|
epoch_trials = 25,
|
||||||
|
|
||||||
learning_rate = 1.0,
|
learning_rate = 1.0,
|
||||||
deviation = 0.1,
|
|
||||||
weight_decay = 0.0025,
|
weight_decay = 0.0025,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
11
main.lua
11
main.lua
|
@ -2,6 +2,7 @@ local globalize = require("strict")
|
||||||
|
|
||||||
-- configuration.
|
-- configuration.
|
||||||
|
|
||||||
|
globalize{preset = arg}
|
||||||
local cfg = require("config")
|
local cfg = require("config")
|
||||||
local gcfg = require("gameconfig")
|
local gcfg = require("gameconfig")
|
||||||
|
|
||||||
|
@ -151,10 +152,12 @@ local function make_network(input_size)
|
||||||
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
|
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
|
||||||
|
|
||||||
nn_tz = nn_ty
|
nn_tz = nn_ty
|
||||||
--nn_tz = nn_tz:feed(nn.Reshape{13, 17 * 2})
|
if cfg.reduce_tiles then
|
||||||
--nn_tz = nn_tz:feed(nn.DenseBroadcast(5))
|
nn_tz = nn_tz:feed(nn.Reshape{13, 17 * 2})
|
||||||
--nn_tz = nn_tz:feed(nn.Relu())
|
nn_tz = nn_tz:feed(nn.DenseBroadcast(5))
|
||||||
|
nn_tz = nn_tz:feed(nn.Relu())
|
||||||
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
|
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
|
||||||
|
end
|
||||||
|
|
||||||
nn_y = nn.Merge()
|
nn_y = nn.Merge()
|
||||||
nn_x:feed(nn_y)
|
nn_x:feed(nn_y)
|
||||||
|
@ -171,7 +174,7 @@ local function make_network(input_size)
|
||||||
--]]
|
--]]
|
||||||
|
|
||||||
nn_z = nn_y
|
nn_z = nn_y
|
||||||
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut), true)
|
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut), true, cfg.bias_out)
|
||||||
nn_z = nn_z:feed(nn.Softmax())
|
nn_z = nn_z:feed(nn.Softmax())
|
||||||
return nn.Model({nn_x, nn_tx}, {nn_z})
|
return nn.Model({nn_x, nn_tx}, {nn_z})
|
||||||
end
|
end
|
||||||
|
|
Loading…
Add table
Reference in a new issue