use locals; fix fitness_shaping and graycode
This commit is contained in:
parent
a1ec797de0
commit
63583789c3
3 changed files with 42 additions and 39 deletions
15
ars.lua
15
ars.lua
|
@ -3,6 +3,7 @@
|
||||||
-- with some tweaks (lips) by myself.
|
-- with some tweaks (lips) by myself.
|
||||||
|
|
||||||
local abs = math.abs
|
local abs = math.abs
|
||||||
|
local exp = math.exp
|
||||||
local floor = math.floor
|
local floor = math.floor
|
||||||
local ipairs = ipairs
|
local ipairs = ipairs
|
||||||
local max = math.max
|
local max = math.max
|
||||||
|
@ -13,11 +14,13 @@ local Base = require "Base"
|
||||||
local nn = require "nn"
|
local nn = require "nn"
|
||||||
local normal = nn.normal
|
local normal = nn.normal
|
||||||
local prod = nn.prod
|
local prod = nn.prod
|
||||||
|
local uniform = nn.uniform
|
||||||
local zeros = nn.zeros
|
local zeros = nn.zeros
|
||||||
|
|
||||||
local util = require "util"
|
local util = require "util"
|
||||||
local argsort = util.argsort
|
local argsort = util.argsort
|
||||||
local calc_mean_dev = util.calc_mean_dev
|
local calc_mean_dev = util.calc_mean_dev
|
||||||
|
local normalize_sums = util.normalize_sums
|
||||||
|
|
||||||
local Ars = Base:extend()
|
local Ars = Base:extend()
|
||||||
|
|
||||||
|
@ -62,7 +65,7 @@ function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic)
|
||||||
assert(self.poptop <= popsize)
|
assert(self.poptop <= popsize)
|
||||||
if self.antithetic then self.popsize = self.popsize * 2 end
|
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||||
|
|
||||||
self._params = nn.zeros(self.dims)
|
self._params = zeros(self.dims)
|
||||||
|
|
||||||
self.evals = 0
|
self.evals = 0
|
||||||
end
|
end
|
||||||
|
@ -93,14 +96,14 @@ function Ars:ask(graycode)
|
||||||
else
|
else
|
||||||
if graycode ~= nil then
|
if graycode ~= nil then
|
||||||
for j = 1, self.dims do
|
for j = 1, self.dims do
|
||||||
noisy[j] = exp(-precision * nn.uniform())
|
noisy[j] = exp(-precision * uniform())
|
||||||
end
|
end
|
||||||
for j = 1, self.dims do
|
for j = 1, self.dims do
|
||||||
noisy[j] = nn.uniform() < 0.5 and noisy[j] or -noisy[j]
|
noisy[j] = uniform() < 0.5 and noisy[j] or -noisy[j]
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
for j = 1, self.dims do
|
for j = 1, self.dims do
|
||||||
noisy[j] = self.sigma * nn.normal()
|
noisy[j] = self.sigma * normal()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -133,10 +136,10 @@ function Ars:tell(scored, unperturbed_score)
|
||||||
for _, i in ipairs(indices) do top_rewards[i] = scored[i] end
|
for _, i in ipairs(indices) do top_rewards[i] = scored[i] end
|
||||||
-- note: although this normalizes the scale, it's later
|
-- note: although this normalizes the scale, it's later
|
||||||
-- re-normalized differently by reward_dev anyway.
|
-- re-normalized differently by reward_dev anyway.
|
||||||
top_rewards = util.normalize_sums(top_rewards)
|
top_rewards = normalize_sums(top_rewards)
|
||||||
end
|
end
|
||||||
|
|
||||||
local step = nn.zeros(self.dims)
|
local step = zeros(self.dims)
|
||||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||||
if reward_dev == 0 then reward_dev = 1 end
|
if reward_dev == 0 then reward_dev = 1 end
|
||||||
|
|
||||||
|
|
18
main.lua
18
main.lua
|
@ -89,16 +89,16 @@ local emu = emu
|
||||||
local gui = gui
|
local gui = gui
|
||||||
|
|
||||||
local util = require("util")
|
local util = require("util")
|
||||||
local argmax = util.argmax
|
local argmax = util.argmax
|
||||||
local argsort = util.argsort
|
local argsort = util.argsort
|
||||||
local calc_mean_dev = util.calc_mean_dev
|
local calc_mean_dev = util.calc_mean_dev
|
||||||
local clamp = util.clamp
|
local clamp = util.clamp
|
||||||
local copy = util.copy
|
local copy = util.copy
|
||||||
local empty = util.empty
|
local empty = util.empty
|
||||||
local lerp = util.lerp
|
local lerp = util.lerp
|
||||||
local softchoice = util.softchoice
|
local softchoice = util.softchoice
|
||||||
local unperturbed_rank = util.unperturbed_rank
|
local unperturbed_rank = util.unperturbed_rank
|
||||||
local exists = util.exists
|
local exists = util.exists
|
||||||
|
|
||||||
local game = require("smb")
|
local game = require("smb")
|
||||||
game.overlay = cfg.enable_overlay
|
game.overlay = cfg.enable_overlay
|
||||||
|
|
48
util.lua
48
util.lua
|
@ -105,30 +105,6 @@ local function normalize_sums(x, out)
|
||||||
return out
|
return out
|
||||||
end
|
end
|
||||||
|
|
||||||
local function fitness_shaping(rewards)
|
|
||||||
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
|
||||||
local decreasing = nn.copy(rewards)
|
|
||||||
sort(decreasing, function(a, b) return a > b end)
|
|
||||||
local shaped_returns = {}
|
|
||||||
local lamb = #rewards
|
|
||||||
|
|
||||||
local denom = 0
|
|
||||||
for i, v in ipairs(rewards) do
|
|
||||||
local l = log2(lamb / 2 + 1)
|
|
||||||
local r = log2(nn.indexof(decreasing, v))
|
|
||||||
denom = denom + max(0, l - r)
|
|
||||||
end
|
|
||||||
|
|
||||||
for i, v in ipairs(rewards) do
|
|
||||||
local l = log2(lamb / 2 + 1)
|
|
||||||
local r = log2(nn.indexof(decreasing, v))
|
|
||||||
local numer = max(0, l - r)
|
|
||||||
insert(shaped_returns, numer / denom + 1 / lamb)
|
|
||||||
end
|
|
||||||
|
|
||||||
return shaped_returns
|
|
||||||
end
|
|
||||||
|
|
||||||
local function unperturbed_rank(rewards, unperturbed_reward)
|
local function unperturbed_rank(rewards, unperturbed_reward)
|
||||||
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
||||||
local nth_place = 1
|
local nth_place = 1
|
||||||
|
@ -207,6 +183,30 @@ local function cdf(x)
|
||||||
return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x)))
|
return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function fitness_shaping(rewards)
|
||||||
|
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
||||||
|
local decreasing = copy(rewards)
|
||||||
|
sort(decreasing, function(a, b) return a > b end)
|
||||||
|
local shaped_returns = {}
|
||||||
|
local lamb = #rewards
|
||||||
|
|
||||||
|
local denom = 0
|
||||||
|
for i, v in ipairs(rewards) do
|
||||||
|
local l = log2(lamb / 2 + 1)
|
||||||
|
local r = log2(indexof(decreasing, v))
|
||||||
|
denom = denom + max(0, l - r)
|
||||||
|
end
|
||||||
|
|
||||||
|
for i, v in ipairs(rewards) do
|
||||||
|
local l = log2(lamb / 2 + 1)
|
||||||
|
local r = log2(indexof(decreasing, v))
|
||||||
|
local numer = max(0, l - r)
|
||||||
|
insert(shaped_returns, numer / denom + 1 / lamb)
|
||||||
|
end
|
||||||
|
|
||||||
|
return shaped_returns
|
||||||
|
end
|
||||||
|
|
||||||
local function weighted_mann_whitney(s0, s1, w0, w1)
|
local function weighted_mann_whitney(s0, s1, w0, w1)
|
||||||
-- when w0 and w1 are nil, this decomposes(?) to the regular Mann-Whitney.
|
-- when w0 and w1 are nil, this decomposes(?) to the regular Mann-Whitney.
|
||||||
if w0 == nil then
|
if w0 == nil then
|
||||||
|
|
Loading…
Add table
Reference in a new issue