Compare commits

...

4 Commits
master ... temp

Author SHA1 Message Date
Connor Olding 7462e69c61 temp 4 2019-03-11 07:15:41 +01:00
Connor Olding 08f476e6ac temp 3 2019-02-26 21:53:38 +01:00
Connor Olding a1429a6271 temp 2 2018-08-13 11:58:45 +02:00
Connor Olding b7938a1785 temp 2018-06-30 20:13:54 +02:00
28 changed files with 3314 additions and 211 deletions

107
ars.lua
View File

@ -1,17 +1,16 @@
-- Augmented Random Search
-- https://arxiv.org/abs/1803.07055
-- with some tweaks (lipschitz stuff) by myself.
-- i also added an option for graycode sampling,
-- borrowed from a (1+1) optimizer,
-- but i haven't yet found a case where it performs better.
local abs = math.abs
local exp = math.exp
local floor = math.floor
local insert = table.insert
local remove = table.remove
local ipairs = ipairs
local log = math.log
local max = math.max
local print = print
local sqrt = math.sqrt
local Base = require "Base"
@ -24,6 +23,7 @@ local zeros = nn.zeros
local util = require "util"
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local calc_mean_dev_unbiased = util.calc_mean_dev_unbiased
local normalize_sums = util.normalize_sums
local sign = util.sign
@ -54,29 +54,16 @@ local function collect_best_indices(scored, top, antithetic)
return indices
end
local function kinda_lipschitz(dir, pos, neg, mid)
-- based on the local lipschitz constant of a quadratic curve
-- drawn through the 3 sampled points: positive, negative, and unperturbed.
-- it kinda helps? there's probably a better function to base it around.
local _, dev = calc_mean_dev(dir)
local c0 = neg - mid
local c1 = pos - mid
local l0 = abs(3 * c1 + c0)
local l1 = abs(c1 + 3 * c0)
return max(l0, l1) / (2 * dev)
end
function Ars:init(dims, popsize, poptop, base_rate, sigma, antithetic,
momentum)
momentum, beta)
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.param_rate = base_rate
self.sigma_rate = base_rate
self.covar_rate = base_rate
self.sigma = sigma or 1
self.antithetic = antithetic == nil and true or antithetic
self.momentum = momentum or 0
self.beta = beta or 1.0
self.poptop = poptop or popsize
assert(self.poptop <= popsize)
@ -104,7 +91,7 @@ function Ars:decay(param_decay, sigma_decay)
end
end
function Ars:ask(graycode)
function Ars:ask()
local asked = {}
local noise = {}
@ -120,17 +107,8 @@ function Ars:ask(graycode)
noisy[j] = -v
end
else
if graycode ~= nil then
for j = 1, self.dims do
noisy[j] = exp(-precision * uniform())
end
for j = 1, self.dims do
noisy[j] = uniform() < 0.5 and noisy[j] or -noisy[j]
end
else
for j = 1, self.dims do
noisy[j] = self.sigma * normal()
end
for j = 1, self.dims do
noisy[j] = self.sigma * normal()
end
end
@ -144,9 +122,8 @@ function Ars:ask(graycode)
end
function Ars:tell(scored, unperturbed_score)
local use_lips = unperturbed_score ~= nil and self.antithetic
self.evals = self.evals + #scored
if use_lips then self.evals = self.evals + 1 end
if unperturbed_score ~= nil then self.evals = self.evals + 1 end
local indices = collect_best_indices(scored, self.poptop, self.antithetic)
@ -167,7 +144,16 @@ function Ars:tell(scored, unperturbed_score)
end
local step = zeros(self.dims)
local _, reward_dev = calc_mean_dev(top_rewards)
local _, reward_dev
if unperturbed_score ~= nil then
-- new stuff:
insert(top_rewards, unperturbed_score)
_, reward_dev = calc_mean_dev_unbiased(top_rewards)
remove(top_rewards)
else
_, reward_dev = calc_mean_dev(top_rewards)
end
if reward_dev == 0 then reward_dev = 1 end
if self.antithetic then
@ -177,31 +163,39 @@ function Ars:tell(scored, unperturbed_score)
local reward = pos - neg
if reward ~= 0 then
local noisy = self.noise[ind * 2 - 1]
if use_lips then
local lips = kinda_lipschitz(noisy, pos, neg, unperturbed_score)
reward = reward / lips / self.sigma
else
reward = reward / reward_dev
end
reward = reward / reward_dev
--[[ new stuff:
local sum_of_squares = 0
for _, v in ipairs(noisy) do
sum_of_squares = sum_of_squares + v * v
end
reward = reward / sqrt(sum_of_squares)
-]]
local scale = reward / self.poptop * self.beta / 2
for j, v in ipairs(noisy) do
step[j] = step[j] + reward * v / self.poptop
step[j] = step[j] + scale * v
end
end
end
else
error("TODO: update with sum of squares stuff")
for i, ind in ipairs(indices) do
local reward = top_rewards[i] / reward_dev
if reward ~= 0 then
local noisy = self.noise[ind]
local scale = reward / self.poptop * self.beta
for j, v in ipairs(noisy) do
step[j] = step[j] + reward * v / self.poptop
step[j] = step[j] + scale * v
end
end
end
end
--[[ powersign momentum
if self.momentum > 0 then
for i, v in ipairs(step) do
self.accum[i] = self.momentum * self.accum[i] + v
@ -212,6 +206,35 @@ function Ars:tell(scored, unperturbed_score)
for i, v in ipairs(self._params) do
self._params[i] = v + self.param_rate * step[i]
end
--]]
-- neumann momentum
if self.momentum > 0 then
local count = self.count or 0
local period = 10
local mu = 1 - 1 / (1 + count % period)
mu = self.momentum / (1 - 1 / period) * mu
self.count = count + 1
-- mu is intentionally 0 for one iteration.
-- make learning rate invariant to sigma.
for i, v in ipairs(step) do
step[i] = v / self.sigma
end
-- update neumann iterate.
for i, v in ipairs(self.accum) do
self.accum[i] = mu * v - self.param_rate * step[i]
end
for i, v in ipairs(self._params) do
self._params[i] = v - mu * self.accum[i] + self.param_rate * step[i]
end
else
for i, v in ipairs(self._params) do
self._params[i] = v + self.param_rate * step[i]
end
end
self.noise = nil

749
binser.lua Normal file
View File

