add xNES optimizer
This commit is contained in:
parent
fe9494b0d5
commit
bcb6cb9da1
4 changed files with 240 additions and 18 deletions
2
ars.lua
2
ars.lua
|
@ -201,7 +201,7 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
self._params[i] = v + self.learning_rate * step[i]
|
self._params[i] = v + self.learning_rate * step[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
self.asked = nil
|
self.noise = nil
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -90,6 +90,8 @@ assert(not cfg.ars_lips or cfg.unperturbed_trial,
|
||||||
"cfg.unperturbed_trial must be true to use cfg.ars_lips")
|
"cfg.unperturbed_trial must be true to use cfg.ars_lips")
|
||||||
assert(not cfg.ars_lips or cfg.negate_trials,
|
assert(not cfg.ars_lips or cfg.negate_trials,
|
||||||
"cfg.negate_trials must be true to use cfg.ars_lips")
|
"cfg.negate_trials must be true to use cfg.ars_lips")
|
||||||
|
assert(not cfg.es == 'xnes' or not cfg.negate_trials,
|
||||||
|
"cfg.negate_trials is not yet compatible with xNES")
|
||||||
|
|
||||||
assert(not cfg.adamant,
|
assert(not cfg.adamant,
|
||||||
"cfg.adamant not yet re-implemented")
|
"cfg.adamant not yet re-implemented")
|
||||||
|
|
48
main.lua
48
main.lua
|
@ -50,27 +50,28 @@ local last_trial_state
|
||||||
|
|
||||||
-- localize some stuff.
|
-- localize some stuff.
|
||||||
|
|
||||||
local assert = assert
|
|
||||||
local print = print
|
|
||||||
local ipairs = ipairs
|
|
||||||
local pairs = pairs
|
|
||||||
local select = select
|
|
||||||
local open = io.open
|
|
||||||
local abs = math.abs
|
local abs = math.abs
|
||||||
local floor = math.floor
|
local assert = assert
|
||||||
local ceil = math.ceil
|
local ceil = math.ceil
|
||||||
local min = math.min
|
local collectgarbage = collectgarbage
|
||||||
local max = math.max
|
|
||||||
local exp = math.exp
|
local exp = math.exp
|
||||||
local pow = math.pow
|
local floor = math.floor
|
||||||
|
local insert = table.insert
|
||||||
|
local ipairs = ipairs
|
||||||
local log = math.log
|
local log = math.log
|
||||||
local sqrt = math.sqrt
|
local max = math.max
|
||||||
|
local min = math.min
|
||||||
|
local open = io.open
|
||||||
|
local pairs = pairs
|
||||||
|
local pow = math.pow
|
||||||
|
local print = print
|
||||||
local random = math.random
|
local random = math.random
|
||||||
local randomseed = math.randomseed
|
local randomseed = math.randomseed
|
||||||
local insert = table.insert
|
|
||||||
local remove = table.remove
|
local remove = table.remove
|
||||||
local unpack = table.unpack or unpack
|
local select = select
|
||||||
local sort = table.sort
|
local sort = table.sort
|
||||||
|
local sqrt = math.sqrt
|
||||||
|
local unpack = table.unpack or unpack
|
||||||
|
|
||||||
local band = bit.band
|
local band = bit.band
|
||||||
local bor = bit.bor
|
local bor = bit.bor
|
||||||
|
@ -173,6 +174,7 @@ end
|
||||||
-- learning and evaluation.
|
-- learning and evaluation.
|
||||||
|
|
||||||
local ars = require("ars")
|
local ars = require("ars")
|
||||||
|
local xnes = require("xnes")
|
||||||
|
|
||||||
local function prepare_epoch()
|
local function prepare_epoch()
|
||||||
trial_neg = false
|
trial_neg = false
|
||||||
|
@ -180,7 +182,7 @@ local function prepare_epoch()
|
||||||
base_params = network:collect()
|
base_params = network:collect()
|
||||||
if cfg.playback_mode then return end
|
if cfg.playback_mode then return end
|
||||||
|
|
||||||
print('preparing epoch '..tostring(epoch_i)..'.')
|
print('preparing epoch '..tostring(epoch_i)..'...')
|
||||||
empty(trial_rewards)
|
empty(trial_rewards)
|
||||||
|
|
||||||
local precision
|
local precision
|
||||||
|
@ -190,7 +192,7 @@ local function prepare_epoch()
|
||||||
end
|
end
|
||||||
|
|
||||||
local dummy
|
local dummy
|
||||||
if es == 'ars' then
|
if cfg.es == 'ars' then
|
||||||
trial_params, dummy = es:ask(precision)
|
trial_params, dummy = es:ask(precision)
|
||||||
else
|
else
|
||||||
trial_params, dummy = es:ask()
|
trial_params, dummy = es:ask()
|
||||||
|
@ -236,12 +238,16 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
if es == 'ars' then
|
if cfg.es == 'ars' and cfg.ars_lips then
|
||||||
es:tell(trial_rewards, current_cost)
|
es:tell(trial_rewards, current_cost)
|
||||||
else
|
else
|
||||||
es:tell(trial_rewards)
|
es:tell(trial_rewards)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if cfg.es == 'xnes' then
|
||||||
|
print("sigma:", es.sigma)
|
||||||
|
end
|
||||||
|
|
||||||
local step_mean, step_dev = 0, 0
|
local step_mean, step_dev = 0, 0
|
||||||
--[[ TODO
|
--[[ TODO
|
||||||
local step_mean, step_dev = calc_mean_dev(step)
|
local step_mean, step_dev = calc_mean_dev(step)
|
||||||
|
@ -369,6 +375,7 @@ local function do_reset()
|
||||||
if epoch_i > 0 then learn_from_epoch() end
|
if epoch_i > 0 then learn_from_epoch() end
|
||||||
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
|
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
|
||||||
prepare_epoch()
|
prepare_epoch()
|
||||||
|
collectgarbage()
|
||||||
if any_random then
|
if any_random then
|
||||||
loadlevel(cfg.starting_world, cfg.starting_level)
|
loadlevel(cfg.starting_world, cfg.starting_level)
|
||||||
state_saved = false
|
state_saved = false
|
||||||
|
@ -438,7 +445,14 @@ local function init()
|
||||||
local res, err = pcall(network.load, network, cfg.params_fn)
|
local res, err = pcall(network.load, network, cfg.params_fn)
|
||||||
if res == false then print(err) end
|
if res == false then print(err) end
|
||||||
|
|
||||||
if cfg.es == 'ars' then
|
if cfg.es == 'xnes' then
|
||||||
|
-- if you get an out of memory error, you can't use xNES. sorry!
|
||||||
|
-- maybe there'll be a patch for FCEUX in the future.
|
||||||
|
local trials = cfg.epoch_trials
|
||||||
|
if cfg.negate_trials then trials = trials * 2 end
|
||||||
|
es = xnes.Xnes(network.n_param, trials, cfg.learning_rate,
|
||||||
|
cfg.deviation, cfg.negate_trials)
|
||||||
|
elseif cfg.es == 'ars' then
|
||||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||||
else
|
else
|
||||||
|
|
206
xnes.lua
Normal file
206
xnes.lua
Normal file
|
@ -0,0 +1,206 @@
|
||||||
|
-- Exponential Natural Evolution Strategies
|
||||||
|
-- http://people.idsia.ch/~juergen/xNES2010gecco.pdf
|
||||||
|
-- not to be confused with the Nintendo Entertainment System.
|
||||||
|
|
||||||
|
local assert = assert
|
||||||
|
local exp = math.exp
|
||||||
|
local floor = math.floor
|
||||||
|
local ipairs = ipairs
|
||||||
|
local log = math.log
|
||||||
|
local max = math.max
|
||||||
|
local pairs = pairs
|
||||||
|
local pow = math.pow
|
||||||
|
local sqrt = math.sqrt
|
||||||
|
local unpack = table.unpack or unpack
|
||||||
|
|
||||||
|
local Base = require "Base"
|
||||||
|
|
||||||
|
local nn = require "nn"
|
||||||
|
local normal = nn.normal
|
||||||
|
local zeros = nn.zeros
|
||||||
|
|
||||||
|
local util = require "util"
|
||||||
|
local argsort = util.argsort
|
||||||
|
|
||||||
|
local Xnes = Base:extend()
|
||||||
|
|
||||||
|
local function dot_mv(mat, vec, out)
|
||||||
|
-- treats matrix as a matrix.
|
||||||
|
-- treats vec as a column vector, flattened.
|
||||||
|
assert(#mat.shape == 2)
|
||||||
|
local d0, d1 = unpack(mat.shape)
|
||||||
|
assert(d1 == #vec)
|
||||||
|
|
||||||
|
local out_shape = {d0}
|
||||||
|
if out == nil then
|
||||||
|
out = zeros(out_shape)
|
||||||
|
else
|
||||||
|
assert(d0 == #out, "given output is the wrong size")
|
||||||
|
end
|
||||||
|
|
||||||
|
for i=1, d0 do
|
||||||
|
local sum = 0
|
||||||
|
for j=1, d1 do
|
||||||
|
sum = sum + mat[(i - 1) * d1 + j] * vec[j]
|
||||||
|
end
|
||||||
|
out[i] = sum
|
||||||
|
end
|
||||||
|
|
||||||
|
return out
|
||||||
|
end
|
||||||
|
|
||||||
|
local function make_utility(popsize, out)
|
||||||
|
local utility = out or {}
|
||||||
|
local temp = log(popsize / 2 + 1)
|
||||||
|
for i=1, popsize do utility[i] = max(0, temp - log(i)) end
|
||||||
|
local sum = 0
|
||||||
|
for _, v in ipairs(utility) do sum = sum + v end
|
||||||
|
for i, v in ipairs(utility) do utility[i] = v / sum - 1 / popsize end
|
||||||
|
return utility
|
||||||
|
end
|
||||||
|
|
||||||
|
local function make_covars(dims, sigma, out)
|
||||||
|
local covars = out or zeros{dims, dims}
|
||||||
|
local c = sigma / dims
|
||||||
|
-- simplified form of the determinant of the matrix we're going to create:
|
||||||
|
local det = pow(1 - c, dims - 1) * (c * (dims - 1) + 1)
|
||||||
|
-- multiplying by this constant makes the determinant 1:
|
||||||
|
local m = pow(1 / det, 1 / dims)
|
||||||
|
|
||||||
|
local filler = c * m
|
||||||
|
for i=1, #covars do covars[i] = filler end
|
||||||
|
-- diagonals:
|
||||||
|
for i=1, dims do covars[i + dims * (i - 1)] = m end
|
||||||
|
|
||||||
|
return covars
|
||||||
|
end
|
||||||
|
|
||||||
|
function Xnes:init(dims, popsize, learning_rate, sigma)
|
||||||
|
-- heuristic borrowed from CMA-ES:
|
||||||
|
self.dims = dims
|
||||||
|
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||||
|
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||||
|
self.sigma = sigma or 1
|
||||||
|
|
||||||
|
self.utility = make_utility(self.popsize)
|
||||||
|
|
||||||
|
self.mean = zeros{dims}
|
||||||
|
-- note: this is technically the co-standard-deviation.
|
||||||
|
-- you can imagine the "s" standing for "sqrt" if you like.
|
||||||
|
self.covars = make_covars(self.dims, self.sigma, self.covars)
|
||||||
|
|
||||||
|
--self.log_sigma = log(self.sigma)
|
||||||
|
--self.log_covars = zeros{dims, dims}
|
||||||
|
--for i, v in ipairs(self.covars) do self.log_covars[i] = log(v) end
|
||||||
|
end
|
||||||
|
|
||||||
|
function Xnes:params(new_mean, new_covars)
|
||||||
|
if new_mean ~= nil then
|
||||||
|
assert(#self.mean == #new_mean, "new parameters have the wrong size")
|
||||||
|
for i, v in ipairs(new_mean) do self.mean[i] = v end
|
||||||
|
end
|
||||||
|
if new_covars ~= nil then
|
||||||
|
-- TODO: assert determinant of new_covars is 1.
|
||||||
|
error("TODO")
|
||||||
|
end
|
||||||
|
return self.mean
|
||||||
|
end
|
||||||
|
|
||||||
|
function Xnes:ask_once(asked, noise)
|
||||||
|
asked = asked or zeros(self.dims)
|
||||||
|
noise = noise or {}
|
||||||
|
|
||||||
|
for i=1, self.dims do noise[i] = normal() end
|
||||||
|
noise.shape = {#noise}
|
||||||
|
|
||||||
|
dot_mv(self.covars, noise, asked)
|
||||||
|
for i, v in ipairs(asked) do asked[i] = self.mean[i] + self.sigma * v end
|
||||||
|
|
||||||
|
return asked, noise
|
||||||
|
end
|
||||||
|
|
||||||
|
function Xnes:ask(asked, noise)
|
||||||
|
-- return a list of parameters for the user to score,
|
||||||
|
-- and later pass to :tell().
|
||||||
|
if asked == nil then
|
||||||
|
asked = {}
|
||||||
|
for i=1, self.popsize do asked[i] = zeros(self.dims) end
|
||||||
|
end
|
||||||
|
if noise == nil then
|
||||||
|
noise = {}
|
||||||
|
for i=1, self.popsize do noise[i] = zeros(self.dims) end
|
||||||
|
end
|
||||||
|
for i=1, self.popsize do self:ask_once(asked[i], noise[i]) end
|
||||||
|
self.noise = noise
|
||||||
|
return asked, noise
|
||||||
|
end
|
||||||
|
|
||||||
|
function Xnes:tell(scored, noise)
|
||||||
|
local noise = noise or self.noise
|
||||||
|
assert(noise, "missing noise argument")
|
||||||
|
|
||||||
|
local arg = argsort(scored, function(a, b) return a > b end)
|
||||||
|
|
||||||
|
local g_delta = zeros{self.dims}
|
||||||
|
for p=1, self.popsize do
|
||||||
|
local noise_p = noise[arg[p]]
|
||||||
|
for i=1, self.dims do
|
||||||
|
g_delta[i] = g_delta[i] + self.utility[p] * noise_p[i]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local g_covars = zeros{self.dims, self.dims}
|
||||||
|
local traced = 0
|
||||||
|
for p=1, self.popsize do
|
||||||
|
local noise_p = noise[arg[p]]
|
||||||
|
for i=1, self.dims do
|
||||||
|
for j=1, self.dims do
|
||||||
|
local ind = (i - 1) * self.dims + j
|
||||||
|
local zzt = noise_p[i] * noise_p[j] - (i == j and 1 or 0)
|
||||||
|
local temp = self.utility[p] * zzt
|
||||||
|
g_covars[ind] = g_covars[ind] + temp
|
||||||
|
traced = traced + temp
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local g_sigma = traced / self.dims
|
||||||
|
|
||||||
|
for i=1, self.dims do
|
||||||
|
local ind = (i - 1) * self.dims + i
|
||||||
|
g_covars[ind] = g_covars[ind] - g_sigma
|
||||||
|
end
|
||||||
|
|
||||||
|
-- finally, update according to the gradients.
|
||||||
|
|
||||||
|
local dotted = dot_mv(self.covars, g_delta)
|
||||||
|
for i, v in ipairs(self.mean) do
|
||||||
|
self.mean[i] = v + self.sigma * dotted[i]
|
||||||
|
end
|
||||||
|
|
||||||
|
--[[
|
||||||
|
--self.log_sigma = self.log_sigma + self.learning_rate / 2 * g_sigma
|
||||||
|
for i, v in ipairs(self.log_covars) do
|
||||||
|
self.log_covars[i] = v + lr * g_covars[i]
|
||||||
|
end
|
||||||
|
--]]
|
||||||
|
|
||||||
|
local lr = self.learning_rate * 0.5
|
||||||
|
self.sigma = self.sigma * exp(lr * g_sigma)
|
||||||
|
for i, v in ipairs(self.covars) do
|
||||||
|
self.covars[i] = v * exp(lr * g_covars[i])
|
||||||
|
end
|
||||||
|
|
||||||
|
-- bookkeeping:
|
||||||
|
--self.sigma = exp(self.log_sigma)
|
||||||
|
--for i, v in ipairs(self.log_covars) do self.covars[i] = exp(v) end
|
||||||
|
self.noise = nil
|
||||||
|
end
|
||||||
|
|
||||||
|
return {
|
||||||
|
dot_mv = dot_mv,
|
||||||
|
make_utility = make_utility,
|
||||||
|
make_covars = make_covars,
|
||||||
|
|
||||||
|
Xnes = Xnes,
|
||||||
|
}
|
Loading…
Reference in a new issue