refactor learn_from_epoch
This commit is contained in:
parent
ec19774af5
commit
3030e83d00
1 changed files with 119 additions and 89 deletions
208
main.lua
208
main.lua
|
@ -48,6 +48,7 @@ local last_trial_state
|
|||
|
||||
-- localize some stuff.
|
||||
|
||||
local assert = assert
|
||||
local print = print
|
||||
local ipairs = ipairs
|
||||
local pairs = pairs
|
||||
|
@ -79,6 +80,7 @@ local arshift = bit.arshift
|
|||
local rol = bit.rol
|
||||
local ror = bit.ror
|
||||
|
||||
local emu = emu
|
||||
local gui = gui
|
||||
|
||||
local util = require("util")
|
||||
|
@ -177,8 +179,6 @@ local function prepare_epoch()
|
|||
for i = 1, cfg.epoch_trials do
|
||||
local noise = nn.zeros(#base_params)
|
||||
if cfg.graycode then
|
||||
--local precision = 1 / cfg.deviation
|
||||
--print(cfg.deviation, precision)
|
||||
for j = 1, #base_params do
|
||||
noise[j] = exp(-precision * nn.uniform())
|
||||
end
|
||||
|
@ -244,6 +244,107 @@ local function load_next_trial()
|
|||
network:distribute(W)
|
||||
end
|
||||
|
||||
local function collect_best_indices()
|
||||
-- select one (the best) reward of each pos/neg pair.
|
||||
local best_rewards = {}
|
||||
if cfg.negate_trials then
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = trial_rewards[ind + 0]
|
||||
local neg = trial_rewards[ind + 1]
|
||||
best_rewards[i] = max(pos, neg)
|
||||
end
|
||||
else
|
||||
best_rewards = copy(trial_rewards)
|
||||
end
|
||||
|
||||
local indices = {}
|
||||
for i = 1, #best_rewards do indices[i] = i end
|
||||
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
|
||||
|
||||
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
||||
return indices
|
||||
end
|
||||
|
||||
local function kinda_lipschitz(dir, pos, neg, mid)
|
||||
local _, dev = calc_mean_dev(dir)
|
||||
local c0 = neg - mid
|
||||
local c1 = pos - mid
|
||||
local l0 = abs(3 * c1 + c0)
|
||||
local l1 = abs(c1 + 3 * c0)
|
||||
return max(l0, l1) / (2 * dev)
|
||||
end
|
||||
|
||||
local function make_step_paired(rewards, current_cost)
|
||||
local step = nn.zeros(#base_params)
|
||||
local _, reward_dev = calc_mean_dev(rewards)
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = rewards[ind + 0]
|
||||
local neg = rewards[ind + 1]
|
||||
local reward = pos - neg
|
||||
if reward ~= 0 then
|
||||
local noise = trial_noise[i]
|
||||
|
||||
if cfg.ars_lips then
|
||||
local lips = kinda_lipschitz(noise, pos, neg, current_cost)
|
||||
reward = reward / lips / cfg.deviation
|
||||
else
|
||||
reward = reward / reward_dev
|
||||
end
|
||||
|
||||
for j, v in ipairs(noise) do
|
||||
step[j] = step[j] + reward * v / cfg.epoch_top_trials
|
||||
end
|
||||
end
|
||||
end
|
||||
return step
|
||||
end
|
||||
|
||||
local function make_step(rewards)
|
||||
local step = nn.zeros(#base_params)
|
||||
local _, reward_dev = calc_mean_dev(rewards)
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local reward = rewards[i] / reward_dev
|
||||
if reward ~= 0 then
|
||||
local noise = trial_noise[i]
|
||||
|
||||
for j, v in ipairs(noise) do
|
||||
step[j] = step[j] + reward * v / cfg.epoch_top_trials
|
||||
end
|
||||
end
|
||||
end
|
||||
return step
|
||||
end
|
||||
|
||||
local function amsgrad(step) -- in-place!
|
||||
if mom1 == nil then mom1 = nn.zeros(#step) end
|
||||
if mom2 == nil then mom2 = nn.zeros(#step) end
|
||||
if mom2max == nil then mom2max = nn.zeros(#step) end
|
||||
|
||||
local b1_t = pow(cfg.adam_b1, epoch_i)
|
||||
local b2_t = pow(cfg.adam_b2, epoch_i)
|
||||
|
||||
-- NOTE: with LuaJIT, splitting this loop would
|
||||
-- almost certainly be faster.
|
||||
for i, v in ipairs(step) do
|
||||
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
|
||||
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
|
||||
mom2max[i] = max(mom2[i], mom2max[i])
|
||||
if cfg.adam_debias then
|
||||
local num = (mom1[i] / (1 - b1_t))
|
||||
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
|
||||
step[i] = num / den
|
||||
else
|
||||
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function learn_from_epoch()
|
||||
print()
|
||||
--print('rewards:', trial_rewards)
|
||||
|
@ -259,37 +360,17 @@ local function learn_from_epoch()
|
|||
print(("test trial: %d out of %d"):format(nth_place, #trial_rewards))
|
||||
end
|
||||
|
||||
local step = nn.zeros(#base_params)
|
||||
|
||||
local delta_rewards -- only used for logging.
|
||||
local best_rewards
|
||||
local delta_rewards = {} -- only used for logging.
|
||||
if cfg.negate_trials then
|
||||
delta_rewards = {}
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = trial_rewards[ind + 0]
|
||||
local neg = trial_rewards[ind + 1]
|
||||
delta_rewards[i] = abs(pos - neg)
|
||||
end
|
||||
|
||||
-- select one (the best) reward of each pos/neg pair.
|
||||
best_rewards = {}
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = trial_rewards[ind + 0]
|
||||
local neg = trial_rewards[ind + 1]
|
||||
best_rewards[i] = max(pos, neg)
|
||||
end
|
||||
else
|
||||
best_rewards = copy(trial_rewards)
|
||||
end
|
||||
|
||||
local indices = {}
|
||||
for i = 1, #best_rewards do indices[i] = i end
|
||||
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
|
||||
|
||||
--print("indices:", indices)
|
||||
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
||||
local indices = collect_best_indices()
|
||||
print("best trials:", indices)
|
||||
|
||||
local top_rewards = {}
|
||||
|
@ -301,77 +382,29 @@ local function learn_from_epoch()
|
|||
end
|
||||
--print("top:", top_rewards)
|
||||
|
||||
local top_delta_rewards = {} -- only used for printing.
|
||||
for i, ind in ipairs(indices) do
|
||||
local sind = (ind - 1) * 2 + 1
|
||||
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
|
||||
end
|
||||
print("best deltas:", top_delta_rewards)
|
||||
|
||||
if not cfg.ars_lips then
|
||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||
--print("mean, dev:", _, reward_dev)
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
|
||||
end
|
||||
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = top_rewards[ind + 0]
|
||||
local neg = top_rewards[ind + 1]
|
||||
local reward = pos - neg
|
||||
if reward ~= 0 then
|
||||
local noise = trial_noise[i]
|
||||
|
||||
if cfg.ars_lips then
|
||||
local _, dev = calc_mean_dev(noise)
|
||||
local c0 = neg - current_cost
|
||||
local c1 = pos - current_cost
|
||||
local l0 = abs(3 * c1 + c0)
|
||||
local l1 = abs(c1 + 3 * c0)
|
||||
local lips = max(l0, l1) / (2 * dev)
|
||||
--reward = pos / lips - neg / lips
|
||||
local old_reward = reward
|
||||
reward = reward / lips
|
||||
reward = reward / cfg.deviation -- FIXME: hack?
|
||||
--print(("trial %i reward: %.0f -> %.5f"):format(i, old_reward, reward))
|
||||
end
|
||||
|
||||
for j, v in ipairs(noise) do
|
||||
step[j] = step[j] + reward * v / cfg.epoch_top_trials
|
||||
end
|
||||
if cfg.negate_trials then
|
||||
local top_delta_rewards = {} -- only used for printing.
|
||||
for i, ind in ipairs(indices) do
|
||||
local sind = (ind - 1) * 2 + 1
|
||||
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
|
||||
end
|
||||
print("best deltas:", top_delta_rewards)
|
||||
end
|
||||
|
||||
local step
|
||||
if cfg.negate_trials then
|
||||
step = make_step_paired(top_rewards, current_cost)
|
||||
else
|
||||
step = make_step(top_rewards)
|
||||
end
|
||||
|
||||
local step_mean, step_dev = calc_mean_dev(step)
|
||||
print("step mean:", step_mean)
|
||||
print("step stddev:", step_dev)
|
||||
--print("full step stddev:", cfg.learning_rate * step_dev)
|
||||
|
||||
local momstep_mean, momstep_dev = 0, 0
|
||||
if cfg.adamant then
|
||||
if mom1 == nil then mom1 = nn.zeros(#step) end
|
||||
if mom2 == nil then mom2 = nn.zeros(#step) end
|
||||
if mom2max == nil then mom2max = nn.zeros(#step) end
|
||||
|
||||
local b1_t = pow(cfg.adam_b1, epoch_i)
|
||||
local b2_t = pow(cfg.adam_b2, epoch_i)
|
||||
|
||||
-- NOTE: with LuaJIT, splitting this loop would
|
||||
-- almost certainly be faster.
|
||||
for i, v in ipairs(step) do
|
||||
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
|
||||
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
|
||||
mom2max[i] = max(mom2[i], mom2max[i])
|
||||
if cfg.adam_debias then
|
||||
local num = (mom1[i] / (1 - b1_t))
|
||||
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
|
||||
step[i] = num / den
|
||||
else
|
||||
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
|
||||
end
|
||||
end
|
||||
|
||||
amsgrad(step)
|
||||
momstep_mean, momstep_dev = calc_mean_dev(step)
|
||||
print("amsgrad mean:", momstep_mean)
|
||||
print("amsgrad stddev:", momstep_dev)
|
||||
|
@ -480,11 +513,8 @@ local function do_reset()
|
|||
game.W(0x756, 1)
|
||||
end
|
||||
|
||||
--max_time = min(log(epoch_i) * 10 + 100, cfg.cap_time)
|
||||
--max_time = min(8 * sqrt(360 / cfg.epoch_trials * (epoch_i - 1)) + 100, cfg.cap_time)
|
||||
max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time)
|
||||
max_time = ceil(max_time)
|
||||
--max_time = cfg.cap_time
|
||||
|
||||
if once then
|
||||
savestate.load(startsave)
|
||||
|
|
Loading…
Reference in a new issue