experimental ARS stuff
This commit is contained in:
parent
4765104c7a
commit
5b98023073
2 changed files with 132 additions and 43 deletions
7
_NOTICE
7
_NOTICE
|
@ -3,11 +3,10 @@ please be mindful when sharing it.
|
|||
however, feel free to copy any snippets of code you find useful.
|
||||
|
||||
TODOs: (that i can remember right now)
|
||||
- finish implementing backprop
|
||||
- replace evolution strategy algorithm with
|
||||
something that utilizes backprop like PPO
|
||||
- finish implementing ARS
|
||||
'-> running mean/std normalization of all inputs (aka ARS V2)
|
||||
'-> normalize and/or embed sprite inputs
|
||||
- settle on a network architecture
|
||||
- normalize and/or embed sprite inputs
|
||||
- fix lag-frames skipped-inputs bug
|
||||
- detect frames when Mario is in a controllable state
|
||||
- fix offscreen sprites sometimes being visible to network
|
||||
|
|
162
main.lua
162
main.lua
|
@ -27,23 +27,29 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
|
|||
--randomseed(11)
|
||||
|
||||
local playable_mode = false
|
||||
local start_big = true
|
||||
local starting_lives = 0
|
||||
--
|
||||
local init_zeros = true -- instead of he_normal noise or whatever.
|
||||
local frameskip = 4
|
||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||
local deterministic = true -- use argmax on outputs instead of random sampling.
|
||||
local det_epsilon = true -- take random actions with probability eps.
|
||||
local det_epsilon = false -- take random actions with probability eps.
|
||||
local eps_start = 1.0 * frameskip / 64
|
||||
local eps_stop = 0.1 * eps_start
|
||||
local eps_frames = 2000000
|
||||
--
|
||||
local epoch_trials = 40
|
||||
local negate_trials = true -- try pairs of normal and negated noise directions.
|
||||
local epoch_trials = 20
|
||||
local epoch_top_trials = 10 -- new with ARS.
|
||||
local unperturbed_trial = true -- do a trial without any noise.
|
||||
local learning_rate = 0.3 -- bigger now that i'm trying a different technique.
|
||||
local deviation = 0.05
|
||||
local negate_trials = true -- try pairs of normal and negated noise directions.
|
||||
-- ^ note that this now doubles the effective trials.
|
||||
local learning_rate = 0.01
|
||||
local deviation = 0.06
|
||||
--
|
||||
local cap_time = 400
|
||||
local timer_loser = 1/3
|
||||
local cap_time = 100 --400
|
||||
local timer_loser = 0 --1/3
|
||||
local decrement_reward = true
|
||||
--
|
||||
local enable_overlay = playable_mode
|
||||
local enable_network = not playable_mode
|
||||
|
@ -118,6 +124,7 @@ local jp_lut = {
|
|||
local epoch_i = 0
|
||||
local base_params
|
||||
local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled.
|
||||
local trial_neg = true
|
||||
local trial_noise = {}
|
||||
local trial_rewards = {}
|
||||
local trials_remaining = 0
|
||||
|
@ -141,7 +148,7 @@ local jp
|
|||
|
||||
local screen_scroll_delta
|
||||
local reward
|
||||
local all_rewards = {}
|
||||
--local all_rewards = {}
|
||||
|
||||
local powerup_old
|
||||
local status_old
|
||||
|
@ -280,19 +287,9 @@ local function make_network(input_size)
|
|||
nn_x:feed(nn_y)
|
||||
nn_ty:feed(nn_y)
|
||||
|
||||
if true then
|
||||
nn_y = nn_y:feed(nn.Dense(128))
|
||||
--nn_y = nn_y:feed(nn.Gelu())
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
nn_y = nn_y:feed(nn.Dense(64))
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
nn_y = nn_y:feed(nn.Dense(48))
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
else
|
||||
nn_y = nn_y:feed(nn.Dense(128))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
nn_y = nn_y:feed(nn.Dense(128))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
end
|
||||
|
||||
nn_z = nn_y
|
||||
nn_z = nn_z:feed(nn.Dense(#jp_lut))
|
||||
|
@ -563,17 +560,48 @@ local function prepare_epoch()
|
|||
-- (the os.time() as of here) and calling nn.normal() each trial.
|
||||
for i = 1, epoch_trials do
|
||||
local noise = nn.zeros(#base_params)
|
||||
if negate_trials and i % 2 == 0 then -- every other iteration...
|
||||
for j = 1, #base_params do noise[j] = -trial_noise[i-1][j] end
|
||||
else
|
||||
for j = 1, #base_params do noise[j] = nn.normal() end
|
||||
end
|
||||
trial_noise[i] = noise
|
||||
end
|
||||
trial_i = -1
|
||||
end
|
||||
|
||||
local function load_next_pair()
|
||||
trial_i = trial_i + 1
|
||||
if trial_i == 0 and not unperturbed_trial then
|
||||
trial_i = 1
|
||||
trial_neg = true
|
||||
end
|
||||
|
||||
local W = nn.copy(base_params)
|
||||
|
||||
if trial_i > 0 then
|
||||
if trial_neg then
|
||||
print('trial', trial_i, 'positive')
|
||||
local noise = trial_noise[trial_i]
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v + deviation * noise[i]
|
||||
end
|
||||
|
||||
else
|
||||
trial_i = trial_i - 1
|
||||
print('trial', trial_i, 'negative')
|
||||
local noise = trial_noise[trial_i]
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v - deviation * noise[i]
|
||||
end
|
||||
end
|
||||
|
||||
trial_neg = not trial_neg
|
||||
else
|
||||
print("test trial")
|
||||
end
|
||||
|
||||
network:distribute(W)
|
||||
end
|
||||
|
||||
local function load_next_trial()
|
||||
if negate_trials then return load_next_pair() end
|
||||
trial_i = trial_i + 1
|
||||
local W = nn.copy(base_params)
|
||||
if trial_i == 0 and not unperturbed_trial then
|
||||
|
@ -628,11 +656,9 @@ end
|
|||
|
||||
local function learn_from_epoch()
|
||||
print()
|
||||
print('rewards:', trial_rewards)
|
||||
--print('rewards:', trial_rewards)
|
||||
|
||||
for _, v in ipairs(trial_rewards) do
|
||||
insert(all_rewards, v)
|
||||
end
|
||||
--for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end
|
||||
|
||||
if unperturbed_trial then
|
||||
local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0])
|
||||
|
@ -642,14 +668,56 @@ local function learn_from_epoch()
|
|||
end
|
||||
|
||||
local step = nn.zeros(#base_params)
|
||||
local shaped_rewards = fitness_shaping(trial_rewards)
|
||||
|
||||
local altogether = learning_rate / (epoch_trials * deviation)
|
||||
-- new stuff
|
||||
|
||||
local best_rewards
|
||||
if negate_trials then
|
||||
-- select one (the best) reward of each pos/neg pair.
|
||||
best_rewards = {}
|
||||
for i = 1, epoch_trials do
|
||||
local reward = shaped_rewards[i]
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = trial_rewards[ind + 0]
|
||||
local neg = trial_rewards[ind + 1]
|
||||
best_rewards[i] = max(pos, neg)
|
||||
end
|
||||
else
|
||||
best_rewards = nn.copy(trial_rewards)
|
||||
end
|
||||
|
||||
local indices = {}
|
||||
for i = 1, #best_rewards do indices[i] = i end
|
||||
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
|
||||
|
||||
--print("indices:", indices)
|
||||
for i = epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
||||
print("best trials:", indices)
|
||||
|
||||
local top_rewards = {}
|
||||
for i = 1, #trial_rewards do top_rewards[i] = 0 end
|
||||
for _, ind in ipairs(indices) do
|
||||
local sind = (ind - 1) * 2 + 1
|
||||
top_rewards[sind + 0] = trial_rewards[sind + 0]
|
||||
top_rewards[sind + 1] = trial_rewards[sind + 1]
|
||||
end
|
||||
print("top:", top_rewards)
|
||||
|
||||
local reward_mean, reward_dev = calc_mean_dev(top_rewards)
|
||||
--print("mean, dev:", reward_mean, reward_dev)
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
|
||||
--print("scaled:", top_rewards)
|
||||
|
||||
-- NOTE: step no longer directly incorporates learning_rate.
|
||||
for i = 1, epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = top_rewards[ind + 0]
|
||||
local neg = top_rewards[ind + 1]
|
||||
local reward = pos - neg
|
||||
local noise = trial_noise[i]
|
||||
for j, v in ipairs(noise) do
|
||||
step[j] = step[j] + altogether * (reward * v)
|
||||
step[j] = step[j] + (reward * v) / epoch_trials
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -658,7 +726,7 @@ local function learn_from_epoch()
|
|||
print("step stddev:", step_dev)
|
||||
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v + step[i]
|
||||
base_params[i] = v + learning_rate * step[i]
|
||||
end
|
||||
|
||||
if enable_network then
|
||||
|
@ -677,9 +745,15 @@ local function do_reset()
|
|||
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
||||
print("reward:", reward, "("..state..")")
|
||||
|
||||
if trial_i >= 0 then trial_rewards[trial_i] = reward end
|
||||
if trial_i >= 0 then
|
||||
if trial_i == 0 or not negate_trials then
|
||||
trial_rewards[trial_i] = reward
|
||||
else
|
||||
trial_rewards[#trial_rewards + 1] = reward
|
||||
end
|
||||
end
|
||||
|
||||
if epoch_i == 0 or trial_i == epoch_trials then
|
||||
if epoch_i == 0 or (trial_i == epoch_trials and trial_neg) then
|
||||
if epoch_i > 0 then learn_from_epoch() end
|
||||
epoch_i = epoch_i + 1
|
||||
prepare_epoch()
|
||||
|
@ -692,10 +766,18 @@ local function do_reset()
|
|||
coins_old = R(0x7ED) * 10 + R(0x7EE)
|
||||
score_old = get_score()
|
||||
|
||||
W(0x75A, 1) -- set number of lives. (mario gets n+1 chances)
|
||||
-- set number of lives. (mario gets n+1 chances)
|
||||
W(0x75A, starting_lives)
|
||||
|
||||
if start_big then
|
||||
-- make mario "super".
|
||||
W(0x754, 0)
|
||||
W(0x756, 1)
|
||||
end
|
||||
|
||||
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
||||
max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
|
||||
--max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
|
||||
--max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time)
|
||||
max_time = ceil(max_time)
|
||||
|
||||
if once then
|
||||
|
@ -720,6 +802,12 @@ local function init()
|
|||
network:print()
|
||||
print("parameters:", network.n_param)
|
||||
|
||||
if init_zeros then
|
||||
local W = network:collect()
|
||||
for i, w in ipairs(W) do W[i] = 0 end
|
||||
network:distribute(W)
|
||||
end
|
||||
|
||||
emu.poweron()
|
||||
emu.unpause()
|
||||
emu.speedmode("turbo")
|
||||
|
@ -794,6 +882,8 @@ local function doit(dummy)
|
|||
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
|
||||
screen_scroll_delta = 0
|
||||
|
||||
if decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end
|
||||
|
||||
if not ingame_paused then reward = reward + reward_delta end
|
||||
|
||||
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
||||
|
|
Loading…
Reference in a new issue