2018-05-12 13:38:51 -07:00
|
|
|
-- TODO: reorganize function order.
|
|
|
|
|
|
|
|
local assert = assert
|
2018-06-10 23:11:23 -07:00
|
|
|
local exp = math.exp
|
2018-05-12 13:38:51 -07:00
|
|
|
local ipairs = ipairs
|
|
|
|
local log = math.log
|
|
|
|
local max = math.max
|
|
|
|
local min = math.min
|
2018-06-10 07:35:28 -07:00
|
|
|
local open = io.open
|
2018-05-12 13:38:51 -07:00
|
|
|
local pairs = pairs
|
2018-06-10 23:11:23 -07:00
|
|
|
local pi = math.pi
|
2018-05-12 13:38:51 -07:00
|
|
|
local random = math.random
|
|
|
|
local select = select
|
2018-06-07 17:46:00 -07:00
|
|
|
local sort = table.sort
|
2018-06-07 17:45:07 -07:00
|
|
|
local sqrt = math.sqrt
|
2018-05-12 13:38:51 -07:00
|
|
|
|
|
|
|
local function signbyte(x)
|
|
|
|
if x >= 128 then x = 256 - x end
|
|
|
|
return x
|
|
|
|
end
|
|
|
|
|
|
|
|
local function boolean_xor(a, b)
|
|
|
|
if a and b then return false end
|
|
|
|
if not a and not b then return false end
|
|
|
|
return true
|
|
|
|
end
|
|
|
|
|
|
|
|
local _invlog2 = 1 / log(2)
|
|
|
|
local function log2(x) return log(x) * _invlog2 end
|
|
|
|
|
|
|
|
local function clamp(x, l, u) return min(max(x, l), u) end
|
|
|
|
|
|
|
|
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
|
|
|
|
|
|
|
|
local function argmax(...)
|
|
|
|
local max_i = 0
|
|
|
|
local max_v = -999999999
|
|
|
|
for i=1, select("#", ...) do
|
|
|
|
local v = select(i, ...)
|
|
|
|
if v > max_v then
|
|
|
|
max_i = i
|
|
|
|
max_v = v
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return max_i
|
|
|
|
end
|
|
|
|
|
|
|
|
local function softchoice(...)
|
|
|
|
local t = random()
|
|
|
|
local psum = 0
|
|
|
|
for i=1, select("#", ...) do
|
|
|
|
local p = select(i, ...)
|
|
|
|
psum = psum + p
|
|
|
|
if t < psum then
|
|
|
|
return i
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
local function empty(t)
|
|
|
|
for k, _ in pairs(t) do t[k] = nil end
|
|
|
|
return t
|
|
|
|
end
|
|
|
|
|
|
|
|
local function calc_mean_dev(x)
|
|
|
|
local mean = 0
|
|
|
|
for i, v in ipairs(x) do
|
|
|
|
mean = mean + v / #x
|
|
|
|
end
|
|
|
|
|
|
|
|
local dev = 0
|
|
|
|
for i, v in ipairs(x) do
|
|
|
|
local delta = v - mean
|
|
|
|
dev = dev + delta * delta / #x
|
|
|
|
end
|
|
|
|
|
|
|
|
return mean, sqrt(dev)
|
|
|
|
end
|
|
|
|
|
|
|
|
local function normalize(x, out)
|
|
|
|
out = out or x
|
|
|
|
local mean, dev = calc_mean_dev(x)
|
|
|
|
if dev <= 0 then dev = 1 end
|
|
|
|
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
|
|
|
local function normalize_wrt(x, s, out)
|
|
|
|
out = out or x
|
|
|
|
local mean, dev = calc_mean_dev(s)
|
|
|
|
if dev <= 0 then dev = 1 end
|
|
|
|
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
|
|
|
local function fitness_shaping(rewards)
|
|
|
|
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
|
|
|
local decreasing = nn.copy(rewards)
|
|
|
|
sort(decreasing, function(a, b) return a > b end)
|
|
|
|
local shaped_returns = {}
|
|
|
|
local lamb = #rewards
|
|
|
|
|
|
|
|
local denom = 0
|
|
|
|
for i, v in ipairs(rewards) do
|
|
|
|
local l = log2(lamb / 2 + 1)
|
|
|
|
local r = log2(nn.indexof(decreasing, v))
|
|
|
|
denom = denom + max(0, l - r)
|
|
|
|
end
|
|
|
|
|
|
|
|
for i, v in ipairs(rewards) do
|
|
|
|
local l = log2(lamb / 2 + 1)
|
|
|
|
local r = log2(nn.indexof(decreasing, v))
|
|
|
|
local numer = max(0, l - r)
|
|
|
|
insert(shaped_returns, numer / denom + 1 / lamb)
|
|
|
|
end
|
|
|
|
|
|
|
|
return shaped_returns
|
|
|
|
end
|
|
|
|
|
|
|
|
local function unperturbed_rank(rewards, unperturbed_reward)
|
|
|
|
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
|
|
|
local nth_place = 1
|
|
|
|
for i, v in ipairs(rewards) do
|
|
|
|
if v > unperturbed_reward then
|
|
|
|
nth_place = nth_place + 1
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return nth_place
|
|
|
|
end
|
|
|
|
|
|
|
|
local function copy(t, out) -- shallow copy
|
|
|
|
assert(type(t) == "table")
|
|
|
|
local out = out or {}
|
|
|
|
for k, v in pairs(t) do out[k] = v end
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
|
|
|
local function indexof(t, a)
|
|
|
|
assert(type(t) == "table")
|
|
|
|
for k, v in pairs(t) do if v == a then return k end end
|
|
|
|
return nil
|
|
|
|
end
|
|
|
|
|
|
|
|
local function contains(t, a)
|
|
|
|
return indexof(t, a) ~= nil
|
|
|
|
end
|
|
|
|
|
|
|
|
local function argmax2(t)
|
|
|
|
return t[1] > t[2]
|
|
|
|
end
|
|
|
|
|
|
|
|
local function rchoice2(t)
|
|
|
|
return t[1] > random()
|
|
|
|
end
|
|
|
|
|
|
|
|
local function rbool()
|
|
|
|
return 0.5 >= random()
|
|
|
|
end
|
|
|
|
|
2018-06-07 17:46:00 -07:00
|
|
|
local function argsort(t, comp, out)
|
|
|
|
comp = comp or function(a, b) return a < b end
|
|
|
|
out = out or {}
|
|
|
|
for i=1, #t do out[i] = i end
|
|
|
|
sort(out, function(a, b) return comp(t[a], t[b]) end)
|
|
|
|
return out
|
|
|
|
end
|
|
|
|
|
2018-06-10 07:35:28 -07:00
|
|
|
local function exists(fn)
|
|
|
|
local f = open(fn, "r")
|
|
|
|
if f then
|
|
|
|
f:close()
|
|
|
|
return true
|
|
|
|
else
|
|
|
|
return false
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2018-06-10 23:11:23 -07:00
|
|
|
local function pdf(x, mean, std)
|
|
|
|
-- probability density function for a normal distribution.
|
|
|
|
mean = mean or 0
|
|
|
|
std = std or 1
|
|
|
|
if mean == 0 and std == 1 then
|
|
|
|
return 0.39894228040143 * exp(x * x * -0.5)
|
|
|
|
end
|
|
|
|
local var = std * std
|
|
|
|
return 1 / sqrt(2 * pi * var) * exp((x - mean) * (x - mean) / (-2 * var))
|
|
|
|
end
|
|
|
|
|
|
|
|
local function cdf(x)
|
|
|
|
-- a very rough approximation of the
|
|
|
|
-- cumulative distribution function for a normal distribution.
|
|
|
|
-- absolute error peaks at plus-or-minus 1.654.
|
|
|
|
-- i don't remember where this is from.
|
|
|
|
local sign = x >= 0 and 1 or -1
|
|
|
|
return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x)))
|
|
|
|
end
|
|
|
|
|
2018-05-12 13:38:51 -07:00
|
|
|
return {
|
|
|
|
signbyte=signbyte,
|
|
|
|
boolean_xor=boolean_xor,
|
|
|
|
log2=log2,
|
|
|
|
clamp=clamp,
|
|
|
|
lerp=lerp,
|
|
|
|
argmax=argmax,
|
|
|
|
softchoice=softchoice,
|
|
|
|
empty=empty,
|
|
|
|
calc_mean_dev=calc_mean_dev,
|
|
|
|
normalize=normalize,
|
|
|
|
normalize_wrt=normalize_wrt,
|
|
|
|
fitness_shaping=fitness_shaping,
|
|
|
|
unperturbed_rank=unperturbed_rank,
|
|
|
|
copy=copy,
|
|
|
|
indexof=indexof,
|
|
|
|
contains=contains,
|
|
|
|
argmax2=argmax2,
|
|
|
|
rchoice2=rchoice2,
|
|
|
|
rbool=rbool,
|
2018-06-07 17:46:00 -07:00
|
|
|
argsort=argsort,
|
2018-06-10 07:35:28 -07:00
|
|
|
exists=exists,
|
2018-06-10 23:11:23 -07:00
|
|
|
pdf=pdf,
|
|
|
|
cdf=cdf,
|
2018-05-12 13:38:51 -07:00
|
|
|
}
|