fix non-antithetic case for ARS
This commit is contained in:
parent
601d78bfda
commit
ac4c534185
1 changed files with 11 additions and 6 deletions
9
ars.lua
9
ars.lua
|
@ -119,17 +119,22 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
self.evals = self.evals + #scored
|
self.evals = self.evals + #scored
|
||||||
if use_lips then self.evals = self.evals + 1 end
|
if use_lips then self.evals = self.evals + 1 end
|
||||||
|
|
||||||
-- FIXME: the non-antithetic case seems to be broken.
|
|
||||||
|
|
||||||
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
|
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
|
||||||
|
|
||||||
local top_rewards = {}
|
local top_rewards = {}
|
||||||
for i = 1, #scored do top_rewards[i] = 0 end
|
for i = 1, #scored do top_rewards[i] = 0 end
|
||||||
|
if self.antithetic then
|
||||||
for _, ind in ipairs(indices) do
|
for _, ind in ipairs(indices) do
|
||||||
local sind = (ind - 1) * 2 + 1
|
local sind = (ind - 1) * 2 + 1
|
||||||
top_rewards[sind + 0] = scored[sind + 0]
|
top_rewards[sind + 0] = scored[sind + 0]
|
||||||
top_rewards[sind + 1] = scored[sind + 1]
|
top_rewards[sind + 1] = scored[sind + 1]
|
||||||
end
|
end
|
||||||
|
else
|
||||||
|
for _, i in ipairs(indices) do top_rewards[i] = scored[i] end
|
||||||
|
-- note: although this normalizes the scale, it's later
|
||||||
|
-- re-normalized differently by reward_dev anyway.
|
||||||
|
top_rewards = util.normalize_sums(top_rewards)
|
||||||
|
end
|
||||||
|
|
||||||
local step = nn.zeros(self.dims)
|
local step = nn.zeros(self.dims)
|
||||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||||
|
|
Loading…
Add table
Reference in a new issue