diff --git a/ars.lua b/ars.lua index b40e314..0fe6887 100644 --- a/ars.lua +++ b/ars.lua @@ -119,16 +119,21 @@ function Ars:tell(scored, unperturbed_score) self.evals = self.evals + #scored if use_lips then self.evals = self.evals + 1 end - -- FIXME: the non-antithetic case seems to be broken. - local indices = collect_best_indices(scored, self.poptop, self.antithetic) local top_rewards = {} for i = 1, #scored do top_rewards[i] = 0 end - for _, ind in ipairs(indices) do - local sind = (ind - 1) * 2 + 1 - top_rewards[sind + 0] = scored[sind + 0] - top_rewards[sind + 1] = scored[sind + 1] + if self.antithetic then + for _, ind in ipairs(indices) do + local sind = (ind - 1) * 2 + 1 + top_rewards[sind + 0] = scored[sind + 0] + top_rewards[sind + 1] = scored[sind + 1] + end + else + for _, i in ipairs(indices) do top_rewards[i] = scored[i] end + -- note: although this normalizes the scale, it's later + -- re-normalized differently by reward_dev anyway. + top_rewards = util.normalize_sums(top_rewards) end local step = nn.zeros(self.dims)