experimental ARS stuff
This commit is contained in:
parent
4765104c7a
commit
5b98023073
2 changed files with 132 additions and 43 deletions
7
_NOTICE
7
_NOTICE
|
@ -3,11 +3,10 @@ please be mindful when sharing it.
|
||||||
however, feel free to copy any snippets of code you find useful.
|
however, feel free to copy any snippets of code you find useful.
|
||||||
|
|
||||||
TODOs: (that i can remember right now)
|
TODOs: (that i can remember right now)
|
||||||
- finish implementing backprop
|
- finish implementing ARS
|
||||||
- replace evolution strategy algorithm with
|
'-> running mean/std normalization of all inputs (aka ARS V2)
|
||||||
something that utilizes backprop like PPO
|
'-> normalize and/or embed sprite inputs
|
||||||
- settle on a network architecture
|
- settle on a network architecture
|
||||||
- normalize and/or embed sprite inputs
|
|
||||||
- fix lag-frames skipped-inputs bug
|
- fix lag-frames skipped-inputs bug
|
||||||
- detect frames when Mario is in a controllable state
|
- detect frames when Mario is in a controllable state
|
||||||
- fix offscreen sprites sometimes being visible to network
|
- fix offscreen sprites sometimes being visible to network
|
||||||
|
|
168
main.lua
168
main.lua
|
@ -27,23 +27,29 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
|
||||||
--randomseed(11)
|
--randomseed(11)
|
||||||
|
|
||||||
local playable_mode = false
|
local playable_mode = false
|
||||||
|
local start_big = true
|
||||||
|
local starting_lives = 0
|
||||||
--
|
--
|
||||||
|
local init_zeros = true -- instead of he_normal noise or whatever.
|
||||||
local frameskip = 4
|
local frameskip = 4
|
||||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||||
local deterministic = true -- use argmax on outputs instead of random sampling.
|
local deterministic = true -- use argmax on outputs instead of random sampling.
|
||||||
local det_epsilon = true -- take random actions with probability eps.
|
local det_epsilon = false -- take random actions with probability eps.
|
||||||
local eps_start = 1.0 * frameskip / 64
|
local eps_start = 1.0 * frameskip / 64
|
||||||
local eps_stop = 0.1 * eps_start
|
local eps_stop = 0.1 * eps_start
|
||||||
local eps_frames = 2000000
|
local eps_frames = 2000000
|
||||||
--
|
--
|
||||||
local epoch_trials = 40
|
local epoch_trials = 20
|
||||||
local negate_trials = true -- try pairs of normal and negated noise directions.
|
local epoch_top_trials = 10 -- new with ARS.
|
||||||
local unperturbed_trial = true -- do a trial without any noise.
|
local unperturbed_trial = true -- do a trial without any noise.
|
||||||
local learning_rate = 0.3 -- bigger now that i'm trying a different technique.
|
local negate_trials = true -- try pairs of normal and negated noise directions.
|
||||||
local deviation = 0.05
|
-- ^ note that this now doubles the effective trials.
|
||||||
|
local learning_rate = 0.01
|
||||||
|
local deviation = 0.06
|
||||||
--
|
--
|
||||||
local cap_time = 400
|
local cap_time = 100 --400
|
||||||
local timer_loser = 1/3
|
local timer_loser = 0 --1/3
|
||||||
|
local decrement_reward = true
|
||||||
--
|
--
|
||||||
local enable_overlay = playable_mode
|
local enable_overlay = playable_mode
|
||||||
local enable_network = not playable_mode
|
local enable_network = not playable_mode
|
||||||
|
@ -118,6 +124,7 @@ local jp_lut = {
|
||||||
local epoch_i = 0
|
local epoch_i = 0
|
||||||
local base_params
|
local base_params
|
||||||
local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled.
|
local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled.
|
||||||
|
local trial_neg = true
|
||||||
local trial_noise = {}
|
local trial_noise = {}
|
||||||
local trial_rewards = {}
|
local trial_rewards = {}
|
||||||
local trials_remaining = 0
|
local trials_remaining = 0
|
||||||
|
@ -141,7 +148,7 @@ local jp
|
||||||
|
|
||||||
local screen_scroll_delta
|
local screen_scroll_delta
|
||||||
local reward
|
local reward
|
||||||
local all_rewards = {}
|
--local all_rewards = {}
|
||||||
|
|
||||||
local powerup_old
|
local powerup_old
|
||||||
local status_old
|
local status_old
|
||||||
|
@ -280,19 +287,9 @@ local function make_network(input_size)
|
||||||
nn_x:feed(nn_y)
|
nn_x:feed(nn_y)
|
||||||
nn_ty:feed(nn_y)
|
nn_ty:feed(nn_y)
|
||||||
|
|
||||||
if true then
|
nn_y = nn_y:feed(nn.Dense(128))
|
||||||
nn_y = nn_y:feed(nn.Dense(128))
|
--nn_y = nn_y:feed(nn.Gelu())
|
||||||
nn_y = nn_y:feed(nn.Relu())
|
nn_y = nn_y:feed(nn.Relu())
|
||||||
nn_y = nn_y:feed(nn.Dense(64))
|
|
||||||
nn_y = nn_y:feed(nn.Relu())
|
|
||||||
nn_y = nn_y:feed(nn.Dense(48))
|
|
||||||
nn_y = nn_y:feed(nn.Relu())
|
|
||||||
else
|
|
||||||
nn_y = nn_y:feed(nn.Dense(128))
|
|
||||||
nn_y = nn_y:feed(nn.Gelu())
|
|
||||||
nn_y = nn_y:feed(nn.Dense(128))
|
|
||||||
nn_y = nn_y:feed(nn.Gelu())
|
|
||||||
end
|
|
||||||
|
|
||||||
nn_z = nn_y
|
nn_z = nn_y
|
||||||
nn_z = nn_z:feed(nn.Dense(#jp_lut))
|
nn_z = nn_z:feed(nn.Dense(#jp_lut))
|
||||||
|
@ -563,17 +560,48 @@ local function prepare_epoch()
|
||||||
-- (the os.time() as of here) and calling nn.normal() each trial.
|
-- (the os.time() as of here) and calling nn.normal() each trial.
|
||||||
for i = 1, epoch_trials do
|
for i = 1, epoch_trials do
|
||||||
local noise = nn.zeros(#base_params)
|
local noise = nn.zeros(#base_params)
|
||||||
if negate_trials and i % 2 == 0 then -- every other iteration...
|
for j = 1, #base_params do noise[j] = nn.normal() end
|
||||||
for j = 1, #base_params do noise[j] = -trial_noise[i-1][j] end
|
|
||||||
else
|
|
||||||
for j = 1, #base_params do noise[j] = nn.normal() end
|
|
||||||
end
|
|
||||||
trial_noise[i] = noise
|
trial_noise[i] = noise
|
||||||
end
|
end
|
||||||
trial_i = -1
|
trial_i = -1
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function load_next_pair()
|
||||||
|
trial_i = trial_i + 1
|
||||||
|
if trial_i == 0 and not unperturbed_trial then
|
||||||
|
trial_i = 1
|
||||||
|
trial_neg = true
|
||||||
|
end
|
||||||
|
|
||||||
|
local W = nn.copy(base_params)
|
||||||
|
|
||||||
|
if trial_i > 0 then
|
||||||
|
if trial_neg then
|
||||||
|
print('trial', trial_i, 'positive')
|
||||||
|
local noise = trial_noise[trial_i]
|
||||||
|
for i, v in ipairs(base_params) do
|
||||||
|
W[i] = v + deviation * noise[i]
|
||||||
|
end
|
||||||
|
|
||||||
|
else
|
||||||
|
trial_i = trial_i - 1
|
||||||
|
print('trial', trial_i, 'negative')
|
||||||
|
local noise = trial_noise[trial_i]
|
||||||
|
for i, v in ipairs(base_params) do
|
||||||
|
W[i] = v - deviation * noise[i]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
trial_neg = not trial_neg
|
||||||
|
else
|
||||||
|
print("test trial")
|
||||||
|
end
|
||||||
|
|
||||||
|
network:distribute(W)
|
||||||
|
end
|
||||||
|
|
||||||
local function load_next_trial()
|
local function load_next_trial()
|
||||||
|
if negate_trials then return load_next_pair() end
|
||||||
trial_i = trial_i + 1
|
trial_i = trial_i + 1
|
||||||
local W = nn.copy(base_params)
|
local W = nn.copy(base_params)
|
||||||
if trial_i == 0 and not unperturbed_trial then
|
if trial_i == 0 and not unperturbed_trial then
|
||||||
|
@ -628,11 +656,9 @@ end
|
||||||
|
|
||||||
local function learn_from_epoch()
|
local function learn_from_epoch()
|
||||||
print()
|
print()
|
||||||
print('rewards:', trial_rewards)
|
--print('rewards:', trial_rewards)
|
||||||
|
|
||||||
for _, v in ipairs(trial_rewards) do
|
--for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end
|
||||||
insert(all_rewards, v)
|
|
||||||
end
|
|
||||||
|
|
||||||
if unperturbed_trial then
|
if unperturbed_trial then
|
||||||
local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0])
|
local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0])
|
||||||
|
@ -642,14 +668,56 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
|
|
||||||
local step = nn.zeros(#base_params)
|
local step = nn.zeros(#base_params)
|
||||||
local shaped_rewards = fitness_shaping(trial_rewards)
|
|
||||||
|
|
||||||
local altogether = learning_rate / (epoch_trials * deviation)
|
-- new stuff
|
||||||
|
|
||||||
|
local best_rewards
|
||||||
|
if negate_trials then
|
||||||
|
-- select one (the best) reward of each pos/neg pair.
|
||||||
|
best_rewards = {}
|
||||||
|
for i = 1, epoch_trials do
|
||||||
|
local ind = (i - 1) * 2 + 1
|
||||||
|
local pos = trial_rewards[ind + 0]
|
||||||
|
local neg = trial_rewards[ind + 1]
|
||||||
|
best_rewards[i] = max(pos, neg)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
best_rewards = nn.copy(trial_rewards)
|
||||||
|
end
|
||||||
|
|
||||||
|
local indices = {}
|
||||||
|
for i = 1, #best_rewards do indices[i] = i end
|
||||||
|
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
|
||||||
|
|
||||||
|
--print("indices:", indices)
|
||||||
|
for i = epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
||||||
|
print("best trials:", indices)
|
||||||
|
|
||||||
|
local top_rewards = {}
|
||||||
|
for i = 1, #trial_rewards do top_rewards[i] = 0 end
|
||||||
|
for _, ind in ipairs(indices) do
|
||||||
|
local sind = (ind - 1) * 2 + 1
|
||||||
|
top_rewards[sind + 0] = trial_rewards[sind + 0]
|
||||||
|
top_rewards[sind + 1] = trial_rewards[sind + 1]
|
||||||
|
end
|
||||||
|
print("top:", top_rewards)
|
||||||
|
|
||||||
|
local reward_mean, reward_dev = calc_mean_dev(top_rewards)
|
||||||
|
--print("mean, dev:", reward_mean, reward_dev)
|
||||||
|
if reward_dev == 0 then reward_dev = 1 end
|
||||||
|
|
||||||
|
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
|
||||||
|
--print("scaled:", top_rewards)
|
||||||
|
|
||||||
|
-- NOTE: step no longer directly incorporates learning_rate.
|
||||||
for i = 1, epoch_trials do
|
for i = 1, epoch_trials do
|
||||||
local reward = shaped_rewards[i]
|
local ind = (i - 1) * 2 + 1
|
||||||
|
local pos = top_rewards[ind + 0]
|
||||||
|
local neg = top_rewards[ind + 1]
|
||||||
|
local reward = pos - neg
|
||||||
local noise = trial_noise[i]
|
local noise = trial_noise[i]
|
||||||
for j, v in ipairs(noise) do
|
for j, v in ipairs(noise) do
|
||||||
step[j] = step[j] + altogether * (reward * v)
|
step[j] = step[j] + (reward * v) / epoch_trials
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -658,7 +726,7 @@ local function learn_from_epoch()
|
||||||
print("step stddev:", step_dev)
|
print("step stddev:", step_dev)
|
||||||
|
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
base_params[i] = v + step[i]
|
base_params[i] = v + learning_rate * step[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
if enable_network then
|
if enable_network then
|
||||||
|
@ -677,9 +745,15 @@ local function do_reset()
|
||||||
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
||||||
print("reward:", reward, "("..state..")")
|
print("reward:", reward, "("..state..")")
|
||||||
|
|
||||||
if trial_i >= 0 then trial_rewards[trial_i] = reward end
|
if trial_i >= 0 then
|
||||||
|
if trial_i == 0 or not negate_trials then
|
||||||
|
trial_rewards[trial_i] = reward
|
||||||
|
else
|
||||||
|
trial_rewards[#trial_rewards + 1] = reward
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
if epoch_i == 0 or trial_i == epoch_trials then
|
if epoch_i == 0 or (trial_i == epoch_trials and trial_neg) then
|
||||||
if epoch_i > 0 then learn_from_epoch() end
|
if epoch_i > 0 then learn_from_epoch() end
|
||||||
epoch_i = epoch_i + 1
|
epoch_i = epoch_i + 1
|
||||||
prepare_epoch()
|
prepare_epoch()
|
||||||
|
@ -692,10 +766,18 @@ local function do_reset()
|
||||||
coins_old = R(0x7ED) * 10 + R(0x7EE)
|
coins_old = R(0x7ED) * 10 + R(0x7EE)
|
||||||
score_old = get_score()
|
score_old = get_score()
|
||||||
|
|
||||||
W(0x75A, 1) -- set number of lives. (mario gets n+1 chances)
|
-- set number of lives. (mario gets n+1 chances)
|
||||||
|
W(0x75A, starting_lives)
|
||||||
|
|
||||||
|
if start_big then
|
||||||
|
-- make mario "super".
|
||||||
|
W(0x754, 0)
|
||||||
|
W(0x756, 1)
|
||||||
|
end
|
||||||
|
|
||||||
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
||||||
max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
|
--max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
|
||||||
|
--max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time)
|
||||||
max_time = ceil(max_time)
|
max_time = ceil(max_time)
|
||||||
|
|
||||||
if once then
|
if once then
|
||||||
|
@ -720,6 +802,12 @@ local function init()
|
||||||
network:print()
|
network:print()
|
||||||
print("parameters:", network.n_param)
|
print("parameters:", network.n_param)
|
||||||
|
|
||||||
|
if init_zeros then
|
||||||
|
local W = network:collect()
|
||||||
|
for i, w in ipairs(W) do W[i] = 0 end
|
||||||
|
network:distribute(W)
|
||||||
|
end
|
||||||
|
|
||||||
emu.poweron()
|
emu.poweron()
|
||||||
emu.unpause()
|
emu.unpause()
|
||||||
emu.speedmode("turbo")
|
emu.speedmode("turbo")
|
||||||
|
@ -794,6 +882,8 @@ local function doit(dummy)
|
||||||
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
|
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
|
||||||
screen_scroll_delta = 0
|
screen_scroll_delta = 0
|
||||||
|
|
||||||
|
if decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end
|
||||||
|
|
||||||
if not ingame_paused then reward = reward + reward_delta end
|
if not ingame_paused then reward = reward + reward_delta end
|
||||||
|
|
||||||
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
||||||
|
|
Loading…
Reference in a new issue