From 50a7ba78f9188e52b276fa599fc92b5cc9ec0458 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Tue, 12 Jun 2018 05:36:24 +0200 Subject: [PATCH] make filenames local to main --- main.lua | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/main.lua b/main.lua index 8db1d99..76dfa4b 100644 --- a/main.lua +++ b/main.lua @@ -7,6 +7,9 @@ local gcfg = require("gameconfig") -- state. +local params_fn +local std_fn + local epoch_i = 0 local base_params local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled. @@ -313,10 +316,9 @@ local function learn_from_epoch() if cfg.enable_network then network:distribute(base_params) - network:save(cfg.params_fn) + network:save(params_fn) if cfg.es == 'snes' then - local std_fn = cfg.params_fn:gsub(".txt", "")..".sigma.txt" local f = assert(open(std_fn, "w")) for _, v in ipairs(es.std) do f:write(("%f\n"):format(v)) end f:close() @@ -471,8 +473,11 @@ local function init() loadlevel(cfg.starting_world, cfg.starting_level) end - if exists(cfg.params_fn) then - network:load(cfg.params_fn) + params_fn = cfg.params_fn or ('network%07i.txt'):format(network.n_param) + std_fn = params_fn:gsub(".txt", "")..".sigma.txt" + + if exists(params_fn) then + network:load(params_fn) end if cfg.es == 'xnes' then @@ -488,7 +493,6 @@ local function init() -- TODO: clean this up into an interface: es.mean_adapt = cfg.mean_adapt - local std_fn = cfg.params_fn:gsub(".txt", "")..".sigma.txt" if exists(std_fn) then local f = assert(open(std_fn, "r")) for i=1, network.n_param do @@ -496,7 +500,6 @@ local function init() end f:close() end - elseif cfg.es == 'ars' then es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, cfg.learning_rate, cfg.deviation, cfg.negate_trials)