temp 3
This commit is contained in:
parent
a1429a6271
commit
08f476e6ac
14 changed files with 1228 additions and 168 deletions
13
ars.lua
13
ars.lua
|
@ -10,8 +10,10 @@ local exp = math.exp
|
|||
local floor = math.floor
|
||||
local insert = table.insert
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local max = math.max
|
||||
local print = print
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local Base = require "Base"
|
||||
|
||||
|
@ -72,16 +74,15 @@ local function kinda_lipschitz(dir, pos, neg, mid)
|
|||
end
|
||||
|
||||
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic,
|
||||
momentum)
|
||||
momentum, beta)
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
self.param_rate = base_rate
|
||||
self.sigma_rate = base_rate
|
||||
self.covar_rate = base_rate
|
||||
self.sigma = sigma or 1
|
||||
self.antithetic = antithetic == nil and true or antithetic
|
||||
self.momentum = momentum or 0
|
||||
self.beta = beta or 1.0
|
||||
|
||||
self.poptop = poptop or popsize
|
||||
assert(self.poptop <= popsize)
|
||||
|
@ -189,8 +190,9 @@ function Ars:tell(scored, unperturbed_score)
|
|||
reward = reward / reward_dev
|
||||
end
|
||||
|
||||
local scale = reward / self.poptop * self.beta / 2
|
||||
for j, v in ipairs(noisy) do
|
||||
step[j] = step[j] + reward * v / self.poptop
|
||||
step[j] = step[j] + scale * v
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -200,8 +202,9 @@ function Ars:tell(scored, unperturbed_score)
|
|||
if reward ~= 0 then
|
||||
local noisy = self.noise[ind]
|
||||
|
||||
local scale = reward / self.poptop * self.beta
|
||||
for j, v in ipairs(noisy) do
|
||||
step[j] = step[j] + reward * v / self.poptop
|
||||
step[j] = step[j] + scale * v
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
14
config.lua
14
config.lua
|
@ -26,15 +26,17 @@ local defaults = {
|
|||
time_inputs = true, -- insert binary inputs of a frame counter.
|
||||
|
||||
-- network layers:
|
||||
embed = true, -- set to false to use a hard-coded tile embedding.
|
||||
reduce_tiles = 0, -- TODO: write description
|
||||
hidden = false, -- use a hidden layer with ReLU/GELU activation.
|
||||
hidden_size = 128,
|
||||
layernorm = false, -- use a LayerNorm layer after said activation.
|
||||
reduce_tiles = false,
|
||||
bias_out = true,
|
||||
|
||||
-- network evaluation (sampling joypad):
|
||||
frameskip = 4,
|
||||
prob_frameskip = 0.0,
|
||||
max_frameskip = 6,
|
||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||
deterministic = false, -- use argmax on outputs instead of random sampling.
|
||||
det_epsilon = false, -- take random actions with probability eps.
|
||||
|
@ -42,12 +44,16 @@ local defaults = {
|
|||
-- evolution strategy and non-rate hyperparemeters:
|
||||
es = 'ars',
|
||||
ars_lips = false, -- for ARS.
|
||||
epoch_top_trials = 9999, -- for ARS.
|
||||
epoch_top_trials = 9999, -- for ARS, Guided.
|
||||
alpha = 0.5, -- for Guided.
|
||||
beta = 2.0, -- for ARS, Guided. should be 1, but defaults to 2 for compat.
|
||||
past_grads = 1, -- for Guided. keeps a history of n past steps taken.
|
||||
|
||||
-- sampling:
|
||||
deviation = 1.0,
|
||||
unperturbed_trial = true, -- perform an extra trial without any noise.
|
||||
-- this is good for logging, so i'd recommend it.
|
||||
attempts = 1, -- TODO: document.
|
||||
epoch_trials = 50,
|
||||
graycode = false, -- for ARS.
|
||||
negate_trials = true, -- try pairs of normal and negated noise directions.
|
||||
|
@ -114,5 +120,9 @@ assert(not cfg.ars_lips or cfg.negate_trials,
|
|||
"cfg.negate_trials must be true to use cfg.ars_lips")
|
||||
assert(not (cfg.es == 'snes' and cfg.negate_trials),
|
||||
"cfg.negate_trials is not yet compatible with SNES")
|
||||
assert(not (cfg.es == 'guided' and cfg.graycode),
|
||||
"cfg.graycode is not compatible with Guided")
|
||||
assert(cfg.es ~= 'guided' or cfg.negate_trials,
|
||||
"cfg.negate_trials must be true to use Guided")
|
||||
|
||||
return cfg
|
||||
|
|
62
es_test.lua
62
es_test.lua
|
@ -1,21 +1,27 @@
|
|||
local floor = math.floor
|
||||
local insert = table.insert
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local max = math.max
|
||||
local print = print
|
||||
|
||||
local ars = require("ars")
|
||||
local snes = require("snes")
|
||||
local xnes = require("xnes")
|
||||
local guided = require("guided")
|
||||
|
||||
-- try it all out on a dummy problem.
|
||||
|
||||
local function typeof(t) return getmetatable(t).__index end
|
||||
|
||||
local function square(x) return x * x end
|
||||
|
||||
-- this function's global minimum is arange(dims) + 1.
|
||||
-- xNES should be able to find it almost exactly.
|
||||
local function spherical(x)
|
||||
local sum = 0
|
||||
for i, v in ipairs(x) do sum = sum + square(v - i) end
|
||||
--for i, v in ipairs(x) do sum = sum + square(v - i) end
|
||||
for i, v in ipairs(x) do sum = sum + square(v - i / #x) end
|
||||
-- we need to negate this to turn it into a maximization problem.
|
||||
return -sum
|
||||
end
|
||||
|
@ -26,10 +32,34 @@ local dims = 100
|
|||
local popsize = dims + 1
|
||||
local sigma_init = 0.5
|
||||
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
|
||||
--local es = snes.Snes(dims, popsize, 0.1, sigma_init)
|
||||
local es = snes.Snes(dims, popsize, 0.1, sigma_init)
|
||||
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
|
||||
local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
|
||||
es.min_refresh = 0.7 -- FIXME: needs a better interface.
|
||||
--local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
|
||||
--local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5)
|
||||
es.min_refresh = 0.5 -- FIXME: needs a better interface.
|
||||
|
||||
if typeof(es) == xnes.Xnes
|
||||
or typeof(es) == snes.Snes
|
||||
then
|
||||
-- use IGO recommendations
|
||||
local pop5 = max(1, floor(es.popsize / 5))
|
||||
|
||||
local sum = 0
|
||||
for i=1, es.popsize do
|
||||
local maybe = i < pop5 and 1 or 0
|
||||
es.utility[i] = maybe
|
||||
sum = sum + maybe
|
||||
end
|
||||
--for i, v in ipairs(es.utility) do es.utility[i] = v / sum end
|
||||
|
||||
local util = require "util"
|
||||
util.normalize_sums(es.utility)
|
||||
|
||||
es.param_rate = 0.39
|
||||
es.sigma_rate = 0.39
|
||||
es.covar_rate = 0.39
|
||||
es.adaptive = false
|
||||
end
|
||||
|
||||
if false then -- TODO: delete me
|
||||
local nn = require("nn")
|
||||
|
@ -64,30 +94,44 @@ local asked = nil -- for caching purposes.
|
|||
local noise = nil -- for caching purposes.
|
||||
local current_cost = spherical(es:params())
|
||||
|
||||
local past_grads = {}
|
||||
local pgi = 0
|
||||
local pgn = 10
|
||||
|
||||
for i=1, iterations do
|
||||
if getmetatable(es).__index == snes.Snes then
|
||||
if typeof(es) == snes.Snes and es.min_refresh ~= 1 then
|
||||
asked, noise = es:ask_mix()
|
||||
elseif getmetatable(es).__index == ars.Ars then
|
||||
elseif typeof(es) == ars.Ars then
|
||||
asked, noise = es:ask()
|
||||
elseif typeof(es) == guided.Guided then
|
||||
asked, noise = es:ask(past_grads)
|
||||
else
|
||||
asked, noise = es:ask(asked, noise)
|
||||
end
|
||||
|
||||
local scores = {}
|
||||
for i, v in ipairs(asked) do
|
||||
scores[i] = spherical(v)
|
||||
end
|
||||
|
||||
if getmetatable(es).__index == ars.Ars then
|
||||
if typeof(es) == ars.Ars then
|
||||
es:tell(scores)--, current_cost) -- use lips
|
||||
elseif typeof(es) == guided.Guided then
|
||||
local step = es:tell(scores)
|
||||
|
||||
for _, v in ipairs(step) do
|
||||
past_grads[pgi + 1] = v
|
||||
pgi = (pgi + 1) % (pgn * #step)
|
||||
end
|
||||
past_grads.shape = {floor(#past_grads / #step), #step}
|
||||
else
|
||||
es:tell(scores)
|
||||
end
|
||||
|
||||
current_cost = spherical(es:params())
|
||||
--if i % 100 == 0 then
|
||||
if i % 100 == 0 then
|
||||
local sigma = es.sigma
|
||||
if getmetatable(es).__index == snes.Snes then
|
||||
if typeof(es) == snes.Snes then
|
||||
sigma = 0
|
||||
for i, v in ipairs(es.std) do sigma = sigma + v end
|
||||
sigma = sigma / #es.std
|
||||
|
|
194
guided.lua
Normal file
194
guided.lua
Normal file
|
@ -0,0 +1,194 @@
|
|||
-- Guided Evolutionary Strategies
|
||||
-- https://arxiv.org/abs/1806.10230
|
||||
|
||||
-- this is just ARS extended to utilize gradients
|
||||
-- approximated from previous iterations.
|
||||
|
||||
-- for simplicity:
|
||||
-- antithetic is always true
|
||||
-- momentum is always 0
|
||||
-- no graycode/lipschitz nonsense
|
||||
|
||||
local floor = math.floor
|
||||
local insert = table.insert
|
||||
local ipairs = ipairs
|
||||
local max = math.max
|
||||
local print = print
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local Base = require "Base"
|
||||
|
||||
local nn = require "nn"
|
||||
local dot_mv = nn.dot_mv
|
||||
local transpose = nn.transpose
|
||||
local normal = nn.normal
|
||||
local prod = nn.prod
|
||||
local uniform = nn.uniform
|
||||
local zeros = nn.zeros
|
||||
|
||||
local qr = require "qr2"
|
||||
|
||||
local util = require "util"
|
||||
local argsort = util.argsort
|
||||
local calc_mean_dev = util.calc_mean_dev
|
||||
|
||||
local Guided = Base:extend()
|
||||
|
||||
local function collect_best_indices(scored, top)
|
||||
-- select one (the best) reward of each pos/neg pair.
|
||||
local best_rewards
|
||||
best_rewards = {}
|
||||
for i = 1, #scored / 2 do
|
||||
local pos = scored[i * 2 - 1]
|
||||
local neg = scored[i * 2 - 0]
|
||||
best_rewards[i] = max(pos, neg)
|
||||
end
|
||||
|
||||
local indices = argsort(best_rewards, function(a, b) return a > b end)
|
||||
|
||||
for i = top + 1, #best_rewards do indices[i] = nil end
|
||||
return indices
|
||||
end
|
||||
|
||||
function Guided:init(dims, popsize, poptop, base_rate, sigma, alpha, beta)
|
||||
-- sigma: scale of random perturbations.
|
||||
-- alpha: blend between full parameter space and its gradient subspace.
|
||||
-- 1.0 is roughly equivalent to ARS.
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
self.param_rate = base_rate
|
||||
self.sigma = sigma or 1.0
|
||||
self.alpha = alpha or 0.5
|
||||
self.beta = beta or 1.0
|
||||
|
||||
self.poptop = poptop or popsize
|
||||
assert(self.poptop <= popsize)
|
||||
self.popsize = self.popsize * 2 -- antithetic
|
||||
|
||||
self._params = zeros(self.dims)
|
||||
--self.accum = zeros(self.dims) -- momentum
|
||||
|
||||
self.evals = 0
|
||||
end
|
||||
|
||||
function Guided:params(new_params)
|
||||
if new_params ~= nil then
|
||||
assert(#self._params == #new_params, "new parameters have the wrong size")
|
||||
for i, v in ipairs(new_params) do self._params[i] = v end
|
||||
end
|
||||
return self._params
|
||||
end
|
||||
|
||||
function Guided:decay(param_decay, sigma_decay)
|
||||
-- FIXME: multiplying by sigma probably isn't correct anymore.
|
||||
-- is this correct now?
|
||||
if param_decay > 0 then
|
||||
local scale = self.sigma / sqrt(self.dims)
|
||||
scale = scale * self.beta
|
||||
scale = scale * self.param_rate / (self.sigma * self.sigma)
|
||||
|
||||
scale = 1 - param_decay * scale
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = scale * v
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function Guided:ask(grads)
|
||||
local asked = {}
|
||||
local noise = {}
|
||||
|
||||
local n_grad = 0
|
||||
local gnoise, U, dummy, left, right
|
||||
if grads ~= nil and #grads > 0 then
|
||||
n_grad = grads.shape[1]
|
||||
gnoise = zeros(n_grad)
|
||||
|
||||
U, dummy = qr(transpose(grads))
|
||||
--print(nn.pp(transpose(U), "%9.4f"))
|
||||
|
||||
left = sqrt(self.alpha / self.dims)
|
||||
right = sqrt((1 - self.alpha) / n_grad)
|
||||
--print(left, right)
|
||||
end
|
||||
|
||||
for i = 1, self.popsize do
|
||||
local asking = zeros(self.dims)
|
||||
local noisy = zeros(self.dims)
|
||||
asked[i] = asking
|
||||
noise[i] = noisy
|
||||
|
||||
if i % 2 == 0 then
|
||||
local old_noisy = noise[i - 1]
|
||||
for j, v in ipairs(old_noisy) do
|
||||
noisy[j] = -v
|
||||
end
|
||||
elseif n_grad == 0 then
|
||||
local scale = self.sigma / sqrt(self.dims)
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = scale * normal()
|
||||
end
|
||||
else
|
||||
for j = 1, self.dims do noisy[j] = normal() end
|
||||
for j = 1, n_grad do gnoise[j] = normal() end
|
||||
local noisier = dot_mv(U, gnoise)
|
||||
for j, v in ipairs(noisy) do
|
||||
noisy[j] = self.sigma * (left * v + right * noisier[j])
|
||||
end
|
||||
end
|
||||
|
||||
for j, v in ipairs(self._params) do
|
||||
asking[j] = v + noisy[j]
|
||||
end
|
||||
end
|
||||
|
||||
self.noise = noise
|
||||
return asked, noise
|
||||
end
|
||||
|
||||
function Guided:tell(scored, unperturbed_score)
|
||||
self.evals = self.evals + #scored
|
||||
|
||||
local indices = collect_best_indices(scored, self.poptop)
|
||||
|
||||
local top_rewards = {}
|
||||
for _, ind in ipairs(indices) do
|
||||
insert(top_rewards, scored[ind * 2 - 1])
|
||||
insert(top_rewards, scored[ind * 2 - 0])
|
||||
end
|
||||
|
||||
local step = zeros(self.dims)
|
||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
for i, ind in ipairs(indices) do
|
||||
local pos = top_rewards[i * 2 - 1]
|
||||
local neg = top_rewards[i * 2 - 0]
|
||||
local reward = pos - neg
|
||||
if reward ~= 0 then
|
||||
local noisy = self.noise[ind * 2 - 1]
|
||||
-- NOTE: technically this reward divide isn't part of guided search.
|
||||
reward = reward / reward_dev
|
||||
|
||||
local scale = reward / self.poptop * self.beta / 2
|
||||
for j, v in ipairs(noisy) do
|
||||
step[j] = step[j] + scale * v
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local coeff = self.param_rate / (self.sigma * self.sigma)
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v + coeff * step[i]
|
||||
end
|
||||
|
||||
self.noise = nil
|
||||
|
||||
return step
|
||||
end
|
||||
|
||||
return {
|
||||
--collect_best_indices = collect_best_indices, -- ars.lua has more features
|
||||
Guided = Guided,
|
||||
}
|
166
main.lua
166
main.lua
|
@ -20,6 +20,12 @@ local trial_rewards = {}
|
|||
local trials_remaining = 0
|
||||
local es -- evolution strategy.
|
||||
|
||||
local attempt_i = 0
|
||||
local sub_rewards = {}
|
||||
|
||||
local past_grads = {} -- for Guided
|
||||
local pgi = 0 -- past_grads_index
|
||||
|
||||
local trial_frames = 0
|
||||
local total_frames = 0
|
||||
local lagless_count = 0
|
||||
|
@ -93,6 +99,7 @@ local util = require("util")
|
|||
local argmax = util.argmax
|
||||
local argsort = util.argsort
|
||||
local calc_mean_dev = util.calc_mean_dev
|
||||
local calc_mean_dev_unbiased = util.calc_mean_dev_unbiased
|
||||
local clamp = util.clamp
|
||||
local copy = util.copy
|
||||
local empty = util.empty
|
||||
|
@ -149,13 +156,21 @@ local network
|
|||
local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z
|
||||
local function make_network(input_size)
|
||||
nn_x = nn.Input({input_size})
|
||||
nn_tx = nn.Input({gcfg.tile_count})
|
||||
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
|
||||
|
||||
local embed_dim = cfg.embed and 2 or 3
|
||||
|
||||
if cfg.embed then
|
||||
nn_tx = nn.Input({gcfg.tile_count})
|
||||
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, embed_dim))
|
||||
else -- new tile inputs.
|
||||
nn_tx = nn.Input({gcfg.tile_count * 3})
|
||||
nn_ty = nn_tx
|
||||
end
|
||||
|
||||
nn_tz = nn_ty
|
||||
if cfg.reduce_tiles then
|
||||
nn_tz = nn_tz:feed(nn.Reshape{11, 17 * 2})
|
||||
nn_tz = nn_tz:feed(nn.DenseBroadcast(5, true))
|
||||
if cfg.reduce_tiles > 0 then
|
||||
nn_tz = nn_tz:feed(nn.Reshape{11, 17 * embed_dim})
|
||||
nn_tz = nn_tz:feed(nn.DenseBroadcast(cfg.reduce_tiles, true))
|
||||
nn_tz = nn_tz:feed(nn.Relu())
|
||||
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
|
||||
end
|
||||
|
@ -185,6 +200,7 @@ end
|
|||
local ars = require("ars")
|
||||
local snes = require("snes")
|
||||
local xnes = require("xnes")
|
||||
local guided = require("guided")
|
||||
|
||||
local function prepare_epoch()
|
||||
trial_neg = false
|
||||
|
@ -213,6 +229,8 @@ local function prepare_epoch()
|
|||
local dummy
|
||||
if cfg.es == 'ars' then
|
||||
trial_params, dummy = es:ask(precision)
|
||||
elseif cfg.es == 'guided' then
|
||||
trial_params, dummy = es:ask(past_grads)
|
||||
elseif cfg.es == 'snes' then
|
||||
trial_params, dummy = es:ask_mix()
|
||||
else
|
||||
|
@ -223,6 +241,7 @@ local function prepare_epoch()
|
|||
end
|
||||
|
||||
local function load_next_trial()
|
||||
attempt_i = 1
|
||||
if cfg.negate_trials then
|
||||
trial_neg = not trial_neg
|
||||
else
|
||||
|
@ -272,8 +291,16 @@ local function learn_from_epoch()
|
|||
end
|
||||
|
||||
local step_mean, step_dev = calc_mean_dev(step)
|
||||
print("step mean:", step_mean)
|
||||
print("step stddev:", step_dev)
|
||||
print(("step mean: %9.6f"):format(step_mean))
|
||||
print(("step stddev: %9.6f"):format(step_dev))
|
||||
|
||||
if cfg.es == 'guided' and cfg.past_grads > 0 then
|
||||
for _, v in ipairs(step) do
|
||||
past_grads[pgi + 1] = v
|
||||
pgi = (pgi + 1) % (cfg.past_grads * #step)
|
||||
end
|
||||
past_grads.shape = {floor(#past_grads / #step), #step}
|
||||
end
|
||||
|
||||
es:decay(cfg.param_decay, cfg.sigma_decay)
|
||||
|
||||
|
@ -329,65 +356,62 @@ local function joypad_mash(button)
|
|||
joypad.write(1, jp_mash)
|
||||
end
|
||||
|
||||
local function loadlevel(world, level)
|
||||
-- TODO: move to smb.lua. rename to load_level.
|
||||
if world == 0 then world = random(1, 8) end
|
||||
if level == 0 then level = random(1, 4) end
|
||||
emu.poweron()
|
||||
while emu.framecount() < 60 do
|
||||
if emu.framecount() == 32 then
|
||||
local area = game.area_lut[world * 10 + level]
|
||||
game.W(0x75F, world - 1)
|
||||
game.W(0x75C, level - 1)
|
||||
game.W(0x760, area)
|
||||
end
|
||||
if emu.framecount() == 42 then
|
||||
game.W(0x7A0, 0) -- world screen timer (reduces startup time)
|
||||
end
|
||||
joypad_mash('start')
|
||||
emu.frameadvance()
|
||||
end
|
||||
end
|
||||
|
||||
local function do_reset()
|
||||
local state = game.get_state()
|
||||
-- be a little more descriptive.
|
||||
if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end
|
||||
|
||||
if trial_i >= 0 then
|
||||
if trial_i == 0 then
|
||||
print('test trial reward:', reward, "("..state..")")
|
||||
elseif cfg.negate_trials then
|
||||
--local dir = trial_neg and "negative" or "positive"
|
||||
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
|
||||
--if cfg.attempts > 1 and attempt_i >= cfg.attempts then
|
||||
attempt_i = attempt_i + 1
|
||||
sub_rewards[#sub_rewards + 1] = reward
|
||||
--print(sub_rewards)
|
||||
|
||||
if trial_neg then
|
||||
local pos = trial_rewards[#trial_rewards]
|
||||
local neg = reward
|
||||
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
|
||||
print(fmt:format(floor(trial_i / 2),
|
||||
pos, neg, last_trial_state, state))
|
||||
if #sub_rewards >= cfg.attempts then
|
||||
if cfg.attempts == 1 then
|
||||
reward = sub_rewards[1]
|
||||
else
|
||||
local sub_mean, sub_std = calc_mean_dev(sub_rewards)
|
||||
reward = floor(sub_mean)
|
||||
--local sub_mean, sub_std = calc_mean_dev_unbiased(sub_rewards)
|
||||
--reward = floor(sub_mean - sub_std)
|
||||
end
|
||||
empty(sub_rewards)
|
||||
|
||||
if trial_i >= 0 then
|
||||
if trial_i == 0 then
|
||||
print('test trial reward:', reward, "("..state..")")
|
||||
elseif cfg.negate_trials then
|
||||
--local dir = trial_neg and "negative" or "positive"
|
||||
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
|
||||
|
||||
if trial_neg then
|
||||
local pos = trial_rewards[#trial_rewards]
|
||||
local neg = reward
|
||||
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
|
||||
print(fmt:format(floor(trial_i / 2),
|
||||
pos, neg, last_trial_state, state))
|
||||
end
|
||||
last_trial_state = state
|
||||
else
|
||||
print('trial', trial_i, 'reward:', reward, "("..state..")")
|
||||
end
|
||||
|
||||
if trial_i == 0 or not cfg.negate_trials then
|
||||
trial_rewards[trial_i] = reward
|
||||
else
|
||||
trial_rewards[#trial_rewards + 1] = reward
|
||||
end
|
||||
last_trial_state = state
|
||||
else
|
||||
print('trial', trial_i, 'reward:', reward, "("..state..")")
|
||||
end
|
||||
|
||||
if trial_i == 0 or not cfg.negate_trials then
|
||||
trial_rewards[trial_i] = reward
|
||||
else
|
||||
trial_rewards[#trial_rewards + 1] = reward
|
||||
end
|
||||
end
|
||||
|
||||
if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then
|
||||
if epoch_i > 0 then learn_from_epoch() end
|
||||
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
|
||||
prepare_epoch()
|
||||
collectgarbage()
|
||||
if any_random then
|
||||
loadlevel(cfg.starting_world, cfg.starting_level)
|
||||
state_saved = false
|
||||
if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then
|
||||
if epoch_i > 0 then learn_from_epoch() end
|
||||
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
|
||||
prepare_epoch()
|
||||
collectgarbage()
|
||||
if any_random then
|
||||
game.load_level(cfg.starting_world, cfg.starting_level)
|
||||
state_saved = false
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -427,7 +451,9 @@ local function do_reset()
|
|||
trial_frames = 0
|
||||
emu.frameadvance() -- prevents emulator from quirking up.
|
||||
|
||||
load_next_trial()
|
||||
if attempt_i > cfg.attempts then
|
||||
load_next_trial()
|
||||
end
|
||||
|
||||
reset = false
|
||||
end
|
||||
|
@ -450,7 +476,7 @@ local function init()
|
|||
if not playing then emu.speedmode("turbo") end
|
||||
|
||||
if not any_random then
|
||||
loadlevel(cfg.starting_world, cfg.starting_level)
|
||||
game.load_level(cfg.starting_world, cfg.starting_level)
|
||||
end
|
||||
|
||||
params_fn = cfg.params_fn or ('network%07i.txt'):format(network.n_param)
|
||||
|
@ -481,7 +507,10 @@ local function init()
|
|||
elseif cfg.es == 'ars' then
|
||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||
cfg.base_rate, cfg.deviation, cfg.negate_trials,
|
||||
cfg.momentum)
|
||||
cfg.momentum, cfg.beta)
|
||||
elseif cfg.es == 'guided' then
|
||||
es = guided.Guided(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||
cfg.base_rate, cfg.deviation, cfg.alpha, cfg.beta)
|
||||
else
|
||||
error("Unknown evolution strategy specified: " + tostring(cfg.es))
|
||||
end
|
||||
|
@ -537,6 +566,7 @@ local function doit(dummy)
|
|||
empty(game.sprite_input)
|
||||
empty(game.tile_input)
|
||||
empty(game.extra_input)
|
||||
empty(game.new_input)
|
||||
|
||||
local controllable = game.R(0x757) == 0 and game.R(0x758) == 0
|
||||
local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
||||
|
@ -616,12 +646,18 @@ local function doit(dummy)
|
|||
for i, v in ipairs(game.extra_input) do insert(X, v / 256) end
|
||||
nn.reshape(X, 1, gcfg.input_size)
|
||||
nn.reshape(game.tile_input, 1, gcfg.tile_count)
|
||||
nn.reshape(game.new_input, 1, gcfg.tile_count * 3)
|
||||
|
||||
trial_frames = trial_frames + cfg.frameskip
|
||||
if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then
|
||||
total_frames = total_frames + cfg.frameskip
|
||||
|
||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
|
||||
local outputs
|
||||
if cfg.embed then
|
||||
outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
|
||||
else
|
||||
outputs = network:forward({[nn_x]=X, [nn_tx]=game.new_input})
|
||||
end
|
||||
|
||||
local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames)
|
||||
if cfg.det_epsilon and random() < eps then
|
||||
|
@ -695,8 +731,12 @@ while true do
|
|||
end
|
||||
|
||||
local delta = lagless_count - last_decision_frame
|
||||
local doot = jp == nil or delta >= cfg.frameskip
|
||||
doot = doot and random() >= cfg.prob_frameskip
|
||||
local doot = true
|
||||
if jp ~= nil then
|
||||
doot = delta >= cfg.frameskip
|
||||
doot = doot and random() >= cfg.prob_frameskip
|
||||
doot = doot or delta >= cfg.max_frameskip
|
||||
end
|
||||
doit(not doot)
|
||||
if doot then last_decision_frame = lagless_count end
|
||||
|
||||
|
|
46
nn.lua
46
nn.lua
|
@ -1,20 +1,15 @@
|
|||
local assert = assert
|
||||
local ceil = math.ceil
|
||||
local cos = math.cos
|
||||
local exp = math.exp
|
||||
local floor = math.floor
|
||||
local huge = math.huge
|
||||
local insert = table.insert
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local max = math.max
|
||||
local min = math.min
|
||||
local open = io.open
|
||||
local pairs = pairs
|
||||
local pi = math.pi
|
||||
local print = print
|
||||
local remove = table.remove
|
||||
local sin = math.sin
|
||||
local sqrt = math.sqrt
|
||||
local tanh = math.tanh
|
||||
local tostring = tostring
|
||||
|
@ -105,19 +100,26 @@ end
|
|||
|
||||
-- ndarray-ish stuff and more involved math
|
||||
|
||||
local function pp_join(sep, fmt, t, a, b)
|
||||
a = a or 1
|
||||
b = b or #t
|
||||
local s = ''
|
||||
for i = a, b do
|
||||
s = s..fmt:format(t[i])
|
||||
if i ~= b then s = s..sep end
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
||||
local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
|
||||
-- pretty-prints an nd-array.
|
||||
fmt = fmt or '%10.7f,'
|
||||
fmt = fmt or '%10.7f'
|
||||
sep = sep or ','
|
||||
ti = ti or 0
|
||||
di = di or 1
|
||||
depth = depth or 0
|
||||
|
||||
if t.shape == nil then
|
||||
local s = '['
|
||||
for i = 1, #t do s = s..fmt:format(t[i]) end
|
||||
return s..']'..sep..'\n'
|
||||
end
|
||||
if t.shape == nil then return '['..pp_join(sep, fmt, t)..']'..sep..'\n' end
|
||||
|
||||
local dim = t.shape[di]
|
||||
|
||||
|
@ -134,11 +136,10 @@ local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
|
|||
s = s..pp(t, fmt, sep, ti, di + 1, depth + 1, i == 1, i == dim)
|
||||
ti = ti + ti_step
|
||||
end
|
||||
if islast then s = s..indent..']'..sep..'\n' else s = s..indent..']'..sep end
|
||||
s = s..indent..']'..sep
|
||||
if islast then s = s..'\n' end
|
||||
else
|
||||
s = s..indent..'['
|
||||
for i = ti + 1, ti + dim do s = s..fmt:format(t[i])..sep end
|
||||
s = s..']'..sep..'\n'
|
||||
s = s..indent..'['..pp_join(sep, fmt, t, ti + 1, ti + dim)..']'..sep..'\n'
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
@ -265,6 +266,20 @@ local function dot(a, b, ax_a, ax_b, out)
|
|||
return out
|
||||
end
|
||||
|
||||
local function transpose(x, out)
|
||||
assert(#x.shape == 2) -- TODO: handle ndarrays like numpy.
|
||||
local rows = x.shape[1]
|
||||
local cols = x.shape[2]
|
||||
local y = out or zeros{cols, rows}
|
||||
-- TODO: simplify? y can be consecutive for sure.
|
||||
for i = 1, rows do
|
||||
for j = 1, cols do
|
||||
y[(j - 1) * rows + i] = x[(i - 1) * cols + j]
|
||||
end
|
||||
end
|
||||
return y
|
||||
end
|
||||
|
||||
-- nodal
|
||||
|
||||
local function traverse(node_in, node_out, nodes, dummy_mode)
|
||||
|
@ -875,6 +890,7 @@ return {
|
|||
ppi = ppi,
|
||||
dot_mv = dot_mv,
|
||||
dot = dot,
|
||||
transpose = transpose,
|
||||
traverse = traverse,
|
||||
traverse_all = traverse_all,
|
||||
|
||||
|
|
230
presets.lua
230
presets.lua
|
@ -33,7 +33,7 @@ make_preset{
|
|||
|
||||
init_zeros = true,
|
||||
|
||||
reduce_tiles = true,
|
||||
reduce_tiles = 5,
|
||||
bias_out = false,
|
||||
|
||||
deterministic = false,
|
||||
|
@ -147,6 +147,15 @@ make_preset{
|
|||
parent = 'ars',
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-lips',
|
||||
parent = 'ars',
|
||||
|
||||
ars_lips = true,
|
||||
-- momentum = 0.5, -- this is default.
|
||||
param_rate = 1.0,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-skip',
|
||||
parent = 'ars',
|
||||
|
@ -155,15 +164,6 @@ make_preset{
|
|||
prob_frameskip = 0.25,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-lips',
|
||||
parent = 'ars',
|
||||
|
||||
ars_lips = true,
|
||||
momentum = 0.5,
|
||||
param_rate = 1.0,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-big',
|
||||
parent = 'ars',
|
||||
|
@ -204,6 +204,216 @@ make_preset{
|
|||
momentum = 0.99,
|
||||
}
|
||||
|
||||
-- new stuff for 2019:
|
||||
|
||||
make_preset{
|
||||
name = 'ars-skip-more',
|
||||
parent = 'ars',
|
||||
|
||||
-- old:
|
||||
--frameskip = 2,
|
||||
--prob_frameskip = 0.5,
|
||||
--max_frameskip = 60,
|
||||
-- new:
|
||||
frameskip = 3,
|
||||
prob_frameskip = 0.5,
|
||||
max_frameskip = 5,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-skip-more-3',
|
||||
parent = 'ars-skip-more',
|
||||
|
||||
attempts = 3, -- per trial. score = mean(scores) - stdev(scores)
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'snes-skip-more-3',
|
||||
parent = 'snes3',
|
||||
|
||||
frameskip = 3,
|
||||
prob_frameskip = 0.5,
|
||||
max_frameskip = 5,
|
||||
attempts = 3,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
||||
es = 'guided',
|
||||
epoch_top_trials = 20,
|
||||
deterministic = true,
|
||||
deviation = 0.1,
|
||||
epoch_trials = 20,
|
||||
param_rate = 0.00368,
|
||||
param_decay = 0.0,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided2',
|
||||
parent = 'guided',
|
||||
|
||||
past_grads = 2,
|
||||
-- after epoch 50, trying this:
|
||||
--param_rate = 0.05,
|
||||
-- after epoch 50+20, stepping back to this:
|
||||
param_rate = 0.01,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided10',
|
||||
parent = 'guided',
|
||||
|
||||
past_grads = 10,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided69', -- the nice one
|
||||
parent = 'guided',
|
||||
|
||||
deviation = 0.05,
|
||||
epoch_top_trials = 10,
|
||||
epoch_trials = 20,
|
||||
param_rate = 0.006,
|
||||
|
||||
past_grads = 4,
|
||||
alpha = 0.25,
|
||||
}
|
||||
|
||||
-- TODO: yet another preset. try building up from 1 trial ARS to something good.
|
||||
make_preset{
|
||||
name = 'redux',
|
||||
|
||||
min_time = 300,
|
||||
max_time = 300,
|
||||
timer_loser = 1/1,
|
||||
|
||||
score_multiplier = 1,
|
||||
|
||||
init_zeros = true,
|
||||
|
||||
deterministic = true,
|
||||
|
||||
es = 'guided',
|
||||
past_grads = 2, -- for Guided.
|
||||
alpha = 0.25, -- for Guided.
|
||||
|
||||
ars_lips = false, -- for ARS.
|
||||
beta = 1.0, -- fix the default.
|
||||
|
||||
epoch_top_trials = 4, -- for ARS, Guided.
|
||||
epoch_trials = 5,
|
||||
attempts = 1, -- TODO: document.
|
||||
|
||||
deviation = 1.0, -- 0.1
|
||||
base_rate = 1.0,
|
||||
param_decay = 0.01,
|
||||
|
||||
graycode = false, -- for ARS.
|
||||
min_refresh = 0.1, -- for SNES.
|
||||
sigma_decay = 0.0, -- for SNES, xNES.
|
||||
momentum = 0.0, -- for ARS.
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'redux_big',
|
||||
parent = 'redux',
|
||||
|
||||
time_inputs = true, -- insert binary inputs of a frame counter.
|
||||
hidden = true, -- use a hidden layer with ReLU/GELU activation.
|
||||
hidden_size = 64,
|
||||
layernorm = true, -- use a LayerNorm layer after said activation.
|
||||
reduce_tiles = false,
|
||||
bias_out = false,
|
||||
|
||||
-- gets stuck pretty quick, so tweak some stuff...
|
||||
epoch_top_trials = 8,
|
||||
epoch_trials = 10,
|
||||
deviation = 1.0,
|
||||
base_rate = 0.15,
|
||||
param_decay = 0.05,
|
||||
past_grads = 4,
|
||||
alpha = 0.25,
|
||||
-- well it doesn't get stuck anymore, but regular redux works much better.
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided-skip-more-3',
|
||||
parent = 'guided',
|
||||
|
||||
--param_rate = 0.00368, -- should probably be this instead...
|
||||
param_rate = 0.01,
|
||||
|
||||
frameskip = 3,
|
||||
prob_frameskip = 0.5,
|
||||
max_frameskip = 5,
|
||||
attempts = 3,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided-skip-more-3-again',
|
||||
parent = 'guided-skip-more-3',
|
||||
|
||||
param_rate = 0.08, --0.0316,
|
||||
deviation = 0.5,
|
||||
alpha = 0.1, --0.5,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'crazy',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
||||
es = 'guided',
|
||||
epoch_top_trials = 15,
|
||||
deterministic = false,
|
||||
deviation = 1.0,
|
||||
epoch_trials = 15,
|
||||
param_rate = 1.0,
|
||||
param_decay = 0.0,
|
||||
alpha = 0.0316,
|
||||
--attempts = 3,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-lips2',
|
||||
parent = 'ars',
|
||||
|
||||
ars_lips = true,
|
||||
--epoch_trials = 10,
|
||||
param_rate = 0.147,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-lips3',
|
||||
parent = 'ars',
|
||||
|
||||
ars_lips = true,
|
||||
param_rate = 0.5,
|
||||
deviation = 0.02, -- added after like 272 epochs
|
||||
param_decay = 0.0276, -- added after like 62 epochs
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'hard-embed',
|
||||
parent = 'big-scroll-hidden',
|
||||
|
||||
embed = false,
|
||||
reduce_tiles = 5,
|
||||
hidden_size = 54,
|
||||
|
||||
epoch_top_trials = 20,
|
||||
deterministic = true,
|
||||
deviation = 0.01,
|
||||
epoch_trials = 20,
|
||||
param_rate = 0.368,
|
||||
param_decay = 0.0138,
|
||||
momentum = 0.5,
|
||||
beta = 1.0,
|
||||
}
|
||||
|
||||
-- end of new stuff
|
||||
|
||||
make_preset{
|
||||
name = 'play',
|
||||
|
||||
|
|
111
qr.lua
Normal file
111
qr.lua
Normal file
|
@ -0,0 +1,111 @@
|
|||
local min = math.min
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local nn = require "nn"
|
||||
local dot = nn.dot
|
||||
local reshape = nn.reshape
|
||||
local transpose = nn.transpose
|
||||
local zeros = nn.zeros
|
||||
|
||||
local function minor(x, d)
|
||||
assert(#x.shape == 2)
|
||||
assert(d <= x.shape[1] and d <= x.shape[2])
|
||||
|
||||
local m = zeros(x.shape)
|
||||
|
||||
-- fill diagonals.
|
||||
--for i = 1, d do m[(i - 1) * m.shape[2] + i] = 1 end
|
||||
for i = 1, d * m.shape[2], m.shape[2] + 1 do m[i] = 1 end
|
||||
|
||||
-- copy values.
|
||||
for i = d + 1, m.shape[1] do
|
||||
for j = d + 1, m.shape[2] do
|
||||
local ind = (i - 1) * m.shape[2] + j
|
||||
m[ind] = x[ind]
|
||||
end
|
||||
end
|
||||
|
||||
return m
|
||||
end
|
||||
|
||||
local function norm(a) -- vector norm
|
||||
local sum = 0
|
||||
for _, v in ipairs(a) do sum = sum + v * v end
|
||||
return sqrt(sum)
|
||||
end
|
||||
|
||||
local function householder(x)
|
||||
local rows = x.shape[1]
|
||||
local cols = x.shape[2]
|
||||
local iters = min(rows - 1, cols)
|
||||
|
||||
local q = nil
|
||||
local vec = zeros(rows)
|
||||
local z = x
|
||||
|
||||
for k = 1, iters do
|
||||
z = minor(z, k - 1)
|
||||
|
||||
-- extract a column.
|
||||
for i = 1, rows do vec[i] = z[k + (i - 1) * cols] end
|
||||
|
||||
local a = norm(vec)
|
||||
-- negate the norm if the original diagonal is non-negative.
|
||||
local ind = (k - 1) * cols + k
|
||||
if x[ind] > 0 then a = -a end
|
||||
|
||||
vec[k] = vec[k] + a
|
||||
|
||||
local a = norm(vec)
|
||||
if a == 0 then a = 1 end -- FIXME: should probably just raise an error.
|
||||
for i, v in ipairs(vec) do vec[i] = v / a end
|
||||
|
||||
-- construct the householder reflection: mat = I - 2 * vec * vec.T
|
||||
local mat = zeros{rows, rows}
|
||||
for i = 1, rows do
|
||||
for j = 1, rows do
|
||||
local ind = (i - 1) * rows + j
|
||||
local diag = i == j and 1 or 0
|
||||
mat[ind] = diag - 2 * vec[i] * vec[j]
|
||||
end
|
||||
end
|
||||
|
||||
--print(nn.pp(mat, "%9.3f"))
|
||||
if q == nil then q = mat else q = dot(mat, q) end
|
||||
|
||||
z = dot(mat, z)
|
||||
end
|
||||
|
||||
return transpose(q), dot(q, x) -- Q, R
|
||||
end
|
||||
|
||||
local function qr(x)
|
||||
-- a wrapper for the householder method that will return reduced matrices.
|
||||
assert(#x.shape == 2)
|
||||
|
||||
local q, r = householder(x)
|
||||
|
||||
local rows = x.shape[1]
|
||||
local cols = x.shape[2]
|
||||
if cols >= rows then return q, r end
|
||||
|
||||
-- trim q in-place.
|
||||
q.shape[2] = cols
|
||||
local ind = 1
|
||||
for i = 1, rows do
|
||||
for j = 1, cols do
|
||||
--ind = (i - 1) * cols + j
|
||||
q[ind] = q[(i - 1) * rows + j]
|
||||
ind = ind + 1
|
||||
end
|
||||
end
|
||||
for i = rows * cols + 1, #q do q[i] = nil end
|
||||
|
||||
-- trim r in-place.
|
||||
r.shape[1] = r.shape[2]
|
||||
for i = r.shape[1] * r.shape[2] + 1, #r do r[i] = nil end
|
||||
|
||||
return q, r
|
||||
end
|
||||
|
||||
return qr
|
76
qr2.lua
Normal file
76
qr2.lua
Normal file
|
@ -0,0 +1,76 @@
|
|||
local min = math.min
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local nn = require "nn"
|
||||
local transpose = nn.transpose
|
||||
local zeros = nn.zeros
|
||||
|
||||
local function qr(a)
|
||||
-- FIXME: if first column is exactly zero,
|
||||
-- and cols > rows, Q @ R will not reconstruct the input.
|
||||
-- this isn't too bad since an input like that is invalid anyway,
|
||||
-- but i feel like it should be salvageable.
|
||||
|
||||
-- actually the scope of the problem is much larger than that.
|
||||
-- an input like
|
||||
--[=[
|
||||
[[0, 0, 0, 0]
|
||||
[1, 0, 1, 1]
|
||||
[0, 0, 2, 2]
|
||||
[0, 0, 3, 3]]
|
||||
--]=]
|
||||
-- will cause a lot of problems. for example, Q @ Q.T won't equal eye(4).
|
||||
-- hmm. maybe we can detect this and reverse the matmul to identity if necessary?
|
||||
|
||||
assert(#a.shape == 2)
|
||||
local rows = a.shape[1]
|
||||
local cols = a.shape[2]
|
||||
local small = min(rows, cols)
|
||||
|
||||
local q = transpose(a)
|
||||
local r = zeros{small, cols}
|
||||
|
||||
for i = 1, cols do
|
||||
local i0 = (i - 1) * rows + 1
|
||||
local i1 = i * rows
|
||||
|
||||
for j = 1, min(i - 1, small) do
|
||||
local j0 = (j - 1) * rows + 1
|
||||
local j1 = j * rows
|
||||
local i_to_j = j0 - i0
|
||||
|
||||
local num = 0
|
||||
local den = 0
|
||||
for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end
|
||||
for k = j0, j1 do den = den + q[k] * q[k] end
|
||||
print(num, den)
|
||||
if den == 0 then den = 1 end -- TODO: should probably just error.
|
||||
|
||||
local x = num / den
|
||||
r[(j - 1) * cols + i] = x
|
||||
for k = i0, i1 do q[k] = q[k] - q[k + i_to_j] * x end
|
||||
end
|
||||
|
||||
if i <= small then
|
||||
local sum = 0
|
||||
for k = i0, i1 do sum = sum + q[k] * q[k] end
|
||||
local norm = sqrt(sum)
|
||||
|
||||
if norm == 0 then
|
||||
--norm = 1
|
||||
--q[i0 + i - 1] = 1 -- FIXME: not robust.
|
||||
r[(i - 1) * cols + i] = 0
|
||||
else
|
||||
for k = i0, i1 do q[k] = q[k] / norm end
|
||||
r[(i - 1) * cols + i] = norm
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
for k = small * rows + 1, #q do q[k] = nil end
|
||||
q.shape[1] = small
|
||||
|
||||
return transpose(q), r
|
||||
end
|
||||
|
||||
return qr
|
72
qr_test.lua
Normal file
72
qr_test.lua
Normal file
|
@ -0,0 +1,72 @@
|
|||
local globalize = require "strict"
|
||||
local nn = require "nn"
|
||||
local qr = require "qr"
|
||||
local qr2 = require "qr2"
|
||||
|
||||
local A
|
||||
|
||||
if false then
|
||||
A = {
|
||||
12, -51, 4,
|
||||
6, 167, -68,
|
||||
-4, 24, -41,
|
||||
|
||||
-1, 1, 0,
|
||||
2, 0, 3,
|
||||
}
|
||||
A = nn.reshape(A, 5, 3)
|
||||
|
||||
elseif false then
|
||||
A = {
|
||||
0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9,
|
||||
10, 11, 12, 13, 14
|
||||
}
|
||||
A = nn.reshape(A, 5, 3)
|
||||
|
||||
else
|
||||
A = {
|
||||
1, 2, 0,
|
||||
2, 3, 1,
|
||||
3, 4, 0,
|
||||
4, 5, 1,
|
||||
5, 6, 0,
|
||||
}
|
||||
|
||||
A = {
|
||||
1, 0, 0,
|
||||
2, 0, 1,
|
||||
3, 0, 0,
|
||||
4, 0, 1,
|
||||
5, 0, 0,
|
||||
}
|
||||
|
||||
A = nn.reshape(A, 5, 3)
|
||||
--A = nn.transpose(A)
|
||||
|
||||
end
|
||||
|
||||
print("A")
|
||||
print(nn.pp(A, "%9.4f"))
|
||||
print()
|
||||
|
||||
local Q, R = qr2(A)
|
||||
|
||||
print("Q")
|
||||
print(nn.pp(Q, "%9.4f"))
|
||||
print()
|
||||
|
||||
print("R")
|
||||
print(nn.pp(R, "%9.4f"))
|
||||
print()
|
||||
|
||||
print("Q @ R")
|
||||
print(nn.pp(nn.dot(Q, R), "%9.4f"))
|
||||
print()
|
||||
|
||||
--print("Q @ Q.T = I")
|
||||
--print(nn.pp(nn.dot(Q, nn.transpose(Q)), "%9.4f"))
|
||||
--print()
|
||||
|
||||
--A = nn.reshape(A, 5, 3)
|
||||
--Q, R = qr(A)
|
217
smb.lua
217
smb.lua
|
@ -1,12 +1,16 @@
|
|||
-- disassembly used for reference:
|
||||
-- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
|
||||
|
||||
local band = bit.band
|
||||
local floor = math.floor
|
||||
local emu = emu
|
||||
local gui = gui
|
||||
|
||||
local util = require("util")
|
||||
|
||||
local band = bit.band
|
||||
local clamp = util.clamp
|
||||
local empty = util.empty
|
||||
local emu = emu
|
||||
local floor = math.floor
|
||||
local gui = gui
|
||||
local insert = table.insert
|
||||
|
||||
local R = memory.readbyteunsigned
|
||||
local W = memory.writebyte
|
||||
local function S(addr) return util.signbyte(R(addr)) end
|
||||
|
@ -76,13 +80,84 @@ local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
|
|||
-8, -38,
|
||||
}
|
||||
|
||||
local tile_embedding = {
|
||||
-- handmade trinary encoding.
|
||||
-- we have 57 valid tile types and 27 permutations to work with.
|
||||
[0x00] = { 0, 0, 0}, -- air
|
||||
[0x10] = { 1, -1, 1}, -- vertical pipe (top left) (enterable)
|
||||
[0x11] = { 1, -1, 1}, -- vertical pipe (top right) (enterable)
|
||||
[0x12] = { 0, -1, 1}, -- vertical pipe (top left)
|
||||
[0x13] = { 0, -1, 1}, -- vertical pipe (top right)
|
||||
[0x14] = { 0, -1, -1}, -- vertical pipe (left)
|
||||
[0x15] = { 0, -1, -1}, -- vertical pipe (right)
|
||||
[0x16] = {-1, -1, -1}, --
|
||||
[0x17] = {-1, -1, -1}, --
|
||||
[0x18] = {-1, -1, -1}, --
|
||||
[0x19] = {-1, -1, -1}, --
|
||||
[0x1A] = {-1, -1, -1}, --
|
||||
[0x1B] = {-1, -1, -1}, --
|
||||
[0x1C] = { 0, -1, 0}, -- horizontal pipe (top left)
|
||||
[0x1D] = { 0, -1, 0}, -- horizontal pipe (top)
|
||||
[0x1E] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (top)
|
||||
[0x1F] = { 0, -1, 0}, -- horizontal pipe (bottom left)
|
||||
[0x20] = { 0, -1, 0}, -- horizontal pipe (bottom)
|
||||
[0x21] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (bottom)
|
||||
[0x22] = { 0, 0, 0}, --
|
||||
[0x23] = { 0, 0, 0}, -- block being hit (either breakable or ?)
|
||||
[0x24] = { 0, 0, 0}, --
|
||||
[0x25] = { 0, 0, 1}, -- flagpole
|
||||
[0x26] = { 0, 0, 0}, --
|
||||
[0x51] = { 1, 1, 0}, -- breakable brick block
|
||||
[0x52] = { 1, 1, 0}, -- breakable brick block (again?)
|
||||
[0x54] = { 0, 1, 0}, -- regular ground
|
||||
[0x55] = { 0, 0, 0}, --
|
||||
[0x56] = { 0, 0, 0}, --
|
||||
[0x57] = {-1, 1, 1}, -- star brick block
|
||||
[0x58] = { 1, 1, -1}, -- coin brick block (many coins)
|
||||
[0x59] = { 0, 0, 0}, --
|
||||
[0x5A] = { 0, 0, 0}, --
|
||||
[0x5B] = { 0, 0, 0}, --
|
||||
[0x5C] = { 0, 0, 0}, --
|
||||
[0x5D] = { 1, 1, -1}, -- coin brick block (many coins) (again?)
|
||||
[0x5E] = { 0, 0, 0}, --
|
||||
[0x5F] = { 0, 0, 0}, --
|
||||
[0x60] = {-1, 0, 0}, -- invisible 1-up block
|
||||
[0x61] = { 0, 1, -1}, -- chocolate block (usually used for stairs)
|
||||
[0x62] = { 0, 0, 0}, --
|
||||
[0x63] = { 0, 0, 0}, --
|
||||
[0x64] = { 0, 0, 0}, --
|
||||
[0x65] = { 0, 0, 0}, --
|
||||
[0x66] = { 0, 0, 0}, --
|
||||
[0x67] = { 0, 0, 0}, --
|
||||
[0x68] = { 0, 0, 0}, --
|
||||
[0x69] = { 0, 0, 0}, --
|
||||
[0x6B] = { 0, 0, 0}, --
|
||||
[0x6C] = { 0, 0, 0}, --
|
||||
[0x89] = { 0, 0, 0}, --
|
||||
[0xC0] = {-1, 1, -1}, -- coin ? block
|
||||
[0xC1] = {-1, 1, 0}, -- mushroom ? block
|
||||
[0xC2] = { 0, 0, -1}, -- coin
|
||||
[0xC3] = { 0, 0, 0}, --
|
||||
[0xC4] = { 0, 1, 1}, -- hit block
|
||||
[0xC5] = { 0, 0, 0}, --
|
||||
}
|
||||
|
||||
-- TODO: reinterface to one "input" array visible to main.lua.
|
||||
local sprite_input = {}
|
||||
local tile_input = {}
|
||||
local extra_input = {}
|
||||
local new_input = {}
|
||||
|
||||
local overlay = false
|
||||
|
||||
local function embed_tile(t)
|
||||
local out = new_input
|
||||
local embedded = tile_embedding[t]
|
||||
insert(out, embedded[1])
|
||||
insert(out, embedded[2])
|
||||
insert(out, embedded[3])
|
||||
end
|
||||
|
||||
local function get_timer()
|
||||
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
||||
end
|
||||
|
@ -130,6 +205,7 @@ end
|
|||
|
||||
local function mark_tile(x, y, t)
|
||||
tile_input[#tile_input+1] = tile_lut[t]
|
||||
embed_tile(t)
|
||||
if t == 0 then return end
|
||||
if overlay then
|
||||
gui.box(x-8, y-8, x+8, y+8)
|
||||
|
@ -289,7 +365,8 @@ local function handle_tiles()
|
|||
extra_input[#extra_input+1] = tile_scroll_remainder
|
||||
-- for y = 0, 12 do
|
||||
-- afaik the bottom row is always a copy of the second to bottom,
|
||||
-- and the top is always air, so drop those from the inputs:
|
||||
-- and the top is always air (except underground!),
|
||||
-- so drop those from the inputs:
|
||||
for y = 1, 11 do
|
||||
for x = 0, 16 do
|
||||
local col = (x + tile_scroll) % 32
|
||||
|
@ -306,6 +383,117 @@ local function handle_tiles()
|
|||
end
|
||||
end
|
||||
|
||||
local function load_level(world, level)
|
||||
if world == 0 then world = random(1, 8) end
|
||||
if level == 0 then level = random(1, 4) end
|
||||
emu.poweron()
|
||||
local jp_mash = {
|
||||
up = false, down = false, left = false, right = false,
|
||||
A = false, B = false, select = false, start = false,
|
||||
}
|
||||
while emu.framecount() < 60 do
|
||||
if emu.framecount() == 32 then
|
||||
local area = area_lut[world * 10 + level]
|
||||
W(0x75F, world - 1)
|
||||
W(0x75C, level - 1)
|
||||
W(0x760, area)
|
||||
end
|
||||
if emu.framecount() == 42 then
|
||||
W(0x7A0, 0) -- world screen timer (reduces startup time)
|
||||
end
|
||||
jp_mash['start'] = emu.framecount() % 2 == 1
|
||||
joypad.write(1, jp_mash)
|
||||
emu.frameadvance()
|
||||
end
|
||||
end
|
||||
|
||||
-- new stuff.
|
||||
|
||||
local function tile_from_xy(x, y)
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
|
||||
local tile_x = floor((x + tile_scroll_remainder) / 16)
|
||||
local tile_y = floor(y / 16)
|
||||
local tile_t = 0 -- default to air
|
||||
if tile_y < 0 then
|
||||
tile_y = 0
|
||||
elseif tile_y > 12 then
|
||||
tile_y = 12
|
||||
else
|
||||
local col = (tile_x + tile_scroll) % 32
|
||||
local addr = col < 16 and 0x500 or 0x5D0
|
||||
tile_t = R(addr + tile_y * 16 + (col % 16))
|
||||
end
|
||||
return tile_t, tile_x, tile_y
|
||||
end
|
||||
|
||||
local function new_stuff()
|
||||
-- obviously very work in progress.
|
||||
|
||||
empty(new_input)
|
||||
|
||||
local mario_x, mario_y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
||||
|
||||
-- normalized mario x
|
||||
-- normalized mario y
|
||||
insert(new_input, (mario_x - 112) / 64)
|
||||
insert(new_input, (mario_y - 160) / 96)
|
||||
|
||||
-- type of tile we're standing on
|
||||
-- type of tile we're occupying
|
||||
|
||||
gui.box(mario_x, mario_y, mario_x + 16, mario_y + 32)
|
||||
|
||||
local mario_tile_t, mario_tile_x, mario_tile_y =
|
||||
tile_from_xy(mario_x + 8, mario_y - 8)
|
||||
|
||||
--[[
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
local sx = mario_tile_x * 16 + 8 - tile_scroll_remainder
|
||||
local sy = mario_tile_y * 16 + 40
|
||||
gui.box(sx-8, sy-8, sx+8, sy+8)
|
||||
gui.text(sx-5, sy-3, ("%02X"):format(mario_tile_t), '#FFFFFF', '#00000000')
|
||||
--]]
|
||||
|
||||
embed_tile(new_input, mario_tile_t)
|
||||
|
||||
-- type of tile to the right, excluding space (small eyeheight)
|
||||
-- how tall non-space extends upward from that tile
|
||||
-- how tall non-space extends downward from that tile
|
||||
-- type of tile to the right, excluding space (large eyeheight)
|
||||
|
||||
-- type of tile to the left, excluding space (small eyeheight)
|
||||
-- how tall non-space extends upward from that tile
|
||||
-- how tall non-space extends downward from that tile
|
||||
-- type of tile to the left, excluding space (large eyeheight)
|
||||
|
||||
-- type of enemy (nearest down-right from mario's upper left)
|
||||
-- normalized enemy x
|
||||
-- normalized enemy y
|
||||
|
||||
-- VISUALIZE:
|
||||
|
||||
--local tile_col = R(0x6A0)
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
for y = 1, 11 do
|
||||
for x = 0, 16 do
|
||||
local col = (x + tile_scroll) % 32
|
||||
local addr = col < 16 and 0x500 or 0x5D0
|
||||
local t = R(addr + y * 16 + (col % 16))
|
||||
local sx = x * 16 + 8 - tile_scroll_remainder
|
||||
local sy = y * 16 + 40
|
||||
|
||||
if t ~= 0 then
|
||||
gui.box(sx-8, sy-8, sx+8, sy+8)
|
||||
gui.text(sx-5, sy-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
-- TODO: don't expose these; provide interfaces for everything needed.
|
||||
R=R,
|
||||
|
@ -315,6 +503,7 @@ overlay=overlay,
|
|||
|
||||
valid_tiles=valid_tiles,
|
||||
area_lut=area_lut,
|
||||
embed_tile=embed_tile,
|
||||
|
||||
sprite_input=sprite_input,
|
||||
tile_input=tile_input,
|
||||
|
@ -323,16 +512,24 @@ extra_input=extra_input,
|
|||
get_timer=get_timer,
|
||||
get_score=get_score,
|
||||
set_timer=set_timer,
|
||||
mark_sprite=mark_sprite,
|
||||
mark_tile=mark_tile,
|
||||
get_state=get_state,
|
||||
|
||||
getxy=getxy,
|
||||
paused=paused,
|
||||
get_state=get_state,
|
||||
advance=advance,
|
||||
|
||||
mark_sprite=mark_sprite,
|
||||
mark_tile=mark_tile,
|
||||
|
||||
handle_enemies=handle_enemies,
|
||||
handle_fireballs=handle_fireballs,
|
||||
handle_blocks=handle_blocks,
|
||||
handle_hammers=handle_hammers,
|
||||
handle_misc=handle_misc,
|
||||
handle_tiles=handle_tiles,
|
||||
|
||||
advance=advance,
|
||||
load_level=load_level,
|
||||
|
||||
new_stuff=new_stuff,
|
||||
new_input=new_input,
|
||||
}
|
||||
|
|
136
snes.lua
136
snes.lua
|
@ -29,9 +29,12 @@ local normalize_sums = util.normalize_sums
|
|||
local pdf = util.pdf
|
||||
local weighted_mann_whitney = util.weighted_mann_whitney
|
||||
|
||||
local xnes = require "xnes"
|
||||
local make_utility = xnes.make_utility
|
||||
|
||||
local Snes = Base:extend()
|
||||
|
||||
function Snes:init(dims, popsize, base_rate, sigma, antithetic)
|
||||
function Snes:init(dims, popsize, base_rate, sigma, antithetic, adaptive)
|
||||
-- heuristic borrowed from CMA-ES:
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
|
@ -41,9 +44,12 @@ function Snes:init(dims, popsize, base_rate, sigma, antithetic)
|
|||
self.covar_rate = base_rate
|
||||
self.sigma = sigma or 1
|
||||
self.antithetic = antithetic and true or false
|
||||
self.adaptive = adaptive == nil and true or adaptive
|
||||
|
||||
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||
|
||||
self.utility = make_utility(self.popsize)
|
||||
|
||||
self.rate_init = self.sigma_rate
|
||||
|
||||
self.mean = zeros{dims}
|
||||
|
@ -147,60 +153,84 @@ function Snes:ask_mix(start_anew)
|
|||
|
||||
-- perform importance mixing.
|
||||
|
||||
local mean_old = self.mean
|
||||
local mean_old = self.mean_old or self.mean
|
||||
local mean_new = self.mean
|
||||
local std_old = self.std_old or self.std
|
||||
local std_new = self.std
|
||||
|
||||
self.new_asked = {}
|
||||
self.new_noise = {}
|
||||
|
||||
local marked = {}
|
||||
for p=1, min(#self.old_asked, self.popsize) do
|
||||
local a = self.old_asked[p]
|
||||
|
||||
-- TODO: cache probs?
|
||||
local function compute_probabilities(a)
|
||||
local prob_new = 0
|
||||
local prob_old = 0
|
||||
for i, v in ipairs(a) do
|
||||
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
|
||||
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
|
||||
end
|
||||
return prob_new, prob_old
|
||||
end
|
||||
|
||||
local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1)
|
||||
if uniform() < accept then
|
||||
--print(("accepted old sample %i with probability %f"):format(p, accept))
|
||||
else
|
||||
-- insert in reverse as not to screw up
|
||||
-- the indices when removing later.
|
||||
insert(marked, 1, p)
|
||||
local all_asked, all_noise, all_score = {}, {}, {}
|
||||
|
||||
for p=1, #self.old_asked do
|
||||
do
|
||||
local pp = floor(uniform() * #self.old_asked) + 1
|
||||
local a = self.old_asked[pp]
|
||||
local prob_new, prob_old = compute_probabilities(a)
|
||||
local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1)
|
||||
if uniform() < accept then
|
||||
--print(("accepted old sample %i with probability %f"):format(pp, accept))
|
||||
insert(all_asked, a)
|
||||
insert(all_noise, self.old_noise[pp])
|
||||
insert(all_score, self.old_score[pp])
|
||||
end
|
||||
end
|
||||
end
|
||||
for _, p in ipairs(marked) do
|
||||
remove(self.old_asked, p)
|
||||
remove(self.old_noise, p)
|
||||
remove(self.old_score, p)
|
||||
|
||||
do
|
||||
local a, n = {}, {}
|
||||
for i=1, self.dims do n[i] = normal() end
|
||||
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
|
||||
local prob_new, prob_old = compute_probabilities(a)
|
||||
local accept = max(1 - prob_old / prob_new, self.min_refresh)
|
||||
if uniform() < accept then
|
||||
--print(("accepted new sample %i with probability %f"):format(#all_asked, accept))
|
||||
insert(all_asked, a)
|
||||
insert(all_noise, n)
|
||||
insert(all_score, false)
|
||||
end
|
||||
end
|
||||
|
||||
-- TODO: early stopping, making sure it doesn't affect performance.
|
||||
end
|
||||
|
||||
while #self.old_asked + #self.new_asked < self.popsize do
|
||||
local a = {}
|
||||
local n = {}
|
||||
while #all_asked > self.popsize do
|
||||
local pp = floor(uniform() * #all_asked) + 1
|
||||
--print(("removing sample %i to fit popsize"):format(pp))
|
||||
remove(all_asked, pp)
|
||||
remove(all_noise, pp)
|
||||
remove(all_score, pp)
|
||||
end
|
||||
|
||||
while #all_asked < self.popsize do
|
||||
local a, n = {}, {}
|
||||
for i=1, self.dims do n[i] = normal() end
|
||||
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
|
||||
--print(("unconditionally added new sample %i"):format(#all_asked))
|
||||
insert(all_asked, a)
|
||||
insert(all_noise, n)
|
||||
insert(all_score, false)
|
||||
end
|
||||
|
||||
-- can't cache here!
|
||||
local prob_new = 0
|
||||
local prob_old = 0
|
||||
for i, v in ipairs(a) do
|
||||
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
|
||||
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
|
||||
end
|
||||
|
||||
local accept = max(1 - prob_old / prob_new, self.min_refresh)
|
||||
if uniform() < accept then
|
||||
-- split all_ tables back into old_ and new_.
|
||||
self.old_asked, self.old_noise, self.old_score = {}, {}, {}
|
||||
self.new_asked, self.new_noise = {}, {}
|
||||
for i, score in ipairs(all_score) do
|
||||
local a, n = all_asked[i], all_noise[i]
|
||||
if score ~= false then
|
||||
insert(self.old_asked, a)
|
||||
insert(self.old_noise, n)
|
||||
insert(self.old_score, score)
|
||||
else
|
||||
insert(self.new_asked, a)
|
||||
insert(self.new_noise, n)
|
||||
--print(("accepted new sample %i with probability %f"):format(0, accept))
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -210,15 +240,15 @@ end
|
|||
function Snes:tell(scored)
|
||||
self.evals = self.evals + #scored
|
||||
|
||||
local asked = self.asked
|
||||
local noise = self.noise
|
||||
local asked = self.mixing and self.new_asked or self.asked
|
||||
local noise = self.mixing and self.new_noise or self.noise
|
||||
if self.mixing then
|
||||
-- note: modifies, in-place, externally exposed tables.
|
||||
for i, v in ipairs(asked) do insert(self.old_asked, v) end
|
||||
for i, v in ipairs(noise) do insert(self.old_noise, v) end
|
||||
for i, v in ipairs(scored) do insert(self.old_score, v) end
|
||||
asked = self.old_asked
|
||||
noise = self.old_noise
|
||||
-- note that these modify tables referenced externally in-place.
|
||||
for i, v in ipairs(self.new_asked) do insert(asked, v) end
|
||||
for i, v in ipairs(self.new_noise) do insert(noise, v) end
|
||||
for i, v in ipairs(scored) do insert(self.old_score, v) end
|
||||
scored = self.old_score
|
||||
end
|
||||
assert(asked and noise, ":tell() called before :ask()")
|
||||
|
@ -231,8 +261,9 @@ function Snes:tell(scored)
|
|||
local g_mean = zeros{self.dims}
|
||||
local g_std = zeros{self.dims}
|
||||
|
||||
--[[
|
||||
local utilize = true
|
||||
local utility
|
||||
local utility = self.utility
|
||||
|
||||
if utilize then
|
||||
utility = {}
|
||||
|
@ -242,17 +273,18 @@ function Snes:tell(scored)
|
|||
else
|
||||
utility = normalize_sums(scored, {})
|
||||
end
|
||||
--]]
|
||||
|
||||
for p=1, self.popsize do
|
||||
local noise_p = noise[p]
|
||||
local noise_p = noise[arg[p]]
|
||||
|
||||
for i, v in ipairs(g_mean) do
|
||||
g_mean[i] = v + utility[p] * noise_p[i]
|
||||
g_mean[i] = v + self.utility[p] * noise_p[i]
|
||||
end
|
||||
|
||||
for i, v in ipairs(g_std) do
|
||||
local n = noise_p[i]
|
||||
g_std[i] = v + utility[p] * (n * n - 1)
|
||||
g_std[i] = v + self.utility[p] * (n * n - 1)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -261,7 +293,9 @@ function Snes:tell(scored)
|
|||
step[i] = self.std[i] * v
|
||||
end
|
||||
|
||||
self.mean_old = {}
|
||||
for i, v in ipairs(self.mean) do
|
||||
self.mean_old[i] = v
|
||||
self.mean[i] = v + self.param_rate * step[i]
|
||||
end
|
||||
|
||||
|
@ -273,7 +307,7 @@ function Snes:tell(scored)
|
|||
otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i])
|
||||
end
|
||||
|
||||
self:adapt(asked, otherwise, utility)
|
||||
if self.adaptive then self:adapt(asked, otherwise, self.utility) end
|
||||
|
||||
return step
|
||||
end
|
||||
|
@ -291,15 +325,15 @@ function Snes:adapt(asked, otherwise, qualities)
|
|||
weights[p] = prob_big / prob_now
|
||||
end
|
||||
|
||||
local p = weighted_mann_whitney(qualities, qualities, nil, weights)
|
||||
--print("p:", p)
|
||||
local u, p = weighted_mann_whitney(qualities, qualities, nil, weights)
|
||||
--print(("u, p: %6.3f, %6.3f"):format(u, p))
|
||||
|
||||
if p < 0.5 - 1 / (3 * (self.dims + 1)) then
|
||||
self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init
|
||||
print("learning rate -:", self.sigma_rate)
|
||||
--print("learning rate -:", self.sigma_rate)
|
||||
else
|
||||
self.sigma_rate = min(1.1 * self.sigma_rate, 1)
|
||||
print("learning rate +:", self.sigma_rate)
|
||||
--print("learning rate +:", self.sigma_rate)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
51
util.lua
51
util.lua
|
@ -14,6 +14,7 @@ local random = math.random
|
|||
local select = select
|
||||
local sort = table.sort
|
||||
local sqrt = math.sqrt
|
||||
local type = type
|
||||
|
||||
local function sign(x)
|
||||
-- remember that 0 is truthy in Lua.
|
||||
|
@ -83,6 +84,28 @@ local function calc_mean_dev(x)
|
|||
return mean, sqrt(dev)
|
||||
end
|
||||
|
||||
local function calc_mean_dev_unbiased(x)
|
||||
-- NOTE: this uses an approximation; there is still a little bias.
|
||||
assert(#x > 1)
|
||||
|
||||
local mean = 0
|
||||
for i, v in ipairs(x) do
|
||||
mean = mean + v / #x
|
||||
end
|
||||
|
||||
-- via Gurland, John; Tripathi, Ram C. (1971):
|
||||
-- A Simple Approximation for Unbiased Estimation of the Standard Deviation
|
||||
local divisor = #x - 1.5 + 1 / (8 * (#x - 1))
|
||||
|
||||
local dev = 0
|
||||
for i, v in ipairs(x) do
|
||||
local delta = v - mean
|
||||
dev = dev + delta * delta / divisor
|
||||
end
|
||||
|
||||
return mean, sqrt(dev)
|
||||
end
|
||||
|
||||
local function normalize(x, out)
|
||||
out = out or x
|
||||
local mean, dev = calc_mean_dev(x)
|
||||
|
@ -246,7 +269,31 @@ local function weighted_mann_whitney(s0, s1, w0, w1)
|
|||
local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6)
|
||||
local p = cdf((U - mean) / std)
|
||||
|
||||
if s0_sum > s1_sum then return 1 - p else return p end
|
||||
local u = U / (w0_sum * w1_sum)
|
||||
if s0_sum > s1_sum then return u, 1 - p else return u, p end
|
||||
end
|
||||
|
||||
local function expect_cossim(n)
|
||||
-- returns gamma(n / 2) / gamma((n + 1) / 2) / sqrt(pi) for positive integers.
|
||||
-- this is the expected absolute cosine similarity between
|
||||
-- two standard normally-distributed random vectors both of size n.
|
||||
assert(n > 0)
|
||||
|
||||
-- abs(error) < 1e-8
|
||||
if n >= 128000 then
|
||||
return 1 / sqrt(pi / 2 * n + 1)
|
||||
elseif n >= 80 then
|
||||
poly = (2.4674010 * n + -2.4673232) * n + 1.2274046
|
||||
return 1 / sqrt(sqrt(poly))
|
||||
end
|
||||
-- fall-through when it's faster just to compute iteratively.
|
||||
|
||||
even = n % 2 == 0
|
||||
res = even and 2.0 or 1.0
|
||||
for i = even and 2 or 1, n - 1, 2 do
|
||||
res = res * (i / (i + 1))
|
||||
end
|
||||
return even and res / pi or res
|
||||
end
|
||||
|
||||
return {
|
||||
|
@ -260,6 +307,7 @@ return {
|
|||
softchoice=softchoice,
|
||||
empty=empty,
|
||||
calc_mean_dev=calc_mean_dev,
|
||||
calc_mean_dev_unbiased=calc_mean_dev_unbiased,
|
||||
normalize=normalize,
|
||||
normalize_wrt=normalize_wrt,
|
||||
normalize_sums=normalize_sums,
|
||||
|
@ -276,4 +324,5 @@ return {
|
|||
pdf=pdf,
|
||||
cdf=cdf,
|
||||
weighted_mann_whitney=weighted_mann_whitney,
|
||||
expect_cossim=expect_cossim,
|
||||
}
|
||||
|
|
8
xnes.lua
8
xnes.lua
|
@ -70,6 +70,8 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
|
|||
-- note: this is technically the co-standard-deviation.
|
||||
-- you can imagine the "s" standing for "sqrt" if you like.
|
||||
self.covars = make_covars(self.dims, self.sigma, self.covars)
|
||||
|
||||
self.evals = 0
|
||||
end
|
||||
|
||||
function Xnes:params(new_mean)
|
||||
|
@ -153,6 +155,8 @@ function Xnes:tell(scored, noise)
|
|||
local noise = noise or self.noise
|
||||
assert(noise, "missing noise argument")
|
||||
|
||||
self.evals = self.evals + #scored
|
||||
|
||||
local arg = argsort(scored, function(a, b) return a > b end)
|
||||
|
||||
local g_delta = zeros{self.dims}
|
||||
|
@ -173,7 +177,7 @@ function Xnes:tell(scored, noise)
|
|||
local zzt = noise_p[i] * noise_p[j] - (i == j and 1 or 0)
|
||||
local temp = self.utility[p] * zzt
|
||||
g_covars[ind] = g_covars[ind] + temp
|
||||
traced = traced + temp
|
||||
if i == j then traced = traced + temp end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -181,7 +185,7 @@ function Xnes:tell(scored, noise)
|
|||
local g_sigma = traced / self.dims
|
||||
|
||||
for i=1, self.dims do
|
||||
local ind = (i - 1) * self.dims + i
|
||||
local ind = (i - 1) * self.dims + i -- diagonal
|
||||
g_covars[ind] = g_covars[ind] - g_sigma
|
||||
end
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue