use locals; fix fitness_shaping and graycode

This commit is contained in:
Connor Olding 2018-06-13 22:51:12 +02:00
parent a1ec797de0
commit 63583789c3
3 changed files with 42 additions and 39 deletions

15
ars.lua
View File

@ -3,6 +3,7 @@
-- with some tweaks (lips) by myself.
local abs = math.abs
local exp = math.exp
local floor = math.floor
local ipairs = ipairs
local max = math.max
@ -13,11 +14,13 @@ local Base = require "Base"
local nn = require "nn"
local normal = nn.normal
local prod = nn.prod
local uniform = nn.uniform
local zeros = nn.zeros
local util = require "util"
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local normalize_sums = util.normalize_sums
local Ars = Base:extend()
@ -62,7 +65,7 @@ function Ars:init(dims, popsize, poptop, learning_rate, sigma, antithetic)
assert(self.poptop <= popsize)
if self.antithetic then self.popsize = self.popsize * 2 end
self._params = nn.zeros(self.dims)
self._params = zeros(self.dims)
self.evals = 0
end
@ -93,14 +96,14 @@ function Ars:ask(graycode)
else
if graycode ~= nil then
for j = 1, self.dims do
noisy[j] = exp(-precision * nn.uniform())
noisy[j] = exp(-precision * uniform())
end
for j = 1, self.dims do
noisy[j] = nn.uniform() < 0.5 and noisy[j] or -noisy[j]
noisy[j] = uniform() < 0.5 and noisy[j] or -noisy[j]
end
else
for j = 1, self.dims do
noisy[j] = self.sigma * nn.normal()
noisy[j] = self.sigma * normal()
end
end
end
@ -133,10 +136,10 @@ function Ars:tell(scored, unperturbed_score)
for _, i in ipairs(indices) do top_rewards[i] = scored[i] end
-- note: although this normalizes the scale, it's later
-- re-normalized differently by reward_dev anyway.
top_rewards = util.normalize_sums(top_rewards)
top_rewards = normalize_sums(top_rewards)
end
local step = nn.zeros(self.dims)
local step = zeros(self.dims)
local _, reward_dev = calc_mean_dev(top_rewards)
if reward_dev == 0 then reward_dev = 1 end

View File

@ -89,16 +89,16 @@ local emu = emu
local gui = gui
local util = require("util")
local argmax = util.argmax
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local clamp = util.clamp
local copy = util.copy
local empty = util.empty
local lerp = util.lerp
local softchoice = util.softchoice
local argmax = util.argmax
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local clamp = util.clamp
local copy = util.copy
local empty = util.empty
local lerp = util.lerp
local softchoice = util.softchoice
local unperturbed_rank = util.unperturbed_rank
local exists = util.exists
local exists = util.exists
local game = require("smb")
game.overlay = cfg.enable_overlay

View File

@ -105,30 +105,6 @@ local function normalize_sums(x, out)
return out
end
local function fitness_shaping(rewards)
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
local decreasing = nn.copy(rewards)
sort(decreasing, function(a, b) return a > b end)
local shaped_returns = {}
local lamb = #rewards
local denom = 0
for i, v in ipairs(rewards) do
local l = log2(lamb / 2 + 1)
local r = log2(nn.indexof(decreasing, v))
denom = denom + max(0, l - r)
end
for i, v in ipairs(rewards) do
local l = log2(lamb / 2 + 1)
local r = log2(nn.indexof(decreasing, v))
local numer = max(0, l - r)
insert(shaped_returns, numer / denom + 1 / lamb)
end
return shaped_returns
end
local function unperturbed_rank(rewards, unperturbed_reward)
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
local nth_place = 1
@ -207,6 +183,30 @@ local function cdf(x)
return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x)))
end
local function fitness_shaping(rewards)
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
local decreasing = copy(rewards)
sort(decreasing, function(a, b) return a > b end)
local shaped_returns = {}
local lamb = #rewards
local denom = 0
for i, v in ipairs(rewards) do
local l = log2(lamb / 2 + 1)
local r = log2(indexof(decreasing, v))
denom = denom + max(0, l - r)
end
for i, v in ipairs(rewards) do
local l = log2(lamb / 2 + 1)
local r = log2(indexof(decreasing, v))
local numer = max(0, l - r)
insert(shaped_returns, numer / denom + 1 / lamb)
end
return shaped_returns
end
local function weighted_mann_whitney(s0, s1, w0, w1)
-- when w0 and w1 are nil, this decomposes(?) to the regular Mann-Whitney.
if w0 == nil then