reimplement softchoice and redo noise generation
This commit is contained in:
parent
bb44d6696e
commit
d696bd8c21
1 changed files with 70 additions and 36 deletions
104
main.lua
104
main.lua
|
@ -26,30 +26,35 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
|
||||||
|
|
||||||
--randomseed(11)
|
--randomseed(11)
|
||||||
|
|
||||||
|
local defer_prints = true
|
||||||
|
|
||||||
local playable_mode = false
|
local playable_mode = false
|
||||||
local start_big = true
|
local start_big = true
|
||||||
local starting_lives = 0
|
local starting_lives = 0
|
||||||
--
|
--
|
||||||
local init_zeros = true -- instead of he_normal noise or whatever.
|
local init_zeros = false -- instead of he_normal noise or whatever.
|
||||||
local frameskip = 4
|
local frameskip = 4
|
||||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||||
local deterministic = true -- use argmax on outputs instead of random sampling.
|
local deterministic = false -- use argmax on outputs instead of random sampling.
|
||||||
local det_epsilon = false -- take random actions with probability eps.
|
local det_epsilon = false -- take random actions with probability eps.
|
||||||
local eps_start = 1.0 * frameskip / 64
|
local eps_start = 1.0 * frameskip / 64
|
||||||
local eps_stop = 0.1 * eps_start
|
local eps_stop = 0.1 * eps_start
|
||||||
local eps_frames = 2000000
|
local eps_frames = 4000000
|
||||||
--
|
--
|
||||||
local epoch_trials = 20
|
local epoch_trials = 18
|
||||||
local epoch_top_trials = 10 -- new with ARS.
|
local epoch_top_trials = 9 -- new with ARS.
|
||||||
local unperturbed_trial = true -- do a trial without any noise.
|
local unperturbed_trial = true -- do a trial without any noise.
|
||||||
local negate_trials = true -- try pairs of normal and negated noise directions.
|
local negate_trials = true -- try pairs of normal and negated noise directions.
|
||||||
-- ^ note that this now doubles the effective trials.
|
-- ^ note that this now doubles the effective trials.
|
||||||
local learning_rate = 0.01
|
local deviation = 0.025
|
||||||
local deviation = 0.06
|
local function approx_cossim(dim)
|
||||||
|
return math.pow(1.521 * dim - 0.521, -0.5026)
|
||||||
|
end
|
||||||
|
local learning_rate = 0.01 / approx_cossim(7051)
|
||||||
--
|
--
|
||||||
local cap_time = 100 --400
|
local cap_time = 200 --400
|
||||||
local timer_loser = 0 --1/3
|
local timer_loser = 0 --1/3
|
||||||
local decrement_reward = true
|
local decrement_reward = false -- bad idea, encourages mario to kill himself
|
||||||
--
|
--
|
||||||
local enable_overlay = playable_mode
|
local enable_overlay = playable_mode
|
||||||
local enable_network = not playable_mode
|
local enable_network = not playable_mode
|
||||||
|
@ -159,6 +164,7 @@ local once = false
|
||||||
local reset = true
|
local reset = true
|
||||||
|
|
||||||
local state_old = ''
|
local state_old = ''
|
||||||
|
local last_trial_state
|
||||||
|
|
||||||
-- localize some stuff.
|
-- localize some stuff.
|
||||||
|
|
||||||
|
@ -222,6 +228,18 @@ local function argmax(...)
|
||||||
return max_i
|
return max_i
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function softchoice(...)
|
||||||
|
local t = random()
|
||||||
|
local psum = 0
|
||||||
|
for i=1, select("#", ...) do
|
||||||
|
local p = select(i, ...)
|
||||||
|
psum = psum + p
|
||||||
|
if t < psum then
|
||||||
|
return i
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
local function argmax2(t)
|
local function argmax2(t)
|
||||||
return t[1] > t[2]
|
return t[1] > t[2]
|
||||||
end
|
end
|
||||||
|
@ -251,15 +269,14 @@ local function calc_mean_dev(x)
|
||||||
dev = dev + delta * delta / #x
|
dev = dev + delta * delta / #x
|
||||||
end
|
end
|
||||||
|
|
||||||
return mean, dev
|
return mean, sqrt(dev)
|
||||||
end
|
end
|
||||||
|
|
||||||
local function normalize(x, out)
|
local function normalize(x, out)
|
||||||
out = out or x
|
out = out or x
|
||||||
local mean, dev = calc_mean_dev(x)
|
local mean, dev = calc_mean_dev(x)
|
||||||
if dev <= 0 then dev = 1 end
|
if dev <= 0 then dev = 1 end
|
||||||
local devs = sqrt(dev)
|
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
||||||
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
|
|
||||||
return out
|
return out
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -267,8 +284,7 @@ local function normalize_wrt(x, s, out)
|
||||||
out = out or x
|
out = out or x
|
||||||
local mean, dev = calc_mean_dev(s)
|
local mean, dev = calc_mean_dev(s)
|
||||||
if dev <= 0 then dev = 1 end
|
if dev <= 0 then dev = 1 end
|
||||||
local devs = sqrt(dev)
|
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
||||||
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
|
|
||||||
return out
|
return out
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -287,10 +303,6 @@ local function make_network(input_size)
|
||||||
nn_x:feed(nn_y)
|
nn_x:feed(nn_y)
|
||||||
nn_ty:feed(nn_y)
|
nn_ty:feed(nn_y)
|
||||||
|
|
||||||
nn_y = nn_y:feed(nn.Dense(128))
|
|
||||||
--nn_y = nn_y:feed(nn.Gelu())
|
|
||||||
nn_y = nn_y:feed(nn.Relu())
|
|
||||||
|
|
||||||
nn_z = nn_y
|
nn_z = nn_y
|
||||||
nn_z = nn_z:feed(nn.Dense(#jp_lut))
|
nn_z = nn_z:feed(nn.Dense(#jp_lut))
|
||||||
nn_z = nn_z:feed(nn.Softmax())
|
nn_z = nn_z:feed(nn.Softmax())
|
||||||
|
@ -560,7 +572,9 @@ local function prepare_epoch()
|
||||||
-- (the os.time() as of here) and calling nn.normal() each trial.
|
-- (the os.time() as of here) and calling nn.normal() each trial.
|
||||||
for i = 1, epoch_trials do
|
for i = 1, epoch_trials do
|
||||||
local noise = nn.zeros(#base_params)
|
local noise = nn.zeros(#base_params)
|
||||||
for j = 1, #base_params do noise[j] = nn.normal() end
|
-- NOTE: change in implementation: deviation is multiplied here
|
||||||
|
-- and ONLY here now.
|
||||||
|
for j = 1, #base_params do noise[j] = deviation * nn.normal() end
|
||||||
trial_noise[i] = noise
|
trial_noise[i] = noise
|
||||||
end
|
end
|
||||||
trial_i = -1
|
trial_i = -1
|
||||||
|
@ -577,24 +591,24 @@ local function load_next_pair()
|
||||||
|
|
||||||
if trial_i > 0 then
|
if trial_i > 0 then
|
||||||
if trial_neg then
|
if trial_neg then
|
||||||
print('trial', trial_i, 'positive')
|
if not defer_prints then print('trial', trial_i, 'positive') end
|
||||||
local noise = trial_noise[trial_i]
|
local noise = trial_noise[trial_i]
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
W[i] = v + deviation * noise[i]
|
W[i] = v + noise[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
else
|
else
|
||||||
trial_i = trial_i - 1
|
trial_i = trial_i - 1
|
||||||
print('trial', trial_i, 'negative')
|
if not defer_prints then print('trial', trial_i, 'positive') end
|
||||||
local noise = trial_noise[trial_i]
|
local noise = trial_noise[trial_i]
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
W[i] = v - deviation * noise[i]
|
W[i] = v - noise[i]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
trial_neg = not trial_neg
|
trial_neg = not trial_neg
|
||||||
else
|
else
|
||||||
print("test trial")
|
if not defer_prints then print("test trial") end
|
||||||
end
|
end
|
||||||
|
|
||||||
network:distribute(W)
|
network:distribute(W)
|
||||||
|
@ -611,7 +625,7 @@ local function load_next_trial()
|
||||||
print('loading trial', trial_i)
|
print('loading trial', trial_i)
|
||||||
local noise = trial_noise[trial_i]
|
local noise = trial_noise[trial_i]
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
W[i] = v + deviation * noise[i]
|
W[i] = v + noise[i]
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
print("test trial")
|
print("test trial")
|
||||||
|
@ -702,12 +716,11 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
print("top:", top_rewards)
|
print("top:", top_rewards)
|
||||||
|
|
||||||
local reward_mean, reward_dev = calc_mean_dev(top_rewards)
|
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||||
--print("mean, dev:", reward_mean, reward_dev)
|
--print("mean, dev:", _, reward_dev)
|
||||||
if reward_dev == 0 then reward_dev = 1 end
|
if reward_dev == 0 then reward_dev = 1 end
|
||||||
|
|
||||||
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
|
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
|
||||||
--print("scaled:", top_rewards)
|
|
||||||
|
|
||||||
-- NOTE: step no longer directly incorporates learning_rate.
|
-- NOTE: step no longer directly incorporates learning_rate.
|
||||||
for i = 1, epoch_trials do
|
for i = 1, epoch_trials do
|
||||||
|
@ -717,13 +730,14 @@ local function learn_from_epoch()
|
||||||
local reward = pos - neg
|
local reward = pos - neg
|
||||||
local noise = trial_noise[i]
|
local noise = trial_noise[i]
|
||||||
for j, v in ipairs(noise) do
|
for j, v in ipairs(noise) do
|
||||||
step[j] = step[j] + (reward * v) / epoch_trials
|
step[j] = step[j] + reward * v / epoch_top_trials
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
local step_mean, step_dev = calc_mean_dev(step)
|
local step_mean, step_dev = calc_mean_dev(step)
|
||||||
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
|
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
|
||||||
print("step stddev:", step_dev)
|
--print("step stddev:", step_dev)
|
||||||
|
print("full step stddev:", learning_rate * step_dev)
|
||||||
|
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
base_params[i] = v + learning_rate * step[i]
|
base_params[i] = v + learning_rate * step[i]
|
||||||
|
@ -743,7 +757,27 @@ local function do_reset()
|
||||||
local state = get_state()
|
local state = get_state()
|
||||||
-- be a little more descriptive.
|
-- be a little more descriptive.
|
||||||
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
||||||
|
|
||||||
|
if trial_i >= 0 and defer_prints then
|
||||||
|
if trial_i == 0 then
|
||||||
|
print('test trial reward:', reward, "("..state..")")
|
||||||
|
elseif negate_trials then
|
||||||
|
--local dir = trial_neg and "negative" or "positive"
|
||||||
|
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
|
||||||
|
|
||||||
|
if trial_neg then
|
||||||
|
local pos = trial_rewards[#trial_rewards]
|
||||||
|
local neg = reward
|
||||||
|
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
|
||||||
|
print(fmt:format(trial_i, pos, neg, last_trial_state, state))
|
||||||
|
end
|
||||||
|
last_trial_state = state
|
||||||
|
else
|
||||||
|
print('trial', trial_i, 'reward:', reward, "("..state..")")
|
||||||
|
end
|
||||||
|
else
|
||||||
print("reward:", reward, "("..state..")")
|
print("reward:", reward, "("..state..")")
|
||||||
|
end
|
||||||
|
|
||||||
if trial_i >= 0 then
|
if trial_i >= 0 then
|
||||||
if trial_i == 0 or not negate_trials then
|
if trial_i == 0 or not negate_trials then
|
||||||
|
@ -777,8 +811,9 @@ local function do_reset()
|
||||||
|
|
||||||
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
||||||
--max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
|
--max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
|
||||||
--max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time)
|
max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time)
|
||||||
max_time = ceil(max_time)
|
max_time = ceil(max_time)
|
||||||
|
--max_time = cap_time
|
||||||
|
|
||||||
if once then
|
if once then
|
||||||
savestate.load(startsave)
|
savestate.load(startsave)
|
||||||
|
@ -916,9 +951,6 @@ local function doit(dummy)
|
||||||
if enable_network and get_state() == 'playing' or ingame_paused then
|
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||||
total_frames = total_frames + frameskip
|
total_frames = total_frames + frameskip
|
||||||
|
|
||||||
-- TODO: reimplement this.
|
|
||||||
local choose = deterministic and argmax2 or rchoice2
|
|
||||||
|
|
||||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
||||||
|
|
||||||
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
|
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
|
||||||
|
@ -926,7 +958,9 @@ local function doit(dummy)
|
||||||
local i = floor(random() * #jp_lut) + 1
|
local i = floor(random() * #jp_lut) + 1
|
||||||
jp = nn.copy(jp_lut[i], jp)
|
jp = nn.copy(jp_lut[i], jp)
|
||||||
else
|
else
|
||||||
jp = nn.copy(jp_lut[argmax(unpack(outputs[nn_z]))], jp)
|
local choose = deterministic and argmax or softchoice
|
||||||
|
local ind = choose(unpack(outputs[nn_z]))
|
||||||
|
jp = nn.copy(jp_lut[ind], jp)
|
||||||
end
|
end
|
||||||
|
|
||||||
if force_start then
|
if force_start then
|
||||||
|
|
Loading…
Reference in a new issue