add SNES optimizer
This commit is contained in:
parent
d87b8e7118
commit
3eebbc534a
2 changed files with 186 additions and 0 deletions
27
main.lua
27
main.lua
|
@ -175,6 +175,7 @@ end
|
|||
-- learning and evaluation.
|
||||
|
||||
local ars = require("ars")
|
||||
local snes = require("snes")
|
||||
local xnes = require("xnes")
|
||||
|
||||
local function prepare_epoch()
|
||||
|
@ -251,6 +252,9 @@ local function learn_from_epoch()
|
|||
|
||||
if cfg.es == 'xnes' then
|
||||
print("sigma:", es.sigma)
|
||||
elseif cfg.es == 'snes' then
|
||||
local sigma_mean, sigma_dev = calc_mean_dev(es.std)
|
||||
print("sigma:", sigma_mean, sigma_dev)
|
||||
end
|
||||
|
||||
local step_mean, step_dev = 0, 0
|
||||
|
@ -298,6 +302,14 @@ local function learn_from_epoch()
|
|||
if cfg.enable_network then
|
||||
network:distribute(base_params)
|
||||
network:save(cfg.params_fn)
|
||||
|
||||
if cfg.es == 'snes' then
|
||||
local std_fn = cfg.params_fn:gsub(".txt", "")..".sigma.txt"
|
||||
local f = assert(open(std_fn, "w"))
|
||||
for _, v in ipairs(es.std) do f:write(("%f\n"):format(v)) end
|
||||
f:close()
|
||||
end
|
||||
|
||||
else
|
||||
print("note: not updating weights in playable mode.")
|
||||
end
|
||||
|
@ -458,6 +470,21 @@ local function init()
|
|||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||
-- TODO: clean this up into an interface:
|
||||
es.mean_adapt = cfg.mean_adapt
|
||||
elseif cfg.es == 'snes' then
|
||||
es = snes.Snes(network.n_param, cfg.epoch_trials,
|
||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||
-- TODO: clean this up into an interface:
|
||||
es.mean_adapt = cfg.mean_adapt
|
||||
|
||||
local std_fn = cfg.params_fn:gsub(".txt", "")..".sigma.txt"
|
||||
if exists(std_fn) then
|
||||
local f = assert(open(std_fn, "r"))
|
||||
for i=1, network.n_param do
|
||||
es.std[i] = assert(tonumber(assert(f:read())))
|
||||
end
|
||||
f:close()
|
||||
end
|
||||
|
||||
elseif cfg.es == 'ars' then
|
||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||
|
|
159
snes.lua
Normal file
159
snes.lua
Normal file
|
@ -0,0 +1,159 @@
|
|||
-- Separable Natural Evolution Strategies
|
||||
-- this particular implementation is based on:
|
||||
-- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
|
||||
-- not to be confused with the Super Nintendo Entertainment System.
|
||||
|
||||
local assert = assert
|
||||
local floor = math.floor
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local exp = math.exp
|
||||
local max = math.max
|
||||
|
||||
local Base = require "Base"
|
||||
|
||||
local nn = require "nn"
|
||||
local normal = nn.normal
|
||||
local zeros = nn.zeros
|
||||
|
||||
local util = require "util"
|
||||
local argsort = util.argsort
|
||||
local clamp = util.clamp
|
||||
|
||||
local Snes = Base:extend()
|
||||
|
||||
-- NOTE: duplicated in xnes.lua!
|
||||
local function make_utility(popsize, out)
|
||||
local utility = out or {}
|
||||
local temp = log(popsize / 2 + 1)
|
||||
for i=1, popsize do utility[i] = max(0, temp - log(i)) end
|
||||
local sum = 0
|
||||
for _, v in ipairs(utility) do sum = sum + v end
|
||||
for i, v in ipairs(utility) do utility[i] = v / sum - 1 / popsize end
|
||||
return utility
|
||||
end
|
||||
|
||||
function Snes:init(dims, popsize, learning_rate, sigma, antithetic)
|
||||
-- heuristic borrowed from CMA-ES:
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
self.sigma = sigma or 1
|
||||
self.antithetic = antithetic and true or false
|
||||
|
||||
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||
|
||||
self.utility = make_utility(self.popsize)
|
||||
|
||||
self.mean = zeros{dims}
|
||||
self.std = zeros{dims}
|
||||
for i=1, self.dims do self.std[i] = self.sigma end
|
||||
|
||||
self.mean_adapt = 1.0
|
||||
end
|
||||
|
||||
function Snes:params(new_mean)
|
||||
if new_mean ~= nil then
|
||||
assert(#self.mean == #new_mean, "new parameters have the wrong size")
|
||||
for i, v in ipairs(new_mean) do self.mean[i] = v end
|
||||
end
|
||||
return self.mean
|
||||
end
|
||||
|
||||
function Snes:ask_once(asked, noise)
|
||||
asked = asked or zeros(self.dims)
|
||||
noise = noise or {}
|
||||
|
||||
for i=1, self.dims do noise[i] = normal() end
|
||||
noise.shape = {#noise}
|
||||
|
||||
for i, v in ipairs(noise) do asked[i] = self.mean[i] + self.std[i] * v end
|
||||
|
||||
return asked, noise
|
||||
end
|
||||
|
||||
function Snes:ask_twice(asked0, asked1, noise0, noise1)
|
||||
asked0 = asked0 or zeros(self.dims)
|
||||
asked1 = asked1 or zeros(self.dims)
|
||||
noise0 = noise0 or {}
|
||||
noise1 = noise1 or {}
|
||||
|
||||
for i=1, self.dims do noise0[i] = normal() end
|
||||
noise0.shape = {#noise0}
|
||||
|
||||
for i, v in ipairs(noise0) do
|
||||
asked0[i] = self.mean[i] + self.std[i] * v
|
||||
asked1[i] = self.mean[i] - self.std[i] * v
|
||||
end
|
||||
for i, v in ipairs(noise0) do noise1[i] = -v end
|
||||
|
||||
return asked0, asked1, noise0, noise1
|
||||
end
|
||||
|
||||
-- NOTE: duplicated in xnes.lua!
|
||||
function Snes:ask(asked, noise)
|
||||
-- return a list of parameters for the user to score,
|
||||
-- and later pass to :tell().
|
||||
if asked == nil then
|
||||
asked = {}
|
||||
for i=1, self.popsize do asked[i] = zeros(self.dims) end
|
||||
end
|
||||
if noise == nil then
|
||||
noise = {}
|
||||
for i=1, self.popsize do noise[i] = zeros(self.dims) end
|
||||
end
|
||||
|
||||
if self.antithetic then
|
||||
for i=1, self.popsize do
|
||||
self:ask_twice(asked[i+0], asked[i+1], noise[i+0], noise[i+1])
|
||||
end
|
||||
else
|
||||
for i=1, self.popsize do
|
||||
self:ask_once(asked[i], noise[i])
|
||||
end
|
||||
end
|
||||
|
||||
self.noise = noise
|
||||
return asked, noise
|
||||
end
|
||||
|
||||
function Snes:tell(scored, noise)
|
||||
local noise = noise or self.noise
|
||||
assert(noise, "missing noise argument")
|
||||
|
||||
local arg = argsort(scored, function(a, b) return a > b end)
|
||||
|
||||
local g_mean = zeros{self.dims}
|
||||
for p=1, self.popsize do
|
||||
local noise_p = noise[arg[p]]
|
||||
for i, v in ipairs(g_mean) do
|
||||
g_mean[i] = v + self.utility[p] * noise_p[i]
|
||||
end
|
||||
end
|
||||
|
||||
local g_std = zeros{self.dims}
|
||||
for p=1, self.popsize do
|
||||
local noise_p = noise[arg[p]]
|
||||
for i, v in ipairs(g_std) do
|
||||
local n = noise_p[i]
|
||||
g_std[i] = v + self.utility[p] * (n * n - 1)
|
||||
end
|
||||
end
|
||||
|
||||
for i, v in ipairs(self.mean) do
|
||||
self.mean[i] = v + self.mean_adapt * self.std[i] * g_mean[i]
|
||||
end
|
||||
|
||||
for i, v in ipairs(self.std) do
|
||||
self.std[i] = v * exp(self.learning_rate / 2 * g_std[i])
|
||||
end
|
||||
|
||||
-- bookkeeping:
|
||||
self.noise = nil
|
||||
end
|
||||
|
||||
return {
|
||||
make_utility = make_utility,
|
||||
|
||||
Snes = Snes,
|
||||
}
|
Loading…
Reference in a new issue