diff --git a/config.lua b/config.lua index 67323fe..9e5d4e9 100644 --- a/config.lua +++ b/config.lua @@ -1,3 +1,5 @@ +local preset = 'snes' + local function approx_cossim(dim) return math.pow(1.521 * dim - 0.521, -0.5026) end @@ -13,75 +15,107 @@ local function intmap(x) end local common_cfg = { + -- read-only modes: playable_mode = false, playback_mode = false, + + -- controlling in-game mechanics: + starting_world = 1, -- set to 0 to randomize the world every epoch. + starting_level = 1, -- set to 0 to randomize the level every epoch. start_big = false, starting_lives = 0, + min_time = 100, + max_time = 300, + timer_loser = 1/2, -- decrement the timer when mario is slacking off. + -- rewards: + decrement_reward = false, -- bad idea, encourages mario to run into goombas. + score_multiplier = 1, -- how much the ingame score influences our rewards. + + -- network initialization: + init_zeros = false, -- instead of random normal noise. + + -- network inputs (connections): + time_inputs = true, -- binary inputs of global frame count + + -- network layers: + layernorm = false, + + -- network evaluation (sampling joypad): frameskip = 4, -- true greedy epsilon has both deterministic and det_epsilon set. deterministic = false, -- use argmax on outputs instead of random sampling. det_epsilon = false, -- take random actions with probability eps. - layernorm = false, - init_zeros = true, -- instead of he_normal noise or whatever. - graycode = false, - unperturbed_trial = true, -- do a trial without any noise. + -- evolution strategy and non-rate hyperparemeters: + es = 'ars', + ars_lips = false, -- for ARS. + epoch_top_trials = 9999, -- for ARS. + + -- sampling: + unperturbed_trial = true, -- perform an extra trial without any noise. + graycode = false, -- for ARS. negate_trials = true, -- try pairs of normal and negated noise directions. -- AKA antithetic sampling. note that this doubles the number of trials. - time_inputs = true, -- binary inputs of global frame count - normalize_inputs = false, + min_refresh = 0.1, -- for SNES. + -- epoch-related rates: learning_rate = 1.0, - mean_adapt = 1.0, -- for xNES + mean_adapt = 1.0, -- for xNES. weight_decay = 0.0, - sigma_decay = 0.0, - min_refresh = 0.2, - - es = 'ars', - ars_lips = false, - adamant = false, -- run steps through AMSgrad. - adam_b1 = math.pow(10, -1 / 1), -- fewer trials, more momentum! - adam_b2 = math.pow(10, -1 / 50), - adam_eps = intmap(-1), -- focus on b1 rather than b2. - adam_debias = true, - - min_time = 100, - max_time = 300, - timer_loser = 1/2, - decrement_reward = false, -- bad idea, encourages mario to run into goombas. - score_multiplier = 1, -- how much the ingame score influences our rewards. - - starting_world = 1, -- set to 0 for random! - starting_level = 1, -- set to 0 for random! + sigma_decay = 0.0, -- for SNES. } -local cfg = { - log_fn = 'logs-snes.csv', - params_fn = 'params-snes.txt', +local cfg +if preset == 'snes' then - decrement_reward = true, - score_multiplier = 5, + cfg = { + es = 'snes', - starting_world = 0, - starting_level = 1, - --starting_lives = 1, - min_time = 300, - max_time = 300, + log_fn = 'logs-snes.csv', + params_fn = 'params-snes.txt', - deterministic = false, --true, + starting_lives = 1, + min_time = 300, + timer_loser = 1.0, - epoch_trials = 32, - epoch_top_trials = 9999, - negate_trials = true, + epoch_trials = 100, + negate_trials = false, - es = 'snes', - learning_rate = 0.5, - mean_adapt = 0.5, - deviation = 0.5, - weight_decay = 0.025, - sigma_decay = 0.001, -} + min_refresh = 0.2, + + learning_rate = 0.1, -- TODO: rename to learn_primary or something. + mean_adapt = 0.5, -- TODO: rename to learn_secondary or something. + deviation = 1.0, + weight_decay = 0.01, -- note: multiplied by its std, and mean_adapt. + sigma_decay = 0.01, -- note: multiplied by learning_rate. + } + +elseif preset == 'ars' then + + cfg = { + es = 'ars', + epoch_top_trials = 20, + ars_lips = false, + + log_fn = 'logs-ars.csv', + params_fn = 'params-ars.txt', + + min_time = 300, + timer_loser = 1.0, + + epoch_trials = 25, + + learning_rate = 1.0, + deviation = 0.1, + weight_decay = 0.0025, + } + +else + + error("invalid preset: "..tostring(preset)) + +end -- TODO: so, uhh.. -- what happens when playback_mode is true but unperturbed_trial is false? @@ -108,8 +142,7 @@ assert(not cfg.ars_lips or cfg.unperturbed_trial, "cfg.unperturbed_trial must be true to use cfg.ars_lips") assert(not cfg.ars_lips or cfg.negate_trials, "cfg.negate_trials must be true to use cfg.ars_lips") - -assert(not cfg.adamant, - "cfg.adamant not yet re-implemented") +assert(not (cfg.es == 'snes' and cfg.negate_trials), + "cfg.negate_trials is not yet compatible with SNES") return cfg