various
This commit is contained in:
parent
3d7741aa6e
commit
693eeb991e
2 changed files with 154 additions and 35 deletions
160
main.lua
160
main.lua
|
@ -24,15 +24,27 @@ end
|
||||||
|
|
||||||
--randomseed(11)
|
--randomseed(11)
|
||||||
|
|
||||||
local enable_overlay = false
|
local playable_mode = false
|
||||||
local enable_network = true
|
|
||||||
--
|
--
|
||||||
local epoch_trials = 24
|
local deterministic = false -- use argmax on outputs instead of random sampling.
|
||||||
local learning_rate = 3.2e-3
|
local det_epsilon = true -- take random actions with probability eps.
|
||||||
local deviation = 1e-1 / epoch_trials
|
local eps_start = 0.50
|
||||||
|
local eps_stop = 0.05
|
||||||
|
local eps_frames = 60*60*60
|
||||||
|
local consider_past_rewards = false
|
||||||
|
local learn_start_select = false
|
||||||
|
--
|
||||||
|
local epoch_trials = 40 -- 24
|
||||||
|
local learning_rate = 1e-3
|
||||||
|
local deviation = 1e-2 -- 4e-3
|
||||||
--
|
--
|
||||||
local timer_loser = 1/3
|
|
||||||
local cap_time = 400
|
local cap_time = 400
|
||||||
|
local timer_loser = 1/3
|
||||||
|
--
|
||||||
|
local enable_overlay = playable_mode
|
||||||
|
local enable_network = not playable_mode
|
||||||
|
|
||||||
|
local input_size = 281 -- TODO: let the script figure this out for us.
|
||||||
|
|
||||||
local epoch_i = 0
|
local epoch_i = 0
|
||||||
local base_params
|
local base_params
|
||||||
|
@ -41,6 +53,9 @@ local trial_noise = {}
|
||||||
local trial_rewards = {}
|
local trial_rewards = {}
|
||||||
local trials_remaining = 0
|
local trials_remaining = 0
|
||||||
|
|
||||||
|
local trial_frames = 0
|
||||||
|
local total_frames = 0
|
||||||
|
|
||||||
local force_start = false
|
local force_start = false
|
||||||
local force_start_old = false
|
local force_start_old = false
|
||||||
|
|
||||||
|
@ -51,12 +66,15 @@ local max_time
|
||||||
|
|
||||||
local sprite_input = {}
|
local sprite_input = {}
|
||||||
local tile_input = {}
|
local tile_input = {}
|
||||||
|
local extra_input = {}
|
||||||
|
|
||||||
local reward
|
local reward
|
||||||
|
local all_rewards = {}
|
||||||
|
|
||||||
local powerup_old
|
local powerup_old
|
||||||
local status_old
|
local status_old
|
||||||
local coins_old
|
local coins_old
|
||||||
|
local score_old
|
||||||
|
|
||||||
local once = false
|
local once = false
|
||||||
local reset = true
|
local reset = true
|
||||||
|
@ -125,6 +143,8 @@ end
|
||||||
|
|
||||||
local function clamp(x, l, u) return min(max(x, l), u) end
|
local function clamp(x, l, u) return min(max(x, l), u) end
|
||||||
|
|
||||||
|
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
|
||||||
|
|
||||||
local function argmax(...)
|
local function argmax(...)
|
||||||
local max_i = 0
|
local max_i = 0
|
||||||
local max_v = -999999999
|
local max_v = -999999999
|
||||||
|
@ -142,6 +162,14 @@ local function argmax2(t)
|
||||||
return t[1] > t[2]
|
return t[1] > t[2]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function rchoice2(t)
|
||||||
|
return t[1] > random()
|
||||||
|
end
|
||||||
|
|
||||||
|
local function rbool(t)
|
||||||
|
return 0.5 >= random()
|
||||||
|
end
|
||||||
|
|
||||||
local function empty(t)
|
local function empty(t)
|
||||||
for k, _ in pairs(t) do t[k] = nil end
|
for k, _ in pairs(t) do t[k] = nil end
|
||||||
return t
|
return t
|
||||||
|
@ -171,6 +199,15 @@ local function normalize(x, out)
|
||||||
return out
|
return out
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function normalize_wrt(x, s, out)
|
||||||
|
out = out or x
|
||||||
|
local mean, dev = calc_mean_dev(s)
|
||||||
|
if dev <= 0 then dev = 1 end
|
||||||
|
local devs = sqrt(dev)
|
||||||
|
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
|
||||||
|
return out
|
||||||
|
end
|
||||||
|
|
||||||
-- game-agnostic stuff (i.e. the network itself)
|
-- game-agnostic stuff (i.e. the network itself)
|
||||||
|
|
||||||
package.loaded['nn'] = nil -- DEBUG
|
package.loaded['nn'] = nil -- DEBUG
|
||||||
|
@ -188,10 +225,12 @@ local function make_network(input_size, buttons)
|
||||||
nn_y = nn_y:feed(nn.Dense(input_size))
|
nn_y = nn_y:feed(nn.Dense(input_size))
|
||||||
nn_y = nn_y:feed(nn.Gelu())
|
nn_y = nn_y:feed(nn.Gelu())
|
||||||
else
|
else
|
||||||
nn_y = nn_y:feed(nn.Dense(114))
|
nn_y = nn_y:feed(nn.Dense(128))
|
||||||
nn_y = nn_y:feed(nn.Relu())
|
nn_y = nn_y:feed(nn.Gelu())
|
||||||
nn_y = nn_y:feed(nn.Dense(57))
|
nn_y = nn_y:feed(nn.Dense(64))
|
||||||
nn_y = nn_y:feed(nn.Relu())
|
nn_y = nn_y:feed(nn.Gelu())
|
||||||
|
nn_y = nn_y:feed(nn.Dense(48))
|
||||||
|
nn_y = nn_y:feed(nn.Gelu())
|
||||||
end
|
end
|
||||||
for i = 1, buttons do
|
for i = 1, buttons do
|
||||||
nn_z[i] = nn_y
|
nn_z[i] = nn_y
|
||||||
|
@ -247,6 +286,14 @@ local function get_timer()
|
||||||
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function get_score()
|
||||||
|
return R(0x7DE) * 10000 +
|
||||||
|
R(0x7DF) * 1000 +
|
||||||
|
R(0x7E0) * 100 +
|
||||||
|
R(0x7E1) * 10 +
|
||||||
|
R(0x7E2)
|
||||||
|
end
|
||||||
|
|
||||||
local function set_timer(time)
|
local function set_timer(time)
|
||||||
W(0x7F8, floor(time / 100))
|
W(0x7F8, floor(time / 100))
|
||||||
W(0x7F9, floor((time / 10) % 10))
|
W(0x7F9, floor((time / 10) % 10))
|
||||||
|
@ -460,7 +507,16 @@ end
|
||||||
local function learn_from_epoch()
|
local function learn_from_epoch()
|
||||||
print()
|
print()
|
||||||
print('rewards:', trial_rewards)
|
print('rewards:', trial_rewards)
|
||||||
normalize(trial_rewards)
|
|
||||||
|
for _, v in ipairs(trial_rewards) do
|
||||||
|
insert(all_rewards, v)
|
||||||
|
end
|
||||||
|
|
||||||
|
if consider_past_rewards then
|
||||||
|
normalize_wrt(trial_rewards, all_rewards)
|
||||||
|
else
|
||||||
|
normalize(trial_rewards)
|
||||||
|
end
|
||||||
--print('normalized:', trial_rewards)
|
--print('normalized:', trial_rewards)
|
||||||
|
|
||||||
local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
|
local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
|
||||||
|
@ -501,9 +557,12 @@ local function learn_from_epoch()
|
||||||
base_params[i] = v + step[i]
|
base_params[i] = v + step[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
network:distribute(base_params)
|
if enable_network then
|
||||||
|
network:distribute(base_params)
|
||||||
network:save()
|
network:save()
|
||||||
|
else
|
||||||
|
print("note: not updating weights in playable mode.")
|
||||||
|
end
|
||||||
|
|
||||||
print()
|
print()
|
||||||
end
|
end
|
||||||
|
@ -534,7 +593,7 @@ local function load_next_trial()
|
||||||
end
|
end
|
||||||
|
|
||||||
local function do_reset()
|
local function do_reset()
|
||||||
print("resetting in state:", get_state())
|
print("resetting in state: "..get_state()..". reward:", reward)
|
||||||
|
|
||||||
if trial_i > 0 then trial_rewards[trial_i] = reward end
|
if trial_i > 0 then trial_rewards[trial_i] = reward end
|
||||||
|
|
||||||
|
@ -550,13 +609,15 @@ local function do_reset()
|
||||||
powerup_old = R(0x754)
|
powerup_old = R(0x754)
|
||||||
status_old = R(0x756)
|
status_old = R(0x756)
|
||||||
coins_old = R(0x7ED) * 10 + R(0x7EE)
|
coins_old = R(0x7ED) * 10 + R(0x7EE)
|
||||||
|
score_old = get_score()
|
||||||
|
|
||||||
-- set lives to 0. you only got one shot!
|
-- set lives to 0. you only got one shot!
|
||||||
-- unless you get a 1-up, in which case, please continue!
|
-- unless you get a 1-up, in which case, please continue!
|
||||||
W(0x75A, 0)
|
W(0x75A, 0)
|
||||||
|
|
||||||
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
||||||
max_time = min(8 * sqrt(15 * (epoch_i - 1)) + 100, cap_time)
|
max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
|
||||||
|
max_time = ceil(max_time)
|
||||||
|
|
||||||
if once then
|
if once then
|
||||||
savestate.load(startsave)
|
savestate.load(startsave)
|
||||||
|
@ -575,7 +636,7 @@ local function do_reset()
|
||||||
end
|
end
|
||||||
|
|
||||||
local function init()
|
local function init()
|
||||||
network = make_network(279, 8)
|
network = make_network(input_size, learn_start_select and 8 or 6)
|
||||||
network:reset()
|
network:reset()
|
||||||
print("parameters:", network.n_param)
|
print("parameters:", network.n_param)
|
||||||
|
|
||||||
|
@ -583,13 +644,14 @@ local function init()
|
||||||
emu.unpause()
|
emu.unpause()
|
||||||
emu.speedmode("turbo")
|
emu.speedmode("turbo")
|
||||||
|
|
||||||
network:load()
|
|
||||||
local res, err = pcall(network.load, network)
|
local res, err = pcall(network.load, network)
|
||||||
if res == false then print(err) end
|
if res == false then print(err) end
|
||||||
end
|
end
|
||||||
|
|
||||||
init()
|
init()
|
||||||
|
|
||||||
|
local dummy_softmax_values = {0, 0}
|
||||||
|
|
||||||
while true do
|
while true do
|
||||||
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||||
|
|
||||||
|
@ -638,6 +700,7 @@ while true do
|
||||||
|
|
||||||
empty(sprite_input)
|
empty(sprite_input)
|
||||||
empty(tile_input)
|
empty(tile_input)
|
||||||
|
empty(extra_input)
|
||||||
|
|
||||||
-- player
|
-- player
|
||||||
-- TODO: check if mario is playable.
|
-- TODO: check if mario is playable.
|
||||||
|
@ -646,6 +709,10 @@ while true do
|
||||||
local status = R(0x756)
|
local status = R(0x756)
|
||||||
mark_sprite(x + 8, y + 24, -powerup - 1)
|
mark_sprite(x + 8, y + 24, -powerup - 1)
|
||||||
|
|
||||||
|
local vx, vy = S(0x57), S(0x9F)
|
||||||
|
insert(extra_input, vx)
|
||||||
|
insert(extra_input, vy)
|
||||||
|
|
||||||
handle_enemies()
|
handle_enemies()
|
||||||
handle_fireballs()
|
handle_fireballs()
|
||||||
-- blocks being hit. not interactable; we don't care!
|
-- blocks being hit. not interactable; we don't care!
|
||||||
|
@ -666,7 +733,10 @@ while true do
|
||||||
local status_delta = clamp(status - status_old, -1, 1)
|
local status_delta = clamp(status - status_old, -1, 1)
|
||||||
local screen_scroll_delta = R(0x775)
|
local screen_scroll_delta = R(0x775)
|
||||||
local flagpole_bonus = R(0xE) == 4 and 1 or 0
|
local flagpole_bonus = R(0xE) == 4 and 1 or 0
|
||||||
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
--local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
||||||
|
local score_delta = get_score() - score_old
|
||||||
|
if score_delta < 0 then score_delta = 0 end
|
||||||
|
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
|
||||||
|
|
||||||
-- TODO: add ingame score to reward.
|
-- TODO: add ingame score to reward.
|
||||||
|
|
||||||
|
@ -674,8 +744,9 @@ while true do
|
||||||
|
|
||||||
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
|
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
|
||||||
--gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F')
|
--gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F')
|
||||||
gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
||||||
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
--gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
||||||
|
gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
||||||
|
|
||||||
if get_state() == 'dead' and state_old ~= 'dead' then
|
if get_state() == 'dead' and state_old ~= 'dead' then
|
||||||
--print("dead. lives remaining:", R(0x75A, 0))
|
--print("dead. lives remaining:", R(0x75A, 0))
|
||||||
|
@ -698,7 +769,9 @@ while true do
|
||||||
timer = timer - 1
|
timer = timer - 1
|
||||||
end
|
end
|
||||||
timer = clamp(timer, 0, max_time)
|
timer = clamp(timer, 0, max_time)
|
||||||
set_timer(timer)
|
if enable_network then
|
||||||
|
set_timer(timer)
|
||||||
|
end
|
||||||
|
|
||||||
-- if we've run out of time while the game is paused...
|
-- if we've run out of time while the game is paused...
|
||||||
-- that's cheating! unpause.
|
-- that's cheating! unpause.
|
||||||
|
@ -707,10 +780,19 @@ while true do
|
||||||
local X = {} -- TODO: cache.
|
local X = {} -- TODO: cache.
|
||||||
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
||||||
for i, v in ipairs(tile_input) do insert(X, v / 256) end
|
for i, v in ipairs(tile_input) do insert(X, v / 256) end
|
||||||
--error(#X)
|
for i, v in ipairs(extra_input) do insert(X, v / 256) end
|
||||||
|
if #X ~= input_size then error("input size should be: "..tostring(#X)) end
|
||||||
|
|
||||||
if enable_network and get_state() == 'playing' or ingame_paused then
|
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||||
|
local choose = deterministic and argmax2 or rchoice2
|
||||||
|
|
||||||
local outputs = network:forward(X)
|
local outputs = network:forward(X)
|
||||||
|
|
||||||
|
-- TODO: predict the *rewards* of all possible actions?
|
||||||
|
-- that's how DQN seems to work anyway.
|
||||||
|
-- ah, but A3C just returns probabilities,
|
||||||
|
-- besides the critic?
|
||||||
|
|
||||||
local softmaxed = {
|
local softmaxed = {
|
||||||
outputs[nn_z[1]],
|
outputs[nn_z[1]],
|
||||||
outputs[nn_z[2]],
|
outputs[nn_z[2]],
|
||||||
|
@ -718,19 +800,29 @@ while true do
|
||||||
outputs[nn_z[4]],
|
outputs[nn_z[4]],
|
||||||
outputs[nn_z[5]],
|
outputs[nn_z[5]],
|
||||||
outputs[nn_z[6]],
|
outputs[nn_z[6]],
|
||||||
outputs[nn_z[7]],
|
learn_start_select and outputs[nn_z[7]] or dummy_softmax_values,
|
||||||
outputs[nn_z[8]],
|
learn_start_select and outputs[nn_z[8]] or dummy_softmax_values,
|
||||||
}
|
}
|
||||||
|
|
||||||
local jp = {
|
local jp = {
|
||||||
up = argmax2(softmaxed[1]),
|
up = choose(softmaxed[1]),
|
||||||
down = argmax2(softmaxed[2]),
|
down = choose(softmaxed[2]),
|
||||||
left = argmax2(softmaxed[3]),
|
left = choose(softmaxed[3]),
|
||||||
right = argmax2(softmaxed[4]),
|
right = choose(softmaxed[4]),
|
||||||
A = argmax2(softmaxed[5]),
|
A = choose(softmaxed[5]),
|
||||||
B = argmax2(softmaxed[6]),
|
B = choose(softmaxed[6]),
|
||||||
start = argmax2(softmaxed[7]),
|
start = choose(softmaxed[7]),
|
||||||
select = argmax2(softmaxed[8]),
|
select = choose(softmaxed[8]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if det_epsilon then
|
||||||
|
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
|
||||||
|
for k, v in pairs(jp) do
|
||||||
|
local ss_ok = k ~= 'start' and k ~= 'select' or learn_start_select
|
||||||
|
if random() < eps and ss_ok then jp[k] = rbool() end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
if force_start then
|
if force_start then
|
||||||
jp = {
|
jp = {
|
||||||
up = false,
|
up = false,
|
||||||
|
@ -743,6 +835,7 @@ while true do
|
||||||
select = false,
|
select = false,
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
joypad.write(1, jp)
|
joypad.write(1, jp)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -751,5 +844,6 @@ while true do
|
||||||
status_old = status
|
status_old = status
|
||||||
force_start_old = force_start
|
force_start_old = force_start
|
||||||
state_old = get_state()
|
state_old = get_state()
|
||||||
|
score_old = get_score()
|
||||||
advance()
|
advance()
|
||||||
end
|
end
|
||||||
|
|
29
nn.lua
29
nn.lua
|
@ -173,6 +173,10 @@ function Layer:forward_deterministic(...)
|
||||||
return self:forward(...)
|
return self:forward(...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Layer:backward()
|
||||||
|
error("Unimplemented.")
|
||||||
|
end
|
||||||
|
|
||||||
function Layer:_new_weights(init)
|
function Layer:_new_weights(init)
|
||||||
local w = Weights(init)
|
local w = Weights(init)
|
||||||
insert(self.weights, w)
|
insert(self.weights, w)
|
||||||
|
@ -227,6 +231,11 @@ function Input:forward(X)
|
||||||
return X
|
return X
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Input:backward(dY)
|
||||||
|
assert(#dY == self.size_out)
|
||||||
|
return zeros(#dY)
|
||||||
|
end
|
||||||
|
|
||||||
function Relu:init()
|
function Relu:init()
|
||||||
Layer.init(self, "Relu")
|
Layer.init(self, "Relu")
|
||||||
end
|
end
|
||||||
|
@ -242,6 +251,18 @@ function Relu:forward(X)
|
||||||
return Y
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Relu:backward(dY)
|
||||||
|
assert(#dY == self.size_out)
|
||||||
|
self.dcache = self.dcache or zeros(self.size_in)
|
||||||
|
local Y = self.cache
|
||||||
|
local dX = self.dcache
|
||||||
|
|
||||||
|
for i = 1, #dY do dX[i] = Y[i] >= 0 and dY[i] or 0 end
|
||||||
|
|
||||||
|
assert(#Y == self.size_in)
|
||||||
|
return Y
|
||||||
|
end
|
||||||
|
|
||||||
function Gelu:init()
|
function Gelu:init()
|
||||||
Layer.init(self, "Gelu")
|
Layer.init(self, "Gelu")
|
||||||
end
|
end
|
||||||
|
@ -383,8 +404,12 @@ function Model:distribute(W)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Model:default_filename()
|
||||||
|
return ('network%07i.txt'):format(self.n_param)
|
||||||
|
end
|
||||||
|
|
||||||
function Model:save(fn)
|
function Model:save(fn)
|
||||||
local fn = fn or 'network.txt'
|
local fn = fn or self:default_filename()
|
||||||
local f = open(fn, 'w')
|
local f = open(fn, 'w')
|
||||||
if f == nil then error("Failed to save network to file "..fn) end
|
if f == nil then error("Failed to save network to file "..fn) end
|
||||||
local W = self:collect()
|
local W = self:collect()
|
||||||
|
@ -396,7 +421,7 @@ function Model:save(fn)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Model:load(fn)
|
function Model:load(fn)
|
||||||
local fn = fn or 'network.txt'
|
local fn = fn or self:default_filename()
|
||||||
local f = open(fn, 'r')
|
local f = open(fn, 'r')
|
||||||
if f == nil then
|
if f == nil then
|
||||||
error("Failed to load network from file "..fn)
|
error("Failed to load network from file "..fn)
|
||||||
|
|
Loading…
Reference in a new issue