From f52fabc54946b6d6e10506b777ffbdf59bfff1b4 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 21 Jun 2018 17:13:35 +0200 Subject: [PATCH] fix and rewrite ARS telling (it was super broken!) --- ars.lua | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/ars.lua b/ars.lua index 28b6176..6828b35 100644 --- a/ars.lua +++ b/ars.lua @@ -8,6 +8,7 @@ local abs = math.abs local exp = math.exp local floor = math.floor +local insert = table.insert local ipairs = ipairs local max = math.max local print = print @@ -38,11 +39,10 @@ local function collect_best_indices(scored, top, antithetic) local best_rewards if antithetic then best_rewards = {} - for i = 1, #scored, 2 do - local ind = floor(i / 2) + 1 - local pos = scored[i + 0] - local neg = scored[i + 1] - best_rewards[ind] = max(pos, neg) + for i = 1, #scored / 2 do + local pos = scored[i * 2 - 1] + local neg = scored[i * 2 - 0] + best_rewards[i] = max(pos, neg) end else best_rewards = scored @@ -143,18 +143,16 @@ function Ars:tell(scored, unperturbed_score) local indices = collect_best_indices(scored, self.poptop, self.antithetic) local top_rewards = {} - for i = 1, #scored do top_rewards[i] = 0 end if self.antithetic then for _, ind in ipairs(indices) do - local sind = (ind - 1) * 2 + 1 - top_rewards[sind + 0] = scored[sind + 0] - top_rewards[sind + 1] = scored[sind + 1] + insert(top_rewards, scored[ind * 2 - 1]) + insert(top_rewards, scored[ind * 2 - 0]) end else -- ARS is built around antithetic sampling, -- but we can still do something without. -- this is getting to be very similar to SNES however. - for _, i in ipairs(indices) do top_rewards[i] = scored[i] end + for _, ind in ipairs(indices) do insert(top_rewards, scored[ind]) end -- note: although this normalizes the scale, it's later -- re-normalized differently by reward_dev anyway. top_rewards = normalize_sums(top_rewards) @@ -165,14 +163,12 @@ function Ars:tell(scored, unperturbed_score) if reward_dev == 0 then reward_dev = 1 end if self.antithetic then - for i = 1, floor(self.popsize / 2) do - local ind = (i - 1) * 2 + 1 - local pos = top_rewards[ind + 0] - local neg = top_rewards[ind + 1] + for i, ind in ipairs(indices) do + local pos = top_rewards[i * 2 - 1] + local neg = top_rewards[i * 2 - 0] local reward = pos - neg if reward ~= 0 then - local noisy = self.noise[i] - + local noisy = self.noise[ind * 2 - 1] if use_lips then local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score) reward = reward / lips / self.sigma @@ -186,10 +182,10 @@ function Ars:tell(scored, unperturbed_score) end end else - for i = 1, self.popsize do + for i, ind in ipairs(indices) do local reward = top_rewards[i] / reward_dev if reward ~= 0 then - local noisy = self.noise[i] + local noisy = self.noise[ind] for j, v in ipairs(noisy) do step[j] = step[j] + reward * v / self.poptop @@ -215,5 +211,6 @@ function Ars:tell(scored, unperturbed_score) end return { + collect_best_indices = collect_best_indices, Ars = Ars, }