2017-09-07 11:34:19 -07:00
|
|
|
-- hacks for FCEUX being dumb.
|
|
|
|
local _error = error
|
|
|
|
local _assert = assert
|
|
|
|
local function error_(msg, level)
|
|
|
|
if level == nil then level = 1 end
|
|
|
|
print()
|
|
|
|
print(debug.traceback(msg, 1 + level):gsub("\n", "\r\n"))
|
|
|
|
_error(msg, level)
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
2017-09-07 11:34:19 -07:00
|
|
|
local function assert_(cond, msg)
|
|
|
|
if cond then return cond end
|
|
|
|
msg = msg or "nondescript"
|
|
|
|
print()
|
|
|
|
print(debug.traceback(msg, 2):gsub("\n", "\r\n"))
|
|
|
|
_error("assertion failed!")
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
2017-09-07 11:34:19 -07:00
|
|
|
rawset(_G, 'error', error_)
|
|
|
|
rawset(_G, 'assert', assert_)
|
2017-06-28 02:33:18 -07:00
|
|
|
|
2017-09-07 11:34:19 -07:00
|
|
|
-- be strict about globals.
|
|
|
|
local mt = getmetatable(_G)
|
|
|
|
if mt == nil then mt = {} setmetatable(_G, mt) end
|
|
|
|
function mt.__newindex(t, n, v) error("cannot assign undeclared global '" .. tostring(n) .. "'", 2) end
|
|
|
|
function mt.__index(t, n) error("cannot use undeclared global '" .. tostring(n) .. "'", 2) end
|
|
|
|
local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
|
2017-06-28 02:33:18 -07:00
|
|
|
|
2017-09-07 11:41:44 -07:00
|
|
|
-- configuration.
|
2017-06-28 21:51:02 -07:00
|
|
|
|
2017-06-29 02:50:33 -07:00
|
|
|
--randomseed(11)
|
2017-06-28 21:51:02 -07:00
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
local playable_mode = false
|
2017-06-29 02:50:33 -07:00
|
|
|
--
|
2017-09-07 11:41:44 -07:00
|
|
|
-- true greedy epsilon has both deterministic and det_epsilon set.
|
|
|
|
local deterministic = true -- use argmax on outputs instead of random sampling.
|
2017-07-05 20:26:27 -07:00
|
|
|
local det_epsilon = true -- take random actions with probability eps.
|
2017-09-07 11:41:44 -07:00
|
|
|
-- using parameters from DQN... sorta.
|
|
|
|
local eps_start = 1.0 * 1/60 -- i think this should be * 1/16 for atari ref.
|
|
|
|
local eps_stop = 0.1 * 1/60 -- "
|
|
|
|
local eps_frames = 1000000
|
2017-07-05 20:26:27 -07:00
|
|
|
local consider_past_rewards = false
|
|
|
|
local learn_start_select = false
|
|
|
|
--
|
|
|
|
local epoch_trials = 40 -- 24
|
|
|
|
local learning_rate = 1e-3
|
|
|
|
local deviation = 1e-2 -- 4e-3
|
2017-06-29 02:50:33 -07:00
|
|
|
--
|
|
|
|
local cap_time = 400
|
2017-07-05 20:26:27 -07:00
|
|
|
local timer_loser = 1/3
|
|
|
|
--
|
|
|
|
local enable_overlay = playable_mode
|
|
|
|
local enable_network = not playable_mode
|
|
|
|
|
|
|
|
local input_size = 281 -- TODO: let the script figure this out for us.
|
2017-06-28 21:51:02 -07:00
|
|
|
|
2017-09-07 11:41:44 -07:00
|
|
|
local ok_routines = {
|
|
|
|
[0x4] = true, -- sliding down flagpole
|
|
|
|
[0x5] = true, -- end of level auto-walk
|
|
|
|
[0x7] = true, -- start of level auto-walk
|
|
|
|
[0x8] = true, -- normal (in control)
|
|
|
|
[0x9] = true, -- acquiring mushroom
|
|
|
|
[0xA] = true, -- losing big mario
|
|
|
|
[0xB] = true, -- uhh
|
|
|
|
[0xC] = true, -- acquiring fireflower
|
|
|
|
}
|
|
|
|
|
|
|
|
local bad_states = {
|
|
|
|
power = true,
|
|
|
|
waiting_demo = true,
|
|
|
|
playing_demo = true,
|
|
|
|
unknown = true,
|
|
|
|
lose = true,
|
|
|
|
}
|
|
|
|
|
|
|
|
-- state.
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local epoch_i = 0
|
|
|
|
local base_params
|
|
|
|
local trial_i = 0
|
|
|
|
local trial_noise = {}
|
|
|
|
local trial_rewards = {}
|
|
|
|
local trials_remaining = 0
|
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
local trial_frames = 0
|
|
|
|
local total_frames = 0
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local force_start = false
|
|
|
|
local force_start_old = false
|
|
|
|
|
|
|
|
local startsave = savestate.create(1)
|
|
|
|
|
|
|
|
local poketime = false
|
2017-06-29 02:50:33 -07:00
|
|
|
local max_time
|
2017-06-28 21:51:02 -07:00
|
|
|
|
|
|
|
local sprite_input = {}
|
|
|
|
local tile_input = {}
|
2017-07-05 20:26:27 -07:00
|
|
|
local extra_input = {}
|
2017-06-28 21:51:02 -07:00
|
|
|
|
|
|
|
local reward
|
2017-07-05 20:26:27 -07:00
|
|
|
local all_rewards = {}
|
2017-06-29 02:50:33 -07:00
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local powerup_old
|
|
|
|
local status_old
|
|
|
|
local coins_old
|
2017-07-05 20:26:27 -07:00
|
|
|
local score_old
|
2017-06-28 21:51:02 -07:00
|
|
|
|
|
|
|
local once = false
|
|
|
|
local reset = true
|
|
|
|
|
|
|
|
local state_old = ''
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
-- localize some stuff.
|
|
|
|
|
2017-06-28 17:14:56 -07:00
|
|
|
local print = print
|
2017-06-28 02:33:18 -07:00
|
|
|
local ipairs = ipairs
|
|
|
|
local pairs = pairs
|
|
|
|
local select = select
|
|
|
|
local abs = math.abs
|
|
|
|
local floor = math.floor
|
|
|
|
local ceil = math.ceil
|
|
|
|
local min = math.min
|
|
|
|
local max = math.max
|
|
|
|
local exp = math.exp
|
2017-06-29 02:50:33 -07:00
|
|
|
local log = math.log
|
2017-06-28 02:33:18 -07:00
|
|
|
local sqrt = math.sqrt
|
|
|
|
local random = math.random
|
2017-06-29 02:50:33 -07:00
|
|
|
local randomseed = math.randomseed
|
2017-06-28 02:33:18 -07:00
|
|
|
local insert = table.insert
|
|
|
|
local remove = table.remove
|
|
|
|
local unpack = table.unpack or unpack
|
|
|
|
local R = memory.readbyteunsigned
|
|
|
|
local S = memory.readbyte --signed
|
|
|
|
local W = memory.writebyte
|
|
|
|
|
|
|
|
local band = bit.band
|
|
|
|
local bor = bit.bor
|
|
|
|
local bxor = bit.bxor
|
|
|
|
local bnot = bit.bnot
|
|
|
|
local lshift = bit.lshift
|
|
|
|
local rshift = bit.rshift
|
|
|
|
local arshift = bit.arshift
|
|
|
|
local rol = bit.rol
|
|
|
|
local ror = bit.ror
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
-- utilities.
|
|
|
|
|
|
|
|
local function boolean_xor(a, b)
|
|
|
|
if a and b then return false end
|
|
|
|
if not a and not b then return false end
|
|
|
|
return true
|
|
|
|
end
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
local function clamp(x, l, u) return min(max(x, l), u) end
|
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
local function argmax(...)
|
|
|
|
local max_i = 0
|
|
|
|
local max_v = -999999999
|
|
|
|
for i=1, select("#", ...) do
|
|
|
|
local v = select(i, ...)
|
|
|
|
if v > max_v then
|
|
|
|
max_i = i
|
|
|
|
max_v = v
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return max_i
|
|
|
|
end
|
|
|
|
|
|
|
|
local function argmax2(t)
|
|
|
|
return t[1] > t[2]
|
|
|
|
end
|
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
local function rchoice2(t)
|
|
|
|
return t[1] > random()
|
|
|
|
end
|
|
|
|
|
|
|
|
local function rbool(t)
|
|
|
|
return 0.5 >= random()
|
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local function empty(t)
|
|
|
|
for k, _ in pairs(t) do t[k] = nil end
|
|
|
|
return t
|
|
|
|
end
|
|
|
|
|
|
|
|
local function calc_mean_dev(x)
|
|
|
|
local mean = 0
|
|
|
|
for i, v in ipairs(x) do
|
|
|
|
mean = mean + v / #x
|
|
|
|
end
|
|
|
|
|
|
|
|
local dev = 0
|
|
|
|
for i, v in ipairs(x) do
|
|
|
|
local delta = v - mean
|
|
|
|
dev = dev + delta * delta / #x
|
|
|
|
end
|
|
|
|
|
|
|
|
return mean, dev
|
|
|
|
end
|
|
|
|
|
|
|
|
local function normalize(x, out)
|
|
|
|
out = out or x
|
|
|
|
local mean, dev = calc_mean_dev(x)
|
|
|
|
if dev <= 0 then dev = 1 end
|
|
|
|
local devs = sqrt(dev)
|
|
|
|
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
local function normalize_wrt(x, s, out)
|
|
|
|
out = out or x
|
|
|
|
local mean, dev = calc_mean_dev(s)
|
|
|
|
if dev <= 0 then dev = 1 end
|
|
|
|
local devs = sqrt(dev)
|
|
|
|
for i, v in ipairs(x) do out[i] = (v - mean) / devs end
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
-- game-agnostic stuff (i.e. the network itself)
|
|
|
|
|
|
|
|
package.loaded['nn'] = nil -- DEBUG
|
|
|
|
local nn = require("nn")
|
|
|
|
|
|
|
|
local network
|
|
|
|
local nn_x
|
|
|
|
local nn_y
|
|
|
|
local nn_z
|
|
|
|
local function make_network(input_size, buttons)
|
|
|
|
nn_x = nn.Input(input_size)
|
|
|
|
nn_y = nn_x
|
|
|
|
nn_z = {}
|
2017-06-29 02:50:33 -07:00
|
|
|
if false then
|
|
|
|
nn_y = nn_y:feed(nn.Dense(input_size))
|
|
|
|
nn_y = nn_y:feed(nn.Gelu())
|
|
|
|
else
|
2017-07-05 20:26:27 -07:00
|
|
|
nn_y = nn_y:feed(nn.Dense(128))
|
|
|
|
nn_y = nn_y:feed(nn.Gelu())
|
|
|
|
nn_y = nn_y:feed(nn.Dense(64))
|
|
|
|
nn_y = nn_y:feed(nn.Gelu())
|
|
|
|
nn_y = nn_y:feed(nn.Dense(48))
|
|
|
|
nn_y = nn_y:feed(nn.Gelu())
|
2017-06-29 02:50:33 -07:00
|
|
|
end
|
2017-06-28 02:33:18 -07:00
|
|
|
for i = 1, buttons do
|
|
|
|
nn_z[i] = nn_y
|
|
|
|
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
|
|
|
|
nn_z[i] = nn_z[i]:feed(nn.Softmax())
|
|
|
|
end
|
|
|
|
|
|
|
|
return nn.Model({nn_x}, nn_z)
|
|
|
|
end
|
|
|
|
|
|
|
|
-- and here we go with the game stuff.
|
|
|
|
|
|
|
|
--[[
|
|
|
|
https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
|
|
|
|
--]]
|
|
|
|
|
|
|
|
local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
|
|
|
|
0, -40, -- 0
|
|
|
|
6, -38,
|
|
|
|
15, -37,
|
|
|
|
22, -32,
|
|
|
|
28, -28,
|
|
|
|
32, -22,
|
|
|
|
37, -14,
|
|
|
|
39, -6,
|
|
|
|
40, 0, -- 8
|
|
|
|
38, 7,
|
|
|
|
37, 15,
|
|
|
|
33, 23,
|
|
|
|
27, 29,
|
|
|
|
22, 33,
|
|
|
|
14, 37,
|
|
|
|
6, 39,
|
|
|
|
0, 41, -- 10
|
|
|
|
-7, 40,
|
|
|
|
-16, 38,
|
|
|
|
-22, 34,
|
|
|
|
-28, 28,
|
|
|
|
-34, 23,
|
|
|
|
-38, 16,
|
|
|
|
-40, 8,
|
|
|
|
-40, -0, -- 18
|
|
|
|
-40, -6,
|
|
|
|
-38, -14,
|
|
|
|
-34, -22,
|
|
|
|
-28, -28,
|
|
|
|
-22, -32,
|
|
|
|
-16, -36,
|
|
|
|
-8, -38,
|
|
|
|
}
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local function get_timer()
|
|
|
|
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
|
|
|
end
|
2017-06-28 02:33:18 -07:00
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
local function get_score()
|
|
|
|
return R(0x7DE) * 10000 +
|
|
|
|
R(0x7DF) * 1000 +
|
|
|
|
R(0x7E0) * 100 +
|
|
|
|
R(0x7E1) * 10 +
|
|
|
|
R(0x7E2)
|
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local function set_timer(time)
|
|
|
|
W(0x7F8, floor(time / 100))
|
|
|
|
W(0x7F9, floor((time / 10) % 10))
|
|
|
|
W(0x7FA, floor(time % 10))
|
|
|
|
end
|
2017-06-28 02:33:18 -07:00
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local function mark_sprite(x, y, t)
|
2017-06-28 02:33:18 -07:00
|
|
|
if x < 0 or x >= 256 or y < 0 or y > 224 then
|
|
|
|
sprite_input[#sprite_input+1] = 0
|
|
|
|
sprite_input[#sprite_input+1] = 0
|
|
|
|
sprite_input[#sprite_input+1] = 0
|
|
|
|
else
|
|
|
|
sprite_input[#sprite_input+1] = x
|
|
|
|
sprite_input[#sprite_input+1] = y
|
|
|
|
sprite_input[#sprite_input+1] = t
|
|
|
|
end
|
|
|
|
if t == 0 then return end
|
|
|
|
if enable_overlay then
|
|
|
|
gui.box(x-4, y-4, x+4, y+4)
|
|
|
|
--gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000')
|
|
|
|
gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F')
|
|
|
|
--gui.text(x-5, y-3+9, ("%02X"):format(x), '#FFFFFF', '#0000003F')
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local function mark_tile(x, y, t)
|
2017-06-28 02:33:18 -07:00
|
|
|
tile_input[#tile_input+1] = t
|
|
|
|
if t == 0 then return end
|
|
|
|
if enable_overlay then
|
|
|
|
gui.box(x-8, y-8, x+8, y+8)
|
|
|
|
gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr)
|
|
|
|
local spl_l = R(0x71A)
|
|
|
|
local spl_r = R(0x71B)
|
|
|
|
local sx_l = R(0x71C)
|
|
|
|
local sx_r = R(0x71D)
|
|
|
|
|
|
|
|
local x = R(x_addr + i)
|
|
|
|
local y = R(y_addr + i)
|
|
|
|
local sx, sy = x, y
|
|
|
|
if pageloc_addr ~= nil then
|
|
|
|
local page = R(pageloc_addr + i)
|
|
|
|
sx = sx - sx_l - (spl_l - page) * 256
|
|
|
|
else
|
|
|
|
sx = sx - sx_l
|
|
|
|
end
|
|
|
|
if hipos_addr ~= nil then
|
|
|
|
local hipos = S(hipos_addr + i)
|
|
|
|
sy = sy + (hipos - 1) * 256
|
|
|
|
end
|
|
|
|
|
|
|
|
return sx, sy
|
|
|
|
end
|
|
|
|
|
|
|
|
local function paused() return band(R(0x776), 1) end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local function get_state()
|
2017-06-28 02:33:18 -07:00
|
|
|
if R(0xE) == 0xFF then return 'power' end
|
|
|
|
if R(0x774) > 0 then return 'lagging' end
|
|
|
|
if R(0x7A2) > 0 then return 'waiting_demo' end
|
|
|
|
if R(0x717) > 0 then return 'playing_demo' end
|
|
|
|
-- if R(0x770) == 0xFF then return 'power' end
|
|
|
|
if paused() ~= 0 then return 'paused' end
|
|
|
|
if R(0xE) == 0 then return 'world_screen' end
|
|
|
|
-- if R(0x712) == 1 then return 'deadmusic' end
|
|
|
|
if R(0x7CA) == 0x94 then return 'dead' end
|
|
|
|
if R(0xE) == 4 then return 'win_flagpole' end
|
|
|
|
if R(0xE) == 5 then return 'win_walking' end
|
|
|
|
if R(0xE) == 6 then return 'lose' end
|
|
|
|
-- if R(0x770) == 0 then return 'not_playing' end
|
|
|
|
if R(0x770) == 2 then return 'win_castle' end
|
|
|
|
if R(0x772) == 2 then return 'no_control' end
|
|
|
|
if R(0x772) == 3 then return 'playing' end
|
|
|
|
if R(0x770) == 1 then return 'loading' end
|
|
|
|
if R(0x770) == 3 then return 'lose' end
|
|
|
|
return 'unknown'
|
|
|
|
end
|
|
|
|
|
|
|
|
local function advance()
|
|
|
|
emu.frameadvance()
|
|
|
|
while emu.lagged() do emu.frameadvance() end -- skip lag frames.
|
|
|
|
while R(0x774) > 0 do emu.frameadvance() end -- also lag frames.
|
|
|
|
end
|
|
|
|
|
|
|
|
while false do
|
2017-06-28 21:51:02 -07:00
|
|
|
local state = get_state()
|
2017-06-28 02:33:18 -07:00
|
|
|
if state ~= state_old then
|
|
|
|
print(emu.framecount(), state)
|
|
|
|
state_old = state
|
|
|
|
end
|
|
|
|
advance()
|
|
|
|
end
|
|
|
|
|
|
|
|
local function handle_enemies()
|
|
|
|
-- enemies, flagpole
|
|
|
|
for i = 0, 5 do
|
|
|
|
local x, y = getxy(i, 0x87, 0xCF, 0x6E, 0xB6)
|
|
|
|
x, y = x + 8, y + 16
|
|
|
|
local tid = R(0x16 + i)
|
|
|
|
local flags = R(0xF + i)
|
|
|
|
--local offscr = R(0x3D8 + i)
|
|
|
|
local invisible = tid < 0x10 and flags == 0
|
|
|
|
if tid == 0x30 then y = y - 8 end -- flagpole flag
|
|
|
|
if tid == 0x31 then y = y - 8 end -- castle flag
|
|
|
|
if tid == 0x16 then x, y = x - 4, y - 12 end -- fireworks
|
|
|
|
if tid >= 0x24 and tid <= 0x29 then x, y = x + 16, y - 12 end -- moving platforms
|
|
|
|
if tid == 0x2D then x, y = x, y end -- bowser (TODO: determine head or body)
|
|
|
|
if tid == 0x15 then x, y = x, y - 12 end -- bowser fire
|
|
|
|
if tid == 0x32 then x, y = x, y - 8 end -- spring
|
|
|
|
-- tid == 0x35 -- toad
|
|
|
|
if tid == 0x1D or tid == 0x1B then -- rotating fire bars
|
|
|
|
x, y = x - 4, y - 12
|
|
|
|
-- this is a mess... gotta find out its rotation and then project.
|
|
|
|
-- TODO: handle long fire bars too
|
|
|
|
local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i)
|
|
|
|
gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F')
|
|
|
|
local x_off, y_off = rotation_offsets[rot*2+1], rotation_offsets[rot*2+2]
|
|
|
|
x, y = x + x_off, y + y_off
|
|
|
|
end
|
|
|
|
if invisible then
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_sprite(0, 0, 0)
|
2017-06-28 02:33:18 -07:00
|
|
|
else
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_sprite(x, y, tid + 1)
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
local function handle_fireballs()
|
|
|
|
-- fireballs
|
|
|
|
for i = 0, 1 do
|
|
|
|
local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC)
|
|
|
|
x, y = x + 4, y + 4
|
|
|
|
local state = R(0x24 + i)
|
|
|
|
local invisible = state == 0
|
|
|
|
if invisible then
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_sprite(0, 0, 0)
|
2017-06-28 02:33:18 -07:00
|
|
|
else
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_sprite(x, y, 257)
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
local function handle_blocks()
|
|
|
|
for i = 0, 3 do
|
|
|
|
local x, y = getxy(i, 0x8F, 0xD7, 0x76, 0xBE)
|
|
|
|
x, y = x + 8, y + 8
|
|
|
|
local state = R(0x26 + i)
|
|
|
|
local invisible = state == 0
|
|
|
|
if invisible then
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_sprite(0, 0, 0)
|
2017-06-28 02:33:18 -07:00
|
|
|
else
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_sprite(x, y, 258)
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
local function handle_hammers()
|
|
|
|
-- hammers, coins, score bonus text...
|
|
|
|
for i = 0, 8 do
|
|
|
|
local x, y = getxy(i, 0x93, 0xDB, 0x7A, 0xC2)
|
|
|
|
x, y = x + 8, y + 8
|
|
|
|
local state = R(0x2A + i)
|
|
|
|
-- skip coin effect states. not interactable; we don't care!
|
|
|
|
if state ~= 0
|
|
|
|
and state >= 0x30
|
|
|
|
then
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_sprite(x, y, state + 1)
|
2017-06-28 02:33:18 -07:00
|
|
|
else
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_sprite(0, 0, 0)
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
local function handle_misc()
|
|
|
|
for i = 0, 0 do
|
|
|
|
local x, y = getxy(i, 0x9C, 0xE4, 0x83, 0xCB)
|
|
|
|
x, y = x + 8, y + 8
|
|
|
|
local state = R(0x33 + i)
|
|
|
|
if state ~= 0 then
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_sprite(x, y, state + 1)
|
2017-06-28 02:33:18 -07:00
|
|
|
else
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_sprite(0, 0, 0)
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
local function handle_tiles()
|
|
|
|
--local tile_col = R(0x6A0)
|
|
|
|
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
|
|
|
local tile_scroll_remainder = R(0x73F) % 16
|
|
|
|
tile_input[#tile_input+1] = tile_scroll_remainder
|
|
|
|
for y = 0, 12 do
|
|
|
|
for x = 0, 16 do
|
|
|
|
local col = (x + tile_scroll) % 32
|
|
|
|
local t
|
|
|
|
if col < 16 then
|
|
|
|
t = R(0x500 + y * 16 + (col % 16))
|
|
|
|
else
|
|
|
|
t = R(0x5D0 + y * 16 + (col % 16))
|
|
|
|
end
|
|
|
|
local sx = x * 16 + 8 - tile_scroll_remainder
|
|
|
|
local sy = y * 16 + 40
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_tile(sx, sy, t)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2017-09-07 11:41:44 -07:00
|
|
|
local function prepare_epoch()
|
|
|
|
print('preparing epoch '..tostring(epoch_i)..'. this might take a while.')
|
|
|
|
base_params = network:collect()
|
|
|
|
empty(trial_noise)
|
|
|
|
empty(trial_rewards)
|
|
|
|
for i = 1, epoch_trials do
|
|
|
|
local noise = nn.zeros(#base_params)
|
|
|
|
for j = 1, #base_params do noise[j] = nn.normal() end
|
|
|
|
trial_noise[i] = noise
|
|
|
|
end
|
|
|
|
trial_i = 0
|
|
|
|
end
|
|
|
|
|
|
|
|
local function load_next_trial()
|
|
|
|
trial_i = trial_i + 1
|
|
|
|
print('loading trial', trial_i)
|
|
|
|
local W = nn.copy(base_params)
|
|
|
|
local noise = trial_noise[trial_i]
|
|
|
|
local devsqrt = sqrt(deviation)
|
|
|
|
for i, v in ipairs(base_params) do
|
|
|
|
W[i] = v + devsqrt * noise[i]
|
|
|
|
end
|
|
|
|
network:distribute(W)
|
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local function learn_from_epoch()
|
|
|
|
print()
|
|
|
|
print('rewards:', trial_rewards)
|
2017-07-05 20:26:27 -07:00
|
|
|
|
|
|
|
for _, v in ipairs(trial_rewards) do
|
|
|
|
insert(all_rewards, v)
|
|
|
|
end
|
|
|
|
|
|
|
|
if consider_past_rewards then
|
|
|
|
normalize_wrt(trial_rewards, all_rewards)
|
|
|
|
else
|
|
|
|
normalize(trial_rewards)
|
|
|
|
end
|
2017-06-29 02:50:33 -07:00
|
|
|
--print('normalized:', trial_rewards)
|
2017-06-28 21:51:02 -07:00
|
|
|
|
|
|
|
local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
|
|
|
|
|
|
|
|
local step = nn.zeros(#base_params)
|
|
|
|
for i = 1, epoch_trials do
|
|
|
|
local reward = trial_rewards[i]
|
|
|
|
local noise = trial_noise[i]
|
|
|
|
for j, v in ipairs(noise) do
|
|
|
|
step[j] = step[j] + reward * v
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
local magnitude = learning_rate / deviation
|
2017-06-29 02:50:33 -07:00
|
|
|
--print('stepping with magnitude', magnitude)
|
2017-06-28 21:51:02 -07:00
|
|
|
-- throw the division from the averaging in there too.
|
|
|
|
local altogether = magnitude / epoch_trials
|
|
|
|
for i, v in ipairs(step) do
|
|
|
|
step[i] = altogether * v
|
|
|
|
end
|
|
|
|
|
|
|
|
local step_mean, step_dev = calc_mean_dev(step)
|
2017-06-29 02:50:33 -07:00
|
|
|
if step_dev < 1e-8 then
|
2017-06-28 21:51:02 -07:00
|
|
|
-- we didn't get anywhere. step in a random direction.
|
2017-06-29 02:50:33 -07:00
|
|
|
print("stepping randomly.")
|
2017-06-28 21:51:02 -07:00
|
|
|
local noise = trial_noise[1]
|
2017-06-29 02:50:33 -07:00
|
|
|
local devsqrt = sqrt(deviation)
|
|
|
|
for i, v in ipairs(step) do
|
|
|
|
step[i] = devsqrt * noise[i]
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
2017-06-29 02:50:33 -07:00
|
|
|
|
|
|
|
step_mean, step_dev = calc_mean_dev(step)
|
|
|
|
end
|
|
|
|
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
|
|
|
|
print("step stddev:", step_dev)
|
|
|
|
|
|
|
|
for i, v in ipairs(base_params) do
|
|
|
|
base_params[i] = v + step[i]
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
2017-06-28 21:51:02 -07:00
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
if enable_network then
|
|
|
|
network:distribute(base_params)
|
|
|
|
network:save()
|
|
|
|
else
|
|
|
|
print("note: not updating weights in playable mode.")
|
|
|
|
end
|
2017-06-28 21:51:02 -07:00
|
|
|
|
|
|
|
print()
|
|
|
|
end
|
|
|
|
|
|
|
|
local function do_reset()
|
2017-09-07 11:41:44 -07:00
|
|
|
local state = get_state()
|
|
|
|
-- be a little more descriptive.
|
|
|
|
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
|
|
|
print("resetting in state: "..state..". reward:", reward)
|
2017-06-28 21:51:02 -07:00
|
|
|
|
|
|
|
if trial_i > 0 then trial_rewards[trial_i] = reward end
|
|
|
|
|
|
|
|
if epoch_i == 0 or trial_i == epoch_trials then
|
2017-06-29 02:50:33 -07:00
|
|
|
if epoch_i > 0 then learn_from_epoch() end
|
2017-06-28 21:51:02 -07:00
|
|
|
epoch_i = epoch_i + 1
|
|
|
|
prepare_epoch()
|
|
|
|
end
|
2017-06-28 17:14:56 -07:00
|
|
|
|
|
|
|
-- bit of a hack:
|
2017-06-28 21:51:02 -07:00
|
|
|
if get_state() == 'loading' then advance() end
|
2017-06-28 17:14:56 -07:00
|
|
|
reward = 0
|
|
|
|
powerup_old = R(0x754)
|
|
|
|
status_old = R(0x756)
|
|
|
|
coins_old = R(0x7ED) * 10 + R(0x7EE)
|
2017-07-05 20:26:27 -07:00
|
|
|
score_old = get_score()
|
2017-06-28 17:14:56 -07:00
|
|
|
|
|
|
|
-- set lives to 0. you only got one shot!
|
|
|
|
-- unless you get a 1-up, in which case, please continue!
|
|
|
|
W(0x75A, 0)
|
|
|
|
|
2017-06-29 02:50:33 -07:00
|
|
|
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
2017-07-05 20:26:27 -07:00
|
|
|
max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
|
|
|
|
max_time = ceil(max_time)
|
2017-06-29 02:50:33 -07:00
|
|
|
|
|
|
|
if once then
|
|
|
|
savestate.load(startsave)
|
|
|
|
--print("end of trial reward:", reward)
|
|
|
|
else
|
|
|
|
savestate.save(startsave)
|
|
|
|
end
|
|
|
|
once = true
|
|
|
|
|
2017-06-28 17:14:56 -07:00
|
|
|
emu.frameadvance() -- prevents emulator from quirking up.
|
|
|
|
|
2017-06-29 02:50:33 -07:00
|
|
|
--print()
|
2017-06-28 21:51:02 -07:00
|
|
|
load_next_trial()
|
|
|
|
|
2017-06-28 17:14:56 -07:00
|
|
|
reset = false
|
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local function init()
|
2017-07-05 20:26:27 -07:00
|
|
|
network = make_network(input_size, learn_start_select and 8 or 6)
|
2017-06-28 21:51:02 -07:00
|
|
|
network:reset()
|
|
|
|
print("parameters:", network.n_param)
|
|
|
|
|
|
|
|
emu.poweron()
|
|
|
|
emu.unpause()
|
2017-06-29 02:50:33 -07:00
|
|
|
emu.speedmode("turbo")
|
|
|
|
|
|
|
|
local res, err = pcall(network.load, network)
|
|
|
|
if res == false then print(err) end
|
2017-06-28 21:51:02 -07:00
|
|
|
end
|
|
|
|
|
|
|
|
init()
|
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
local dummy_softmax_values = {0, 0}
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
while true do
|
2017-06-28 21:51:02 -07:00
|
|
|
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
|
|
|
|
|
|
|
while bad_states[get_state()] do
|
|
|
|
--gui.text(120, 124, ("%02X"):format(R(0xE)), '#FFFFFF', '#0000003F')
|
2017-06-28 02:33:18 -07:00
|
|
|
-- mash the start button until we have control.
|
|
|
|
-- TODO: learn this too.
|
|
|
|
--local jp = joypad.read(1)
|
|
|
|
local jp = {
|
|
|
|
up = false,
|
|
|
|
down = false,
|
|
|
|
left = false,
|
|
|
|
right = false,
|
|
|
|
A = false,
|
|
|
|
B = false,
|
|
|
|
select = false,
|
|
|
|
start = emu.framecount() % 2 == 1,
|
|
|
|
}
|
|
|
|
joypad.write(1, jp)
|
|
|
|
|
|
|
|
reset = true
|
|
|
|
|
|
|
|
advance()
|
2017-06-28 21:51:02 -07:00
|
|
|
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
2017-06-28 02:33:18 -07:00
|
|
|
|
|
|
|
-- bit of a hack:
|
2017-06-28 21:51:02 -07:00
|
|
|
while get_state() == "loading" do advance() end
|
|
|
|
state_old = get_state()
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
if reset then do_reset() end
|
2017-06-28 02:33:18 -07:00
|
|
|
|
|
|
|
if not enable_network then
|
|
|
|
-- infinite time cheat. super handy for testing.
|
|
|
|
if R(0xE) == 8 then
|
2017-06-28 21:51:02 -07:00
|
|
|
set_timer(667)
|
2017-06-28 02:33:18 -07:00
|
|
|
poketime = true
|
|
|
|
elseif poketime then
|
|
|
|
poketime = false
|
2017-06-28 21:51:02 -07:00
|
|
|
set_timer(1)
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
|
|
|
|
|
|
|
-- infinite lives.
|
|
|
|
W(0x75A, 1)
|
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
empty(sprite_input)
|
|
|
|
empty(tile_input)
|
2017-07-05 20:26:27 -07:00
|
|
|
empty(extra_input)
|
2017-06-28 02:33:18 -07:00
|
|
|
|
|
|
|
-- player
|
2017-06-28 21:51:02 -07:00
|
|
|
-- TODO: check if mario is playable.
|
2017-06-28 02:33:18 -07:00
|
|
|
local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
|
|
|
local powerup = R(0x754)
|
|
|
|
local status = R(0x756)
|
2017-06-28 21:51:02 -07:00
|
|
|
mark_sprite(x + 8, y + 24, -powerup - 1)
|
2017-06-28 02:33:18 -07:00
|
|
|
|
2017-07-05 20:26:27 -07:00
|
|
|
local vx, vy = S(0x57), S(0x9F)
|
|
|
|
insert(extra_input, vx)
|
|
|
|
insert(extra_input, vy)
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
handle_enemies()
|
|
|
|
handle_fireballs()
|
|
|
|
-- blocks being hit. not interactable; we don't care!
|
|
|
|
--handle_blocks()
|
|
|
|
handle_hammers()
|
|
|
|
handle_misc()
|
|
|
|
handle_tiles()
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
local ingame_paused = get_state() == "paused"
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
local coins = R(0x7ED) * 10 + R(0x7EE)
|
|
|
|
local coins_delta = coins - coins_old
|
|
|
|
-- handle wrap-around.
|
|
|
|
if coins_delta < 0 then coins_delta = 100 + coins - coins_old end
|
|
|
|
-- remember that 0 is big mario and 1 is small mario.
|
|
|
|
local powerup_delta = powerup_old - powerup
|
|
|
|
-- 2 is fire mario.
|
|
|
|
local status_delta = clamp(status - status_old, -1, 1)
|
|
|
|
local screen_scroll_delta = R(0x775)
|
2017-06-28 21:51:02 -07:00
|
|
|
local flagpole_bonus = R(0xE) == 4 and 1 or 0
|
2017-07-05 20:26:27 -07:00
|
|
|
--local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
|
|
|
local score_delta = get_score() - score_old
|
|
|
|
if score_delta < 0 then score_delta = 0 end
|
|
|
|
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
|
2017-06-28 02:33:18 -07:00
|
|
|
|
2017-06-29 02:50:33 -07:00
|
|
|
-- TODO: add ingame score to reward.
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
if not ingame_paused then reward = reward + reward_delta end
|
2017-06-28 02:33:18 -07:00
|
|
|
|
|
|
|
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
|
|
|
|
--gui.text(4, 22, ("%02X"):format(#tile_input), '#FFFFFF', '#0000003F')
|
2017-07-05 20:26:27 -07:00
|
|
|
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
|
|
|
--gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
|
|
|
gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
2017-06-28 02:33:18 -07:00
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
if get_state() == 'dead' and state_old ~= 'dead' then
|
2017-06-29 02:50:33 -07:00
|
|
|
--print("dead. lives remaining:", R(0x75A, 0))
|
2017-06-28 02:33:18 -07:00
|
|
|
if R(0x75A, 0) == 0 then reset = true end
|
|
|
|
end
|
2017-06-28 21:51:02 -07:00
|
|
|
if get_state() == 'lose' then
|
2017-06-28 02:33:18 -07:00
|
|
|
print("ran out of lives.")
|
|
|
|
reset = true
|
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
-- lose a point for every frame paused.
|
2017-06-29 02:50:33 -07:00
|
|
|
--if ingame_paused then reward = reward - 1 end
|
|
|
|
if ingame_paused then reward = reward - 402; reset = true end
|
2017-06-28 21:51:02 -07:00
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
-- every few frames mario stands still, forcibly decrease the timer.
|
2017-06-28 21:51:02 -07:00
|
|
|
-- this includes having the game paused.
|
|
|
|
-- TODO: more robust. doesn't detect moonwalking against a wall.
|
|
|
|
local timer = get_timer()
|
2017-06-29 02:50:33 -07:00
|
|
|
if ingame_paused or random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
|
|
|
timer = timer - 1
|
2017-06-28 02:33:18 -07:00
|
|
|
end
|
2017-06-29 02:50:33 -07:00
|
|
|
timer = clamp(timer, 0, max_time)
|
2017-07-05 20:26:27 -07:00
|
|
|
if enable_network then
|
|
|
|
set_timer(timer)
|
|
|
|
end
|
2017-06-28 02:33:18 -07:00
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
-- if we've run out of time while the game is paused...
|
|
|
|
-- that's cheating! unpause.
|
|
|
|
force_start = ingame_paused and timer == 0
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
local X = {} -- TODO: cache.
|
|
|
|
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
|
|
|
for i, v in ipairs(tile_input) do insert(X, v / 256) end
|
2017-07-05 20:26:27 -07:00
|
|
|
for i, v in ipairs(extra_input) do insert(X, v / 256) end
|
|
|
|
if #X ~= input_size then error("input size should be: "..tostring(#X)) end
|
2017-06-28 02:33:18 -07:00
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
if enable_network and get_state() == 'playing' or ingame_paused then
|
2017-07-05 20:26:27 -07:00
|
|
|
local choose = deterministic and argmax2 or rchoice2
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
local outputs = network:forward(X)
|
2017-07-05 20:26:27 -07:00
|
|
|
|
|
|
|
-- TODO: predict the *rewards* of all possible actions?
|
|
|
|
-- that's how DQN seems to work anyway.
|
|
|
|
-- ah, but A3C just returns probabilities,
|
|
|
|
-- besides the critic?
|
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
local softmaxed = {
|
|
|
|
outputs[nn_z[1]],
|
|
|
|
outputs[nn_z[2]],
|
|
|
|
outputs[nn_z[3]],
|
|
|
|
outputs[nn_z[4]],
|
|
|
|
outputs[nn_z[5]],
|
|
|
|
outputs[nn_z[6]],
|
2017-07-05 20:26:27 -07:00
|
|
|
learn_start_select and outputs[nn_z[7]] or dummy_softmax_values,
|
|
|
|
learn_start_select and outputs[nn_z[8]] or dummy_softmax_values,
|
2017-06-28 02:33:18 -07:00
|
|
|
}
|
2017-07-05 20:26:27 -07:00
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
local jp = {
|
2017-07-05 20:26:27 -07:00
|
|
|
up = choose(softmaxed[1]),
|
|
|
|
down = choose(softmaxed[2]),
|
|
|
|
left = choose(softmaxed[3]),
|
|
|
|
right = choose(softmaxed[4]),
|
|
|
|
A = choose(softmaxed[5]),
|
|
|
|
B = choose(softmaxed[6]),
|
|
|
|
start = choose(softmaxed[7]),
|
|
|
|
select = choose(softmaxed[8]),
|
2017-06-28 02:33:18 -07:00
|
|
|
}
|
2017-07-05 20:26:27 -07:00
|
|
|
|
|
|
|
if det_epsilon then
|
|
|
|
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
|
|
|
|
for k, v in pairs(jp) do
|
|
|
|
local ss_ok = k ~= 'start' and k ~= 'select' or learn_start_select
|
|
|
|
if random() < eps and ss_ok then jp[k] = rbool() end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2017-06-28 21:51:02 -07:00
|
|
|
if force_start then
|
|
|
|
jp = {
|
|
|
|
up = false,
|
|
|
|
down = false,
|
|
|
|
left = false,
|
|
|
|
right = false,
|
|
|
|
A = false,
|
|
|
|
B = false,
|
|
|
|
start = force_start_old,
|
|
|
|
select = false,
|
|
|
|
}
|
|
|
|
end
|
2017-07-05 20:26:27 -07:00
|
|
|
|
2017-06-28 02:33:18 -07:00
|
|
|
joypad.write(1, jp)
|
|
|
|
end
|
|
|
|
|
|
|
|
coins_old = coins
|
|
|
|
powerup_old = powerup
|
|
|
|
status_old = status
|
2017-06-28 21:51:02 -07:00
|
|
|
force_start_old = force_start
|
|
|
|
state_old = get_state()
|
2017-07-05 20:26:27 -07:00
|
|
|
score_old = get_score()
|
2017-06-28 02:33:18 -07:00
|
|
|
advance()
|
|
|
|
end
|