smbot/config.lua

260 lines
6.3 KiB
Lua

local preset = rawget(_G, 'preset')
preset = preset ~= nil and preset ~= '' and preset or 'ars'
local common_cfg = {
-- read-only modes:
playable_mode = false,
playback_mode = false,
-- controlling in-game mechanics:
starting_world = 1, -- set to 0 to randomize the world every epoch.
starting_level = 1, -- set to 0 to randomize the level every epoch.
start_big = false,
starting_lives = 0,
min_time = 100,
max_time = 300,
timer_loser = 1/2, -- decrement the timer when mario is slacking off.
-- rewards:
decrement_reward = false, -- bad idea, encourages mario to run into goombas.
score_multiplier = 1, -- how much the ingame score influences our rewards.
-- network initialization:
init_zeros = false, -- instead of random normal noise.
-- network inputs (connections):
time_inputs = true, -- insert binary inputs of a frame counter.
-- network layers:
layernorm = false, -- (doesn't do anything right now)
reduce_tiles = false,
bias_out = true,
-- network evaluation (sampling joypad):
frameskip = 4,
-- true greedy epsilon has both deterministic and det_epsilon set.
deterministic = false, -- use argmax on outputs instead of random sampling.
det_epsilon = false, -- take random actions with probability eps.
-- evolution strategy and non-rate hyperparemeters:
es = 'ars',
ars_lips = false, -- for ARS.
epoch_top_trials = 9999, -- for ARS.
-- sampling:
deviation = 1.0,
unperturbed_trial = true, -- perform an extra trial without any noise.
-- this is good for logging, so i'd recommend it.
graycode = false, -- for ARS.
negate_trials = true, -- try pairs of normal and negated noise directions.
-- AKA antithetic sampling. note that this doubles the number of trials.
min_refresh = 0.1, -- for SNES.
-- epoch-related rates:
base_rate = 1.0, -- param_rate, sigma_rate, and covar_rate will base on this
-- if you don't specify them individually.
param_decay = 0.0,
sigma_decay = 0.0, -- for SNES, xNES.
}
local cfg
if preset == 'snes' then
cfg = {
es = 'snes',
log_fn = 'logs-snes.csv',
params_fn = 'params-snes.txt',
starting_lives = 1,
min_time = 300,
timer_loser = 1.0,
epoch_trials = 100,
negate_trials = false,
deviation = 1.0,
min_refresh = 0.2,
param_rate = 0.5,
sigma_rate = 0.1,
param_decay = 0.02, -- note: multiplied by its std, and param_rate.
sigma_decay = 0.01, -- note: multiplied by sigma_rate.
}
elseif preset == 'snes2' then
cfg = {
es = 'snes',
log_fn = 'logs-snes2.csv',
params_fn = 'params-snes2.txt',
start_big = true,
min_time = 300,
timer_loser = 1.0,
score_multiplier = 0,
init_zeros = true,
reduce_tiles = true,
bias_out = false,
deterministic = false,
deviation = 0.5,
negate_trials = false,
min_refresh = 0.5,
epoch_trials = 100,
param_rate = 1.0,
sigma_rate = 0.01,
param_decay = 0.02,
sigma_decay = 0.01,
}
elseif preset == 'xnes' then
cfg = {
es = 'xnes',
log_fn = 'logs-xnes.csv',
params_fn = 'params-xnes.txt',
start_big = true,
min_time = 300,
timer_loser = 1.0,
score_multiplier = 0,
init_zeros = true,
reduce_tiles = true,
bias_out = false,
deterministic = false,
deviation = 0.5,
negate_trials = false,
epoch_trials = 50,
param_rate = 1.0,
sigma_rate = 0.01,
covar_rate = 0.01,
}
elseif preset == 'xnes2' then
cfg = {
es = 'xnes',
log_fn = 'logs-xnes2.csv',
params_fn = 'params-xnes2.txt',
start_big = true,
min_time = 300,
timer_loser = 1.0,
score_multiplier = 0,
init_zeros = true,
reduce_tiles = true,
bias_out = false,
deterministic = false,
deviation = 1.0,
negate_trials = true, --false,
epoch_trials = 10, --50,
param_rate = 0.5,
sigma_rate = 0.04,
covar_rate = 0.04,
param_decay = 0.004,
sigma_decay = 0.00128,
}
elseif preset == 'ars' then
cfg = {
es = 'ars',
epoch_top_trials = 20 * 2,
ars_lips = false,
log_fn = 'logs-ars.csv',
params_fn = 'params-ars.txt',
start_big = true,
min_time = 300,
timer_loser = 1.0,
bias_out = false,
deterministic = false,
graycode = false,
deviation = 0.1,
negate_trials = false,
epoch_trials = 25 * 2,
param_rate = 1.0,
param_decay = 0.0025,
}
elseif preset == 'play' then
cfg = {
playable_mode = true,
}
else
error("invalid preset: "..tostring(preset))
end
-- TODO: so, uhh..
-- what happens when playback_mode is true but unperturbed_trial is false?
setmetatable(cfg, {__index=common_cfg}) -- gets overridden later.
if cfg.es == 'ars' then
if cfg.param_rate == nil then cfg.param_rate = cfg.base_rate end
else
if cfg.param_rate == nil then cfg.param_rate = 1.0 end
end
if cfg.sigma_rate == nil then cfg.sigma_rate = cfg.base_rate end
if cfg.covar_rate == nil then cfg.covar_rate = cfg.sigma_rate end
setmetatable(cfg, {
__index = function(t, n)
if common_cfg[n] ~= nil then return common_cfg[n] end
if n == 'log_fn' then return nil end
if n == 'params_fn' then return nil end
if n == 'stats_fn' then return nil end
error("cannot use undeclared config '" .. tostring(n) .. "'", 2)
end
})
cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials)
cfg.eps_start = 1.0 * cfg.frameskip / 64
cfg.eps_stop = 0.1 * cfg.eps_start
cfg.eps_frames = 1000000
cfg.enable_overlay = cfg.playable_mode
cfg.enable_network = not cfg.playable_mode
assert(not cfg.ars_lips or cfg.unperturbed_trial,
"cfg.unperturbed_trial must be true to use cfg.ars_lips")
assert(not cfg.ars_lips or cfg.negate_trials,
"cfg.negate_trials must be true to use cfg.ars_lips")
assert(not (cfg.es == 'snes' and cfg.negate_trials),
"cfg.negate_trials is not yet compatible with SNES")
return cfg