smbot/main.lua

753 lines
21 KiB
Lua
Raw Normal View History

2018-04-02 06:28:00 -07:00
local globalize = require("strict")
2017-06-28 02:33:18 -07:00
2017-09-07 11:41:44 -07:00
-- configuration.
2017-06-28 21:51:02 -07:00
local cfg = require("config")
local gcfg = require("gameconfig")
2017-09-07 11:41:44 -07:00
-- state.
2017-06-28 21:51:02 -07:00
local epoch_i = 0
local base_params
2017-09-07 12:00:09 -07:00
local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled.
2018-03-26 07:32:00 -07:00
local trial_neg = true
2017-06-28 21:51:02 -07:00
local trial_noise = {}
local trial_rewards = {}
local trials_remaining = 0
2018-05-03 06:33:17 -07:00
local mom1 -- first moments in AMSgrad.
local mom2 -- second moments in AMSgrad.
local mom2max -- running element-wise maximum of mom2.
2017-06-28 21:51:02 -07:00
2017-07-05 20:26:27 -07:00
local trial_frames = 0
local total_frames = 0
2017-06-28 21:51:02 -07:00
local force_start = false
local force_start_old = false
local startsave = savestate.create(1)
local poketime = false
2017-06-29 02:50:33 -07:00
local max_time
2017-06-28 21:51:02 -07:00
2017-09-07 14:20:53 -07:00
local jp
local screen_scroll_delta
2017-06-28 21:51:02 -07:00
local reward
2018-03-26 07:32:00 -07:00
--local all_rewards = {}
2017-06-29 02:50:33 -07:00
2017-06-28 21:51:02 -07:00
local powerup_old
local status_old
local coins_old
2017-07-05 20:26:27 -07:00
local score_old
2017-06-28 21:51:02 -07:00
local once = false
local reset = true
local state_old = ''
local last_trial_state
2017-06-28 21:51:02 -07:00
2017-06-28 02:33:18 -07:00
-- localize some stuff.
2018-05-13 16:34:08 -07:00
local assert = assert
2017-06-28 17:14:56 -07:00
local print = print
2017-06-28 02:33:18 -07:00
local ipairs = ipairs
local pairs = pairs
local select = select
2018-05-03 06:33:17 -07:00
local open = io.open
2017-06-28 02:33:18 -07:00
local abs = math.abs
local floor = math.floor
local ceil = math.ceil
local min = math.min
local max = math.max
local exp = math.exp
2018-04-02 07:29:12 -07:00
local pow = math.pow
2017-06-29 02:50:33 -07:00
local log = math.log
2017-06-28 02:33:18 -07:00
local sqrt = math.sqrt
local random = math.random
2017-06-29 02:50:33 -07:00
local randomseed = math.randomseed
2017-06-28 02:33:18 -07:00
local insert = table.insert
local remove = table.remove
local unpack = table.unpack or unpack
2017-09-07 11:53:37 -07:00
local sort = table.sort
2017-06-28 02:33:18 -07:00
local band = bit.band
local bor = bit.bor
local bxor = bit.bxor
local bnot = bit.bnot
local lshift = bit.lshift
local rshift = bit.rshift
local arshift = bit.arshift
local rol = bit.rol
local ror = bit.ror
2018-05-13 16:34:08 -07:00
local emu = emu
2018-04-02 07:29:12 -07:00
local gui = gui
2018-05-12 13:38:51 -07:00
local util = require("util")
local argmax = util.argmax
local calc_mean_dev = util.calc_mean_dev
local clamp = util.clamp
local copy = util.copy
local empty = util.empty
local lerp = util.lerp
local softchoice = util.softchoice
local unperturbed_rank = util.unperturbed_rank
2018-05-12 13:55:04 -07:00
local game = require("smb")
game.overlay = cfg.enable_overlay
2017-06-28 21:51:02 -07:00
-- utilities.
2018-05-03 06:33:17 -07:00
local log_map = {
epoch = 1,
trial_mean = 2,
trial_std = 3,
delta_mean = 4,
delta_std = 5,
step_std = 6,
adam_std = 7,
weight_mean = 8,
weight_std = 9,
2018-05-06 20:57:18 -07:00
test_trial = 10,
2018-05-03 06:33:17 -07:00
}
local function log_csv(t)
if cfg.log_fn == nil then return end
local f = open(cfg.log_fn, 'a')
if f == nil then error("Failed to open log file "..cfg.log_fn) end
local values = {}
for k, v in pairs(t) do
local i = log_map[k]
if i == nil then error("Unexpected log key "..tostring(k)) end
values[i] = v
end
for k, i in pairs(log_map) do
if values[i] == nil then error("Missing log key "..tostring(k)) end
end
for i, v in ipairs(values) do
f:write(tostring(v))
if i ~= #values then f:write(",") end
end
f:write('\n')
f:close()
end
2017-09-09 12:46:35 -07:00
-- network parameters.
2017-06-28 02:33:18 -07:00
package.loaded['nn'] = nil -- DEBUG
local nn = require("nn")
local network
2017-09-07 16:06:30 -07:00
local nn_x, nn_tx, nn_ty, nn_y, nn_z
2017-09-09 12:37:01 -07:00
local function make_network(input_size)
nn_x = nn.Input({input_size})
nn_tx = nn.Input({gcfg.tile_count})
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
2017-09-07 16:06:30 -07:00
nn_y = nn.Merge()
nn_x:feed(nn_y)
nn_ty:feed(nn_y)
2018-03-31 09:40:35 -07:00
nn_y = nn_y:feed(nn.Dense(128))
if cfg.deterministic then
nn_y = nn_y:feed(nn.Relu())
else
nn_y = nn_y:feed(nn.Gelu())
end
2018-05-07 07:22:48 -07:00
if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end
2018-03-31 09:40:35 -07:00
nn_z = nn_y
nn_z = nn_z:feed(nn.Dense(#gcfg.jp_lut))
nn_z = nn_z:feed(nn.Softmax())
return nn.Model({nn_x, nn_tx}, {nn_z})
2017-06-28 02:33:18 -07:00
end
2017-09-09 12:46:35 -07:00
-- learning and evaluation.
2017-09-07 11:41:44 -07:00
local function prepare_epoch()
base_params = network:collect()
2018-04-03 09:13:11 -07:00
if cfg.playback_mode then return end
print('preparing epoch '..tostring(epoch_i)..'.')
2017-09-07 11:41:44 -07:00
empty(trial_noise)
empty(trial_rewards)
2018-04-02 07:29:12 -07:00
local precision = (pow(cfg.deviation, 1/-0.51175585) - 8.68297257) / 1.66484392
2018-05-02 04:06:28 -07:00
if cfg.graycode then
print(("chosen precision: %.2f"):format(precision))
end
2018-04-02 07:29:12 -07:00
for i = 1, cfg.epoch_trials do
2017-09-07 11:41:44 -07:00
local noise = nn.zeros(#base_params)
2018-05-02 04:06:28 -07:00
if cfg.graycode then
2018-04-02 07:29:12 -07:00
for j = 1, #base_params do
noise[j] = exp(-precision * nn.uniform())
end
for j = 1, #base_params do
noise[j] = nn.uniform() < 0.5 and noise[j] or -noise[j]
end
else
for j = 1, #base_params do
noise[j] = cfg.deviation * nn.normal()
end
end
2017-09-07 11:41:44 -07:00
trial_noise[i] = noise
end
2017-09-07 12:00:09 -07:00
trial_i = -1
2017-09-07 11:41:44 -07:00
end
2018-03-26 07:32:00 -07:00
local function load_next_pair()
trial_i = trial_i + 1
if trial_i == 0 and not cfg.unperturbed_trial then
2018-03-26 07:32:00 -07:00
trial_i = 1
trial_neg = true
end
2018-05-12 13:38:51 -07:00
local W = copy(base_params)
2018-03-26 07:32:00 -07:00
if trial_i > 0 then
if trial_neg then
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v + noise[i]
2018-03-26 07:32:00 -07:00
end
else
trial_i = trial_i - 1
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v - noise[i]
2018-03-26 07:32:00 -07:00
end
end
trial_neg = not trial_neg
end
network:distribute(W)
end
2017-09-07 11:41:44 -07:00
local function load_next_trial()
if cfg.negate_trials then return load_next_pair() end
2017-09-07 11:41:44 -07:00
trial_i = trial_i + 1
2018-05-12 13:38:51 -07:00
local W = copy(base_params)
if trial_i == 0 and not cfg.unperturbed_trial then
2017-09-07 12:00:09 -07:00
trial_i = 1
end
if trial_i > 0 then
print('loading trial', trial_i)
local noise = trial_noise[trial_i]
for i, v in ipairs(base_params) do
W[i] = v + noise[i]
2017-09-07 12:00:09 -07:00
end
else
print("test trial")
2017-09-07 11:41:44 -07:00
end
network:distribute(W)
end
2018-05-13 16:34:08 -07:00
local function collect_best_indices()
-- select one (the best) reward of each pos/neg pair.
local best_rewards = {}
if cfg.negate_trials then
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = trial_rewards[ind + 0]
local neg = trial_rewards[ind + 1]
best_rewards[i] = max(pos, neg)
end
else
best_rewards = copy(trial_rewards)
end
local indices = {}
for i = 1, #best_rewards do indices[i] = i end
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
return indices
end
local function kinda_lipschitz(dir, pos, neg, mid)
local _, dev = calc_mean_dev(dir)
local c0 = neg - mid
local c1 = pos - mid
local l0 = abs(3 * c1 + c0)
local l1 = abs(c1 + 3 * c0)
return max(l0, l1) / (2 * dev)
end
local function make_step_paired(rewards, current_cost)
local step = nn.zeros(#base_params)
local _, reward_dev = calc_mean_dev(rewards)
if reward_dev == 0 then reward_dev = 1 end
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = rewards[ind + 0]
local neg = rewards[ind + 1]
local reward = pos - neg
if reward ~= 0 then
local noise = trial_noise[i]
if cfg.ars_lips then
local lips = kinda_lipschitz(noise, pos, neg, current_cost)
reward = reward / lips / cfg.deviation
else
reward = reward / reward_dev
end
for j, v in ipairs(noise) do
step[j] = step[j] + reward * v / cfg.epoch_top_trials
end
end
end
return step
end
local function make_step(rewards)
local step = nn.zeros(#base_params)
local _, reward_dev = calc_mean_dev(rewards)
if reward_dev == 0 then reward_dev = 1 end
for i = 1, cfg.epoch_trials do
local reward = rewards[i] / reward_dev
if reward ~= 0 then
local noise = trial_noise[i]
for j, v in ipairs(noise) do
step[j] = step[j] + reward * v / cfg.epoch_top_trials
end
end
end
return step
end
local function amsgrad(step) -- in-place!
if mom1 == nil then mom1 = nn.zeros(#step) end
if mom2 == nil then mom2 = nn.zeros(#step) end
if mom2max == nil then mom2max = nn.zeros(#step) end
local b1_t = pow(cfg.adam_b1, epoch_i)
local b2_t = pow(cfg.adam_b2, epoch_i)
-- NOTE: with LuaJIT, splitting this loop would
-- almost certainly be faster.
for i, v in ipairs(step) do
mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v
mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v
mom2max[i] = max(mom2[i], mom2max[i])
if cfg.adam_debias then
local num = (mom1[i] / (1 - b1_t))
local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps
step[i] = num / den
else
step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps)
end
end
end
2017-06-28 21:51:02 -07:00
local function learn_from_epoch()
print()
2018-03-26 07:32:00 -07:00
--print('rewards:', trial_rewards)
2017-07-05 20:26:27 -07:00
2018-03-26 07:32:00 -07:00
--for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end
2017-07-05 20:26:27 -07:00
2018-05-06 20:57:18 -07:00
local current_cost = trial_rewards[0] -- may be nil!
if cfg.unperturbed_trial then
2018-05-06 20:57:18 -07:00
local nth_place = unperturbed_rank(trial_rewards, current_cost)
2017-09-07 12:00:09 -07:00
-- a rank of 1 means our gradient is uninformative.
print(("test trial: %d out of %d"):format(nth_place, #trial_rewards))
2017-07-05 20:26:27 -07:00
end
2017-06-28 21:51:02 -07:00
2018-05-13 16:34:08 -07:00
local delta_rewards = {} -- only used for logging.
if cfg.negate_trials then
2018-05-03 06:33:17 -07:00
for i = 1, cfg.epoch_trials do
local ind = (i - 1) * 2 + 1
local pos = trial_rewards[ind + 0]
local neg = trial_rewards[ind + 1]
delta_rewards[i] = abs(pos - neg)
end
2018-03-26 07:32:00 -07:00
end
2018-05-13 16:34:08 -07:00
local indices = collect_best_indices()
2018-03-26 07:32:00 -07:00
print("best trials:", indices)
local top_rewards = {}
for i = 1, #trial_rewards do top_rewards[i] = 0 end
for _, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1
top_rewards[sind + 0] = trial_rewards[sind + 0]
top_rewards[sind + 1] = trial_rewards[sind + 1]
end
2018-05-02 04:06:28 -07:00
--print("top:", top_rewards)
2018-05-13 16:34:08 -07:00
if cfg.negate_trials then
local top_delta_rewards = {} -- only used for printing.
for i, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
end
print("best deltas:", top_delta_rewards)
2018-05-06 20:57:52 -07:00
end
2018-03-26 07:32:00 -07:00
2018-05-13 16:34:08 -07:00
local step
if cfg.negate_trials then
step = make_step_paired(top_rewards, current_cost)
else
step = make_step(top_rewards)
2017-06-28 21:51:02 -07:00
end
local step_mean, step_dev = calc_mean_dev(step)
2018-05-03 06:33:17 -07:00
print("step mean:", step_mean)
print("step stddev:", step_dev)
local momstep_mean, momstep_dev = 0, 0
2018-05-03 06:33:17 -07:00
if cfg.adamant then
2018-05-13 16:34:08 -07:00
amsgrad(step)
2018-05-03 06:33:17 -07:00
momstep_mean, momstep_dev = calc_mean_dev(step)
print("amsgrad mean:", momstep_mean)
print("amsgrad stddev:", momstep_dev)
end
2017-06-29 02:50:33 -07:00
for i, v in ipairs(base_params) do
base_params[i] = v + cfg.learning_rate * step[i] - cfg.weight_decay * v
2017-06-28 02:33:18 -07:00
end
2017-06-28 21:51:02 -07:00
2018-05-03 06:33:17 -07:00
local trial_mean, trial_std = calc_mean_dev(trial_rewards)
local delta_mean, delta_std = calc_mean_dev(delta_rewards)
local weight_mean, weight_std = calc_mean_dev(base_params)
log_csv{
epoch = epoch_i,
trial_mean = trial_mean,
trial_std = trial_std,
delta_mean = delta_mean,
delta_std = delta_std,
step_std = step_dev,
adam_std = momstep_dev,
weight_mean = weight_mean,
weight_std = weight_std,
2018-05-06 20:57:18 -07:00
test_trial = current_cost or 0,
2018-05-03 06:33:17 -07:00
}
if cfg.enable_network then
2017-07-05 20:26:27 -07:00
network:distribute(base_params)
network:save(cfg.params_fn)
2017-07-05 20:26:27 -07:00
else
print("note: not updating weights in playable mode.")
end
2017-06-28 21:51:02 -07:00
print()
end
local function joypad_mash(button)
local jp_mash = {
up = false,
down = false,
left = false,
right = false,
A = false,
B = false,
select = false,
start = false,
}
assert(jp_mash[button] == false, "invalid button: "..tostring(button), 1)
jp_mash[button] = emu.framecount() % 2 == 1
joypad.write(1, jp_mash)
end
2017-06-28 21:51:02 -07:00
local function do_reset()
2018-05-12 13:38:51 -07:00
local state = game.get_state()
2017-09-07 11:41:44 -07:00
-- be a little more descriptive.
2018-05-12 13:38:51 -07:00
if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end
if trial_i >= 0 then
if trial_i == 0 then
print('test trial reward:', reward, "("..state..")")
elseif cfg.negate_trials then
--local dir = trial_neg and "negative" or "positive"
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
if trial_neg then
local pos = trial_rewards[#trial_rewards]
local neg = reward
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
print(fmt:format(trial_i, pos, neg, last_trial_state, state))
end
last_trial_state = state
else
print('trial', trial_i, 'reward:', reward, "("..state..")")
end
else
print("reward:", reward, "("..state..")")
end
2017-06-28 21:51:02 -07:00
2018-03-26 07:32:00 -07:00
if trial_i >= 0 then
if trial_i == 0 or not cfg.negate_trials then
2018-03-26 07:32:00 -07:00
trial_rewards[trial_i] = reward
else
trial_rewards[#trial_rewards + 1] = reward
end
end
2017-06-28 21:51:02 -07:00
if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then
2017-06-29 02:50:33 -07:00
if epoch_i > 0 then learn_from_epoch() end
2018-04-03 09:13:11 -07:00
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
2017-06-28 21:51:02 -07:00
prepare_epoch()
end
2017-06-28 17:14:56 -07:00
2018-05-12 13:38:51 -07:00
if game.get_state() == 'loading' then game.advance() end -- kind of a hack.
2017-06-28 17:14:56 -07:00
reward = 0
2018-05-12 13:38:51 -07:00
powerup_old = game.R(0x754)
status_old = game.R(0x756)
coins_old = game.R(0x7ED) * 10 + game.R(0x7EE)
score_old = game.get_score()
2017-06-28 17:14:56 -07:00
2018-03-26 07:32:00 -07:00
-- set number of lives. (mario gets n+1 chances)
2018-05-12 13:38:51 -07:00
game.W(0x75A, cfg.starting_lives)
2018-03-26 07:32:00 -07:00
if cfg.start_big then
2018-03-26 07:32:00 -07:00
-- make mario "super".
2018-05-12 13:38:51 -07:00
game.W(0x754, 0)
game.W(0x756, 1)
2018-03-26 07:32:00 -07:00
end
2017-06-28 17:14:56 -07:00
max_time = min(6 * sqrt(480 / cfg.epoch_trials * (epoch_i - 1)) + 60, cfg.cap_time)
2017-07-05 20:26:27 -07:00
max_time = ceil(max_time)
2017-06-29 02:50:33 -07:00
if once then
savestate.load(startsave)
else
savestate.save(startsave)
end
once = true
2017-09-07 14:20:53 -07:00
jp = nil
screen_scroll_delta = 0
trial_frames = 0
2017-06-28 17:14:56 -07:00
emu.frameadvance() -- prevents emulator from quirking up.
2017-06-28 21:51:02 -07:00
load_next_trial()
2017-06-28 17:14:56 -07:00
reset = false
end
2017-06-28 21:51:02 -07:00
local function init()
network = make_network(gcfg.input_size)
2017-06-28 21:51:02 -07:00
network:reset()
2017-09-07 16:06:43 -07:00
network:print()
2017-06-28 21:51:02 -07:00
print("parameters:", network.n_param)
if cfg.init_zeros then
2018-03-26 07:32:00 -07:00
local W = network:collect()
for i, w in ipairs(W) do W[i] = 0 end
network:distribute(W)
end
2017-06-28 21:51:02 -07:00
emu.poweron()
emu.unpause()
2017-06-29 02:50:33 -07:00
emu.speedmode("turbo")
while emu.framecount() < 195 do -- FIXME: don't hardcode this.
joypad_mash('start')
emu.frameadvance()
end
2018-04-02 07:29:12 -07:00
--print(emu.framecount())
local res, err = pcall(network.load, network, cfg.params_fn)
2017-06-29 02:50:33 -07:00
if res == false then print(err) end
2017-06-28 21:51:02 -07:00
end
2018-04-03 09:13:11 -07:00
local function prepare_reset()
if cfg.playback_mode then return end
reset = true
end
2017-09-07 14:20:53 -07:00
local function doit(dummy)
2018-05-12 13:38:51 -07:00
local ingame_paused = game.get_state() == "paused"
-- every few frames mario stands still, forcibly decrease the timer.
-- this includes having the game paused.
-- TODO: more robust. doesn't detect moonwalking against a wall.
2017-09-09 12:46:35 -07:00
-- well, that shouldn't happen anymore now that i've disabled left+right.
2018-05-12 13:38:51 -07:00
local timer = game.get_timer()
if ingame_paused or random() > 1 - cfg.timer_loser and game.R(0x1D) == 0 and game.R(0x57) == 0 then
timer = timer - 1
end
2018-04-03 09:13:11 -07:00
if not cfg.playback_mode then
timer = clamp(timer, 0, max_time)
if cfg.enable_network then
2018-05-12 13:38:51 -07:00
game.set_timer(timer)
2018-04-03 09:13:11 -07:00
end
end
local tf0 = total_frames % 1000
local tf1 = (total_frames % 1000000 - tf0) / 1000
2017-09-08 03:27:10 -07:00
local tf2 = (total_frames - tf0 - tf1) / 1000000
gui.text(12, 212, ("%03i,%03i,%03i"):format(tf2,tf1,tf0), '#FFFFFF', '#0000003F')
2018-05-12 13:38:51 -07:00
screen_scroll_delta = screen_scroll_delta + game.R(0x775)
2017-09-09 12:37:01 -07:00
2017-09-07 14:20:53 -07:00
if dummy == true then
2017-09-09 12:37:01 -07:00
-- don't invoke AI this frame. (keep holding the old inputs)
2018-03-31 09:40:35 -07:00
gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F')
2017-09-07 14:20:53 -07:00
return
2017-06-28 02:33:18 -07:00
end
2018-05-12 13:38:51 -07:00
empty(game.sprite_input)
empty(game.tile_input)
empty(game.extra_input)
2017-06-28 02:33:18 -07:00
2017-09-09 12:46:35 -07:00
-- TODO: check if mario is in a playable state.
2018-05-12 13:38:51 -07:00
local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
local powerup = game.R(0x754)
local status = game.R(0x756)
game.mark_sprite(x + 8, y + 24, -powerup - 1)
2017-06-28 02:33:18 -07:00
2018-05-12 13:38:51 -07:00
local vx, vy = game.S(0x57), game.S(0x9F)
insert(game.extra_input, vx * 16)
insert(game.extra_input, vy * 16)
2017-07-05 20:26:27 -07:00
if cfg.time_inputs then
for i=2,5 do
local v = band(trial_frames, lshift(1, i)) == 0 and -181 or 181
2018-05-12 13:38:51 -07:00
insert(game.extra_input, v)
end
end
2018-05-12 13:38:51 -07:00
game.handle_enemies()
game.handle_fireballs()
--game.handle_blocks() -- blocks being hit. not interactable; we don't care!
game.handle_hammers()
game.handle_misc()
game.handle_tiles()
2017-06-28 02:33:18 -07:00
2018-05-12 13:38:51 -07:00
local coins = game.R(0x7ED) * 10 + game.R(0x7EE)
2017-06-28 02:33:18 -07:00
local coins_delta = coins - coins_old
-- handle wrap-around.
if coins_delta < 0 then coins_delta = 100 + coins - coins_old end
-- remember that 0 is big mario and 1 is small mario.
local powerup_delta = powerup_old - powerup
-- 2 is fire mario.
local status_delta = clamp(status - status_old, -1, 1)
2018-05-12 13:38:51 -07:00
local flagpole_bonus = game.R(0xE) == 4 and cfg.frameskip or 0
2017-07-05 20:26:27 -07:00
--local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
2018-05-12 13:38:51 -07:00
local score_delta = game.get_score() - score_old
2017-07-05 20:26:27 -07:00
if score_delta < 0 then score_delta = 0 end
local reward_delta = screen_scroll_delta + score_delta + flagpole_bonus
2017-09-07 14:20:53 -07:00
screen_scroll_delta = 0
2017-06-29 02:50:33 -07:00
if cfg.decrement_reward and reward_delta == 0 then reward_delta = reward_delta - 1 end
2018-03-26 07:32:00 -07:00
2017-06-28 21:51:02 -07:00
if not ingame_paused then reward = reward + reward_delta end
2017-06-28 02:33:18 -07:00
2017-07-05 20:26:27 -07:00
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
2018-03-31 09:40:35 -07:00
gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F')
2017-06-28 02:33:18 -07:00
2018-05-12 13:38:51 -07:00
if game.get_state() == 'dead' and state_old ~= 'dead' then
--print("dead. lives remaining:", game.R(0x75A, 0))
if game.R(0x75A, 0) == 0 then prepare_reset() end
2017-06-28 02:33:18 -07:00
end
2018-05-12 13:38:51 -07:00
if game.get_state() == 'lose' then
2017-09-09 12:46:35 -07:00
-- this shouldn't happen if we catch the deaths as above.
2017-06-28 02:33:18 -07:00
print("ran out of lives.")
2018-04-03 09:13:11 -07:00
if not cfg.playback_mode then prepare_reset() end
2017-06-28 02:33:18 -07:00
end
2017-06-28 21:51:02 -07:00
-- lose a point for every frame paused.
2017-06-29 02:50:33 -07:00
--if ingame_paused then reward = reward - 1 end
2018-04-03 09:13:11 -07:00
if ingame_paused then reward = reward - 402; prepare_reset() end
2017-06-28 21:51:02 -07:00
-- if we've run out of time while the game is paused...
-- that's cheating! unpause.
force_start = ingame_paused and timer == 0
2017-09-07 14:20:53 -07:00
local X = {}
2018-05-12 13:38:51 -07:00
for i, v in ipairs(game.sprite_input) do insert(X, v / 256) end
for i, v in ipairs(game.extra_input) do insert(X, v / 256) end
nn.reshape(X, 1, gcfg.input_size)
2018-05-12 13:38:51 -07:00
nn.reshape(game.tile_input, 1, gcfg.tile_count)
2017-06-28 02:33:18 -07:00
trial_frames = trial_frames + cfg.frameskip
2018-05-12 13:38:51 -07:00
if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then
total_frames = total_frames + cfg.frameskip
2017-09-08 03:27:10 -07:00
2018-05-12 13:38:51 -07:00
local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
2017-07-05 20:26:27 -07:00
local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames)
if cfg.det_epsilon and random() < eps then
local i = floor(random() * #gcfg.jp_lut) + 1
2018-05-12 13:38:51 -07:00
jp = copy(gcfg.jp_lut[i], jp)
else
local choose = cfg.deterministic and argmax or softchoice
local ind = choose(unpack(outputs[nn_z]))
2018-05-12 13:38:51 -07:00
jp = copy(gcfg.jp_lut[ind], jp)
2017-07-05 20:26:27 -07:00
end
2017-06-28 21:51:02 -07:00
if force_start then
jp = {
up = false,
down = false,
left = false,
right = false,
A = false,
B = false,
start = force_start_old,
select = false,
}
end
2017-06-28 02:33:18 -07:00
end
coins_old = coins
powerup_old = powerup
status_old = status
2017-06-28 21:51:02 -07:00
force_start_old = force_start
2018-05-12 13:38:51 -07:00
state_old = game.get_state()
score_old = game.get_score()
2017-09-07 14:20:53 -07:00
end
2017-09-09 12:46:35 -07:00
init()
2017-09-07 14:20:53 -07:00
while true do
2018-05-12 13:38:51 -07:00
gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F')
2017-09-07 14:20:53 -07:00
2018-05-12 13:38:51 -07:00
while gcfg.bad_states[game.get_state()] do
2017-09-07 14:20:53 -07:00
-- mash the start button until we have control.
joypad_mash('start')
2018-04-03 09:13:11 -07:00
prepare_reset()
2017-09-07 14:20:53 -07:00
2018-05-12 13:38:51 -07:00
game.advance()
gui.text(4, 12, game.get_state(), '#FFFFFF', '#0000003F')
2017-09-07 14:20:53 -07:00
2018-05-12 13:38:51 -07:00
while game.get_state() == "loading" do game.advance() end -- kind of a hack.
state_old = game.get_state()
2017-09-07 14:20:53 -07:00
end
if reset then do_reset() end
if not cfg.enable_network then
2017-09-07 14:20:53 -07:00
-- infinite time cheat. super handy for testing.
2018-05-12 13:38:51 -07:00
if game.R(0xE) == 8 then
game.set_timer(667)
2017-09-07 14:20:53 -07:00
poketime = true
elseif poketime then
poketime = false
2018-05-12 13:38:51 -07:00
game.set_timer(1)
2017-09-07 14:20:53 -07:00
end
-- infinite lives.
2018-05-12 13:38:51 -07:00
game.W(0x75A, 1)
2017-09-07 14:20:53 -07:00
end
2018-01-30 11:24:32 -08:00
-- FIXME: if the game lags then we might miss our frame to change inputs!
2017-09-09 12:37:01 -07:00
-- don't rely on emu.framecount.
local doot = jp == nil or emu.framecount() % cfg.frameskip == 0
2017-09-07 14:20:53 -07:00
doit(not doot)
-- jp might still be nil if we're not ingame or we're not playing.
if jp ~= nil then joypad.write(1, jp) end
2018-05-12 13:38:51 -07:00
game.advance()
2017-06-28 02:33:18 -07:00
end