smbot/running.lua
2018-06-30 20:13:54 +02:00

184 lines
4.2 KiB
Lua

local huge = math.huge
local ipairs = ipairs
local open = io.open
local sqrt = math.sqrt
local nn = require("nn")
local Base = require("Base")
-- https://github.com/modestyachts/ARS/blob/master/code/filter.py
-- http://www.johndcook.com/blog/standard_deviation/
local Stats = Base:extend()
local Normalizer = Base:extend()
function Stats:init(shape)
self._n = 0
self._M = nn.zeros(shape)
self._S = nn.zeros(shape)
end
function Stats:push(x)
assert(nn.prod(x.shape) == nn.prod(self._M.shape), "sizes mismatch")
local n1 = self._n
self._n = self._n + 1
if self._n == 1 then
nn.copy(x, self._M)
else
local delta = {}
for i, v in ipairs(self._M) do delta[i] = x[i] - v end
for i, v in ipairs(self._M) do self._M[i] = v + delta[i] / self._n end
for i, v in ipairs(self._S) do self._S[i] = v + delta[i] * delta[i] * n1 / self._n end
end
end
function Stats:var()
local out = {}
if self._n == 1 then
for i, v in ipairs(self._M) do out[i] = v * v end
else
for i, v in ipairs(self._S) do out[i] = v / (self._n - 1) end
end
return out
end
function Stats:dev()
local out = self:var()
for i, v in ipairs(out) do out[i] = sqrt(v) end
return out
end
function Normalizer:init(shape, demean, destd)
if demean == nil then demean = true end
if destd == nil then destd = true end
self.shape = shape
self.demean = demean
self.destd = destd
self.rs = Stats(shape)
self.mean = nn.zeros(shape)
self.std = nn.zeros(shape)
for i = 1, #self.std do self.std[i] = 1 end
end
function Normalizer:process(x)
local out = nn.copy(x)
if self.demean then
for i, v in ipairs(out) do out[i] = out[i] - self.mean[i] end
end
if self.destd then
for i, v in ipairs(out) do out[i] = out[i] / (self.std[i] + 1e-8) end
end
return out
end
function Normalizer:update()
nn.copy(self.rs._M, self.mean) -- FIXME: HACK
nn.copy(self.rs:dev(), self.std)
-- Set values for std less than 1e-7 to +inf
-- to avoid dividing by zero. State elements
-- with zero variance are set to zero as a result.
for i, v in ipairs(self.std) do
if v < 1e-7 then self.std[i] = huge end
end
end
function Normalizer:push(x, update)
self.rs:push(x)
if update == nil or update then self:update() end
return self:process(x)
end
function Normalizer:default_filename()
return ('stats%07i.txt'):format(nn.prod(self.shape))
end
function Normalizer:save(fn)
local fn = fn or self:default_filename()
local f = open(fn, 'w')
if f == nil then error("Failed to save stats to file "..fn) end
f:write(self.rs._n)
f:write('\n')
for i, v in ipairs(self.rs._M) do
f:write(v)
f:write('\n')
end
for i, v in ipairs(self.rs._S) do
f:write(v)
f:write('\n')
end
f:close()
end
function Normalizer:load(fn)
local fn = fn or self:default_filename()
local f = open(fn, 'r')
if f == nil then error("Failed to load stats from file "..fn) end
local i = 0
local split_M = 1
local split_S = split_M + nn.prod(self.shape)
for line in f:lines() do
i = i + 1
local n = tonumber(line)
if n == nil then
error("Failed reading line "..tostring(i).." of file "..fn)
end
if i <= split_M then
self.rs._n = n
elseif i <= split_S then
self.rs._M[i - split_M] = n
else
self.rs._S[i - split_S] = n
end
end
f:close()
self:update()
end
--[[
-- basic tests
local dims = 20
local rs = Stats(dims)
local x = nn.zeros(dims)
for i = 1, #x do x[i] = nn.normal() end
rs:push(x)
print(nn.pp(rs:dev()))
for j = 1, 10000 do
for i = 1, #x do x[i] = nn.normal() end
rs:push(x)
end
print(nn.pp(rs:dev()))
--
local ms = Normalizer(dims)
local exp = math.exp
local y
for i = 1, #x do x[i] = exp(nn.normal()) end
y = ms:push(x)
print(nn.pp(y))
for j = 1, 10000 do
for i = 1, #x do x[i] = exp(nn.normal()) end
y = ms:push(x)
end
print(nn.pp(y))
print("mean:")
print(nn.pp(ms.mean))
print("stdev:")
print(nn.pp(ms.std))
--]]
return {
Stats = Stats,
Normalizer = Normalizer,
}