-- Separable Natural Evolution Strategies -- this particular implementation is based on: -- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf -- not to be confused with the Super Nintendo Entertainment System. local assert = assert local floor = math.floor local ipairs = ipairs local log = math.log local exp = math.exp local max = math.max local Base = require "Base" local nn = require "nn" local normal = nn.normal local zeros = nn.zeros local util = require "util" local argsort = util.argsort local clamp = util.clamp local Snes = Base:extend() -- NOTE: duplicated in xnes.lua! local function make_utility(popsize, out) local utility = out or {} local temp = log(popsize / 2 + 1) for i=1, popsize do utility[i] = max(0, temp - log(i)) end local sum = 0 for _, v in ipairs(utility) do sum = sum + v end for i, v in ipairs(utility) do utility[i] = v / sum - 1 / popsize end return utility end function Snes:init(dims, popsize, learning_rate, sigma, antithetic) -- heuristic borrowed from CMA-ES: self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) self.sigma = sigma or 1 self.antithetic = antithetic and true or false if self.antithetic then self.popsize = self.popsize * 2 end self.utility = make_utility(self.popsize) self.mean = zeros{dims} self.std = zeros{dims} for i=1, self.dims do self.std[i] = self.sigma end self.mean_adapt = 1.0 end function Snes:params(new_mean) if new_mean ~= nil then assert(#self.mean == #new_mean, "new parameters have the wrong size") for i, v in ipairs(new_mean) do self.mean[i] = v end end return self.mean end function Snes:ask_once(asked, noise) asked = asked or zeros(self.dims) noise = noise or {} for i=1, self.dims do noise[i] = normal() end noise.shape = {#noise} for i, v in ipairs(noise) do asked[i] = self.mean[i] + self.std[i] * v end return asked, noise end function Snes:ask_twice(asked0, asked1, noise0, noise1) asked0 = asked0 or zeros(self.dims) asked1 = asked1 or zeros(self.dims) noise0 = noise0 or {} noise1 = noise1 or {} for i=1, self.dims do noise0[i] = normal() end noise0.shape = {#noise0} for i, v in ipairs(noise0) do asked0[i] = self.mean[i] + self.std[i] * v asked1[i] = self.mean[i] - self.std[i] * v end for i, v in ipairs(noise0) do noise1[i] = -v end return asked0, asked1, noise0, noise1 end -- NOTE: duplicated in xnes.lua! function Snes:ask(asked, noise) -- return a list of parameters for the user to score, -- and later pass to :tell(). if asked == nil then asked = {} for i=1, self.popsize do asked[i] = zeros(self.dims) end end if noise == nil then noise = {} for i=1, self.popsize do noise[i] = zeros(self.dims) end end if self.antithetic then for i=1, self.popsize do self:ask_twice(asked[i+0], asked[i+1], noise[i+0], noise[i+1]) end else for i=1, self.popsize do self:ask_once(asked[i], noise[i]) end end self.noise = noise return asked, noise end function Snes:tell(scored, noise) local noise = noise or self.noise assert(noise, "missing noise argument") local arg = argsort(scored, function(a, b) return a > b end) local g_mean = zeros{self.dims} for p=1, self.popsize do local noise_p = noise[arg[p]] for i, v in ipairs(g_mean) do g_mean[i] = v + self.utility[p] * noise_p[i] end end local g_std = zeros{self.dims} for p=1, self.popsize do local noise_p = noise[arg[p]] for i, v in ipairs(g_std) do local n = noise_p[i] g_std[i] = v + self.utility[p] * (n * n - 1) end end for i, v in ipairs(self.mean) do self.mean[i] = v + self.mean_adapt * self.std[i] * g_mean[i] end for i, v in ipairs(self.std) do self.std[i] = v * exp(self.learning_rate / 2 * g_std[i]) end -- bookkeeping: self.noise = nil end return { make_utility = make_utility, Snes = Snes, }