local preset = rawget(_G, 'preset') preset = preset ~= nil and preset ~= '' and preset or 'ars' local common_cfg = { -- read-only modes: playable_mode = false, playback_mode = false, -- controlling in-game mechanics: starting_world = 1, -- set to 0 to randomize the world every epoch. starting_level = 1, -- set to 0 to randomize the level every epoch. start_big = false, starting_lives = 0, min_time = 100, max_time = 300, timer_loser = 1/2, -- decrement the timer when mario is slacking off. -- rewards: decrement_reward = false, -- bad idea, encourages mario to run into goombas. score_multiplier = 1, -- how much the ingame score influences our rewards. -- network initialization: init_zeros = false, -- instead of random normal noise. -- network inputs (connections): time_inputs = true, -- insert binary inputs of a frame counter. -- network layers: layernorm = false, -- (doesn't do anything right now) reduce_tiles = false, bias_out = true, -- network evaluation (sampling joypad): frameskip = 4, -- true greedy epsilon has both deterministic and det_epsilon set. deterministic = false, -- use argmax on outputs instead of random sampling. det_epsilon = false, -- take random actions with probability eps. -- evolution strategy and non-rate hyperparemeters: es = 'ars', ars_lips = false, -- for ARS. epoch_top_trials = 9999, -- for ARS. -- sampling: deviation = 1.0, unperturbed_trial = true, -- perform an extra trial without any noise. -- this is good for logging, so i'd recommend it. graycode = false, -- for ARS. negate_trials = true, -- try pairs of normal and negated noise directions. -- AKA antithetic sampling. note that this doubles the number of trials. min_refresh = 0.1, -- for SNES. -- epoch-related rates: base_rate = 1.0, -- param_rate, sigma_rate, and covar_rate will base on this -- if you don't specify them individually. param_decay = 0.0, sigma_decay = 0.0, -- for SNES, xNES. } local cfg if preset == 'snes' then cfg = { es = 'snes', log_fn = 'logs-snes.csv', params_fn = 'params-snes.txt', starting_lives = 1, min_time = 300, timer_loser = 1.0, epoch_trials = 100, negate_trials = false, deviation = 1.0, min_refresh = 0.2, param_rate = 0.5, sigma_rate = 0.1, param_decay = 0.02, -- note: multiplied by its std, and param_rate. sigma_decay = 0.01, -- note: multiplied by sigma_rate. } elseif preset == 'snes2' then cfg = { es = 'snes', log_fn = 'logs-snes2.csv', params_fn = 'params-snes2.txt', start_big = true, min_time = 300, timer_loser = 1.0, score_multiplier = 0, init_zeros = true, reduce_tiles = true, bias_out = false, deterministic = false, deviation = 0.5, negate_trials = false, min_refresh = 0.5, epoch_trials = 100, param_rate = 1.0, sigma_rate = 0.01, param_decay = 0.02, sigma_decay = 0.01, } elseif preset == 'xnes' then cfg = { es = 'xnes', log_fn = 'logs-xnes.csv', params_fn = 'params-xnes.txt', start_big = true, min_time = 300, timer_loser = 1.0, score_multiplier = 0, init_zeros = true, reduce_tiles = true, bias_out = false, deterministic = false, deviation = 0.5, negate_trials = false, epoch_trials = 50, param_rate = 1.0, sigma_rate = 0.01, covar_rate = 0.01, } elseif preset == 'xnes2' then cfg = { es = 'xnes', log_fn = 'logs-xnes2.csv', params_fn = 'params-xnes2.txt', start_big = true, min_time = 300, timer_loser = 1.0, score_multiplier = 0, init_zeros = true, reduce_tiles = true, bias_out = false, deterministic = false, deviation = 1.0, negate_trials = true, --false, epoch_trials = 10, --50, param_rate = 0.5, sigma_rate = 0.04, covar_rate = 0.04, param_decay = 0.004, sigma_decay = 0.00128, } elseif preset == 'ars' then cfg = { es = 'ars', epoch_top_trials = 20 * 2, ars_lips = false, log_fn = 'logs-ars.csv', params_fn = 'params-ars.txt', start_big = true, min_time = 300, timer_loser = 1.0, bias_out = false, deterministic = false, graycode = false, deviation = 0.1, negate_trials = false, epoch_trials = 25 * 2, param_rate = 1.0, param_decay = 0.0025, } elseif preset == 'play' then cfg = { playable_mode = true, } else error("invalid preset: "..tostring(preset)) end -- TODO: so, uhh.. -- what happens when playback_mode is true but unperturbed_trial is false? setmetatable(cfg, {__index=common_cfg}) -- gets overridden later. if cfg.es == 'ars' then if cfg.param_rate == nil then cfg.param_rate = cfg.base_rate end else if cfg.param_rate == nil then cfg.param_rate = 1.0 end end if cfg.sigma_rate == nil then cfg.sigma_rate = cfg.base_rate end if cfg.covar_rate == nil then cfg.covar_rate = cfg.sigma_rate end setmetatable(cfg, { __index = function(t, n) if common_cfg[n] ~= nil then return common_cfg[n] end if n == 'log_fn' then return nil end if n == 'params_fn' then return nil end if n == 'stats_fn' then return nil end error("cannot use undeclared config '" .. tostring(n) .. "'", 2) end }) cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials) cfg.eps_start = 1.0 * cfg.frameskip / 64 cfg.eps_stop = 0.1 * cfg.eps_start cfg.eps_frames = 1000000 cfg.enable_overlay = cfg.playable_mode cfg.enable_network = not cfg.playable_mode assert(not cfg.ars_lips or cfg.unperturbed_trial, "cfg.unperturbed_trial must be true to use cfg.ars_lips") assert(not cfg.ars_lips or cfg.negate_trials, "cfg.negate_trials must be true to use cfg.ars_lips") assert(not (cfg.es == 'snes' and cfg.negate_trials), "cfg.negate_trials is not yet compatible with SNES") return cfg