temp 4
This commit is contained in:
parent
08f476e6ac
commit
7462e69c61
18 changed files with 1467 additions and 77 deletions
97
ars.lua
97
ars.lua
|
@ -1,14 +1,11 @@
|
|||
-- Augmented Random Search
|
||||
-- https://arxiv.org/abs/1803.07055
|
||||
-- with some tweaks (lipschitz stuff) by myself.
|
||||
-- i also added an option for graycode sampling,
|
||||
-- borrowed from a (1+1) optimizer,
|
||||
-- but i haven't yet found a case where it performs better.
|
||||
|
||||
local abs = math.abs
|
||||
local exp = math.exp
|
||||
local floor = math.floor
|
||||
local insert = table.insert
|
||||
local remove = table.remove
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local max = math.max
|
||||
|
@ -26,6 +23,7 @@ local zeros = nn.zeros
|
|||
local util = require "util"
|
||||
local argsort = util.argsort
|
||||
local calc_mean_dev = util.calc_mean_dev
|
||||
local calc_mean_dev_unbiased = util.calc_mean_dev_unbiased
|
||||
local normalize_sums = util.normalize_sums
|
||||
local sign = util.sign
|
||||
|
||||
|
@ -56,23 +54,6 @@ local function collect_best_indices(scored, top, antithetic)
|
|||
return indices
|
||||
end
|
||||
|
||||
local function kinda_lipschitz(dir, pos, neg, mid)
|
||||
--[[
|
||||
-- based on the local lipschitz constant of a quadratic curve
|
||||
-- drawn through the 3 sampled points: positive, negative, and unperturbed.
|
||||
-- it kinda helps? there's probably a better function to base it around.
|
||||
local _, dev = calc_mean_dev(dir)
|
||||
local c0 = neg - mid
|
||||
local c1 = pos - mid
|
||||
local l0 = abs(3 * c1 + c0)
|
||||
local l1 = abs(c1 + 3 * c0)
|
||||
return max(l0, l1) / (2 * dev)
|
||||
--]]
|
||||
-- based on a piece-wise linear function of the 3 sampled points.
|
||||
local _, dev = calc_mean_dev(dir)
|
||||
return max(abs(pos - mid), abs(neg - mid)) / dev
|
||||
end
|
||||
|
||||
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic,
|
||||
momentum, beta)
|
||||
self.dims = dims
|
||||
|
@ -110,7 +91,7 @@ function Ars:decay(param_decay, sigma_decay)
|
|||
end
|
||||
end
|
||||
|
||||
function Ars:ask(graycode)
|
||||
function Ars:ask()
|
||||
local asked = {}
|
||||
local noise = {}
|
||||
|
||||
|
@ -126,17 +107,8 @@ function Ars:ask(graycode)
|
|||
noisy[j] = -v
|
||||
end
|
||||
else
|
||||
if graycode ~= nil then
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = exp(-precision * uniform())
|
||||
end
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = uniform() < 0.5 and noisy[j] or -noisy[j]
|
||||
end
|
||||
else
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = self.sigma * normal()
|
||||
end
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = self.sigma * normal()
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -150,9 +122,8 @@ function Ars:ask(graycode)
|
|||
end
|
||||
|
||||
function Ars:tell(scored, unperturbed_score)
|
||||
local use_lips = unperturbed_score ~= nil and self.antithetic
|
||||
self.evals = self.evals + #scored
|
||||
if use_lips then self.evals = self.evals + 1 end
|
||||
if unperturbed_score ~= nil then self.evals = self.evals + 1 end
|
||||
|
||||
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
|
||||
|
||||
|
@ -173,7 +144,16 @@ function Ars:tell(scored, unperturbed_score)
|
|||
end
|
||||
|
||||
local step = zeros(self.dims)
|
||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||
|
||||
local _, reward_dev
|
||||
if unperturbed_score ~= nil then
|
||||
-- new stuff:
|
||||
insert(top_rewards, unperturbed_score)
|
||||
_, reward_dev = calc_mean_dev_unbiased(top_rewards)
|
||||
remove(top_rewards)
|
||||
else
|
||||
_, reward_dev = calc_mean_dev(top_rewards)
|
||||
end
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
if self.antithetic then
|
||||
|
@ -183,12 +163,15 @@ function Ars:tell(scored, unperturbed_score)
|
|||
local reward = pos - neg
|
||||
if reward ~= 0 then
|
||||
local noisy = self.noise[ind * 2 - 1]
|
||||
if use_lips then
|
||||
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
|
||||
reward = reward / lips / self.sigma
|
||||
else
|
||||
reward = reward / reward_dev
|
||||
reward = reward / reward_dev
|
||||
|
||||
--[[ new stuff:
|
||||
local sum_of_squares = 0
|
||||
for _, v in ipairs(noisy) do
|
||||
sum_of_squares = sum_of_squares + v * v
|
||||
end
|
||||
reward = reward / sqrt(sum_of_squares)
|
||||
-]]
|
||||
|
||||
local scale = reward / self.poptop * self.beta / 2
|
||||
for j, v in ipairs(noisy) do
|
||||
|
@ -196,7 +179,9 @@ function Ars:tell(scored, unperturbed_score)
|
|||
end
|
||||
end
|
||||
end
|
||||
|
||||
else
|
||||
error("TODO: update with sum of squares stuff")
|
||||
for i, ind in ipairs(indices) do
|
||||
local reward = top_rewards[i] / reward_dev
|
||||
if reward ~= 0 then
|
||||
|
@ -210,6 +195,7 @@ function Ars:tell(scored, unperturbed_score)
|
|||
end
|
||||
end
|
||||
|
||||
--[[ powersign momentum
|
||||
if self.momentum > 0 then
|
||||
for i, v in ipairs(step) do
|
||||
self.accum[i] = self.momentum * self.accum[i] + v
|
||||
|
@ -220,6 +206,35 @@ function Ars:tell(scored, unperturbed_score)
|
|||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v + self.param_rate * step[i]
|
||||
end
|
||||
--]]
|
||||
|
||||
-- neumann momentum
|
||||
if self.momentum > 0 then
|
||||
local count = self.count or 0
|
||||
local period = 10
|
||||
local mu = 1 - 1 / (1 + count % period)
|
||||
mu = self.momentum / (1 - 1 / period) * mu
|
||||
self.count = count + 1
|
||||
-- mu is intentionally 0 for one iteration.
|
||||
|
||||
-- make learning rate invariant to sigma.
|
||||
for i, v in ipairs(step) do
|
||||
step[i] = v / self.sigma
|
||||
end
|
||||
|
||||
-- update neumann iterate.
|
||||
for i, v in ipairs(self.accum) do
|
||||
self.accum[i] = mu * v - self.param_rate * step[i]
|
||||
end
|
||||
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v - mu * self.accum[i] + self.param_rate * step[i]
|
||||
end
|
||||
else
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v + self.param_rate * step[i]
|
||||
end
|
||||
end
|
||||
|
||||
self.noise = nil
|
||||
|
||||
|
|
749
binser.lua
Normal file
749
binser.lua
Normal file
|
@ -0,0 +1,749 @@
|
|||
-- binser.lua
|
||||
|
||||
--[[
|
||||
Copyright (c) 2016 Calvin Rose
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
]]
|
||||
|
||||
local assert = assert
|
||||
local error = error
|
||||
local select = select
|
||||
local pairs = pairs
|
||||
local getmetatable = getmetatable
|
||||
local setmetatable = setmetatable
|
||||
local type = type
|
||||
local loadstring = loadstring or load
|
||||
local concat = table.concat
|
||||
local char = string.char
|
||||
local byte = string.byte
|
||||
local format = string.format
|
||||
local sub = string.sub
|
||||
local dump = string.dump
|
||||
local floor = math.floor
|
||||
local frexp = math.frexp
|
||||
local unpack = unpack or table.unpack
|
||||
|
||||
-- Lua 5.3 frexp polyfill
|
||||
-- From https://github.com/excessive/cpml/blob/master/modules/utils.lua
|
||||
if not frexp then
|
||||
local log, abs, floor = math.log, math.abs, math.floor
|
||||
local log2 = log(2)
|
||||
frexp = function(x)
|
||||
if x == 0 then return 0, 0 end
|
||||
local e = floor(log(abs(x)) / log2 + 1)
|
||||
return x / 2 ^ e, e
|
||||
end
|
||||
end
|
||||
|
||||
local function pack(...)
|
||||
return {...}, select("#", ...)
|
||||
end
|
||||
|
||||
local function not_array_index(x, len)
|
||||
return type(x) ~= "number" or x < 1 or x > len or x ~= floor(x)
|
||||
end
|
||||
|
||||
local function type_check(x, tp, name)
|
||||
assert(type(x) == tp,
|
||||
format("Expected parameter %q to be of type %q.", name, tp))
|
||||
end
|
||||
|
||||
local bigIntSupport = false
|
||||
local isInteger
|
||||
if math.type then -- Detect Lua 5.3
|
||||
local mtype = math.type
|
||||
bigIntSupport = loadstring[[
|
||||
local char = string.char
|
||||
return function(n)
|
||||
local nn = n < 0 and -(n + 1) or n
|
||||
local b1 = nn // 0x100000000000000
|
||||
local b2 = nn // 0x1000000000000 % 0x100
|
||||
local b3 = nn // 0x10000000000 % 0x100
|
||||
local b4 = nn // 0x100000000 % 0x100
|
||||
local b5 = nn // 0x1000000 % 0x100
|
||||
local b6 = nn // 0x10000 % 0x100
|
||||
local b7 = nn // 0x100 % 0x100
|
||||
local b8 = nn % 0x100
|
||||
if n < 0 then
|
||||
b1, b2, b3, b4 = 0xFF - b1, 0xFF - b2, 0xFF - b3, 0xFF - b4
|
||||
b5, b6, b7, b8 = 0xFF - b5, 0xFF - b6, 0xFF - b7, 0xFF - b8
|
||||
end
|
||||
return char(212, b1, b2, b3, b4, b5, b6, b7, b8)
|
||||
end]]()
|
||||
isInteger = function(x)
|
||||
return mtype(x) == 'integer'
|
||||
end
|
||||
else
|
||||
isInteger = function(x)
|
||||
return floor(x) == x
|
||||
end
|
||||
end
|
||||
|
||||
-- Copyright (C) 2012-2015 Francois Perrad.
|
||||
-- number serialization code modified from https://github.com/fperrad/lua-MessagePack
|
||||
-- Encode a number as a big-endian ieee-754 double, big-endian signed 64 bit integer, or a small integer
|
||||
local function number_to_str(n)
|
||||
if isInteger(n) then -- int
|
||||
if n <= 100 and n >= -27 then -- 1 byte, 7 bits of data
|
||||
return char(n + 27)
|
||||
elseif n <= 8191 and n >= -8192 then -- 2 bytes, 14 bits of data
|
||||
n = n + 8192
|
||||
return char(128 + (floor(n / 0x100) % 0x100), n % 0x100)
|
||||
elseif bigIntSupport then
|
||||
return bigIntSupport(n)
|
||||
end
|
||||
end
|
||||
local sign = 0
|
||||
if n < 0.0 then
|
||||
sign = 0x80
|
||||
n = -n
|
||||
end
|
||||
local m, e = frexp(n) -- mantissa, exponent
|
||||
if m ~= m then
|
||||
return char(203, 0xFF, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
|
||||
elseif m == 1/0 then
|
||||
if sign == 0 then
|
||||
return char(203, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
|
||||
else
|
||||
return char(203, 0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
|
||||
end
|
||||
end
|
||||
e = e + 0x3FE
|
||||
if e < 1 then -- denormalized numbers
|
||||
m = m * 2 ^ (52 + e)
|
||||
e = 0
|
||||
else
|
||||
m = (m * 2 - 1) * 2 ^ 52
|
||||
end
|
||||
return char(203,
|
||||
sign + floor(e / 0x10),
|
||||
(e % 0x10) * 0x10 + floor(m / 0x1000000000000),
|
||||
floor(m / 0x10000000000) % 0x100,
|
||||
floor(m / 0x100000000) % 0x100,
|
||||
floor(m / 0x1000000) % 0x100,
|
||||
floor(m / 0x10000) % 0x100,
|
||||
floor(m / 0x100) % 0x100,
|
||||
m % 0x100)
|
||||
end
|
||||
|
||||
-- Copyright (C) 2012-2015 Francois Perrad.
|
||||
-- number deserialization code also modified from https://github.com/fperrad/lua-MessagePack
|
||||
local function number_from_str(str, index)
|
||||
local b = byte(str, index)
|
||||
if not b then error("Expected more bytes of input.") end
|
||||
if b < 128 then
|
||||
return b - 27, index + 1
|
||||
elseif b < 192 then
|
||||
local b2 = byte(str, index + 1)
|
||||
if not b2 then error("Expected more bytes of input.") end
|
||||
return b2 + 0x100 * (b - 128) - 8192, index + 2
|
||||
end
|
||||
local b1, b2, b3, b4, b5, b6, b7, b8 = byte(str, index + 1, index + 8)
|
||||
if (not b1) or (not b2) or (not b3) or (not b4) or
|
||||
(not b5) or (not b6) or (not b7) or (not b8) then
|
||||
error("Expected more bytes of input.")
|
||||
end
|
||||
if b == 212 then
|
||||
local flip = b1 >= 128
|
||||
if flip then -- negative
|
||||
b1, b2, b3, b4 = 0xFF - b1, 0xFF - b2, 0xFF - b3, 0xFF - b4
|
||||
b5, b6, b7, b8 = 0xFF - b5, 0xFF - b6, 0xFF - b7, 0xFF - b8
|
||||
end
|
||||
local n = ((((((b1 * 0x100 + b2) * 0x100 + b3) * 0x100 + b4) *
|
||||
0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8
|
||||
if flip then
|
||||
return (-n) - 1, index + 9
|
||||
else
|
||||
return n, index + 9
|
||||
end
|
||||
end
|
||||
if b ~= 203 then
|
||||
error("Expected number")
|
||||
end
|
||||
local sign = b1 > 0x7F and -1 or 1
|
||||
local e = (b1 % 0x80) * 0x10 + floor(b2 / 0x10)
|
||||
local m = ((((((b2 % 0x10) * 0x100 + b3) * 0x100 + b4) * 0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8
|
||||
local n
|
||||
if e == 0 then
|
||||
if m == 0 then
|
||||
n = sign * 0.0
|
||||
else
|
||||
n = sign * (m / 2 ^ 52) * 2 ^ -1022
|
||||
end
|
||||
elseif e == 0x7FF then
|
||||
if m == 0 then
|
||||
n = sign * (1/0)
|
||||
else
|
||||
n = 0.0/0.0
|
||||
end
|
||||
else
|
||||
n = sign * (1.0 + m / 2 ^ 52) * 2 ^ (e - 0x3FF)
|
||||
end
|
||||
return n, index + 9
|
||||
end
|
||||
|
||||
|
||||
local function newbinser()
|
||||
|
||||
-- unique table key for getting next value
|
||||
local NEXT = {}
|
||||
local CTORSTACK = {}
|
||||
|
||||
-- NIL = 202
|
||||
-- FLOAT = 203
|
||||
-- TRUE = 204
|
||||
-- FALSE = 205
|
||||
-- STRING = 206
|
||||
-- TABLE = 207
|
||||
-- REFERENCE = 208
|
||||
-- CONSTRUCTOR = 209
|
||||
-- FUNCTION = 210
|
||||
-- RESOURCE = 211
|
||||
-- INT64 = 212
|
||||
-- TABLE WITH META = 213
|
||||
|
||||
local mts = {}
|
||||
local ids = {}
|
||||
local serializers = {}
|
||||
local deserializers = {}
|
||||
local resources = {}
|
||||
local resources_by_name = {}
|
||||
local types = {}
|
||||
|
||||
types["nil"] = function(x, visited, accum)
|
||||
accum[#accum + 1] = "\202"
|
||||
end
|
||||
|
||||
function types.number(x, visited, accum)
|
||||
accum[#accum + 1] = number_to_str(x)
|
||||
end
|
||||
|
||||
function types.boolean(x, visited, accum)
|
||||
accum[#accum + 1] = x and "\204" or "\205"
|
||||
end
|
||||
|
||||
function types.string(x, visited, accum)
|
||||
local alen = #accum
|
||||
if visited[x] then
|
||||
accum[alen + 1] = "\208"
|
||||
accum[alen + 2] = number_to_str(visited[x])
|
||||
else
|
||||
visited[x] = visited[NEXT]
|
||||
visited[NEXT] = visited[NEXT] + 1
|
||||
accum[alen + 1] = "\206"
|
||||
accum[alen + 2] = number_to_str(#x)
|
||||
accum[alen + 3] = x
|
||||
end
|
||||
end
|
||||
|
||||
local function check_custom_type(x, visited, accum)
|
||||
local res = resources[x]
|
||||
if res then
|
||||
accum[#accum + 1] = "\211"
|
||||
types[type(res)](res, visited, accum)
|
||||
return true
|
||||
end
|
||||
local mt = getmetatable(x)
|
||||
local id = mt and ids[mt]
|
||||
if id then
|
||||
local constructing = visited[CTORSTACK]
|
||||
if constructing[x] then
|
||||
error("Infinite loop in constructor.")
|
||||
end
|
||||
constructing[x] = true
|
||||
accum[#accum + 1] = "\209"
|
||||
types[type(id)](id, visited, accum)
|
||||
local args, len = pack(serializers[id](x))
|
||||
accum[#accum + 1] = number_to_str(len)
|
||||
for i = 1, len do
|
||||
local arg = args[i]
|
||||
types[type(arg)](arg, visited, accum)
|
||||
end
|
||||
visited[x] = visited[NEXT]
|
||||
visited[NEXT] = visited[NEXT] + 1
|
||||
-- We finished constructing
|
||||
constructing[x] = nil
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
function types.userdata(x, visited, accum)
|
||||
if visited[x] then
|
||||
accum[#accum + 1] = "\208"
|
||||
accum[#accum + 1] = number_to_str(visited[x])
|
||||
else
|
||||
if check_custom_type(x, visited, accum) then return end
|
||||
error("Cannot serialize this userdata.")
|
||||
end
|
||||
end
|
||||
|
||||
function types.table(x, visited, accum)
|
||||
if visited[x] then
|
||||
accum[#accum + 1] = "\208"
|
||||
accum[#accum + 1] = number_to_str(visited[x])
|
||||
else
|
||||
if check_custom_type(x, visited, accum) then return end
|
||||
visited[x] = visited[NEXT]
|
||||
visited[NEXT] = visited[NEXT] + 1
|
||||
local xlen = #x
|
||||
local mt = getmetatable(x)
|
||||
if mt then
|
||||
accum[#accum + 1] = "\213"
|
||||
types.table(mt, visited, accum)
|
||||
else
|
||||
accum[#accum + 1] = "\207"
|
||||
end
|
||||
accum[#accum + 1] = number_to_str(xlen)
|
||||
for i = 1, xlen do
|
||||
local v = x[i]
|
||||
types[type(v)](v, visited, accum)
|
||||
end
|
||||
local key_count = 0
|
||||
for k in pairs(x) do
|
||||
if not_array_index(k, xlen) then
|
||||
key_count = key_count + 1
|
||||
end
|
||||
end
|
||||
accum[#accum + 1] = number_to_str(key_count)
|
||||
for k, v in pairs(x) do
|
||||
if not_array_index(k, xlen) then
|
||||
types[type(k)](k, visited, accum)
|
||||
types[type(v)](v, visited, accum)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
types["function"] = function(x, visited, accum)
|
||||
if visited[x] then
|
||||
accum[#accum + 1] = "\208"
|
||||
accum[#accum + 1] = number_to_str(visited[x])
|
||||
else
|
||||
if check_custom_type(x, visited, accum) then return end
|
||||
visited[x] = visited[NEXT]
|
||||
visited[NEXT] = visited[NEXT] + 1
|
||||
local str = dump(x)
|
||||
accum[#accum + 1] = "\210"
|
||||
accum[#accum + 1] = number_to_str(#str)
|
||||
accum[#accum + 1] = str
|
||||
end
|
||||
end
|
||||
|
||||
types.cdata = function(x, visited, accum)
|
||||
if visited[x] then
|
||||
accum[#accum + 1] = "\208"
|
||||
accum[#accum + 1] = number_to_str(visited[x])
|
||||
else
|
||||
if check_custom_type(x, visited, #accum) then return end
|
||||
error("Cannot serialize this cdata.")
|
||||
end
|
||||
end
|
||||
|
||||
types.thread = function() error("Cannot serialize threads.") end
|
||||
|
||||
local function deserialize_value(str, index, visited)
|
||||
local t = byte(str, index)
|
||||
if not t then return nil, index end
|
||||
if t < 128 then
|
||||
return t - 27, index + 1
|
||||
elseif t < 192 then
|
||||
local b2 = byte(str, index + 1)
|
||||
if not b2 then error("Expected more bytes of input.") end
|
||||
return b2 + 0x100 * (t - 128) - 8192, index + 2
|
||||
elseif t == 202 then
|
||||
return nil, index + 1
|
||||
elseif t == 203 or t == 212 then
|
||||
return number_from_str(str, index)
|
||||
elseif t == 204 then
|
||||
return true, index + 1
|
||||
elseif t == 205 then
|
||||
return false, index + 1
|
||||
elseif t == 206 then
|
||||
local length, dataindex = number_from_str(str, index + 1)
|
||||
local nextindex = dataindex + length
|
||||
if not (length >= 0) then error("Bad string length") end
|
||||
if #str < nextindex - 1 then error("Expected more bytes of string") end
|
||||
local substr = sub(str, dataindex, nextindex - 1)
|
||||
visited[#visited + 1] = substr
|
||||
return substr, nextindex
|
||||
elseif t == 207 or t == 213 then
|
||||
local mt, count, nextindex
|
||||
local ret = {}
|
||||
visited[#visited + 1] = ret
|
||||
nextindex = index + 1
|
||||
if t == 213 then
|
||||
mt, nextindex = deserialize_value(str, nextindex, visited)
|
||||
if type(mt) ~= "table" then error("Expected table metatable") end
|
||||
end
|
||||
count, nextindex = number_from_str(str, nextindex)
|
||||
for i = 1, count do
|
||||
local oldindex = nextindex
|
||||
ret[i], nextindex = deserialize_value(str, nextindex, visited)
|
||||
if nextindex == oldindex then error("Expected more bytes of input.") end
|
||||
end
|
||||
count, nextindex = number_from_str(str, nextindex)
|
||||
for i = 1, count do
|
||||
local k, v
|
||||
local oldindex = nextindex
|
||||
k, nextindex = deserialize_value(str, nextindex, visited)
|
||||
if nextindex == oldindex then error("Expected more bytes of input.") end
|
||||
oldindex = nextindex
|
||||
v, nextindex = deserialize_value(str, nextindex, visited)
|
||||
if nextindex == oldindex then error("Expected more bytes of input.") end
|
||||
if k == nil then error("Can't have nil table keys") end
|
||||
ret[k] = v
|
||||
end
|
||||
if mt then setmetatable(ret, mt) end
|
||||
return ret, nextindex
|
||||
elseif t == 208 then
|
||||
local ref, nextindex = number_from_str(str, index + 1)
|
||||
return visited[ref], nextindex
|
||||
elseif t == 209 then
|
||||
local count
|
||||
local name, nextindex = deserialize_value(str, index + 1, visited)
|
||||
count, nextindex = number_from_str(str, nextindex)
|
||||
local args = {}
|
||||
for i = 1, count do
|
||||
local oldindex = nextindex
|
||||
args[i], nextindex = deserialize_value(str, nextindex, visited)
|
||||
if nextindex == oldindex then error("Expected more bytes of input.") end
|
||||
end
|
||||
if not name or not deserializers[name] then
|
||||
error(("Cannot deserialize class '%s'"):format(tostring(name)))
|
||||
end
|
||||
local ret = deserializers[name](unpack(args))
|
||||
visited[#visited + 1] = ret
|
||||
return ret, nextindex
|
||||
elseif t == 210 then
|
||||
local length, dataindex = number_from_str(str, index + 1)
|
||||
local nextindex = dataindex + length
|
||||
if not (length >= 0) then error("Bad string length") end
|
||||
if #str < nextindex - 1 then error("Expected more bytes of string") end
|
||||
local ret = loadstring(sub(str, dataindex, nextindex - 1))
|
||||
visited[#visited + 1] = ret
|
||||
return ret, nextindex
|
||||
elseif t == 211 then
|
||||
local resname, nextindex = deserialize_value(str, index + 1, visited)
|
||||
if resname == nil then error("Got nil resource name") end
|
||||
local res = resources_by_name[resname]
|
||||
if res == nil then
|
||||
error(("No resources found for name '%s'"):format(tostring(resname)))
|
||||
end
|
||||
return res, nextindex
|
||||
else
|
||||
error("Could not deserialize type byte " .. t .. ".")
|
||||
end
|
||||
end
|
||||
|
||||
local function serialize(...)
|
||||
local visited = {[NEXT] = 1, [CTORSTACK] = {}}
|
||||
local accum = {}
|
||||
for i = 1, select("#", ...) do
|
||||
local x = select(i, ...)
|
||||
types[type(x)](x, visited, accum)
|
||||
end
|
||||
return concat(accum)
|
||||
end
|
||||
|
||||
local function make_file_writer(file)
|
||||
return setmetatable({}, {
|
||||
__newindex = function(_, _, v)
|
||||
file:write(v)
|
||||
end
|
||||
})
|
||||
end
|
||||
|
||||
local function serialize_to_file(path, mode, ...)
|
||||
local file, err = io.open(path, mode)
|
||||
assert(file, err)
|
||||
local visited = {[NEXT] = 1, [CTORSTACK] = {}}
|
||||
local accum = make_file_writer(file)
|
||||
for i = 1, select("#", ...) do
|
||||
local x = select(i, ...)
|
||||
types[type(x)](x, visited, accum)
|
||||
end
|
||||
-- flush the writer
|
||||
file:flush()
|
||||
file:close()
|
||||
end
|
||||
|
||||
local function writeFile(path, ...)
|
||||
return serialize_to_file(path, "wb", ...)
|
||||
end
|
||||
|
||||
local function appendFile(path, ...)
|
||||
return serialize_to_file(path, "ab", ...)
|
||||
end
|
||||
|
||||
local function deserialize(str, index)
|
||||
assert(type(str) == "string", "Expected string to deserialize.")
|
||||
local vals = {}
|
||||
index = index or 1
|
||||
local visited = {}
|
||||
local len = 0
|
||||
local val
|
||||
while true do
|
||||
local nextindex
|
||||
val, nextindex = deserialize_value(str, index, visited)
|
||||
if nextindex > index then
|
||||
len = len + 1
|
||||
vals[len] = val
|
||||
index = nextindex
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
return vals, len
|
||||
end
|
||||
|
||||
local function deserializeN(str, n, index)
|
||||
assert(type(str) == "string", "Expected string to deserialize.")
|
||||
n = n or 1
|
||||
assert(type(n) == "number", "Expected a number for parameter n.")
|
||||
assert(n > 0 and floor(n) == n, "N must be a poitive integer.")
|
||||
local vals = {}
|
||||
index = index or 1
|
||||
local visited = {}
|
||||
local len = 0
|
||||
local val
|
||||
while len < n do
|
||||
local nextindex
|
||||
val, nextindex = deserialize_value(str, index, visited)
|
||||
if nextindex > index then
|
||||
len = len + 1
|
||||
vals[len] = val
|
||||
index = nextindex
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
vals[len + 1] = index
|
||||
return unpack(vals, 1, n + 1)
|
||||
end
|
||||
|
||||
local function readFile(path)
|
||||
local file, err = io.open(path, "rb")
|
||||
assert(file, err)
|
||||
local str = file:read("*all")
|
||||
file:close()
|
||||
return deserialize(str)
|
||||
end
|
||||
|
||||
-- Resources
|
||||
|
||||
local function registerResource(resource, name)
|
||||
type_check(name, "string", "name")
|
||||
assert(not resources[resource],
|
||||
"Resource already registered.")
|
||||
assert(not resources_by_name[name],
|
||||
format("Resource %q already exists.", name))
|
||||
resources_by_name[name] = resource
|
||||
resources[resource] = name
|
||||
return resource
|
||||
end
|
||||
|
||||
local function unregisterResource(name)
|
||||
type_check(name, "string", "name")
|
||||
assert(resources_by_name[name], format("Resource %q does not exist.", name))
|
||||
local resource = resources_by_name[name]
|
||||
resources_by_name[name] = nil
|
||||
resources[resource] = nil
|
||||
return resource
|
||||
end
|
||||
|
||||
-- Templating
|
||||
|
||||
local function normalize_template(template)
|
||||
local ret = {}
|
||||
for i = 1, #template do
|
||||
ret[i] = template[i]
|
||||
end
|
||||
local non_array_part = {}
|
||||
-- The non-array part of the template (nested templates) have to be deterministic, so they are sorted.
|
||||
-- This means that inherently non deterministicly sortable keys (tables, functions) should NOT be used
|
||||
-- in templates. Looking for way around this.
|
||||
for k in pairs(template) do
|
||||
if not_array_index(k, #template) then
|
||||
non_array_part[#non_array_part + 1] = k
|
||||
end
|
||||
end
|
||||
table.sort(non_array_part)
|
||||
for i = 1, #non_array_part do
|
||||
local name = non_array_part[i]
|
||||
ret[#ret + 1] = {name, normalize_template(template[name])}
|
||||
end
|
||||
return ret
|
||||
end
|
||||
|
||||
local function templatepart_serialize(part, argaccum, x, len)
|
||||
local extras = {}
|
||||
local extracount = 0
|
||||
for k, v in pairs(x) do
|
||||
extras[k] = v
|
||||
extracount = extracount + 1
|
||||
end
|
||||
for i = 1, #part do
|
||||
local name
|
||||
if type(part[i]) == "table" then
|
||||
name = part[i][1]
|
||||
len = templatepart_serialize(part[i][2], argaccum, x[name], len)
|
||||
else
|
||||
name = part[i]
|
||||
len = len + 1
|
||||
argaccum[len] = x[part[i]]
|
||||
end
|
||||
if extras[name] ~= nil then
|
||||
extracount = extracount - 1
|
||||
extras[name] = nil
|
||||
end
|
||||
end
|
||||
if extracount > 0 then
|
||||
argaccum[len + 1] = extras
|
||||
else
|
||||
argaccum[len + 1] = nil
|
||||
end
|
||||
return len + 1
|
||||
end
|
||||
|
||||
local function templatepart_deserialize(ret, part, values, vindex)
|
||||
for i = 1, #part do
|
||||
local name = part[i]
|
||||
if type(name) == "table" then
|
||||
local newret = {}
|
||||
ret[name[1]] = newret
|
||||
vindex = templatepart_deserialize(newret, name[2], values, vindex)
|
||||
else
|
||||
ret[name] = values[vindex]
|
||||
vindex = vindex + 1
|
||||
end
|
||||
end
|
||||
local extras = values[vindex]
|
||||
if extras then
|
||||
for k, v in pairs(extras) do
|
||||
ret[k] = v
|
||||
end
|
||||
end
|
||||
return vindex + 1
|
||||
end
|
||||
|
||||
local function template_serializer_and_deserializer(metatable, template)
|
||||
return function(x)
|
||||
local argaccum = {}
|
||||
local len = templatepart_serialize(template, argaccum, x, 0)
|
||||
return unpack(argaccum, 1, len)
|
||||
end, function(...)
|
||||
local ret = {}
|
||||
local args = {...}
|
||||
templatepart_deserialize(ret, template, args, 1)
|
||||
return setmetatable(ret, metatable)
|
||||
end
|
||||
end
|
||||
|
||||
-- Used to serialize classes withh custom serializers and deserializers.
|
||||
-- If no _serialize or _deserialize (or no _template) value is found in the
|
||||
-- metatable, then the metatable is registered as a resources.
|
||||
local function register(metatable, name, serialize, deserialize)
|
||||
if type(metatable) == "table" then
|
||||
name = name or metatable.name
|
||||
serialize = serialize or metatable._serialize
|
||||
deserialize = deserialize or metatable._deserialize
|
||||
if (not serialize) or (not deserialize) then
|
||||
if metatable._template then
|
||||
-- Register as template
|
||||
local t = normalize_template(metatable._template)
|
||||
serialize, deserialize = template_serializer_and_deserializer(metatable, t)
|
||||
else
|
||||
-- Register the metatable as a resource. This is semantically
|
||||
-- similar and more flexible (handles cycles).
|
||||
registerResource(metatable, name)
|
||||
return
|
||||
end
|
||||
end
|
||||
elseif type(metatable) == "string" then
|
||||
name = name or metatable
|
||||
end
|
||||
type_check(name, "string", "name")
|
||||
type_check(serialize, "function", "serialize")
|
||||
type_check(deserialize, "function", "deserialize")
|
||||
assert((not ids[metatable]) and (not resources[metatable]),
|
||||
"Metatable already registered.")
|
||||
assert((not mts[name]) and (not resources_by_name[name]),
|
||||
("Name %q already registered."):format(name))
|
||||
mts[name] = metatable
|
||||
ids[metatable] = name
|
||||
serializers[name] = serialize
|
||||
deserializers[name] = deserialize
|
||||
return metatable
|
||||
end
|
||||
|
||||
local function unregister(item)
|
||||
local name, metatable
|
||||
if type(item) == "string" then -- assume name
|
||||
name, metatable = item, mts[item]
|
||||
else -- assume metatable
|
||||
name, metatable = ids[item], item
|
||||
end
|
||||
type_check(name, "string", "name")
|
||||
mts[name] = nil
|
||||
if (metatable) then
|
||||
resources[metatable] = nil
|
||||
ids[metatable] = nil
|
||||
end
|
||||
serializers[name] = nil
|
||||
deserializers[name] = nil
|
||||
resources_by_name[name] = nil;
|
||||
return metatable
|
||||
end
|
||||
|
||||
local function registerClass(class, name)
|
||||
name = name or class.name
|
||||
if class.__instanceDict then -- middleclass
|
||||
register(class.__instanceDict, name)
|
||||
else -- assume 30log or similar library
|
||||
register(class, name)
|
||||
end
|
||||
return class
|
||||
end
|
||||
|
||||
return {
|
||||
-- aliases
|
||||
s = serialize,
|
||||
d = deserialize,
|
||||
dn = deserializeN,
|
||||
r = readFile,
|
||||
w = writeFile,
|
||||
a = appendFile,
|
||||
|
||||
serialize = serialize,
|
||||
deserialize = deserialize,
|
||||
deserializeN = deserializeN,
|
||||
readFile = readFile,
|
||||
writeFile = writeFile,
|
||||
appendFile = appendFile,
|
||||
register = register,
|
||||
unregister = unregister,
|
||||
registerResource = registerResource,
|
||||
unregisterResource = unregisterResource,
|
||||
registerClass = registerClass,
|
||||
|
||||
newbinser = newbinser
|
||||
}
|
||||
end
|
||||
|
||||
return newbinser()
|
10
cossim_test.lua
Normal file
10
cossim_test.lua
Normal file
|
@ -0,0 +1,10 @@
|
|||
local util = require "util"
|
||||
|
||||
for dims = 1, 100 do
|
||||
print(dims, util.expect_cossim(dims))
|
||||
end
|
||||
|
||||
dims = 1000
|
||||
print(dims, util.expect_cossim(dims))
|
||||
dims = 10000
|
||||
print(dims, util.expect_cossim(dims))
|
290
eig.lua
Normal file
290
eig.lua
Normal file
|
@ -0,0 +1,290 @@
|
|||
-- blah
|
||||
|
||||
--local globalize = require "strict"
|
||||
local nn = require "nn"
|
||||
local util = require "util"
|
||||
|
||||
local sign = util.sign
|
||||
local abs = math.abs
|
||||
local reshape = nn.reshape
|
||||
local sqrt = math.sqrt
|
||||
local zeros = nn.zeros
|
||||
|
||||
local function tred2(a)
|
||||
assert(#a.shape == 2)
|
||||
assert(a.shape[1] == a.shape[2])
|
||||
local n = a.shape[1]
|
||||
local d = zeros(n) -- diagonal
|
||||
local e = zeros(n) -- off-diagonal (e[1] is a dummy value?)
|
||||
|
||||
for i = n, 2, -1 do
|
||||
local l = i - 1
|
||||
local ind_li = (l - 1) * n + i
|
||||
local h = 0
|
||||
local scale = 0
|
||||
|
||||
if l > 1 then
|
||||
for k = 1, l do
|
||||
local ind_ki = (k - 1) * n + i
|
||||
scale = scale + abs(a[ind_ki])
|
||||
end
|
||||
|
||||
if scale == 0 then
|
||||
e[i] = a[ind_li]
|
||||
else
|
||||
for k = 1, l do
|
||||
local ind_ki = (k - 1) * n + i
|
||||
a[ind_ki] = a[ind_ki] / scale
|
||||
h = h + a[ind_ki] * a[ind_ki]
|
||||
end
|
||||
|
||||
local f = a[ind_li]
|
||||
local g = sqrt(h)
|
||||
if f >= 0 then g = -g end
|
||||
e[i] = scale * g
|
||||
h = h - f * g
|
||||
a[ind_li] = f - g
|
||||
f = 0
|
||||
|
||||
for j = 1, l do
|
||||
local ind_ij = (i - 1) * n + j
|
||||
local ind_ji = (j - 1) * n + i
|
||||
a[ind_ij] = a[ind_ji] / h
|
||||
g = 0
|
||||
for k = 1, j do
|
||||
local ind_kj = (k - 1) * n + j
|
||||
local ind_ki = (k - 1) * n + i
|
||||
g = g + a[ind_kj] * a[ind_ki]
|
||||
end
|
||||
for k = j + 1, l do
|
||||
local ind_jk = (j - 1) * n + k
|
||||
local ind_ki = (k - 1) * n + i
|
||||
g = g + a[ind_jk] * a[ind_ki]
|
||||
end
|
||||
e[j] = g / h
|
||||
f = f + e[j] * a[ind_ji]
|
||||
end
|
||||
|
||||
local hh = f / (h + h)
|
||||
for j = 1, l do
|
||||
local ind_ji = (j - 1) * n + i
|
||||
f = a[ind_ji]
|
||||
g = e[j] - hh * f
|
||||
e[j] = g
|
||||
for k = 1, j do
|
||||
local ind_kj = (k - 1) * n + j
|
||||
local ind_ki = (k - 1) * n + i
|
||||
a[ind_kj] = a[ind_kj] - (f * e[k] + g * a[ind_ki])
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
e[i] = a[ind_li]
|
||||
end
|
||||
|
||||
d[i] = h
|
||||
end
|
||||
|
||||
d[1] = 0
|
||||
e[1] = 0
|
||||
for i = 1, n do
|
||||
local l = i - 1
|
||||
|
||||
if d[i] ~= 0 then
|
||||
for j = 1, l do
|
||||
local g = 0
|
||||
for k = 1, l do
|
||||
local ind_ki = (k - 1) * n + i
|
||||
local ind_jk = (j - 1) * n + k
|
||||
g = g + a[ind_ki] * a[ind_jk]
|
||||
end
|
||||
for k = 1, l do
|
||||
local ind_ik = (i - 1) * n + k
|
||||
local ind_jk = (j - 1) * n + k
|
||||
a[ind_jk] = a[ind_jk] - g * a[ind_ik]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local ind_ii = (i - 1) * n + i
|
||||
d[i] = a[ind_ii]
|
||||
a[ind_ii] = 1
|
||||
for j = 1, l do
|
||||
local ind_ij = (i - 1) * n + j
|
||||
local ind_ji = (j - 1) * n + i
|
||||
a[ind_ij] = 0
|
||||
a[ind_ji] = 0
|
||||
end
|
||||
end
|
||||
|
||||
return d, e
|
||||
end
|
||||
|
||||
local function pythag(a, b)
|
||||
--return sqrt(a * a + b * b)
|
||||
local abs_a = abs(a)
|
||||
local abs_b = abs(b)
|
||||
|
||||
if abs_a > abs_b then
|
||||
local temp = abs_b / abs_a
|
||||
temp = temp * temp
|
||||
return abs_a * sqrt(1 + temp)
|
||||
elseif abs_b ~= 0 then
|
||||
local temp = abs_a / abs_b
|
||||
temp = temp * temp
|
||||
return abs_b * sqrt(1 + temp)
|
||||
end
|
||||
|
||||
return 0
|
||||
end
|
||||
|
||||
local function tqli(d, e, z)
|
||||
assert(#z.shape == 2)
|
||||
assert(z.shape[1] == z.shape[2])
|
||||
local n = z.shape[1]
|
||||
assert(#d == n)
|
||||
assert(#e == n)
|
||||
|
||||
local eps = 1.2e-7
|
||||
|
||||
for i = 2, n do e[i - 1] = e[i] end
|
||||
e[n] = 0
|
||||
|
||||
for l = 1, n do
|
||||
local iter = 0
|
||||
local fucky = 0
|
||||
|
||||
local m
|
||||
while true do
|
||||
m = l
|
||||
while m <= n - 1 do
|
||||
local dd = abs(d[m]) + abs(d[m + 1])
|
||||
if abs(e[m]) + dd == dd then break end
|
||||
--if abs(e[m]) <= eps * dd then break end
|
||||
m = m + 1
|
||||
end
|
||||
|
||||
fucky = fucky + 1
|
||||
if fucky == 100 then print("fucky!"); break end
|
||||
if fucky == 1000 then error("super fucky!"); break end
|
||||
|
||||
--print(("l: %i, m: %i"):format(l - 1, m - 1))
|
||||
|
||||
if m == l then break end
|
||||
|
||||
iter = iter + 1
|
||||
if iter >= 32 then error("Too many iterations in tqli") end
|
||||
|
||||
local g = (d[l + 1] - d[l]) / (2 * e[l])
|
||||
local r = pythag(g, 1)
|
||||
g = d[m] - d[l] + e[l] / (g + r * sign(g))
|
||||
local s = 1
|
||||
local c = 1
|
||||
local p = 0
|
||||
|
||||
for i = m - 1, l, -1 do
|
||||
local f = s * e[i]
|
||||
local b = c * e[i]
|
||||
r = pythag(f, g)
|
||||
e[i + 1] = r
|
||||
if r == 0 then
|
||||
d[i + 1] = d[i + 1] - p
|
||||
e[m] = 0
|
||||
break
|
||||
end
|
||||
|
||||
s = f / r
|
||||
c = g / r
|
||||
g = d[i + 1] - p
|
||||
r = (d[i] - g) * s + 2 * c * b
|
||||
p = s * r
|
||||
d[i + 1] = g + p
|
||||
g = c * r - b
|
||||
|
||||
for k = 1, n do
|
||||
if true then
|
||||
local ind = (i - 1) * n + k
|
||||
f = z[ind + n]
|
||||
z[ind + n] = s * z[ind] + c * f
|
||||
z[ind] = c * z[ind] - s * f
|
||||
else
|
||||
local ind = (k - 1) * n + i
|
||||
f = z[ind + 1]
|
||||
z[ind + 1] = s * z[ind] + c * f
|
||||
z[ind] = c * z[ind] - s * f
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if r == 0 and i >= l then
|
||||
-- continue
|
||||
else
|
||||
d[l] = d[l] - p
|
||||
e[l] = g
|
||||
e[m] = 0
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--[=[
|
||||
|
||||
local A = {
|
||||
4, 1, -2, 2,
|
||||
1, 2, 0, 1,
|
||||
-2, 0, 3, -2,
|
||||
2, 1, -2, -1,
|
||||
}
|
||||
reshape(A, 4, 4)
|
||||
|
||||
local d, e = tred2(A)
|
||||
|
||||
--[[
|
||||
print(nn.pp(A))
|
||||
print(nn.pp(d))
|
||||
print(nn.pp(e))
|
||||
--]]
|
||||
|
||||
--[[
|
||||
{
|
||||
0.248069, 0.744208, 0.620174, 0.000000,
|
||||
0.702863, -0.578829, 0.413449, 0.000000,
|
||||
0.666667, 0.333333, -0.666667, 0.000000,
|
||||
0.000000, 0.000000, 0.000000, 1.000000,
|
||||
}
|
||||
{ 2.261538, 1.182906, 5.555556, -1.000000,}
|
||||
{ 0.000000, -0.092308, 0.895806, 3.000000,}
|
||||
--]]
|
||||
|
||||
--A = nn.transpose(A)
|
||||
tqli(d, e, A)
|
||||
|
||||
print(nn.pp(A))
|
||||
print(nn.pp(d))
|
||||
print(nn.pp(e))
|
||||
|
||||
local D = zeros{4, 4}
|
||||
for i = 1, 4 do
|
||||
D[(i - 1) * 4 + i] = math.exp(d[i])
|
||||
end
|
||||
local out = nn.dot(nn.transpose(A), D)
|
||||
out = nn.dot(out, A)
|
||||
print(nn.pp(out))
|
||||
|
||||
--[[
|
||||
{
|
||||
703.414032, 1410.125991, 1478.990752, -43.126976,
|
||||
-1205.204565, 1573.963121, -940.902581, 319.676625,
|
||||
14.433478, 3.884400, -10.434671, 6.841011,
|
||||
3.529384, 3.087068, -5.116236, -17.850223,
|
||||
}
|
||||
{ 2.273819, 1.072834, 6.818268, -2.164921,}
|
||||
{ -0.000000, 0.000000, 0.000000, 0.000000,}
|
||||
--]]
|
||||
|
||||
--]=]
|
||||
|
||||
return {
|
||||
tred2=tred2,
|
||||
tqli=tqli,
|
||||
}
|
21
es_test.lua
21
es_test.lua
|
@ -28,15 +28,24 @@ end
|
|||
|
||||
-- i'm just copying settings from hardmaru's simple_es_example.ipynb.
|
||||
local iterations = 3000 --4000
|
||||
local dims = 100
|
||||
local popsize = dims + 1
|
||||
|
||||
local dims, popsize
|
||||
if false then
|
||||
dims = 100
|
||||
popsize = dims + 1
|
||||
else
|
||||
dims = 30
|
||||
popsize = 99
|
||||
end
|
||||
|
||||
local sigma_init = 0.5
|
||||
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
|
||||
local es = snes.Snes(dims, popsize, 0.1, sigma_init)
|
||||
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
|
||||
--local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
|
||||
--local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5)
|
||||
es.min_refresh = 0.5 -- FIXME: needs a better interface.
|
||||
|
||||
es.min_refresh = 1.0 -- FIXME: needs a better interface.
|
||||
|
||||
if typeof(es) == xnes.Xnes
|
||||
or typeof(es) == snes.Snes
|
||||
|
@ -55,9 +64,9 @@ then
|
|||
local util = require "util"
|
||||
util.normalize_sums(es.utility)
|
||||
|
||||
es.param_rate = 0.39
|
||||
es.sigma_rate = 0.39
|
||||
es.covar_rate = 0.39
|
||||
es.param_rate = 0.2 --39
|
||||
es.sigma_rate = 0.05 --39
|
||||
es.covar_rate = 0.1 --39
|
||||
es.adaptive = false
|
||||
end
|
||||
|
||||
|
|
81
expm.lua
Normal file
81
expm.lua
Normal file
|
@ -0,0 +1,81 @@
|
|||
-- here's a really awful way of computing the matrix exponential.
|
||||
-- we employ the QR algorithm to find eigenpairs of a given symmetric matrix,
|
||||
-- then run the ordinary exponent function over the eigenvalues.
|
||||
-- this only works for symmetric matrices!
|
||||
|
||||
local nn = require "nn"
|
||||
local qr = require "qr"
|
||||
local util = require "util"
|
||||
|
||||
local copy = util.copy
|
||||
local dot = nn.dot
|
||||
local exp = math.exp
|
||||
local reshape = nn.reshape
|
||||
local transpose = nn.transpose
|
||||
local zeros = nn.zeros
|
||||
|
||||
local function expm(mat)
|
||||
assert(#mat.shape == 2)
|
||||
assert(mat.shape[1] == mat.shape[2], "expm input must be square")
|
||||
--assert(stuff(mat), "expm input must be symmetrical")
|
||||
|
||||
local dims = mat.shape[1]
|
||||
|
||||
local vec = zeros(mat.shape)
|
||||
for i = 1, dims do
|
||||
local ind = (i - 1) * dims + i -- diagonal
|
||||
vec[ind] = 1
|
||||
end
|
||||
|
||||
local diag = mat
|
||||
for i = 1, 10 do
|
||||
local q, r = qr(diag)
|
||||
vec = dot(vec, q)
|
||||
diag = dot(r, q)
|
||||
end
|
||||
|
||||
for y = 1, dims do
|
||||
for x = 1, dims do
|
||||
local ind = (y - 1) * dims + x
|
||||
if x == y then
|
||||
diag[ind] = exp(diag[ind])
|
||||
else
|
||||
diag[ind] = 0
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return dot(dot(vec, diag), transpose(vec))
|
||||
end
|
||||
|
||||
local eig = require "eig"
|
||||
local tred2 = eig.tred2
|
||||
local tqli = eig.tqli
|
||||
|
||||
local function expm2(mat)
|
||||
assert(#mat.shape == 2)
|
||||
assert(mat.shape[1] == mat.shape[2])
|
||||
local dims = mat.shape[1]
|
||||
|
||||
-- new version that computes much better (faster?) eigenpairs
|
||||
local vec = copy(mat)
|
||||
local d, e = tred2(vec)
|
||||
tqli(d, e, vec)
|
||||
|
||||
local diag = {}
|
||||
for y = 1, dims do
|
||||
for x = 1, dims do
|
||||
local ind = (y - 1) * dims + x
|
||||
if x == y then
|
||||
diag[ind] = exp(d[x])
|
||||
else
|
||||
diag[ind] = 0
|
||||
end
|
||||
end
|
||||
end
|
||||
reshape(diag, dims, dims)
|
||||
|
||||
return dot(dot(transpose(vec), diag), vec)
|
||||
end
|
||||
|
||||
return expm2
|
31
expm_test.lua
Normal file
31
expm_test.lua
Normal file
|
@ -0,0 +1,31 @@
|
|||
local globalize = require "strict"
|
||||
|
||||
local nn = require "nn"
|
||||
local expm = require "expm"
|
||||
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local dims = 8
|
||||
local n = dims * dims
|
||||
local M = {}
|
||||
local coeff = dims * dims * sqrt(dims)
|
||||
for i = 1, n do M[i] = i / coeff end
|
||||
M.shape = {dims, dims}
|
||||
|
||||
M = nn.dot(M, nn.transpose(M))
|
||||
|
||||
print(nn.pp(M, "%9.6f"))
|
||||
|
||||
local exp_M = expm(M)
|
||||
|
||||
print(nn.pp(exp_M, "%9.6f"))
|
||||
|
||||
local binser = require "binser"
|
||||
|
||||
--local dat = binser.s(exp_M)
|
||||
binser.writeFile("expm.bin", exp_M)
|
||||
local res, len = binser.readFile("expm.bin")
|
||||
assert(len == 1)
|
||||
exp_M = res[1]
|
||||
|
||||
print(nn.pp(exp_M, "%9.6f"))
|
2
main.lua
2
main.lua
|
@ -284,7 +284,7 @@ local function learn_from_epoch()
|
|||
end
|
||||
|
||||
local step
|
||||
if cfg.es == 'ars' and cfg.ars_lips then
|
||||
if cfg.es == 'ars' then --and cfg.ars_lips then
|
||||
step = es:tell(trial_rewards, current_cost)
|
||||
else
|
||||
step = es:tell(trial_rewards)
|
||||
|
|
2
nn.lua
2
nn.lua
|
@ -119,6 +119,8 @@ local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
|
|||
di = di or 1
|
||||
depth = depth or 0
|
||||
|
||||
if t == nil then return "nil" end
|
||||
|
||||
if t.shape == nil then return '['..pp_join(sep, fmt, t)..']'..sep..'\n' end
|
||||
|
||||
local dim = t.shape[di]
|
||||
|
|
7
pp_test.lua
Normal file
7
pp_test.lua
Normal file
|
@ -0,0 +1,7 @@
|
|||
local nn = require "nn"
|
||||
|
||||
a = {0,1,2,3,4,5,6,7}
|
||||
print(nn.pp(a, "%9.4f"))
|
||||
print(nn.pp(nn.reshape(a, 4, 2), "%9.4f"))
|
||||
print(nn.pp(nn.reshape(a, 2, 4), "%9.4f"))
|
||||
print(nn.pp(nn.reshape(a, 2, 2, 2), "%9.4f"))
|
97
presets.lua
97
presets.lua
|
@ -412,6 +412,103 @@ make_preset{
|
|||
beta = 1.0,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'xnes-recon', -- recon-sidered
|
||||
-- parent = 'big-scroll-hidden',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
||||
es = 'xnes',
|
||||
|
||||
-- embed = false,
|
||||
-- reduce_tiles = 5,
|
||||
-- hidden_size = 54,
|
||||
|
||||
epoch_trials = 20,
|
||||
epoch_top_trials = 20,
|
||||
negate_trials = true,
|
||||
|
||||
deterministic = true,
|
||||
deviation = 0.1,
|
||||
|
||||
param_decay = 0.01,
|
||||
|
||||
param_rate = 0.2,
|
||||
sigma_rate = 0.05,
|
||||
covar_rate = 0.1,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
||||
deterministic = true,
|
||||
|
||||
es = 'ars',
|
||||
epoch_trials = 5,
|
||||
epoch_top_trials = 9999,
|
||||
deviation = 1.0,
|
||||
param_rate = 1.0,
|
||||
beta = 1.0, -- fix the default.
|
||||
|
||||
beta = 5.0, -- oops, i had a dumb bug.
|
||||
param_decay = 0.01,
|
||||
momentum = 0.0,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse2',
|
||||
parent = 'arse',
|
||||
|
||||
beta = 5.0, -- oops, i had a dumb bug.
|
||||
param_decay = 0.001,
|
||||
embed = false,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse3',
|
||||
parent = 'arse2',
|
||||
|
||||
epoch_trials = 10,
|
||||
|
||||
-- note: also disabled sum of squares in ars.lua
|
||||
beta = 1.0,
|
||||
|
||||
--deviation = 0.5, -- after 500 epochs
|
||||
deviation = 0.7071, -- after 500+521 epochs
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse4',
|
||||
parent = 'arse3',
|
||||
|
||||
hidden = true,
|
||||
hidden_size = 68,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse5',
|
||||
parent = 'arse3',
|
||||
|
||||
-- sum of squares still disabled. i probably won't re-enable it really.
|
||||
--deviation = 1.0,
|
||||
deviation = 1.414, -- after 80+360+790 epochs
|
||||
momentum = 0.8, -- neumann momentum, peaks at this value
|
||||
--param_decay = 0.003, -- after 80 epochs
|
||||
--param_decay = 0.01, -- after 80+360 epochs
|
||||
param_decay = 0.0, -- after 80+360+790 epochs
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse6',
|
||||
parent = 'arse3',
|
||||
|
||||
epoch_trials = 20,
|
||||
param_rate = 0.25,
|
||||
deviation = 0.1, -- maybe try 0.15
|
||||
momentum = 0.9, -- maybe try 0.8
|
||||
param_decay = 0.0,
|
||||
}
|
||||
|
||||
-- end of new stuff
|
||||
|
||||
make_preset{
|
||||
|
|
8
qr.lua
8
qr.lua
|
@ -1,9 +1,11 @@
|
|||
local min = math.min
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local nn = require "nn"
|
||||
|
||||
local assert = assert
|
||||
local dot = nn.dot
|
||||
local ipairs = ipairs
|
||||
local min = math.min
|
||||
local reshape = nn.reshape
|
||||
local sqrt = math.sqrt
|
||||
local transpose = nn.transpose
|
||||
local zeros = nn.zeros
|
||||
|
||||
|
|
2
qr2.lua
2
qr2.lua
|
@ -43,7 +43,7 @@ local function qr(a)
|
|||
local den = 0
|
||||
for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end
|
||||
for k = j0, j1 do den = den + q[k] * q[k] end
|
||||
print(num, den)
|
||||
--print(num, den)
|
||||
if den == 0 then den = 1 end -- TODO: should probably just error.
|
||||
|
||||
local x = num / den
|
||||
|
|
79
qr3.lua
Normal file
79
qr3.lua
Normal file
|
@ -0,0 +1,79 @@
|
|||
-- gram-schmidt process
|
||||
|
||||
local nn = require "nn"
|
||||
|
||||
local dot = nn.dot
|
||||
local reshape = nn.reshape
|
||||
local sqrt = math.sqrt
|
||||
local transpose = nn.transpose
|
||||
local zeros = nn.zeros
|
||||
|
||||
local function qr(mat)
|
||||
assert(#mat.shape == 2)
|
||||
local v = transpose(mat)
|
||||
local rows = v.shape[1]
|
||||
local cols = v.shape[2]
|
||||
|
||||
local u = zeros(v.shape)
|
||||
|
||||
local w = {}
|
||||
local y = 1
|
||||
|
||||
local function load_row()
|
||||
local start = (y - 1) * cols
|
||||
for x = 1, cols do w[x] = v[start + x] end
|
||||
end
|
||||
|
||||
local function push_row()
|
||||
local sum = 0
|
||||
for _, value in ipairs(w) do sum = sum + value * value end
|
||||
local norm = sqrt(sum)
|
||||
|
||||
local start = (y - 1) * cols
|
||||
for x, value in ipairs(w) do u[start + x] = value / norm end
|
||||
|
||||
y = y + 1
|
||||
end
|
||||
|
||||
load_row()
|
||||
push_row()
|
||||
|
||||
local sums = {}
|
||||
|
||||
for i = 2, rows do
|
||||
load_row()
|
||||
|
||||
for x = 1, cols do sums[x] = 0 end
|
||||
|
||||
for j = 1, i - 1 do
|
||||
local start = (j - 1) * cols
|
||||
|
||||
local dotted = 0
|
||||
for x, value in ipairs(w) do
|
||||
dotted = dotted + value * u[start + x]
|
||||
end
|
||||
|
||||
--[[
|
||||
local scale = 0
|
||||
for x = 1, cols do
|
||||
local value = u[start + x]
|
||||
scale = scale + value * value
|
||||
end
|
||||
print(scale)
|
||||
dotted = dotted / scale
|
||||
--]]
|
||||
|
||||
for x, value in ipairs(sums) do
|
||||
sums[x] = value + dotted * u[start + x]
|
||||
end
|
||||
end
|
||||
|
||||
for x, value in ipairs(w) do w[x] = value - sums[x] end
|
||||
|
||||
push_row()
|
||||
end
|
||||
|
||||
return transpose(u), dot(u, mat)
|
||||
end
|
||||
|
||||
return qr
|
19
qr_test.lua
19
qr_test.lua
|
@ -2,6 +2,7 @@ local globalize = require "strict"
|
|||
local nn = require "nn"
|
||||
local qr = require "qr"
|
||||
local qr2 = require "qr2"
|
||||
local qr3 = require "qr3"
|
||||
|
||||
local A
|
||||
|
||||
|
@ -24,7 +25,7 @@ elseif false then
|
|||
}
|
||||
A = nn.reshape(A, 5, 3)
|
||||
|
||||
else
|
||||
elseif false then
|
||||
A = {
|
||||
1, 2, 0,
|
||||
2, 3, 1,
|
||||
|
@ -44,13 +45,27 @@ else
|
|||
A = nn.reshape(A, 5, 3)
|
||||
--A = nn.transpose(A)
|
||||
|
||||
else
|
||||
A = {
|
||||
1, 2, -3,
|
||||
2, 4, 5,
|
||||
-3, 5, 6,
|
||||
}
|
||||
|
||||
A = nn.reshape(A, 3, 3)
|
||||
end
|
||||
|
||||
print("A")
|
||||
print(nn.pp(A, "%9.4f"))
|
||||
print()
|
||||
|
||||
local Q, R = qr2(A)
|
||||
local Q, R = qr(A)
|
||||
|
||||
print("Q (reference)")
|
||||
print(nn.pp(Q, "%9.4f"))
|
||||
print()
|
||||
|
||||
local Q, R = qr3(A)
|
||||
|
||||
print("Q")
|
||||
print(nn.pp(Q, "%9.4f"))
|
||||
|
|
7
sign_test.lua
Normal file
7
sign_test.lua
Normal file
|
@ -0,0 +1,7 @@
|
|||
local util = require("util")
|
||||
print(util.sign(-1))
|
||||
print(util.sign(0))
|
||||
print(util.sign(1))
|
||||
print(util.sign(-0.1))
|
||||
print(util.sign(0.0))
|
||||
print(util.sign(0.1))
|
7
util.lua
7
util.lua
|
@ -209,6 +209,11 @@ local function cdf(x)
|
|||
-- i don't remember where this is from.
|
||||
local sign = x >= 0 and 1 or -1
|
||||
return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x)))
|
||||
|
||||
-- more accurate (via GELU paper, might be lifted from elsewhere):
|
||||
--local const = sqrt(2 / pi)
|
||||
--return 0.5 * (1 + tanh(const * (1 + 0.044715 * x * x) * x))
|
||||
--return 0.5 * (1 + tanh(0.7978845608 * (1 + 0.044715 * x * x) * x))
|
||||
end
|
||||
|
||||
local function fitness_shaping(rewards)
|
||||
|
@ -283,7 +288,7 @@ local function expect_cossim(n)
|
|||
if n >= 128000 then
|
||||
return 1 / sqrt(pi / 2 * n + 1)
|
||||
elseif n >= 80 then
|
||||
poly = (2.4674010 * n + -2.4673232) * n + 1.2274046
|
||||
poly = (2.4674010 * n - 2.4673232) * n + 1.2274046
|
||||
return 1 / sqrt(sqrt(poly))
|
||||
end
|
||||
-- fall-through when it's faster just to compute iteratively.
|
||||
|
|
35
xnes.lua
35
xnes.lua
|
@ -16,10 +16,13 @@ local unpack = table.unpack or unpack
|
|||
local Base = require "Base"
|
||||
|
||||
local nn = require "nn"
|
||||
local dot = nn.dot
|
||||
local dot_mv = nn.dot_mv
|
||||
local normal = nn.normal
|
||||
local zeros = nn.zeros
|
||||
|
||||
local expm = require "expm"
|
||||
|
||||
local util = require "util"
|
||||
local argsort = util.argsort
|
||||
|
||||
|
@ -35,22 +38,6 @@ local function make_utility(popsize, out)
|
|||
return utility
|
||||
end
|
||||
|
||||
local function make_covars(dims, sigma, out)
|
||||
local covars = out or zeros{dims, dims}
|
||||
local c = sigma / dims
|
||||
-- simplified form of the determinant of the matrix we're going to create:
|
||||
local det = pow(1 - c, dims - 1) * (c * (dims - 1) + 1)
|
||||
-- multiplying by this constant makes the determinant 1:
|
||||
local m = pow(1 / det, 1 / dims)
|
||||
|
||||
local filler = c * m
|
||||
for i=1, #covars do covars[i] = filler end
|
||||
-- diagonals:
|
||||
for i=1, dims do covars[i + dims * (i - 1)] = m end
|
||||
|
||||
return covars
|
||||
end
|
||||
|
||||
function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
|
||||
-- heuristic borrowed from CMA-ES:
|
||||
self.dims = dims
|
||||
|
@ -67,9 +54,11 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
|
|||
self.utility = make_utility(self.popsize)
|
||||
|
||||
self.mean = zeros{dims}
|
||||
-- note: this is technically the co-standard-deviation.
|
||||
-- you can imagine the "s" standing for "sqrt" if you like.
|
||||
self.covars = make_covars(self.dims, self.sigma, self.covars)
|
||||
self.covars = zeros{dims, dims}
|
||||
for i=1, dims do
|
||||
local ind = (i - 1) * dims + i -- diagonal
|
||||
self.covars[ind] = 1
|
||||
end
|
||||
|
||||
self.evals = 0
|
||||
end
|
||||
|
@ -202,9 +191,12 @@ function Xnes:tell(scored, noise)
|
|||
end
|
||||
|
||||
self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma)
|
||||
for i, v in ipairs(self.covars) do
|
||||
self.covars[i] = v * exp(self.covar_rate * 0.5 * g_covars[i])
|
||||
|
||||
-- re-use g_covars from before just to scale it.
|
||||
for i, v in ipairs(g_covars) do
|
||||
g_covars[i] = self.covar_rate * 0.5 * v
|
||||
end
|
||||
self.covars = dot(self.covars, expm(g_covars))
|
||||
|
||||
-- bookkeeping:
|
||||
self.noise = nil
|
||||
|
@ -214,7 +206,6 @@ end
|
|||
|
||||
return {
|
||||
make_utility = make_utility,
|
||||
make_covars = make_covars,
|
||||
|
||||
Xnes = Xnes,
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue