refactor game and utility functions
This commit is contained in:
parent
7f34de8e7c
commit
a836314b8b
5 changed files with 549 additions and 481 deletions
|
@ -66,41 +66,6 @@ local gcfg = {
|
|||
select = false, start = false, B = false, A = false,
|
||||
},
|
||||
},
|
||||
|
||||
rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
|
||||
0, -40, -- 0x00
|
||||
6, -38,
|
||||
15, -37,
|
||||
22, -32,
|
||||
28, -28,
|
||||
32, -22,
|
||||
37, -14,
|
||||
39, -6,
|
||||
40, 0, -- 0x08
|
||||
38, 7,
|
||||
37, 15,
|
||||
33, 23,
|
||||
27, 29,
|
||||
22, 33,
|
||||
14, 37,
|
||||
6, 39,
|
||||
0, 41, -- 0x10
|
||||
-7, 40,
|
||||
-16, 38,
|
||||
-22, 34,
|
||||
-28, 28,
|
||||
-34, 23,
|
||||
-38, 16,
|
||||
-40, 8,
|
||||
-40, -0, -- 0x18
|
||||
-40, -6,
|
||||
-38, -14,
|
||||
-34, -22,
|
||||
-28, -28,
|
||||
-22, -32,
|
||||
-16, -36,
|
||||
-8, -38,
|
||||
},
|
||||
}
|
||||
|
||||
return setmetatable(gcfg, {
|
||||
|
|
493
main.lua
493
main.lua
|
@ -31,10 +31,6 @@ local startsave = savestate.create(1)
|
|||
local poketime = false
|
||||
local max_time
|
||||
|
||||
local sprite_input = {}
|
||||
local tile_input = {}
|
||||
local extra_input = {}
|
||||
|
||||
local jp
|
||||
|
||||
local screen_scroll_delta
|
||||
|
@ -74,9 +70,6 @@ local insert = table.insert
|
|||
local remove = table.remove
|
||||
local unpack = table.unpack or unpack
|
||||
local sort = table.sort
|
||||
local R = memory.readbyteunsigned
|
||||
local S = memory.readbyte --signed
|
||||
local W = memory.writebyte
|
||||
|
||||
local band = bit.band
|
||||
local bor = bit.bor
|
||||
|
@ -90,6 +83,16 @@ local ror = bit.ror
|
|||
|
||||
local gui = gui
|
||||
|
||||
local util = require("util")
|
||||
local argmax = util.argmax
|
||||
local calc_mean_dev = util.calc_mean_dev
|
||||
local clamp = util.clamp
|
||||
local copy = util.copy
|
||||
local empty = util.empty
|
||||
local lerp = util.lerp
|
||||
local softchoice = util.softchoice
|
||||
local unperturbed_rank = util.unperturbed_rank
|
||||
|
||||
-- utilities.
|
||||
|
||||
local log_map = {
|
||||
|
@ -126,97 +129,6 @@ local function log_csv(t)
|
|||
f:close()
|
||||
end
|
||||
|
||||
local function boolean_xor(a, b)
|
||||
if a and b then return false end
|
||||
if not a and not b then return false end
|
||||
return true
|
||||
end
|
||||
|
||||
local function signbyte(x)
|
||||
if x >= 128 then x = 256 - x end
|
||||
return x
|
||||
end
|
||||
|
||||
local _invlog2 = 1 / log(2)
|
||||
local function log2(x) return log(x) * _invlog2 end
|
||||
|
||||
local function clamp(x, l, u) return min(max(x, l), u) end
|
||||
|
||||
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
|
||||
|
||||
local function argmax(...)
|
||||
local max_i = 0
|
||||
local max_v = -999999999
|
||||
for i=1, select("#", ...) do
|
||||
local v = select(i, ...)
|
||||
if v > max_v then
|
||||
max_i = i
|
||||
max_v = v
|
||||
end
|
||||
end
|
||||
return max_i
|
||||
end
|
||||
|
||||
local function softchoice(...)
|
||||
local t = random()
|
||||
local psum = 0
|
||||
for i=1, select("#", ...) do
|
||||
local p = select(i, ...)
|
||||
psum = psum + p
|
||||
if t < psum then
|
||||
return i
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function argmax2(t)
|
||||
return t[1] > t[2]
|
||||
end
|
||||
|
||||
local function rchoice2(t)
|
||||
return t[1] > random()
|
||||
end
|
||||
|
||||
local function rbool()
|
||||
return 0.5 >= random()
|
||||
end
|
||||
|
||||
local function empty(t)
|
||||
for k, _ in pairs(t) do t[k] = nil end
|
||||
return t
|
||||
end
|
||||
|
||||
local function calc_mean_dev(x)
|
||||
local mean = 0
|
||||
for i, v in ipairs(x) do
|
||||
mean = mean + v / #x
|
||||
end
|
||||
|
||||
local dev = 0
|
||||
for i, v in ipairs(x) do
|
||||
local delta = v - mean
|
||||
dev = dev + delta * delta / #x
|
||||
end
|
||||
|
||||
return mean, sqrt(dev)
|
||||
end
|
||||
|
||||
local function normalize(x, out)
|
||||
out = out or x
|
||||
local mean, dev = calc_mean_dev(x)
|
||||
if dev <= 0 then dev = 1 end
|
||||
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
||||
return out
|
||||
end
|
||||
|
||||
local function normalize_wrt(x, s, out)
|
||||
out = out or x
|
||||
local mean, dev = calc_mean_dev(s)
|
||||
if dev <= 0 then dev = 1 end
|
||||
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
||||
return out
|
||||
end
|
||||
|
||||
-- network parameters.
|
||||
|
||||
package.loaded['nn'] = nil -- DEBUG
|
||||
|
@ -247,219 +159,10 @@ local function make_network(input_size)
|
|||
end
|
||||
|
||||
-- and here we go with the game stuff.
|
||||
-- which was all refactored out, so this comment looks a little silly now.
|
||||
|
||||
-- disassembly used for reference:
|
||||
-- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
|
||||
|
||||
local function get_timer()
|
||||
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
||||
end
|
||||
|
||||
local function get_score()
|
||||
return R(0x7DE) * 10000 +
|
||||
R(0x7DF) * 1000 +
|
||||
R(0x7E0) * 100 +
|
||||
R(0x7E1) * 10 +
|
||||
R(0x7E2)
|
||||
end
|
||||
|
||||
local function set_timer(time)
|
||||
W(0x7F8, floor(time / 100))
|
||||
W(0x7F9, floor((time / 10) % 10))
|
||||
W(0x7FA, floor(time % 10))
|
||||
end
|
||||
|
||||
local function mark_sprite(x, y, t)
|
||||
if x < 0 or x >= 256 or y < 0 or y > 224 then
|
||||
sprite_input[#sprite_input+1] = 0
|
||||
sprite_input[#sprite_input+1] = 0
|
||||
sprite_input[#sprite_input+1] = 0
|
||||
else
|
||||
sprite_input[#sprite_input+1] = x
|
||||
sprite_input[#sprite_input+1] = y
|
||||
sprite_input[#sprite_input+1] = t
|
||||
end
|
||||
if t == 0 then return end
|
||||
if cfg.enable_overlay then
|
||||
gui.box(x-4, y-4, x+4, y+4)
|
||||
--gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000')
|
||||
gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F')
|
||||
--gui.text(x-5, y-3+9, ("%02X"):format(x), '#FFFFFF', '#0000003F')
|
||||
end
|
||||
end
|
||||
|
||||
local function mark_tile(x, y, t)
|
||||
tile_input[#tile_input+1] = t
|
||||
if t == 0 then return end
|
||||
if cfg.enable_overlay then
|
||||
gui.box(x-8, y-8, x+8, y+8)
|
||||
gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
|
||||
end
|
||||
end
|
||||
|
||||
local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr)
|
||||
local spl_l = R(0x71A)
|
||||
local spl_r = R(0x71B)
|
||||
local sx_l = R(0x71C)
|
||||
local sx_r = R(0x71D)
|
||||
|
||||
local x = R(x_addr + i)
|
||||
local y = R(y_addr + i)
|
||||
local sx, sy = x, y
|
||||
if pageloc_addr ~= nil then
|
||||
local page = R(pageloc_addr + i)
|
||||
sx = sx - sx_l - (spl_l - page) * 256
|
||||
else
|
||||
sx = sx - sx_l
|
||||
end
|
||||
if hipos_addr ~= nil then
|
||||
local hipos = S(hipos_addr + i)
|
||||
sy = sy + (signbyte(hipos) - 1) * 256
|
||||
end
|
||||
|
||||
return sx, sy
|
||||
end
|
||||
|
||||
local function paused() return band(R(0x776), 1) end
|
||||
|
||||
local function get_state()
|
||||
if R(0xE) == 0xFF then return 'power' end
|
||||
if R(0x774) > 0 then return 'lagging' end
|
||||
if R(0x7A2) > 0 then return 'waiting_demo' end
|
||||
if R(0x717) > 0 then return 'playing_demo' end
|
||||
-- if R(0x770) == 0xFF then return 'power' end
|
||||
if paused() ~= 0 then return 'paused' end
|
||||
if R(0xE) == 0 then return 'world_screen' end
|
||||
-- if R(0x712) == 1 then return 'deadmusic' end
|
||||
if R(0x7CA) == 0x94 then return 'dead' end
|
||||
if R(0xE) == 4 then return 'win_flagpole' end
|
||||
if R(0xE) == 5 then return 'win_walking' end
|
||||
if R(0xE) == 6 then return 'lose' end
|
||||
-- if R(0x770) == 0 then return 'not_playing' end
|
||||
if R(0x770) == 2 then return 'win_castle' end
|
||||
if R(0x772) == 2 then return 'no_control' end
|
||||
if R(0x772) == 3 then return 'playing' end
|
||||
if R(0x770) == 1 then return 'loading' end
|
||||
if R(0x770) == 3 then return 'lose' end
|
||||
return 'unknown'
|
||||
end
|
||||
|
||||
local function advance()
|
||||
emu.frameadvance()
|
||||
while emu.lagged() do emu.frameadvance() end -- skip lag frames.
|
||||
while R(0x774) > 0 do emu.frameadvance() end -- also lag frames.
|
||||
end
|
||||
|
||||
local function handle_enemies()
|
||||
-- enemies, flagpole
|
||||
for i = 0, 5 do
|
||||
local x, y = getxy(i, 0x87, 0xCF, 0x6E, 0xB6)
|
||||
x, y = x + 8, y + 16
|
||||
local tid = R(0x16 + i)
|
||||
local flags = R(0xF + i)
|
||||
--local offscr = R(0x3D8 + i)
|
||||
local invisible = tid < 0x10 and flags == 0
|
||||
if tid == 0x30 then y = y - 8 end -- flagpole flag
|
||||
if tid == 0x31 then y = y - 8 end -- castle flag
|
||||
if tid == 0x16 then x, y = x - 4, y - 12 end -- fireworks
|
||||
if tid >= 0x24 and tid <= 0x29 then x, y = x + 16, y - 12 end -- moving platforms
|
||||
if tid == 0x2D then x, y = x, y end -- bowser (TODO: determine head or body)
|
||||
if tid == 0x15 then x, y = x, y - 12 end -- bowser fire
|
||||
if tid == 0x32 then x, y = x, y - 8 end -- spring
|
||||
-- tid == 0x35 -- toad
|
||||
if tid == 0x1D or tid == 0x1B then -- rotating fire bars
|
||||
x, y = x - 4, y - 12
|
||||
-- this is a mess... gotta find out its rotation and then project.
|
||||
-- TODO: handle long fire bars too
|
||||
local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i)
|
||||
gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F')
|
||||
local x_off, y_off = gcfg.rotation_offsets[rot*2+1], gcfg.rotation_offsets[rot*2+2]
|
||||
x, y = x + x_off, y + y_off
|
||||
end
|
||||
if invisible then
|
||||
mark_sprite(0, 0, 0)
|
||||
else
|
||||
mark_sprite(x, y, tid + 1)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_fireballs()
|
||||
for i = 0, 1 do
|
||||
local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC)
|
||||
x, y = x + 4, y + 4
|
||||
local state = R(0x24 + i)
|
||||
local invisible = state == 0
|
||||
if invisible then
|
||||
mark_sprite(0, 0, 0)
|
||||
else
|
||||
mark_sprite(x, y, 257)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_blocks()
|
||||
for i = 0, 3 do
|
||||
local x, y = getxy(i, 0x8F, 0xD7, 0x76, 0xBE)
|
||||
x, y = x + 8, y + 8
|
||||
local state = R(0x26 + i)
|
||||
local invisible = state == 0
|
||||
if invisible then
|
||||
mark_sprite(0, 0, 0)
|
||||
else
|
||||
mark_sprite(x, y, 258)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_hammers()
|
||||
-- hammers, coins, score bonus text...
|
||||
for i = 0, 8 do
|
||||
local x, y = getxy(i, 0x93, 0xDB, 0x7A, 0xC2)
|
||||
x, y = x + 8, y + 8
|
||||
local state = R(0x2A + i)
|
||||
-- skip coin effect states. not interactable; we don't care!
|
||||
if state ~= 0 and state >= 0x30 then
|
||||
mark_sprite(x, y, state + 1)
|
||||
else
|
||||
mark_sprite(0, 0, 0)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_misc()
|
||||
for i = 0, 0 do
|
||||
local x, y = getxy(i, 0x9C, 0xE4, 0x83, 0xCB)
|
||||
x, y = x + 8, y + 8
|
||||
local state = R(0x33 + i)
|
||||
if state ~= 0 then
|
||||
mark_sprite(x, y, state + 1)
|
||||
else
|
||||
mark_sprite(0, 0, 0)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_tiles()
|
||||
--local tile_col = R(0x6A0)
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
extra_input[#extra_input+1] = tile_scroll_remainder
|
||||
for y = 0, 12 do
|
||||
for x = 0, 16 do
|
||||
local col = (x + tile_scroll) % 32
|
||||
local t
|
||||
if col < 16 then
|
||||
t = R(0x500 + y * 16 + (col % 16))
|
||||
else
|
||||
t = R(0x5D0 + y * 16 + (col % 16))
|
||||
end
|
||||
local sx = x * 16 + 8 - tile_scroll_remainder
|
||||
local sy = y * 16 + 40
|
||||
mark_tile(sx, sy, t)
|
||||
end
|
||||
end
|
||||
end
|
||||
local game = require("smb")
|
||||
game.overlay = cfg.enable_overlay
|
||||
|
||||
-- learning and evaluation.
|
||||
|
||||
|
@ -514,7 +217,7 @@ local function load_next_pair()
|
|||
trial_neg = true
|
||||
end
|
||||
|
||||
local W = nn.copy(base_params)
|
||||
local W = copy(base_params)
|
||||
|
||||
if trial_i > 0 then
|
||||
if trial_neg then
|
||||
|
@ -544,7 +247,7 @@ end
|
|||
local function load_next_trial()
|
||||
if cfg.negate_trials then return load_next_pair() end
|
||||
trial_i = trial_i + 1
|
||||
local W = nn.copy(base_params)
|
||||
local W = copy(base_params)
|
||||
if trial_i == 0 and not cfg.unperturbed_trial then
|
||||
trial_i = 1
|
||||
end
|
||||
|
@ -560,41 +263,6 @@ local function load_next_trial()
|
|||
network:distribute(W)
|
||||
end
|
||||
|
||||
local function fitness_shaping(rewards)
|
||||
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
||||
local decreasing = nn.copy(rewards)
|
||||
sort(decreasing, function(a, b) return a > b end)
|
||||
local shaped_returns = {}
|
||||
local lamb = #rewards
|
||||
|
||||
local denom = 0
|
||||
for i, v in ipairs(rewards) do
|
||||
local l = log2(lamb / 2 + 1)
|
||||
local r = log2(nn.indexof(decreasing, v))
|
||||
denom = denom + max(0, l - r)
|
||||
end
|
||||
|
||||
for i, v in ipairs(rewards) do
|
||||
local l = log2(lamb / 2 + 1)
|
||||
local r = log2(nn.indexof(decreasing, v))
|
||||
local numer = max(0, l - r)
|
||||
insert(shaped_returns, numer / denom + 1 / lamb)
|
||||
end
|
||||
|
||||
return shaped_returns
|
||||
end
|
||||
|
||||
local function unperturbed_rank(rewards, unperturbed_reward)
|
||||
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
||||
local nth_place = 1
|
||||
for i, v in ipairs(rewards) do
|
||||
if v > unperturbed_reward then
|
||||
nth_place = nth_place + 1
|
||||
end
|
||||
end
|
||||
return nth_place
|
||||
end
|
||||
|
||||
local function learn_from_epoch()
|
||||
print()
|
||||
--print('rewards:', trial_rewards)
|
||||
|
@ -632,7 +300,7 @@ local function learn_from_epoch()
|
|||
best_rewards[i] = max(pos, neg)
|
||||
end
|
||||
else
|
||||
best_rewards = nn.copy(trial_rewards)
|
||||
best_rewards = copy(trial_rewards)
|
||||
end
|
||||
|
||||
local indices = {}
|
||||
|
@ -750,15 +418,6 @@ local function learn_from_epoch()
|
|||
test_trial = current_cost or 0,
|
||||
}
|
||||
|
||||
-- trying a heuristic...
|
||||
--[[
|
||||
if delta_std < trial_std then
|
||||
cfg.deviation = cfg.deviation * 0.933
|
||||
-- this one might be bad...
|
||||
cfg.weight_decay = cfg.weight_decay * 0.933
|
||||
end
|
||||
--]]
|
||||
|
||||
if cfg.enable_network then
|
||||
network:distribute(base_params)
|
||||
network:save(cfg.params_fn)
|
||||
|
@ -786,9 +445,9 @@ local function joypad_mash(button)
|
|||
end
|
||||
|
||||
local function do_reset()
|
||||
local state = get_state()
|
||||
local state = game.get_state()
|
||||
-- be a little more descriptive.
|
||||
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
||||
if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end
|
||||
|
||||
if trial_i >= 0 and cfg.defer_prints then
|
||||
if trial_i == 0 then
|
||||
|
@ -825,20 +484,20 @@ local function do_reset()
|
|||
prepare_epoch()
|
||||
end
|
||||
|
||||
if get_state() == 'loading' then advance() end -- kind of a hack.
|
||||
if game.get_state() == 'loading' then game.advance() end -- kind of a hack.
|
||||
reward = 0
|
||||
powerup_old = R(0x754)
|
||||
status_old = R(0x756)
|
||||
coins_old = R(0x7ED) * 10 + R(0x7EE)
|
||||
score_old = get_score()
|
||||
powerup_old = game.R(0x754)
|
||||
status_old = game.R(0x756)
|
||||
coins_old = game.R(0x7ED) * 10 + game.R(0x7EE)
|
||||
score_old = game.get_score()
|
||||
|
||||
-- set number of lives. (mario gets n+1 chances)
|
||||
W(0x75A, cfg.starting_lives)
|
||||
game.W(0x75A, cfg.starting_lives)
|
||||
|
||||
if cfg.start_big then
|
||||
-- make mario "super".
|
||||
W(0x754, 0)
|
||||
W(0x756, 1)
|
||||
game.W(0x754, 0)
|
||||
game.W(0x756, 1)
|
||||
end
|
||||
|
||||
--max_time = min(log(epoch_i) * 10 + 100, cfg.cap_time)
|
||||
|
@ -896,20 +555,20 @@ local function prepare_reset()
|
|||
end
|
||||
|
||||
local function doit(dummy)
|
||||
local ingame_paused = get_state() == "paused"
|
||||
local ingame_paused = game.get_state() == "paused"
|
||||
|
||||
-- every few frames mario stands still, forcibly decrease the timer.
|
||||
-- this includes having the game paused.
|
||||
-- TODO: more robust. doesn't detect moonwalking against a wall.
|
||||
-- well, that shouldn't happen anymore now that i've disabled left+right.
|
||||
local timer = get_timer()
|
||||
if ingame_paused or random() > 1 - cfg.timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
||||
local timer = game.get_timer()
|
||||
if ingame_paused or random() > 1 - cfg.timer_loser and game.R(0x1D) == 0 and game.R(0x57) == 0 then
|
||||
timer = timer - 1
|
||||
end
|
||||
if not cfg.playback_mode then
|
||||
timer = clamp(timer, 0, max_time)
|
||||
if cfg.enable_network then
|
||||
set_timer(timer)
|
||||
game.set_timer(timer)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -918,7 +577,7 @@ local function doit(dummy)
|
|||
local tf2 = (total_frames - tf0 - tf1) / 1000000
|
||||
gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F')
|
||||
|
||||
screen_scroll_delta = screen_scroll_delta + R(0x775)
|
||||
screen_scroll_delta = screen_scroll_delta + game.R(0x775)
|
||||
|
||||
if dummy == true then
|
||||
-- don't invoke AI this frame. (keep holding the old inputs)
|
||||
|
@ -926,37 +585,35 @@ local function doit(dummy)
|
|||
return
|
||||
end
|
||||
|
||||
empty(sprite_input)
|
||||
empty(tile_input)
|
||||
empty(extra_input)
|
||||
empty(game.sprite_input)
|
||||
empty(game.tile_input)
|
||||
empty(game.extra_input)
|
||||
|
||||
-- TODO: check if mario is in a playable state.
|
||||
local x, y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
||||
local powerup = R(0x754)
|
||||
local status = R(0x756)
|
||||
mark_sprite(x + 8, y + 24, -powerup - 1)
|
||||
local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
||||
local powerup = game.R(0x754)
|
||||
local status = game.R(0x756)
|
||||
game.mark_sprite(x + 8, y + 24, -powerup - 1)
|
||||
|
||||
local vx, vy = S(0x57), S(0x9F)
|
||||
-- i shouldn't need to do this if it's signed, but apparently...
|
||||
insert(extra_input, signbyte(vx) * 16)
|
||||
insert(extra_input, signbyte(vy) * 16)
|
||||
local vx, vy = game.S(0x57), game.S(0x9F)
|
||||
insert(game.extra_input, vx * 16)
|
||||
insert(game.extra_input, vy * 16)
|
||||
|
||||
if cfg.time_inputs then
|
||||
for i=2,5 do
|
||||
local v = band(trial_frames, lshift(1, i)) == 0 and -181 or 181
|
||||
insert(extra_input, v)
|
||||
insert(game.extra_input, v)
|
||||
end
|
||||
end
|
||||
|
||||
handle_enemies()
|
||||
handle_fireballs()
|
||||
-- blocks being hit. not interactable; we don't care!
|
||||
--handle_blocks()
|
||||
handle_hammers()
|
||||
handle_misc()
|
||||
handle_tiles()
|
||||
game.handle_enemies()
|
||||
game.handle_fireballs()
|
||||
--game.handle_blocks() -- blocks being hit. not interactable; we don't care!
|
||||
game.handle_hammers()
|
||||
game.handle_misc()
|
||||
game.handle_tiles()
|
||||
|
||||
local coins = R(0x7ED) * 10 + R(0x7EE)
|
||||
local coins = game.R(0x7ED) * 10 + game.R(0x7EE)
|
||||
local coins_delta = coins - coins_old
|
||||
-- handle wrap-around.
|
||||
if coins_delta < 0 then coins_delta = 100 + coins - coins_old end
|
||||
|
@ -964,9 +621,9 @@ local function doit(dummy)
|
|||
local powerup_delta = powerup_old - powerup
|
||||
-- 2 is fire mario.
|
||||
local status_delta = clamp(status - status_old, -1, 1)
|
||||
local flagpole_bonus = R(0xE) == 4 and cfg.frameskip or 0
|
||||
local flagpole_bonus = game.R(0xE) == 4 and cfg.frameskip or 0
|
||||
--local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
||||
local score_delta = get_score() - score_old
|
||||
local score_delta = game.get_score() - score_old
|
||||
if score_delta < 0 then score_delta = 0 end
|
||||
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
|
||||
screen_scroll_delta = 0
|
||||
|
@ -978,11 +635,11 @@ local function doit(dummy)
|
|||
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
||||
gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F')
|
||||
|
||||
if get_state() == 'dead' and state_old ~= 'dead' then
|
||||
--print("dead. lives remaining:", R(0x75A, 0))
|
||||
if R(0x75A, 0) == 0 then prepare_reset() end
|
||||
if game.get_state() == 'dead' and state_old ~= 'dead' then
|
||||
--print("dead. lives remaining:", game.R(0x75A, 0))
|
||||
if game.R(0x75A, 0) == 0 then prepare_reset() end
|
||||
end
|
||||
if get_state() == 'lose' then
|
||||
if game.get_state() == 'lose' then
|
||||
-- this shouldn't happen if we catch the deaths as above.
|
||||
print("ran out of lives.")
|
||||
if not cfg.playback_mode then prepare_reset() end
|
||||
|
@ -997,25 +654,25 @@ local function doit(dummy)
|
|||
force_start = ingame_paused and timer == 0
|
||||
|
||||
local X = {}
|
||||
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
||||
for i, v in ipairs(extra_input) do insert(X, v / 256) end
|
||||
for i, v in ipairs(game.sprite_input) do insert(X, v / 256) end
|
||||
for i, v in ipairs(game.extra_input) do insert(X, v / 256) end
|
||||
nn.reshape(X, 1, gcfg.input_size)
|
||||
nn.reshape(tile_input, 1, gcfg.tile_count)
|
||||
nn.reshape(game.tile_input, 1, gcfg.tile_count)
|
||||
|
||||
trial_frames = trial_frames + cfg.frameskip
|
||||
if cfg.enable_network and get_state() == 'playing' or ingame_paused then
|
||||
if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then
|
||||
total_frames = total_frames + cfg.frameskip
|
||||
|
||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
|
||||
|
||||
local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames)
|
||||
if cfg.det_epsilon and random() < eps then
|
||||
local i = floor(random() * #gcfg.jp_lut) + 1
|
||||
jp = nn.copy(gcfg.jp_lut[i], jp)
|
||||
jp = copy(gcfg.jp_lut[i], jp)
|
||||
else
|
||||
local choose = cfg.deterministic and argmax or softchoice
|
||||
local ind = choose(unpack(outputs[nn_z]))
|
||||
jp = nn.copy(gcfg.jp_lut[ind], jp)
|
||||
jp = copy(gcfg.jp_lut[ind], jp)
|
||||
end
|
||||
|
||||
if force_start then
|
||||
|
@ -1036,41 +693,41 @@ local function doit(dummy)
|
|||
powerup_old = powerup
|
||||
status_old = status
|
||||
force_start_old = force_start
|
||||
state_old = get_state()
|
||||
score_old = get_score()
|
||||
state_old = game.get_state()
|
||||
score_old = game.get_score()
|
||||
end
|
||||
|
||||
init()
|
||||
|
||||
while true do
|
||||
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||
gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F')
|
||||
|
||||
while gcfg.bad_states[get_state()] do
|
||||
while gcfg.bad_states[game.get_state()] do
|
||||
-- mash the start button until we have control.
|
||||
joypad_mash('start')
|
||||
prepare_reset()
|
||||
|
||||
advance()
|
||||
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||
game.advance()
|
||||
gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F')
|
||||
|
||||
while get_state() == "loading" do advance() end -- kind of a hack.
|
||||
state_old = get_state()
|
||||
while game.get_state() == "loading" do game.advance() end -- kind of a hack.
|
||||
state_old = game.get_state()
|
||||
end
|
||||
|
||||
if reset then do_reset() end
|
||||
|
||||
if not cfg.enable_network then
|
||||
-- infinite time cheat. super handy for testing.
|
||||
if R(0xE) == 8 then
|
||||
set_timer(667)
|
||||
if game.R(0xE) == 8 then
|
||||
game.set_timer(667)
|
||||
poketime = true
|
||||
elseif poketime then
|
||||
poketime = false
|
||||
set_timer(1)
|
||||
game.set_timer(1)
|
||||
end
|
||||
|
||||
-- infinite lives.
|
||||
W(0x75A, 1)
|
||||
game.W(0x75A, 1)
|
||||
end
|
||||
|
||||
-- FIXME: if the game lags then we might miss our frame to change inputs!
|
||||
|
@ -1081,5 +738,5 @@ while true do
|
|||
-- jp might still be nil if we're not ingame or we're not playing.
|
||||
if jp ~= nil then joypad.write(1, jp) end
|
||||
|
||||
advance()
|
||||
game.advance()
|
||||
end
|
||||
|
|
36
nn.lua
36
nn.lua
|
@ -21,28 +21,11 @@ local unpack = table.unpack or unpack
|
|||
|
||||
local Base = require("Base")
|
||||
|
||||
local util = require("util")
|
||||
|
||||
-- hacks
|
||||
local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) end
|
||||
|
||||
-- general utilities
|
||||
|
||||
local function copy(t, out) -- shallow copy
|
||||
assert(type(t) == "table")
|
||||
local out = out or {}
|
||||
for k, v in pairs(t) do out[k] = v end
|
||||
return out
|
||||
end
|
||||
|
||||
local function indexof(t, a)
|
||||
assert(type(t) == "table")
|
||||
for k, v in pairs(t) do if v == a then return k end end
|
||||
return nil
|
||||
end
|
||||
|
||||
local function contains(t, a)
|
||||
return indexof(t, a) ~= nil
|
||||
end
|
||||
|
||||
-- math utilities
|
||||
|
||||
local function prod(x, ...)
|
||||
|
@ -198,7 +181,7 @@ end
|
|||
|
||||
local function cache(bs, shape)
|
||||
if bs == nil then return nil end
|
||||
local fullshape = copy(shape)
|
||||
local fullshape = util.copy(shape)
|
||||
insert(fullshape, bs, 1)
|
||||
return zeros(fullshape)
|
||||
end
|
||||
|
@ -274,19 +257,19 @@ local function traverse(node_in, node_out, nodes, dummy_mode)
|
|||
if seen_up[node] then
|
||||
local all_parents_added = true
|
||||
for _, parent in ipairs(node.parents) do
|
||||
if not contains(nodes, parent) then
|
||||
if not util.contains(nodes, parent) then
|
||||
all_parents_added = false
|
||||
break
|
||||
end
|
||||
end
|
||||
if not contains(nodes, node) and all_parents_added then
|
||||
if not util.contains(nodes, node) and all_parents_added then
|
||||
insert(nodes, node)
|
||||
end
|
||||
for _, child in ipairs(node.children) do insert(q, child) end
|
||||
end
|
||||
end
|
||||
|
||||
if dummy_mode then remove(nodes, indexof(nodes, node_in)) end
|
||||
if dummy_mode then remove(nodes, util.indexof(nodes, node_in)) end
|
||||
|
||||
return nodes
|
||||
end
|
||||
|
@ -705,7 +688,7 @@ function Model:forward(inputs)
|
|||
local outputs = {}
|
||||
for i, node in ipairs(self.nodes) do
|
||||
--print(i, node.name)
|
||||
if contains(self.nodes_in, node) then
|
||||
if util.contains(self.nodes_in, node) then
|
||||
local X = inputs[node]
|
||||
assert(X ~= nil, ("missing input for node %s"):format(node.name))
|
||||
assert(X.shape, ("missing shape for node %s"):format(node.name))
|
||||
|
@ -713,7 +696,7 @@ function Model:forward(inputs)
|
|||
else
|
||||
values[node] = node:propagate(values)
|
||||
end
|
||||
if contains(self.nodes_out, node) then
|
||||
if util.contains(self.nodes_out, node) then
|
||||
outputs[node] = values[node]
|
||||
end
|
||||
end
|
||||
|
@ -805,9 +788,6 @@ function Model:load(fn)
|
|||
end
|
||||
|
||||
return {
|
||||
copy = copy,
|
||||
indexof = indexof,
|
||||
contains = contains,
|
||||
prod = prod,
|
||||
uniform = uniform,
|
||||
normal = normal,
|
||||
|
|
290
smb.lua
Normal file
290
smb.lua
Normal file
|
@ -0,0 +1,290 @@
|
|||
-- disassembly used for reference:
|
||||
-- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
|
||||
|
||||
local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
|
||||
0, -40, -- 0x00
|
||||
6, -38,
|
||||
15, -37,
|
||||
22, -32,
|
||||
28, -28,
|
||||
32, -22,
|
||||
37, -14,
|
||||
39, -6,
|
||||
40, 0, -- 0x08
|
||||
38, 7,
|
||||
37, 15,
|
||||
33, 23,
|
||||
27, 29,
|
||||
22, 33,
|
||||
14, 37,
|
||||
6, 39,
|
||||
0, 41, -- 0x10
|
||||
-7, 40,
|
||||
-16, 38,
|
||||
-22, 34,
|
||||
-28, 28,
|
||||
-34, 23,
|
||||
-38, 16,
|
||||
-40, 8,
|
||||
-40, -0, -- 0x18
|
||||
-40, -6,
|
||||
-38, -14,
|
||||
-34, -22,
|
||||
-28, -28,
|
||||
-22, -32,
|
||||
-16, -36,
|
||||
-8, -38,
|
||||
}
|
||||
|
||||
local util = require("util")
|
||||
local R = memory.readbyteunsigned
|
||||
local W = memory.writebyte
|
||||
local function S(addr) return util.signbyte(R(addr)) end
|
||||
|
||||
-- TODO: reinterface to one "input" array visible to main.lua.
|
||||
local sprite_input = {}
|
||||
local tile_input = {}
|
||||
local extra_input = {}
|
||||
|
||||
local overlay = false
|
||||
|
||||
local band = bit.band
|
||||
local floor = math.floor
|
||||
|
||||
local function get_timer()
|
||||
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
||||
end
|
||||
|
||||
local function get_score()
|
||||
return R(0x7DE) * 10000 +
|
||||
R(0x7DF) * 1000 +
|
||||
R(0x7E0) * 100 +
|
||||
R(0x7E1) * 10 +
|
||||
R(0x7E2)
|
||||
end
|
||||
|
||||
local function set_timer(time)
|
||||
W(0x7F8, floor(time / 100))
|
||||
W(0x7F9, floor((time / 10) % 10))
|
||||
W(0x7FA, floor(time % 10))
|
||||
end
|
||||
|
||||
local function mark_sprite(x, y, t)
|
||||
if x < 0 or x >= 256 or y < 0 or y > 224 then
|
||||
sprite_input[#sprite_input+1] = 0
|
||||
sprite_input[#sprite_input+1] = 0
|
||||
sprite_input[#sprite_input+1] = 0
|
||||
else
|
||||
sprite_input[#sprite_input+1] = x
|
||||
sprite_input[#sprite_input+1] = y
|
||||
sprite_input[#sprite_input+1] = t
|
||||
end
|
||||
if t == 0 then return end
|
||||
if overlay then
|
||||
gui.box(x-4, y-4, x+4, y+4)
|
||||
--gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000')
|
||||
gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F')
|
||||
--gui.text(x-5, y-3+9, ("%02X"):format(x), '#FFFFFF', '#0000003F')
|
||||
end
|
||||
end
|
||||
|
||||
local function mark_tile(x, y, t)
|
||||
tile_input[#tile_input+1] = t
|
||||
if t == 0 then return end
|
||||
if overlay then
|
||||
gui.box(x-8, y-8, x+8, y+8)
|
||||
gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
|
||||
end
|
||||
end
|
||||
|
||||
local function getxy(i, x_addr, y_addr, pageloc_addr, hipos_addr)
|
||||
local spl_l = R(0x71A)
|
||||
local spl_r = R(0x71B)
|
||||
local sx_l = R(0x71C)
|
||||
local sx_r = R(0x71D)
|
||||
|
||||
local x = R(x_addr + i)
|
||||
local y = R(y_addr + i)
|
||||
local sx, sy = x, y
|
||||
if pageloc_addr ~= nil then
|
||||
local page = R(pageloc_addr + i)
|
||||
sx = sx - sx_l - (spl_l - page) * 256
|
||||
else
|
||||
sx = sx - sx_l
|
||||
end
|
||||
if hipos_addr ~= nil then
|
||||
local hipos = S(hipos_addr + i)
|
||||
sy = sy + (hipos - 1) * 256
|
||||
end
|
||||
|
||||
return sx, sy
|
||||
end
|
||||
|
||||
local function paused() return band(R(0x776), 1) end
|
||||
|
||||
local function get_state()
|
||||
if R(0xE) == 0xFF then return 'power' end
|
||||
if R(0x774) > 0 then return 'lagging' end
|
||||
if R(0x7A2) > 0 then return 'waiting_demo' end
|
||||
if R(0x717) > 0 then return 'playing_demo' end
|
||||
-- if R(0x770) == 0xFF then return 'power' end
|
||||
if paused() ~= 0 then return 'paused' end
|
||||
if R(0xE) == 0 then return 'world_screen' end
|
||||
-- if R(0x712) == 1 then return 'deadmusic' end
|
||||
if R(0x7CA) == 0x94 then return 'dead' end
|
||||
if R(0xE) == 4 then return 'win_flagpole' end
|
||||
if R(0xE) == 5 then return 'win_walking' end
|
||||
if R(0xE) == 6 then return 'lose' end
|
||||
-- if R(0x770) == 0 then return 'not_playing' end
|
||||
if R(0x770) == 2 then return 'win_castle' end
|
||||
if R(0x772) == 2 then return 'no_control' end
|
||||
if R(0x772) == 3 then return 'playing' end
|
||||
if R(0x770) == 1 then return 'loading' end
|
||||
if R(0x770) == 3 then return 'lose' end
|
||||
return 'unknown'
|
||||
end
|
||||
|
||||
local function advance()
|
||||
emu.frameadvance()
|
||||
while emu.lagged() do emu.frameadvance() end -- skip lag frames.
|
||||
while R(0x774) > 0 do emu.frameadvance() end -- also lag frames.
|
||||
end
|
||||
|
||||
local function handle_enemies()
|
||||
-- enemies, flagpole
|
||||
for i = 0, 5 do
|
||||
local x, y = getxy(i, 0x87, 0xCF, 0x6E, 0xB6)
|
||||
x, y = x + 8, y + 16
|
||||
local tid = R(0x16 + i)
|
||||
local flags = R(0xF + i)
|
||||
--local offscr = R(0x3D8 + i)
|
||||
local invisible = tid < 0x10 and flags == 0
|
||||
if tid == 0x30 then y = y - 8 end -- flagpole flag
|
||||
if tid == 0x31 then y = y - 8 end -- castle flag
|
||||
if tid == 0x16 then x, y = x - 4, y - 12 end -- fireworks
|
||||
if tid >= 0x24 and tid <= 0x29 then x, y = x + 16, y - 12 end -- moving platforms
|
||||
if tid == 0x2D then x, y = x, y end -- bowser (TODO: determine head or body)
|
||||
if tid == 0x15 then x, y = x, y - 12 end -- bowser fire
|
||||
if tid == 0x32 then x, y = x, y - 8 end -- spring
|
||||
-- tid == 0x35 -- toad
|
||||
if tid == 0x1D or tid == 0x1B then -- rotating fire bars
|
||||
x, y = x - 4, y - 12
|
||||
-- this is a mess... gotta find out its rotation and then project.
|
||||
-- TODO: handle long fire bars too
|
||||
local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i)
|
||||
gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F')
|
||||
local x_off, y_off = rotation_offsets[rot*2+1], rotation_offsets[rot*2+2]
|
||||
x, y = x + x_off, y + y_off
|
||||
end
|
||||
if invisible then
|
||||
mark_sprite(0, 0, 0)
|
||||
else
|
||||
mark_sprite(x, y, tid + 1)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_fireballs()
|
||||
for i = 0, 1 do
|
||||
local x, y = getxy(i, 0x8D, 0xD5, 0x74, 0xBC)
|
||||
x, y = x + 4, y + 4
|
||||
local state = R(0x24 + i)
|
||||
local invisible = state == 0
|
||||
if invisible then
|
||||
mark_sprite(0, 0, 0)
|
||||
else
|
||||
mark_sprite(x, y, 257)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_blocks()
|
||||
for i = 0, 3 do
|
||||
local x, y = getxy(i, 0x8F, 0xD7, 0x76, 0xBE)
|
||||
x, y = x + 8, y + 8
|
||||
local state = R(0x26 + i)
|
||||
local invisible = state == 0
|
||||
if invisible then
|
||||
mark_sprite(0, 0, 0)
|
||||
else
|
||||
mark_sprite(x, y, 258)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_hammers()
|
||||
-- hammers, coins, score bonus text...
|
||||
for i = 0, 8 do
|
||||
local x, y = getxy(i, 0x93, 0xDB, 0x7A, 0xC2)
|
||||
x, y = x + 8, y + 8
|
||||
local state = R(0x2A + i)
|
||||
-- skip coin effect states. not interactable; we don't care!
|
||||
if state ~= 0 and state >= 0x30 then
|
||||
mark_sprite(x, y, state + 1)
|
||||
else
|
||||
mark_sprite(0, 0, 0)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_misc()
|
||||
for i = 0, 0 do
|
||||
local x, y = getxy(i, 0x9C, 0xE4, 0x83, 0xCB)
|
||||
x, y = x + 8, y + 8
|
||||
local state = R(0x33 + i)
|
||||
if state ~= 0 then
|
||||
mark_sprite(x, y, state + 1)
|
||||
else
|
||||
mark_sprite(0, 0, 0)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_tiles()
|
||||
--local tile_col = R(0x6A0)
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
extra_input[#extra_input+1] = tile_scroll_remainder
|
||||
for y = 0, 12 do
|
||||
for x = 0, 16 do
|
||||
local col = (x + tile_scroll) % 32
|
||||
local t
|
||||
if col < 16 then
|
||||
t = R(0x500 + y * 16 + (col % 16))
|
||||
else
|
||||
t = R(0x5D0 + y * 16 + (col % 16))
|
||||
end
|
||||
local sx = x * 16 + 8 - tile_scroll_remainder
|
||||
local sy = y * 16 + 40
|
||||
mark_tile(sx, sy, t)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
-- TODO: don't expose these; provide interfaces for everything needed.
|
||||
R=R,
|
||||
W=W,
|
||||
S=S,
|
||||
overlay=overlay,
|
||||
|
||||
sprite_input=sprite_input,
|
||||
tile_input=tile_input,
|
||||
extra_input=extra_input,
|
||||
|
||||
get_timer=get_timer,
|
||||
get_score=get_score,
|
||||
set_timer=set_timer,
|
||||
mark_sprite=mark_sprite,
|
||||
mark_tile=mark_tile,
|
||||
getxy=getxy,
|
||||
paused=paused,
|
||||
get_state=get_state,
|
||||
advance=advance,
|
||||
handle_enemies=handle_enemies,
|
||||
handle_fireballs=handle_fireballs,
|
||||
handle_blocks=handle_blocks,
|
||||
handle_hammers=handle_hammers,
|
||||
handle_misc=handle_misc,
|
||||
handle_tiles=handle_tiles,
|
||||
}
|
176
util.lua
Normal file
176
util.lua
Normal file
|
@ -0,0 +1,176 @@
|
|||
-- TODO: reorganize function order.
|
||||
|
||||
local assert = assert
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local max = math.max
|
||||
local min = math.min
|
||||
local pairs = pairs
|
||||
local random = math.random
|
||||
local select = select
|
||||
local sqrt= math.sqrt
|
||||
|
||||
local function signbyte(x)
|
||||
if x >= 128 then x = 256 - x end
|
||||
return x
|
||||
end
|
||||
|
||||
local function boolean_xor(a, b)
|
||||
if a and b then return false end
|
||||
if not a and not b then return false end
|
||||
return true
|
||||
end
|
||||
|
||||
local _invlog2 = 1 / log(2)
|
||||
local function log2(x) return log(x) * _invlog2 end
|
||||
|
||||
local function clamp(x, l, u) return min(max(x, l), u) end
|
||||
|
||||
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
|
||||
|
||||
local function argmax(...)
|
||||
local max_i = 0
|
||||
local max_v = -999999999
|
||||
for i=1, select("#", ...) do
|
||||
local v = select(i, ...)
|
||||
if v > max_v then
|
||||
max_i = i
|
||||
max_v = v
|
||||
end
|
||||
end
|
||||
return max_i
|
||||
end
|
||||
|
||||
local function softchoice(...)
|
||||
local t = random()
|
||||
local psum = 0
|
||||
for i=1, select("#", ...) do
|
||||
local p = select(i, ...)
|
||||
psum = psum + p
|
||||
if t < psum then
|
||||
return i
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function empty(t)
|
||||
for k, _ in pairs(t) do t[k] = nil end
|
||||
return t
|
||||
end
|
||||
|
||||
local function calc_mean_dev(x)
|
||||
local mean = 0
|
||||
for i, v in ipairs(x) do
|
||||
mean = mean + v / #x
|
||||
end
|
||||
|
||||
local dev = 0
|
||||
for i, v in ipairs(x) do
|
||||
local delta = v - mean
|
||||
dev = dev + delta * delta / #x
|
||||
end
|
||||
|
||||
return mean, sqrt(dev)
|
||||
end
|
||||
|
||||
local function normalize(x, out)
|
||||
out = out or x
|
||||
local mean, dev = calc_mean_dev(x)
|
||||
if dev <= 0 then dev = 1 end
|
||||
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
||||
return out
|
||||
end
|
||||
|
||||
local function normalize_wrt(x, s, out)
|
||||
out = out or x
|
||||
local mean, dev = calc_mean_dev(s)
|
||||
if dev <= 0 then dev = 1 end
|
||||
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
||||
return out
|
||||
end
|
||||
|
||||
local function fitness_shaping(rewards)
|
||||
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
||||
local decreasing = nn.copy(rewards)
|
||||
sort(decreasing, function(a, b) return a > b end)
|
||||
local shaped_returns = {}
|
||||
local lamb = #rewards
|
||||
|
||||
local denom = 0
|
||||
for i, v in ipairs(rewards) do
|
||||
local l = log2(lamb / 2 + 1)
|
||||
local r = log2(nn.indexof(decreasing, v))
|
||||
denom = denom + max(0, l - r)
|
||||
end
|
||||
|
||||
for i, v in ipairs(rewards) do
|
||||
local l = log2(lamb / 2 + 1)
|
||||
local r = log2(nn.indexof(decreasing, v))
|
||||
local numer = max(0, l - r)
|
||||
insert(shaped_returns, numer / denom + 1 / lamb)
|
||||
end
|
||||
|
||||
return shaped_returns
|
||||
end
|
||||
|
||||
local function unperturbed_rank(rewards, unperturbed_reward)
|
||||
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
||||
local nth_place = 1
|
||||
for i, v in ipairs(rewards) do
|
||||
if v > unperturbed_reward then
|
||||
nth_place = nth_place + 1
|
||||
end
|
||||
end
|
||||
return nth_place
|
||||
end
|
||||
|
||||
local function copy(t, out) -- shallow copy
|
||||
assert(type(t) == "table")
|
||||
local out = out or {}
|
||||
for k, v in pairs(t) do out[k] = v end
|
||||
return out
|
||||
end
|
||||
|
||||
local function indexof(t, a)
|
||||
assert(type(t) == "table")
|
||||
for k, v in pairs(t) do if v == a then return k end end
|
||||
return nil
|
||||
end
|
||||
|
||||
local function contains(t, a)
|
||||
return indexof(t, a) ~= nil
|
||||
end
|
||||
|
||||
local function argmax2(t)
|
||||
return t[1] > t[2]
|
||||
end
|
||||
|
||||
local function rchoice2(t)
|
||||
return t[1] > random()
|
||||
end
|
||||
|
||||
local function rbool()
|
||||
return 0.5 >= random()
|
||||
end
|
||||
|
||||
return {
|
||||
signbyte=signbyte,
|
||||
boolean_xor=boolean_xor,
|
||||
log2=log2,
|
||||
clamp=clamp,
|
||||
lerp=lerp,
|
||||
argmax=argmax,
|
||||
softchoice=softchoice,
|
||||
empty=empty,
|
||||
calc_mean_dev=calc_mean_dev,
|
||||
normalize=normalize,
|
||||
normalize_wrt=normalize_wrt,
|
||||
fitness_shaping=fitness_shaping,
|
||||
unperturbed_rank=unperturbed_rank,
|
||||
copy=copy,
|
||||
indexof=indexof,
|
||||
contains=contains,
|
||||
argmax2=argmax2,
|
||||
rchoice2=rchoice2,
|
||||
rbool=rbool,
|
||||
}
|
Loading…
Add table
Reference in a new issue