smbot/snes.lua
2018-06-10 16:40:20 +02:00

159 lines
4.2 KiB
Lua

-- Separable Natural Evolution Strategies
-- this particular implementation is based on:
-- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
-- not to be confused with the Super Nintendo Entertainment System.
local assert = assert
local floor = math.floor
local ipairs = ipairs
local log = math.log
local exp = math.exp
local max = math.max
local Base = require "Base"
local nn = require "nn"
local normal = nn.normal
local zeros = nn.zeros
local util = require "util"
local argsort = util.argsort
local clamp = util.clamp
local Snes = Base:extend()
-- NOTE: duplicated in xnes.lua!
local function make_utility(popsize, out)
local utility = out or {}
local temp = log(popsize / 2 + 1)
for i=1, popsize do utility[i] = max(0, temp - log(i)) end
local sum = 0
for _, v in ipairs(utility) do sum = sum + v end
for i, v in ipairs(utility) do utility[i] = v / sum - 1 / popsize end
return utility
end
function Snes:init(dims, popsize, learning_rate, sigma, antithetic)
-- heuristic borrowed from CMA-ES:
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.sigma = sigma or 1
self.antithetic = antithetic and true or false
if self.antithetic then self.popsize = self.popsize * 2 end
self.utility = make_utility(self.popsize)
self.mean = zeros{dims}
self.std = zeros{dims}
for i=1, self.dims do self.std[i] = self.sigma end
self.mean_adapt = 1.0
end
function Snes:params(new_mean)
if new_mean ~= nil then
assert(#self.mean == #new_mean, "new parameters have the wrong size")
for i, v in ipairs(new_mean) do self.mean[i] = v end
end
return self.mean
end
function Snes:ask_once(asked, noise)
asked = asked or zeros(self.dims)
noise = noise or {}
for i=1, self.dims do noise[i] = normal() end
noise.shape = {#noise}
for i, v in ipairs(noise) do asked[i] = self.mean[i] + self.std[i] * v end
return asked, noise
end
function Snes:ask_twice(asked0, asked1, noise0, noise1)
asked0 = asked0 or zeros(self.dims)
asked1 = asked1 or zeros(self.dims)
noise0 = noise0 or {}
noise1 = noise1 or {}
for i=1, self.dims do noise0[i] = normal() end
noise0.shape = {#noise0}
for i, v in ipairs(noise0) do
asked0[i] = self.mean[i] + self.std[i] * v
asked1[i] = self.mean[i] - self.std[i] * v
end
for i, v in ipairs(noise0) do noise1[i] = -v end
return asked0, asked1, noise0, noise1
end
-- NOTE: duplicated in xnes.lua!
function Snes:ask(asked, noise)
-- return a list of parameters for the user to score,
-- and later pass to :tell().
if asked == nil then
asked = {}
for i=1, self.popsize do asked[i] = zeros(self.dims) end
end
if noise == nil then
noise = {}
for i=1, self.popsize do noise[i] = zeros(self.dims) end
end
if self.antithetic then
for i=1, self.popsize do
self:ask_twice(asked[i+0], asked[i+1], noise[i+0], noise[i+1])
end
else
for i=1, self.popsize do
self:ask_once(asked[i], noise[i])
end
end
self.noise = noise
return asked, noise
end
function Snes:tell(scored, noise)
local noise = noise or self.noise
assert(noise, "missing noise argument")
local arg = argsort(scored, function(a, b) return a > b end)
local g_mean = zeros{self.dims}
for p=1, self.popsize do
local noise_p = noise[arg[p]]
for i, v in ipairs(g_mean) do
g_mean[i] = v + self.utility[p] * noise_p[i]
end
end
local g_std = zeros{self.dims}
for p=1, self.popsize do
local noise_p = noise[arg[p]]
for i, v in ipairs(g_std) do
local n = noise_p[i]
g_std[i] = v + self.utility[p] * (n * n - 1)
end
end
for i, v in ipairs(self.mean) do
self.mean[i] = v + self.mean_adapt * self.std[i] * g_mean[i]
end
for i, v in ipairs(self.std) do
self.std[i] = v * exp(self.learning_rate / 2 * g_std[i])
end
-- bookkeeping:
self.noise = nil
end
return {
make_utility = make_utility,
Snes = Snes,
}