This commit is contained in:
Connor Olding 2017-07-06 03:26:27 +00:00
parent 3d7741aa6e
commit 693eeb991e
2 changed files with 154 additions and 35 deletions

160
main.lua
View file

@ -24,15 +24,27 @@ end
--randomseed(11)
local enable_overlay = false
local enable_network = true
local playable_mode = false
--
local epoch_trials = 24
local learning_rate = 3.2e-3
local deviation = 1e-1 / epoch_trials
local deterministic = false -- use argmax on outputs instead of random sampling.
local det_epsilon = true -- take random actions with probability eps.
local eps_start = 0.50
local eps_stop = 0.05
local eps_frames = 60*60*60
local consider_past_rewards = false
local learn_start_select = false
--
local epoch_trials = 40 -- 24
local learning_rate = 1e-3
local deviation = 1e-2 -- 4e-3
--
local timer_loser = 1/3
local cap_time = 400
local timer_loser = 1/3
--
local enable_overlay = playable_mode
local enable_network = not playable_mode
local input_size = 281 -- TODO: let the script figure this out for us.
local epoch_i = 0
local base_params
@ -41,6 +53,9 @@ local trial_noise = {}
local trial_rewards = {}
local trials_remaining = 0
local trial_frames = 0
local total_frames = 0
local force_start = false
local force_start_old = false
@ -51,12 +66,15 @@ local max_time
local sprite_input = {}
local tile_input = {}
local extra_input = {}
local reward
local all_rewards = {}
local powerup_old
local status_old
local coins_old
local score_old
local once = false
local reset = true
@ -125,6 +143,8 @@ end
local function clamp(x, l, u) return min(max(x, l), u) end
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
local function argmax(...)
local max_i = 0
local max_v = -999999999
@ -142,6 +162,14 @@ local function argmax2(t)
return t[1] > t[2]
end
local function rchoice2(t)
return t[1] > random()
end
local function rbool(t)
return 0.5 >= random()
end
local function empty(t)
for k, _ in pairs(t) do t[k] = nil end
return t
@ -171,6 +199,15 @@ local function normalize(x, out)
return out
end
local function normalize_wrt(x, s, out)
out = out or x
local mean, dev = calc_mean_dev(s)
if dev <= 0 then dev = 1 end
local devs = sqrt(dev)
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
return out
end
-- game-agnostic stuff (i.e. the network itself)
package.loaded['nn'] = nil -- DEBUG
@ -188,10 +225,12 @@ local function make_network(input_size, buttons)
nn_y = nn_y:feed(nn.Dense(input_size))
nn_y = nn_y:feed(nn.Gelu())
else
nn_y = nn_y:feed(nn.Dense(114))
nn_y = nn_y:feed(nn.Relu())
nn_y = nn_y:feed(nn.Dense(57))
nn_y = nn_y:feed(nn.Relu())
nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Gelu())
nn_y = nn_y:feed(nn.Dense(64))
nn_y = nn_y:feed(nn.Gelu())
nn_y = nn_y:feed(nn.Dense(48))
nn_y = nn_y:feed(nn.Gelu())
end
for i = 1, buttons do
nn_z[i] = nn_y
@ -247,6 +286,14 @@ local function get_timer()
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
end
local function get_score()
return R(0x7DE) * 10000 +
R(0x7DF) * 1000 +
R(0x7E0) * 100 +
R(0x7E1) * 10 +
R(0x7E2)
end
local function set_timer(time)
W(0x7F8, floor(time / 100))
W(0x7F9, floor((time / 10) % 10))
@ -460,7 +507,16 @@ end
local function learn_from_epoch()
print()
print('rewards:', trial_rewards)
normalize(trial_rewards)
for _, v in ipairs(trial_rewards) do
insert(all_rewards, v)
end
if consider_past_rewards then
normalize_wrt(trial_rewards, all_rewards)
else
normalize(trial_rewards)
end
--print('normalized:', trial_rewards)
local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
@ -501,9 +557,12 @@ local function learn_from_epoch()
base_params[i] = v + step[i]
end
network:distribute(base_params)
network:save()
if enable_network then
network:distribute(base_params)
network:save()
else
print("note: not updating weights in playable mode.")
end
print()
end
@ -534,7 +593,7 @@ local function load_next_trial()
end
local function do_reset()
print("resetting in state:", get_state())
print("resetting in state: "..get_state()..". reward:", reward)
if trial_i > 0 then trial_rewards[trial_i] = reward end
@ -550,13 +609,15 @@ local function do_reset()
powerup_old = R(0x754)
status_old = R(0x756)
coins_old = R(0x7ED) * 10 + R(0x7EE)
score_old = get_score()
-- set lives to 0. you only got one shot!
-- unless you get a 1-up, in which case, please continue!
W(0x75A, 0)
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
max_time = min(8 * sqrt(15 * (epoch_i - 1)) + 100, cap_time)
max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
max_time = ceil(max_time)
if once then
savestate.load(startsave)
@ -575,7 +636,7 @@ local function do_reset()
end
local function init()
network = make_network(279, 8)
network = make_network(input_size, learn_start_select and 8 or 6)
network:reset()
print("parameters:", network.n_param)
@ -583,13 +644,14 @@ local function init()
emu.unpause()
emu.speedmode("turbo")
network:load()
local res, err = pcall(network.load, network)
if res == false then print(err) end
end
init()
local dummy_softmax_values = {0, 0}
while true do
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
@ -638,6 +700,7 @@ while true do
empty(sprite_input)
empty(tile_input)
empty(extra_input)
-- player
-- TODO: check if mario is playable.
@ -646,6 +709,10 @@ while true do
local status = R(0x756)
mark_sprite(x + 8, y + 24, -powerup - 1)
local vx, vy = S(0x57), S(0x9F)
insert(extra_input, vx)
insert(extra_input, vy)
handle_enemies()
handle_fireballs()
-- blocks being hit. not interactable; we don't care!
@ -666,7 +733,10 @@ while true do
local status_delta = clamp(status - status_old, -1, 1)
local screen_scroll_delta = R(0x775)
local flagpole_bonus = R(0xE) == 4 and 1 or 0
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
--local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
local score_delta = get_score() - score_old
if score_delta < 0 then score_delta = 0 end
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
-- TODO: add ingame score to reward.
@ -674,8 +744,9 @@ while true do
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
--gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F')
gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
--gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
if get_state() == 'dead' and state_old ~= 'dead' then
--print("dead. lives remaining:", R(0x75A, 0))
@ -698,7 +769,9 @@ while true do
timer = timer - 1
end
timer = clamp(timer, 0, max_time)
set_timer(timer)
if enable_network then
set_timer(timer)
end
-- if we've run out of time while the game is paused...
-- that's cheating! unpause.
@ -707,10 +780,19 @@ while true do
local X = {} -- TODO: cache.
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
for i, v in ipairs(tile_input) do insert(X, v / 256) end
--error(#X)
for i, v in ipairs(extra_input) do insert(X, v / 256) end
if #X ~= input_size then error("input size should be: "..tostring(#X)) end
if enable_network and get_state() == 'playing' or ingame_paused then
local choose = deterministic and argmax2 or rchoice2
local outputs = network:forward(X)
-- TODO: predict the *rewards* of all possible actions?
-- that's how DQN seems to work anyway.
-- ah, but A3C just returns probabilities,
-- besides the critic?
local softmaxed = {
outputs[nn_z[1]],
outputs[nn_z[2]],
@ -718,19 +800,29 @@ while true do
outputs[nn_z[4]],
outputs[nn_z[5]],
outputs[nn_z[6]],
outputs[nn_z[7]],
outputs[nn_z[8]],
learn_start_select and outputs[nn_z[7]] or dummy_softmax_values,
learn_start_select and outputs[nn_z[8]] or dummy_softmax_values,
}
local jp = {
up = argmax2(softmaxed[1]),
down = argmax2(softmaxed[2]),
left = argmax2(softmaxed[3]),
right = argmax2(softmaxed[4]),
A = argmax2(softmaxed[5]),
B = argmax2(softmaxed[6]),
start = argmax2(softmaxed[7]),
select = argmax2(softmaxed[8]),
up = choose(softmaxed[1]),
down = choose(softmaxed[2]),
left = choose(softmaxed[3]),
right = choose(softmaxed[4]),
A = choose(softmaxed[5]),
B = choose(softmaxed[6]),
start = choose(softmaxed[7]),
select = choose(softmaxed[8]),
}
if det_epsilon then
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
for k, v in pairs(jp) do
local ss_ok = k ~= 'start' and k ~= 'select' or learn_start_select
if random() < eps and ss_ok then jp[k] = rbool() end
end
end
if force_start then
jp = {
up = false,
@ -743,6 +835,7 @@ while true do
select = false,
}
end
joypad.write(1, jp)
end
@ -751,5 +844,6 @@ while true do
status_old = status
force_start_old = force_start
state_old = get_state()
score_old = get_score()
advance()
end

29
nn.lua
View file

@ -173,6 +173,10 @@ function Layer:forward_deterministic(...)
return self:forward(...)
end
function Layer:backward()
error("Unimplemented.")
end
function Layer:_new_weights(init)
local w = Weights(init)
insert(self.weights, w)
@ -227,6 +231,11 @@ function Input:forward(X)
return X
end
function Input:backward(dY)
assert(#dY == self.size_out)
return zeros(#dY)
end
function Relu:init()
Layer.init(self, "Relu")
end
@ -242,6 +251,18 @@ function Relu:forward(X)
return Y
end
function Relu:backward(dY)
assert(#dY == self.size_out)
self.dcache = self.dcache or zeros(self.size_in)
local Y = self.cache
local dX = self.dcache
for i = 1, #dY do dX[i] = Y[i] >= 0 and dY[i] or 0 end
assert(#Y == self.size_in)
return Y
end
function Gelu:init()
Layer.init(self, "Gelu")
end
@ -383,8 +404,12 @@ function Model:distribute(W)
end
end
function Model:default_filename()
return ('network%07i.txt'):format(self.n_param)
end
function Model:save(fn)
local fn = fn or 'network.txt'
local fn = fn or self:default_filename()
local f = open(fn, 'w')
if f == nil then error("Failed to save network to file "..fn) end
local W = self:collect()
@ -396,7 +421,7 @@ function Model:save(fn)
end
function Model:load(fn)
local fn = fn or 'network.txt'
local fn = fn or self:default_filename()
local f = open(fn, 'r')
if f == nil then
error("Failed to load network from file "..fn)