From 12098ee592630b31113f84845e22d0228f0ef51b Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Tue, 12 Jun 2018 05:37:55 +0200 Subject: [PATCH] add normalize_sums utility function --- util.lua | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/util.lua b/util.lua index c6856f3..e03802c 100644 --- a/util.lua +++ b/util.lua @@ -1,5 +1,6 @@ -- TODO: reorganize function order. +local abs = math.abs local assert = assert local exp = math.exp local ipairs = ipairs @@ -93,6 +94,17 @@ local function normalize_wrt(x, s, out) return out end +local function normalize_sums(x, out) + out = out or x + local sum = 0 + for i, v in ipairs(x) do sum = sum + v end + for i, v in ipairs(x) do out[i] = v - sum / #x end + local abssum = 0 + for i, v in ipairs(out) do abssum = abssum + abs(v) end + for i, v in ipairs(out) do out[i] = v / abssum end + return out +end + local function fitness_shaping(rewards) -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py local decreasing = nn.copy(rewards) @@ -207,6 +219,7 @@ return { calc_mean_dev=calc_mean_dev, normalize=normalize, normalize_wrt=normalize_wrt, + normalize_sums=normalize_sums, fitness_shaping=fitness_shaping, unperturbed_rank=unperturbed_rank, copy=copy,