This commit is contained in:
Connor Olding 2019-03-11 07:15:41 +01:00
parent 08f476e6ac
commit 7462e69c61
18 changed files with 1467 additions and 77 deletions

97
ars.lua
View File

@ -1,14 +1,11 @@
-- Augmented Random Search -- Augmented Random Search
-- https://arxiv.org/abs/1803.07055 -- https://arxiv.org/abs/1803.07055
-- with some tweaks (lipschitz stuff) by myself.
-- i also added an option for graycode sampling,
-- borrowed from a (1+1) optimizer,
-- but i haven't yet found a case where it performs better.
local abs = math.abs local abs = math.abs
local exp = math.exp local exp = math.exp
local floor = math.floor local floor = math.floor
local insert = table.insert local insert = table.insert
local remove = table.remove
local ipairs = ipairs local ipairs = ipairs
local log = math.log local log = math.log
local max = math.max local max = math.max
@ -26,6 +23,7 @@ local zeros = nn.zeros
local util = require "util" local util = require "util"
local argsort = util.argsort local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev local calc_mean_dev = util.calc_mean_dev
local calc_mean_dev_unbiased = util.calc_mean_dev_unbiased
local normalize_sums = util.normalize_sums local normalize_sums = util.normalize_sums
local sign = util.sign local sign = util.sign
@ -56,23 +54,6 @@ local function collect_best_indices(scored, top, antithetic)
return indices return indices
end end
local function kinda_lipschitz(dir, pos, neg, mid)
--[[
-- based on the local lipschitz constant of a quadratic curve
-- drawn through the 3 sampled points: positive, negative, and unperturbed.
-- it kinda helps? there's probably a better function to base it around.
local _, dev = calc_mean_dev(dir)
local c0 = neg - mid
local c1 = pos - mid
local l0 = abs(3 * c1 + c0)
local l1 = abs(c1 + 3 * c0)
return max(l0, l1) / (2 * dev)
--]]
-- based on a piece-wise linear function of the 3 sampled points.
local _, dev = calc_mean_dev(dir)
return max(abs(pos - mid), abs(neg - mid)) / dev
end
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic, function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic,
momentum, beta) momentum, beta)
self.dims = dims self.dims = dims
@ -110,7 +91,7 @@ function Ars:decay(param_decay, sigma_decay)
end end
end end
function Ars:ask(graycode) function Ars:ask()
local asked = {} local asked = {}
local noise = {} local noise = {}
@ -126,17 +107,8 @@ function Ars:ask(graycode)
noisy[j] = -v noisy[j] = -v
end end
else else
if graycode ~= nil then for j = 1, self.dims do
for j = 1, self.dims do noisy[j] = self.sigma * normal()
noisy[j] = exp(-precision * uniform())
end
for j = 1, self.dims do
noisy[j] = uniform() < 0.5 and noisy[j] or -noisy[j]
end
else
for j = 1, self.dims do
noisy[j] = self.sigma * normal()
end
end end
end end
@ -150,9 +122,8 @@ function Ars:ask(graycode)
end end
function Ars:tell(scored, unperturbed_score) function Ars:tell(scored, unperturbed_score)
local use_lips = unperturbed_score ~= nil and self.antithetic
self.evals = self.evals + #scored self.evals = self.evals + #scored
if use_lips then self.evals = self.evals + 1 end if unperturbed_score ~= nil then self.evals = self.evals + 1 end
local indices = collect_best_indices(scored, self.poptop, self.antithetic) local indices = collect_best_indices(scored, self.poptop, self.antithetic)
@ -173,7 +144,16 @@ function Ars:tell(scored, unperturbed_score)
end end
local step = zeros(self.dims) local step = zeros(self.dims)
local _, reward_dev = calc_mean_dev(top_rewards)
local _, reward_dev
if unperturbed_score ~= nil then
-- new stuff:
insert(top_rewards, unperturbed_score)
_, reward_dev = calc_mean_dev_unbiased(top_rewards)
remove(top_rewards)
else
_, reward_dev = calc_mean_dev(top_rewards)
end
if reward_dev == 0 then reward_dev = 1 end if reward_dev == 0 then reward_dev = 1 end
if self.antithetic then if self.antithetic then
@ -183,12 +163,15 @@ function Ars:tell(scored, unperturbed_score)
local reward = pos - neg local reward = pos - neg
if reward ~= 0 then if reward ~= 0 then
local noisy = self.noise[ind * 2 - 1] local noisy = self.noise[ind * 2 - 1]
if use_lips then reward = reward / reward_dev
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
reward = reward / lips / self.sigma --[[ new stuff:
else local sum_of_squares = 0
reward = reward / reward_dev for _, v in ipairs(noisy) do
sum_of_squares = sum_of_squares + v * v
end end
reward = reward / sqrt(sum_of_squares)
-]]
local scale = reward / self.poptop * self.beta / 2 local scale = reward / self.poptop * self.beta / 2
for j, v in ipairs(noisy) do for j, v in ipairs(noisy) do
@ -196,7 +179,9 @@ function Ars:tell(scored, unperturbed_score)
end end
end end
end end
else else
error("TODO: update with sum of squares stuff")
for i, ind in ipairs(indices) do for i, ind in ipairs(indices) do
local reward = top_rewards[i] / reward_dev local reward = top_rewards[i] / reward_dev
if reward ~= 0 then if reward ~= 0 then
@ -210,6 +195,7 @@ function Ars:tell(scored, unperturbed_score)
end end
end end
--[[ powersign momentum
if self.momentum > 0 then if self.momentum > 0 then
for i, v in ipairs(step) do for i, v in ipairs(step) do
self.accum[i] = self.momentum * self.accum[i] + v self.accum[i] = self.momentum * self.accum[i] + v
@ -220,6 +206,35 @@ function Ars:tell(scored, unperturbed_score)
for i, v in ipairs(self._params) do for i, v in ipairs(self._params) do
self._params[i] = v + self.param_rate * step[i] self._params[i] = v + self.param_rate * step[i]
end end
--]]
-- neumann momentum
if self.momentum > 0 then
local count = self.count or 0
local period = 10
local mu = 1 - 1 / (1 + count % period)
mu = self.momentum / (1 - 1 / period) * mu
self.count = count + 1
-- mu is intentionally 0 for one iteration.
-- make learning rate invariant to sigma.
for i, v in ipairs(step) do
step[i] = v / self.sigma
end
-- update neumann iterate.
for i, v in ipairs(self.accum) do
self.accum[i] = mu * v - self.param_rate * step[i]
end
for i, v in ipairs(self._params) do
self._params[i] = v - mu * self.accum[i] + self.param_rate * step[i]
end
else
for i, v in ipairs(self._params) do
self._params[i] = v + self.param_rate * step[i]
end
end
self.noise = nil self.noise = nil

