temp 3
This commit is contained in:
parent
a1429a6271
commit
08f476e6ac
14 changed files with 1228 additions and 168 deletions
13
ars.lua
13
ars.lua
|
@ -10,8 +10,10 @@ local exp = math.exp
|
||||||
local floor = math.floor
|
local floor = math.floor
|
||||||
local insert = table.insert
|
local insert = table.insert
|
||||||
local ipairs = ipairs
|
local ipairs = ipairs
|
||||||
|
local log = math.log
|
||||||
local max = math.max
|
local max = math.max
|
||||||
local print = print
|
local print = print
|
||||||
|
local sqrt = math.sqrt
|
||||||
|
|
||||||
local Base = require "Base"
|
local Base = require "Base"
|
||||||
|
|
||||||
|
@ -72,16 +74,15 @@ local function kinda_lipschitz(dir, pos, neg, mid)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic,
|
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic,
|
||||||
momentum)
|
momentum, beta)
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||||
self.param_rate = base_rate
|
self.param_rate = base_rate
|
||||||
self.sigma_rate = base_rate
|
|
||||||
self.covar_rate = base_rate
|
|
||||||
self.sigma = sigma or 1
|
self.sigma = sigma or 1
|
||||||
self.antithetic = antithetic == nil and true or antithetic
|
self.antithetic = antithetic == nil and true or antithetic
|
||||||
self.momentum = momentum or 0
|
self.momentum = momentum or 0
|
||||||
|
self.beta = beta or 1.0
|
||||||
|
|
||||||
self.poptop = poptop or popsize
|
self.poptop = poptop or popsize
|
||||||
assert(self.poptop <= popsize)
|
assert(self.poptop <= popsize)
|
||||||
|
@ -189,8 +190,9 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
reward = reward / reward_dev
|
reward = reward / reward_dev
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local scale = reward / self.poptop * self.beta / 2
|
||||||
for j, v in ipairs(noisy) do
|
for j, v in ipairs(noisy) do
|
||||||
step[j] = step[j] + reward * v / self.poptop
|
step[j] = step[j] + scale * v
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -200,8 +202,9 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
if reward ~= 0 then
|
if reward ~= 0 then
|
||||||
local noisy = self.noise[ind]
|
local noisy = self.noise[ind]
|
||||||
|
|
||||||
|
local scale = reward / self.poptop * self.beta
|
||||||
for j, v in ipairs(noisy) do
|
for j, v in ipairs(noisy) do
|
||||||
step[j] = step[j] + reward * v / self.poptop
|
step[j] = step[j] + scale * v
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
14
config.lua
14
config.lua
|
@ -26,15 +26,17 @@ local defaults = {
|
||||||
time_inputs = true, -- insert binary inputs of a frame counter.
|
time_inputs = true, -- insert binary inputs of a frame counter.
|
||||||
|
|
||||||
-- network layers:
|
-- network layers:
|
||||||
|
embed = true, -- set to false to use a hard-coded tile embedding.
|
||||||
|
reduce_tiles = 0, -- TODO: write description
|
||||||
hidden = false, -- use a hidden layer with ReLU/GELU activation.
|
hidden = false, -- use a hidden layer with ReLU/GELU activation.
|
||||||
hidden_size = 128,
|
hidden_size = 128,
|
||||||
layernorm = false, -- use a LayerNorm layer after said activation.
|
layernorm = false, -- use a LayerNorm layer after said activation.
|
||||||
reduce_tiles = false,
|
|
||||||
bias_out = true,
|
bias_out = true,
|
||||||
|
|
||||||
-- network evaluation (sampling joypad):
|
-- network evaluation (sampling joypad):
|
||||||
frameskip = 4,
|
frameskip = 4,
|
||||||
prob_frameskip = 0.0,
|
prob_frameskip = 0.0,
|
||||||
|
max_frameskip = 6,
|
||||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||||
deterministic = false, -- use argmax on outputs instead of random sampling.
|
deterministic = false, -- use argmax on outputs instead of random sampling.
|
||||||
det_epsilon = false, -- take random actions with probability eps.
|
det_epsilon = false, -- take random actions with probability eps.
|
||||||
|
@ -42,12 +44,16 @@ local defaults = {
|
||||||
-- evolution strategy and non-rate hyperparemeters:
|
-- evolution strategy and non-rate hyperparemeters:
|
||||||
es = 'ars',
|
es = 'ars',
|
||||||
ars_lips = false, -- for ARS.
|
ars_lips = false, -- for ARS.
|
||||||
epoch_top_trials = 9999, -- for ARS.
|
epoch_top_trials = 9999, -- for ARS, Guided.
|
||||||
|
alpha = 0.5, -- for Guided.
|
||||||
|
beta = 2.0, -- for ARS, Guided. should be 1, but defaults to 2 for compat.
|
||||||
|
past_grads = 1, -- for Guided. keeps a history of n past steps taken.
|
||||||
|
|
||||||
-- sampling:
|
-- sampling:
|
||||||
deviation = 1.0,
|
deviation = 1.0,
|
||||||
unperturbed_trial = true, -- perform an extra trial without any noise.
|
unperturbed_trial = true, -- perform an extra trial without any noise.
|
||||||
-- this is good for logging, so i'd recommend it.
|
-- this is good for logging, so i'd recommend it.
|
||||||
|
attempts = 1, -- TODO: document.
|
||||||
epoch_trials = 50,
|
epoch_trials = 50,
|
||||||
graycode = false, -- for ARS.
|
graycode = false, -- for ARS.
|
||||||
negate_trials = true, -- try pairs of normal and negated noise directions.
|
negate_trials = true, -- try pairs of normal and negated noise directions.
|
||||||
|
@ -114,5 +120,9 @@ assert(not cfg.ars_lips or cfg.negate_trials,
|
||||||
"cfg.negate_trials must be true to use cfg.ars_lips")
|
"cfg.negate_trials must be true to use cfg.ars_lips")
|
||||||
assert(not (cfg.es == 'snes' and cfg.negate_trials),
|
assert(not (cfg.es == 'snes' and cfg.negate_trials),
|
||||||
"cfg.negate_trials is not yet compatible with SNES")
|
"cfg.negate_trials is not yet compatible with SNES")
|
||||||
|
assert(not (cfg.es == 'guided' and cfg.graycode),
|
||||||
|
"cfg.graycode is not compatible with Guided")
|
||||||
|
assert(cfg.es ~= 'guided' or cfg.negate_trials,
|
||||||
|
"cfg.negate_trials must be true to use Guided")
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
62
es_test.lua
62
es_test.lua
|
@ -1,21 +1,27 @@
|
||||||
local floor = math.floor
|
local floor = math.floor
|
||||||
|
local insert = table.insert
|
||||||
local ipairs = ipairs
|
local ipairs = ipairs
|
||||||
local log = math.log
|
local log = math.log
|
||||||
|
local max = math.max
|
||||||
local print = print
|
local print = print
|
||||||
|
|
||||||
local ars = require("ars")
|
local ars = require("ars")
|
||||||
local snes = require("snes")
|
local snes = require("snes")
|
||||||
local xnes = require("xnes")
|
local xnes = require("xnes")
|
||||||
|
local guided = require("guided")
|
||||||
|
|
||||||
-- try it all out on a dummy problem.
|
-- try it all out on a dummy problem.
|
||||||
|
|
||||||
|
local function typeof(t) return getmetatable(t).__index end
|
||||||
|
|
||||||
local function square(x) return x * x end
|
local function square(x) return x * x end
|
||||||
|
|
||||||
-- this function's global minimum is arange(dims) + 1.
|
-- this function's global minimum is arange(dims) + 1.
|
||||||
-- xNES should be able to find it almost exactly.
|
-- xNES should be able to find it almost exactly.
|
||||||
local function spherical(x)
|
local function spherical(x)
|
||||||
local sum = 0
|
local sum = 0
|
||||||
for i, v in ipairs(x) do sum = sum + square(v - i) end
|
--for i, v in ipairs(x) do sum = sum + square(v - i) end
|
||||||
|
for i, v in ipairs(x) do sum = sum + square(v - i / #x) end
|
||||||
-- we need to negate this to turn it into a maximization problem.
|
-- we need to negate this to turn it into a maximization problem.
|
||||||
return -sum
|
return -sum
|
||||||
end
|
end
|
||||||
|
@ -26,10 +32,34 @@ local dims = 100
|
||||||
local popsize = dims + 1
|
local popsize = dims + 1
|
||||||
local sigma_init = 0.5
|
local sigma_init = 0.5
|
||||||
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
|
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
|
||||||
--local es = snes.Snes(dims, popsize, 0.1, sigma_init)
|
local es = snes.Snes(dims, popsize, 0.1, sigma_init)
|
||||||
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
|
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
|
||||||
local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
|
--local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
|
||||||
es.min_refresh = 0.7 -- FIXME: needs a better interface.
|
--local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5)
|
||||||
|
es.min_refresh = 0.5 -- FIXME: needs a better interface.
|
||||||
|
|
||||||
|
if typeof(es) == xnes.Xnes
|
||||||
|
or typeof(es) == snes.Snes
|
||||||
|
then
|
||||||
|
-- use IGO recommendations
|
||||||
|
local pop5 = max(1, floor(es.popsize / 5))
|
||||||
|
|
||||||
|
local sum = 0
|
||||||
|
for i=1, es.popsize do
|
||||||
|
local maybe = i < pop5 and 1 or 0
|
||||||
|
es.utility[i] = maybe
|
||||||
|
sum = sum + maybe
|
||||||
|
end
|
||||||
|
--for i, v in ipairs(es.utility) do es.utility[i] = v / sum end
|
||||||
|
|
||||||
|
local util = require "util"
|
||||||
|
util.normalize_sums(es.utility)
|
||||||
|
|
||||||
|
es.param_rate = 0.39
|
||||||
|
es.sigma_rate = 0.39
|
||||||
|
es.covar_rate = 0.39
|
||||||
|
es.adaptive = false
|
||||||
|
end
|
||||||
|
|
||||||
if false then -- TODO: delete me
|
if false then -- TODO: delete me
|
||||||
local nn = require("nn")
|
local nn = require("nn")
|
||||||
|
@ -64,30 +94,44 @@ local asked = nil -- for caching purposes.
|
||||||
local noise = nil -- for caching purposes.
|
local noise = nil -- for caching purposes.
|
||||||
local current_cost = spherical(es:params())
|
local current_cost = spherical(es:params())
|
||||||
|
|
||||||
|
local past_grads = {}
|
||||||
|
local pgi = 0
|
||||||
|
local pgn = 10
|
||||||
|
|
||||||
for i=1, iterations do
|
for i=1, iterations do
|
||||||
if getmetatable(es).__index == snes.Snes then
|
if typeof(es) == snes.Snes and es.min_refresh ~= 1 then
|
||||||
asked, noise = es:ask_mix()
|
asked, noise = es:ask_mix()
|
||||||
elseif getmetatable(es).__index == ars.Ars then
|
elseif typeof(es) == ars.Ars then
|
||||||
asked, noise = es:ask()
|
asked, noise = es:ask()
|
||||||
|
elseif typeof(es) == guided.Guided then
|
||||||
|
asked, noise = es:ask(past_grads)
|
||||||
else
|
else
|
||||||
asked, noise = es:ask(asked, noise)
|
asked, noise = es:ask(asked, noise)
|
||||||
end
|
end
|
||||||
|
|
||||||
local scores = {}
|
local scores = {}
|
||||||
for i, v in ipairs(asked) do
|
for i, v in ipairs(asked) do
|
||||||
scores[i] = spherical(v)
|
scores[i] = spherical(v)
|
||||||
end
|
end
|
||||||
|
|
||||||
if getmetatable(es).__index == ars.Ars then
|
if typeof(es) == ars.Ars then
|
||||||
es:tell(scores)--, current_cost) -- use lips
|
es:tell(scores)--, current_cost) -- use lips
|
||||||
|
elseif typeof(es) == guided.Guided then
|
||||||
|
local step = es:tell(scores)
|
||||||
|
|
||||||
|
for _, v in ipairs(step) do
|
||||||
|
past_grads[pgi + 1] = v
|
||||||
|
pgi = (pgi + 1) % (pgn * #step)
|
||||||
|
end
|
||||||
|
past_grads.shape = {floor(#past_grads / #step), #step}
|
||||||
else
|
else
|
||||||
es:tell(scores)
|
es:tell(scores)
|
||||||
end
|
end
|
||||||
|
|
||||||
current_cost = spherical(es:params())
|
current_cost = spherical(es:params())
|
||||||
--if i % 100 == 0 then
|
|
||||||
if i % 100 == 0 then
|
if i % 100 == 0 then
|
||||||
local sigma = es.sigma
|
local sigma = es.sigma
|
||||||
if getmetatable(es).__index == snes.Snes then
|
if typeof(es) == snes.Snes then
|
||||||
sigma = 0
|
sigma = 0
|
||||||
for i, v in ipairs(es.std) do sigma = sigma + v end
|
for i, v in ipairs(es.std) do sigma = sigma + v end
|
||||||
sigma = sigma / #es.std
|
sigma = sigma / #es.std
|
||||||
|
|
194
guided.lua
Normal file
194
guided.lua
Normal file
|
@ -0,0 +1,194 @@
|
||||||
|
-- Guided Evolutionary Strategies
|
||||||
|
-- https://arxiv.org/abs/1806.10230
|
||||||
|
|
||||||
|
-- this is just ARS extended to utilize gradients
|
||||||
|
-- approximated from previous iterations.
|
||||||
|
|
||||||
|
-- for simplicity:
|
||||||
|
-- antithetic is always true
|
||||||
|
-- momentum is always 0
|
||||||
|
-- no graycode/lipschitz nonsense
|
||||||
|
|
||||||
|
local floor = math.floor
|
||||||
|
local insert = table.insert
|
||||||
|
local ipairs = ipairs
|
||||||
|
local max = math.max
|
||||||
|
local print = print
|
||||||
|
local sqrt = math.sqrt
|
||||||
|
|
||||||
|
local Base = require "Base"
|
||||||
|
|
||||||
|
local nn = require "nn"
|
||||||
|
local dot_mv = nn.dot_mv
|
||||||
|
local transpose = nn.transpose
|
||||||
|
local normal = nn.normal
|
||||||
|
local prod = nn.prod
|
||||||
|
local uniform = nn.uniform
|
||||||
|
local zeros = nn.zeros
|
||||||
|
|
||||||
|
local qr = require "qr2"
|
||||||
|
|
||||||
|
local util = require "util"
|
||||||
|
local argsort = util.argsort
|
||||||
|
local calc_mean_dev = util.calc_mean_dev
|
||||||
|
|
||||||
|
local Guided = Base:extend()
|
||||||
|
|
||||||
|
local function collect_best_indices(scored, top)
|
||||||
|
-- select one (the best) reward of each pos/neg pair.
|
||||||
|
local best_rewards
|
||||||
|
best_rewards = {}
|
||||||
|
for i = 1, #scored / 2 do
|
||||||
|
local pos = scored[i * 2 - 1]
|
||||||
|
local neg = scored[i * 2 - 0]
|
||||||
|
best_rewards[i] = max(pos, neg)
|
||||||
|
end
|
||||||
|
|
||||||
|
local indices = argsort(best_rewards, function(a, b) return a > b end)
|
||||||
|
|
||||||
|
for i = top + 1, #best_rewards do indices[i] = nil end
|
||||||
|
return indices
|
||||||
|
end
|
||||||
|
|
||||||
|
function Guided:init(dims, popsize, poptop, base_rate, sigma, alpha, beta)
|
||||||
|
-- sigma: scale of random perturbations.
|
||||||
|
-- alpha: blend between full parameter space and its gradient subspace.
|
||||||
|
-- 1.0 is roughly equivalent to ARS.
|
||||||
|
self.dims = dims
|
||||||
|
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||||
|
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||||
|
self.param_rate = base_rate
|
||||||
|
self.sigma = sigma or 1.0
|
||||||
|
self.alpha = alpha or 0.5
|
||||||
|
self.beta = beta or 1.0
|
||||||
|
|
||||||
|
self.poptop = poptop or popsize
|
||||||
|
assert(self.poptop <= popsize)
|
||||||
|
self.popsize = self.popsize * 2 -- antithetic
|
||||||
|
|
||||||
|
self._params = zeros(self.dims)
|
||||||
|
--self.accum = zeros(self.dims) -- momentum
|
||||||
|
|
||||||
|
self.evals = 0
|
||||||
|
end
|
||||||
|
|
||||||
|
function Guided:params(new_params)
|
||||||
|
if new_params ~= nil then
|
||||||
|
assert(#self._params == #new_params, "new parameters have the wrong size")
|
||||||
|
for i, v in ipairs(new_params) do self._params[i] = v end
|
||||||
|
end
|
||||||
|
return self._params
|
||||||
|
end
|
||||||
|
|
||||||
|
function Guided:decay(param_decay, sigma_decay)
|
||||||
|
-- FIXME: multiplying by sigma probably isn't correct anymore.
|
||||||
|
-- is this correct now?
|
||||||
|
if param_decay > 0 then
|
||||||
|
local scale = self.sigma / sqrt(self.dims)
|
||||||
|
scale = scale * self.beta
|
||||||
|
scale = scale * self.param_rate / (self.sigma * self.sigma)
|
||||||
|
|
||||||
|
scale = 1 - param_decay * scale
|
||||||
|
for i, v in ipairs(self._params) do
|
||||||
|
self._params[i] = scale * v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function Guided:ask(grads)
|
||||||
|
local asked = {}
|
||||||
|
local noise = {}
|
||||||
|
|
||||||
|
local n_grad = 0
|
||||||
|
local gnoise, U, dummy, left, right
|
||||||
|
if grads ~= nil and #grads > 0 then
|
||||||
|
n_grad = grads.shape[1]
|
||||||
|
gnoise = zeros(n_grad)
|
||||||
|
|
||||||
|
U, dummy = qr(transpose(grads))
|
||||||
|
--print(nn.pp(transpose(U), "%9.4f"))
|
||||||
|
|
||||||
|
left = sqrt(self.alpha / self.dims)
|
||||||
|
right = sqrt((1 - self.alpha) / n_grad)
|
||||||
|
--print(left, right)
|
||||||
|
end
|
||||||
|
|
||||||
|
for i = 1, self.popsize do
|
||||||
|
local asking = zeros(self.dims)
|
||||||
|
local noisy = zeros(self.dims)
|
||||||
|
asked[i] = asking
|
||||||
|
noise[i] = noisy
|
||||||
|
|
||||||
|
if i % 2 == 0 then
|
||||||
|
local old_noisy = noise[i - 1]
|
||||||
|
for j, v in ipairs(old_noisy) do
|
||||||
|
noisy[j] = -v
|
||||||
|
end
|
||||||
|
elseif n_grad == 0 then
|
||||||
|
local scale = self.sigma / sqrt(self.dims)
|
||||||
|
for j = 1, self.dims do
|
||||||
|
noisy[j] = scale * normal()
|
||||||
|
end
|
||||||
|
else
|
||||||
|
for j = 1, self.dims do noisy[j] = normal() end
|
||||||
|
for j = 1, n_grad do gnoise[j] = normal() end
|
||||||
|
local noisier = dot_mv(U, gnoise)
|
||||||
|
for j, v in ipairs(noisy) do
|
||||||
|
noisy[j] = self.sigma * (left * v + right * noisier[j])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
for j, v in ipairs(self._params) do
|
||||||
|
asking[j] = v + noisy[j]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
self.noise = noise
|
||||||
|
return asked, noise
|
||||||
|
end
|
||||||
|
|
||||||
|
function Guided:tell(scored, unperturbed_score)
|
||||||
|
self.evals = self.evals + #scored
|
||||||
|
|
||||||
|
local indices = collect_best_indices(scored, self.poptop)
|
||||||
|
|
||||||
|
local top_rewards = {}
|
||||||
|
for _, ind in ipairs(indices) do
|
||||||
|
insert(top_rewards, scored[ind * 2 - 1])
|
||||||
|
insert(top_rewards, scored[ind * 2 - 0])
|
||||||
|
end
|
||||||
|
|
||||||
|
local step = zeros(self.dims)
|
||||||
|
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||||
|
if reward_dev == 0 then reward_dev = 1 end
|
||||||
|
|
||||||
|
for i, ind in ipairs(indices) do
|
||||||
|
local pos = top_rewards[i * 2 - 1]
|
||||||
|
local neg = top_rewards[i * 2 - 0]
|
||||||
|
local reward = pos - neg
|
||||||
|
if reward ~= 0 then
|
||||||
|
local noisy = self.noise[ind * 2 - 1]
|
||||||
|
-- NOTE: technically this reward divide isn't part of guided search.
|
||||||
|
reward = reward / reward_dev
|
||||||
|
|
||||||
|
local scale = reward / self.poptop * self.beta / 2
|
||||||
|
for j, v in ipairs(noisy) do
|
||||||
|
step[j] = step[j] + scale * v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local coeff = self.param_rate / (self.sigma * self.sigma)
|
||||||
|
for i, v in ipairs(self._params) do
|
||||||
|
self._params[i] = v + coeff * step[i]
|
||||||
|
end
|
||||||
|
|
||||||
|
self.noise = nil
|
||||||
|
|
||||||
|
return step
|
||||||
|
end
|
||||||
|
|
||||||
|
return {
|
||||||
|
--collect_best_indices = collect_best_indices, -- ars.lua has more features
|
||||||
|
Guided = Guided,
|
||||||
|
}
|
166
main.lua
166
main.lua
|
@ -20,6 +20,12 @@ local trial_rewards = {}
|
||||||
local trials_remaining = 0
|
local trials_remaining = 0
|
||||||
local es -- evolution strategy.
|
local es -- evolution strategy.
|
||||||
|
|
||||||
|
local attempt_i = 0
|
||||||
|
local sub_rewards = {}
|
||||||
|
|
||||||
|
local past_grads = {} -- for Guided
|
||||||
|
local pgi = 0 -- past_grads_index
|
||||||
|
|
||||||
local trial_frames = 0
|
local trial_frames = 0
|
||||||
local total_frames = 0
|
local total_frames = 0
|
||||||
local lagless_count = 0
|
local lagless_count = 0
|
||||||
|
@ -93,6 +99,7 @@ local util = require("util")
|
||||||
local argmax = util.argmax
|
local argmax = util.argmax
|
||||||
local argsort = util.argsort
|
local argsort = util.argsort
|
||||||
local calc_mean_dev = util.calc_mean_dev
|
local calc_mean_dev = util.calc_mean_dev
|
||||||
|
local calc_mean_dev_unbiased = util.calc_mean_dev_unbiased
|
||||||
local clamp = util.clamp
|
local clamp = util.clamp
|
||||||
local copy = util.copy
|
local copy = util.copy
|
||||||
local empty = util.empty
|
local empty = util.empty
|
||||||
|
@ -149,13 +156,21 @@ local network
|
||||||
local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z
|
local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z
|
||||||
local function make_network(input_size)
|
local function make_network(input_size)
|
||||||
nn_x = nn.Input({input_size})
|
nn_x = nn.Input({input_size})
|
||||||
nn_tx = nn.Input({gcfg.tile_count})
|
|
||||||
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
|
local embed_dim = cfg.embed and 2 or 3
|
||||||
|
|
||||||
|
if cfg.embed then
|
||||||
|
nn_tx = nn.Input({gcfg.tile_count})
|
||||||
|
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, embed_dim))
|
||||||
|
else -- new tile inputs.
|
||||||
|
nn_tx = nn.Input({gcfg.tile_count * 3})
|
||||||
|
nn_ty = nn_tx
|
||||||
|
end
|
||||||
|
|
||||||
nn_tz = nn_ty
|
nn_tz = nn_ty
|
||||||
if cfg.reduce_tiles then
|
if cfg.reduce_tiles > 0 then
|
||||||
nn_tz = nn_tz:feed(nn.Reshape{11, 17 * 2})
|
nn_tz = nn_tz:feed(nn.Reshape{11, 17 * embed_dim})
|
||||||
nn_tz = nn_tz:feed(nn.DenseBroadcast(5, true))
|
nn_tz = nn_tz:feed(nn.DenseBroadcast(cfg.reduce_tiles, true))
|
||||||
nn_tz = nn_tz:feed(nn.Relu())
|
nn_tz = nn_tz:feed(nn.Relu())
|
||||||
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
|
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
|
||||||
end
|
end
|
||||||
|
@ -185,6 +200,7 @@ end
|
||||||
local ars = require("ars")
|
local ars = require("ars")
|
||||||
local snes = require("snes")
|
local snes = require("snes")
|
||||||
local xnes = require("xnes")
|
local xnes = require("xnes")
|
||||||
|
local guided = require("guided")
|
||||||
|
|
||||||
local function prepare_epoch()
|
local function prepare_epoch()
|
||||||
trial_neg = false
|
trial_neg = false
|
||||||
|
@ -213,6 +229,8 @@ local function prepare_epoch()
|
||||||
local dummy
|
local dummy
|
||||||
if cfg.es == 'ars' then
|
if cfg.es == 'ars' then
|
||||||
trial_params, dummy = es:ask(precision)
|
trial_params, dummy = es:ask(precision)
|
||||||
|
elseif cfg.es == 'guided' then
|
||||||
|
trial_params, dummy = es:ask(past_grads)
|
||||||
elseif cfg.es == 'snes' then
|
elseif cfg.es == 'snes' then
|
||||||
trial_params, dummy = es:ask_mix()
|
trial_params, dummy = es:ask_mix()
|
||||||
else
|
else
|
||||||
|
@ -223,6 +241,7 @@ local function prepare_epoch()
|
||||||
end
|
end
|
||||||
|
|
||||||
local function load_next_trial()
|
local function load_next_trial()
|
||||||
|
attempt_i = 1
|
||||||
if cfg.negate_trials then
|
if cfg.negate_trials then
|
||||||
trial_neg = not trial_neg
|
trial_neg = not trial_neg
|
||||||
else
|
else
|
||||||
|
@ -272,8 +291,16 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
|
|
||||||
local step_mean, step_dev = calc_mean_dev(step)
|
local step_mean, step_dev = calc_mean_dev(step)
|
||||||
print("step mean:", step_mean)
|
print(("step mean: %9.6f"):format(step_mean))
|
||||||
print("step stddev:", step_dev)
|
print(("step stddev: %9.6f"):format(step_dev))
|
||||||
|
|
||||||
|
if cfg.es == 'guided' and cfg.past_grads > 0 then
|
||||||
|
for _, v in ipairs(step) do
|
||||||
|
past_grads[pgi + 1] = v
|
||||||
|
pgi = (pgi + 1) % (cfg.past_grads * #step)
|
||||||
|
end
|
||||||
|
past_grads.shape = {floor(#past_grads / #step), #step}
|
||||||
|
end
|
||||||
|
|
||||||
es:decay(cfg.param_decay, cfg.sigma_decay)
|
es:decay(cfg.param_decay, cfg.sigma_decay)
|
||||||
|
|
||||||
|
@ -329,65 +356,62 @@ local function joypad_mash(button)
|
||||||
joypad.write(1, jp_mash)
|
joypad.write(1, jp_mash)
|
||||||
end
|
end
|
||||||
|
|
||||||
local function loadlevel(world, level)
|
|
||||||
-- TODO: move to smb.lua. rename to load_level.
|
|
||||||
if world == 0 then world = random(1, 8) end
|
|
||||||
if level == 0 then level = random(1, 4) end
|
|
||||||
emu.poweron()
|
|
||||||
while emu.framecount() < 60 do
|
|
||||||
if emu.framecount() == 32 then
|
|
||||||
local area = game.area_lut[world * 10 + level]
|
|
||||||
game.W(0x75F, world - 1)
|
|
||||||
game.W(0x75C, level - 1)
|
|
||||||
game.W(0x760, area)
|
|
||||||
end
|
|
||||||
if emu.framecount() == 42 then
|
|
||||||
game.W(0x7A0, 0) -- world screen timer (reduces startup time)
|
|
||||||
end
|
|
||||||
joypad_mash('start')
|
|
||||||
emu.frameadvance()
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
local function do_reset()
|
local function do_reset()
|
||||||
local state = game.get_state()
|
local state = game.get_state()
|
||||||
-- be a little more descriptive.
|
-- be a little more descriptive.
|
||||||
if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end
|
if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end
|
||||||
|
|
||||||
if trial_i >= 0 then
|
--if cfg.attempts > 1 and attempt_i >= cfg.attempts then
|
||||||
if trial_i == 0 then
|
attempt_i = attempt_i + 1
|
||||||
print('test trial reward:', reward, "("..state..")")
|
sub_rewards[#sub_rewards + 1] = reward
|
||||||
elseif cfg.negate_trials then
|
--print(sub_rewards)
|
||||||
--local dir = trial_neg and "negative" or "positive"
|
|
||||||
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
|
|
||||||
|
|
||||||
if trial_neg then
|
if #sub_rewards >= cfg.attempts then
|
||||||
local pos = trial_rewards[#trial_rewards]
|
if cfg.attempts == 1 then
|
||||||
local neg = reward
|
reward = sub_rewards[1]
|
||||||
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
|
else
|
||||||
print(fmt:format(floor(trial_i / 2),
|
local sub_mean, sub_std = calc_mean_dev(sub_rewards)
|
||||||
pos, neg, last_trial_state, state))
|
reward = floor(sub_mean)
|
||||||
|
--local sub_mean, sub_std = calc_mean_dev_unbiased(sub_rewards)
|
||||||
|
--reward = floor(sub_mean - sub_std)
|
||||||
|
end
|
||||||
|
empty(sub_rewards)
|
||||||
|
|
||||||
|
if trial_i >= 0 then
|
||||||
|
if trial_i == 0 then
|
||||||
|
print('test trial reward:', reward, "("..state..")")
|
||||||
|
elseif cfg.negate_trials then
|
||||||
|
--local dir = trial_neg and "negative" or "positive"
|
||||||
|
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
|
||||||
|
|
||||||
|
if trial_neg then
|
||||||
|
local pos = trial_rewards[#trial_rewards]
|
||||||
|
local neg = reward
|
||||||
|
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
|
||||||
|
print(fmt:format(floor(trial_i / 2),
|
||||||
|
pos, neg, last_trial_state, state))
|
||||||
|
end
|
||||||
|
last_trial_state = state
|
||||||
|
else
|
||||||
|
print('trial', trial_i, 'reward:', reward, "("..state..")")
|
||||||
|
end
|
||||||
|
|
||||||
|
if trial_i == 0 or not cfg.negate_trials then
|
||||||
|
trial_rewards[trial_i] = reward
|
||||||
|
else
|
||||||
|
trial_rewards[#trial_rewards + 1] = reward
|
||||||
end
|
end
|
||||||
last_trial_state = state
|
|
||||||
else
|
|
||||||
print('trial', trial_i, 'reward:', reward, "("..state..")")
|
|
||||||
end
|
end
|
||||||
|
|
||||||
if trial_i == 0 or not cfg.negate_trials then
|
if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then
|
||||||
trial_rewards[trial_i] = reward
|
if epoch_i > 0 then learn_from_epoch() end
|
||||||
else
|
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
|
||||||
trial_rewards[#trial_rewards + 1] = reward
|
prepare_epoch()
|
||||||
end
|
collectgarbage()
|
||||||
end
|
if any_random then
|
||||||
|
game.load_level(cfg.starting_world, cfg.starting_level)
|
||||||
if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then
|
state_saved = false
|
||||||
if epoch_i > 0 then learn_from_epoch() end
|
end
|
||||||
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
|
|
||||||
prepare_epoch()
|
|
||||||
collectgarbage()
|
|
||||||
if any_random then
|
|
||||||
loadlevel(cfg.starting_world, cfg.starting_level)
|
|
||||||
state_saved = false
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -427,7 +451,9 @@ local function do_reset()
|
||||||
trial_frames = 0
|
trial_frames = 0
|
||||||
emu.frameadvance() -- prevents emulator from quirking up.
|
emu.frameadvance() -- prevents emulator from quirking up.
|
||||||
|
|
||||||
load_next_trial()
|
if attempt_i > cfg.attempts then
|
||||||
|
load_next_trial()
|
||||||
|
end
|
||||||
|
|
||||||
reset = false
|
reset = false
|
||||||
end
|
end
|
||||||
|
@ -450,7 +476,7 @@ local function init()
|
||||||
if not playing then emu.speedmode("turbo") end
|
if not playing then emu.speedmode("turbo") end
|
||||||
|
|
||||||
if not any_random then
|
if not any_random then
|
||||||
loadlevel(cfg.starting_world, cfg.starting_level)
|
game.load_level(cfg.starting_world, cfg.starting_level)
|
||||||
end
|
end
|
||||||
|
|
||||||
params_fn = cfg.params_fn or ('network%07i.txt'):format(network.n_param)
|
params_fn = cfg.params_fn or ('network%07i.txt'):format(network.n_param)
|
||||||
|
@ -481,7 +507,10 @@ local function init()
|
||||||
elseif cfg.es == 'ars' then
|
elseif cfg.es == 'ars' then
|
||||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||||
cfg.base_rate, cfg.deviation, cfg.negate_trials,
|
cfg.base_rate, cfg.deviation, cfg.negate_trials,
|
||||||
cfg.momentum)
|
cfg.momentum, cfg.beta)
|
||||||
|
elseif cfg.es == 'guided' then
|
||||||
|
es = guided.Guided(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||||
|
cfg.base_rate, cfg.deviation, cfg.alpha, cfg.beta)
|
||||||
else
|
else
|
||||||
error("Unknown evolution strategy specified: " + tostring(cfg.es))
|
error("Unknown evolution strategy specified: " + tostring(cfg.es))
|
||||||
end
|
end
|
||||||
|
@ -537,6 +566,7 @@ local function doit(dummy)
|
||||||
empty(game.sprite_input)
|
empty(game.sprite_input)
|
||||||
empty(game.tile_input)
|
empty(game.tile_input)
|
||||||
empty(game.extra_input)
|
empty(game.extra_input)
|
||||||
|
empty(game.new_input)
|
||||||
|
|
||||||
local controllable = game.R(0x757) == 0 and game.R(0x758) == 0
|
local controllable = game.R(0x757) == 0 and game.R(0x758) == 0
|
||||||
local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
||||||
|
@ -616,12 +646,18 @@ local function doit(dummy)
|
||||||
for i, v in ipairs(game.extra_input) do insert(X, v / 256) end
|
for i, v in ipairs(game.extra_input) do insert(X, v / 256) end
|
||||||
nn.reshape(X, 1, gcfg.input_size)
|
nn.reshape(X, 1, gcfg.input_size)
|
||||||
nn.reshape(game.tile_input, 1, gcfg.tile_count)
|
nn.reshape(game.tile_input, 1, gcfg.tile_count)
|
||||||
|
nn.reshape(game.new_input, 1, gcfg.tile_count * 3)
|
||||||
|
|
||||||
trial_frames = trial_frames + cfg.frameskip
|
trial_frames = trial_frames + cfg.frameskip
|
||||||
if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then
|
if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then
|
||||||
total_frames = total_frames + cfg.frameskip
|
total_frames = total_frames + cfg.frameskip
|
||||||
|
|
||||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
|
local outputs
|
||||||
|
if cfg.embed then
|
||||||
|
outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
|
||||||
|
else
|
||||||
|
outputs = network:forward({[nn_x]=X, [nn_tx]=game.new_input})
|
||||||
|
end
|
||||||
|
|
||||||
local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames)
|
local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames)
|
||||||
if cfg.det_epsilon and random() < eps then
|
if cfg.det_epsilon and random() < eps then
|
||||||
|
@ -695,8 +731,12 @@ while true do
|
||||||
end
|
end
|
||||||
|
|
||||||
local delta = lagless_count - last_decision_frame
|
local delta = lagless_count - last_decision_frame
|
||||||
local doot = jp == nil or delta >= cfg.frameskip
|
local doot = true
|
||||||
doot = doot and random() >= cfg.prob_frameskip
|
if jp ~= nil then
|
||||||
|
doot = delta >= cfg.frameskip
|
||||||
|
doot = doot and random() >= cfg.prob_frameskip
|
||||||
|
doot = doot or delta >= cfg.max_frameskip
|
||||||
|
end
|
||||||
doit(not doot)
|
doit(not doot)
|
||||||
if doot then last_decision_frame = lagless_count end
|
if doot then last_decision_frame = lagless_count end
|
||||||
|
|
||||||
|
|
46
nn.lua
46
nn.lua
|
@ -1,20 +1,15 @@
|
||||||
local assert = assert
|
local assert = assert
|
||||||
local ceil = math.ceil
|
|
||||||
local cos = math.cos
|
local cos = math.cos
|
||||||
local exp = math.exp
|
local exp = math.exp
|
||||||
local floor = math.floor
|
|
||||||
local huge = math.huge
|
local huge = math.huge
|
||||||
local insert = table.insert
|
local insert = table.insert
|
||||||
local ipairs = ipairs
|
local ipairs = ipairs
|
||||||
local log = math.log
|
local log = math.log
|
||||||
local max = math.max
|
local max = math.max
|
||||||
local min = math.min
|
|
||||||
local open = io.open
|
local open = io.open
|
||||||
local pairs = pairs
|
|
||||||
local pi = math.pi
|
local pi = math.pi
|
||||||
local print = print
|
local print = print
|
||||||
local remove = table.remove
|
local remove = table.remove
|
||||||
local sin = math.sin
|
|
||||||
local sqrt = math.sqrt
|
local sqrt = math.sqrt
|
||||||
local tanh = math.tanh
|
local tanh = math.tanh
|
||||||
local tostring = tostring
|
local tostring = tostring
|
||||||
|
@ -105,19 +100,26 @@ end
|
||||||
|
|
||||||
-- ndarray-ish stuff and more involved math
|
-- ndarray-ish stuff and more involved math
|
||||||
|
|
||||||
|
local function pp_join(sep, fmt, t, a, b)
|
||||||
|
a = a or 1
|
||||||
|
b = b or #t
|
||||||
|
local s = ''
|
||||||
|
for i = a, b do
|
||||||
|
s = s..fmt:format(t[i])
|
||||||
|
if i ~= b then s = s..sep end
|
||||||
|
end
|
||||||
|
return s
|
||||||
|
end
|
||||||
|
|
||||||
local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
|
local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
|
||||||
-- pretty-prints an nd-array.
|
-- pretty-prints an nd-array.
|
||||||
fmt = fmt or '%10.7f,'
|
fmt = fmt or '%10.7f'
|
||||||
sep = sep or ','
|
sep = sep or ','
|
||||||
ti = ti or 0
|
ti = ti or 0
|
||||||
di = di or 1
|
di = di or 1
|
||||||
depth = depth or 0
|
depth = depth or 0
|
||||||
|
|
||||||
if t.shape == nil then
|
if t.shape == nil then return '['..pp_join(sep, fmt, t)..']'..sep..'\n' end
|
||||||
local s = '['
|
|
||||||
for i = 1, #t do s = s..fmt:format(t[i]) end
|
|
||||||
return s..']'..sep..'\n'
|
|
||||||
end
|
|
||||||
|
|
||||||
local dim = t.shape[di]
|
local dim = t.shape[di]
|
||||||
|
|
||||||
|
@ -134,11 +136,10 @@ local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
|
||||||
s = s..pp(t, fmt, sep, ti, di + 1, depth + 1, i == 1, i == dim)
|
s = s..pp(t, fmt, sep, ti, di + 1, depth + 1, i == 1, i == dim)
|
||||||
ti = ti + ti_step
|
ti = ti + ti_step
|
||||||
end
|
end
|
||||||
if islast then s = s..indent..']'..sep..'\n' else s = s..indent..']'..sep end
|
s = s..indent..']'..sep
|
||||||
|
if islast then s = s..'\n' end
|
||||||
else
|
else
|
||||||
s = s..indent..'['
|
s = s..indent..'['..pp_join(sep, fmt, t, ti + 1, ti + dim)..']'..sep..'\n'
|
||||||
for i = ti + 1, ti + dim do s = s..fmt:format(t[i])..sep end
|
|
||||||
s = s..']'..sep..'\n'
|
|
||||||
end
|
end
|
||||||
return s
|
return s
|
||||||
end
|
end
|
||||||
|
@ -265,6 +266,20 @@ local function dot(a, b, ax_a, ax_b, out)
|
||||||
return out
|
return out
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function transpose(x, out)
|
||||||
|
assert(#x.shape == 2) -- TODO: handle ndarrays like numpy.
|
||||||
|
local rows = x.shape[1]
|
||||||
|
local cols = x.shape[2]
|
||||||
|
local y = out or zeros{cols, rows}
|
||||||
|
-- TODO: simplify? y can be consecutive for sure.
|
||||||
|
for i = 1, rows do
|
||||||
|
for j = 1, cols do
|
||||||
|
y[(j - 1) * rows + i] = x[(i - 1) * cols + j]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return y
|
||||||
|
end
|
||||||
|
|
||||||
-- nodal
|
-- nodal
|
||||||
|
|
||||||
local function traverse(node_in, node_out, nodes, dummy_mode)
|
local function traverse(node_in, node_out, nodes, dummy_mode)
|
||||||
|
@ -875,6 +890,7 @@ return {
|
||||||
ppi = ppi,
|
ppi = ppi,
|
||||||
dot_mv = dot_mv,
|
dot_mv = dot_mv,
|
||||||
dot = dot,
|
dot = dot,
|
||||||
|
transpose = transpose,
|
||||||
traverse = traverse,
|
traverse = traverse,
|
||||||
traverse_all = traverse_all,
|
traverse_all = traverse_all,
|
||||||
|
|
||||||
|
|
230
presets.lua
230
presets.lua
|
@ -33,7 +33,7 @@ make_preset{
|
||||||
|
|
||||||
init_zeros = true,
|
init_zeros = true,
|
||||||
|
|
||||||
reduce_tiles = true,
|
reduce_tiles = 5,
|
||||||
bias_out = false,
|
bias_out = false,
|
||||||
|
|
||||||
deterministic = false,
|
deterministic = false,
|
||||||
|
@ -147,6 +147,15 @@ make_preset{
|
||||||
parent = 'ars',
|
parent = 'ars',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'ars-lips',
|
||||||
|
parent = 'ars',
|
||||||
|
|
||||||
|
ars_lips = true,
|
||||||
|
-- momentum = 0.5, -- this is default.
|
||||||
|
param_rate = 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
make_preset{
|
make_preset{
|
||||||
name = 'ars-skip',
|
name = 'ars-skip',
|
||||||
parent = 'ars',
|
parent = 'ars',
|
||||||
|
@ -155,15 +164,6 @@ make_preset{
|
||||||
prob_frameskip = 0.25,
|
prob_frameskip = 0.25,
|
||||||
}
|
}
|
||||||
|
|
||||||
make_preset{
|
|
||||||
name = 'ars-lips',
|
|
||||||
parent = 'ars',
|
|
||||||
|
|
||||||
ars_lips = true,
|
|
||||||
momentum = 0.5,
|
|
||||||
param_rate = 1.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
make_preset{
|
make_preset{
|
||||||
name = 'ars-big',
|
name = 'ars-big',
|
||||||
parent = 'ars',
|
parent = 'ars',
|
||||||
|
@ -204,6 +204,216 @@ make_preset{
|
||||||
momentum = 0.99,
|
momentum = 0.99,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
-- new stuff for 2019:
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'ars-skip-more',
|
||||||
|
parent = 'ars',
|
||||||
|
|
||||||
|
-- old:
|
||||||
|
--frameskip = 2,
|
||||||
|
--prob_frameskip = 0.5,
|
||||||
|
--max_frameskip = 60,
|
||||||
|
-- new:
|
||||||
|
frameskip = 3,
|
||||||
|
prob_frameskip = 0.5,
|
||||||
|
max_frameskip = 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'ars-skip-more-3',
|
||||||
|
parent = 'ars-skip-more',
|
||||||
|
|
||||||
|
attempts = 3, -- per trial. score = mean(scores) - stdev(scores)
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'snes-skip-more-3',
|
||||||
|
parent = 'snes3',
|
||||||
|
|
||||||
|
frameskip = 3,
|
||||||
|
prob_frameskip = 0.5,
|
||||||
|
max_frameskip = 5,
|
||||||
|
attempts = 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'guided',
|
||||||
|
parent = 'big-scroll-reduced',
|
||||||
|
|
||||||
|
es = 'guided',
|
||||||
|
epoch_top_trials = 20,
|
||||||
|
deterministic = true,
|
||||||
|
deviation = 0.1,
|
||||||
|
epoch_trials = 20,
|
||||||
|
param_rate = 0.00368,
|
||||||
|
param_decay = 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'guided2',
|
||||||
|
parent = 'guided',
|
||||||
|
|
||||||
|
past_grads = 2,
|
||||||
|
-- after epoch 50, trying this:
|
||||||
|
--param_rate = 0.05,
|
||||||
|
-- after epoch 50+20, stepping back to this:
|
||||||
|
param_rate = 0.01,
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'guided10',
|
||||||
|
parent = 'guided',
|
||||||
|
|
||||||
|
past_grads = 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'guided69', -- the nice one
|
||||||
|
parent = 'guided',
|
||||||
|
|
||||||
|
deviation = 0.05,
|
||||||
|
epoch_top_trials = 10,
|
||||||
|
epoch_trials = 20,
|
||||||
|
param_rate = 0.006,
|
||||||
|
|
||||||
|
past_grads = 4,
|
||||||
|
alpha = 0.25,
|
||||||
|
}
|
||||||
|
|
||||||
|
-- TODO: yet another preset. try building up from 1 trial ARS to something good.
|
||||||
|
make_preset{
|
||||||
|
name = 'redux',
|
||||||
|
|
||||||
|
min_time = 300,
|
||||||
|
max_time = 300,
|
||||||
|
timer_loser = 1/1,
|
||||||
|
|
||||||
|
score_multiplier = 1,
|
||||||
|
|
||||||
|
init_zeros = true,
|
||||||
|
|
||||||
|
deterministic = true,
|
||||||
|
|
||||||
|
es = 'guided',
|
||||||
|
past_grads = 2, -- for Guided.
|
||||||
|
alpha = 0.25, -- for Guided.
|
||||||
|
|
||||||
|
ars_lips = false, -- for ARS.
|
||||||
|
beta = 1.0, -- fix the default.
|
||||||
|
|
||||||
|
epoch_top_trials = 4, -- for ARS, Guided.
|
||||||
|
epoch_trials = 5,
|
||||||
|
attempts = 1, -- TODO: document.
|
||||||
|
|
||||||
|
deviation = 1.0, -- 0.1
|
||||||
|
base_rate = 1.0,
|
||||||
|
param_decay = 0.01,
|
||||||
|
|
||||||
|
graycode = false, -- for ARS.
|
||||||
|
min_refresh = 0.1, -- for SNES.
|
||||||
|
sigma_decay = 0.0, -- for SNES, xNES.
|
||||||
|
momentum = 0.0, -- for ARS.
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'redux_big',
|
||||||
|
parent = 'redux',
|
||||||
|
|
||||||
|
time_inputs = true, -- insert binary inputs of a frame counter.
|
||||||
|
hidden = true, -- use a hidden layer with ReLU/GELU activation.
|
||||||
|
hidden_size = 64,
|
||||||
|
layernorm = true, -- use a LayerNorm layer after said activation.
|
||||||
|
reduce_tiles = false,
|
||||||
|
bias_out = false,
|
||||||
|
|
||||||
|
-- gets stuck pretty quick, so tweak some stuff...
|
||||||
|
epoch_top_trials = 8,
|
||||||
|
epoch_trials = 10,
|
||||||
|
deviation = 1.0,
|
||||||
|
base_rate = 0.15,
|
||||||
|
param_decay = 0.05,
|
||||||
|
past_grads = 4,
|
||||||
|
alpha = 0.25,
|
||||||
|
-- well it doesn't get stuck anymore, but regular redux works much better.
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'guided-skip-more-3',
|
||||||
|
parent = 'guided',
|
||||||
|
|
||||||
|
--param_rate = 0.00368, -- should probably be this instead...
|
||||||
|
param_rate = 0.01,
|
||||||
|
|
||||||
|
frameskip = 3,
|
||||||
|
prob_frameskip = 0.5,
|
||||||
|
max_frameskip = 5,
|
||||||
|
attempts = 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'guided-skip-more-3-again',
|
||||||
|
parent = 'guided-skip-more-3',
|
||||||
|
|
||||||
|
param_rate = 0.08, --0.0316,
|
||||||
|
deviation = 0.5,
|
||||||
|
alpha = 0.1, --0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'crazy',
|
||||||
|
parent = 'big-scroll-reduced',
|
||||||
|
|
||||||
|
es = 'guided',
|
||||||
|
epoch_top_trials = 15,
|
||||||
|
deterministic = false,
|
||||||
|
deviation = 1.0,
|
||||||
|
epoch_trials = 15,
|
||||||
|
param_rate = 1.0,
|
||||||
|
param_decay = 0.0,
|
||||||
|
alpha = 0.0316,
|
||||||
|
--attempts = 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'ars-lips2',
|
||||||
|
parent = 'ars',
|
||||||
|
|
||||||
|
ars_lips = true,
|
||||||
|
--epoch_trials = 10,
|
||||||
|
param_rate = 0.147,
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'ars-lips3',
|
||||||
|
parent = 'ars',
|
||||||
|
|
||||||
|
ars_lips = true,
|
||||||
|
param_rate = 0.5,
|
||||||
|
deviation = 0.02, -- added after like 272 epochs
|
||||||
|
param_decay = 0.0276, -- added after like 62 epochs
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'hard-embed',
|
||||||
|
parent = 'big-scroll-hidden',
|
||||||
|
|
||||||
|
embed = false,
|
||||||
|
reduce_tiles = 5,
|
||||||
|
hidden_size = 54,
|
||||||
|
|
||||||
|
epoch_top_trials = 20,
|
||||||
|
deterministic = true,
|
||||||
|
deviation = 0.01,
|
||||||
|
epoch_trials = 20,
|
||||||
|
param_rate = 0.368,
|
||||||
|
param_decay = 0.0138,
|
||||||
|
momentum = 0.5,
|
||||||
|
beta = 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
-- end of new stuff
|
||||||
|
|
||||||
make_preset{
|
make_preset{
|
||||||
name = 'play',
|
name = 'play',
|
||||||
|
|
||||||
|
|
111
qr.lua
Normal file
111
qr.lua
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
local min = math.min
|
||||||
|
local sqrt = math.sqrt
|
||||||
|
|
||||||
|
local nn = require "nn"
|
||||||
|
local dot = nn.dot
|
||||||
|
local reshape = nn.reshape
|
||||||
|
local transpose = nn.transpose
|
||||||
|
local zeros = nn.zeros
|
||||||
|
|
||||||
|
local function minor(x, d)
|
||||||
|
assert(#x.shape == 2)
|
||||||
|
assert(d <= x.shape[1] and d <= x.shape[2])
|
||||||
|
|
||||||
|
local m = zeros(x.shape)
|
||||||
|
|
||||||
|
-- fill diagonals.
|
||||||
|
--for i = 1, d do m[(i - 1) * m.shape[2] + i] = 1 end
|
||||||
|
for i = 1, d * m.shape[2], m.shape[2] + 1 do m[i] = 1 end
|
||||||
|
|
||||||
|
-- copy values.
|
||||||
|
for i = d + 1, m.shape[1] do
|
||||||
|
for j = d + 1, m.shape[2] do
|
||||||
|
local ind = (i - 1) * m.shape[2] + j
|
||||||
|
m[ind] = x[ind]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return m
|
||||||
|
end
|
||||||
|
|
||||||
|
local function norm(a) -- vector norm
|
||||||
|
local sum = 0
|
||||||
|
for _, v in ipairs(a) do sum = sum + v * v end
|
||||||
|
return sqrt(sum)
|
||||||
|
end
|
||||||
|
|
||||||
|
local function householder(x)
|
||||||
|
local rows = x.shape[1]
|
||||||
|
local cols = x.shape[2]
|
||||||
|
local iters = min(rows - 1, cols)
|
||||||
|
|
||||||
|
local q = nil
|
||||||
|
local vec = zeros(rows)
|
||||||
|
local z = x
|
||||||
|
|
||||||
|
for k = 1, iters do
|
||||||
|
z = minor(z, k - 1)
|
||||||
|
|
||||||
|
-- extract a column.
|
||||||
|
for i = 1, rows do vec[i] = z[k + (i - 1) * cols] end
|
||||||
|
|
||||||
|
local a = norm(vec)
|
||||||
|
-- negate the norm if the original diagonal is non-negative.
|
||||||
|
local ind = (k - 1) * cols + k
|
||||||
|
if x[ind] > 0 then a = -a end
|
||||||
|
|
||||||
|
vec[k] = vec[k] + a
|
||||||
|
|
||||||
|
local a = norm(vec)
|
||||||
|
if a == 0 then a = 1 end -- FIXME: should probably just raise an error.
|
||||||
|
for i, v in ipairs(vec) do vec[i] = v / a end
|
||||||
|
|
||||||
|
-- construct the householder reflection: mat = I - 2 * vec * vec.T
|
||||||
|
local mat = zeros{rows, rows}
|
||||||
|
for i = 1, rows do
|
||||||
|
for j = 1, rows do
|
||||||
|
local ind = (i - 1) * rows + j
|
||||||
|
local diag = i == j and 1 or 0
|
||||||
|
mat[ind] = diag - 2 * vec[i] * vec[j]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
--print(nn.pp(mat, "%9.3f"))
|
||||||
|
if q == nil then q = mat else q = dot(mat, q) end
|
||||||
|
|
||||||
|
z = dot(mat, z)
|
||||||
|
end
|
||||||
|
|
||||||
|
return transpose(q), dot(q, x) -- Q, R
|
||||||
|
end
|
||||||
|
|
||||||
|
local function qr(x)
|
||||||
|
-- a wrapper for the householder method that will return reduced matrices.
|
||||||
|
assert(#x.shape == 2)
|
||||||
|
|
||||||
|
local q, r = householder(x)
|
||||||
|
|
||||||
|
local rows = x.shape[1]
|
||||||
|
local cols = x.shape[2]
|
||||||
|
if cols >= rows then return q, r end
|
||||||
|
|
||||||
|
-- trim q in-place.
|
||||||
|
q.shape[2] = cols
|
||||||
|
local ind = 1
|
||||||
|
for i = 1, rows do
|
||||||
|
for j = 1, cols do
|
||||||
|
--ind = (i - 1) * cols + j
|
||||||
|
q[ind] = q[(i - 1) * rows + j]
|
||||||
|
ind = ind + 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
for i = rows * cols + 1, #q do q[i] = nil end
|
||||||
|
|
||||||
|
-- trim r in-place.
|
||||||
|
r.shape[1] = r.shape[2]
|
||||||
|
for i = r.shape[1] * r.shape[2] + 1, #r do r[i] = nil end
|
||||||
|
|
||||||
|
return q, r
|
||||||
|
end
|
||||||
|
|
||||||
|
return qr
|
76
qr2.lua
Normal file
76
qr2.lua
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
local min = math.min
|
||||||
|
local sqrt = math.sqrt
|
||||||
|
|
||||||
|
local nn = require "nn"
|
||||||
|
local transpose = nn.transpose
|
||||||
|
local zeros = nn.zeros
|
||||||
|
|
||||||
|
local function qr(a)
|
||||||
|
-- FIXME: if first column is exactly zero,
|
||||||
|
-- and cols > rows, Q @ R will not reconstruct the input.
|
||||||
|
-- this isn't too bad since an input like that is invalid anyway,
|
||||||
|
-- but i feel like it should be salvageable.
|
||||||
|
|
||||||
|
-- actually the scope of the problem is much larger than that.
|
||||||
|
-- an input like
|
||||||
|
--[=[
|
||||||
|
[[0, 0, 0, 0]
|
||||||
|
[1, 0, 1, 1]
|
||||||
|
[0, 0, 2, 2]
|
||||||
|
[0, 0, 3, 3]]
|
||||||
|
--]=]
|
||||||
|
-- will cause a lot of problems. for example, Q @ Q.T won't equal eye(4).
|
||||||
|
-- hmm. maybe we can detect this and reverse the matmul to identity if necessary?
|
||||||
|
|
||||||
|
assert(#a.shape == 2)
|
||||||
|
local rows = a.shape[1]
|
||||||
|
local cols = a.shape[2]
|
||||||
|
local small = min(rows, cols)
|
||||||
|
|
||||||
|
local q = transpose(a)
|
||||||
|
local r = zeros{small, cols}
|
||||||
|
|
||||||
|
for i = 1, cols do
|
||||||
|
local i0 = (i - 1) * rows + 1
|
||||||
|
local i1 = i * rows
|
||||||
|
|
||||||
|
for j = 1, min(i - 1, small) do
|
||||||
|
local j0 = (j - 1) * rows + 1
|
||||||
|
local j1 = j * rows
|
||||||
|
local i_to_j = j0 - i0
|
||||||
|
|
||||||
|
local num = 0
|
||||||
|
local den = 0
|
||||||
|
for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end
|
||||||
|
for k = j0, j1 do den = den + q[k] * q[k] end
|
||||||
|
print(num, den)
|
||||||
|
if den == 0 then den = 1 end -- TODO: should probably just error.
|
||||||
|
|
||||||
|
local x = num / den
|
||||||
|
r[(j - 1) * cols + i] = x
|
||||||
|
for k = i0, i1 do q[k] = q[k] - q[k + i_to_j] * x end
|
||||||
|
end
|
||||||
|
|
||||||
|
if i <= small then
|
||||||
|
local sum = 0
|
||||||
|
for k = i0, i1 do sum = sum + q[k] * q[k] end
|
||||||
|
local norm = sqrt(sum)
|
||||||
|
|
||||||
|
if norm == 0 then
|
||||||
|
--norm = 1
|
||||||
|
--q[i0 + i - 1] = 1 -- FIXME: not robust.
|
||||||
|
r[(i - 1) * cols + i] = 0
|
||||||
|
else
|
||||||
|
for k = i0, i1 do q[k] = q[k] / norm end
|
||||||
|
r[(i - 1) * cols + i] = norm
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
for k = small * rows + 1, #q do q[k] = nil end
|
||||||
|
q.shape[1] = small
|
||||||
|
|
||||||
|
return transpose(q), r
|
||||||
|
end
|
||||||
|
|
||||||
|
return qr
|
72
qr_test.lua
Normal file
72
qr_test.lua
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
local globalize = require "strict"
|
||||||
|
local nn = require "nn"
|
||||||
|
local qr = require "qr"
|
||||||
|
local qr2 = require "qr2"
|
||||||
|
|
||||||
|
local A
|
||||||
|
|
||||||
|
if false then
|
||||||
|
A = {
|
||||||
|
12, -51, 4,
|
||||||
|
6, 167, -68,
|
||||||
|
-4, 24, -41,
|
||||||
|
|
||||||
|
-1, 1, 0,
|
||||||
|
2, 0, 3,
|
||||||
|
}
|
||||||
|
A = nn.reshape(A, 5, 3)
|
||||||
|
|
||||||
|
elseif false then
|
||||||
|
A = {
|
||||||
|
0, 1, 2, 3, 4,
|
||||||
|
5, 6, 7, 8, 9,
|
||||||
|
10, 11, 12, 13, 14
|
||||||
|
}
|
||||||
|
A = nn.reshape(A, 5, 3)
|
||||||
|
|
||||||
|
else
|
||||||
|
A = {
|
||||||
|
1, 2, 0,
|
||||||
|
2, 3, 1,
|
||||||
|
3, 4, 0,
|
||||||
|
4, 5, 1,
|
||||||
|
5, 6, 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
A = {
|
||||||
|
1, 0, 0,
|
||||||
|
2, 0, 1,
|
||||||
|
3, 0, 0,
|
||||||
|
4, 0, 1,
|
||||||
|
5, 0, 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
A = nn.reshape(A, 5, 3)
|
||||||
|
--A = nn.transpose(A)
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
print("A")
|
||||||
|
print(nn.pp(A, "%9.4f"))
|
||||||
|
print()
|
||||||
|
|
||||||
|
local Q, R = qr2(A)
|
||||||
|
|
||||||
|
print("Q")
|
||||||
|
print(nn.pp(Q, "%9.4f"))
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("R")
|
||||||
|
print(nn.pp(R, "%9.4f"))
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Q @ R")
|
||||||
|
print(nn.pp(nn.dot(Q, R), "%9.4f"))
|
||||||
|
print()
|
||||||
|
|
||||||
|
--print("Q @ Q.T = I")
|
||||||
|
--print(nn.pp(nn.dot(Q, nn.transpose(Q)), "%9.4f"))
|
||||||
|
--print()
|
||||||
|
|
||||||
|
--A = nn.reshape(A, 5, 3)
|
||||||
|
--Q, R = qr(A)
|
217
smb.lua
217
smb.lua
|
@ -1,12 +1,16 @@
|
||||||
-- disassembly used for reference:
|
-- disassembly used for reference:
|
||||||
-- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
|
-- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
|
||||||
|
|
||||||
local band = bit.band
|
|
||||||
local floor = math.floor
|
|
||||||
local emu = emu
|
|
||||||
local gui = gui
|
|
||||||
|
|
||||||
local util = require("util")
|
local util = require("util")
|
||||||
|
|
||||||
|
local band = bit.band
|
||||||
|
local clamp = util.clamp
|
||||||
|
local empty = util.empty
|
||||||
|
local emu = emu
|
||||||
|
local floor = math.floor
|
||||||
|
local gui = gui
|
||||||
|
local insert = table.insert
|
||||||
|
|
||||||
local R = memory.readbyteunsigned
|
local R = memory.readbyteunsigned
|
||||||
local W = memory.writebyte
|
local W = memory.writebyte
|
||||||
local function S(addr) return util.signbyte(R(addr)) end
|
local function S(addr) return util.signbyte(R(addr)) end
|
||||||
|
@ -76,13 +80,84 @@ local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
|
||||||
-8, -38,
|
-8, -38,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
local tile_embedding = {
|
||||||
|
-- handmade trinary encoding.
|
||||||
|
-- we have 57 valid tile types and 27 permutations to work with.
|
||||||
|
[0x00] = { 0, 0, 0}, -- air
|
||||||
|
[0x10] = { 1, -1, 1}, -- vertical pipe (top left) (enterable)
|
||||||
|
[0x11] = { 1, -1, 1}, -- vertical pipe (top right) (enterable)
|
||||||
|
[0x12] = { 0, -1, 1}, -- vertical pipe (top left)
|
||||||
|
[0x13] = { 0, -1, 1}, -- vertical pipe (top right)
|
||||||
|
[0x14] = { 0, -1, -1}, -- vertical pipe (left)
|
||||||
|
[0x15] = { 0, -1, -1}, -- vertical pipe (right)
|
||||||
|
[0x16] = {-1, -1, -1}, --
|
||||||
|
[0x17] = {-1, -1, -1}, --
|
||||||
|
[0x18] = {-1, -1, -1}, --
|
||||||
|
[0x19] = {-1, -1, -1}, --
|
||||||
|
[0x1A] = {-1, -1, -1}, --
|
||||||
|
[0x1B] = {-1, -1, -1}, --
|
||||||
|
[0x1C] = { 0, -1, 0}, -- horizontal pipe (top left)
|
||||||
|
[0x1D] = { 0, -1, 0}, -- horizontal pipe (top)
|
||||||
|
[0x1E] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (top)
|
||||||
|
[0x1F] = { 0, -1, 0}, -- horizontal pipe (bottom left)
|
||||||
|
[0x20] = { 0, -1, 0}, -- horizontal pipe (bottom)
|
||||||
|
[0x21] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (bottom)
|
||||||
|
[0x22] = { 0, 0, 0}, --
|
||||||
|
[0x23] = { 0, 0, 0}, -- block being hit (either breakable or ?)
|
||||||
|
[0x24] = { 0, 0, 0}, --
|
||||||
|
[0x25] = { 0, 0, 1}, -- flagpole
|
||||||
|
[0x26] = { 0, 0, 0}, --
|
||||||
|
[0x51] = { 1, 1, 0}, -- breakable brick block
|
||||||
|
[0x52] = { 1, 1, 0}, -- breakable brick block (again?)
|
||||||
|
[0x54] = { 0, 1, 0}, -- regular ground
|
||||||
|
[0x55] = { 0, 0, 0}, --
|
||||||
|
[0x56] = { 0, 0, 0}, --
|
||||||
|
[0x57] = {-1, 1, 1}, -- star brick block
|
||||||
|
[0x58] = { 1, 1, -1}, -- coin brick block (many coins)
|
||||||
|
[0x59] = { 0, 0, 0}, --
|
||||||
|
[0x5A] = { 0, 0, 0}, --
|
||||||
|
[0x5B] = { 0, 0, 0}, --
|
||||||
|
[0x5C] = { 0, 0, 0}, --
|
||||||
|
[0x5D] = { 1, 1, -1}, -- coin brick block (many coins) (again?)
|
||||||
|
[0x5E] = { 0, 0, 0}, --
|
||||||
|
[0x5F] = { 0, 0, 0}, --
|
||||||
|
[0x60] = {-1, 0, 0}, -- invisible 1-up block
|
||||||
|
[0x61] = { 0, 1, -1}, -- chocolate block (usually used for stairs)
|
||||||
|
[0x62] = { 0, 0, 0}, --
|
||||||
|
[0x63] = { 0, 0, 0}, --
|
||||||
|
[0x64] = { 0, 0, 0}, --
|
||||||
|
[0x65] = { 0, 0, 0}, --
|
||||||
|
[0x66] = { 0, 0, 0}, --
|
||||||
|
[0x67] = { 0, 0, 0}, --
|
||||||
|
[0x68] = { 0, 0, 0}, --
|
||||||
|
[0x69] = { 0, 0, 0}, --
|
||||||
|
[0x6B] = { 0, 0, 0}, --
|
||||||
|
[0x6C] = { 0, 0, 0}, --
|
||||||
|
[0x89] = { 0, 0, 0}, --
|
||||||
|
[0xC0] = {-1, 1, -1}, -- coin ? block
|
||||||
|
[0xC1] = {-1, 1, 0}, -- mushroom ? block
|
||||||
|
[0xC2] = { 0, 0, -1}, -- coin
|
||||||
|
[0xC3] = { 0, 0, 0}, --
|
||||||
|
[0xC4] = { 0, 1, 1}, -- hit block
|
||||||
|
[0xC5] = { 0, 0, 0}, --
|
||||||
|
}
|
||||||
|
|
||||||
-- TODO: reinterface to one "input" array visible to main.lua.
|
-- TODO: reinterface to one "input" array visible to main.lua.
|
||||||
local sprite_input = {}
|
local sprite_input = {}
|
||||||
local tile_input = {}
|
local tile_input = {}
|
||||||
local extra_input = {}
|
local extra_input = {}
|
||||||
|
local new_input = {}
|
||||||
|
|
||||||
local overlay = false
|
local overlay = false
|
||||||
|
|
||||||
|
local function embed_tile(t)
|
||||||
|
local out = new_input
|
||||||
|
local embedded = tile_embedding[t]
|
||||||
|
insert(out, embedded[1])
|
||||||
|
insert(out, embedded[2])
|
||||||
|
insert(out, embedded[3])
|
||||||
|
end
|
||||||
|
|
||||||
local function get_timer()
|
local function get_timer()
|
||||||
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
||||||
end
|
end
|
||||||
|
@ -130,6 +205,7 @@ end
|
||||||
|
|
||||||
local function mark_tile(x, y, t)
|
local function mark_tile(x, y, t)
|
||||||
tile_input[#tile_input+1] = tile_lut[t]
|
tile_input[#tile_input+1] = tile_lut[t]
|
||||||
|
embed_tile(t)
|
||||||
if t == 0 then return end
|
if t == 0 then return end
|
||||||
if overlay then
|
if overlay then
|
||||||
gui.box(x-8, y-8, x+8, y+8)
|
gui.box(x-8, y-8, x+8, y+8)
|
||||||
|
@ -289,7 +365,8 @@ local function handle_tiles()
|
||||||
extra_input[#extra_input+1] = tile_scroll_remainder
|
extra_input[#extra_input+1] = tile_scroll_remainder
|
||||||
-- for y = 0, 12 do
|
-- for y = 0, 12 do
|
||||||
-- afaik the bottom row is always a copy of the second to bottom,
|
-- afaik the bottom row is always a copy of the second to bottom,
|
||||||
-- and the top is always air, so drop those from the inputs:
|
-- and the top is always air (except underground!),
|
||||||
|
-- so drop those from the inputs:
|
||||||
for y = 1, 11 do
|
for y = 1, 11 do
|
||||||
for x = 0, 16 do
|
for x = 0, 16 do
|
||||||
local col = (x + tile_scroll) % 32
|
local col = (x + tile_scroll) % 32
|
||||||
|
@ -306,6 +383,117 @@ local function handle_tiles()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function load_level(world, level)
|
||||||
|
if world == 0 then world = random(1, 8) end
|
||||||
|
if level == 0 then level = random(1, 4) end
|
||||||
|
emu.poweron()
|
||||||
|
local jp_mash = {
|
||||||
|
up = false, down = false, left = false, right = false,
|
||||||
|
A = false, B = false, select = false, start = false,
|
||||||
|
}
|
||||||
|
while emu.framecount() < 60 do
|
||||||
|
if emu.framecount() == 32 then
|
||||||
|
local area = area_lut[world * 10 + level]
|
||||||
|
W(0x75F, world - 1)
|
||||||
|
W(0x75C, level - 1)
|
||||||
|
W(0x760, area)
|
||||||
|
end
|
||||||
|
if emu.framecount() == 42 then
|
||||||
|
W(0x7A0, 0) -- world screen timer (reduces startup time)
|
||||||
|
end
|
||||||
|
jp_mash['start'] = emu.framecount() % 2 == 1
|
||||||
|
joypad.write(1, jp_mash)
|
||||||
|
emu.frameadvance()
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- new stuff.
|
||||||
|
|
||||||
|
local function tile_from_xy(x, y)
|
||||||
|
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||||
|
local tile_scroll_remainder = R(0x73F) % 16
|
||||||
|
|
||||||
|
local tile_x = floor((x + tile_scroll_remainder) / 16)
|
||||||
|
local tile_y = floor(y / 16)
|
||||||
|
local tile_t = 0 -- default to air
|
||||||
|
if tile_y < 0 then
|
||||||
|
tile_y = 0
|
||||||
|
elseif tile_y > 12 then
|
||||||
|
tile_y = 12
|
||||||
|
else
|
||||||
|
local col = (tile_x + tile_scroll) % 32
|
||||||
|
local addr = col < 16 and 0x500 or 0x5D0
|
||||||
|
tile_t = R(addr + tile_y * 16 + (col % 16))
|
||||||
|
end
|
||||||
|
return tile_t, tile_x, tile_y
|
||||||
|
end
|
||||||
|
|
||||||
|
local function new_stuff()
|
||||||
|
-- obviously very work in progress.
|
||||||
|
|
||||||
|
empty(new_input)
|
||||||
|
|
||||||
|
local mario_x, mario_y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
||||||
|
|
||||||
|
-- normalized mario x
|
||||||
|
-- normalized mario y
|
||||||
|
insert(new_input, (mario_x - 112) / 64)
|
||||||
|
insert(new_input, (mario_y - 160) / 96)
|
||||||
|
|
||||||
|
-- type of tile we're standing on
|
||||||
|
-- type of tile we're occupying
|
||||||
|
|
||||||
|
gui.box(mario_x, mario_y, mario_x + 16, mario_y + 32)
|
||||||
|
|
||||||
|
local mario_tile_t, mario_tile_x, mario_tile_y =
|
||||||
|
tile_from_xy(mario_x + 8, mario_y - 8)
|
||||||
|
|
||||||
|
--[[
|
||||||
|
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||||
|
local tile_scroll_remainder = R(0x73F) % 16
|
||||||
|
local sx = mario_tile_x * 16 + 8 - tile_scroll_remainder
|
||||||
|
local sy = mario_tile_y * 16 + 40
|
||||||
|
gui.box(sx-8, sy-8, sx+8, sy+8)
|
||||||
|
gui.text(sx-5, sy-3, ("%02X"):format(mario_tile_t), '#FFFFFF', '#00000000')
|
||||||
|
--]]
|
||||||
|
|
||||||
|
embed_tile(new_input, mario_tile_t)
|
||||||
|
|
||||||
|
-- type of tile to the right, excluding space (small eyeheight)
|
||||||
|
-- how tall non-space extends upward from that tile
|
||||||
|
-- how tall non-space extends downward from that tile
|
||||||
|
-- type of tile to the right, excluding space (large eyeheight)
|
||||||
|
|
||||||
|
-- type of tile to the left, excluding space (small eyeheight)
|
||||||
|
-- how tall non-space extends upward from that tile
|
||||||
|
-- how tall non-space extends downward from that tile
|
||||||
|
-- type of tile to the left, excluding space (large eyeheight)
|
||||||
|
|
||||||
|
-- type of enemy (nearest down-right from mario's upper left)
|
||||||
|
-- normalized enemy x
|
||||||
|
-- normalized enemy y
|
||||||
|
|
||||||
|
-- VISUALIZE:
|
||||||
|
|
||||||
|
--local tile_col = R(0x6A0)
|
||||||
|
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||||
|
local tile_scroll_remainder = R(0x73F) % 16
|
||||||
|
for y = 1, 11 do
|
||||||
|
for x = 0, 16 do
|
||||||
|
local col = (x + tile_scroll) % 32
|
||||||
|
local addr = col < 16 and 0x500 or 0x5D0
|
||||||
|
local t = R(addr + y * 16 + (col % 16))
|
||||||
|
local sx = x * 16 + 8 - tile_scroll_remainder
|
||||||
|
local sy = y * 16 + 40
|
||||||
|
|
||||||
|
if t ~= 0 then
|
||||||
|
gui.box(sx-8, sy-8, sx+8, sy+8)
|
||||||
|
gui.text(sx-5, sy-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
-- TODO: don't expose these; provide interfaces for everything needed.
|
-- TODO: don't expose these; provide interfaces for everything needed.
|
||||||
R=R,
|
R=R,
|
||||||
|
@ -315,6 +503,7 @@ overlay=overlay,
|
||||||
|
|
||||||
valid_tiles=valid_tiles,
|
valid_tiles=valid_tiles,
|
||||||
area_lut=area_lut,
|
area_lut=area_lut,
|
||||||
|
embed_tile=embed_tile,
|
||||||
|
|
||||||
sprite_input=sprite_input,
|
sprite_input=sprite_input,
|
||||||
tile_input=tile_input,
|
tile_input=tile_input,
|
||||||
|
@ -323,16 +512,24 @@ extra_input=extra_input,
|
||||||
get_timer=get_timer,
|
get_timer=get_timer,
|
||||||
get_score=get_score,
|
get_score=get_score,
|
||||||
set_timer=set_timer,
|
set_timer=set_timer,
|
||||||
mark_sprite=mark_sprite,
|
get_state=get_state,
|
||||||
mark_tile=mark_tile,
|
|
||||||
getxy=getxy,
|
getxy=getxy,
|
||||||
paused=paused,
|
paused=paused,
|
||||||
get_state=get_state,
|
|
||||||
advance=advance,
|
mark_sprite=mark_sprite,
|
||||||
|
mark_tile=mark_tile,
|
||||||
|
|
||||||
handle_enemies=handle_enemies,
|
handle_enemies=handle_enemies,
|
||||||
handle_fireballs=handle_fireballs,
|
handle_fireballs=handle_fireballs,
|
||||||
handle_blocks=handle_blocks,
|
handle_blocks=handle_blocks,
|
||||||
handle_hammers=handle_hammers,
|
handle_hammers=handle_hammers,
|
||||||
handle_misc=handle_misc,
|
handle_misc=handle_misc,
|
||||||
handle_tiles=handle_tiles,
|
handle_tiles=handle_tiles,
|
||||||
|
|
||||||
|
advance=advance,
|
||||||
|
load_level=load_level,
|
||||||
|
|
||||||
|
new_stuff=new_stuff,
|
||||||
|
new_input=new_input,
|
||||||
}
|
}
|
||||||
|
|
136
snes.lua
136
snes.lua
|
@ -29,9 +29,12 @@ local normalize_sums = util.normalize_sums
|
||||||
local pdf = util.pdf
|
local pdf = util.pdf
|
||||||
local weighted_mann_whitney = util.weighted_mann_whitney
|
local weighted_mann_whitney = util.weighted_mann_whitney
|
||||||
|
|
||||||
|
local xnes = require "xnes"
|
||||||
|
local make_utility = xnes.make_utility
|
||||||
|
|
||||||
local Snes = Base:extend()
|
local Snes = Base:extend()
|
||||||
|
|
||||||
function Snes:init(dims, popsize, base_rate, sigma, antithetic)
|
function Snes:init(dims, popsize, base_rate, sigma, antithetic, adaptive)
|
||||||
-- heuristic borrowed from CMA-ES:
|
-- heuristic borrowed from CMA-ES:
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||||
|
@ -41,9 +44,12 @@ function Snes:init(dims, popsize, base_rate, sigma, antithetic)
|
||||||
self.covar_rate = base_rate
|
self.covar_rate = base_rate
|
||||||
self.sigma = sigma or 1
|
self.sigma = sigma or 1
|
||||||
self.antithetic = antithetic and true or false
|
self.antithetic = antithetic and true or false
|
||||||
|
self.adaptive = adaptive == nil and true or adaptive
|
||||||
|
|
||||||
if self.antithetic then self.popsize = self.popsize * 2 end
|
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||||
|
|
||||||
|
self.utility = make_utility(self.popsize)
|
||||||
|
|
||||||
self.rate_init = self.sigma_rate
|
self.rate_init = self.sigma_rate
|
||||||
|
|
||||||
self.mean = zeros{dims}
|
self.mean = zeros{dims}
|
||||||
|
@ -147,60 +153,84 @@ function Snes:ask_mix(start_anew)
|
||||||
|
|
||||||
-- perform importance mixing.
|
-- perform importance mixing.
|
||||||
|
|
||||||
local mean_old = self.mean
|
local mean_old = self.mean_old or self.mean
|
||||||
local mean_new = self.mean
|
local mean_new = self.mean
|
||||||
local std_old = self.std_old or self.std
|
local std_old = self.std_old or self.std
|
||||||
local std_new = self.std
|
local std_new = self.std
|
||||||
|
|
||||||
self.new_asked = {}
|
local function compute_probabilities(a)
|
||||||
self.new_noise = {}
|
|
||||||
|
|
||||||
local marked = {}
|
|
||||||
for p=1, min(#self.old_asked, self.popsize) do
|
|
||||||
local a = self.old_asked[p]
|
|
||||||
|
|
||||||
-- TODO: cache probs?
|
|
||||||
local prob_new = 0
|
local prob_new = 0
|
||||||
local prob_old = 0
|
local prob_old = 0
|
||||||
for i, v in ipairs(a) do
|
for i, v in ipairs(a) do
|
||||||
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
|
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
|
||||||
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
|
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
|
||||||
end
|
end
|
||||||
|
return prob_new, prob_old
|
||||||
|
end
|
||||||
|
|
||||||
local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1)
|
local all_asked, all_noise, all_score = {}, {}, {}
|
||||||
if uniform() < accept then
|
|
||||||
--print(("accepted old sample %i with probability %f"):format(p, accept))
|
for p=1, #self.old_asked do
|
||||||
else
|
do
|
||||||
-- insert in reverse as not to screw up
|
local pp = floor(uniform() * #self.old_asked) + 1
|
||||||
-- the indices when removing later.
|
local a = self.old_asked[pp]
|
||||||
insert(marked, 1, p)
|
local prob_new, prob_old = compute_probabilities(a)
|
||||||
|
local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1)
|
||||||
|
if uniform() < accept then
|
||||||
|
--print(("accepted old sample %i with probability %f"):format(pp, accept))
|
||||||
|
insert(all_asked, a)
|
||||||
|
insert(all_noise, self.old_noise[pp])
|
||||||
|
insert(all_score, self.old_score[pp])
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
|
||||||
for _, p in ipairs(marked) do
|
do
|
||||||
remove(self.old_asked, p)
|
local a, n = {}, {}
|
||||||
remove(self.old_noise, p)
|
for i=1, self.dims do n[i] = normal() end
|
||||||
remove(self.old_score, p)
|
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
|
||||||
|
local prob_new, prob_old = compute_probabilities(a)
|
||||||
|
local accept = max(1 - prob_old / prob_new, self.min_refresh)
|
||||||
|
if uniform() < accept then
|
||||||
|
--print(("accepted new sample %i with probability %f"):format(#all_asked, accept))
|
||||||
|
insert(all_asked, a)
|
||||||
|
insert(all_noise, n)
|
||||||
|
insert(all_score, false)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- TODO: early stopping, making sure it doesn't affect performance.
|
||||||
end
|
end
|
||||||
|
|
||||||
while #self.old_asked + #self.new_asked < self.popsize do
|
while #all_asked > self.popsize do
|
||||||
local a = {}
|
local pp = floor(uniform() * #all_asked) + 1
|
||||||
local n = {}
|
--print(("removing sample %i to fit popsize"):format(pp))
|
||||||
|
remove(all_asked, pp)
|
||||||
|
remove(all_noise, pp)
|
||||||
|
remove(all_score, pp)
|
||||||
|
end
|
||||||
|
|
||||||
|
while #all_asked < self.popsize do
|
||||||
|
local a, n = {}, {}
|
||||||
for i=1, self.dims do n[i] = normal() end
|
for i=1, self.dims do n[i] = normal() end
|
||||||
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
|
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
|
||||||
|
--print(("unconditionally added new sample %i"):format(#all_asked))
|
||||||
|
insert(all_asked, a)
|
||||||
|
insert(all_noise, n)
|
||||||
|
insert(all_score, false)
|
||||||
|
end
|
||||||
|
|
||||||
-- can't cache here!
|
-- split all_ tables back into old_ and new_.
|
||||||
local prob_new = 0
|
self.old_asked, self.old_noise, self.old_score = {}, {}, {}
|
||||||
local prob_old = 0
|
self.new_asked, self.new_noise = {}, {}
|
||||||
for i, v in ipairs(a) do
|
for i, score in ipairs(all_score) do
|
||||||
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
|
local a, n = all_asked[i], all_noise[i]
|
||||||
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
|
if score ~= false then
|
||||||
end
|
insert(self.old_asked, a)
|
||||||
|
insert(self.old_noise, n)
|
||||||
local accept = max(1 - prob_old / prob_new, self.min_refresh)
|
insert(self.old_score, score)
|
||||||
if uniform() < accept then
|
else
|
||||||
insert(self.new_asked, a)
|
insert(self.new_asked, a)
|
||||||
insert(self.new_noise, n)
|
insert(self.new_noise, n)
|
||||||
--print(("accepted new sample %i with probability %f"):format(0, accept))
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -210,15 +240,15 @@ end
|
||||||
function Snes:tell(scored)
|
function Snes:tell(scored)
|
||||||
self.evals = self.evals + #scored
|
self.evals = self.evals + #scored
|
||||||
|
|
||||||
local asked = self.asked
|
local asked = self.mixing and self.new_asked or self.asked
|
||||||
local noise = self.noise
|
local noise = self.mixing and self.new_noise or self.noise
|
||||||
if self.mixing then
|
if self.mixing then
|
||||||
|
-- note: modifies, in-place, externally exposed tables.
|
||||||
|
for i, v in ipairs(asked) do insert(self.old_asked, v) end
|
||||||
|
for i, v in ipairs(noise) do insert(self.old_noise, v) end
|
||||||
|
for i, v in ipairs(scored) do insert(self.old_score, v) end
|
||||||
asked = self.old_asked
|
asked = self.old_asked
|
||||||
noise = self.old_noise
|
noise = self.old_noise
|
||||||
-- note that these modify tables referenced externally in-place.
|
|
||||||
for i, v in ipairs(self.new_asked) do insert(asked, v) end
|
|
||||||
for i, v in ipairs(self.new_noise) do insert(noise, v) end
|
|
||||||
for i, v in ipairs(scored) do insert(self.old_score, v) end
|
|
||||||
scored = self.old_score
|
scored = self.old_score
|
||||||
end
|
end
|
||||||
assert(asked and noise, ":tell() called before :ask()")
|
assert(asked and noise, ":tell() called before :ask()")
|
||||||
|
@ -231,8 +261,9 @@ function Snes:tell(scored)
|
||||||
local g_mean = zeros{self.dims}
|
local g_mean = zeros{self.dims}
|
||||||
local g_std = zeros{self.dims}
|
local g_std = zeros{self.dims}
|
||||||
|
|
||||||
|
--[[
|
||||||
local utilize = true
|
local utilize = true
|
||||||
local utility
|
local utility = self.utility
|
||||||
|
|
||||||
if utilize then
|
if utilize then
|
||||||
utility = {}
|
utility = {}
|
||||||
|
@ -242,17 +273,18 @@ function Snes:tell(scored)
|
||||||
else
|
else
|
||||||
utility = normalize_sums(scored, {})
|
utility = normalize_sums(scored, {})
|
||||||
end
|
end
|
||||||
|
--]]
|
||||||
|
|
||||||
for p=1, self.popsize do
|
for p=1, self.popsize do
|
||||||
local noise_p = noise[p]
|
local noise_p = noise[arg[p]]
|
||||||
|
|
||||||
for i, v in ipairs(g_mean) do
|
for i, v in ipairs(g_mean) do
|
||||||
g_mean[i] = v + utility[p] * noise_p[i]
|
g_mean[i] = v + self.utility[p] * noise_p[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
for i, v in ipairs(g_std) do
|
for i, v in ipairs(g_std) do
|
||||||
local n = noise_p[i]
|
local n = noise_p[i]
|
||||||
g_std[i] = v + utility[p] * (n * n - 1)
|
g_std[i] = v + self.utility[p] * (n * n - 1)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -261,7 +293,9 @@ function Snes:tell(scored)
|
||||||
step[i] = self.std[i] * v
|
step[i] = self.std[i] * v
|
||||||
end
|
end
|
||||||
|
|
||||||
|
self.mean_old = {}
|
||||||
for i, v in ipairs(self.mean) do
|
for i, v in ipairs(self.mean) do
|
||||||
|
self.mean_old[i] = v
|
||||||
self.mean[i] = v + self.param_rate * step[i]
|
self.mean[i] = v + self.param_rate * step[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -273,7 +307,7 @@ function Snes:tell(scored)
|
||||||
otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i])
|
otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i])
|
||||||
end
|
end
|
||||||
|
|
||||||
self:adapt(asked, otherwise, utility)
|
if self.adaptive then self:adapt(asked, otherwise, self.utility) end
|
||||||
|
|
||||||
return step
|
return step
|
||||||
end
|
end
|
||||||
|
@ -291,15 +325,15 @@ function Snes:adapt(asked, otherwise, qualities)
|
||||||
weights[p] = prob_big / prob_now
|
weights[p] = prob_big / prob_now
|
||||||
end
|
end
|
||||||
|
|
||||||
local p = weighted_mann_whitney(qualities, qualities, nil, weights)
|
local u, p = weighted_mann_whitney(qualities, qualities, nil, weights)
|
||||||
--print("p:", p)
|
--print(("u, p: %6.3f, %6.3f"):format(u, p))
|
||||||
|
|
||||||
if p < 0.5 - 1 / (3 * (self.dims + 1)) then
|
if p < 0.5 - 1 / (3 * (self.dims + 1)) then
|
||||||
self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init
|
self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init
|
||||||
print("learning rate -:", self.sigma_rate)
|
--print("learning rate -:", self.sigma_rate)
|
||||||
else
|
else
|
||||||
self.sigma_rate = min(1.1 * self.sigma_rate, 1)
|
self.sigma_rate = min(1.1 * self.sigma_rate, 1)
|
||||||
print("learning rate +:", self.sigma_rate)
|
--print("learning rate +:", self.sigma_rate)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
51
util.lua
51
util.lua
|
@ -14,6 +14,7 @@ local random = math.random
|
||||||
local select = select
|
local select = select
|
||||||
local sort = table.sort
|
local sort = table.sort
|
||||||
local sqrt = math.sqrt
|
local sqrt = math.sqrt
|
||||||
|
local type = type
|
||||||
|
|
||||||
local function sign(x)
|
local function sign(x)
|
||||||
-- remember that 0 is truthy in Lua.
|
-- remember that 0 is truthy in Lua.
|
||||||
|
@ -83,6 +84,28 @@ local function calc_mean_dev(x)
|
||||||
return mean, sqrt(dev)
|
return mean, sqrt(dev)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function calc_mean_dev_unbiased(x)
|
||||||
|
-- NOTE: this uses an approximation; there is still a little bias.
|
||||||
|
assert(#x > 1)
|
||||||
|
|
||||||
|
local mean = 0
|
||||||
|
for i, v in ipairs(x) do
|
||||||
|
mean = mean + v / #x
|
||||||
|
end
|
||||||
|
|
||||||
|
-- via Gurland, John; Tripathi, Ram C. (1971):
|
||||||
|
-- A Simple Approximation for Unbiased Estimation of the Standard Deviation
|
||||||
|
local divisor = #x - 1.5 + 1 / (8 * (#x - 1))
|
||||||
|
|
||||||
|
local dev = 0
|
||||||
|
for i, v in ipairs(x) do
|
||||||
|
local delta = v - mean
|
||||||
|
dev = dev + delta * delta / divisor
|
||||||
|
end
|
||||||
|
|
||||||
|
return mean, sqrt(dev)
|
||||||
|
end
|
||||||
|
|
||||||
local function normalize(x, out)
|
local function normalize(x, out)
|
||||||
out = out or x
|
out = out or x
|
||||||
local mean, dev = calc_mean_dev(x)
|
local mean, dev = calc_mean_dev(x)
|
||||||
|
@ -246,7 +269,31 @@ local function weighted_mann_whitney(s0, s1, w0, w1)
|
||||||
local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6)
|
local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6)
|
||||||
local p = cdf((U - mean) / std)
|
local p = cdf((U - mean) / std)
|
||||||
|
|
||||||
if s0_sum > s1_sum then return 1 - p else return p end
|
local u = U / (w0_sum * w1_sum)
|
||||||
|
if s0_sum > s1_sum then return u, 1 - p else return u, p end
|
||||||
|
end
|
||||||
|
|
||||||
|
local function expect_cossim(n)
|
||||||
|
-- returns gamma(n / 2) / gamma((n + 1) / 2) / sqrt(pi) for positive integers.
|
||||||
|
-- this is the expected absolute cosine similarity between
|
||||||
|
-- two standard normally-distributed random vectors both of size n.
|
||||||
|
assert(n > 0)
|
||||||
|
|
||||||
|
-- abs(error) < 1e-8
|
||||||
|
if n >= 128000 then
|
||||||
|
return 1 / sqrt(pi / 2 * n + 1)
|
||||||
|
elseif n >= 80 then
|
||||||
|
poly = (2.4674010 * n + -2.4673232) * n + 1.2274046
|
||||||
|
return 1 / sqrt(sqrt(poly))
|
||||||
|
end
|
||||||
|
-- fall-through when it's faster just to compute iteratively.
|
||||||
|
|
||||||
|
even = n % 2 == 0
|
||||||
|
res = even and 2.0 or 1.0
|
||||||
|
for i = even and 2 or 1, n - 1, 2 do
|
||||||
|
res = res * (i / (i + 1))
|
||||||
|
end
|
||||||
|
return even and res / pi or res
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -260,6 +307,7 @@ return {
|
||||||
softchoice=softchoice,
|
softchoice=softchoice,
|
||||||
empty=empty,
|
empty=empty,
|
||||||
calc_mean_dev=calc_mean_dev,
|
calc_mean_dev=calc_mean_dev,
|
||||||
|
calc_mean_dev_unbiased=calc_mean_dev_unbiased,
|
||||||
normalize=normalize,
|
normalize=normalize,
|
||||||
normalize_wrt=normalize_wrt,
|
normalize_wrt=normalize_wrt,
|
||||||
normalize_sums=normalize_sums,
|
normalize_sums=normalize_sums,
|
||||||
|
@ -276,4 +324,5 @@ return {
|
||||||
pdf=pdf,
|
pdf=pdf,
|
||||||
cdf=cdf,
|
cdf=cdf,
|
||||||
weighted_mann_whitney=weighted_mann_whitney,
|
weighted_mann_whitney=weighted_mann_whitney,
|
||||||
|
expect_cossim=expect_cossim,
|
||||||
}
|
}
|
||||||
|
|
8
xnes.lua
8
xnes.lua
|
@ -70,6 +70,8 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
|
||||||
-- note: this is technically the co-standard-deviation.
|
-- note: this is technically the co-standard-deviation.
|
||||||
-- you can imagine the "s" standing for "sqrt" if you like.
|
-- you can imagine the "s" standing for "sqrt" if you like.
|
||||||
self.covars = make_covars(self.dims, self.sigma, self.covars)
|
self.covars = make_covars(self.dims, self.sigma, self.covars)
|
||||||
|
|
||||||
|
self.evals = 0
|
||||||
end
|
end
|
||||||
|
|
||||||
function Xnes:params(new_mean)
|
function Xnes:params(new_mean)
|
||||||
|
@ -153,6 +155,8 @@ function Xnes:tell(scored, noise)
|
||||||
local noise = noise or self.noise
|
local noise = noise or self.noise
|
||||||
assert(noise, "missing noise argument")
|
assert(noise, "missing noise argument")
|
||||||
|
|
||||||
|
self.evals = self.evals + #scored
|
||||||
|
|
||||||
local arg = argsort(scored, function(a, b) return a > b end)
|
local arg = argsort(scored, function(a, b) return a > b end)
|
||||||
|
|
||||||
local g_delta = zeros{self.dims}
|
local g_delta = zeros{self.dims}
|
||||||
|
@ -173,7 +177,7 @@ function Xnes:tell(scored, noise)
|
||||||
local zzt = noise_p[i] * noise_p[j] - (i == j and 1 or 0)
|
local zzt = noise_p[i] * noise_p[j] - (i == j and 1 or 0)
|
||||||
local temp = self.utility[p] * zzt
|
local temp = self.utility[p] * zzt
|
||||||
g_covars[ind] = g_covars[ind] + temp
|
g_covars[ind] = g_covars[ind] + temp
|
||||||
traced = traced + temp
|
if i == j then traced = traced + temp end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -181,7 +185,7 @@ function Xnes:tell(scored, noise)
|
||||||
local g_sigma = traced / self.dims
|
local g_sigma = traced / self.dims
|
||||||
|
|
||||||
for i=1, self.dims do
|
for i=1, self.dims do
|
||||||
local ind = (i - 1) * self.dims + i
|
local ind = (i - 1) * self.dims + i -- diagonal
|
||||||
g_covars[ind] = g_covars[ind] - g_sigma
|
g_covars[ind] = g_covars[ind] - g_sigma
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue