various work
This commit is contained in:
parent
e2d29352c4
commit
9017af0d13
2 changed files with 88 additions and 44 deletions
112
main.lua
112
main.lua
|
@ -22,13 +22,19 @@ end
|
|||
|
||||
-- configuration and globals.
|
||||
|
||||
math.randomseed(10)
|
||||
--randomseed(11)
|
||||
|
||||
local learning_rate = 1e-2
|
||||
local deviation = 2e-2
|
||||
local enable_overlay = false
|
||||
local enable_network = true
|
||||
--
|
||||
local epoch_trials = 24
|
||||
local learning_rate = 3.2e-3
|
||||
local deviation = 1e-1 / epoch_trials
|
||||
--
|
||||
local timer_loser = 1/3
|
||||
local cap_time = 400
|
||||
|
||||
local epoch_i = 0
|
||||
local epoch_trials = 12
|
||||
local base_params
|
||||
local trial_i = 0
|
||||
local trial_noise = {}
|
||||
|
@ -38,17 +44,16 @@ local trials_remaining = 0
|
|||
local force_start = false
|
||||
local force_start_old = false
|
||||
|
||||
local enable_overlay = false
|
||||
local enable_network = true
|
||||
|
||||
local startsave = savestate.create(1)
|
||||
|
||||
local poketime = false
|
||||
local max_time
|
||||
|
||||
local sprite_input = {}
|
||||
local tile_input = {}
|
||||
|
||||
local reward
|
||||
|
||||
local powerup_old
|
||||
local status_old
|
||||
local coins_old
|
||||
|
@ -89,8 +94,10 @@ local ceil = math.ceil
|
|||
local min = math.min
|
||||
local max = math.max
|
||||
local exp = math.exp
|
||||
local log = math.log
|
||||
local sqrt = math.sqrt
|
||||
local random = math.random
|
||||
local randomseed = math.randomseed
|
||||
local insert = table.insert
|
||||
local remove = table.remove
|
||||
local unpack = table.unpack or unpack
|
||||
|
@ -177,12 +184,15 @@ local function make_network(input_size, buttons)
|
|||
nn_x = nn.Input(input_size)
|
||||
nn_y = nn_x
|
||||
nn_z = {}
|
||||
nn_y = nn_y:feed(nn.Dense(input_size))
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16)))
|
||||
--nn_y = nn_y:feed(nn.Relu())
|
||||
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16)))
|
||||
--nn_y = nn_y:feed(nn.Relu())
|
||||
if false then
|
||||
nn_y = nn_y:feed(nn.Dense(input_size))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
else
|
||||
nn_y = nn_y:feed(nn.Dense(114))
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
nn_y = nn_y:feed(nn.Dense(57))
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
end
|
||||
for i = 1, buttons do
|
||||
nn_z[i] = nn_y
|
||||
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
|
||||
|
@ -451,7 +461,7 @@ local function learn_from_epoch()
|
|||
print()
|
||||
print('rewards:', trial_rewards)
|
||||
normalize(trial_rewards)
|
||||
print('normalized:', trial_rewards)
|
||||
--print('normalized:', trial_rewards)
|
||||
|
||||
local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
|
||||
|
||||
|
@ -465,7 +475,7 @@ local function learn_from_epoch()
|
|||
end
|
||||
|
||||
local magnitude = learning_rate / deviation
|
||||
print('stepping with magnitude', magnitude)
|
||||
--print('stepping with magnitude', magnitude)
|
||||
-- throw the division from the averaging in there too.
|
||||
local altogether = magnitude / epoch_trials
|
||||
for i, v in ipairs(step) do
|
||||
|
@ -473,19 +483,22 @@ local function learn_from_epoch()
|
|||
end
|
||||
|
||||
local step_mean, step_dev = calc_mean_dev(step)
|
||||
print("step mean:", step_mean)
|
||||
if step_dev < 1e-8 then
|
||||
-- we didn't get anywhere. step in a random direction.
|
||||
print("stepping randomly.")
|
||||
local noise = trial_noise[1]
|
||||
local devsqrt = sqrt(deviation)
|
||||
for i, v in ipairs(step) do
|
||||
step[i] = devsqrt * noise[i]
|
||||
end
|
||||
|
||||
step_mean, step_dev = calc_mean_dev(step)
|
||||
end
|
||||
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
|
||||
print("step stddev:", step_dev)
|
||||
|
||||
if step_dev > 1e-8 then
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v + step[i]
|
||||
end
|
||||
else
|
||||
-- we didn't get anywhere. step in a random direction.
|
||||
local noise = trial_noise[1]
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v + magnitude * noise[i]
|
||||
end
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v + step[i]
|
||||
end
|
||||
|
||||
network:distribute(base_params)
|
||||
|
@ -513,8 +526,9 @@ local function load_next_trial()
|
|||
print('loading trial', trial_i)
|
||||
local W = nn.copy(base_params)
|
||||
local noise = trial_noise[trial_i]
|
||||
local devsqrt = sqrt(deviation)
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v + deviation * noise[i]
|
||||
W[i] = v + devsqrt * noise[i]
|
||||
end
|
||||
network:distribute(W)
|
||||
end
|
||||
|
@ -525,19 +539,11 @@ local function do_reset()
|
|||
if trial_i > 0 then trial_rewards[trial_i] = reward end
|
||||
|
||||
if epoch_i == 0 or trial_i == epoch_trials then
|
||||
if epoch_i > 0 then learn_from_epoch() else network:reset() end
|
||||
if epoch_i > 0 then learn_from_epoch() end
|
||||
epoch_i = epoch_i + 1
|
||||
prepare_epoch()
|
||||
end
|
||||
|
||||
if once then
|
||||
savestate.load(startsave)
|
||||
print("end of trial reward:", reward)
|
||||
else
|
||||
savestate.save(startsave)
|
||||
end
|
||||
once = true
|
||||
|
||||
-- bit of a hack:
|
||||
if get_state() == 'loading' then advance() end
|
||||
reward = 0
|
||||
|
@ -549,9 +555,20 @@ local function do_reset()
|
|||
-- unless you get a 1-up, in which case, please continue!
|
||||
W(0x75A, 0)
|
||||
|
||||
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
||||
max_time = min(8 * sqrt(15 * (epoch_i - 1)) + 100, cap_time)
|
||||
|
||||
if once then
|
||||
savestate.load(startsave)
|
||||
--print("end of trial reward:", reward)
|
||||
else
|
||||
savestate.save(startsave)
|
||||
end
|
||||
once = true
|
||||
|
||||
emu.frameadvance() -- prevents emulator from quirking up.
|
||||
|
||||
print()
|
||||
--print()
|
||||
load_next_trial()
|
||||
|
||||
reset = false
|
||||
|
@ -564,7 +581,11 @@ local function init()
|
|||
|
||||
emu.poweron()
|
||||
emu.unpause()
|
||||
emu.speedmode("normal")
|
||||
emu.speedmode("turbo")
|
||||
|
||||
network:load()
|
||||
local res, err = pcall(network.load, network)
|
||||
if res == false then print(err) end
|
||||
end
|
||||
|
||||
init()
|
||||
|
@ -647,6 +668,8 @@ while true do
|
|||
local flagpole_bonus = R(0xE) == 4 and 1 or 0
|
||||
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
||||
|
||||
-- TODO: add ingame score to reward.
|
||||
|
||||
if not ingame_paused then reward = reward + reward_delta end
|
||||
|
||||
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
|
||||
|
@ -655,7 +678,7 @@ while true do
|
|||
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
||||
|
||||
if get_state() == 'dead' and state_old ~= 'dead' then
|
||||
print("dead. lives remaining:", R(0x75A, 0))
|
||||
--print("dead. lives remaining:", R(0x75A, 0))
|
||||
if R(0x75A, 0) == 0 then reset = true end
|
||||
end
|
||||
if get_state() == 'lose' then
|
||||
|
@ -664,17 +687,18 @@ while true do
|
|||
end
|
||||
|
||||
-- lose a point for every frame paused.
|
||||
if ingame_paused then reward = reward - 1 end
|
||||
--if ingame_paused then reward = reward - 1 end
|
||||
if ingame_paused then reward = reward - 402; reset = true end
|
||||
|
||||
-- every few frames mario stands still, forcibly decrease the timer.
|
||||
-- this includes having the game paused.
|
||||
-- TODO: more robust. doesn't detect moonwalking against a wall.
|
||||
local timer = get_timer()
|
||||
local timer_loser = 1/5
|
||||
if ingame_paused or math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
||||
timer = clamp(timer - 1, 0, 400)
|
||||
set_timer(timer)
|
||||
if ingame_paused or random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
||||
timer = timer - 1
|
||||
end
|
||||
timer = clamp(timer, 0, max_time)
|
||||
set_timer(timer)
|
||||
|
||||
-- if we've run out of time while the game is paused...
|
||||
-- that's cheating! unpause.
|
||||
|
|
20
nn.lua
20
nn.lua
|
@ -124,6 +124,7 @@ local Layer = Base:extend()
|
|||
local Model = Base:extend()
|
||||
local Input = Layer:extend()
|
||||
local Relu = Layer:extend()
|
||||
local Gelu = Layer:extend()
|
||||
local Dense = Layer:extend()
|
||||
local Softmax = Layer:extend()
|
||||
|
||||
|
@ -241,6 +242,24 @@ function Relu:forward(X)
|
|||
return Y
|
||||
end
|
||||
|
||||
function Gelu:init()
|
||||
Layer.init(self, "Gelu")
|
||||
end
|
||||
|
||||
function Gelu:forward(X)
|
||||
assert(#X == self.size_in)
|
||||
self.cache = self.cache or zeros(self.size_out)
|
||||
local Y = self.cache
|
||||
|
||||
-- NOTE: approximate form of GELU exploiting similarities to sigmoid curve.
|
||||
for i = 1, #X do
|
||||
Y[i] = X[i] / (1 + exp(-1.704 * X[i]))
|
||||
end
|
||||
|
||||
assert(#Y == self.size_out)
|
||||
return Y
|
||||
end
|
||||
|
||||
function Dense:init(dim)
|
||||
Layer.init(self, "Dense")
|
||||
assert(type(dim) == "number")
|
||||
|
@ -413,6 +432,7 @@ return {
|
|||
Model = Model,
|
||||
Input = Input,
|
||||
Relu = Relu,
|
||||
Gelu = Gelu,
|
||||
Dense = Dense,
|
||||
Softmax = Softmax,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue