refactor config vars to their own files

This commit is contained in:
Connor Olding 2018-04-02 15:21:55 +02:00
parent 66bf689e04
commit 545618c70b
3 changed files with 214 additions and 190 deletions

48
config.lua Normal file
View file

@ -0,0 +1,48 @@
local function approx_cossim(dim)
return math.pow(1.521 * dim - 0.521, -0.5026)
end
local cfg = {
defer_prints = true,
playable_mode = false,
start_big = false, --true
starting_lives = 0, --1
--
init_zeros = true, -- instead of he_normal noise or whatever.
frameskip = 4,
-- true greedy epsilon has both deterministic and det_epsilon set.
deterministic = true, -- use argmax on outputs instead of random sampling.
det_epsilon = false, -- take random actions with probability eps.
--
epoch_trials = 50,
epoch_top_trials = 25, -- new with ARS.
unperturbed_trial = true, -- do a trial without any noise.
negate_trials = true, -- try pairs of normal and negated noise directions.
time_inputs = true, -- binary inputs of global frame count
-- ^ note that this now doubles the effective trials.
deviation = 0.05, --0.075 --0.1
--learning_rate = 0.01 / approx_cossim(7051)
learning_rate = 1.0,
--learning_rate = 0.0032 / approx_cossim(66573)
--learning_rate = 0.0056 / approx_cossim(66573)
weight_decay = 0.00032, --0.001 --0.0023
--
cap_time = 200, --400
timer_loser = 1/2,
decrement_reward = false, -- bad idea, encourages mario to kill himself
}
cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials)
cfg.eps_start = 1.0 * cfg.frameskip / 64
cfg.eps_stop = 0.1 * cfg.eps_start
cfg.eps_frames = 1000000
cfg.enable_overlay = cfg.playable_mode
cfg.enable_network = not cfg.playable_mode
return setmetatable(cfg, {
__index = function(t, n)
error("cannot use undeclared config '" .. tostring(n) .. "'", 2)
end
})

107
gameconfig.lua Normal file
View file

@ -0,0 +1,107 @@
local gcfg = {
input_size = 60 + 4, -- TODO: let the script figure this out for us.
tile_count = 17 * 13,
ok_routines = {
[0x4] = true, -- sliding down flagpole
[0x5] = true, -- end of level auto-walk
[0x7] = true, -- start of level auto-walk
[0x8] = true, -- normal (in control)
[0x9] = true, -- acquiring mushroom
[0xA] = true, -- losing big mario
[0xB] = true, -- uhh
[0xC] = true, -- acquiring fireflower
},
bad_states = {
power = true,
waiting_demo = true,
playing_demo = true,
unknown = true,
lose = true,
},
jp_lut = {
{ -- none
up = false, down = false, left = false, right = false,
select = false, start = false, B = false, A = false,
}, { -- A
up = false, down = false, left = false, right = false,
select = false, start = false, B = false, A = true,
}, { -- L
up = false, down = false, left = true, right = false,
select = false, start = false, B = false, A = false,
}, { -- R
up = false, down = false, left = false, right = true,
select = false, start = false, B = false, A = false,
}, { -- L + B
up = false, down = false, left = true, right = false,
select = false, start = false, B = true, A = false,
}, { -- R + B
up = false, down = false, left = false, right = true,
select = false, start = false, B = true, A = false,
}, { -- L + A
up = false, down = false, left = true, right = false,
select = false, start = false, B = false, A = true,
}, { -- R + A
up = false, down = false, left = false, right = true,
select = false, start = false, B = false, A = true,
}, { -- L + A + B
up = false, down = false, left = true, right = false,
select = false, start = false, B = true, A = true,
}, { -- R + A + B
up = false, down = false, left = false, right = true,
select = false, start = false, B = true, A = true,
}, { -- D
up = false, down = true, left = false, right = false,
select = false, start = false, B = false, A = false,
}, { -- D + A
up = false, down = true, left = false, right = false,
select = false, start = false, B = false, A = true,
}, { -- U
up = true, down = false, left = false, right = false,
select = false, start = false, B = false, A = false,
},
},
rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
0, -40, -- 0x00
6, -38,
15, -37,
22, -32,
28, -28,
32, -22,
37, -14,
39, -6,
40, 0, -- 0x08
38, 7,
37, 15,
33, 23,
27, 29,
22, 33,
14, 37,
6, 39,
0, 41, -- 0x10
-7, 40,
-16, 38,
-22, 34,
-28, 28,
-34, 23,
-38, 16,
-40, 8,
-40, -0, -- 0x18
-40, -6,
-38, -14,
-34, -22,
-28, -28,
-22, -32,
-16, -36,
-8, -38,
},
}
return setmetatable(gcfg, {
__index = function(t, n)
error("cannot use undeclared gameconfig '" .. tostring(n) .. "'", 2)
end
})

