From 7462e69c6106d4d264ecce420bf851028a80b4d7 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 11 Mar 2019 07:15:41 +0100 Subject: [PATCH] temp 4 --- ars.lua | 97 ++++--- binser.lua | 749 ++++++++++++++++++++++++++++++++++++++++++++++++ cossim_test.lua | 10 + eig.lua | 290 +++++++++++++++++++ es_test.lua | 21 +- expm.lua | 81 ++++++ expm_test.lua | 31 ++ main.lua | 2 +- nn.lua | 2 + pp_test.lua | 7 + presets.lua | 97 +++++++ qr.lua | 8 +- qr2.lua | 2 +- qr3.lua | 79 +++++ qr_test.lua | 19 +- sign_test.lua | 7 + util.lua | 7 +- xnes.lua | 35 +-- 18 files changed, 1467 insertions(+), 77 deletions(-) create mode 100644 binser.lua create mode 100644 cossim_test.lua create mode 100644 eig.lua create mode 100644 expm.lua create mode 100644 expm_test.lua create mode 100644 pp_test.lua create mode 100644 qr3.lua create mode 100644 sign_test.lua diff --git a/ars.lua b/ars.lua index cb6f6da..b7ed606 100644 --- a/ars.lua +++ b/ars.lua @@ -1,14 +1,11 @@ -- Augmented Random Search -- https://arxiv.org/abs/1803.07055 --- with some tweaks (lipschitz stuff) by myself. --- i also added an option for graycode sampling, --- borrowed from a (1+1) optimizer, --- but i haven't yet found a case where it performs better. local abs = math.abs local exp = math.exp local floor = math.floor local insert = table.insert +local remove = table.remove local ipairs = ipairs local log = math.log local max = math.max @@ -26,6 +23,7 @@ local zeros = nn.zeros local util = require "util" local argsort = util.argsort local calc_mean_dev = util.calc_mean_dev +local calc_mean_dev_unbiased = util.calc_mean_dev_unbiased local normalize_sums = util.normalize_sums local sign = util.sign @@ -56,23 +54,6 @@ local function collect_best_indices(scored, top, antithetic) return indices end -local function kinda_lipschitz(dir, pos, neg, mid) - --[[ - -- based on the local lipschitz constant of a quadratic curve - -- drawn through the 3 sampled points: positive, negative, and unperturbed. - -- it kinda helps? there's probably a better function to base it around. - local _, dev = calc_mean_dev(dir) - local c0 = neg - mid - local c1 = pos - mid - local l0 = abs(3 * c1 + c0) - local l1 = abs(c1 + 3 * c0) - return max(l0, l1) / (2 * dev) - --]] - -- based on a piece-wise linear function of the 3 sampled points. - local _, dev = calc_mean_dev(dir) - return max(abs(pos - mid), abs(neg - mid)) / dev -end - function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic, momentum, beta) self.dims = dims @@ -110,7 +91,7 @@ function Ars:decay(param_decay, sigma_decay) end end -function Ars:ask(graycode) +function Ars:ask() local asked = {} local noise = {} @@ -126,17 +107,8 @@ function Ars:ask(graycode) noisy[j] = -v end else - if graycode ~= nil then - for j = 1, self.dims do - noisy[j] = exp(-precision * uniform()) - end - for j = 1, self.dims do - noisy[j] = uniform() < 0.5 and noisy[j] or -noisy[j] - end - else - for j = 1, self.dims do - noisy[j] = self.sigma * normal() - end + for j = 1, self.dims do + noisy[j] = self.sigma * normal() end end @@ -150,9 +122,8 @@ function Ars:ask(graycode) end function Ars:tell(scored, unperturbed_score) - local use_lips = unperturbed_score ~= nil and self.antithetic self.evals = self.evals + #scored - if use_lips then self.evals = self.evals + 1 end + if unperturbed_score ~= nil then self.evals = self.evals + 1 end local indices = collect_best_indices(scored, self.poptop, self.antithetic) @@ -173,7 +144,16 @@ function Ars:tell(scored, unperturbed_score) end local step = zeros(self.dims) - local _, reward_dev = calc_mean_dev(top_rewards) + + local _, reward_dev + if unperturbed_score ~= nil then + -- new stuff: + insert(top_rewards, unperturbed_score) + _, reward_dev = calc_mean_dev_unbiased(top_rewards) + remove(top_rewards) + else + _, reward_dev = calc_mean_dev(top_rewards) + end if reward_dev == 0 then reward_dev = 1 end if self.antithetic then @@ -183,12 +163,15 @@ function Ars:tell(scored, unperturbed_score) local reward = pos - neg if reward ~= 0 then local noisy = self.noise[ind * 2 - 1] - if use_lips then - local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score) - reward = reward / lips / self.sigma - else - reward = reward / reward_dev + reward = reward / reward_dev + + --[[ new stuff: + local sum_of_squares = 0 + for _, v in ipairs(noisy) do + sum_of_squares = sum_of_squares + v * v end + reward = reward / sqrt(sum_of_squares) + -]] local scale = reward / self.poptop * self.beta / 2 for j, v in ipairs(noisy) do @@ -196,7 +179,9 @@ function Ars:tell(scored, unperturbed_score) end end end + else + error("TODO: update with sum of squares stuff") for i, ind in ipairs(indices) do local reward = top_rewards[i] / reward_dev if reward ~= 0 then @@ -210,6 +195,7 @@ function Ars:tell(scored, unperturbed_score) end end + --[[ powersign momentum if self.momentum > 0 then for i, v in ipairs(step) do self.accum[i] = self.momentum * self.accum[i] + v @@ -220,6 +206,35 @@ function Ars:tell(scored, unperturbed_score) for i, v in ipairs(self._params) do self._params[i] = v + self.param_rate * step[i] end + --]] + + -- neumann momentum + if self.momentum > 0 then + local count = self.count or 0 + local period = 10 + local mu = 1 - 1 / (1 + count % period) + mu = self.momentum / (1 - 1 / period) * mu + self.count = count + 1 + -- mu is intentionally 0 for one iteration. + + -- make learning rate invariant to sigma. + for i, v in ipairs(step) do + step[i] = v / self.sigma + end + + -- update neumann iterate. + for i, v in ipairs(self.accum) do + self.accum[i] = mu * v - self.param_rate * step[i] + end + + for i, v in ipairs(self._params) do + self._params[i] = v - mu * self.accum[i] + self.param_rate * step[i] + end + else + for i, v in ipairs(self._params) do + self._params[i] = v + self.param_rate * step[i] + end + end self.noise = nil diff --git a/binser.lua b/binser.lua new file mode 100644 index 0000000..b689f9a --- /dev/null +++ b/binser.lua @@ -0,0 +1,749 @@ +-- binser.lua + +--[[ +Copyright (c) 2016 Calvin Rose + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +]] + +local assert = assert +local error = error +local select = select +local pairs = pairs +local getmetatable = getmetatable +local setmetatable = setmetatable +local type = type +local loadstring = loadstring or load +local concat = table.concat +local char = string.char +local byte = string.byte +local format = string.format +local sub = string.sub +local dump = string.dump +local floor = math.floor +local frexp = math.frexp +local unpack = unpack or table.unpack + +-- Lua 5.3 frexp polyfill +-- From https://github.com/excessive/cpml/blob/master/modules/utils.lua +if not frexp then + local log, abs, floor = math.log, math.abs, math.floor + local log2 = log(2) + frexp = function(x) + if x == 0 then return 0, 0 end + local e = floor(log(abs(x)) / log2 + 1) + return x / 2 ^ e, e + end +end + +local function pack(...) + return {...}, select("#", ...) +end + +local function not_array_index(x, len) + return type(x) ~= "number" or x < 1 or x > len or x ~= floor(x) +end + +local function type_check(x, tp, name) + assert(type(x) == tp, + format("Expected parameter %q to be of type %q.", name, tp)) +end + +local bigIntSupport = false +local isInteger +if math.type then -- Detect Lua 5.3 + local mtype = math.type + bigIntSupport = loadstring[[ + local char = string.char + return function(n) + local nn = n < 0 and -(n + 1) or n + local b1 = nn // 0x100000000000000 + local b2 = nn // 0x1000000000000 % 0x100 + local b3 = nn // 0x10000000000 % 0x100 + local b4 = nn // 0x100000000 % 0x100 + local b5 = nn // 0x1000000 % 0x100 + local b6 = nn // 0x10000 % 0x100 + local b7 = nn // 0x100 % 0x100 + local b8 = nn % 0x100 + if n < 0 then + b1, b2, b3, b4 = 0xFF - b1, 0xFF - b2, 0xFF - b3, 0xFF - b4 + b5, b6, b7, b8 = 0xFF - b5, 0xFF - b6, 0xFF - b7, 0xFF - b8 + end + return char(212, b1, b2, b3, b4, b5, b6, b7, b8) + end]]() + isInteger = function(x) + return mtype(x) == 'integer' + end +else + isInteger = function(x) + return floor(x) == x + end +end + +-- Copyright (C) 2012-2015 Francois Perrad. +-- number serialization code modified from https://github.com/fperrad/lua-MessagePack +-- Encode a number as a big-endian ieee-754 double, big-endian signed 64 bit integer, or a small integer +local function number_to_str(n) + if isInteger(n) then -- int + if n <= 100 and n >= -27 then -- 1 byte, 7 bits of data + return char(n + 27) + elseif n <= 8191 and n >= -8192 then -- 2 bytes, 14 bits of data + n = n + 8192 + return char(128 + (floor(n / 0x100) % 0x100), n % 0x100) + elseif bigIntSupport then + return bigIntSupport(n) + end + end + local sign = 0 + if n < 0.0 then + sign = 0x80 + n = -n + end + local m, e = frexp(n) -- mantissa, exponent + if m ~= m then + return char(203, 0xFF, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) + elseif m == 1/0 then + if sign == 0 then + return char(203, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) + else + return char(203, 0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) + end + end + e = e + 0x3FE + if e < 1 then -- denormalized numbers + m = m * 2 ^ (52 + e) + e = 0 + else + m = (m * 2 - 1) * 2 ^ 52 + end + return char(203, + sign + floor(e / 0x10), + (e % 0x10) * 0x10 + floor(m / 0x1000000000000), + floor(m / 0x10000000000) % 0x100, + floor(m / 0x100000000) % 0x100, + floor(m / 0x1000000) % 0x100, + floor(m / 0x10000) % 0x100, + floor(m / 0x100) % 0x100, + m % 0x100) +end + +-- Copyright (C) 2012-2015 Francois Perrad. +-- number deserialization code also modified from https://github.com/fperrad/lua-MessagePack +local function number_from_str(str, index) + local b = byte(str, index) + if not b then error("Expected more bytes of input.") end + if b < 128 then + return b - 27, index + 1 + elseif b < 192 then + local b2 = byte(str, index + 1) + if not b2 then error("Expected more bytes of input.") end + return b2 + 0x100 * (b - 128) - 8192, index + 2 + end + local b1, b2, b3, b4, b5, b6, b7, b8 = byte(str, index + 1, index + 8) + if (not b1) or (not b2) or (not b3) or (not b4) or + (not b5) or (not b6) or (not b7) or (not b8) then + error("Expected more bytes of input.") + end + if b == 212 then + local flip = b1 >= 128 + if flip then -- negative + b1, b2, b3, b4 = 0xFF - b1, 0xFF - b2, 0xFF - b3, 0xFF - b4 + b5, b6, b7, b8 = 0xFF - b5, 0xFF - b6, 0xFF - b7, 0xFF - b8 + end + local n = ((((((b1 * 0x100 + b2) * 0x100 + b3) * 0x100 + b4) * + 0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8 + if flip then + return (-n) - 1, index + 9 + else + return n, index + 9 + end + end + if b ~= 203 then + error("Expected number") + end + local sign = b1 > 0x7F and -1 or 1 + local e = (b1 % 0x80) * 0x10 + floor(b2 / 0x10) + local m = ((((((b2 % 0x10) * 0x100 + b3) * 0x100 + b4) * 0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8 + local n + if e == 0 then + if m == 0 then + n = sign * 0.0 + else + n = sign * (m / 2 ^ 52) * 2 ^ -1022 + end + elseif e == 0x7FF then + if m == 0 then + n = sign * (1/0) + else + n = 0.0/0.0 + end + else + n = sign * (1.0 + m / 2 ^ 52) * 2 ^ (e - 0x3FF) + end + return n, index + 9 +end + + +local function newbinser() + + -- unique table key for getting next value + local NEXT = {} + local CTORSTACK = {} + + -- NIL = 202 + -- FLOAT = 203 + -- TRUE = 204 + -- FALSE = 205 + -- STRING = 206 + -- TABLE = 207 + -- REFERENCE = 208 + -- CONSTRUCTOR = 209 + -- FUNCTION = 210 + -- RESOURCE = 211 + -- INT64 = 212 + -- TABLE WITH META = 213 + + local mts = {} + local ids = {} + local serializers = {} + local deserializers = {} + local resources = {} + local resources_by_name = {} + local types = {} + + types["nil"] = function(x, visited, accum) + accum[#accum + 1] = "\202" + end + + function types.number(x, visited, accum) + accum[#accum + 1] = number_to_str(x) + end + + function types.boolean(x, visited, accum) + accum[#accum + 1] = x and "\204" or "\205" + end + + function types.string(x, visited, accum) + local alen = #accum + if visited[x] then + accum[alen + 1] = "\208" + accum[alen + 2] = number_to_str(visited[x]) + else + visited[x] = visited[NEXT] + visited[NEXT] = visited[NEXT] + 1 + accum[alen + 1] = "\206" + accum[alen + 2] = number_to_str(#x) + accum[alen + 3] = x + end + end + + local function check_custom_type(x, visited, accum) + local res = resources[x] + if res then + accum[#accum + 1] = "\211" + types[type(res)](res, visited, accum) + return true + end + local mt = getmetatable(x) + local id = mt and ids[mt] + if id then + local constructing = visited[CTORSTACK] + if constructing[x] then + error("Infinite loop in constructor.") + end + constructing[x] = true + accum[#accum + 1] = "\209" + types[type(id)](id, visited, accum) + local args, len = pack(serializers[id](x)) + accum[#accum + 1] = number_to_str(len) + for i = 1, len do + local arg = args[i] + types[type(arg)](arg, visited, accum) + end + visited[x] = visited[NEXT] + visited[NEXT] = visited[NEXT] + 1 + -- We finished constructing + constructing[x] = nil + return true + end + end + + function types.userdata(x, visited, accum) + if visited[x] then + accum[#accum + 1] = "\208" + accum[#accum + 1] = number_to_str(visited[x]) + else + if check_custom_type(x, visited, accum) then return end + error("Cannot serialize this userdata.") + end + end + + function types.table(x, visited, accum) + if visited[x] then + accum[#accum + 1] = "\208" + accum[#accum + 1] = number_to_str(visited[x]) + else + if check_custom_type(x, visited, accum) then return end + visited[x] = visited[NEXT] + visited[NEXT] = visited[NEXT] + 1 + local xlen = #x + local mt = getmetatable(x) + if mt then + accum[#accum + 1] = "\213" + types.table(mt, visited, accum) + else + accum[#accum + 1] = "\207" + end + accum[#accum + 1] = number_to_str(xlen) + for i = 1, xlen do + local v = x[i] + types[type(v)](v, visited, accum) + end + local key_count = 0 + for k in pairs(x) do + if not_array_index(k, xlen) then + key_count = key_count + 1 + end + end + accum[#accum + 1] = number_to_str(key_count) + for k, v in pairs(x) do + if not_array_index(k, xlen) then + types[type(k)](k, visited, accum) + types[type(v)](v, visited, accum) + end + end + end + end + + types["function"] = function(x, visited, accum) + if visited[x] then + accum[#accum + 1] = "\208" + accum[#accum + 1] = number_to_str(visited[x]) + else + if check_custom_type(x, visited, accum) then return end + visited[x] = visited[NEXT] + visited[NEXT] = visited[NEXT] + 1 + local str = dump(x) + accum[#accum + 1] = "\210" + accum[#accum + 1] = number_to_str(#str) + accum[#accum + 1] = str + end + end + + types.cdata = function(x, visited, accum) + if visited[x] then + accum[#accum + 1] = "\208" + accum[#accum + 1] = number_to_str(visited[x]) + else + if check_custom_type(x, visited, #accum) then return end + error("Cannot serialize this cdata.") + end + end + + types.thread = function() error("Cannot serialize threads.") end + + local function deserialize_value(str, index, visited) + local t = byte(str, index) + if not t then return nil, index end + if t < 128 then + return t - 27, index + 1 + elseif t < 192 then + local b2 = byte(str, index + 1) + if not b2 then error("Expected more bytes of input.") end + return b2 + 0x100 * (t - 128) - 8192, index + 2 + elseif t == 202 then + return nil, index + 1 + elseif t == 203 or t == 212 then + return number_from_str(str, index) + elseif t == 204 then + return true, index + 1 + elseif t == 205 then + return false, index + 1 + elseif t == 206 then + local length, dataindex = number_from_str(str, index + 1) + local nextindex = dataindex + length + if not (length >= 0) then error("Bad string length") end + if #str < nextindex - 1 then error("Expected more bytes of string") end + local substr = sub(str, dataindex, nextindex - 1) + visited[#visited + 1] = substr + return substr, nextindex + elseif t == 207 or t == 213 then + local mt, count, nextindex + local ret = {} + visited[#visited + 1] = ret + nextindex = index + 1 + if t == 213 then + mt, nextindex = deserialize_value(str, nextindex, visited) + if type(mt) ~= "table" then error("Expected table metatable") end + end + count, nextindex = number_from_str(str, nextindex) + for i = 1, count do + local oldindex = nextindex + ret[i], nextindex = deserialize_value(str, nextindex, visited) + if nextindex == oldindex then error("Expected more bytes of input.") end + end + count, nextindex = number_from_str(str, nextindex) + for i = 1, count do + local k, v + local oldindex = nextindex + k, nextindex = deserialize_value(str, nextindex, visited) + if nextindex == oldindex then error("Expected more bytes of input.") end + oldindex = nextindex + v, nextindex = deserialize_value(str, nextindex, visited) + if nextindex == oldindex then error("Expected more bytes of input.") end + if k == nil then error("Can't have nil table keys") end + ret[k] = v + end + if mt then setmetatable(ret, mt) end + return ret, nextindex + elseif t == 208 then + local ref, nextindex = number_from_str(str, index + 1) + return visited[ref], nextindex + elseif t == 209 then + local count + local name, nextindex = deserialize_value(str, index + 1, visited) + count, nextindex = number_from_str(str, nextindex) + local args = {} + for i = 1, count do + local oldindex = nextindex + args[i], nextindex = deserialize_value(str, nextindex, visited) + if nextindex == oldindex then error("Expected more bytes of input.") end + end + if not name or not deserializers[name] then + error(("Cannot deserialize class '%s'"):format(tostring(name))) + end + local ret = deserializers[name](unpack(args)) + visited[#visited + 1] = ret + return ret, nextindex + elseif t == 210 then + local length, dataindex = number_from_str(str, index + 1) + local nextindex = dataindex + length + if not (length >= 0) then error("Bad string length") end + if #str < nextindex - 1 then error("Expected more bytes of string") end + local ret = loadstring(sub(str, dataindex, nextindex - 1)) + visited[#visited + 1] = ret + return ret, nextindex + elseif t == 211 then + local resname, nextindex = deserialize_value(str, index + 1, visited) + if resname == nil then error("Got nil resource name") end + local res = resources_by_name[resname] + if res == nil then + error(("No resources found for name '%s'"):format(tostring(resname))) + end + return res, nextindex + else + error("Could not deserialize type byte " .. t .. ".") + end + end + + local function serialize(...) + local visited = {[NEXT] = 1, [CTORSTACK] = {}} + local accum = {} + for i = 1, select("#", ...) do + local x = select(i, ...) + types[type(x)](x, visited, accum) + end + return concat(accum) + end + + local function make_file_writer(file) + return setmetatable({}, { + __newindex = function(_, _, v) + file:write(v) + end + }) + end + + local function serialize_to_file(path, mode, ...) + local file, err = io.open(path, mode) + assert(file, err) + local visited = {[NEXT] = 1, [CTORSTACK] = {}} + local accum = make_file_writer(file) + for i = 1, select("#", ...) do + local x = select(i, ...) + types[type(x)](x, visited, accum) + end + -- flush the writer + file:flush() + file:close() + end + + local function writeFile(path, ...) + return serialize_to_file(path, "wb", ...) + end + + local function appendFile(path, ...) + return serialize_to_file(path, "ab", ...) + end + + local function deserialize(str, index) + assert(type(str) == "string", "Expected string to deserialize.") + local vals = {} + index = index or 1 + local visited = {} + local len = 0 + local val + while true do + local nextindex + val, nextindex = deserialize_value(str, index, visited) + if nextindex > index then + len = len + 1 + vals[len] = val + index = nextindex + else + break + end + end + return vals, len + end + + local function deserializeN(str, n, index) + assert(type(str) == "string", "Expected string to deserialize.") + n = n or 1 + assert(type(n) == "number", "Expected a number for parameter n.") + assert(n > 0 and floor(n) == n, "N must be a poitive integer.") + local vals = {} + index = index or 1 + local visited = {} + local len = 0 + local val + while len < n do + local nextindex + val, nextindex = deserialize_value(str, index, visited) + if nextindex > index then + len = len + 1 + vals[len] = val + index = nextindex + else + break + end + end + vals[len + 1] = index + return unpack(vals, 1, n + 1) + end + + local function readFile(path) + local file, err = io.open(path, "rb") + assert(file, err) + local str = file:read("*all") + file:close() + return deserialize(str) + end + + -- Resources + + local function registerResource(resource, name) + type_check(name, "string", "name") + assert(not resources[resource], + "Resource already registered.") + assert(not resources_by_name[name], + format("Resource %q already exists.", name)) + resources_by_name[name] = resource + resources[resource] = name + return resource + end + + local function unregisterResource(name) + type_check(name, "string", "name") + assert(resources_by_name[name], format("Resource %q does not exist.", name)) + local resource = resources_by_name[name] + resources_by_name[name] = nil + resources[resource] = nil + return resource + end + + -- Templating + + local function normalize_template(template) + local ret = {} + for i = 1, #template do + ret[i] = template[i] + end + local non_array_part = {} + -- The non-array part of the template (nested templates) have to be deterministic, so they are sorted. + -- This means that inherently non deterministicly sortable keys (tables, functions) should NOT be used + -- in templates. Looking for way around this. + for k in pairs(template) do + if not_array_index(k, #template) then + non_array_part[#non_array_part + 1] = k + end + end + table.sort(non_array_part) + for i = 1, #non_array_part do + local name = non_array_part[i] + ret[#ret + 1] = {name, normalize_template(template[name])} + end + return ret + end + + local function templatepart_serialize(part, argaccum, x, len) + local extras = {} + local extracount = 0 + for k, v in pairs(x) do + extras[k] = v + extracount = extracount + 1 + end + for i = 1, #part do + local name + if type(part[i]) == "table" then + name = part[i][1] + len = templatepart_serialize(part[i][2], argaccum, x[name], len) + else + name = part[i] + len = len + 1 + argaccum[len] = x[part[i]] + end + if extras[name] ~= nil then + extracount = extracount - 1 + extras[name] = nil + end + end + if extracount > 0 then + argaccum[len + 1] = extras + else + argaccum[len + 1] = nil + end + return len + 1 + end + + local function templatepart_deserialize(ret, part, values, vindex) + for i = 1, #part do + local name = part[i] + if type(name) == "table" then + local newret = {} + ret[name[1]] = newret + vindex = templatepart_deserialize(newret, name[2], values, vindex) + else + ret[name] = values[vindex] + vindex = vindex + 1 + end + end + local extras = values[vindex] + if extras then + for k, v in pairs(extras) do + ret[k] = v + end + end + return vindex + 1 + end + + local function template_serializer_and_deserializer(metatable, template) + return function(x) + local argaccum = {} + local len = templatepart_serialize(template, argaccum, x, 0) + return unpack(argaccum, 1, len) + end, function(...) + local ret = {} + local args = {...} + templatepart_deserialize(ret, template, args, 1) + return setmetatable(ret, metatable) + end + end + + -- Used to serialize classes withh custom serializers and deserializers. + -- If no _serialize or _deserialize (or no _template) value is found in the + -- metatable, then the metatable is registered as a resources. + local function register(metatable, name, serialize, deserialize) + if type(metatable) == "table" then + name = name or metatable.name + serialize = serialize or metatable._serialize + deserialize = deserialize or metatable._deserialize + if (not serialize) or (not deserialize) then + if metatable._template then + -- Register as template + local t = normalize_template(metatable._template) + serialize, deserialize = template_serializer_and_deserializer(metatable, t) + else + -- Register the metatable as a resource. This is semantically + -- similar and more flexible (handles cycles). + registerResource(metatable, name) + return + end + end + elseif type(metatable) == "string" then + name = name or metatable + end + type_check(name, "string", "name") + type_check(serialize, "function", "serialize") + type_check(deserialize, "function", "deserialize") + assert((not ids[metatable]) and (not resources[metatable]), + "Metatable already registered.") + assert((not mts[name]) and (not resources_by_name[name]), + ("Name %q already registered."):format(name)) + mts[name] = metatable + ids[metatable] = name + serializers[name] = serialize + deserializers[name] = deserialize + return metatable + end + + local function unregister(item) + local name, metatable + if type(item) == "string" then -- assume name + name, metatable = item, mts[item] + else -- assume metatable + name, metatable = ids[item], item + end + type_check(name, "string", "name") + mts[name] = nil + if (metatable) then + resources[metatable] = nil + ids[metatable] = nil + end + serializers[name] = nil + deserializers[name] = nil + resources_by_name[name] = nil; + return metatable + end + + local function registerClass(class, name) + name = name or class.name + if class.__instanceDict then -- middleclass + register(class.__instanceDict, name) + else -- assume 30log or similar library + register(class, name) + end + return class + end + + return { + -- aliases + s = serialize, + d = deserialize, + dn = deserializeN, + r = readFile, + w = writeFile, + a = appendFile, + + serialize = serialize, + deserialize = deserialize, + deserializeN = deserializeN, + readFile = readFile, + writeFile = writeFile, + appendFile = appendFile, + register = register, + unregister = unregister, + registerResource = registerResource, + unregisterResource = unregisterResource, + registerClass = registerClass, + + newbinser = newbinser + } +end + +return newbinser() diff --git a/cossim_test.lua b/cossim_test.lua new file mode 100644 index 0000000..f203543 --- /dev/null +++ b/cossim_test.lua @@ -0,0 +1,10 @@ +local util = require "util" + +for dims = 1, 100 do + print(dims, util.expect_cossim(dims)) +end + +dims = 1000 +print(dims, util.expect_cossim(dims)) +dims = 10000 +print(dims, util.expect_cossim(dims)) diff --git a/eig.lua b/eig.lua new file mode 100644 index 0000000..9b9f13c --- /dev/null +++ b/eig.lua @@ -0,0 +1,290 @@ +-- blah + +--local globalize = require "strict" +local nn = require "nn" +local util = require "util" + +local sign = util.sign +local abs = math.abs +local reshape = nn.reshape +local sqrt = math.sqrt +local zeros = nn.zeros + +local function tred2(a) + assert(#a.shape == 2) + assert(a.shape[1] == a.shape[2]) + local n = a.shape[1] + local d = zeros(n) -- diagonal + local e = zeros(n) -- off-diagonal (e[1] is a dummy value?) + + for i = n, 2, -1 do + local l = i - 1 + local ind_li = (l - 1) * n + i + local h = 0 + local scale = 0 + + if l > 1 then + for k = 1, l do + local ind_ki = (k - 1) * n + i + scale = scale + abs(a[ind_ki]) + end + + if scale == 0 then + e[i] = a[ind_li] + else + for k = 1, l do + local ind_ki = (k - 1) * n + i + a[ind_ki] = a[ind_ki] / scale + h = h + a[ind_ki] * a[ind_ki] + end + + local f = a[ind_li] + local g = sqrt(h) + if f >= 0 then g = -g end + e[i] = scale * g + h = h - f * g + a[ind_li] = f - g + f = 0 + + for j = 1, l do + local ind_ij = (i - 1) * n + j + local ind_ji = (j - 1) * n + i + a[ind_ij] = a[ind_ji] / h + g = 0 + for k = 1, j do + local ind_kj = (k - 1) * n + j + local ind_ki = (k - 1) * n + i + g = g + a[ind_kj] * a[ind_ki] + end + for k = j + 1, l do + local ind_jk = (j - 1) * n + k + local ind_ki = (k - 1) * n + i + g = g + a[ind_jk] * a[ind_ki] + end + e[j] = g / h + f = f + e[j] * a[ind_ji] + end + + local hh = f / (h + h) + for j = 1, l do + local ind_ji = (j - 1) * n + i + f = a[ind_ji] + g = e[j] - hh * f + e[j] = g + for k = 1, j do + local ind_kj = (k - 1) * n + j + local ind_ki = (k - 1) * n + i + a[ind_kj] = a[ind_kj] - (f * e[k] + g * a[ind_ki]) + end + end + end + else + e[i] = a[ind_li] + end + + d[i] = h + end + + d[1] = 0 + e[1] = 0 + for i = 1, n do + local l = i - 1 + + if d[i] ~= 0 then + for j = 1, l do + local g = 0 + for k = 1, l do + local ind_ki = (k - 1) * n + i + local ind_jk = (j - 1) * n + k + g = g + a[ind_ki] * a[ind_jk] + end + for k = 1, l do + local ind_ik = (i - 1) * n + k + local ind_jk = (j - 1) * n + k + a[ind_jk] = a[ind_jk] - g * a[ind_ik] + end + end + end + + local ind_ii = (i - 1) * n + i + d[i] = a[ind_ii] + a[ind_ii] = 1 + for j = 1, l do + local ind_ij = (i - 1) * n + j + local ind_ji = (j - 1) * n + i + a[ind_ij] = 0 + a[ind_ji] = 0 + end + end + + return d, e +end + +local function pythag(a, b) + --return sqrt(a * a + b * b) + local abs_a = abs(a) + local abs_b = abs(b) + + if abs_a > abs_b then + local temp = abs_b / abs_a + temp = temp * temp + return abs_a * sqrt(1 + temp) + elseif abs_b ~= 0 then + local temp = abs_a / abs_b + temp = temp * temp + return abs_b * sqrt(1 + temp) + end + + return 0 +end + +local function tqli(d, e, z) + assert(#z.shape == 2) + assert(z.shape[1] == z.shape[2]) + local n = z.shape[1] + assert(#d == n) + assert(#e == n) + + local eps = 1.2e-7 + + for i = 2, n do e[i - 1] = e[i] end + e[n] = 0 + + for l = 1, n do + local iter = 0 + local fucky = 0 + + local m + while true do + m = l + while m <= n - 1 do + local dd = abs(d[m]) + abs(d[m + 1]) + if abs(e[m]) + dd == dd then break end + --if abs(e[m]) <= eps * dd then break end + m = m + 1 + end + + fucky = fucky + 1 + if fucky == 100 then print("fucky!"); break end + if fucky == 1000 then error("super fucky!"); break end + + --print(("l: %i, m: %i"):format(l - 1, m - 1)) + + if m == l then break end + + iter = iter + 1 + if iter >= 32 then error("Too many iterations in tqli") end + + local g = (d[l + 1] - d[l]) / (2 * e[l]) + local r = pythag(g, 1) + g = d[m] - d[l] + e[l] / (g + r * sign(g)) + local s = 1 + local c = 1 + local p = 0 + + for i = m - 1, l, -1 do + local f = s * e[i] + local b = c * e[i] + r = pythag(f, g) + e[i + 1] = r + if r == 0 then + d[i + 1] = d[i + 1] - p + e[m] = 0 + break + end + + s = f / r + c = g / r + g = d[i + 1] - p + r = (d[i] - g) * s + 2 * c * b + p = s * r + d[i + 1] = g + p + g = c * r - b + + for k = 1, n do + if true then + local ind = (i - 1) * n + k + f = z[ind + n] + z[ind + n] = s * z[ind] + c * f + z[ind] = c * z[ind] - s * f + else + local ind = (k - 1) * n + i + f = z[ind + 1] + z[ind + 1] = s * z[ind] + c * f + z[ind] = c * z[ind] - s * f + end + end + end + + if r == 0 and i >= l then + -- continue + else + d[l] = d[l] - p + e[l] = g + e[m] = 0 + end + end + end +end + +--[=[ + +local A = { + 4, 1, -2, 2, + 1, 2, 0, 1, + -2, 0, 3, -2, + 2, 1, -2, -1, +} +reshape(A, 4, 4) + +local d, e = tred2(A) + +--[[ +print(nn.pp(A)) +print(nn.pp(d)) +print(nn.pp(e)) +--]] + +--[[ +{ + 0.248069, 0.744208, 0.620174, 0.000000, + 0.702863, -0.578829, 0.413449, 0.000000, + 0.666667, 0.333333, -0.666667, 0.000000, + 0.000000, 0.000000, 0.000000, 1.000000, +} +{ 2.261538, 1.182906, 5.555556, -1.000000,} +{ 0.000000, -0.092308, 0.895806, 3.000000,} +--]] + +--A = nn.transpose(A) +tqli(d, e, A) + +print(nn.pp(A)) +print(nn.pp(d)) +print(nn.pp(e)) + +local D = zeros{4, 4} +for i = 1, 4 do + D[(i - 1) * 4 + i] = math.exp(d[i]) +end +local out = nn.dot(nn.transpose(A), D) +out = nn.dot(out, A) +print(nn.pp(out)) + +--[[ +{ + 703.414032, 1410.125991, 1478.990752, -43.126976, + -1205.204565, 1573.963121, -940.902581, 319.676625, + 14.433478, 3.884400, -10.434671, 6.841011, + 3.529384, 3.087068, -5.116236, -17.850223, +} +{ 2.273819, 1.072834, 6.818268, -2.164921,} +{ -0.000000, 0.000000, 0.000000, 0.000000,} +--]] + +--]=] + +return { + tred2=tred2, + tqli=tqli, +} diff --git a/es_test.lua b/es_test.lua index 264155b..018d499 100644 --- a/es_test.lua +++ b/es_test.lua @@ -28,15 +28,24 @@ end -- i'm just copying settings from hardmaru's simple_es_example.ipynb. local iterations = 3000 --4000 -local dims = 100 -local popsize = dims + 1 + +local dims, popsize +if false then + dims = 100 + popsize = dims + 1 +else + dims = 30 + popsize = 99 +end + local sigma_init = 0.5 --local es = xnes.Xnes(dims, popsize, 0.1, sigma_init) local es = snes.Snes(dims, popsize, 0.1, sigma_init) --local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true) --local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true) --local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5) -es.min_refresh = 0.5 -- FIXME: needs a better interface. + +es.min_refresh = 1.0 -- FIXME: needs a better interface. if typeof(es) == xnes.Xnes or typeof(es) == snes.Snes @@ -55,9 +64,9 @@ then local util = require "util" util.normalize_sums(es.utility) - es.param_rate = 0.39 - es.sigma_rate = 0.39 - es.covar_rate = 0.39 + es.param_rate = 0.2 --39 + es.sigma_rate = 0.05 --39 + es.covar_rate = 0.1 --39 es.adaptive = false end diff --git a/expm.lua b/expm.lua new file mode 100644 index 0000000..e2d4703 --- /dev/null +++ b/expm.lua @@ -0,0 +1,81 @@ +-- here's a really awful way of computing the matrix exponential. +-- we employ the QR algorithm to find eigenpairs of a given symmetric matrix, +-- then run the ordinary exponent function over the eigenvalues. +-- this only works for symmetric matrices! + +local nn = require "nn" +local qr = require "qr" +local util = require "util" + +local copy = util.copy +local dot = nn.dot +local exp = math.exp +local reshape = nn.reshape +local transpose = nn.transpose +local zeros = nn.zeros + +local function expm(mat) + assert(#mat.shape == 2) + assert(mat.shape[1] == mat.shape[2], "expm input must be square") + --assert(stuff(mat), "expm input must be symmetrical") + + local dims = mat.shape[1] + + local vec = zeros(mat.shape) + for i = 1, dims do + local ind = (i - 1) * dims + i -- diagonal + vec[ind] = 1 + end + + local diag = mat + for i = 1, 10 do + local q, r = qr(diag) + vec = dot(vec, q) + diag = dot(r, q) + end + + for y = 1, dims do + for x = 1, dims do + local ind = (y - 1) * dims + x + if x == y then + diag[ind] = exp(diag[ind]) + else + diag[ind] = 0 + end + end + end + + return dot(dot(vec, diag), transpose(vec)) +end + +local eig = require "eig" +local tred2 = eig.tred2 +local tqli = eig.tqli + +local function expm2(mat) + assert(#mat.shape == 2) + assert(mat.shape[1] == mat.shape[2]) + local dims = mat.shape[1] + + -- new version that computes much better (faster?) eigenpairs + local vec = copy(mat) + local d, e = tred2(vec) + tqli(d, e, vec) + + local diag = {} + for y = 1, dims do + for x = 1, dims do + local ind = (y - 1) * dims + x + if x == y then + diag[ind] = exp(d[x]) + else + diag[ind] = 0 + end + end + end + reshape(diag, dims, dims) + + return dot(dot(transpose(vec), diag), vec) +end + +return expm2 diff --git a/expm_test.lua b/expm_test.lua new file mode 100644 index 0000000..c3994c2 --- /dev/null +++ b/expm_test.lua @@ -0,0 +1,31 @@ +local globalize = require "strict" + +local nn = require "nn" +local expm = require "expm" + +local sqrt = math.sqrt + +local dims = 8 +local n = dims * dims +local M = {} +local coeff = dims * dims * sqrt(dims) +for i = 1, n do M[i] = i / coeff end +M.shape = {dims, dims} + +M = nn.dot(M, nn.transpose(M)) + +print(nn.pp(M, "%9.6f")) + +local exp_M = expm(M) + +print(nn.pp(exp_M, "%9.6f")) + +local binser = require "binser" + +--local dat = binser.s(exp_M) +binser.writeFile("expm.bin", exp_M) +local res, len = binser.readFile("expm.bin") +assert(len == 1) +exp_M = res[1] + +print(nn.pp(exp_M, "%9.6f")) diff --git a/main.lua b/main.lua index 2cf4ebe..c9c02de 100644 --- a/main.lua +++ b/main.lua @@ -284,7 +284,7 @@ local function learn_from_epoch() end local step - if cfg.es == 'ars' and cfg.ars_lips then + if cfg.es == 'ars' then --and cfg.ars_lips then step = es:tell(trial_rewards, current_cost) else step = es:tell(trial_rewards) diff --git a/nn.lua b/nn.lua index d702b40..3770536 100644 --- a/nn.lua +++ b/nn.lua @@ -119,6 +119,8 @@ local function pp(t, fmt, sep, ti, di, depth, isfirst, islast) di = di or 1 depth = depth or 0 + if t == nil then return "nil" end + if t.shape == nil then return '['..pp_join(sep, fmt, t)..']'..sep..'\n' end local dim = t.shape[di] diff --git a/pp_test.lua b/pp_test.lua new file mode 100644 index 0000000..82e62eb --- /dev/null +++ b/pp_test.lua @@ -0,0 +1,7 @@ +local nn = require "nn" + +a = {0,1,2,3,4,5,6,7} +print(nn.pp(a, "%9.4f")) +print(nn.pp(nn.reshape(a, 4, 2), "%9.4f")) +print(nn.pp(nn.reshape(a, 2, 4), "%9.4f")) +print(nn.pp(nn.reshape(a, 2, 2, 2), "%9.4f")) diff --git a/presets.lua b/presets.lua index 9a397f9..ffb9434 100644 --- a/presets.lua +++ b/presets.lua @@ -412,6 +412,103 @@ make_preset{ beta = 1.0, } +make_preset{ + name = 'xnes-recon', -- recon-sidered +-- parent = 'big-scroll-hidden', + parent = 'big-scroll-reduced', + + es = 'xnes', + +-- embed = false, +-- reduce_tiles = 5, +-- hidden_size = 54, + + epoch_trials = 20, + epoch_top_trials = 20, + negate_trials = true, + + deterministic = true, + deviation = 0.1, + + param_decay = 0.01, + + param_rate = 0.2, + sigma_rate = 0.05, + covar_rate = 0.1, +} + +make_preset{ + name = 'arse', + parent = 'big-scroll-reduced', + + deterministic = true, + + es = 'ars', + epoch_trials = 5, + epoch_top_trials = 9999, + deviation = 1.0, + param_rate = 1.0, + beta = 1.0, -- fix the default. + + beta = 5.0, -- oops, i had a dumb bug. + param_decay = 0.01, + momentum = 0.0, +} + +make_preset{ + name = 'arse2', + parent = 'arse', + + beta = 5.0, -- oops, i had a dumb bug. + param_decay = 0.001, + embed = false, +} + +make_preset{ + name = 'arse3', + parent = 'arse2', + + epoch_trials = 10, + + -- note: also disabled sum of squares in ars.lua + beta = 1.0, + + --deviation = 0.5, -- after 500 epochs + deviation = 0.7071, -- after 500+521 epochs +} + +make_preset{ + name = 'arse4', + parent = 'arse3', + + hidden = true, + hidden_size = 68, +} + +make_preset{ + name = 'arse5', + parent = 'arse3', + + -- sum of squares still disabled. i probably won't re-enable it really. + --deviation = 1.0, + deviation = 1.414, -- after 80+360+790 epochs + momentum = 0.8, -- neumann momentum, peaks at this value + --param_decay = 0.003, -- after 80 epochs + --param_decay = 0.01, -- after 80+360 epochs + param_decay = 0.0, -- after 80+360+790 epochs +} + +make_preset{ + name = 'arse6', + parent = 'arse3', + + epoch_trials = 20, + param_rate = 0.25, + deviation = 0.1, -- maybe try 0.15 + momentum = 0.9, -- maybe try 0.8 + param_decay = 0.0, +} + -- end of new stuff make_preset{ diff --git a/qr.lua b/qr.lua index d03ef1e..9dc1b23 100644 --- a/qr.lua +++ b/qr.lua @@ -1,9 +1,11 @@ -local min = math.min -local sqrt = math.sqrt - local nn = require "nn" + +local assert = assert local dot = nn.dot +local ipairs = ipairs +local min = math.min local reshape = nn.reshape +local sqrt = math.sqrt local transpose = nn.transpose local zeros = nn.zeros diff --git a/qr2.lua b/qr2.lua index 8d5f98b..0042e8c 100644 --- a/qr2.lua +++ b/qr2.lua @@ -43,7 +43,7 @@ local function qr(a) local den = 0 for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end for k = j0, j1 do den = den + q[k] * q[k] end - print(num, den) + --print(num, den) if den == 0 then den = 1 end -- TODO: should probably just error. local x = num / den diff --git a/qr3.lua b/qr3.lua new file mode 100644 index 0000000..da5f241 --- /dev/null +++ b/qr3.lua @@ -0,0 +1,79 @@ +-- gram-schmidt process + +local nn = require "nn" + +local dot = nn.dot +local reshape = nn.reshape +local sqrt = math.sqrt +local transpose = nn.transpose +local zeros = nn.zeros + +local function qr(mat) + assert(#mat.shape == 2) + local v = transpose(mat) + local rows = v.shape[1] + local cols = v.shape[2] + + local u = zeros(v.shape) + + local w = {} + local y = 1 + + local function load_row() + local start = (y - 1) * cols + for x = 1, cols do w[x] = v[start + x] end + end + + local function push_row() + local sum = 0 + for _, value in ipairs(w) do sum = sum + value * value end + local norm = sqrt(sum) + + local start = (y - 1) * cols + for x, value in ipairs(w) do u[start + x] = value / norm end + + y = y + 1 + end + + load_row() + push_row() + + local sums = {} + + for i = 2, rows do + load_row() + + for x = 1, cols do sums[x] = 0 end + + for j = 1, i - 1 do + local start = (j - 1) * cols + + local dotted = 0 + for x, value in ipairs(w) do + dotted = dotted + value * u[start + x] + end + + --[[ + local scale = 0 + for x = 1, cols do + local value = u[start + x] + scale = scale + value * value + end + print(scale) + dotted = dotted / scale + --]] + + for x, value in ipairs(sums) do + sums[x] = value + dotted * u[start + x] + end + end + + for x, value in ipairs(w) do w[x] = value - sums[x] end + + push_row() + end + + return transpose(u), dot(u, mat) +end + +return qr diff --git a/qr_test.lua b/qr_test.lua index 46faed3..22f9b19 100644 --- a/qr_test.lua +++ b/qr_test.lua @@ -2,6 +2,7 @@ local globalize = require "strict" local nn = require "nn" local qr = require "qr" local qr2 = require "qr2" +local qr3 = require "qr3" local A @@ -24,7 +25,7 @@ elseif false then } A = nn.reshape(A, 5, 3) -else +elseif false then A = { 1, 2, 0, 2, 3, 1, @@ -44,13 +45,27 @@ else A = nn.reshape(A, 5, 3) --A = nn.transpose(A) +else + A = { + 1, 2, -3, + 2, 4, 5, + -3, 5, 6, + } + + A = nn.reshape(A, 3, 3) end print("A") print(nn.pp(A, "%9.4f")) print() -local Q, R = qr2(A) +local Q, R = qr(A) + +print("Q (reference)") +print(nn.pp(Q, "%9.4f")) +print() + +local Q, R = qr3(A) print("Q") print(nn.pp(Q, "%9.4f")) diff --git a/sign_test.lua b/sign_test.lua new file mode 100644 index 0000000..6ba3036 --- /dev/null +++ b/sign_test.lua @@ -0,0 +1,7 @@ +local util = require("util") +print(util.sign(-1)) +print(util.sign(0)) +print(util.sign(1)) +print(util.sign(-0.1)) +print(util.sign(0.0)) +print(util.sign(0.1)) diff --git a/util.lua b/util.lua index cbb61ab..91c717d 100644 --- a/util.lua +++ b/util.lua @@ -209,6 +209,11 @@ local function cdf(x) -- i don't remember where this is from. local sign = x >= 0 and 1 or -1 return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x))) + + -- more accurate (via GELU paper, might be lifted from elsewhere): + --local const = sqrt(2 / pi) + --return 0.5 * (1 + tanh(const * (1 + 0.044715 * x * x) * x)) + --return 0.5 * (1 + tanh(0.7978845608 * (1 + 0.044715 * x * x) * x)) end local function fitness_shaping(rewards) @@ -283,7 +288,7 @@ local function expect_cossim(n) if n >= 128000 then return 1 / sqrt(pi / 2 * n + 1) elseif n >= 80 then - poly = (2.4674010 * n + -2.4673232) * n + 1.2274046 + poly = (2.4674010 * n - 2.4673232) * n + 1.2274046 return 1 / sqrt(sqrt(poly)) end -- fall-through when it's faster just to compute iteratively. diff --git a/xnes.lua b/xnes.lua index ba94127..a504db8 100644 --- a/xnes.lua +++ b/xnes.lua @@ -16,10 +16,13 @@ local unpack = table.unpack or unpack local Base = require "Base" local nn = require "nn" +local dot = nn.dot local dot_mv = nn.dot_mv local normal = nn.normal local zeros = nn.zeros +local expm = require "expm" + local util = require "util" local argsort = util.argsort @@ -35,22 +38,6 @@ local function make_utility(popsize, out) return utility end -local function make_covars(dims, sigma, out) - local covars = out or zeros{dims, dims} - local c = sigma / dims - -- simplified form of the determinant of the matrix we're going to create: - local det = pow(1 - c, dims - 1) * (c * (dims - 1) + 1) - -- multiplying by this constant makes the determinant 1: - local m = pow(1 / det, 1 / dims) - - local filler = c * m - for i=1, #covars do covars[i] = filler end - -- diagonals: - for i=1, dims do covars[i + dims * (i - 1)] = m end - - return covars -end - function Xnes:init(dims, popsize, base_rate, sigma, antithetic) -- heuristic borrowed from CMA-ES: self.dims = dims @@ -67,9 +54,11 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic) self.utility = make_utility(self.popsize) self.mean = zeros{dims} - -- note: this is technically the co-standard-deviation. - -- you can imagine the "s" standing for "sqrt" if you like. - self.covars = make_covars(self.dims, self.sigma, self.covars) + self.covars = zeros{dims, dims} + for i=1, dims do + local ind = (i - 1) * dims + i -- diagonal + self.covars[ind] = 1 + end self.evals = 0 end @@ -202,9 +191,12 @@ function Xnes:tell(scored, noise) end self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma) - for i, v in ipairs(self.covars) do - self.covars[i] = v * exp(self.covar_rate * 0.5 * g_covars[i]) + + -- re-use g_covars from before just to scale it. + for i, v in ipairs(g_covars) do + g_covars[i] = self.covar_rate * 0.5 * v end + self.covars = dot(self.covars, expm(g_covars)) -- bookkeeping: self.noise = nil @@ -214,7 +206,6 @@ end return { make_utility = make_utility, - make_covars = make_covars, Xnes = Xnes, }