basic learning
This commit is contained in:
parent
75ad46bfe9
commit
e2d29352c4
2 changed files with 391 additions and 184 deletions
359
main.lua
359
main.lua
|
@ -20,6 +20,63 @@ local function globalize(t)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- configuration and globals.
|
||||||
|
|
||||||
|
math.randomseed(10)
|
||||||
|
|
||||||
|
local learning_rate = 1e-2
|
||||||
|
local deviation = 2e-2
|
||||||
|
|
||||||
|
local epoch_i = 0
|
||||||
|
local epoch_trials = 12
|
||||||
|
local base_params
|
||||||
|
local trial_i = 0
|
||||||
|
local trial_noise = {}
|
||||||
|
local trial_rewards = {}
|
||||||
|
local trials_remaining = 0
|
||||||
|
|
||||||
|
local force_start = false
|
||||||
|
local force_start_old = false
|
||||||
|
|
||||||
|
local enable_overlay = false
|
||||||
|
local enable_network = true
|
||||||
|
|
||||||
|
local startsave = savestate.create(1)
|
||||||
|
|
||||||
|
local poketime = false
|
||||||
|
|
||||||
|
local sprite_input = {}
|
||||||
|
local tile_input = {}
|
||||||
|
|
||||||
|
local reward
|
||||||
|
local powerup_old
|
||||||
|
local status_old
|
||||||
|
local coins_old
|
||||||
|
|
||||||
|
local once = false
|
||||||
|
local reset = true
|
||||||
|
|
||||||
|
local ok_routines = {
|
||||||
|
[0x4] = true, -- sliding down flagpole
|
||||||
|
[0x5] = true, -- end of level auto-walk
|
||||||
|
[0x7] = true, -- start of level auto-walk
|
||||||
|
[0x8] = true, -- normal (in control)
|
||||||
|
[0x9] = true, -- acquiring mushroom
|
||||||
|
[0xA] = true, -- losing big mario
|
||||||
|
[0xB] = true, -- uhh
|
||||||
|
[0xC] = true, -- acquiring fireflower
|
||||||
|
}
|
||||||
|
|
||||||
|
local bad_states = {
|
||||||
|
power = true,
|
||||||
|
waiting_demo = true,
|
||||||
|
playing_demo = true,
|
||||||
|
unknown = true,
|
||||||
|
lose = true,
|
||||||
|
}
|
||||||
|
|
||||||
|
local state_old = ''
|
||||||
|
|
||||||
-- localize some stuff.
|
-- localize some stuff.
|
||||||
|
|
||||||
local print = print
|
local print = print
|
||||||
|
@ -51,6 +108,14 @@ local arshift = bit.arshift
|
||||||
local rol = bit.rol
|
local rol = bit.rol
|
||||||
local ror = bit.ror
|
local ror = bit.ror
|
||||||
|
|
||||||
|
-- utilities.
|
||||||
|
|
||||||
|
local function boolean_xor(a, b)
|
||||||
|
if a and b then return false end
|
||||||
|
if not a and not b then return false end
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
|
||||||
local function clamp(x, l, u) return min(max(x, l), u) end
|
local function clamp(x, l, u) return min(max(x, l), u) end
|
||||||
|
|
||||||
local function argmax(...)
|
local function argmax(...)
|
||||||
|
@ -70,6 +135,35 @@ local function argmax2(t)
|
||||||
return t[1] > t[2]
|
return t[1] > t[2]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function empty(t)
|
||||||
|
for k, _ in pairs(t) do t[k] = nil end
|
||||||
|
return t
|
||||||
|
end
|
||||||
|
|
||||||
|
local function calc_mean_dev(x)
|
||||||
|
local mean = 0
|
||||||
|
for i, v in ipairs(x) do
|
||||||
|
mean = mean + v / #x
|
||||||
|
end
|
||||||
|
|
||||||
|
local dev = 0
|
||||||
|
for i, v in ipairs(x) do
|
||||||
|
local delta = v - mean
|
||||||
|
dev = dev + delta * delta / #x
|
||||||
|
end
|
||||||
|
|
||||||
|
return mean, dev
|
||||||
|
end
|
||||||
|
|
||||||
|
local function normalize(x, out)
|
||||||
|
out = out or x
|
||||||
|
local mean, dev = calc_mean_dev(x)
|
||||||
|
if dev <= 0 then dev = 1 end
|
||||||
|
local devs = sqrt(dev)
|
||||||
|
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
|
||||||
|
return out
|
||||||
|
end
|
||||||
|
|
||||||
-- game-agnostic stuff (i.e. the network itself)
|
-- game-agnostic stuff (i.e. the network itself)
|
||||||
|
|
||||||
package.loaded['nn'] = nil -- DEBUG
|
package.loaded['nn'] = nil -- DEBUG
|
||||||
|
@ -100,9 +194,6 @@ end
|
||||||
|
|
||||||
-- and here we go with the game stuff.
|
-- and here we go with the game stuff.
|
||||||
|
|
||||||
local enable_overlay = false
|
|
||||||
local enable_network = true
|
|
||||||
|
|
||||||
--[[
|
--[[
|
||||||
https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
|
https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
|
||||||
--]]
|
--]]
|
||||||
|
@ -142,14 +233,17 @@ local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
|
||||||
-8, -38,
|
-8, -38,
|
||||||
}
|
}
|
||||||
|
|
||||||
local startsave = savestate.create(1)
|
local function get_timer()
|
||||||
|
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
||||||
|
end
|
||||||
|
|
||||||
local poketime = false
|
local function set_timer(time)
|
||||||
|
W(0x7F8, floor(time / 100))
|
||||||
|
W(0x7F9, floor((time / 10) % 10))
|
||||||
|
W(0x7FA, floor(time % 10))
|
||||||
|
end
|
||||||
|
|
||||||
local sprite_input = {}
|
local function mark_sprite(x, y, t)
|
||||||
local tile_input = {}
|
|
||||||
|
|
||||||
local function shitsprite(x, y, t)
|
|
||||||
if x < 0 or x >= 256 or y < 0 or y > 224 then
|
if x < 0 or x >= 256 or y < 0 or y > 224 then
|
||||||
sprite_input[#sprite_input+1] = 0
|
sprite_input[#sprite_input+1] = 0
|
||||||
sprite_input[#sprite_input+1] = 0
|
sprite_input[#sprite_input+1] = 0
|
||||||
|
@ -168,7 +262,7 @@ local function shitsprite(x, y, t)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
local function shittile(x, y, t)
|
local function mark_tile(x, y, t)
|
||||||
tile_input[#tile_input+1] = t
|
tile_input[#tile_input+1] = t
|
||||||
if t == 0 then return end
|
if t == 0 then return end
|
||||||
if enable_overlay then
|
if enable_overlay then
|
||||||
|
@ -200,23 +294,9 @@ local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr)
|
||||||
return sx, sy
|
return sx, sy
|
||||||
end
|
end
|
||||||
|
|
||||||
local reward
|
|
||||||
local powerup_old
|
|
||||||
local status_old
|
|
||||||
local coins_old
|
|
||||||
|
|
||||||
local once = false
|
|
||||||
local reset = true
|
|
||||||
|
|
||||||
emu.poweron()
|
|
||||||
emu.unpause()
|
|
||||||
emu.speedmode("normal")
|
|
||||||
|
|
||||||
local function opermode() return R(0x770) end
|
|
||||||
local function paused() return band(R(0x776), 1) end
|
local function paused() return band(R(0x776), 1) end
|
||||||
local function subroutine() return R(0xE) end
|
|
||||||
|
|
||||||
local function getstate()
|
local function get_state()
|
||||||
if R(0xE) == 0xFF then return 'power' end
|
if R(0xE) == 0xFF then return 'power' end
|
||||||
if R(0x774) > 0 then return 'lagging' end
|
if R(0x774) > 0 then return 'lagging' end
|
||||||
if R(0x7A2) > 0 then return 'waiting_demo' end
|
if R(0x7A2) > 0 then return 'waiting_demo' end
|
||||||
|
@ -238,35 +318,14 @@ local function getstate()
|
||||||
return 'unknown'
|
return 'unknown'
|
||||||
end
|
end
|
||||||
|
|
||||||
local ok_routines = {
|
|
||||||
[0x4] = true, -- sliding down flagpole
|
|
||||||
[0x5] = true, -- end of level auto-walk
|
|
||||||
[0x7] = true, -- start of level auto-walk
|
|
||||||
[0x8] = true, -- normal (in control)
|
|
||||||
[0x9] = true, -- acquiring mushroom
|
|
||||||
[0xA] = true, -- losing big mario
|
|
||||||
[0xB] = true, -- uhh
|
|
||||||
[0xC] = true, -- acquiring fireflower
|
|
||||||
}
|
|
||||||
|
|
||||||
local fuckstates = {
|
|
||||||
power = true,
|
|
||||||
waiting_demo = true,
|
|
||||||
playing_demo = true,
|
|
||||||
unknown = true,
|
|
||||||
lose = true,
|
|
||||||
paused = true,
|
|
||||||
}
|
|
||||||
|
|
||||||
local function advance()
|
local function advance()
|
||||||
emu.frameadvance()
|
emu.frameadvance()
|
||||||
while emu.lagged() do emu.frameadvance() end -- skip lag frames.
|
while emu.lagged() do emu.frameadvance() end -- skip lag frames.
|
||||||
while R(0x774) > 0 do emu.frameadvance() end -- also lag frames.
|
while R(0x774) > 0 do emu.frameadvance() end -- also lag frames.
|
||||||
end
|
end
|
||||||
|
|
||||||
local state_old = ''
|
|
||||||
while false do
|
while false do
|
||||||
local state = getstate()
|
local state = get_state()
|
||||||
if state ~= state_old then
|
if state ~= state_old then
|
||||||
print(emu.framecount(), state)
|
print(emu.framecount(), state)
|
||||||
state_old = state
|
state_old = state
|
||||||
|
@ -301,9 +360,9 @@ local function handle_enemies()
|
||||||
x, y = x + x_off, y + y_off
|
x, y = x + x_off, y + y_off
|
||||||
end
|
end
|
||||||
if invisible then
|
if invisible then
|
||||||
shitsprite(0, 0, 0)
|
mark_sprite(0, 0, 0)
|
||||||
else
|
else
|
||||||
shitsprite(x, y, tid + 1)
|
mark_sprite(x, y, tid + 1)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -316,9 +375,9 @@ local function handle_fireballs()
|
||||||
local state = R(0x24 + i)
|
local state = R(0x24 + i)
|
||||||
local invisible = state == 0
|
local invisible = state == 0
|
||||||
if invisible then
|
if invisible then
|
||||||
shitsprite(0, 0, 0)
|
mark_sprite(0, 0, 0)
|
||||||
else
|
else
|
||||||
shitsprite(x, y, 257)
|
mark_sprite(x, y, 257)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -330,9 +389,9 @@ local function handle_blocks()
|
||||||
local state = R(0x26 + i)
|
local state = R(0x26 + i)
|
||||||
local invisible = state == 0
|
local invisible = state == 0
|
||||||
if invisible then
|
if invisible then
|
||||||
shitsprite(0, 0, 0)
|
mark_sprite(0, 0, 0)
|
||||||
else
|
else
|
||||||
shitsprite(x, y, 258)
|
mark_sprite(x, y, 258)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -347,9 +406,9 @@ local function handle_hammers()
|
||||||
if state ~= 0
|
if state ~= 0
|
||||||
and state >= 0x30
|
and state >= 0x30
|
||||||
then
|
then
|
||||||
shitsprite(x, y, state + 1)
|
mark_sprite(x, y, state + 1)
|
||||||
else
|
else
|
||||||
shitsprite(0, 0, 0)
|
mark_sprite(0, 0, 0)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -360,9 +419,9 @@ local function handle_misc()
|
||||||
x, y = x + 8, y + 8
|
x, y = x + 8, y + 8
|
||||||
local state = R(0x33 + i)
|
local state = R(0x33 + i)
|
||||||
if state ~= 0 then
|
if state ~= 0 then
|
||||||
shitsprite(x, y, state + 1)
|
mark_sprite(x, y, state + 1)
|
||||||
else
|
else
|
||||||
shitsprite(0, 0, 0)
|
mark_sprite(0, 0, 0)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -383,25 +442,104 @@ local function handle_tiles()
|
||||||
end
|
end
|
||||||
local sx = x * 16 + 8 - tile_scroll_remainder
|
local sx = x * 16 + 8 - tile_scroll_remainder
|
||||||
local sy = y * 16 + 40
|
local sy = y * 16 + 40
|
||||||
shittile(sx, sy, t)
|
mark_tile(sx, sy, t)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
local function doreset()
|
local function learn_from_epoch()
|
||||||
print("resetting in state:", getstate())
|
print()
|
||||||
|
print('rewards:', trial_rewards)
|
||||||
|
normalize(trial_rewards)
|
||||||
|
print('normalized:', trial_rewards)
|
||||||
|
|
||||||
|
local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
|
||||||
|
|
||||||
|
local step = nn.zeros(#base_params)
|
||||||
|
for i = 1, epoch_trials do
|
||||||
|
local reward = trial_rewards[i]
|
||||||
|
local noise = trial_noise[i]
|
||||||
|
for j, v in ipairs(noise) do
|
||||||
|
step[j] = step[j] + reward * v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local magnitude = learning_rate / deviation
|
||||||
|
print('stepping with magnitude', magnitude)
|
||||||
|
-- throw the division from the averaging in there too.
|
||||||
|
local altogether = magnitude / epoch_trials
|
||||||
|
for i, v in ipairs(step) do
|
||||||
|
step[i] = altogether * v
|
||||||
|
end
|
||||||
|
|
||||||
|
local step_mean, step_dev = calc_mean_dev(step)
|
||||||
|
print("step mean:", step_mean)
|
||||||
|
print("step stddev:", step_dev)
|
||||||
|
|
||||||
|
if step_dev > 1e-8 then
|
||||||
|
for i, v in ipairs(base_params) do
|
||||||
|
base_params[i] = v + step[i]
|
||||||
|
end
|
||||||
|
else
|
||||||
|
-- we didn't get anywhere. step in a random direction.
|
||||||
|
local noise = trial_noise[1]
|
||||||
|
for i, v in ipairs(base_params) do
|
||||||
|
base_params[i] = v + magnitude * noise[i]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
network:distribute(base_params)
|
||||||
|
|
||||||
|
network:save()
|
||||||
|
|
||||||
|
print()
|
||||||
|
end
|
||||||
|
|
||||||
|
local function prepare_epoch()
|
||||||
|
print('preparing epoch '..tostring(epoch_i)..'. this might take a while.')
|
||||||
|
base_params = network:collect()
|
||||||
|
empty(trial_noise)
|
||||||
|
empty(trial_rewards)
|
||||||
|
for i = 1, epoch_trials do
|
||||||
|
local noise = nn.zeros(#base_params)
|
||||||
|
for j = 1, #base_params do noise[j] = nn.normal() end
|
||||||
|
trial_noise[i] = noise
|
||||||
|
end
|
||||||
|
trial_i = 0
|
||||||
|
end
|
||||||
|
|
||||||
|
local function load_next_trial()
|
||||||
|
trial_i = trial_i + 1
|
||||||
|
print('loading trial', trial_i)
|
||||||
|
local W = nn.copy(base_params)
|
||||||
|
local noise = trial_noise[trial_i]
|
||||||
|
for i, v in ipairs(base_params) do
|
||||||
|
W[i] = v + deviation * noise[i]
|
||||||
|
end
|
||||||
|
network:distribute(W)
|
||||||
|
end
|
||||||
|
|
||||||
|
local function do_reset()
|
||||||
|
print("resetting in state:", get_state())
|
||||||
|
|
||||||
|
if trial_i > 0 then trial_rewards[trial_i] = reward end
|
||||||
|
|
||||||
|
if epoch_i == 0 or trial_i == epoch_trials then
|
||||||
|
if epoch_i > 0 then learn_from_epoch() else network:reset() end
|
||||||
|
epoch_i = epoch_i + 1
|
||||||
|
prepare_epoch()
|
||||||
|
end
|
||||||
|
|
||||||
if once then
|
if once then
|
||||||
savestate.load(startsave)
|
savestate.load(startsave)
|
||||||
print("end of trial reward:", reward)
|
print("end of trial reward:", reward)
|
||||||
print()
|
|
||||||
else
|
else
|
||||||
savestate.save(startsave)
|
savestate.save(startsave)
|
||||||
end
|
end
|
||||||
once = true
|
once = true
|
||||||
|
|
||||||
-- bit of a hack:
|
-- bit of a hack:
|
||||||
if getstate() == 'loading' then advance() end
|
if get_state() == 'loading' then advance() end
|
||||||
reward = 0
|
reward = 0
|
||||||
powerup_old = R(0x754)
|
powerup_old = R(0x754)
|
||||||
status_old = R(0x756)
|
status_old = R(0x756)
|
||||||
|
@ -413,13 +551,29 @@ local function doreset()
|
||||||
|
|
||||||
emu.frameadvance() -- prevents emulator from quirking up.
|
emu.frameadvance() -- prevents emulator from quirking up.
|
||||||
|
|
||||||
|
print()
|
||||||
|
load_next_trial()
|
||||||
|
|
||||||
reset = false
|
reset = false
|
||||||
if network ~= nil then network:reset() end -- FIXME: hack
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function init()
|
||||||
|
network = make_network(279, 8)
|
||||||
|
network:reset()
|
||||||
|
print("parameters:", network.n_param)
|
||||||
|
|
||||||
|
emu.poweron()
|
||||||
|
emu.unpause()
|
||||||
|
emu.speedmode("normal")
|
||||||
|
end
|
||||||
|
|
||||||
|
init()
|
||||||
|
|
||||||
while true do
|
while true do
|
||||||
while fuckstates[getstate()] do
|
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||||
--gui.text(120, 124, ("%02X"):format(subroutine()), '#FFFFFF', '#0000003F')
|
|
||||||
|
while bad_states[get_state()] do
|
||||||
|
--gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F')
|
||||||
-- mash the start button until we have control.
|
-- mash the start button until we have control.
|
||||||
-- TODO: learn this too.
|
-- TODO: learn this too.
|
||||||
--local jp = joypad.read(1)
|
--local jp = joypad.read(1)
|
||||||
|
@ -437,44 +591,39 @@ while true do
|
||||||
|
|
||||||
reset = true
|
reset = true
|
||||||
|
|
||||||
gui.text(4, 12, getstate(), '#FFFFFF', '#0000003F')
|
|
||||||
advance()
|
advance()
|
||||||
|
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||||
|
|
||||||
-- bit of a hack:
|
-- bit of a hack:
|
||||||
while getstate() == "loading" do advance() end
|
while get_state() == "loading" do advance() end
|
||||||
state_old = getstate()
|
state_old = get_state()
|
||||||
end
|
end
|
||||||
|
|
||||||
if reset then doreset() end
|
if reset then do_reset() end
|
||||||
|
|
||||||
if not enable_network then
|
if not enable_network then
|
||||||
-- infinite time cheat. super handy for testing.
|
-- infinite time cheat. super handy for testing.
|
||||||
if R(0xE) == 8 then
|
if R(0xE) == 8 then
|
||||||
W(0x7F8, 9)
|
set_timer(667)
|
||||||
W(0x7F9, 9)
|
|
||||||
W(0x7FA, 10)
|
|
||||||
poketime = true
|
poketime = true
|
||||||
elseif poketime then
|
elseif poketime then
|
||||||
poketime = false
|
poketime = false
|
||||||
W(0x7F8, 0)
|
set_timer(1)
|
||||||
W(0x7F9, 0)
|
|
||||||
W(0x7FA, 1)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
-- infinite lives.
|
-- infinite lives.
|
||||||
W(0x75A, 1)
|
W(0x75A, 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- empty input lists without creating a new table.
|
empty(sprite_input)
|
||||||
for k, v in pairs(sprite_input) do sprite_input[k] = nil end
|
empty(tile_input)
|
||||||
for k, v in pairs(tile_input) do tile_input[k] = nil end
|
|
||||||
|
|
||||||
-- player
|
-- player
|
||||||
-- TODO: add check if mario is playable.
|
-- TODO: check if mario is playable.
|
||||||
local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
||||||
local powerup = R(0x754)
|
local powerup = R(0x754)
|
||||||
local status = R(0x756)
|
local status = R(0x756)
|
||||||
shitsprite(x + 8, y + 24, -powerup - 1)
|
mark_sprite(x + 8, y + 24, -powerup - 1)
|
||||||
|
|
||||||
handle_enemies()
|
handle_enemies()
|
||||||
handle_fireballs()
|
handle_fireballs()
|
||||||
|
@ -484,6 +633,8 @@ while true do
|
||||||
handle_misc()
|
handle_misc()
|
||||||
handle_tiles()
|
handle_tiles()
|
||||||
|
|
||||||
|
local ingame_paused = get_state() == "paused"
|
||||||
|
|
||||||
local coins = R(0x7ED) * 10 + R(0x7EE)
|
local coins = R(0x7ED) * 10 + R(0x7EE)
|
||||||
local coins_delta = coins - coins_old
|
local coins_delta = coins - coins_old
|
||||||
-- handle wrap-around.
|
-- handle wrap-around.
|
||||||
|
@ -493,45 +644,48 @@ while true do
|
||||||
-- 2 is fire mario.
|
-- 2 is fire mario.
|
||||||
local status_delta = clamp(status - status_old, -1, 1)
|
local status_delta = clamp(status - status_old, -1, 1)
|
||||||
local screen_scroll_delta = R(0x775)
|
local screen_scroll_delta = R(0x775)
|
||||||
local flagpole_bonus = subroutine() == 4 and 1 or 0
|
local flagpole_bonus = R(0xE) == 4 and 1 or 0
|
||||||
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
||||||
|
|
||||||
reward = reward + reward_delta
|
if not ingame_paused then reward = reward + reward_delta end
|
||||||
|
|
||||||
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
|
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
|
||||||
--gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F')
|
--gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F')
|
||||||
gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
||||||
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
||||||
|
|
||||||
if getstate() == 'dead' and state_old ~= 'dead' then
|
if get_state() == 'dead' and state_old ~= 'dead' then
|
||||||
print("dead. lives remaining:", R(0x75A, 0))
|
print("dead. lives remaining:", R(0x75A, 0))
|
||||||
if R(0x75A, 0) == 0 then reset = true end
|
if R(0x75A, 0) == 0 then reset = true end
|
||||||
end
|
end
|
||||||
if getstate() == 'lose' then
|
if get_state() == 'lose' then
|
||||||
print("ran out of lives.")
|
print("ran out of lives.")
|
||||||
reset = true
|
reset = true
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- lose a point for every frame paused.
|
||||||
|
if ingame_paused then reward = reward - 1 end
|
||||||
|
|
||||||
-- every few frames mario stands still, forcibly decrease the timer.
|
-- every few frames mario stands still, forcibly decrease the timer.
|
||||||
|
-- this includes having the game paused.
|
||||||
|
-- TODO: more robust. doesn't detect moonwalking against a wall.
|
||||||
|
local timer = get_timer()
|
||||||
local timer_loser = 1/5
|
local timer_loser = 1/5
|
||||||
if math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
if ingame_paused or math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
||||||
local timer = R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
|
||||||
timer = clamp(timer - 1, 0, 400)
|
timer = clamp(timer - 1, 0, 400)
|
||||||
W(0x7F8, floor(timer / 100))
|
set_timer(timer)
|
||||||
W(0x7F9, floor((timer / 10) % 10))
|
|
||||||
W(0x7FA, floor(timer % 10))
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- if we've run out of time while the game is paused...
|
||||||
|
-- that's cheating! unpause.
|
||||||
|
force_start = ingame_paused and timer == 0
|
||||||
|
|
||||||
local X = {} -- TODO: cache.
|
local X = {} -- TODO: cache.
|
||||||
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
||||||
for i, v in ipairs(tile_input) do insert(X, v / 256) end
|
for i, v in ipairs(tile_input) do insert(X, v / 256) end
|
||||||
|
--error(#X)
|
||||||
|
|
||||||
if network == nil then
|
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||||
network = make_network(#X, 8); network:reset()
|
|
||||||
print("parameters:", network.n_param)
|
|
||||||
end
|
|
||||||
|
|
||||||
if enable_network and getstate() == 'playing' then
|
|
||||||
local outputs = network:forward(X)
|
local outputs = network:forward(X)
|
||||||
local softmaxed = {
|
local softmaxed = {
|
||||||
outputs[nn_z[1]],
|
outputs[nn_z[1]],
|
||||||
|
@ -553,12 +707,25 @@ while true do
|
||||||
start = argmax2(softmaxed[7]),
|
start = argmax2(softmaxed[7]),
|
||||||
select = argmax2(softmaxed[8]),
|
select = argmax2(softmaxed[8]),
|
||||||
}
|
}
|
||||||
|
if force_start then
|
||||||
|
jp = {
|
||||||
|
up = false,
|
||||||
|
down = false,
|
||||||
|
left = false,
|
||||||
|
right = false,
|
||||||
|
A = false,
|
||||||
|
B = false,
|
||||||
|
start = force_start_old,
|
||||||
|
select = false,
|
||||||
|
}
|
||||||
|
end
|
||||||
joypad.write(1, jp)
|
joypad.write(1, jp)
|
||||||
end
|
end
|
||||||
|
|
||||||
coins_old = coins
|
coins_old = coins
|
||||||
powerup_old = powerup
|
powerup_old = powerup
|
||||||
status_old = status
|
status_old = status
|
||||||
state_old = getstate()
|
force_start_old = force_start
|
||||||
|
state_old = get_state()
|
||||||
advance()
|
advance()
|
||||||
end
|
end
|
||||||
|
|
206
nn.lua
206
nn.lua
|
@ -13,6 +13,7 @@ local cos = math.cos
|
||||||
local sin = math.sin
|
local sin = math.sin
|
||||||
local insert = table.insert
|
local insert = table.insert
|
||||||
local remove = table.remove
|
local remove = table.remove
|
||||||
|
local open = io.open
|
||||||
|
|
||||||
local bor = bit.bor
|
local bor = bit.bor
|
||||||
|
|
||||||
|
@ -24,6 +25,17 @@ local function contains(t, a)
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function prod(x, ...)
|
||||||
|
if type(x) == "table" then
|
||||||
|
return prod(unpack(x))
|
||||||
|
end
|
||||||
|
local ret = x
|
||||||
|
for i = 1, select("#", ...) do
|
||||||
|
ret = ret * select(i, ...)
|
||||||
|
end
|
||||||
|
return ret
|
||||||
|
end
|
||||||
|
|
||||||
local function normal() -- box muller
|
local function normal() -- box muller
|
||||||
return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform())
|
return sqrt(-2 * log(uniform() + 1e-8) + 1e-8) * cos(2 * pi * uniform())
|
||||||
end
|
end
|
||||||
|
@ -58,28 +70,53 @@ local function copy(t) -- shallow copy
|
||||||
end
|
end
|
||||||
|
|
||||||
local function allocate(t, out, init)
|
local function allocate(t, out, init)
|
||||||
-- FIXME: this code is fucking disgusting.
|
|
||||||
|
|
||||||
out = out or {}
|
out = out or {}
|
||||||
assert(type(out) == "table", type(out))
|
|
||||||
if type(t) == "number" then
|
|
||||||
local size = t
|
local size = t
|
||||||
if init ~= nil then
|
if init ~= nil then
|
||||||
return init(zeros(size, out))
|
return init(zeros(size, out))
|
||||||
else
|
else
|
||||||
return zeros(size, out)
|
return zeros(size, out)
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local function levelorder(field, node_in, nodes)
|
||||||
|
-- horribly inefficient.
|
||||||
|
nodes = nodes or {}
|
||||||
|
local q = {node_in}
|
||||||
|
while #q > 0 do
|
||||||
|
local node = q[1]
|
||||||
|
remove(q, 1)
|
||||||
|
insert(nodes, node)
|
||||||
|
for _, child in ipairs(node[field]) do
|
||||||
|
q[#q+1] = child
|
||||||
end
|
end
|
||||||
local topsize = t[1]
|
|
||||||
t = copy(t)
|
|
||||||
remove(t, 1)
|
|
||||||
if #t == 1 then t = t[1] end
|
|
||||||
for i = 1, topsize do
|
|
||||||
local res = allocate(t, nil, init)
|
|
||||||
assert(res ~= nil)
|
|
||||||
insert(out, res)
|
|
||||||
end
|
end
|
||||||
return out
|
return nodes
|
||||||
|
end
|
||||||
|
|
||||||
|
local function traverse(node_in, node_out, nodes)
|
||||||
|
nodes = nodes or {}
|
||||||
|
local down = levelorder('children', node_in, {})
|
||||||
|
local up = levelorder('parents', node_out, {})
|
||||||
|
local seen = {}
|
||||||
|
for _, node in ipairs(up) do
|
||||||
|
seen[node] = bor(seen[node] or 0, 1)
|
||||||
|
end
|
||||||
|
for _, node in ipairs(down) do
|
||||||
|
seen[node] = bor(seen[node] or 0, 2)
|
||||||
|
if seen[node] == 3 then
|
||||||
|
insert(nodes, node)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return nodes
|
||||||
|
end
|
||||||
|
|
||||||
|
local function traverse_all(nodes_in, nodes_out, nodes)
|
||||||
|
local all_in = {children={}}
|
||||||
|
local all_out = {parents={}}
|
||||||
|
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
|
||||||
|
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
|
||||||
|
return traverse(all_in, all_out, nodes or {})
|
||||||
end
|
end
|
||||||
|
|
||||||
local Weights = Base:extend()
|
local Weights = Base:extend()
|
||||||
|
@ -95,24 +132,13 @@ function Weights:init(weight_init)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Weights:allocate(fan_in, fan_out)
|
function Weights:allocate(fan_in, fan_out)
|
||||||
|
self.size = prod(self.shape)
|
||||||
return allocate(self.size, self, function(t)
|
return allocate(self.size, self, function(t)
|
||||||
--print('initializing weights of size', self.size, 'with fans', fan_in, fan_out)
|
--print('initializing weights of size', self.size, 'with fans', fan_in, fan_out)
|
||||||
return self.weight_init(t, fan_in, fan_out)
|
return self.weight_init(t, fan_in, fan_out)
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
--[[
|
|
||||||
local w = Weights(init_he_uniform)
|
|
||||||
w.size = {16, 16}
|
|
||||||
w:allocate(16, 16)
|
|
||||||
print(w)
|
|
||||||
do return end
|
|
||||||
|
|
||||||
local w = zeros(16)
|
|
||||||
for i = 1, #w do w[i] = normal() * 1920 / 2560 end
|
|
||||||
print(w)
|
|
||||||
--]]
|
|
||||||
|
|
||||||
local counter = {}
|
local counter = {}
|
||||||
function Layer:init(name)
|
function Layer:init(name)
|
||||||
assert(type(name) == "string")
|
assert(type(name) == "string")
|
||||||
|
@ -152,18 +178,7 @@ function Layer:_new_weights(init)
|
||||||
return w
|
return w
|
||||||
end
|
end
|
||||||
|
|
||||||
local function prod(x, ...)
|
function Layer:get_size()
|
||||||
if type(x) == "table" then
|
|
||||||
return prod(unpack(x))
|
|
||||||
end
|
|
||||||
local ret = x
|
|
||||||
for i = 1, select("#", ...) do
|
|
||||||
ret = ret * select(i, ...)
|
|
||||||
end
|
|
||||||
return ret
|
|
||||||
end
|
|
||||||
|
|
||||||
function Layer:getsize()
|
|
||||||
local size = 0
|
local size = 0
|
||||||
for i, w in ipairs(self.weights) do size = size + prod(w.size) end
|
for i, w in ipairs(self.weights) do size = size + prod(w.size) end
|
||||||
return size
|
return size
|
||||||
|
@ -237,8 +252,8 @@ end
|
||||||
|
|
||||||
function Dense:make_shape(parent)
|
function Dense:make_shape(parent)
|
||||||
self.size_in = parent.size_out
|
self.size_in = parent.size_out
|
||||||
self.coeffs.size = {self.dim, self.size_in}
|
self.coeffs.shape = {self.size_in, self.dim}
|
||||||
self.biases.size = self.dim
|
self.biases.shape = self.dim
|
||||||
end
|
end
|
||||||
|
|
||||||
function Dense:forward(X)
|
function Dense:forward(X)
|
||||||
|
@ -246,11 +261,11 @@ function Dense:forward(X)
|
||||||
self.cache = self.cache or zeros(self.size_out)
|
self.cache = self.cache or zeros(self.size_out)
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
|
|
||||||
for i = 1, #self.coeffs do
|
for i = 1, self.dim do
|
||||||
local res = 0
|
local res = 0
|
||||||
local c = self.coeffs[i]
|
local c = (i - 1) * #X
|
||||||
for j = 1, #X do
|
for j = 1, #X do
|
||||||
res = res + X[j] * c[j]
|
res = res + X[j] * self.coeffs[c + j]
|
||||||
end
|
end
|
||||||
Y[i] = res + self.biases[i]
|
Y[i] = res + self.biases[i]
|
||||||
end
|
end
|
||||||
|
@ -281,46 +296,6 @@ function Softmax:forward(X)
|
||||||
return Y
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
local function levelorder(field, node_in, nodes)
|
|
||||||
-- horribly inefficient.
|
|
||||||
nodes = nodes or {}
|
|
||||||
local q = {node_in}
|
|
||||||
while #q > 0 do
|
|
||||||
local node = q[1]
|
|
||||||
remove(q, 1)
|
|
||||||
insert(nodes, node)
|
|
||||||
for _, child in ipairs(node[field]) do
|
|
||||||
q[#q+1] = child
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return nodes
|
|
||||||
end
|
|
||||||
|
|
||||||
local function traverse(node_in, node_out, nodes)
|
|
||||||
nodes = nodes or {}
|
|
||||||
local down = levelorder('children', node_in, {})
|
|
||||||
local up = levelorder('parents', node_out, {})
|
|
||||||
local seen = {}
|
|
||||||
for _, node in ipairs(up) do
|
|
||||||
seen[node] = bor(seen[node] or 0, 1)
|
|
||||||
end
|
|
||||||
for _, node in ipairs(down) do
|
|
||||||
seen[node] = bor(seen[node] or 0, 2)
|
|
||||||
if seen[node] == 3 then
|
|
||||||
insert(nodes, node)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return nodes
|
|
||||||
end
|
|
||||||
|
|
||||||
local function traverse_all(nodes_in, nodes_out, nodes)
|
|
||||||
local all_in = {children={}}
|
|
||||||
local all_out = {parents={}}
|
|
||||||
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
|
|
||||||
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
|
|
||||||
return traverse(all_in, all_out, nodes or {})
|
|
||||||
end
|
|
||||||
|
|
||||||
function Model:init(nodes_in, nodes_out)
|
function Model:init(nodes_in, nodes_out)
|
||||||
assert(#nodes_in > 0, #nodes_in)
|
assert(#nodes_in > 0, #nodes_in)
|
||||||
assert(#nodes_out > 0, #nodes_out)
|
assert(#nodes_out > 0, #nodes_out)
|
||||||
|
@ -338,7 +313,7 @@ function Model:reset()
|
||||||
self.n_param = 0
|
self.n_param = 0
|
||||||
for _, node in ipairs(self.nodes) do
|
for _, node in ipairs(self.nodes) do
|
||||||
node:init_weights()
|
node:init_weights()
|
||||||
self.n_param = self.n_param + node:getsize()
|
self.n_param = self.n_param + node:get_size()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -359,10 +334,75 @@ function Model:forward(X)
|
||||||
return outputs
|
return outputs
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Model:collect()
|
||||||
|
-- return a flat array of all the weights in the graph.
|
||||||
|
-- if Lua had slices, we wouldn't need this. future library idea?
|
||||||
|
assert(self.n_param >= 0, self.n_param)
|
||||||
|
local W = zeros(self.n_param)
|
||||||
|
local i = 0
|
||||||
|
for _, node in ipairs(self.nodes) do
|
||||||
|
for _, w in ipairs(node.weights) do
|
||||||
|
for j, v in ipairs(w) do
|
||||||
|
W[i+j] = v
|
||||||
|
end
|
||||||
|
i = i + #w
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return W
|
||||||
|
end
|
||||||
|
|
||||||
|
function Model:distribute(W)
|
||||||
|
-- inverse operation of collect().
|
||||||
|
local i = 0
|
||||||
|
for _, node in ipairs(self.nodes) do
|
||||||
|
for _, w in ipairs(node.weights) do
|
||||||
|
for j, v in ipairs(w) do
|
||||||
|
w[j] = W[i+j]
|
||||||
|
end
|
||||||
|
i = i + #w
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function Model:save(fn)
|
||||||
|
local fn = fn or 'network.txt'
|
||||||
|
local f = open(fn, 'w')
|
||||||
|
if f == nil then error("Failed to save network to file "..fn) end
|
||||||
|
local W = self:collect()
|
||||||
|
for i, v in ipairs(W) do
|
||||||
|
f:write(v)
|
||||||
|
f:write('\n')
|
||||||
|
end
|
||||||
|
f:close()
|
||||||
|
end
|
||||||
|
|
||||||
|
function Model:load(fn)
|
||||||
|
local fn = fn or 'network.txt'
|
||||||
|
local f = open(fn, 'r')
|
||||||
|
if f == nil then
|
||||||
|
error("Failed to load network from file "..fn)
|
||||||
|
end
|
||||||
|
|
||||||
|
local W = zeros(self.n_param)
|
||||||
|
local i = 0
|
||||||
|
for line in f:lines() do
|
||||||
|
i = i + 1
|
||||||
|
local n = tonumber(line)
|
||||||
|
if n == nil then
|
||||||
|
error("Failed reading line "..tostring(i).." of file "..fn)
|
||||||
|
end
|
||||||
|
W[i] = n
|
||||||
|
end
|
||||||
|
f:close()
|
||||||
|
|
||||||
|
self:distribute(W)
|
||||||
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
uniform = uniform,
|
uniform = uniform,
|
||||||
normal = normal,
|
normal = normal,
|
||||||
|
|
||||||
|
copy = copy,
|
||||||
zeros = zeros,
|
zeros = zeros,
|
||||||
init_zeros = init_zeros,
|
init_zeros = init_zeros,
|
||||||
init_he_uniform = init_he_uniform,
|
init_he_uniform = init_he_uniform,
|
||||||
|
|
Loading…
Reference in a new issue