2018-06-15 15:37:19 -07:00
|
|
|
local preset = rawget(_G, 'preset')
|
|
|
|
preset = preset ~= nil and preset ~= '' and preset or 'ars'
|
2018-06-12 17:02:56 -07:00
|
|
|
|
2018-06-16 20:48:15 -07:00
|
|
|
local defaults = {
|
2018-06-12 17:02:56 -07:00
|
|
|
-- read-only modes:
|
2018-04-02 06:21:55 -07:00
|
|
|
playable_mode = false,
|
2018-05-06 20:55:58 -07:00
|
|
|
playback_mode = false,
|
2018-06-12 17:02:56 -07:00
|
|
|
|
|
|
|
-- controlling in-game mechanics:
|
|
|
|
starting_world = 1, -- set to 0 to randomize the world every epoch.
|
|
|
|
starting_level = 1, -- set to 0 to randomize the level every epoch.
|
2018-05-06 20:55:58 -07:00
|
|
|
start_big = false,
|
|
|
|
starting_lives = 0,
|
2018-06-12 17:02:56 -07:00
|
|
|
min_time = 100,
|
|
|
|
max_time = 300,
|
|
|
|
timer_loser = 1/2, -- decrement the timer when mario is slacking off.
|
2018-04-02 07:29:12 -07:00
|
|
|
|
2018-06-12 17:02:56 -07:00
|
|
|
-- rewards:
|
|
|
|
decrement_reward = false, -- bad idea, encourages mario to run into goombas.
|
|
|
|
score_multiplier = 1, -- how much the ingame score influences our rewards.
|
|
|
|
|
|
|
|
-- network initialization:
|
|
|
|
init_zeros = false, -- instead of random normal noise.
|
|
|
|
|
|
|
|
-- network inputs (connections):
|
2018-06-14 13:41:43 -07:00
|
|
|
time_inputs = true, -- insert binary inputs of a frame counter.
|
2018-06-12 17:02:56 -07:00
|
|
|
|
|
|
|
-- network layers:
|
2018-06-17 14:48:43 -07:00
|
|
|
hidden = false, -- use a hidden layer with ReLU/GELU activation.
|
|
|
|
hidden_size = 128,
|
|
|
|
layernorm = false, -- use a LayerNorm layer after said activation.
|
2018-06-12 18:00:42 -07:00
|
|
|
reduce_tiles = false,
|
|
|
|
bias_out = true,
|
2018-06-12 17:02:56 -07:00
|
|
|
|
|
|
|
-- network evaluation (sampling joypad):
|
2018-04-02 06:21:55 -07:00
|
|
|
frameskip = 4,
|
|
|
|
-- true greedy epsilon has both deterministic and det_epsilon set.
|
2018-05-02 04:06:28 -07:00
|
|
|
deterministic = false, -- use argmax on outputs instead of random sampling.
|
2018-04-02 06:21:55 -07:00
|
|
|
det_epsilon = false, -- take random actions with probability eps.
|
2018-04-02 07:29:12 -07:00
|
|
|
|
2018-06-12 17:02:56 -07:00
|
|
|
-- evolution strategy and non-rate hyperparemeters:
|
|
|
|
es = 'ars',
|
|
|
|
ars_lips = false, -- for ARS.
|
|
|
|
epoch_top_trials = 9999, -- for ARS.
|
|
|
|
|
|
|
|
-- sampling:
|
2018-06-12 18:00:42 -07:00
|
|
|
deviation = 1.0,
|
2018-06-12 17:02:56 -07:00
|
|
|
unperturbed_trial = true, -- perform an extra trial without any noise.
|
2018-06-14 13:41:43 -07:00
|
|
|
-- this is good for logging, so i'd recommend it.
|
2018-06-16 20:48:15 -07:00
|
|
|
epoch_trials = 50,
|
2018-06-12 17:02:56 -07:00
|
|
|
graycode = false, -- for ARS.
|
2018-04-02 06:21:55 -07:00
|
|
|
negate_trials = true, -- try pairs of normal and negated noise directions.
|
2018-06-09 08:56:18 -07:00
|
|
|
-- AKA antithetic sampling. note that this doubles the number of trials.
|
2018-06-12 17:02:56 -07:00
|
|
|
min_refresh = 0.1, -- for SNES.
|
2018-05-03 06:33:17 -07:00
|
|
|
|
2018-06-12 17:02:56 -07:00
|
|
|
-- epoch-related rates:
|
2018-06-15 15:33:11 -07:00
|
|
|
base_rate = 1.0, -- param_rate, sigma_rate, and covar_rate will base on this
|
2018-06-15 15:24:55 -07:00
|
|
|
-- if you don't specify them individually.
|
2018-06-15 15:33:11 -07:00
|
|
|
param_decay = 0.0,
|
2018-06-14 13:25:54 -07:00
|
|
|
sigma_decay = 0.0, -- for SNES, xNES.
|
2018-06-20 20:14:45 -07:00
|
|
|
momentum = 0.0, -- for ARS.
|
2018-06-12 17:02:56 -07:00
|
|
|
}
|
2018-06-10 07:38:25 -07:00
|
|
|
|
2018-06-16 20:48:15 -07:00
|
|
|
local presets = require("presets")
|
2018-06-12 17:02:56 -07:00
|
|
|
|
2018-06-16 20:48:15 -07:00
|
|
|
for _, cfg in pairs(presets) do
|
|
|
|
local parent = defaults
|
|
|
|
if cfg.parent ~= nil then
|
|
|
|
parent = presets[cfg.parent]
|
|
|
|
assert(parent, "no such parent preset: "..tostring(cfg.parent))
|
|
|
|
end
|
|
|
|
setmetatable(cfg, {__index=parent})
|
2018-06-12 17:02:56 -07:00
|
|
|
end
|
2018-04-02 06:21:55 -07:00
|
|
|
|
2018-06-16 20:48:15 -07:00
|
|
|
local cfg = presets[preset]
|
|
|
|
assert(cfg, "invalid preset: "..tostring(preset))
|
|
|
|
|
2018-05-02 04:06:28 -07:00
|
|
|
-- TODO: so, uhh..
|
|
|
|
-- what happens when playback_mode is true but unperturbed_trial is false?
|
|
|
|
|
2018-06-15 15:24:55 -07:00
|
|
|
if cfg.es == 'ars' then
|
2018-06-15 15:33:11 -07:00
|
|
|
if cfg.param_rate == nil then cfg.param_rate = cfg.base_rate end
|
2018-06-15 15:24:55 -07:00
|
|
|
else
|
2018-06-15 15:33:11 -07:00
|
|
|
if cfg.param_rate == nil then cfg.param_rate = 1.0 end
|
2018-06-15 15:24:55 -07:00
|
|
|
end
|
|
|
|
if cfg.sigma_rate == nil then cfg.sigma_rate = cfg.base_rate end
|
|
|
|
if cfg.covar_rate == nil then cfg.covar_rate = cfg.sigma_rate end
|
|
|
|
|
2018-06-16 20:48:15 -07:00
|
|
|
local parent = getmetatable(cfg).__index
|
|
|
|
|
2018-05-06 20:55:58 -07:00
|
|
|
setmetatable(cfg, {
|
|
|
|
__index = function(t, n)
|
2018-06-16 20:48:15 -07:00
|
|
|
if parent ~= nil and parent[n] ~= nil then return parent[n] end
|
|
|
|
if n == 'name' then return nil end
|
2018-05-07 00:20:01 -07:00
|
|
|
if n == 'log_fn' then return nil end
|
2018-05-06 20:55:58 -07:00
|
|
|
if n == 'params_fn' then return nil end
|
2018-05-07 00:20:01 -07:00
|
|
|
if n == 'stats_fn' then return nil end
|
2018-05-06 20:55:58 -07:00
|
|
|
error("cannot use undeclared config '" .. tostring(n) .. "'", 2)
|
|
|
|
end
|
|
|
|
})
|
|
|
|
|
2018-04-02 06:21:55 -07:00
|
|
|
cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials)
|
|
|
|
|
|
|
|
cfg.eps_start = 1.0 * cfg.frameskip / 64
|
|
|
|
cfg.eps_stop = 0.1 * cfg.eps_start
|
|
|
|
cfg.eps_frames = 1000000
|
|
|
|
cfg.enable_overlay = cfg.playable_mode
|
|
|
|
cfg.enable_network = not cfg.playable_mode
|
|
|
|
|
2018-05-06 20:55:58 -07:00
|
|
|
assert(not cfg.ars_lips or cfg.unperturbed_trial,
|
|
|
|
"cfg.unperturbed_trial must be true to use cfg.ars_lips")
|
2018-06-07 17:45:07 -07:00
|
|
|
assert(not cfg.ars_lips or cfg.negate_trials,
|
|
|
|
"cfg.negate_trials must be true to use cfg.ars_lips")
|
2018-06-12 17:02:56 -07:00
|
|
|
assert(not (cfg.es == 'snes' and cfg.negate_trials),
|
|
|
|
"cfg.negate_trials is not yet compatible with SNES")
|
2018-06-09 08:56:18 -07:00
|
|
|
|
2018-05-06 20:55:58 -07:00
|
|
|
return cfg
|