allow weights/params file to be configured
This commit is contained in:
parent
90922a2bc3
commit
feaf86dc6b
1 changed files with 12 additions and 3 deletions
15
main.lua
15
main.lua
|
@ -675,7 +675,7 @@ local function learn_from_epoch()
|
|||
print("step stddev:", step_dev)
|
||||
--print("full step stddev:", cfg.learning_rate * step_dev)
|
||||
|
||||
local momstep_mean, momstep_dev
|
||||
local momstep_mean, momstep_dev = 0, 0
|
||||
if cfg.adamant then
|
||||
if mom1 == nil then mom1 = nn.zeros(#step) end
|
||||
if mom2 == nil then mom2 = nn.zeros(#step) end
|
||||
|
@ -724,9 +724,18 @@ local function learn_from_epoch()
|
|||
weight_std = weight_std,
|
||||
}
|
||||
|
||||
-- trying a heuristic...
|
||||
--[[
|
||||
if delta_std < trial_std then
|
||||
cfg.deviation = cfg.deviation * 0.933
|
||||
-- this one might be bad...
|
||||
cfg.weight_decay = cfg.weight_decay * 0.933
|
||||
end
|
||||
--]]
|
||||
|
||||
if cfg.enable_network then
|
||||
network:distribute(base_params)
|
||||
network:save()
|
||||
network:save(cfg.params_fn)
|
||||
else
|
||||
print("note: not updating weights in playable mode.")
|
||||
end
|
||||
|
@ -850,7 +859,7 @@ local function init()
|
|||
end
|
||||
--print(emu.framecount())
|
||||
|
||||
local res, err = pcall(network.load, network)
|
||||
local res, err = pcall(network.load, network, cfg.params_fn)
|
||||
if res == false then print(err) end
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in a new issue