diff --git a/ars.lua b/ars.lua index 99e3c46..2e2eddd 100644 --- a/ars.lua +++ b/ars.lua @@ -105,12 +105,13 @@ function Ars:ask(graycode) local asking = zeros(self.dims) local noisy = zeros(self.dims) asked[i] = asking + noise[i] = noisy if self.antithetic and i % 2 == 0 then - for j, v in ipairs(self._params) do - asking[i] = v - noisy[j] + local old_noisy = noise[i - 1] + for j, v in ipairs(old_noisy) do + noisy[j] = -v end - else if graycode ~= nil then for j = 1, self.dims do @@ -124,13 +125,11 @@ function Ars:ask(graycode) noisy[j] = self.sigma * nn.normal() end end - - for j, v in ipairs(self._params) do - asking[j] = v + noisy[j] - end end - noise[i] = noisy + for j, v in ipairs(self._params) do + asking[j] = v + noisy[j] + end end self.noise = noise