fix and rewrite ARS telling (it was super broken!)

This commit is contained in:
Connor Olding 2018-06-21 17:13:35 +02:00
parent 102eefe98c
commit f52fabc549

33
ars.lua
View File

@ -8,6 +8,7 @@
local abs = math.abs local abs = math.abs
local exp = math.exp local exp = math.exp
local floor = math.floor local floor = math.floor
local insert = table.insert
local ipairs = ipairs local ipairs = ipairs
local max = math.max local max = math.max
local print = print local print = print
@ -38,11 +39,10 @@ local function collect_best_indices(scored, top, antithetic)
local best_rewards local best_rewards
if antithetic then if antithetic then
best_rewards = {} best_rewards = {}
for i = 1, #scored, 2 do for i = 1, #scored / 2 do
local ind = floor(i / 2) + 1 local pos = scored[i * 2 - 1]
local pos = scored[i + 0] local neg = scored[i * 2 - 0]
local neg = scored[i + 1] best_rewards[i] = max(pos, neg)
best_rewards[ind] = max(pos, neg)
end end
else else
best_rewards = scored best_rewards = scored
@ -143,18 +143,16 @@ function Ars:tell(scored, unperturbed_score)
local indices = collect_best_indices(scored, self.poptop, self.antithetic) local indices = collect_best_indices(scored, self.poptop, self.antithetic)
local top_rewards = {} local top_rewards = {}
for i = 1, #scored do top_rewards[i] = 0 end
if self.antithetic then if self.antithetic then
for _, ind in ipairs(indices) do for _, ind in ipairs(indices) do
local sind = (ind - 1) * 2 + 1 insert(top_rewards, scored[ind * 2 - 1])
top_rewards[sind + 0] = scored[sind + 0] insert(top_rewards, scored[ind * 2 - 0])
top_rewards[sind + 1] = scored[sind + 1]
end end
else else
-- ARS is built around antithetic sampling, -- ARS is built around antithetic sampling,
-- but we can still do something without. -- but we can still do something without.
-- this is getting to be very similar to SNES however. -- this is getting to be very similar to SNES however.
for _, i in ipairs(indices) do top_rewards[i] = scored[i] end for _, ind in ipairs(indices) do insert(top_rewards, scored[ind]) end
-- note: although this normalizes the scale, it's later -- note: although this normalizes the scale, it's later
-- re-normalized differently by reward_dev anyway. -- re-normalized differently by reward_dev anyway.
top_rewards = normalize_sums(top_rewards) top_rewards = normalize_sums(top_rewards)
@ -165,14 +163,12 @@ function Ars:tell(scored, unperturbed_score)
if reward_dev == 0 then reward_dev = 1 end if reward_dev == 0 then reward_dev = 1 end
if self.antithetic then if self.antithetic then
for i = 1, floor(self.popsize / 2) do for i, ind in ipairs(indices) do
local ind = (i - 1) * 2 + 1 local pos = top_rewards[i * 2 - 1]
local pos = top_rewards[ind + 0] local neg = top_rewards[i * 2 - 0]
local neg = top_rewards[ind + 1]
local reward = pos - neg local reward = pos - neg
if reward ~= 0 then if reward ~= 0 then
local noisy = self.noise[i] local noisy = self.noise[ind * 2 - 1]
if use_lips then if use_lips then
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score) local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
reward = reward / lips / self.sigma reward = reward / lips / self.sigma
@ -186,10 +182,10 @@ function Ars:tell(scored, unperturbed_score)
end end
end end
else else
for i = 1, self.popsize do for i, ind in ipairs(indices) do
local reward = top_rewards[i] / reward_dev local reward = top_rewards[i] / reward_dev
if reward ~= 0 then if reward ~= 0 then
local noisy = self.noise[i] local noisy = self.noise[ind]
for j, v in ipairs(noisy) do for j, v in ipairs(noisy) do
step[j] = step[j] + reward * v / self.poptop step[j] = step[j] + reward * v / self.poptop
@ -215,5 +211,6 @@ function Ars:tell(scored, unperturbed_score)
end end
return { return {
collect_best_indices = collect_best_indices,
Ars = Ars, Ars = Ars,
} }