176 lines
3.9 KiB
Lua
176 lines
3.9 KiB
Lua
-- TODO: reorganize function order.
|
|
|
|
local assert = assert
|
|
local ipairs = ipairs
|
|
local log = math.log
|
|
local max = math.max
|
|
local min = math.min
|
|
local pairs = pairs
|
|
local random = math.random
|
|
local select = select
|
|
local sqrt= math.sqrt
|
|
|
|
local function signbyte(x)
|
|
if x >= 128 then x = 256 - x end
|
|
return x
|
|
end
|
|
|
|
local function boolean_xor(a, b)
|
|
if a and b then return false end
|
|
if not a and not b then return false end
|
|
return true
|
|
end
|
|
|
|
local _invlog2 = 1 / log(2)
|
|
local function log2(x) return log(x) * _invlog2 end
|
|
|
|
local function clamp(x, l, u) return min(max(x, l), u) end
|
|
|
|
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
|
|
|
|
local function argmax(...)
|
|
local max_i = 0
|
|
local max_v = -999999999
|
|
for i=1, select("#", ...) do
|
|
local v = select(i, ...)
|
|
if v > max_v then
|
|
max_i = i
|
|
max_v = v
|
|
end
|
|
end
|
|
return max_i
|
|
end
|
|
|
|
local function softchoice(...)
|
|
local t = random()
|
|
local psum = 0
|
|
for i=1, select("#", ...) do
|
|
local p = select(i, ...)
|
|
psum = psum + p
|
|
if t < psum then
|
|
return i
|
|
end
|
|
end
|
|
end
|
|
|
|
local function empty(t)
|
|
for k, _ in pairs(t) do t[k] = nil end
|
|
return t
|
|
end
|
|
|
|
local function calc_mean_dev(x)
|
|
local mean = 0
|
|
for i, v in ipairs(x) do
|
|
mean = mean + v / #x
|
|
end
|
|
|
|
local dev = 0
|
|
for i, v in ipairs(x) do
|
|
local delta = v - mean
|
|
dev = dev + delta * delta / #x
|
|
end
|
|
|
|
return mean, sqrt(dev)
|
|
end
|
|
|
|
local function normalize(x, out)
|
|
out = out or x
|
|
local mean, dev = calc_mean_dev(x)
|
|
if dev <= 0 then dev = 1 end
|
|
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
|
return out
|
|
end
|
|
|
|
local function normalize_wrt(x, s, out)
|
|
out = out or x
|
|
local mean, dev = calc_mean_dev(s)
|
|
if dev <= 0 then dev = 1 end
|
|
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
|
|
return out
|
|
end
|
|
|
|
local function fitness_shaping(rewards)
|
|
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
|
local decreasing = nn.copy(rewards)
|
|
sort(decreasing, function(a, b) return a > b end)
|
|
local shaped_returns = {}
|
|
local lamb = #rewards
|
|
|
|
local denom = 0
|
|
for i, v in ipairs(rewards) do
|
|
local l = log2(lamb / 2 + 1)
|
|
local r = log2(nn.indexof(decreasing, v))
|
|
denom = denom + max(0, l - r)
|
|
end
|
|
|
|
for i, v in ipairs(rewards) do
|
|
local l = log2(lamb / 2 + 1)
|
|
local r = log2(nn.indexof(decreasing, v))
|
|
local numer = max(0, l - r)
|
|
insert(shaped_returns, numer / denom + 1 / lamb)
|
|
end
|
|
|
|
return shaped_returns
|
|
end
|
|
|
|
local function unperturbed_rank(rewards, unperturbed_reward)
|
|
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
|
local nth_place = 1
|
|
for i, v in ipairs(rewards) do
|
|
if v > unperturbed_reward then
|
|
nth_place = nth_place + 1
|
|
end
|
|
end
|
|
return nth_place
|
|
end
|
|
|
|
local function copy(t, out) -- shallow copy
|
|
assert(type(t) == "table")
|
|
local out = out or {}
|
|
for k, v in pairs(t) do out[k] = v end
|
|
return out
|
|
end
|
|
|
|
local function indexof(t, a)
|
|
assert(type(t) == "table")
|
|
for k, v in pairs(t) do if v == a then return k end end
|
|
return nil
|
|
end
|
|
|
|
local function contains(t, a)
|
|
return indexof(t, a) ~= nil
|
|
end
|
|
|
|
local function argmax2(t)
|
|
return t[1] > t[2]
|
|
end
|
|
|
|
local function rchoice2(t)
|
|
return t[1] > random()
|
|
end
|
|
|
|
local function rbool()
|
|
return 0.5 >= random()
|
|
end
|
|
|
|
return {
|
|
signbyte=signbyte,
|
|
boolean_xor=boolean_xor,
|
|
log2=log2,
|
|
clamp=clamp,
|
|
lerp=lerp,
|
|
argmax=argmax,
|
|
softchoice=softchoice,
|
|
empty=empty,
|
|
calc_mean_dev=calc_mean_dev,
|
|
normalize=normalize,
|
|
normalize_wrt=normalize_wrt,
|
|
fitness_shaping=fitness_shaping,
|
|
unperturbed_rank=unperturbed_rank,
|
|
copy=copy,
|
|
indexof=indexof,
|
|
contains=contains,
|
|
argmax2=argmax2,
|
|
rchoice2=rchoice2,
|
|
rbool=rbool,
|
|
}
|