diff --git a/ars.lua b/ars.lua index 9398afb..cb6f6da 100644 --- a/ars.lua +++ b/ars.lua @@ -10,8 +10,10 @@ local exp = math.exp local floor = math.floor local insert = table.insert local ipairs = ipairs +local log = math.log local max = math.max local print = print +local sqrt = math.sqrt local Base = require "Base" @@ -72,16 +74,15 @@ local function kinda_lipschitz(dir, pos, neg, mid) end function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic, - momentum) + momentum, beta) self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) self.param_rate = base_rate - self.sigma_rate = base_rate - self.covar_rate = base_rate self.sigma = sigma or 1 self.antithetic = antithetic == nil and true or antithetic self.momentum = momentum or 0 + self.beta = beta or 1.0 self.poptop = poptop or popsize assert(self.poptop <= popsize) @@ -189,8 +190,9 @@ function Ars:tell(scored, unperturbed_score) reward = reward / reward_dev end + local scale = reward / self.poptop * self.beta / 2 for j, v in ipairs(noisy) do - step[j] = step[j] + reward * v / self.poptop + step[j] = step[j] + scale * v end end end @@ -200,8 +202,9 @@ function Ars:tell(scored, unperturbed_score) if reward ~= 0 then local noisy = self.noise[ind] + local scale = reward / self.poptop * self.beta for j, v in ipairs(noisy) do - step[j] = step[j] + reward * v / self.poptop + step[j] = step[j] + scale * v end end end diff --git a/config.lua b/config.lua index ee338f3..e04c419 100644 --- a/config.lua +++ b/config.lua @@ -26,15 +26,17 @@ local defaults = { time_inputs = true, -- insert binary inputs of a frame counter. -- network layers: + embed = true, -- set to false to use a hard-coded tile embedding. + reduce_tiles = 0, -- TODO: write description hidden = false, -- use a hidden layer with ReLU/GELU activation. hidden_size = 128, layernorm = false, -- use a LayerNorm layer after said activation. - reduce_tiles = false, bias_out = true, -- network evaluation (sampling joypad): frameskip = 4, prob_frameskip = 0.0, + max_frameskip = 6, -- true greedy epsilon has both deterministic and det_epsilon set. deterministic = false, -- use argmax on outputs instead of random sampling. det_epsilon = false, -- take random actions with probability eps. @@ -42,12 +44,16 @@ local defaults = { -- evolution strategy and non-rate hyperparemeters: es = 'ars', ars_lips = false, -- for ARS. - epoch_top_trials = 9999, -- for ARS. + epoch_top_trials = 9999, -- for ARS, Guided. + alpha = 0.5, -- for Guided. + beta = 2.0, -- for ARS, Guided. should be 1, but defaults to 2 for compat. + past_grads = 1, -- for Guided. keeps a history of n past steps taken. -- sampling: deviation = 1.0, unperturbed_trial = true, -- perform an extra trial without any noise. -- this is good for logging, so i'd recommend it. + attempts = 1, -- TODO: document. epoch_trials = 50, graycode = false, -- for ARS. negate_trials = true, -- try pairs of normal and negated noise directions. @@ -114,5 +120,9 @@ assert(not cfg.ars_lips or cfg.negate_trials, "cfg.negate_trials must be true to use cfg.ars_lips") assert(not (cfg.es == 'snes' and cfg.negate_trials), "cfg.negate_trials is not yet compatible with SNES") +assert(not (cfg.es == 'guided' and cfg.graycode), + "cfg.graycode is not compatible with Guided") +assert(cfg.es ~= 'guided' or cfg.negate_trials, + "cfg.negate_trials must be true to use Guided") return cfg diff --git a/es_test.lua b/es_test.lua index bf76479..264155b 100644 --- a/es_test.lua +++ b/es_test.lua @@ -1,21 +1,27 @@ local floor = math.floor +local insert = table.insert local ipairs = ipairs local log = math.log +local max = math.max local print = print local ars = require("ars") local snes = require("snes") local xnes = require("xnes") +local guided = require("guided") -- try it all out on a dummy problem. +local function typeof(t) return getmetatable(t).__index end + local function square(x) return x * x end -- this function's global minimum is arange(dims) + 1. -- xNES should be able to find it almost exactly. local function spherical(x) local sum = 0 - for i, v in ipairs(x) do sum = sum + square(v - i) end + --for i, v in ipairs(x) do sum = sum + square(v - i) end + for i, v in ipairs(x) do sum = sum + square(v - i / #x) end -- we need to negate this to turn it into a maximization problem. return -sum end @@ -26,10 +32,34 @@ local dims = 100 local popsize = dims + 1 local sigma_init = 0.5 --local es = xnes.Xnes(dims, popsize, 0.1, sigma_init) ---local es = snes.Snes(dims, popsize, 0.1, sigma_init) +local es = snes.Snes(dims, popsize, 0.1, sigma_init) --local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true) -local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true) -es.min_refresh = 0.7 -- FIXME: needs a better interface. +--local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true) +--local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5) +es.min_refresh = 0.5 -- FIXME: needs a better interface. + +if typeof(es) == xnes.Xnes +or typeof(es) == snes.Snes +then + -- use IGO recommendations + local pop5 = max(1, floor(es.popsize / 5)) + + local sum = 0 + for i=1, es.popsize do + local maybe = i < pop5 and 1 or 0 + es.utility[i] = maybe + sum = sum + maybe + end + --for i, v in ipairs(es.utility) do es.utility[i] = v / sum end + + local util = require "util" + util.normalize_sums(es.utility) + + es.param_rate = 0.39 + es.sigma_rate = 0.39 + es.covar_rate = 0.39 + es.adaptive = false +end if false then -- TODO: delete me local nn = require("nn") @@ -64,30 +94,44 @@ local asked = nil -- for caching purposes. local noise = nil -- for caching purposes. local current_cost = spherical(es:params()) +local past_grads = {} +local pgi = 0 +local pgn = 10 + for i=1, iterations do - if getmetatable(es).__index == snes.Snes then + if typeof(es) == snes.Snes and es.min_refresh ~= 1 then asked, noise = es:ask_mix() - elseif getmetatable(es).__index == ars.Ars then + elseif typeof(es) == ars.Ars then asked, noise = es:ask() + elseif typeof(es) == guided.Guided then + asked, noise = es:ask(past_grads) else asked, noise = es:ask(asked, noise) end + local scores = {} for i, v in ipairs(asked) do scores[i] = spherical(v) end - if getmetatable(es).__index == ars.Ars then + if typeof(es) == ars.Ars then es:tell(scores)--, current_cost) -- use lips + elseif typeof(es) == guided.Guided then + local step = es:tell(scores) + + for _, v in ipairs(step) do + past_grads[pgi + 1] = v + pgi = (pgi + 1) % (pgn * #step) + end + past_grads.shape = {floor(#past_grads / #step), #step} else es:tell(scores) end current_cost = spherical(es:params()) - --if i % 100 == 0 then if i % 100 == 0 then local sigma = es.sigma - if getmetatable(es).__index == snes.Snes then + if typeof(es) == snes.Snes then sigma = 0 for i, v in ipairs(es.std) do sigma = sigma + v end sigma = sigma / #es.std diff --git a/guided.lua b/guided.lua new file mode 100644 index 0000000..0d18fab --- /dev/null +++ b/guided.lua @@ -0,0 +1,194 @@ +-- Guided Evolutionary Strategies +-- https://arxiv.org/abs/1806.10230 + +-- this is just ARS extended to utilize gradients +-- approximated from previous iterations. + +-- for simplicity: +-- antithetic is always true +-- momentum is always 0 +-- no graycode/lipschitz nonsense + +local floor = math.floor +local insert = table.insert +local ipairs = ipairs +local max = math.max +local print = print +local sqrt = math.sqrt + +local Base = require "Base" + +local nn = require "nn" +local dot_mv = nn.dot_mv +local transpose = nn.transpose +local normal = nn.normal +local prod = nn.prod +local uniform = nn.uniform +local zeros = nn.zeros + +local qr = require "qr2" + +local util = require "util" +local argsort = util.argsort +local calc_mean_dev = util.calc_mean_dev + +local Guided = Base:extend() + +local function collect_best_indices(scored, top) + -- select one (the best) reward of each pos/neg pair. + local best_rewards + best_rewards = {} + for i = 1, #scored / 2 do + local pos = scored[i * 2 - 1] + local neg = scored[i * 2 - 0] + best_rewards[i] = max(pos, neg) + end + + local indices = argsort(best_rewards, function(a, b) return a > b end) + + for i = top + 1, #best_rewards do indices[i] = nil end + return indices +end + +function Guided:init(dims, popsize, poptop, base_rate, sigma, alpha, beta) + -- sigma: scale of random perturbations. + -- alpha: blend between full parameter space and its gradient subspace. + -- 1.0 is roughly equivalent to ARS. + self.dims = dims + self.popsize = popsize or 4 + (3 * floor(log(dims))) + base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) + self.param_rate = base_rate + self.sigma = sigma or 1.0 + self.alpha = alpha or 0.5 + self.beta = beta or 1.0 + + self.poptop = poptop or popsize + assert(self.poptop <= popsize) + self.popsize = self.popsize * 2 -- antithetic + + self._params = zeros(self.dims) + --self.accum = zeros(self.dims) -- momentum + + self.evals = 0 +end + +function Guided:params(new_params) + if new_params ~= nil then + assert(#self._params == #new_params, "new parameters have the wrong size") + for i, v in ipairs(new_params) do self._params[i] = v end + end + return self._params +end + +function Guided:decay(param_decay, sigma_decay) + -- FIXME: multiplying by sigma probably isn't correct anymore. + -- is this correct now? + if param_decay > 0 then + local scale = self.sigma / sqrt(self.dims) + scale = scale * self.beta + scale = scale * self.param_rate / (self.sigma * self.sigma) + + scale = 1 - param_decay * scale + for i, v in ipairs(self._params) do + self._params[i] = scale * v + end + end +end + +function Guided:ask(grads) + local asked = {} + local noise = {} + + local n_grad = 0 + local gnoise, U, dummy, left, right + if grads ~= nil and #grads > 0 then + n_grad = grads.shape[1] + gnoise = zeros(n_grad) + + U, dummy = qr(transpose(grads)) + --print(nn.pp(transpose(U), "%9.4f")) + + left = sqrt(self.alpha / self.dims) + right = sqrt((1 - self.alpha) / n_grad) + --print(left, right) + end + + for i = 1, self.popsize do + local asking = zeros(self.dims) + local noisy = zeros(self.dims) + asked[i] = asking + noise[i] = noisy + + if i % 2 == 0 then + local old_noisy = noise[i - 1] + for j, v in ipairs(old_noisy) do + noisy[j] = -v + end + elseif n_grad == 0 then + local scale = self.sigma / sqrt(self.dims) + for j = 1, self.dims do + noisy[j] = scale * normal() + end + else + for j = 1, self.dims do noisy[j] = normal() end + for j = 1, n_grad do gnoise[j] = normal() end + local noisier = dot_mv(U, gnoise) + for j, v in ipairs(noisy) do + noisy[j] = self.sigma * (left * v + right * noisier[j]) + end + end + + for j, v in ipairs(self._params) do + asking[j] = v + noisy[j] + end + end + + self.noise = noise + return asked, noise +end + +function Guided:tell(scored, unperturbed_score) + self.evals = self.evals + #scored + + local indices = collect_best_indices(scored, self.poptop) + + local top_rewards = {} + for _, ind in ipairs(indices) do + insert(top_rewards, scored[ind * 2 - 1]) + insert(top_rewards, scored[ind * 2 - 0]) + end + + local step = zeros(self.dims) + local _, reward_dev = calc_mean_dev(top_rewards) + if reward_dev == 0 then reward_dev = 1 end + + for i, ind in ipairs(indices) do + local pos = top_rewards[i * 2 - 1] + local neg = top_rewards[i * 2 - 0] + local reward = pos - neg + if reward ~= 0 then + local noisy = self.noise[ind * 2 - 1] + -- NOTE: technically this reward divide isn't part of guided search. + reward = reward / reward_dev + + local scale = reward / self.poptop * self.beta / 2 + for j, v in ipairs(noisy) do + step[j] = step[j] + scale * v + end + end + end + + local coeff = self.param_rate / (self.sigma * self.sigma) + for i, v in ipairs(self._params) do + self._params[i] = v + coeff * step[i] + end + + self.noise = nil + + return step +end + +return { + --collect_best_indices = collect_best_indices, -- ars.lua has more features + Guided = Guided, +} diff --git a/main.lua b/main.lua index 61df8eb..2cf4ebe 100644 --- a/main.lua +++ b/main.lua @@ -20,6 +20,12 @@ local trial_rewards = {} local trials_remaining = 0 local es -- evolution strategy. +local attempt_i = 0 +local sub_rewards = {} + +local past_grads = {} -- for Guided +local pgi = 0 -- past_grads_index + local trial_frames = 0 local total_frames = 0 local lagless_count = 0 @@ -93,6 +99,7 @@ local util = require("util") local argmax = util.argmax local argsort = util.argsort local calc_mean_dev = util.calc_mean_dev +local calc_mean_dev_unbiased = util.calc_mean_dev_unbiased local clamp = util.clamp local copy = util.copy local empty = util.empty @@ -149,13 +156,21 @@ local network local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z local function make_network(input_size) nn_x = nn.Input({input_size}) - nn_tx = nn.Input({gcfg.tile_count}) - nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2)) + + local embed_dim = cfg.embed and 2 or 3 + + if cfg.embed then + nn_tx = nn.Input({gcfg.tile_count}) + nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, embed_dim)) + else -- new tile inputs. + nn_tx = nn.Input({gcfg.tile_count * 3}) + nn_ty = nn_tx + end nn_tz = nn_ty - if cfg.reduce_tiles then - nn_tz = nn_tz:feed(nn.Reshape{11, 17 * 2}) - nn_tz = nn_tz:feed(nn.DenseBroadcast(5, true)) + if cfg.reduce_tiles > 0 then + nn_tz = nn_tz:feed(nn.Reshape{11, 17 * embed_dim}) + nn_tz = nn_tz:feed(nn.DenseBroadcast(cfg.reduce_tiles, true)) nn_tz = nn_tz:feed(nn.Relu()) -- note: due to a quirk in Merge, we don't need to flatten nn_tz. end @@ -185,6 +200,7 @@ end local ars = require("ars") local snes = require("snes") local xnes = require("xnes") +local guided = require("guided") local function prepare_epoch() trial_neg = false @@ -213,6 +229,8 @@ local function prepare_epoch() local dummy if cfg.es == 'ars' then trial_params, dummy = es:ask(precision) + elseif cfg.es == 'guided' then + trial_params, dummy = es:ask(past_grads) elseif cfg.es == 'snes' then trial_params, dummy = es:ask_mix() else @@ -223,6 +241,7 @@ local function prepare_epoch() end local function load_next_trial() + attempt_i = 1 if cfg.negate_trials then trial_neg = not trial_neg else @@ -272,8 +291,16 @@ local function learn_from_epoch() end local step_mean, step_dev = calc_mean_dev(step) - print("step mean:", step_mean) - print("step stddev:", step_dev) + print(("step mean: %9.6f"):format(step_mean)) + print(("step stddev: %9.6f"):format(step_dev)) + + if cfg.es == 'guided' and cfg.past_grads > 0 then + for _, v in ipairs(step) do + past_grads[pgi + 1] = v + pgi = (pgi + 1) % (cfg.past_grads * #step) + end + past_grads.shape = {floor(#past_grads / #step), #step} + end es:decay(cfg.param_decay, cfg.sigma_decay) @@ -329,65 +356,62 @@ local function joypad_mash(button) joypad.write(1, jp_mash) end -local function loadlevel(world, level) - -- TODO: move to smb.lua. rename to load_level. - if world == 0 then world = random(1, 8) end - if level == 0 then level = random(1, 4) end - emu.poweron() - while emu.framecount() < 60 do - if emu.framecount() == 32 then - local area = game.area_lut[world * 10 + level] - game.W(0x75F, world - 1) - game.W(0x75C, level - 1) - game.W(0x760, area) - end - if emu.framecount() == 42 then - game.W(0x7A0, 0) -- world screen timer (reduces startup time) - end - joypad_mash('start') - emu.frameadvance() - end -end - local function do_reset() local state = game.get_state() -- be a little more descriptive. if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end - if trial_i >= 0 then - if trial_i == 0 then - print('test trial reward:', reward, "("..state..")") - elseif cfg.negate_trials then - --local dir = trial_neg and "negative" or "positive" - --print('trial', trial_i, dir, 'reward:', reward, "("..state..")") + --if cfg.attempts > 1 and attempt_i >= cfg.attempts then + attempt_i = attempt_i + 1 + sub_rewards[#sub_rewards + 1] = reward + --print(sub_rewards) - if trial_neg then - local pos = trial_rewards[#trial_rewards] - local neg = reward - local fmt = "trial %i rewards: %+i, %+i (%s, %s)" - print(fmt:format(floor(trial_i / 2), - pos, neg, last_trial_state, state)) + if #sub_rewards >= cfg.attempts then + if cfg.attempts == 1 then + reward = sub_rewards[1] + else + local sub_mean, sub_std = calc_mean_dev(sub_rewards) + reward = floor(sub_mean) + --local sub_mean, sub_std = calc_mean_dev_unbiased(sub_rewards) + --reward = floor(sub_mean - sub_std) + end + empty(sub_rewards) + + if trial_i >= 0 then + if trial_i == 0 then + print('test trial reward:', reward, "("..state..")") + elseif cfg.negate_trials then + --local dir = trial_neg and "negative" or "positive" + --print('trial', trial_i, dir, 'reward:', reward, "("..state..")") + + if trial_neg then + local pos = trial_rewards[#trial_rewards] + local neg = reward + local fmt = "trial %i rewards: %+i, %+i (%s, %s)" + print(fmt:format(floor(trial_i / 2), + pos, neg, last_trial_state, state)) + end + last_trial_state = state + else + print('trial', trial_i, 'reward:', reward, "("..state..")") + end + + if trial_i == 0 or not cfg.negate_trials then + trial_rewards[trial_i] = reward + else + trial_rewards[#trial_rewards + 1] = reward end - last_trial_state = state - else - print('trial', trial_i, 'reward:', reward, "("..state..")") end - if trial_i == 0 or not cfg.negate_trials then - trial_rewards[trial_i] = reward - else - trial_rewards[#trial_rewards + 1] = reward - end - end - - if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then - if epoch_i > 0 then learn_from_epoch() end - if not cfg.playback_mode then epoch_i = epoch_i + 1 end - prepare_epoch() - collectgarbage() - if any_random then - loadlevel(cfg.starting_world, cfg.starting_level) - state_saved = false + if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then + if epoch_i > 0 then learn_from_epoch() end + if not cfg.playback_mode then epoch_i = epoch_i + 1 end + prepare_epoch() + collectgarbage() + if any_random then + game.load_level(cfg.starting_world, cfg.starting_level) + state_saved = false + end end end @@ -427,7 +451,9 @@ local function do_reset() trial_frames = 0 emu.frameadvance() -- prevents emulator from quirking up. - load_next_trial() + if attempt_i > cfg.attempts then + load_next_trial() + end reset = false end @@ -450,7 +476,7 @@ local function init() if not playing then emu.speedmode("turbo") end if not any_random then - loadlevel(cfg.starting_world, cfg.starting_level) + game.load_level(cfg.starting_world, cfg.starting_level) end params_fn = cfg.params_fn or ('network%07i.txt'):format(network.n_param) @@ -481,7 +507,10 @@ local function init() elseif cfg.es == 'ars' then es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, cfg.base_rate, cfg.deviation, cfg.negate_trials, - cfg.momentum) + cfg.momentum, cfg.beta) + elseif cfg.es == 'guided' then + es = guided.Guided(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, + cfg.base_rate, cfg.deviation, cfg.alpha, cfg.beta) else error("Unknown evolution strategy specified: " + tostring(cfg.es)) end @@ -537,6 +566,7 @@ local function doit(dummy) empty(game.sprite_input) empty(game.tile_input) empty(game.extra_input) + empty(game.new_input) local controllable = game.R(0x757) == 0 and game.R(0x758) == 0 local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5) @@ -616,12 +646,18 @@ local function doit(dummy) for i, v in ipairs(game.extra_input) do insert(X, v / 256) end nn.reshape(X, 1, gcfg.input_size) nn.reshape(game.tile_input, 1, gcfg.tile_count) + nn.reshape(game.new_input, 1, gcfg.tile_count * 3) trial_frames = trial_frames + cfg.frameskip if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then total_frames = total_frames + cfg.frameskip - local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input}) + local outputs + if cfg.embed then + outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input}) + else + outputs = network:forward({[nn_x]=X, [nn_tx]=game.new_input}) + end local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames) if cfg.det_epsilon and random() < eps then @@ -695,8 +731,12 @@ while true do end local delta = lagless_count - last_decision_frame - local doot = jp == nil or delta >= cfg.frameskip - doot = doot and random() >= cfg.prob_frameskip + local doot = true + if jp ~= nil then + doot = delta >= cfg.frameskip + doot = doot and random() >= cfg.prob_frameskip + doot = doot or delta >= cfg.max_frameskip + end doit(not doot) if doot then last_decision_frame = lagless_count end diff --git a/nn.lua b/nn.lua index ecb085e..d702b40 100644 --- a/nn.lua +++ b/nn.lua @@ -1,20 +1,15 @@ local assert = assert -local ceil = math.ceil local cos = math.cos local exp = math.exp -local floor = math.floor local huge = math.huge local insert = table.insert local ipairs = ipairs local log = math.log local max = math.max -local min = math.min local open = io.open -local pairs = pairs local pi = math.pi local print = print local remove = table.remove -local sin = math.sin local sqrt = math.sqrt local tanh = math.tanh local tostring = tostring @@ -105,19 +100,26 @@ end -- ndarray-ish stuff and more involved math +local function pp_join(sep, fmt, t, a, b) + a = a or 1 + b = b or #t + local s = '' + for i = a, b do + s = s..fmt:format(t[i]) + if i ~= b then s = s..sep end + end + return s +end + local function pp(t, fmt, sep, ti, di, depth, isfirst, islast) -- pretty-prints an nd-array. - fmt = fmt or '%10.7f,' + fmt = fmt or '%10.7f' sep = sep or ',' ti = ti or 0 di = di or 1 depth = depth or 0 - if t.shape == nil then - local s = '[' - for i = 1, #t do s = s..fmt:format(t[i]) end - return s..']'..sep..'\n' - end + if t.shape == nil then return '['..pp_join(sep, fmt, t)..']'..sep..'\n' end local dim = t.shape[di] @@ -134,11 +136,10 @@ local function pp(t, fmt, sep, ti, di, depth, isfirst, islast) s = s..pp(t, fmt, sep, ti, di + 1, depth + 1, i == 1, i == dim) ti = ti + ti_step end - if islast then s = s..indent..']'..sep..'\n' else s = s..indent..']'..sep end + s = s..indent..']'..sep + if islast then s = s..'\n' end else - s = s..indent..'[' - for i = ti + 1, ti + dim do s = s..fmt:format(t[i])..sep end - s = s..']'..sep..'\n' + s = s..indent..'['..pp_join(sep, fmt, t, ti + 1, ti + dim)..']'..sep..'\n' end return s end @@ -265,6 +266,20 @@ local function dot(a, b, ax_a, ax_b, out) return out end +local function transpose(x, out) + assert(#x.shape == 2) -- TODO: handle ndarrays like numpy. + local rows = x.shape[1] + local cols = x.shape[2] + local y = out or zeros{cols, rows} + -- TODO: simplify? y can be consecutive for sure. + for i = 1, rows do + for j = 1, cols do + y[(j - 1) * rows + i] = x[(i - 1) * cols + j] + end + end + return y +end + -- nodal local function traverse(node_in, node_out, nodes, dummy_mode) @@ -875,6 +890,7 @@ return { ppi = ppi, dot_mv = dot_mv, dot = dot, + transpose = transpose, traverse = traverse, traverse_all = traverse_all, diff --git a/presets.lua b/presets.lua index bb87166..9a397f9 100644 --- a/presets.lua +++ b/presets.lua @@ -33,7 +33,7 @@ make_preset{ init_zeros = true, - reduce_tiles = true, + reduce_tiles = 5, bias_out = false, deterministic = false, @@ -147,6 +147,15 @@ make_preset{ parent = 'ars', } +make_preset{ + name = 'ars-lips', + parent = 'ars', + + ars_lips = true, +-- momentum = 0.5, -- this is default. + param_rate = 1.0, +} + make_preset{ name = 'ars-skip', parent = 'ars', @@ -155,15 +164,6 @@ make_preset{ prob_frameskip = 0.25, } -make_preset{ - name = 'ars-lips', - parent = 'ars', - - ars_lips = true, - momentum = 0.5, - param_rate = 1.0, -} - make_preset{ name = 'ars-big', parent = 'ars', @@ -204,6 +204,216 @@ make_preset{ momentum = 0.99, } +-- new stuff for 2019: + +make_preset{ + name = 'ars-skip-more', + parent = 'ars', + + -- old: + --frameskip = 2, + --prob_frameskip = 0.5, + --max_frameskip = 60, + -- new: + frameskip = 3, + prob_frameskip = 0.5, + max_frameskip = 5, +} + +make_preset{ + name = 'ars-skip-more-3', + parent = 'ars-skip-more', + + attempts = 3, -- per trial. score = mean(scores) - stdev(scores) +} + +make_preset{ + name = 'snes-skip-more-3', + parent = 'snes3', + + frameskip = 3, + prob_frameskip = 0.5, + max_frameskip = 5, + attempts = 3, +} + +make_preset{ + name = 'guided', + parent = 'big-scroll-reduced', + + es = 'guided', + epoch_top_trials = 20, + deterministic = true, + deviation = 0.1, + epoch_trials = 20, + param_rate = 0.00368, + param_decay = 0.0, +} + +make_preset{ + name = 'guided2', + parent = 'guided', + + past_grads = 2, + -- after epoch 50, trying this: + --param_rate = 0.05, + -- after epoch 50+20, stepping back to this: + param_rate = 0.01, +} + +make_preset{ + name = 'guided10', + parent = 'guided', + + past_grads = 10, +} + +make_preset{ + name = 'guided69', -- the nice one + parent = 'guided', + + deviation = 0.05, + epoch_top_trials = 10, + epoch_trials = 20, + param_rate = 0.006, + + past_grads = 4, + alpha = 0.25, +} + +-- TODO: yet another preset. try building up from 1 trial ARS to something good. +make_preset{ + name = 'redux', + + min_time = 300, + max_time = 300, + timer_loser = 1/1, + + score_multiplier = 1, + + init_zeros = true, + + deterministic = true, + + es = 'guided', + past_grads = 2, -- for Guided. + alpha = 0.25, -- for Guided. + + ars_lips = false, -- for ARS. + beta = 1.0, -- fix the default. + + epoch_top_trials = 4, -- for ARS, Guided. + epoch_trials = 5, + attempts = 1, -- TODO: document. + + deviation = 1.0, -- 0.1 + base_rate = 1.0, + param_decay = 0.01, + + graycode = false, -- for ARS. + min_refresh = 0.1, -- for SNES. + sigma_decay = 0.0, -- for SNES, xNES. + momentum = 0.0, -- for ARS. +} + +make_preset{ + name = 'redux_big', + parent = 'redux', + + time_inputs = true, -- insert binary inputs of a frame counter. + hidden = true, -- use a hidden layer with ReLU/GELU activation. + hidden_size = 64, + layernorm = true, -- use a LayerNorm layer after said activation. + reduce_tiles = false, + bias_out = false, + + -- gets stuck pretty quick, so tweak some stuff... + epoch_top_trials = 8, + epoch_trials = 10, + deviation = 1.0, + base_rate = 0.15, + param_decay = 0.05, + past_grads = 4, + alpha = 0.25, + -- well it doesn't get stuck anymore, but regular redux works much better. +} + +make_preset{ + name = 'guided-skip-more-3', + parent = 'guided', + + --param_rate = 0.00368, -- should probably be this instead... + param_rate = 0.01, + + frameskip = 3, + prob_frameskip = 0.5, + max_frameskip = 5, + attempts = 3, +} + +make_preset{ + name = 'guided-skip-more-3-again', + parent = 'guided-skip-more-3', + + param_rate = 0.08, --0.0316, + deviation = 0.5, + alpha = 0.1, --0.5, +} + +make_preset{ + name = 'crazy', + parent = 'big-scroll-reduced', + + es = 'guided', + epoch_top_trials = 15, + deterministic = false, + deviation = 1.0, + epoch_trials = 15, + param_rate = 1.0, + param_decay = 0.0, + alpha = 0.0316, + --attempts = 3, +} + +make_preset{ + name = 'ars-lips2', + parent = 'ars', + + ars_lips = true, + --epoch_trials = 10, + param_rate = 0.147, +} + +make_preset{ + name = 'ars-lips3', + parent = 'ars', + + ars_lips = true, + param_rate = 0.5, + deviation = 0.02, -- added after like 272 epochs + param_decay = 0.0276, -- added after like 62 epochs +} + +make_preset{ + name = 'hard-embed', + parent = 'big-scroll-hidden', + + embed = false, + reduce_tiles = 5, + hidden_size = 54, + + epoch_top_trials = 20, + deterministic = true, + deviation = 0.01, + epoch_trials = 20, + param_rate = 0.368, + param_decay = 0.0138, + momentum = 0.5, + beta = 1.0, +} + +-- end of new stuff + make_preset{ name = 'play', diff --git a/qr.lua b/qr.lua new file mode 100644 index 0000000..d03ef1e --- /dev/null +++ b/qr.lua @@ -0,0 +1,111 @@ +local min = math.min +local sqrt = math.sqrt + +local nn = require "nn" +local dot = nn.dot +local reshape = nn.reshape +local transpose = nn.transpose +local zeros = nn.zeros + +local function minor(x, d) + assert(#x.shape == 2) + assert(d <= x.shape[1] and d <= x.shape[2]) + + local m = zeros(x.shape) + + -- fill diagonals. + --for i = 1, d do m[(i - 1) * m.shape[2] + i] = 1 end + for i = 1, d * m.shape[2], m.shape[2] + 1 do m[i] = 1 end + + -- copy values. + for i = d + 1, m.shape[1] do + for j = d + 1, m.shape[2] do + local ind = (i - 1) * m.shape[2] + j + m[ind] = x[ind] + end + end + + return m +end + +local function norm(a) -- vector norm + local sum = 0 + for _, v in ipairs(a) do sum = sum + v * v end + return sqrt(sum) +end + +local function householder(x) + local rows = x.shape[1] + local cols = x.shape[2] + local iters = min(rows - 1, cols) + + local q = nil + local vec = zeros(rows) + local z = x + + for k = 1, iters do + z = minor(z, k - 1) + + -- extract a column. + for i = 1, rows do vec[i] = z[k + (i - 1) * cols] end + + local a = norm(vec) + -- negate the norm if the original diagonal is non-negative. + local ind = (k - 1) * cols + k + if x[ind] > 0 then a = -a end + + vec[k] = vec[k] + a + + local a = norm(vec) + if a == 0 then a = 1 end -- FIXME: should probably just raise an error. + for i, v in ipairs(vec) do vec[i] = v / a end + + -- construct the householder reflection: mat = I - 2 * vec * vec.T + local mat = zeros{rows, rows} + for i = 1, rows do + for j = 1, rows do + local ind = (i - 1) * rows + j + local diag = i == j and 1 or 0 + mat[ind] = diag - 2 * vec[i] * vec[j] + end + end + + --print(nn.pp(mat, "%9.3f")) + if q == nil then q = mat else q = dot(mat, q) end + + z = dot(mat, z) + end + + return transpose(q), dot(q, x) -- Q, R +end + +local function qr(x) + -- a wrapper for the householder method that will return reduced matrices. + assert(#x.shape == 2) + + local q, r = householder(x) + + local rows = x.shape[1] + local cols = x.shape[2] + if cols >= rows then return q, r end + + -- trim q in-place. + q.shape[2] = cols + local ind = 1 + for i = 1, rows do + for j = 1, cols do + --ind = (i - 1) * cols + j + q[ind] = q[(i - 1) * rows + j] + ind = ind + 1 + end + end + for i = rows * cols + 1, #q do q[i] = nil end + + -- trim r in-place. + r.shape[1] = r.shape[2] + for i = r.shape[1] * r.shape[2] + 1, #r do r[i] = nil end + + return q, r +end + +return qr diff --git a/qr2.lua b/qr2.lua new file mode 100644 index 0000000..8d5f98b --- /dev/null +++ b/qr2.lua @@ -0,0 +1,76 @@ +local min = math.min +local sqrt = math.sqrt + +local nn = require "nn" +local transpose = nn.transpose +local zeros = nn.zeros + +local function qr(a) + -- FIXME: if first column is exactly zero, + -- and cols > rows, Q @ R will not reconstruct the input. + -- this isn't too bad since an input like that is invalid anyway, + -- but i feel like it should be salvageable. + + -- actually the scope of the problem is much larger than that. + -- an input like + --[=[ + [[0, 0, 0, 0] + [1, 0, 1, 1] + [0, 0, 2, 2] + [0, 0, 3, 3]] + --]=] + -- will cause a lot of problems. for example, Q @ Q.T won't equal eye(4). + -- hmm. maybe we can detect this and reverse the matmul to identity if necessary? + + assert(#a.shape == 2) + local rows = a.shape[1] + local cols = a.shape[2] + local small = min(rows, cols) + + local q = transpose(a) + local r = zeros{small, cols} + + for i = 1, cols do + local i0 = (i - 1) * rows + 1 + local i1 = i * rows + + for j = 1, min(i - 1, small) do + local j0 = (j - 1) * rows + 1 + local j1 = j * rows + local i_to_j = j0 - i0 + + local num = 0 + local den = 0 + for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end + for k = j0, j1 do den = den + q[k] * q[k] end + print(num, den) + if den == 0 then den = 1 end -- TODO: should probably just error. + + local x = num / den + r[(j - 1) * cols + i] = x + for k = i0, i1 do q[k] = q[k] - q[k + i_to_j] * x end + end + + if i <= small then + local sum = 0 + for k = i0, i1 do sum = sum + q[k] * q[k] end + local norm = sqrt(sum) + + if norm == 0 then + --norm = 1 + --q[i0 + i - 1] = 1 -- FIXME: not robust. + r[(i - 1) * cols + i] = 0 + else + for k = i0, i1 do q[k] = q[k] / norm end + r[(i - 1) * cols + i] = norm + end + end + end + + for k = small * rows + 1, #q do q[k] = nil end + q.shape[1] = small + + return transpose(q), r +end + +return qr diff --git a/qr_test.lua b/qr_test.lua new file mode 100644 index 0000000..46faed3 --- /dev/null +++ b/qr_test.lua @@ -0,0 +1,72 @@ +local globalize = require "strict" +local nn = require "nn" +local qr = require "qr" +local qr2 = require "qr2" + +local A + +if false then + A = { + 12, -51, 4, + 6, 167, -68, + -4, 24, -41, + + -1, 1, 0, + 2, 0, 3, + } + A = nn.reshape(A, 5, 3) + +elseif false then + A = { + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14 + } + A = nn.reshape(A, 5, 3) + +else + A = { + 1, 2, 0, + 2, 3, 1, + 3, 4, 0, + 4, 5, 1, + 5, 6, 0, + } + + A = { + 1, 0, 0, + 2, 0, 1, + 3, 0, 0, + 4, 0, 1, + 5, 0, 0, + } + + A = nn.reshape(A, 5, 3) + --A = nn.transpose(A) + +end + +print("A") +print(nn.pp(A, "%9.4f")) +print() + +local Q, R = qr2(A) + +print("Q") +print(nn.pp(Q, "%9.4f")) +print() + +print("R") +print(nn.pp(R, "%9.4f")) +print() + +print("Q @ R") +print(nn.pp(nn.dot(Q, R), "%9.4f")) +print() + +--print("Q @ Q.T = I") +--print(nn.pp(nn.dot(Q, nn.transpose(Q)), "%9.4f")) +--print() + +--A = nn.reshape(A, 5, 3) +--Q, R = qr(A) diff --git a/smb.lua b/smb.lua index a8f46cf..6d22d97 100644 --- a/smb.lua +++ b/smb.lua @@ -1,12 +1,16 @@ -- disassembly used for reference: -- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM -local band = bit.band -local floor = math.floor -local emu = emu -local gui = gui - local util = require("util") + +local band = bit.band +local clamp = util.clamp +local empty = util.empty +local emu = emu +local floor = math.floor +local gui = gui +local insert = table.insert + local R = memory.readbyteunsigned local W = memory.writebyte local function S(addr) return util.signbyte(R(addr)) end @@ -76,13 +80,84 @@ local rotation_offsets = { -- FIXME: not all of these are pixel-perfect. -8, -38, } +local tile_embedding = { + -- handmade trinary encoding. + -- we have 57 valid tile types and 27 permutations to work with. + [0x00] = { 0, 0, 0}, -- air + [0x10] = { 1, -1, 1}, -- vertical pipe (top left) (enterable) + [0x11] = { 1, -1, 1}, -- vertical pipe (top right) (enterable) + [0x12] = { 0, -1, 1}, -- vertical pipe (top left) + [0x13] = { 0, -1, 1}, -- vertical pipe (top right) + [0x14] = { 0, -1, -1}, -- vertical pipe (left) + [0x15] = { 0, -1, -1}, -- vertical pipe (right) + [0x16] = {-1, -1, -1}, -- + [0x17] = {-1, -1, -1}, -- + [0x18] = {-1, -1, -1}, -- + [0x19] = {-1, -1, -1}, -- + [0x1A] = {-1, -1, -1}, -- + [0x1B] = {-1, -1, -1}, -- + [0x1C] = { 0, -1, 0}, -- horizontal pipe (top left) + [0x1D] = { 0, -1, 0}, -- horizontal pipe (top) + [0x1E] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (top) + [0x1F] = { 0, -1, 0}, -- horizontal pipe (bottom left) + [0x20] = { 0, -1, 0}, -- horizontal pipe (bottom) + [0x21] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (bottom) + [0x22] = { 0, 0, 0}, -- + [0x23] = { 0, 0, 0}, -- block being hit (either breakable or ?) + [0x24] = { 0, 0, 0}, -- + [0x25] = { 0, 0, 1}, -- flagpole + [0x26] = { 0, 0, 0}, -- + [0x51] = { 1, 1, 0}, -- breakable brick block + [0x52] = { 1, 1, 0}, -- breakable brick block (again?) + [0x54] = { 0, 1, 0}, -- regular ground + [0x55] = { 0, 0, 0}, -- + [0x56] = { 0, 0, 0}, -- + [0x57] = {-1, 1, 1}, -- star brick block + [0x58] = { 1, 1, -1}, -- coin brick block (many coins) + [0x59] = { 0, 0, 0}, -- + [0x5A] = { 0, 0, 0}, -- + [0x5B] = { 0, 0, 0}, -- + [0x5C] = { 0, 0, 0}, -- + [0x5D] = { 1, 1, -1}, -- coin brick block (many coins) (again?) + [0x5E] = { 0, 0, 0}, -- + [0x5F] = { 0, 0, 0}, -- + [0x60] = {-1, 0, 0}, -- invisible 1-up block + [0x61] = { 0, 1, -1}, -- chocolate block (usually used for stairs) + [0x62] = { 0, 0, 0}, -- + [0x63] = { 0, 0, 0}, -- + [0x64] = { 0, 0, 0}, -- + [0x65] = { 0, 0, 0}, -- + [0x66] = { 0, 0, 0}, -- + [0x67] = { 0, 0, 0}, -- + [0x68] = { 0, 0, 0}, -- + [0x69] = { 0, 0, 0}, -- + [0x6B] = { 0, 0, 0}, -- + [0x6C] = { 0, 0, 0}, -- + [0x89] = { 0, 0, 0}, -- + [0xC0] = {-1, 1, -1}, -- coin ? block + [0xC1] = {-1, 1, 0}, -- mushroom ? block + [0xC2] = { 0, 0, -1}, -- coin + [0xC3] = { 0, 0, 0}, -- + [0xC4] = { 0, 1, 1}, -- hit block + [0xC5] = { 0, 0, 0}, -- +} + -- TODO: reinterface to one "input" array visible to main.lua. local sprite_input = {} local tile_input = {} local extra_input = {} +local new_input = {} local overlay = false +local function embed_tile(t) + local out = new_input + local embedded = tile_embedding[t] + insert(out, embedded[1]) + insert(out, embedded[2]) + insert(out, embedded[3]) +end + local function get_timer() return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA) end @@ -130,6 +205,7 @@ end local function mark_tile(x, y, t) tile_input[#tile_input+1] = tile_lut[t] + embed_tile(t) if t == 0 then return end if overlay then gui.box(x-8, y-8, x+8, y+8) @@ -289,7 +365,8 @@ local function handle_tiles() extra_input[#extra_input+1] = tile_scroll_remainder -- for y = 0, 12 do -- afaik the bottom row is always a copy of the second to bottom, - -- and the top is always air, so drop those from the inputs: + -- and the top is always air (except underground!), + -- so drop those from the inputs: for y = 1, 11 do for x = 0, 16 do local col = (x + tile_scroll) % 32 @@ -306,6 +383,117 @@ local function handle_tiles() end end +local function load_level(world, level) + if world == 0 then world = random(1, 8) end + if level == 0 then level = random(1, 4) end + emu.poweron() + local jp_mash = { + up = false, down = false, left = false, right = false, + A = false, B = false, select = false, start = false, + } + while emu.framecount() < 60 do + if emu.framecount() == 32 then + local area = area_lut[world * 10 + level] + W(0x75F, world - 1) + W(0x75C, level - 1) + W(0x760, area) + end + if emu.framecount() == 42 then + W(0x7A0, 0) -- world screen timer (reduces startup time) + end + jp_mash['start'] = emu.framecount() % 2 == 1 + joypad.write(1, jp_mash) + emu.frameadvance() + end +end + +-- new stuff. + +local function tile_from_xy(x, y) + local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16 + local tile_scroll_remainder = R(0x73F) % 16 + + local tile_x = floor((x + tile_scroll_remainder) / 16) + local tile_y = floor(y / 16) + local tile_t = 0 -- default to air + if tile_y < 0 then + tile_y = 0 + elseif tile_y > 12 then + tile_y = 12 + else + local col = (tile_x + tile_scroll) % 32 + local addr = col < 16 and 0x500 or 0x5D0 + tile_t = R(addr + tile_y * 16 + (col % 16)) + end + return tile_t, tile_x, tile_y +end + +local function new_stuff() + -- obviously very work in progress. + + empty(new_input) + + local mario_x, mario_y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5) + + -- normalized mario x + -- normalized mario y + insert(new_input, (mario_x - 112) / 64) + insert(new_input, (mario_y - 160) / 96) + + -- type of tile we're standing on + -- type of tile we're occupying + + gui.box(mario_x, mario_y, mario_x + 16, mario_y + 32) + + local mario_tile_t, mario_tile_x, mario_tile_y = + tile_from_xy(mario_x + 8, mario_y - 8) + + --[[ + local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16 + local tile_scroll_remainder = R(0x73F) % 16 + local sx = mario_tile_x * 16 + 8 - tile_scroll_remainder + local sy = mario_tile_y * 16 + 40 + gui.box(sx-8, sy-8, sx+8, sy+8) + gui.text(sx-5, sy-3, ("%02X"):format(mario_tile_t), '#FFFFFF', '#00000000') + --]] + + embed_tile(new_input, mario_tile_t) + + -- type of tile to the right, excluding space (small eyeheight) + -- how tall non-space extends upward from that tile + -- how tall non-space extends downward from that tile + -- type of tile to the right, excluding space (large eyeheight) + + -- type of tile to the left, excluding space (small eyeheight) + -- how tall non-space extends upward from that tile + -- how tall non-space extends downward from that tile + -- type of tile to the left, excluding space (large eyeheight) + + -- type of enemy (nearest down-right from mario's upper left) + -- normalized enemy x + -- normalized enemy y + + -- VISUALIZE: + + --local tile_col = R(0x6A0) + local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16 + local tile_scroll_remainder = R(0x73F) % 16 + for y = 1, 11 do + for x = 0, 16 do + local col = (x + tile_scroll) % 32 + local addr = col < 16 and 0x500 or 0x5D0 + local t = R(addr + y * 16 + (col % 16)) + local sx = x * 16 + 8 - tile_scroll_remainder + local sy = y * 16 + 40 + + if t ~= 0 then + gui.box(sx-8, sy-8, sx+8, sy+8) + gui.text(sx-5, sy-3, ("%02X"):format(t), '#FFFFFF', '#00000000') + end + end + end +end + return { -- TODO: don't expose these; provide interfaces for everything needed. R=R, @@ -315,6 +503,7 @@ overlay=overlay, valid_tiles=valid_tiles, area_lut=area_lut, +embed_tile=embed_tile, sprite_input=sprite_input, tile_input=tile_input, @@ -323,16 +512,24 @@ extra_input=extra_input, get_timer=get_timer, get_score=get_score, set_timer=set_timer, -mark_sprite=mark_sprite, -mark_tile=mark_tile, +get_state=get_state, + getxy=getxy, paused=paused, -get_state=get_state, -advance=advance, + +mark_sprite=mark_sprite, +mark_tile=mark_tile, + handle_enemies=handle_enemies, handle_fireballs=handle_fireballs, handle_blocks=handle_blocks, handle_hammers=handle_hammers, handle_misc=handle_misc, handle_tiles=handle_tiles, + +advance=advance, +load_level=load_level, + +new_stuff=new_stuff, +new_input=new_input, } diff --git a/snes.lua b/snes.lua index 38fd3bb..c2c7cd2 100644 --- a/snes.lua +++ b/snes.lua @@ -29,9 +29,12 @@ local normalize_sums = util.normalize_sums local pdf = util.pdf local weighted_mann_whitney = util.weighted_mann_whitney +local xnes = require "xnes" +local make_utility = xnes.make_utility + local Snes = Base:extend() -function Snes:init(dims, popsize, base_rate, sigma, antithetic) +function Snes:init(dims, popsize, base_rate, sigma, antithetic, adaptive) -- heuristic borrowed from CMA-ES: self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) @@ -41,9 +44,12 @@ function Snes:init(dims, popsize, base_rate, sigma, antithetic) self.covar_rate = base_rate self.sigma = sigma or 1 self.antithetic = antithetic and true or false + self.adaptive = adaptive == nil and true or adaptive if self.antithetic then self.popsize = self.popsize * 2 end + self.utility = make_utility(self.popsize) + self.rate_init = self.sigma_rate self.mean = zeros{dims} @@ -147,60 +153,84 @@ function Snes:ask_mix(start_anew) -- perform importance mixing. - local mean_old = self.mean + local mean_old = self.mean_old or self.mean local mean_new = self.mean local std_old = self.std_old or self.std local std_new = self.std - self.new_asked = {} - self.new_noise = {} - - local marked = {} - for p=1, min(#self.old_asked, self.popsize) do - local a = self.old_asked[p] - - -- TODO: cache probs? + local function compute_probabilities(a) local prob_new = 0 local prob_old = 0 for i, v in ipairs(a) do prob_new = prob_new + pdf(v, mean_new[i], std_new[i]) prob_old = prob_old + pdf(v, mean_old[i], std_old[i]) end + return prob_new, prob_old + end - local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1) - if uniform() < accept then - --print(("accepted old sample %i with probability %f"):format(p, accept)) - else - -- insert in reverse as not to screw up - -- the indices when removing later. - insert(marked, 1, p) + local all_asked, all_noise, all_score = {}, {}, {} + + for p=1, #self.old_asked do + do + local pp = floor(uniform() * #self.old_asked) + 1 + local a = self.old_asked[pp] + local prob_new, prob_old = compute_probabilities(a) + local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1) + if uniform() < accept then + --print(("accepted old sample %i with probability %f"):format(pp, accept)) + insert(all_asked, a) + insert(all_noise, self.old_noise[pp]) + insert(all_score, self.old_score[pp]) + end end - end - for _, p in ipairs(marked) do - remove(self.old_asked, p) - remove(self.old_noise, p) - remove(self.old_score, p) + + do + local a, n = {}, {} + for i=1, self.dims do n[i] = normal() end + for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end + local prob_new, prob_old = compute_probabilities(a) + local accept = max(1 - prob_old / prob_new, self.min_refresh) + if uniform() < accept then + --print(("accepted new sample %i with probability %f"):format(#all_asked, accept)) + insert(all_asked, a) + insert(all_noise, n) + insert(all_score, false) + end + end + + -- TODO: early stopping, making sure it doesn't affect performance. end - while #self.old_asked + #self.new_asked < self.popsize do - local a = {} - local n = {} + while #all_asked > self.popsize do + local pp = floor(uniform() * #all_asked) + 1 + --print(("removing sample %i to fit popsize"):format(pp)) + remove(all_asked, pp) + remove(all_noise, pp) + remove(all_score, pp) + end + + while #all_asked < self.popsize do + local a, n = {}, {} for i=1, self.dims do n[i] = normal() end for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end + --print(("unconditionally added new sample %i"):format(#all_asked)) + insert(all_asked, a) + insert(all_noise, n) + insert(all_score, false) + end - -- can't cache here! - local prob_new = 0 - local prob_old = 0 - for i, v in ipairs(a) do - prob_new = prob_new + pdf(v, mean_new[i], std_new[i]) - prob_old = prob_old + pdf(v, mean_old[i], std_old[i]) - end - - local accept = max(1 - prob_old / prob_new, self.min_refresh) - if uniform() < accept then + -- split all_ tables back into old_ and new_. + self.old_asked, self.old_noise, self.old_score = {}, {}, {} + self.new_asked, self.new_noise = {}, {} + for i, score in ipairs(all_score) do + local a, n = all_asked[i], all_noise[i] + if score ~= false then + insert(self.old_asked, a) + insert(self.old_noise, n) + insert(self.old_score, score) + else insert(self.new_asked, a) insert(self.new_noise, n) - --print(("accepted new sample %i with probability %f"):format(0, accept)) end end @@ -210,15 +240,15 @@ end function Snes:tell(scored) self.evals = self.evals + #scored - local asked = self.asked - local noise = self.noise + local asked = self.mixing and self.new_asked or self.asked + local noise = self.mixing and self.new_noise or self.noise if self.mixing then + -- note: modifies, in-place, externally exposed tables. + for i, v in ipairs(asked) do insert(self.old_asked, v) end + for i, v in ipairs(noise) do insert(self.old_noise, v) end + for i, v in ipairs(scored) do insert(self.old_score, v) end asked = self.old_asked noise = self.old_noise - -- note that these modify tables referenced externally in-place. - for i, v in ipairs(self.new_asked) do insert(asked, v) end - for i, v in ipairs(self.new_noise) do insert(noise, v) end - for i, v in ipairs(scored) do insert(self.old_score, v) end scored = self.old_score end assert(asked and noise, ":tell() called before :ask()") @@ -231,8 +261,9 @@ function Snes:tell(scored) local g_mean = zeros{self.dims} local g_std = zeros{self.dims} +--[[ local utilize = true - local utility + local utility = self.utility if utilize then utility = {} @@ -242,17 +273,18 @@ function Snes:tell(scored) else utility = normalize_sums(scored, {}) end +--]] for p=1, self.popsize do - local noise_p = noise[p] + local noise_p = noise[arg[p]] for i, v in ipairs(g_mean) do - g_mean[i] = v + utility[p] * noise_p[i] + g_mean[i] = v + self.utility[p] * noise_p[i] end for i, v in ipairs(g_std) do local n = noise_p[i] - g_std[i] = v + utility[p] * (n * n - 1) + g_std[i] = v + self.utility[p] * (n * n - 1) end end @@ -261,7 +293,9 @@ function Snes:tell(scored) step[i] = self.std[i] * v end + self.mean_old = {} for i, v in ipairs(self.mean) do + self.mean_old[i] = v self.mean[i] = v + self.param_rate * step[i] end @@ -273,7 +307,7 @@ function Snes:tell(scored) otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i]) end - self:adapt(asked, otherwise, utility) + if self.adaptive then self:adapt(asked, otherwise, self.utility) end return step end @@ -291,15 +325,15 @@ function Snes:adapt(asked, otherwise, qualities) weights[p] = prob_big / prob_now end - local p = weighted_mann_whitney(qualities, qualities, nil, weights) - --print("p:", p) + local u, p = weighted_mann_whitney(qualities, qualities, nil, weights) + --print(("u, p: %6.3f, %6.3f"):format(u, p)) if p < 0.5 - 1 / (3 * (self.dims + 1)) then self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init - print("learning rate -:", self.sigma_rate) + --print("learning rate -:", self.sigma_rate) else self.sigma_rate = min(1.1 * self.sigma_rate, 1) - print("learning rate +:", self.sigma_rate) + --print("learning rate +:", self.sigma_rate) end end diff --git a/util.lua b/util.lua index b8b5b89..cbb61ab 100644 --- a/util.lua +++ b/util.lua @@ -14,6 +14,7 @@ local random = math.random local select = select local sort = table.sort local sqrt = math.sqrt +local type = type local function sign(x) -- remember that 0 is truthy in Lua. @@ -83,6 +84,28 @@ local function calc_mean_dev(x) return mean, sqrt(dev) end +local function calc_mean_dev_unbiased(x) + -- NOTE: this uses an approximation; there is still a little bias. + assert(#x > 1) + + local mean = 0 + for i, v in ipairs(x) do + mean = mean + v / #x + end + + -- via Gurland, John; Tripathi, Ram C. (1971): + -- A Simple Approximation for Unbiased Estimation of the Standard Deviation + local divisor = #x - 1.5 + 1 / (8 * (#x - 1)) + + local dev = 0 + for i, v in ipairs(x) do + local delta = v - mean + dev = dev + delta * delta / divisor + end + + return mean, sqrt(dev) +end + local function normalize(x, out) out = out or x local mean, dev = calc_mean_dev(x) @@ -246,7 +269,31 @@ local function weighted_mann_whitney(s0, s1, w0, w1) local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6) local p = cdf((U - mean) / std) - if s0_sum > s1_sum then return 1 - p else return p end + local u = U / (w0_sum * w1_sum) + if s0_sum > s1_sum then return u, 1 - p else return u, p end +end + +local function expect_cossim(n) + -- returns gamma(n / 2) / gamma((n + 1) / 2) / sqrt(pi) for positive integers. + -- this is the expected absolute cosine similarity between + -- two standard normally-distributed random vectors both of size n. + assert(n > 0) + + -- abs(error) < 1e-8 + if n >= 128000 then + return 1 / sqrt(pi / 2 * n + 1) + elseif n >= 80 then + poly = (2.4674010 * n + -2.4673232) * n + 1.2274046 + return 1 / sqrt(sqrt(poly)) + end + -- fall-through when it's faster just to compute iteratively. + + even = n % 2 == 0 + res = even and 2.0 or 1.0 + for i = even and 2 or 1, n - 1, 2 do + res = res * (i / (i + 1)) + end + return even and res / pi or res end return { @@ -260,6 +307,7 @@ return { softchoice=softchoice, empty=empty, calc_mean_dev=calc_mean_dev, + calc_mean_dev_unbiased=calc_mean_dev_unbiased, normalize=normalize, normalize_wrt=normalize_wrt, normalize_sums=normalize_sums, @@ -276,4 +324,5 @@ return { pdf=pdf, cdf=cdf, weighted_mann_whitney=weighted_mann_whitney, + expect_cossim=expect_cossim, } diff --git a/xnes.lua b/xnes.lua index 1e45a8f..ba94127 100644 --- a/xnes.lua +++ b/xnes.lua @@ -70,6 +70,8 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic) -- note: this is technically the co-standard-deviation. -- you can imagine the "s" standing for "sqrt" if you like. self.covars = make_covars(self.dims, self.sigma, self.covars) + + self.evals = 0 end function Xnes:params(new_mean) @@ -153,6 +155,8 @@ function Xnes:tell(scored, noise) local noise = noise or self.noise assert(noise, "missing noise argument") + self.evals = self.evals + #scored + local arg = argsort(scored, function(a, b) return a > b end) local g_delta = zeros{self.dims} @@ -173,7 +177,7 @@ function Xnes:tell(scored, noise) local zzt = noise_p[i] * noise_p[j] - (i == j and 1 or 0) local temp = self.utility[p] * zzt g_covars[ind] = g_covars[ind] + temp - traced = traced + temp + if i == j then traced = traced + temp end end end end @@ -181,7 +185,7 @@ function Xnes:tell(scored, noise) local g_sigma = traced / self.dims for i=1, self.dims do - local ind = (i - 1) * self.dims + i + local ind = (i - 1) * self.dims + i -- diagonal g_covars[ind] = g_covars[ind] - g_sigma end