fix and rewrite ARS telling (it was super broken!)
This commit is contained in:
parent
102eefe98c
commit
f52fabc549
1 changed files with 15 additions and 18 deletions
33
ars.lua
33
ars.lua
|
@ -8,6 +8,7 @@
|
||||||
local abs = math.abs
|
local abs = math.abs
|
||||||
local exp = math.exp
|
local exp = math.exp
|
||||||
local floor = math.floor
|
local floor = math.floor
|
||||||
|
local insert = table.insert
|
||||||
local ipairs = ipairs
|
local ipairs = ipairs
|
||||||
local max = math.max
|
local max = math.max
|
||||||
local print = print
|
local print = print
|
||||||
|
@ -38,11 +39,10 @@ local function collect_best_indices(scored, top, antithetic)
|
||||||
local best_rewards
|
local best_rewards
|
||||||
if antithetic then
|
if antithetic then
|
||||||
best_rewards = {}
|
best_rewards = {}
|
||||||
for i = 1, #scored, 2 do
|
for i = 1, #scored / 2 do
|
||||||
local ind = floor(i / 2) + 1
|
local pos = scored[i * 2 - 1]
|
||||||
local pos = scored[i + 0]
|
local neg = scored[i * 2 - 0]
|
||||||
local neg = scored[i + 1]
|
best_rewards[i] = max(pos, neg)
|
||||||
best_rewards[ind] = max(pos, neg)
|
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
best_rewards = scored
|
best_rewards = scored
|
||||||
|
@ -143,18 +143,16 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
|
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
|
||||||
|
|
||||||
local top_rewards = {}
|
local top_rewards = {}
|
||||||
for i = 1, #scored do top_rewards[i] = 0 end
|
|
||||||
if self.antithetic then
|
if self.antithetic then
|
||||||
for _, ind in ipairs(indices) do
|
for _, ind in ipairs(indices) do
|
||||||
local sind = (ind - 1) * 2 + 1
|
insert(top_rewards, scored[ind * 2 - 1])
|
||||||
top_rewards[sind + 0] = scored[sind + 0]
|
insert(top_rewards, scored[ind * 2 - 0])
|
||||||
top_rewards[sind + 1] = scored[sind + 1]
|
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
-- ARS is built around antithetic sampling,
|
-- ARS is built around antithetic sampling,
|
||||||
-- but we can still do something without.
|
-- but we can still do something without.
|
||||||
-- this is getting to be very similar to SNES however.
|
-- this is getting to be very similar to SNES however.
|
||||||
for _, i in ipairs(indices) do top_rewards[i] = scored[i] end
|
for _, ind in ipairs(indices) do insert(top_rewards, scored[ind]) end
|
||||||
-- note: although this normalizes the scale, it's later
|
-- note: although this normalizes the scale, it's later
|
||||||
-- re-normalized differently by reward_dev anyway.
|
-- re-normalized differently by reward_dev anyway.
|
||||||
top_rewards = normalize_sums(top_rewards)
|
top_rewards = normalize_sums(top_rewards)
|
||||||
|
@ -165,14 +163,12 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
if reward_dev == 0 then reward_dev = 1 end
|
if reward_dev == 0 then reward_dev = 1 end
|
||||||
|
|
||||||
if self.antithetic then
|
if self.antithetic then
|
||||||
for i = 1, floor(self.popsize / 2) do
|
for i, ind in ipairs(indices) do
|
||||||
local ind = (i - 1) * 2 + 1
|
local pos = top_rewards[i * 2 - 1]
|
||||||
local pos = top_rewards[ind + 0]
|
local neg = top_rewards[i * 2 - 0]
|
||||||
local neg = top_rewards[ind + 1]
|
|
||||||
local reward = pos - neg
|
local reward = pos - neg
|
||||||
if reward ~= 0 then
|
if reward ~= 0 then
|
||||||
local noisy = self.noise[i]
|
local noisy = self.noise[ind * 2 - 1]
|
||||||
|
|
||||||
if use_lips then
|
if use_lips then
|
||||||
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
|
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
|
||||||
reward = reward / lips / self.sigma
|
reward = reward / lips / self.sigma
|
||||||
|
@ -186,10 +182,10 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
for i = 1, self.popsize do
|
for i, ind in ipairs(indices) do
|
||||||
local reward = top_rewards[i] / reward_dev
|
local reward = top_rewards[i] / reward_dev
|
||||||
if reward ~= 0 then
|
if reward ~= 0 then
|
||||||
local noisy = self.noise[i]
|
local noisy = self.noise[ind]
|
||||||
|
|
||||||
for j, v in ipairs(noisy) do
|
for j, v in ipairs(noisy) do
|
||||||
step[j] = step[j] + reward * v / self.poptop
|
step[j] = step[j] + reward * v / self.poptop
|
||||||
|
@ -215,5 +211,6 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
collect_best_indices = collect_best_indices,
|
||||||
Ars = Ars,
|
Ars = Ars,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue