various work

This commit is contained in:
Connor Olding 2017-06-29 09:50:33 +00:00
parent e2d29352c4
commit 9017af0d13
2 changed files with 88 additions and 44 deletions

112
main.lua
View File

@ -22,13 +22,19 @@ end
-- configuration and globals. -- configuration and globals.
math.randomseed(10) --randomseed(11)
local learning_rate = 1e-2 local enable_overlay = false
local deviation = 2e-2 local enable_network = true
--
local epoch_trials = 24
local learning_rate = 3.2e-3
local deviation = 1e-1 / epoch_trials
--
local timer_loser = 1/3
local cap_time = 400
local epoch_i = 0 local epoch_i = 0
local epoch_trials = 12
local base_params local base_params
local trial_i = 0 local trial_i = 0
local trial_noise = {} local trial_noise = {}
@ -38,17 +44,16 @@ local trials_remaining = 0
local force_start = false local force_start = false
local force_start_old = false local force_start_old = false
local enable_overlay = false
local enable_network = true
local startsave = savestate.create(1) local startsave = savestate.create(1)
local poketime = false local poketime = false
local max_time
local sprite_input = {} local sprite_input = {}
local tile_input = {} local tile_input = {}
local reward local reward
local powerup_old local powerup_old
local status_old local status_old
local coins_old local coins_old
@ -89,8 +94,10 @@ local ceil = math.ceil
local min = math.min local min = math.min
local max = math.max local max = math.max
local exp = math.exp local exp = math.exp
local log = math.log
local sqrt = math.sqrt local sqrt = math.sqrt
local random = math.random local random = math.random
local randomseed = math.randomseed
local insert = table.insert local insert = table.insert
local remove = table.remove local remove = table.remove
local unpack = table.unpack or unpack local unpack = table.unpack or unpack
@ -177,12 +184,15 @@ local function make_network(input_size, buttons)
nn_x = nn.Input(input_size) nn_x = nn.Input(input_size)
nn_y = nn_x nn_y = nn_x
nn_z = {} nn_z = {}
nn_y = nn_y:feed(nn.Dense(input_size)) if false then
nn_y = nn_y:feed(nn.Relu()) nn_y = nn_y:feed(nn.Dense(input_size))
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16))) nn_y = nn_y:feed(nn.Gelu())
--nn_y = nn_y:feed(nn.Relu()) else
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16))) nn_y = nn_y:feed(nn.Dense(114))
--nn_y = nn_y:feed(nn.Relu()) nn_y = nn_y:feed(nn.Relu())
nn_y = nn_y:feed(nn.Dense(57))
nn_y = nn_y:feed(nn.Relu())
end
for i = 1, buttons do for i = 1, buttons do
nn_z[i] = nn_y nn_z[i] = nn_y
nn_z[i] = nn_z[i]:feed(nn.Dense(2)) nn_z[i] = nn_z[i]:feed(nn.Dense(2))
@ -451,7 +461,7 @@ local function learn_from_epoch()
print() print()
print('rewards:', trial_rewards) print('rewards:', trial_rewards)
normalize(trial_rewards) normalize(trial_rewards)
print('normalized:', trial_rewards) --print('normalized:', trial_rewards)
local reward_mean, reward_dev = calc_mean_dev(trial_rewards) local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
@ -465,7 +475,7 @@ local function learn_from_epoch()
end end
local magnitude = learning_rate / deviation local magnitude = learning_rate / deviation
print('stepping with magnitude', magnitude) --print('stepping with magnitude', magnitude)
-- throw the division from the averaging in there too. -- throw the division from the averaging in there too.
local altogether = magnitude / epoch_trials local altogether = magnitude / epoch_trials
for i, v in ipairs(step) do for i, v in ipairs(step) do
@ -473,19 +483,22 @@ local function learn_from_epoch()
end end
local step_mean, step_dev = calc_mean_dev(step) local step_mean, step_dev = calc_mean_dev(step)
print("step mean:", step_mean) if step_dev < 1e-8 then
-- we didn't get anywhere. step in a random direction.
print("stepping randomly.")
local noise = trial_noise[1]
local devsqrt = sqrt(deviation)
for i, v in ipairs(step) do
step[i] = devsqrt * noise[i]
end
step_mean, step_dev = calc_mean_dev(step)
end
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
print("step stddev:", step_dev) print("step stddev:", step_dev)
if step_dev > 1e-8 then for i, v in ipairs(base_params) do
for i, v in ipairs(base_params) do base_params[i] = v + step[i]
base_params[i] = v + step[i]
end
else
-- we didn't get anywhere. step in a random direction.
local noise = trial_noise[1]
for i, v in ipairs(base_params) do
base_params[i] = v + magnitude * noise[i]
end
end end
network:distribute(base_params) network:distribute(base_params)
@ -513,8 +526,9 @@ local function load_next_trial()
print('loading trial', trial_i) print('loading trial', trial_i)
local W = nn.copy(base_params) local W = nn.copy(base_params)
local noise = trial_noise[trial_i] local noise = trial_noise[trial_i]
local devsqrt = sqrt(deviation)
for i, v in ipairs(base_params) do for i, v in ipairs(base_params) do
W[i] = v + deviation * noise[i] W[i] = v + devsqrt * noise[i]
end end
network:distribute(W) network:distribute(W)
end end
@ -525,19 +539,11 @@ local function do_reset()
if trial_i > 0 then trial_rewards[trial_i] = reward end if trial_i > 0 then trial_rewards[trial_i] = reward end
if epoch_i == 0 or trial_i == epoch_trials then if epoch_i == 0 or trial_i == epoch_trials then
if epoch_i > 0 then learn_from_epoch() else network:reset() end if epoch_i > 0 then learn_from_epoch() end
epoch_i = epoch_i + 1 epoch_i = epoch_i + 1
prepare_epoch() prepare_epoch()
end end
if once then
savestate.load(startsave)
print("end of trial reward:", reward)
else
savestate.save(startsave)
end
once = true
-- bit of a hack: -- bit of a hack:
if get_state() == 'loading' then advance() end if get_state() == 'loading' then advance() end
reward = 0 reward = 0
@ -549,9 +555,20 @@ local function do_reset()
-- unless you get a 1-up, in which case, please continue! -- unless you get a 1-up, in which case, please continue!
W(0x75A, 0) W(0x75A, 0)
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
max_time = min(8 * sqrt(15 * (epoch_i - 1)) + 100, cap_time)
if once then
savestate.load(startsave)
--print("end of trial reward:", reward)
else
savestate.save(startsave)
end
once = true
emu.frameadvance() -- prevents emulator from quirking up. emu.frameadvance() -- prevents emulator from quirking up.
print() --print()
load_next_trial() load_next_trial()
reset = false reset = false
@ -564,7 +581,11 @@ local function init()
emu.poweron() emu.poweron()
emu.unpause() emu.unpause()
emu.speedmode("normal") emu.speedmode("turbo")
network:load()
local res, err = pcall(network.load, network)
if res == false then print(err) end
end end
init() init()
@ -647,6 +668,8 @@ while true do
local flagpole_bonus = R(0xE) == 4 and 1 or 0 local flagpole_bonus = R(0xE) == 4 and 1 or 0
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
-- TODO: add ingame score to reward.
if not ingame_paused then reward = reward + reward_delta end if not ingame_paused then reward = reward + reward_delta end
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F') --gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
@ -655,7 +678,7 @@ while true do
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
if get_state() == 'dead' and state_old ~= 'dead' then if get_state() == 'dead' and state_old ~= 'dead' then
print("dead. lives remaining:", R(0x75A, 0)) --print("dead. lives remaining:", R(0x75A, 0))
if R(0x75A, 0) == 0 then reset = true end if R(0x75A, 0) == 0 then reset = true end
end end
if get_state() == 'lose' then if get_state() == 'lose' then
@ -664,17 +687,18 @@ while true do
end end
-- lose a point for every frame paused. -- lose a point for every frame paused.
if ingame_paused then reward = reward - 1 end --if ingame_paused then reward = reward - 1 end
if ingame_paused then reward = reward - 402; reset = true end
-- every few frames mario stands still, forcibly decrease the timer. -- every few frames mario stands still, forcibly decrease the timer.
-- this includes having the game paused. -- this includes having the game paused.
-- TODO: more robust. doesn't detect moonwalking against a wall. -- TODO: more robust. doesn't detect moonwalking against a wall.
local timer = get_timer() local timer = get_timer()
local timer_loser = 1/5 if ingame_paused or random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
if ingame_paused or math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then timer = timer - 1
timer = clamp(timer - 1, 0, 400)
set_timer(timer)
end end
timer = clamp(timer, 0, max_time)
set_timer(timer)
-- if we've run out of time while the game is paused... -- if we've run out of time while the game is paused...
-- that's cheating! unpause. -- that's cheating! unpause.

20
nn.lua
View File

@ -124,6 +124,7 @@ local Layer = Base:extend()
local Model = Base:extend() local Model = Base:extend()
local Input = Layer:extend() local Input = Layer:extend()
local Relu = Layer:extend() local Relu = Layer:extend()
local Gelu = Layer:extend()
local Dense = Layer:extend() local Dense = Layer:extend()
local Softmax = Layer:extend() local Softmax = Layer:extend()
@ -241,6 +242,24 @@ function Relu:forward(X)
return Y return Y
end end
function Gelu:init()
Layer.init(self, "Gelu")
end
function Gelu:forward(X)
assert(#X == self.size_in)
self.cache = self.cache or zeros(self.size_out)
local Y = self.cache
-- NOTE: approximate form of GELU exploiting similarities to sigmoid curve.
for i = 1, #X do
Y[i] = X[i] / (1 + exp(-1.704 * X[i]))
end
assert(#Y == self.size_out)
return Y
end
function Dense:init(dim) function Dense:init(dim)
Layer.init(self, "Dense") Layer.init(self, "Dense")
assert(type(dim) == "number") assert(type(dim) == "number")
@ -413,6 +432,7 @@ return {
Model = Model, Model = Model,
Input = Input, Input = Input,
Relu = Relu, Relu = Relu,
Gelu = Gelu,
Dense = Dense, Dense = Dense,
Softmax = Softmax, Softmax = Softmax,
} }