This commit is contained in:
Connor Olding 2019-02-26 21:41:37 +01:00
parent a1429a6271
commit 08f476e6ac
14 changed files with 1228 additions and 168 deletions

13
ars.lua
View File

@ -10,8 +10,10 @@ local exp = math.exp
local floor = math.floor
local insert = table.insert
local ipairs = ipairs
local log = math.log
local max = math.max
local print = print
local sqrt = math.sqrt
local Base = require "Base"
@ -72,16 +74,15 @@ local function kinda_lipschitz(dir, pos, neg, mid)
end
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic,
momentum)
momentum, beta)
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.param_rate = base_rate
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1
self.antithetic = antithetic == nil and true or antithetic
self.momentum = momentum or 0
self.beta = beta or 1.0
self.poptop = poptop or popsize
assert(self.poptop <= popsize)
@ -189,8 +190,9 @@ function Ars:tell(scored, unperturbed_score)
reward = reward / reward_dev
end
local scale = reward / self.poptop * self.beta / 2
for j, v in ipairs(noisy) do
step[j] = step[j] + reward * v / self.poptop
step[j] = step[j] + scale * v
end
end
end
@ -200,8 +202,9 @@ function Ars:tell(scored, unperturbed_score)
if reward ~= 0 then
local noisy = self.noise[ind]
local scale = reward / self.poptop * self.beta
for j, v in ipairs(noisy) do
step[j] = step[j] + reward * v / self.poptop
step[j] = step[j] + scale * v
end
end
end

View File

