diff --git a/ars.lua b/ars.lua index 2e2eddd..fd56b52 100644 --- a/ars.lua +++ b/ars.lua @@ -51,30 +51,6 @@ local function kinda_lipschitz(dir, pos, neg, mid) return max(l0, l1) / (2 * dev) end -local function amsgrad(step) -- in-place! -- TODO: fix this. - if mom1 == nil then mom1 = nn.zeros(#step) end - if mom2 == nil then mom2 = nn.zeros(#step) end - if mom2max == nil then mom2max = nn.zeros(#step) end - - local b1_t = pow(cfg.adam_b1, epoch_i) - local b2_t = pow(cfg.adam_b2, epoch_i) - - -- NOTE: with LuaJIT, splitting this loop would - -- almost certainly be faster. - for i, v in ipairs(step) do - mom1[i] = cfg.adam_b1 * mom1[i] + (1 - cfg.adam_b1) * v - mom2[i] = cfg.adam_b2 * mom2[i] + (1 - cfg.adam_b2) * v * v - mom2max[i] = max(mom2[i], mom2max[i]) - if cfg.adam_debias then - local num = (mom1[i] / (1 - b1_t)) - local den = sqrt(mom2max[i] / (1 - b2_t)) + cfg.adam_eps - step[i] = num / den - else - step[i] = mom1[i] / (sqrt(mom2max[i]) + cfg.adam_eps) - end - end -end - function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic) self.dims = dims self.popsize = popsize or 4 + (3 * floor(log(dims))) @@ -201,6 +177,8 @@ function Ars:tell(scored, unperturbed_score) end self.noise = nil + + return step end return { diff --git a/main.lua b/main.lua index a8a4a95..f4dcef0 100644 --- a/main.lua +++ b/main.lua @@ -17,9 +17,6 @@ local trial_neg = true local trial_params --= {} local trial_rewards = {} local trials_remaining = 0 -local mom1 -- first moments in AMSgrad. -local mom2 -- second moments in AMSgrad. -local mom2max -- running element-wise maximum of mom2. local es -- evolution strategy. local trial_frames = 0 @@ -114,10 +111,9 @@ local log_map = { delta_mean = 4, delta_std = 5, step_std = 6, - adam_std = 7, - weight_mean = 8, - weight_std = 9, - test_trial = 10, + weight_mean = 7, + weight_std = 8, + test_trial = 9, } local function log_csv(t) @@ -263,28 +259,16 @@ local function learn_from_epoch() end end + local step if cfg.es == 'ars' and cfg.ars_lips then - es:tell(trial_rewards, current_cost) + step = es:tell(trial_rewards, current_cost) else - es:tell(trial_rewards) + step = es:tell(trial_rewards) end - local step_mean, step_dev = 0, 0 - --[[ TODO local step_mean, step_dev = calc_mean_dev(step) print("step mean:", step_mean) print("step stddev:", step_dev) - --]] - - local momstep_mean, momstep_dev = 0, 0 - --[[ TODO - if cfg.adamant then - amsgrad(step) - momstep_mean, momstep_dev = calc_mean_dev(step) - print("amsgrad mean:", momstep_mean) - print("amsgrad stddev:", momstep_dev) - end - --]] base_params = es:params() @@ -320,7 +304,6 @@ local function learn_from_epoch() delta_mean = delta_mean, delta_std = delta_std, step_std = step_dev, - adam_std = momstep_dev, weight_mean = weight_mean, weight_std = weight_std, test_trial = current_cost or 0, diff --git a/snes.lua b/snes.lua index 1f43b7e..b1e5f32 100644 --- a/snes.lua +++ b/snes.lua @@ -244,8 +244,13 @@ function Snes:tell(scored) end end + local step = {} + for i, v in ipairs(g_mean) do + step[i] = self.std[i] * v + end + for i, v in ipairs(self.mean) do - self.mean[i] = v + self.mean_adapt * self.std[i] * g_mean[i] + self.mean[i] = v + self.mean_adapt * step[i] end local otherwise = {} @@ -257,6 +262,8 @@ function Snes:tell(scored) end self:adapt(asked, otherwise, utility) + + return step end function Snes:adapt(asked, otherwise, qualities) diff --git a/xnes.lua b/xnes.lua index fb04695..caa20a1 100644 --- a/xnes.lua +++ b/xnes.lua @@ -179,9 +179,14 @@ function Xnes:tell(scored, noise) -- finally, update according to the gradients. + local step = {} + for i, v in ipairs(dotted) do + step[i] = self.sigma * v + end + local dotted = dot_mv(self.covars, g_delta) for i, v in ipairs(self.mean) do - self.mean[i] = v + self.mean_adapt * self.sigma * dotted[i] + self.mean[i] = v + self.mean_adapt * step[i] end --[[ @@ -201,6 +206,8 @@ function Xnes:tell(scored, noise) --self.sigma = exp(self.log_sigma) --for i, v in ipairs(self.log_covars) do self.covars[i] = exp(v) end self.noise = nil + + return step end return {