make filenames local to main
This commit is contained in:
parent
4a09280be4
commit
50a7ba78f9
1 changed files with 9 additions and 6 deletions
15
main.lua
15
main.lua
|
@ -7,6 +7,9 @@ local gcfg = require("gameconfig")
|
||||||
|
|
||||||
-- state.
|
-- state.
|
||||||
|
|
||||||
|
local params_fn
|
||||||
|
local std_fn
|
||||||
|
|
||||||
local epoch_i = 0
|
local epoch_i = 0
|
||||||
local base_params
|
local base_params
|
||||||
local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled.
|
local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled.
|
||||||
|
@ -313,10 +316,9 @@ local function learn_from_epoch()
|
||||||
|
|
||||||
if cfg.enable_network then
|
if cfg.enable_network then
|
||||||
network:distribute(base_params)
|
network:distribute(base_params)
|
||||||
network:save(cfg.params_fn)
|
network:save(params_fn)
|
||||||
|
|
||||||
if cfg.es == 'snes' then
|
if cfg.es == 'snes' then
|
||||||
local std_fn = cfg.params_fn:gsub(".txt", "")..".sigma.txt"
|
|
||||||
local f = assert(open(std_fn, "w"))
|
local f = assert(open(std_fn, "w"))
|
||||||
for _, v in ipairs(es.std) do f:write(("%f\n"):format(v)) end
|
for _, v in ipairs(es.std) do f:write(("%f\n"):format(v)) end
|
||||||
f:close()
|
f:close()
|
||||||
|
@ -471,8 +473,11 @@ local function init()
|
||||||
loadlevel(cfg.starting_world, cfg.starting_level)
|
loadlevel(cfg.starting_world, cfg.starting_level)
|
||||||
end
|
end
|
||||||
|
|
||||||
if exists(cfg.params_fn) then
|
params_fn = cfg.params_fn or ('network%07i.txt'):format(network.n_param)
|
||||||
network:load(cfg.params_fn)
|
std_fn = params_fn:gsub(".txt", "")..".sigma.txt"
|
||||||
|
|
||||||
|
if exists(params_fn) then
|
||||||
|
network:load(params_fn)
|
||||||
end
|
end
|
||||||
|
|
||||||
if cfg.es == 'xnes' then
|
if cfg.es == 'xnes' then
|
||||||
|
@ -488,7 +493,6 @@ local function init()
|
||||||
-- TODO: clean this up into an interface:
|
-- TODO: clean this up into an interface:
|
||||||
es.mean_adapt = cfg.mean_adapt
|
es.mean_adapt = cfg.mean_adapt
|
||||||
|
|
||||||
local std_fn = cfg.params_fn:gsub(".txt", "")..".sigma.txt"
|
|
||||||
if exists(std_fn) then
|
if exists(std_fn) then
|
||||||
local f = assert(open(std_fn, "r"))
|
local f = assert(open(std_fn, "r"))
|
||||||
for i=1, network.n_param do
|
for i=1, network.n_param do
|
||||||
|
@ -496,7 +500,6 @@ local function init()
|
||||||
end
|
end
|
||||||
f:close()
|
f:close()
|
||||||
end
|
end
|
||||||
|
|
||||||
elseif cfg.es == 'ars' then
|
elseif cfg.es == 'ars' then
|
||||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||||
|
|
Loading…
Add table
Reference in a new issue