-- TODO: reorganize function order. local abs = math.abs local assert = assert local exp = math.exp local ipairs = ipairs local log = math.log local max = math.max local min = math.min local open = io.open local pairs = pairs local pi = math.pi local random = math.random local select = select local sort = table.sort local sqrt = math.sqrt local function sign(x) -- remember that 0 is truthy in Lua. return x == 0 and 0 or x > 0 and 1 or -1 end local function signbyte(x) if x >= 128 then x = 256 - x end return x end local function boolean_xor(a, b) if a and b then return false end if not a and not b then return false end return true end local _invlog2 = 1 / log(2) local function log2(x) return log(x) * _invlog2 end local function clamp(x, l, u) return min(max(x, l), u) end local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end local function argmax(...) local max_i = 0 local max_v = -999999999 for i=1, select("#", ...) do local v = select(i, ...) if v > max_v then max_i = i max_v = v end end return max_i end local function softchoice(...) local t = random() local psum = 0 for i=1, select("#", ...) do local p = select(i, ...) psum = psum + p if t < psum then return i end end end local function empty(t) for k, _ in pairs(t) do t[k] = nil end return t end local function calc_mean_dev(x) local mean = 0 for i, v in ipairs(x) do mean = mean + v / #x end local dev = 0 for i, v in ipairs(x) do local delta = v - mean dev = dev + delta * delta / #x end return mean, sqrt(dev) end local function normalize(x, out) out = out or x local mean, dev = calc_mean_dev(x) if dev <= 0 then dev = 1 end for i, v in ipairs(x) do out[i] = (v - mean) / dev end return out end local function normalize_wrt(x, s, out) out = out or x local mean, dev = calc_mean_dev(s) if dev <= 0 then dev = 1 end for i, v in ipairs(x) do out[i] = (v - mean) / dev end return out end local function normalize_sums(x, out) out = out or x local sum = 0 for i, v in ipairs(x) do sum = sum + v end for i, v in ipairs(x) do out[i] = v - sum / #x end local abssum = 0 for i, v in ipairs(out) do abssum = abssum + abs(v) end for i, v in ipairs(out) do out[i] = v / abssum end return out end local function unperturbed_rank(rewards, unperturbed_reward) -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py local nth_place = 1 for i, v in ipairs(rewards) do if v > unperturbed_reward then nth_place = nth_place + 1 end end return nth_place end local function copy(t, out) -- shallow copy assert(type(t) == "table") local out = out or {} for k, v in pairs(t) do out[k] = v end return out end local function indexof(t, a) assert(type(t) == "table") for k, v in pairs(t) do if v == a then return k end end return nil end local function contains(t, a) return indexof(t, a) ~= nil end local function argmax2(t) return t[1] > t[2] end local function rchoice2(t) return t[1] > random() end local function rbool() return 0.5 >= random() end local function argsort(t, comp, out) comp = comp or function(a, b) return a < b end out = out or {} for i=1, #t do out[i] = i end sort(out, function(a, b) return comp(t[a], t[b]) end) return out end local function exists(fn) local f = open(fn, "r") if f then f:close() return true else return false end end local function pdf(x, mean, std) -- probability density function for a normal distribution. mean = mean or 0 std = std or 1 if mean == 0 and std == 1 then return 0.39894228040143 * exp(x * x * -0.5) end local var = std * std return 1 / sqrt(2 * pi * var) * exp((x - mean) * (x - mean) / (-2 * var)) end local function cdf(x) -- a very rough approximation of the -- cumulative distribution function for a normal distribution. -- absolute error peaks at plus-or-minus 1.654. -- i don't remember where this is from. local sign = x >= 0 and 1 or -1 return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x))) end local function fitness_shaping(rewards) -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py local decreasing = copy(rewards) sort(decreasing, function(a, b) return a > b end) local shaped_returns = {} local lamb = #rewards local denom = 0 for i, v in ipairs(rewards) do local l = log2(lamb / 2 + 1) local r = log2(indexof(decreasing, v)) denom = denom + max(0, l - r) end for i, v in ipairs(rewards) do local l = log2(lamb / 2 + 1) local r = log2(indexof(decreasing, v)) local numer = max(0, l - r) insert(shaped_returns, numer / denom + 1 / lamb) end return shaped_returns end local function weighted_mann_whitney(s0, s1, w0, w1) -- when w0 and w1 are nil, this decomposes(?) to the regular Mann-Whitney. if w0 == nil then w0 = {} for i=1, #s0 do w0[i] = 1.0 end end if w1 == nil then w1 = {} for i=1, #s1 do w1[i] = 1.0 end end assert(#s0 == #w0) assert(#s1 == #w1) local s0_sum, s1_sum, w0_sum, w1_sum = 0, 0, 0, 0 for i, v in ipairs(s0) do s0_sum = s0_sum + v end for i, v in ipairs(s1) do s1_sum = s1_sum + v end for i, v in ipairs(w0) do w0_sum = w0_sum + v end for i, v in ipairs(w1) do w1_sum = w1_sum + v end local U = 0 for i=1, #s0 do for j=1, #s1 do if s0[i] > s1[j] then U = U + w0[i] * w1[j] elseif s0[i] == s1[j] then U = U + w0[i] * w1[j] * 0.5 end end end local mean = w0_sum * w1_sum * 0.5 local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6) local p = cdf((U - mean) / std) if s0_sum > s1_sum then return 1 - p else return p end end return { sign=sign, signbyte=signbyte, boolean_xor=boolean_xor, log2=log2, clamp=clamp, lerp=lerp, argmax=argmax, softchoice=softchoice, empty=empty, calc_mean_dev=calc_mean_dev, normalize=normalize, normalize_wrt=normalize_wrt, normalize_sums=normalize_sums, fitness_shaping=fitness_shaping, unperturbed_rank=unperturbed_rank, copy=copy, indexof=indexof, contains=contains, argmax2=argmax2, rchoice2=rchoice2, rbool=rbool, argsort=argsort, exists=exists, pdf=pdf, cdf=cdf, weighted_mann_whitney=weighted_mann_whitney, }