2018-05-12 13:38:51 -07:00
|
|
|
-- TODO: reorganize function order.
|
|
|
|
|
2018-06-11 20:37:55 -07:00
|
|
|
local abs = math.abs
|
2018-05-12 13:38:51 -07:00
|
|
|
local assert = assert
|
2018-06-10 23:11:23 -07:00
|
|
|
local exp = math.exp
|
2018-05-12 13:38:51 -07:00
|
|
|
local ipairs = ipairs
|
|
|
|
local log = math.log
|
|
|
|
local max = math.max
|
|
|
|
local min = math.min
|
2018-06-10 07:35:28 -07:00
|
|
|
local open = io.open
|
2018-05-12 13:38:51 -07:00
|
|
|
local pairs = pairs
|
2018-06-10 23:11:23 -07:00
|
|
|
local pi = math.pi
|
2018-05-12 13:38:51 -07:00
|
|
|
local random = math.random
|
|
|
|
local select = select
|
2018-06-07 17:46:00 -07:00
|
|
|
local sort = table.sort
|
2018-06-07 17:45:07 -07:00
|
|
|
local sqrt = math.sqrt
|
2018-05-12 13:38:51 -07:00
|
|
|
|
2018-06-20 20:14:45 -07:00
|
|
|
local function sign(x)
|
|
|
|
-- remember that 0 is truthy in Lua.
|
|
|
|
return x == 0 and 0 or x > 0 and 1 or -1
|
|
|
|
end
|
|
|
|
|
2018-05-12 13:38:51 -07:00
|
|
|
local function signbyte(x)
|
|
|
|
if x >= 128 then x = 256 - x end
|
|
|
|
return x
|
|
|
|
end
|
|
|
|
|
|
|
|
local function boolean_xor(a, b)
|
|
|
|
if a and b then return false end
|
|
|
|
if not a and not b then return false end
|
|
|
|
return true
|
|
|
|
end
|
|
|
|
|
|
|
|
local _invlog2 = 1 / log(2)
|
|
|
|
local function log2(x) return log(x) * _invlog2 end
|
|
|
|
|
|
|
|
local function clamp(x, l, u) return min(max(x, l), u) end
|
|
|
|
|
|
|
|
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
|
|
|
|
|
|
|
|
local function argmax(...)
|
|
|
|
local max_i = 0
|
|
|
|
local max_v = -999999999
|
|
|
|
for i=1, select("#", ...) do
|
|
|
|
local v = select(i, ...)
|
|
|
|
if v > max_v then
|
|
|
|
max_i = i
|
|
|
|
max_v = v
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return max_i
|
|
|
|
end
|
|
|
|
|
|
|
|
local function softchoice(...)
|
|
|
|
local t = random()
|
|
|
|
local psum = 0
|
|
|
|
for i=1, select("#", ...) do
|
|
|
|
local p = select(i, ...)
|
|
|
|
psum = psum + p
|
|
|
|
if t < psum then
|
|
|
|
return i
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
local function empty(t)
|
|
|
|
for k, _ in pairs(t) do t[k] = nil end
|
|
|
|
return t
|
|
|
|
end
|
|
|
|
|
|
|
|
local function calc_mean_dev(x)
|
|
|
|
local mean = 0
|
|
|
|
for i, v in ipairs(x) do
|
|
|
|
mean = mean + v / #x
|
|
|
|
end
|
|
|
|
|
|
|
|
local dev = 0
|
|
|
|
for i, v in ipairs(x) do
|
|
|
|
local delta = v - mean
|
|
|
|
dev = dev + delta * delta / #x
|
|
|
|
end
|
|
|
|
|
|
|
|
return mean, sqrt(dev)
|
|
|
|
end
|
|
|
|
|
|
|
|
local function normalize(x, out)
|
|
|
|
out = out or x
|
|
|
|
local mean, dev = calc_mean_dev(x)
|
|
|
|
if dev <= 0 then dev = 1 end
|
|
|
|
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
|
|
|
local function normalize_wrt(x, s, out)
|
|
|
|
out = out or x
|
|
|
|
local mean, dev = calc_mean_dev(s)
|
|
|
|
if dev <= 0 then dev = 1 end
|
|
|
|
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
2018-06-11 20:37:55 -07:00
|
|
|
local function normalize_sums(x, out)
|
|
|
|
out = out or x
|
|
|
|
local sum = 0
|
|
|
|
for i, v in ipairs(x) do sum = sum + v end
|
|
|
|
for i, v in ipairs(x) do out[i] = v - sum / #x end
|
|
|
|
local abssum = 0
|
|
|
|
for i, v in ipairs(out) do abssum = abssum + abs(v) end
|
|
|
|
for i, v in ipairs(out) do out[i] = v / abssum end
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
2018-05-12 13:38:51 -07:00
|
|
|
local function unperturbed_rank(rewards, unperturbed_reward)
|
|
|
|
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
|
|
|
local nth_place = 1
|
|
|
|
for i, v in ipairs(rewards) do
|
|
|
|
if v > unperturbed_reward then
|
|
|
|
nth_place = nth_place + 1
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return nth_place
|
|
|
|
end
|
|
|
|
|
|
|
|
local function copy(t, out) -- shallow copy
|
|
|
|
assert(type(t) == "table")
|
|
|
|
local out = out or {}
|
|
|
|
for k, v in pairs(t) do out[k] = v end
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
|
|
|
local function indexof(t, a)
|
|
|
|
assert(type(t) == "table")
|
|
|
|
for k, v in pairs(t) do if v == a then return k end end
|
|
|
|
return nil
|
|
|
|
end
|
|
|
|
|
|
|
|
local function contains(t, a)
|
|
|
|
return indexof(t, a) ~= nil
|
|
|
|
end
|
|
|
|
|
|
|
|
local function argmax2(t)
|
|
|
|
return t[1] > t[2]
|
|
|
|
end
|
|
|
|
|
|
|
|
local function rchoice2(t)
|
|
|
|
return t[1] > random()
|
|
|
|
end
|
|
|
|
|
|
|
|
local function rbool()
|
|
|
|
return 0.5 >= random()
|
|
|
|
end
|
|
|
|
|
2018-06-07 17:46:00 -07:00
|
|
|
local function argsort(t, comp, out)
|
|
|
|
comp = comp or function(a, b) return a < b end
|
|
|
|
out = out or {}
|
|
|
|
for i=1, #t do out[i] = i end
|
|
|
|
sort(out, function(a, b) return comp(t[a], t[b]) end)
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
2018-06-10 07:35:28 -07:00
|
|
|
local function exists(fn)
|
|
|
|
local f = open(fn, "r")
|
|
|
|
if f then
|
|
|
|
f:close()
|
|
|
|
return true
|
|
|
|
else
|
|
|
|
return false
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2018-06-10 23:11:23 -07:00
|
|
|
local function pdf(x, mean, std)
|
|
|
|
-- probability density function for a normal distribution.
|
|
|
|
mean = mean or 0
|
|
|
|
std = std or 1
|
|
|
|
if mean == 0 and std == 1 then
|
|
|
|
return 0.39894228040143 * exp(x * x * -0.5)
|
|
|
|
end
|
|
|
|
local var = std * std
|
|
|
|
return 1 / sqrt(2 * pi * var) * exp((x - mean) * (x - mean) / (-2 * var))
|
|
|
|
end
|
|
|
|
|
|
|
|
local function cdf(x)
|
|
|
|
-- a very rough approximation of the
|
|
|
|
-- cumulative distribution function for a normal distribution.
|
|
|
|
-- absolute error peaks at plus-or-minus 1.654.
|
|
|
|
-- i don't remember where this is from.
|
|
|
|
local sign = x >= 0 and 1 or -1
|
|
|
|
return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x)))
|
|
|
|
end
|
|
|
|
|
2018-06-13 13:51:12 -07:00
|
|
|
local function fitness_shaping(rewards)
|
|
|
|
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
|
|
|
local decreasing = copy(rewards)
|
|
|
|
sort(decreasing, function(a, b) return a > b end)
|
|
|
|
local shaped_returns = {}
|
|
|
|
local lamb = #rewards
|
|
|
|
|
|
|
|
local denom = 0
|
|
|
|
for i, v in ipairs(rewards) do
|
|
|
|
local l = log2(lamb / 2 + 1)
|
|
|
|
local r = log2(indexof(decreasing, v))
|
|
|
|
denom = denom + max(0, l - r)
|
|
|
|
end
|
|
|
|
|
|
|
|
for i, v in ipairs(rewards) do
|
|
|
|
local l = log2(lamb / 2 + 1)
|
|
|
|
local r = log2(indexof(decreasing, v))
|
|
|
|
local numer = max(0, l - r)
|
|
|
|
insert(shaped_returns, numer / denom + 1 / lamb)
|
|
|
|
end
|
|
|
|
|
|
|
|
return shaped_returns
|
|
|
|
end
|
|
|
|
|
2018-06-12 16:19:32 -07:00
|
|
|
local function weighted_mann_whitney(s0, s1, w0, w1)
|
|
|
|
-- when w0 and w1 are nil, this decomposes(?) to the regular Mann-Whitney.
|
|
|
|
if w0 == nil then
|
|
|
|
w0 = {}
|
|
|
|
for i=1, #s0 do w0[i] = 1.0 end
|
|
|
|
end
|
|
|
|
if w1 == nil then
|
|
|
|
w1 = {}
|
|
|
|
for i=1, #s1 do w1[i] = 1.0 end
|
|
|
|
end
|
|
|
|
assert(#s0 == #w0)
|
|
|
|
assert(#s1 == #w1)
|
|
|
|
|
|
|
|
local s0_sum, s1_sum, w0_sum, w1_sum = 0, 0, 0, 0
|
|
|
|
for i, v in ipairs(s0) do s0_sum = s0_sum + v end
|
|
|
|
for i, v in ipairs(s1) do s1_sum = s1_sum + v end
|
|
|
|
for i, v in ipairs(w0) do w0_sum = w0_sum + v end
|
|
|
|
for i, v in ipairs(w1) do w1_sum = w1_sum + v end
|
|
|
|
|
|
|
|
local U = 0
|
|
|
|
for i=1, #s0 do
|
|
|
|
for j=1, #s1 do
|
|
|
|
if s0[i] > s1[j] then
|
|
|
|
U = U + w0[i] * w1[j]
|
|
|
|
elseif s0[i] == s1[j] then
|
|
|
|
U = U + w0[i] * w1[j] * 0.5
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
local mean = w0_sum * w1_sum * 0.5
|
|
|
|
local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6)
|
|
|
|
local p = cdf((U - mean) / std)
|
|
|
|
|
|
|
|
if s0_sum > s1_sum then return 1 - p else return p end
|
|
|
|
end
|
|
|
|
|
2018-05-12 13:38:51 -07:00
|
|
|
return {
|
2018-06-20 20:14:45 -07:00
|
|
|
sign=sign,
|
2018-05-12 13:38:51 -07:00
|
|
|
signbyte=signbyte,
|
|
|
|
boolean_xor=boolean_xor,
|
|
|
|
log2=log2,
|
|
|
|
clamp=clamp,
|
|
|
|
lerp=lerp,
|
|
|
|
argmax=argmax,
|
|
|
|
softchoice=softchoice,
|
|
|
|
empty=empty,
|
|
|
|
calc_mean_dev=calc_mean_dev,
|
|
|
|
normalize=normalize,
|
|
|
|
normalize_wrt=normalize_wrt,
|
2018-06-11 20:37:55 -07:00
|
|
|
normalize_sums=normalize_sums,
|
2018-05-12 13:38:51 -07:00
|
|
|
fitness_shaping=fitness_shaping,
|
|
|
|
unperturbed_rank=unperturbed_rank,
|
|
|
|
copy=copy,
|
|
|
|
indexof=indexof,
|
|
|
|
contains=contains,
|
|
|
|
argmax2=argmax2,
|
|
|
|
rchoice2=rchoice2,
|
|
|
|
rbool=rbool,
|
2018-06-07 17:46:00 -07:00
|
|
|
argsort=argsort,
|
2018-06-10 07:35:28 -07:00
|
|
|
exists=exists,
|
2018-06-10 23:11:23 -07:00
|
|
|
pdf=pdf,
|
|
|
|
cdf=cdf,
|
2018-06-12 16:19:32 -07:00
|
|
|
weighted_mann_whitney=weighted_mann_whitney,
|
2018-05-12 13:38:51 -07:00
|
|
|
}
|