Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
Connor Olding | 7462e69c61 | ||
Connor Olding | 08f476e6ac | ||
Connor Olding | a1429a6271 | ||
Connor Olding | b7938a1785 |
107
ars.lua
107
ars.lua
|
@ -1,17 +1,16 @@
|
|||
-- Augmented Random Search
|
||||
-- https://arxiv.org/abs/1803.07055
|
||||
-- with some tweaks (lipschitz stuff) by myself.
|
||||
-- i also added an option for graycode sampling,
|
||||
-- borrowed from a (1+1) optimizer,
|
||||
-- but i haven't yet found a case where it performs better.
|
||||
|
||||
local abs = math.abs
|
||||
local exp = math.exp
|
||||
local floor = math.floor
|
||||
local insert = table.insert
|
||||
local remove = table.remove
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local max = math.max
|
||||
local print = print
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local Base = require "Base"
|
||||
|
||||
|
@ -24,6 +23,7 @@ local zeros = nn.zeros
|
|||
local util = require "util"
|
||||
local argsort = util.argsort
|
||||
local calc_mean_dev = util.calc_mean_dev
|
||||
local calc_mean_dev_unbiased = util.calc_mean_dev_unbiased
|
||||
local normalize_sums = util.normalize_sums
|
||||
local sign = util.sign
|
||||
|
||||
|
@ -54,29 +54,16 @@ local function collect_best_indices(scored, top, antithetic)
|
|||
return indices
|
||||
end
|
||||
|
||||
local function kinda_lipschitz(dir, pos, neg, mid)
|
||||
-- based on the local lipschitz constant of a quadratic curve
|
||||
-- drawn through the 3 sampled points: positive, negative, and unperturbed.
|
||||
-- it kinda helps? there's probably a better function to base it around.
|
||||
local _, dev = calc_mean_dev(dir)
|
||||
local c0 = neg - mid
|
||||
local c1 = pos - mid
|
||||
local l0 = abs(3 * c1 + c0)
|
||||
local l1 = abs(c1 + 3 * c0)
|
||||
return max(l0, l1) / (2 * dev)
|
||||
end
|
||||
|
||||
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic,
|
||||
momentum)
|
||||
momentum, beta)
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
self.param_rate = base_rate
|
||||
self.sigma_rate = base_rate
|
||||
self.covar_rate = base_rate
|
||||
self.sigma = sigma or 1
|
||||
self.antithetic = antithetic == nil and true or antithetic
|
||||
self.momentum = momentum or 0
|
||||
self.beta = beta or 1.0
|
||||
|
||||
self.poptop = poptop or popsize
|
||||
assert(self.poptop <= popsize)
|
||||
|
@ -104,7 +91,7 @@ function Ars:decay(param_decay, sigma_decay)
|
|||
end
|
||||
end
|
||||
|
||||
function Ars:ask(graycode)
|
||||
function Ars:ask()
|
||||
local asked = {}
|
||||
local noise = {}
|
||||
|
||||
|
@ -119,20 +106,11 @@ function Ars:ask(graycode)
|
|||
for j, v in ipairs(old_noisy) do
|
||||
noisy[j] = -v
|
||||
end
|
||||
else
|
||||
if graycode ~= nil then
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = exp(-precision * uniform())
|
||||
end
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = uniform() < 0.5 and noisy[j] or -noisy[j]
|
||||
end
|
||||
else
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = self.sigma * normal()
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
for j, v in ipairs(self._params) do
|
||||
asking[j] = v + noisy[j]
|
||||
|
@ -144,9 +122,8 @@ function Ars:ask(graycode)
|
|||
end
|
||||
|
||||
function Ars:tell(scored, unperturbed_score)
|
||||
local use_lips = unperturbed_score ~= nil and self.antithetic
|
||||
self.evals = self.evals + #scored
|
||||
if use_lips then self.evals = self.evals + 1 end
|
||||
if unperturbed_score ~= nil then self.evals = self.evals + 1 end
|
||||
|
||||
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
|
||||
|
||||
|
@ -167,7 +144,16 @@ function Ars:tell(scored, unperturbed_score)
|
|||
end
|
||||
|
||||
local step = zeros(self.dims)
|
||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||
|
||||
local _, reward_dev
|
||||
if unperturbed_score ~= nil then
|
||||
-- new stuff:
|
||||
insert(top_rewards, unperturbed_score)
|
||||
_, reward_dev = calc_mean_dev_unbiased(top_rewards)
|
||||
remove(top_rewards)
|
||||
else
|
||||
_, reward_dev = calc_mean_dev(top_rewards)
|
||||
end
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
if self.antithetic then
|
||||
|
@ -177,31 +163,39 @@ function Ars:tell(scored, unperturbed_score)
|
|||
local reward = pos - neg
|
||||
if reward ~= 0 then
|
||||
local noisy = self.noise[ind * 2 - 1]
|
||||
if use_lips then
|
||||
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
|
||||
reward = reward / lips / self.sigma
|
||||
else
|
||||
reward = reward / reward_dev
|
||||
|
||||
--[[ new stuff:
|
||||
local sum_of_squares = 0
|
||||
for _, v in ipairs(noisy) do
|
||||
sum_of_squares = sum_of_squares + v * v
|
||||
end
|
||||
reward = reward / sqrt(sum_of_squares)
|
||||
-]]
|
||||
|
||||
local scale = reward / self.poptop * self.beta / 2
|
||||
for j, v in ipairs(noisy) do
|
||||
step[j] = step[j] + scale * v
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
for j, v in ipairs(noisy) do
|
||||
step[j] = step[j] + reward * v / self.poptop
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
error("TODO: update with sum of squares stuff")
|
||||
for i, ind in ipairs(indices) do
|
||||
local reward = top_rewards[i] / reward_dev
|
||||
if reward ~= 0 then
|
||||
local noisy = self.noise[ind]
|
||||
|
||||
local scale = reward / self.poptop * self.beta
|
||||
for j, v in ipairs(noisy) do
|
||||
step[j] = step[j] + reward * v / self.poptop
|
||||
step[j] = step[j] + scale * v
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--[[ powersign momentum
|
||||
if self.momentum > 0 then
|
||||
for i, v in ipairs(step) do
|
||||
self.accum[i] = self.momentum * self.accum[i] + v
|
||||
|
@ -212,6 +206,35 @@ function Ars:tell(scored, unperturbed_score)
|
|||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v + self.param_rate * step[i]
|
||||
end
|
||||
--]]
|
||||
|
||||
-- neumann momentum
|
||||
if self.momentum > 0 then
|
||||
local count = self.count or 0
|
||||
local period = 10
|
||||
local mu = 1 - 1 / (1 + count % period)
|
||||
mu = self.momentum / (1 - 1 / period) * mu
|
||||
self.count = count + 1
|
||||
-- mu is intentionally 0 for one iteration.
|
||||
|
||||
-- make learning rate invariant to sigma.
|
||||
for i, v in ipairs(step) do
|
||||
step[i] = v / self.sigma
|
||||
end
|
||||
|
||||
-- update neumann iterate.
|
||||
for i, v in ipairs(self.accum) do
|
||||
self.accum[i] = mu * v - self.param_rate * step[i]
|
||||
end
|
||||
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v - mu * self.accum[i] + self.param_rate * step[i]
|
||||
end
|
||||
else
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v + self.param_rate * step[i]
|
||||
end
|
||||
end
|
||||
|
||||
self.noise = nil
|
||||
|
||||
|
|
749
binser.lua
Normal file
749
binser.lua
Normal file
|
@ -0,0 +1,749 @@
|
|||
-- binser.lua
|
||||
|
||||
--[[
|
||||
Copyright (c) 2016 Calvin Rose
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
]]
|
||||
|
||||
local assert = assert
|
||||
local error = error
|
||||
local select = select
|
||||
local pairs = pairs
|
||||
local getmetatable = getmetatable
|
||||
local setmetatable = setmetatable
|
||||
local type = type
|
||||
local loadstring = loadstring or load
|
||||
local concat = table.concat
|
||||
local char = string.char
|
||||
local byte = string.byte
|
||||
local format = string.format
|
||||
local sub = string.sub
|
||||
local dump = string.dump
|
||||
local floor = math.floor
|
||||
local frexp = math.frexp
|
||||
local unpack = unpack or table.unpack
|
||||
|
||||
-- Lua 5.3 frexp polyfill
|
||||
-- From https://github.com/excessive/cpml/blob/master/modules/utils.lua
|
||||
if not frexp then
|
||||
local log, abs, floor = math.log, math.abs, math.floor
|
||||
local log2 = log(2)
|
||||
frexp = function(x)
|
||||
if x == 0 then return 0, 0 end
|
||||
local e = floor(log(abs(x)) / log2 + 1)
|
||||
return x / 2 ^ e, e
|
||||
end
|
||||
end
|
||||
|
||||
local function pack(...)
|
||||
return {...}, select("#", ...)
|
||||
end
|
||||
|
||||
local function not_array_index(x, len)
|
||||
return type(x) ~= "number" or x < 1 or x > len or x ~= floor(x)
|
||||
end
|
||||
|
||||
local function type_check(x, tp, name)
|
||||
assert(type(x) == tp,
|
||||
format("Expected parameter %q to be of type %q.", name, tp))
|
||||
end
|
||||
|
||||
local bigIntSupport = false
|
||||
local isInteger
|
||||
if math.type then -- Detect Lua 5.3
|
||||
local mtype = math.type
|
||||
bigIntSupport = loadstring[[
|
||||
local char = string.char
|
||||
return function(n)
|
||||
local nn = n < 0 and -(n + 1) or n
|
||||
local b1 = nn // 0x100000000000000
|
||||
local b2 = nn // 0x1000000000000 % 0x100
|
||||
local b3 = nn // 0x10000000000 % 0x100
|
||||
local b4 = nn // 0x100000000 % 0x100
|
||||
local b5 = nn // 0x1000000 % 0x100
|
||||
local b6 = nn // 0x10000 % 0x100
|
||||
local b7 = nn // 0x100 % 0x100
|
||||
local b8 = nn % 0x100
|
||||
if n < 0 then
|
||||
b1, b2, b3, b4 = 0xFF - b1, 0xFF - b2, 0xFF - b3, 0xFF - b4
|
||||
b5, b6, b7, b8 = 0xFF - b5, 0xFF - b6, 0xFF - b7, 0xFF - b8
|
||||
end
|
||||
return char(212, b1, b2, b3, b4, b5, b6, b7, b8)
|
||||
end]]()
|
||||
isInteger = function(x)
|
||||
return mtype(x) == 'integer'
|
||||
end
|
||||
else
|
||||
isInteger = function(x)
|
||||
return floor(x) == x
|
||||
end
|
||||
end
|
||||
|
||||
-- Copyright (C) 2012-2015 Francois Perrad.
|
||||
-- number serialization code modified from https://github.com/fperrad/lua-MessagePack
|
||||
-- Encode a number as a big-endian ieee-754 double, big-endian signed 64 bit integer, or a small integer
|
||||
local function number_to_str(n)
|
||||
if isInteger(n) then -- int
|
||||
if n <= 100 and n >= -27 then -- 1 byte, 7 bits of data
|
||||
return char(n + 27)
|
||||
elseif n <= 8191 and n >= -8192 then -- 2 bytes, 14 bits of data
|
||||
n = n + 8192
|
||||
return char(128 + (floor(n / 0x100) % 0x100), n % 0x100)
|
||||
elseif bigIntSupport then
|
||||
return bigIntSupport(n)
|
||||
end
|
||||
end
|
||||
local sign = 0
|
||||
if n < 0.0 then
|
||||
sign = 0x80
|
||||
n = -n
|
||||
end
|
||||
local m, e = frexp(n) -- mantissa, exponent
|
||||
if m ~= m then
|
||||
return char(203, 0xFF, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
|
||||
elseif m == 1/0 then
|
||||
if sign == 0 then
|
||||
return char(203, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
|
||||
else
|
||||
return char(203, 0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
|
||||
end
|
||||
end
|
||||
e = e + 0x3FE
|
||||
if e < 1 then -- denormalized numbers
|
||||
m = m * 2 ^ (52 + e)
|
||||
e = 0
|
||||
else
|
||||
m = (m * 2 - 1) * 2 ^ 52
|
||||
end
|
||||
return char(203,
|
||||
sign + floor(e / 0x10),
|
||||
(e % 0x10) * 0x10 + floor(m / 0x1000000000000),
|
||||
floor(m / 0x10000000000) % 0x100,
|
||||
floor(m / 0x100000000) % 0x100,
|
||||
floor(m / 0x1000000) % 0x100,
|
||||
floor(m / 0x10000) % 0x100,
|
||||
floor(m / 0x100) % 0x100,
|
||||
m % 0x100)
|
||||
end
|
||||
|
||||
-- Copyright (C) 2012-2015 Francois Perrad.
|
||||
-- number deserialization code also modified from https://github.com/fperrad/lua-MessagePack
|
||||
local function number_from_str(str, index)
|
||||
local b = byte(str, index)
|
||||
if not b then error("Expected more bytes of input.") end
|
||||
if b < 128 then
|
||||
return b - 27, index + 1
|
||||
elseif b < 192 then
|
||||
local b2 = byte(str, index + 1)
|
||||
if not b2 then error("Expected more bytes of input.") end
|
||||
return b2 + 0x100 * (b - 128) - 8192, index + 2
|
||||
end
|
||||
local b1, b2, b3, b4, b5, b6, b7, b8 = byte(str, index + 1, index + 8)
|
||||
if (not b1) or (not b2) or (not b3) or (not b4) or
|
||||
(not b5) or (not b6) or (not b7) or (not b8) then
|
||||
error("Expected more bytes of input.")
|
||||
end
|
||||
if b == 212 then
|
||||
local flip = b1 >= 128
|
||||
if flip then -- negative
|
||||
b1, b2, b3, b4 = 0xFF - b1, 0xFF - b2, 0xFF - b3, 0xFF - b4
|
||||
b5, b6, b7, b8 = 0xFF - b5, 0xFF - b6, 0xFF - b7, 0xFF - b8
|
||||
end
|
||||
local n = ((((((b1 * 0x100 + b2) * 0x100 + b3) * 0x100 + b4) *
|
||||
0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8
|
||||
if flip then
|
||||
return (-n) - 1, index + 9
|
||||
else
|
||||
return n, index + 9
|
||||
end
|
||||
end
|
||||
if b ~= 203 then
|
||||
error("Expected number")
|
||||
end
|
||||
local sign = b1 > 0x7F and -1 or 1
|
||||
local e = (b1 % 0x80) * 0x10 + floor(b2 / 0x10)
|
||||
local m = ((((((b2 % 0x10) * 0x100 + b3) * 0x100 + b4) * 0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8
|
||||
local n
|
||||
if e == 0 then
|
||||
if m == 0 then
|
||||
n = sign * 0.0
|
||||
else
|
||||
n = sign * (m / 2 ^ 52) * 2 ^ -1022
|
||||
end
|
||||
elseif e == 0x7FF then
|
||||
if m == 0 then
|
||||
n = sign * (1/0)
|
||||
else
|
||||
n = 0.0/0.0
|
||||
end
|
||||
else
|
||||
n = sign * (1.0 + m / 2 ^ 52) * 2 ^ (e - 0x3FF)
|
||||
end
|
||||
return n, index + 9
|
||||
end
|
||||
|
||||
|
||||
local function newbinser()
|
||||
|
||||
-- unique table key for getting next value
|
||||
local NEXT = {}
|
||||
local CTORSTACK = {}
|
||||
|
||||
-- NIL = 202
|
||||
-- FLOAT = 203
|
||||
-- TRUE = 204
|
||||
-- FALSE = 205
|
||||
-- STRING = 206
|
||||
-- TABLE = 207
|
||||
-- REFERENCE = 208
|
||||
-- CONSTRUCTOR = 209
|
||||
-- FUNCTION = 210
|
||||
-- RESOURCE = 211
|
||||
-- INT64 = 212
|
||||
-- TABLE WITH META = 213
|
||||
|
||||
local mts = {}
|
||||
local ids = {}
|
||||
local serializers = {}
|
||||
local deserializers = {}
|
||||
local resources = {}
|
||||
local resources_by_name = {}
|
||||
local types = {}
|
||||
|
||||
types["nil"] = function(x, visited, accum)
|
||||
accum[#accum + 1] = "\202"
|
||||
end
|
||||
|
||||
function types.number(x, visited, accum)
|
||||
accum[#accum + 1] = number_to_str(x)
|
||||
end
|
||||
|
||||
function types.boolean(x, visited, accum)
|
||||
accum[#accum + 1] = x and "\204" or "\205"
|
||||
end
|
||||
|
||||
function types.string(x, visited, accum)
|
||||
local alen = #accum
|
||||
if visited[x] then
|
||||
accum[alen + 1] = "\208"
|
||||
accum[alen + 2] = number_to_str(visited[x])
|
||||
else
|
||||
visited[x] = visited[NEXT]
|
||||
visited[NEXT] = visited[NEXT] + 1
|
||||
accum[alen + 1] = "\206"
|
||||
accum[alen + 2] = number_to_str(#x)
|
||||
accum[alen + 3] = x
|
||||
end
|
||||
end
|
||||
|
||||
local function check_custom_type(x, visited, accum)
|
||||
local res = resources[x]
|
||||
if res then
|
||||
accum[#accum + 1] = "\211"
|
||||
types[type(res)](res, visited, accum)
|
||||
return true
|
||||
end
|
||||
local mt = getmetatable(x)
|
||||
local id = mt and ids[mt]
|
||||
if id then
|
||||
local constructing = visited[CTORSTACK]
|
||||
if constructing[x] then
|
||||
error("Infinite loop in constructor.")
|
||||
end
|
||||
constructing[x] = true
|
||||
accum[#accum + 1] = "\209"
|
||||
types[type(id)](id, visited, accum)
|
||||
local args, len = pack(serializers[id](x))
|
||||
accum[#accum + 1] = number_to_str(len)
|
||||
for i = 1, len do
|
||||
local arg = args[i]
|
||||
types[type(arg)](arg, visited, accum)
|
||||
end
|
||||
visited[x] = visited[NEXT]
|
||||
visited[NEXT] = visited[NEXT] + 1
|
||||
-- We finished constructing
|
||||
constructing[x] = nil
|
||||
return true
|
||||
end
|
||||
end
|
||||
|
||||
function types.userdata(x, visited, accum)
|
||||
if visited[x] then
|
||||
accum[#accum + 1] = "\208"
|
||||
accum[#accum + 1] = number_to_str(visited[x])
|
||||
else
|
||||
if check_custom_type(x, visited, accum) then return end
|
||||
error("Cannot serialize this userdata.")
|
||||
end
|
||||
end
|
||||
|
||||
function types.table(x, visited, accum)
|
||||
if visited[x] then
|
||||
accum[#accum + 1] = "\208"
|
||||
accum[#accum + 1] = number_to_str(visited[x])
|
||||
else
|
||||
if check_custom_type(x, visited, accum) then return end
|
||||
visited[x] = visited[NEXT]
|
||||
visited[NEXT] = visited[NEXT] + 1
|
||||
local xlen = #x
|
||||
local mt = getmetatable(x)
|
||||
if mt then
|
||||
accum[#accum + 1] = "\213"
|
||||
types.table(mt, visited, accum)
|
||||
else
|
||||
accum[#accum + 1] = "\207"
|
||||
end
|
||||
accum[#accum + 1] = number_to_str(xlen)
|
||||
for i = 1, xlen do
|
||||
local v = x[i]
|
||||
types[type(v)](v, visited, accum)
|
||||
end
|
||||
local key_count = 0
|
||||
for k in pairs(x) do
|
||||
if not_array_index(k, xlen) then
|
||||
key_count = key_count + 1
|
||||
end
|
||||
end
|
||||
accum[#accum + 1] = number_to_str(key_count)
|
||||
for k, v in pairs(x) do
|
||||
if not_array_index(k, xlen) then
|
||||
types[type(k)](k, visited, accum)
|
||||
types[type(v)](v, visited, accum)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
types["function"] = function(x, visited, accum)
|
||||
if visited[x] then
|
||||
accum[#accum + 1] = "\208"
|
||||
accum[#accum + 1] = number_to_str(visited[x])
|
||||
else
|
||||
if check_custom_type(x, visited, accum) then return end
|
||||
visited[x] = visited[NEXT]
|
||||
visited[NEXT] = visited[NEXT] + 1
|
||||
local str = dump(x)
|
||||
accum[#accum + 1] = "\210"
|
||||
accum[#accum + 1] = number_to_str(#str)
|
||||
accum[#accum + 1] = str
|
||||
end
|
||||
end
|
||||
|
||||
types.cdata = function(x, visited, accum)
|
||||
if visited[x] then
|
||||
accum[#accum + 1] = "\208"
|
||||
accum[#accum + 1] = number_to_str(visited[x])
|
||||
else
|
||||
if check_custom_type(x, visited, #accum) then return end
|
||||
error("Cannot serialize this cdata.")
|
||||
end
|
||||
end
|
||||
|
||||
types.thread = function() error("Cannot serialize threads.") end
|
||||
|
||||
local function deserialize_value(str, index, visited)
|
||||
local t = byte(str, index)
|
||||
if not t then return nil, index end
|
||||
if t < 128 then
|
||||
return t - 27, index + 1
|
||||
elseif t < 192 then
|
||||
local b2 = byte(str, index + 1)
|
||||
if not b2 then error("Expected more bytes of input.") end
|
||||
return b2 + 0x100 * (t - 128) - 8192, index + 2
|
||||
elseif t == 202 then
|
||||
return nil, index + 1
|
||||
elseif t == 203 or t == 212 then
|
||||
return number_from_str(str, index)
|
||||
elseif t == 204 then
|
||||
return true, index + 1
|
||||
elseif t == 205 then
|
||||
return false, index + 1
|
||||
elseif t == 206 then
|
||||
local length, dataindex = number_from_str(str, index + 1)
|
||||
local nextindex = dataindex + length
|
||||
if not (length >= 0) then error("Bad string length") end
|
||||
if #str < nextindex - 1 then error("Expected more bytes of string") end
|
||||
local substr = sub(str, dataindex, nextindex - 1)
|
||||
visited[#visited + 1] = substr
|
||||
return substr, nextindex
|
||||
elseif t == 207 or t == 213 then
|
||||
local mt, count, nextindex
|
||||
local ret = {}
|
||||
visited[#visited + 1] = ret
|
||||
nextindex = index + 1
|
||||
if t == 213 then
|
||||
mt, nextindex = deserialize_value(str, nextindex, visited)
|
||||
if type(mt) ~= "table" then error("Expected table metatable") end
|
||||
end
|
||||
count, nextindex = number_from_str(str, nextindex)
|
||||
for i = 1, count do
|
||||
local oldindex = nextindex
|
||||
ret[i], nextindex = deserialize_value(str, nextindex, visited)
|
||||
if nextindex == oldindex then error("Expected more bytes of input.") end
|
||||
end
|
||||
count, nextindex = number_from_str(str, nextindex)
|
||||
for i = 1, count do
|
||||
local k, v
|
||||
local oldindex = nextindex
|
||||
k, nextindex = deserialize_value(str, nextindex, visited)
|
||||
if nextindex == oldindex then error("Expected more bytes of input.") end
|
||||
oldindex = nextindex
|
||||
v, nextindex = deserialize_value(str, nextindex, visited)
|
||||
if nextindex == oldindex then error("Expected more bytes of input.") end
|
||||
if k == nil then error("Can't have nil table keys") end
|
||||
ret[k] = v
|
||||
end
|
||||
if mt then setmetatable(ret, mt) end
|
||||
return ret, nextindex
|
||||
elseif t == 208 then
|
||||
local ref, nextindex = number_from_str(str, index + 1)
|
||||
return visited[ref], nextindex
|
||||
elseif t == 209 then
|
||||
local count
|
||||
local name, nextindex = deserialize_value(str, index + 1, visited)
|
||||
count, nextindex = number_from_str(str, nextindex)
|
||||
local args = {}
|
||||
for i = 1, count do
|
||||
local oldindex = nextindex
|
||||
args[i], nextindex = deserialize_value(str, nextindex, visited)
|
||||
if nextindex == oldindex then error("Expected more bytes of input.") end
|
||||
end
|
||||
if not name or not deserializers[name] then
|
||||
error(("Cannot deserialize class '%s'"):format(tostring(name)))
|
||||
end
|
||||
local ret = deserializers[name](unpack(args))
|
||||
visited[#visited + 1] = ret
|
||||
return ret, nextindex
|
||||
elseif t == 210 then
|
||||
local length, dataindex = number_from_str(str, index + 1)
|
||||
local nextindex = dataindex + length
|
||||
if not (length >= 0) then error("Bad string length") end
|
||||
if #str < nextindex - 1 then error("Expected more bytes of string") end
|
||||
local ret = loadstring(sub(str, dataindex, nextindex - 1))
|
||||
visited[#visited + 1] = ret
|
||||
return ret, nextindex
|
||||
elseif t == 211 then
|
||||
local resname, nextindex = deserialize_value(str, index + 1, visited)
|
||||
if resname == nil then error("Got nil resource name") end
|
||||
local res = resources_by_name[resname]
|
||||
if res == nil then
|
||||
error(("No resources found for name '%s'"):format(tostring(resname)))
|
||||
end
|
||||
return res, nextindex
|
||||
else
|
||||
error("Could not deserialize type byte " .. t .. ".")
|
||||
end
|
||||
end
|
||||
|
||||
local function serialize(...)
|
||||
local visited = {[NEXT] = 1, [CTORSTACK] = {}}
|
||||
local accum = {}
|
||||
for i = 1, select("#", ...) do
|
||||
local x = select(i, ...)
|
||||
types[type(x)](x, visited, accum)
|
||||
end
|
||||
return concat(accum)
|
||||
end
|
||||
|
||||
local function make_file_writer(file)
|
||||
return setmetatable({}, {
|
||||
__newindex = function(_, _, v)
|
||||
file:write(v)
|
||||
end
|
||||
})
|
||||
end
|
||||
|
||||
local function serialize_to_file(path, mode, ...)
|
||||
local file, err = io.open(path, mode)
|
||||
assert(file, err)
|
||||
local visited = {[NEXT] = 1, [CTORSTACK] = {}}
|
||||
local accum = make_file_writer(file)
|
||||
for i = 1, select("#", ...) do
|
||||
local x = select(i, ...)
|
||||
types[type(x)](x, visited, accum)
|
||||
end
|
||||
-- flush the writer
|
||||
file:flush()
|
||||
file:close()
|
||||
end
|
||||
|
||||
local function writeFile(path, ...)
|
||||
return serialize_to_file(path, "wb", ...)
|
||||
end
|
||||
|
||||
local function appendFile(path, ...)
|
||||
return serialize_to_file(path, "ab", ...)
|
||||
end
|
||||
|
||||
local function deserialize(str, index)
|
||||
assert(type(str) == "string", "Expected string to deserialize.")
|
||||
local vals = {}
|
||||
index = index or 1
|
||||
local visited = {}
|
||||
local len = 0
|
||||
local val
|
||||
while true do
|
||||
local nextindex
|
||||
val, nextindex = deserialize_value(str, index, visited)
|
||||
if nextindex > index then
|
||||
len = len + 1
|
||||
vals[len] = val
|
||||
index = nextindex
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
return vals, len
|
||||
end
|
||||
|
||||
local function deserializeN(str, n, index)
|
||||
assert(type(str) == "string", "Expected string to deserialize.")
|
||||
n = n or 1
|
||||
assert(type(n) == "number", "Expected a number for parameter n.")
|
||||
assert(n > 0 and floor(n) == n, "N must be a poitive integer.")
|
||||
local vals = {}
|
||||
index = index or 1
|
||||
local visited = {}
|
||||
local len = 0
|
||||
local val
|
||||
while len < n do
|
||||
local nextindex
|
||||
val, nextindex = deserialize_value(str, index, visited)
|
||||
if nextindex > index then
|
||||
len = len + 1
|
||||
vals[len] = val
|
||||
index = nextindex
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
vals[len + 1] = index
|
||||
return unpack(vals, 1, n + 1)
|
||||
end
|
||||
|
||||
local function readFile(path)
|
||||
local file, err = io.open(path, "rb")
|
||||
assert(file, err)
|
||||
local str = file:read("*all")
|
||||
file:close()
|
||||
return deserialize(str)
|
||||
end
|
||||
|
||||
-- Resources
|
||||
|
||||
local function registerResource(resource, name)
|
||||
type_check(name, "string", "name")
|
||||
assert(not resources[resource],
|
||||
"Resource already registered.")
|
||||
assert(not resources_by_name[name],
|
||||
format("Resource %q already exists.", name))
|
||||
resources_by_name[name] = resource
|
||||
resources[resource] = name
|
||||
return resource
|
||||
end
|
||||
|
||||
local function unregisterResource(name)
|
||||
type_check(name, "string", "name")
|
||||
assert(resources_by_name[name], format("Resource %q does not exist.", name))
|
||||
local resource = resources_by_name[name]
|
||||
resources_by_name[name] = nil
|
||||
resources[resource] = nil
|
||||
return resource
|
||||
end
|
||||
|
||||
-- Templating
|
||||
|
||||
local function normalize_template(template)
|
||||
local ret = {}
|
||||
for i = 1, #template do
|
||||
ret[i] = template[i]
|
||||
end
|
||||
local non_array_part = {}
|
||||
-- The non-array part of the template (nested templates) have to be deterministic, so they are sorted.
|
||||
-- This means that inherently non deterministicly sortable keys (tables, functions) should NOT be used
|
||||
-- in templates. Looking for way around this.
|
||||
for k in pairs(template) do
|
||||
if not_array_index(k, #template) then
|
||||
non_array_part[#non_array_part + 1] = k
|
||||
end
|
||||
end
|
||||
table.sort(non_array_part)
|
||||
for i = 1, #non_array_part do
|
||||
local name = non_array_part[i]
|
||||
ret[#ret + 1] = {name, normalize_template(template[name])}
|
||||
end
|
||||
return ret
|
||||
end
|
||||
|
||||
local function templatepart_serialize(part, argaccum, x, len)
|
||||
local extras = {}
|
||||
local extracount = 0
|
||||
for k, v in pairs(x) do
|
||||
extras[k] = v
|
||||
extracount = extracount + 1
|
||||
end
|
||||
for i = 1, #part do
|
||||
local name
|
||||
if type(part[i]) == "table" then
|
||||
name = part[i][1]
|
||||
len = templatepart_serialize(part[i][2], argaccum, x[name], len)
|
||||
else
|
||||
name = part[i]
|
||||
len = len + 1
|
||||
argaccum[len] = x[part[i]]
|
||||
end
|
||||
if extras[name] ~= nil then
|
||||
extracount = extracount - 1
|
||||
extras[name] = nil
|
||||
end
|
||||
end
|
||||
if extracount > 0 then
|
||||
argaccum[len + 1] = extras
|
||||
else
|
||||
argaccum[len + 1] = nil
|
||||
end
|
||||
return len + 1
|
||||
end
|
||||
|
||||
local function templatepart_deserialize(ret, part, values, vindex)
|
||||
for i = 1, #part do
|
||||
local name = part[i]
|
||||
if type(name) == "table" then
|
||||
local newret = {}
|
||||
ret[name[1]] = newret
|
||||
vindex = templatepart_deserialize(newret, name[2], values, vindex)
|
||||
else
|
||||
ret[name] = values[vindex]
|
||||
vindex = vindex + 1
|
||||
end
|
||||
end
|
||||
local extras = values[vindex]
|
||||
if extras then
|
||||
for k, v in pairs(extras) do
|
||||
ret[k] = v
|
||||
end
|
||||
end
|
||||
return vindex + 1
|
||||
end
|
||||
|
||||
local function template_serializer_and_deserializer(metatable, template)
|
||||
return function(x)
|
||||
local argaccum = {}
|
||||
local len = templatepart_serialize(template, argaccum, x, 0)
|
||||
return unpack(argaccum, 1, len)
|
||||
end, function(...)
|
||||
local ret = {}
|
||||
local args = {...}
|
||||
templatepart_deserialize(ret, template, args, 1)
|
||||
return setmetatable(ret, metatable)
|
||||
end
|
||||
end
|
||||
|
||||
-- Used to serialize classes withh custom serializers and deserializers.
|
||||
-- If no _serialize or _deserialize (or no _template) value is found in the
|
||||
-- metatable, then the metatable is registered as a resources.
|
||||
local function register(metatable, name, serialize, deserialize)
|
||||
if type(metatable) == "table" then
|
||||
name = name or metatable.name
|
||||
serialize = serialize or metatable._serialize
|
||||
deserialize = deserialize or metatable._deserialize
|
||||
if (not serialize) or (not deserialize) then
|
||||
if metatable._template then
|
||||
-- Register as template
|
||||
local t = normalize_template(metatable._template)
|
||||
serialize, deserialize = template_serializer_and_deserializer(metatable, t)
|
||||
else
|
||||
-- Register the metatable as a resource. This is semantically
|
||||
-- similar and more flexible (handles cycles).
|
||||
registerResource(metatable, name)
|
||||
return
|
||||
end
|
||||
end
|
||||
elseif type(metatable) == "string" then
|
||||
name = name or metatable
|
||||
end
|
||||
type_check(name, "string", "name")
|
||||
type_check(serialize, "function", "serialize")
|
||||
type_check(deserialize, "function", "deserialize")
|
||||
assert((not ids[metatable]) and (not resources[metatable]),
|
||||
"Metatable already registered.")
|
||||
assert((not mts[name]) and (not resources_by_name[name]),
|
||||
("Name %q already registered."):format(name))
|
||||
mts[name] = metatable
|
||||
ids[metatable] = name
|
||||
serializers[name] = serialize
|
||||
deserializers[name] = deserialize
|
||||
return metatable
|
||||
end
|
||||
|
||||
local function unregister(item)
|
||||
local name, metatable
|
||||
if type(item) == "string" then -- assume name
|
||||
name, metatable = item, mts[item]
|
||||
else -- assume metatable
|
||||
name, metatable = ids[item], item
|
||||
end
|
||||
type_check(name, "string", "name")
|
||||
mts[name] = nil
|
||||
if (metatable) then
|
||||
resources[metatable] = nil
|
||||
ids[metatable] = nil
|
||||
end
|
||||
serializers[name] = nil
|
||||
deserializers[name] = nil
|
||||
resources_by_name[name] = nil;
|
||||
return metatable
|
||||
end
|
||||
|
||||
local function registerClass(class, name)
|
||||
name = name or class.name
|
||||
if class.__instanceDict then -- middleclass
|
||||
register(class.__instanceDict, name)
|
||||
else -- assume 30log or similar library
|
||||
register(class, name)
|
||||
end
|
||||
return class
|
||||
end
|
||||
|
||||
return {
|
||||
-- aliases
|
||||
s = serialize,
|
||||
d = deserialize,
|
||||
dn = deserializeN,
|
||||
r = readFile,
|
||||
w = writeFile,
|
||||
a = appendFile,
|
||||
|
||||
serialize = serialize,
|
||||
deserialize = deserialize,
|
||||
deserializeN = deserializeN,
|
||||
readFile = readFile,
|
||||
writeFile = writeFile,
|
||||
appendFile = appendFile,
|
||||
register = register,
|
||||
unregister = unregister,
|
||||
registerResource = registerResource,
|
||||
unregisterResource = unregisterResource,
|
||||
registerClass = registerClass,
|
||||
|
||||
newbinser = newbinser
|
||||
}
|
||||
end
|
||||
|
||||
return newbinser()
|
15
config.lua
15
config.lua
|
@ -26,14 +26,17 @@ local defaults = {
|
|||
time_inputs = true, -- insert binary inputs of a frame counter.
|
||||
|
||||
-- network layers:
|
||||
embed = true, -- set to false to use a hard-coded tile embedding.
|
||||
reduce_tiles = 0, -- TODO: write description
|
||||
hidden = false, -- use a hidden layer with ReLU/GELU activation.
|
||||
hidden_size = 128,
|
||||
layernorm = false, -- use a LayerNorm layer after said activation.
|
||||
reduce_tiles = false,
|
||||
bias_out = true,
|
||||
|
||||
-- network evaluation (sampling joypad):
|
||||
frameskip = 4,
|
||||
prob_frameskip = 0.0,
|
||||
max_frameskip = 6,
|
||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||
deterministic = false, -- use argmax on outputs instead of random sampling.
|
||||
det_epsilon = false, -- take random actions with probability eps.
|
||||
|
@ -41,12 +44,16 @@ local defaults = {
|
|||
-- evolution strategy and non-rate hyperparemeters:
|
||||
es = 'ars',
|
||||
ars_lips = false, -- for ARS.
|
||||
epoch_top_trials = 9999, -- for ARS.
|
||||
epoch_top_trials = 9999, -- for ARS, Guided.
|
||||
alpha = 0.5, -- for Guided.
|
||||
beta = 2.0, -- for ARS, Guided. should be 1, but defaults to 2 for compat.
|
||||
past_grads = 1, -- for Guided. keeps a history of n past steps taken.
|
||||
|
||||
-- sampling:
|
||||
deviation = 1.0,
|
||||
unperturbed_trial = true, -- perform an extra trial without any noise.
|
||||
-- this is good for logging, so i'd recommend it.
|
||||
attempts = 1, -- TODO: document.
|
||||
epoch_trials = 50,
|
||||
graycode = false, -- for ARS.
|
||||
negate_trials = true, -- try pairs of normal and negated noise directions.
|
||||
|
@ -113,5 +120,9 @@ assert(not cfg.ars_lips or cfg.negate_trials,
|
|||
"cfg.negate_trials must be true to use cfg.ars_lips")
|
||||
assert(not (cfg.es == 'snes' and cfg.negate_trials),
|
||||
"cfg.negate_trials is not yet compatible with SNES")
|
||||
assert(not (cfg.es == 'guided' and cfg.graycode),
|
||||
"cfg.graycode is not compatible with Guided")
|
||||
assert(cfg.es ~= 'guided' or cfg.negate_trials,
|
||||
"cfg.negate_trials must be true to use Guided")
|
||||
|
||||
return cfg
|
||||
|
|
10
cossim_test.lua
Normal file
10
cossim_test.lua
Normal file
|
@ -0,0 +1,10 @@
|
|||
local util = require "util"
|
||||
|
||||
for dims = 1, 100 do
|
||||
print(dims, util.expect_cossim(dims))
|
||||
end
|
||||
|
||||
dims = 1000
|
||||
print(dims, util.expect_cossim(dims))
|
||||
dims = 10000
|
||||
print(dims, util.expect_cossim(dims))
|
290
eig.lua
Normal file
290
eig.lua
Normal file
|
@ -0,0 +1,290 @@
|
|||
-- blah
|
||||
|
||||
--local globalize = require "strict"
|
||||
local nn = require "nn"
|
||||
local util = require "util"
|
||||
|
||||
local sign = util.sign
|
||||
local abs = math.abs
|
||||
local reshape = nn.reshape
|
||||
local sqrt = math.sqrt
|
||||
local zeros = nn.zeros
|
||||
|
||||
local function tred2(a)
|
||||
assert(#a.shape == 2)
|
||||
assert(a.shape[1] == a.shape[2])
|
||||
local n = a.shape[1]
|
||||
local d = zeros(n) -- diagonal
|
||||
local e = zeros(n) -- off-diagonal (e[1] is a dummy value?)
|
||||
|
||||
for i = n, 2, -1 do
|
||||
local l = i - 1
|
||||
local ind_li = (l - 1) * n + i
|
||||
local h = 0
|
||||
local scale = 0
|
||||
|
||||
if l > 1 then
|
||||
for k = 1, l do
|
||||
local ind_ki = (k - 1) * n + i
|
||||
scale = scale + abs(a[ind_ki])
|
||||
end
|
||||
|
||||
if scale == 0 then
|
||||
e[i] = a[ind_li]
|
||||
else
|
||||
for k = 1, l do
|
||||
local ind_ki = (k - 1) * n + i
|
||||
a[ind_ki] = a[ind_ki] / scale
|
||||
h = h + a[ind_ki] * a[ind_ki]
|
||||
end
|
||||
|
||||
local f = a[ind_li]
|
||||
local g = sqrt(h)
|
||||
if f >= 0 then g = -g end
|
||||
e[i] = scale * g
|
||||
h = h - f * g
|
||||
a[ind_li] = f - g
|
||||
f = 0
|
||||
|
||||
for j = 1, l do
|
||||
local ind_ij = (i - 1) * n + j
|
||||
local ind_ji = (j - 1) * n + i
|
||||
a[ind_ij] = a[ind_ji] / h
|
||||
g = 0
|
||||
for k = 1, j do
|
||||
local ind_kj = (k - 1) * n + j
|
||||
local ind_ki = (k - 1) * n + i
|
||||
g = g + a[ind_kj] * a[ind_ki]
|
||||
end
|
||||
for k = j + 1, l do
|
||||
local ind_jk = (j - 1) * n + k
|
||||
local ind_ki = (k - 1) * n + i
|
||||
g = g + a[ind_jk] * a[ind_ki]
|
||||
end
|
||||
e[j] = g / h
|
||||
f = f + e[j] * a[ind_ji]
|
||||
end
|
||||
|
||||
local hh = f / (h + h)
|
||||
for j = 1, l do
|
||||
local ind_ji = (j - 1) * n + i
|
||||
f = a[ind_ji]
|
||||
g = e[j] - hh * f
|
||||
e[j] = g
|
||||
for k = 1, j do
|
||||
local ind_kj = (k - 1) * n + j
|
||||
local ind_ki = (k - 1) * n + i
|
||||
a[ind_kj] = a[ind_kj] - (f * e[k] + g * a[ind_ki])
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
e[i] = a[ind_li]
|
||||
end
|
||||
|
||||
d[i] = h
|
||||
end
|
||||
|
||||
d[1] = 0
|
||||
e[1] = 0
|
||||
for i = 1, n do
|
||||
local l = i - 1
|
||||
|
||||
if d[i] ~= 0 then
|
||||
for j = 1, l do
|
||||
local g = 0
|
||||
for k = 1, l do
|
||||
local ind_ki = (k - 1) * n + i
|
||||
local ind_jk = (j - 1) * n + k
|
||||
g = g + a[ind_ki] * a[ind_jk]
|
||||
end
|
||||
for k = 1, l do
|
||||
local ind_ik = (i - 1) * n + k
|
||||
local ind_jk = (j - 1) * n + k
|
||||
a[ind_jk] = a[ind_jk] - g * a[ind_ik]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local ind_ii = (i - 1) * n + i
|
||||
d[i] = a[ind_ii]
|
||||
a[ind_ii] = 1
|
||||
for j = 1, l do
|
||||
local ind_ij = (i - 1) * n + j
|
||||
local ind_ji = (j - 1) * n + i
|
||||
a[ind_ij] = 0
|
||||
a[ind_ji] = 0
|
||||
end
|
||||
end
|
||||
|
||||
return d, e
|
||||
end
|
||||
|
||||
local function pythag(a, b)
|
||||
--return sqrt(a * a + b * b)
|
||||
local abs_a = abs(a)
|
||||
local abs_b = abs(b)
|
||||
|
||||
if abs_a > abs_b then
|
||||
local temp = abs_b / abs_a
|
||||
temp = temp * temp
|
||||
return abs_a * sqrt(1 + temp)
|
||||
elseif abs_b ~= 0 then
|
||||
local temp = abs_a / abs_b
|
||||
temp = temp * temp
|
||||
return abs_b * sqrt(1 + temp)
|
||||
end
|
||||
|
||||
return 0
|
||||
end
|
||||
|
||||
local function tqli(d, e, z)
|
||||
assert(#z.shape == 2)
|
||||
assert(z.shape[1] == z.shape[2])
|
||||
local n = z.shape[1]
|
||||
assert(#d == n)
|
||||
assert(#e == n)
|
||||
|
||||
local eps = 1.2e-7
|
||||
|
||||
for i = 2, n do e[i - 1] = e[i] end
|
||||
e[n] = 0
|
||||
|
||||
for l = 1, n do
|
||||
local iter = 0
|
||||
local fucky = 0
|
||||
|
||||
local m
|
||||
while true do
|
||||
m = l
|
||||
while m <= n - 1 do
|
||||
local dd = abs(d[m]) + abs(d[m + 1])
|
||||
if abs(e[m]) + dd == dd then break end
|
||||
--if abs(e[m]) <= eps * dd then break end
|
||||
m = m + 1
|
||||
end
|
||||
|
||||
fucky = fucky + 1
|
||||
if fucky == 100 then print("fucky!"); break end
|
||||
if fucky == 1000 then error("super fucky!"); break end
|
||||
|
||||
--print(("l: %i, m: %i"):format(l - 1, m - 1))
|
||||
|
||||
if m == l then break end
|
||||
|
||||
iter = iter + 1
|
||||
if iter >= 32 then error("Too many iterations in tqli") end
|
||||
|
||||
local g = (d[l + 1] - d[l]) / (2 * e[l])
|
||||
local r = pythag(g, 1)
|
||||
g = d[m] - d[l] + e[l] / (g + r * sign(g))
|
||||
local s = 1
|
||||
local c = 1
|
||||
local p = 0
|
||||
|
||||
for i = m - 1, l, -1 do
|
||||
local f = s * e[i]
|
||||
local b = c * e[i]
|
||||
r = pythag(f, g)
|
||||
e[i + 1] = r
|
||||
if r == 0 then
|
||||
d[i + 1] = d[i + 1] - p
|
||||
e[m] = 0
|
||||
break
|
||||
end
|
||||
|
||||
s = f / r
|
||||
c = g / r
|
||||
g = d[i + 1] - p
|
||||
r = (d[i] - g) * s + 2 * c * b
|
||||
p = s * r
|
||||
d[i + 1] = g + p
|
||||
g = c * r - b
|
||||
|
||||
for k = 1, n do
|
||||
if true then
|
||||
local ind = (i - 1) * n + k
|
||||
f = z[ind + n]
|
||||
z[ind + n] = s * z[ind] + c * f
|
||||
z[ind] = c * z[ind] - s * f
|
||||
else
|
||||
local ind = (k - 1) * n + i
|
||||
f = z[ind + 1]
|
||||
z[ind + 1] = s * z[ind] + c * f
|
||||
z[ind] = c * z[ind] - s * f
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if r == 0 and i >= l then
|
||||
-- continue
|
||||
else
|
||||
d[l] = d[l] - p
|
||||
e[l] = g
|
||||
e[m] = 0
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--[=[
|
||||
|
||||
local A = {
|
||||
4, 1, -2, 2,
|
||||
1, 2, 0, 1,
|
||||
-2, 0, 3, -2,
|
||||
2, 1, -2, -1,
|
||||
}
|
||||
reshape(A, 4, 4)
|
||||
|
||||
local d, e = tred2(A)
|
||||
|
||||
--[[
|
||||
print(nn.pp(A))
|
||||
print(nn.pp(d))
|
||||
print(nn.pp(e))
|
||||
--]]
|
||||
|
||||
--[[
|
||||
{
|
||||
0.248069, 0.744208, 0.620174, 0.000000,
|
||||
0.702863, -0.578829, 0.413449, 0.000000,
|
||||
0.666667, 0.333333, -0.666667, 0.000000,
|
||||
0.000000, 0.000000, 0.000000, 1.000000,
|
||||
}
|
||||
{ 2.261538, 1.182906, 5.555556, -1.000000,}
|
||||
{ 0.000000, -0.092308, 0.895806, 3.000000,}
|
||||
--]]
|
||||
|
||||
--A = nn.transpose(A)
|
||||
tqli(d, e, A)
|
||||
|
||||
print(nn.pp(A))
|
||||
print(nn.pp(d))
|
||||
print(nn.pp(e))
|
||||
|
||||
local D = zeros{4, 4}
|
||||
for i = 1, 4 do
|
||||
D[(i - 1) * 4 + i] = math.exp(d[i])
|
||||
end
|
||||
local out = nn.dot(nn.transpose(A), D)
|
||||
out = nn.dot(out, A)
|
||||
print(nn.pp(out))
|
||||
|
||||
--[[
|
||||
{
|
||||
703.414032, 1410.125991, 1478.990752, -43.126976,
|
||||
-1205.204565, 1573.963121, -940.902581, 319.676625,
|
||||
14.433478, 3.884400, -10.434671, 6.841011,
|
||||
3.529384, 3.087068, -5.116236, -17.850223,
|
||||
}
|
||||
{ 2.273819, 1.072834, 6.818268, -2.164921,}
|
||||
{ -0.000000, 0.000000, 0.000000, 0.000000,}
|
||||
--]]
|
||||
|
||||
--]=]
|
||||
|
||||
return {
|
||||
tred2=tred2,
|
||||
tqli=tqli,
|
||||
}
|
163
es_test.lua
Normal file
163
es_test.lua
Normal file
|
@ -0,0 +1,163 @@
|
|||
local floor = math.floor
|
||||
local insert = table.insert
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local max = math.max
|
||||
local print = print
|
||||
|
||||
local ars = require("ars")
|
||||
local snes = require("snes")
|
||||
local xnes = require("xnes")
|
||||
local guided = require("guided")
|
||||
|
||||
-- try it all out on a dummy problem.
|
||||
|
||||
local function typeof(t) return getmetatable(t).__index end
|
||||
|
||||
local function square(x) return x * x end
|
||||
|
||||
-- this function's global minimum is arange(dims) + 1.
|
||||
-- xNES should be able to find it almost exactly.
|
||||
local function spherical(x)
|
||||
local sum = 0
|
||||
--for i, v in ipairs(x) do sum = sum + square(v - i) end
|
||||
for i, v in ipairs(x) do sum = sum + square(v - i / #x) end
|
||||
-- we need to negate this to turn it into a maximization problem.
|
||||
return -sum
|
||||
end
|
||||
|
||||
-- i'm just copying settings from hardmaru's simple_es_example.ipynb.
|
||||
local iterations = 3000 --4000
|
||||
|
||||
local dims, popsize
|
||||
if false then
|
||||
dims = 100
|
||||
popsize = dims + 1
|
||||
else
|
||||
dims = 30
|
||||
popsize = 99
|
||||
end
|
||||
|
||||
local sigma_init = 0.5
|
||||
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
|
||||
local es = snes.Snes(dims, popsize, 0.1, sigma_init)
|
||||
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
|
||||
--local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
|
||||
--local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5)
|
||||
|
||||
es.min_refresh = 1.0 -- FIXME: needs a better interface.
|
||||
|
||||
if typeof(es) == xnes.Xnes
|
||||
or typeof(es) == snes.Snes
|
||||
then
|
||||
-- use IGO recommendations
|
||||
local pop5 = max(1, floor(es.popsize / 5))
|
||||
|
||||
local sum = 0
|
||||
for i=1, es.popsize do
|
||||
local maybe = i < pop5 and 1 or 0
|
||||
es.utility[i] = maybe
|
||||
sum = sum + maybe
|
||||
end
|
||||
--for i, v in ipairs(es.utility) do es.utility[i] = v / sum end
|
||||
|
||||
local util = require "util"
|
||||
util.normalize_sums(es.utility)
|
||||
|
||||
es.param_rate = 0.2 --39
|
||||
es.sigma_rate = 0.05 --39
|
||||
es.covar_rate = 0.1 --39
|
||||
es.adaptive = false
|
||||
end
|
||||
|
||||
if false then -- TODO: delete me
|
||||
local nn = require("nn")
|
||||
local util = require("util")
|
||||
local insert = table.insert
|
||||
local scored = nn.arange(10)
|
||||
local indices = ars.collect_best_indices(scored, 3, true)
|
||||
for i, ind in ipairs(indices) do
|
||||
print(ind, ":", scored[ind * 2 - 1], scored[ind * 2 - 0])
|
||||
end
|
||||
local top_rewards = {}
|
||||
for _, ind in ipairs(indices) do
|
||||
insert(top_rewards, scored[ind * 2 - 1])
|
||||
insert(top_rewards, scored[ind * 2 - 0])
|
||||
end
|
||||
-- this shouldn't make a difference to the final print:
|
||||
top_rewards = util.normalize_sums(top_rewards)
|
||||
print(nn.pp(top_rewards))
|
||||
local _, reward_dev = util.calc_mean_dev(top_rewards)
|
||||
print(reward_dev)
|
||||
for i, ind in ipairs(indices) do
|
||||
local pos = top_rewards[i * 2 - 1]
|
||||
local neg = top_rewards[i * 2 - 0]
|
||||
local reward = pos - neg
|
||||
reward = reward / reward_dev
|
||||
print(reward)
|
||||
end
|
||||
do return end
|
||||
end
|
||||
|
||||
local asked = nil -- for caching purposes.
|
||||
local noise = nil -- for caching purposes.
|
||||
local current_cost = spherical(es:params())
|
||||
|
||||
local past_grads = {}
|
||||
local pgi = 0
|
||||
local pgn = 10
|
||||
|
||||
for i=1, iterations do
|
||||
if typeof(es) == snes.Snes and es.min_refresh ~= 1 then
|
||||
asked, noise = es:ask_mix()
|
||||
elseif typeof(es) == ars.Ars then
|
||||
asked, noise = es:ask()
|
||||
elseif typeof(es) == guided.Guided then
|
||||
asked, noise = es:ask(past_grads)
|
||||
else
|
||||
asked, noise = es:ask(asked, noise)
|
||||
end
|
||||
|
||||
local scores = {}
|
||||
for i, v in ipairs(asked) do
|
||||
scores[i] = spherical(v)
|
||||
end
|
||||
|
||||
if typeof(es) == ars.Ars then
|
||||
es:tell(scores)--, current_cost) -- use lips
|
||||
elseif typeof(es) == guided.Guided then
|
||||
local step = es:tell(scores)
|
||||
|
||||
for _, v in ipairs(step) do
|
||||
past_grads[pgi + 1] = v
|
||||
pgi = (pgi + 1) % (pgn * #step)
|
||||
end
|
||||
past_grads.shape = {floor(#past_grads / #step), #step}
|
||||
else
|
||||
es:tell(scores)
|
||||
end
|
||||
|
||||
current_cost = spherical(es:params())
|
||||
if i % 100 == 0 then
|
||||
local sigma = es.sigma
|
||||
if typeof(es) == snes.Snes then
|
||||
sigma = 0
|
||||
for i, v in ipairs(es.std) do sigma = sigma + v end
|
||||
sigma = sigma / #es.std
|
||||
end
|
||||
local inconvergence = sigma / sigma_init
|
||||
local fmt = "fitness at iteration %i: %.4f (%.4f)"
|
||||
print(fmt:format(i, current_cost, log(inconvergence) / log(10)))
|
||||
end
|
||||
end
|
||||
|
||||
-- note: this metric doesn't include the "fitness at iteration" evaluations,
|
||||
-- because those aren't actually used to step towards the optimum.
|
||||
print(("optimized in %i function evaluations"):format(es.evals))
|
||||
|
||||
local s = ''
|
||||
for i, v in ipairs(es:params()) do
|
||||
s = s..("%.8f"):format(v)
|
||||
if i ~= es.dim then s = s..', ' end
|
||||
end
|
||||
print(s)
|
81
expm.lua
Normal file
81
expm.lua
Normal file
|
@ -0,0 +1,81 @@
|
|||
-- here's a really awful way of computing the matrix exponential.
|
||||
-- we employ the QR algorithm to find eigenpairs of a given symmetric matrix,
|
||||
-- then run the ordinary exponent function over the eigenvalues.
|
||||
-- this only works for symmetric matrices!
|
||||
|
||||
local nn = require "nn"
|
||||
local qr = require "qr"
|
||||
local util = require "util"
|
||||
|
||||
local copy = util.copy
|
||||
local dot = nn.dot
|
||||
local exp = math.exp
|
||||
local reshape = nn.reshape
|
||||
local transpose = nn.transpose
|
||||
local zeros = nn.zeros
|
||||
|
||||
local function expm(mat)
|
||||
assert(#mat.shape == 2)
|
||||
assert(mat.shape[1] == mat.shape[2], "expm input must be square")
|
||||
--assert(stuff(mat), "expm input must be symmetrical")
|
||||
|
||||
local dims = mat.shape[1]
|
||||
|
||||
local vec = zeros(mat.shape)
|
||||
for i = 1, dims do
|
||||
local ind = (i - 1) * dims + i -- diagonal
|
||||
vec[ind] = 1
|
||||
end
|
||||
|
||||
local diag = mat
|
||||
for i = 1, 10 do
|
||||
local q, r = qr(diag)
|
||||
vec = dot(vec, q)
|
||||
diag = dot(r, q)
|
||||
end
|
||||
|
||||
for y = 1, dims do
|
||||
for x = 1, dims do
|
||||
local ind = (y - 1) * dims + x
|
||||
if x == y then
|
||||
diag[ind] = exp(diag[ind])
|
||||
else
|
||||
diag[ind] = 0
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return dot(dot(vec, diag), transpose(vec))
|
||||
end
|
||||
|
||||
local eig = require "eig"
|
||||
local tred2 = eig.tred2
|
||||
local tqli = eig.tqli
|
||||
|
||||
local function expm2(mat)
|
||||
assert(#mat.shape == 2)
|
||||
assert(mat.shape[1] == mat.shape[2])
|
||||
local dims = mat.shape[1]
|
||||
|
||||
-- new version that computes much better (faster?) eigenpairs
|
||||
local vec = copy(mat)
|
||||
local d, e = tred2(vec)
|
||||
tqli(d, e, vec)
|
||||
|
||||
local diag = {}
|
||||
for y = 1, dims do
|
||||
for x = 1, dims do
|
||||
local ind = (y - 1) * dims + x
|
||||
if x == y then
|
||||
diag[ind] = exp(d[x])
|
||||
else
|
||||
diag[ind] = 0
|
||||
end
|
||||
end
|
||||
end
|
||||
reshape(diag, dims, dims)
|
||||
|
||||
return dot(dot(transpose(vec), diag), vec)
|
||||
end
|
||||
|
||||
return expm2
|
31
expm_test.lua
Normal file
31
expm_test.lua
Normal file
|
@ -0,0 +1,31 @@
|
|||
local globalize = require "strict"
|
||||
|
||||
local nn = require "nn"
|
||||
local expm = require "expm"
|
||||
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local dims = 8
|
||||
local n = dims * dims
|
||||
local M = {}
|
||||
local coeff = dims * dims * sqrt(dims)
|
||||
for i = 1, n do M[i] = i / coeff end
|
||||
M.shape = {dims, dims}
|
||||
|
||||
M = nn.dot(M, nn.transpose(M))
|
||||
|
||||
print(nn.pp(M, "%9.6f"))
|
||||
|
||||
local exp_M = expm(M)
|
||||
|
||||
print(nn.pp(exp_M, "%9.6f"))
|
||||
|
||||
local binser = require "binser"
|
||||
|
||||
--local dat = binser.s(exp_M)
|
||||
binser.writeFile("expm.bin", exp_M)
|
||||
local res, len = binser.readFile("expm.bin")
|
||||
assert(len == 1)
|
||||
exp_M = res[1]
|
||||
|
||||
print(nn.pp(exp_M, "%9.6f"))
|
62
extra.lua
Normal file
62
extra.lua
Normal file
|
@ -0,0 +1,62 @@
|
|||
local function strpad(num, count, pad)
|
||||
num = tostring(num)
|
||||
return (pad:rep(count)..num):sub(#num)
|
||||
end
|
||||
|
||||
local function add_zeros(num, count)
|
||||
return strpad(num, count - 1, '0')
|
||||
end
|
||||
|
||||
local function mixed_sorter(a, b)
|
||||
a = type(a) == 'number' and add_zeros(a, 16) or tostring(a)
|
||||
b = type(b) == 'number' and add_zeros(b, 16) or tostring(b)
|
||||
return a < b
|
||||
end
|
||||
|
||||
-- loosely based on http://lua-users.org/wiki/SortedIteration
|
||||
-- the original didn't make use of closures for who knows why
|
||||
local function order_keys(t)
|
||||
local oi = {}
|
||||
for key in pairs(t) do
|
||||
table.insert(oi, key)
|
||||
end
|
||||
table.sort(oi, mixed_sorter)
|
||||
return oi
|
||||
end
|
||||
|
||||
local function opairs(t, cache)
|
||||
local oi = cache and cache[t] or order_keys(t)
|
||||
if cache then
|
||||
cache[t] = oi
|
||||
end
|
||||
local i = 0
|
||||
return function()
|
||||
i = i + 1
|
||||
local key = oi[i]
|
||||
if key then return key, t[key] end
|
||||
end
|
||||
end
|
||||
|
||||
local function traverse(path)
|
||||
if not path then return end
|
||||
local parent = _G
|
||||
local key
|
||||
for w in path:gfind("[%w_]+") do
|
||||
if key then
|
||||
parent = rawget(parent, key)
|
||||
if type(parent) ~= 'table' then return end
|
||||
end
|
||||
key = w
|
||||
end
|
||||
if not key then return end
|
||||
return {parent=parent, key=key}
|
||||
end
|
||||
|
||||
return {
|
||||
strpad = strpad,
|
||||
add_zeros = add_zeros,
|
||||
mixed_sorter = mixed_sorter,
|
||||
order_keys = order_keys,
|
||||
opairs = opairs,
|
||||
traverse = traverse,
|
||||
}
|
194
guided.lua
Normal file
194
guided.lua
Normal file
|
@ -0,0 +1,194 @@
|
|||
-- Guided Evolutionary Strategies
|
||||
-- https://arxiv.org/abs/1806.10230
|
||||
|
||||
-- this is just ARS extended to utilize gradients
|
||||
-- approximated from previous iterations.
|
||||
|
||||
-- for simplicity:
|
||||
-- antithetic is always true
|
||||
-- momentum is always 0
|
||||
-- no graycode/lipschitz nonsense
|
||||
|
||||
local floor = math.floor
|
||||
local insert = table.insert
|
||||
local ipairs = ipairs
|
||||
local max = math.max
|
||||
local print = print
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local Base = require "Base"
|
||||
|
||||
local nn = require "nn"
|
||||
local dot_mv = nn.dot_mv
|
||||
local transpose = nn.transpose
|
||||
local normal = nn.normal
|
||||
local prod = nn.prod
|
||||
local uniform = nn.uniform
|
||||
local zeros = nn.zeros
|
||||
|
||||
local qr = require "qr2"
|
||||
|
||||
local util = require "util"
|
||||
local argsort = util.argsort
|
||||
local calc_mean_dev = util.calc_mean_dev
|
||||
|
||||
local Guided = Base:extend()
|
||||
|
||||
local function collect_best_indices(scored, top)
|
||||
-- select one (the best) reward of each pos/neg pair.
|
||||
local best_rewards
|
||||
best_rewards = {}
|
||||
for i = 1, #scored / 2 do
|
||||
local pos = scored[i * 2 - 1]
|
||||
local neg = scored[i * 2 - 0]
|
||||
best_rewards[i] = max(pos, neg)
|
||||
end
|
||||
|
||||
local indices = argsort(best_rewards, function(a, b) return a > b end)
|
||||
|
||||
for i = top + 1, #best_rewards do indices[i] = nil end
|
||||
return indices
|
||||
end
|
||||
|
||||
function Guided:init(dims, popsize, poptop, base_rate, sigma, alpha, beta)
|
||||
-- sigma: scale of random perturbations.
|
||||
-- alpha: blend between full parameter space and its gradient subspace.
|
||||
-- 1.0 is roughly equivalent to ARS.
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
|
||||
self.param_rate = base_rate
|
||||
self.sigma = sigma or 1.0
|
||||
self.alpha = alpha or 0.5
|
||||
self.beta = beta or 1.0
|
||||
|
||||
self.poptop = poptop or popsize
|
||||
assert(self.poptop <= popsize)
|
||||
self.popsize = self.popsize * 2 -- antithetic
|
||||
|
||||
self._params = zeros(self.dims)
|
||||
--self.accum = zeros(self.dims) -- momentum
|
||||
|
||||
self.evals = 0
|
||||
end
|
||||
|
||||
function Guided:params(new_params)
|
||||
if new_params ~= nil then
|
||||
assert(#self._params == #new_params, "new parameters have the wrong size")
|
||||
for i, v in ipairs(new_params) do self._params[i] = v end
|
||||
end
|
||||
return self._params
|
||||
end
|
||||
|
||||
function Guided:decay(param_decay, sigma_decay)
|
||||
-- FIXME: multiplying by sigma probably isn't correct anymore.
|
||||
-- is this correct now?
|
||||
if param_decay > 0 then
|
||||
local scale = self.sigma / sqrt(self.dims)
|
||||
scale = scale * self.beta
|
||||
scale = scale * self.param_rate / (self.sigma * self.sigma)
|
||||
|
||||
scale = 1 - param_decay * scale
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = scale * v
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function Guided:ask(grads)
|
||||
local asked = {}
|
||||
local noise = {}
|
||||
|
||||
local n_grad = 0
|
||||
local gnoise, U, dummy, left, right
|
||||
if grads ~= nil and #grads > 0 then
|
||||
n_grad = grads.shape[1]
|
||||
gnoise = zeros(n_grad)
|
||||
|
||||
U, dummy = qr(transpose(grads))
|
||||
--print(nn.pp(transpose(U), "%9.4f"))
|
||||
|
||||
left = sqrt(self.alpha / self.dims)
|
||||
right = sqrt((1 - self.alpha) / n_grad)
|
||||
--print(left, right)
|
||||
end
|
||||
|
||||
for i = 1, self.popsize do
|
||||
local asking = zeros(self.dims)
|
||||
local noisy = zeros(self.dims)
|
||||
asked[i] = asking
|
||||
noise[i] = noisy
|
||||
|
||||
if i % 2 == 0 then
|
||||
local old_noisy = noise[i - 1]
|
||||
for j, v in ipairs(old_noisy) do
|
||||
noisy[j] = -v
|
||||
end
|
||||
elseif n_grad == 0 then
|
||||
local scale = self.sigma / sqrt(self.dims)
|
||||
for j = 1, self.dims do
|
||||
noisy[j] = scale * normal()
|
||||
end
|
||||
else
|
||||
for j = 1, self.dims do noisy[j] = normal() end
|
||||
for j = 1, n_grad do gnoise[j] = normal() end
|
||||
local noisier = dot_mv(U, gnoise)
|
||||
for j, v in ipairs(noisy) do
|
||||
noisy[j] = self.sigma * (left * v + right * noisier[j])
|
||||
end
|
||||
end
|
||||
|
||||
for j, v in ipairs(self._params) do
|
||||
asking[j] = v + noisy[j]
|
||||
end
|
||||
end
|
||||
|
||||
self.noise = noise
|
||||
return asked, noise
|
||||
end
|
||||
|
||||
function Guided:tell(scored, unperturbed_score)
|
||||
self.evals = self.evals + #scored
|
||||
|
||||
local indices = collect_best_indices(scored, self.poptop)
|
||||
|
||||
local top_rewards = {}
|
||||
for _, ind in ipairs(indices) do
|
||||
insert(top_rewards, scored[ind * 2 - 1])
|
||||
insert(top_rewards, scored[ind * 2 - 0])
|
||||
end
|
||||
|
||||
local step = zeros(self.dims)
|
||||
local _, reward_dev = calc_mean_dev(top_rewards)
|
||||
if reward_dev == 0 then reward_dev = 1 end
|
||||
|
||||
for i, ind in ipairs(indices) do
|
||||
local pos = top_rewards[i * 2 - 1]
|
||||
local neg = top_rewards[i * 2 - 0]
|
||||
local reward = pos - neg
|
||||
if reward ~= 0 then
|
||||
local noisy = self.noise[ind * 2 - 1]
|
||||
-- NOTE: technically this reward divide isn't part of guided search.
|
||||
reward = reward / reward_dev
|
||||
|
||||
local scale = reward / self.poptop * self.beta / 2
|
||||
for j, v in ipairs(noisy) do
|
||||
step[j] = step[j] + scale * v
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local coeff = self.param_rate / (self.sigma * self.sigma)
|
||||
for i, v in ipairs(self._params) do
|
||||
self._params[i] = v + coeff * step[i]
|
||||
end
|
||||
|
||||
self.noise = nil
|
||||
|
||||
return step
|
||||
end
|
||||
|
||||
return {
|
||||
--collect_best_indices = collect_best_indices, -- ars.lua has more features
|
||||
Guided = Guided,
|
||||
}
|
109
main.lua
109
main.lua
|
@ -20,10 +20,17 @@ local trial_rewards = {}
|
|||
local trials_remaining = 0
|
||||
local es -- evolution strategy.
|
||||
|
||||
local attempt_i = 0
|
||||
local sub_rewards = {}
|
||||
|
||||
local past_grads = {} -- for Guided
|
||||
local pgi = 0 -- past_grads_index
|
||||
|
||||
local trial_frames = 0
|
||||
local total_frames = 0
|
||||
local lagless_count = 0
|
||||
local decisions_made = 0
|
||||
local last_decision_frame = -1
|
||||
|
||||
local force_start = false
|
||||
local force_start_old = false
|
||||
|
@ -92,6 +99,7 @@ local util = require("util")
|
|||
local argmax = util.argmax
|
||||
local argsort = util.argsort
|
||||
local calc_mean_dev = util.calc_mean_dev
|
||||
local calc_mean_dev_unbiased = util.calc_mean_dev_unbiased
|
||||
local clamp = util.clamp
|
||||
local copy = util.copy
|
||||
local empty = util.empty
|
||||
|
@ -148,13 +156,21 @@ local network
|
|||
local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z
|
||||
local function make_network(input_size)
|
||||
nn_x = nn.Input({input_size})
|
||||
|
||||
local embed_dim = cfg.embed and 2 or 3
|
||||
|
||||
if cfg.embed then
|
||||
nn_tx = nn.Input({gcfg.tile_count})
|
||||
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
|
||||
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, embed_dim))
|
||||
else -- new tile inputs.
|
||||
nn_tx = nn.Input({gcfg.tile_count * 3})
|
||||
nn_ty = nn_tx
|
||||
end
|
||||
|
||||
nn_tz = nn_ty
|
||||
if cfg.reduce_tiles then
|
||||
nn_tz = nn_tz:feed(nn.Reshape{11, 17 * 2})
|
||||
nn_tz = nn_tz:feed(nn.DenseBroadcast(5, true))
|
||||
if cfg.reduce_tiles > 0 then
|
||||
nn_tz = nn_tz:feed(nn.Reshape{11, 17 * embed_dim})
|
||||
nn_tz = nn_tz:feed(nn.DenseBroadcast(cfg.reduce_tiles, true))
|
||||
nn_tz = nn_tz:feed(nn.Relu())
|
||||
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
|
||||
end
|
||||
|
@ -184,6 +200,7 @@ end
|
|||
local ars = require("ars")
|
||||
local snes = require("snes")
|
||||
local xnes = require("xnes")
|
||||
local guided = require("guided")
|
||||
|
||||
local function prepare_epoch()
|
||||
trial_neg = false
|
||||
|
@ -212,6 +229,8 @@ local function prepare_epoch()
|
|||
local dummy
|
||||
if cfg.es == 'ars' then
|
||||
trial_params, dummy = es:ask(precision)
|
||||
elseif cfg.es == 'guided' then
|
||||
trial_params, dummy = es:ask(past_grads)
|
||||
elseif cfg.es == 'snes' then
|
||||
trial_params, dummy = es:ask_mix()
|
||||
else
|
||||
|
@ -222,6 +241,7 @@ local function prepare_epoch()
|
|||
end
|
||||
|
||||
local function load_next_trial()
|
||||
attempt_i = 1
|
||||
if cfg.negate_trials then
|
||||
trial_neg = not trial_neg
|
||||
else
|
||||
|
@ -264,15 +284,23 @@ local function learn_from_epoch()
|
|||
end
|
||||
|
||||
local step
|
||||
if cfg.es == 'ars' and cfg.ars_lips then
|
||||
if cfg.es == 'ars' then --and cfg.ars_lips then
|
||||
step = es:tell(trial_rewards, current_cost)
|
||||
else
|
||||
step = es:tell(trial_rewards)
|
||||
end
|
||||
|
||||
local step_mean, step_dev = calc_mean_dev(step)
|
||||
print("step mean:", step_mean)
|
||||
print("step stddev:", step_dev)
|
||||
print(("step mean: %9.6f"):format(step_mean))
|
||||
print(("step stddev: %9.6f"):format(step_dev))
|
||||
|
||||
if cfg.es == 'guided' and cfg.past_grads > 0 then
|
||||
for _, v in ipairs(step) do
|
||||
past_grads[pgi + 1] = v
|
||||
pgi = (pgi + 1) % (cfg.past_grads * #step)
|
||||
end
|
||||
past_grads.shape = {floor(#past_grads / #step), #step}
|
||||
end
|
||||
|
||||
es:decay(cfg.param_decay, cfg.sigma_decay)
|
||||
|
||||
|
@ -328,31 +356,27 @@ local function joypad_mash(button)
|
|||
joypad.write(1, jp_mash)
|
||||
end
|
||||
|
||||
local function loadlevel(world, level)
|
||||
-- TODO: move to smb.lua. rename to load_level.
|
||||
if world == 0 then world = random(1, 8) end
|
||||
if level == 0 then level = random(1, 4) end
|
||||
emu.poweron()
|
||||
while emu.framecount() < 60 do
|
||||
if emu.framecount() == 32 then
|
||||
local area = game.area_lut[world * 10 + level]
|
||||
game.W(0x75F, world - 1)
|
||||
game.W(0x75C, level - 1)
|
||||
game.W(0x760, area)
|
||||
end
|
||||
if emu.framecount() == 42 then
|
||||
game.W(0x7A0, 0) -- world screen timer (reduces startup time)
|
||||
end
|
||||
joypad_mash('start')
|
||||
emu.frameadvance()
|
||||
end
|
||||
end
|
||||
|
||||
local function do_reset()
|
||||
local state = game.get_state()
|
||||
-- be a little more descriptive.
|
||||
if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end
|
||||
|
||||
--if cfg.attempts > 1 and attempt_i >= cfg.attempts then
|
||||
attempt_i = attempt_i + 1
|
||||
sub_rewards[#sub_rewards + 1] = reward
|
||||
--print(sub_rewards)
|
||||
|
||||
if #sub_rewards >= cfg.attempts then
|
||||
if cfg.attempts == 1 then
|
||||
reward = sub_rewards[1]
|
||||
else
|
||||
local sub_mean, sub_std = calc_mean_dev(sub_rewards)
|
||||
reward = floor(sub_mean)
|
||||
--local sub_mean, sub_std = calc_mean_dev_unbiased(sub_rewards)
|
||||
--reward = floor(sub_mean - sub_std)
|
||||
end
|
||||
empty(sub_rewards)
|
||||
|
||||
if trial_i >= 0 then
|
||||
if trial_i == 0 then
|
||||
print('test trial reward:', reward, "("..state..")")
|
||||
|
@ -385,10 +409,11 @@ local function do_reset()
|
|||
prepare_epoch()
|
||||
collectgarbage()
|
||||
if any_random then
|
||||
loadlevel(cfg.starting_world, cfg.starting_level)
|
||||
game.load_level(cfg.starting_world, cfg.starting_level)
|
||||
state_saved = false
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
max_time = 6 * sqrt(10 * (epoch_i - 1)) + 60
|
||||
max_time = clamp(max_time, cfg.min_time, cfg.max_time)
|
||||
|
@ -426,7 +451,9 @@ local function do_reset()
|
|||
trial_frames = 0
|
||||
emu.frameadvance() -- prevents emulator from quirking up.
|
||||
|
||||
if attempt_i > cfg.attempts then
|
||||
load_next_trial()
|
||||
end
|
||||
|
||||
reset = false
|
||||
end
|
||||
|
@ -449,7 +476,7 @@ local function init()
|
|||
if not playing then emu.speedmode("turbo") end
|
||||
|
||||
if not any_random then
|
||||
loadlevel(cfg.starting_world, cfg.starting_level)
|
||||
game.load_level(cfg.starting_world, cfg.starting_level)
|
||||
end
|
||||
|
||||
params_fn = cfg.params_fn or ('network%07i.txt'):format(network.n_param)
|
||||
|
@ -480,7 +507,10 @@ local function init()
|
|||
elseif cfg.es == 'ars' then
|
||||
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||
cfg.base_rate, cfg.deviation, cfg.negate_trials,
|
||||
cfg.momentum)
|
||||
cfg.momentum, cfg.beta)
|
||||
elseif cfg.es == 'guided' then
|
||||
es = guided.Guided(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
|
||||
cfg.base_rate, cfg.deviation, cfg.alpha, cfg.beta)
|
||||
else
|
||||
error("Unknown evolution strategy specified: " + tostring(cfg.es))
|
||||
end
|
||||
|
@ -536,6 +566,7 @@ local function doit(dummy)
|
|||
empty(game.sprite_input)
|
||||
empty(game.tile_input)
|
||||
empty(game.extra_input)
|
||||
empty(game.new_input)
|
||||
|
||||
local controllable = game.R(0x757) == 0 and game.R(0x758) == 0
|
||||
local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
||||
|
@ -615,12 +646,18 @@ local function doit(dummy)
|
|||
for i, v in ipairs(game.extra_input) do insert(X, v / 256) end
|
||||
nn.reshape(X, 1, gcfg.input_size)
|
||||
nn.reshape(game.tile_input, 1, gcfg.tile_count)
|
||||
nn.reshape(game.new_input, 1, gcfg.tile_count * 3)
|
||||
|
||||
trial_frames = trial_frames + cfg.frameskip
|
||||
if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then
|
||||
total_frames = total_frames + cfg.frameskip
|
||||
|
||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
|
||||
local outputs
|
||||
if cfg.embed then
|
||||
outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
|
||||
else
|
||||
outputs = network:forward({[nn_x]=X, [nn_tx]=game.new_input})
|
||||
end
|
||||
|
||||
local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames)
|
||||
if cfg.det_epsilon and random() < eps then
|
||||
|
@ -676,6 +713,7 @@ while true do
|
|||
if reset then
|
||||
do_reset()
|
||||
lagless_count = 0
|
||||
last_decision_frame = -1
|
||||
end
|
||||
|
||||
if not cfg.enable_network then
|
||||
|
@ -692,8 +730,15 @@ while true do
|
|||
game.W(0x75A, 1)
|
||||
end
|
||||
|
||||
local doot = jp == nil or lagless_count % cfg.frameskip == 0
|
||||
local delta = lagless_count - last_decision_frame
|
||||
local doot = true
|
||||
if jp ~= nil then
|
||||
doot = delta >= cfg.frameskip
|
||||
doot = doot and random() >= cfg.prob_frameskip
|
||||
doot = doot or delta >= cfg.max_frameskip
|
||||
end
|
||||
doit(not doot)
|
||||
if doot then last_decision_frame = lagless_count end
|
||||
|
||||
-- jp might still be nil if we're not ingame or we're not playing.
|
||||
if jp ~= nil then joypad.write(1, jp) end
|
||||
|
|
54
monitor_tiles.lua
Normal file
54
monitor_tiles.lua
Normal file
|
@ -0,0 +1,54 @@
|
|||
-- keep track of which blocks are actually seen in the game.
|
||||
-- play back an all-levels TAS with this script running.
|
||||
|
||||
local floor = math.floor
|
||||
local open = io.open
|
||||
local pairs = pairs
|
||||
local print = print
|
||||
|
||||
local util = require("util")
|
||||
local R = memory.readbyteunsigned
|
||||
local W = memory.writebyte
|
||||
local function S(addr) return util.signbyte(R(addr)) end
|
||||
|
||||
local game = require("smb") -- just for advance()
|
||||
|
||||
local serial = require "serialize"
|
||||
local serialize = serial.serialize
|
||||
local deserialize = serial.deserialize
|
||||
|
||||
local fn = 'seen_tiles.lua'
|
||||
local seen = deserialize(fn) or {}
|
||||
|
||||
local function mark_tile(sx, sy, kind)
|
||||
if not seen[kind] then
|
||||
seen[kind] = true
|
||||
print(("%02X"):format(kind))
|
||||
serialize(fn, seen)
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_tiles()
|
||||
--local tile_col = R(0x6A0)
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
for y = 0, 12 do
|
||||
for x = 0, 16 do
|
||||
local col = (x + tile_scroll) % 32
|
||||
local t
|
||||
if col < 16 then
|
||||
t = R(0x500 + y * 16 + (col % 16))
|
||||
else
|
||||
t = R(0x5D0 + y * 16 + (col % 16))
|
||||
end
|
||||
local sx = x * 16 + 8 - tile_scroll_remainder
|
||||
local sy = y * 16 + 40
|
||||
mark_tile(sx, sy, t)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
while true do
|
||||
handle_tiles()
|
||||
game.advance()
|
||||
end
|
48
nn.lua
48
nn.lua
|
@ -1,20 +1,15 @@
|
|||
local assert = assert
|
||||
local ceil = math.ceil
|
||||
local cos = math.cos
|
||||
local exp = math.exp
|
||||
local floor = math.floor
|
||||
local huge = math.huge
|
||||
local insert = table.insert
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local max = math.max
|
||||
local min = math.min
|
||||
local open = io.open
|
||||
local pairs = pairs
|
||||
local pi = math.pi
|
||||
local print = print
|
||||
local remove = table.remove
|
||||
local sin = math.sin
|
||||
local sqrt = math.sqrt
|
||||
local tanh = math.tanh
|
||||
local tostring = tostring
|
||||
|
@ -105,19 +100,28 @@ end
|
|||
|
||||
-- ndarray-ish stuff and more involved math
|
||||
|
||||
local function pp_join(sep, fmt, t, a, b)
|
||||
a = a or 1
|
||||
b = b or #t
|
||||
local s = ''
|
||||
for i = a, b do
|
||||
s = s..fmt:format(t[i])
|
||||
if i ~= b then s = s..sep end
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
||||
local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
|
||||
-- pretty-prints an nd-array.
|
||||
fmt = fmt or '%10.7f,'
|
||||
fmt = fmt or '%10.7f'
|
||||
sep = sep or ','
|
||||
ti = ti or 0
|
||||
di = di or 1
|
||||
depth = depth or 0
|
||||
|
||||
if t.shape == nil then
|
||||
local s = '['
|
||||
for i = 1, #t do s = s..fmt:format(t[i]) end
|
||||
return s..']'..sep..'\n'
|
||||
end
|
||||
if t == nil then return "nil" end
|
||||
|
||||
if t.shape == nil then return '['..pp_join(sep, fmt, t)..']'..sep..'\n' end
|
||||
|
||||
local dim = t.shape[di]
|
||||
|
||||
|
@ -134,11 +138,10 @@ local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
|
|||
s = s..pp(t, fmt, sep, ti, di + 1, depth + 1, i == 1, i == dim)
|
||||
ti = ti + ti_step
|
||||
end
|
||||
if islast then s = s..indent..']'..sep..'\n' else s = s..indent..']'..sep end
|
||||
s = s..indent..']'..sep
|
||||
if islast then s = s..'\n' end
|
||||
else
|
||||
s = s..indent..'['
|
||||
for i = ti + 1, ti + dim do s = s..fmt:format(t[i])..sep end
|
||||
s = s..']'..sep..'\n'
|
||||
s = s..indent..'['..pp_join(sep, fmt, t, ti + 1, ti + dim)..']'..sep..'\n'
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
@ -265,6 +268,20 @@ local function dot(a, b, ax_a, ax_b, out)
|
|||
return out
|
||||
end
|
||||
|
||||
local function transpose(x, out)
|
||||
assert(#x.shape == 2) -- TODO: handle ndarrays like numpy.
|
||||
local rows = x.shape[1]
|
||||
local cols = x.shape[2]
|
||||
local y = out or zeros{cols, rows}
|
||||
-- TODO: simplify? y can be consecutive for sure.
|
||||
for i = 1, rows do
|
||||
for j = 1, cols do
|
||||
y[(j - 1) * rows + i] = x[(i - 1) * cols + j]
|
||||
end
|
||||
end
|
||||
return y
|
||||
end
|
||||
|
||||
-- nodal
|
||||
|
||||
local function traverse(node_in, node_out, nodes, dummy_mode)
|
||||
|
@ -875,6 +892,7 @@ return {
|
|||
ppi = ppi,
|
||||
dot_mv = dot_mv,
|
||||
dot = dot,
|
||||
transpose = transpose,
|
||||
traverse = traverse,
|
||||
traverse_all = traverse_all,
|
||||
|
||||
|
|
7
pp_test.lua
Normal file
7
pp_test.lua
Normal file
|
@ -0,0 +1,7 @@
|
|||
local nn = require "nn"
|
||||
|
||||
a = {0,1,2,3,4,5,6,7}
|
||||
print(nn.pp(a, "%9.4f"))
|
||||
print(nn.pp(nn.reshape(a, 4, 2), "%9.4f"))
|
||||
print(nn.pp(nn.reshape(a, 2, 4), "%9.4f"))
|
||||
print(nn.pp(nn.reshape(a, 2, 2, 2), "%9.4f"))
|
393
presets.lua
393
presets.lua
|
@ -33,7 +33,7 @@ make_preset{
|
|||
|
||||
init_zeros = true,
|
||||
|
||||
reduce_tiles = true,
|
||||
reduce_tiles = 5,
|
||||
bias_out = false,
|
||||
|
||||
deterministic = false,
|
||||
|
@ -72,6 +72,28 @@ make_preset{
|
|||
sigma_decay = 0.008,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'snes2',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
||||
es = 'snes',
|
||||
deterministic = true,
|
||||
deviation = 0.01,
|
||||
negate_trials = false,
|
||||
epoch_trials = 60,
|
||||
min_refresh = 2/3,
|
||||
param_rate = 0.368,
|
||||
param_decay = 0.0138,
|
||||
sigma_rate = 0.100,
|
||||
sigma_decay = 0.0051,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'snes3',
|
||||
parent = 'snes2',
|
||||
min_refresh = 1/3,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'xnes',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
@ -120,6 +142,375 @@ make_preset{
|
|||
momentum = 0.5,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-vanilla',
|
||||
parent = 'ars',
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-lips',
|
||||
parent = 'ars',
|
||||
|
||||
ars_lips = true,
|
||||
-- momentum = 0.5, -- this is default.
|
||||
param_rate = 1.0,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-skip',
|
||||
parent = 'ars',
|
||||
|
||||
frameskip = 1,
|
||||
prob_frameskip = 0.25,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-big',
|
||||
parent = 'ars',
|
||||
|
||||
epoch_top_trials = 75,
|
||||
epoch_trials = 100,
|
||||
momentum = 0.5,
|
||||
param_rate = 1.0,
|
||||
|
||||
--graycode = true,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-huge',
|
||||
parent = 'big-scroll-hidden',
|
||||
|
||||
deterministic = true,
|
||||
deviation = 0.01,
|
||||
epoch_top_trials = 75,
|
||||
epoch_trials = 100,
|
||||
es = 'ars',
|
||||
momentum = 0.5,
|
||||
param_decay = 0.0138,
|
||||
param_rate = 0.5,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-stupid',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
||||
es = 'ars',
|
||||
epoch_top_trials = 4,
|
||||
deterministic = false,
|
||||
deviation = 0.2,
|
||||
epoch_trials = 4,
|
||||
param_rate = 0.1,
|
||||
param_decay = 0.003,
|
||||
momentum = 0.99,
|
||||
}
|
||||
|
||||
-- new stuff for 2019:
|
||||
|
||||
make_preset{
|
||||
name = 'ars-skip-more',
|
||||
parent = 'ars',
|
||||
|
||||
-- old:
|
||||
--frameskip = 2,
|
||||
--prob_frameskip = 0.5,
|
||||
--max_frameskip = 60,
|
||||
-- new:
|
||||
frameskip = 3,
|
||||
prob_frameskip = 0.5,
|
||||
max_frameskip = 5,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-skip-more-3',
|
||||
parent = 'ars-skip-more',
|
||||
|
||||
attempts = 3, -- per trial. score = mean(scores) - stdev(scores)
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'snes-skip-more-3',
|
||||
parent = 'snes3',
|
||||
|
||||
frameskip = 3,
|
||||
prob_frameskip = 0.5,
|
||||
max_frameskip = 5,
|
||||
attempts = 3,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
||||
es = 'guided',
|
||||
epoch_top_trials = 20,
|
||||
deterministic = true,
|
||||
deviation = 0.1,
|
||||
epoch_trials = 20,
|
||||
param_rate = 0.00368,
|
||||
param_decay = 0.0,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided2',
|
||||
parent = 'guided',
|
||||
|
||||
past_grads = 2,
|
||||
-- after epoch 50, trying this:
|
||||
--param_rate = 0.05,
|
||||
-- after epoch 50+20, stepping back to this:
|
||||
param_rate = 0.01,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided10',
|
||||
parent = 'guided',
|
||||
|
||||
past_grads = 10,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided69', -- the nice one
|
||||
parent = 'guided',
|
||||
|
||||
deviation = 0.05,
|
||||
epoch_top_trials = 10,
|
||||
epoch_trials = 20,
|
||||
param_rate = 0.006,
|
||||
|
||||
past_grads = 4,
|
||||
alpha = 0.25,
|
||||
}
|
||||
|
||||
-- TODO: yet another preset. try building up from 1 trial ARS to something good.
|
||||
make_preset{
|
||||
name = 'redux',
|
||||
|
||||
min_time = 300,
|
||||
max_time = 300,
|
||||
timer_loser = 1/1,
|
||||
|
||||
score_multiplier = 1,
|
||||
|
||||
init_zeros = true,
|
||||
|
||||
deterministic = true,
|
||||
|
||||
es = 'guided',
|
||||
past_grads = 2, -- for Guided.
|
||||
alpha = 0.25, -- for Guided.
|
||||
|
||||
ars_lips = false, -- for ARS.
|
||||
beta = 1.0, -- fix the default.
|
||||
|
||||
epoch_top_trials = 4, -- for ARS, Guided.
|
||||
epoch_trials = 5,
|
||||
attempts = 1, -- TODO: document.
|
||||
|
||||
deviation = 1.0, -- 0.1
|
||||
base_rate = 1.0,
|
||||
param_decay = 0.01,
|
||||
|
||||
graycode = false, -- for ARS.
|
||||
min_refresh = 0.1, -- for SNES.
|
||||
sigma_decay = 0.0, -- for SNES, xNES.
|
||||
momentum = 0.0, -- for ARS.
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'redux_big',
|
||||
parent = 'redux',
|
||||
|
||||
time_inputs = true, -- insert binary inputs of a frame counter.
|
||||
hidden = true, -- use a hidden layer with ReLU/GELU activation.
|
||||
hidden_size = 64,
|
||||
layernorm = true, -- use a LayerNorm layer after said activation.
|
||||
reduce_tiles = false,
|
||||
bias_out = false,
|
||||
|
||||
-- gets stuck pretty quick, so tweak some stuff...
|
||||
epoch_top_trials = 8,
|
||||
epoch_trials = 10,
|
||||
deviation = 1.0,
|
||||
base_rate = 0.15,
|
||||
param_decay = 0.05,
|
||||
past_grads = 4,
|
||||
alpha = 0.25,
|
||||
-- well it doesn't get stuck anymore, but regular redux works much better.
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided-skip-more-3',
|
||||
parent = 'guided',
|
||||
|
||||
--param_rate = 0.00368, -- should probably be this instead...
|
||||
param_rate = 0.01,
|
||||
|
||||
frameskip = 3,
|
||||
prob_frameskip = 0.5,
|
||||
max_frameskip = 5,
|
||||
attempts = 3,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'guided-skip-more-3-again',
|
||||
parent = 'guided-skip-more-3',
|
||||
|
||||
param_rate = 0.08, --0.0316,
|
||||
deviation = 0.5,
|
||||
alpha = 0.1, --0.5,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'crazy',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
||||
es = 'guided',
|
||||
epoch_top_trials = 15,
|
||||
deterministic = false,
|
||||
deviation = 1.0,
|
||||
epoch_trials = 15,
|
||||
param_rate = 1.0,
|
||||
param_decay = 0.0,
|
||||
alpha = 0.0316,
|
||||
--attempts = 3,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-lips2',
|
||||
parent = 'ars',
|
||||
|
||||
ars_lips = true,
|
||||
--epoch_trials = 10,
|
||||
param_rate = 0.147,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'ars-lips3',
|
||||
parent = 'ars',
|
||||
|
||||
ars_lips = true,
|
||||
param_rate = 0.5,
|
||||
deviation = 0.02, -- added after like 272 epochs
|
||||
param_decay = 0.0276, -- added after like 62 epochs
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'hard-embed',
|
||||
parent = 'big-scroll-hidden',
|
||||
|
||||
embed = false,
|
||||
reduce_tiles = 5,
|
||||
hidden_size = 54,
|
||||
|
||||
epoch_top_trials = 20,
|
||||
deterministic = true,
|
||||
deviation = 0.01,
|
||||
epoch_trials = 20,
|
||||
param_rate = 0.368,
|
||||
param_decay = 0.0138,
|
||||
momentum = 0.5,
|
||||
beta = 1.0,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'xnes-recon', -- recon-sidered
|
||||
-- parent = 'big-scroll-hidden',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
||||
es = 'xnes',
|
||||
|
||||
-- embed = false,
|
||||
-- reduce_tiles = 5,
|
||||
-- hidden_size = 54,
|
||||
|
||||
epoch_trials = 20,
|
||||
epoch_top_trials = 20,
|
||||
negate_trials = true,
|
||||
|
||||
deterministic = true,
|
||||
deviation = 0.1,
|
||||
|
||||
param_decay = 0.01,
|
||||
|
||||
param_rate = 0.2,
|
||||
sigma_rate = 0.05,
|
||||
covar_rate = 0.1,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
||||
deterministic = true,
|
||||
|
||||
es = 'ars',
|
||||
epoch_trials = 5,
|
||||
epoch_top_trials = 9999,
|
||||
deviation = 1.0,
|
||||
param_rate = 1.0,
|
||||
beta = 1.0, -- fix the default.
|
||||
|
||||
beta = 5.0, -- oops, i had a dumb bug.
|
||||
param_decay = 0.01,
|
||||
momentum = 0.0,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse2',
|
||||
parent = 'arse',
|
||||
|
||||
beta = 5.0, -- oops, i had a dumb bug.
|
||||
param_decay = 0.001,
|
||||
embed = false,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse3',
|
||||
parent = 'arse2',
|
||||
|
||||
epoch_trials = 10,
|
||||
|
||||
-- note: also disabled sum of squares in ars.lua
|
||||
beta = 1.0,
|
||||
|
||||
--deviation = 0.5, -- after 500 epochs
|
||||
deviation = 0.7071, -- after 500+521 epochs
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse4',
|
||||
parent = 'arse3',
|
||||
|
||||
hidden = true,
|
||||
hidden_size = 68,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse5',
|
||||
parent = 'arse3',
|
||||
|
||||
-- sum of squares still disabled. i probably won't re-enable it really.
|
||||
--deviation = 1.0,
|
||||
deviation = 1.414, -- after 80+360+790 epochs
|
||||
momentum = 0.8, -- neumann momentum, peaks at this value
|
||||
--param_decay = 0.003, -- after 80 epochs
|
||||
--param_decay = 0.01, -- after 80+360 epochs
|
||||
param_decay = 0.0, -- after 80+360+790 epochs
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'arse6',
|
||||
parent = 'arse3',
|
||||
|
||||
epoch_trials = 20,
|
||||
param_rate = 0.25,
|
||||
deviation = 0.1, -- maybe try 0.15
|
||||
momentum = 0.9, -- maybe try 0.8
|
||||
param_decay = 0.0,
|
||||
}
|
||||
|
||||
-- end of new stuff
|
||||
|
||||
make_preset{
|
||||
name = 'play',
|
||||
|
||||
|
|
113
qr.lua
Normal file
113
qr.lua
Normal file
|
@ -0,0 +1,113 @@
|
|||
local nn = require "nn"
|
||||
|
||||
local assert = assert
|
||||
local dot = nn.dot
|
||||
local ipairs = ipairs
|
||||
local min = math.min
|
||||
local reshape = nn.reshape
|
||||
local sqrt = math.sqrt
|
||||
local transpose = nn.transpose
|
||||
local zeros = nn.zeros
|
||||
|
||||
local function minor(x, d)
|
||||
assert(#x.shape == 2)
|
||||
assert(d <= x.shape[1] and d <= x.shape[2])
|
||||
|
||||
local m = zeros(x.shape)
|
||||
|
||||
-- fill diagonals.
|
||||
--for i = 1, d do m[(i - 1) * m.shape[2] + i] = 1 end
|
||||
for i = 1, d * m.shape[2], m.shape[2] + 1 do m[i] = 1 end
|
||||
|
||||
-- copy values.
|
||||
for i = d + 1, m.shape[1] do
|
||||
for j = d + 1, m.shape[2] do
|
||||
local ind = (i - 1) * m.shape[2] + j
|
||||
m[ind] = x[ind]
|
||||
end
|
||||
end
|
||||
|
||||
return m
|
||||
end
|
||||
|
||||
local function norm(a) -- vector norm
|
||||
local sum = 0
|
||||
for _, v in ipairs(a) do sum = sum + v * v end
|
||||
return sqrt(sum)
|
||||
end
|
||||
|
||||
local function householder(x)
|
||||
local rows = x.shape[1]
|
||||
local cols = x.shape[2]
|
||||
local iters = min(rows - 1, cols)
|
||||
|
||||
local q = nil
|
||||
local vec = zeros(rows)
|
||||
local z = x
|
||||
|
||||
for k = 1, iters do
|
||||
z = minor(z, k - 1)
|
||||
|
||||
-- extract a column.
|
||||
for i = 1, rows do vec[i] = z[k + (i - 1) * cols] end
|
||||
|
||||
local a = norm(vec)
|
||||
-- negate the norm if the original diagonal is non-negative.
|
||||
local ind = (k - 1) * cols + k
|
||||
if x[ind] > 0 then a = -a end
|
||||
|
||||
vec[k] = vec[k] + a
|
||||
|
||||
local a = norm(vec)
|
||||
if a == 0 then a = 1 end -- FIXME: should probably just raise an error.
|
||||
for i, v in ipairs(vec) do vec[i] = v / a end
|
||||
|
||||
-- construct the householder reflection: mat = I - 2 * vec * vec.T
|
||||
local mat = zeros{rows, rows}
|
||||
for i = 1, rows do
|
||||
for j = 1, rows do
|
||||
local ind = (i - 1) * rows + j
|
||||
local diag = i == j and 1 or 0
|
||||
mat[ind] = diag - 2 * vec[i] * vec[j]
|
||||
end
|
||||
end
|
||||
|
||||
--print(nn.pp(mat, "%9.3f"))
|
||||
if q == nil then q = mat else q = dot(mat, q) end
|
||||
|
||||
z = dot(mat, z)
|
||||
end
|
||||
|
||||
return transpose(q), dot(q, x) -- Q, R
|
||||
end
|
||||
|
||||
local function qr(x)
|
||||
-- a wrapper for the householder method that will return reduced matrices.
|
||||
assert(#x.shape == 2)
|
||||
|
||||
local q, r = householder(x)
|
||||
|
||||
local rows = x.shape[1]
|
||||
local cols = x.shape[2]
|
||||
if cols >= rows then return q, r end
|
||||
|
||||
-- trim q in-place.
|
||||
q.shape[2] = cols
|
||||
local ind = 1
|
||||
for i = 1, rows do
|
||||
for j = 1, cols do
|
||||
--ind = (i - 1) * cols + j
|
||||
q[ind] = q[(i - 1) * rows + j]
|
||||
ind = ind + 1
|
||||
end
|
||||
end
|
||||
for i = rows * cols + 1, #q do q[i] = nil end
|
||||
|
||||
-- trim r in-place.
|
||||
r.shape[1] = r.shape[2]
|
||||
for i = r.shape[1] * r.shape[2] + 1, #r do r[i] = nil end
|
||||
|
||||
return q, r
|
||||
end
|
||||
|
||||
return qr
|
76
qr2.lua
Normal file
76
qr2.lua
Normal file
|
@ -0,0 +1,76 @@
|
|||
local min = math.min
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local nn = require "nn"
|
||||
local transpose = nn.transpose
|
||||
local zeros = nn.zeros
|
||||
|
||||
local function qr(a)
|
||||
-- FIXME: if first column is exactly zero,
|
||||
-- and cols > rows, Q @ R will not reconstruct the input.
|
||||
-- this isn't too bad since an input like that is invalid anyway,
|
||||
-- but i feel like it should be salvageable.
|
||||
|
||||
-- actually the scope of the problem is much larger than that.
|
||||
-- an input like
|
||||
--[=[
|
||||
[[0, 0, 0, 0]
|
||||
[1, 0, 1, 1]
|
||||
[0, 0, 2, 2]
|
||||
[0, 0, 3, 3]]
|
||||
--]=]
|
||||
-- will cause a lot of problems. for example, Q @ Q.T won't equal eye(4).
|
||||
-- hmm. maybe we can detect this and reverse the matmul to identity if necessary?
|
||||
|
||||
assert(#a.shape == 2)
|
||||
local rows = a.shape[1]
|
||||
local cols = a.shape[2]
|
||||
local small = min(rows, cols)
|
||||
|
||||
local q = transpose(a)
|
||||
local r = zeros{small, cols}
|
||||
|
||||
for i = 1, cols do
|
||||
local i0 = (i - 1) * rows + 1
|
||||
local i1 = i * rows
|
||||
|
||||
for j = 1, min(i - 1, small) do
|
||||
local j0 = (j - 1) * rows + 1
|
||||
local j1 = j * rows
|
||||
local i_to_j = j0 - i0
|
||||
|
||||
local num = 0
|
||||
local den = 0
|
||||
for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end
|
||||
for k = j0, j1 do den = den + q[k] * q[k] end
|
||||
--print(num, den)
|
||||
if den == 0 then den = 1 end -- TODO: should probably just error.
|
||||
|
||||
local x = num / den
|
||||
r[(j - 1) * cols + i] = x
|
||||
for k = i0, i1 do q[k] = q[k] - q[k + i_to_j] * x end
|
||||
end
|
||||
|
||||
if i <= small then
|
||||
local sum = 0
|
||||
for k = i0, i1 do sum = sum + q[k] * q[k] end
|
||||
local norm = sqrt(sum)
|
||||
|
||||
if norm == 0 then
|
||||
--norm = 1
|
||||
--q[i0 + i - 1] = 1 -- FIXME: not robust.
|
||||
r[(i - 1) * cols + i] = 0
|
||||
else
|
||||
for k = i0, i1 do q[k] = q[k] / norm end
|
||||
r[(i - 1) * cols + i] = norm
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
for k = small * rows + 1, #q do q[k] = nil end
|
||||
q.shape[1] = small
|
||||
|
||||
return transpose(q), r
|
||||
end
|
||||
|
||||
return qr
|
79
qr3.lua
Normal file
79
qr3.lua
Normal file
|
@ -0,0 +1,79 @@
|
|||
-- gram-schmidt process
|
||||
|
||||
local nn = require "nn"
|
||||
|
||||
local dot = nn.dot
|
||||
local reshape = nn.reshape
|
||||
local sqrt = math.sqrt
|
||||
local transpose = nn.transpose
|
||||
local zeros = nn.zeros
|
||||
|
||||
local function qr(mat)
|
||||
assert(#mat.shape == 2)
|
||||
local v = transpose(mat)
|
||||
local rows = v.shape[1]
|
||||
local cols = v.shape[2]
|
||||
|
||||
local u = zeros(v.shape)
|
||||
|
||||
local w = {}
|
||||
local y = 1
|
||||
|
||||
local function load_row()
|
||||
local start = (y - 1) * cols
|
||||
for x = 1, cols do w[x] = v[start + x] end
|
||||
end
|
||||
|
||||
local function push_row()
|
||||
local sum = 0
|
||||
for _, value in ipairs(w) do sum = sum + value * value end
|
||||
local norm = sqrt(sum)
|
||||
|
||||
local start = (y - 1) * cols
|
||||
for x, value in ipairs(w) do u[start + x] = value / norm end
|
||||
|
||||
y = y + 1
|
||||
end
|
||||
|
||||
load_row()
|
||||
push_row()
|
||||
|
||||
local sums = {}
|
||||
|
||||
for i = 2, rows do
|
||||
load_row()
|
||||
|
||||
for x = 1, cols do sums[x] = 0 end
|
||||
|
||||
for j = 1, i - 1 do
|
||||
local start = (j - 1) * cols
|
||||
|
||||
local dotted = 0
|
||||
for x, value in ipairs(w) do
|
||||
dotted = dotted + value * u[start + x]
|
||||
end
|
||||
|
||||
--[[
|
||||
local scale = 0
|
||||
for x = 1, cols do
|
||||
local value = u[start + x]
|
||||
scale = scale + value * value
|
||||
end
|
||||
print(scale)
|
||||
dotted = dotted / scale
|
||||
--]]
|
||||
|
||||
for x, value in ipairs(sums) do
|
||||
sums[x] = value + dotted * u[start + x]
|
||||
end
|
||||
end
|
||||
|
||||
for x, value in ipairs(w) do w[x] = value - sums[x] end
|
||||
|
||||
push_row()
|
||||
end
|
||||
|
||||
return transpose(u), dot(u, mat)
|
||||
end
|
||||
|
||||
return qr
|
87
qr_test.lua
Normal file
87
qr_test.lua
Normal file
|
@ -0,0 +1,87 @@
|
|||
local globalize = require "strict"
|
||||
local nn = require "nn"
|
||||
local qr = require "qr"
|
||||
local qr2 = require "qr2"
|
||||
local qr3 = require "qr3"
|
||||
|
||||
local A
|
||||
|
||||
if false then
|
||||
A = {
|
||||
12, -51, 4,
|
||||
6, 167, -68,
|
||||
-4, 24, -41,
|
||||
|
||||
-1, 1, 0,
|
||||
2, 0, 3,
|
||||
}
|
||||
A = nn.reshape(A, 5, 3)
|
||||
|
||||
elseif false then
|
||||
A = {
|
||||
0, 1, 2, 3, 4,
|
||||
5, 6, 7, 8, 9,
|
||||
10, 11, 12, 13, 14
|
||||
}
|
||||
A = nn.reshape(A, 5, 3)
|
||||
|
||||
elseif false then
|
||||
A = {
|
||||
1, 2, 0,
|
||||
2, 3, 1,
|
||||
3, 4, 0,
|
||||
4, 5, 1,
|
||||
5, 6, 0,
|
||||
}
|
||||
|
||||
A = {
|
||||
1, 0, 0,
|
||||
2, 0, 1,
|
||||
3, 0, 0,
|
||||
4, 0, 1,
|
||||
5, 0, 0,
|
||||
}
|
||||
|
||||
A = nn.reshape(A, 5, 3)
|
||||
--A = nn.transpose(A)
|
||||
|
||||
else
|
||||
A = {
|
||||
1, 2, -3,
|
||||
2, 4, 5,
|
||||
-3, 5, 6,
|
||||
}
|
||||
|
||||
A = nn.reshape(A, 3, 3)
|
||||
end
|
||||
|
||||
print("A")
|
||||
print(nn.pp(A, "%9.4f"))
|
||||
print()
|
||||
|
||||
local Q, R = qr(A)
|
||||
|
||||
print("Q (reference)")
|
||||
print(nn.pp(Q, "%9.4f"))
|
||||
print()
|
||||
|
||||
local Q, R = qr3(A)
|
||||
|
||||
print("Q")
|
||||
print(nn.pp(Q, "%9.4f"))
|
||||
print()
|
||||
|
||||
print("R")
|
||||
print(nn.pp(R, "%9.4f"))
|
||||
print()
|
||||
|
||||
print("Q @ R")
|
||||
print(nn.pp(nn.dot(Q, R), "%9.4f"))
|
||||
print()
|
||||
|
||||
--print("Q @ Q.T = I")
|
||||
--print(nn.pp(nn.dot(Q, nn.transpose(Q)), "%9.4f"))
|
||||
--print()
|
||||
|
||||
--A = nn.reshape(A, 5, 3)
|
||||
--Q, R = qr(A)
|
15
rescale.lua
Normal file
15
rescale.lua
Normal file
|
@ -0,0 +1,15 @@
|
|||
local f = assert(io.open("params-ars4.txt", "r"))
|
||||
local data = f:read("*a")
|
||||
f:close()
|
||||
local values = {}
|
||||
for v in data:gmatch("[^\r\n]+") do
|
||||
table.insert(values, tonumber(v))
|
||||
end
|
||||
|
||||
for i, v in ipairs(values) do
|
||||
values[i] = v * 100
|
||||
end
|
||||
|
||||
for i, v in ipairs(values) do
|
||||
print(v)
|
||||
end
|
183
running.lua
Normal file
183
running.lua
Normal file
|
@ -0,0 +1,183 @@
|
|||
local huge = math.huge
|
||||
local ipairs = ipairs
|
||||
local open = io.open
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local nn = require("nn")
|
||||
local Base = require("Base")
|
||||
|
||||
-- https://github.com/modestyachts/ARS/blob/master/code/filter.py
|
||||
-- http://www.johndcook.com/blog/standard_deviation/
|
||||
local Stats = Base:extend()
|
||||
local Normalizer = Base:extend()
|
||||
|
||||
function Stats:init(shape)
|
||||
self._n = 0
|
||||
self._M = nn.zeros(shape)
|
||||
self._S = nn.zeros(shape)
|
||||
end
|
||||
|
||||
function Stats:push(x)
|
||||
assert(nn.prod(x.shape) == nn.prod(self._M.shape), "sizes mismatch")
|
||||
local n1 = self._n
|
||||
self._n = self._n + 1
|
||||
if self._n == 1 then
|
||||
nn.copy(x, self._M)
|
||||
else
|
||||
local delta = {}
|
||||
for i, v in ipairs(self._M) do delta[i] = x[i] - v end
|
||||
for i, v in ipairs(self._M) do self._M[i] = v + delta[i] / self._n end
|
||||
for i, v in ipairs(self._S) do self._S[i] = v + delta[i] * delta[i] * n1 / self._n end
|
||||
end
|
||||
end
|
||||
|
||||
function Stats:var()
|
||||
local out = {}
|
||||
if self._n == 1 then
|
||||
for i, v in ipairs(self._M) do out[i] = v * v end
|
||||
else
|
||||
for i, v in ipairs(self._S) do out[i] = v / (self._n - 1) end
|
||||
end
|
||||
return out
|
||||
end
|
||||
|
||||
function Stats:dev()
|
||||
local out = self:var()
|
||||
for i, v in ipairs(out) do out[i] = sqrt(v) end
|
||||
return out
|
||||
end
|
||||
|
||||
function Normalizer:init(shape, demean, destd)
|
||||
if demean == nil then demean = true end
|
||||
if destd == nil then destd = true end
|
||||
self.shape = shape
|
||||
self.demean = demean
|
||||
self.destd = destd
|
||||
self.rs = Stats(shape)
|
||||
self.mean = nn.zeros(shape)
|
||||
self.std = nn.zeros(shape)
|
||||
for i = 1, #self.std do self.std[i] = 1 end
|
||||
end
|
||||
|
||||
function Normalizer:process(x)
|
||||
local out = nn.copy(x)
|
||||
if self.demean then
|
||||
for i, v in ipairs(out) do out[i] = out[i] - self.mean[i] end
|
||||
end
|
||||
if self.destd then
|
||||
for i, v in ipairs(out) do out[i] = out[i] / (self.std[i] + 1e-8) end
|
||||
end
|
||||
return out
|
||||
end
|
||||
|
||||
function Normalizer:update()
|
||||
nn.copy(self.rs._M, self.mean) -- FIXME: HACK
|
||||
nn.copy(self.rs:dev(), self.std)
|
||||
-- Set values for std less than 1e-7 to +inf
|
||||
-- to avoid dividing by zero. State elements
|
||||
-- with zero variance are set to zero as a result.
|
||||
for i, v in ipairs(self.std) do
|
||||
if v < 1e-7 then self.std[i] = huge end
|
||||
end
|
||||
end
|
||||
|
||||
function Normalizer:push(x, update)
|
||||
self.rs:push(x)
|
||||
if update == nil or update then self:update() end
|
||||
return self:process(x)
|
||||
end
|
||||
|
||||
function Normalizer:default_filename()
|
||||
return ('stats%07i.txt'):format(nn.prod(self.shape))
|
||||
end
|
||||
|
||||
function Normalizer:save(fn)
|
||||
local fn = fn or self:default_filename()
|
||||
local f = open(fn, 'w')
|
||||
if f == nil then error("Failed to save stats to file "..fn) end
|
||||
f:write(self.rs._n)
|
||||
f:write('\n')
|
||||
for i, v in ipairs(self.rs._M) do
|
||||
f:write(v)
|
||||
f:write('\n')
|
||||
end
|
||||
for i, v in ipairs(self.rs._S) do
|
||||
f:write(v)
|
||||
f:write('\n')
|
||||
end
|
||||
f:close()
|
||||
end
|
||||
|
||||
function Normalizer:load(fn)
|
||||
local fn = fn or self:default_filename()
|
||||
local f = open(fn, 'r')
|
||||
if f == nil then error("Failed to load stats from file "..fn) end
|
||||
|
||||
local i = 0
|
||||
local split_M = 1
|
||||
local split_S = split_M + nn.prod(self.shape)
|
||||
for line in f:lines() do
|
||||
i = i + 1
|
||||
local n = tonumber(line)
|
||||
if n == nil then
|
||||
error("Failed reading line "..tostring(i).." of file "..fn)
|
||||
end
|
||||
|
||||
if i <= split_M then
|
||||
self.rs._n = n
|
||||
elseif i <= split_S then
|
||||
self.rs._M[i - split_M] = n
|
||||
else
|
||||
self.rs._S[i - split_S] = n
|
||||
end
|
||||
end
|
||||
f:close()
|
||||
|
||||
self:update()
|
||||
end
|
||||
|
||||
--[[
|
||||
|
||||
-- basic tests
|
||||
|
||||
local dims = 20
|
||||
local rs = Stats(dims)
|
||||
local x = nn.zeros(dims)
|
||||
|
||||
for i = 1, #x do x[i] = nn.normal() end
|
||||
rs:push(x)
|
||||
print(nn.pp(rs:dev()))
|
||||
|
||||
for j = 1, 10000 do
|
||||
for i = 1, #x do x[i] = nn.normal() end
|
||||
rs:push(x)
|
||||
end
|
||||
print(nn.pp(rs:dev()))
|
||||
|
||||
--
|
||||
|
||||
local ms = Normalizer(dims)
|
||||
local exp = math.exp
|
||||
local y
|
||||
|
||||
for i = 1, #x do x[i] = exp(nn.normal()) end
|
||||
y = ms:push(x)
|
||||
print(nn.pp(y))
|
||||
|
||||
for j = 1, 10000 do
|
||||
for i = 1, #x do x[i] = exp(nn.normal()) end
|
||||
y = ms:push(x)
|
||||
end
|
||||
print(nn.pp(y))
|
||||
|
||||
print("mean:")
|
||||
print(nn.pp(ms.mean))
|
||||
print("stdev:")
|
||||
print(nn.pp(ms.std))
|
||||
|
||||
--]]
|
||||
|
||||
return {
|
||||
Stats = Stats,
|
||||
Normalizer = Normalizer,
|
||||
}
|
59
seen_tiles.lua
Normal file
59
seen_tiles.lua
Normal file
|
@ -0,0 +1,59 @@
|
|||
return {
|
||||
[0] = true,
|
||||
[16] = true,
|
||||
[17] = true,
|
||||
[18] = true,
|
||||
[19] = true,
|
||||
[20] = true,
|
||||
[21] = true,
|
||||
[22] = true,
|
||||
[23] = true,
|
||||
[24] = true,
|
||||
[25] = true,
|
||||
[26] = true,
|
||||
[27] = true,
|
||||
[28] = true,
|
||||
[29] = true,
|
||||
[30] = true,
|
||||
[31] = true,
|
||||
[32] = true,
|
||||
[33] = true,
|
||||
[34] = true,
|
||||
[35] = true,
|
||||
[36] = true,
|
||||
[37] = true,
|
||||
[38] = true,
|
||||
[81] = true,
|
||||
[82] = true,
|
||||
[84] = true,
|
||||
[85] = true,
|
||||
[86] = true,
|
||||
[87] = true,
|
||||
[88] = true,
|
||||
[89] = true,
|
||||
[90] = true,
|
||||
[91] = true,
|
||||
[92] = true,
|
||||
[93] = true,
|
||||
[94] = true,
|
||||
[95] = true,
|
||||
[96] = true,
|
||||
[97] = true,
|
||||
[98] = true,
|
||||
[99] = true,
|
||||
[100] = true,
|
||||
[101] = true,
|
||||
[102] = true,
|
||||
[103] = true,
|
||||
[104] = true,
|
||||
[105] = true,
|
||||
[107] = true,
|
||||
[108] = true,
|
||||
[137] = true,
|
||||
[192] = true,
|
||||
[193] = true,
|
||||
[194] = true,
|
||||
[195] = true,
|
||||
[196] = true,
|
||||
[197] = true,
|
||||
}
|
76
serialize.lua
Normal file
76
serialize.lua
Normal file
|
@ -0,0 +1,76 @@
|
|||
-- it's simple, dumb, unsafe, incomplete, and it gets the damn job done
|
||||
|
||||
local type = type
|
||||
local extra = require "extra"
|
||||
local opairs = extra.opairs
|
||||
local tostring = tostring
|
||||
local open = io.open
|
||||
local strfmt = string.format
|
||||
local strrep = string.rep
|
||||
|
||||
local function kill_bom(s)
|
||||
if #s >= 3 and s:byte(1)==0xEF and s:byte(2)==0xBB and s:byte(3)==0xBF then
|
||||
return s:sub(4)
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
||||
local function sanitize(v)
|
||||
local force = type(v) == 'string' and v:sub(1, 1):match('%d')
|
||||
force = force and true or false
|
||||
return type(v) == 'string' and strfmt('%q', v) or tostring(v), force
|
||||
end
|
||||
|
||||
local function _serialize(value, writer, level)
|
||||
level = level or 1
|
||||
if type(value) == 'table' then
|
||||
local indent = strrep('\t', level)
|
||||
writer('{\n')
|
||||
for key,value in opairs(value) do
|
||||
local sane, force = sanitize(key)
|
||||
local keyval = (sane == '"'..key..'"' and not force) and key or '['..sane..']'
|
||||
writer(indent..keyval..' = ')
|
||||
_serialize(value, writer, level + 1)
|
||||
writer(',\n')
|
||||
end
|
||||
writer(strrep('\t', level - 1)..'}')
|
||||
else
|
||||
local sane, force = sanitize(value)
|
||||
writer(sane)
|
||||
end
|
||||
end
|
||||
|
||||
local function _deserialize(script)
|
||||
local f = loadstring(kill_bom(script))
|
||||
if f ~= nil then
|
||||
return f()
|
||||
else
|
||||
print('WARNING: no function to deserialize with')
|
||||
return nil
|
||||
end
|
||||
end
|
||||
|
||||
local function serialize(path, value)
|
||||
local file = open(path, 'w')
|
||||
if not file then return end
|
||||
file:write("return ")
|
||||
_serialize(value, function(...)
|
||||
file:write(...)
|
||||
end)
|
||||
file:write("\n")
|
||||
file:close()
|
||||
end
|
||||
|
||||
local function deserialize(path)
|
||||
local file = open(path, 'r')
|
||||
if not file then return end
|
||||
local script = file:read('*a')
|
||||
local value = _deserialize(script)
|
||||
file:close()
|
||||
return value
|
||||
end
|
||||
|
||||
return {
|
||||
serialize = serialize,
|
||||
deserialize = deserialize,
|
||||
}
|
7
sign_test.lua
Normal file
7
sign_test.lua
Normal file
|
@ -0,0 +1,7 @@
|
|||
local util = require("util")
|
||||
print(util.sign(-1))
|
||||
print(util.sign(0))
|
||||
print(util.sign(1))
|
||||
print(util.sign(-0.1))
|
||||
print(util.sign(0.0))
|
||||
print(util.sign(0.1))
|
217
smb.lua
217
smb.lua
|
@ -1,12 +1,16 @@
|
|||
-- disassembly used for reference:
|
||||
-- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
|
||||
|
||||
local band = bit.band
|
||||
local floor = math.floor
|
||||
local emu = emu
|
||||
local gui = gui
|
||||
|
||||
local util = require("util")
|
||||
|
||||
local band = bit.band
|
||||
local clamp = util.clamp
|
||||
local empty = util.empty
|
||||
local emu = emu
|
||||
local floor = math.floor
|
||||
local gui = gui
|
||||
local insert = table.insert
|
||||
|
||||
local R = memory.readbyteunsigned
|
||||
local W = memory.writebyte
|
||||
local function S(addr) return util.signbyte(R(addr)) end
|
||||
|
@ -76,13 +80,84 @@ local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
|
|||
-8, -38,
|
||||
}
|
||||
|
||||
local tile_embedding = {
|
||||
-- handmade trinary encoding.
|
||||
-- we have 57 valid tile types and 27 permutations to work with.
|
||||
[0x00] = { 0, 0, 0}, -- air
|
||||
[0x10] = { 1, -1, 1}, -- vertical pipe (top left) (enterable)
|
||||
[0x11] = { 1, -1, 1}, -- vertical pipe (top right) (enterable)
|
||||
[0x12] = { 0, -1, 1}, -- vertical pipe (top left)
|
||||
[0x13] = { 0, -1, 1}, -- vertical pipe (top right)
|
||||
[0x14] = { 0, -1, -1}, -- vertical pipe (left)
|
||||
[0x15] = { 0, -1, -1}, -- vertical pipe (right)
|
||||
[0x16] = {-1, -1, -1}, --
|
||||
[0x17] = {-1, -1, -1}, --
|
||||
[0x18] = {-1, -1, -1}, --
|
||||
[0x19] = {-1, -1, -1}, --
|
||||
[0x1A] = {-1, -1, -1}, --
|
||||
[0x1B] = {-1, -1, -1}, --
|
||||
[0x1C] = { 0, -1, 0}, -- horizontal pipe (top left)
|
||||
[0x1D] = { 0, -1, 0}, -- horizontal pipe (top)
|
||||
[0x1E] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (top)
|
||||
[0x1F] = { 0, -1, 0}, -- horizontal pipe (bottom left)
|
||||
[0x20] = { 0, -1, 0}, -- horizontal pipe (bottom)
|
||||
[0x21] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (bottom)
|
||||
[0x22] = { 0, 0, 0}, --
|
||||
[0x23] = { 0, 0, 0}, -- block being hit (either breakable or ?)
|
||||
[0x24] = { 0, 0, 0}, --
|
||||
[0x25] = { 0, 0, 1}, -- flagpole
|
||||
[0x26] = { 0, 0, 0}, --
|
||||
[0x51] = { 1, 1, 0}, -- breakable brick block
|
||||
[0x52] = { 1, 1, 0}, -- breakable brick block (again?)
|
||||
[0x54] = { 0, 1, 0}, -- regular ground
|
||||
[0x55] = { 0, 0, 0}, --
|
||||
[0x56] = { 0, 0, 0}, --
|
||||
[0x57] = {-1, 1, 1}, -- star brick block
|
||||
[0x58] = { 1, 1, -1}, -- coin brick block (many coins)
|
||||
[0x59] = { 0, 0, 0}, --
|
||||
[0x5A] = { 0, 0, 0}, --
|
||||
[0x5B] = { 0, 0, 0}, --
|
||||
[0x5C] = { 0, 0, 0}, --
|
||||
[0x5D] = { 1, 1, -1}, -- coin brick block (many coins) (again?)
|
||||
[0x5E] = { 0, 0, 0}, --
|
||||
[0x5F] = { 0, 0, 0}, --
|
||||
[0x60] = {-1, 0, 0}, -- invisible 1-up block
|
||||
[0x61] = { 0, 1, -1}, -- chocolate block (usually used for stairs)
|
||||
[0x62] = { 0, 0, 0}, --
|
||||
[0x63] = { 0, 0, 0}, --
|
||||
[0x64] = { 0, 0, 0}, --
|
||||
[0x65] = { 0, 0, 0}, --
|
||||
[0x66] = { 0, 0, 0}, --
|
||||
[0x67] = { 0, 0, 0}, --
|
||||
[0x68] = { 0, 0, 0}, --
|
||||
[0x69] = { 0, 0, 0}, --
|
||||
[0x6B] = { 0, 0, 0}, --
|
||||
[0x6C] = { 0, 0, 0}, --
|
||||
[0x89] = { 0, 0, 0}, --
|
||||
[0xC0] = {-1, 1, -1}, -- coin ? block
|
||||
[0xC1] = {-1, 1, 0}, -- mushroom ? block
|
||||
[0xC2] = { 0, 0, -1}, -- coin
|
||||
[0xC3] = { 0, 0, 0}, --
|
||||
[0xC4] = { 0, 1, 1}, -- hit block
|
||||
[0xC5] = { 0, 0, 0}, --
|
||||
}
|
||||
|
||||
-- TODO: reinterface to one "input" array visible to main.lua.
|
||||
local sprite_input = {}
|
||||
local tile_input = {}
|
||||
local extra_input = {}
|
||||
local new_input = {}
|
||||
|
||||
local overlay = false
|
||||
|
||||
local function embed_tile(t)
|
||||
local out = new_input
|
||||
local embedded = tile_embedding[t]
|
||||
insert(out, embedded[1])
|
||||
insert(out, embedded[2])
|
||||
insert(out, embedded[3])
|
||||
end
|
||||
|
||||
local function get_timer()
|
||||
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
|
||||
end
|
||||
|
@ -130,6 +205,7 @@ end
|
|||
|
||||
local function mark_tile(x, y, t)
|
||||
tile_input[#tile_input+1] = tile_lut[t]
|
||||
embed_tile(t)
|
||||
if t == 0 then return end
|
||||
if overlay then
|
||||
gui.box(x-8, y-8, x+8, y+8)
|
||||
|
@ -289,7 +365,8 @@ local function handle_tiles()
|
|||
extra_input[#extra_input+1] = tile_scroll_remainder
|
||||
-- for y = 0, 12 do
|
||||
-- afaik the bottom row is always a copy of the second to bottom,
|
||||
-- and the top is always air, so drop those from the inputs:
|
||||
-- and the top is always air (except underground!),
|
||||
-- so drop those from the inputs:
|
||||
for y = 1, 11 do
|
||||
for x = 0, 16 do
|
||||
local col = (x + tile_scroll) % 32
|
||||
|
@ -306,6 +383,117 @@ local function handle_tiles()
|
|||
end
|
||||
end
|
||||
|
||||
local function load_level(world, level)
|
||||
if world == 0 then world = random(1, 8) end
|
||||
if level == 0 then level = random(1, 4) end
|
||||
emu.poweron()
|
||||
local jp_mash = {
|
||||
up = false, down = false, left = false, right = false,
|
||||
A = false, B = false, select = false, start = false,
|
||||
}
|
||||
while emu.framecount() < 60 do
|
||||
if emu.framecount() == 32 then
|
||||
local area = area_lut[world * 10 + level]
|
||||
W(0x75F, world - 1)
|
||||
W(0x75C, level - 1)
|
||||
W(0x760, area)
|
||||
end
|
||||
if emu.framecount() == 42 then
|
||||
W(0x7A0, 0) -- world screen timer (reduces startup time)
|
||||
end
|
||||
jp_mash['start'] = emu.framecount() % 2 == 1
|
||||
joypad.write(1, jp_mash)
|
||||
emu.frameadvance()
|
||||
end
|
||||
end
|
||||
|
||||
-- new stuff.
|
||||
|
||||
local function tile_from_xy(x, y)
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
|
||||
local tile_x = floor((x + tile_scroll_remainder) / 16)
|
||||
local tile_y = floor(y / 16)
|
||||
local tile_t = 0 -- default to air
|
||||
if tile_y < 0 then
|
||||
tile_y = 0
|
||||
elseif tile_y > 12 then
|
||||
tile_y = 12
|
||||
else
|
||||
local col = (tile_x + tile_scroll) % 32
|
||||
local addr = col < 16 and 0x500 or 0x5D0
|
||||
tile_t = R(addr + tile_y * 16 + (col % 16))
|
||||
end
|
||||
return tile_t, tile_x, tile_y
|
||||
end
|
||||
|
||||
local function new_stuff()
|
||||
-- obviously very work in progress.
|
||||
|
||||
empty(new_input)
|
||||
|
||||
local mario_x, mario_y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
|
||||
|
||||
-- normalized mario x
|
||||
-- normalized mario y
|
||||
insert(new_input, (mario_x - 112) / 64)
|
||||
insert(new_input, (mario_y - 160) / 96)
|
||||
|
||||
-- type of tile we're standing on
|
||||
-- type of tile we're occupying
|
||||
|
||||
gui.box(mario_x, mario_y, mario_x + 16, mario_y + 32)
|
||||
|
||||
local mario_tile_t, mario_tile_x, mario_tile_y =
|
||||
tile_from_xy(mario_x + 8, mario_y - 8)
|
||||
|
||||
--[[
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
local sx = mario_tile_x * 16 + 8 - tile_scroll_remainder
|
||||
local sy = mario_tile_y * 16 + 40
|
||||
gui.box(sx-8, sy-8, sx+8, sy+8)
|
||||
gui.text(sx-5, sy-3, ("%02X"):format(mario_tile_t), '#FFFFFF', '#00000000')
|
||||
--]]
|
||||
|
||||
embed_tile(new_input, mario_tile_t)
|
||||
|
||||
-- type of tile to the right, excluding space (small eyeheight)
|
||||
-- how tall non-space extends upward from that tile
|
||||
-- how tall non-space extends downward from that tile
|
||||
-- type of tile to the right, excluding space (large eyeheight)
|
||||
|
||||
-- type of tile to the left, excluding space (small eyeheight)
|
||||
-- how tall non-space extends upward from that tile
|
||||
-- how tall non-space extends downward from that tile
|
||||
-- type of tile to the left, excluding space (large eyeheight)
|
||||
|
||||
-- type of enemy (nearest down-right from mario's upper left)
|
||||
-- normalized enemy x
|
||||
-- normalized enemy y
|
||||
|
||||
-- VISUALIZE:
|
||||
|
||||
--local tile_col = R(0x6A0)
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
for y = 1, 11 do
|
||||
for x = 0, 16 do
|
||||
local col = (x + tile_scroll) % 32
|
||||
local addr = col < 16 and 0x500 or 0x5D0
|
||||
local t = R(addr + y * 16 + (col % 16))
|
||||
local sx = x * 16 + 8 - tile_scroll_remainder
|
||||
local sy = y * 16 + 40
|
||||
|
||||
if t ~= 0 then
|
||||
gui.box(sx-8, sy-8, sx+8, sy+8)
|
||||
gui.text(sx-5, sy-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return {
|
||||
-- TODO: don't expose these; provide interfaces for everything needed.
|
||||
R=R,
|
||||
|
@ -315,6 +503,7 @@ overlay=overlay,
|
|||
|
||||
valid_tiles=valid_tiles,
|
||||
area_lut=area_lut,
|
||||
embed_tile=embed_tile,
|
||||
|
||||
sprite_input=sprite_input,
|
||||
tile_input=tile_input,
|
||||
|
@ -323,16 +512,24 @@ extra_input=extra_input,
|
|||
get_timer=get_timer,
|
||||
get_score=get_score,
|
||||
set_timer=set_timer,
|
||||
mark_sprite=mark_sprite,
|
||||
mark_tile=mark_tile,
|
||||
get_state=get_state,
|
||||
|
||||
getxy=getxy,
|
||||
paused=paused,
|
||||
get_state=get_state,
|
||||
advance=advance,
|
||||
|
||||
mark_sprite=mark_sprite,
|
||||
mark_tile=mark_tile,
|
||||
|
||||
handle_enemies=handle_enemies,
|
||||
handle_fireballs=handle_fireballs,
|
||||
handle_blocks=handle_blocks,
|
||||
handle_hammers=handle_hammers,
|
||||
handle_misc=handle_misc,
|
||||
handle_tiles=handle_tiles,
|
||||
|
||||
advance=advance,
|
||||
load_level=load_level,
|
||||
|
||||
new_stuff=new_stuff,
|
||||
new_input=new_input,
|
||||
}
|
||||
|
|
131
snes.lua
131
snes.lua
|
@ -3,7 +3,6 @@
|
|||
-- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
|
||||
-- not to be confused with the Super Nintendo Entertainment System.
|
||||
|
||||
local abs = math.abs
|
||||
local assert = assert
|
||||
local exp = math.exp
|
||||
local floor = math.floor
|
||||
|
@ -30,9 +29,12 @@ local normalize_sums = util.normalize_sums
|
|||
local pdf = util.pdf
|
||||
local weighted_mann_whitney = util.weighted_mann_whitney
|
||||
|
||||
local xnes = require "xnes"
|
||||
local make_utility = xnes.make_utility
|
||||
|
||||
local Snes = Base:extend()
|
||||
|
||||
function Snes:init(dims, popsize, base_rate, sigma, antithetic)
|
||||
function Snes:init(dims, popsize, base_rate, sigma, antithetic, adaptive)
|
||||
-- heuristic borrowed from CMA-ES:
|
||||
self.dims = dims
|
||||
self.popsize = popsize or 4 + (3 * floor(log(dims)))
|
||||
|
@ -42,9 +44,12 @@ function Snes:init(dims, popsize, base_rate, sigma, antithetic)
|
|||
self.covar_rate = base_rate
|
||||
self.sigma = sigma or 1
|
||||
self.antithetic = antithetic and true or false
|
||||
self.adaptive = adaptive == nil and true or adaptive
|
||||
|
||||
if self.antithetic then self.popsize = self.popsize * 2 end
|
||||
|
||||
self.utility = make_utility(self.popsize)
|
||||
|
||||
self.rate_init = self.sigma_rate
|
||||
|
||||
self.mean = zeros{dims}
|
||||
|
@ -148,60 +153,84 @@ function Snes:ask_mix(start_anew)
|
|||
|
||||
-- perform importance mixing.
|
||||
|
||||
local mean_old = self.mean
|
||||
local mean_old = self.mean_old or self.mean
|
||||
local mean_new = self.mean
|
||||
local std_old = self.std_old or self.std
|
||||
local std_new = self.std
|
||||
|
||||
self.new_asked = {}
|
||||
self.new_noise = {}
|
||||
|
||||
local marked = {}
|
||||
for p=1, min(#self.old_asked, self.popsize) do
|
||||
local a = self.old_asked[p]
|
||||
|
||||
-- TODO: cache probs?
|
||||
local function compute_probabilities(a)
|
||||
local prob_new = 0
|
||||
local prob_old = 0
|
||||
for i, v in ipairs(a) do
|
||||
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
|
||||
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
|
||||
end
|
||||
return prob_new, prob_old
|
||||
end
|
||||
|
||||
local all_asked, all_noise, all_score = {}, {}, {}
|
||||
|
||||
for p=1, #self.old_asked do
|
||||
do
|
||||
local pp = floor(uniform() * #self.old_asked) + 1
|
||||
local a = self.old_asked[pp]
|
||||
local prob_new, prob_old = compute_probabilities(a)
|
||||
local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1)
|
||||
if uniform() < accept then
|
||||
--print(("accepted old sample %i with probability %f"):format(p, accept))
|
||||
else
|
||||
-- insert in reverse as not to screw up
|
||||
-- the indices when removing later.
|
||||
insert(marked, 1, p)
|
||||
--print(("accepted old sample %i with probability %f"):format(pp, accept))
|
||||
insert(all_asked, a)
|
||||
insert(all_noise, self.old_noise[pp])
|
||||
insert(all_score, self.old_score[pp])
|
||||
end
|
||||
end
|
||||
for _, p in ipairs(marked) do
|
||||
remove(self.old_asked, p)
|
||||
remove(self.old_noise, p)
|
||||
remove(self.old_score, p)
|
||||
end
|
||||
|
||||
while #self.old_asked + #self.new_asked < self.popsize do
|
||||
local a = {}
|
||||
local n = {}
|
||||
do
|
||||
local a, n = {}, {}
|
||||
for i=1, self.dims do n[i] = normal() end
|
||||
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
|
||||
|
||||
-- can't cache here!
|
||||
local prob_new = 0
|
||||
local prob_old = 0
|
||||
for i, v in ipairs(a) do
|
||||
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
|
||||
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
|
||||
end
|
||||
|
||||
local prob_new, prob_old = compute_probabilities(a)
|
||||
local accept = max(1 - prob_old / prob_new, self.min_refresh)
|
||||
if uniform() < accept then
|
||||
--print(("accepted new sample %i with probability %f"):format(#all_asked, accept))
|
||||
insert(all_asked, a)
|
||||
insert(all_noise, n)
|
||||
insert(all_score, false)
|
||||
end
|
||||
end
|
||||
|
||||
-- TODO: early stopping, making sure it doesn't affect performance.
|
||||
end
|
||||
|
||||
while #all_asked > self.popsize do
|
||||
local pp = floor(uniform() * #all_asked) + 1
|
||||
--print(("removing sample %i to fit popsize"):format(pp))
|
||||
remove(all_asked, pp)
|
||||
remove(all_noise, pp)
|
||||
remove(all_score, pp)
|
||||
end
|
||||
|
||||
while #all_asked < self.popsize do
|
||||
local a, n = {}, {}
|
||||
for i=1, self.dims do n[i] = normal() end
|
||||
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
|
||||
--print(("unconditionally added new sample %i"):format(#all_asked))
|
||||
insert(all_asked, a)
|
||||
insert(all_noise, n)
|
||||
insert(all_score, false)
|
||||
end
|
||||
|
||||
-- split all_ tables back into old_ and new_.
|
||||
self.old_asked, self.old_noise, self.old_score = {}, {}, {}
|
||||
self.new_asked, self.new_noise = {}, {}
|
||||
for i, score in ipairs(all_score) do
|
||||
local a, n = all_asked[i], all_noise[i]
|
||||
if score ~= false then
|
||||
insert(self.old_asked, a)
|
||||
insert(self.old_noise, n)
|
||||
insert(self.old_score, score)
|
||||
else
|
||||
insert(self.new_asked, a)
|
||||
insert(self.new_noise, n)
|
||||
--print(("accepted new sample %i with probability %f"):format(0, accept))
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -211,15 +240,15 @@ end
|
|||
function Snes:tell(scored)
|
||||
self.evals = self.evals + #scored
|
||||
|
||||
local asked = self.asked
|
||||
local noise = self.noise
|
||||
local asked = self.mixing and self.new_asked or self.asked
|
||||
local noise = self.mixing and self.new_noise or self.noise
|
||||
if self.mixing then
|
||||
-- note: modifies, in-place, externally exposed tables.
|
||||
for i, v in ipairs(asked) do insert(self.old_asked, v) end
|
||||
for i, v in ipairs(noise) do insert(self.old_noise, v) end
|
||||
for i, v in ipairs(scored) do insert(self.old_score, v) end
|
||||
asked = self.old_asked
|
||||
noise = self.old_noise
|
||||
-- note that these modify tables referenced externally in-place.
|
||||
for i, v in ipairs(self.new_asked) do insert(asked, v) end
|
||||
for i, v in ipairs(self.new_noise) do insert(noise, v) end
|
||||
for i, v in ipairs(scored) do insert(self.old_score, v) end
|
||||
scored = self.old_score
|
||||
end
|
||||
assert(asked and noise, ":tell() called before :ask()")
|
||||
|
@ -232,8 +261,9 @@ function Snes:tell(scored)
|
|||
local g_mean = zeros{self.dims}
|
||||
local g_std = zeros{self.dims}
|
||||
|
||||
--[[
|
||||
local utilize = true
|
||||
local utility
|
||||
local utility = self.utility
|
||||
|
||||
if utilize then
|
||||
utility = {}
|
||||
|
@ -243,17 +273,18 @@ function Snes:tell(scored)
|
|||
else
|
||||
utility = normalize_sums(scored, {})
|
||||
end
|
||||
--]]
|
||||
|
||||
for p=1, self.popsize do
|
||||
local noise_p = noise[p]
|
||||
local noise_p = noise[arg[p]]
|
||||
|
||||
for i, v in ipairs(g_mean) do
|
||||
g_mean[i] = v + utility[p] * noise_p[i]
|
||||
g_mean[i] = v + self.utility[p] * noise_p[i]
|
||||
end
|
||||
|
||||
for i, v in ipairs(g_std) do
|
||||
local n = noise_p[i]
|
||||
g_std[i] = v + utility[p] * (n * n - 1)
|
||||
g_std[i] = v + self.utility[p] * (n * n - 1)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -262,7 +293,9 @@ function Snes:tell(scored)
|
|||
step[i] = self.std[i] * v
|
||||
end
|
||||
|
||||
self.mean_old = {}
|
||||
for i, v in ipairs(self.mean) do
|
||||
self.mean_old[i] = v
|
||||
self.mean[i] = v + self.param_rate * step[i]
|
||||
end
|
||||
|
||||
|
@ -274,7 +307,7 @@ function Snes:tell(scored)
|
|||
otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i])
|
||||
end
|
||||
|
||||
self:adapt(asked, otherwise, utility)
|
||||
if self.adaptive then self:adapt(asked, otherwise, self.utility) end
|
||||
|
||||
return step
|
||||
end
|
||||
|
@ -292,15 +325,15 @@ function Snes:adapt(asked, otherwise, qualities)
|
|||
weights[p] = prob_big / prob_now
|
||||
end
|
||||
|
||||
local p = weighted_mann_whitney(qualities, qualities, nil, weights)
|
||||
--print("p:", p)
|
||||
local u, p = weighted_mann_whitney(qualities, qualities, nil, weights)
|
||||
--print(("u, p: %6.3f, %6.3f"):format(u, p))
|
||||
|
||||
if p < 0.5 - 1 / (3 * (self.dims + 1)) then
|
||||
self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init
|
||||
print("learning rate -:", self.sigma_rate)
|
||||
--print("learning rate -:", self.sigma_rate)
|
||||
else
|
||||
self.sigma_rate = min(1.1 * self.sigma_rate, 1)
|
||||
print("learning rate +:", self.sigma_rate)
|
||||
--print("learning rate +:", self.sigma_rate)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
56
util.lua
56
util.lua
|
@ -14,6 +14,7 @@ local random = math.random
|
|||
local select = select
|
||||
local sort = table.sort
|
||||
local sqrt = math.sqrt
|
||||
local type = type
|
||||
|
||||
local function sign(x)
|
||||
-- remember that 0 is truthy in Lua.
|
||||
|
@ -83,6 +84,28 @@ local function calc_mean_dev(x)
|
|||
return mean, sqrt(dev)
|
||||
end
|
||||
|
||||
local function calc_mean_dev_unbiased(x)
|
||||
-- NOTE: this uses an approximation; there is still a little bias.
|
||||
assert(#x > 1)
|
||||
|
||||
local mean = 0
|
||||
for i, v in ipairs(x) do
|
||||
mean = mean + v / #x
|
||||
end
|
||||
|
||||
-- via Gurland, John; Tripathi, Ram C. (1971):
|
||||
-- A Simple Approximation for Unbiased Estimation of the Standard Deviation
|
||||
local divisor = #x - 1.5 + 1 / (8 * (#x - 1))
|
||||
|
||||
local dev = 0
|
||||
for i, v in ipairs(x) do
|
||||
local delta = v - mean
|
||||
dev = dev + delta * delta / divisor
|
||||
end
|
||||
|
||||
return mean, sqrt(dev)
|
||||
end
|
||||
|
||||
local function normalize(x, out)
|
||||
out = out or x
|
||||
local mean, dev = calc_mean_dev(x)
|
||||
|
@ -186,6 +209,11 @@ local function cdf(x)
|
|||
-- i don't remember where this is from.
|
||||
local sign = x >= 0 and 1 or -1
|
||||
return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x)))
|
||||
|
||||
-- more accurate (via GELU paper, might be lifted from elsewhere):
|
||||
--local const = sqrt(2 / pi)
|
||||
--return 0.5 * (1 + tanh(const * (1 + 0.044715 * x * x) * x))
|
||||
--return 0.5 * (1 + tanh(0.7978845608 * (1 + 0.044715 * x * x) * x))
|
||||
end
|
||||
|
||||
local function fitness_shaping(rewards)
|
||||
|
@ -246,7 +274,31 @@ local function weighted_mann_whitney(s0, s1, w0, w1)
|
|||
local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6)
|
||||
local p = cdf((U - mean) / std)
|
||||
|
||||
if s0_sum > s1_sum then return 1 - p else return p end
|
||||
local u = U / (w0_sum * w1_sum)
|
||||
if s0_sum > s1_sum then return u, 1 - p else return u, p end
|
||||
end
|
||||
|
||||
local function expect_cossim(n)
|
||||
-- returns gamma(n / 2) / gamma((n + 1) / 2) / sqrt(pi) for positive integers.
|
||||
-- this is the expected absolute cosine similarity between
|
||||
-- two standard normally-distributed random vectors both of size n.
|
||||
assert(n > 0)
|
||||
|
||||
-- abs(error) < 1e-8
|
||||
if n >= 128000 then
|
||||
return 1 / sqrt(pi / 2 * n + 1)
|
||||
elseif n >= 80 then
|
||||
poly = (2.4674010 * n - 2.4673232) * n + 1.2274046
|
||||
return 1 / sqrt(sqrt(poly))
|
||||
end
|
||||
-- fall-through when it's faster just to compute iteratively.
|
||||
|
||||
even = n % 2 == 0
|
||||
res = even and 2.0 or 1.0
|
||||
for i = even and 2 or 1, n - 1, 2 do
|
||||
res = res * (i / (i + 1))
|
||||
end
|
||||
return even and res / pi or res
|
||||
end
|
||||
|
||||
return {
|
||||
|
@ -260,6 +312,7 @@ return {
|
|||
softchoice=softchoice,
|
||||
empty=empty,
|
||||
calc_mean_dev=calc_mean_dev,
|
||||
calc_mean_dev_unbiased=calc_mean_dev_unbiased,
|
||||
normalize=normalize,
|
||||
normalize_wrt=normalize_wrt,
|
||||
normalize_sums=normalize_sums,
|
||||
|
@ -276,4 +329,5 @@ return {
|
|||
pdf=pdf,
|
||||
cdf=cdf,
|
||||
weighted_mann_whitney=weighted_mann_whitney,
|
||||
expect_cossim=expect_cossim,
|
||||
}
|
||||
|
|
43
xnes.lua
43
xnes.lua
|
@ -16,10 +16,13 @@ local unpack = table.unpack or unpack
|
|||
local Base = require "Base"
|
||||
|
||||
local nn = require "nn"
|
||||
local dot = nn.dot
|
||||
local dot_mv = nn.dot_mv
|
||||
local normal = nn.normal
|
||||
local zeros = nn.zeros
|
||||
|
||||
local expm = require "expm"
|
||||
|
||||
local util = require "util"
|
||||
local argsort = util.argsort
|
||||
|
||||
|
@ -35,22 +38,6 @@ local function make_utility(popsize, out)
|
|||
return utility
|
||||
end
|
||||
|
||||
local function make_covars(dims, sigma, out)
|
||||
local covars = out or zeros{dims, dims}
|
||||
local c = sigma / dims
|
||||
-- simplified form of the determinant of the matrix we're going to create:
|
||||
local det = pow(1 - c, dims - 1) * (c * (dims - 1) + 1)
|
||||
-- multiplying by this constant makes the determinant 1:
|
||||
local m = pow(1 / det, 1 / dims)
|
||||
|
||||
local filler = c * m
|
||||
for i=1, #covars do covars[i] = filler end
|
||||
-- diagonals:
|
||||
for i=1, dims do covars[i + dims * (i - 1)] = m end
|
||||
|
||||
return covars
|
||||
end
|
||||
|
||||
function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
|
||||
-- heuristic borrowed from CMA-ES:
|
||||
self.dims = dims
|
||||
|
@ -67,9 +54,13 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
|
|||
self.utility = make_utility(self.popsize)
|
||||
|
||||
self.mean = zeros{dims}
|
||||
-- note: this is technically the co-standard-deviation.
|
||||
-- you can imagine the "s" standing for "sqrt" if you like.
|
||||
self.covars = make_covars(self.dims, self.sigma, self.covars)
|
||||
self.covars = zeros{dims, dims}
|
||||
for i=1, dims do
|
||||
local ind = (i - 1) * dims + i -- diagonal
|
||||
self.covars[ind] = 1
|
||||
end
|
||||
|
||||
self.evals = 0
|
||||
end
|
||||
|
||||
function Xnes:params(new_mean)
|
||||
|
@ -153,6 +144,8 @@ function Xnes:tell(scored, noise)
|
|||
local noise = noise or self.noise
|
||||
assert(noise, "missing noise argument")
|
||||
|
||||
self.evals = self.evals + #scored
|
||||
|
||||
local arg = argsort(scored, function(a, b) return a > b end)
|
||||
|
||||
local g_delta = zeros{self.dims}
|
||||
|
@ -173,7 +166,7 @@ function Xnes:tell(scored, noise)
|
|||
local zzt = noise_p[i] * noise_p[j] - (i == j and 1 or 0)
|
||||
local temp = self.utility[p] * zzt
|
||||
g_covars[ind] = g_covars[ind] + temp
|
||||
traced = traced + temp
|
||||
if i == j then traced = traced + temp end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -181,7 +174,7 @@ function Xnes:tell(scored, noise)
|
|||
local g_sigma = traced / self.dims
|
||||
|
||||
for i=1, self.dims do
|
||||
local ind = (i - 1) * self.dims + i
|
||||
local ind = (i - 1) * self.dims + i -- diagonal
|
||||
g_covars[ind] = g_covars[ind] - g_sigma
|
||||
end
|
||||
|
||||
|
@ -198,9 +191,12 @@ function Xnes:tell(scored, noise)
|
|||
end
|
||||
|
||||
self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma)
|
||||
for i, v in ipairs(self.covars) do
|
||||
self.covars[i] = v * exp(self.covar_rate * 0.5 * g_covars[i])
|
||||
|
||||
-- re-use g_covars from before just to scale it.
|
||||
for i, v in ipairs(g_covars) do
|
||||
g_covars[i] = self.covar_rate * 0.5 * v
|
||||
end
|
||||
self.covars = dot(self.covars, expm(g_covars))
|
||||
|
||||
-- bookkeeping:
|
||||
self.noise = nil
|
||||
|
@ -210,7 +206,6 @@ end
|
|||
|
||||
return {
|
||||
make_utility = make_utility,
|
||||
make_covars = make_covars,
|
||||
|
||||
Xnes = Xnes,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user