183 lines
4.2 KiB
Lua
183 lines
4.2 KiB
Lua
local huge = math.huge
|
|
local ipairs = ipairs
|
|
local open = io.open
|
|
local sqrt = math.sqrt
|
|
|
|
local nn = require("nn")
|
|
local Base = require("Base")
|
|
|
|
-- https://github.com/modestyachts/ARS/blob/master/code/filter.py
|
|
-- http://www.johndcook.com/blog/standard_deviation/
|
|
local Stats = Base:extend()
|
|
local Normalizer = Base:extend()
|
|
|
|
function Stats:init(shape)
|
|
self._n = 0
|
|
self._M = nn.zeros(shape)
|
|
self._S = nn.zeros(shape)
|
|
end
|
|
|
|
function Stats:push(x)
|
|
assert(nn.prod(x.shape) == nn.prod(self._M.shape), "sizes mismatch")
|
|
local n1 = self._n
|
|
self._n = self._n + 1
|
|
if self._n == 1 then
|
|
nn.copy(x, self._M)
|
|
else
|
|
local delta = {}
|
|
for i, v in ipairs(self._M) do delta[i] = x[i] - v end
|
|
for i, v in ipairs(self._M) do self._M[i] = v + delta[i] / self._n end
|
|
for i, v in ipairs(self._S) do self._S[i] = v + delta[i] * delta[i] * n1 / self._n end
|
|
end
|
|
end
|
|
|
|
function Stats:var()
|
|
local out = {}
|
|
if self._n == 1 then
|
|
for i, v in ipairs(self._M) do out[i] = v * v end
|
|
else
|
|
for i, v in ipairs(self._S) do out[i] = v / (self._n - 1) end
|
|
end
|
|
return out
|
|
end
|
|
|
|
function Stats:dev()
|
|
local out = self:var()
|
|
for i, v in ipairs(out) do out[i] = sqrt(v) end
|
|
return out
|
|
end
|
|
|
|
function Normalizer:init(shape, demean, destd)
|
|
if demean == nil then demean = true end
|
|
if destd == nil then destd = true end
|
|
self.shape = shape
|
|
self.demean = demean
|
|
self.destd = destd
|
|
self.rs = Stats(shape)
|
|
self.mean = nn.zeros(shape)
|
|
self.std = nn.zeros(shape)
|
|
for i = 1, #self.std do self.std[i] = 1 end
|
|
end
|
|
|
|
function Normalizer:process(x)
|
|
local out = nn.copy(x)
|
|
if self.demean then
|
|
for i, v in ipairs(out) do out[i] = out[i] - self.mean[i] end
|
|
end
|
|
if self.destd then
|
|
for i, v in ipairs(out) do out[i] = out[i] / (self.std[i] + 1e-8) end
|
|
end
|
|
return out
|
|
end
|
|
|
|
function Normalizer:update()
|
|
nn.copy(self.rs._M, self.mean) -- FIXME: HACK
|
|
nn.copy(self.rs:dev(), self.std)
|
|
-- Set values for std less than 1e-7 to +inf
|
|
-- to avoid dividing by zero. State elements
|
|
-- with zero variance are set to zero as a result.
|
|
for i, v in ipairs(self.std) do
|
|
if v < 1e-7 then self.std[i] = huge end
|
|
end
|
|
end
|
|
|
|
function Normalizer:push(x, update)
|
|
self.rs:push(x)
|
|
if update == nil or update then self:update() end
|
|
return self:process(x)
|
|
end
|
|
|
|
function Normalizer:default_filename()
|
|
return ('stats%07i.txt'):format(nn.prod(self.shape))
|
|
end
|
|
|
|
function Normalizer:save(fn)
|
|
local fn = fn or self:default_filename()
|
|
local f = open(fn, 'w')
|
|
if f == nil then error("Failed to save stats to file "..fn) end
|
|
f:write(self.rs._n)
|
|
f:write('\n')
|
|
for i, v in ipairs(self.rs._M) do
|
|
f:write(v)
|
|
f:write('\n')
|
|
end
|
|
for i, v in ipairs(self.rs._S) do
|
|
f:write(v)
|
|
f:write('\n')
|
|
end
|
|
f:close()
|
|
end
|
|
|
|
function Normalizer:load(fn)
|
|
local fn = fn or self:default_filename()
|
|
local f = open(fn, 'r')
|
|
if f == nil then error("Failed to load stats from file "..fn) end
|
|
|
|
local i = 0
|
|
local split_M = 1
|
|
local split_S = split_M + nn.prod(self.shape)
|
|
for line in f:lines() do
|
|
i = i + 1
|
|
local n = tonumber(line)
|
|
if n == nil then
|
|
error("Failed reading line "..tostring(i).." of file "..fn)
|
|
end
|
|
|
|
if i <= split_M then
|
|
self.rs._n = n
|
|
elseif i <= split_S then
|
|
self.rs._M[i - split_M] = n
|
|
else
|
|
self.rs._S[i - split_S] = n
|
|
end
|
|
end
|
|
f:close()
|
|
|
|
self:update()
|
|
end
|
|
|
|
--[[
|
|
|
|
-- basic tests
|
|
|
|
local dims = 20
|
|
local rs = Stats(dims)
|
|
local x = nn.zeros(dims)
|
|
|
|
for i = 1, #x do x[i] = nn.normal() end
|
|
rs:push(x)
|
|
print(nn.pp(rs:dev()))
|
|
|
|
for j = 1, 10000 do
|
|
for i = 1, #x do x[i] = nn.normal() end
|
|
rs:push(x)
|
|
end
|
|
print(nn.pp(rs:dev()))
|
|
|
|
--
|
|
|
|
local ms = Normalizer(dims)
|
|
local exp = math.exp
|
|
local y
|
|
|
|
for i = 1, #x do x[i] = exp(nn.normal()) end
|
|
y = ms:push(x)
|
|
print(nn.pp(y))
|
|
|
|
for j = 1, 10000 do
|
|
for i = 1, #x do x[i] = exp(nn.normal()) end
|
|
y = ms:push(x)
|
|
end
|
|
print(nn.pp(y))
|
|
|
|
print("mean:")
|
|
print(nn.pp(ms.mean))
|
|
print("stdev:")
|
|
print(nn.pp(ms.std))
|
|
|
|
--]]
|
|
|
|
return {
|
|
Stats = Stats,
|
|
Normalizer = Normalizer,
|
|
}
|