@ -0,0 +1,749 @@
-- binser.lua
--[[
Copyright (c) 2016 Calvin Rose
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
]]
local assert = assert
local error = error
local select = select
local pairs = pairs
local getmetatable = getmetatable
local setmetatable = setmetatable
local type = type
local loadstring = loadstring or load
local concat = table.concat
local char = string.char
local byte = string.byte
local format = string.format
local sub = string.sub
local dump = string.dump
local floor = math.floor
local frexp = math.frexp
local unpack = unpack or table.unpack
-- Lua 5.3 frexp polyfill
-- From https://github.com/excessive/cpml/blob/master/modules/utils.lua
if not frexp then
local log, abs, floor = math.log, math.abs, math.floor
local log2 = log(2)
frexp = function(x)
if x == 0 then return 0, 0 end
local e = floor(log(abs(x)) / log2 + 1)
return x / 2 ^ e, e
end
end
local function pack(...)
return {...}, select("#", ...)
end
local function not_array_index(x, len)
return type(x) ~= "number" or x < 1 or x > len or x ~= floor(x)
end
local function type_check(x, tp, name)
assert(type(x) == tp,
format("Expected parameter %q to be of type %q.", name, tp))
end
local bigIntSupport = false
local isInteger
if math.type then -- Detect Lua 5.3
local mtype = math.type
bigIntSupport = loadstring[[
local char = string.char
return function(n)
local nn = n < 0 and -(n + 1) or n
local b1 = nn // 0x100000000000000
local b2 = nn // 0x1000000000000 % 0x100
local b3 = nn // 0x10000000000 % 0x100
local b4 = nn // 0x100000000 % 0x100
local b5 = nn // 0x1000000 % 0x100
local b6 = nn // 0x10000 % 0x100
local b7 = nn // 0x100 % 0x100
local b8 = nn % 0x100
if n < 0 then
b1, b2, b3, b4 = 0xFF - b1, 0xFF - b2, 0xFF - b3, 0xFF - b4
b5, b6, b7, b8 = 0xFF - b5, 0xFF - b6, 0xFF - b7, 0xFF - b8
end
return char(212, b1, b2, b3, b4, b5, b6, b7, b8)
end]]()
isInteger = function(x)
return mtype(x) == 'integer'
end
else
isInteger = function(x)
return floor(x) == x
end
end
-- Copyright (C) 2012-2015 Francois Perrad.
-- number serialization code modified from https://github.com/fperrad/lua-MessagePack
-- Encode a number as a big-endian ieee-754 double, big-endian signed 64 bit integer, or a small integer
local function number_to_str(n)
if isInteger(n) then -- int
if n <= 100 and n >= -27 then -- 1 byte, 7 bits of data
return char(n + 27)
elseif n <= 8191 and n >= -8192 then -- 2 bytes, 14 bits of data
n = n + 8192
return char(128 + (floor(n / 0x100) % 0x100), n % 0x100)
elseif bigIntSupport then
return bigIntSupport(n)
end
end
local sign = 0
if n < 0.0 then
sign = 0x80
n = -n
end
local m, e = frexp(n) -- mantissa, exponent
if m ~= m then
return char(203, 0xFF, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
elseif m == 1/0 then
if sign == 0 then
return char(203, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
else
return char(203, 0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
end
end
e = e + 0x3FE
if e < 1 then -- denormalized numbers
m = m * 2 ^ (52 + e)
e = 0
else
m = (m * 2 - 1) * 2 ^ 52
end
return char(203,
sign + floor(e / 0x10),
(e % 0x10) * 0x10 + floor(m / 0x1000000000000),
floor(m / 0x10000000000) % 0x100,
floor(m / 0x100000000) % 0x100,
floor(m / 0x1000000) % 0x100,
floor(m / 0x10000) % 0x100,
floor(m / 0x100) % 0x100,
m % 0x100)
end
-- Copyright (C) 2012-2015 Francois Perrad.
-- number deserialization code also modified from https://github.com/fperrad/lua-MessagePack
local function number_from_str(str, index)
local b = byte(str, index)
if not b then error("Expected more bytes of input.") end
if b < 128 then
return b - 27, index + 1
elseif b < 192 then
local b2 = byte(str, index + 1)
if not b2 then error("Expected more bytes of input.") end
return b2 + 0x100 * (b - 128) - 8192, index + 2
end
local b1, b2, b3, b4, b5, b6, b7, b8 = byte(str, index + 1, index + 8)
if (not b1) or (not b2) or (not b3) or (not b4) or
(not b5) or (not b6) or (not b7) or (not b8) then
error("Expected more bytes of input.")
end
if b == 212 then
local flip = b1 >= 128
if flip then -- negative
b1, b2, b3, b4 = 0xFF - b1, 0xFF - b2, 0xFF - b3, 0xFF - b4
b5, b6, b7, b8 = 0xFF - b5, 0xFF - b6, 0xFF - b7, 0xFF - b8
end
local n = ((((((b1 * 0x100 + b2) * 0x100 + b3) * 0x100 + b4) *
0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8
if flip then
return (-n) - 1, index + 9
else
return n, index + 9
end
end
if b ~= 203 then
error("Expected number")
end
local sign = b1 > 0x7F and -1 or 1
local e = (b1 % 0x80) * 0x10 + floor(b2 / 0x10)
local m = ((((((b2 % 0x10) * 0x100 + b3) * 0x100 + b4) * 0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8
local n
if e == 0 then
if m == 0 then
n = sign * 0.0
else
n = sign * (m / 2 ^ 52) * 2 ^ -1022
end
elseif e == 0x7FF then
if m == 0 then
n = sign * (1/0)
else
n = 0.0/0.0
end
else
n = sign * (1.0 + m / 2 ^ 52) * 2 ^ (e - 0x3FF)
end
return n, index + 9
end
local function newbinser()
-- unique table key for getting next value
local NEXT = {}
local CTORSTACK = {}
-- NIL = 202
-- FLOAT = 203
-- TRUE = 204
-- FALSE = 205
-- STRING = 206
-- TABLE = 207
-- REFERENCE = 208
-- CONSTRUCTOR = 209
-- FUNCTION = 210
-- RESOURCE = 211
-- INT64 = 212
-- TABLE WITH META = 213
local mts = {}
local ids = {}
local serializers = {}
local deserializers = {}
local resources = {}
local resources_by_name = {}
local types = {}
types["nil"] = function(x, visited, accum)
accum[#accum + 1] = "\202"
end
function types.number(x, visited, accum)
accum[#accum + 1] = number_to_str(x)
end
function types.boolean(x, visited, accum)
accum[#accum + 1] = x and "\204" or "\205"
end
function types.string(x, visited, accum)
local alen = #accum
if visited[x] then
accum[alen + 1] = "\208"
accum[alen + 2] = number_to_str(visited[x])
else
visited[x] = visited[NEXT]
visited[NEXT] = visited[NEXT] + 1
accum[alen + 1] = "\206"
accum[alen + 2] = number_to_str(#x)
accum[alen + 3] = x
end
end
local function check_custom_type(x, visited, accum)
local res = resources[x]
if res then
accum[#accum + 1] = "\211"
types[type(res)](res, visited, accum)
return true
end
local mt = getmetatable(x)
local id = mt and ids[mt]
if id then
local constructing = visited[CTORSTACK]
if constructing[x] then
error("Infinite loop in constructor.")
end
constructing[x] = true
accum[#accum + 1] = "\209"
types[type(id)](id, visited, accum)
local args, len = pack(serializers[id](x))
accum[#accum + 1] = number_to_str(len)
for i = 1, len do
local arg = args[i]
types[type(arg)](arg, visited, accum)
end
visited[x] = visited[NEXT]
visited[NEXT] = visited[NEXT] + 1
-- We finished constructing
constructing[x] = nil
return true
end
end
function types.userdata(x, visited, accum)
if visited[x] then
accum[#accum + 1] = "\208"
accum[#accum + 1] = number_to_str(visited[x])
else
if check_custom_type(x, visited, accum) then return end
error("Cannot serialize this userdata.")
end
end
function types.table(x, visited, accum)
if visited[x] then
accum[#accum + 1] = "\208"
accum[#accum + 1] = number_to_str(visited[x])
else
if check_custom_type(x, visited, accum) then return end
visited[x] = visited[NEXT]
visited[NEXT] = visited[NEXT] + 1
local xlen = #x
local mt = getmetatable(x)
if mt then
accum[#accum + 1] = "\213"
types.table(mt, visited, accum)
else
accum[#accum + 1] = "\207"
end
accum[#accum + 1] = number_to_str(xlen)
for i = 1, xlen do
local v = x[i]
types[type(v)](v, visited, accum)
end
local key_count = 0
for k in pairs(x) do
if not_array_index(k, xlen) then
key_count = key_count + 1
end
end
accum[#accum + 1] = number_to_str(key_count)
for k, v in pairs(x) do
if not_array_index(k, xlen) then
types[type(k)](k, visited, accum)
types[type(v)](v, visited, accum)
end
end
end
end
types["function"] = function(x, visited, accum)
if visited[x] then
accum[#accum + 1] = "\208"
accum[#accum + 1] = number_to_str(visited[x])
else
if check_custom_type(x, visited, accum) then return end
visited[x] = visited[NEXT]
visited[NEXT] = visited[NEXT] + 1
local str = dump(x)
accum[#accum + 1] = "\210"
accum[#accum + 1] = number_to_str(#str)
accum[#accum + 1] = str
end
end
types.cdata = function(x, visited, accum)
if visited[x] then
accum[#accum + 1] = "\208"
accum[#accum + 1] = number_to_str(visited[x])
else
if check_custom_type(x, visited, #accum) then return end
error("Cannot serialize this cdata.")
end
end
types.thread = function() error("Cannot serialize threads.") end
local function deserialize_value(str, index, visited)
local t = byte(str, index)
if not t then return nil, index end
if t < 128 then
return t - 27, index + 1
elseif t < 192 then
local b2 = byte(str, index + 1)
if not b2 then error("Expected more bytes of input.") end
return b2 + 0x100 * (t - 128) - 8192, index + 2
elseif t == 202 then
return nil, index + 1
elseif t == 203 or t == 212 then
return number_from_str(str, index)
elseif t == 204 then
return true, index + 1
elseif t == 205 then
return false, index + 1
elseif t == 206 then
local length, dataindex = number_from_str(str, index + 1)
local nextindex = dataindex + length
if not (length >= 0) then error("Bad string length") end
if #str < nextindex - 1 then error("Expected more bytes of string") end
local substr = sub(str, dataindex, nextindex - 1)
visited[#visited + 1] = substr
return substr, nextindex
elseif t == 207 or t == 213 then
local mt, count, nextindex
local ret = {}
visited[#visited + 1] = ret
nextindex = index + 1
if t == 213 then
mt, nextindex = deserialize_value(str, nextindex, visited)
if type(mt) ~= "table" then error("Expected table metatable") end
end
count, nextindex = number_from_str(str, nextindex)
for i = 1, count do
local oldindex = nextindex
ret[i], nextindex = deserialize_value(str, nextindex, visited)
if nextindex == oldindex then error("Expected more bytes of input.") end
end
count, nextindex = number_from_str(str, nextindex)
for i = 1, count do
local k, v
local oldindex = nextindex
k, nextindex = deserialize_value(str, nextindex, visited)
if nextindex == oldindex then error("Expected more bytes of input.") end
oldindex = nextindex
v, nextindex = deserialize_value(str, nextindex, visited)
if nextindex == oldindex then error("Expected more bytes of input.") end
if k == nil then error("Can't have nil table keys") end
ret[k] = v
end
if mt then setmetatable(ret, mt) end
return ret, nextindex
elseif t == 208 then
local ref, nextindex = number_from_str(str, index + 1)
return visited[ref], nextindex
elseif t == 209 then
local count
local name, nextindex = deserialize_value(str, index + 1, visited)
count, nextindex = number_from_str(str, nextindex)
local args = {}
for i = 1, count do
local oldindex = nextindex
args[i], nextindex = deserialize_value(str, nextindex, visited)
if nextindex == oldindex then error("Expected more bytes of input.") end
end
if not name or not deserializers[name] then
error(("Cannot deserialize class '%s'"):format(tostring(name)))
end
local ret = deserializers[name](unpack(args))
visited[#visited + 1] = ret
return ret, nextindex
elseif t == 210 then
local length, dataindex = number_from_str(str, index + 1)
local nextindex = dataindex + length
if not (length >= 0) then error("Bad string length") end
if #str < nextindex - 1 then error("Expected more bytes of string") end
local ret = loadstring(sub(str, dataindex, nextindex - 1))
visited[#visited + 1] = ret
return ret, nextindex
elseif t == 211 then
local resname, nextindex = deserialize_value(str, index + 1, visited)
if resname == nil then error("Got nil resource name") end
local res = resources_by_name[resname]
if res == nil then
error(("No resources found for name '%s'"):format(tostring(resname)))
end
return res, nextindex
else
error("Could not deserialize type byte " .. t .. ".")
end
end
local function serialize(...)
local visited = {[NEXT] = 1, [CTORSTACK] = {}}
local accum = {}
for i = 1, select("#", ...) do
local x = select(i, ...)
types[type(x)](x, visited, accum)
end
return concat(accum)
end
local function make_file_writer(file)
return setmetatable({}, {
__newindex = function(_, _, v)
file:write(v)
end
})
end
local function serialize_to_file(path, mode, ...)
local file, err = io.open(path, mode)
assert(file, err)
local visited = {[NEXT] = 1, [CTORSTACK] = {}}
local accum = make_file_writer(file)
for i = 1, select("#", ...) do
local x = select(i, ...)
types[type(x)](x, visited, accum)
end
-- flush the writer
file:flush()
file:close()
end
local function writeFile(path, ...)
return serialize_to_file(path, "wb", ...)
end
local function appendFile(path, ...)
return serialize_to_file(path, "ab", ...)
end
local function deserialize(str, index)
assert(type(str) == "string", "Expected string to deserialize.")
local vals = {}
index = index or 1
local visited = {}
local len = 0
local val
while true do
local nextindex
val, nextindex = deserialize_value(str, index, visited)
if nextindex > index then
len = len + 1
vals[len] = val
index = nextindex
else
break
end
end
return vals, len
end
local function deserializeN(str, n, index)
assert(type(str) == "string", "Expected string to deserialize.")
n = n or 1
assert(type(n) == "number", "Expected a number for parameter n.")
assert(n > 0 and floor(n) == n, "N must be a poitive integer.")
local vals = {}
index = index or 1
local visited = {}
local len = 0
local val
while len < n do
local nextindex
val, nextindex = deserialize_value(str, index, visited)
if nextindex > index then
len = len + 1
vals[len] = val
index = nextindex
else
break
end
end
vals[len + 1] = index
return unpack(vals, 1, n + 1)
end
local function readFile(path)
local file, err = io.open(path, "rb")
assert(file, err)
local str = file:read("*all")
file:close()
return deserialize(str)
end
-- Resources
local function registerResource(resource, name)
type_check(name, "string", "name")
assert(not resources[resource],
"Resource already registered.")
assert(not resources_by_name[name],
format("Resource %q already exists.", name))
resources_by_name[name] = resource
resources[resource] = name
return resource
end
local function unregisterResource(name)
type_check(name, "string", "name")
assert(resources_by_name[name], format("Resource %q does not exist.", name))
local resource = resources_by_name[name]
resources_by_name[name] = nil
resources[resource] = nil
return resource
end
-- Templating
local function normalize_template(template)
local ret = {}
for i = 1, #template do
ret[i] = template[i]
end
local non_array_part = {}
-- The non-array part of the template (nested templates) have to be deterministic, so they are sorted.
-- This means that inherently non deterministicly sortable keys (tables, functions) should NOT be used
-- in templates. Looking for way around this.
for k in pairs(template) do
if not_array_index(k, #template) then
non_array_part[#non_array_part + 1] = k
end
end
table.sort(non_array_part)
for i = 1, #non_array_part do
local name = non_array_part[i]
ret[#ret + 1] = {name, normalize_template(template[name])}
end
return ret
end
local function templatepart_serialize(part, argaccum, x, len)
local extras = {}
local extracount = 0
for k, v in pairs(x) do
extras[k] = v
extracount = extracount + 1
end
for i = 1, #part do
local name
if type(part[i]) == "table" then
name = part[i][1]
len = templatepart_serialize(part[i][2], argaccum, x[name], len)
else
name = part[i]
len = len + 1
argaccum[len] = x[part[i]]
end
if extras[name] ~= nil then
extracount = extracount - 1
extras[name] = nil
end
end
if extracount > 0 then
argaccum[len + 1] = extras
else
argaccum[len + 1] = nil
end
return len + 1
end
local function templatepart_deserialize(ret, part, values, vindex)
for i = 1, #part do
local name = part[i]
if type(name) == "table" then
local newret = {}
ret[name[1]] = newret
vindex = templatepart_deserialize(newret, name[2], values, vindex)
else
ret[name] = values[vindex]
vindex = vindex + 1
end
end
local extras = values[vindex]
if extras then
for k, v in pairs(extras) do
ret[k] = v
end
end
return vindex + 1
end
local function template_serializer_and_deserializer(metatable, template)
return function(x)
local argaccum = {}
local len = templatepart_serialize(template, argaccum, x, 0)
return unpack(argaccum, 1, len)
end, function(...)
local ret = {}
local args = {...}
templatepart_deserialize(ret, template, args, 1)
return setmetatable(ret, metatable)
end
end
-- Used to serialize classes withh custom serializers and deserializers.
-- If no _serialize or _deserialize (or no _template) value is found in the
-- metatable, then the metatable is registered as a resources.
local function register(metatable, name, serialize, deserialize)
if type(metatable) == "table" then
name = name or metatable.name
serialize = serialize or metatable._serialize
deserialize = deserialize or metatable._deserialize
if (not serialize) or (not deserialize) then
if metatable._template then
-- Register as template
local t = normalize_template(metatable._template)
serialize, deserialize = template_serializer_and_deserializer(metatable, t)
else
-- Register the metatable as a resource. This is semantically
-- similar and more flexible (handles cycles).
registerResource(metatable, name)
return
end
end
elseif type(metatable) == "string" then
name = name or metatable
end
type_check(name, "string", "name")
type_check(serialize, "function", "serialize")
type_check(deserialize, "function", "deserialize")
assert((not ids[metatable]) and (not resources[metatable]),
"Metatable already registered.")
assert((not mts[name]) and (not resources_by_name[name]),
("Name %q already registered."):format(name))
mts[name] = metatable
ids[metatable] = name
serializers[name] = serialize
deserializers[name] = deserialize
return metatable
end
local function unregister(item)
local name, metatable
if type(item) == "string" then -- assume name
name, metatable = item, mts[item]
else -- assume metatable
name, metatable = ids[item], item
end
type_check(name, "string", "name")
mts[name] = nil
if (metatable) then
resources[metatable] = nil
ids[metatable] = nil
end
serializers[name] = nil
deserializers[name] = nil
resources_by_name[name] = nil;
return metatable
end
local function registerClass(class, name)
name = name or class.name
if class.__instanceDict then -- middleclass
register(class.__instanceDict, name)
else -- assume 30log or similar library
register(class, name)
end
return class
end
return {
-- aliases
s = serialize,
d = deserialize,
dn = deserializeN,
r = readFile,
w = writeFile,
a = appendFile,
serialize = serialize,
deserialize = deserialize,
deserializeN = deserializeN,
readFile = readFile,
writeFile = writeFile,
appendFile = appendFile,
register = register,
unregister = unregister,
registerResource = registerResource,
unregisterResource = unregisterResource,
registerClass = registerClass,
newbinser = newbinser
}
end
return newbinser()

View File

@ -26,14 +26,17 @@ local defaults = {
time_inputs = true, -- insert binary inputs of a frame counter.
-- network layers:
embed = true, -- set to false to use a hard-coded tile embedding.
reduce_tiles = 0, -- TODO: write description
hidden = false, -- use a hidden layer with ReLU/GELU activation.
hidden_size = 128,
layernorm = false, -- use a LayerNorm layer after said activation.
reduce_tiles = false,
bias_out = true,
-- network evaluation (sampling joypad):
frameskip = 4,
prob_frameskip = 0.0,
max_frameskip = 6,
-- true greedy epsilon has both deterministic and det_epsilon set.
deterministic = false, -- use argmax on outputs instead of random sampling.
det_epsilon = false, -- take random actions with probability eps.
@ -41,12 +44,16 @@ local defaults = {
-- evolution strategy and non-rate hyperparemeters:
es = 'ars',
ars_lips = false, -- for ARS.
epoch_top_trials = 9999, -- for ARS.
epoch_top_trials = 9999, -- for ARS, Guided.
alpha = 0.5, -- for Guided.
beta = 2.0, -- for ARS, Guided. should be 1, but defaults to 2 for compat.
past_grads = 1, -- for Guided. keeps a history of n past steps taken.
-- sampling:
deviation = 1.0,
unperturbed_trial = true, -- perform an extra trial without any noise.
-- this is good for logging, so i'd recommend it.
attempts = 1, -- TODO: document.
epoch_trials = 50,
graycode = false, -- for ARS.
negate_trials = true, -- try pairs of normal and negated noise directions.
@ -113,5 +120,9 @@ assert(not cfg.ars_lips or cfg.negate_trials,
"cfg.negate_trials must be true to use cfg.ars_lips")
assert(not (cfg.es == 'snes' and cfg.negate_trials),
"cfg.negate_trials is not yet compatible with SNES")
assert(not (cfg.es == 'guided' and cfg.graycode),
"cfg.graycode is not compatible with Guided")
assert(cfg.es ~= 'guided' or cfg.negate_trials,
"cfg.negate_trials must be true to use Guided")
return cfg

10
cossim_test.lua Normal file
View File

@ -0,0 +1,10 @@
local util = require "util"
for dims = 1, 100 do
print(dims, util.expect_cossim(dims))
end
dims = 1000
print(dims, util.expect_cossim(dims))
dims = 10000
print(dims, util.expect_cossim(dims))

290
eig.lua Normal file
View File

@ -0,0 +1,290 @@
-- blah
--local globalize = require "strict"
local nn = require "nn"
local util = require "util"
local sign = util.sign
local abs = math.abs
local reshape = nn.reshape
local sqrt = math.sqrt
local zeros = nn.zeros
local function tred2(a)
assert(#a.shape == 2)
assert(a.shape[1] == a.shape[2])
local n = a.shape[1]
local d = zeros(n) -- diagonal
local e = zeros(n) -- off-diagonal (e[1] is a dummy value?)
for i = n, 2, -1 do
local l = i - 1
local ind_li = (l - 1) * n + i
local h = 0
local scale = 0
if l > 1 then
for k = 1, l do
local ind_ki = (k - 1) * n + i
scale = scale + abs(a[ind_ki])
end
if scale == 0 then
e[i] = a[ind_li]
else
for k = 1, l do
local ind_ki = (k - 1) * n + i
a[ind_ki] = a[ind_ki] / scale
h = h + a[ind_ki] * a[ind_ki]
end
local f = a[ind_li]
local g = sqrt(h)
if f >= 0 then g = -g end
e[i] = scale * g
h = h - f * g
a[ind_li] = f - g
f = 0
for j = 1, l do
local ind_ij = (i - 1) * n + j
local ind_ji = (j - 1) * n + i
a[ind_ij] = a[ind_ji] / h
g = 0
for k = 1, j do
local ind_kj = (k - 1) * n + j
local ind_ki = (k - 1) * n + i
g = g + a[ind_kj] * a[ind_ki]
end
for k = j + 1, l do
local ind_jk = (j - 1) * n + k
local ind_ki = (k - 1) * n + i
g = g + a[ind_jk] * a[ind_ki]
end
e[j] = g / h
f = f + e[j] * a[ind_ji]
end
local hh = f / (h + h)
for j = 1, l do
local ind_ji = (j - 1) * n + i
f = a[ind_ji]
g = e[j] - hh * f
e[j] = g
for k = 1, j do
local ind_kj = (k - 1) * n + j
local ind_ki = (k - 1) * n + i
a[ind_kj] = a[ind_kj] - (f * e[k] + g * a[ind_ki])
end
end
end
else
e[i] = a[ind_li]
end
d[i] = h
end
d[1] = 0
e[1] = 0
for i = 1, n do
local l = i - 1
if d[i] ~= 0 then
for j = 1, l do
local g = 0
for k = 1, l do
local ind_ki = (k - 1) * n + i
local ind_jk = (j - 1) * n + k
g = g + a[ind_ki] * a[ind_jk]
end
for k = 1, l do
local ind_ik = (i - 1) * n + k
local ind_jk = (j - 1) * n + k
a[ind_jk] = a[ind_jk] - g * a[ind_ik]
end
end
end
local ind_ii = (i - 1) * n + i
d[i] = a[ind_ii]
a[ind_ii] = 1
for j = 1, l do
local ind_ij = (i - 1) * n + j
local ind_ji = (j - 1) * n + i
a[ind_ij] = 0
a[ind_ji] = 0
end
end
return d, e
end
local function pythag(a, b)
--return sqrt(a * a + b * b)
local abs_a = abs(a)
local abs_b = abs(b)
if abs_a > abs_b then
local temp = abs_b / abs_a
temp = temp * temp
return abs_a * sqrt(1 + temp)
elseif abs_b ~= 0 then
local temp = abs_a / abs_b
temp = temp * temp
return abs_b * sqrt(1 + temp)
end
return 0
end
local function tqli(d, e, z)
assert(#z.shape == 2)
assert(z.shape[1] == z.shape[2])
local n = z.shape[1]
assert(#d == n)
assert(#e == n)
local eps = 1.2e-7
for i = 2, n do e[i - 1] = e[i] end
e[n] = 0
for l = 1, n do
local iter = 0
local fucky = 0
local m
while true do
m = l
while m <= n - 1 do
local dd = abs(d[m]) + abs(d[m + 1])
if abs(e[m]) + dd == dd then break end
--if abs(e[m]) <= eps * dd then break end
m = m + 1
end
fucky = fucky + 1
if fucky == 100 then print("fucky!"); break end
if fucky == 1000 then error("super fucky!"); break end
--print(("l: %i, m: %i"):format(l - 1, m - 1))
if m == l then break end
iter = iter + 1
if iter >= 32 then error("Too many iterations in tqli") end
local g = (d[l + 1] - d[l]) / (2 * e[l])
local r = pythag(g, 1)
g = d[m] - d[l] + e[l] / (g + r * sign(g))
local s = 1
local c = 1
local p = 0
for i = m - 1, l, -1 do
local f = s * e[i]
local b = c * e[i]
r = pythag(f, g)
e[i + 1] = r
if r == 0 then
d[i + 1] = d[i + 1] - p
e[m] = 0
break
end
s = f / r
c = g / r
g = d[i + 1] - p
r = (d[i] - g) * s + 2 * c * b
p = s * r
d[i + 1] = g + p
g = c * r - b
for k = 1, n do
if true then
local ind = (i - 1) * n + k
f = z[ind + n]
z[ind + n] = s * z[ind] + c * f
z[ind] = c * z[ind] - s * f
else
local ind = (k - 1) * n + i
f = z[ind + 1]
z[ind + 1] = s * z[ind] + c * f
z[ind] = c * z[ind] - s * f
end
end
end
if r == 0 and i >= l then
-- continue
else
d[l] = d[l] - p
e[l] = g
e[m] = 0
end
end
end
end
--[=[
local A = {
4, 1, -2, 2,
1, 2, 0, 1,
-2, 0, 3, -2,
2, 1, -2, -1,
}
reshape(A, 4, 4)
local d, e = tred2(A)
--[[
print(nn.pp(A))
print(nn.pp(d))
print(nn.pp(e))
--]]
--[[
{
0.248069, 0.744208, 0.620174, 0.000000,
0.702863, -0.578829, 0.413449, 0.000000,
0.666667, 0.333333, -0.666667, 0.000000,
0.000000, 0.000000, 0.000000, 1.000000,
}
{ 2.261538, 1.182906, 5.555556, -1.000000,}
{ 0.000000, -0.092308, 0.895806, 3.000000,}
--]]
--A = nn.transpose(A)
tqli(d, e, A)
print(nn.pp(A))
print(nn.pp(d))
print(nn.pp(e))
local D = zeros{4, 4}
for i = 1, 4 do
D[(i - 1) * 4 + i] = math.exp(d[i])
end
local out = nn.dot(nn.transpose(A), D)
out = nn.dot(out, A)
print(nn.pp(out))
--[[
{
703.414032, 1410.125991, 1478.990752, -43.126976,
-1205.204565, 1573.963121, -940.902581, 319.676625,
14.433478, 3.884400, -10.434671, 6.841011,
3.529384, 3.087068, -5.116236, -17.850223,
}
{ 2.273819, 1.072834, 6.818268, -2.164921,}
{ -0.000000, 0.000000, 0.000000, 0.000000,}
--]]
--]=]
return {
tred2=tred2,
tqli=tqli,
}

163
es_test.lua Normal file
View File

@ -0,0 +1,163 @@
local floor = math.floor
local insert = table.insert
local ipairs = ipairs
local log = math.log
local max = math.max
local print = print
local ars = require("ars")
local snes = require("snes")
local xnes = require("xnes")
local guided = require("guided")
-- try it all out on a dummy problem.
local function typeof(t) return getmetatable(t).__index end
local function square(x) return x * x end
-- this function's global minimum is arange(dims) + 1.
-- xNES should be able to find it almost exactly.
local function spherical(x)
local sum = 0
--for i, v in ipairs(x) do sum = sum + square(v - i) end
for i, v in ipairs(x) do sum = sum + square(v - i / #x) end
-- we need to negate this to turn it into a maximization problem.
return -sum
end
-- i'm just copying settings from hardmaru's simple_es_example.ipynb.
local iterations = 3000 --4000
local dims, popsize
if false then
dims = 100
popsize = dims + 1
else
dims = 30
popsize = 99
end
local sigma_init = 0.5
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
local es = snes.Snes(dims, popsize, 0.1, sigma_init)
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
--local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
--local es = guided.Guided(dims, popsize, popsize, 1.0, sigma_init, 0.5)
es.min_refresh = 1.0 -- FIXME: needs a better interface.
if typeof(es) == xnes.Xnes
or typeof(es) == snes.Snes
then
-- use IGO recommendations
local pop5 = max(1, floor(es.popsize / 5))
local sum = 0
for i=1, es.popsize do
local maybe = i < pop5 and 1 or 0
es.utility[i] = maybe
sum = sum + maybe
end
--for i, v in ipairs(es.utility) do es.utility[i] = v / sum end
local util = require "util"
util.normalize_sums(es.utility)
es.param_rate = 0.2 --39
es.sigma_rate = 0.05 --39
es.covar_rate = 0.1 --39
es.adaptive = false
end
if false then -- TODO: delete me
local nn = require("nn")
local util = require("util")
local insert = table.insert
local scored = nn.arange(10)
local indices = ars.collect_best_indices(scored, 3, true)
for i, ind in ipairs(indices) do
print(ind, ":", scored[ind * 2 - 1], scored[ind * 2 - 0])
end
local top_rewards = {}
for _, ind in ipairs(indices) do
insert(top_rewards, scored[ind * 2 - 1])
insert(top_rewards, scored[ind * 2 - 0])
end
-- this shouldn't make a difference to the final print:
top_rewards = util.normalize_sums(top_rewards)
print(nn.pp(top_rewards))
local _, reward_dev = util.calc_mean_dev(top_rewards)
print(reward_dev)
for i, ind in ipairs(indices) do
local pos = top_rewards[i * 2 - 1]
local neg = top_rewards[i * 2 - 0]
local reward = pos - neg
reward = reward / reward_dev
print(reward)
end
do return end
end
local asked = nil -- for caching purposes.
local noise = nil -- for caching purposes.
local current_cost = spherical(es:params())
local past_grads = {}
local pgi = 0
local pgn = 10
for i=1, iterations do
if typeof(es) == snes.Snes and es.min_refresh ~= 1 then
asked, noise = es:ask_mix()
elseif typeof(es) == ars.Ars then
asked, noise = es:ask()
elseif typeof(es) == guided.Guided then
asked, noise = es:ask(past_grads)
else
asked, noise = es:ask(asked, noise)
end
local scores = {}
for i, v in ipairs(asked) do
scores[i] = spherical(v)
end
if typeof(es) == ars.Ars then
es:tell(scores)--, current_cost) -- use lips
elseif typeof(es) == guided.Guided then
local step = es:tell(scores)
for _, v in ipairs(step) do
past_grads[pgi + 1] = v
pgi = (pgi + 1) % (pgn * #step)
end
past_grads.shape = {floor(#past_grads / #step), #step}
else
es:tell(scores)
end
current_cost = spherical(es:params())
if i % 100 == 0 then
local sigma = es.sigma
if typeof(es) == snes.Snes then
sigma = 0
for i, v in ipairs(es.std) do sigma = sigma + v end
sigma = sigma / #es.std
end
local inconvergence = sigma / sigma_init
local fmt = "fitness at iteration %i: %.4f (%.4f)"
print(fmt:format(i, current_cost, log(inconvergence) / log(10)))
end
end
-- note: this metric doesn't include the "fitness at iteration" evaluations,
-- because those aren't actually used to step towards the optimum.
print(("optimized in %i function evaluations"):format(es.evals))
local s = ''
for i, v in ipairs(es:params()) do
s = s..("%.8f"):format(v)
if i ~= es.dim then s = s..', ' end
end
print(s)

81
expm.lua Normal file
View File

@ -0,0 +1,81 @@
-- here's a really awful way of computing the matrix exponential.
-- we employ the QR algorithm to find eigenpairs of a given symmetric matrix,
-- then run the ordinary exponent function over the eigenvalues.
-- this only works for symmetric matrices!
local nn = require "nn"
local qr = require "qr"
local util = require "util"
local copy = util.copy
local dot = nn.dot
local exp = math.exp
local reshape = nn.reshape
local transpose = nn.transpose
local zeros = nn.zeros
local function expm(mat)
assert(#mat.shape == 2)
assert(mat.shape[1] == mat.shape[2], "expm input must be square")
--assert(stuff(mat), "expm input must be symmetrical")
local dims = mat.shape[1]
local vec = zeros(mat.shape)
for i = 1, dims do
local ind = (i - 1) * dims + i -- diagonal
vec[ind] = 1
end
local diag = mat
for i = 1, 10 do
local q, r = qr(diag)
vec = dot(vec, q)
diag = dot(r, q)
end
for y = 1, dims do
for x = 1, dims do
local ind = (y - 1) * dims + x
if x == y then
diag[ind] = exp(diag[ind])
else
diag[ind] = 0
end
end
end
return dot(dot(vec, diag), transpose(vec))
end
local eig = require "eig"
local tred2 = eig.tred2
local tqli = eig.tqli
local function expm2(mat)
assert(#mat.shape == 2)
assert(mat.shape[1] == mat.shape[2])
local dims = mat.shape[1]
-- new version that computes much better (faster?) eigenpairs
local vec = copy(mat)
local d, e = tred2(vec)
tqli(d, e, vec)
local diag = {}
for y = 1, dims do
for x = 1, dims do
local ind = (y - 1) * dims + x
if x == y then
diag[ind] = exp(d[x])
else
diag[ind] = 0
end
end
end
reshape(diag, dims, dims)
return dot(dot(transpose(vec), diag), vec)
end
return expm2

31
expm_test.lua Normal file
View File

@ -0,0 +1,31 @@
local globalize = require "strict"
local nn = require "nn"
local expm = require "expm"
local sqrt = math.sqrt
local dims = 8
local n = dims * dims
local M = {}
local coeff = dims * dims * sqrt(dims)
for i = 1, n do M[i] = i / coeff end
M.shape = {dims, dims}
M = nn.dot(M, nn.transpose(M))
print(nn.pp(M, "%9.6f"))
local exp_M = expm(M)
print(nn.pp(exp_M, "%9.6f"))
local binser = require "binser"
--local dat = binser.s(exp_M)
binser.writeFile("expm.bin", exp_M)
local res, len = binser.readFile("expm.bin")
assert(len == 1)
exp_M = res[1]
print(nn.pp(exp_M, "%9.6f"))

62
extra.lua Normal file
View File

@ -0,0 +1,62 @@
local function strpad(num, count, pad)
num = tostring(num)
return (pad:rep(count)..num):sub(#num)
end
local function add_zeros(num, count)
return strpad(num, count - 1, '0')
end
local function mixed_sorter(a, b)
a = type(a) == 'number' and add_zeros(a, 16) or tostring(a)
b = type(b) == 'number' and add_zeros(b, 16) or tostring(b)
return a < b
end
-- loosely based on http://lua-users.org/wiki/SortedIteration
-- the original didn't make use of closures for who knows why
local function order_keys(t)
local oi = {}
for key in pairs(t) do
table.insert(oi, key)
end
table.sort(oi, mixed_sorter)
return oi
end
local function opairs(t, cache)
local oi = cache and cache[t] or order_keys(t)
if cache then
cache[t] = oi
end
local i = 0
return function()
i = i + 1
local key = oi[i]
if key then return key, t[key] end
end
end
local function traverse(path)
if not path then return end
local parent = _G
local key
for w in path:gfind("[%w_]+") do
if key then
parent = rawget(parent, key)
if type(parent) ~= 'table' then return end
end
key = w
end
if not key then return end
return {parent=parent, key=key}
end
return {
strpad = strpad,
add_zeros = add_zeros,
mixed_sorter = mixed_sorter,
order_keys = order_keys,
opairs = opairs,
traverse = traverse,
}

194
guided.lua Normal file
View File

@ -0,0 +1,194 @@
-- Guided Evolutionary Strategies
-- https://arxiv.org/abs/1806.10230
-- this is just ARS extended to utilize gradients
-- approximated from previous iterations.
-- for simplicity:
-- antithetic is always true
-- momentum is always 0
-- no graycode/lipschitz nonsense
local floor = math.floor
local insert = table.insert
local ipairs = ipairs
local max = math.max
local print = print
local sqrt = math.sqrt
local Base = require "Base"
local nn = require "nn"
local dot_mv = nn.dot_mv
local transpose = nn.transpose
local normal = nn.normal
local prod = nn.prod
local uniform = nn.uniform
local zeros = nn.zeros
local qr = require "qr2"
local util = require "util"
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local Guided = Base:extend()
local function collect_best_indices(scored, top)
-- select one (the best) reward of each pos/neg pair.
local best_rewards
best_rewards = {}
for i = 1, #scored / 2 do
local pos = scored[i * 2 - 1]
local neg = scored[i * 2 - 0]
best_rewards[i] = max(pos, neg)
end
local indices = argsort(best_rewards, function(a, b) return a > b end)
for i = top + 1, #best_rewards do indices[i] = nil end
return indices
end
function Guided:init(dims, popsize, poptop, base_rate, sigma, alpha, beta)
-- sigma: scale of random perturbations.
-- alpha: blend between full parameter space and its gradient subspace.
-- 1.0 is roughly equivalent to ARS.
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
base_rate = base_rate or 3/5 * (3 + log(dims)) / (dims * sqrt(dims))
self.param_rate = base_rate
self.sigma = sigma or 1.0
self.alpha = alpha or 0.5
self.beta = beta or 1.0
self.poptop = poptop or popsize
assert(self.poptop <= popsize)
self.popsize = self.popsize * 2 -- antithetic
self._params = zeros(self.dims)
--self.accum = zeros(self.dims) -- momentum
self.evals = 0
end
function Guided:params(new_params)
if new_params ~= nil then
assert(#self._params == #new_params, "new parameters have the wrong size")
for i, v in ipairs(new_params) do self._params[i] = v end
end
return self._params
end
function Guided:decay(param_decay, sigma_decay)
-- FIXME: multiplying by sigma probably isn't correct anymore.
-- is this correct now?
if param_decay > 0 then
local scale = self.sigma / sqrt(self.dims)
scale = scale * self.beta
scale = scale * self.param_rate / (self.sigma * self.sigma)
scale = 1 - param_decay * scale
for i, v in ipairs(self._params) do
self._params[i] = scale * v
end
end
end
function Guided:ask(grads)
local asked = {}
local noise = {}
local n_grad = 0
local gnoise, U, dummy, left, right
if grads ~= nil and #grads > 0 then
n_grad = grads.shape[1]
gnoise = zeros(n_grad)
U, dummy = qr(transpose(grads))
--print(nn.pp(transpose(U), "%9.4f"))
left = sqrt(self.alpha / self.dims)
right = sqrt((1 - self.alpha) / n_grad)
--print(left, right)
end
for i = 1, self.popsize do
local asking = zeros(self.dims)
local noisy = zeros(self.dims)
asked[i] = asking
noise[i] = noisy
if i % 2 == 0 then
local old_noisy = noise[i - 1]
for j, v in ipairs(old_noisy) do
noisy[j] = -v
end
elseif n_grad == 0 then
local scale = self.sigma / sqrt(self.dims)
for j = 1, self.dims do
noisy[j] = scale * normal()
end
else
for j = 1, self.dims do noisy[j] = normal() end
for j = 1, n_grad do gnoise[j] = normal() end
local noisier = dot_mv(U, gnoise)
for j, v in ipairs(noisy) do
noisy[j] = self.sigma * (left * v + right * noisier[j])
end
end
for j, v in ipairs(self._params) do
asking[j] = v + noisy[j]
end
end
self.noise = noise
return asked, noise
end
function Guided:tell(scored, unperturbed_score)
self.evals = self.evals + #scored
local indices = collect_best_indices(scored, self.poptop)
local top_rewards = {}
for _, ind in ipairs(indices) do
insert(top_rewards, scored[ind * 2 - 1])
insert(top_rewards, scored[ind * 2 - 0])
end
local step = zeros(self.dims)
local _, reward_dev = calc_mean_dev(top_rewards)
if reward_dev == 0 then reward_dev = 1 end
for i, ind in ipairs(indices) do
local pos = top_rewards[i * 2 - 1]
local neg = top_rewards[i * 2 - 0]
local reward = pos - neg
if reward ~= 0 then
local noisy = self.noise[ind * 2 - 1]
-- NOTE: technically this reward divide isn't part of guided search.
reward = reward / reward_dev
local scale = reward / self.poptop * self.beta / 2
for j, v in ipairs(noisy) do
step[j] = step[j] + scale * v
end
end
end
local coeff = self.param_rate / (self.sigma * self.sigma)
for i, v in ipairs(self._params) do
self._params[i] = v + coeff * step[i]
end
self.noise = nil
return step
end
return {
--collect_best_indices = collect_best_indices, -- ars.lua has more features
Guided = Guided,
}

171
main.lua
View File

@ -20,10 +20,17 @@ local trial_rewards = {}
local trials_remaining = 0
local es -- evolution strategy.
local attempt_i = 0
local sub_rewards = {}
local past_grads = {} -- for Guided
local pgi = 0 -- past_grads_index
local trial_frames = 0
local total_frames = 0
local lagless_count = 0
local decisions_made = 0
local last_decision_frame = -1
local force_start = false
local force_start_old = false
@ -92,6 +99,7 @@ local util = require("util")
local argmax = util.argmax
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local calc_mean_dev_unbiased = util.calc_mean_dev_unbiased
local clamp = util.clamp
local copy = util.copy
local empty = util.empty
@ -148,13 +156,21 @@ local network
local nn_x, nn_tx, nn_ty, nn_tz, nn_y, nn_z
local function make_network(input_size)
nn_x = nn.Input({input_size})
nn_tx = nn.Input({gcfg.tile_count})
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, 2))
local embed_dim = cfg.embed and 2 or 3
if cfg.embed then
nn_tx = nn.Input({gcfg.tile_count})
nn_ty = nn_tx:feed(nn.Embed(#game.valid_tiles, embed_dim))
else -- new tile inputs.
nn_tx = nn.Input({gcfg.tile_count * 3})
nn_ty = nn_tx
end
nn_tz = nn_ty
if cfg.reduce_tiles then
nn_tz = nn_tz:feed(nn.Reshape{11, 17 * 2})
nn_tz = nn_tz:feed(nn.DenseBroadcast(5, true))
if cfg.reduce_tiles > 0 then
nn_tz = nn_tz:feed(nn.Reshape{11, 17 * embed_dim})
nn_tz = nn_tz:feed(nn.DenseBroadcast(cfg.reduce_tiles, true))
nn_tz = nn_tz:feed(nn.Relu())
-- note: due to a quirk in Merge, we don't need to flatten nn_tz.
end
@ -184,6 +200,7 @@ end
local ars = require("ars")
local snes = require("snes")
local xnes = require("xnes")
local guided = require("guided")
local function prepare_epoch()
trial_neg = false
@ -212,6 +229,8 @@ local function prepare_epoch()
local dummy
if cfg.es == 'ars' then
trial_params, dummy = es:ask(precision)
elseif cfg.es == 'guided' then
trial_params, dummy = es:ask(past_grads)
elseif cfg.es == 'snes' then
trial_params, dummy = es:ask_mix()
else
@ -222,6 +241,7 @@ local function prepare_epoch()
end
local function load_next_trial()
attempt_i = 1
if cfg.negate_trials then
trial_neg = not trial_neg
else
@ -264,15 +284,23 @@ local function learn_from_epoch()
end
local step
if cfg.es == 'ars' and cfg.ars_lips then
if cfg.es == 'ars' then --and cfg.ars_lips then
step = es:tell(trial_rewards, current_cost)
else
step = es:tell(trial_rewards)
end
local step_mean, step_dev = calc_mean_dev(step)
print("step mean:", step_mean)
print("step stddev:", step_dev)
print(("step mean: %9.6f"):format(step_mean))
print(("step stddev: %9.6f"):format(step_dev))
if cfg.es == 'guided' and cfg.past_grads > 0 then
for _, v in ipairs(step) do
past_grads[pgi + 1] = v
pgi = (pgi + 1) % (cfg.past_grads * #step)
end
past_grads.shape = {floor(#past_grads / #step), #step}
end
es:decay(cfg.param_decay, cfg.sigma_decay)
@ -328,65 +356,62 @@ local function joypad_mash(button)
joypad.write(1, jp_mash)
end
local function loadlevel(world, level)
-- TODO: move to smb.lua. rename to load_level.
if world == 0 then world = random(1, 8) end
if level == 0 then level = random(1, 4) end
emu.poweron()
while emu.framecount() < 60 do
if emu.framecount() == 32 then
local area = game.area_lut[world * 10 + level]
game.W(0x75F, world - 1)
game.W(0x75C, level - 1)
game.W(0x760, area)
end
if emu.framecount() == 42 then
game.W(0x7A0, 0) -- world screen timer (reduces startup time)
end
joypad_mash('start')
emu.frameadvance()
end
end
local function do_reset()
local state = game.get_state()
-- be a little more descriptive.
if state == 'dead' and game.get_timer() == 0 then state = 'timeup' end
if trial_i >= 0 then
if trial_i == 0 then
print('test trial reward:', reward, "("..state..")")
elseif cfg.negate_trials then
--local dir = trial_neg and "negative" or "positive"
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
--if cfg.attempts > 1 and attempt_i >= cfg.attempts then
attempt_i = attempt_i + 1
sub_rewards[#sub_rewards + 1] = reward
--print(sub_rewards)
if trial_neg then
local pos = trial_rewards[#trial_rewards]
local neg = reward
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
print(fmt:format(floor(trial_i / 2),
pos, neg, last_trial_state, state))
if #sub_rewards >= cfg.attempts then
if cfg.attempts == 1 then
reward = sub_rewards[1]
else
local sub_mean, sub_std = calc_mean_dev(sub_rewards)
reward = floor(sub_mean)
--local sub_mean, sub_std = calc_mean_dev_unbiased(sub_rewards)
--reward = floor(sub_mean - sub_std)
end
empty(sub_rewards)
if trial_i >= 0 then
if trial_i == 0 then
print('test trial reward:', reward, "("..state..")")
elseif cfg.negate_trials then
--local dir = trial_neg and "negative" or "positive"
--print('trial', trial_i, dir, 'reward:', reward, "("..state..")")
if trial_neg then
local pos = trial_rewards[#trial_rewards]
local neg = reward
local fmt = "trial %i rewards: %+i, %+i (%s, %s)"
print(fmt:format(floor(trial_i / 2),
pos, neg, last_trial_state, state))
end
last_trial_state = state
else
print('trial', trial_i, 'reward:', reward, "("..state..")")
end
if trial_i == 0 or not cfg.negate_trials then
trial_rewards[trial_i] = reward
else
trial_rewards[#trial_rewards + 1] = reward
end
last_trial_state = state
else
print('trial', trial_i, 'reward:', reward, "("..state..")")
end
if trial_i == 0 or not cfg.negate_trials then
trial_rewards[trial_i] = reward
else
trial_rewards[#trial_rewards + 1] = reward
end
end
if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then
if epoch_i > 0 then learn_from_epoch() end
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
prepare_epoch()
collectgarbage()
if any_random then
loadlevel(cfg.starting_world, cfg.starting_level)
state_saved = false
if epoch_i == 0 or (trial_i == #trial_params and trial_neg) then
if epoch_i > 0 then learn_from_epoch() end
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
prepare_epoch()
collectgarbage()
if any_random then
game.load_level(cfg.starting_world, cfg.starting_level)
state_saved = false
end
end
end
@ -426,7 +451,9 @@ local function do_reset()
trial_frames = 0
emu.frameadvance() -- prevents emulator from quirking up.
load_next_trial()
if attempt_i > cfg.attempts then
load_next_trial()
end
reset = false
end
@ -449,7 +476,7 @@ local function init()
if not playing then emu.speedmode("turbo") end
if not any_random then
loadlevel(cfg.starting_world, cfg.starting_level)
game.load_level(cfg.starting_world, cfg.starting_level)
end
params_fn = cfg.params_fn or ('network%07i.txt'):format(network.n_param)
@ -480,7 +507,10 @@ local function init()
elseif cfg.es == 'ars' then
es = ars.Ars(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
cfg.base_rate, cfg.deviation, cfg.negate_trials,
cfg.momentum)
cfg.momentum, cfg.beta)
elseif cfg.es == 'guided' then
es = guided.Guided(network.n_param, cfg.epoch_trials, cfg.epoch_top_trials,
cfg.base_rate, cfg.deviation, cfg.alpha, cfg.beta)
else
error("Unknown evolution strategy specified: " + tostring(cfg.es))
end
@ -536,6 +566,7 @@ local function doit(dummy)
empty(game.sprite_input)
empty(game.tile_input)
empty(game.extra_input)
empty(game.new_input)
local controllable = game.R(0x757) == 0 and game.R(0x758) == 0
local x, y = game.getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
@ -615,12 +646,18 @@ local function doit(dummy)
for i, v in ipairs(game.extra_input) do insert(X, v / 256) end
nn.reshape(X, 1, gcfg.input_size)
nn.reshape(game.tile_input, 1, gcfg.tile_count)
nn.reshape(game.new_input, 1, gcfg.tile_count * 3)
trial_frames = trial_frames + cfg.frameskip
if cfg.enable_network and game.get_state() == 'playing' or ingame_paused then
total_frames = total_frames + cfg.frameskip
local outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
local outputs
if cfg.embed then
outputs = network:forward({[nn_x]=X, [nn_tx]=game.tile_input})
else
outputs = network:forward({[nn_x]=X, [nn_tx]=game.new_input})
end
local eps = lerp(cfg.eps_start, cfg.eps_stop, total_frames / cfg.eps_frames)
if cfg.det_epsilon and random() < eps then
@ -676,6 +713,7 @@ while true do
if reset then
do_reset()
lagless_count = 0
last_decision_frame = -1
end
if not cfg.enable_network then
@ -692,8 +730,15 @@ while true do
game.W(0x75A, 1)
end
local doot = jp == nil or lagless_count % cfg.frameskip == 0
local delta = lagless_count - last_decision_frame
local doot = true
if jp ~= nil then
doot = delta >= cfg.frameskip
doot = doot and random() >= cfg.prob_frameskip
doot = doot or delta >= cfg.max_frameskip
end
doit(not doot)
if doot then last_decision_frame = lagless_count end
-- jp might still be nil if we're not ingame or we're not playing.
if jp ~= nil then joypad.write(1, jp) end

54
monitor_tiles.lua Normal file
View File

@ -0,0 +1,54 @@
-- keep track of which blocks are actually seen in the game.
-- play back an all-levels TAS with this script running.
local floor = math.floor
local open = io.open
local pairs = pairs
local print = print
local util = require("util")
local R = memory.readbyteunsigned
local W = memory.writebyte
local function S(addr) return util.signbyte(R(addr)) end
local game = require("smb") -- just for advance()
local serial = require "serialize"
local serialize = serial.serialize
local deserialize = serial.deserialize
local fn = 'seen_tiles.lua'
local seen = deserialize(fn) or {}
local function mark_tile(sx, sy, kind)
if not seen[kind] then
seen[kind] = true
print(("%02X"):format(kind))
serialize(fn, seen)
end
end
local function handle_tiles()
--local tile_col = R(0x6A0)
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
local tile_scroll_remainder = R(0x73F) % 16
for y = 0, 12 do
for x = 0, 16 do
local col = (x + tile_scroll) % 32
local t
if col < 16 then
t = R(0x500 + y * 16 + (col % 16))
else
t = R(0x5D0 + y * 16 + (col % 16))
end
local sx = x * 16 + 8 - tile_scroll_remainder
local sy = y * 16 + 40
mark_tile(sx, sy, t)
end
end
end
while true do
handle_tiles()
game.advance()
end

48
nn.lua
View File

@ -1,20 +1,15 @@
local assert = assert
local ceil = math.ceil
local cos = math.cos
local exp = math.exp
local floor = math.floor
local huge = math.huge
local insert = table.insert
local ipairs = ipairs
local log = math.log
local max = math.max
local min = math.min
local open = io.open
local pairs = pairs
local pi = math.pi
local print = print
local remove = table.remove
local sin = math.sin
local sqrt = math.sqrt
local tanh = math.tanh
local tostring = tostring
@ -105,19 +100,28 @@ end
-- ndarray-ish stuff and more involved math
local function pp_join(sep, fmt, t, a, b)
a = a or 1
b = b or #t
local s = ''
for i = a, b do
s = s..fmt:format(t[i])
if i ~= b then s = s..sep end
end
return s
end
local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
-- pretty-prints an nd-array.
fmt = fmt or '%10.7f,'
fmt = fmt or '%10.7f'
sep = sep or ','
ti = ti or 0
di = di or 1
depth = depth or 0
if t.shape == nil then
local s = '['
for i = 1, #t do s = s..fmt:format(t[i]) end
return s..']'..sep..'\n'
end
if t == nil then return "nil" end
if t.shape == nil then return '['..pp_join(sep, fmt, t)..']'..sep..'\n' end
local dim = t.shape[di]
@ -134,11 +138,10 @@ local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
s = s..pp(t, fmt, sep, ti, di + 1, depth + 1, i == 1, i == dim)
ti = ti + ti_step
end
if islast then s = s..indent..']'..sep..'\n' else s = s..indent..']'..sep end
s = s..indent..']'..sep
if islast then s = s..'\n' end
else
s = s..indent..'['
for i = ti + 1, ti + dim do s = s..fmt:format(t[i])..sep end
s = s..']'..sep..'\n'
s = s..indent..'['..pp_join(sep, fmt, t, ti + 1, ti + dim)..']'..sep..'\n'
end
return s
end
@ -265,6 +268,20 @@ local function dot(a, b, ax_a, ax_b, out)
return out
end
local function transpose(x, out)
assert(#x.shape == 2) -- TODO: handle ndarrays like numpy.
local rows = x.shape[1]
local cols = x.shape[2]
local y = out or zeros{cols, rows}
-- TODO: simplify? y can be consecutive for sure.
for i = 1, rows do
for j = 1, cols do
y[(j - 1) * rows + i] = x[(i - 1) * cols + j]
end
end
return y
end
-- nodal
local function traverse(node_in, node_out, nodes, dummy_mode)
@ -875,6 +892,7 @@ return {
ppi = ppi,
dot_mv = dot_mv,
dot = dot,
transpose = transpose,
traverse = traverse,
traverse_all = traverse_all,

7
pp_test.lua Normal file
View File

@ -0,0 +1,7 @@
local nn = require "nn"
a = {0,1,2,3,4,5,6,7}
print(nn.pp(a, "%9.4f"))
print(nn.pp(nn.reshape(a, 4, 2), "%9.4f"))
print(nn.pp(nn.reshape(a, 2, 4), "%9.4f"))
print(nn.pp(nn.reshape(a, 2, 2, 2), "%9.4f"))

View File

@ -33,7 +33,7 @@ make_preset{
init_zeros = true,
reduce_tiles = true,
reduce_tiles = 5,
bias_out = false,
deterministic = false,
@ -72,6 +72,28 @@ make_preset{
sigma_decay = 0.008,
}
make_preset{
name = 'snes2',
parent = 'big-scroll-reduced',
es = 'snes',
deterministic = true,
deviation = 0.01,
negate_trials = false,
epoch_trials = 60,
min_refresh = 2/3,
param_rate = 0.368,
param_decay = 0.0138,
sigma_rate = 0.100,
sigma_decay = 0.0051,
}
make_preset{
name = 'snes3',
parent = 'snes2',
min_refresh = 1/3,
}
make_preset{
name = 'xnes',
parent = 'big-scroll-reduced',
@ -120,6 +142,375 @@ make_preset{
momentum = 0.5,
}
make_preset{
name = 'ars-vanilla',
parent = 'ars',
}
make_preset{
name = 'ars-lips',
parent = 'ars',
ars_lips = true,
-- momentum = 0.5, -- this is default.
param_rate = 1.0,
}
make_preset{
name = 'ars-skip',
parent = 'ars',
frameskip = 1,
prob_frameskip = 0.25,
}
make_preset{
name = 'ars-big',
parent = 'ars',
epoch_top_trials = 75,
epoch_trials = 100,
momentum = 0.5,
param_rate = 1.0,
--graycode = true,
}
make_preset{
name = 'ars-huge',
parent = 'big-scroll-hidden',
deterministic = true,
deviation = 0.01,
epoch_top_trials = 75,
epoch_trials = 100,
es = 'ars',
momentum = 0.5,
param_decay = 0.0138,
param_rate = 0.5,
}
make_preset{
name = 'ars-stupid',
parent = 'big-scroll-reduced',
es = 'ars',
epoch_top_trials = 4,
deterministic = false,
deviation = 0.2,
epoch_trials = 4,
param_rate = 0.1,
param_decay = 0.003,
momentum = 0.99,
}
-- new stuff for 2019:
make_preset{
name = 'ars-skip-more',
parent = 'ars',
-- old:
--frameskip = 2,
--prob_frameskip = 0.5,
--max_frameskip = 60,
-- new:
frameskip = 3,
prob_frameskip = 0.5,
max_frameskip = 5,
}
make_preset{
name = 'ars-skip-more-3',
parent = 'ars-skip-more',
attempts = 3, -- per trial. score = mean(scores) - stdev(scores)
}
make_preset{
name = 'snes-skip-more-3',
parent = 'snes3',
frameskip = 3,
prob_frameskip = 0.5,
max_frameskip = 5,
attempts = 3,
}
make_preset{
name = 'guided',
parent = 'big-scroll-reduced',
es = 'guided',
epoch_top_trials = 20,
deterministic = true,
deviation = 0.1,
epoch_trials = 20,
param_rate = 0.00368,
param_decay = 0.0,
}
make_preset{
name = 'guided2',
parent = 'guided',
past_grads = 2,
-- after epoch 50, trying this:
--param_rate = 0.05,
-- after epoch 50+20, stepping back to this:
param_rate = 0.01,
}
make_preset{
name = 'guided10',
parent = 'guided',
past_grads = 10,
}
make_preset{
name = 'guided69', -- the nice one
parent = 'guided',
deviation = 0.05,
epoch_top_trials = 10,
epoch_trials = 20,
param_rate = 0.006,
past_grads = 4,
alpha = 0.25,
}
-- TODO: yet another preset. try building up from 1 trial ARS to something good.
make_preset{
name = 'redux',
min_time = 300,
max_time = 300,
timer_loser = 1/1,
score_multiplier = 1,
init_zeros = true,
deterministic = true,
es = 'guided',
past_grads = 2, -- for Guided.
alpha = 0.25, -- for Guided.
ars_lips = false, -- for ARS.
beta = 1.0, -- fix the default.
epoch_top_trials = 4, -- for ARS, Guided.
epoch_trials = 5,
attempts = 1, -- TODO: document.
deviation = 1.0, -- 0.1
base_rate = 1.0,
param_decay = 0.01,
graycode = false, -- for ARS.
min_refresh = 0.1, -- for SNES.
sigma_decay = 0.0, -- for SNES, xNES.
momentum = 0.0, -- for ARS.
}
make_preset{
name = 'redux_big',
parent = 'redux',
time_inputs = true, -- insert binary inputs of a frame counter.
hidden = true, -- use a hidden layer with ReLU/GELU activation.
hidden_size = 64,
layernorm = true, -- use a LayerNorm layer after said activation.
reduce_tiles = false,
bias_out = false,
-- gets stuck pretty quick, so tweak some stuff...
epoch_top_trials = 8,
epoch_trials = 10,
deviation = 1.0,
base_rate = 0.15,
param_decay = 0.05,
past_grads = 4,
alpha = 0.25,
-- well it doesn't get stuck anymore, but regular redux works much better.
}
make_preset{
name = 'guided-skip-more-3',
parent = 'guided',
--param_rate = 0.00368, -- should probably be this instead...
param_rate = 0.01,
frameskip = 3,
prob_frameskip = 0.5,
max_frameskip = 5,
attempts = 3,
}
make_preset{
name = 'guided-skip-more-3-again',
parent = 'guided-skip-more-3',
param_rate = 0.08, --0.0316,
deviation = 0.5,
alpha = 0.1, --0.5,
}
make_preset{
name = 'crazy',
parent = 'big-scroll-reduced',
es = 'guided',
epoch_top_trials = 15,
deterministic = false,
deviation = 1.0,
epoch_trials = 15,
param_rate = 1.0,
param_decay = 0.0,
alpha = 0.0316,
--attempts = 3,
}
make_preset{
name = 'ars-lips2',
parent = 'ars',
ars_lips = true,
--epoch_trials = 10,
param_rate = 0.147,
}
make_preset{
name = 'ars-lips3',
parent = 'ars',
ars_lips = true,
param_rate = 0.5,
deviation = 0.02, -- added after like 272 epochs
param_decay = 0.0276, -- added after like 62 epochs
}
make_preset{
name = 'hard-embed',
parent = 'big-scroll-hidden',
embed = false,
reduce_tiles = 5,
hidden_size = 54,
epoch_top_trials = 20,
deterministic = true,
deviation = 0.01,
epoch_trials = 20,
param_rate = 0.368,
param_decay = 0.0138,
momentum = 0.5,
beta = 1.0,
}
make_preset{
name = 'xnes-recon', -- recon-sidered
-- parent = 'big-scroll-hidden',
parent = 'big-scroll-reduced',
es = 'xnes',
-- embed = false,
-- reduce_tiles = 5,
-- hidden_size = 54,
epoch_trials = 20,
epoch_top_trials = 20,
negate_trials = true,
deterministic = true,
deviation = 0.1,
param_decay = 0.01,
param_rate = 0.2,
sigma_rate = 0.05,
covar_rate = 0.1,
}
make_preset{
name = 'arse',
parent = 'big-scroll-reduced',
deterministic = true,
es = 'ars',
epoch_trials = 5,
epoch_top_trials = 9999,
deviation = 1.0,
param_rate = 1.0,
beta = 1.0, -- fix the default.
beta = 5.0, -- oops, i had a dumb bug.
param_decay = 0.01,
momentum = 0.0,
}
make_preset{
name = 'arse2',
parent = 'arse',
beta = 5.0, -- oops, i had a dumb bug.
param_decay = 0.001,
embed = false,
}
make_preset{
name = 'arse3',
parent = 'arse2',
epoch_trials = 10,
-- note: also disabled sum of squares in ars.lua
beta = 1.0,
--deviation = 0.5, -- after 500 epochs
deviation = 0.7071, -- after 500+521 epochs
}
make_preset{
name = 'arse4',
parent = 'arse3',
hidden = true,
hidden_size = 68,
}
make_preset{
name = 'arse5',
parent = 'arse3',
-- sum of squares still disabled. i probably won't re-enable it really.
--deviation = 1.0,
deviation = 1.414, -- after 80+360+790 epochs
momentum = 0.8, -- neumann momentum, peaks at this value
--param_decay = 0.003, -- after 80 epochs
--param_decay = 0.01, -- after 80+360 epochs
param_decay = 0.0, -- after 80+360+790 epochs
}
make_preset{
name = 'arse6',
parent = 'arse3',
epoch_trials = 20,
param_rate = 0.25,
deviation = 0.1, -- maybe try 0.15
momentum = 0.9, -- maybe try 0.8
param_decay = 0.0,
}
-- end of new stuff
make_preset{
name = 'play',

113
qr.lua Normal file
View File

@ -0,0 +1,113 @@
local nn = require "nn"
local assert = assert
local dot = nn.dot
local ipairs = ipairs
local min = math.min
local reshape = nn.reshape
local sqrt = math.sqrt
local transpose = nn.transpose
local zeros = nn.zeros
local function minor(x, d)
assert(#x.shape == 2)
assert(d <= x.shape[1] and d <= x.shape[2])
local m = zeros(x.shape)
-- fill diagonals.
--for i = 1, d do m[(i - 1) * m.shape[2] + i] = 1 end
for i = 1, d * m.shape[2], m.shape[2] + 1 do m[i] = 1 end
-- copy values.
for i = d + 1, m.shape[1] do
for j = d + 1, m.shape[2] do
local ind = (i - 1) * m.shape[2] + j
m[ind] = x[ind]
end
end
return m
end
local function norm(a) -- vector norm
local sum = 0
for _, v in ipairs(a) do sum = sum + v * v end
return sqrt(sum)
end
local function householder(x)
local rows = x.shape[1]
local cols = x.shape[2]
local iters = min(rows - 1, cols)
local q = nil
local vec = zeros(rows)
local z = x
for k = 1, iters do
z = minor(z, k - 1)
-- extract a column.
for i = 1, rows do vec[i] = z[k + (i - 1) * cols] end
local a = norm(vec)
-- negate the norm if the original diagonal is non-negative.
local ind = (k - 1) * cols + k
if x[ind] > 0 then a = -a end
vec[k] = vec[k] + a
local a = norm(vec)
if a == 0 then a = 1 end -- FIXME: should probably just raise an error.
for i, v in ipairs(vec) do vec[i] = v / a end
-- construct the householder reflection: mat = I - 2 * vec * vec.T
local mat = zeros{rows, rows}
for i = 1, rows do
for j = 1, rows do
local ind = (i - 1) * rows + j
local diag = i == j and 1 or 0
mat[ind] = diag - 2 * vec[i] * vec[j]
end
end
--print(nn.pp(mat, "%9.3f"))
if q == nil then q = mat else q = dot(mat, q) end
z = dot(mat, z)
end
return transpose(q), dot(q, x) -- Q, R
end
local function qr(x)
-- a wrapper for the householder method that will return reduced matrices.
assert(#x.shape == 2)
local q, r = householder(x)
local rows = x.shape[1]
local cols = x.shape[2]
if cols >= rows then return q, r end
-- trim q in-place.
q.shape[2] = cols
local ind = 1
for i = 1, rows do
for j = 1, cols do
--ind = (i - 1) * cols + j
q[ind] = q[(i - 1) * rows + j]
ind = ind + 1
end
end
for i = rows * cols + 1, #q do q[i] = nil end
-- trim r in-place.
r.shape[1] = r.shape[2]
for i = r.shape[1] * r.shape[2] + 1, #r do r[i] = nil end
return q, r
end
return qr

76
qr2.lua Normal file
View File

@ -0,0 +1,76 @@
local min = math.min
local sqrt = math.sqrt
local nn = require "nn"
local transpose = nn.transpose
local zeros = nn.zeros
local function qr(a)
-- FIXME: if first column is exactly zero,
-- and cols > rows, Q @ R will not reconstruct the input.
-- this isn't too bad since an input like that is invalid anyway,
-- but i feel like it should be salvageable.
-- actually the scope of the problem is much larger than that.
-- an input like
--[=[
[[0, 0, 0, 0]
[1, 0, 1, 1]
[0, 0, 2, 2]
[0, 0, 3, 3]]
--]=]
-- will cause a lot of problems. for example, Q @ Q.T won't equal eye(4).
-- hmm. maybe we can detect this and reverse the matmul to identity if necessary?
assert(#a.shape == 2)
local rows = a.shape[1]
local cols = a.shape[2]
local small = min(rows, cols)
local q = transpose(a)
local r = zeros{small, cols}
for i = 1, cols do
local i0 = (i - 1) * rows + 1
local i1 = i * rows
for j = 1, min(i - 1, small) do
local j0 = (j - 1) * rows + 1
local j1 = j * rows
local i_to_j = j0 - i0
local num = 0
local den = 0
for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end
for k = j0, j1 do den = den + q[k] * q[k] end
--print(num, den)
if den == 0 then den = 1 end -- TODO: should probably just error.
local x = num / den
r[(j - 1) * cols + i] = x
for k = i0, i1 do q[k] = q[k] - q[k + i_to_j] * x end
end
if i <= small then
local sum = 0
for k = i0, i1 do sum = sum + q[k] * q[k] end
local norm = sqrt(sum)
if norm == 0 then
--norm = 1
--q[i0 + i - 1] = 1 -- FIXME: not robust.
r[(i - 1) * cols + i] = 0
else
for k = i0, i1 do q[k] = q[k] / norm end
r[(i - 1) * cols + i] = norm
end
end
end
for k = small * rows + 1, #q do q[k] = nil end
q.shape[1] = small
return transpose(q), r
end
return qr

79
qr3.lua Normal file
View File

@ -0,0 +1,79 @@
-- gram-schmidt process
local nn = require "nn"
local dot = nn.dot
local reshape = nn.reshape
local sqrt = math.sqrt
local transpose = nn.transpose
local zeros = nn.zeros
local function qr(mat)
assert(#mat.shape == 2)
local v = transpose(mat)
local rows = v.shape[1]
local cols = v.shape[2]
local u = zeros(v.shape)
local w = {}
local y = 1
local function load_row()
local start = (y - 1) * cols
for x = 1, cols do w[x] = v[start + x] end
end
local function push_row()
local sum = 0
for _, value in ipairs(w) do sum = sum + value * value end
local norm = sqrt(sum)
local start = (y - 1) * cols
for x, value in ipairs(w) do u[start + x] = value / norm end
y = y + 1
end
load_row()
push_row()
local sums = {}
for i = 2, rows do
load_row()
for x = 1, cols do sums[x] = 0 end
for j = 1, i - 1 do
local start = (j - 1) * cols
local dotted = 0
for x, value in ipairs(w) do
dotted = dotted + value * u[start + x]
end
--[[
local scale = 0
for x = 1, cols do
local value = u[start + x]
scale = scale + value * value
end
print(scale)
dotted = dotted / scale
--]]
for x, value in ipairs(sums) do
sums[x] = value + dotted * u[start + x]
end
end
for x, value in ipairs(w) do w[x] = value - sums[x] end
push_row()
end
return transpose(u), dot(u, mat)
end
return qr

87
qr_test.lua Normal file
View File

@ -0,0 +1,87 @@
local globalize = require "strict"
local nn = require "nn"
local qr = require "qr"
local qr2 = require "qr2"
local qr3 = require "qr3"
local A
if false then
A = {
12, -51, 4,
6, 167, -68,
-4, 24, -41,
-1, 1, 0,
2, 0, 3,
}
A = nn.reshape(A, 5, 3)
elseif false then
A = {
0, 1, 2, 3, 4,
5, 6, 7, 8, 9,
10, 11, 12, 13, 14
}
A = nn.reshape(A, 5, 3)
elseif false then
A = {
1, 2, 0,
2, 3, 1,
3, 4, 0,
4, 5, 1,
5, 6, 0,
}
A = {
1, 0, 0,
2, 0, 1,
3, 0, 0,
4, 0, 1,
5, 0, 0,
}
A = nn.reshape(A, 5, 3)
--A = nn.transpose(A)
else
A = {
1, 2, -3,
2, 4, 5,
-3, 5, 6,
}
A = nn.reshape(A, 3, 3)
end
print("A")
print(nn.pp(A, "%9.4f"))
print()
local Q, R = qr(A)
print("Q (reference)")
print(nn.pp(Q, "%9.4f"))
print()
local Q, R = qr3(A)
print("Q")
print(nn.pp(Q, "%9.4f"))
print()
print("R")
print(nn.pp(R, "%9.4f"))
print()
print("Q @ R")
print(nn.pp(nn.dot(Q, R), "%9.4f"))
print()
--print("Q @ Q.T = I")
--print(nn.pp(nn.dot(Q, nn.transpose(Q)), "%9.4f"))
--print()
--A = nn.reshape(A, 5, 3)
--Q, R = qr(A)

15
rescale.lua Normal file
View File

@ -0,0 +1,15 @@
local f = assert(io.open("params-ars4.txt", "r"))
local data = f:read("*a")
f:close()
local values = {}
for v in data:gmatch("[^\r\n]+") do
table.insert(values, tonumber(v))
end
for i, v in ipairs(values) do
values[i] = v * 100
end
for i, v in ipairs(values) do
print(v)
end

183
running.lua Normal file
View File

@ -0,0 +1,183 @@
local huge = math.huge
local ipairs = ipairs
local open = io.open
local sqrt = math.sqrt
local nn = require("nn")
local Base = require("Base")
-- https://github.com/modestyachts/ARS/blob/master/code/filter.py
-- http://www.johndcook.com/blog/standard_deviation/
local Stats = Base:extend()
local Normalizer = Base:extend()
function Stats:init(shape)
self._n = 0
self._M = nn.zeros(shape)
self._S = nn.zeros(shape)
end
function Stats:push(x)
assert(nn.prod(x.shape) == nn.prod(self._M.shape), "sizes mismatch")
local n1 = self._n
self._n = self._n + 1
if self._n == 1 then
nn.copy(x, self._M)
else
local delta = {}
for i, v in ipairs(self._M) do delta[i] = x[i] - v end
for i, v in ipairs(self._M) do self._M[i] = v + delta[i] / self._n end
for i, v in ipairs(self._S) do self._S[i] = v + delta[i] * delta[i] * n1 / self._n end
end
end
function Stats:var()
local out = {}
if self._n == 1 then
for i, v in ipairs(self._M) do out[i] = v * v end
else
for i, v in ipairs(self._S) do out[i] = v / (self._n - 1) end
end
return out
end
function Stats:dev()
local out = self:var()
for i, v in ipairs(out) do out[i] = sqrt(v) end
return out
end
function Normalizer:init(shape, demean, destd)
if demean == nil then demean = true end
if destd == nil then destd = true end
self.shape = shape
self.demean = demean
self.destd = destd
self.rs = Stats(shape)
self.mean = nn.zeros(shape)
self.std = nn.zeros(shape)
for i = 1, #self.std do self.std[i] = 1 end
end
function Normalizer:process(x)
local out = nn.copy(x)
if self.demean then
for i, v in ipairs(out) do out[i] = out[i] - self.mean[i] end
end
if self.destd then
for i, v in ipairs(out) do out[i] = out[i] / (self.std[i] + 1e-8) end
end
return out
end
function Normalizer:update()
nn.copy(self.rs._M, self.mean) -- FIXME: HACK
nn.copy(self.rs:dev(), self.std)
-- Set values for std less than 1e-7 to +inf
-- to avoid dividing by zero. State elements
-- with zero variance are set to zero as a result.
for i, v in ipairs(self.std) do
if v < 1e-7 then self.std[i] = huge end
end
end
function Normalizer:push(x, update)
self.rs:push(x)
if update == nil or update then self:update() end
return self:process(x)
end
function Normalizer:default_filename()
return ('stats%07i.txt'):format(nn.prod(self.shape))
end
function Normalizer:save(fn)
local fn = fn or self:default_filename()
local f = open(fn, 'w')
if f == nil then error("Failed to save stats to file "..fn) end
f:write(self.rs._n)
f:write('\n')
for i, v in ipairs(self.rs._M) do
f:write(v)
f:write('\n')
end
for i, v in ipairs(self.rs._S) do
f:write(v)
f:write('\n')
end
f:close()
end
function Normalizer:load(fn)
local fn = fn or self:default_filename()
local f = open(fn, 'r')
if f == nil then error("Failed to load stats from file "..fn) end
local i = 0
local split_M = 1
local split_S = split_M + nn.prod(self.shape)
for line in f:lines() do
i = i + 1
local n = tonumber(line)
if n == nil then
error("Failed reading line "..tostring(i).." of file "..fn)
end
if i <= split_M then
self.rs._n = n
elseif i <= split_S then
self.rs._M[i - split_M] = n
else
self.rs._S[i - split_S] = n
end
end
f:close()
self:update()
end
--[[
-- basic tests
local dims = 20
local rs = Stats(dims)
local x = nn.zeros(dims)
for i = 1, #x do x[i] = nn.normal() end
rs:push(x)
print(nn.pp(rs:dev()))
for j = 1, 10000 do
for i = 1, #x do x[i] = nn.normal() end
rs:push(x)
end
print(nn.pp(rs:dev()))
--
local ms = Normalizer(dims)
local exp = math.exp
local y
for i = 1, #x do x[i] = exp(nn.normal()) end
y = ms:push(x)
print(nn.pp(y))
for j = 1, 10000 do
for i = 1, #x do x[i] = exp(nn.normal()) end
y = ms:push(x)
end
print(nn.pp(y))
print("mean:")
print(nn.pp(ms.mean))
print("stdev:")
print(nn.pp(ms.std))
--]]
return {
Stats = Stats,
Normalizer = Normalizer,
}

59
seen_tiles.lua Normal file
View File

@ -0,0 +1,59 @@
return {
[0] = true,
[16] = true,
[17] = true,
[18] = true,
[19] = true,
[20] = true,
[21] = true,
[22] = true,
[23] = true,
[24] = true,
[25] = true,
[26] = true,
[27] = true,
[28] = true,
[29] = true,
[30] = true,
[31] = true,
[32] = true,
[33] = true,
[34] = true,
[35] = true,
[36] = true,
[37] = true,
[38] = true,
[81] = true,
[82] = true,
[84] = true,
[85] = true,
[86] = true,
[87] = true,
[88] = true,
[89] = true,
[90] = true,
[91] = true,
[92] = true,
[93] = true,
[94] = true,
[95] = true,
[96] = true,
[97] = true,
[98] = true,
[99] = true,
[100] = true,
[101] = true,
[102] = true,
[103] = true,
[104] = true,
[105] = true,
[107] = true,
[108] = true,
[137] = true,
[192] = true,
[193] = true,
[194] = true,
[195] = true,
[196] = true,
[197] = true,
}

76
serialize.lua Normal file
View File

@ -0,0 +1,76 @@
-- it's simple, dumb, unsafe, incomplete, and it gets the damn job done
local type = type
local extra = require "extra"
local opairs = extra.opairs
local tostring = tostring
local open = io.open
local strfmt = string.format
local strrep = string.rep
local function kill_bom(s)
if #s >= 3 and s:byte(1)==0xEF and s:byte(2)==0xBB and s:byte(3)==0xBF then
return s:sub(4)
end
return s
end
local function sanitize(v)
local force = type(v) == 'string' and v:sub(1, 1):match('%d')
force = force and true or false
return type(v) == 'string' and strfmt('%q', v) or tostring(v), force
end
local function _serialize(value, writer, level)
level = level or 1
if type(value) == 'table' then
local indent = strrep('\t', level)
writer('{\n')
for key,value in opairs(value) do
local sane, force = sanitize(key)
local keyval = (sane == '"'..key..'"' and not force) and key or '['..sane..']'
writer(indent..keyval..' = ')
_serialize(value, writer, level + 1)
writer(',\n')
end
writer(strrep('\t', level - 1)..'}')
else
local sane, force = sanitize(value)
writer(sane)
end
end
local function _deserialize(script)
local f = loadstring(kill_bom(script))
if f ~= nil then
return f()
else
print('WARNING: no function to deserialize with')
return nil
end
end
local function serialize(path, value)
local file = open(path, 'w')
if not file then return end
file:write("return ")
_serialize(value, function(...)
file:write(...)
end)
file:write("\n")
file:close()
end
local function deserialize(path)
local file = open(path, 'r')
if not file then return end
local script = file:read('*a')
local value = _deserialize(script)
file:close()
return value
end
return {
serialize = serialize,
deserialize = deserialize,
}

7
sign_test.lua Normal file
View File

@ -0,0 +1,7 @@
local util = require("util")
print(util.sign(-1))
print(util.sign(0))
print(util.sign(1))
print(util.sign(-0.1))
print(util.sign(0.0))
print(util.sign(0.1))

217
smb.lua
View File

@ -1,12 +1,16 @@
-- disassembly used for reference:
-- https://gist.githubusercontent.com/1wErt3r/4048722/raw/59e88c0028a58c6d7b9156749230ccac647bc7d4/SMBDIS.ASM
local band = bit.band
local floor = math.floor
local emu = emu
local gui = gui
local util = require("util")
local band = bit.band
local clamp = util.clamp
local empty = util.empty
local emu = emu
local floor = math.floor
local gui = gui
local insert = table.insert
local R = memory.readbyteunsigned
local W = memory.writebyte
local function S(addr) return util.signbyte(R(addr)) end
@ -76,13 +80,84 @@ local rotation_offsets = { -- FIXME: not all of these are pixel-perfect.
-8, -38,
}
local tile_embedding = {
-- handmade trinary encoding.
-- we have 57 valid tile types and 27 permutations to work with.
[0x00] = { 0, 0, 0}, -- air
[0x10] = { 1, -1, 1}, -- vertical pipe (top left) (enterable)
[0x11] = { 1, -1, 1}, -- vertical pipe (top right) (enterable)
[0x12] = { 0, -1, 1}, -- vertical pipe (top left)
[0x13] = { 0, -1, 1}, -- vertical pipe (top right)
[0x14] = { 0, -1, -1}, -- vertical pipe (left)
[0x15] = { 0, -1, -1}, -- vertical pipe (right)
[0x16] = {-1, -1, -1}, --
[0x17] = {-1, -1, -1}, --
[0x18] = {-1, -1, -1}, --
[0x19] = {-1, -1, -1}, --
[0x1A] = {-1, -1, -1}, --
[0x1B] = {-1, -1, -1}, --
[0x1C] = { 0, -1, 0}, -- horizontal pipe (top left)
[0x1D] = { 0, -1, 0}, -- horizontal pipe (top)
[0x1E] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (top)
[0x1F] = { 0, -1, 0}, -- horizontal pipe (bottom left)
[0x20] = { 0, -1, 0}, -- horizontal pipe (bottom)
[0x21] = { 0, -1, 0}, -- horizontal pipe joining vertical pipe (bottom)
[0x22] = { 0, 0, 0}, --
[0x23] = { 0, 0, 0}, -- block being hit (either breakable or ?)
[0x24] = { 0, 0, 0}, --
[0x25] = { 0, 0, 1}, -- flagpole
[0x26] = { 0, 0, 0}, --
[0x51] = { 1, 1, 0}, -- breakable brick block
[0x52] = { 1, 1, 0}, -- breakable brick block (again?)
[0x54] = { 0, 1, 0}, -- regular ground
[0x55] = { 0, 0, 0}, --
[0x56] = { 0, 0, 0}, --
[0x57] = {-1, 1, 1}, -- star brick block
[0x58] = { 1, 1, -1}, -- coin brick block (many coins)
[0x59] = { 0, 0, 0}, --
[0x5A] = { 0, 0, 0}, --
[0x5B] = { 0, 0, 0}, --
[0x5C] = { 0, 0, 0}, --
[0x5D] = { 1, 1, -1}, -- coin brick block (many coins) (again?)
[0x5E] = { 0, 0, 0}, --
[0x5F] = { 0, 0, 0}, --
[0x60] = {-1, 0, 0}, -- invisible 1-up block
[0x61] = { 0, 1, -1}, -- chocolate block (usually used for stairs)
[0x62] = { 0, 0, 0}, --
[0x63] = { 0, 0, 0}, --
[0x64] = { 0, 0, 0}, --
[0x65] = { 0, 0, 0}, --
[0x66] = { 0, 0, 0}, --
[0x67] = { 0, 0, 0}, --
[0x68] = { 0, 0, 0}, --
[0x69] = { 0, 0, 0}, --
[0x6B] = { 0, 0, 0}, --
[0x6C] = { 0, 0, 0}, --
[0x89] = { 0, 0, 0}, --
[0xC0] = {-1, 1, -1}, -- coin ? block
[0xC1] = {-1, 1, 0}, -- mushroom ? block
[0xC2] = { 0, 0, -1}, -- coin
[0xC3] = { 0, 0, 0}, --
[0xC4] = { 0, 1, 1}, -- hit block
[0xC5] = { 0, 0, 0}, --
}
-- TODO: reinterface to one "input" array visible to main.lua.
local sprite_input = {}
local tile_input = {}
local extra_input = {}
local new_input = {}
local overlay = false
local function embed_tile(t)
local out = new_input
local embedded = tile_embedding[t]
insert(out, embedded[1])
insert(out, embedded[2])
insert(out, embedded[3])
end
local function get_timer()
return R(0x7F8) * 100 + R(0x7F9) * 10 + R(0x7FA)
end
@ -130,6 +205,7 @@ end
local function mark_tile(x, y, t)
tile_input[#tile_input+1] = tile_lut[t]
embed_tile(t)
if t == 0 then return end
if overlay then
gui.box(x-8, y-8, x+8, y+8)
@ -289,7 +365,8 @@ local function handle_tiles()
extra_input[#extra_input+1] = tile_scroll_remainder
-- for y = 0, 12 do
-- afaik the bottom row is always a copy of the second to bottom,
-- and the top is always air, so drop those from the inputs:
-- and the top is always air (except underground!),
-- so drop those from the inputs:
for y = 1, 11 do
for x = 0, 16 do
local col = (x + tile_scroll) % 32
@ -306,6 +383,117 @@ local function handle_tiles()
end
end
local function load_level(world, level)
if world == 0 then world = random(1, 8) end
if level == 0 then level = random(1, 4) end
emu.poweron()
local jp_mash = {
up = false, down = false, left = false, right = false,
A = false, B = false, select = false, start = false,
}
while emu.framecount() < 60 do
if emu.framecount() == 32 then
local area = area_lut[world * 10 + level]
W(0x75F, world - 1)
W(0x75C, level - 1)
W(0x760, area)
end
if emu.framecount() == 42 then
W(0x7A0, 0) -- world screen timer (reduces startup time)
end
jp_mash['start'] = emu.framecount() % 2 == 1
joypad.write(1, jp_mash)
emu.frameadvance()
end
end
-- new stuff.
local function tile_from_xy(x, y)
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
local tile_scroll_remainder = R(0x73F) % 16
local tile_x = floor((x + tile_scroll_remainder) / 16)
local tile_y = floor(y / 16)
local tile_t = 0 -- default to air
if tile_y < 0 then
tile_y = 0
elseif tile_y > 12 then
tile_y = 12
else
local col = (tile_x + tile_scroll) % 32
local addr = col < 16 and 0x500 or 0x5D0
tile_t = R(addr + tile_y * 16 + (col % 16))
end
return tile_t, tile_x, tile_y
end
local function new_stuff()
-- obviously very work in progress.
empty(new_input)
local mario_x, mario_y = getxy(0, 0x86, 0xCE, 0x6D, 0xB5)
-- normalized mario x
-- normalized mario y
insert(new_input, (mario_x - 112) / 64)
insert(new_input, (mario_y - 160) / 96)
-- type of tile we're standing on
-- type of tile we're occupying
gui.box(mario_x, mario_y, mario_x + 16, mario_y + 32)
local mario_tile_t, mario_tile_x, mario_tile_y =
tile_from_xy(mario_x + 8, mario_y - 8)
--[[
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
local tile_scroll_remainder = R(0x73F) % 16
local sx = mario_tile_x * 16 + 8 - tile_scroll_remainder
local sy = mario_tile_y * 16 + 40
gui.box(sx-8, sy-8, sx+8, sy+8)
gui.text(sx-5, sy-3, ("%02X"):format(mario_tile_t), '#FFFFFF', '#00000000')
--]]
embed_tile(new_input, mario_tile_t)
-- type of tile to the right, excluding space (small eyeheight)
-- how tall non-space extends upward from that tile
-- how tall non-space extends downward from that tile
-- type of tile to the right, excluding space (large eyeheight)
-- type of tile to the left, excluding space (small eyeheight)
-- how tall non-space extends upward from that tile
-- how tall non-space extends downward from that tile
-- type of tile to the left, excluding space (large eyeheight)
-- type of enemy (nearest down-right from mario's upper left)
-- normalized enemy x
-- normalized enemy y
-- VISUALIZE:
--local tile_col = R(0x6A0)
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
local tile_scroll_remainder = R(0x73F) % 16
for y = 1, 11 do
for x = 0, 16 do
local col = (x + tile_scroll) % 32
local addr = col < 16 and 0x500 or 0x5D0
local t = R(addr + y * 16 + (col % 16))
local sx = x * 16 + 8 - tile_scroll_remainder
local sy = y * 16 + 40
if t ~= 0 then
gui.box(sx-8, sy-8, sx+8, sy+8)
gui.text(sx-5, sy-3, ("%02X"):format(t), '#FFFFFF', '#00000000')
end
end
end
end
return {
-- TODO: don't expose these; provide interfaces for everything needed.
R=R,
@ -315,6 +503,7 @@ overlay=overlay,
valid_tiles=valid_tiles,
area_lut=area_lut,
embed_tile=embed_tile,
sprite_input=sprite_input,
tile_input=tile_input,
@ -323,16 +512,24 @@ extra_input=extra_input,
get_timer=get_timer,
get_score=get_score,
set_timer=set_timer,
mark_sprite=mark_sprite,
mark_tile=mark_tile,
get_state=get_state,
getxy=getxy,
paused=paused,
get_state=get_state,
advance=advance,
mark_sprite=mark_sprite,
mark_tile=mark_tile,
handle_enemies=handle_enemies,
handle_fireballs=handle_fireballs,
handle_blocks=handle_blocks,
handle_hammers=handle_hammers,
handle_misc=handle_misc,
handle_tiles=handle_tiles,
advance=advance,
load_level=load_level,
new_stuff=new_stuff,
new_input=new_input,
}

137
snes.lua
View File

@ -3,7 +3,6 @@
-- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
-- not to be confused with the Super Nintendo Entertainment System.
local abs = math.abs
local assert = assert
local exp = math.exp
local floor = math.floor
@ -30,9 +29,12 @@ local normalize_sums = util.normalize_sums
local pdf = util.pdf
local weighted_mann_whitney = util.weighted_mann_whitney
local xnes = require "xnes"
local make_utility = xnes.make_utility
local Snes = Base:extend()
function Snes:init(dims, popsize, base_rate, sigma, antithetic)
function Snes:init(dims, popsize, base_rate, sigma, antithetic, adaptive)
-- heuristic borrowed from CMA-ES:
self.dims = dims
self.popsize = popsize or 4 + (3 * floor(log(dims)))
@ -42,9 +44,12 @@ function Snes:init(dims, popsize, base_rate, sigma, antithetic)
self.covar_rate = base_rate
self.sigma = sigma or 1
self.antithetic = antithetic and true or false
self.adaptive = adaptive == nil and true or adaptive
if self.antithetic then self.popsize = self.popsize * 2 end
self.utility = make_utility(self.popsize)
self.rate_init = self.sigma_rate
self.mean = zeros{dims}
@ -148,60 +153,84 @@ function Snes:ask_mix(start_anew)
-- perform importance mixing.
local mean_old = self.mean
local mean_old = self.mean_old or self.mean
local mean_new = self.mean
local std_old = self.std_old or self.std
local std_new = self.std
self.new_asked = {}
self.new_noise = {}
local marked = {}
for p=1, min(#self.old_asked, self.popsize) do
local a = self.old_asked[p]
-- TODO: cache probs?
local function compute_probabilities(a)
local prob_new = 0
local prob_old = 0
for i, v in ipairs(a) do
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
end
return prob_new, prob_old
end
local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1)
if uniform() < accept then
--print(("accepted old sample %i with probability %f"):format(p, accept))
else
-- insert in reverse as not to screw up
-- the indices when removing later.
insert(marked, 1, p)
local all_asked, all_noise, all_score = {}, {}, {}
for p=1, #self.old_asked do
do
local pp = floor(uniform() * #self.old_asked) + 1
local a = self.old_asked[pp]
local prob_new, prob_old = compute_probabilities(a)
local accept = min(prob_new / prob_old * (1 - self.min_refresh), 1)
if uniform() < accept then
--print(("accepted old sample %i with probability %f"):format(pp, accept))
insert(all_asked, a)
insert(all_noise, self.old_noise[pp])
insert(all_score, self.old_score[pp])
end
end
end
for _, p in ipairs(marked) do
remove(self.old_asked, p)
remove(self.old_noise, p)
remove(self.old_score, p)
do
local a, n = {}, {}
for i=1, self.dims do n[i] = normal() end
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
local prob_new, prob_old = compute_probabilities(a)
local accept = max(1 - prob_old / prob_new, self.min_refresh)
if uniform() < accept then
--print(("accepted new sample %i with probability %f"):format(#all_asked, accept))
insert(all_asked, a)
insert(all_noise, n)
insert(all_score, false)
end
end
-- TODO: early stopping, making sure it doesn't affect performance.
end
while #self.old_asked + #self.new_asked < self.popsize do
local a = {}
local n = {}
while #all_asked > self.popsize do
local pp = floor(uniform() * #all_asked) + 1
--print(("removing sample %i to fit popsize"):format(pp))
remove(all_asked, pp)
remove(all_noise, pp)
remove(all_score, pp)
end
while #all_asked < self.popsize do
local a, n = {}, {}
for i=1, self.dims do n[i] = normal() end
for i, v in ipairs(n) do a[i] = mean_new[i] + std_new[i] * v end
--print(("unconditionally added new sample %i"):format(#all_asked))
insert(all_asked, a)
insert(all_noise, n)
insert(all_score, false)
end
-- can't cache here!
local prob_new = 0
local prob_old = 0
for i, v in ipairs(a) do
prob_new = prob_new + pdf(v, mean_new[i], std_new[i])
prob_old = prob_old + pdf(v, mean_old[i], std_old[i])
end
local accept = max(1 - prob_old / prob_new, self.min_refresh)
if uniform() < accept then
-- split all_ tables back into old_ and new_.
self.old_asked, self.old_noise, self.old_score = {}, {}, {}
self.new_asked, self.new_noise = {}, {}
for i, score in ipairs(all_score) do
local a, n = all_asked[i], all_noise[i]
if score ~= false then
insert(self.old_asked, a)
insert(self.old_noise, n)
insert(self.old_score, score)
else
insert(self.new_asked, a)
insert(self.new_noise, n)
--print(("accepted new sample %i with probability %f"):format(0, accept))
end
end
@ -211,15 +240,15 @@ end
function Snes:tell(scored)
self.evals = self.evals + #scored
local asked = self.asked
local noise = self.noise
local asked = self.mixing and self.new_asked or self.asked
local noise = self.mixing and self.new_noise or self.noise
if self.mixing then
-- note: modifies, in-place, externally exposed tables.
for i, v in ipairs(asked) do insert(self.old_asked, v) end
for i, v in ipairs(noise) do insert(self.old_noise, v) end
for i, v in ipairs(scored) do insert(self.old_score, v) end
asked = self.old_asked
noise = self.old_noise
-- note that these modify tables referenced externally in-place.
for i, v in ipairs(self.new_asked) do insert(asked, v) end
for i, v in ipairs(self.new_noise) do insert(noise, v) end
for i, v in ipairs(scored) do insert(self.old_score, v) end
scored = self.old_score
end
assert(asked and noise, ":tell() called before :ask()")
@ -232,8 +261,9 @@ function Snes:tell(scored)
local g_mean = zeros{self.dims}
local g_std = zeros{self.dims}
--[[
local utilize = true
local utility
local utility = self.utility
if utilize then
utility = {}
@ -243,17 +273,18 @@ function Snes:tell(scored)
else
utility = normalize_sums(scored, {})
end
--]]
for p=1, self.popsize do
local noise_p = noise[p]
local noise_p = noise[arg[p]]
for i, v in ipairs(g_mean) do
g_mean[i] = v + utility[p] * noise_p[i]
g_mean[i] = v + self.utility[p] * noise_p[i]
end
for i, v in ipairs(g_std) do
local n = noise_p[i]
g_std[i] = v + utility[p] * (n * n - 1)
g_std[i] = v + self.utility[p] * (n * n - 1)
end
end
@ -262,7 +293,9 @@ function Snes:tell(scored)
step[i] = self.std[i] * v
end
self.mean_old = {}
for i, v in ipairs(self.mean) do
self.mean_old[i] = v
self.mean[i] = v + self.param_rate * step[i]
end
@ -274,7 +307,7 @@ function Snes:tell(scored)
otherwise[i] = v * exp(self.sigma_rate * 0.75 * g_std[i])
end
self:adapt(asked, otherwise, utility)
if self.adaptive then self:adapt(asked, otherwise, self.utility) end
return step
end
@ -292,15 +325,15 @@ function Snes:adapt(asked, otherwise, qualities)
weights[p] = prob_big / prob_now
end
local p = weighted_mann_whitney(qualities, qualities, nil, weights)
--print("p:", p)
local u, p = weighted_mann_whitney(qualities, qualities, nil, weights)
--print(("u, p: %6.3f, %6.3f"):format(u, p))
if p < 0.5 - 1 / (3 * (self.dims + 1)) then
self.sigma_rate = 0.9 * self.sigma_rate + 0.1 * self.rate_init
print("learning rate -:", self.sigma_rate)
--print("learning rate -:", self.sigma_rate)
else
self.sigma_rate = min(1.1 * self.sigma_rate, 1)
print("learning rate +:", self.sigma_rate)
--print("learning rate +:", self.sigma_rate)
end
end

View File

@ -14,6 +14,7 @@ local random = math.random
local select = select
local sort = table.sort
local sqrt = math.sqrt
local type = type
local function sign(x)
-- remember that 0 is truthy in Lua.
@ -83,6 +84,28 @@ local function calc_mean_dev(x)
return mean, sqrt(dev)
end
local function calc_mean_dev_unbiased(x)
-- NOTE: this uses an approximation; there is still a little bias.
assert(#x > 1)
local mean = 0
for i, v in ipairs(x) do
mean = mean + v / #x
end
-- via Gurland, John; Tripathi, Ram C. (1971):
-- A Simple Approximation for Unbiased Estimation of the Standard Deviation
local divisor = #x - 1.5 + 1 / (8 * (#x - 1))
local dev = 0
for i, v in ipairs(x) do
local delta = v - mean
dev = dev + delta * delta / divisor
end
return mean, sqrt(dev)
end
local function normalize(x, out)
out = out or x
local mean, dev = calc_mean_dev(x)
@ -115,7 +138,7 @@ local function unperturbed_rank(rewards, unperturbed_reward)
local nth_place = 1
for i, v in ipairs(rewards) do
if v > unperturbed_reward then
nth_place = nth_place + 1
nth_place = nth_place + 1
end
end
return nth_place
@ -186,6 +209,11 @@ local function cdf(x)
-- i don't remember where this is from.
local sign = x >= 0 and 1 or -1
return 0.5 * (1 + sign * sqrt(1 - exp(-2 / pi * x * x)))
-- more accurate (via GELU paper, might be lifted from elsewhere):
--local const = sqrt(2 / pi)
--return 0.5 * (1 + tanh(const * (1 + 0.044715 * x * x) * x))
--return 0.5 * (1 + tanh(0.7978845608 * (1 + 0.044715 * x * x) * x))
end
local function fitness_shaping(rewards)
@ -246,7 +274,31 @@ local function weighted_mann_whitney(s0, s1, w0, w1)
local std = sqrt(mean * (w0_sum + w1_sum + 1) / 6)
local p = cdf((U - mean) / std)
if s0_sum > s1_sum then return 1 - p else return p end
local u = U / (w0_sum * w1_sum)
if s0_sum > s1_sum then return u, 1 - p else return u, p end
end
local function expect_cossim(n)
-- returns gamma(n / 2) / gamma((n + 1) / 2) / sqrt(pi) for positive integers.
-- this is the expected absolute cosine similarity between
-- two standard normally-distributed random vectors both of size n.
assert(n > 0)
-- abs(error) < 1e-8
if n >= 128000 then
return 1 / sqrt(pi / 2 * n + 1)
elseif n >= 80 then
poly = (2.4674010 * n - 2.4673232) * n + 1.2274046
return 1 / sqrt(sqrt(poly))
end
-- fall-through when it's faster just to compute iteratively.
even = n % 2 == 0
res = even and 2.0 or 1.0
for i = even and 2 or 1, n - 1, 2 do
res = res * (i / (i + 1))
end
return even and res / pi or res
end
return {
@ -260,6 +312,7 @@ return {
softchoice=softchoice,
empty=empty,
calc_mean_dev=calc_mean_dev,
calc_mean_dev_unbiased=calc_mean_dev_unbiased,
normalize=normalize,
normalize_wrt=normalize_wrt,
normalize_sums=normalize_sums,
@ -276,4 +329,5 @@ return {
pdf=pdf,
cdf=cdf,
weighted_mann_whitney=weighted_mann_whitney,
expect_cossim=expect_cossim,
}

View File

@ -16,10 +16,13 @@ local unpack = table.unpack or unpack
local Base = require "Base"
local nn = require "nn"
local dot = nn.dot
local dot_mv = nn.dot_mv
local normal = nn.normal
local zeros = nn.zeros
local expm = require "expm"
local util = require "util"
local argsort = util.argsort
@ -35,22 +38,6 @@ local function make_utility(popsize, out)
return utility
end
local function make_covars(dims, sigma, out)
local covars = out or zeros{dims, dims}
local c = sigma / dims
-- simplified form of the determinant of the matrix we're going to create:
local det = pow(1 - c, dims - 1) * (c * (dims - 1) + 1)
-- multiplying by this constant makes the determinant 1:
local m = pow(1 / det, 1 / dims)
local filler = c * m
for i=1, #covars do covars[i] = filler end
-- diagonals:
for i=1, dims do covars[i + dims * (i - 1)] = m end
return covars
end
function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
-- heuristic borrowed from CMA-ES:
self.dims = dims
@ -67,9 +54,13 @@ function Xnes:init(dims, popsize, base_rate, sigma, antithetic)
self.utility = make_utility(self.popsize)
self.mean = zeros{dims}
-- note: this is technically the co-standard-deviation.
-- you can imagine the "s" standing for "sqrt" if you like.
self.covars = make_covars(self.dims, self.sigma, self.covars)
self.covars = zeros{dims, dims}
for i=1, dims do
local ind = (i - 1) * dims + i -- diagonal
self.covars[ind] = 1
end
self.evals = 0
end
function Xnes:params(new_mean)
@ -153,6 +144,8 @@ function Xnes:tell(scored, noise)
local noise = noise or self.noise
assert(noise, "missing noise argument")
self.evals = self.evals + #scored
local arg = argsort(scored, function(a, b) return a > b end)
local g_delta = zeros{self.dims}
@ -173,7 +166,7 @@ function Xnes:tell(scored, noise)
local zzt = noise_p[i] * noise_p[j] - (i == j and 1 or 0)
local temp = self.utility[p] * zzt
g_covars[ind] = g_covars[ind] + temp
traced = traced + temp
if i == j then traced = traced + temp end
end
end
end
@ -181,7 +174,7 @@ function Xnes:tell(scored, noise)
local g_sigma = traced / self.dims
for i=1, self.dims do
local ind = (i - 1) * self.dims + i
local ind = (i - 1) * self.dims + i -- diagonal
g_covars[ind] = g_covars[ind] - g_sigma
end
@ -198,9 +191,12 @@ function Xnes:tell(scored, noise)
end
self.sigma = self.sigma * exp(self.sigma_rate * 0.5 * g_sigma)
for i, v in ipairs(self.covars) do
self.covars[i] = v * exp(self.covar_rate * 0.5 * g_covars[i])
-- re-use g_covars from before just to scale it.
for i, v in ipairs(g_covars) do
g_covars[i] = self.covar_rate * 0.5 * v
end
self.covars = dot(self.covars, expm(g_covars))
-- bookkeeping:
self.noise = nil
@ -210,7 +206,6 @@ end
return {
make_utility = make_utility,
make_covars = make_covars,
Xnes = Xnes,
}