fix and rewrite ARS telling (it was super broken!)
This commit is contained in:
parent
102eefe98c
commit
f52fabc549
1 changed files with 15 additions and 18 deletions
33
ars.lua
33
ars.lua
|
@ -8,6 +8,7 @@
|
|||
local abs = math.abs
|
||||
local exp = math.exp
|
||||
local floor = math.floor
|
||||
local insert = table.insert
|
||||
local ipairs = ipairs
|
||||
local max = math.max
|
||||
local print = print
|
||||
|
@ -38,11 +39,10 @@ local function collect_best_indices(scored, top, antithetic)
|
|||
local best_rewards
|
||||
if antithetic then
|
||||
best_rewards = {}
|
||||
for i = 1, #scored, 2 do
|
||||
local ind = floor(i / 2) + 1
|
||||
local pos = scored[i + 0]
|
||||
local neg = scored[i + 1]
|
||||
best_rewards[ind] = max(pos, neg)
|
||||
for i = 1, #scored / 2 do
|
||||
local pos = scored[i * 2 - 1]
|
||||
local neg = scored[i * 2 - 0]
|
||||
best_rewards[i] = max(pos, neg)
|
||||
end
|
||||
else
|
||||
best_rewards = scored
|
||||
|
@ -143,18 +143,16 @@ function Ars:tell(scored, unperturbed_score)
|
|||
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
|
||||
|
||||
local top_rewards = {}
|
||||
for i = 1, #scored do top_rewards[i] = 0 end
|
||||
if self.antithetic then
|
||||
for _, ind in ipairs(indices) do
|
||||
local sind = (ind - 1) * 2 + 1
|
||||
top_rewards[sind + 0] = scored[sind + 0]
|
||||
top_rewards[sind + 1] = scored[sind + 1]
|
||||
insert(top_rewards, scored[ind * 2 - 1])
|
||||
insert(top_rewards, scored[ind * 2 - 0])
|
||||
end
|
||||
else
|
||||
-- ARS is built around antithetic sampling,
|
||||
-- but we can still do something without.
|
||||
-- this is getting to be very similar to SNES however.
|
||||
for _, i in ipairs(indices) do top_rewards[i] = scored[i] end
|
||||
for _, ind in ipairs(indices) do insert(top_rewards, scored[ind]) end
|
||||
-- note: although this normalizes the scale, it's later
|
||||
-- re-normalized differently by reward_dev anyway.
|
||||
top_rewards = normalize_sums(top_rewards)
|
||||
|
@ -165,14 +163,12 @@ function Ars:tell(scored, unperturbed_score)
|
|||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
if self.antithetic then
|
||||
for i = 1, floor(self.popsize / 2) do
|
||||
local ind = (i - 1) * 2 + 1
|
||||
local pos = top_rewards[ind + 0]
|
||||
local neg = top_rewards[ind + 1]
|
||||
for i, ind in ipairs(indices) do
|
||||
local pos = top_rewards[i * 2 - 1]
|
||||
local neg = top_rewards[i * 2 - 0]
|
||||
local reward = pos - neg
|
||||
if reward ~= 0 then
|
||||
local noisy = self.noise[i]
|
||||
|
||||
local noisy = self.noise[ind * 2 - 1]
|
||||
if use_lips then
|
||||
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
|
||||
reward = reward / lips / self.sigma
|
||||
|
@ -186,10 +182,10 @@ function Ars:tell(scored, unperturbed_score)
|
|||
end
|
||||
end
|
||||
else
|
||||
for i = 1, self.popsize do
|
||||
for i, ind in ipairs(indices) do
|
||||
local reward = top_rewards[i] / reward_dev
|
||||
if reward ~= 0 then
|
||||
local noisy = self.noise[i]
|
||||
local noisy = self.noise[ind]
|
||||
|
||||
for j, v in ipairs(noisy) do
|
||||
step[j] = step[j] + reward * v / self.poptop
|
||||
|
@ -215,5 +211,6 @@ function Ars:tell(scored, unperturbed_score)
|
|||
end
|
||||
|
||||
return {
|
||||
collect_best_indices = collect_best_indices,
|
||||
Ars = Ars,
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue