-- Augmented Random Search -- https://arxiv.org/abs/1803.07055 -- with some tweaks (lipschitz stuff) by myself. -- i also added an option for graycode sampling, -- borrowed from a (1+1) optimizer, -- but i haven't yet found a case where it performs better. local abs = math.abs local exp = math.exp local floor = math.floor local insert = table.insert local ipairs = ipairs local max = math.max local print = print local Base = require "Base" local nn = require "nn" local normal = nn.normal local prod = nn.prod local uniform = nn.uniform local zeros = nn.zeros local util = require "util" local argsort = util.argsort local calc_mean_dev = util.calc_mean_dev local normalize_sums = util.normalize_sums local sign = util.sign local Ars = Base:extend() local exp_lut = {} exp_lut[-1] = exp(-1) exp_lut[0] = exp(0) exp_lut[1] = exp(1) local function collect_best_indices(scored, top, antithetic) -- select one (the best) reward of each pos/neg pair. local best_rewards if antithetic then best_rewards = {} for i = 1, #scored / 2 do local pos = scored[i * 2 - 1] local neg = scored[i * 2 - 0] best_rewards[i] = max(pos, neg) end else best_rewards = scored end local indices = argsort(best_rewards, function(a, b) return a > b end) for i = top + 1, #best_rewards do indices[i] = nil end return indices end local function kinda_lipschitz(dir, pos, neg, mid) -- based on the local lipschitz constant of a quadratic curve -- drawn through the 3 sampled points: positive, negative, and unperturbed. -- it kinda helps? there's probably a better function to base it around. local _, dev = calc_mean_dev(dir) local c0 = neg - mid local c1 = pos - mid local l0 = abs(3 * c1 + c0) local l1 = abs(c1 + 3 * c0) return max(l0, l1) / (2 * dev) end function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic, momentum) self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) self.param_rate = base_rate self.sigma_rate = base_rate self.covar_rate = base_rate self.sigma = sigma or 1 self.antithetic = antithetic == nil and true or antithetic self.momentum = momentum or 0 self.poptop = poptop or popsize assert(self.poptop <= popsize) if self.antithetic then self.popsize = self.popsize * 2 end self._params = zeros(self.dims) if self.momentum > 0 then self.accum = zeros(self.dims) end self.evals = 0 end function Ars:params(new_params) if new_params ~= nil then assert(#self._params == #new_params, "new parameters have the wrong size") for i, v in ipairs(new_params) do self._params[i] = v end end return self._params end function Ars:decay(param_decay, sigma_decay) if param_decay > 0 then for i, v in ipairs(self._params) do self._params[i] = v * (1 - self.param_rate * param_decay * self.sigma) end end end function Ars:ask(graycode) local asked = {} local noise = {} for i = 1, self.popsize do local asking = zeros(self.dims) local noisy = zeros(self.dims) asked[i] = asking noise[i] = noisy if self.antithetic and i % 2 == 0 then local old_noisy = noise[i - 1] for j, v in ipairs(old_noisy) do noisy[j] = -v end else if graycode ~= nil then for j = 1, self.dims do noisy[j] = exp(-precision * uniform()) end for j = 1, self.dims do noisy[j] = uniform() < 0.5 and noisy[j] or -noisy[j] end else for j = 1, self.dims do noisy[j] = self.sigma * normal() end end end for j, v in ipairs(self._params) do asking[j] = v + noisy[j] end end self.noise = noise return asked, noise end function Ars:tell(scored, unperturbed_score) local use_lips = unperturbed_score ~= nil and self.antithetic self.evals = self.evals + #scored if use_lips then self.evals = self.evals + 1 end local indices = collect_best_indices(scored, self.poptop, self.antithetic) local top_rewards = {} if self.antithetic then for _, ind in ipairs(indices) do insert(top_rewards, scored[ind * 2 - 1]) insert(top_rewards, scored[ind * 2 - 0]) end else -- ARS is built around antithetic sampling, -- but we can still do something without. -- this is getting to be very similar to SNES however. for _, ind in ipairs(indices) do insert(top_rewards, scored[ind]) end -- note: although this normalizes the scale, it's later -- re-normalized differently by reward_dev anyway. top_rewards = normalize_sums(top_rewards) end local step = zeros(self.dims) local _, reward_dev = calc_mean_dev(top_rewards) if reward_dev == 0 then reward_dev = 1 end if self.antithetic then for i, ind in ipairs(indices) do local pos = top_rewards[i * 2 - 1] local neg = top_rewards[i * 2 - 0] local reward = pos - neg if reward ~= 0 then local noisy = self.noise[ind * 2 - 1] if use_lips then local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score) reward = reward / lips / self.sigma else reward = reward / reward_dev end for j, v in ipairs(noisy) do step[j] = step[j] + reward * v / self.poptop end end end else for i, ind in ipairs(indices) do local reward = top_rewards[i] / reward_dev if reward ~= 0 then local noisy = self.noise[ind] for j, v in ipairs(noisy) do step[j] = step[j] + reward * v / self.poptop end end end end if self.momentum > 0 then for i, v in ipairs(step) do self.accum[i] = self.momentum * self.accum[i] + v step[i] = v * exp_lut[sign(v) * sign(self.accum[i])] end end for i, v in ipairs(self._params) do self._params[i] = v + self.param_rate * step[i] end self.noise = nil return step end return { collect_best_indices = collect_best_indices, Ars = Ars, }