add evaluation counting to ARS, cleanup
This commit is contained in:
parent
6498b4143f
commit
601d78bfda
1 changed files with 9 additions and 12 deletions
21
ars.lua
21
ars.lua
|
@ -63,6 +63,8 @@ function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic)
|
||||||
if self.antithetic then self.popsize = self.popsize * 2 end
|
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||||
|
|
||||||
self._params = nn.zeros(self.dims)
|
self._params = nn.zeros(self.dims)
|
||||||
|
|
||||||
|
self.evals = 0
|
||||||
end
|
end
|
||||||
|
|
||||||
function Ars:params(new_params)
|
function Ars:params(new_params)
|
||||||
|
@ -113,8 +115,13 @@ function Ars:ask(graycode)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Ars:tell(scored, unperturbed_score)
|
function Ars:tell(scored, unperturbed_score)
|
||||||
|
local use_lips = unperturbed_score ~= nil and self.antithetic
|
||||||
|
self.evals = self.evals + #scored
|
||||||
|
if use_lips then self.evals = self.evals + 1 end
|
||||||
|
|
||||||
|
-- FIXME: the non-antithetic case seems to be broken.
|
||||||
|
|
||||||
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
|
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
|
||||||
--print("best trials:", indices)
|
|
||||||
|
|
||||||
local top_rewards = {}
|
local top_rewards = {}
|
||||||
for i = 1, #scored do top_rewards[i] = 0 end
|
for i = 1, #scored do top_rewards[i] = 0 end
|
||||||
|
@ -123,16 +130,6 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
top_rewards[sind + 0] = scored[sind + 0]
|
top_rewards[sind + 0] = scored[sind + 0]
|
||||||
top_rewards[sind + 1] = scored[sind + 1]
|
top_rewards[sind + 1] = scored[sind + 1]
|
||||||
end
|
end
|
||||||
--print("top:", top_rewards)
|
|
||||||
|
|
||||||
if self.antithetic then
|
|
||||||
local top_delta_rewards = {} -- only used for printing.
|
|
||||||
for i, ind in ipairs(indices) do
|
|
||||||
local sind = (ind - 1) * 2 + 1
|
|
||||||
top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1])
|
|
||||||
end
|
|
||||||
--print("best deltas:", top_delta_rewards)
|
|
||||||
end
|
|
||||||
|
|
||||||
local step = nn.zeros(self.dims)
|
local step = nn.zeros(self.dims)
|
||||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||||
|
@ -147,7 +144,7 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
if reward ~= 0 then
|
if reward ~= 0 then
|
||||||
local noisy = self.noise[i]
|
local noisy = self.noise[i]
|
||||||
|
|
||||||
if unperturbed_score ~= nil then
|
if use_lips then
|
||||||
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
|
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
|
||||||
reward = reward / lips / self.sigma
|
reward = reward / lips / self.sigma
|
||||||
else
|
else
|
||||||
|
|
Loading…
Add table
Reference in a new issue