add normalize_sums utility function
This commit is contained in:
parent
0d28db0fc4
commit
12098ee592
1 changed files with 13 additions and 0 deletions
13
util.lua
13
util.lua
|
@ -1,5 +1,6 @@
|
|||
-- TODO: reorganize function order.
|
||||
|
||||
local abs = math.abs
|
||||
local assert = assert
|
||||
local exp = math.exp
|
||||
local ipairs = ipairs
|
||||
|
@ -93,6 +94,17 @@ local function normalize_wrt(x, s, out)
|
|||
return out
|
||||
end
|
||||
|
||||
local function normalize_sums(x, out)
|
||||
out = out or x
|
||||
local sum = 0
|
||||
for i, v in ipairs(x) do sum = sum + v end
|
||||
for i, v in ipairs(x) do out[i] = v - sum / #x end
|
||||
local abssum = 0
|
||||
for i, v in ipairs(out) do abssum = abssum + abs(v) end
|
||||
for i, v in ipairs(out) do out[i] = v / abssum end
|
||||
return out
|
||||
end
|
||||
|
||||
local function fitness_shaping(rewards)
|
||||
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
||||
local decreasing = nn.copy(rewards)
|
||||
|
@ -207,6 +219,7 @@ return {
|
|||
calc_mean_dev=calc_mean_dev,
|
||||
normalize=normalize,
|
||||
normalize_wrt=normalize_wrt,
|
||||
normalize_sums=normalize_sums,
|
||||
fitness_shaping=fitness_shaping,
|
||||
unperturbed_rank=unperturbed_rank,
|
||||
copy=copy,
|
||||
|
|
Loading…
Add table
Reference in a new issue