fix non-antithetic case for ARS

This commit is contained in:
Connor Olding 2018-06-13 22:46:09 +02:00
parent 601d78bfda
commit ac4c534185

17
ars.lua
View file

@ -119,16 +119,21 @@ function Ars:tell(scored, unperturbed_score)
self.evals = self.evals + #scored
if use_lips then self.evals = self.evals + 1 end
-- FIXME: the non-antithetic case seems to be broken.
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
local top_rewards = {}
for i = 1, #scored do top_rewards[i] = 0 end
for _, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1
top_rewards[sind + 0] = scored[sind + 0]
top_rewards[sind + 1] = scored[sind + 1]
if self.antithetic then
for _, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1
top_rewards[sind + 0] = scored[sind + 0]
top_rewards[sind + 1] = scored[sind + 1]
end
else
for _, i in ipairs(indices) do top_rewards[i] = scored[i] end
-- note: although this normalizes the scale, it's later
-- re-normalized differently by reward_dev anyway.
top_rewards = util.normalize_sums(top_rewards)
end
local step = nn.zeros(self.dims)