smbot/main.lua

565 lines
15 KiB
Lua
Raw Normal View History

2017-06-28 02:33:18 -07:00
-- be strict about globals.
local mt = getmetatable(_G)
if mt == nil then
mt = {}
setmetatable(_G, mt)
end
function mt.__newindex(t, n, v)
error("cannot assign undeclared global '" .. tostring(n) .. "'", 2)
end
function mt.__index(t, n)
error("cannot use undeclared global '" .. tostring(n) .. "'", 2)
end
local function globalize(t)
for k, v in pairs(t) do
rawset(_G, k, v)
end
end
-- localize some stuff.
2017-06-28 17:14:56 -07:00
local print = print
2017-06-28 02:33:18 -07:00
local ipairs = ipairs
local pairs = pairs
local select = select
local abs = math.abs
local floor = math.floor
local ceil = math.ceil
local min = math.min
local max = math.max
local exp = math.exp
local sqrt = math.sqrt
local random = math.random
local insert = table.insert
local remove = table.remove
local unpack = table.unpack or unpack
local R = memory.readbyteunsigned
local S = memory.readbyte --signed
local W = memory.writebyte
local band = bit.band
local bor = bit.bor
local bxor = bit.bxor
local bnot = bit.bnot
local lshift = bit.lshift
local rshift = bit.rshift
local arshift = bit.arshift
local rol = bit.rol
local ror = bit.ror
local function clamp(x, l, u) return min(max(x, l), u) end
local function argmax(...)
local max_i = 0
local max_v = -999999999
for i=1, select("#", ...) do
local v = select(i, ...)
if v > max_v then
max_i = i
max_v = v
end
end
return max_i
end
local function argmax2(t)
return t[1] > t[2]
end
-- game-agnostic stuff (i.e. the network itself)
package.loaded['nn'] = nil -- DEBUG
local nn = require("nn")
local network
local nn_x
local nn_y
local nn_z
local function make_network(input_size, buttons)
nn_x = nn.Input(input_size)
nn_y = nn_x
nn_z = {}
nn_y = nn_y:feed(nn.Dense(input_size))
nn_y = nn_y:feed(nn.Relu())
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16)))
--nn_y = nn_y:feed(nn.Relu())
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16)))
--nn_y = nn_y:feed(nn.Relu())
for i = 1, buttons do
nn_z[i] = nn_y
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
nn_z[i] = nn_z[i]:feed(nn.Softmax())
end
return nn.Model({nn_x}, nn_z)
end
-- and here we go with the game stuff.
local enable_overlay = false
local enable_network = true
--[[
https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
--]]
local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
0, -40, -- 0
6, -38,
15, -37,
22, -32,
28, -28,
32, -22,
37, -14,
39, -6,
40, 0, -- 8
38, 7,
37, 15,
33, 23,
27, 29,
22, 33,
14, 37,
6, 39,
0, 41, -- 10
-7, 40,
-16, 38,
-22, 34,
-28, 28,
-34, 23,
-38, 16,
-40, 8,
-40, -0, -- 18
-40, -6,
-38, -14,
-34, -22,
-28, -28,
-22, -32,
-16, -36,
-8, -38,
}
local startsave = savestate.create(1)
local poketime = false
local sprite_input = {}
local tile_input = {}
local function shitsprite(x, y, t)
if x < 0 or x >= 256 or y < 0 or y > 224 then
sprite_input[#sprite_input+1] = 0
sprite_input[#sprite_input+1] = 0
sprite_input[#sprite_input+1] = 0
else
sprite_input[#sprite_input+1] = x
sprite_input[#sprite_input+1] = y
sprite_input[#sprite_input+1] = t
end
if t == 0 then return end
if enable_overlay then
gui.box(x-4, y-4, x+4, y+4)
--gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000')
gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F')
--gui.text(x-5, y-3+9, ("%02X"):format(x), '#FFFFFF', '#0000003F')
end
end
local function shittile(x, y, t)
tile_input[#tile_input+1] = t
if t == 0 then return end
if enable_overlay then
gui.box(x-8, y-8, x+8, y+8)
gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
end
end
local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr)
local spl_l = R(0x71A)
local spl_r = R(0x71B)
local sx_l = R(0x71C)
local sx_r = R(0x71D)
local x = R(x_addr + i)
local y = R(y_addr + i)
local sx, sy = x, y
if pageloc_addr ~= nil then
local page = R(pageloc_addr + i)
sx = sx - sx_l - (spl_l - page) * 256
else
sx = sx - sx_l
end
if hipos_addr ~= nil then
local hipos = S(hipos_addr + i)
sy = sy + (hipos - 1) * 256
end
return sx, sy
end
local reward
local powerup_old
local status_old
local coins_old
local once = false
local reset = true
emu.poweron()
emu.unpause()
emu.speedmode("normal")
local function opermode() return R(0x770) end
local function paused() return band(R(0x776), 1) end
local function subroutine() return R(0xE) end
local function getstate()
if R(0xE) == 0xFF then return 'power' end
if R(0x774) > 0 then return 'lagging' end
if R(0x7A2) > 0 then return 'waiting_demo' end
if R(0x717) > 0 then return 'playing_demo' end
-- if R(0x770) == 0xFF then return 'power' end
if paused() ~= 0 then return 'paused' end
if R(0xE) == 0 then return 'world_screen' end
-- if R(0x712) == 1 then return 'deadmusic' end
if R(0x7CA) == 0x94 then return 'dead' end
if R(0xE) == 4 then return 'win_flagpole' end
if R(0xE) == 5 then return 'win_walking' end
if R(0xE) == 6 then return 'lose' end
-- if R(0x770) == 0 then return 'not_playing' end
if R(0x770) == 2 then return 'win_castle' end
if R(0x772) == 2 then return 'no_control' end
if R(0x772) == 3 then return 'playing' end
if R(0x770) == 1 then return 'loading' end
if R(0x770) == 3 then return 'lose' end
return 'unknown'
end
local ok_routines = {
[0x4] = true, -- sliding down flagpole
[0x5] = true, -- end of level auto-walk
[0x7] = true, -- start of level auto-walk
[0x8] = true, -- normal (in control)
[0x9] = true, -- acquiring mushroom
[0xA] = true, -- losing big mario
[0xB] = true, -- uhh
[0xC] = true, -- acquiring fireflower
}
local fuckstates = {
power = true,
waiting_demo = true,
playing_demo = true,
unknown = true,
lose = true,
paused = true,
}
local function advance()
emu.frameadvance()
while emu.lagged() do emu.frameadvance() end -- skip lag frames.
while R(0x774) > 0 do emu.frameadvance() end -- also lag frames.
end
local state_old = ''
while false do
local state = getstate()
if state ~= state_old then
print(emu.framecount(), state)
state_old = state
end
advance()
end
local function handle_enemies()
-- enemies, flagpole
for i = 0, 5 do
local x, y = getxy(i, 0x87, 0xCF, 0x6E, 0xB6)
x, y = x + 8, y + 16
local tid = R(0x16 + i)
local flags = R(0xF + i)
--local offscr = R(0x3D8 + i)
local invisible = tid < 0x10 and flags == 0
if tid == 0x30 then y = y - 8 end -- flagpole flag
if tid == 0x31 then y = y - 8 end -- castle flag
if tid == 0x16 then x, y = x - 4, y - 12 end -- fireworks
if tid >= 0x24 and tid <= 0x29 then x, y = x + 16, y - 12 end -- moving platforms
if tid == 0x2D then x, y = x, y end -- bowser (TODO: determine head or body)
if tid == 0x15 then x, y = x, y - 12 end -- bowser fire
if tid == 0x32 then x, y = x, y - 8 end -- spring
-- tid == 0x35 -- toad
if tid == 0x1D or tid == 0x1B then -- rotating fire bars
x, y = x - 4, y - 12
-- this is a mess... gotta find out its rotation and then project.
-- TODO: handle long fire bars too
local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i)
gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F')
local x_off, y_off = rotation_offsets[rot*2+1], rotation_offsets[rot*2+2]
x, y = x + x_off, y + y_off
end
if invisible then
shitsprite(0, 0, 0)
else
shitsprite(x, y, tid + 1)
end
end
end
local function handle_fireballs()
-- fireballs
for i = 0, 1 do
local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC)
x, y = x + 4, y + 4
local state = R(0x24 + i)
local invisible = state == 0
if invisible then
shitsprite(0, 0, 0)
else
shitsprite(x, y, 257)
end
end
end
local function handle_blocks()
for i = 0, 3 do
local x, y = getxy(i, 0x8F, 0xD7, 0x76, 0xBE)
x, y = x + 8, y + 8
local state = R(0x26 + i)
local invisible = state == 0
if invisible then
shitsprite(0, 0, 0)
else
shitsprite(x, y, 258)
end
end
end
local function handle_hammers()
-- hammers, coins, score bonus text...
for i = 0, 8 do
local x, y = getxy(i, 0x93, 0xDB, 0x7A, 0xC2)
x, y = x + 8, y + 8
local state = R(0x2A + i)
-- skip coin effect states. not interactable; we don't care!
if state ~= 0
and state >= 0x30
then
shitsprite(x, y, state + 1)
else
shitsprite(0, 0, 0)
end
end
end
local function handle_misc()
for i = 0, 0 do
local x, y = getxy(i, 0x9C, 0xE4, 0x83, 0xCB)
x, y = x + 8, y + 8
local state = R(0x33 + i)
if state ~= 0 then
shitsprite(x, y, state + 1)
else
shitsprite(0, 0, 0)
end
end
end
local function handle_tiles()
--local tile_col = R(0x6A0)
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
local tile_scroll_remainder = R(0x73F) % 16
tile_input[#tile_input+1] = tile_scroll_remainder
for y = 0, 12 do
for x = 0, 16 do
local col = (x + tile_scroll) % 32
local t
if col < 16 then
t = R(0x500 + y * 16 + (col % 16))
else
t = R(0x5D0 + y * 16 + (col % 16))
end
local sx = x * 16 + 8 - tile_scroll_remainder
local sy = y * 16 + 40
shittile(sx, sy, t)
end
end
end
2017-06-28 17:14:56 -07:00
local function doreset()
print("resetting in state:", getstate())
if once then
savestate.load(startsave)
print("end of trial reward:", reward)
print()
else
savestate.save(startsave)
end
once = true
-- bit of a hack:
if getstate() == 'loading' then advance() end
reward = 0
powerup_old = R(0x754)
status_old = R(0x756)
coins_old = R(0x7ED) * 10 + R(0x7EE)
-- set lives to 0. you only got one shot!
-- unless you get a 1-up, in which case, please continue!
W(0x75A, 0)
emu.frameadvance() -- prevents emulator from quirking up.
reset = false
if network ~= nil then network:reset() end -- FIXME: hack
end
2017-06-28 02:33:18 -07:00
while true do
while fuckstates[getstate()] do
--gui.text(120, 124, ("%02X"):format(subroutine()), '#FFFFFF', '#0000003F')
-- mash the start button until we have control.
-- TODO: learn this too.
--local jp = joypad.read(1)
local jp = {
up = false,
down = false,
left = false,
right = false,
A = false,
B = false,
select = false,
start = emu.framecount() % 2 == 1,
}
joypad.write(1, jp)
reset = true
gui.text(4, 12, getstate(), '#FFFFFF', '#0000003F')
advance()
-- bit of a hack:
while getstate() == "loading" do advance() end
state_old = getstate()
end
2017-06-28 17:14:56 -07:00
if reset then doreset() end
2017-06-28 02:33:18 -07:00
if not enable_network then
-- infinite time cheat. super handy for testing.
if R(0xE) == 8 then
W(0x7F8, 9)
W(0x7F9, 9)
W(0x7FA, 10)
poketime = true
elseif poketime then
poketime = false
W(0x7F8, 0)
W(0x7F9, 0)
W(0x7FA, 1)
end
-- infinite lives.
W(0x75A, 1)
end
-- empty input lists without creating a new table.
for k, v in pairs(sprite_input) do sprite_input[k] = nil end
for k, v in pairs(tile_input) do tile_input[k] = nil end
-- player
-- TODO: add check if mario is playable.
local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
local powerup = R(0x754)
local status = R(0x756)
shitsprite(x + 8, y + 24, -powerup - 1)
handle_enemies()
handle_fireballs()
-- blocks being hit. not interactable; we don't care!
--handle_blocks()
handle_hammers()
handle_misc()
handle_tiles()
local coins = R(0x7ED) * 10 + R(0x7EE)
local coins_delta = coins - coins_old
-- handle wrap-around.
if coins_delta < 0 then coins_delta = 100 + coins - coins_old end
-- remember that 0 is big mario and 1 is small mario.
local powerup_delta = powerup_old - powerup
-- 2 is fire mario.
local status_delta = clamp(status - status_old, -1, 1)
local screen_scroll_delta = R(0x775)
local flagpole_bonus = subroutine() == 4 and 1 or 0
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
reward = reward + reward_delta
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
--gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F')
gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
if getstate() == 'dead' and state_old ~= 'dead' then
print("dead. lives remaining:", R(0x75A, 0))
if R(0x75A, 0) == 0 then reset = true end
end
if getstate() == 'lose' then
print("ran out of lives.")
reset = true
end
-- every few frames mario stands still, forcibly decrease the timer.
local timer_loser = 1/5
if math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
local timer = R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
timer = clamp(timer - 1, 0, 400)
W(0x7F8, floor(timer / 100))
W(0x7F9, floor((timer / 10) % 10))
W(0x7FA, floor(timer % 10))
end
local X = {} -- TODO: cache.
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
for i, v in ipairs(tile_input) do insert(X, v / 256) end
if network == nil then
network = make_network(#X, 8); network:reset()
print("parameters:", network.n_param)
end
if enable_network and getstate() == 'playing' then
local outputs = network:forward(X)
local softmaxed = {
outputs[nn_z[1]],
outputs[nn_z[2]],
outputs[nn_z[3]],
outputs[nn_z[4]],
outputs[nn_z[5]],
outputs[nn_z[6]],
outputs[nn_z[7]],
outputs[nn_z[8]],
}
local jp = {
up = argmax2(softmaxed[1]),
down = argmax2(softmaxed[2]),
left = argmax2(softmaxed[3]),
right = argmax2(softmaxed[4]),
A = argmax2(softmaxed[5]),
B = argmax2(softmaxed[6]),
start = argmax2(softmaxed[7]),
select = argmax2(softmaxed[8]),
}
joypad.write(1, jp)
end
coins_old = coins
powerup_old = powerup
status_old = status
state_old = getstate()
advance()
end