smbot/util.lua

187 lines
4.1 KiB
Lua
Raw Normal View History

2018-05-12 13:38:51 -07:00
-- TODO: reorganize function order.
local assert = assert
local ipairs = ipairs
local log = math.log
local max = math.max
local min = math.min
local pairs = pairs
local random = math.random
local select = select
2018-06-07 17:46:00 -07:00
local sort = table.sort
2018-06-07 17:45:07 -07:00
local sqrt = math.sqrt
2018-05-12 13:38:51 -07:00
local function signbyte(x)
if x >= 128 then x = 256 - x end
return x
end
local function boolean_xor(a, b)
if a and b then return false end
if not a and not b then return false end
return true
end
local _invlog2 = 1 / log(2)
local function log2(x) return log(x) * _invlog2 end
local function clamp(x, l, u) return min(max(x, l), u) end
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
local function argmax(...)
local max_i = 0
local max_v = -999999999
for i=1, select("#", ...) do
local v = select(i, ...)
if v > max_v then
max_i = i
max_v = v
end
end
return max_i
end
local function softchoice(...)
local t = random()
local psum = 0
for i=1, select("#", ...) do
local p = select(i, ...)
psum = psum + p
if t < psum then
return i
end
end
end
local function empty(t)
for k, _ in pairs(t) do t[k] = nil end
return t
end
local function calc_mean_dev(x)
local mean = 0
for i, v in ipairs(x) do
mean = mean + v / #x
end
local dev = 0
for i, v in ipairs(x) do
local delta = v - mean
dev = dev + delta * delta / #x
end
return mean, sqrt(dev)
end
local function normalize(x, out)
out = out or x
local mean, dev = calc_mean_dev(x)
if dev <= 0 then dev = 1 end
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
return out
end
local function normalize_wrt(x, s, out)
out = out or x
local mean, dev = calc_mean_dev(s)
if dev <= 0 then dev = 1 end
for i, v in ipairs(x) do out[i] = (v - mean) / dev end
return out
end
local function fitness_shaping(rewards)
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
local decreasing = nn.copy(rewards)
sort(decreasing, function(a, b) return a > b end)
local shaped_returns = {}
local lamb = #rewards
local denom = 0
for i, v in ipairs(rewards) do
local l = log2(lamb / 2 + 1)
local r = log2(nn.indexof(decreasing, v))
denom = denom + max(0, l - r)
end
for i, v in ipairs(rewards) do
local l = log2(lamb / 2 + 1)
local r = log2(nn.indexof(decreasing, v))
local numer = max(0, l - r)
insert(shaped_returns, numer / denom + 1 / lamb)
end
return shaped_returns
end
local function unperturbed_rank(rewards, unperturbed_reward)
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
local nth_place = 1
for i, v in ipairs(rewards) do
if v > unperturbed_reward then
nth_place = nth_place + 1
end
end
return nth_place
end
local function copy(t, out) -- shallow copy
assert(type(t) == "table")
local out = out or {}
for k, v in pairs(t) do out[k] = v end
return out
end
local function indexof(t, a)
assert(type(t) == "table")
for k, v in pairs(t) do if v == a then return k end end
return nil
end
local function contains(t, a)
return indexof(t, a) ~= nil
end
local function argmax2(t)
return t[1] > t[2]
end
local function rchoice2(t)
return t[1] > random()
end
local function rbool()
return 0.5 >= random()
end
2018-06-07 17:46:00 -07:00
local function argsort(t, comp, out)
comp = comp or function(a, b) return a < b end
out = out or {}
for i=1, #t do out[i] = i end
sort(out, function(a, b) return comp(t[a], t[b]) end)
return out
end
2018-05-12 13:38:51 -07:00
return {
signbyte=signbyte,
boolean_xor=boolean_xor,
log2=log2,
clamp=clamp,
lerp=lerp,
argmax=argmax,
softchoice=softchoice,
empty=empty,
calc_mean_dev=calc_mean_dev,
normalize=normalize,
normalize_wrt=normalize_wrt,
fitness_shaping=fitness_shaping,
unperturbed_rank=unperturbed_rank,
copy=copy,
indexof=indexof,
contains=contains,
argmax2=argmax2,
rchoice2=rchoice2,
rbool=rbool,
2018-06-07 17:46:00 -07:00
argsort=argsort,
2018-05-12 13:38:51 -07:00
}