-- TODO: reorganize function order. local assert = assert local ipairs = ipairs local log = math.log local max = math.max local min = math.min local pairs = pairs local random = math.random local select = select local sqrt= math.sqrt local function signbyte(x) if x >= 128 then x = 256 - x end return x end local function boolean_xor(a, b) if a and b then return false end if not a and not b then return false end return true end local _invlog2 = 1 / log(2) local function log2(x) return log(x) * _invlog2 end local function clamp(x, l, u) return min(max(x, l), u) end local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end local function argmax(...) local max_i = 0 local max_v = -999999999 for i=1, select("#", ...) do local v = select(i, ...) if v > max_v then max_i = i max_v = v end end return max_i end local function softchoice(...) local t = random() local psum = 0 for i=1, select("#", ...) do local p = select(i, ...) psum = psum + p if t < psum then return i end end end local function empty(t) for k, _ in pairs(t) do t[k] = nil end return t end local function calc_mean_dev(x) local mean = 0 for i, v in ipairs(x) do mean = mean + v / #x end local dev = 0 for i, v in ipairs(x) do local delta = v - mean dev = dev + delta * delta / #x end return mean, sqrt(dev) end local function normalize(x, out) out = out or x local mean, dev = calc_mean_dev(x) if dev <= 0 then dev = 1 end for i, v in ipairs(x) do out[i] = (v - mean) / dev end return out end local function normalize_wrt(x, s, out) out = out or x local mean, dev = calc_mean_dev(s) if dev <= 0 then dev = 1 end for i, v in ipairs(x) do out[i] = (v - mean) / dev end return out end local function fitness_shaping(rewards) -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py local decreasing = nn.copy(rewards) sort(decreasing, function(a, b) return a > b end) local shaped_returns = {} local lamb = #rewards local denom = 0 for i, v in ipairs(rewards) do local l = log2(lamb / 2 + 1) local r = log2(nn.indexof(decreasing, v)) denom = denom + max(0, l - r) end for i, v in ipairs(rewards) do local l = log2(lamb / 2 + 1) local r = log2(nn.indexof(decreasing, v)) local numer = max(0, l - r) insert(shaped_returns, numer / denom + 1 / lamb) end return shaped_returns end local function unperturbed_rank(rewards, unperturbed_reward) -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py local nth_place = 1 for i, v in ipairs(rewards) do if v > unperturbed_reward then nth_place = nth_place + 1 end end return nth_place end local function copy(t, out) -- shallow copy assert(type(t) == "table") local out = out or {} for k, v in pairs(t) do out[k] = v end return out end local function indexof(t, a) assert(type(t) == "table") for k, v in pairs(t) do if v == a then return k end end return nil end local function contains(t, a) return indexof(t, a) ~= nil end local function argmax2(t) return t[1] > t[2] end local function rchoice2(t) return t[1] > random() end local function rbool() return 0.5 >= random() end return { signbyte=signbyte, boolean_xor=boolean_xor, log2=log2, clamp=clamp, lerp=lerp, argmax=argmax, softchoice=softchoice, empty=empty, calc_mean_dev=calc_mean_dev, normalize=normalize, normalize_wrt=normalize_wrt, fitness_shaping=fitness_shaping, unperturbed_rank=unperturbed_rank, copy=copy, indexof=indexof, contains=contains, argmax2=argmax2, rchoice2=rchoice2, rbool=rbool, }