diff --git a/config.lua b/config.lua index 5ccf87d..96c56ff 100644 --- a/config.lua +++ b/config.lua @@ -53,27 +53,28 @@ local common_cfg = { } local cfg = { - log_fn = 'logs-xnes.csv', - params_fn = 'params-xnes.txt', + log_fn = 'logs-snes.csv', + params_fn = 'params-snes.txt', decrement_reward = true, score_multiplier = 5, - starting_world = 1, + starting_world = 0, starting_level = 1, - starting_lives = 1, + --starting_lives = 1, cap_time = 300, - deterministic = true, + deterministic = false, --true, - epoch_trials = 50, + epoch_trials = 32, epoch_top_trials = 9999, - negate_trials = false, + negate_trials = true, - es = 'xnes', - learning_rate = 0.14, - deviation = 1.0, - weight_decay = 0.0, + es = 'snes', + learning_rate = 0.5, + mean_adapt = 0.5, + deviation = 0.5, + weight_decay = 0.025, } -- TODO: so, uhh.. diff --git a/main.lua b/main.lua index 83ec759..1b29113 100644 --- a/main.lua +++ b/main.lua @@ -148,10 +148,13 @@ local function make_network(input_size) nn_x = nn.Input({input_size}) nn_tx = nn.Input({gcfg.tile_count}) nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2)) - nn_tz = nn_ty:feed(nn.Reshape{13, 17 * 2}) - nn_tz = nn_tz:feed(nn.DenseBroadcast(5)) - nn_tz = nn_tz:feed(nn.Relu()) + + nn_tz = nn_ty + --nn_tz = nn_tz:feed(nn.Reshape{13, 17 * 2}) + --nn_tz = nn_tz:feed(nn.DenseBroadcast(5)) + --nn_tz = nn_tz:feed(nn.Relu()) -- note: due to a quirk in Merge, we don't need to flatten nn_tz. + nn_y = nn.Merge() nn_x:feed(nn_y) nn_tz:feed(nn_y)