diff --git a/config.lua b/config.lua index 66ceae4..67323fe 100644 --- a/config.lua +++ b/config.lua @@ -36,6 +36,7 @@ local common_cfg = { mean_adapt = 1.0, -- for xNES weight_decay = 0.0, sigma_decay = 0.0, + min_refresh = 0.2, es = 'ars', ars_lips = false, diff --git a/main.lua b/main.lua index e5d17dd..a8a4a95 100644 --- a/main.lua +++ b/main.lua @@ -199,6 +199,7 @@ local function prepare_epoch() elseif cfg.es == 'snes' then local sigma_mean, sigma_dev = calc_mean_dev(es.std) --print("sigma:", sigma_mean, sigma_dev) + print("sigma 50%:", sigma_mean) print("sigma 95%:", sigma_mean + sigma_dev * 1.64485) end @@ -211,6 +212,8 @@ local function prepare_epoch() local dummy if cfg.es == 'ars' then trial_params, dummy = es:ask(precision) + elseif cfg.es == 'snes' then + trial_params, dummy = es:ask_mix() else trial_params, dummy = es:ask() end @@ -288,12 +291,12 @@ local function learn_from_epoch() if cfg.es == 'snes' then if cfg.sigma_decay > 0 then for i, v in ipairs(es.std) do - es.std[i] = v * (1 - cfg.sigma_decay) + es.std[i] = v * (1 - cfg.learning_rate * cfg.sigma_decay) end end if cfg.weight_decay > 0 then for i, v in ipairs(base_params) do - base_params[i] = v * (1 - cfg.weight_decay * es.std[i]) + base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay * es.std[i]) end end else @@ -502,6 +505,7 @@ local function init() cfg.learning_rate, cfg.deviation, cfg.negate_trials) -- TODO: clean this up into an interface: es.mean_adapt = cfg.mean_adapt + es.min_refresh = cfg.min_refresh if exists(std_fn) then local f = assert(open(std_fn, "r")) diff --git a/snes.lua b/snes.lua index ad28d18..1f43b7e 100644 --- a/snes.lua +++ b/snes.lua @@ -3,36 +3,35 @@ -- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf -- not to be confused with the Super Nintendo Entertainment System. +local abs = math.abs local assert = assert +local exp = math.exp local floor = math.floor local ipairs = ipairs local log = math.log -local exp = math.exp local max = math.max +local min = math.min +local sqrt = math.sqrt +local insert = table.insert +local remove = table.remove local Base = require "Base" local nn = require "nn" local normal = nn.normal +local uniform = nn.uniform local zeros = nn.zeros local util = require "util" local argsort = util.argsort +local cdf = util.cdf local clamp = util.clamp +local normalize_sums = util.normalize_sums +local pdf = util.pdf +local weighted_mann_whitney = util.weighted_mann_whitney local Snes = Base:extend() --- NOTE: duplicated in xnes.lua! -local function make_utility(popsize, out) - local utility = out or {} - local temp = log(popsize / 2 + 1) - for i=1, popsize do utility[i] = max(0, temp - log(i)) end - local sum = 0 - for _, v in ipairs(utility) do sum = sum + v end - for i, v in ipairs(utility) do utility[i] = v / sum - 1 / popsize end - return utility -end - function Snes:init(dims, popsize, learning_rate, sigma, antithetic) -- heuristic borrowed from CMA-ES: self.dims = dims @@ -43,13 +42,21 @@ function Snes:init(dims, popsize, learning_rate, sigma, antithetic) if self.antithetic then self.popsize = self.popsize * 2 end - self.utility = make_utility(self.popsize) + self.rate_init = self.learning_rate self.mean = zeros{dims} self.std = zeros{dims} for i=1, self.dims do self.std[i] = self.sigma end self.mean_adapt = 1.0 + + self.old_asked = {} + self.old_noise = {} + self.old_score = {} + self.new_asked = {} + self.new_noise = {} + + self.evals = 0 end function Snes:params(new_mean) @@ -61,12 +68,10 @@ function Snes:params(new_mean) end function Snes:ask_once(asked, noise) - asked = asked or zeros(self.dims) + asked = asked or {} noise = noise or {} for i=1, self.dims do noise[i] = normal() end - noise.shape = {#noise} - for i, v in ipairs(noise) do asked[i] = self.mean[i] + self.std[i] * v end return asked, noise @@ -90,10 +95,10 @@ function Snes:ask_twice(asked0, asked1, noise0, noise1) return asked0, asked1, noise0, noise1 end --- NOTE: duplicated in xnes.lua! function Snes:ask(asked, noise) -- return a list of parameters for the user to score, -- and later pass to :tell(). + self.mixing = false if asked == nil then asked = {} for i=1, self.popsize do asked[i] = zeros(self.dims) end @@ -113,30 +118,129 @@ function Snes:ask(asked, noise) end end + self.asked = asked self.noise = noise return asked, noise end -function Snes:tell(scored, noise) - local noise = noise or self.noise - assert(noise, "missing noise argument") +function Snes:ask_mix(start_anew) + -- TODO: refactor and merge with :ask()? + self.mixing = true + if start_anew then + self.old_asked = {} + self.old_noise = {} + self.old_score = {} + end - local arg = argsort(scored, function(a, b) return a > b end) + -- perform importance mixing. - local g_mean = zeros{self.dims} - for p=1, self.popsize do - local noise_p = noise[arg[p]] - for i, v in ipairs(g_mean) do - g_mean[i] = v + self.utility[p] * noise_p[i] + local mean_old = self.mean + local mean_new = self.mean + local std_old = self.std_old or self.std + local std_new = self.std + + self.new_asked = {} + self.new_noise = {} + + local marked = {} + for p=1, min(#self.old_asked, self.popsize) do + local a = self.old_asked[p] + local n = self.old_noise[p] + + -- TODO: cache probs? + local prob_new = 0 + local prob_old = 0 + for i, v in ipairs(a) do + prob_new = prob_new + pdf(v, mean_new[i], std_new[i]) + prob_old = prob_old + pdf(v, mean_old[i], std_old[i]) + end + + local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1) + if uniform() < accept then + --print(("accepted old sample %i with probability %f"):format(p, accept)) + else + -- insert in reverse as not to screw up + -- the indices when removing later. + insert(marked, 1, p) + end + end + for _, p in ipairs(marked) do + remove(self.old_asked, p) + remove(self.old_noise, p) + remove(self.old_score, p) + end + + while #self.old_asked + #self.new_asked < self.popsize do + local a = {} + local n = {} + for i=1, self.dims do n[i] = normal() end + for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end + + -- can't cache here! + local prob_new = 0 + local prob_old = 0 + for i, v in ipairs(a) do + prob_new = prob_new + pdf(v, mean_new[i], std_new[i]) + prob_old = prob_old + pdf(v, mean_old[i], std_old[i]) + end + + local accept = max(1 - prob_old / prob_new, self.min_refresh) + if uniform() < accept then + insert(self.new_asked, a) + insert(self.new_noise, n) + --print(("accepted new sample %i with probability %f"):format(0, accept)) end end + return self.new_asked, self.new_noise +end + +function Snes:tell(scored) + self.evals = self.evals + #scored + + local asked = self.asked + local noise = self.noise + if self.mixing then + asked = self.old_asked + noise = self.old_noise + -- note that these modify tables referenced externally in-place. + for i, v in ipairs(self.new_asked) do insert(asked, v) end + for i, v in ipairs(self.new_noise) do insert(noise, v) end + for i, v in ipairs(scored) do insert(self.old_score, v) end + scored = self.old_score + end + assert(asked and noise, ":tell() called before :ask()") + assert(#asked == #noise and #asked == #scored, "length mismatch") + assert(#scored == self.popsize) + + -- TODO: use a proper ranking function. + local arg = argsort(scored, function(a, b) return a > b end) + + local g_mean = zeros{self.dims} local g_std = zeros{self.dims} + + local utilize = true + local utility + + if utilize then + utility = {} + local const = log(self.popsize * 0.5 + 1) + for i, v in ipairs(arg) do utility[v] = max(const - log(i), 0) end + normalize_sums(utility) + else + utility = normalize_sums(scored, {}) + end + for p=1, self.popsize do - local noise_p = noise[arg[p]] + local noise_p = noise[p] + + for i, v in ipairs(g_mean) do + g_mean[i] = v + utility[p] * noise_p[i] + end + for i, v in ipairs(g_std) do local n = noise_p[i] - g_std[i] = v + self.utility[p] * (n * n - 1) + g_std[i] = v + utility[p] * (n * n - 1) end end @@ -144,16 +248,42 @@ function Snes:tell(scored, noise) self.mean[i] = v + self.mean_adapt * self.std[i] * g_mean[i] end + local otherwise = {} + self.std_old = {} for i, v in ipairs(self.std) do - self.std[i] = v * exp(self.learning_rate / 2 * g_std[i]) + self.std_old[i] = v + self.std[i] = v * exp(self.learning_rate * 0.5 * g_std[i]) + otherwise[i] = v * exp(self.learning_rate * 0.75 * g_std[i]) end - -- bookkeeping: - self.noise = nil + self:adapt(asked, otherwise, utility) +end + +function Snes:adapt(asked, otherwise, qualities) + local weights = {} + for p=1, self.popsize do + local asked_p = asked[p] + local prob_now = 0 + local prob_big = 0 + for i, v in ipairs(asked_p) do + prob_now = prob_now + pdf(v, self.mean[i], self.std[i]) + prob_big = prob_big + pdf(v, self.mean[i], otherwise[i]) + end + weights[p] = prob_big / prob_now + end + + local p = weighted_mann_whitney(qualities, qualities, nil, weights) + --print("p:", p) + + if p < 0.5 - 1 / (3 * (self.dims + 1)) then + self.learning_rate = 0.9 * self.learning_rate + 0.1 * self.rate_init + print("learning rate -:", self.learning_rate) + else + self.learning_rate = min(1.1 * self.learning_rate, 1) + print("learning rate +:", self.learning_rate) + end end return { - make_utility = make_utility, - Snes = Snes, } diff --git a/util.lua b/util.lua index e03802c..89b1298 100644 --- a/util.lua +++ b/util.lua @@ -207,6 +207,43 @@ local function cdf(x) return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x))) end +local function weighted_mann_whitney(s0, s1, w0, w1) + -- when w0 and w1 are nil, this decomposes(?) to the regular Mann-Whitney. + if w0 == nil then + w0 = {} + for i=1, #s0 do w0[i] = 1.0 end + end + if w1 == nil then + w1 = {} + for i=1, #s1 do w1[i] = 1.0 end + end + assert(#s0 == #w0) + assert(#s1 == #w1) + + local s0_sum, s1_sum, w0_sum, w1_sum = 0, 0, 0, 0 + for i, v in ipairs(s0) do s0_sum = s0_sum + v end + for i, v in ipairs(s1) do s1_sum = s1_sum + v end + for i, v in ipairs(w0) do w0_sum = w0_sum + v end + for i, v in ipairs(w1) do w1_sum = w1_sum + v end + + local U = 0 + for i=1, #s0 do + for j=1, #s1 do + if s0[i] > s1[j] then + U = U + w0[i] * w1[j] + elseif s0[i] == s1[j] then + U = U + w0[i] * w1[j] * 0.5 + end + end + end + + local mean = w0_sum * w1_sum * 0.5 + local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6) + local p = cdf((U - mean) / std) + + if s0_sum > s1_sum then return 1 - p else return p end +end + return { signbyte=signbyte, boolean_xor=boolean_xor, @@ -232,4 +269,5 @@ return { exists=exists, pdf=pdf, cdf=cdf, + weighted_mann_whitney=weighted_mann_whitney, }