experimental ARS stuff

This commit is contained in:
Connor Olding 2018-03-26 16:32:00 +02:00
parent 4765104c7a
commit 5b98023073
2 changed files with 132 additions and 43 deletions

View file

@ -3,11 +3,10 @@ please be mindful when sharing it.
however, feel free to copy any snippets of code you find useful.
TODOs: (that i can remember right now)
- finish implementing backprop
- replace evolution strategy algorithm with
something that utilizes backprop like PPO
- finish implementing ARS
'-> running mean/std normalization of all inputs (aka ARS V2)
'-> normalize and/or embed sprite inputs
- settle on a network architecture
- normalize and/or embed sprite inputs
- fix lag-frames skipped-inputs bug
- detect frames when Mario is in a controllable state
- fix offscreen sprites sometimes being visible to network

168
main.lua
View file

@ -27,23 +27,29 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
--randomseed(11)
local playable_mode = false
local start_big = true
local starting_lives = 0
--
local init_zeros = true -- instead of he_normal noise or whatever.
local frameskip = 4
-- true greedy epsilon has both deterministic and det_epsilon set.
local deterministic = true -- use argmax on outputs instead of random sampling.
local det_epsilon = true -- take random actions with probability eps.
local det_epsilon = false -- take random actions with probability eps.
local eps_start = 1.0 * frameskip / 64
local eps_stop = 0.1 * eps_start
local eps_frames = 2000000
--
local epoch_trials = 40
local negate_trials = true -- try pairs of normal and negated noise directions.
local epoch_trials = 20
local epoch_top_trials = 10 -- new with ARS.
local unperturbed_trial = true -- do a trial without any noise.
local learning_rate = 0.3 -- bigger now that i'm trying a different technique.
local deviation = 0.05
local negate_trials = true -- try pairs of normal and negated noise directions.
-- ^ note that this now doubles the effective trials.
local learning_rate = 0.01
local deviation = 0.06
--
local cap_time = 400
local timer_loser = 1/3
local cap_time = 100 --400
local timer_loser = 0 --1/3
local decrement_reward = true
--
local enable_overlay = playable_mode
local enable_network = not playable_mode
@ -118,6 +124,7 @@ local jp_lut = {
local epoch_i = 0
local base_params
local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled.
local trial_neg = true
local trial_noise = {}
local trial_rewards = {}
local trials_remaining = 0
@ -141,7 +148,7 @@ local jp
local screen_scroll_delta
local reward
local all_rewards = {}
--local all_rewards = {}
local powerup_old
local status_old
@ -280,19 +287,9 @@ local function make_network(input_size)
nn_x:feed(nn_y)
nn_ty:feed(nn_y)
if true then
nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Relu())
nn_y = nn_y:feed(nn.Dense(64))
nn_y = nn_y:feed(nn.Relu())
nn_y = nn_y:feed(nn.Dense(48))
nn_y = nn_y:feed(nn.Relu())
else
nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Gelu())
nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Gelu())
end
nn_y = nn_y:feed(nn.Dense(128))
--nn_y = nn_y:feed(nn.Gelu())
nn_y = nn_y:feed(nn.Relu())
nn_z = nn_y
nn_z = nn_z:feed(nn.Dense(#jp_lut))
@ -563,17 +560,48 @@ local function prepare_epoch()
-- (the os.time() as of here) and calling nn.normal() each trial.
for i = 1, epoch_trials do
local noise = nn.zeros(#base_params)
if negate_trials and i % 2 == 0 then -- every other iteration...
for j = 1, #base_params do noise[j] = -trial_noise[i-1][j] end
else
for j = 1, #base_params do noise[j] = nn.normal() end
end
for j = 1, #base_params do noise[j] = nn.normal() end
trial_noise[i] = noise
end
trial_i = -1
end
local function load_next_pair()
trial_i = trial_i + 1
if trial_i == 0 and not unperturbed_trial then
trial_i = 1
trial_neg = true
end
local W = nn.copy(base_params)
if trial_i > 0 then
if trial_neg then
print('trial', trial_i, 'positive')
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v + deviation * noise[i]
end
else
trial_i = trial_i - 1
print('trial', trial_i, 'negative')
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v - deviation * noise[i]
end
end
trial_neg = not trial_neg
else
print("test trial")
end
network:distribute(W)
end
local function load_next_trial()
if negate_trials then return load_next_pair() end
trial_i = trial_i + 1
local W = nn.copy(base_params)
if trial_i == 0 and not unperturbed_trial then
@ -628,11 +656,9 @@ end
local function learn_from_epoch()
print()
print('rewards:', trial_rewards)
--print('rewards:', trial_rewards)
for _, v in ipairs(trial_rewards) do
insert(all_rewards, v)
end
--for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end
if unperturbed_trial then
local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0])
@ -642,14 +668,56 @@ local function learn_from_epoch()
end
local step = nn.zeros(#base_params)
local shaped_rewards = fitness_shaping(trial_rewards)
local altogether = learning_rate / (epoch_trials * deviation)
-- new stuff
local best_rewards
if negate_trials then
-- select one (the best) reward of each pos/neg pair.
best_rewards = {}
for i = 1, epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = trial_rewards[ind + 0]
local neg = trial_rewards[ind + 1]
best_rewards[i] = max(pos, neg)
end
else
best_rewards = nn.copy(trial_rewards)
end
local indices = {}
for i = 1, #best_rewards do indices[i] = i end
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
--print("indices:", indices)
for i = epoch_top_trials + 1, #best_rewards do indices[i] = nil end
print("best trials:", indices)
local top_rewards = {}
for i = 1, #trial_rewards do top_rewards[i] = 0 end
for _, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1
top_rewards[sind + 0] = trial_rewards[sind + 0]
top_rewards[sind + 1] = trial_rewards[sind + 1]
end
print("top:", top_rewards)
local reward_mean, reward_dev = calc_mean_dev(top_rewards)
--print("mean, dev:", reward_mean, reward_dev)
if reward_dev == 0 then reward_dev = 1 end
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
--print("scaled:", top_rewards)
-- NOTE: step no longer directly incorporates learning_rate.
for i = 1, epoch_trials do
local reward = shaped_rewards[i]
local ind = (i - 1) * 2 + 1
local pos = top_rewards[ind + 0]
local neg = top_rewards[ind + 1]
local reward = pos - neg
local noise = trial_noise[i]
for j, v in ipairs(noise) do
step[j] = step[j] + altogether * (reward * v)
step[j] = step[j] + (reward * v) / epoch_trials
end
end
@ -658,7 +726,7 @@ local function learn_from_epoch()
print("step stddev:", step_dev)
for i, v in ipairs(base_params) do
base_params[i] = v + step[i]
base_params[i] = v + learning_rate * step[i]
end
if enable_network then
@ -677,9 +745,15 @@ local function do_reset()
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
print("reward:", reward, "("..state..")")
if trial_i >= 0 then trial_rewards[trial_i] = reward end
if trial_i >= 0 then
if trial_i == 0 or not negate_trials then
trial_rewards[trial_i] = reward
else
trial_rewards[#trial_rewards + 1] = reward
end
end
if epoch_i == 0 or trial_i == epoch_trials then
if epoch_i == 0 or (trial_i == epoch_trials and trial_neg) then
if epoch_i > 0 then learn_from_epoch() end
epoch_i = epoch_i + 1
prepare_epoch()
@ -692,10 +766,18 @@ local function do_reset()
coins_old = R(0x7ED) * 10 + R(0x7EE)
score_old = get_score()
W(0x75A, 1) -- set number of lives. (mario gets n+1 chances)
-- set number of lives. (mario gets n+1 chances)
W(0x75A, starting_lives)
if start_big then
-- make mario "super".
W(0x754, 0)
W(0x756, 1)
end
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
--max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
--max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time)
max_time = ceil(max_time)
if once then
@ -720,6 +802,12 @@ local function init()
network:print()
print("parameters:", network.n_param)
if init_zeros then
local W = network:collect()
for i, w in ipairs(W) do W[i] = 0 end
network:distribute(W)
end
emu.poweron()
emu.unpause()
emu.speedmode("turbo")
@ -794,6 +882,8 @@ local function doit(dummy)
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
screen_scroll_delta = 0
if decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end
if not ingame_paused then reward = reward + reward_delta end
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')