249
main.lua
View file

@ -26,108 +26,8 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
--randomseed(11)
local defer_prints = true
local playable_mode = false
local start_big = false --true
local starting_lives = 0 --1
--
local init_zeros = true -- instead of he_normal noise or whatever.
local frameskip = 4
-- true greedy epsilon has both deterministic and det_epsilon set.
local deterministic = true -- use argmax on outputs instead of random sampling.
local det_epsilon = false -- take random actions with probability eps.
local eps_start = 1.0 * frameskip / 64
local eps_stop = 0.1 * eps_start
local eps_frames = 1000000
--
local epoch_trials = 50
local epoch_top_trials = 25 -- new with ARS.
local unperturbed_trial = true -- do a trial without any noise.
local negate_trials = true -- try pairs of normal and negated noise directions.
local time_inputs = true -- binary inputs of global frame count
-- ^ note that this now doubles the effective trials.
local deviation = 0.05 --0.075 --0.1
local function approx_cossim(dim)
return math.pow(1.521 * dim - 0.521, -0.5026)
end
--local learning_rate = 0.01 / approx_cossim(7051)
local learning_rate = 1.0
--local learning_rate = 0.0032 / approx_cossim(66573)
--local learning_rate = 0.0056 / approx_cossim(66573)
local weight_decay = 0.00032 --0.001 --0.0023
--
local cap_time = 200 --400
local timer_loser = 1/2
local decrement_reward = false -- bad idea, encourages mario to kill himself
--
local enable_overlay = playable_mode
local enable_network = not playable_mode
local input_size = 60 + 4 -- TODO: let the script figure this out for us.
local tile_count = 17 * 13
local ok_routines = {
[0x4] = true, -- sliding down flagpole
[0x5] = true, -- end of level auto-walk
[0x7] = true, -- start of level auto-walk
[0x8] = true, -- normal (in control)
[0x9] = true, -- acquiring mushroom
[0xA] = true, -- losing big mario
[0xB] = true, -- uhh
[0xC] = true, -- acquiring fireflower
}
local bad_states = {
power = true,
waiting_demo = true,
playing_demo = true,
unknown = true,
lose = true,
}
local jp_lut = {
{ -- none
up = false, down = false, left = false, right = false,
select = false, start = false, B = false, A = false,
}, { -- A
up = false, down = false, left = false, right = false,
select = false, start = false, B = false, A = true,
}, { -- L
up = false, down = false, left = true, right = false,
select = false, start = false, B = false, A = false,
}, { -- R
up = false, down = false, left = false, right = true,
select = false, start = false, B = false, A = false,
}, { -- L + B
up = false, down = false, left = true, right = false,
select = false, start = false, B = true, A = false,
}, { -- R + B
up = false, down = false, left = false, right = true,
select = false, start = false, B = true, A = false,
}, { -- L + A
up = false, down = false, left = true, right = false,
select = false, start = false, B = false, A = true,
}, { -- R + A
up = false, down = false, left = false, right = true,
select = false, start = false, B = false, A = true,
}, { -- L + A + B
up = false, down = false, left = true, right = false,
select = false, start = false, B = true, A = true,
}, { -- R + A + B
up = false, down = false, left = false, right = true,
select = false, start = false, B = true, A = true,
}, { -- D
up = false, down = true, left = false, right = false,
select = false, start = false, B = false, A = false,
}, { -- D + A
up = false, down = true, left = false, right = false,
select = false, start = false, B = false, A = true,
}, { -- U
up = true, down = false, left = false, right = false,
select = false, start = false, B = false, A = false,
},
}
local cfg = require("config")
local gcfg = require("gameconfig")
-- state.
@ -302,17 +202,21 @@ local network
local nn_x, nn_tx, nn_ty, nn_y, nn_z
local function make_network(input_size)
nn_x = nn.Input({input_size})
nn_tx = nn.Input({tile_count})
nn_tx = nn.Input({gcfg.tile_count})
nn_ty = nn_tx:feed(nn.Embed(256, 2))
nn_y = nn.Merge()
nn_x:feed(nn_y)
nn_ty:feed(nn_y)
nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Gelu())
if cfg.deterministic then
nn_y = nn_y:feed(nn.Relu())
else
nn_y = nn_y:feed(nn.Gelu())
end
nn_z = nn_y
nn_z = nn_z:feed(nn.Dense(#jp_lut))
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut))
nn_z = nn_z:feed(nn.Softmax())
return nn.Model({nn_x, nn_tx}, {nn_z})
end
@ -322,41 +226,6 @@ end
-- disassembly used for reference:
-- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
0, -40, -- 0x00
6, -38,
15, -37,
22, -32,
28, -28,
32, -22,
37, -14,
39, -6,
40, 0, -- 0x08
38, 7,
37, 15,
33, 23,
27, 29,
22, 33,
14, 37,
6, 39,
0, 41, -- 0x10
-7, 40,
-16, 38,
-22, 34,
-28, 28,
-34, 23,
-38, 16,
-40, 8,
-40, -0, -- 0x18
-40, -6,
-38, -14,
-34, -22,
-28, -28,
-22, -32,
-16, -36,
-8, -38,
}
local function get_timer()
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
end
@ -386,7 +255,7 @@ local function mark_sprite(x, y, t)
sprite_input[#sprite_input+1] = t
end
if t == 0 then return end
if enable_overlay then
if cfg.enable_overlay then
gui.box(x-4, y-4, x+4, y+4)
--gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000')
gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F')
@ -397,7 +266,7 @@ end
local function mark_tile(x, y, t)
tile_input[#tile_input+1] = t
if t == 0 then return end
if enable_overlay then
if cfg.enable_overlay then
gui.box(x-8, y-8, x+8, y+8)
gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
end
@ -479,7 +348,7 @@ local function handle_enemies()
-- TODO: handle long fire bars too
local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i)
gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F')
local x_off, y_off = rotation_offsets[rot*2+1], rotation_offsets[rot*2+2]
local x_off, y_off = gcfg.rotation_offsets[rot*2+1], gcfg.rotation_offsets[rot*2+2]
x, y = x + x_off, y + y_off
end
if invisible then
@ -578,11 +447,11 @@ local function prepare_epoch()
empty(trial_rewards)
-- TODO: save memory. generate noise as needed by saving the seed
-- (the os.time() as of here) and calling nn.normal() each trial.
for i = 1, epoch_trials do
for i = 1, cfg.epoch_trials do
local noise = nn.zeros(#base_params)
-- NOTE: change in implementation: deviation is multiplied here
-- and ONLY here now.
for j = 1, #base_params do noise[j] = deviation * nn.normal() end
for j = 1, #base_params do noise[j] = cfg.deviation * nn.normal() end
trial_noise[i] = noise
end
trial_i = -1
@ -590,7 +459,7 @@ end
local function load_next_pair()
trial_i = trial_i + 1
if trial_i == 0 and not unperturbed_trial then
if trial_i == 0 and not cfg.unperturbed_trial then
trial_i = 1
trial_neg = true
end
@ -599,7 +468,7 @@ local function load_next_pair()
if trial_i > 0 then
if trial_neg then
if not defer_prints then print('trial', trial_i, 'positive') end
if not cfg.defer_prints then print('trial', trial_i, 'positive') end
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v + noise[i]
@ -607,7 +476,7 @@ local function load_next_pair()
else
trial_i = trial_i - 1
if not defer_prints then print('trial', trial_i, 'positive') end
if not cfg.defer_prints then print('trial', trial_i, 'positive') end
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v - noise[i]
@ -616,17 +485,17 @@ local function load_next_pair()
trial_neg = not trial_neg
else
if not defer_prints then print("test trial") end
if not cfg.defer_prints then print("test trial") end
end
network:distribute(W)
end
local function load_next_trial()
if negate_trials then return load_next_pair() end
if cfg.negate_trials then return load_next_pair() end
trial_i = trial_i + 1
local W = nn.copy(base_params)
if trial_i == 0 and not unperturbed_trial then
if trial_i == 0 and not cfg.unperturbed_trial then
trial_i = 1
end
if trial_i > 0 then
@ -682,7 +551,7 @@ local function learn_from_epoch()
--for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end
if unperturbed_trial then
if cfg.unperturbed_trial then
local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0])
-- a rank of 1 means our gradient is uninformative.
@ -694,10 +563,10 @@ local function learn_from_epoch()
-- new stuff
local best_rewards
if negate_trials then
if cfg.negate_trials then
-- select one (the best) reward of each pos/neg pair.
best_rewards = {}
for i = 1, epoch_trials do
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = trial_rewards[ind + 0]
local neg = trial_rewards[ind + 1]
@ -712,7 +581,7 @@ local function learn_from_epoch()
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
--print("indices:", indices)
for i = epoch_top_trials + 1, #best_rewards do indices[i] = nil end
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
print("best trials:", indices)
local top_rewards = {}
@ -731,27 +600,27 @@ local function learn_from_epoch()
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
-- NOTE: step no longer directly incorporates learning_rate.
for i = 1, epoch_trials do
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = top_rewards[ind + 0]
local neg = top_rewards[ind + 1]
local reward = pos - neg
local noise = trial_noise[i]
for j, v in ipairs(noise) do
step[j] = step[j] + reward * v / epoch_top_trials
step[j] = step[j] + reward * v / cfg.epoch_top_trials
end
end
local step_mean, step_dev = calc_mean_dev(step)
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
--print("step stddev:", step_dev)
print("full step stddev:", learning_rate * step_dev)
print("full step stddev:", cfg.learning_rate * step_dev)
for i, v in ipairs(base_params) do
base_params[i] = v + learning_rate * step[i] - weight_decay * v
base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v
end
if enable_network then
if cfg.enable_network then
network:distribute(base_params)
network:save()
else
@ -782,10 +651,10 @@ local function do_reset()
-- be a little more descriptive.
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
if trial_i >= 0 and defer_prints then
if trial_i >= 0 and cfg.defer_prints then
if trial_i == 0 then
print('test trial reward:', reward, "("..state..")")
elseif negate_trials then
elseif cfg.negate_trials then
--local dir = trial_neg and "negative" or "positive"
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
@ -804,14 +673,14 @@ local function do_reset()
end
if trial_i >= 0 then
if trial_i == 0 or not negate_trials then
if trial_i == 0 or not cfg.negate_trials then
trial_rewards[trial_i] = reward
else
trial_rewards[#trial_rewards + 1] = reward
end
end
if epoch_i == 0 or (trial_i == epoch_trials and trial_neg) then
if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then
if epoch_i > 0 then learn_from_epoch() end
epoch_i = epoch_i + 1
prepare_epoch()
@ -825,19 +694,19 @@ local function do_reset()
score_old = get_score()
-- set number of lives. (mario gets n+1 chances)
W(0x75A, starting_lives)
W(0x75A, cfg.starting_lives)
if start_big then
if cfg.start_big then
-- make mario "super".
W(0x754, 0)
W(0x756, 1)
end
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
--max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time)
--max_time = min(log(epoch_i) * 10 + 100, cfg.cap_time)
--max_time = min(8 * sqrt(360 / cfg.epoch_trials * (epoch_i - 1)) + 100, cfg.cap_time)
max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time)
max_time = ceil(max_time)
--max_time = cap_time
--max_time = cfg.cap_time
if once then
savestate.load(startsave)
@ -856,12 +725,12 @@ local function do_reset()
end
local function init()
network = make_network(input_size)
network = make_network(gcfg.input_size)
network:reset()
network:print()
print("parameters:", network.n_param)
if init_zeros then
if cfg.init_zeros then
local W = network:collect()
for i, w in ipairs(W) do W[i] = 0 end
network:distribute(W)
@ -889,11 +758,11 @@ local function doit(dummy)
-- TODO: more robust. doesn't detect moonwalking against a wall.
-- well, that shouldn't happen anymore now that i've disabled left+right.
local timer = get_timer()
if ingame_paused or random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
if ingame_paused or random() > 1 - cfg.timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
timer = timer - 1
end
timer = clamp(timer, 0, max_time)
if enable_network then
if cfg.enable_network then
set_timer(timer)
end
@ -924,7 +793,7 @@ local function doit(dummy)
insert(extra_input, vx)
insert(extra_input, vy)
if time_inputs then
if cfg.time_inputs then
for i=2,5 do
insert(extra_input, band(total_frames, lshift(1, i)))
end
@ -946,14 +815,14 @@ local function doit(dummy)
local powerup_delta = powerup_old - powerup
-- 2 is fire mario.
local status_delta = clamp(status - status_old, -1, 1)
local flagpole_bonus = R(0xE) == 4 and frameskip or 0
local flagpole_bonus = R(0xE) == 4 and cfg.frameskip or 0
--local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
local score_delta = get_score() - score_old
if score_delta < 0 then score_delta = 0 end
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
screen_scroll_delta = 0
if decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end
if cfg.decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end
if not ingame_paused then reward = reward + reward_delta end
@ -981,22 +850,22 @@ local function doit(dummy)
local X = {}
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
for i, v in ipairs(extra_input) do insert(X, v / 256) end
nn.reshape(X, 1, input_size)
nn.reshape(tile_input, 1, tile_count)
nn.reshape(X, 1, gcfg.input_size)
nn.reshape(tile_input, 1, gcfg.tile_count)
if enable_network and get_state() == 'playing' or ingame_paused then
total_frames = total_frames + frameskip
if cfg.enable_network and get_state() == 'playing' or ingame_paused then
total_frames = total_frames + cfg.frameskip
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
if det_epsilon and random() < eps then
local i = floor(random() * #jp_lut) + 1
jp = nn.copy(jp_lut[i], jp)
local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames)
if cfg.det_epsilon and random() < eps then
local i = floor(random() * #gcfg.jp_lut) + 1
jp = nn.copy(gcfg.jp_lut[i], jp)
else
local choose = deterministic and argmax or softchoice
local choose = cfg.deterministic and argmax or softchoice
local ind = choose(unpack(outputs[nn_z]))
jp = nn.copy(jp_lut[ind], jp)
jp = nn.copy(gcfg.jp_lut[ind], jp)
end
if force_start then
@ -1026,7 +895,7 @@ init()
while true do
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
while bad_states[get_state()] do
while gcfg.bad_states[get_state()] do
-- mash the start button until we have control.
joypad_mash('start')
reset = true
@ -1040,7 +909,7 @@ while true do
if reset then do_reset() end
if not enable_network then
if not cfg.enable_network then
-- infinite time cheat. super handy for testing.
if R(0xE) == 8 then
set_timer(667)
@ -1056,7 +925,7 @@ while true do
-- FIXME: if the game lags then we might miss our frame to change inputs!
-- don't rely on emu.framecount.
local doot = jp == nil or emu.framecount() % frameskip == 0
local doot = jp == nil or emu.framecount() % cfg.frameskip == 0
doit(not doot)
-- jp might still be nil if we're not ingame or we're not playing.