From 601d78bfda487371ca542c03b3b679212fdeeed8 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Wed, 13 Jun 2018 21:54:04 +0200 Subject: [PATCH] add evaluation counting to ARS, cleanup --- ars.lua | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/ars.lua b/ars.lua index fd56b52..b40e314 100644 --- a/ars.lua +++ b/ars.lua @@ -63,6 +63,8 @@ function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic) if self.antithetic then self.popsize = self.popsize * 2 end self._params = nn.zeros(self.dims) + + self.evals = 0 end function Ars:params(new_params) @@ -113,8 +115,13 @@ function Ars:ask(graycode) end function Ars:tell(scored, unperturbed_score) + local use_lips = unperturbed_score ~= nil and self.antithetic + self.evals = self.evals + #scored + if use_lips then self.evals = self.evals + 1 end + + -- FIXME: the non-antithetic case seems to be broken. + local indices = collect_best_indices(scored, self.poptop, self.antithetic) - --print("best trials:", indices) local top_rewards = {} for i = 1, #scored do top_rewards[i] = 0 end @@ -123,16 +130,6 @@ function Ars:tell(scored, unperturbed_score) top_rewards[sind + 0] = scored[sind + 0] top_rewards[sind + 1] = scored[sind + 1] end - --print("top:", top_rewards) - - if self.antithetic then - local top_delta_rewards = {} -- only used for printing. - for i, ind in ipairs(indices) do - local sind = (ind - 1) * 2 + 1 - top_delta_rewards[i] = abs(top_rewards[sind + 0] - top_rewards[sind + 1]) - end - --print("best deltas:", top_delta_rewards) - end local step = nn.zeros(self.dims) local _, reward_dev = calc_mean_dev(top_rewards) @@ -147,7 +144,7 @@ function Ars:tell(scored, unperturbed_score) if reward ~= 0 then local noisy = self.noise[i] - if unperturbed_score ~= nil then + if use_lips then local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score) reward = reward / lips / self.sigma else