diff --git a/config.lua b/config.lua index b9307f0..2386e6c 100644 --- a/config.lua +++ b/config.lua @@ -22,10 +22,10 @@ local common_cfg = { init_zeros = false, -- instead of random normal noise. -- network inputs (connections): - time_inputs = true, -- binary inputs of global frame count + time_inputs = true, -- insert binary inputs of a frame counter. -- network layers: - layernorm = false, + layernorm = false, -- (doesn't do anything right now) reduce_tiles = false, bias_out = true, @@ -43,6 +43,7 @@ local common_cfg = { -- sampling: deviation = 1.0, unperturbed_trial = true, -- perform an extra trial without any noise. + -- this is good for logging, so i'd recommend it. graycode = false, -- for ARS. negate_trials = true, -- try pairs of normal and negated noise directions. -- AKA antithetic sampling. note that this doubles the number of trials. @@ -80,6 +81,39 @@ if preset == 'snes' then sigma_decay = 0.01, -- note: multiplied by learning_rate. } +elseif preset == 'snes2' then + + cfg = { + es = 'snes', + + log_fn = 'logs-snes2.csv', + params_fn = 'params-snes2.txt', + + start_big = true, + min_time = 300, + timer_loser = 1.0, + + score_multiplier = 0, + + init_zeros = true, + + reduce_tiles = true, + bias_out = false, + + deterministic = false, + + deviation = 0.5, + negate_trials = false, + min_refresh = 0.5, + + epoch_trials = 100, + + learning_rate = 0.01, + mean_adapt = 1.0, + weight_decay = 0.02, + sigma_decay = 0.01, + } + elseif preset == 'xnes' then cfg = { @@ -92,7 +126,6 @@ elseif preset == 'xnes' then min_time = 300, timer_loser = 1.0, - decrement_reward = false, score_multiplier = 0, init_zeros = true, @@ -104,32 +137,67 @@ elseif preset == 'xnes' then deviation = 0.5, negate_trials = false, - min_refresh = 0.1, epoch_trials = 50, learning_rate = 0.01, - mean_adapt = 1.0, - weight_decay = 0.0, - sigma_decay = 0.0, + } + +elseif preset == 'xnes2' then + + cfg = { + es = 'xnes', + + log_fn = 'logs-xnes2.csv', + params_fn = 'params-xnes2.txt', + + start_big = true, + min_time = 300, + timer_loser = 1.0, + + score_multiplier = 0, + + init_zeros = true, + + reduce_tiles = true, + bias_out = false, + + deterministic = false, + + deviation = 0.5, + negate_trials = true, + + epoch_trials = 25, + + learning_rate = 0.01, + mean_adapt = 0.5, + weight_decay = 0.01, + sigma_decay = 0.0016, --0.00128, } elseif preset == 'ars' then cfg = { es = 'ars', - epoch_top_trials = 20, + epoch_top_trials = 20 * 2, ars_lips = false, log_fn = 'logs-ars.csv', params_fn = 'params-ars.txt', + start_big = true, min_time = 300, timer_loser = 1.0, - deviation = 0.1, + bias_out = false, - epoch_trials = 25, + deterministic = false, + + graycode = false, + deviation = 0.1, + negate_trials = false, + + epoch_trials = 25 * 2, learning_rate = 1.0, weight_decay = 0.0025,