From 3eebbc534a701f79e599e2f192ea505f56111df7 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sun, 10 Jun 2018 16:40:20 +0200 Subject: [PATCH] add SNES optimizer --- main.lua | 27 ++++++++++ snes.lua | 159 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+) create mode 100644 snes.lua diff --git a/main.lua b/main.lua index 1cf985f..3104ddc 100644 --- a/main.lua +++ b/main.lua @@ -175,6 +175,7 @@ end -- learning and evaluation. local ars = require("ars") +local snes = require("snes") local xnes = require("xnes") local function prepare_epoch() @@ -251,6 +252,9 @@ local function learn_from_epoch() if cfg.es == 'xnes' then print("sigma:", es.sigma) + elseif cfg.es == 'snes' then + local sigma_mean, sigma_dev = calc_mean_dev(es.std) + print("sigma:", sigma_mean, sigma_dev) end local step_mean, step_dev = 0, 0 @@ -298,6 +302,14 @@ local function learn_from_epoch() if cfg.enable_network then network:distribute(base_params) network:save(cfg.params_fn) + + if cfg.es == 'snes' then + local std_fn = cfg.params_fn:gsub(".txt", "")..".sigma.txt" + local f = assert(open(std_fn, "w")) + for _, v in ipairs(es.std) do f:write(("%f\n"):format(v)) end + f:close() + end + else print("note: not updating weights in playable mode.") end @@ -458,6 +470,21 @@ local function init() cfg.learning_rate, cfg.deviation, cfg.negate_trials) -- TODO: clean this up into an interface: es.mean_adapt = cfg.mean_adapt + elseif cfg.es == 'snes' then + es = snes.Snes(network.n_param, cfg.epoch_trials, + cfg.learning_rate, cfg.deviation, cfg.negate_trials) + -- TODO: clean this up into an interface: + es.mean_adapt = cfg.mean_adapt + + local std_fn = cfg.params_fn:gsub(".txt", "")..".sigma.txt" + if exists(std_fn) then + local f = assert(open(std_fn, "r")) + for i=1, network.n_param do + es.std[i] = assert(tonumber(assert(f:read()))) + end + f:close() + end + elseif cfg.es == 'ars' then es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, cfg.learning_rate, cfg.deviation, cfg.negate_trials) diff --git a/snes.lua b/snes.lua new file mode 100644 index 0000000..ad28d18 --- /dev/null +++ b/snes.lua @@ -0,0 +1,159 @@ +-- Separable Natural Evolution Strategies +-- this particular implementation is based on: +-- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf +-- not to be confused with the Super Nintendo Entertainment System. + +local assert = assert +local floor = math.floor +local ipairs = ipairs +local log = math.log +local exp = math.exp +local max = math.max + +local Base = require "Base" + +local nn = require "nn" +local normal = nn.normal +local zeros = nn.zeros + +local util = require "util" +local argsort = util.argsort +local clamp = util.clamp + +local Snes = Base:extend() + +-- NOTE: duplicated in xnes.lua! +local function make_utility(popsize, out) + local utility = out or {} + local temp = log(popsize / 2 + 1) + for i=1, popsize do utility[i] = max(0, temp - log(i)) end + local sum = 0 + for _, v in ipairs(utility) do sum = sum + v end + for i, v in ipairs(utility) do utility[i] = v / sum - 1 / popsize end + return utility +end + +function Snes:init(dims, popsize, learning_rate, sigma, antithetic) + -- heuristic borrowed from CMA-ES: + self.dims = dims + self.popsize = popsize or 4 + (3 * floor(log(dims))) + self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) + self.sigma = sigma or 1 + self.antithetic = antithetic and true or false + + if self.antithetic then self.popsize = self.popsize * 2 end + + self.utility = make_utility(self.popsize) + + self.mean = zeros{dims} + self.std = zeros{dims} + for i=1, self.dims do self.std[i] = self.sigma end + + self.mean_adapt = 1.0 +end + +function Snes:params(new_mean) + if new_mean ~= nil then + assert(#self.mean == #new_mean, "new parameters have the wrong size") + for i, v in ipairs(new_mean) do self.mean[i] = v end + end + return self.mean +end + +function Snes:ask_once(asked, noise) + asked = asked or zeros(self.dims) + noise = noise or {} + + for i=1, self.dims do noise[i] = normal() end + noise.shape = {#noise} + + for i, v in ipairs(noise) do asked[i] = self.mean[i] + self.std[i] * v end + + return asked, noise +end + +function Snes:ask_twice(asked0, asked1, noise0, noise1) + asked0 = asked0 or zeros(self.dims) + asked1 = asked1 or zeros(self.dims) + noise0 = noise0 or {} + noise1 = noise1 or {} + + for i=1, self.dims do noise0[i] = normal() end + noise0.shape = {#noise0} + + for i, v in ipairs(noise0) do + asked0[i] = self.mean[i] + self.std[i] * v + asked1[i] = self.mean[i] - self.std[i] * v + end + for i, v in ipairs(noise0) do noise1[i] = -v end + + return asked0, asked1, noise0, noise1 +end + +-- NOTE: duplicated in xnes.lua! +function Snes:ask(asked, noise) + -- return a list of parameters for the user to score, + -- and later pass to :tell(). + if asked == nil then + asked = {} + for i=1, self.popsize do asked[i] = zeros(self.dims) end + end + if noise == nil then + noise = {} + for i=1, self.popsize do noise[i] = zeros(self.dims) end + end + + if self.antithetic then + for i=1, self.popsize do + self:ask_twice(asked[i+0], asked[i+1], noise[i+0], noise[i+1]) + end + else + for i=1, self.popsize do + self:ask_once(asked[i], noise[i]) + end + end + + self.noise = noise + return asked, noise +end + +function Snes:tell(scored, noise) + local noise = noise or self.noise + assert(noise, "missing noise argument") + + local arg = argsort(scored, function(a, b) return a > b end) + + local g_mean = zeros{self.dims} + for p=1, self.popsize do + local noise_p = noise[arg[p]] + for i, v in ipairs(g_mean) do + g_mean[i] = v + self.utility[p] * noise_p[i] + end + end + + local g_std = zeros{self.dims} + for p=1, self.popsize do + local noise_p = noise[arg[p]] + for i, v in ipairs(g_std) do + local n = noise_p[i] + g_std[i] = v + self.utility[p] * (n * n - 1) + end + end + + for i, v in ipairs(self.mean) do + self.mean[i] = v + self.mean_adapt * self.std[i] * g_mean[i] + end + + for i, v in ipairs(self.std) do + self.std[i] = v * exp(self.learning_rate / 2 * g_std[i]) + end + + -- bookkeeping: + self.noise = nil +end + +return { + make_utility = make_utility, + + Snes = Snes, +}