diff --git a/ars.lua b/ars.lua index cc9431a..99e3c46 100644 --- a/ars.lua +++ b/ars.lua @@ -201,7 +201,7 @@ function Ars:tell(scored, unperturbed_score) self._params[i] = v + self.learning_rate * step[i] end - self.asked = nil + self.noise = nil end return { diff --git a/config.lua b/config.lua index a3b00ba..3106c45 100644 --- a/config.lua +++ b/config.lua @@ -90,6 +90,8 @@ assert(not cfg.ars_lips or cfg.unperturbed_trial, "cfg.unperturbed_trial must be true to use cfg.ars_lips") assert(not cfg.ars_lips or cfg.negate_trials, "cfg.negate_trials must be true to use cfg.ars_lips") +assert(not cfg.es == 'xnes' or not cfg.negate_trials, + "cfg.negate_trials is not yet compatible with xNES") assert(not cfg.adamant, "cfg.adamant not yet re-implemented") diff --git a/main.lua b/main.lua index e51e169..0b92d01 100644 --- a/main.lua +++ b/main.lua @@ -50,27 +50,28 @@ local last_trial_state -- localize some stuff. -local assert = assert -local print = print -local ipairs = ipairs -local pairs = pairs -local select = select -local open = io.open local abs = math.abs -local floor = math.floor +local assert = assert local ceil = math.ceil -local min = math.min -local max = math.max +local collectgarbage = collectgarbage local exp = math.exp -local pow = math.pow +local floor = math.floor +local insert = table.insert +local ipairs = ipairs local log = math.log -local sqrt = math.sqrt +local max = math.max +local min = math.min +local open = io.open +local pairs = pairs +local pow = math.pow +local print = print local random = math.random local randomseed = math.randomseed -local insert = table.insert local remove = table.remove -local unpack = table.unpack or unpack +local select = select local sort = table.sort +local sqrt = math.sqrt +local unpack = table.unpack or unpack local band = bit.band local bor = bit.bor @@ -173,6 +174,7 @@ end -- learning and evaluation. local ars = require("ars") +local xnes = require("xnes") local function prepare_epoch() trial_neg = false @@ -180,7 +182,7 @@ local function prepare_epoch() base_params = network:collect() if cfg.playback_mode then return end - print('preparing epoch '..tostring(epoch_i)..'.') + print('preparing epoch '..tostring(epoch_i)..'...') empty(trial_rewards) local precision @@ -190,7 +192,7 @@ local function prepare_epoch() end local dummy - if es == 'ars' then + if cfg.es == 'ars' then trial_params, dummy = es:ask(precision) else trial_params, dummy = es:ask() @@ -236,12 +238,16 @@ local function learn_from_epoch() end end - if es == 'ars' then + if cfg.es == 'ars' and cfg.ars_lips then es:tell(trial_rewards, current_cost) else es:tell(trial_rewards) end + if cfg.es == 'xnes' then + print("sigma:", es.sigma) + end + local step_mean, step_dev = 0, 0 --[[ TODO local step_mean, step_dev = calc_mean_dev(step) @@ -369,6 +375,7 @@ local function do_reset() if epoch_i > 0 then learn_from_epoch() end if not cfg.playback_mode then epoch_i = epoch_i + 1 end prepare_epoch() + collectgarbage() if any_random then loadlevel(cfg.starting_world, cfg.starting_level) state_saved = false @@ -438,7 +445,14 @@ local function init() local res, err = pcall(network.load, network, cfg.params_fn) if res == false then print(err) end - if cfg.es == 'ars' then + if cfg.es == 'xnes' then + -- if you get an out of memory error, you can't use xNES. sorry! + -- maybe there'll be a patch for FCEUX in the future. + local trials = cfg.epoch_trials + if cfg.negate_trials then trials = trials * 2 end + es = xnes.Xnes(network.n_param, trials, cfg.learning_rate, + cfg.deviation, cfg.negate_trials) + elseif cfg.es == 'ars' then es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials, cfg.learning_rate, cfg.deviation, cfg.negate_trials) else diff --git a/xnes.lua b/xnes.lua new file mode 100644 index 0000000..b3c2fe5 --- /dev/null +++ b/xnes.lua @@ -0,0 +1,206 @@ +-- Exponential Natural Evolution Strategies +-- http://people.idsia.ch/~juergen/xNES2010gecco.pdf +-- not to be confused with the Nintendo Entertainment System. + +local assert = assert +local exp = math.exp +local floor = math.floor +local ipairs = ipairs +local log = math.log +local max = math.max +local pairs = pairs +local pow = math.pow +local sqrt = math.sqrt +local unpack = table.unpack or unpack + +local Base = require "Base" + +local nn = require "nn" +local normal = nn.normal +local zeros = nn.zeros + +local util = require "util" +local argsort = util.argsort + +local Xnes = Base:extend() + +local function dot_mv(mat, vec, out) + -- treats matrix as a matrix. + -- treats vec as a column vector, flattened. + assert(#mat.shape == 2) + local d0, d1 = unpack(mat.shape) + assert(d1 == #vec) + + local out_shape = {d0} + if out == nil then + out = zeros(out_shape) + else + assert(d0 == #out, "given output is the wrong size") + end + + for i=1, d0 do + local sum = 0 + for j=1, d1 do + sum = sum + mat[(i - 1) * d1 + j] * vec[j] + end + out[i] = sum + end + + return out +end + +local function make_utility(popsize, out) + local utility = out or {} + local temp = log(popsize / 2 + 1) + for i=1, popsize do utility[i] = max(0, temp - log(i)) end + local sum = 0 + for _, v in ipairs(utility) do sum = sum + v end + for i, v in ipairs(utility) do utility[i] = v / sum - 1 / popsize end + return utility +end + +local function make_covars(dims, sigma, out) + local covars = out or zeros{dims, dims} + local c = sigma / dims + -- simplified form of the determinant of the matrix we're going to create: + local det = pow(1 - c, dims - 1) * (c * (dims - 1) + 1) + -- multiplying by this constant makes the determinant 1: + local m = pow(1 / det, 1 / dims) + + local filler = c * m + for i=1, #covars do covars[i] = filler end + -- diagonals: + for i=1, dims do covars[i + dims * (i - 1)] = m end + + return covars +end + +function Xnes:init(dims, popsize, learning_rate, sigma) + -- heuristic borrowed from CMA-ES: + self.dims = dims + self.popsize = popsize or 4 + (3 * floor(log(dims))) + self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) + self.sigma = sigma or 1 + + self.utility = make_utility(self.popsize) + + self.mean = zeros{dims} + -- note: this is technically the co-standard-deviation. + -- you can imagine the "s" standing for "sqrt" if you like. + self.covars = make_covars(self.dims, self.sigma, self.covars) + + --self.log_sigma = log(self.sigma) + --self.log_covars = zeros{dims, dims} + --for i, v in ipairs(self.covars) do self.log_covars[i] = log(v) end +end + +function Xnes:params(new_mean, new_covars) + if new_mean ~= nil then + assert(#self.mean == #new_mean, "new parameters have the wrong size") + for i, v in ipairs(new_mean) do self.mean[i] = v end + end + if new_covars ~= nil then + -- TODO: assert determinant of new_covars is 1. + error("TODO") + end + return self.mean +end + +function Xnes:ask_once(asked, noise) + asked = asked or zeros(self.dims) + noise = noise or {} + + for i=1, self.dims do noise[i] = normal() end + noise.shape = {#noise} + + dot_mv(self.covars, noise, asked) + for i, v in ipairs(asked) do asked[i] = self.mean[i] + self.sigma * v end + + return asked, noise +end + +function Xnes:ask(asked, noise) + -- return a list of parameters for the user to score, + -- and later pass to :tell(). + if asked == nil then + asked = {} + for i=1, self.popsize do asked[i] = zeros(self.dims) end + end + if noise == nil then + noise = {} + for i=1, self.popsize do noise[i] = zeros(self.dims) end + end + for i=1, self.popsize do self:ask_once(asked[i], noise[i]) end + self.noise = noise + return asked, noise +end + +function Xnes:tell(scored, noise) + local noise = noise or self.noise + assert(noise, "missing noise argument") + + local arg = argsort(scored, function(a, b) return a > b end) + + local g_delta = zeros{self.dims} + for p=1, self.popsize do + local noise_p = noise[arg[p]] + for i=1, self.dims do + g_delta[i] = g_delta[i] + self.utility[p] * noise_p[i] + end + end + + local g_covars = zeros{self.dims, self.dims} + local traced = 0 + for p=1, self.popsize do + local noise_p = noise[arg[p]] + for i=1, self.dims do + for j=1, self.dims do + local ind = (i - 1) * self.dims + j + local zzt = noise_p[i] * noise_p[j] - (i == j and 1 or 0) + local temp = self.utility[p] * zzt + g_covars[ind] = g_covars[ind] + temp + traced = traced + temp + end + end + end + + local g_sigma = traced / self.dims + + for i=1, self.dims do + local ind = (i - 1) * self.dims + i + g_covars[ind] = g_covars[ind] - g_sigma + end + + -- finally, update according to the gradients. + + local dotted = dot_mv(self.covars, g_delta) + for i, v in ipairs(self.mean) do + self.mean[i] = v + self.sigma * dotted[i] + end + + --[[ + --self.log_sigma = self.log_sigma + self.learning_rate / 2 * g_sigma + for i, v in ipairs(self.log_covars) do + self.log_covars[i] = v + lr * g_covars[i] + end + --]] + + local lr = self.learning_rate * 0.5 + self.sigma = self.sigma * exp(lr * g_sigma) + for i, v in ipairs(self.covars) do + self.covars[i] = v * exp(lr * g_covars[i]) + end + + -- bookkeeping: + --self.sigma = exp(self.log_sigma) + --for i, v in ipairs(self.log_covars) do self.covars[i] = exp(v) end + self.noise = nil +end + +return { + dot_mv = dot_mv, + make_utility = make_utility, + make_covars = make_covars, + + Xnes = Xnes, +}