smbot/ars.lua

217 lines
6.3 KiB
Lua
Raw Normal View History

-- Augmented Random Search
-- https://arxiv.org/abs/1803.07055
2018-06-14 13:15:49 -07:00
-- with some tweaks (lipschitz stuff) by myself.
-- i also added an option for graycode sampling,
-- borrowed from a (1+1) optimizer,
-- but i haven't yet found a case where it performs better.
local abs = math.abs
local exp = math.exp
local floor = math.floor
local insert = table.insert
local ipairs = ipairs
local max = math.max
local print = print
local Base = require "Base"
local nn = require "nn"
local normal = nn.normal
local prod = nn.prod
local uniform = nn.uniform
local zeros = nn.zeros
local util = require "util"
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local normalize_sums = util.normalize_sums
local sign = util.sign
local Ars = Base:extend()
local exp_lut = {}
exp_lut[-1] = exp(-1)
exp_lut[0] = exp(0)
exp_lut[1] = exp(1)
local function collect_best_indices(scored, top, antithetic)
-- select one (the best) reward of each pos/neg pair.
local best_rewards
if antithetic then
best_rewards = {}
for i = 1, #scored / 2 do
local pos = scored[i * 2 - 1]
local neg = scored[i * 2 - 0]
best_rewards[i] = max(pos, neg)
end
else
best_rewards = scored
end
local indices = argsort(best_rewards, function(a, b) return a > b end)
for i = top + 1, #best_rewards do indices[i] = nil end
return indices
end
local function kinda_lipschitz(dir, pos, neg, mid)
2018-06-14 13:15:49 -07:00
-- based on the local lipschitz constant of a quadratic curve
-- drawn through the 3 sampled points: positive, negative, and unperturbed.
-- it kinda helps? there's probably a better function to base it around.
local _, dev = calc_mean_dev(dir)
local c0 = neg - mid
local c1 = pos - mid
local l0 = abs(3 * c1 + c0)
local l1 = abs(c1 + 3 * c0)
return max(l0, l1) / (2 * dev)
end
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic,
momentum)
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.param_rate = base_rate
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1
self.antithetic = antithetic == nil and true or antithetic
self.momentum = momentum or 0
self.poptop = poptop or popsize
assert(self.poptop <= popsize)
if self.antithetic then self.popsize = self.popsize * 2 end
self._params = zeros(self.dims)
if self.momentum > 0 then self.accum = zeros(self.dims) end
self.evals = 0
end
function Ars:params(new_params)
if new_params ~= nil then
assert(#self._params == #new_params, "new parameters have the wrong size")
for i, v in ipairs(new_params) do self._params[i] = v end
end
return self._params
end
function Ars:ask(graycode)
local asked = {}
local noise = {}
for i = 1, self.popsize do
local asking = zeros(self.dims)
local noisy = zeros(self.dims)
asked[i] = asking
2018-06-12 11:51:08 -07:00
noise[i] = noisy
if self.antithetic and i % 2 == 0 then
2018-06-12 11:51:08 -07:00
local old_noisy = noise[i - 1]
for j, v in ipairs(old_noisy) do
noisy[j] = -v
end
else
if graycode ~= nil then
for j = 1, self.dims do
noisy[j] = exp(-precision * uniform())
end
for j = 1, self.dims do
noisy[j] = uniform() < 0.5 and noisy[j] or -noisy[j]
end
else
for j = 1, self.dims do
noisy[j] = self.sigma * normal()
end
end
end
2018-06-12 11:51:08 -07:00
for j, v in ipairs(self._params) do
asking[j] = v + noisy[j]
end
end
self.noise = noise
return asked, noise
end
function Ars:tell(scored, unperturbed_score)
local use_lips = unperturbed_score ~= nil and self.antithetic
self.evals = self.evals + #scored
if use_lips then self.evals = self.evals + 1 end
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
local top_rewards = {}
2018-06-13 13:46:09 -07:00
if self.antithetic then
for _, ind in ipairs(indices) do
insert(top_rewards, scored[ind * 2 - 1])
insert(top_rewards, scored[ind * 2 - 0])
2018-06-13 13:46:09 -07:00
end
else
2018-06-14 13:15:49 -07:00
-- ARS is built around antithetic sampling,
-- but we can still do something without.
-- this is getting to be very similar to SNES however.
for _, ind in ipairs(indices) do insert(top_rewards, scored[ind]) end
2018-06-13 13:46:09 -07:00
-- note: although this normalizes the scale, it's later
-- re-normalized differently by reward_dev anyway.
top_rewards = normalize_sums(top_rewards)
end
local step = zeros(self.dims)
local _, reward_dev = calc_mean_dev(top_rewards)
if reward_dev == 0 then reward_dev = 1 end
if self.antithetic then
for i, ind in ipairs(indices) do
local pos = top_rewards[i * 2 - 1]
local neg = top_rewards[i * 2 - 0]
local reward = pos - neg
if reward ~= 0 then
local noisy = self.noise[ind * 2 - 1]
if use_lips then
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
reward = reward / lips / self.sigma
else
reward = reward / reward_dev
end
for j, v in ipairs(noisy) do
step[j] = step[j] + reward * v / self.poptop
end
end
end
else
for i, ind in ipairs(indices) do
local reward = top_rewards[i] / reward_dev
if reward ~= 0 then
local noisy = self.noise[ind]
for j, v in ipairs(noisy) do
step[j] = step[j] + reward * v / self.poptop
end
end
end
end
if self.momentum > 0 then
for i, v in ipairs(step) do
self.accum[i] = self.momentum * self.accum[i] + v
step[i] = v * exp_lut[sign(v) * sign(self.accum[i])]
end
end
for i, v in ipairs(self._params) do
self._params[i] = v + self.param_rate * step[i]
end
2018-06-09 09:27:13 -07:00
self.noise = nil
return step
end
return {
collect_best_indices = collect_best_indices,
Ars = Ars,
}