various work
This commit is contained in:
parent
e2d29352c4
commit
9017af0d13
2 changed files with 88 additions and 44 deletions
104
main.lua
104
main.lua
|
@ -22,13 +22,19 @@ end
|
||||||
|
|
||||||
-- configuration and globals.
|
-- configuration and globals.
|
||||||
|
|
||||||
math.randomseed(10)
|
--randomseed(11)
|
||||||
|
|
||||||
local learning_rate = 1e-2
|
local enable_overlay = false
|
||||||
local deviation = 2e-2
|
local enable_network = true
|
||||||
|
--
|
||||||
|
local epoch_trials = 24
|
||||||
|
local learning_rate = 3.2e-3
|
||||||
|
local deviation = 1e-1 / epoch_trials
|
||||||
|
--
|
||||||
|
local timer_loser = 1/3
|
||||||
|
local cap_time = 400
|
||||||
|
|
||||||
local epoch_i = 0
|
local epoch_i = 0
|
||||||
local epoch_trials = 12
|
|
||||||
local base_params
|
local base_params
|
||||||
local trial_i = 0
|
local trial_i = 0
|
||||||
local trial_noise = {}
|
local trial_noise = {}
|
||||||
|
@ -38,17 +44,16 @@ local trials_remaining = 0
|
||||||
local force_start = false
|
local force_start = false
|
||||||
local force_start_old = false
|
local force_start_old = false
|
||||||
|
|
||||||
local enable_overlay = false
|
|
||||||
local enable_network = true
|
|
||||||
|
|
||||||
local startsave = savestate.create(1)
|
local startsave = savestate.create(1)
|
||||||
|
|
||||||
local poketime = false
|
local poketime = false
|
||||||
|
local max_time
|
||||||
|
|
||||||
local sprite_input = {}
|
local sprite_input = {}
|
||||||
local tile_input = {}
|
local tile_input = {}
|
||||||
|
|
||||||
local reward
|
local reward
|
||||||
|
|
||||||
local powerup_old
|
local powerup_old
|
||||||
local status_old
|
local status_old
|
||||||
local coins_old
|
local coins_old
|
||||||
|
@ -89,8 +94,10 @@ local ceil = math.ceil
|
||||||
local min = math.min
|
local min = math.min
|
||||||
local max = math.max
|
local max = math.max
|
||||||
local exp = math.exp
|
local exp = math.exp
|
||||||
|
local log = math.log
|
||||||
local sqrt = math.sqrt
|
local sqrt = math.sqrt
|
||||||
local random = math.random
|
local random = math.random
|
||||||
|
local randomseed = math.randomseed
|
||||||
local insert = table.insert
|
local insert = table.insert
|
||||||
local remove = table.remove
|
local remove = table.remove
|
||||||
local unpack = table.unpack or unpack
|
local unpack = table.unpack or unpack
|
||||||
|
@ -177,12 +184,15 @@ local function make_network(input_size, buttons)
|
||||||
nn_x = nn.Input(input_size)
|
nn_x = nn.Input(input_size)
|
||||||
nn_y = nn_x
|
nn_y = nn_x
|
||||||
nn_z = {}
|
nn_z = {}
|
||||||
|
if false then
|
||||||
nn_y = nn_y:feed(nn.Dense(input_size))
|
nn_y = nn_y:feed(nn.Dense(input_size))
|
||||||
|
nn_y = nn_y:feed(nn.Gelu())
|
||||||
|
else
|
||||||
|
nn_y = nn_y:feed(nn.Dense(114))
|
||||||
nn_y = nn_y:feed(nn.Relu())
|
nn_y = nn_y:feed(nn.Relu())
|
||||||
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16)))
|
nn_y = nn_y:feed(nn.Dense(57))
|
||||||
--nn_y = nn_y:feed(nn.Relu())
|
nn_y = nn_y:feed(nn.Relu())
|
||||||
--nn_y = nn_y:feed(nn.Dense(floor(input_size / 16)))
|
end
|
||||||
--nn_y = nn_y:feed(nn.Relu())
|
|
||||||
for i = 1, buttons do
|
for i = 1, buttons do
|
||||||
nn_z[i] = nn_y
|
nn_z[i] = nn_y
|
||||||
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
|
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
|
||||||
|
@ -451,7 +461,7 @@ local function learn_from_epoch()
|
||||||
print()
|
print()
|
||||||
print('rewards:', trial_rewards)
|
print('rewards:', trial_rewards)
|
||||||
normalize(trial_rewards)
|
normalize(trial_rewards)
|
||||||
print('normalized:', trial_rewards)
|
--print('normalized:', trial_rewards)
|
||||||
|
|
||||||
local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
|
local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
|
||||||
|
|
||||||
|
@ -465,7 +475,7 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
|
|
||||||
local magnitude = learning_rate / deviation
|
local magnitude = learning_rate / deviation
|
||||||
print('stepping with magnitude', magnitude)
|
--print('stepping with magnitude', magnitude)
|
||||||
-- throw the division from the averaging in there too.
|
-- throw the division from the averaging in there too.
|
||||||
local altogether = magnitude / epoch_trials
|
local altogether = magnitude / epoch_trials
|
||||||
for i, v in ipairs(step) do
|
for i, v in ipairs(step) do
|
||||||
|
@ -473,20 +483,23 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
|
|
||||||
local step_mean, step_dev = calc_mean_dev(step)
|
local step_mean, step_dev = calc_mean_dev(step)
|
||||||
print("step mean:", step_mean)
|
if step_dev < 1e-8 then
|
||||||
|
-- we didn't get anywhere. step in a random direction.
|
||||||
|
print("stepping randomly.")
|
||||||
|
local noise = trial_noise[1]
|
||||||
|
local devsqrt = sqrt(deviation)
|
||||||
|
for i, v in ipairs(step) do
|
||||||
|
step[i] = devsqrt * noise[i]
|
||||||
|
end
|
||||||
|
|
||||||
|
step_mean, step_dev = calc_mean_dev(step)
|
||||||
|
end
|
||||||
|
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
|
||||||
print("step stddev:", step_dev)
|
print("step stddev:", step_dev)
|
||||||
|
|
||||||
if step_dev > 1e-8 then
|
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
base_params[i] = v + step[i]
|
base_params[i] = v + step[i]
|
||||||
end
|
end
|
||||||
else
|
|
||||||
-- we didn't get anywhere. step in a random direction.
|
|
||||||
local noise = trial_noise[1]
|
|
||||||
for i, v in ipairs(base_params) do
|
|
||||||
base_params[i] = v + magnitude * noise[i]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
network:distribute(base_params)
|
network:distribute(base_params)
|
||||||
|
|
||||||
|
@ -513,8 +526,9 @@ local function load_next_trial()
|
||||||
print('loading trial', trial_i)
|
print('loading trial', trial_i)
|
||||||
local W = nn.copy(base_params)
|
local W = nn.copy(base_params)
|
||||||
local noise = trial_noise[trial_i]
|
local noise = trial_noise[trial_i]
|
||||||
|
local devsqrt = sqrt(deviation)
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
W[i] = v + deviation * noise[i]
|
W[i] = v + devsqrt * noise[i]
|
||||||
end
|
end
|
||||||
network:distribute(W)
|
network:distribute(W)
|
||||||
end
|
end
|
||||||
|
@ -525,19 +539,11 @@ local function do_reset()
|
||||||
if trial_i > 0 then trial_rewards[trial_i] = reward end
|
if trial_i > 0 then trial_rewards[trial_i] = reward end
|
||||||
|
|
||||||
if epoch_i == 0 or trial_i == epoch_trials then
|
if epoch_i == 0 or trial_i == epoch_trials then
|
||||||
if epoch_i > 0 then learn_from_epoch() else network:reset() end
|
if epoch_i > 0 then learn_from_epoch() end
|
||||||
epoch_i = epoch_i + 1
|
epoch_i = epoch_i + 1
|
||||||
prepare_epoch()
|
prepare_epoch()
|
||||||
end
|
end
|
||||||
|
|
||||||
if once then
|
|
||||||
savestate.load(startsave)
|
|
||||||
print("end of trial reward:", reward)
|
|
||||||
else
|
|
||||||
savestate.save(startsave)
|
|
||||||
end
|
|
||||||
once = true
|
|
||||||
|
|
||||||
-- bit of a hack:
|
-- bit of a hack:
|
||||||
if get_state() == 'loading' then advance() end
|
if get_state() == 'loading' then advance() end
|
||||||
reward = 0
|
reward = 0
|
||||||
|
@ -549,9 +555,20 @@ local function do_reset()
|
||||||
-- unless you get a 1-up, in which case, please continue!
|
-- unless you get a 1-up, in which case, please continue!
|
||||||
W(0x75A, 0)
|
W(0x75A, 0)
|
||||||
|
|
||||||
|
--max_time = min(log(epoch_i) * 10 + 100, cap_time)
|
||||||
|
max_time = min(8 * sqrt(15 * (epoch_i - 1)) + 100, cap_time)
|
||||||
|
|
||||||
|
if once then
|
||||||
|
savestate.load(startsave)
|
||||||
|
--print("end of trial reward:", reward)
|
||||||
|
else
|
||||||
|
savestate.save(startsave)
|
||||||
|
end
|
||||||
|
once = true
|
||||||
|
|
||||||
emu.frameadvance() -- prevents emulator from quirking up.
|
emu.frameadvance() -- prevents emulator from quirking up.
|
||||||
|
|
||||||
print()
|
--print()
|
||||||
load_next_trial()
|
load_next_trial()
|
||||||
|
|
||||||
reset = false
|
reset = false
|
||||||
|
@ -564,7 +581,11 @@ local function init()
|
||||||
|
|
||||||
emu.poweron()
|
emu.poweron()
|
||||||
emu.unpause()
|
emu.unpause()
|
||||||
emu.speedmode("normal")
|
emu.speedmode("turbo")
|
||||||
|
|
||||||
|
network:load()
|
||||||
|
local res, err = pcall(network.load, network)
|
||||||
|
if res == false then print(err) end
|
||||||
end
|
end
|
||||||
|
|
||||||
init()
|
init()
|
||||||
|
@ -647,6 +668,8 @@ while true do
|
||||||
local flagpole_bonus = R(0xE) == 4 and 1 or 0
|
local flagpole_bonus = R(0xE) == 4 and 1 or 0
|
||||||
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
local reward_delta = screen_scroll_delta + status_delta * 256 + flagpole_bonus
|
||||||
|
|
||||||
|
-- TODO: add ingame score to reward.
|
||||||
|
|
||||||
if not ingame_paused then reward = reward + reward_delta end
|
if not ingame_paused then reward = reward + reward_delta end
|
||||||
|
|
||||||
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
|
--gui.text(4, 12, ("%02X"):format(#sprite_input), '#FFFFFF', '#0000003F')
|
||||||
|
@ -655,7 +678,7 @@ while true do
|
||||||
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
gui.text(112, 12, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
||||||
|
|
||||||
if get_state() == 'dead' and state_old ~= 'dead' then
|
if get_state() == 'dead' and state_old ~= 'dead' then
|
||||||
print("dead. lives remaining:", R(0x75A, 0))
|
--print("dead. lives remaining:", R(0x75A, 0))
|
||||||
if R(0x75A, 0) == 0 then reset = true end
|
if R(0x75A, 0) == 0 then reset = true end
|
||||||
end
|
end
|
||||||
if get_state() == 'lose' then
|
if get_state() == 'lose' then
|
||||||
|
@ -664,17 +687,18 @@ while true do
|
||||||
end
|
end
|
||||||
|
|
||||||
-- lose a point for every frame paused.
|
-- lose a point for every frame paused.
|
||||||
if ingame_paused then reward = reward - 1 end
|
--if ingame_paused then reward = reward - 1 end
|
||||||
|
if ingame_paused then reward = reward - 402; reset = true end
|
||||||
|
|
||||||
-- every few frames mario stands still, forcibly decrease the timer.
|
-- every few frames mario stands still, forcibly decrease the timer.
|
||||||
-- this includes having the game paused.
|
-- this includes having the game paused.
|
||||||
-- TODO: more robust. doesn't detect moonwalking against a wall.
|
-- TODO: more robust. doesn't detect moonwalking against a wall.
|
||||||
local timer = get_timer()
|
local timer = get_timer()
|
||||||
local timer_loser = 1/5
|
if ingame_paused or random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
||||||
if ingame_paused or math.random() > 1 - timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
timer = timer - 1
|
||||||
timer = clamp(timer - 1, 0, 400)
|
|
||||||
set_timer(timer)
|
|
||||||
end
|
end
|
||||||
|
timer = clamp(timer, 0, max_time)
|
||||||
|
set_timer(timer)
|
||||||
|
|
||||||
-- if we've run out of time while the game is paused...
|
-- if we've run out of time while the game is paused...
|
||||||
-- that's cheating! unpause.
|
-- that's cheating! unpause.
|
||||||
|
|
20
nn.lua
20
nn.lua
|
@ -124,6 +124,7 @@ local Layer = Base:extend()
|
||||||
local Model = Base:extend()
|
local Model = Base:extend()
|
||||||
local Input = Layer:extend()
|
local Input = Layer:extend()
|
||||||
local Relu = Layer:extend()
|
local Relu = Layer:extend()
|
||||||
|
local Gelu = Layer:extend()
|
||||||
local Dense = Layer:extend()
|
local Dense = Layer:extend()
|
||||||
local Softmax = Layer:extend()
|
local Softmax = Layer:extend()
|
||||||
|
|
||||||
|
@ -241,6 +242,24 @@ function Relu:forward(X)
|
||||||
return Y
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Gelu:init()
|
||||||
|
Layer.init(self, "Gelu")
|
||||||
|
end
|
||||||
|
|
||||||
|
function Gelu:forward(X)
|
||||||
|
assert(#X == self.size_in)
|
||||||
|
self.cache = self.cache or zeros(self.size_out)
|
||||||
|
local Y = self.cache
|
||||||
|
|
||||||
|
-- NOTE: approximate form of GELU exploiting similarities to sigmoid curve.
|
||||||
|
for i = 1, #X do
|
||||||
|
Y[i] = X[i] / (1 + exp(-1.704 * X[i]))
|
||||||
|
end
|
||||||
|
|
||||||
|
assert(#Y == self.size_out)
|
||||||
|
return Y
|
||||||
|
end
|
||||||
|
|
||||||
function Dense:init(dim)
|
function Dense:init(dim)
|
||||||
Layer.init(self, "Dense")
|
Layer.init(self, "Dense")
|
||||||
assert(type(dim) == "number")
|
assert(type(dim) == "number")
|
||||||
|
@ -413,6 +432,7 @@ return {
|
||||||
Model = Model,
|
Model = Model,
|
||||||
Input = Input,
|
Input = Input,
|
||||||
Relu = Relu,
|
Relu = Relu,
|
||||||
|
Gelu = Gelu,
|
||||||
Dense = Dense,
|
Dense = Dense,
|
||||||
Softmax = Softmax,
|
Softmax = Softmax,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue