various work

This commit is contained in:
Connor Olding 2017-06-29 09:50:33 +00:00
parent e2d29352c4
commit 9017af0d13
2 changed files with 88 additions and 44 deletions

112
main.lua
View file

@ -22,13 +22,19 @@ end
-- configuration and globals.
math.randomseed(10)
--randomseed(11)
local learning_rate = 1e-2
local deviation = 2e-2
local enable_overlay = false
local enable_network = true
--
local epoch_trials = 24
local learning_rate = 3.2e-3
local deviation = 1e-1 / epoch_trials
--
local timer_loser = 1/3
local cap_time = 400
local epoch_i = 0
local epoch_trials = 12
local base_params
local trial_i = 0
local trial_noise = {}
@ -38,17 +44,16 @@ local trials_remaining = 0
local force_start = false
local force_start_old = false
local enable_overlay = false
local enable_network = true
local startsave = savestate.create(1)
local poketime = false
local max_time
local sprite_input = {}
local tile_input = {}
local reward
local powerup_old
local status_old
local coins_old
@ -89,8 +94,10 @@ local ceil = math.ceil
local min = math.min
local max = math.max
local exp = math.exp
local log = math.log
local sqrt = math.sqrt
local random = math.random
local randomseed = math.randomseed
local insert = table.insert
local remove = table.remove
local unpack = table.unpack or unpack
@ -177,12 +184,15 @@ local function make_network(input_size, buttons)
nn_x = nn.Input(input_size)
nn_y = nn_x
nn_z = {}
nn_y = nn_y:feed(nn.Dense(input_size))
nn_y = nn_y:feed(nn.Relu())
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16)))
--nn_y = nn_y:feed(nn.Relu())
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16)))
--nn_y = nn_y:feed(nn.Relu())
if false then
nn_y = nn_y:feed(nn.Dense(input_size))
nn_y = nn_y:feed(nn.Gelu())
else
nn_y = nn_y:feed(nn.Dense(114))
nn_y = nn_y:feed(nn.Relu())
nn_y = nn_y:feed(nn.Dense(57))
nn_y = nn_y:feed(nn.Relu())
end
for i = 1, buttons do
nn_z[i] = nn_y
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
@ -451,7 +461,7 @@ local function learn_from_epoch()
print()
print('rewards:', trial_rewards)
normalize(trial_rewards)
print('normalized:', trial_rewards)
--print('normalized:', trial_rewards)
local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
@ -465,7 +475,7 @@ local function learn_from_epoch()
end
local magnitude = learning_rate / deviation
print('stepping with magnitude', magnitude)
--print('stepping with magnitude', magnitude)
-- throw the division from the averaging in there too.
local altogether = magnitude / epoch_trials
for i, v in ipairs(step) do
@ -473,19 +483,22 @@ local function learn_from_epoch()
end
local step_mean, step_dev = calc_mean_dev(step)
print("step mean:", step_mean)
if step_dev < 1e-8 then
-- we didn't get anywhere. step in a random direction.
print("stepping randomly.")
local noise = trial_noise[1]
local devsqrt = sqrt(deviation)
for i, v in ipairs(step) do
step[i] = devsqrt * noise[i]
end
step_mean, step_dev = calc_mean_dev(step)
end
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
print("step stddev:", step_dev)
if step_dev > 1e-8 then
for i, v in ipairs(base_params) do
base_params[i] = v + step[i]
end
else
-- we didn't get anywhere. step in a random direction.
local noise = trial_noise[1]
for i, v in ipairs(base_params) do
base_params[i] = v + magnitude * noise[i]
end
for i, v in ipairs(base_params) do
base_params[i] = v + step[i]
end
network:distribute(base_params)
@ -513,8 +526,9 @@ local function load_next_trial()
print('loading trial', trial_i)
local W = nn.copy(base_params)
local noise = trial_noise[trial_i]
local devsqrt = sqrt(deviation)
for i, v in ipairs(base_params) do
W[i] = v + deviation * noise[i]
W[i] = v + devsqrt * noise[i]
end
network:distribute(W)
end
@ -525,19 +539,11 @@ local function do_reset()
if trial_i > 0 then trial_rewards[trial_i] = reward end
if epoch_i == 0 or trial_i == epoch_trials then
if epoch_i > 0 then learn_from_epoch() else network:reset() end
if epoch_i > 0 then learn_from_epoch() end
epoch_i = epoch_i + 1
prepare_epoch()
end
if once then
savestate.load(startsave)
print("end of trial reward:", reward)
else
savestate.save(startsave)
end
once = true
-- bit of a hack:
if get_state() == 'loading' then advance() end
reward = 0
@ -549,9 +555,20 @@ local function do_reset()
-- unless you get a 1-up, in which case, please continue!
W(0x75A, 0)
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
max_time = min(8 * sqrt(15 * (epoch_i - 1)) + 100, cap_time)
if once then
savestate.load(startsave)
--print("end of trial reward:", reward)
else
savestate.save(startsave)
end
once = true
emu.frameadvance() -- prevents emulator from quirking up.
print()
--print()
load_next_trial()
reset = false
@ -564,7 +581,11 @@ local function init()
emu.poweron()
emu.unpause()
emu.speedmode("normal")
emu.speedmode("turbo")
network:load()
local res, err = pcall(network.load, network)
if res == false then print(err) end
end
init()
@ -647,6 +668,8 @@ while true do
local flagpole_bonus = R(0xE) == 4 and 1 or 0
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
-- TODO: add ingame score to reward.
if not ingame_paused then reward = reward + reward_delta end
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
@ -655,7 +678,7 @@ while true do
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
if get_state() == 'dead' and state_old ~= 'dead' then
print("dead. lives remaining:", R(0x75A, 0))
--print("dead. lives remaining:", R(0x75A, 0))
if R(0x75A, 0) == 0 then reset = true end
end
if get_state() == 'lose' then
@ -664,17 +687,18 @@ while true do
end
-- lose a point for every frame paused.
if ingame_paused then reward = reward - 1 end
--if ingame_paused then reward = reward - 1 end
if ingame_paused then reward = reward - 402; reset = true end
-- every few frames mario stands still, forcibly decrease the timer.
-- this includes having the game paused.
-- TODO: more robust. doesn't detect moonwalking against a wall.
local timer = get_timer()
local timer_loser = 1/5
if ingame_paused or math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
timer = clamp(timer - 1, 0, 400)
set_timer(timer)
if ingame_paused or random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
timer = timer - 1
end
timer = clamp(timer, 0, max_time)
set_timer(timer)
-- if we've run out of time while the game is paused...
-- that's cheating! unpause.

20
nn.lua
View file

@ -124,6 +124,7 @@ local Layer = Base:extend()
local Model = Base:extend()
local Input = Layer:extend()
local Relu = Layer:extend()
local Gelu = Layer:extend()
local Dense = Layer:extend()
local Softmax = Layer:extend()
@ -241,6 +242,24 @@ function Relu:forward(X)
return Y
end
function Gelu:init()
Layer.init(self, "Gelu")
end
function Gelu:forward(X)
assert(#X == self.size_in)
self.cache = self.cache or zeros(self.size_out)
local Y = self.cache
-- NOTE: approximate form of GELU exploiting similarities to sigmoid curve.
for i = 1, #X do
Y[i] = X[i] / (1 + exp(-1.704 * X[i]))
end
assert(#Y == self.size_out)
return Y
end
function Dense:init(dim)
Layer.init(self, "Dense")
assert(type(dim) == "number")
@ -413,6 +432,7 @@ return {
Model = Model,
Input = Input,
Relu = Relu,
Gelu = Gelu,
Dense = Dense,
Softmax = Softmax,
}