reimplement softchoice and redo noise generation
This commit is contained in:
parent
bb44d6696e
commit
d696bd8c21
1 changed files with 70 additions and 36 deletions
106
main.lua
106
main.lua
|
@ -26,30 +26,35 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
|
|||
|
||||
--randomseed(11)
|
||||
|
||||
local defer_prints = true
|
||||
|
||||
local playable_mode = false
|
||||
local start_big = true
|
||||
local starting_lives = 0
|
||||
--
|
||||
local init_zeros = true -- instead of he_normal noise or whatever.
|
||||
local init_zeros = false -- instead of he_normal noise or whatever.
|
||||
local frameskip = 4
|
||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||
local deterministic = true -- use argmax on outputs instead of random sampling.
|
||||
local deterministic = false -- use argmax on outputs instead of random sampling.
|
||||
local det_epsilon = false -- take random actions with probability eps.
|
||||
local eps_start = 1.0 * frameskip / 64
|
||||
local eps_stop = 0.1 * eps_start
|
||||
local eps_frames = 2000000
|
||||
local eps_frames = 4000000
|
||||
--
|
||||
local epoch_trials = 20
|
||||
local epoch_top_trials = 10 -- new with ARS.
|
||||
local epoch_trials = 18
|
||||
local epoch_top_trials = 9 -- new with ARS.
|
||||
local unperturbed_trial = true -- do a trial without any noise.
|
||||
local negate_trials = true -- try pairs of normal and negated noise directions.
|
||||
-- ^ note that this now doubles the effective trials.
|
||||
local learning_rate = 0.01
|
||||
local deviation = 0.06
|
||||
local deviation = 0.025
|
||||
local function approx_cossim(dim)
|
||||
return math.pow(1.521 * dim - 0.521, -0.5026)
|
||||
end
|
||||
local learning_rate = 0.01 / approx_cossim(7051)
|
||||
--
|
||||
local cap_time = 100 --400
|
||||
local cap_time = 200 --400
|
||||
local timer_loser = 0 --1/3
|
||||
local decrement_reward = true
|
||||
local decrement_reward = false -- bad idea, encourages mario to kill himself
|
||||
--
|
||||
local enable_overlay = playable_mode
|
||||
local enable_network = not playable_mode
|
||||
|
@ -159,6 +164,7 @@ local once = false
|
|||
local reset = true
|
||||
|
||||
local state_old = ''
|
||||
local last_trial_state
|
||||
|
||||
-- localize some stuff.
|
||||
|
||||
|
@ -222,6 +228,18 @@ local function argmax(...)
|
|||
return max_i
|
||||
end
|
||||
|
||||
local function softchoice(...)
|
||||
local t = random()
|
||||
local psum = 0
|
||||
for i=1, select("#", ...) do
|
||||
local p = select(i, ...)
|
||||
psum = psum + p
|
||||
if t < psum then
|
||||
return i
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function argmax2(t)
|
||||
return t[1] > t[2]
|
||||
end
|
||||
|
@ -251,15 +269,14 @@ local function calc_mean_dev(x)
|
|||
dev = dev + delta * delta / #x
|
||||
end
|
||||
|
||||
return mean, dev
|
||||
return mean, sqrt(dev)
|
||||
end
|
||||
|
||||
local function normalize(x, out)
|
||||
out = out or x
|
||||
local mean, dev = calc_mean_dev(x)
|
||||
if dev <= 0 then dev = 1 end
|
||||
local devs = sqrt(dev)
|
||||
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
|
||||
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
||||
return out
|
||||
end
|
||||
|
||||
|
@ -267,8 +284,7 @@ local function normalize_wrt(x, s, out)
|
|||
out = out or x
|
||||
local mean, dev = calc_mean_dev(s)
|
||||
if dev <= 0 then dev = 1 end
|
||||
local devs = sqrt(dev)
|
||||
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
|
||||
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
||||
return out
|
||||
end
|
||||
|
||||
|
@ -287,10 +303,6 @@ local function make_network(input_size)
|
|||
nn_x:feed(nn_y)
|
||||
nn_ty:feed(nn_y)
|
||||
|
||||
nn_y = nn_y:feed(nn.Dense(128))
|
||||
--nn_y = nn_y:feed(nn.Gelu())
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
|
||||
nn_z = nn_y
|
||||
nn_z = nn_z:feed(nn.Dense(#jp_lut))
|
||||
nn_z = nn_z:feed(nn.Softmax())
|
||||
|
@ -560,7 +572,9 @@ local function prepare_epoch()
|
|||
-- (the os.time() as of here) and calling nn.normal() each trial.
|
||||
for i = 1, epoch_trials do
|
||||
local noise = nn.zeros(#base_params)
|
||||
for j = 1, #base_params do noise[j] = nn.normal() end
|
||||
-- NOTE: change in implementation: deviation is multiplied here
|
||||
-- and ONLY here now.
|
||||
for j = 1, #base_params do noise[j] = deviation * nn.normal() end
|
||||
trial_noise[i] = noise
|
||||
end
|
||||
trial_i = -1
|
||||
|
@ -577,24 +591,24 @@ local function load_next_pair()
|
|||
|
||||
if trial_i > 0 then
|
||||
if trial_neg then
|
||||
print('trial', trial_i, 'positive')
|
||||
if not defer_prints then print('trial', trial_i, 'positive') end
|
||||
local noise = trial_noise[trial_i]
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v + deviation * noise[i]
|
||||
W[i] = v + noise[i]
|
||||
end
|
||||
|
||||
else
|
||||
trial_i = trial_i - 1
|
||||
print('trial', trial_i, 'negative')
|
||||
if not defer_prints then print('trial', trial_i, 'positive') end
|
||||
local noise = trial_noise[trial_i]
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v - deviation * noise[i]
|
||||
W[i] = v - noise[i]
|
||||
end
|
||||
end
|
||||
|
||||
trial_neg = not trial_neg
|
||||
else
|
||||
print("test trial")
|
||||
if not defer_prints then print("test trial") end
|
||||
end
|
||||
|
||||
network:distribute(W)
|
||||
|
@ -611,7 +625,7 @@ local function load_next_trial()
|
|||
print('loading trial', trial_i)
|
||||
local noise = trial_noise[trial_i]
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v + deviation * noise[i]
|
||||
W[i] = v + noise[i]
|
||||
end
|
||||
else
|
||||
print("test trial")
|
||||
|
@ -702,12 +716,11 @@ local function learn_from_epoch()
|
|||
end
|
||||
print("top:", top_rewards)
|
||||
|
||||
local reward_mean, reward_dev = calc_mean_dev(top_rewards)
|
||||
--print("mean, dev:", reward_mean, reward_dev)
|
||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||
--print("mean, dev:", _, reward_dev)
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
|
||||
--print("scaled:", top_rewards)
|
||||
|
||||
-- NOTE: step no longer directly incorporates learning_rate.
|
||||
for i = 1, epoch_trials do
|
||||
|
@ -717,13 +730,14 @@ local function learn_from_epoch()
|
|||
local reward = pos - neg
|
||||
local noise = trial_noise[i]
|
||||
for j, v in ipairs(noise) do
|
||||
step[j] = step[j] + (reward * v) / epoch_trials
|
||||
step[j] = step[j] + reward * v / epoch_top_trials
|
||||
end
|
||||
end
|
||||
|
||||
local step_mean, step_dev = calc_mean_dev(step)
|
||||
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
|
||||
print("step stddev:", step_dev)
|
||||
--print("step stddev:", step_dev)
|
||||
print("full step stddev:", learning_rate * step_dev)
|
||||
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v + learning_rate * step[i]
|
||||
|
@ -743,7 +757,27 @@ local function do_reset()
|
|||
local state = get_state()
|
||||
-- be a little more descriptive.
|
||||
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
||||
print("reward:", reward, "("..state..")")
|
||||
|
||||
if trial_i >= 0 and defer_prints then
|
||||
if trial_i == 0 then
|
||||
print('test trial reward:', reward, "("..state..")")
|
||||
elseif negate_trials then
|
||||
--local dir = trial_neg and "negative" or "positive"
|
||||
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
|
||||
|
||||
if trial_neg then
|
||||
local pos = trial_rewards[#trial_rewards]
|
||||
local neg = reward
|
||||
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
|
||||
print(fmt:format(trial_i, pos, neg, last_trial_state, state))
|
||||
end
|
||||
last_trial_state = state
|
||||
else
|
||||
print('trial', trial_i, 'reward:', reward, "("..state..")")
|
||||
end
|
||||
else
|
||||
print("reward:", reward, "("..state..")")
|
||||
end
|
||||
|
||||
if trial_i >= 0 then
|
||||
if trial_i == 0 or not negate_trials then
|
||||
|
@ -777,8 +811,9 @@ local function do_reset()
|
|||
|
||||
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
||||
--max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
|
||||
--max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time)
|
||||
max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time)
|
||||
max_time = ceil(max_time)
|
||||
--max_time = cap_time
|
||||
|
||||
if once then
|
||||
savestate.load(startsave)
|
||||
|
@ -916,9 +951,6 @@ local function doit(dummy)
|
|||
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||
total_frames = total_frames + frameskip
|
||||
|
||||
-- TODO: reimplement this.
|
||||
local choose = deterministic and argmax2 or rchoice2
|
||||
|
||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
||||
|
||||
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
|
||||
|
@ -926,7 +958,9 @@ local function doit(dummy)
|
|||
local i = floor(random() * #jp_lut) + 1
|
||||
jp = nn.copy(jp_lut[i], jp)
|
||||
else
|
||||
jp = nn.copy(jp_lut[argmax(unpack(outputs[nn_z]))], jp)
|
||||
local choose = deterministic and argmax or softchoice
|
||||
local ind = choose(unpack(outputs[nn_z]))
|
||||
jp = nn.copy(jp_lut[ind], jp)
|
||||
end
|
||||
|
||||
if force_start then
|
||||
|
|
Loading…
Reference in a new issue