overhaul SNES (importance sampling, adaptation sampling, etc)
This commit is contained in:
parent
7bb9c79367
commit
5c64fcf395
4 changed files with 208 additions and 35 deletions
|
@ -36,6 +36,7 @@ local common_cfg = {
|
|||
mean_adapt = 1.0, -- for xNES
|
||||
weight_decay = 0.0,
|
||||
sigma_decay = 0.0,
|
||||
min_refresh = 0.2,
|
||||
|
||||
es = 'ars',
|
||||
ars_lips = false,
|
||||
|
|
8
main.lua
8
main.lua
|
@ -199,6 +199,7 @@ local function prepare_epoch()
|
|||
elseif cfg.es == 'snes' then
|
||||
local sigma_mean, sigma_dev = calc_mean_dev(es.std)
|
||||
--print("sigma:", sigma_mean, sigma_dev)
|
||||
print("sigma 50%:", sigma_mean)
|
||||
print("sigma 95%:", sigma_mean + sigma_dev * 1.64485)
|
||||
end
|
||||
|
||||
|
@ -211,6 +212,8 @@ local function prepare_epoch()
|
|||
local dummy
|
||||
if cfg.es == 'ars' then
|
||||
trial_params, dummy = es:ask(precision)
|
||||
elseif cfg.es == 'snes' then
|
||||
trial_params, dummy = es:ask_mix()
|
||||
else
|
||||
trial_params, dummy = es:ask()
|
||||
end
|
||||
|
@ -288,12 +291,12 @@ local function learn_from_epoch()
|
|||
if cfg.es == 'snes' then
|
||||
if cfg.sigma_decay > 0 then
|
||||
for i, v in ipairs(es.std) do
|
||||
es.std[i] = v * (1 - cfg.sigma_decay)
|
||||
es.std[i] = v * (1 - cfg.learning_rate * cfg.sigma_decay)
|
||||
end
|
||||
end
|
||||
if cfg.weight_decay > 0 then
|
||||
for i, v in ipairs(base_params) do
|
||||
base_params[i] = v * (1 - cfg.weight_decay * es.std[i])
|
||||
base_params[i] = v * (1 - cfg.mean_adapt * cfg.weight_decay * es.std[i])
|
||||
end
|
||||
end
|
||||
else
|
||||
|
@ -502,6 +505,7 @@ local function init()
|
|||
cfg.learning_rate, cfg.deviation, cfg.negate_trials)
|
||||
-- TODO: clean this up into an interface:
|
||||
es.mean_adapt = cfg.mean_adapt
|
||||
es.min_refresh = cfg.min_refresh
|
||||
|
||||
if exists(std_fn) then
|
||||
local f = assert(open(std_fn, "r"))
|
||||
|
|
196
snes.lua
196
snes.lua
|
@ -3,36 +3,35 @@
|
|||
-- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
|
||||
-- not to be confused with the Super Nintendo Entertainment System.
|
||||
|
||||
local abs = math.abs
|
||||
local assert = assert
|
||||
local exp = math.exp
|
||||
local floor = math.floor
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local exp = math.exp
|
||||
local max = math.max
|
||||
local min = math.min
|
||||
local sqrt = math.sqrt
|
||||
local insert = table.insert
|
||||
local remove = table.remove
|
||||
|
||||
local Base = require "Base"
|
||||
|
||||
local nn = require "nn"
|
||||
local normal = nn.normal
|
||||
local uniform = nn.uniform
|
||||
local zeros = nn.zeros
|
||||
|
||||
local util = require "util"
|
||||
local argsort = util.argsort
|
||||
local cdf = util.cdf
|
||||
local clamp = util.clamp
|
||||
local normalize_sums = util.normalize_sums
|
||||
local pdf = util.pdf
|
||||
local weighted_mann_whitney = util.weighted_mann_whitney
|
||||
|
||||
local Snes = Base:extend()
|
||||
|
||||
-- NOTE: duplicated in xnes.lua!
|
||||
local function make_utility(popsize, out)
|
||||
local utility = out or {}
|
||||
local temp = log(popsize / 2 + 1)
|
||||
for i=1, popsize do utility[i] = max(0, temp - log(i)) end
|
||||
local sum = 0
|
||||
for _, v in ipairs(utility) do sum = sum + v end
|
||||
for i, v in ipairs(utility) do utility[i] = v / sum - 1 / popsize end
|
||||
return utility
|
||||
end
|
||||
|
||||
function Snes:init(dims, popsize, learning_rate, sigma, antithetic)
|
||||
-- heuristic borrowed from CMA-ES:
|
||||
self.dims = dims
|
||||
|
@ -43,13 +42,21 @@ function Snes:init(dims, popsize, learning_rate, sigma, antithetic)
|
|||
|
||||
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||
|
||||
self.utility = make_utility(self.popsize)
|
||||
self.rate_init = self.learning_rate
|
||||
|
||||
self.mean = zeros{dims}
|
||||
self.std = zeros{dims}
|
||||
for i=1, self.dims do self.std[i] = self.sigma end
|
||||
|
||||
self.mean_adapt = 1.0
|
||||
|
||||
self.old_asked = {}
|
||||
self.old_noise = {}
|
||||
self.old_score = {}
|
||||
self.new_asked = {}
|
||||
self.new_noise = {}
|
||||
|
||||
self.evals = 0
|
||||
end
|
||||
|
||||
function Snes:params(new_mean)
|
||||
|
@ -61,12 +68,10 @@ function Snes:params(new_mean)
|
|||
end
|
||||
|
||||
function Snes:ask_once(asked, noise)
|
||||
asked = asked or zeros(self.dims)
|
||||
asked = asked or {}
|
||||
noise = noise or {}
|
||||
|
||||
for i=1, self.dims do noise[i] = normal() end
|
||||
noise.shape = {#noise}
|
||||
|
||||
for i, v in ipairs(noise) do asked[i] = self.mean[i] + self.std[i] * v end
|
||||
|
||||
return asked, noise
|
||||
|
@ -90,10 +95,10 @@ function Snes:ask_twice(asked0, asked1, noise0, noise1)
|
|||
return asked0, asked1, noise0, noise1
|
||||
end
|
||||
|
||||
-- NOTE: duplicated in xnes.lua!
|
||||
function Snes:ask(asked, noise)
|
||||
-- return a list of parameters for the user to score,
|
||||
-- and later pass to :tell().
|
||||
self.mixing = false
|
||||
if asked == nil then
|
||||
asked = {}
|
||||
for i=1, self.popsize do asked[i] = zeros(self.dims) end
|
||||
|
@ -113,30 +118,129 @@ function Snes:ask(asked, noise)
|
|||
end
|
||||
end
|
||||
|
||||
self.asked = asked
|
||||
self.noise = noise
|
||||
return asked, noise
|
||||
end
|
||||
|
||||
function Snes:tell(scored, noise)
|
||||
local noise = noise or self.noise
|
||||
assert(noise, "missing noise argument")
|
||||
function Snes:ask_mix(start_anew)
|
||||
-- TODO: refactor and merge with :ask()?
|
||||
self.mixing = true
|
||||
if start_anew then
|
||||
self.old_asked = {}
|
||||
self.old_noise = {}
|
||||
self.old_score = {}
|
||||
end
|
||||
|
||||
-- perform importance mixing.
|
||||
|
||||
local mean_old = self.mean
|
||||
local mean_new = self.mean
|
||||
local std_old = self.std_old or self.std
|
||||
local std_new = self.std
|
||||
|
||||
self.new_asked = {}
|
||||
self.new_noise = {}
|
||||
|
||||
local marked = {}
|
||||
for p=1, min(#self.old_asked, self.popsize) do
|
||||
local a = self.old_asked[p]
|
||||
local n = self.old_noise[p]
|
||||
|
||||
-- TODO: cache probs?
|
||||
local prob_new = 0
|
||||
local prob_old = 0
|
||||
for i, v in ipairs(a) do
|
||||
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
|
||||
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
|
||||
end
|
||||
|
||||
local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1)
|
||||
if uniform() < accept then
|
||||
--print(("accepted old sample %i with probability %f"):format(p, accept))
|
||||
else
|
||||
-- insert in reverse as not to screw up
|
||||
-- the indices when removing later.
|
||||
insert(marked, 1, p)
|
||||
end
|
||||
end
|
||||
for _, p in ipairs(marked) do
|
||||
remove(self.old_asked, p)
|
||||
remove(self.old_noise, p)
|
||||
remove(self.old_score, p)
|
||||
end
|
||||
|
||||
while #self.old_asked + #self.new_asked < self.popsize do
|
||||
local a = {}
|
||||
local n = {}
|
||||
for i=1, self.dims do n[i] = normal() end
|
||||
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
|
||||
|
||||
-- can't cache here!
|
||||
local prob_new = 0
|
||||
local prob_old = 0
|
||||
for i, v in ipairs(a) do
|
||||
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
|
||||
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
|
||||
end
|
||||
|
||||
local accept = max(1 - prob_old / prob_new, self.min_refresh)
|
||||
if uniform() < accept then
|
||||
insert(self.new_asked, a)
|
||||
insert(self.new_noise, n)
|
||||
--print(("accepted new sample %i with probability %f"):format(0, accept))
|
||||
end
|
||||
end
|
||||
|
||||
return self.new_asked, self.new_noise
|
||||
end
|
||||
|
||||
function Snes:tell(scored)
|
||||
self.evals = self.evals + #scored
|
||||
|
||||
local asked = self.asked
|
||||
local noise = self.noise
|
||||
if self.mixing then
|
||||
asked = self.old_asked
|
||||
noise = self.old_noise
|
||||
-- note that these modify tables referenced externally in-place.
|
||||
for i, v in ipairs(self.new_asked) do insert(asked, v) end
|
||||
for i, v in ipairs(self.new_noise) do insert(noise, v) end
|
||||
for i, v in ipairs(scored) do insert(self.old_score, v) end
|
||||
scored = self.old_score
|
||||
end
|
||||
assert(asked and noise, ":tell() called before :ask()")
|
||||
assert(#asked == #noise and #asked == #scored, "length mismatch")
|
||||
assert(#scored == self.popsize)
|
||||
|
||||
-- TODO: use a proper ranking function.
|
||||
local arg = argsort(scored, function(a, b) return a > b end)
|
||||
|
||||
local g_mean = zeros{self.dims}
|
||||
for p=1, self.popsize do
|
||||
local noise_p = noise[arg[p]]
|
||||
for i, v in ipairs(g_mean) do
|
||||
g_mean[i] = v + self.utility[p] * noise_p[i]
|
||||
end
|
||||
local g_std = zeros{self.dims}
|
||||
|
||||
local utilize = true
|
||||
local utility
|
||||
|
||||
if utilize then
|
||||
utility = {}
|
||||
local const = log(self.popsize * 0.5 + 1)
|
||||
for i, v in ipairs(arg) do utility[v] = max(const - log(i), 0) end
|
||||
normalize_sums(utility)
|
||||
else
|
||||
utility = normalize_sums(scored, {})
|
||||
end
|
||||
|
||||
local g_std = zeros{self.dims}
|
||||
for p=1, self.popsize do
|
||||
local noise_p = noise[arg[p]]
|
||||
local noise_p = noise[p]
|
||||
|
||||
for i, v in ipairs(g_mean) do
|
||||
g_mean[i] = v + utility[p] * noise_p[i]
|
||||
end
|
||||
|
||||
for i, v in ipairs(g_std) do
|
||||
local n = noise_p[i]
|
||||
g_std[i] = v + self.utility[p] * (n * n - 1)
|
||||
g_std[i] = v + utility[p] * (n * n - 1)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -144,16 +248,42 @@ function Snes:tell(scored, noise)
|
|||
self.mean[i] = v + self.mean_adapt * self.std[i] * g_mean[i]
|
||||
end
|
||||
|
||||
local otherwise = {}
|
||||
self.std_old = {}
|
||||
for i, v in ipairs(self.std) do
|
||||
self.std[i] = v * exp(self.learning_rate / 2 * g_std[i])
|
||||
self.std_old[i] = v
|
||||
self.std[i] = v * exp(self.learning_rate * 0.5 * g_std[i])
|
||||
otherwise[i] = v * exp(self.learning_rate * 0.75 * g_std[i])
|
||||
end
|
||||
|
||||
-- bookkeeping:
|
||||
self.noise = nil
|
||||
self:adapt(asked, otherwise, utility)
|
||||
end
|
||||
|
||||
function Snes:adapt(asked, otherwise, qualities)
|
||||
local weights = {}
|
||||
for p=1, self.popsize do
|
||||
local asked_p = asked[p]
|
||||
local prob_now = 0
|
||||
local prob_big = 0
|
||||
for i, v in ipairs(asked_p) do
|
||||
prob_now = prob_now + pdf(v, self.mean[i], self.std[i])
|
||||
prob_big = prob_big + pdf(v, self.mean[i], otherwise[i])
|
||||
end
|
||||
weights[p] = prob_big / prob_now
|
||||
end
|
||||
|
||||
local p = weighted_mann_whitney(qualities, qualities, nil, weights)
|
||||
--print("p:", p)
|
||||
|
||||
if p < 0.5 - 1 / (3 * (self.dims + 1)) then
|
||||
self.learning_rate = 0.9 * self.learning_rate + 0.1 * self.rate_init
|
||||
print("learning rate -:", self.learning_rate)
|
||||
else
|
||||
self.learning_rate = min(1.1 * self.learning_rate, 1)
|
||||
print("learning rate +:", self.learning_rate)
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
make_utility = make_utility,
|
||||
|
||||
Snes = Snes,
|
||||
}
|
||||
|
|
38
util.lua
38
util.lua
|
@ -207,6 +207,43 @@ local function cdf(x)
|
|||
return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x)))
|
||||
end
|
||||
|
||||
local function weighted_mann_whitney(s0, s1, w0, w1)
|
||||
-- when w0 and w1 are nil, this decomposes(?) to the regular Mann-Whitney.
|
||||
if w0 == nil then
|
||||
w0 = {}
|
||||
for i=1, #s0 do w0[i] = 1.0 end
|
||||
end
|
||||
if w1 == nil then
|
||||
w1 = {}
|
||||
for i=1, #s1 do w1[i] = 1.0 end
|
||||
end
|
||||
assert(#s0 == #w0)
|
||||
assert(#s1 == #w1)
|
||||
|
||||
local s0_sum, s1_sum, w0_sum, w1_sum = 0, 0, 0, 0
|
||||
for i, v in ipairs(s0) do s0_sum = s0_sum + v end
|
||||
for i, v in ipairs(s1) do s1_sum = s1_sum + v end
|
||||
for i, v in ipairs(w0) do w0_sum = w0_sum + v end
|
||||
for i, v in ipairs(w1) do w1_sum = w1_sum + v end
|
||||
|
||||
local U = 0
|
||||
for i=1, #s0 do
|
||||
for j=1, #s1 do
|
||||
if s0[i] > s1[j] then
|
||||
U = U + w0[i] * w1[j]
|
||||
elseif s0[i] == s1[j] then
|
||||
U = U + w0[i] * w1[j] * 0.5
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local mean = w0_sum * w1_sum * 0.5
|
||||
local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6)
|
||||
local p = cdf((U - mean) / std)
|
||||
|
||||
if s0_sum > s1_sum then return 1 - p else return p end
|
||||
end
|
||||
|
||||
return {
|
||||
signbyte=signbyte,
|
||||
boolean_xor=boolean_xor,
|
||||
|
@ -232,4 +269,5 @@ return {
|
|||
exists=exists,
|
||||
pdf=pdf,
|
||||
cdf=cdf,
|
||||
weighted_mann_whitney=weighted_mann_whitney,
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue