refactor ARS out of main (breaks a bunch of stuff)
This commit is contained in:
parent
d3e6441c40
commit
fe9494b0d5
4 changed files with 269 additions and 196 deletions
|
@ -3,6 +3,7 @@ please be mindful when sharing it.
|
|||
however, feel free to copy any snippets of code you find useful.
|
||||
|
||||
TODOs: (that i can remember right now)
|
||||
- normalize `for i=a,b` code style
|
||||
- normalize and/or embed sprite type inputs
|
||||
- settle on a network architecture
|
||||
- compute how many input neurons the network needs instead of hardcoding
|
||||
|
|
209
ars.lua
Normal file
209
ars.lua
Normal file
|
@ -0,0 +1,209 @@
|
|||
-- Augmented Random Search
|
||||
-- https://arxiv.org/abs/1803.07055
|
||||
-- with some tweaks (lips) by myself.
|
||||
|
||||
local abs = math.abs
|
||||
local floor = math.floor
|
||||
local ipairs = ipairs
|
||||
local max = math.max
|
||||
local print = print
|
||||
|
||||
local Base = require "Base"
|
||||
|
||||
local nn = require "nn"
|
||||
local normal = nn.normal
|
||||
local prod = nn.prod
|
||||
local zeros = nn.zeros
|
||||
|
||||
local util = require "util"
|
||||
local argsort = util.argsort
|
||||
local calc_mean_dev = util.calc_mean_dev
|
||||
|
||||
local Ars = Base:extend()
|
||||
|
||||
local function collect_best_indices(scored, top, antithetic)
|
||||
-- select one (the best) reward of each pos/neg pair.
|
||||
local best_rewards
|
||||
if antithetic then
|
||||
best_rewards = {}
|
||||
for i = 1, #scored, 2 do
|
||||
local ind = floor(i / 2) + 1
|
||||
local pos = scored[i + 0]
|
||||
local neg = scored[i + 1]
|
||||
best_rewards[ind] = max(pos, neg)
|
||||
end
|
||||
else
|
||||
best_rewards = scored
|
||||
end
|
||||
|
||||
local indices = argsort(best_rewards, function(a, b) return a > b end)
|
||||
|
||||
for i = top + 1, #best_rewards do indices[i] = nil end
|
||||
return indices
|
||||
end
|
||||
|
||||
local function kinda_lipschitz(dir, pos, neg, mid)
|
||||
local _, dev = calc_mean_dev(dir)
|
||||
local c0 = neg - mid
|
||||
local c1 = pos - mid
|
||||
local l0 = abs(3 * c1 + c0)
|
||||
local l1 = abs(c1 + 3 * c0)
|
||||
return max(l0, l1) / (2 * dev)
|
||||
end
|
||||
|
||||
local function amsgrad(step) -- in-place! -- TODO: fix this.
|
||||
if mom1 == nil then mom1 = nn.zeros(#step) end
|
||||
if mom2 == nil then mom2 = nn.zeros(#step) end
|
||||
if mom2max == nil then mom2max = nn.zeros(#step) end
|
||||
|
||||
local b1_t = pow(cfg.adam_b1, epoch_i)
|
||||
local b2_t = pow(cfg.adam_b2, epoch_i)
|
||||
|
||||
-- NOTE: with LuaJIT, splitting this loop would
|
||||
-- almost certainly be faster.
|
||||
for i, v in ipairs(step) do
|
||||
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
|
||||
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
|
||||
mom2max[i] = max(mom2[i], mom2max[i])
|
||||
if cfg.adam_debias then
|
||||
local num = (mom1[i] / (1 - b1_t))
|
||||
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
|
||||
step[i] = num / den
|
||||
else
|
||||
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic)
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
self.sigma = sigma or 1
|
||||
self.antithetic = antithetic and true or false
|
||||
|
||||
self.poptop = poptop or popsize
|
||||
assert(self.poptop <= popsize)
|
||||
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||
|
||||
self._params = nn.zeros(self.dims)
|
||||
end
|
||||
|
||||
function Ars:params(new_params)
|
||||
if new_params ~= nil then
|
||||
assert(#self._params == #new_params, "new parameters have the wrong size")
|
||||
for i, v in ipairs(new_params) do self._params[i] = v end
|
||||
end
|
||||
return self._params
|
||||
end
|
||||
|
||||
function Ars:ask(graycode)
|
||||
local asked = {}
|
||||
local noise = {}
|
||||
|
||||
for i = 1, self.popsize do
|
||||
local asking = zeros(self.dims)
|
||||
local noisy = zeros(self.dims)
|
||||
asked[i] = asking
|
||||
|
||||
if self.antithetic and i % 2 == 0 then
|
||||
for j, v in ipairs(self._params) do
|
||||
asking[i] = v - noisy[j]
|
||||
end
|
||||
|
||||
else
|
||||
if graycode ~= nil then
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = exp(-precision * nn.uniform())
|
||||
end
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = nn.uniform() < 0.5 and noisy[j] or -noisy[j]
|
||||
end
|
||||
else
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = self.sigma * nn.normal()
|
||||
end
|
||||
end
|
||||
|
||||
for j, v in ipairs(self._params) do
|
||||
asking[j] = v + noisy[j]
|
||||
end
|
||||
end
|
||||
|
||||
noise[i] = noisy
|
||||
end
|
||||
|
||||
self.noise = noise
|
||||
return asked, noise
|
||||
end
|
||||
|
||||
function Ars:tell(scored, unperturbed_score)
|
||||
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
|
||||
--print("best trials:", indices)
|
||||
|
||||
local top_rewards = {}
|
||||
for i = 1, #scored do top_rewards[i] = 0 end
|
||||
for _, ind in ipairs(indices) do
|
||||
local sind = (ind - 1) * 2 + 1
|
||||
top_rewards[sind + 0] = scored[sind + 0]
|
||||
top_rewards[sind + 1] = scored[sind + 1]
|
||||
end
|
||||
--print("top:", top_rewards)
|
||||
|
||||
if self.antithetic then
|
||||
local top_delta_rewards = {} -- only used for printing.
|
||||
for i, ind in ipairs(indices) do
|
||||
local sind = (ind - 1) * 2 + 1
|
||||
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
|
||||
end
|
||||
--print("best deltas:", top_delta_rewards)
|
||||
end
|
||||
|
||||
local step = nn.zeros(self.dims)
|
||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
if self.antithetic then
|
||||
for i = 1, floor(self.popsize / 2) do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = top_rewards[ind + 0]
|
||||
local neg = top_rewards[ind + 1]
|
||||
local reward = pos - neg
|
||||
if reward ~= 0 then
|
||||
local noisy = self.noise[i]
|
||||
|
||||
if unperturbed_score ~= nil then
|
||||
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
|
||||
reward = reward / lips / self.sigma
|
||||
else
|
||||
reward = reward / reward_dev
|
||||
end
|
||||
|
||||
for j, v in ipairs(noisy) do
|
||||
step[j] = step[j] + reward * v / self.poptop
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
for i = 1, self.popsize do
|
||||
local reward = top_rewards[i] / reward_dev
|
||||
if reward ~= 0 then
|
||||
local noisy = self.noise[i]
|
||||
|
||||
for j, v in ipairs(noisy) do
|
||||
step[j] = step[j] + reward * v / self.poptop
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v + self.learning_rate * step[i]
|
||||
end
|
||||
|
||||
self.asked = nil
|
||||
end
|
||||
|
||||
return {
|
||||
Ars = Ars,
|
||||
}
|
|
@ -28,10 +28,11 @@ local common_cfg = {
|
|||
graycode = false,
|
||||
unperturbed_trial = true, -- do a trial without any noise.
|
||||
negate_trials = true, -- try pairs of normal and negated noise directions.
|
||||
-- ^ note that this now doubles the effective trials.
|
||||
-- AKA antithetic sampling. note that this doubles the number of trials.
|
||||
time_inputs = true, -- binary inputs of global frame count
|
||||
normalize_inputs = false,
|
||||
|
||||
es = 'ars',
|
||||
ars_lips = false,
|
||||
adamant = false, -- run steps through AMSgrad.
|
||||
adam_b1 = math.pow(10, -1 / 1), -- fewer trials, more momentum!
|
||||
|
@ -90,4 +91,7 @@ assert(not cfg.ars_lips or cfg.unperturbed_trial,
|
|||
assert(not cfg.ars_lips or cfg.negate_trials,
|
||||
"cfg.negate_trials must be true to use cfg.ars_lips")
|
||||
|
||||
assert(not cfg.adamant,
|
||||
"cfg.adamant not yet re-implemented")
|
||||
|
||||
return cfg
|
||||
|
|
247
main.lua
247
main.lua
|
@ -11,12 +11,13 @@ local epoch_i = 0
|
|||
local base_params
|
||||
local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled.
|
||||
local trial_neg = true
|
||||
local trial_noise = {}
|
||||
local trial_params --= {}
|
||||
local trial_rewards = {}
|
||||
local trials_remaining = 0
|
||||
local mom1 -- first moments in AMSgrad.
|
||||
local mom2 -- second moments in AMSgrad.
|
||||
local mom2max -- running element-wise maximum of mom2.
|
||||
local es -- evolution strategy.
|
||||
|
||||
local trial_frames = 0
|
||||
local total_frames = 0
|
||||
|
@ -35,7 +36,6 @@ local jp
|
|||
|
||||
local screen_scroll_delta
|
||||
local reward
|
||||
--local all_rewards = {}
|
||||
|
||||
local powerup_old
|
||||
local status_old
|
||||
|
@ -172,191 +172,50 @@ end
|
|||
|
||||
-- learning and evaluation.
|
||||
|
||||
local ars = require("ars")
|
||||
|
||||
local function prepare_epoch()
|
||||
trial_neg = false
|
||||
|
||||
base_params = network:collect()
|
||||
if cfg.playback_mode then return end
|
||||
|
||||
print('preparing epoch '..tostring(epoch_i)..'.')
|
||||
empty(trial_noise)
|
||||
empty(trial_rewards)
|
||||
|
||||
local precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392
|
||||
local precision
|
||||
if cfg.graycode then
|
||||
precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392
|
||||
print(("chosen precision: %.2f"):format(precision))
|
||||
end
|
||||
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local noise = nn.zeros(#base_params)
|
||||
if cfg.graycode then
|
||||
for j = 1, #base_params do
|
||||
noise[j] = exp(-precision * nn.uniform())
|
||||
end
|
||||
for j = 1, #base_params do
|
||||
noise[j] = nn.uniform() < 0.5 and noise[j] or -noise[j]
|
||||
end
|
||||
local dummy
|
||||
if es == 'ars' then
|
||||
trial_params, dummy = es:ask(precision)
|
||||
else
|
||||
for j = 1, #base_params do
|
||||
noise[j] = cfg.deviation * nn.normal()
|
||||
end
|
||||
end
|
||||
trial_noise[i] = noise
|
||||
trial_params, dummy = es:ask()
|
||||
end
|
||||
|
||||
trial_i = -1
|
||||
end
|
||||
|
||||
local function load_next_pair()
|
||||
trial_i = trial_i + 1
|
||||
if trial_i == 0 and not cfg.unperturbed_trial then
|
||||
trial_i = 1
|
||||
trial_neg = true
|
||||
end
|
||||
|
||||
local W = copy(base_params)
|
||||
|
||||
if trial_i > 0 then
|
||||
if trial_neg then
|
||||
local noise = trial_noise[trial_i]
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v + noise[i]
|
||||
end
|
||||
|
||||
else
|
||||
trial_i = trial_i - 1
|
||||
local noise = trial_noise[trial_i]
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v - noise[i]
|
||||
end
|
||||
end
|
||||
|
||||
trial_neg = not trial_neg
|
||||
end
|
||||
|
||||
network:distribute(W)
|
||||
end
|
||||
|
||||
local function load_next_trial()
|
||||
if cfg.negate_trials then return load_next_pair() end
|
||||
if cfg.negate_trials then trial_neg = not trial_neg end
|
||||
trial_i = trial_i + 1
|
||||
local W = copy(base_params)
|
||||
if trial_i == 0 and not cfg.unperturbed_trial then
|
||||
trial_i = 1
|
||||
end
|
||||
if trial_i > 0 then
|
||||
print('loading trial', trial_i)
|
||||
local noise = trial_noise[trial_i]
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v + noise[i]
|
||||
end
|
||||
--print('loading trial', trial_i)
|
||||
network:distribute(trial_params[trial_i])
|
||||
else
|
||||
print("test trial")
|
||||
end
|
||||
network:distribute(W)
|
||||
end
|
||||
|
||||
local function collect_best_indices()
|
||||
-- select one (the best) reward of each pos/neg pair.
|
||||
local best_rewards = {}
|
||||
if cfg.negate_trials then
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = trial_rewards[ind + 0]
|
||||
local neg = trial_rewards[ind + 1]
|
||||
best_rewards[i] = max(pos, neg)
|
||||
end
|
||||
else
|
||||
best_rewards = copy(trial_rewards)
|
||||
end
|
||||
|
||||
local indices = argsort(best_rewards, function(a, b) return a > b end)
|
||||
|
||||
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
||||
return indices
|
||||
end
|
||||
|
||||
local function kinda_lipschitz(dir, pos, neg, mid)
|
||||
local _, dev = calc_mean_dev(dir)
|
||||
local c0 = neg - mid
|
||||
local c1 = pos - mid
|
||||
local l0 = abs(3 * c1 + c0)
|
||||
local l1 = abs(c1 + 3 * c0)
|
||||
return max(l0, l1) / (2 * dev)
|
||||
end
|
||||
|
||||
local function make_step_paired(rewards, current_cost)
|
||||
local step = nn.zeros(#base_params)
|
||||
local _, reward_dev = calc_mean_dev(rewards)
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = rewards[ind + 0]
|
||||
local neg = rewards[ind + 1]
|
||||
local reward = pos - neg
|
||||
if reward ~= 0 then
|
||||
local noise = trial_noise[i]
|
||||
|
||||
if cfg.ars_lips then
|
||||
local lips = kinda_lipschitz(noise, pos, neg, current_cost)
|
||||
reward = reward / lips / cfg.deviation
|
||||
else
|
||||
reward = reward / reward_dev
|
||||
end
|
||||
|
||||
for j, v in ipairs(noise) do
|
||||
step[j] = step[j] + reward * v / cfg.epoch_top_trials
|
||||
end
|
||||
end
|
||||
end
|
||||
return step
|
||||
end
|
||||
|
||||
local function make_step(rewards)
|
||||
local step = nn.zeros(#base_params)
|
||||
local _, reward_dev = calc_mean_dev(rewards)
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local reward = rewards[i] / reward_dev
|
||||
if reward ~= 0 then
|
||||
local noise = trial_noise[i]
|
||||
|
||||
for j, v in ipairs(noise) do
|
||||
step[j] = step[j] + reward * v / cfg.epoch_top_trials
|
||||
end
|
||||
end
|
||||
end
|
||||
return step
|
||||
end
|
||||
|
||||
local function amsgrad(step) -- in-place!
|
||||
if mom1 == nil then mom1 = nn.zeros(#step) end
|
||||
if mom2 == nil then mom2 = nn.zeros(#step) end
|
||||
if mom2max == nil then mom2max = nn.zeros(#step) end
|
||||
|
||||
local b1_t = pow(cfg.adam_b1, epoch_i)
|
||||
local b2_t = pow(cfg.adam_b2, epoch_i)
|
||||
|
||||
-- NOTE: with LuaJIT, splitting this loop would
|
||||
-- almost certainly be faster.
|
||||
for i, v in ipairs(step) do
|
||||
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
|
||||
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
|
||||
mom2max[i] = max(mom2[i], mom2max[i])
|
||||
if cfg.adam_debias then
|
||||
local num = (mom1[i] / (1 - b1_t))
|
||||
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
|
||||
step[i] = num / den
|
||||
else
|
||||
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
|
||||
end
|
||||
--print("test trial")
|
||||
network:distribute(base_params)
|
||||
end
|
||||
end
|
||||
|
||||
local function learn_from_epoch()
|
||||
print()
|
||||
--print('rewards:', trial_rewards)
|
||||
|
||||
--for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end
|
||||
|
||||
local current_cost = trial_rewards[0] -- may be nil!
|
||||
|
||||
|
@ -369,58 +228,45 @@ local function learn_from_epoch()
|
|||
|
||||
local delta_rewards = {} -- only used for logging.
|
||||
if cfg.negate_trials then
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = trial_rewards[ind + 0]
|
||||
local neg = trial_rewards[ind + 1]
|
||||
delta_rewards[i] = abs(pos - neg)
|
||||
for i = 1, #trial_rewards, 2 do
|
||||
local ind = floor(i / 2) + 1
|
||||
local pos = trial_rewards[i + 0]
|
||||
local neg = trial_rewards[i + 1]
|
||||
delta_rewards[ind] = abs(pos - neg)
|
||||
end
|
||||
end
|
||||
|
||||
local indices = collect_best_indices()
|
||||
print("best trials:", indices)
|
||||
|
||||
local top_rewards = {}
|
||||
for i = 1, #trial_rewards do top_rewards[i] = 0 end
|
||||
for _, ind in ipairs(indices) do
|
||||
local sind = (ind - 1) * 2 + 1
|
||||
top_rewards[sind + 0] = trial_rewards[sind + 0]
|
||||
top_rewards[sind + 1] = trial_rewards[sind + 1]
|
||||
end
|
||||
--print("top:", top_rewards)
|
||||
|
||||
if cfg.negate_trials then
|
||||
local top_delta_rewards = {} -- only used for printing.
|
||||
for i, ind in ipairs(indices) do
|
||||
local sind = (ind - 1) * 2 + 1
|
||||
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
|
||||
end
|
||||
print("best deltas:", top_delta_rewards)
|
||||
end
|
||||
|
||||
local step
|
||||
if cfg.negate_trials then
|
||||
step = make_step_paired(top_rewards, current_cost)
|
||||
if es == 'ars' then
|
||||
es:tell(trial_rewards, current_cost)
|
||||
else
|
||||
step = make_step(top_rewards)
|
||||
es:tell(trial_rewards)
|
||||
end
|
||||
|
||||
local step_mean, step_dev = 0, 0
|
||||
--[[ TODO
|
||||
local step_mean, step_dev = calc_mean_dev(step)
|
||||
print("step mean:", step_mean)
|
||||
print("step stddev:", step_dev)
|
||||
--]]
|
||||
|
||||
local momstep_mean, momstep_dev = 0, 0
|
||||
--[[ TODO
|
||||
if cfg.adamant then
|
||||
amsgrad(step)
|
||||
momstep_mean, momstep_dev = calc_mean_dev(step)
|
||||
print("amsgrad mean:", momstep_mean)
|
||||
print("amsgrad stddev:", momstep_dev)
|
||||
end
|
||||
--]]
|
||||
|
||||
base_params = es:params()
|
||||
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v
|
||||
base_params[i] = v * (1 - cfg.weight_decay)
|
||||
end
|
||||
|
||||
es:params(base_params)
|
||||
|
||||
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
|
||||
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
|
||||
local weight_mean, weight_std = calc_mean_dev(base_params)
|
||||
|
@ -465,6 +311,7 @@ local function joypad_mash(button)
|
|||
end
|
||||
|
||||
local function loadlevel(world, level)
|
||||
-- TODO: move to smb.lua. rename to load_level.
|
||||
if world == 0 then world = random(1, 8) end
|
||||
if level == 0 then level = random(1, 4) end
|
||||
emu.poweron()
|
||||
|
@ -499,7 +346,8 @@ local function do_reset()
|
|||
local pos = trial_rewards[#trial_rewards]
|
||||
local neg = reward
|
||||
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
|
||||
print(fmt:format(trial_i, pos, neg, last_trial_state, state))
|
||||
print(fmt:format(floor(trial_i / 2),
|
||||
pos, neg, last_trial_state, state))
|
||||
end
|
||||
last_trial_state = state
|
||||
else
|
||||
|
@ -517,7 +365,7 @@ local function do_reset()
|
|||
end
|
||||
end
|
||||
|
||||
if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then
|
||||
if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then
|
||||
if epoch_i > 0 then learn_from_epoch() end
|
||||
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
|
||||
prepare_epoch()
|
||||
|
@ -527,6 +375,11 @@ local function do_reset()
|
|||
end
|
||||
end
|
||||
|
||||
max_time = min(6 * sqrt(480 / #trial_params * (epoch_i - 1)) + 60, cfg.cap_time)
|
||||
max_time = ceil(max_time)
|
||||
|
||||
-- TODO: game.reset(cfg.starting_lives, cfg.start_big)
|
||||
|
||||
if game.get_state() == 'loading' then game.advance() end -- kind of a hack.
|
||||
reward = 0
|
||||
powerup_old = game.R(0x754)
|
||||
|
@ -543,8 +396,7 @@ local function do_reset()
|
|||
game.W(0x756, 1)
|
||||
end
|
||||
|
||||
max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time)
|
||||
max_time = ceil(max_time)
|
||||
-- end of game.reset()
|
||||
|
||||
if state_saved then
|
||||
savestate.load(startsave)
|
||||
|
@ -585,6 +437,13 @@ local function init()
|
|||
|
||||
local res, err = pcall(network.load, network, cfg.params_fn)
|
||||
if res == false then print(err) end
|
||||
|
||||
if cfg.es == 'ars' then
|
||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||
else
|
||||
error("Unknown evolution strategy specified: " + tostring(cfg.es))
|
||||
end
|
||||
end
|
||||
|
||||
local function prepare_reset()
|
||||
|
|
Loading…
Add table
Reference in a new issue