-- Separable Natural Evolution Strategies -- this particular implementation is based on: -- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf -- not to be confused with the Super Nintendo Entertainment System. local abs = math.abs local assert = assert local exp = math.exp local floor = math.floor local ipairs = ipairs local log = math.log local max = math.max local min = math.min local sqrt = math.sqrt local insert = table.insert local remove = table.remove local Base = require "Base" local nn = require "nn" local normal = nn.normal local uniform = nn.uniform local zeros = nn.zeros local util = require "util" local argsort = util.argsort local cdf = util.cdf local clamp = util.clamp local normalize_sums = util.normalize_sums local pdf = util.pdf local weighted_mann_whitney = util.weighted_mann_whitney local Snes = Base:extend() function Snes:init(dims, popsize, base_rate, sigma, antithetic) -- heuristic borrowed from CMA-ES: self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) self.param_rate = 1.0 self.sigma_rate = base_rate self.covar_rate = base_rate self.sigma = sigma or 1 self.antithetic = antithetic and true or false if self.antithetic then self.popsize = self.popsize * 2 end self.rate_init = self.sigma_rate self.mean = zeros{dims} self.std = zeros{dims} for i=1, self.dims do self.std[i] = self.sigma end self.old_asked = {} self.old_noise = {} self.old_score = {} self.new_asked = {} self.new_noise = {} self.evals = 0 end function Snes:params(new_mean) if new_mean ~= nil then assert(#self.mean == #new_mean, "new parameters have the wrong size") for i, v in ipairs(new_mean) do self.mean[i] = v end end return self.mean end function Snes:decay(param_decay, sigma_decay) if sigma_decay > 0 then for i, v in ipairs(self.std) do self.std[i] = v * (1 - self.sigma_rate * sigma_decay) end end if param_decay > 0 then for i, v in ipairs(self.mean) do self.mean[i] = v * (1 - self.param_rate * param_decay * self.std[i]) end end end function Snes:ask_once(asked, noise) asked = asked or {} noise = noise or {} for i=1, self.dims do noise[i] = normal() end for i, v in ipairs(noise) do asked[i] = self.mean[i] + self.std[i] * v end return asked, noise end function Snes:ask_twice(asked0, asked1, noise0, noise1) asked0 = asked0 or zeros(self.dims) asked1 = asked1 or zeros(self.dims) noise0 = noise0 or {} noise1 = noise1 or {} for i=1, self.dims do noise0[i] = normal() end noise0.shape = {#noise0} for i, v in ipairs(noise0) do asked0[i] = self.mean[i] + self.std[i] * v asked1[i] = self.mean[i] - self.std[i] * v end for i, v in ipairs(noise0) do noise1[i] = -v end return asked0, asked1, noise0, noise1 end function Snes:ask(asked, noise) -- return a list of parameters for the user to score, -- and later pass to :tell(). self.mixing = false if asked == nil then asked = {} for i=1, self.popsize do asked[i] = zeros(self.dims) end end if noise == nil then noise = {} for i=1, self.popsize do noise[i] = zeros(self.dims) end end if self.antithetic then for i=1, self.popsize do self:ask_twice(asked[i+0], asked[i+1], noise[i+0], noise[i+1]) end else for i=1, self.popsize do self:ask_once(asked[i], noise[i]) end end self.asked = asked self.noise = noise return asked, noise end function Snes:ask_mix(start_anew) -- TODO: refactor and merge with :ask()? self.mixing = true if start_anew then self.old_asked = {} self.old_noise = {} self.old_score = {} end -- perform importance mixing. local mean_old = self.mean local mean_new = self.mean local std_old = self.std_old or self.std local std_new = self.std self.new_asked = {} self.new_noise = {} local marked = {} for p=1, min(#self.old_asked, self.popsize) do local a = self.old_asked[p] -- TODO: cache probs? local prob_new = 0 local prob_old = 0 for i, v in ipairs(a) do prob_new = prob_new + pdf(v, mean_new[i], std_new[i]) prob_old = prob_old + pdf(v, mean_old[i], std_old[i]) end local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1) if uniform() < accept then --print(("accepted old sample %i with probability %f"):format(p, accept)) else -- insert in reverse as not to screw up -- the indices when removing later. insert(marked, 1, p) end end for _, p in ipairs(marked) do remove(self.old_asked, p) remove(self.old_noise, p) remove(self.old_score, p) end while #self.old_asked + #self.new_asked < self.popsize do local a = {} local n = {} for i=1, self.dims do n[i] = normal() end for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end -- can't cache here! local prob_new = 0 local prob_old = 0 for i, v in ipairs(a) do prob_new = prob_new + pdf(v, mean_new[i], std_new[i]) prob_old = prob_old + pdf(v, mean_old[i], std_old[i]) end local accept = max(1 - prob_old / prob_new, self.min_refresh) if uniform() < accept then insert(self.new_asked, a) insert(self.new_noise, n) --print(("accepted new sample %i with probability %f"):format(0, accept)) end end return self.new_asked, self.new_noise end function Snes:tell(scored) self.evals = self.evals + #scored local asked = self.asked local noise = self.noise if self.mixing then asked = self.old_asked noise = self.old_noise -- note that these modify tables referenced externally in-place. for i, v in ipairs(self.new_asked) do insert(asked, v) end for i, v in ipairs(self.new_noise) do insert(noise, v) end for i, v in ipairs(scored) do insert(self.old_score, v) end scored = self.old_score end assert(asked and noise, ":tell() called before :ask()") assert(#asked == #noise and #asked == #scored, "length mismatch") assert(#scored == self.popsize) -- TODO: use a proper ranking function. local arg = argsort(scored, function(a, b) return a > b end) local g_mean = zeros{self.dims} local g_std = zeros{self.dims} local utilize = true local utility if utilize then utility = {} local const = log(self.popsize * 0.5 + 1) for i, v in ipairs(arg) do utility[v] = max(const - log(i), 0) end normalize_sums(utility) else utility = normalize_sums(scored, {}) end for p=1, self.popsize do local noise_p = noise[p] for i, v in ipairs(g_mean) do g_mean[i] = v + utility[p] * noise_p[i] end for i, v in ipairs(g_std) do local n = noise_p[i] g_std[i] = v + utility[p] * (n * n - 1) end end local step = {} for i, v in ipairs(g_mean) do step[i] = self.std[i] * v end for i, v in ipairs(self.mean) do self.mean[i] = v + self.param_rate * step[i] end local otherwise = {} self.std_old = {} for i, v in ipairs(self.std) do self.std_old[i] = v self.std[i] = v * exp(self.sigma_rate * 0.5 * g_std[i]) otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i]) end self:adapt(asked, otherwise, utility) return step end function Snes:adapt(asked, otherwise, qualities) local weights = {} for p=1, self.popsize do local asked_p = asked[p] local prob_now = 0 local prob_big = 0 for i, v in ipairs(asked_p) do prob_now = prob_now + pdf(v, self.mean[i], self.std[i]) prob_big = prob_big + pdf(v, self.mean[i], otherwise[i]) end weights[p] = prob_big / prob_now end local p = weighted_mann_whitney(qualities, qualities, nil, weights) --print("p:", p) if p < 0.5 - 1 / (3 * (self.dims + 1)) then self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init print("learning rate -:", self.sigma_rate) else self.sigma_rate = min(1.1 * self.sigma_rate, 1) print("learning rate +:", self.sigma_rate) end end return { Snes = Snes, }