diff --git a/es_test.lua b/es_test.lua new file mode 100644 index 0000000..bf76479 --- /dev/null +++ b/es_test.lua @@ -0,0 +1,110 @@ +local floor = math.floor +local ipairs = ipairs +local log = math.log +local print = print + +local ars = require("ars") +local snes = require("snes") +local xnes = require("xnes") + +-- try it all out on a dummy problem. + +local function square(x) return x * x end + +-- this function's global minimum is arange(dims) + 1. +-- xNES should be able to find it almost exactly. +local function spherical(x) + local sum = 0 + for i, v in ipairs(x) do sum = sum + square(v - i) end + -- we need to negate this to turn it into a maximization problem. + return -sum +end + +-- i'm just copying settings from hardmaru's simple_es_example.ipynb. +local iterations = 3000 --4000 +local dims = 100 +local popsize = dims + 1 +local sigma_init = 0.5 +--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init) +--local es = snes.Snes(dims, popsize, 0.1, sigma_init) +--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true) +local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true) +es.min_refresh = 0.7 -- FIXME: needs a better interface. + +if false then -- TODO: delete me + local nn = require("nn") + local util = require("util") + local insert = table.insert + local scored = nn.arange(10) + local indices = ars.collect_best_indices(scored, 3, true) + for i, ind in ipairs(indices) do + print(ind, ":", scored[ind * 2 - 1], scored[ind * 2 - 0]) + end + local top_rewards = {} + for _, ind in ipairs(indices) do + insert(top_rewards, scored[ind * 2 - 1]) + insert(top_rewards, scored[ind * 2 - 0]) + end + -- this shouldn't make a difference to the final print: + top_rewards = util.normalize_sums(top_rewards) + print(nn.pp(top_rewards)) + local _, reward_dev = util.calc_mean_dev(top_rewards) + print(reward_dev) + for i, ind in ipairs(indices) do + local pos = top_rewards[i * 2 - 1] + local neg = top_rewards[i * 2 - 0] + local reward = pos - neg + reward = reward / reward_dev + print(reward) + end + do return end +end + +local asked = nil -- for caching purposes. +local noise = nil -- for caching purposes. +local current_cost = spherical(es:params()) + +for i=1, iterations do + if getmetatable(es).__index == snes.Snes then + asked, noise = es:ask_mix() + elseif getmetatable(es).__index == ars.Ars then + asked, noise = es:ask() + else + asked, noise = es:ask(asked, noise) + end + local scores = {} + for i, v in ipairs(asked) do + scores[i] = spherical(v) + end + + if getmetatable(es).__index == ars.Ars then + es:tell(scores)--, current_cost) -- use lips + else + es:tell(scores) + end + + current_cost = spherical(es:params()) + --if i % 100 == 0 then + if i % 100 == 0 then + local sigma = es.sigma + if getmetatable(es).__index == snes.Snes then + sigma = 0 + for i, v in ipairs(es.std) do sigma = sigma + v end + sigma = sigma / #es.std + end + local inconvergence = sigma / sigma_init + local fmt = "fitness at iteration %i: %.4f (%.4f)" + print(fmt:format(i, current_cost, log(inconvergence) / log(10))) + end +end + +-- note: this metric doesn't include the "fitness at iteration" evaluations, +-- because those aren't actually used to step towards the optimum. +print(("optimized in %i function evaluations"):format(es.evals)) + +local s = '' +for i, v in ipairs(es:params()) do + s = s..("%.8f"):format(v) + if i ~= es.dim then s = s..', ' end +end +print(s) diff --git a/extra.lua b/extra.lua new file mode 100644 index 0000000..c0aa538 --- /dev/null +++ b/extra.lua @@ -0,0 +1,62 @@ +local function strpad(num, count, pad) + num = tostring(num) + return (pad:rep(count)..num):sub(#num) +end + +local function add_zeros(num, count) + return strpad(num, count - 1, '0') +end + +local function mixed_sorter(a, b) + a = type(a) == 'number' and add_zeros(a, 16) or tostring(a) + b = type(b) == 'number' and add_zeros(b, 16) or tostring(b) + return a < b +end + +-- loosely based on http://lua-users.org/wiki/SortedIteration +-- the original didn't make use of closures for who knows why +local function order_keys(t) + local oi = {} + for key in pairs(t) do + table.insert(oi, key) + end + table.sort(oi, mixed_sorter) + return oi +end + +local function opairs(t, cache) + local oi = cache and cache[t] or order_keys(t) + if cache then + cache[t] = oi + end + local i = 0 + return function() + i = i + 1 + local key = oi[i] + if key then return key, t[key] end + end +end + +local function traverse(path) + if not path then return end + local parent = _G + local key + for w in path:gfind("[%w_]+") do + if key then + parent = rawget(parent, key) + if type(parent) ~= 'table' then return end + end + key = w + end + if not key then return end + return {parent=parent, key=key} +end + +return { + strpad = strpad, + add_zeros = add_zeros, + mixed_sorter = mixed_sorter, + order_keys = order_keys, + opairs = opairs, + traverse = traverse, +} diff --git a/monitor_tiles.lua b/monitor_tiles.lua new file mode 100644 index 0000000..111d3b5 --- /dev/null +++ b/monitor_tiles.lua @@ -0,0 +1,54 @@ +-- keep track of which blocks are actually seen in the game. +-- play back an all-levels TAS with this script running. + +local floor = math.floor +local open = io.open +local pairs = pairs +local print = print + +local util = require("util") +local R = memory.readbyteunsigned +local W = memory.writebyte +local function S(addr) return util.signbyte(R(addr)) end + +local game = require("smb") -- just for advance() + +local serial = require "serialize" +local serialize = serial.serialize +local deserialize = serial.deserialize + +local fn = 'seen_tiles.lua' +local seen = deserialize(fn) or {} + +local function mark_tile(sx, sy, kind) + if not seen[kind] then + seen[kind] = true + print(("%02X"):format(kind)) + serialize(fn, seen) + end +end + +local function handle_tiles() + --local tile_col = R(0x6A0) + local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16 + local tile_scroll_remainder = R(0x73F) % 16 + for y = 0, 12 do + for x = 0, 16 do + local col = (x + tile_scroll) % 32 + local t + if col < 16 then + t = R(0x500 + y * 16 + (col % 16)) + else + t = R(0x5D0 + y * 16 + (col % 16)) + end + local sx = x * 16 + 8 - tile_scroll_remainder + local sy = y * 16 + 40 + mark_tile(sx, sy, t) + end + end +end + +while true do + handle_tiles() + game.advance() +end diff --git a/presets.lua b/presets.lua index 47736f1..37d9965 100644 --- a/presets.lua +++ b/presets.lua @@ -72,6 +72,28 @@ make_preset{ sigma_decay = 0.008, } +make_preset{ + name = 'snes2', + parent = 'big-scroll-reduced', + + es = 'snes', + deterministic = true, + deviation = 0.01, + negate_trials = false, + epoch_trials = 60, + min_refresh = 2/3, + param_rate = 0.368, + param_decay = 0.0138, + sigma_rate = 0.100, + sigma_decay = 0.0051, +} + +make_preset{ + name = 'snes3', + parent = 'snes2', + min_refresh = 1/3, +} + make_preset{ name = 'xnes', parent = 'big-scroll-reduced', diff --git a/rescale.lua b/rescale.lua new file mode 100644 index 0000000..b3976d2 --- /dev/null +++ b/rescale.lua @@ -0,0 +1,15 @@ +local f = assert(io.open("params-ars4.txt", "r")) +local data = f:read("*a") +f:close() +local values = {} +for v in data:gmatch("[^\r\n]+") do + table.insert(values, tonumber(v)) +end + +for i, v in ipairs(values) do + values[i] = v * 100 +end + +for i, v in ipairs(values) do + print(v) +end diff --git a/running.lua b/running.lua new file mode 100644 index 0000000..969f1ab --- /dev/null +++ b/running.lua @@ -0,0 +1,183 @@ +local huge = math.huge +local ipairs = ipairs +local open = io.open +local sqrt = math.sqrt + +local nn = require("nn") +local Base = require("Base") + +-- https://github.com/modestyachts/ARS/blob/master/code/filter.py +-- http://www.johndcook.com/blog/standard_deviation/ +local Stats = Base:extend() +local Normalizer = Base:extend() + +function Stats:init(shape) + self._n = 0 + self._M = nn.zeros(shape) + self._S = nn.zeros(shape) +end + +function Stats:push(x) + assert(nn.prod(x.shape) == nn.prod(self._M.shape), "sizes mismatch") + local n1 = self._n + self._n = self._n + 1 + if self._n == 1 then + nn.copy(x, self._M) + else + local delta = {} + for i, v in ipairs(self._M) do delta[i] = x[i] - v end + for i, v in ipairs(self._M) do self._M[i] = v + delta[i] / self._n end + for i, v in ipairs(self._S) do self._S[i] = v + delta[i] * delta[i] * n1 / self._n end + end +end + +function Stats:var() + local out = {} + if self._n == 1 then + for i, v in ipairs(self._M) do out[i] = v * v end + else + for i, v in ipairs(self._S) do out[i] = v / (self._n - 1) end + end + return out +end + +function Stats:dev() + local out = self:var() + for i, v in ipairs(out) do out[i] = sqrt(v) end + return out +end + +function Normalizer:init(shape, demean, destd) + if demean == nil then demean = true end + if destd == nil then destd = true end + self.shape = shape + self.demean = demean + self.destd = destd + self.rs = Stats(shape) + self.mean = nn.zeros(shape) + self.std = nn.zeros(shape) + for i = 1, #self.std do self.std[i] = 1 end +end + +function Normalizer:process(x) + local out = nn.copy(x) + if self.demean then + for i, v in ipairs(out) do out[i] = out[i] - self.mean[i] end + end + if self.destd then + for i, v in ipairs(out) do out[i] = out[i] / (self.std[i] + 1e-8) end + end + return out +end + +function Normalizer:update() + nn.copy(self.rs._M, self.mean) -- FIXME: HACK + nn.copy(self.rs:dev(), self.std) + -- Set values for std less than 1e-7 to +inf + -- to avoid dividing by zero. State elements + -- with zero variance are set to zero as a result. + for i, v in ipairs(self.std) do + if v < 1e-7 then self.std[i] = huge end + end +end + +function Normalizer:push(x, update) + self.rs:push(x) + if update == nil or update then self:update() end + return self:process(x) +end + +function Normalizer:default_filename() + return ('stats%07i.txt'):format(nn.prod(self.shape)) +end + +function Normalizer:save(fn) + local fn = fn or self:default_filename() + local f = open(fn, 'w') + if f == nil then error("Failed to save stats to file "..fn) end + f:write(self.rs._n) + f:write('\n') + for i, v in ipairs(self.rs._M) do + f:write(v) + f:write('\n') + end + for i, v in ipairs(self.rs._S) do + f:write(v) + f:write('\n') + end + f:close() +end + +function Normalizer:load(fn) + local fn = fn or self:default_filename() + local f = open(fn, 'r') + if f == nil then error("Failed to load stats from file "..fn) end + + local i = 0 + local split_M = 1 + local split_S = split_M + nn.prod(self.shape) + for line in f:lines() do + i = i + 1 + local n = tonumber(line) + if n == nil then + error("Failed reading line "..tostring(i).." of file "..fn) + end + + if i <= split_M then + self.rs._n = n + elseif i <= split_S then + self.rs._M[i - split_M] = n + else + self.rs._S[i - split_S] = n + end + end + f:close() + + self:update() +end + +--[[ + +-- basic tests + +local dims = 20 +local rs = Stats(dims) +local x = nn.zeros(dims) + +for i = 1, #x do x[i] = nn.normal() end +rs:push(x) +print(nn.pp(rs:dev())) + +for j = 1, 10000 do + for i = 1, #x do x[i] = nn.normal() end + rs:push(x) +end +print(nn.pp(rs:dev())) + +-- + +local ms = Normalizer(dims) +local exp = math.exp +local y + +for i = 1, #x do x[i] = exp(nn.normal()) end +y = ms:push(x) +print(nn.pp(y)) + +for j = 1, 10000 do + for i = 1, #x do x[i] = exp(nn.normal()) end + y = ms:push(x) +end +print(nn.pp(y)) + +print("mean:") +print(nn.pp(ms.mean)) +print("stdev:") +print(nn.pp(ms.std)) + +--]] + +return { + Stats = Stats, + Normalizer = Normalizer, +} diff --git a/seen_tiles.lua b/seen_tiles.lua new file mode 100644 index 0000000..90ad5b7 --- /dev/null +++ b/seen_tiles.lua @@ -0,0 +1,59 @@ +return { + [0] = true, + [16] = true, + [17] = true, + [18] = true, + [19] = true, + [20] = true, + [21] = true, + [22] = true, + [23] = true, + [24] = true, + [25] = true, + [26] = true, + [27] = true, + [28] = true, + [29] = true, + [30] = true, + [31] = true, + [32] = true, + [33] = true, + [34] = true, + [35] = true, + [36] = true, + [37] = true, + [38] = true, + [81] = true, + [82] = true, + [84] = true, + [85] = true, + [86] = true, + [87] = true, + [88] = true, + [89] = true, + [90] = true, + [91] = true, + [92] = true, + [93] = true, + [94] = true, + [95] = true, + [96] = true, + [97] = true, + [98] = true, + [99] = true, + [100] = true, + [101] = true, + [102] = true, + [103] = true, + [104] = true, + [105] = true, + [107] = true, + [108] = true, + [137] = true, + [192] = true, + [193] = true, + [194] = true, + [195] = true, + [196] = true, + [197] = true, +} diff --git a/serialize.lua b/serialize.lua new file mode 100644 index 0000000..44f1726 --- /dev/null +++ b/serialize.lua @@ -0,0 +1,76 @@ +-- it's simple, dumb, unsafe, incomplete, and it gets the damn job done + +local type = type +local extra = require "extra" +local opairs = extra.opairs +local tostring = tostring +local open = io.open +local strfmt = string.format +local strrep = string.rep + +local function kill_bom(s) + if #s >= 3 and s:byte(1)==0xEF and s:byte(2)==0xBB and s:byte(3)==0xBF then + return s:sub(4) + end + return s +end + +local function sanitize(v) + local force = type(v) == 'string' and v:sub(1, 1):match('%d') + force = force and true or false + return type(v) == 'string' and strfmt('%q', v) or tostring(v), force +end + +local function _serialize(value, writer, level) + level = level or 1 + if type(value) == 'table' then + local indent = strrep('\t', level) + writer('{\n') + for key,value in opairs(value) do + local sane, force = sanitize(key) + local keyval = (sane == '"'..key..'"' and not force) and key or '['..sane..']' + writer(indent..keyval..' = ') + _serialize(value, writer, level + 1) + writer(',\n') + end + writer(strrep('\t', level - 1)..'}') + else + local sane, force = sanitize(value) + writer(sane) + end +end + +local function _deserialize(script) + local f = loadstring(kill_bom(script)) + if f ~= nil then + return f() + else + print('WARNING: no function to deserialize with') + return nil + end +end + +local function serialize(path, value) + local file = open(path, 'w') + if not file then return end + file:write("return ") + _serialize(value, function(...) + file:write(...) + end) + file:write("\n") + file:close() +end + +local function deserialize(path) + local file = open(path, 'r') + if not file then return end + local script = file:read('*a') + local value = _deserialize(script) + file:close() + return value +end + +return { + serialize = serialize, + deserialize = deserialize, +} diff --git a/snes.lua b/snes.lua index f2532d5..38fd3bb 100644 --- a/snes.lua +++ b/snes.lua @@ -3,7 +3,6 @@ -- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf -- not to be confused with the Super Nintendo Entertainment System. -local abs = math.abs local assert = assert local exp = math.exp local floor = math.floor diff --git a/util.lua b/util.lua index 1b98428..b8b5b89 100644 --- a/util.lua +++ b/util.lua @@ -115,7 +115,7 @@ local function unperturbed_rank(rewards, unperturbed_reward) local nth_place = 1 for i, v in ipairs(rewards) do if v > unperturbed_reward then - nth_place = nth_place + 1 + nth_place = nth_place + 1 end end return nth_place