749
binser.lua Normal file
View File

@ -0,0 +1,749 @@
-- binser.lua
--[[
Copyright (c) 2016 Calvin Rose
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
]]
local assert = assert
local error = error
local select = select
local pairs = pairs
local getmetatable = getmetatable
local setmetatable = setmetatable
local type = type
local loadstring = loadstring or load
local concat = table.concat
local char = string.char
local byte = string.byte
local format = string.format
local sub = string.sub
local dump = string.dump
local floor = math.floor
local frexp = math.frexp
local unpack = unpack or table.unpack
-- Lua 5.3 frexp polyfill
-- From https://github.com/excessive/cpml/blob/master/modules/utils.lua
if not frexp then
local log, abs, floor = math.log, math.abs, math.floor
local log2 = log(2)
frexp = function(x)
if x == 0 then return 0, 0 end
local e = floor(log(abs(x)) / log2 + 1)
return x / 2 ^ e, e
end
end
local function pack(...)
return {...}, select("#", ...)
end
local function not_array_index(x, len)
return type(x) ~= "number" or x < 1 or x > len or x ~= floor(x)
end
local function type_check(x, tp, name)
assert(type(x) == tp,
format("Expected parameter %q to be of type %q.", name, tp))
end
local bigIntSupport = false
local isInteger
if math.type then -- Detect Lua 5.3
local mtype = math.type
bigIntSupport = loadstring[[
local char = string.char
return function(n)
local nn = n < 0 and -(n + 1) or n
local b1 = nn // 0x100000000000000
local b2 = nn // 0x1000000000000 % 0x100
local b3 = nn // 0x10000000000 % 0x100
local b4 = nn // 0x100000000 % 0x100
local b5 = nn // 0x1000000 % 0x100
local b6 = nn // 0x10000 % 0x100
local b7 = nn // 0x100 % 0x100
local b8 = nn % 0x100
if n < 0 then
b1, b2, b3, b4 = 0xFF - b1, 0xFF - b2, 0xFF - b3, 0xFF - b4
b5, b6, b7, b8 = 0xFF - b5, 0xFF - b6, 0xFF - b7, 0xFF - b8
end
return char(212, b1, b2, b3, b4, b5, b6, b7, b8)
end]]()
isInteger = function(x)
return mtype(x) == 'integer'
end
else
isInteger = function(x)
return floor(x) == x
end
end
-- Copyright (C) 2012-2015 Francois Perrad.
-- number serialization code modified from https://github.com/fperrad/lua-MessagePack
-- Encode a number as a big-endian ieee-754 double, big-endian signed 64 bit integer, or a small integer
local function number_to_str(n)
if isInteger(n) then -- int
if n <= 100 and n >= -27 then -- 1 byte, 7 bits of data
return char(n + 27)
elseif n <= 8191 and n >= -8192 then -- 2 bytes, 14 bits of data
n = n + 8192
return char(128 + (floor(n / 0x100) % 0x100), n % 0x100)
elseif bigIntSupport then
return bigIntSupport(n)
end
end
local sign = 0
if n < 0.0 then
sign = 0x80
n = -n
end
local m, e = frexp(n) -- mantissa, exponent
if m ~= m then
return char(203, 0xFF, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
elseif m == 1/0 then
if sign == 0 then
return char(203, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
else
return char(203, 0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
end
end
e = e + 0x3FE
if e < 1 then -- denormalized numbers
m = m * 2 ^ (52 + e)
e = 0
else
m = (m * 2 - 1) * 2 ^ 52
end
return char(203,
sign + floor(e / 0x10),
(e % 0x10) * 0x10 + floor(m / 0x1000000000000),
floor(m / 0x10000000000) % 0x100,
floor(m / 0x100000000) % 0x100,
floor(m / 0x1000000) % 0x100,
floor(m / 0x10000) % 0x100,
floor(m / 0x100) % 0x100,
m % 0x100)
end
-- Copyright (C) 2012-2015 Francois Perrad.
-- number deserialization code also modified from https://github.com/fperrad/lua-MessagePack
local function number_from_str(str, index)
local b = byte(str, index)
if not b then error("Expected more bytes of input.") end
if b < 128 then
return b - 27, index + 1
elseif b < 192 then
local b2 = byte(str, index + 1)
if not b2 then error("Expected more bytes of input.") end
return b2 + 0x100 * (b - 128) - 8192, index + 2
end
local b1, b2, b3, b4, b5, b6, b7, b8 = byte(str, index + 1, index + 8)
if (not b1) or (not b2) or (not b3) or (not b4) or
(not b5) or (not b6) or (not b7) or (not b8) then
error("Expected more bytes of input.")
end
if b == 212 then
local flip = b1 >= 128
if flip then -- negative
b1, b2, b3, b4 = 0xFF - b1, 0xFF - b2, 0xFF - b3, 0xFF - b4
b5, b6, b7, b8 = 0xFF - b5, 0xFF - b6, 0xFF - b7, 0xFF - b8
end
local n = ((((((b1 * 0x100 + b2) * 0x100 + b3) * 0x100 + b4) *
0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8
if flip then
return (-n) - 1, index + 9
else
return n, index + 9
end
end
if b ~= 203 then
error("Expected number")
end
local sign = b1 > 0x7F and -1 or 1
local e = (b1 % 0x80) * 0x10 + floor(b2 / 0x10)
local m = ((((((b2 % 0x10) * 0x100 + b3) * 0x100 + b4) * 0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8
local n
if e == 0 then
if m == 0 then
n = sign * 0.0
else
n = sign * (m / 2 ^ 52) * 2 ^ -1022
end
elseif e == 0x7FF then
if m == 0 then
n = sign * (1/0)
else
n = 0.0/0.0
end
else
n = sign * (1.0 + m / 2 ^ 52) * 2 ^ (e - 0x3FF)
end
return n, index + 9
end
local function newbinser()
-- unique table key for getting next value
local NEXT = {}
local CTORSTACK = {}
-- NIL = 202
-- FLOAT = 203
-- TRUE = 204
-- FALSE = 205
-- STRING = 206
-- TABLE = 207
-- REFERENCE = 208
-- CONSTRUCTOR = 209
-- FUNCTION = 210
-- RESOURCE = 211
-- INT64 = 212
-- TABLE WITH META = 213
local mts = {}
local ids = {}
local serializers = {}
local deserializers = {}
local resources = {}
local resources_by_name = {}
local types = {}
types["nil"] = function(x, visited, accum)
accum[#accum + 1] = "\202"
end
function types.number(x, visited, accum)
accum[#accum + 1] = number_to_str(x)
end
function types.boolean(x, visited, accum)
accum[#accum + 1] = x and "\204" or "\205"
end
function types.string(x, visited, accum)
local alen = #accum
if visited[x] then
accum[alen + 1] = "\208"
accum[alen + 2] = number_to_str(visited[x])
else
visited[x] = visited[NEXT]
visited[NEXT] = visited[NEXT] + 1
accum[alen + 1] = "\206"
accum[alen + 2] = number_to_str(#x)
accum[alen + 3] = x
end
end
local function check_custom_type(x, visited, accum)
local res = resources[x]
if res then
accum[#accum + 1] = "\211"
types[type(res)](res, visited, accum)
return true
end
local mt = getmetatable(x)
local id = mt and ids[mt]
if id then
local constructing = visited[CTORSTACK]
if constructing[x] then
error("Infinite loop in constructor.")
end
constructing[x] = true
accum[#accum + 1] = "\209"
types[type(id)](id, visited, accum)
local args, len = pack(serializers[id](x))
accum[#accum + 1] = number_to_str(len)
for i = 1, len do
local arg = args[i]
types[type(arg)](arg, visited, accum)
end
visited[x] = visited[NEXT]
visited[NEXT] = visited[NEXT] + 1
-- We finished constructing
constructing[x] = nil
return true
end
end
function types.userdata(x, visited, accum)
if visited[x] then
accum[#accum + 1] = "\208"
accum[#accum + 1] = number_to_str(visited[x])
else
if check_custom_type(x, visited, accum) then return end
error("Cannot serialize this userdata.")
end
end
function types.table(x, visited, accum)
if visited[x] then
accum[#accum + 1] = "\208"
accum[#accum + 1] = number_to_str(visited[x])
else
if check_custom_type(x, visited, accum) then return end
visited[x] = visited[NEXT]
visited[NEXT] = visited[NEXT] + 1
local xlen = #x
local mt = getmetatable(x)
if mt then
accum[#accum + 1] = "\213"
types.table(mt, visited, accum)
else
accum[#accum + 1] = "\207"
end
accum[#accum + 1] = number_to_str(xlen)
for i = 1, xlen do
local v = x[i]
types[type(v)](v, visited, accum)
end
local key_count = 0
for k in pairs(x) do
if not_array_index(k, xlen) then
key_count = key_count + 1
end
end
accum[#accum + 1] = number_to_str(key_count)
for k, v in pairs(x) do
if not_array_index(k, xlen) then
types[type(k)](k, visited, accum)
types[type(v)](v, visited, accum)
end
end
end
end
types["function"] = function(x, visited, accum)
if visited[x] then
accum[#accum + 1] = "\208"
accum[#accum + 1] = number_to_str(visited[x])
else
if check_custom_type(x, visited, accum) then return end
visited[x] = visited[NEXT]
visited[NEXT] = visited[NEXT] + 1
local str = dump(x)
accum[#accum + 1] = "\210"
accum[#accum + 1] = number_to_str(#str)
accum[#accum + 1] = str
end
end
types.cdata = function(x, visited, accum)
if visited[x] then
accum[#accum + 1] = "\208"
accum[#accum + 1] = number_to_str(visited[x])
else
if check_custom_type(x, visited, #accum) then return end
error("Cannot serialize this cdata.")
end
end
types.thread = function() error("Cannot serialize threads.") end
local function deserialize_value(str, index, visited)
local t = byte(str, index)
if not t then return nil, index end
if t < 128 then
return t - 27, index + 1
elseif t < 192 then
local b2 = byte(str, index + 1)
if not b2 then error("Expected more bytes of input.") end
return b2 + 0x100 * (t - 128) - 8192, index + 2
elseif t == 202 then
return nil, index + 1
elseif t == 203 or t == 212 then
return number_from_str(str, index)
elseif t == 204 then
return true, index + 1
elseif t == 205 then
return false, index + 1
elseif t == 206 then
local length, dataindex = number_from_str(str, index + 1)
local nextindex = dataindex + length
if not (length >= 0) then error("Bad string length") end
if #str < nextindex - 1 then error("Expected more bytes of string") end
local substr = sub(str, dataindex, nextindex - 1)
visited[#visited + 1] = substr
return substr, nextindex
elseif t == 207 or t == 213 then
local mt, count, nextindex
local ret = {}
visited[#visited + 1] = ret
nextindex = index + 1
if t == 213 then
mt, nextindex = deserialize_value(str, nextindex, visited)
if type(mt) ~= "table" then error("Expected table metatable") end
end
count, nextindex = number_from_str(str, nextindex)
for i = 1, count do
local oldindex = nextindex
ret[i], nextindex = deserialize_value(str, nextindex, visited)
if nextindex == oldindex then error("Expected more bytes of input.") end
end
count, nextindex = number_from_str(str, nextindex)
for i = 1, count do
local k, v
local oldindex = nextindex
k, nextindex = deserialize_value(str, nextindex, visited)
if nextindex == oldindex then error("Expected more bytes of input.") end
oldindex = nextindex
v, nextindex = deserialize_value(str, nextindex, visited)
if nextindex == oldindex then error("Expected more bytes of input.") end
if k == nil then error("Can't have nil table keys") end
ret[k] = v
end
if mt then setmetatable(ret, mt) end
return ret, nextindex
elseif t == 208 then
local ref, nextindex = number_from_str(str, index + 1)
return visited[ref], nextindex
elseif t == 209 then
local count
local name, nextindex = deserialize_value(str, index + 1, visited)
count, nextindex = number_from_str(str, nextindex)
local args = {}
for i = 1, count do
local oldindex = nextindex
args[i], nextindex = deserialize_value(str, nextindex, visited)
if nextindex == oldindex then error("Expected more bytes of input.") end
end
if not name or not deserializers[name] then
error(("Cannot deserialize class '%s'"):format(tostring(name)))
end
local ret = deserializers[name](unpack(args))
visited[#visited + 1] = ret
return ret, nextindex
elseif t == 210 then
local length, dataindex = number_from_str(str, index + 1)
local nextindex = dataindex + length
if not (length >= 0) then error("Bad string length") end
if #str < nextindex - 1 then error("Expected more bytes of string") end
local ret = loadstring(sub(str, dataindex, nextindex - 1))
visited[#visited + 1] = ret
return ret, nextindex
elseif t == 211 then
local resname, nextindex = deserialize_value(str, index + 1, visited)
if resname == nil then error("Got nil resource name") end
local res = resources_by_name[resname]
if res == nil then
error(("No resources found for name '%s'"):format(tostring(resname)))
end
return res, nextindex
else
error("Could not deserialize type byte " .. t .. ".")
end
end
local function serialize(...)
local visited = {[NEXT] = 1, [CTORSTACK] = {}}
local accum = {}
for i = 1, select("#", ...) do
local x = select(i, ...)
types[type(x)](x, visited, accum)
end
return concat(accum)
end
local function make_file_writer(file)
return setmetatable({}, {
__newindex = function(_, _, v)
file:write(v)
end
})
end
local function serialize_to_file(path, mode, ...)
local file, err = io.open(path, mode)
assert(file, err)
local visited = {[NEXT] = 1, [CTORSTACK] = {}}
local accum = make_file_writer(file)
for i = 1, select("#", ...) do
local x = select(i, ...)
types[type(x)](x, visited, accum)
end
-- flush the writer
file:flush()
file:close()
end
local function writeFile(path, ...)
return serialize_to_file(path, "wb", ...)
end
local function appendFile(path, ...)
return serialize_to_file(path, "ab", ...)
end
local function deserialize(str, index)
assert(type(str) == "string", "Expected string to deserialize.")
local vals = {}
index = index or 1
local visited = {}
local len = 0
local val
while true do
local nextindex
val, nextindex = deserialize_value(str, index, visited)
if nextindex > index then
len = len + 1
vals[len] = val
index = nextindex
else
break
end
end
return vals, len
end
local function deserializeN(str, n, index)
assert(type(str) == "string", "Expected string to deserialize.")
n = n or 1
assert(type(n) == "number", "Expected a number for parameter n.")
assert(n > 0 and floor(n) == n, "N must be a poitive integer.")
local vals = {}
index = index or 1
local visited = {}
local len = 0
local val
while len < n do
local nextindex
val, nextindex = deserialize_value(str, index, visited)
if nextindex > index then
len = len + 1
vals[len] = val
index = nextindex
else
break
end
end
vals[len + 1] = index
return unpack(vals, 1, n + 1)
end
local function readFile(path)
local file, err = io.open(path, "rb")
assert(file, err)
local str = file:read("*all")
file:close()
return deserialize(str)
end
-- Resources
local function registerResource(resource, name)
type_check(name, "string", "name")
assert(not resources[resource],
"Resource already registered.")
assert(not resources_by_name[name],
format("Resource %q already exists.", name))
resources_by_name[name] = resource
resources[resource] = name
return resource
end
local function unregisterResource(name)
type_check(name, "string", "name")
assert(resources_by_name[name], format("Resource %q does not exist.", name))
local resource = resources_by_name[name]
resources_by_name[name] = nil
resources[resource] = nil
return resource
end
-- Templating
local function normalize_template(template)
local ret = {}
for i = 1, #template do
ret[i] = template[i]
end
local non_array_part = {}
-- The non-array part of the template (nested templates) have to be deterministic, so they are sorted.
-- This means that inherently non deterministicly sortable keys (tables, functions) should NOT be used
-- in templates. Looking for way around this.
for k in pairs(template) do
if not_array_index(k, #template) then
non_array_part[#non_array_part + 1] = k
end
end
table.sort(non_array_part)
for i = 1, #non_array_part do
local name = non_array_part[i]
ret[#ret + 1] = {name, normalize_template(template[name])}
end
return ret
end
local function templatepart_serialize(part, argaccum, x, len)
local extras = {}
local extracount = 0
for k, v in pairs(x) do
extras[k] = v
extracount = extracount + 1
end
for i = 1, #part do
local name
if type(part[i]) == "table" then
name = part[i][1]
len = templatepart_serialize(part[i][2], argaccum, x[name], len)
else
name = part[i]
len = len + 1
argaccum[len] = x[part[i]]
end
if extras[name] ~= nil then
extracount = extracount - 1
extras[name] = nil
end
end
if extracount > 0 then
argaccum[len + 1] = extras
else
argaccum[len + 1] = nil
end
return len + 1
end
local function templatepart_deserialize(ret, part, values, vindex)
for i = 1, #part do
local name = part[i]
if type(name) == "table" then
local newret = {}
ret[name[1]] = newret
vindex = templatepart_deserialize(newret, name[2], values, vindex)
else
ret[name] = values[vindex]
vindex = vindex + 1
end
end
local extras = values[vindex]
if extras then
for k, v in pairs(extras) do
ret[k] = v
end
end
return vindex + 1
end
local function template_serializer_and_deserializer(metatable, template)
return function(x)
local argaccum = {}
local len = templatepart_serialize(template, argaccum, x, 0)
return unpack(argaccum, 1, len)
end, function(...)
local ret = {}
local args = {...}
templatepart_deserialize(ret, template, args, 1)
return setmetatable(ret, metatable)
end
end
-- Used to serialize classes withh custom serializers and deserializers.
-- If no _serialize or _deserialize (or no _template) value is found in the
-- metatable, then the metatable is registered as a resources.
local function register(metatable, name, serialize, deserialize)
if type(metatable) == "table" then
name = name or metatable.name
serialize = serialize or metatable._serialize
deserialize = deserialize or metatable._deserialize
if (not serialize) or (not deserialize) then
if metatable._template then
-- Register as template
local t = normalize_template(metatable._template)
serialize, deserialize = template_serializer_and_deserializer(metatable, t)
else
-- Register the metatable as a resource. This is semantically
-- similar and more flexible (handles cycles).
registerResource(metatable, name)
return
end
end
elseif type(metatable) == "string" then
name = name or metatable
end
type_check(name, "string", "name")
type_check(serialize, "function", "serialize")
type_check(deserialize, "function", "deserialize")
assert((not ids[metatable]) and (not resources[metatable]),
"Metatable already registered.")
assert((not mts[name]) and (not resources_by_name[name]),
("Name %q already registered."):format(name))
mts[name] = metatable
ids[metatable] = name
serializers[name] = serialize
deserializers[name] = deserialize
return metatable
end
local function unregister(item)
local name, metatable
if type(item) == "string" then -- assume name
name, metatable = item, mts[item]
else -- assume metatable
name, metatable = ids[item], item
end
type_check(name, "string", "name")
mts[name] = nil
if (metatable) then
resources[metatable] = nil
ids[metatable] = nil
end
serializers[name] = nil
deserializers[name] = nil
resources_by_name[name] = nil;
return metatable
end
local function registerClass(class, name)
name = name or class.name
if class.__instanceDict then -- middleclass
register(class.__instanceDict, name)
else -- assume 30log or similar library
register(class, name)
end
return class
end
return {
-- aliases
s = serialize,
d = deserialize,
dn = deserializeN,
r = readFile,
w = writeFile,
a = appendFile,
serialize = serialize,
deserialize = deserialize,
deserializeN = deserializeN,
readFile = readFile,
writeFile = writeFile,
appendFile = appendFile,
register = register,
unregister = unregister,
registerResource = registerResource,
unregisterResource = unregisterResource,
registerClass = registerClass,
newbinser = newbinser
}
end
return newbinser()

10
cossim_test.lua Normal file
View File

@ -0,0 +1,10 @@
local util = require "util"
for dims = 1, 100 do
print(dims, util.expect_cossim(dims))
end
dims = 1000
print(dims, util.expect_cossim(dims))
dims = 10000
print(dims, util.expect_cossim(dims))

290
eig.lua Normal file
View File

@ -0,0 +1,290 @@
-- blah
--local globalize = require "strict"
local nn = require "nn"
local util = require "util"
local sign = util.sign
local abs = math.abs
local reshape = nn.reshape
local sqrt = math.sqrt
local zeros = nn.zeros
local function tred2(a)
assert(#a.shape == 2)
assert(a.shape[1] == a.shape[2])
local n = a.shape[1]
local d = zeros(n) -- diagonal
local e = zeros(n) -- off-diagonal (e[1] is a dummy value?)
for i = n, 2, -1 do
local l = i - 1
local ind_li = (l - 1) * n + i
local h = 0
local scale = 0
if l > 1 then
for k = 1, l do
local ind_ki = (k - 1) * n + i
scale = scale + abs(a[ind_ki])
end
if scale == 0 then
e[i] = a[ind_li]
else
for k = 1, l do
local ind_ki = (k - 1) * n + i
a[ind_ki] = a[ind_ki] / scale
h = h + a[ind_ki] * a[ind_ki]
end
local f = a[ind_li]
local g = sqrt(h)
if f >= 0 then g = -g end
e[i] = scale * g
h = h - f * g
a[ind_li] = f - g
f = 0
for j = 1, l do
local ind_ij = (i - 1) * n + j
local ind_ji = (j - 1) * n + i
a[ind_ij] = a[ind_ji] / h
g = 0
for k = 1, j do
local ind_kj = (k - 1) * n + j
local ind_ki = (k - 1) * n + i
g = g + a[ind_kj] * a[ind_ki]
end
for k = j + 1, l do
local ind_jk = (j - 1) * n + k
local ind_ki = (k - 1) * n + i
g = g + a[ind_jk] * a[ind_ki]
end
e[j] = g / h
f = f + e[j] * a[ind_ji]
end
local hh = f / (h + h)
for j = 1, l do
local ind_ji = (j - 1) * n + i
f = a[ind_ji]
g = e[j] - hh * f
e[j] = g
for k = 1, j do
local ind_kj = (k - 1) * n + j
local ind_ki = (k - 1) * n + i
a[ind_kj] = a[ind_kj] - (f * e[k] + g * a[ind_ki])
end
end
end
else
e[i] = a[ind_li]
end
d[i] = h
end
d[1] = 0
e[1] = 0
for i = 1, n do
local l = i - 1
if d[i] ~= 0 then
for j = 1, l do
local g = 0
for k = 1, l do
local ind_ki = (k - 1) * n + i
local ind_jk = (j - 1) * n + k
g = g + a[ind_ki] * a[ind_jk]
end
for k = 1, l do
local ind_ik = (i - 1) * n + k
local ind_jk = (j - 1) * n + k
a[ind_jk] = a[ind_jk] - g * a[ind_ik]
end
end
end
local ind_ii = (i - 1) * n + i
d[i] = a[ind_ii]
a[ind_ii] = 1
for j = 1, l do
local ind_ij = (i - 1) * n + j
local ind_ji = (j - 1) * n + i
a[ind_ij] = 0
a[ind_ji] = 0
end
end
return d, e
end
local function pythag(a, b)
--return sqrt(a * a + b * b)
local abs_a = abs(a)
local abs_b = abs(b)
if abs_a > abs_b then
local temp = abs_b / abs_a
temp = temp * temp
return abs_a * sqrt(1 + temp)
elseif abs_b ~= 0 then
local temp = abs_a / abs_b
temp = temp * temp
return abs_b * sqrt(1 + temp)
end
return 0
end
local function tqli(d, e, z)
assert(#z.shape == 2)
assert(z.shape[1] == z.shape[2])
local n = z.shape[1]
assert(#d == n)
assert(#e == n)
local eps = 1.2e-7
for i = 2, n do e[i - 1] = e[i] end
e[n] = 0
for l = 1, n do
local iter = 0
local fucky = 0
local m
while true do
m = l
while m <= n - 1 do
local dd = abs(d[m]) + abs(d[m + 1])
if abs(e[m]) + dd == dd then break end
--if abs(e[m]) <= eps * dd then break end
m = m + 1
end
fucky = fucky + 1
if fucky == 100 then print("fucky!"); break end
if fucky == 1000 then error("super fucky!"); break end
--print(("l: %i, m: %i"):format(l - 1, m - 1))
if m == l then break end
iter = iter + 1
if iter >= 32 then error("Too many iterations in tqli") end
local g = (d[l + 1] - d[l]) / (2 * e[l])
local r = pythag(g, 1)
g = d[m] - d[l] + e[l] / (g + r * sign(g))
local s = 1
local c = 1
local p = 0
for i = m - 1, l, -1 do
local f = s * e[i]
local b = c * e[i]
r = pythag(f, g)
e[i + 1] = r
if r == 0 then
d[i + 1] = d[i + 1] - p
e[m] = 0
break
end
s = f / r
c = g / r
g = d[i + 1] - p
r = (d[i] - g) * s + 2 * c * b
p = s * r
d[i + 1] = g + p
g = c * r - b
for k = 1, n do
if true then
local ind = (i - 1) * n + k
f = z[ind + n]
z[ind + n] = s * z[ind] + c * f
z[ind] = c * z[ind] - s * f
else
local ind = (k - 1) * n + i
f = z[ind + 1]
z[ind + 1] = s * z[ind] + c * f
z[ind] = c * z[ind] - s * f
end
end
end
if r == 0 and i >= l then
-- continue
else
d[l] = d[l] - p
e[l] = g
e[m] = 0
end
end
end
end
--[=[
local A = {
4, 1, -2, 2,
1, 2, 0, 1,
-2, 0, 3, -2,
2, 1, -2, -1,
}
reshape(A, 4, 4)
local d, e = tred2(A)
--[[
print(nn.pp(A))
print(nn.pp(d))
print(nn.pp(e))
--]]
--[[
{
0.248069, 0.744208, 0.620174, 0.000000,
0.702863, -0.578829, 0.413449, 0.000000,
0.666667, 0.333333, -0.666667, 0.000000,
0.000000, 0.000000, 0.000000, 1.000000,
}
{ 2.261538, 1.182906, 5.555556, -1.000000,}
{ 0.000000, -0.092308, 0.895806, 3.000000,}
--]]
--A = nn.transpose(A)
tqli(d, e, A)
print(nn.pp(A))
print(nn.pp(d))
print(nn.pp(e))
local D = zeros{4, 4}
for i = 1, 4 do
D[(i - 1) * 4 + i] = math.exp(d[i])
end
local out = nn.dot(nn.transpose(A), D)
out = nn.dot(out, A)
print(nn.pp(out))
--[[
{
703.414032, 1410.125991, 1478.990752, -43.126976,
-1205.204565, 1573.963121, -940.902581, 319.676625,
14.433478, 3.884400, -10.434671, 6.841011,
3.529384, 3.087068, -5.116236, -17.850223,
}
{ 2.273819, 1.072834, 6.818268, -2.164921,}
{ -0.000000, 0.000000, 0.000000, 0.000000,}
--]]
--]=]
return {
tred2=tred2,
tqli=tqli,
}

View File

@ -28,15 +28,24 @@ end
-- i'm just copying settings from hardmaru's simple_es_example.ipynb. -- i'm just copying settings from hardmaru's simple_es_example.ipynb.
local iterations = 3000 --4000 local iterations = 3000 --4000
local dims = 100
local popsize = dims + 1 local dims, popsize
if false then
dims = 100
popsize = dims + 1
else
dims = 30
popsize = 99
end
local sigma_init = 0.5 local sigma_init = 0.5
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init) --local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
local es = snes.Snes(dims, popsize, 0.1, sigma_init) local es = snes.Snes(dims, popsize, 0.1, sigma_init)
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true) --local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
--local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true) --local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
--local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5) --local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5)
es.min_refresh = 0.5 -- FIXME: needs a better interface.
es.min_refresh = 1.0 -- FIXME: needs a better interface.
if typeof(es) == xnes.Xnes if typeof(es) == xnes.Xnes
or typeof(es) == snes.Snes or typeof(es) == snes.Snes
@ -55,9 +64,9 @@ then
local util = require "util" local util = require "util"
util.normalize_sums(es.utility) util.normalize_sums(es.utility)
es.param_rate = 0.39 es.param_rate = 0.2 --39
es.sigma_rate = 0.39 es.sigma_rate = 0.05 --39
es.covar_rate = 0.39 es.covar_rate = 0.1 --39
es.adaptive = false es.adaptive = false
end end

81
expm.lua Normal file
View File

@ -0,0 +1,81 @@
-- here's a really awful way of computing the matrix exponential.
-- we employ the QR algorithm to find eigenpairs of a given symmetric matrix,
-- then run the ordinary exponent function over the eigenvalues.
-- this only works for symmetric matrices!
local nn = require "nn"
local qr = require "qr"
local util = require "util"
local copy = util.copy
local dot = nn.dot
local exp = math.exp
local reshape = nn.reshape
local transpose = nn.transpose
local zeros = nn.zeros
local function expm(mat)
assert(#mat.shape == 2)
assert(mat.shape[1] == mat.shape[2], "expm input must be square")
--assert(stuff(mat), "expm input must be symmetrical")
local dims = mat.shape[1]
local vec = zeros(mat.shape)
for i = 1, dims do
local ind = (i - 1) * dims + i -- diagonal
vec[ind] = 1
end
local diag = mat
for i = 1, 10 do
local q, r = qr(diag)
vec = dot(vec, q)
diag = dot(r, q)
end
for y = 1, dims do
for x = 1, dims do
local ind = (y - 1) * dims + x
if x == y then
diag[ind] = exp(diag[ind])
else
diag[ind] = 0
end
end
end
return dot(dot(vec, diag), transpose(vec))
end
local eig = require "eig"
local tred2 = eig.tred2
local tqli = eig.tqli
local function expm2(mat)
assert(#mat.shape == 2)
assert(mat.shape[1] == mat.shape[2])
local dims = mat.shape[1]
-- new version that computes much better (faster?) eigenpairs
local vec = copy(mat)
local d, e = tred2(vec)
tqli(d, e, vec)
local diag = {}
for y = 1, dims do
for x = 1, dims do
local ind = (y - 1) * dims + x
if x == y then
diag[ind] = exp(d[x])
else
diag[ind] = 0
end
end
end
reshape(diag, dims, dims)
return dot(dot(transpose(vec), diag), vec)
end
return expm2

31
expm_test.lua Normal file
View File

@ -0,0 +1,31 @@
local globalize = require "strict"
local nn = require "nn"
local expm = require "expm"
local sqrt = math.sqrt
local dims = 8
local n = dims * dims
local M = {}
local coeff = dims * dims * sqrt(dims)
for i = 1, n do M[i] = i / coeff end
M.shape = {dims, dims}
M = nn.dot(M, nn.transpose(M))
print(nn.pp(M, "%9.6f"))
local exp_M = expm(M)
print(nn.pp(exp_M, "%9.6f"))
local binser = require "binser"
--local dat = binser.s(exp_M)
binser.writeFile("expm.bin", exp_M)
local res, len = binser.readFile("expm.bin")
assert(len == 1)
exp_M = res[1]
print(nn.pp(exp_M, "%9.6f"))

View File

@ -284,7 +284,7 @@ local function learn_from_epoch()
end end
local step local step
if cfg.es == 'ars' and cfg.ars_lips then if cfg.es == 'ars' then --and cfg.ars_lips then
step = es:tell(trial_rewards, current_cost) step = es:tell(trial_rewards, current_cost)
else else
step = es:tell(trial_rewards) step = es:tell(trial_rewards)

2
nn.lua
View File

@ -119,6 +119,8 @@ local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
di = di or 1 di = di or 1
depth = depth or 0 depth = depth or 0
if t == nil then return "nil" end
if t.shape == nil then return '['..pp_join(sep, fmt, t)..']'..sep..'\n' end if t.shape == nil then return '['..pp_join(sep, fmt, t)..']'..sep..'\n' end
local dim = t.shape[di] local dim = t.shape[di]

7
pp_test.lua Normal file
View File

@ -0,0 +1,7 @@
local nn = require "nn"
a = {0,1,2,3,4,5,6,7}
print(nn.pp(a, "%9.4f"))
print(nn.pp(nn.reshape(a, 4, 2), "%9.4f"))
print(nn.pp(nn.reshape(a, 2, 4), "%9.4f"))
print(nn.pp(nn.reshape(a, 2, 2, 2), "%9.4f"))

View File

@ -412,6 +412,103 @@ make_preset{
beta = 1.0, beta = 1.0,
} }
make_preset{
name = 'xnes-recon', -- recon-sidered
-- parent = 'big-scroll-hidden',
parent = 'big-scroll-reduced',
es = 'xnes',
-- embed = false,
-- reduce_tiles = 5,
-- hidden_size = 54,
epoch_trials = 20,
epoch_top_trials = 20,
negate_trials = true,
deterministic = true,
deviation = 0.1,
param_decay = 0.01,
param_rate = 0.2,
sigma_rate = 0.05,
covar_rate = 0.1,
}
make_preset{
name = 'arse',
parent = 'big-scroll-reduced',
deterministic = true,
es = 'ars',
epoch_trials = 5,
epoch_top_trials = 9999,
deviation = 1.0,
param_rate = 1.0,
beta = 1.0, -- fix the default.
beta = 5.0, -- oops, i had a dumb bug.
param_decay = 0.01,
momentum = 0.0,
}
make_preset{
name = 'arse2',
parent = 'arse',
beta = 5.0, -- oops, i had a dumb bug.
param_decay = 0.001,
embed = false,
}
make_preset{
name = 'arse3',
parent = 'arse2',
epoch_trials = 10,
-- note: also disabled sum of squares in ars.lua
beta = 1.0,
--deviation = 0.5, -- after 500 epochs
deviation = 0.7071, -- after 500+521 epochs
}
make_preset{
name = 'arse4',
parent = 'arse3',
hidden = true,
hidden_size = 68,
}
make_preset{
name = 'arse5',
parent = 'arse3',
-- sum of squares still disabled. i probably won't re-enable it really.
--deviation = 1.0,
deviation = 1.414, -- after 80+360+790 epochs
momentum = 0.8, -- neumann momentum, peaks at this value
--param_decay = 0.003, -- after 80 epochs
--param_decay = 0.01, -- after 80+360 epochs
param_decay = 0.0, -- after 80+360+790 epochs
}
make_preset{
name = 'arse6',
parent = 'arse3',
epoch_trials = 20,
param_rate = 0.25,
deviation = 0.1, -- maybe try 0.15
momentum = 0.9, -- maybe try 0.8
param_decay = 0.0,
}
-- end of new stuff -- end of new stuff
make_preset{ make_preset{

8
qr.lua
View File

@ -1,9 +1,11 @@
local min = math.min
local sqrt = math.sqrt
local nn = require "nn" local nn = require "nn"
local assert = assert
local dot = nn.dot local dot = nn.dot
local ipairs = ipairs
local min = math.min
local reshape = nn.reshape local reshape = nn.reshape
local sqrt = math.sqrt
local transpose = nn.transpose local transpose = nn.transpose
local zeros = nn.zeros local zeros = nn.zeros

View File

@ -43,7 +43,7 @@ local function qr(a)
local den = 0 local den = 0
for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end
for k = j0, j1 do den = den + q[k] * q[k] end for k = j0, j1 do den = den + q[k] * q[k] end
print(num, den) --print(num, den)
if den == 0 then den = 1 end -- TODO: should probably just error. if den == 0 then den = 1 end -- TODO: should probably just error.
local x = num / den local x = num / den

79
qr3.lua Normal file
View File

@ -0,0 +1,79 @@
-- gram-schmidt process
local nn = require "nn"
local dot = nn.dot
local reshape = nn.reshape
local sqrt = math.sqrt
local transpose = nn.transpose
local zeros = nn.zeros
local function qr(mat)
assert(#mat.shape == 2)
local v = transpose(mat)
local rows = v.shape[1]
local cols = v.shape[2]
local u = zeros(v.shape)
local w = {}
local y = 1
local function load_row()
local start = (y - 1) * cols
for x = 1, cols do w[x] = v[start + x] end
end
local function push_row()
local sum = 0
for _, value in ipairs(w) do sum = sum + value * value end
local norm = sqrt(sum)
local start = (y - 1) * cols
for x, value in ipairs(w) do u[start + x] = value / norm end
y = y + 1
end
load_row()
push_row()
local sums = {}
for i = 2, rows do
load_row()
for x = 1, cols do sums[x] = 0 end
for j = 1, i - 1 do
local start = (j - 1) * cols
local dotted = 0
for x, value in ipairs(w) do
dotted = dotted + value * u[start + x]
end
--[[
local scale = 0
for x = 1, cols do
local value = u[start + x]
scale = scale + value * value
end
print(scale)
dotted = dotted / scale
--]]
for x, value in ipairs(sums) do
sums[x] = value + dotted * u[start + x]
end
end
for x, value in ipairs(w) do w[x] = value - sums[x] end
push_row()
end
return transpose(u), dot(u, mat)
end
return qr

View File

@ -2,6 +2,7 @@ local globalize = require "strict"
local nn = require "nn" local nn = require "nn"
local qr = require "qr" local qr = require "qr"
local qr2 = require "qr2" local qr2 = require "qr2"
local qr3 = require "qr3"
local A local A
@ -24,7 +25,7 @@ elseif false then
} }
A = nn.reshape(A, 5, 3) A = nn.reshape(A, 5, 3)
else elseif false then
A = { A = {
1, 2, 0, 1, 2, 0,
2, 3, 1, 2, 3, 1,
@ -44,13 +45,27 @@ else
A = nn.reshape(A, 5, 3) A = nn.reshape(A, 5, 3)
--A = nn.transpose(A) --A = nn.transpose(A)
else
A = {
1, 2, -3,
2, 4, 5,
-3, 5, 6,
}
A = nn.reshape(A, 3, 3)
end end
print("A") print("A")
print(nn.pp(A, "%9.4f")) print(nn.pp(A, "%9.4f"))
print() print()
local Q, R = qr2(A) local Q, R = qr(A)
print("Q (reference)")
print(nn.pp(Q, "%9.4f"))
print()
local Q, R = qr3(A)
print("Q") print("Q")
print(nn.pp(Q, "%9.4f")) print(nn.pp(Q, "%9.4f"))

7
sign_test.lua Normal file
View File

@ -0,0 +1,7 @@
local util = require("util")
print(util.sign(-1))
print(util.sign(0))
print(util.sign(1))
print(util.sign(-0.1))
print(util.sign(0.0))
print(util.sign(0.1))

View File

@ -209,6 +209,11 @@ local function cdf(x)
-- i don't remember where this is from. -- i don't remember where this is from.
local sign = x >= 0 and 1 or -1 local sign = x >= 0 and 1 or -1
return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x))) return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x)))
-- more accurate (via GELU paper, might be lifted from elsewhere):
--local const = sqrt(2 / pi)
--return 0.5 * (1 + tanh(const * (1 + 0.044715 * x * x) * x))
--return 0.5 * (1 + tanh(0.7978845608 * (1 + 0.044715 * x * x) * x))
end end
local function fitness_shaping(rewards) local function fitness_shaping(rewards)
@ -283,7 +288,7 @@ local function expect_cossim(n)
if n >= 128000 then if n >= 128000 then
return 1 / sqrt(pi / 2 * n + 1) return 1 / sqrt(pi / 2 * n + 1)
elseif n >= 80 then elseif n >= 80 then
poly = (2.4674010 * n + -2.4673232) * n + 1.2274046 poly = (2.4674010 * n - 2.4673232) * n + 1.2274046
return 1 / sqrt(sqrt(poly)) return 1 / sqrt(sqrt(poly))
end end
-- fall-through when it's faster just to compute iteratively. -- fall-through when it's faster just to compute iteratively.

View File

@ -16,10 +16,13 @@ local unpack = table.unpack or unpack
local Base = require "Base" local Base = require "Base"
local nn = require "nn" local nn = require "nn"
local dot = nn.dot
local dot_mv = nn.dot_mv local dot_mv = nn.dot_mv
local normal = nn.normal local normal = nn.normal
local zeros = nn.zeros local zeros = nn.zeros
local expm = require "expm"
local util = require "util" local util = require "util"
local argsort = util.argsort local argsort = util.argsort
@ -35,22 +38,6 @@ local function make_utility(popsize, out)
return utility return utility
end end
local function make_covars(dims, sigma, out)
local covars = out or zeros{dims, dims}
local c = sigma / dims
-- simplified form of the determinant of the matrix we're going to create:
local det = pow(1 - c, dims - 1) * (c * (dims - 1) + 1)
-- multiplying by this constant makes the determinant 1:
local m = pow(1 / det, 1 / dims)
local filler = c * m
for i=1, #covars do covars[i] = filler end
-- diagonals:
for i=1, dims do covars[i + dims * (i - 1)] = m end
return covars
end
function Xnes:init(dims, popsize, base_rate, sigma, antithetic) function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
-- heuristic borrowed from CMA-ES: -- heuristic borrowed from CMA-ES:
self.dims = dims self.dims = dims
@ -67,9 +54,11 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
self.utility = make_utility(self.popsize) self.utility = make_utility(self.popsize)
self.mean = zeros{dims} self.mean = zeros{dims}
-- note: this is technically the co-standard-deviation. self.covars = zeros{dims, dims}
-- you can imagine the "s" standing for "sqrt" if you like. for i=1, dims do
self.covars = make_covars(self.dims, self.sigma, self.covars) local ind = (i - 1) * dims + i -- diagonal
self.covars[ind] = 1
end
self.evals = 0 self.evals = 0
end end
@ -202,9 +191,12 @@ function Xnes:tell(scored, noise)
end end
self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma) self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma)
for i, v in ipairs(self.covars) do
self.covars[i] = v * exp(self.covar_rate * 0.5 * g_covars[i]) -- re-use g_covars from before just to scale it.
for i, v in ipairs(g_covars) do
g_covars[i] = self.covar_rate * 0.5 * v
end end
self.covars = dot(self.covars, expm(g_covars))
-- bookkeeping: -- bookkeeping:
self.noise = nil self.noise = nil
@ -214,7 +206,6 @@ end
return { return {
make_utility = make_utility, make_utility = make_utility,
make_covars = make_covars,
Xnes = Xnes, Xnes = Xnes,
} }