reimplement softchoice and redo noise generation

This commit is contained in:
Connor Olding 2018-03-27 13:04:44 +02:00
parent bb44d6696e
commit d696bd8c21

104
main.lua
View file

@ -26,30 +26,35 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
--randomseed(11) --randomseed(11)
local defer_prints = true
local playable_mode = false local playable_mode = false
local start_big = true local start_big = true
local starting_lives = 0 local starting_lives = 0
-- --
local init_zeros = true -- instead of he_normal noise or whatever. local init_zeros = false -- instead of he_normal noise or whatever.
local frameskip = 4 local frameskip = 4
-- true greedy epsilon has both deterministic and det_epsilon set. -- true greedy epsilon has both deterministic and det_epsilon set.
local deterministic = true -- use argmax on outputs instead of random sampling. local deterministic = false -- use argmax on outputs instead of random sampling.
local det_epsilon = false -- take random actions with probability eps. local det_epsilon = false -- take random actions with probability eps.
local eps_start = 1.0 * frameskip / 64 local eps_start = 1.0 * frameskip / 64
local eps_stop = 0.1 * eps_start local eps_stop = 0.1 * eps_start
local eps_frames = 2000000 local eps_frames = 4000000
-- --
local epoch_trials = 20 local epoch_trials = 18
local epoch_top_trials = 10 -- new with ARS. local epoch_top_trials = 9 -- new with ARS.
local unperturbed_trial = true -- do a trial without any noise. local unperturbed_trial = true -- do a trial without any noise.
local negate_trials = true -- try pairs of normal and negated noise directions. local negate_trials = true -- try pairs of normal and negated noise directions.
-- ^ note that this now doubles the effective trials. -- ^ note that this now doubles the effective trials.
local learning_rate = 0.01 local deviation = 0.025
local deviation = 0.06 local function approx_cossim(dim)
return math.pow(1.521 * dim - 0.521, -0.5026)
end
local learning_rate = 0.01 / approx_cossim(7051)
-- --
local cap_time = 100 --400 local cap_time = 200 --400
local timer_loser = 0 --1/3 local timer_loser = 0 --1/3
local decrement_reward = true local decrement_reward = false -- bad idea, encourages mario to kill himself
-- --
local enable_overlay = playable_mode local enable_overlay = playable_mode
local enable_network = not playable_mode local enable_network = not playable_mode
@ -159,6 +164,7 @@ local once = false
local reset = true local reset = true
local state_old = '' local state_old = ''
local last_trial_state
-- localize some stuff. -- localize some stuff.
@ -222,6 +228,18 @@ local function argmax(...)
return max_i return max_i
end end
local function softchoice(...)
local t = random()
local psum = 0
for i=1, select("#", ...) do
local p = select(i, ...)
psum = psum + p
if t < psum then
return i
end
end
end
local function argmax2(t) local function argmax2(t)
return t[1] > t[2] return t[1] > t[2]
end end
@ -251,15 +269,14 @@ local function calc_mean_dev(x)
dev = dev + delta * delta / #x dev = dev + delta * delta / #x
end end
return mean, dev return mean, sqrt(dev)
end end
local function normalize(x, out) local function normalize(x, out)
out = out or x out = out or x
local mean, dev = calc_mean_dev(x) local mean, dev = calc_mean_dev(x)
if dev <= 0 then dev = 1 end if dev <= 0 then dev = 1 end
local devs = sqrt(dev) for i, v in ipairs(x) do out[i] = (v - mean) / dev end
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
return out return out
end end
@ -267,8 +284,7 @@ local function normalize_wrt(x, s, out)
out = out or x out = out or x
local mean, dev = calc_mean_dev(s) local mean, dev = calc_mean_dev(s)
if dev <= 0 then dev = 1 end if dev <= 0 then dev = 1 end
local devs = sqrt(dev) for i, v in ipairs(x) do out[i] = (v - mean) / dev end
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
return out return out
end end
@ -287,10 +303,6 @@ local function make_network(input_size)
nn_x:feed(nn_y) nn_x:feed(nn_y)
nn_ty:feed(nn_y) nn_ty:feed(nn_y)
nn_y = nn_y:feed(nn.Dense(128))
--nn_y = nn_y:feed(nn.Gelu())
nn_y = nn_y:feed(nn.Relu())
nn_z = nn_y nn_z = nn_y
nn_z = nn_z:feed(nn.Dense(#jp_lut)) nn_z = nn_z:feed(nn.Dense(#jp_lut))
nn_z = nn_z:feed(nn.Softmax()) nn_z = nn_z:feed(nn.Softmax())
@ -560,7 +572,9 @@ local function prepare_epoch()
-- (the os.time() as of here) and calling nn.normal() each trial. -- (the os.time() as of here) and calling nn.normal() each trial.
for i = 1, epoch_trials do for i = 1, epoch_trials do
local noise = nn.zeros(#base_params) local noise = nn.zeros(#base_params)
for j = 1, #base_params do noise[j] = nn.normal() end -- NOTE: change in implementation: deviation is multiplied here
-- and ONLY here now.
for j = 1, #base_params do noise[j] = deviation * nn.normal() end
trial_noise[i] = noise trial_noise[i] = noise
end end
trial_i = -1 trial_i = -1
@ -577,24 +591,24 @@ local function load_next_pair()
if trial_i > 0 then if trial_i > 0 then
if trial_neg then if trial_neg then
print('trial', trial_i, 'positive') if not defer_prints then print('trial', trial_i, 'positive') end
local noise = trial_noise[trial_i] local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do for i, v in ipairs(base_params) do
W[i] = v + deviation * noise[i] W[i] = v + noise[i]
end end
else else
trial_i = trial_i - 1 trial_i = trial_i - 1
print('trial', trial_i, 'negative') if not defer_prints then print('trial', trial_i, 'positive') end
local noise = trial_noise[trial_i] local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do for i, v in ipairs(base_params) do
W[i] = v - deviation * noise[i] W[i] = v - noise[i]
end end
end end
trial_neg = not trial_neg trial_neg = not trial_neg
else else
print("test trial") if not defer_prints then print("test trial") end
end end
network:distribute(W) network:distribute(W)
@ -611,7 +625,7 @@ local function load_next_trial()
print('loading trial', trial_i) print('loading trial', trial_i)
local noise = trial_noise[trial_i] local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do for i, v in ipairs(base_params) do
W[i] = v + deviation * noise[i] W[i] = v + noise[i]
end end
else else
print("test trial") print("test trial")
@ -702,12 +716,11 @@ local function learn_from_epoch()
end end
print("top:", top_rewards) print("top:", top_rewards)
local reward_mean, reward_dev = calc_mean_dev(top_rewards) local _, reward_dev = calc_mean_dev(top_rewards)
--print("mean, dev:", reward_mean, reward_dev) --print("mean, dev:", _, reward_dev)
if reward_dev == 0 then reward_dev = 1 end if reward_dev == 0 then reward_dev = 1 end
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
--print("scaled:", top_rewards)
-- NOTE: step no longer directly incorporates learning_rate. -- NOTE: step no longer directly incorporates learning_rate.
for i = 1, epoch_trials do for i = 1, epoch_trials do
@ -717,13 +730,14 @@ local function learn_from_epoch()
local reward = pos - neg local reward = pos - neg
local noise = trial_noise[i] local noise = trial_noise[i]
for j, v in ipairs(noise) do for j, v in ipairs(noise) do
step[j] = step[j] + (reward * v) / epoch_trials step[j] = step[j] + reward * v / epoch_top_trials
end end
end end
local step_mean, step_dev = calc_mean_dev(step) local step_mean, step_dev = calc_mean_dev(step)
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
print("step stddev:", step_dev) --print("step stddev:", step_dev)
print("full step stddev:", learning_rate * step_dev)
for i, v in ipairs(base_params) do for i, v in ipairs(base_params) do
base_params[i] = v + learning_rate * step[i] base_params[i] = v + learning_rate * step[i]
@ -743,7 +757,27 @@ local function do_reset()
local state = get_state() local state = get_state()
-- be a little more descriptive. -- be a little more descriptive.
if state == 'dead' and get_timer() == 0 then state = 'timeup' end if state == 'dead' and get_timer() == 0 then state = 'timeup' end
if trial_i >= 0 and defer_prints then
if trial_i == 0 then
print('test trial reward:', reward, "("..state..")")
elseif negate_trials then
--local dir = trial_neg and "negative" or "positive"
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
if trial_neg then
local pos = trial_rewards[#trial_rewards]
local neg = reward
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
print(fmt:format(trial_i, pos, neg, last_trial_state, state))
end
last_trial_state = state
else
print('trial', trial_i, 'reward:', reward, "("..state..")")
end
else
print("reward:", reward, "("..state..")") print("reward:", reward, "("..state..")")
end
if trial_i >= 0 then if trial_i >= 0 then
if trial_i == 0 or not negate_trials then if trial_i == 0 or not negate_trials then
@ -777,8 +811,9 @@ local function do_reset()
--max_time = min(log(epoch_i) * 10 + 100, cap_time) --max_time = min(log(epoch_i) * 10 + 100, cap_time)
--max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time) --max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
--max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time) max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time)
max_time = ceil(max_time) max_time = ceil(max_time)
--max_time = cap_time
if once then if once then
savestate.load(startsave) savestate.load(startsave)
@ -916,9 +951,6 @@ local function doit(dummy)
if enable_network and get_state() == 'playing' or ingame_paused then if enable_network and get_state() == 'playing' or ingame_paused then
total_frames = total_frames + frameskip total_frames = total_frames + frameskip
-- TODO: reimplement this.
local choose = deterministic and argmax2 or rchoice2
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input}) local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames) local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
@ -926,7 +958,9 @@ local function doit(dummy)
local i = floor(random() * #jp_lut) + 1 local i = floor(random() * #jp_lut) + 1
jp = nn.copy(jp_lut[i], jp) jp = nn.copy(jp_lut[i], jp)
else else
jp = nn.copy(jp_lut[argmax(unpack(outputs[nn_z]))], jp) local choose = deterministic and argmax or softchoice
local ind = choose(unpack(outputs[nn_z]))
jp = nn.copy(jp_lut[ind], jp)
end end
if force_start then if force_start then