2018-06-10 07:40:20 -07:00
|
|
|
-- Separable Natural Evolution Strategies
|
|
|
|
-- this particular implementation is based on:
|
|
|
|
-- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
|
|
|
|
-- not to be confused with the Super Nintendo Entertainment System.
|
|
|
|
|
2018-06-12 16:19:32 -07:00
|
|
|
local abs = math.abs
|
2018-06-10 07:40:20 -07:00
|
|
|
local assert = assert
|
2018-06-12 16:19:32 -07:00
|
|
|
local exp = math.exp
|
2018-06-10 07:40:20 -07:00
|
|
|
local floor = math.floor
|
|
|
|
local ipairs = ipairs
|
|
|
|
local log = math.log
|
|
|
|
local max = math.max
|
2018-06-12 16:19:32 -07:00
|
|
|
local min = math.min
|
|
|
|
local sqrt = math.sqrt
|
|
|
|
local insert = table.insert
|
|
|
|
local remove = table.remove
|
2018-06-10 07:40:20 -07:00
|
|
|
|
|
|
|
local Base = require "Base"
|
|
|
|
|
|
|
|
local nn = require "nn"
|
|
|
|
local normal = nn.normal
|
2018-06-12 16:19:32 -07:00
|
|
|
local uniform = nn.uniform
|
2018-06-10 07:40:20 -07:00
|
|
|
local zeros = nn.zeros
|
|
|
|
|
|
|
|
local util = require "util"
|
|
|
|
local argsort = util.argsort
|
2018-06-12 16:19:32 -07:00
|
|
|
local cdf = util.cdf
|
2018-06-10 07:40:20 -07:00
|
|
|
local clamp = util.clamp
|
2018-06-12 16:19:32 -07:00
|
|
|
local normalize_sums = util.normalize_sums
|
|
|
|
local pdf = util.pdf
|
|
|
|
local weighted_mann_whitney = util.weighted_mann_whitney
|
2018-06-10 07:40:20 -07:00
|
|
|
|
|
|
|
local Snes = Base:extend()
|
|
|
|
|
2018-06-15 15:24:55 -07:00
|
|
|
function Snes:init(dims, popsize, base_rate, sigma, antithetic)
|
2018-06-10 07:40:20 -07:00
|
|
|
-- heuristic borrowed from CMA-ES:
|
|
|
|
self.dims = dims
|
|
|
|
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
2018-06-15 15:24:55 -07:00
|
|
|
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
2018-06-15 15:33:11 -07:00
|
|
|
self.param_rate = 1.0
|
2018-06-15 15:24:55 -07:00
|
|
|
self.sigma_rate = base_rate
|
|
|
|
self.covar_rate = base_rate
|
2018-06-10 07:40:20 -07:00
|
|
|
self.sigma = sigma or 1
|
|
|
|
self.antithetic = antithetic and true or false
|
|
|
|
|
|
|
|
if self.antithetic then self.popsize = self.popsize * 2 end
|
|
|
|
|
2018-06-15 15:24:55 -07:00
|
|
|
self.rate_init = self.sigma_rate
|
2018-06-10 07:40:20 -07:00
|
|
|
|
|
|
|
self.mean = zeros{dims}
|
|
|
|
self.std = zeros{dims}
|
|
|
|
for i=1, self.dims do self.std[i] = self.sigma end
|
|
|
|
|
2018-06-12 16:19:32 -07:00
|
|
|
self.old_asked = {}
|
|
|
|
self.old_noise = {}
|
|
|
|
self.old_score = {}
|
|
|
|
self.new_asked = {}
|
|
|
|
self.new_noise = {}
|
|
|
|
|
|
|
|
self.evals = 0
|
2018-06-10 07:40:20 -07:00
|
|
|
end
|
|
|
|
|
|
|
|
function Snes:params(new_mean)
|
|
|
|
if new_mean ~= nil then
|
|
|
|
assert(#self.mean == #new_mean, "new parameters have the wrong size")
|
|
|
|
for i, v in ipairs(new_mean) do self.mean[i] = v end
|
|
|
|
end
|
|
|
|
return self.mean
|
|
|
|
end
|
|
|
|
|
2018-06-28 02:03:28 -07:00
|
|
|
function Snes:decay(param_decay, sigma_decay)
|
|
|
|
if sigma_decay > 0 then
|
|
|
|
for i, v in ipairs(self.std) do
|
|
|
|
self.std[i] = v * (1 - self.sigma_rate * sigma_decay)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
if param_decay > 0 then
|
|
|
|
for i, v in ipairs(self.mean) do
|
|
|
|
self.mean[i] = v * (1 - self.param_rate * param_decay * self.std[i])
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2018-06-10 07:40:20 -07:00
|
|
|
function Snes:ask_once(asked, noise)
|
2018-06-12 16:19:32 -07:00
|
|
|
asked = asked or {}
|
2018-06-10 07:40:20 -07:00
|
|
|
noise = noise or {}
|
|
|
|
|
|
|
|
for i=1, self.dims do noise[i] = normal() end
|
|
|
|
for i, v in ipairs(noise) do asked[i] = self.mean[i] + self.std[i] * v end
|
|
|
|
|
|
|
|
return asked, noise
|
|
|
|
end
|
|
|
|
|
|
|
|
function Snes:ask_twice(asked0, asked1, noise0, noise1)
|
|
|
|
asked0 = asked0 or zeros(self.dims)
|
|
|
|
asked1 = asked1 or zeros(self.dims)
|
|
|
|
noise0 = noise0 or {}
|
|
|
|
noise1 = noise1 or {}
|
|
|
|
|
|
|
|
for i=1, self.dims do noise0[i] = normal() end
|
|
|
|
noise0.shape = {#noise0}
|
|
|
|
|
|
|
|
for i, v in ipairs(noise0) do
|
|
|
|
asked0[i] = self.mean[i] + self.std[i] * v
|
|
|
|
asked1[i] = self.mean[i] - self.std[i] * v
|
|
|
|
end
|
|
|
|
for i, v in ipairs(noise0) do noise1[i] = -v end
|
|
|
|
|
|
|
|
return asked0, asked1, noise0, noise1
|
|
|
|
end
|
|
|
|
|
|
|
|
function Snes:ask(asked, noise)
|
|
|
|
-- return a list of parameters for the user to score,
|
|
|
|
-- and later pass to :tell().
|
2018-06-12 16:19:32 -07:00
|
|
|
self.mixing = false
|
2018-06-10 07:40:20 -07:00
|
|
|
if asked == nil then
|
|
|
|
asked = {}
|
|
|
|
for i=1, self.popsize do asked[i] = zeros(self.dims) end
|
|
|
|
end
|
|
|
|
if noise == nil then
|
|
|
|
noise = {}
|
|
|
|
for i=1, self.popsize do noise[i] = zeros(self.dims) end
|
|
|
|
end
|
|
|
|
|
|
|
|
if self.antithetic then
|
|
|
|
for i=1, self.popsize do
|
|
|
|
self:ask_twice(asked[i+0], asked[i+1], noise[i+0], noise[i+1])
|
|
|
|
end
|
|
|
|
else
|
|
|
|
for i=1, self.popsize do
|
|
|
|
self:ask_once(asked[i], noise[i])
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2018-06-12 16:19:32 -07:00
|
|
|
self.asked = asked
|
2018-06-10 07:40:20 -07:00
|
|
|
self.noise = noise
|
|
|
|
return asked, noise
|
|
|
|
end
|
|
|
|
|
2018-06-12 16:19:32 -07:00
|
|
|
function Snes:ask_mix(start_anew)
|
|
|
|
-- TODO: refactor and merge with :ask()?
|
|
|
|
self.mixing = true
|
|
|
|
if start_anew then
|
|
|
|
self.old_asked = {}
|
|
|
|
self.old_noise = {}
|
|
|
|
self.old_score = {}
|
|
|
|
end
|
|
|
|
|
|
|
|
-- perform importance mixing.
|
|
|
|
|
|
|
|
local mean_old = self.mean
|
|
|
|
local mean_new = self.mean
|
|
|
|
local std_old = self.std_old or self.std
|
|
|
|
local std_new = self.std
|
|
|
|
|
|
|
|
self.new_asked = {}
|
|
|
|
self.new_noise = {}
|
|
|
|
|
|
|
|
local marked = {}
|
|
|
|
for p=1, min(#self.old_asked, self.popsize) do
|
|
|
|
local a = self.old_asked[p]
|
2018-06-10 07:40:20 -07:00
|
|
|
|
2018-06-12 16:19:32 -07:00
|
|
|
-- TODO: cache probs?
|
|
|
|
local prob_new = 0
|
|
|
|
local prob_old = 0
|
|
|
|
for i, v in ipairs(a) do
|
|
|
|
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
|
|
|
|
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
|
|
|
|
end
|
|
|
|
|
|
|
|
local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1)
|
|
|
|
if uniform() < accept then
|
|
|
|
--print(("accepted old sample %i with probability %f"):format(p, accept))
|
|
|
|
else
|
|
|
|
-- insert in reverse as not to screw up
|
|
|
|
-- the indices when removing later.
|
|
|
|
insert(marked, 1, p)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
for _, p in ipairs(marked) do
|
|
|
|
remove(self.old_asked, p)
|
|
|
|
remove(self.old_noise, p)
|
|
|
|
remove(self.old_score, p)
|
|
|
|
end
|
|
|
|
|
|
|
|
while #self.old_asked + #self.new_asked < self.popsize do
|
|
|
|
local a = {}
|
|
|
|
local n = {}
|
|
|
|
for i=1, self.dims do n[i] = normal() end
|
|
|
|
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
|
|
|
|
|
|
|
|
-- can't cache here!
|
|
|
|
local prob_new = 0
|
|
|
|
local prob_old = 0
|
|
|
|
for i, v in ipairs(a) do
|
|
|
|
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
|
|
|
|
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
|
|
|
|
end
|
|
|
|
|
|
|
|
local accept = max(1 - prob_old / prob_new, self.min_refresh)
|
|
|
|
if uniform() < accept then
|
|
|
|
insert(self.new_asked, a)
|
|
|
|
insert(self.new_noise, n)
|
|
|
|
--print(("accepted new sample %i with probability %f"):format(0, accept))
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
return self.new_asked, self.new_noise
|
|
|
|
end
|
|
|
|
|
|
|
|
function Snes:tell(scored)
|
|
|
|
self.evals = self.evals + #scored
|
|
|
|
|
|
|
|
local asked = self.asked
|
|
|
|
local noise = self.noise
|
|
|
|
if self.mixing then
|
|
|
|
asked = self.old_asked
|
|
|
|
noise = self.old_noise
|
|
|
|
-- note that these modify tables referenced externally in-place.
|
|
|
|
for i, v in ipairs(self.new_asked) do insert(asked, v) end
|
|
|
|
for i, v in ipairs(self.new_noise) do insert(noise, v) end
|
|
|
|
for i, v in ipairs(scored) do insert(self.old_score, v) end
|
|
|
|
scored = self.old_score
|
|
|
|
end
|
|
|
|
assert(asked and noise, ":tell() called before :ask()")
|
|
|
|
assert(#asked == #noise and #asked == #scored, "length mismatch")
|
|
|
|
assert(#scored == self.popsize)
|
|
|
|
|
|
|
|
-- TODO: use a proper ranking function.
|
2018-06-10 07:40:20 -07:00
|
|
|
local arg = argsort(scored, function(a, b) return a > b end)
|
|
|
|
|
|
|
|
local g_mean = zeros{self.dims}
|
2018-06-12 16:19:32 -07:00
|
|
|
local g_std = zeros{self.dims}
|
|
|
|
|
|
|
|
local utilize = true
|
|
|
|
local utility
|
|
|
|
|
|
|
|
if utilize then
|
|
|
|
utility = {}
|
|
|
|
local const = log(self.popsize * 0.5 + 1)
|
|
|
|
for i, v in ipairs(arg) do utility[v] = max(const - log(i), 0) end
|
|
|
|
normalize_sums(utility)
|
|
|
|
else
|
|
|
|
utility = normalize_sums(scored, {})
|
|
|
|
end
|
|
|
|
|
2018-06-10 07:40:20 -07:00
|
|
|
for p=1, self.popsize do
|
2018-06-12 16:19:32 -07:00
|
|
|
local noise_p = noise[p]
|
|
|
|
|
2018-06-10 07:40:20 -07:00
|
|
|
for i, v in ipairs(g_mean) do
|
2018-06-12 16:19:32 -07:00
|
|
|
g_mean[i] = v + utility[p] * noise_p[i]
|
2018-06-10 07:40:20 -07:00
|
|
|
end
|
|
|
|
|
|
|
|
for i, v in ipairs(g_std) do
|
|
|
|
local n = noise_p[i]
|
2018-06-12 16:19:32 -07:00
|
|
|
g_std[i] = v + utility[p] * (n * n - 1)
|
2018-06-10 07:40:20 -07:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2018-06-12 16:36:40 -07:00
|
|
|
local step = {}
|
|
|
|
for i, v in ipairs(g_mean) do
|
|
|
|
step[i] = self.std[i] * v
|
|
|
|
end
|
|
|
|
|
2018-06-10 07:40:20 -07:00
|
|
|
for i, v in ipairs(self.mean) do
|
2018-06-15 15:33:11 -07:00
|
|
|
self.mean[i] = v + self.param_rate * step[i]
|
2018-06-10 07:40:20 -07:00
|
|
|
end
|
|
|
|
|
2018-06-12 16:19:32 -07:00
|
|
|
local otherwise = {}
|
|
|
|
self.std_old = {}
|
2018-06-10 07:40:20 -07:00
|
|
|
for i, v in ipairs(self.std) do
|
2018-06-12 16:19:32 -07:00
|
|
|
self.std_old[i] = v
|
2018-06-15 15:24:55 -07:00
|
|
|
self.std[i] = v * exp(self.sigma_rate * 0.5 * g_std[i])
|
|
|
|
otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i])
|
2018-06-10 07:40:20 -07:00
|
|
|
end
|
|
|
|
|
2018-06-12 16:19:32 -07:00
|
|
|
self:adapt(asked, otherwise, utility)
|
2018-06-12 16:36:40 -07:00
|
|
|
|
|
|
|
return step
|
2018-06-10 07:40:20 -07:00
|
|
|
end
|
|
|
|
|
2018-06-12 16:19:32 -07:00
|
|
|
function Snes:adapt(asked, otherwise, qualities)
|
|
|
|
local weights = {}
|
|
|
|
for p=1, self.popsize do
|
|
|
|
local asked_p = asked[p]
|
|
|
|
local prob_now = 0
|
|
|
|
local prob_big = 0
|
|
|
|
for i, v in ipairs(asked_p) do
|
|
|
|
prob_now = prob_now + pdf(v, self.mean[i], self.std[i])
|
|
|
|
prob_big = prob_big + pdf(v, self.mean[i], otherwise[i])
|
|
|
|
end
|
|
|
|
weights[p] = prob_big / prob_now
|
|
|
|
end
|
|
|
|
|
|
|
|
local p = weighted_mann_whitney(qualities, qualities, nil, weights)
|
|
|
|
--print("p:", p)
|
2018-06-10 07:40:20 -07:00
|
|
|
|
2018-06-12 16:19:32 -07:00
|
|
|
if p < 0.5 - 1 / (3 * (self.dims + 1)) then
|
2018-06-15 15:24:55 -07:00
|
|
|
self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init
|
|
|
|
print("learning rate -:", self.sigma_rate)
|
2018-06-12 16:19:32 -07:00
|
|
|
else
|
2018-06-15 15:24:55 -07:00
|
|
|
self.sigma_rate = min(1.1 * self.sigma_rate, 1)
|
|
|
|
print("learning rate +:", self.sigma_rate)
|
2018-06-12 16:19:32 -07:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
return {
|
2018-06-10 07:40:20 -07:00
|
|
|
Snes = Snes,
|
|
|
|
}
|