refactor ARS out of main (breaks a bunch of stuff)

This commit is contained in:
Connor Olding 2018-06-09 17:56:18 +02:00
parent d3e6441c40
commit fe9494b0d5
4 changed files with 269 additions and 196 deletions

View file

@ -3,6 +3,7 @@ please be mindful when sharing it.
however, feel free to copy any snippets of code you find useful. however, feel free to copy any snippets of code you find useful.
TODOs: (that i can remember right now) TODOs: (that i can remember right now)
- normalize `for i=a,b` code style
- normalize and/or embed sprite type inputs - normalize and/or embed sprite type inputs
- settle on a network architecture - settle on a network architecture
- compute how many input neurons the network needs instead of hardcoding - compute how many input neurons the network needs instead of hardcoding

209
ars.lua Normal file
View file

@ -0,0 +1,209 @@
-- Augmented Random Search
-- https://arxiv.org/abs/1803.07055
-- with some tweaks (lips) by myself.
local abs = math.abs
local floor = math.floor
local ipairs = ipairs
local max = math.max
local print = print
local Base = require "Base"
local nn = require "nn"
local normal = nn.normal
local prod = nn.prod
local zeros = nn.zeros
local util = require "util"
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local Ars = Base:extend()
local function collect_best_indices(scored, top, antithetic)
-- select one (the best) reward of each pos/neg pair.
local best_rewards
if antithetic then
best_rewards = {}
for i = 1, #scored, 2 do
local ind = floor(i / 2) + 1
local pos = scored[i + 0]
local neg = scored[i + 1]
best_rewards[ind] = max(pos, neg)
end
else
best_rewards = scored
end
local indices = argsort(best_rewards, function(a, b) return a > b end)
for i = top + 1, #best_rewards do indices[i] = nil end
return indices
end
local function kinda_lipschitz(dir, pos, neg, mid)
local _, dev = calc_mean_dev(dir)
local c0 = neg - mid
local c1 = pos - mid
local l0 = abs(3 * c1 + c0)
local l1 = abs(c1 + 3 * c0)
return max(l0, l1) / (2 * dev)
end
local function amsgrad(step) -- in-place! -- TODO: fix this.
if mom1 == nil then mom1 = nn.zeros(#step) end
if mom2 == nil then mom2 = nn.zeros(#step) end
if mom2max == nil then mom2max = nn.zeros(#step) end
local b1_t = pow(cfg.adam_b1, epoch_i)
local b2_t = pow(cfg.adam_b2, epoch_i)
-- NOTE: with LuaJIT, splitting this loop would
-- almost certainly be faster.
for i, v in ipairs(step) do
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
mom2max[i] = max(mom2[i], mom2max[i])
if cfg.adam_debias then
local num = (mom1[i] / (1 - b1_t))
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
step[i] = num / den
else
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
end
end
end
function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic)
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.sigma = sigma or 1
self.antithetic = antithetic and true or false
self.poptop = poptop or popsize
assert(self.poptop <= popsize)
if self.antithetic then self.popsize = self.popsize * 2 end
self._params = nn.zeros(self.dims)
end
function Ars:params(new_params)
if new_params ~= nil then
assert(#self._params == #new_params, "new parameters have the wrong size")
for i, v in ipairs(new_params) do self._params[i] = v end
end
return self._params
end
function Ars:ask(graycode)
local asked = {}
local noise = {}
for i = 1, self.popsize do
local asking = zeros(self.dims)
local noisy = zeros(self.dims)
asked[i] = asking
if self.antithetic and i % 2 == 0 then
for j, v in ipairs(self._params) do
asking[i] = v - noisy[j]
end
else
if graycode ~= nil then
for j = 1, self.dims do
noisy[j] = exp(-precision * nn.uniform())
end
for j = 1, self.dims do
noisy[j] = nn.uniform() < 0.5 and noisy[j] or -noisy[j]
end
else
for j = 1, self.dims do
noisy[j] = self.sigma * nn.normal()
end
end
for j, v in ipairs(self._params) do
asking[j] = v + noisy[j]
end
end
noise[i] = noisy
end
self.noise = noise
return asked, noise
end
function Ars:tell(scored, unperturbed_score)
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
--print("best trials:", indices)
local top_rewards = {}
for i = 1, #scored do top_rewards[i] = 0 end
for _, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1
top_rewards[sind + 0] = scored[sind + 0]
top_rewards[sind + 1] = scored[sind + 1]
end
--print("top:", top_rewards)
if self.antithetic then
local top_delta_rewards = {} -- only used for printing.
for i, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
end
--print("best deltas:", top_delta_rewards)
end
local step = nn.zeros(self.dims)
local _, reward_dev = calc_mean_dev(top_rewards)
if reward_dev == 0 then reward_dev = 1 end
if self.antithetic then
for i = 1, floor(self.popsize / 2) do
local ind = (i - 1) * 2 + 1
local pos = top_rewards[ind + 0]
local neg = top_rewards[ind + 1]
local reward = pos - neg
if reward ~= 0 then
local noisy = self.noise[i]
if unperturbed_score ~= nil then
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
reward = reward / lips / self.sigma
else
reward = reward / reward_dev
end
for j, v in ipairs(noisy) do
step[j] = step[j] + reward * v / self.poptop
end
end
end
else
for i = 1, self.popsize do
local reward = top_rewards[i] / reward_dev
if reward ~= 0 then
local noisy = self.noise[i]
for j, v in ipairs(noisy) do
step[j] = step[j] + reward * v / self.poptop
end
end
end
end
for i, v in ipairs(self._params) do
self._params[i] = v + self.learning_rate * step[i]
end
self.asked = nil
end
return {
Ars = Ars,
}

View file

@ -28,10 +28,11 @@ local common_cfg = {
graycode = false, graycode = false,
unperturbed_trial = true, -- do a trial without any noise. unperturbed_trial = true, -- do a trial without any noise.
negate_trials = true, -- try pairs of normal and negated noise directions. negate_trials = true, -- try pairs of normal and negated noise directions.
-- ^ note that this now doubles the effective trials. -- AKA antithetic sampling. note that this doubles the number of trials.
time_inputs = true, -- binary inputs of global frame count time_inputs = true, -- binary inputs of global frame count
normalize_inputs = false, normalize_inputs = false,
es = 'ars',
ars_lips = false, ars_lips = false,
adamant = false, -- run steps through AMSgrad. adamant = false, -- run steps through AMSgrad.
adam_b1 = math.pow(10, -1 / 1), -- fewer trials, more momentum! adam_b1 = math.pow(10, -1 / 1), -- fewer trials, more momentum!
@ -90,4 +91,7 @@ assert(not cfg.ars_lips or cfg.unperturbed_trial,
assert(not cfg.ars_lips or cfg.negate_trials, assert(not cfg.ars_lips or cfg.negate_trials,
"cfg.negate_trials must be true to use cfg.ars_lips") "cfg.negate_trials must be true to use cfg.ars_lips")
assert(not cfg.adamant,
"cfg.adamant not yet re-implemented")
return cfg return cfg

247
main.lua
View file

@ -11,12 +11,13 @@ local epoch_i = 0
local base_params local base_params
local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled. local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled.
local trial_neg = true local trial_neg = true
local trial_noise = {} local trial_params --= {}
local trial_rewards = {} local trial_rewards = {}
local trials_remaining = 0 local trials_remaining = 0
local mom1 -- first moments in AMSgrad. local mom1 -- first moments in AMSgrad.
local mom2 -- second moments in AMSgrad. local mom2 -- second moments in AMSgrad.
local mom2max -- running element-wise maximum of mom2. local mom2max -- running element-wise maximum of mom2.
local es -- evolution strategy.
local trial_frames = 0 local trial_frames = 0
local total_frames = 0 local total_frames = 0
@ -35,7 +36,6 @@ local jp
local screen_scroll_delta local screen_scroll_delta
local reward local reward
--local all_rewards = {}
local powerup_old local powerup_old
local status_old local status_old
@ -172,191 +172,50 @@ end
-- learning and evaluation. -- learning and evaluation.
local ars = require("ars")
local function prepare_epoch() local function prepare_epoch()
trial_neg = false
base_params = network:collect() base_params = network:collect()
if cfg.playback_mode then return end if cfg.playback_mode then return end
print('preparing epoch '..tostring(epoch_i)..'.') print('preparing epoch '..tostring(epoch_i)..'.')
empty(trial_noise)
empty(trial_rewards) empty(trial_rewards)
local precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392 local precision
if cfg.graycode then if cfg.graycode then
precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392
print(("chosen precision: %.2f"):format(precision)) print(("chosen precision: %.2f"):format(precision))
end end
for i = 1, cfg.epoch_trials do local dummy
local noise = nn.zeros(#base_params) if es == 'ars' then
if cfg.graycode then trial_params, dummy = es:ask(precision)
for j = 1, #base_params do
noise[j] = exp(-precision * nn.uniform())
end
for j = 1, #base_params do
noise[j] = nn.uniform() < 0.5 and noise[j] or -noise[j]
end
else else
for j = 1, #base_params do trial_params, dummy = es:ask()
noise[j] = cfg.deviation * nn.normal()
end
end
trial_noise[i] = noise
end end
trial_i = -1 trial_i = -1
end end
local function load_next_pair()
trial_i = trial_i + 1
if trial_i == 0 and not cfg.unperturbed_trial then
trial_i = 1
trial_neg = true
end
local W = copy(base_params)
if trial_i > 0 then
if trial_neg then
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v + noise[i]
end
else
trial_i = trial_i - 1
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v - noise[i]
end
end
trial_neg = not trial_neg
end
network:distribute(W)
end
local function load_next_trial() local function load_next_trial()
if cfg.negate_trials then return load_next_pair() end if cfg.negate_trials then trial_neg = not trial_neg end
trial_i = trial_i + 1 trial_i = trial_i + 1
local W = copy(base_params)
if trial_i == 0 and not cfg.unperturbed_trial then if trial_i == 0 and not cfg.unperturbed_trial then
trial_i = 1 trial_i = 1
end end
if trial_i > 0 then if trial_i > 0 then
print('loading trial', trial_i) --print('loading trial', trial_i)
local noise = trial_noise[trial_i] network:distribute(trial_params[trial_i])
for i, v in ipairs(base_params) do
W[i] = v + noise[i]
end
else else
print("test trial") --print("test trial")
end network:distribute(base_params)
network:distribute(W)
end
local function collect_best_indices()
-- select one (the best) reward of each pos/neg pair.
local best_rewards = {}
if cfg.negate_trials then
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = trial_rewards[ind + 0]
local neg = trial_rewards[ind + 1]
best_rewards[i] = max(pos, neg)
end
else
best_rewards = copy(trial_rewards)
end
local indices = argsort(best_rewards, function(a, b) return a > b end)
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
return indices
end
local function kinda_lipschitz(dir, pos, neg, mid)
local _, dev = calc_mean_dev(dir)
local c0 = neg - mid
local c1 = pos - mid
local l0 = abs(3 * c1 + c0)
local l1 = abs(c1 + 3 * c0)
return max(l0, l1) / (2 * dev)
end
local function make_step_paired(rewards, current_cost)
local step = nn.zeros(#base_params)
local _, reward_dev = calc_mean_dev(rewards)
if reward_dev == 0 then reward_dev = 1 end
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = rewards[ind + 0]
local neg = rewards[ind + 1]
local reward = pos - neg
if reward ~= 0 then
local noise = trial_noise[i]
if cfg.ars_lips then
local lips = kinda_lipschitz(noise, pos, neg, current_cost)
reward = reward / lips / cfg.deviation
else
reward = reward / reward_dev
end
for j, v in ipairs(noise) do
step[j] = step[j] + reward * v / cfg.epoch_top_trials
end
end
end
return step
end
local function make_step(rewards)
local step = nn.zeros(#base_params)
local _, reward_dev = calc_mean_dev(rewards)
if reward_dev == 0 then reward_dev = 1 end
for i = 1, cfg.epoch_trials do
local reward = rewards[i] / reward_dev
if reward ~= 0 then
local noise = trial_noise[i]
for j, v in ipairs(noise) do
step[j] = step[j] + reward * v / cfg.epoch_top_trials
end
end
end
return step
end
local function amsgrad(step) -- in-place!
if mom1 == nil then mom1 = nn.zeros(#step) end
if mom2 == nil then mom2 = nn.zeros(#step) end
if mom2max == nil then mom2max = nn.zeros(#step) end
local b1_t = pow(cfg.adam_b1, epoch_i)
local b2_t = pow(cfg.adam_b2, epoch_i)
-- NOTE: with LuaJIT, splitting this loop would
-- almost certainly be faster.
for i, v in ipairs(step) do
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
mom2max[i] = max(mom2[i], mom2max[i])
if cfg.adam_debias then
local num = (mom1[i] / (1 - b1_t))
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
step[i] = num / den
else
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
end
end end
end end
local function learn_from_epoch() local function learn_from_epoch()
print() print()
--print('rewards:', trial_rewards)
--for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end
local current_cost = trial_rewards[0] -- may be nil! local current_cost = trial_rewards[0] -- may be nil!
@ -369,58 +228,45 @@ local function learn_from_epoch()
local delta_rewards = {} -- only used for logging. local delta_rewards = {} -- only used for logging.
if cfg.negate_trials then if cfg.negate_trials then
for i = 1, cfg.epoch_trials do for i = 1, #trial_rewards, 2 do
local ind = (i - 1) * 2 + 1 local ind = floor(i / 2) + 1
local pos = trial_rewards[ind + 0] local pos = trial_rewards[i + 0]
local neg = trial_rewards[ind + 1] local neg = trial_rewards[i + 1]
delta_rewards[i] = abs(pos - neg) delta_rewards[ind] = abs(pos - neg)
end end
end end
local indices = collect_best_indices() if es == 'ars' then
print("best trials:", indices) es:tell(trial_rewards, current_cost)
local top_rewards = {}
for i = 1, #trial_rewards do top_rewards[i] = 0 end
for _, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1
top_rewards[sind + 0] = trial_rewards[sind + 0]
top_rewards[sind + 1] = trial_rewards[sind + 1]
end
--print("top:", top_rewards)
if cfg.negate_trials then
local top_delta_rewards = {} -- only used for printing.
for i, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
end
print("best deltas:", top_delta_rewards)
end
local step
if cfg.negate_trials then
step = make_step_paired(top_rewards, current_cost)
else else
step = make_step(top_rewards) es:tell(trial_rewards)
end end
local step_mean, step_dev = 0, 0
--[[ TODO
local step_mean, step_dev = calc_mean_dev(step) local step_mean, step_dev = calc_mean_dev(step)
print("step mean:", step_mean) print("step mean:", step_mean)
print("step stddev:", step_dev) print("step stddev:", step_dev)
--]]
local momstep_mean, momstep_dev = 0, 0 local momstep_mean, momstep_dev = 0, 0
--[[ TODO
if cfg.adamant then if cfg.adamant then
amsgrad(step) amsgrad(step)
momstep_mean, momstep_dev = calc_mean_dev(step) momstep_mean, momstep_dev = calc_mean_dev(step)
print("amsgrad mean:", momstep_mean) print("amsgrad mean:", momstep_mean)
print("amsgrad stddev:", momstep_dev) print("amsgrad stddev:", momstep_dev)
end end
--]]
base_params = es:params()
for i, v in ipairs(base_params) do for i, v in ipairs(base_params) do
base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v base_params[i] = v * (1 - cfg.weight_decay)
end end
es:params(base_params)
local trial_mean, trial_std = calc_mean_dev(trial_rewards) local trial_mean, trial_std = calc_mean_dev(trial_rewards)
local delta_mean, delta_std = calc_mean_dev(delta_rewards) local delta_mean, delta_std = calc_mean_dev(delta_rewards)
local weight_mean, weight_std = calc_mean_dev(base_params) local weight_mean, weight_std = calc_mean_dev(base_params)
@ -465,6 +311,7 @@ local function joypad_mash(button)
end end
local function loadlevel(world, level) local function loadlevel(world, level)
-- TODO: move to smb.lua. rename to load_level.
if world == 0 then world = random(1, 8) end if world == 0 then world = random(1, 8) end
if level == 0 then level = random(1, 4) end if level == 0 then level = random(1, 4) end
emu.poweron() emu.poweron()
@ -499,7 +346,8 @@ local function do_reset()
local pos = trial_rewards[#trial_rewards] local pos = trial_rewards[#trial_rewards]
local neg = reward local neg = reward
local fmt = "trial %i rewards: %+i, %+i (%s, %s)" local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
print(fmt:format(trial_i, pos, neg, last_trial_state, state)) print(fmt:format(floor(trial_i / 2),
pos, neg, last_trial_state, state))
end end
last_trial_state = state last_trial_state = state
else else
@ -517,7 +365,7 @@ local function do_reset()
end end
end end
if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then
if epoch_i > 0 then learn_from_epoch() end if epoch_i > 0 then learn_from_epoch() end
if not cfg.playback_mode then epoch_i = epoch_i + 1 end if not cfg.playback_mode then epoch_i = epoch_i + 1 end
prepare_epoch() prepare_epoch()
@ -527,6 +375,11 @@ local function do_reset()
end end
end end
max_time = min(6 * sqrt(480 / #trial_params * (epoch_i - 1)) + 60, cfg.cap_time)
max_time = ceil(max_time)
-- TODO: game.reset(cfg.starting_lives, cfg.start_big)
if game.get_state() == 'loading' then game.advance() end -- kind of a hack. if game.get_state() == 'loading' then game.advance() end -- kind of a hack.
reward = 0 reward = 0
powerup_old = game.R(0x754) powerup_old = game.R(0x754)
@ -543,8 +396,7 @@ local function do_reset()
game.W(0x756, 1) game.W(0x756, 1)
end end
max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time) -- end of game.reset()
max_time = ceil(max_time)
if state_saved then if state_saved then
savestate.load(startsave) savestate.load(startsave)
@ -585,6 +437,13 @@ local function init()
local res, err = pcall(network.load, network, cfg.params_fn) local res, err = pcall(network.load, network, cfg.params_fn)
if res == false then print(err) end if res == false then print(err) end
if cfg.es == 'ars' then
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
else
error("Unknown evolution strategy specified: " + tostring(cfg.es))
end
end end
local function prepare_reset() local function prepare_reset()