2018-06-09 08:56:18 -07:00
|
|
|
-- Augmented Random Search
|
|
|
|
-- https://arxiv.org/abs/1803.07055
|
2018-06-14 13:15:49 -07:00
|
|
|
-- with some tweaks (lipschitz stuff) by myself.
|
|
|
|
-- i also added an option for graycode sampling,
|
|
|
|
-- borrowed from a (1+1) optimizer,
|
|
|
|
-- but i haven't yet found a case where it performs better.
|
2018-06-09 08:56:18 -07:00
|
|
|
|
|
|
|
local abs = math.abs
|
2018-06-13 13:51:12 -07:00
|
|
|
local exp = math.exp
|
2018-06-09 08:56:18 -07:00
|
|
|
local floor = math.floor
|
|
|
|
local ipairs = ipairs
|
|
|
|
local max = math.max
|
|
|
|
local print = print
|
|
|
|
|
|
|
|
local Base = require "Base"
|
|
|
|
|
|
|
|
local nn = require "nn"
|
|
|
|
local normal = nn.normal
|
|
|
|
local prod = nn.prod
|
2018-06-13 13:51:12 -07:00
|
|
|
local uniform = nn.uniform
|
2018-06-09 08:56:18 -07:00
|
|
|
local zeros = nn.zeros
|
|
|
|
|
|
|
|
local util = require "util"
|
|
|
|
local argsort = util.argsort
|
|
|
|
local calc_mean_dev = util.calc_mean_dev
|
2018-06-13 13:51:12 -07:00
|
|
|
local normalize_sums = util.normalize_sums
|
2018-06-09 08:56:18 -07:00
|
|
|
|
|
|
|
local Ars = Base:extend()
|
|
|
|
|
|
|
|
local function collect_best_indices(scored, top, antithetic)
|
|
|
|
-- select one (the best) reward of each pos/neg pair.
|
|
|
|
local best_rewards
|
|
|
|
if antithetic then
|
|
|
|
best_rewards = {}
|
|
|
|
for i = 1, #scored, 2 do
|
|
|
|
local ind = floor(i / 2) + 1
|
|
|
|
local pos = scored[i + 0]
|
|
|
|
local neg = scored[i + 1]
|
|
|
|
best_rewards[ind] = max(pos, neg)
|
|
|
|
end
|
|
|
|
else
|
|
|
|
best_rewards = scored
|
|
|
|
end
|
|
|
|
|
|
|
|
local indices = argsort(best_rewards, function(a, b) return a > b end)
|
|
|
|
|
|
|
|
for i = top + 1, #best_rewards do indices[i] = nil end
|
|
|
|
return indices
|
|
|
|
end
|
|
|
|
|
|
|
|
local function kinda_lipschitz(dir, pos, neg, mid)
|
2018-06-14 13:15:49 -07:00
|
|
|
-- based on the local lipschitz constant of a quadratic curve
|
|
|
|
-- drawn through the 3 sampled points: positive, negative, and unperturbed.
|
|
|
|
-- it kinda helps? there's probably a better function to base it around.
|
2018-06-09 08:56:18 -07:00
|
|
|
local _, dev = calc_mean_dev(dir)
|
|
|
|
local c0 = neg - mid
|
|
|
|
local c1 = pos - mid
|
|
|
|
local l0 = abs(3 * c1 + c0)
|
|
|
|
local l1 = abs(c1 + 3 * c0)
|
|
|
|
return max(l0, l1) / (2 * dev)
|
|
|
|
end
|
|
|
|
|
|
|
|
function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic)
|
|
|
|
self.dims = dims
|
|
|
|
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
|
|
|
self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
|
|
|
self.sigma = sigma or 1
|
|
|
|
self.antithetic = antithetic and true or false
|
|
|
|
|
|
|
|
self.poptop = poptop or popsize
|
|
|
|
assert(self.poptop <= popsize)
|
|
|
|
if self.antithetic then self.popsize = self.popsize * 2 end
|
|
|
|
|
2018-06-13 13:51:12 -07:00
|
|
|
self._params = zeros(self.dims)
|
2018-06-13 12:54:04 -07:00
|
|
|
|
|
|
|
self.evals = 0
|
2018-06-09 08:56:18 -07:00
|
|
|
end
|
|
|
|
|
|
|
|
function Ars:params(new_params)
|
|
|
|
if new_params ~= nil then
|
|
|
|
assert(#self._params == #new_params, "new parameters have the wrong size")
|
|
|
|
for i, v in ipairs(new_params) do self._params[i] = v end
|
|
|
|
end
|
|
|
|
return self._params
|
|
|
|
end
|
|
|
|
|
|
|
|
function Ars:ask(graycode)
|
|
|
|
local asked = {}
|
|
|
|
local noise = {}
|
|
|
|
|
|
|
|
for i = 1, self.popsize do
|
|
|
|
local asking = zeros(self.dims)
|
|
|
|
local noisy = zeros(self.dims)
|
|
|
|
asked[i] = asking
|
2018-06-12 11:51:08 -07:00
|
|
|
noise[i] = noisy
|
2018-06-09 08:56:18 -07:00
|
|
|
|
|
|
|
if self.antithetic and i % 2 == 0 then
|
2018-06-12 11:51:08 -07:00
|
|
|
local old_noisy = noise[i - 1]
|
|
|
|
for j, v in ipairs(old_noisy) do
|
|
|
|
noisy[j] = -v
|
2018-06-09 08:56:18 -07:00
|
|
|
end
|
|
|
|
else
|
|
|
|
if graycode ~= nil then
|
|
|
|
for j = 1, self.dims do
|
2018-06-13 13:51:12 -07:00
|
|
|
noisy[j] = exp(-precision * uniform())
|
2018-06-09 08:56:18 -07:00
|
|
|
end
|
|
|
|
for j = 1, self.dims do
|
2018-06-13 13:51:12 -07:00
|
|
|
noisy[j] = uniform() < 0.5 and noisy[j] or -noisy[j]
|
2018-06-09 08:56:18 -07:00
|
|
|
end
|
|
|
|
else
|
|
|
|
for j = 1, self.dims do
|
2018-06-13 13:51:12 -07:00
|
|
|
noisy[j] = self.sigma * normal()
|
2018-06-09 08:56:18 -07:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2018-06-12 11:51:08 -07:00
|
|
|
for j, v in ipairs(self._params) do
|
|
|
|
asking[j] = v + noisy[j]
|
|
|
|
end
|
2018-06-09 08:56:18 -07:00
|
|
|
end
|
|
|
|
|
|
|
|
self.noise = noise
|
|
|
|
return asked, noise
|
|
|
|
end
|
|
|
|
|
|
|
|
function Ars:tell(scored, unperturbed_score)
|
2018-06-13 12:54:04 -07:00
|
|
|
local use_lips = unperturbed_score ~= nil and self.antithetic
|
|
|
|
self.evals = self.evals + #scored
|
|
|
|
if use_lips then self.evals = self.evals + 1 end
|
|
|
|
|
2018-06-09 08:56:18 -07:00
|
|
|
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
|
|
|
|
|
|
|
|
local top_rewards = {}
|
|
|
|
for i = 1, #scored do top_rewards[i] = 0 end
|
2018-06-13 13:46:09 -07:00
|
|
|
if self.antithetic then
|
|
|
|
for _, ind in ipairs(indices) do
|
|
|
|
local sind = (ind - 1) * 2 + 1
|
|
|
|
top_rewards[sind + 0] = scored[sind + 0]
|
|
|
|
top_rewards[sind + 1] = scored[sind + 1]
|
|
|
|
end
|
|
|
|
else
|
2018-06-14 13:15:49 -07:00
|
|
|
-- ARS is built around antithetic sampling,
|
|
|
|
-- but we can still do something without.
|
|
|
|
-- this is getting to be very similar to SNES however.
|
2018-06-13 13:46:09 -07:00
|
|
|
for _, i in ipairs(indices) do top_rewards[i] = scored[i] end
|
|
|
|
-- note: although this normalizes the scale, it's later
|
|
|
|
-- re-normalized differently by reward_dev anyway.
|
2018-06-13 13:51:12 -07:00
|
|
|
top_rewards = normalize_sums(top_rewards)
|
2018-06-09 08:56:18 -07:00
|
|
|
end
|
|
|
|
|
2018-06-13 13:51:12 -07:00
|
|
|
local step = zeros(self.dims)
|
2018-06-09 08:56:18 -07:00
|
|
|
local _, reward_dev = calc_mean_dev(top_rewards)
|
|
|
|
if reward_dev == 0 then reward_dev = 1 end
|
|
|
|
|
|
|
|
if self.antithetic then
|
|
|
|
for i = 1, floor(self.popsize / 2) do
|
|
|
|
local ind = (i - 1) * 2 + 1
|
|
|
|
local pos = top_rewards[ind + 0]
|
|
|
|
local neg = top_rewards[ind + 1]
|
|
|
|
local reward = pos - neg
|
|
|
|
if reward ~= 0 then
|
|
|
|
local noisy = self.noise[i]
|
|
|
|
|
2018-06-13 12:54:04 -07:00
|
|
|
if use_lips then
|
2018-06-09 08:56:18 -07:00
|
|
|
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
|
|
|
|
reward = reward / lips / self.sigma
|
|
|
|
else
|
|
|
|
reward = reward / reward_dev
|
|
|
|
end
|
|
|
|
|
|
|
|
for j, v in ipairs(noisy) do
|
|
|
|
step[j] = step[j] + reward * v / self.poptop
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
else
|
|
|
|
for i = 1, self.popsize do
|
|
|
|
local reward = top_rewards[i] / reward_dev
|
|
|
|
if reward ~= 0 then
|
|
|
|
local noisy = self.noise[i]
|
|
|
|
|
|
|
|
for j, v in ipairs(noisy) do
|
|
|
|
step[j] = step[j] + reward * v / self.poptop
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
for i, v in ipairs(self._params) do
|
|
|
|
self._params[i] = v + self.learning_rate * step[i]
|
|
|
|
end
|
|
|
|
|
2018-06-09 09:27:13 -07:00
|
|
|
self.noise = nil
|
2018-06-12 16:36:40 -07:00
|
|
|
|
|
|
|
return step
|
2018-06-09 08:56:18 -07:00
|
|
|
end
|
|
|
|
|
|
|
|
return {
|
|
|
|
Ars = Ars,
|
|
|
|
}
|