smbot/util.lua

274 lines
6.5 KiB
Lua
Raw Normal View History

2018-05-12 13:38:51 -07:00
-- TODO: reorganize function order.
2018-06-11 20:37:55 -07:00
local abs = math.abs
2018-05-12 13:38:51 -07:00
local assert = assert
2018-06-10 23:11:23 -07:00
local exp = math.exp
2018-05-12 13:38:51 -07:00
local ipairs = ipairs
local log = math.log
local max = math.max
local min = math.min
2018-06-10 07:35:28 -07:00
local open = io.open
2018-05-12 13:38:51 -07:00
local pairs = pairs
2018-06-10 23:11:23 -07:00
local pi = math.pi
2018-05-12 13:38:51 -07:00
local random = math.random
local select = select
2018-06-07 17:46:00 -07:00
local sort = table.sort
2018-06-07 17:45:07 -07:00
local sqrt = math.sqrt
2018-05-12 13:38:51 -07:00
local function signbyte(x)
if x >= 128 then x = 256 - x end
return x
end
local function boolean_xor(a, b)
if a and b then return false end
if not a and not b then return false end
return true
end
local _invlog2 = 1 / log(2)
local function log2(x) return log(x) * _invlog2 end
local function clamp(x, l, u) return min(max(x, l), u) end
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
local function argmax(...)
local max_i = 0
local max_v = -999999999
for i=1, select("#", ...) do
local v = select(i, ...)
if v > max_v then
max_i = i
max_v = v
end
end
return max_i
end
local function softchoice(...)
local t = random()
local psum = 0
for i=1, select("#", ...) do
local p = select(i, ...)
psum = psum + p
if t < psum then
return i
end
end
end
local function empty(t)
for k, _ in pairs(t) do t[k] = nil end
return t
end
local function calc_mean_dev(x)
local mean = 0
for i, v in ipairs(x) do
mean = mean + v / #x
end
local dev = 0
for i, v in ipairs(x) do
local delta = v - mean
dev = dev + delta * delta / #x
end
return mean, sqrt(dev)
end
local function normalize(x, out)
out = out or x
local mean, dev = calc_mean_dev(x)
if dev <= 0 then dev = 1 end
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
return out
end
local function normalize_wrt(x, s, out)
out = out or x
local mean, dev = calc_mean_dev(s)
if dev <= 0 then dev = 1 end
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
return out
end
2018-06-11 20:37:55 -07:00
local function normalize_sums(x, out)
out = out or x
local sum = 0
for i, v in ipairs(x) do sum = sum + v end
for i, v in ipairs(x) do out[i] = v - sum / #x end
local abssum = 0
for i, v in ipairs(out) do abssum = abssum + abs(v) end
for i, v in ipairs(out) do out[i] = v / abssum end
return out
end
2018-05-12 13:38:51 -07:00
local function fitness_shaping(rewards)
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
local decreasing = nn.copy(rewards)
sort(decreasing, function(a, b) return a > b end)
local shaped_returns = {}
local lamb = #rewards
local denom = 0
for i, v in ipairs(rewards) do
local l = log2(lamb / 2 + 1)
local r = log2(nn.indexof(decreasing, v))
denom = denom + max(0, l - r)
end
for i, v in ipairs(rewards) do
local l = log2(lamb / 2 + 1)
local r = log2(nn.indexof(decreasing, v))
local numer = max(0, l - r)
insert(shaped_returns, numer / denom + 1 / lamb)
end
return shaped_returns
end
local function unperturbed_rank(rewards, unperturbed_reward)
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
local nth_place = 1
for i, v in ipairs(rewards) do
if v > unperturbed_reward then
nth_place = nth_place + 1
end
end
return nth_place
end
local function copy(t, out) -- shallow copy
assert(type(t) == "table")
local out = out or {}
for k, v in pairs(t) do out[k] = v end
return out
end
local function indexof(t, a)
assert(type(t) == "table")
for k, v in pairs(t) do if v == a then return k end end
return nil
end
local function contains(t, a)
return indexof(t, a) ~= nil
end
local function argmax2(t)
return t[1] > t[2]
end
local function rchoice2(t)
return t[1] > random()
end
local function rbool()
return 0.5 >= random()
end
2018-06-07 17:46:00 -07:00
local function argsort(t, comp, out)
comp = comp or function(a, b) return a < b end
out = out or {}
for i=1, #t do out[i] = i end
sort(out, function(a, b) return comp(t[a], t[b]) end)
return out
end
2018-06-10 07:35:28 -07:00
local function exists(fn)
local f = open(fn, "r")
if f then
f:close()
return true
else
return false
end
end
2018-06-10 23:11:23 -07:00
local function pdf(x, mean, std)
-- probability density function for a normal distribution.
mean = mean or 0
std = std or 1
if mean == 0 and std == 1 then
return 0.39894228040143 * exp(x * x * -0.5)
end
local var = std * std
return 1 / sqrt(2 * pi * var) * exp((x - mean) * (x - mean) / (-2 * var))
end
local function cdf(x)
-- a very rough approximation of the
-- cumulative distribution function for a normal distribution.
-- absolute error peaks at plus-or-minus 1.654.
-- i don't remember where this is from.
local sign = x >= 0 and 1 or -1
return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x)))
end
local function weighted_mann_whitney(s0, s1, w0, w1)
-- when w0 and w1 are nil, this decomposes(?) to the regular Mann-Whitney.
if w0 == nil then
w0 = {}
for i=1, #s0 do w0[i] = 1.0 end
end
if w1 == nil then
w1 = {}
for i=1, #s1 do w1[i] = 1.0 end
end
assert(#s0 == #w0)
assert(#s1 == #w1)
local s0_sum, s1_sum, w0_sum, w1_sum = 0, 0, 0, 0
for i, v in ipairs(s0) do s0_sum = s0_sum + v end
for i, v in ipairs(s1) do s1_sum = s1_sum + v end
for i, v in ipairs(w0) do w0_sum = w0_sum + v end
for i, v in ipairs(w1) do w1_sum = w1_sum + v end
local U = 0
for i=1, #s0 do
for j=1, #s1 do
if s0[i] > s1[j] then
U = U + w0[i] * w1[j]
elseif s0[i] == s1[j] then
U = U + w0[i] * w1[j] * 0.5
end
end
end
local mean = w0_sum * w1_sum * 0.5
local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6)
local p = cdf((U - mean) / std)
if s0_sum > s1_sum then return 1 - p else return p end
end
2018-05-12 13:38:51 -07:00
return {
signbyte=signbyte,
boolean_xor=boolean_xor,
log2=log2,
clamp=clamp,
lerp=lerp,
argmax=argmax,
softchoice=softchoice,
empty=empty,
calc_mean_dev=calc_mean_dev,
normalize=normalize,
normalize_wrt=normalize_wrt,
2018-06-11 20:37:55 -07:00
normalize_sums=normalize_sums,
2018-05-12 13:38:51 -07:00
fitness_shaping=fitness_shaping,
unperturbed_rank=unperturbed_rank,
copy=copy,
indexof=indexof,
contains=contains,
argmax2=argmax2,
rchoice2=rchoice2,
rbool=rbool,
2018-06-07 17:46:00 -07:00
argsort=argsort,
2018-06-10 07:35:28 -07:00
exists=exists,
2018-06-10 23:11:23 -07:00
pdf=pdf,
cdf=cdf,
weighted_mann_whitney=weighted_mann_whitney,
2018-05-12 13:38:51 -07:00
}