refactor learn_from_epoch
This commit is contained in:
parent
ec19774af5
commit
3030e83d00
1 changed files with 119 additions and 89 deletions
208
main.lua
208
main.lua
|
@ -48,6 +48,7 @@ local last_trial_state
|
||||||
|
|
||||||
-- localize some stuff.
|
-- localize some stuff.
|
||||||
|
|
||||||
|
local assert = assert
|
||||||
local print = print
|
local print = print
|
||||||
local ipairs = ipairs
|
local ipairs = ipairs
|
||||||
local pairs = pairs
|
local pairs = pairs
|
||||||
|
@ -79,6 +80,7 @@ local arshift = bit.arshift
|
||||||
local rol = bit.rol
|
local rol = bit.rol
|
||||||
local ror = bit.ror
|
local ror = bit.ror
|
||||||
|
|
||||||
|
local emu = emu
|
||||||
local gui = gui
|
local gui = gui
|
||||||
|
|
||||||
local util = require("util")
|
local util = require("util")
|
||||||
|
@ -177,8 +179,6 @@ local function prepare_epoch()
|
||||||
for i = 1, cfg.epoch_trials do
|
for i = 1, cfg.epoch_trials do
|
||||||
local noise = nn.zeros(#base_params)
|
local noise = nn.zeros(#base_params)
|
||||||
if cfg.graycode then
|
if cfg.graycode then
|
||||||
--local precision = 1 / cfg.deviation
|
|
||||||
--print(cfg.deviation, precision)
|
|
||||||
for j = 1, #base_params do
|
for j = 1, #base_params do
|
||||||
noise[j] = exp(-precision * nn.uniform())
|
noise[j] = exp(-precision * nn.uniform())
|
||||||
end
|
end
|
||||||
|
@ -244,6 +244,107 @@ local function load_next_trial()
|
||||||
network:distribute(W)
|
network:distribute(W)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function collect_best_indices()
|
||||||
|
-- select one (the best) reward of each pos/neg pair.
|
||||||
|
local best_rewards = {}
|
||||||
|
if cfg.negate_trials then
|
||||||
|
for i = 1, cfg.epoch_trials do
|
||||||
|
local ind = (i - 1) * 2 + 1
|
||||||
|
local pos = trial_rewards[ind + 0]
|
||||||
|
local neg = trial_rewards[ind + 1]
|
||||||
|
best_rewards[i] = max(pos, neg)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
best_rewards = copy(trial_rewards)
|
||||||
|
end
|
||||||
|
|
||||||
|
local indices = {}
|
||||||
|
for i = 1, #best_rewards do indices[i] = i end
|
||||||
|
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
|
||||||
|
|
||||||
|
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
||||||
|
return indices
|
||||||
|
end
|
||||||
|
|
||||||
|
local function kinda_lipschitz(dir, pos, neg, mid)
|
||||||
|
local _, dev = calc_mean_dev(dir)
|
||||||
|
local c0 = neg - mid
|
||||||
|
local c1 = pos - mid
|
||||||
|
local l0 = abs(3 * c1 + c0)
|
||||||
|
local l1 = abs(c1 + 3 * c0)
|
||||||
|
return max(l0, l1) / (2 * dev)
|
||||||
|
end
|
||||||
|
|
||||||
|
local function make_step_paired(rewards, current_cost)
|
||||||
|
local step = nn.zeros(#base_params)
|
||||||
|
local _, reward_dev = calc_mean_dev(rewards)
|
||||||
|
if reward_dev == 0 then reward_dev = 1 end
|
||||||
|
|
||||||
|
for i = 1, cfg.epoch_trials do
|
||||||
|
local ind = (i - 1) * 2 + 1
|
||||||
|
local pos = rewards[ind + 0]
|
||||||
|
local neg = rewards[ind + 1]
|
||||||
|
local reward = pos - neg
|
||||||
|
if reward ~= 0 then
|
||||||
|
local noise = trial_noise[i]
|
||||||
|
|
||||||
|
if cfg.ars_lips then
|
||||||
|
local lips = kinda_lipschitz(noise, pos, neg, current_cost)
|
||||||
|
reward = reward / lips / cfg.deviation
|
||||||
|
else
|
||||||
|
reward = reward / reward_dev
|
||||||
|
end
|
||||||
|
|
||||||
|
for j, v in ipairs(noise) do
|
||||||
|
step[j] = step[j] + reward * v / cfg.epoch_top_trials
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return step
|
||||||
|
end
|
||||||
|
|
||||||
|
local function make_step(rewards)
|
||||||
|
local step = nn.zeros(#base_params)
|
||||||
|
local _, reward_dev = calc_mean_dev(rewards)
|
||||||
|
if reward_dev == 0 then reward_dev = 1 end
|
||||||
|
|
||||||
|
for i = 1, cfg.epoch_trials do
|
||||||
|
local reward = rewards[i] / reward_dev
|
||||||
|
if reward ~= 0 then
|
||||||
|
local noise = trial_noise[i]
|
||||||
|
|
||||||
|
for j, v in ipairs(noise) do
|
||||||
|
step[j] = step[j] + reward * v / cfg.epoch_top_trials
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return step
|
||||||
|
end
|
||||||
|
|
||||||
|
local function amsgrad(step) -- in-place!
|
||||||
|
if mom1 == nil then mom1 = nn.zeros(#step) end
|
||||||
|
if mom2 == nil then mom2 = nn.zeros(#step) end
|
||||||
|
if mom2max == nil then mom2max = nn.zeros(#step) end
|
||||||
|
|
||||||
|
local b1_t = pow(cfg.adam_b1, epoch_i)
|
||||||
|
local b2_t = pow(cfg.adam_b2, epoch_i)
|
||||||
|
|
||||||
|
-- NOTE: with LuaJIT, splitting this loop would
|
||||||
|
-- almost certainly be faster.
|
||||||
|
for i, v in ipairs(step) do
|
||||||
|
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
|
||||||
|
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
|
||||||
|
mom2max[i] = max(mom2[i], mom2max[i])
|
||||||
|
if cfg.adam_debias then
|
||||||
|
local num = (mom1[i] / (1 - b1_t))
|
||||||
|
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
|
||||||
|
step[i] = num / den
|
||||||
|
else
|
||||||
|
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
local function learn_from_epoch()
|
local function learn_from_epoch()
|
||||||
print()
|
print()
|
||||||
--print('rewards:', trial_rewards)
|
--print('rewards:', trial_rewards)
|
||||||
|
@ -259,37 +360,17 @@ local function learn_from_epoch()
|
||||||
print(("test trial: %d out of %d"):format(nth_place, #trial_rewards))
|
print(("test trial: %d out of %d"):format(nth_place, #trial_rewards))
|
||||||
end
|
end
|
||||||
|
|
||||||
local step = nn.zeros(#base_params)
|
local delta_rewards = {} -- only used for logging.
|
||||||
|
|
||||||
local delta_rewards -- only used for logging.
|
|
||||||
local best_rewards
|
|
||||||
if cfg.negate_trials then
|
if cfg.negate_trials then
|
||||||
delta_rewards = {}
|
|
||||||
for i = 1, cfg.epoch_trials do
|
for i = 1, cfg.epoch_trials do
|
||||||
local ind = (i - 1) * 2 + 1
|
local ind = (i - 1) * 2 + 1
|
||||||
local pos = trial_rewards[ind + 0]
|
local pos = trial_rewards[ind + 0]
|
||||||
local neg = trial_rewards[ind + 1]
|
local neg = trial_rewards[ind + 1]
|
||||||
delta_rewards[i] = abs(pos - neg)
|
delta_rewards[i] = abs(pos - neg)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- select one (the best) reward of each pos/neg pair.
|
|
||||||
best_rewards = {}
|
|
||||||
for i = 1, cfg.epoch_trials do
|
|
||||||
local ind = (i - 1) * 2 + 1
|
|
||||||
local pos = trial_rewards[ind + 0]
|
|
||||||
local neg = trial_rewards[ind + 1]
|
|
||||||
best_rewards[i] = max(pos, neg)
|
|
||||||
end
|
|
||||||
else
|
|
||||||
best_rewards = copy(trial_rewards)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
local indices = {}
|
local indices = collect_best_indices()
|
||||||
for i = 1, #best_rewards do indices[i] = i end
|
|
||||||
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
|
|
||||||
|
|
||||||
--print("indices:", indices)
|
|
||||||
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
|
||||||
print("best trials:", indices)
|
print("best trials:", indices)
|
||||||
|
|
||||||
local top_rewards = {}
|
local top_rewards = {}
|
||||||
|
@ -301,77 +382,29 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
--print("top:", top_rewards)
|
--print("top:", top_rewards)
|
||||||
|
|
||||||
local top_delta_rewards = {} -- only used for printing.
|
if cfg.negate_trials then
|
||||||
for i, ind in ipairs(indices) do
|
local top_delta_rewards = {} -- only used for printing.
|
||||||
local sind = (ind - 1) * 2 + 1
|
for i, ind in ipairs(indices) do
|
||||||
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
|
local sind = (ind - 1) * 2 + 1
|
||||||
end
|
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
|
||||||
print("best deltas:", top_delta_rewards)
|
|
||||||
|
|
||||||
if not cfg.ars_lips then
|
|
||||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
|
||||||
--print("mean, dev:", _, reward_dev)
|
|
||||||
if reward_dev == 0 then reward_dev = 1 end
|
|
||||||
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
|
|
||||||
end
|
|
||||||
|
|
||||||
for i = 1, cfg.epoch_trials do
|
|
||||||
local ind = (i - 1) * 2 + 1
|
|
||||||
local pos = top_rewards[ind + 0]
|
|
||||||
local neg = top_rewards[ind + 1]
|
|
||||||
local reward = pos - neg
|
|
||||||
if reward ~= 0 then
|
|
||||||
local noise = trial_noise[i]
|
|
||||||
|
|
||||||
if cfg.ars_lips then
|
|
||||||
local _, dev = calc_mean_dev(noise)
|
|
||||||
local c0 = neg - current_cost
|
|
||||||
local c1 = pos - current_cost
|
|
||||||
local l0 = abs(3 * c1 + c0)
|
|
||||||
local l1 = abs(c1 + 3 * c0)
|
|
||||||
local lips = max(l0, l1) / (2 * dev)
|
|
||||||
--reward = pos / lips - neg / lips
|
|
||||||
local old_reward = reward
|
|
||||||
reward = reward / lips
|
|
||||||
reward = reward / cfg.deviation -- FIXME: hack?
|
|
||||||
--print(("trial %i reward: %.0f -> %.5f"):format(i, old_reward, reward))
|
|
||||||
end
|
|
||||||
|
|
||||||
for j, v in ipairs(noise) do
|
|
||||||
step[j] = step[j] + reward * v / cfg.epoch_top_trials
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
print("best deltas:", top_delta_rewards)
|
||||||
|
end
|
||||||
|
|
||||||
|
local step
|
||||||
|
if cfg.negate_trials then
|
||||||
|
step = make_step_paired(top_rewards, current_cost)
|
||||||
|
else
|
||||||
|
step = make_step(top_rewards)
|
||||||
end
|
end
|
||||||
|
|
||||||
local step_mean, step_dev = calc_mean_dev(step)
|
local step_mean, step_dev = calc_mean_dev(step)
|
||||||
print("step mean:", step_mean)
|
print("step mean:", step_mean)
|
||||||
print("step stddev:", step_dev)
|
print("step stddev:", step_dev)
|
||||||
--print("full step stddev:", cfg.learning_rate * step_dev)
|
|
||||||
|
|
||||||
local momstep_mean, momstep_dev = 0, 0
|
local momstep_mean, momstep_dev = 0, 0
|
||||||
if cfg.adamant then
|
if cfg.adamant then
|
||||||
if mom1 == nil then mom1 = nn.zeros(#step) end
|
amsgrad(step)
|
||||||
if mom2 == nil then mom2 = nn.zeros(#step) end
|
|
||||||
if mom2max == nil then mom2max = nn.zeros(#step) end
|
|
||||||
|
|
||||||
local b1_t = pow(cfg.adam_b1, epoch_i)
|
|
||||||
local b2_t = pow(cfg.adam_b2, epoch_i)
|
|
||||||
|
|
||||||
-- NOTE: with LuaJIT, splitting this loop would
|
|
||||||
-- almost certainly be faster.
|
|
||||||
for i, v in ipairs(step) do
|
|
||||||
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
|
|
||||||
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
|
|
||||||
mom2max[i] = max(mom2[i], mom2max[i])
|
|
||||||
if cfg.adam_debias then
|
|
||||||
local num = (mom1[i] / (1 - b1_t))
|
|
||||||
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
|
|
||||||
step[i] = num / den
|
|
||||||
else
|
|
||||||
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
momstep_mean, momstep_dev = calc_mean_dev(step)
|
momstep_mean, momstep_dev = calc_mean_dev(step)
|
||||||
print("amsgrad mean:", momstep_mean)
|
print("amsgrad mean:", momstep_mean)
|
||||||
print("amsgrad stddev:", momstep_dev)
|
print("amsgrad stddev:", momstep_dev)
|
||||||
|
@ -480,11 +513,8 @@ local function do_reset()
|
||||||
game.W(0x756, 1)
|
game.W(0x756, 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
--max_time = min(log(epoch_i) * 10 + 100, cfg.cap_time)
|
|
||||||
--max_time = min(8 * sqrt(360 / cfg.epoch_trials * (epoch_i - 1)) + 100, cfg.cap_time)
|
|
||||||
max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time)
|
max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time)
|
||||||
max_time = ceil(max_time)
|
max_time = ceil(max_time)
|
||||||
--max_time = cfg.cap_time
|
|
||||||
|
|
||||||
if once then
|
if once then
|
||||||
savestate.load(startsave)
|
savestate.load(startsave)
|
||||||
|
|
Loading…
Reference in a new issue