refactor learn_from_epoch

This commit is contained in:
Connor Olding 2018-05-14 01:34:08 +02:00
parent ec19774af5
commit 3030e83d00

208
main.lua
View file

@ -48,6 +48,7 @@ local last_trial_state
-- localize some stuff. -- localize some stuff.
local assert = assert
local print = print local print = print
local ipairs = ipairs local ipairs = ipairs
local pairs = pairs local pairs = pairs
@ -79,6 +80,7 @@ local arshift = bit.arshift
local rol = bit.rol local rol = bit.rol
local ror = bit.ror local ror = bit.ror
local emu = emu
local gui = gui local gui = gui
local util = require("util") local util = require("util")
@ -177,8 +179,6 @@ local function prepare_epoch()
for i = 1, cfg.epoch_trials do for i = 1, cfg.epoch_trials do
local noise = nn.zeros(#base_params) local noise = nn.zeros(#base_params)
if cfg.graycode then if cfg.graycode then
--local precision = 1 / cfg.deviation
--print(cfg.deviation, precision)
for j = 1, #base_params do for j = 1, #base_params do
noise[j] = exp(-precision * nn.uniform()) noise[j] = exp(-precision * nn.uniform())
end end
@ -244,6 +244,107 @@ local function load_next_trial()
network:distribute(W) network:distribute(W)
end end
local function collect_best_indices()
-- select one (the best) reward of each pos/neg pair.
local best_rewards = {}
if cfg.negate_trials then
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = trial_rewards[ind + 0]
local neg = trial_rewards[ind + 1]
best_rewards[i] = max(pos, neg)
end
else
best_rewards = copy(trial_rewards)
end
local indices = {}
for i = 1, #best_rewards do indices[i] = i end
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
return indices
end
local function kinda_lipschitz(dir, pos, neg, mid)
local _, dev = calc_mean_dev(dir)
local c0 = neg - mid
local c1 = pos - mid
local l0 = abs(3 * c1 + c0)
local l1 = abs(c1 + 3 * c0)
return max(l0, l1) / (2 * dev)
end
local function make_step_paired(rewards, current_cost)
local step = nn.zeros(#base_params)
local _, reward_dev = calc_mean_dev(rewards)
if reward_dev == 0 then reward_dev = 1 end
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = rewards[ind + 0]
local neg = rewards[ind + 1]
local reward = pos - neg
if reward ~= 0 then
local noise = trial_noise[i]
if cfg.ars_lips then
local lips = kinda_lipschitz(noise, pos, neg, current_cost)
reward = reward / lips / cfg.deviation
else
reward = reward / reward_dev
end
for j, v in ipairs(noise) do
step[j] = step[j] + reward * v / cfg.epoch_top_trials
end
end
end
return step
end
local function make_step(rewards)
local step = nn.zeros(#base_params)
local _, reward_dev = calc_mean_dev(rewards)
if reward_dev == 0 then reward_dev = 1 end
for i = 1, cfg.epoch_trials do
local reward = rewards[i] / reward_dev
if reward ~= 0 then
local noise = trial_noise[i]
for j, v in ipairs(noise) do
step[j] = step[j] + reward * v / cfg.epoch_top_trials
end
end
end
return step
end
local function amsgrad(step) -- in-place!
if mom1 == nil then mom1 = nn.zeros(#step) end
if mom2 == nil then mom2 = nn.zeros(#step) end
if mom2max == nil then mom2max = nn.zeros(#step) end
local b1_t = pow(cfg.adam_b1, epoch_i)
local b2_t = pow(cfg.adam_b2, epoch_i)
-- NOTE: with LuaJIT, splitting this loop would
-- almost certainly be faster.
for i, v in ipairs(step) do
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
mom2max[i] = max(mom2[i], mom2max[i])
if cfg.adam_debias then
local num = (mom1[i] / (1 - b1_t))
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
step[i] = num / den
else
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
end
end
end
local function learn_from_epoch() local function learn_from_epoch()
print() print()
--print('rewards:', trial_rewards) --print('rewards:', trial_rewards)
@ -259,37 +360,17 @@ local function learn_from_epoch()
print(("test trial: %d out of %d"):format(nth_place, #trial_rewards)) print(("test trial: %d out of %d"):format(nth_place, #trial_rewards))
end end
local step = nn.zeros(#base_params) local delta_rewards = {} -- only used for logging.
local delta_rewards -- only used for logging.
local best_rewards
if cfg.negate_trials then if cfg.negate_trials then
delta_rewards = {}
for i = 1, cfg.epoch_trials do for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1 local ind = (i - 1) * 2 + 1
local pos = trial_rewards[ind + 0] local pos = trial_rewards[ind + 0]
local neg = trial_rewards[ind + 1] local neg = trial_rewards[ind + 1]
delta_rewards[i] = abs(pos - neg) delta_rewards[i] = abs(pos - neg)
end end
-- select one (the best) reward of each pos/neg pair.
best_rewards = {}
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = trial_rewards[ind + 0]
local neg = trial_rewards[ind + 1]
best_rewards[i] = max(pos, neg)
end
else
best_rewards = copy(trial_rewards)
end end
local indices = {} local indices = collect_best_indices()
for i = 1, #best_rewards do indices[i] = i end
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
--print("indices:", indices)
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
print("best trials:", indices) print("best trials:", indices)
local top_rewards = {} local top_rewards = {}
@ -301,77 +382,29 @@ local function learn_from_epoch()
end end
--print("top:", top_rewards) --print("top:", top_rewards)
local top_delta_rewards = {} -- only used for printing. if cfg.negate_trials then
for i, ind in ipairs(indices) do local top_delta_rewards = {} -- only used for printing.
local sind = (ind - 1) * 2 + 1 for i, ind in ipairs(indices) do
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1]) local sind = (ind - 1) * 2 + 1
end top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
print("best deltas:", top_delta_rewards)
if not cfg.ars_lips then
local _, reward_dev = calc_mean_dev(top_rewards)
--print("mean, dev:", _, reward_dev)
if reward_dev == 0 then reward_dev = 1 end
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
end
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = top_rewards[ind + 0]
local neg = top_rewards[ind + 1]
local reward = pos - neg
if reward ~= 0 then
local noise = trial_noise[i]
if cfg.ars_lips then
local _, dev = calc_mean_dev(noise)
local c0 = neg - current_cost
local c1 = pos - current_cost
local l0 = abs(3 * c1 + c0)
local l1 = abs(c1 + 3 * c0)
local lips = max(l0, l1) / (2 * dev)
--reward = pos / lips - neg / lips
local old_reward = reward
reward = reward / lips
reward = reward / cfg.deviation -- FIXME: hack?
--print(("trial %i reward: %.0f -> %.5f"):format(i, old_reward, reward))
end
for j, v in ipairs(noise) do
step[j] = step[j] + reward * v / cfg.epoch_top_trials
end
end end
print("best deltas:", top_delta_rewards)
end
local step
if cfg.negate_trials then
step = make_step_paired(top_rewards, current_cost)
else
step = make_step(top_rewards)
end end
local step_mean, step_dev = calc_mean_dev(step) local step_mean, step_dev = calc_mean_dev(step)
print("step mean:", step_mean) print("step mean:", step_mean)
print("step stddev:", step_dev) print("step stddev:", step_dev)
--print("full step stddev:", cfg.learning_rate * step_dev)
local momstep_mean, momstep_dev = 0, 0 local momstep_mean, momstep_dev = 0, 0
if cfg.adamant then if cfg.adamant then
if mom1 == nil then mom1 = nn.zeros(#step) end amsgrad(step)
if mom2 == nil then mom2 = nn.zeros(#step) end
if mom2max == nil then mom2max = nn.zeros(#step) end
local b1_t = pow(cfg.adam_b1, epoch_i)
local b2_t = pow(cfg.adam_b2, epoch_i)
-- NOTE: with LuaJIT, splitting this loop would
-- almost certainly be faster.
for i, v in ipairs(step) do
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
mom2max[i] = max(mom2[i], mom2max[i])
if cfg.adam_debias then
local num = (mom1[i] / (1 - b1_t))
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
step[i] = num / den
else
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
end
end
momstep_mean, momstep_dev = calc_mean_dev(step) momstep_mean, momstep_dev = calc_mean_dev(step)
print("amsgrad mean:", momstep_mean) print("amsgrad mean:", momstep_mean)
print("amsgrad stddev:", momstep_dev) print("amsgrad stddev:", momstep_dev)
@ -480,11 +513,8 @@ local function do_reset()
game.W(0x756, 1) game.W(0x756, 1)
end end
--max_time = min(log(epoch_i) * 10 + 100, cfg.cap_time)
--max_time = min(8 * sqrt(360 / cfg.epoch_trials * (epoch_i - 1)) + 100, cfg.cap_time)
max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time) max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time)
max_time = ceil(max_time) max_time = ceil(max_time)
--max_time = cfg.cap_time
if once then if once then
savestate.load(startsave) savestate.load(startsave)