diff --git a/main.lua b/main.lua index d69b734..a6f3e6c 100644 --- a/main.lua +++ b/main.lua @@ -675,7 +675,7 @@ local function learn_from_epoch() print("step stddev:", step_dev) --print("full step stddev:", cfg.learning_rate * step_dev) - local momstep_mean, momstep_dev + local momstep_mean, momstep_dev = 0, 0 if cfg.adamant then if mom1 == nil then mom1 = nn.zeros(#step) end if mom2 == nil then mom2 = nn.zeros(#step) end @@ -724,9 +724,18 @@ local function learn_from_epoch() weight_std = weight_std, } + -- trying a heuristic... + --[[ + if delta_std < trial_std then + cfg.deviation = cfg.deviation * 0.933 + -- this one might be bad... + cfg.weight_decay = cfg.weight_decay * 0.933 + end + --]] + if cfg.enable_network then network:distribute(base_params) - network:save() + network:save(cfg.params_fn) else print("note: not updating weights in playable mode.") end @@ -850,7 +859,7 @@ local function init() end --print(emu.framecount()) - local res, err = pcall(network.load, network) + local res, err = pcall(network.load, network, cfg.params_fn) if res == false then print(err) end end