refactor config vars to their own files
This commit is contained in:
parent
66bf689e04
commit
545618c70b
3 changed files with 214 additions and 190 deletions
48
config.lua
Normal file
48
config.lua
Normal file
|
@ -0,0 +1,48 @@
|
|||
local function approx_cossim(dim)
|
||||
return math.pow(1.521 * dim - 0.521, -0.5026)
|
||||
end
|
||||
|
||||
local cfg = {
|
||||
defer_prints = true,
|
||||
|
||||
playable_mode = false,
|
||||
start_big = false, --true
|
||||
starting_lives = 0, --1
|
||||
--
|
||||
init_zeros = true, -- instead of he_normal noise or whatever.
|
||||
frameskip = 4,
|
||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||
deterministic = true, -- use argmax on outputs instead of random sampling.
|
||||
det_epsilon = false, -- take random actions with probability eps.
|
||||
--
|
||||
epoch_trials = 50,
|
||||
epoch_top_trials = 25, -- new with ARS.
|
||||
unperturbed_trial = true, -- do a trial without any noise.
|
||||
negate_trials = true, -- try pairs of normal and negated noise directions.
|
||||
time_inputs = true, -- binary inputs of global frame count
|
||||
-- ^ note that this now doubles the effective trials.
|
||||
deviation = 0.05, --0.075 --0.1
|
||||
--learning_rate = 0.01 / approx_cossim(7051)
|
||||
learning_rate = 1.0,
|
||||
--learning_rate = 0.0032 / approx_cossim(66573)
|
||||
--learning_rate = 0.0056 / approx_cossim(66573)
|
||||
weight_decay = 0.00032, --0.001 --0.0023
|
||||
--
|
||||
cap_time = 200, --400
|
||||
timer_loser = 1/2,
|
||||
decrement_reward = false, -- bad idea, encourages mario to kill himself
|
||||
}
|
||||
|
||||
cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials)
|
||||
|
||||
cfg.eps_start = 1.0 * cfg.frameskip / 64
|
||||
cfg.eps_stop = 0.1 * cfg.eps_start
|
||||
cfg.eps_frames = 1000000
|
||||
cfg.enable_overlay = cfg.playable_mode
|
||||
cfg.enable_network = not cfg.playable_mode
|
||||
|
||||
return setmetatable(cfg, {
|
||||
__index = function(t, n)
|
||||
error("cannot use undeclared config '" .. tostring(n) .. "'", 2)
|
||||
end
|
||||
})
|
107
gameconfig.lua
Normal file
107
gameconfig.lua
Normal file
|
@ -0,0 +1,107 @@
|
|||
local gcfg = {
|
||||
input_size = 60 + 4, -- TODO: let the script figure this out for us.
|
||||
tile_count = 17 * 13,
|
||||
|
||||
ok_routines = {
|
||||
[0x4] = true, -- sliding down flagpole
|
||||
[0x5] = true, -- end of level auto-walk
|
||||
[0x7] = true, -- start of level auto-walk
|
||||
[0x8] = true, -- normal (in control)
|
||||
[0x9] = true, -- acquiring mushroom
|
||||
[0xA] = true, -- losing big mario
|
||||
[0xB] = true, -- uhh
|
||||
[0xC] = true, -- acquiring fireflower
|
||||
},
|
||||
|
||||
bad_states = {
|
||||
power = true,
|
||||
waiting_demo = true,
|
||||
playing_demo = true,
|
||||
unknown = true,
|
||||
lose = true,
|
||||
},
|
||||
|
||||
jp_lut = {
|
||||
{ -- none
|
||||
up = false, down = false, left = false, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- A
|
||||
up = false, down = false, left = false, right = false,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- L
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- R
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- L + B
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = true, A = false,
|
||||
}, { -- R + B
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = true, A = false,
|
||||
}, { -- L + A
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- R + A
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- L + A + B
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = true, A = true,
|
||||
}, { -- R + A + B
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = true, A = true,
|
||||
}, { -- D
|
||||
up = false, down = true, left = false, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- D + A
|
||||
up = false, down = true, left = false, right = false,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- U
|
||||
up = true, down = false, left = false, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
},
|
||||
},
|
||||
|
||||
rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
|
||||
0, -40, -- 0x00
|
||||
6, -38,
|
||||
15, -37,
|
||||
22, -32,
|
||||
28, -28,
|
||||
32, -22,
|
||||
37, -14,
|
||||
39, -6,
|
||||
40, 0, -- 0x08
|
||||
38, 7,
|
||||
37, 15,
|
||||
33, 23,
|
||||
27, 29,
|
||||
22, 33,
|
||||
14, 37,
|
||||
6, 39,
|
||||
0, 41, -- 0x10
|
||||
-7, 40,
|
||||
-16, 38,
|
||||
-22, 34,
|
||||
-28, 28,
|
||||
-34, 23,
|
||||
-38, 16,
|
||||
-40, 8,
|
||||
-40, -0, -- 0x18
|
||||
-40, -6,
|
||||
-38, -14,
|
||||
-34, -22,
|
||||
-28, -28,
|
||||
-22, -32,
|
||||
-16, -36,
|
||||
-8, -38,
|
||||
},
|
||||
}
|
||||
|
||||
return setmetatable(gcfg, {
|
||||
__index = function(t, n)
|
||||
error("cannot use undeclared gameconfig '" .. tostring(n) .. "'", 2)
|
||||
end
|
||||
})
|
249
main.lua
249
main.lua
|
@ -26,108 +26,8 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
|
|||
|
||||
--randomseed(11)
|
||||
|
||||
local defer_prints = true
|
||||
|
||||
local playable_mode = false
|
||||
local start_big = false --true
|
||||
local starting_lives = 0 --1
|
||||
--
|
||||
local init_zeros = true -- instead of he_normal noise or whatever.
|
||||
local frameskip = 4
|
||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||
local deterministic = true -- use argmax on outputs instead of random sampling.
|
||||
local det_epsilon = false -- take random actions with probability eps.
|
||||
local eps_start = 1.0 * frameskip / 64
|
||||
local eps_stop = 0.1 * eps_start
|
||||
local eps_frames = 1000000
|
||||
--
|
||||
local epoch_trials = 50
|
||||
local epoch_top_trials = 25 -- new with ARS.
|
||||
local unperturbed_trial = true -- do a trial without any noise.
|
||||
local negate_trials = true -- try pairs of normal and negated noise directions.
|
||||
local time_inputs = true -- binary inputs of global frame count
|
||||
-- ^ note that this now doubles the effective trials.
|
||||
local deviation = 0.05 --0.075 --0.1
|
||||
local function approx_cossim(dim)
|
||||
return math.pow(1.521 * dim - 0.521, -0.5026)
|
||||
end
|
||||
--local learning_rate = 0.01 / approx_cossim(7051)
|
||||
local learning_rate = 1.0
|
||||
--local learning_rate = 0.0032 / approx_cossim(66573)
|
||||
--local learning_rate = 0.0056 / approx_cossim(66573)
|
||||
local weight_decay = 0.00032 --0.001 --0.0023
|
||||
--
|
||||
local cap_time = 200 --400
|
||||
local timer_loser = 1/2
|
||||
local decrement_reward = false -- bad idea, encourages mario to kill himself
|
||||
--
|
||||
local enable_overlay = playable_mode
|
||||
local enable_network = not playable_mode
|
||||
|
||||
local input_size = 60 + 4 -- TODO: let the script figure this out for us.
|
||||
local tile_count = 17 * 13
|
||||
|
||||
local ok_routines = {
|
||||
[0x4] = true, -- sliding down flagpole
|
||||
[0x5] = true, -- end of level auto-walk
|
||||
[0x7] = true, -- start of level auto-walk
|
||||
[0x8] = true, -- normal (in control)
|
||||
[0x9] = true, -- acquiring mushroom
|
||||
[0xA] = true, -- losing big mario
|
||||
[0xB] = true, -- uhh
|
||||
[0xC] = true, -- acquiring fireflower
|
||||
}
|
||||
|
||||
local bad_states = {
|
||||
power = true,
|
||||
waiting_demo = true,
|
||||
playing_demo = true,
|
||||
unknown = true,
|
||||
lose = true,
|
||||
}
|
||||
|
||||
local jp_lut = {
|
||||
{ -- none
|
||||
up = false, down = false, left = false, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- A
|
||||
up = false, down = false, left = false, right = false,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- L
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- R
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- L + B
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = true, A = false,
|
||||
}, { -- R + B
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = true, A = false,
|
||||
}, { -- L + A
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- R + A
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- L + A + B
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = true, A = true,
|
||||
}, { -- R + A + B
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = true, A = true,
|
||||
}, { -- D
|
||||
up = false, down = true, left = false, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- D + A
|
||||
up = false, down = true, left = false, right = false,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- U
|
||||
up = true, down = false, left = false, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
},
|
||||
}
|
||||
local cfg = require("config")
|
||||
local gcfg = require("gameconfig")
|
||||
|
||||
-- state.
|
||||
|
||||
|
@ -302,17 +202,21 @@ local network
|
|||
local nn_x, nn_tx, nn_ty, nn_y, nn_z
|
||||
local function make_network(input_size)
|
||||
nn_x = nn.Input({input_size})
|
||||
nn_tx = nn.Input({tile_count})
|
||||
nn_tx = nn.Input({gcfg.tile_count})
|
||||
nn_ty = nn_tx:feed(nn.Embed(256, 2))
|
||||
nn_y = nn.Merge()
|
||||
nn_x:feed(nn_y)
|
||||
nn_ty:feed(nn_y)
|
||||
|
||||
nn_y = nn_y:feed(nn.Dense(128))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
if cfg.deterministic then
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
else
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
end
|
||||
|
||||
nn_z = nn_y
|
||||
nn_z = nn_z:feed(nn.Dense(#jp_lut))
|
||||
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut))
|
||||
nn_z = nn_z:feed(nn.Softmax())
|
||||
return nn.Model({nn_x, nn_tx}, {nn_z})
|
||||
end
|
||||
|
@ -322,41 +226,6 @@ end
|
|||
-- disassembly used for reference:
|
||||
-- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
|
||||
|
||||
local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
|
||||
0, -40, -- 0x00
|
||||
6, -38,
|
||||
15, -37,
|
||||
22, -32,
|
||||
28, -28,
|
||||
32, -22,
|
||||
37, -14,
|
||||
39, -6,
|
||||
40, 0, -- 0x08
|
||||
38, 7,
|
||||
37, 15,
|
||||
33, 23,
|
||||
27, 29,
|
||||
22, 33,
|
||||
14, 37,
|
||||
6, 39,
|
||||
0, 41, -- 0x10
|
||||
-7, 40,
|
||||
-16, 38,
|
||||
-22, 34,
|
||||
-28, 28,
|
||||
-34, 23,
|
||||
-38, 16,
|
||||
-40, 8,
|
||||
-40, -0, -- 0x18
|
||||
-40, -6,
|
||||
-38, -14,
|
||||
-34, -22,
|
||||
-28, -28,
|
||||
-22, -32,
|
||||
-16, -36,
|
||||
-8, -38,
|
||||
}
|
||||
|
||||
local function get_timer()
|
||||
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
||||
end
|
||||
|
@ -386,7 +255,7 @@ local function mark_sprite(x, y, t)
|
|||
sprite_input[#sprite_input+1] = t
|
||||
end
|
||||
if t == 0 then return end
|
||||
if enable_overlay then
|
||||
if cfg.enable_overlay then
|
||||
gui.box(x-4, y-4, x+4, y+4)
|
||||
--gui.text(x-2, y-3, tostring(i), '#FFFFFF', '#00000000')
|
||||
gui.text(x-13, y-3-9, ("%+04i"):format(t), '#FFFFFF', '#0000003F')
|
||||
|
@ -397,7 +266,7 @@ end
|
|||
local function mark_tile(x, y, t)
|
||||
tile_input[#tile_input+1] = t
|
||||
if t == 0 then return end
|
||||
if enable_overlay then
|
||||
if cfg.enable_overlay then
|
||||
gui.box(x-8, y-8, x+8, y+8)
|
||||
gui.text(x-5, y-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
|
||||
end
|
||||
|
@ -479,7 +348,7 @@ local function handle_enemies()
|
|||
-- TODO: handle long fire bars too
|
||||
local rot = R(0xA0 + i) --* 0x100 + R(0x58 + i)
|
||||
gui.text(x-13, y-3+9, ("%04X"):format(rot), '#FFFFFF', '#0000003F')
|
||||
local x_off, y_off = rotation_offsets[rot*2+1], rotation_offsets[rot*2+2]
|
||||
local x_off, y_off = gcfg.rotation_offsets[rot*2+1], gcfg.rotation_offsets[rot*2+2]
|
||||
x, y = x + x_off, y + y_off
|
||||
end
|
||||
if invisible then
|
||||
|
@ -578,11 +447,11 @@ local function prepare_epoch()
|
|||
empty(trial_rewards)
|
||||
-- TODO: save memory. generate noise as needed by saving the seed
|
||||
-- (the os.time() as of here) and calling nn.normal() each trial.
|
||||
for i = 1, epoch_trials do
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local noise = nn.zeros(#base_params)
|
||||
-- NOTE: change in implementation: deviation is multiplied here
|
||||
-- and ONLY here now.
|
||||
for j = 1, #base_params do noise[j] = deviation * nn.normal() end
|
||||
for j = 1, #base_params do noise[j] = cfg.deviation * nn.normal() end
|
||||
trial_noise[i] = noise
|
||||
end
|
||||
trial_i = -1
|
||||
|
@ -590,7 +459,7 @@ end
|
|||
|
||||
local function load_next_pair()
|
||||
trial_i = trial_i + 1
|
||||
if trial_i == 0 and not unperturbed_trial then
|
||||
if trial_i == 0 and not cfg.unperturbed_trial then
|
||||
trial_i = 1
|
||||
trial_neg = true
|
||||
end
|
||||
|
@ -599,7 +468,7 @@ local function load_next_pair()
|
|||
|
||||
if trial_i > 0 then
|
||||
if trial_neg then
|
||||
if not defer_prints then print('trial', trial_i, 'positive') end
|
||||
if not cfg.defer_prints then print('trial', trial_i, 'positive') end
|
||||
local noise = trial_noise[trial_i]
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v + noise[i]
|
||||
|
@ -607,7 +476,7 @@ local function load_next_pair()
|
|||
|
||||
else
|
||||
trial_i = trial_i - 1
|
||||
if not defer_prints then print('trial', trial_i, 'positive') end
|
||||
if not cfg.defer_prints then print('trial', trial_i, 'positive') end
|
||||
local noise = trial_noise[trial_i]
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v - noise[i]
|
||||
|
@ -616,17 +485,17 @@ local function load_next_pair()
|
|||
|
||||
trial_neg = not trial_neg
|
||||
else
|
||||
if not defer_prints then print("test trial") end
|
||||
if not cfg.defer_prints then print("test trial") end
|
||||
end
|
||||
|
||||
network:distribute(W)
|
||||
end
|
||||
|
||||
local function load_next_trial()
|
||||
if negate_trials then return load_next_pair() end
|
||||
if cfg.negate_trials then return load_next_pair() end
|
||||
trial_i = trial_i + 1
|
||||
local W = nn.copy(base_params)
|
||||
if trial_i == 0 and not unperturbed_trial then
|
||||
if trial_i == 0 and not cfg.unperturbed_trial then
|
||||
trial_i = 1
|
||||
end
|
||||
if trial_i > 0 then
|
||||
|
@ -682,7 +551,7 @@ local function learn_from_epoch()
|
|||
|
||||
--for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end
|
||||
|
||||
if unperturbed_trial then
|
||||
if cfg.unperturbed_trial then
|
||||
local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0])
|
||||
|
||||
-- a rank of 1 means our gradient is uninformative.
|
||||
|
@ -694,10 +563,10 @@ local function learn_from_epoch()
|
|||
-- new stuff
|
||||
|
||||
local best_rewards
|
||||
if negate_trials then
|
||||
if cfg.negate_trials then
|
||||
-- select one (the best) reward of each pos/neg pair.
|
||||
best_rewards = {}
|
||||
for i = 1, epoch_trials do
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = trial_rewards[ind + 0]
|
||||
local neg = trial_rewards[ind + 1]
|
||||
|
@ -712,7 +581,7 @@ local function learn_from_epoch()
|
|||
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
|
||||
|
||||
--print("indices:", indices)
|
||||
for i = epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
||||
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
||||
print("best trials:", indices)
|
||||
|
||||
local top_rewards = {}
|
||||
|
@ -731,27 +600,27 @@ local function learn_from_epoch()
|
|||
for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end
|
||||
|
||||
-- NOTE: step no longer directly incorporates learning_rate.
|
||||
for i = 1, epoch_trials do
|
||||
for i = 1, cfg.epoch_trials do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = top_rewards[ind + 0]
|
||||
local neg = top_rewards[ind + 1]
|
||||
local reward = pos - neg
|
||||
local noise = trial_noise[i]
|
||||
for j, v in ipairs(noise) do
|
||||
step[j] = step[j] + reward * v / epoch_top_trials
|
||||
step[j] = step[j] + reward * v / cfg.epoch_top_trials
|
||||
end
|
||||
end
|
||||
|
||||
local step_mean, step_dev = calc_mean_dev(step)
|
||||
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
|
||||
--print("step stddev:", step_dev)
|
||||
print("full step stddev:", learning_rate * step_dev)
|
||||
print("full step stddev:", cfg.learning_rate * step_dev)
|
||||
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v + learning_rate * step[i] - weight_decay * v
|
||||
base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v
|
||||
end
|
||||
|
||||
if enable_network then
|
||||
if cfg.enable_network then
|
||||
network:distribute(base_params)
|
||||
network:save()
|
||||
else
|
||||
|
@ -782,10 +651,10 @@ local function do_reset()
|
|||
-- be a little more descriptive.
|
||||
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
||||
|
||||
if trial_i >= 0 and defer_prints then
|
||||
if trial_i >= 0 and cfg.defer_prints then
|
||||
if trial_i == 0 then
|
||||
print('test trial reward:', reward, "("..state..")")
|
||||
elseif negate_trials then
|
||||
elseif cfg.negate_trials then
|
||||
--local dir = trial_neg and "negative" or "positive"
|
||||
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
|
||||
|
||||
|
@ -804,14 +673,14 @@ local function do_reset()
|
|||
end
|
||||
|
||||
if trial_i >= 0 then
|
||||
if trial_i == 0 or not negate_trials then
|
||||
if trial_i == 0 or not cfg.negate_trials then
|
||||
trial_rewards[trial_i] = reward
|
||||
else
|
||||
trial_rewards[#trial_rewards + 1] = reward
|
||||
end
|
||||
end
|
||||
|
||||
if epoch_i == 0 or (trial_i == epoch_trials and trial_neg) then
|
||||
if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then
|
||||
if epoch_i > 0 then learn_from_epoch() end
|
||||
epoch_i = epoch_i + 1
|
||||
prepare_epoch()
|
||||
|
@ -825,19 +694,19 @@ local function do_reset()
|
|||
score_old = get_score()
|
||||
|
||||
-- set number of lives. (mario gets n+1 chances)
|
||||
W(0x75A, starting_lives)
|
||||
W(0x75A, cfg.starting_lives)
|
||||
|
||||
if start_big then
|
||||
if cfg.start_big then
|
||||
-- make mario "super".
|
||||
W(0x754, 0)
|
||||
W(0x756, 1)
|
||||
end
|
||||
|
||||
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
||||
--max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time)
|
||||
max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time)
|
||||
--max_time = min(log(epoch_i) * 10 + 100, cfg.cap_time)
|
||||
--max_time = min(8 * sqrt(360 / cfg.epoch_trials * (epoch_i - 1)) + 100, cfg.cap_time)
|
||||
max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time)
|
||||
max_time = ceil(max_time)
|
||||
--max_time = cap_time
|
||||
--max_time = cfg.cap_time
|
||||
|
||||
if once then
|
||||
savestate.load(startsave)
|
||||
|
@ -856,12 +725,12 @@ local function do_reset()
|
|||
end
|
||||
|
||||
local function init()
|
||||
network = make_network(input_size)
|
||||
network = make_network(gcfg.input_size)
|
||||
network:reset()
|
||||
network:print()
|
||||
print("parameters:", network.n_param)
|
||||
|
||||
if init_zeros then
|
||||
if cfg.init_zeros then
|
||||
local W = network:collect()
|
||||
for i, w in ipairs(W) do W[i] = 0 end
|
||||
network:distribute(W)
|
||||
|
@ -889,11 +758,11 @@ local function doit(dummy)
|
|||
-- TODO: more robust. doesn't detect moonwalking against a wall.
|
||||
-- well, that shouldn't happen anymore now that i've disabled left+right.
|
||||
local timer = get_timer()
|
||||
if ingame_paused or random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
||||
if ingame_paused or random() > 1 - cfg.timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
||||
timer = timer - 1
|
||||
end
|
||||
timer = clamp(timer, 0, max_time)
|
||||
if enable_network then
|
||||
if cfg.enable_network then
|
||||
set_timer(timer)
|
||||
end
|
||||
|
||||
|
@ -924,7 +793,7 @@ local function doit(dummy)
|
|||
insert(extra_input, vx)
|
||||
insert(extra_input, vy)
|
||||
|
||||
if time_inputs then
|
||||
if cfg.time_inputs then
|
||||
for i=2,5 do
|
||||
insert(extra_input, band(total_frames, lshift(1, i)))
|
||||
end
|
||||
|
@ -946,14 +815,14 @@ local function doit(dummy)
|
|||
local powerup_delta = powerup_old - powerup
|
||||
-- 2 is fire mario.
|
||||
local status_delta = clamp(status - status_old, -1, 1)
|
||||
local flagpole_bonus = R(0xE) == 4 and frameskip or 0
|
||||
local flagpole_bonus = R(0xE) == 4 and cfg.frameskip or 0
|
||||
--local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
||||
local score_delta = get_score() - score_old
|
||||
if score_delta < 0 then score_delta = 0 end
|
||||
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
|
||||
screen_scroll_delta = 0
|
||||
|
||||
if decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end
|
||||
if cfg.decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end
|
||||
|
||||
if not ingame_paused then reward = reward + reward_delta end
|
||||
|
||||
|
@ -981,22 +850,22 @@ local function doit(dummy)
|
|||
local X = {}
|
||||
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
||||
for i, v in ipairs(extra_input) do insert(X, v / 256) end
|
||||
nn.reshape(X, 1, input_size)
|
||||
nn.reshape(tile_input, 1, tile_count)
|
||||
nn.reshape(X, 1, gcfg.input_size)
|
||||
nn.reshape(tile_input, 1, gcfg.tile_count)
|
||||
|
||||
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||
total_frames = total_frames + frameskip
|
||||
if cfg.enable_network and get_state() == 'playing' or ingame_paused then
|
||||
total_frames = total_frames + cfg.frameskip
|
||||
|
||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
||||
|
||||
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
|
||||
if det_epsilon and random() < eps then
|
||||
local i = floor(random() * #jp_lut) + 1
|
||||
jp = nn.copy(jp_lut[i], jp)
|
||||
local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames)
|
||||
if cfg.det_epsilon and random() < eps then
|
||||
local i = floor(random() * #gcfg.jp_lut) + 1
|
||||
jp = nn.copy(gcfg.jp_lut[i], jp)
|
||||
else
|
||||
local choose = deterministic and argmax or softchoice
|
||||
local choose = cfg.deterministic and argmax or softchoice
|
||||
local ind = choose(unpack(outputs[nn_z]))
|
||||
jp = nn.copy(jp_lut[ind], jp)
|
||||
jp = nn.copy(gcfg.jp_lut[ind], jp)
|
||||
end
|
||||
|
||||
if force_start then
|
||||
|
@ -1026,7 +895,7 @@ init()
|
|||
while true do
|
||||
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||
|
||||
while bad_states[get_state()] do
|
||||
while gcfg.bad_states[get_state()] do
|
||||
-- mash the start button until we have control.
|
||||
joypad_mash('start')
|
||||
reset = true
|
||||
|
@ -1040,7 +909,7 @@ while true do
|
|||
|
||||
if reset then do_reset() end
|
||||
|
||||
if not enable_network then
|
||||
if not cfg.enable_network then
|
||||
-- infinite time cheat. super handy for testing.
|
||||
if R(0xE) == 8 then
|
||||
set_timer(667)
|
||||
|
@ -1056,7 +925,7 @@ while true do
|
|||
|
||||
-- FIXME: if the game lags then we might miss our frame to change inputs!
|
||||
-- don't rely on emu.framecount.
|
||||
local doot = jp == nil or emu.framecount() % frameskip == 0
|
||||
local doot = jp == nil or emu.framecount() % cfg.frameskip == 0
|
||||
doit(not doot)
|
||||
|
||||
-- jp might still be nil if we're not ingame or we're not playing.
|
||||
|
|
Loading…
Reference in a new issue