-- Augmented Random Search -- https://arxiv.org/abs/1803.07055 -- with some tweaks (lips) by myself. local abs = math.abs local floor = math.floor local ipairs = ipairs local max = math.max local print = print local Base = require "Base" local nn = require "nn" local normal = nn.normal local prod = nn.prod local zeros = nn.zeros local util = require "util" local argsort = util.argsort local calc_mean_dev = util.calc_mean_dev local Ars = Base:extend() local function collect_best_indices(scored, top, antithetic) -- select one (the best) reward of each pos/neg pair. local best_rewards if antithetic then best_rewards = {} for i = 1, #scored, 2 do local ind = floor(i / 2) + 1 local pos = scored[i + 0] local neg = scored[i + 1] best_rewards[ind] = max(pos, neg) end else best_rewards = scored end local indices = argsort(best_rewards, function(a, b) return a > b end) for i = top + 1, #best_rewards do indices[i] = nil end return indices end local function kinda_lipschitz(dir, pos, neg, mid) local _, dev = calc_mean_dev(dir) local c0 = neg - mid local c1 = pos - mid local l0 = abs(3 * c1 + c0) local l1 = abs(c1 + 3 * c0) return max(l0, l1) / (2 * dev) end local function amsgrad(step) -- in-place! -- TODO: fix this. if mom1 == nil then mom1 = nn.zeros(#step) end if mom2 == nil then mom2 = nn.zeros(#step) end if mom2max == nil then mom2max = nn.zeros(#step) end local b1_t = pow(cfg.adam_b1, epoch_i) local b2_t = pow(cfg.adam_b2, epoch_i) -- NOTE: with LuaJIT, splitting this loop would -- almost certainly be faster. for i, v in ipairs(step) do mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v mom2max[i] = max(mom2[i], mom2max[i]) if cfg.adam_debias then local num = (mom1[i] / (1 - b1_t)) local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps step[i] = num / den else step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps) end end end function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic) self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) self.learning_rate = learning_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims)) self.sigma = sigma or 1 self.antithetic = antithetic and true or false self.poptop = poptop or popsize assert(self.poptop <= popsize) if self.antithetic then self.popsize = self.popsize * 2 end self._params = nn.zeros(self.dims) end function Ars:params(new_params) if new_params ~= nil then assert(#self._params == #new_params, "new parameters have the wrong size") for i, v in ipairs(new_params) do self._params[i] = v end end return self._params end function Ars:ask(graycode) local asked = {} local noise = {} for i = 1, self.popsize do local asking = zeros(self.dims) local noisy = zeros(self.dims) asked[i] = asking if self.antithetic and i % 2 == 0 then for j, v in ipairs(self._params) do asking[i] = v - noisy[j] end else if graycode ~= nil then for j = 1, self.dims do noisy[j] = exp(-precision * nn.uniform()) end for j = 1, self.dims do noisy[j] = nn.uniform() < 0.5 and noisy[j] or -noisy[j] end else for j = 1, self.dims do noisy[j] = self.sigma * nn.normal() end end for j, v in ipairs(self._params) do asking[j] = v + noisy[j] end end noise[i] = noisy end self.noise = noise return asked, noise end function Ars:tell(scored, unperturbed_score) local indices = collect_best_indices(scored, self.poptop, self.antithetic) --print("best trials:", indices) local top_rewards = {} for i = 1, #scored do top_rewards[i] = 0 end for _, ind in ipairs(indices) do local sind = (ind - 1) * 2 + 1 top_rewards[sind + 0] = scored[sind + 0] top_rewards[sind + 1] = scored[sind + 1] end --print("top:", top_rewards) if self.antithetic then local top_delta_rewards = {} -- only used for printing. for i, ind in ipairs(indices) do local sind = (ind - 1) * 2 + 1 top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1]) end --print("best deltas:", top_delta_rewards) end local step = nn.zeros(self.dims) local _, reward_dev = calc_mean_dev(top_rewards) if reward_dev == 0 then reward_dev = 1 end if self.antithetic then for i = 1, floor(self.popsize / 2) do local ind = (i - 1) * 2 + 1 local pos = top_rewards[ind + 0] local neg = top_rewards[ind + 1] local reward = pos - neg if reward ~= 0 then local noisy = self.noise[i] if unperturbed_score ~= nil then local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score) reward = reward / lips / self.sigma else reward = reward / reward_dev end for j, v in ipairs(noisy) do step[j] = step[j] + reward * v / self.poptop end end end else for i = 1, self.popsize do local reward = top_rewards[i] / reward_dev if reward ~= 0 then local noisy = self.noise[i] for j, v in ipairs(noisy) do step[j] = step[j] + reward * v / self.poptop end end end end for i, v in ipairs(self._params) do self._params[i] = v + self.learning_rate * step[i] end self.asked = nil end return { Ars = Ars, }