@ -26,15 +26,17 @@ local defaults = {
time_inputs = true, -- insert binary inputs of a frame counter.
-- network layers:
embed = true, -- set to false to use a hard-coded tile embedding.
reduce_tiles = 0, -- TODO: write description
hidden = false, -- use a hidden layer with ReLU/GELU activation.
hidden_size = 128,
layernorm = false, -- use a LayerNorm layer after said activation.
reduce_tiles = false,
bias_out = true,
-- network evaluation (sampling joypad):
frameskip = 4,
prob_frameskip = 0.0,
max_frameskip = 6,
-- true greedy epsilon has both deterministic and det_epsilon set.
deterministic = false, -- use argmax on outputs instead of random sampling.
det_epsilon = false, -- take random actions with probability eps.
@ -42,12 +44,16 @@ local defaults = {
-- evolution strategy and non-rate hyperparemeters:
es = 'ars',
ars_lips = false, -- for ARS.
epoch_top_trials = 9999, -- for ARS.
epoch_top_trials = 9999, -- for ARS, Guided.
alpha = 0.5, -- for Guided.
beta = 2.0, -- for ARS, Guided. should be 1, but defaults to 2 for compat.
past_grads = 1, -- for Guided. keeps a history of n past steps taken.
-- sampling:
deviation = 1.0,
unperturbed_trial = true, -- perform an extra trial without any noise.
-- this is good for logging, so i'd recommend it.
attempts = 1, -- TODO: document.
epoch_trials = 50,
graycode = false, -- for ARS.
negate_trials = true, -- try pairs of normal and negated noise directions.
@ -114,5 +120,9 @@ assert(not cfg.ars_lips or cfg.negate_trials,
"cfg.negate_trials must be true to use cfg.ars_lips")
assert(not (cfg.es == 'snes' and cfg.negate_trials),
"cfg.negate_trials is not yet compatible with SNES")
assert(not (cfg.es == 'guided' and cfg.graycode),
"cfg.graycode is not compatible with Guided")
assert(cfg.es ~= 'guided' or cfg.negate_trials,
"cfg.negate_trials must be true to use Guided")
return cfg

View File

@ -1,21 +1,27 @@
local floor = math.floor
local insert = table.insert
local ipairs = ipairs
local log = math.log
local max = math.max
local print = print
local ars = require("ars")
local snes = require("snes")
local xnes = require("xnes")
local guided = require("guided")
-- try it all out on a dummy problem.
local function typeof(t) return getmetatable(t).__index end
local function square(x) return x * x end
-- this function's global minimum is arange(dims) + 1.
-- xNES should be able to find it almost exactly.
local function spherical(x)
local sum = 0
for i, v in ipairs(x) do sum = sum + square(v - i) end
--for i, v in ipairs(x) do sum = sum + square(v - i) end
for i, v in ipairs(x) do sum = sum + square(v - i / #x) end
-- we need to negate this to turn it into a maximization problem.
return -sum
end
@ -26,10 +32,34 @@ local dims = 100
local popsize = dims + 1
local sigma_init = 0.5
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
--local es = snes.Snes(dims, popsize, 0.1, sigma_init)
local es = snes.Snes(dims, popsize, 0.1, sigma_init)
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
es.min_refresh = 0.7 -- FIXME: needs a better interface.
--local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
--local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5)
es.min_refresh = 0.5 -- FIXME: needs a better interface.
if typeof(es) == xnes.Xnes
or typeof(es) == snes.Snes
then
-- use IGO recommendations
local pop5 = max(1, floor(es.popsize / 5))
local sum = 0
for i=1, es.popsize do
local maybe = i < pop5 and 1 or 0
es.utility[i] = maybe
sum = sum + maybe
end
--for i, v in ipairs(es.utility) do es.utility[i] = v / sum end
local util = require "util"
util.normalize_sums(es.utility)
es.param_rate = 0.39
es.sigma_rate = 0.39
es.covar_rate = 0.39
es.adaptive = false
end
if false then -- TODO: delete me
local nn = require("nn")
@ -64,30 +94,44 @@ local asked = nil -- for caching purposes.
local noise = nil -- for caching purposes.
local current_cost = spherical(es:params())
local past_grads = {}
local pgi = 0
local pgn = 10
for i=1, iterations do
if getmetatable(es).__index == snes.Snes then
if typeof(es) == snes.Snes and es.min_refresh ~= 1 then
asked, noise = es:ask_mix()
elseif getmetatable(es).__index == ars.Ars then
elseif typeof(es) == ars.Ars then
asked, noise = es:ask()
elseif typeof(es) == guided.Guided then
asked, noise = es:ask(past_grads)
else
asked, noise = es:ask(asked, noise)
end
local scores = {}
for i, v in ipairs(asked) do
scores[i] = spherical(v)
end
if getmetatable(es).__index == ars.Ars then
if typeof(es) == ars.Ars then
es:tell(scores)--, current_cost) -- use lips
elseif typeof(es) == guided.Guided then
local step = es:tell(scores)
for _, v in ipairs(step) do
past_grads[pgi + 1] = v
pgi = (pgi + 1) % (pgn * #step)
end
past_grads.shape = {floor(#past_grads / #step), #step}
else
es:tell(scores)
end
current_cost = spherical(es:params())
--if i % 100 == 0 then
if i % 100 == 0 then
local sigma = es.sigma
if getmetatable(es).__index == snes.Snes then
if typeof(es) == snes.Snes then
sigma = 0
for i, v in ipairs(es.std) do sigma = sigma + v end
sigma = sigma / #es.std

194
guided.lua Normal file
View File

@ -0,0 +1,194 @@
-- Guided Evolutionary Strategies
-- https://arxiv.org/abs/1806.10230
-- this is just ARS extended to utilize gradients
-- approximated from previous iterations.
-- for simplicity:
-- antithetic is always true
-- momentum is always 0
-- no graycode/lipschitz nonsense
local floor = math.floor
local insert = table.insert
local ipairs = ipairs
local max = math.max
local print = print
local sqrt = math.sqrt
local Base = require "Base"
local nn = require "nn"
local dot_mv = nn.dot_mv
local transpose = nn.transpose
local normal = nn.normal
local prod = nn.prod
local uniform = nn.uniform
local zeros = nn.zeros
local qr = require "qr2"
local util = require "util"
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local Guided = Base:extend()
local function collect_best_indices(scored, top)
-- select one (the best) reward of each pos/neg pair.
local best_rewards
best_rewards = {}
for i = 1, #scored / 2 do
local pos = scored[i * 2 - 1]
local neg = scored[i * 2 - 0]
best_rewards[i] = max(pos, neg)
end
local indices = argsort(best_rewards, function(a, b) return a > b end)
for i = top + 1, #best_rewards do indices[i] = nil end
return indices
end
function Guided:init(dims, popsize, poptop, base_rate, sigma, alpha, beta)
-- sigma: scale of random perturbations.
-- alpha: blend between full parameter space and its gradient subspace.
-- 1.0 is roughly equivalent to ARS.
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.param_rate = base_rate
self.sigma = sigma or 1.0
self.alpha = alpha or 0.5
self.beta = beta or 1.0
self.poptop = poptop or popsize
assert(self.poptop <= popsize)
self.popsize = self.popsize * 2 -- antithetic
self._params = zeros(self.dims)
--self.accum = zeros(self.dims) -- momentum
self.evals = 0
end
function Guided:params(new_params)
if new_params ~= nil then
assert(#self._params == #new_params, "new parameters have the wrong size")
for i, v in ipairs(new_params) do self._params[i] = v end
end
return self._params
end
function Guided:decay(param_decay, sigma_decay)
-- FIXME: multiplying by sigma probably isn't correct anymore.
-- is this correct now?
if param_decay > 0 then
local scale = self.sigma / sqrt(self.dims)
scale = scale * self.beta
scale = scale * self.param_rate / (self.sigma * self.sigma)
scale = 1 - param_decay * scale
for i, v in ipairs(self._params) do
self._params[i] = scale * v
end
end
end
function Guided:ask(grads)
local asked = {}
local noise = {}
local n_grad = 0
local gnoise, U, dummy, left, right
if grads ~= nil and #grads > 0 then
n_grad = grads.shape[1]
gnoise = zeros(n_grad)
U, dummy = qr(transpose(grads))
--print(nn.pp(transpose(U), "%9.4f"))
left = sqrt(self.alpha / self.dims)
right = sqrt((1 - self.alpha) / n_grad)
--print(left, right)
end
for i = 1, self.popsize do
local asking = zeros(self.dims)
local noisy = zeros(self.dims)
asked[i] = asking
noise[i] = noisy
if i % 2 == 0 then
local old_noisy = noise[i - 1]
for j, v in ipairs(old_noisy) do
noisy[j] = -v
end
elseif n_grad == 0 then
local scale = self.sigma / sqrt(self.dims)
for j = 1, self.dims do
noisy[j] = scale * normal()
end
else
for j = 1, self.dims do noisy[j] = normal() end
for j = 1, n_grad do gnoise[j] = normal() end
local noisier = dot_mv(U, gnoise)
for j, v in ipairs(noisy) do
noisy[j] = self.sigma * (left * v + right * noisier[j])
end
end
for j, v in ipairs(self._params) do
asking[j] = v + noisy[j]
end
end
self.noise = noise
return asked, noise
end
function Guided:tell(scored, unperturbed_score)
self.evals = self.evals + #scored
local indices = collect_best_indices(scored, self.poptop)
local top_rewards = {}
for _, ind in ipairs(indices) do
insert(top_rewards, scored[ind * 2 - 1])
insert(top_rewards, scored[ind * 2 - 0])
end
local step = zeros(self.dims)
local _, reward_dev = calc_mean_dev(top_rewards)
if reward_dev == 0 then reward_dev = 1 end
for i, ind in ipairs(indices) do
local pos = top_rewards[i * 2 - 1]
local neg = top_rewards[i * 2 - 0]
local reward = pos - neg
if reward ~= 0 then
local noisy = self.noise[ind * 2 - 1]
-- NOTE: technically this reward divide isn't part of guided search.
reward = reward / reward_dev
local scale = reward / self.poptop * self.beta / 2
for j, v in ipairs(noisy) do
step[j] = step[j] + scale * v
end
end
end
local coeff = self.param_rate / (self.sigma * self.sigma)
for i, v in ipairs(self._params) do
self._params[i] = v + coeff * step[i]
end
self.noise = nil
return step
end
return {
--collect_best_indices = collect_best_indices, -- ars.lua has more features
Guided = Guided,
}

166
main.lua
View File

@ -20,6 +20,12 @@ local trial_rewards = {}
local trials_remaining = 0
local es -- evolution strategy.
local attempt_i = 0
local sub_rewards = {}
local past_grads = {} -- for Guided
local pgi = 0 -- past_grads_index
local trial_frames = 0
local total_frames = 0
local lagless_count = 0
@ -93,6 +99,7 @@ local util = require("util")
local argmax = util.argmax
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local calc_mean_dev_unbiased = util.calc_mean_dev_unbiased
local clamp = util.clamp
local copy = util.copy
local empty = util.empty
@ -149,13 +156,21 @@ local network
local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z
local function make_network(input_size)
nn_x = nn.Input({input_size})
nn_tx = nn.Input({gcfg.tile_count})
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
local embed_dim = cfg.embed and 2 or 3
if cfg.embed then
nn_tx = nn.Input({gcfg.tile_count})
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, embed_dim))
else -- new tile inputs.
nn_tx = nn.Input({gcfg.tile_count * 3})
nn_ty = nn_tx
end
nn_tz = nn_ty
if cfg.reduce_tiles then
nn_tz = nn_tz:feed(nn.Reshape{11, 17 * 2})
nn_tz = nn_tz:feed(nn.DenseBroadcast(5, true))
if cfg.reduce_tiles > 0 then
nn_tz = nn_tz:feed(nn.Reshape{11, 17 * embed_dim})
nn_tz = nn_tz:feed(nn.DenseBroadcast(cfg.reduce_tiles, true))
nn_tz = nn_tz:feed(nn.Relu())
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
end
@ -185,6 +200,7 @@ end
local ars = require("ars")
local snes = require("snes")
local xnes = require("xnes")
local guided = require("guided")
local function prepare_epoch()
trial_neg = false
@ -213,6 +229,8 @@ local function prepare_epoch()
local dummy
if cfg.es == 'ars' then
trial_params, dummy = es:ask(precision)
elseif cfg.es == 'guided' then
trial_params, dummy = es:ask(past_grads)
elseif cfg.es == 'snes' then
trial_params, dummy = es:ask_mix()
else
@ -223,6 +241,7 @@ local function prepare_epoch()
end
local function load_next_trial()
attempt_i = 1
if cfg.negate_trials then
trial_neg = not trial_neg
else
@ -272,8 +291,16 @@ local function learn_from_epoch()
end
local step_mean, step_dev = calc_mean_dev(step)
print("step mean:", step_mean)
print("step stddev:", step_dev)
print(("step mean: %9.6f"):format(step_mean))
print(("step stddev: %9.6f"):format(step_dev))
if cfg.es == 'guided' and cfg.past_grads > 0 then
for _, v in ipairs(step) do
past_grads[pgi + 1] = v
pgi = (pgi + 1) % (cfg.past_grads * #step)
end
past_grads.shape = {floor(#past_grads / #step), #step}
end
es:decay(cfg.param_decay, cfg.sigma_decay)
@ -329,65 +356,62 @@ local function joypad_mash(button)
joypad.write(1, jp_mash)
end
local function loadlevel(world, level)
-- TODO: move to smb.lua. rename to load_level.
if world == 0 then world = random(1, 8) end
if level == 0 then level = random(1, 4) end
emu.poweron()
while emu.framecount() < 60 do
if emu.framecount() == 32 then
local area = game.area_lut[world * 10 + level]
game.W(0x75F, world - 1)
game.W(0x75C, level - 1)
game.W(0x760, area)
end
if emu.framecount() == 42 then
game.W(0x7A0, 0) -- world screen timer (reduces startup time)
end
joypad_mash('start')
emu.frameadvance()
end
end
local function do_reset()
local state = game.get_state()
-- be a little more descriptive.
if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end
if trial_i >= 0 then
if trial_i == 0 then
print('test trial reward:', reward, "("..state..")")
elseif cfg.negate_trials then
--local dir = trial_neg and "negative" or "positive"
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
--if cfg.attempts > 1 and attempt_i >= cfg.attempts then
attempt_i = attempt_i + 1
sub_rewards[#sub_rewards + 1] = reward
--print(sub_rewards)
if trial_neg then
local pos = trial_rewards[#trial_rewards]
local neg = reward
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
print(fmt:format(floor(trial_i / 2),
pos, neg, last_trial_state, state))
if #sub_rewards >= cfg.attempts then
if cfg.attempts == 1 then
reward = sub_rewards[1]
else
local sub_mean, sub_std = calc_mean_dev(sub_rewards)
reward = floor(sub_mean)
--local sub_mean, sub_std = calc_mean_dev_unbiased(sub_rewards)
--reward = floor(sub_mean - sub_std)
end
empty(sub_rewards)
if trial_i >= 0 then
if trial_i == 0 then
print('test trial reward:', reward, "("..state..")")
elseif cfg.negate_trials then
--local dir = trial_neg and "negative" or "positive"
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
if trial_neg then
local pos = trial_rewards[#trial_rewards]
local neg = reward
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
print(fmt:format(floor(trial_i / 2),
pos, neg, last_trial_state, state))
end
last_trial_state = state
else
print('trial', trial_i, 'reward:', reward, "("..state..")")
end
if trial_i == 0 or not cfg.negate_trials then
trial_rewards[trial_i] = reward
else
trial_rewards[#trial_rewards + 1] = reward
end
last_trial_state = state
else
print('trial', trial_i, 'reward:', reward, "("..state..")")
end
if trial_i == 0 or not cfg.negate_trials then
trial_rewards[trial_i] = reward
else
trial_rewards[#trial_rewards + 1] = reward
end
end
if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then
if epoch_i > 0 then learn_from_epoch() end
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
prepare_epoch()
collectgarbage()
if any_random then
loadlevel(cfg.starting_world, cfg.starting_level)
state_saved = false
if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then
if epoch_i > 0 then learn_from_epoch() end
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
prepare_epoch()
collectgarbage()
if any_random then
game.load_level(cfg.starting_world, cfg.starting_level)
state_saved = false
end
end
end
@ -427,7 +451,9 @@ local function do_reset()
trial_frames = 0
emu.frameadvance() -- prevents emulator from quirking up.
load_next_trial()
if attempt_i > cfg.attempts then
load_next_trial()
end
reset = false
end
@ -450,7 +476,7 @@ local function init()
if not playing then emu.speedmode("turbo") end
if not any_random then
loadlevel(cfg.starting_world, cfg.starting_level)
game.load_level(cfg.starting_world, cfg.starting_level)
end
params_fn = cfg.params_fn or ('network%07i.txt'):format(network.n_param)
@ -481,7 +507,10 @@ local function init()
elseif cfg.es == 'ars' then
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
cfg.base_rate, cfg.deviation, cfg.negate_trials,
cfg.momentum)
cfg.momentum, cfg.beta)
elseif cfg.es == 'guided' then
es = guided.Guided(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
cfg.base_rate, cfg.deviation, cfg.alpha, cfg.beta)
else
error("Unknown evolution strategy specified: " + tostring(cfg.es))
end
@ -537,6 +566,7 @@ local function doit(dummy)
empty(game.sprite_input)
empty(game.tile_input)
empty(game.extra_input)
empty(game.new_input)
local controllable = game.R(0x757) == 0 and game.R(0x758) == 0
local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
@ -616,12 +646,18 @@ local function doit(dummy)
for i, v in ipairs(game.extra_input) do insert(X, v / 256) end
nn.reshape(X, 1, gcfg.input_size)
nn.reshape(game.tile_input, 1, gcfg.tile_count)
nn.reshape(game.new_input, 1, gcfg.tile_count * 3)
trial_frames = trial_frames + cfg.frameskip
if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then
total_frames = total_frames + cfg.frameskip
local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
local outputs
if cfg.embed then
outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
else
outputs = network:forward({[nn_x]=X, [nn_tx]=game.new_input})
end
local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames)
if cfg.det_epsilon and random() < eps then
@ -695,8 +731,12 @@ while true do
end
local delta = lagless_count - last_decision_frame
local doot = jp == nil or delta >= cfg.frameskip
doot = doot and random() >= cfg.prob_frameskip
local doot = true
if jp ~= nil then
doot = delta >= cfg.frameskip
doot = doot and random() >= cfg.prob_frameskip
doot = doot or delta >= cfg.max_frameskip
end
doit(not doot)
if doot then last_decision_frame = lagless_count end

46
nn.lua
View File

@ -1,20 +1,15 @@
local assert = assert
local ceil = math.ceil
local cos = math.cos
local exp = math.exp
local floor = math.floor
local huge = math.huge
local insert = table.insert
local ipairs = ipairs
local log = math.log
local max = math.max
local min = math.min
local open = io.open
local pairs = pairs
local pi = math.pi
local print = print
local remove = table.remove
local sin = math.sin
local sqrt = math.sqrt
local tanh = math.tanh
local tostring = tostring
@ -105,19 +100,26 @@ end
-- ndarray-ish stuff and more involved math
local function pp_join(sep, fmt, t, a, b)
a = a or 1
b = b or #t
local s = ''
for i = a, b do
s = s..fmt:format(t[i])
if i ~= b then s = s..sep end
end
return s
end
local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
-- pretty-prints an nd-array.
fmt = fmt or '%10.7f,'
fmt = fmt or '%10.7f'
sep = sep or ','
ti = ti or 0
di = di or 1
depth = depth or 0
if t.shape == nil then
local s = '['
for i = 1, #t do s = s..fmt:format(t[i]) end
return s..']'..sep..'\n'
end
if t.shape == nil then return '['..pp_join(sep, fmt, t)..']'..sep..'\n' end
local dim = t.shape[di]
@ -134,11 +136,10 @@ local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
s = s..pp(t, fmt, sep, ti, di + 1, depth + 1, i == 1, i == dim)
ti = ti + ti_step
end
if islast then s = s..indent..']'..sep..'\n' else s = s..indent..']'..sep end
s = s..indent..']'..sep
if islast then s = s..'\n' end
else
s = s..indent..'['
for i = ti + 1, ti + dim do s = s..fmt:format(t[i])..sep end
s = s..']'..sep..'\n'
s = s..indent..'['..pp_join(sep, fmt, t, ti + 1, ti + dim)..']'..sep..'\n'
end
return s
end
@ -265,6 +266,20 @@ local function dot(a, b, ax_a, ax_b, out)
return out
end
local function transpose(x, out)
assert(#x.shape == 2) -- TODO: handle ndarrays like numpy.
local rows = x.shape[1]
local cols = x.shape[2]
local y = out or zeros{cols, rows}
-- TODO: simplify? y can be consecutive for sure.
for i = 1, rows do
for j = 1, cols do
y[(j - 1) * rows + i] = x[(i - 1) * cols + j]
end
end
return y
end
-- nodal
local function traverse(node_in, node_out, nodes, dummy_mode)
@ -875,6 +890,7 @@ return {
ppi = ppi,
dot_mv = dot_mv,
dot = dot,
transpose = transpose,
traverse = traverse,
traverse_all = traverse_all,

View File

@ -33,7 +33,7 @@ make_preset{
init_zeros = true,
reduce_tiles = true,
reduce_tiles = 5,
bias_out = false,
deterministic = false,
@ -147,6 +147,15 @@ make_preset{
parent = 'ars',
}
make_preset{
name = 'ars-lips',
parent = 'ars',
ars_lips = true,
-- momentum = 0.5, -- this is default.
param_rate = 1.0,
}
make_preset{
name = 'ars-skip',
parent = 'ars',
@ -155,15 +164,6 @@ make_preset{
prob_frameskip = 0.25,
}
make_preset{
name = 'ars-lips',
parent = 'ars',
ars_lips = true,
momentum = 0.5,
param_rate = 1.0,
}
make_preset{
name = 'ars-big',
parent = 'ars',
@ -204,6 +204,216 @@ make_preset{
momentum = 0.99,
}
-- new stuff for 2019:
make_preset{
name = 'ars-skip-more',
parent = 'ars',
-- old:
--frameskip = 2,
--prob_frameskip = 0.5,
--max_frameskip = 60,
-- new:
frameskip = 3,
prob_frameskip = 0.5,
max_frameskip = 5,
}
make_preset{
name = 'ars-skip-more-3',
parent = 'ars-skip-more',
attempts = 3, -- per trial. score = mean(scores) - stdev(scores)
}
make_preset{
name = 'snes-skip-more-3',
parent = 'snes3',
frameskip = 3,
prob_frameskip = 0.5,
max_frameskip = 5,
attempts = 3,
}
make_preset{
name = 'guided',
parent = 'big-scroll-reduced',
es = 'guided',
epoch_top_trials = 20,
deterministic = true,
deviation = 0.1,
epoch_trials = 20,
param_rate = 0.00368,
param_decay = 0.0,
}
make_preset{
name = 'guided2',
parent = 'guided',
past_grads = 2,
-- after epoch 50, trying this:
--param_rate = 0.05,
-- after epoch 50+20, stepping back to this:
param_rate = 0.01,
}
make_preset{
name = 'guided10',
parent = 'guided',
past_grads = 10,
}
make_preset{
name = 'guided69', -- the nice one
parent = 'guided',
deviation = 0.05,
epoch_top_trials = 10,
epoch_trials = 20,
param_rate = 0.006,
past_grads = 4,
alpha = 0.25,
}
-- TODO: yet another preset. try building up from 1 trial ARS to something good.
make_preset{
name = 'redux',
min_time = 300,
max_time = 300,
timer_loser = 1/1,
score_multiplier = 1,
init_zeros = true,
deterministic = true,
es = 'guided',
past_grads = 2, -- for Guided.
alpha = 0.25, -- for Guided.
ars_lips = false, -- for ARS.
beta = 1.0, -- fix the default.
epoch_top_trials = 4, -- for ARS, Guided.
epoch_trials = 5,
attempts = 1, -- TODO: document.
deviation = 1.0, -- 0.1
base_rate = 1.0,
param_decay = 0.01,
graycode = false, -- for ARS.
min_refresh = 0.1, -- for SNES.
sigma_decay = 0.0, -- for SNES, xNES.
momentum = 0.0, -- for ARS.
}
make_preset{
name = 'redux_big',
parent = 'redux',
time_inputs = true, -- insert binary inputs of a frame counter.
hidden = true, -- use a hidden layer with ReLU/GELU activation.
hidden_size = 64,
layernorm = true, -- use a LayerNorm layer after said activation.
reduce_tiles = false,
bias_out = false,
-- gets stuck pretty quick, so tweak some stuff...
epoch_top_trials = 8,
epoch_trials = 10,
deviation = 1.0,
base_rate = 0.15,
param_decay = 0.05,
past_grads = 4,
alpha = 0.25,
-- well it doesn't get stuck anymore, but regular redux works much better.
}
make_preset{
name = 'guided-skip-more-3',
parent = 'guided',
--param_rate = 0.00368, -- should probably be this instead...
param_rate = 0.01,
frameskip = 3,
prob_frameskip = 0.5,
max_frameskip = 5,
attempts = 3,
}
make_preset{
name = 'guided-skip-more-3-again',
parent = 'guided-skip-more-3',
param_rate = 0.08, --0.0316,
deviation = 0.5,
alpha = 0.1, --0.5,
}
make_preset{
name = 'crazy',
parent = 'big-scroll-reduced',
es = 'guided',
epoch_top_trials = 15,
deterministic = false,
deviation = 1.0,
epoch_trials = 15,
param_rate = 1.0,
param_decay = 0.0,
alpha = 0.0316,
--attempts = 3,
}
make_preset{
name = 'ars-lips2',
parent = 'ars',
ars_lips = true,
--epoch_trials = 10,
param_rate = 0.147,
}
make_preset{
name = 'ars-lips3',
parent = 'ars',
ars_lips = true,
param_rate = 0.5,
deviation = 0.02, -- added after like 272 epochs
param_decay = 0.0276, -- added after like 62 epochs
}
make_preset{
name = 'hard-embed',
parent = 'big-scroll-hidden',
embed = false,
reduce_tiles = 5,
hidden_size = 54,
epoch_top_trials = 20,
deterministic = true,
deviation = 0.01,
epoch_trials = 20,
param_rate = 0.368,
param_decay = 0.0138,
momentum = 0.5,
beta = 1.0,
}
-- end of new stuff
make_preset{
name = 'play',

111
qr.lua Normal file
View File

@ -0,0 +1,111 @@
local min = math.min
local sqrt = math.sqrt
local nn = require "nn"
local dot = nn.dot
local reshape = nn.reshape
local transpose = nn.transpose
local zeros = nn.zeros
local function minor(x, d)
assert(#x.shape == 2)
assert(d <= x.shape[1] and d <= x.shape[2])
local m = zeros(x.shape)
-- fill diagonals.
--for i = 1, d do m[(i - 1) * m.shape[2] + i] = 1 end
for i = 1, d * m.shape[2], m.shape[2] + 1 do m[i] = 1 end
-- copy values.
for i = d + 1, m.shape[1] do
for j = d + 1, m.shape[2] do
local ind = (i - 1) * m.shape[2] + j
m[ind] = x[ind]
end
end
return m
end
local function norm(a) -- vector norm
local sum = 0
for _, v in ipairs(a) do sum = sum + v * v end
return sqrt(sum)
end
local function householder(x)
local rows = x.shape[1]
local cols = x.shape[2]
local iters = min(rows - 1, cols)
local q = nil
local vec = zeros(rows)
local z = x
for k = 1, iters do
z = minor(z, k - 1)
-- extract a column.
for i = 1, rows do vec[i] = z[k + (i - 1) * cols] end
local a = norm(vec)
-- negate the norm if the original diagonal is non-negative.
local ind = (k - 1) * cols + k
if x[ind] > 0 then a = -a end
vec[k] = vec[k] + a
local a = norm(vec)
if a == 0 then a = 1 end -- FIXME: should probably just raise an error.
for i, v in ipairs(vec) do vec[i] = v / a end
-- construct the householder reflection: mat = I - 2 * vec * vec.T
local mat = zeros{rows, rows}
for i = 1, rows do
for j = 1, rows do
local ind = (i - 1) * rows + j
local diag = i == j and 1 or 0
mat[ind] = diag - 2 * vec[i] * vec[j]
end
end
--print(nn.pp(mat, "%9.3f"))
if q == nil then q = mat else q = dot(mat, q) end
z = dot(mat, z)
end
return transpose(q), dot(q, x) -- Q, R
end
local function qr(x)
-- a wrapper for the householder method that will return reduced matrices.
assert(#x.shape == 2)
local q, r = householder(x)
local rows = x.shape[1]
local cols = x.shape[2]
if cols >= rows then return q, r end
-- trim q in-place.
q.shape[2] = cols
local ind = 1
for i = 1, rows do
for j = 1, cols do
--ind = (i - 1) * cols + j
q[ind] = q[(i - 1) * rows + j]
ind = ind + 1
end
end
for i = rows * cols + 1, #q do q[i] = nil end
-- trim r in-place.
r.shape[1] = r.shape[2]
for i = r.shape[1] * r.shape[2] + 1, #r do r[i] = nil end
return q, r
end
return qr

76
qr2.lua Normal file
View File

@ -0,0 +1,76 @@
local min = math.min
local sqrt = math.sqrt
local nn = require "nn"
local transpose = nn.transpose
local zeros = nn.zeros
local function qr(a)
-- FIXME: if first column is exactly zero,
-- and cols > rows, Q @ R will not reconstruct the input.
-- this isn't too bad since an input like that is invalid anyway,
-- but i feel like it should be salvageable.
-- actually the scope of the problem is much larger than that.
-- an input like
--[=[
[[0, 0, 0, 0]
[1, 0, 1, 1]
[0, 0, 2, 2]
[0, 0, 3, 3]]
--]=]
-- will cause a lot of problems. for example, Q @ Q.T won't equal eye(4).
-- hmm. maybe we can detect this and reverse the matmul to identity if necessary?
assert(#a.shape == 2)
local rows = a.shape[1]
local cols = a.shape[2]
local small = min(rows, cols)
local q = transpose(a)
local r = zeros{small, cols}
for i = 1, cols do
local i0 = (i - 1) * rows + 1
local i1 = i * rows
for j = 1, min(i - 1, small) do
local j0 = (j - 1) * rows + 1
local j1 = j * rows
local i_to_j = j0 - i0
local num = 0
local den = 0
for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end
for k = j0, j1 do den = den + q[k] * q[k] end
print(num, den)
if den == 0 then den = 1 end -- TODO: should probably just error.
local x = num / den
r[(j - 1) * cols + i] = x
for k = i0, i1 do q[k] = q[k] - q[k + i_to_j] * x end
end
if i <= small then
local sum = 0
for k = i0, i1 do sum = sum + q[k] * q[k] end
local norm = sqrt(sum)
if norm == 0 then
--norm = 1
--q[i0 + i - 1] = 1 -- FIXME: not robust.
r[(i - 1) * cols + i] = 0
else
for k = i0, i1 do q[k] = q[k] / norm end
r[(i - 1) * cols + i] = norm
end
end
end
for k = small * rows + 1, #q do q[k] = nil end
q.shape[1] = small
return transpose(q), r
end
return qr

72
qr_test.lua Normal file
View File

@ -0,0 +1,72 @@
local globalize = require "strict"
local nn = require "nn"
local qr = require "qr"
local qr2 = require "qr2"
local A
if false then
A = {
12, -51, 4,
6, 167, -68,
-4, 24, -41,
-1, 1, 0,
2, 0, 3,
}
A = nn.reshape(A, 5, 3)
elseif false then
A = {
0, 1, 2, 3, 4,
5, 6, 7, 8, 9,
10, 11, 12, 13, 14
}
A = nn.reshape(A, 5, 3)
else
A = {
1, 2, 0,
2, 3, 1,
3, 4, 0,
4, 5, 1,
5, 6, 0,
}
A = {
1, 0, 0,
2, 0, 1,
3, 0, 0,
4, 0, 1,
5, 0, 0,
}
A = nn.reshape(A, 5, 3)
--A = nn.transpose(A)
end
print("A")
print(nn.pp(A, "%9.4f"))
print()
local Q, R = qr2(A)
print("Q")
print(nn.pp(Q, "%9.4f"))
print()
print("R")
print(nn.pp(R, "%9.4f"))
print()
print("Q @ R")
print(nn.pp(nn.dot(Q, R), "%9.4f"))
print()
--print("Q @ Q.T = I")
--print(nn.pp(nn.dot(Q, nn.transpose(Q)), "%9.4f"))
--print()
--A = nn.reshape(A, 5, 3)
--Q, R = qr(A)

217
smb.lua
View File

@ -1,12 +1,16 @@
-- disassembly used for reference:
-- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
local band = bit.band
local floor = math.floor
local emu = emu
local gui = gui
local util = require("util")
local band = bit.band
local clamp = util.clamp
local empty = util.empty
local emu = emu
local floor = math.floor
local gui = gui
local insert = table.insert
local R = memory.readbyteunsigned
local W = memory.writebyte
local function S(addr) return util.signbyte(R(addr)) end
@ -76,13 +80,84 @@ local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
-8, -38,
}
local tile_embedding = {
-- handmade trinary encoding.
-- we have 57 valid tile types and 27 permutations to work with.
[0x00] = { 0, 0, 0}, -- air
[0x10] = { 1, -1, 1}, -- vertical pipe (top left) (enterable)
[0x11] = { 1, -1, 1}, -- vertical pipe (top right) (enterable)
[0x12] = { 0, -1, 1}, -- vertical pipe (top left)
[0x13] = { 0, -1, 1}, -- vertical pipe (top right)
[0x14] = { 0, -1, -1}, -- vertical pipe (left)
[0x15] = { 0, -1, -1}, -- vertical pipe (right)
[0x16] = {-1, -1, -1}, --
[0x17] = {-1, -1, -1}, --
[0x18] = {-1, -1, -1}, --
[0x19] = {-1, -1, -1}, --
[0x1A] = {-1, -1, -1}, --
[0x1B] = {-1, -1, -1}, --
[0x1C] = { 0, -1, 0}, -- horizontal pipe (top left)
[0x1D] = { 0, -1, 0}, -- horizontal pipe (top)
[0x1E] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (top)
[0x1F] = { 0, -1, 0}, -- horizontal pipe (bottom left)
[0x20] = { 0, -1, 0}, -- horizontal pipe (bottom)
[0x21] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (bottom)
[0x22] = { 0, 0, 0}, --
[0x23] = { 0, 0, 0}, -- block being hit (either breakable or ?)
[0x24] = { 0, 0, 0}, --
[0x25] = { 0, 0, 1}, -- flagpole
[0x26] = { 0, 0, 0}, --
[0x51] = { 1, 1, 0}, -- breakable brick block
[0x52] = { 1, 1, 0}, -- breakable brick block (again?)
[0x54] = { 0, 1, 0}, -- regular ground
[0x55] = { 0, 0, 0}, --
[0x56] = { 0, 0, 0}, --
[0x57] = {-1, 1, 1}, -- star brick block
[0x58] = { 1, 1, -1}, -- coin brick block (many coins)
[0x59] = { 0, 0, 0}, --
[0x5A] = { 0, 0, 0}, --
[0x5B] = { 0, 0, 0}, --
[0x5C] = { 0, 0, 0}, --
[0x5D] = { 1, 1, -1}, -- coin brick block (many coins) (again?)
[0x5E] = { 0, 0, 0}, --
[0x5F] = { 0, 0, 0}, --
[0x60] = {-1, 0, 0}, -- invisible 1-up block
[0x61] = { 0, 1, -1}, -- chocolate block (usually used for stairs)
[0x62] = { 0, 0, 0}, --
[0x63] = { 0, 0, 0}, --
[0x64] = { 0, 0, 0}, --
[0x65] = { 0, 0, 0}, --
[0x66] = { 0, 0, 0}, --
[0x67] = { 0, 0, 0}, --
[0x68] = { 0, 0, 0}, --
[0x69] = { 0, 0, 0}, --
[0x6B] = { 0, 0, 0}, --
[0x6C] = { 0, 0, 0}, --
[0x89] = { 0, 0, 0}, --
[0xC0] = {-1, 1, -1}, -- coin ? block
[0xC1] = {-1, 1, 0}, -- mushroom ? block
[0xC2] = { 0, 0, -1}, -- coin
[0xC3] = { 0, 0, 0}, --
[0xC4] = { 0, 1, 1}, -- hit block
[0xC5] = { 0, 0, 0}, --
}
-- TODO: reinterface to one "input" array visible to main.lua.
local sprite_input = {}
local tile_input = {}
local extra_input = {}
local new_input = {}
local overlay = false
local function embed_tile(t)
local out = new_input
local embedded = tile_embedding[t]
insert(out, embedded[1])
insert(out, embedded[2])
insert(out, embedded[3])
end
local function get_timer()
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
end
@ -130,6 +205,7 @@ end
local function mark_tile(x, y, t)
tile_input[#tile_input+1] = tile_lut[t]
embed_tile(t)
if t == 0 then return end
if overlay then
gui.box(x-8, y-8, x+8, y+8)
@ -289,7 +365,8 @@ local function handle_tiles()
extra_input[#extra_input+1] = tile_scroll_remainder
-- for y = 0, 12 do
-- afaik the bottom row is always a copy of the second to bottom,
-- and the top is always air, so drop those from the inputs:
-- and the top is always air (except underground!),
-- so drop those from the inputs:
for y = 1, 11 do
for x = 0, 16 do
local col = (x + tile_scroll) % 32
@ -306,6 +383,117 @@ local function handle_tiles()
end
end
local function load_level(world, level)
if world == 0 then world = random(1, 8) end
if level == 0 then level = random(1, 4) end
emu.poweron()
local jp_mash = {
up = false, down = false, left = false, right = false,
A = false, B = false, select = false, start = false,
}
while emu.framecount() < 60 do
if emu.framecount() == 32 then
local area = area_lut[world * 10 + level]
W(0x75F, world - 1)
W(0x75C, level - 1)
W(0x760, area)
end
if emu.framecount() == 42 then
W(0x7A0, 0) -- world screen timer (reduces startup time)
end
jp_mash['start'] = emu.framecount() % 2 == 1
joypad.write(1, jp_mash)
emu.frameadvance()
end
end
-- new stuff.
local function tile_from_xy(x, y)
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
local tile_scroll_remainder = R(0x73F) % 16
local tile_x = floor((x + tile_scroll_remainder) / 16)
local tile_y = floor(y / 16)
local tile_t = 0 -- default to air
if tile_y < 0 then
tile_y = 0
elseif tile_y > 12 then
tile_y = 12
else
local col = (tile_x + tile_scroll) % 32
local addr = col < 16 and 0x500 or 0x5D0
tile_t = R(addr + tile_y * 16 + (col % 16))
end
return tile_t, tile_x, tile_y
end
local function new_stuff()
-- obviously very work in progress.
empty(new_input)
local mario_x, mario_y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
-- normalized mario x
-- normalized mario y
insert(new_input, (mario_x - 112) / 64)
insert(new_input, (mario_y - 160) / 96)
-- type of tile we're standing on
-- type of tile we're occupying
gui.box(mario_x, mario_y, mario_x + 16, mario_y + 32)
local mario_tile_t, mario_tile_x, mario_tile_y =
tile_from_xy(mario_x + 8, mario_y - 8)
--[[
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
local tile_scroll_remainder = R(0x73F) % 16
local sx = mario_tile_x * 16 + 8 - tile_scroll_remainder
local sy = mario_tile_y * 16 + 40
gui.box(sx-8, sy-8, sx+8, sy+8)
gui.text(sx-5, sy-3, ("%02X"):format(mario_tile_t), '#FFFFFF', '#00000000')
--]]
embed_tile(new_input, mario_tile_t)
-- type of tile to the right, excluding space (small eyeheight)
-- how tall non-space extends upward from that tile
-- how tall non-space extends downward from that tile
-- type of tile to the right, excluding space (large eyeheight)
-- type of tile to the left, excluding space (small eyeheight)
-- how tall non-space extends upward from that tile
-- how tall non-space extends downward from that tile
-- type of tile to the left, excluding space (large eyeheight)
-- type of enemy (nearest down-right from mario's upper left)
-- normalized enemy x
-- normalized enemy y
-- VISUALIZE:
--local tile_col = R(0x6A0)
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
local tile_scroll_remainder = R(0x73F) % 16
for y = 1, 11 do
for x = 0, 16 do
local col = (x + tile_scroll) % 32
local addr = col < 16 and 0x500 or 0x5D0
local t = R(addr + y * 16 + (col % 16))
local sx = x * 16 + 8 - tile_scroll_remainder
local sy = y * 16 + 40
if t ~= 0 then
gui.box(sx-8, sy-8, sx+8, sy+8)
gui.text(sx-5, sy-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
end
end
end
end
return {
-- TODO: don't expose these; provide interfaces for everything needed.
R=R,
@ -315,6 +503,7 @@ overlay=overlay,
valid_tiles=valid_tiles,
area_lut=area_lut,
embed_tile=embed_tile,
sprite_input=sprite_input,
tile_input=tile_input,
@ -323,16 +512,24 @@ extra_input=extra_input,
get_timer=get_timer,
get_score=get_score,
set_timer=set_timer,
mark_sprite=mark_sprite,
mark_tile=mark_tile,
get_state=get_state,
getxy=getxy,
paused=paused,
get_state=get_state,
advance=advance,
mark_sprite=mark_sprite,
mark_tile=mark_tile,
handle_enemies=handle_enemies,
handle_fireballs=handle_fireballs,
handle_blocks=handle_blocks,
handle_hammers=handle_hammers,
handle_misc=handle_misc,
handle_tiles=handle_tiles,
advance=advance,
load_level=load_level,
new_stuff=new_stuff,
new_input=new_input,
}

136
snes.lua
View File

@ -29,9 +29,12 @@ local normalize_sums = util.normalize_sums
local pdf = util.pdf
local weighted_mann_whitney = util.weighted_mann_whitney
local xnes = require "xnes"
local make_utility = xnes.make_utility
local Snes = Base:extend()
function Snes:init(dims, popsize, base_rate, sigma, antithetic)
function Snes:init(dims, popsize, base_rate, sigma, antithetic, adaptive)
-- heuristic borrowed from CMA-ES:
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
@ -41,9 +44,12 @@ function Snes:init(dims, popsize, base_rate, sigma, antithetic)
self.covar_rate = base_rate
self.sigma = sigma or 1
self.antithetic = antithetic and true or false
self.adaptive = adaptive == nil and true or adaptive
if self.antithetic then self.popsize = self.popsize * 2 end
self.utility = make_utility(self.popsize)
self.rate_init = self.sigma_rate
self.mean = zeros{dims}
@ -147,60 +153,84 @@ function Snes:ask_mix(start_anew)
-- perform importance mixing.
local mean_old = self.mean
local mean_old = self.mean_old or self.mean
local mean_new = self.mean
local std_old = self.std_old or self.std
local std_new = self.std
self.new_asked = {}
self.new_noise = {}
local marked = {}
for p=1, min(#self.old_asked, self.popsize) do
local a = self.old_asked[p]
-- TODO: cache probs?
local function compute_probabilities(a)
local prob_new = 0
local prob_old = 0
for i, v in ipairs(a) do
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
end
return prob_new, prob_old
end
local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1)
if uniform() < accept then
--print(("accepted old sample %i with probability %f"):format(p, accept))
else
-- insert in reverse as not to screw up
-- the indices when removing later.
insert(marked, 1, p)
local all_asked, all_noise, all_score = {}, {}, {}
for p=1, #self.old_asked do
do
local pp = floor(uniform() * #self.old_asked) + 1
local a = self.old_asked[pp]
local prob_new, prob_old = compute_probabilities(a)
local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1)
if uniform() < accept then
--print(("accepted old sample %i with probability %f"):format(pp, accept))
insert(all_asked, a)
insert(all_noise, self.old_noise[pp])
insert(all_score, self.old_score[pp])
end
end
end
for _, p in ipairs(marked) do
remove(self.old_asked, p)
remove(self.old_noise, p)
remove(self.old_score, p)
do
local a, n = {}, {}
for i=1, self.dims do n[i] = normal() end
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
local prob_new, prob_old = compute_probabilities(a)
local accept = max(1 - prob_old / prob_new, self.min_refresh)
if uniform() < accept then
--print(("accepted new sample %i with probability %f"):format(#all_asked, accept))
insert(all_asked, a)
insert(all_noise, n)
insert(all_score, false)
end
end
-- TODO: early stopping, making sure it doesn't affect performance.
end
while #self.old_asked + #self.new_asked < self.popsize do
local a = {}
local n = {}
while #all_asked > self.popsize do
local pp = floor(uniform() * #all_asked) + 1
--print(("removing sample %i to fit popsize"):format(pp))
remove(all_asked, pp)
remove(all_noise, pp)
remove(all_score, pp)
end
while #all_asked < self.popsize do
local a, n = {}, {}
for i=1, self.dims do n[i] = normal() end
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
--print(("unconditionally added new sample %i"):format(#all_asked))
insert(all_asked, a)
insert(all_noise, n)
insert(all_score, false)
end
-- can't cache here!
local prob_new = 0
local prob_old = 0
for i, v in ipairs(a) do
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
end
local accept = max(1 - prob_old / prob_new, self.min_refresh)
if uniform() < accept then
-- split all_ tables back into old_ and new_.
self.old_asked, self.old_noise, self.old_score = {}, {}, {}
self.new_asked, self.new_noise = {}, {}
for i, score in ipairs(all_score) do
local a, n = all_asked[i], all_noise[i]
if score ~= false then
insert(self.old_asked, a)
insert(self.old_noise, n)
insert(self.old_score, score)
else
insert(self.new_asked, a)
insert(self.new_noise, n)
--print(("accepted new sample %i with probability %f"):format(0, accept))
end
end
@ -210,15 +240,15 @@ end
function Snes:tell(scored)
self.evals = self.evals + #scored
local asked = self.asked
local noise = self.noise
local asked = self.mixing and self.new_asked or self.asked
local noise = self.mixing and self.new_noise or self.noise
if self.mixing then
-- note: modifies, in-place, externally exposed tables.
for i, v in ipairs(asked) do insert(self.old_asked, v) end
for i, v in ipairs(noise) do insert(self.old_noise, v) end
for i, v in ipairs(scored) do insert(self.old_score, v) end
asked = self.old_asked
noise = self.old_noise
-- note that these modify tables referenced externally in-place.
for i, v in ipairs(self.new_asked) do insert(asked, v) end
for i, v in ipairs(self.new_noise) do insert(noise, v) end
for i, v in ipairs(scored) do insert(self.old_score, v) end
scored = self.old_score
end
assert(asked and noise, ":tell() called before :ask()")
@ -231,8 +261,9 @@ function Snes:tell(scored)
local g_mean = zeros{self.dims}
local g_std = zeros{self.dims}
--[[
local utilize = true
local utility
local utility = self.utility
if utilize then
utility = {}
@ -242,17 +273,18 @@ function Snes:tell(scored)
else
utility = normalize_sums(scored, {})
end
--]]
for p=1, self.popsize do
local noise_p = noise[p]
local noise_p = noise[arg[p]]
for i, v in ipairs(g_mean) do
g_mean[i] = v + utility[p] * noise_p[i]
g_mean[i] = v + self.utility[p] * noise_p[i]
end
for i, v in ipairs(g_std) do
local n = noise_p[i]
g_std[i] = v + utility[p] * (n * n - 1)
g_std[i] = v + self.utility[p] * (n * n - 1)
end
end
@ -261,7 +293,9 @@ function Snes:tell(scored)
step[i] = self.std[i] * v
end
self.mean_old = {}
for i, v in ipairs(self.mean) do
self.mean_old[i] = v
self.mean[i] = v + self.param_rate * step[i]
end
@ -273,7 +307,7 @@ function Snes:tell(scored)
otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i])
end
self:adapt(asked, otherwise, utility)
if self.adaptive then self:adapt(asked, otherwise, self.utility) end
return step
end
@ -291,15 +325,15 @@ function Snes:adapt(asked, otherwise, qualities)
weights[p] = prob_big / prob_now
end
local p = weighted_mann_whitney(qualities, qualities, nil, weights)
--print("p:", p)
local u, p = weighted_mann_whitney(qualities, qualities, nil, weights)
--print(("u, p: %6.3f, %6.3f"):format(u, p))
if p < 0.5 - 1 / (3 * (self.dims + 1)) then
self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init
print("learning rate -:", self.sigma_rate)
--print("learning rate -:", self.sigma_rate)
else
self.sigma_rate = min(1.1 * self.sigma_rate, 1)
print("learning rate +:", self.sigma_rate)
--print("learning rate +:", self.sigma_rate)
end
end

View File

@ -14,6 +14,7 @@ local random = math.random
local select = select
local sort = table.sort
local sqrt = math.sqrt
local type = type
local function sign(x)
-- remember that 0 is truthy in Lua.
@ -83,6 +84,28 @@ local function calc_mean_dev(x)
return mean, sqrt(dev)
end
local function calc_mean_dev_unbiased(x)
-- NOTE: this uses an approximation; there is still a little bias.
assert(#x > 1)
local mean = 0
for i, v in ipairs(x) do
mean = mean + v / #x
end
-- via Gurland, John; Tripathi, Ram C. (1971):
-- A Simple Approximation for Unbiased Estimation of the Standard Deviation
local divisor = #x - 1.5 + 1 / (8 * (#x - 1))
local dev = 0
for i, v in ipairs(x) do
local delta = v - mean
dev = dev + delta * delta / divisor
end
return mean, sqrt(dev)
end
local function normalize(x, out)
out = out or x
local mean, dev = calc_mean_dev(x)
@ -246,7 +269,31 @@ local function weighted_mann_whitney(s0, s1, w0, w1)
local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6)
local p = cdf((U - mean) / std)
if s0_sum > s1_sum then return 1 - p else return p end
local u = U / (w0_sum * w1_sum)
if s0_sum > s1_sum then return u, 1 - p else return u, p end
end
local function expect_cossim(n)
-- returns gamma(n / 2) / gamma((n + 1) / 2) / sqrt(pi) for positive integers.
-- this is the expected absolute cosine similarity between
-- two standard normally-distributed random vectors both of size n.
assert(n > 0)
-- abs(error) < 1e-8
if n >= 128000 then
return 1 / sqrt(pi / 2 * n + 1)
elseif n >= 80 then
poly = (2.4674010 * n + -2.4673232) * n + 1.2274046
return 1 / sqrt(sqrt(poly))
end
-- fall-through when it's faster just to compute iteratively.
even = n % 2 == 0
res = even and 2.0 or 1.0
for i = even and 2 or 1, n - 1, 2 do
res = res * (i / (i + 1))
end
return even and res / pi or res
end
return {
@ -260,6 +307,7 @@ return {
softchoice=softchoice,
empty=empty,
calc_mean_dev=calc_mean_dev,
calc_mean_dev_unbiased=calc_mean_dev_unbiased,
normalize=normalize,
normalize_wrt=normalize_wrt,
normalize_sums=normalize_sums,
@ -276,4 +324,5 @@ return {
pdf=pdf,
cdf=cdf,
weighted_mann_whitney=weighted_mann_whitney,
expect_cossim=expect_cossim,
}

View File

@ -70,6 +70,8 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
-- note: this is technically the co-standard-deviation.
-- you can imagine the "s" standing for "sqrt" if you like.
self.covars = make_covars(self.dims, self.sigma, self.covars)
self.evals = 0
end
function Xnes:params(new_mean)
@ -153,6 +155,8 @@ function Xnes:tell(scored, noise)
local noise = noise or self.noise
assert(noise, "missing noise argument")
self.evals = self.evals + #scored
local arg = argsort(scored, function(a, b) return a > b end)
local g_delta = zeros{self.dims}
@ -173,7 +177,7 @@ function Xnes:tell(scored, noise)
local zzt = noise_p[i] * noise_p[j] - (i == j and 1 or 0)
local temp = self.utility[p] * zzt
g_covars[ind] = g_covars[ind] + temp
traced = traced + temp
if i == j then traced = traced + temp end
end
end
end
@ -181,7 +185,7 @@ function Xnes:tell(scored, noise)
local g_sigma = traced / self.dims
for i=1, self.dims do
local ind = (i - 1) * self.dims + i
local ind = (i - 1) * self.dims + i -- diagonal
g_covars[ind] = g_covars[ind] - g_sigma
end