experimental ARS stuff

This commit is contained in:
Connor Olding 2018-03-26 16:32:00 +02:00
parent 4765104c7a
commit 5b98023073
2 changed files with 132 additions and 43 deletions

View file

@ -3,11 +3,10 @@ please be mindful when sharing it.
however, feel free to copy any snippets of code you find useful. however, feel free to copy any snippets of code you find useful.
TODOs: (that i can remember right now) TODOs: (that i can remember right now)
- finish implementing backprop - finish implementing ARS
- replace evolution strategy algorithm with '-> running mean/std normalization of all inputs (aka ARS V2)
something that utilizes backprop like PPO '-> normalize and/or embed sprite inputs
- settle on a network architecture - settle on a network architecture
- normalize and/or embed sprite inputs
- fix lag-frames skipped-inputs bug - fix lag-frames skipped-inputs bug
- detect frames when Mario is in a controllable state - detect frames when Mario is in a controllable state
- fix offscreen sprites sometimes being visible to network - fix offscreen sprites sometimes being visible to network

162
main.lua
View file

@ -27,23 +27,29 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
--randomseed(11) --randomseed(11)
local playable_mode = false local playable_mode = false
local start_big = true
local starting_lives = 0
-- --
local init_zeros = true -- instead of he_normal noise or whatever.
local frameskip = 4 local frameskip = 4
-- true greedy epsilon has both deterministic and det_epsilon set. -- true greedy epsilon has both deterministic and det_epsilon set.
local deterministic = true -- use argmax on outputs instead of random sampling. local deterministic = true -- use argmax on outputs instead of random sampling.
local det_epsilon = true -- take random actions with probability eps. local det_epsilon = false -- take random actions with probability eps.
local eps_start = 1.0 * frameskip / 64 local eps_start = 1.0 * frameskip / 64
local eps_stop = 0.1 * eps_start local eps_stop = 0.1 * eps_start
local eps_frames = 2000000 local eps_frames = 2000000
-- --
local epoch_trials = 40 local epoch_trials = 20
local negate_trials = true -- try pairs of normal and negated noise directions. local epoch_top_trials = 10 -- new with ARS.
local unperturbed_trial = true -- do a trial without any noise. local unperturbed_trial = true -- do a trial without any noise.
local learning_rate = 0.3 -- bigger now that i'm trying a different technique. local negate_trials = true -- try pairs of normal and negated noise directions.
local deviation = 0.05 -- ^ note that this now doubles the effective trials.
local learning_rate = 0.01
local deviation = 0.06
-- --
local cap_time = 400 local cap_time = 100 --400
local timer_loser = 1/3 local timer_loser = 0 --1/3
local decrement_reward = true
-- --
local enable_overlay = playable_mode local enable_overlay = playable_mode
local enable_network = not playable_mode local enable_network = not playable_mode
@ -118,6 +124,7 @@ local jp_lut = {
local epoch_i = 0 local epoch_i = 0
local base_params local base_params
local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled. local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled.
local trial_neg = true
local trial_noise = {} local trial_noise = {}
local trial_rewards = {} local trial_rewards = {}
local trials_remaining = 0 local trials_remaining = 0
@ -141,7 +148,7 @@ local jp
local screen_scroll_delta local screen_scroll_delta
local reward local reward
local all_rewards = {} --local all_rewards = {}
local powerup_old local powerup_old
local status_old local status_old
@ -280,19 +287,9 @@ local function make_network(input_size)
nn_x:feed(nn_y) nn_x:feed(nn_y)
nn_ty:feed(nn_y) nn_ty:feed(nn_y)
if true then
nn_y = nn_y:feed(nn.Dense(128)) nn_y = nn_y:feed(nn.Dense(128))
--nn_y = nn_y:feed(nn.Gelu())
nn_y = nn_y:feed(nn.Relu()) nn_y = nn_y:feed(nn.Relu())
nn_y = nn_y:feed(nn.Dense(64))
nn_y = nn_y:feed(nn.Relu())
nn_y = nn_y:feed(nn.Dense(48))
nn_y = nn_y:feed(nn.Relu())
else
nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Gelu())
nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Gelu())
end
nn_z = nn_y nn_z = nn_y
nn_z = nn_z:feed(nn.Dense(#jp_lut)) nn_z = nn_z:feed(nn.Dense(#jp_lut))
@ -563,17 +560,48 @@ local function prepare_epoch()
-- (the os.time() as of here) and calling nn.normal() each trial. -- (the os.time() as of here) and calling nn.normal() each trial.
for i = 1, epoch_trials do for i = 1, epoch_trials do
local noise = nn.zeros(#base_params) local noise = nn.zeros(#base_params)
if negate_trials and i % 2 == 0 then -- every other iteration...
for j = 1, #base_params do noise[j] = -trial_noise[i-1][j] end
else
for j = 1, #base_params do noise[j] = nn.normal() end for j = 1, #base_params do noise[j] = nn.normal() end
end
trial_noise[i] = noise trial_noise[i] = noise
end end
trial_i = -1 trial_i = -1
end end
local function load_next_pair()
trial_i = trial_i + 1
if trial_i == 0 and not unperturbed_trial then
trial_i = 1
trial_neg = true
end
local W = nn.copy(base_params)
if trial_i > 0 then
if trial_neg then
print('trial', trial_i, 'positive')
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v + deviation * noise[i]
end
else
trial_i = trial_i - 1
print('trial', trial_i, 'negative')
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v - deviation * noise[i]
end
end
trial_neg = not trial_neg
else
print("test trial")
end
network:distribute(W)
end
local function load_next_trial() local function load_next_trial()
if negate_trials then return load_next_pair() end
trial_i = trial_i + 1 trial_i = trial_i + 1
local W = nn.copy(base_params) local W = nn.copy(base_params)
if trial_i == 0 and not unperturbed_trial then if trial_i == 0 and not unperturbed_trial then
@ -628,11 +656,9 @@ end
local function learn_from_epoch() local function learn_from_epoch()
print() print()
print('rewards:', trial_rewards) --print('rewards:', trial_rewards)
for _, v in ipairs(trial_rewards) do --for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end
insert(all_rewards, v)
end
if unperturbed_trial then if unperturbed_trial then
local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0]) local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0])
@ -642,14 +668,56 @@ local function learn_from_epoch()
end end
local step = nn.zeros(#base_params) local step = nn.zeros(#base_params)
local shaped_rewards = fitness_shaping(trial_rewards)
local altogether = learning_rate / (epoch_trials * deviation) -- new stuff
local best_rewards
if negate_trials then
-- select one (the best) reward of each pos/neg pair.
best_rewards = {}
for i = 1, epoch_trials do for i = 1, epoch_trials do
local reward = shaped_rewards[i] local ind = (i - 1) * 2 + 1
local pos = trial_rewards[ind + 0]
local neg = trial_rewards[ind + 1]
best_rewards[i] = max(pos, neg)
end
else
best_rewards = nn.copy(trial_rewards)
end
local indices = {}
for i = 1, #best_rewards do indices[i] = i end
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
--print("indices:", indices)
for i = epoch_top_trials + 1, #best_rewards do indices[i] = nil end
print("best trials:", indices)
local top_rewards = {}
for i = 1, #trial_rewards do top_rewards[i] = 0 end
for _, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1
top_rewards[sind + 0] = trial_rewards[sind + 0]
top_rewards[sind + 1] = trial_rewards[sind + 1]
end
print("top:", top_rewards)
local reward_mean, reward_dev = calc_mean_dev(top_rewards)
--print("mean, dev:", reward_mean, reward_dev)
if reward_dev == 0 then reward_dev = 1 end
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
--print("scaled:", top_rewards)
-- NOTE: step no longer directly incorporates learning_rate.
for i = 1, epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = top_rewards[ind + 0]
local neg = top_rewards[ind + 1]
local reward = pos - neg
local noise = trial_noise[i] local noise = trial_noise[i]
for j, v in ipairs(noise) do for j, v in ipairs(noise) do
step[j] = step[j] + altogether * (reward * v) step[j] = step[j] + (reward * v) / epoch_trials
end end
end end
@ -658,7 +726,7 @@ local function learn_from_epoch()
print("step stddev:", step_dev) print("step stddev:", step_dev)
for i, v in ipairs(base_params) do for i, v in ipairs(base_params) do
base_params[i] = v + step[i] base_params[i] = v + learning_rate * step[i]
end end
if enable_network then if enable_network then
@ -677,9 +745,15 @@ local function do_reset()
if state == 'dead' and get_timer() == 0 then state = 'timeup' end if state == 'dead' and get_timer() == 0 then state = 'timeup' end
print("reward:", reward, "("..state..")") print("reward:", reward, "("..state..")")
if trial_i >= 0 then trial_rewards[trial_i] = reward end if trial_i >= 0 then
if trial_i == 0 or not negate_trials then
trial_rewards[trial_i] = reward
else
trial_rewards[#trial_rewards + 1] = reward
end
end
if epoch_i == 0 or trial_i == epoch_trials then if epoch_i == 0 or (trial_i == epoch_trials and trial_neg) then
if epoch_i > 0 then learn_from_epoch() end if epoch_i > 0 then learn_from_epoch() end
epoch_i = epoch_i + 1 epoch_i = epoch_i + 1
prepare_epoch() prepare_epoch()
@ -692,10 +766,18 @@ local function do_reset()
coins_old = R(0x7ED) * 10 + R(0x7EE) coins_old = R(0x7ED) * 10 + R(0x7EE)
score_old = get_score() score_old = get_score()
W(0x75A, 1) -- set number of lives. (mario gets n+1 chances) -- set number of lives. (mario gets n+1 chances)
W(0x75A, starting_lives)
if start_big then
-- make mario "super".
W(0x754, 0)
W(0x756, 1)
end
--max_time = min(log(epoch_i) * 10 + 100, cap_time) --max_time = min(log(epoch_i) * 10 + 100, cap_time)
max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time) --max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
--max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time)
max_time = ceil(max_time) max_time = ceil(max_time)
if once then if once then
@ -720,6 +802,12 @@ local function init()
network:print() network:print()
print("parameters:", network.n_param) print("parameters:", network.n_param)
if init_zeros then
local W = network:collect()
for i, w in ipairs(W) do W[i] = 0 end
network:distribute(W)
end
emu.poweron() emu.poweron()
emu.unpause() emu.unpause()
emu.speedmode("turbo") emu.speedmode("turbo")
@ -794,6 +882,8 @@ local function doit(dummy)
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
screen_scroll_delta = 0 screen_scroll_delta = 0
if decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end
if not ingame_paused then reward = reward + reward_delta end if not ingame_paused then reward = reward + reward_delta end
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') --gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')