This commit is contained in:
Connor Olding 2018-06-30 20:08:49 +02:00
parent dc8969469d
commit b7938a1785
10 changed files with 582 additions and 2 deletions

110
es_test.lua Normal file
View File

@ -0,0 +1,110 @@
local floor = math.floor
local ipairs = ipairs
local log = math.log
local print = print
local ars = require("ars")
local snes = require("snes")
local xnes = require("xnes")
-- try it all out on a dummy problem.
local function square(x) return x * x end
-- this function's global minimum is arange(dims) + 1.
-- xNES should be able to find it almost exactly.
local function spherical(x)
local sum = 0
for i, v in ipairs(x) do sum = sum + square(v - i) end
-- we need to negate this to turn it into a maximization problem.
return -sum
end
-- i'm just copying settings from hardmaru's simple_es_example.ipynb.
local iterations = 3000 --4000
local dims = 100
local popsize = dims + 1
local sigma_init = 0.5
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
--local es = snes.Snes(dims, popsize, 0.1, sigma_init)
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
es.min_refresh = 0.7 -- FIXME: needs a better interface.
if false then -- TODO: delete me
local nn = require("nn")
local util = require("util")
local insert = table.insert
local scored = nn.arange(10)
local indices = ars.collect_best_indices(scored, 3, true)
for i, ind in ipairs(indices) do
print(ind, ":", scored[ind * 2 - 1], scored[ind * 2 - 0])
end
local top_rewards = {}
for _, ind in ipairs(indices) do
insert(top_rewards, scored[ind * 2 - 1])
insert(top_rewards, scored[ind * 2 - 0])
end
-- this shouldn't make a difference to the final print:
top_rewards = util.normalize_sums(top_rewards)
print(nn.pp(top_rewards))
local _, reward_dev = util.calc_mean_dev(top_rewards)
print(reward_dev)
for i, ind in ipairs(indices) do
local pos = top_rewards[i * 2 - 1]
local neg = top_rewards[i * 2 - 0]
local reward = pos - neg
reward = reward / reward_dev
print(reward)
end
do return end
end
local asked = nil -- for caching purposes.
local noise = nil -- for caching purposes.
local current_cost = spherical(es:params())
for i=1, iterations do
if getmetatable(es).__index == snes.Snes then
asked, noise = es:ask_mix()
elseif getmetatable(es).__index == ars.Ars then
asked, noise = es:ask()
else
asked, noise = es:ask(asked, noise)
end
local scores = {}
for i, v in ipairs(asked) do
scores[i] = spherical(v)
end
if getmetatable(es).__index == ars.Ars then
es:tell(scores)--, current_cost) -- use lips
else
es:tell(scores)
end
current_cost = spherical(es:params())
--if i % 100 == 0 then
if i % 100 == 0 then
local sigma = es.sigma
if getmetatable(es).__index == snes.Snes then
sigma = 0
for i, v in ipairs(es.std) do sigma = sigma + v end
sigma = sigma / #es.std
end
local inconvergence = sigma / sigma_init
local fmt = "fitness at iteration %i: %.4f (%.4f)"
print(fmt:format(i, current_cost, log(inconvergence) / log(10)))
end
end
-- note: this metric doesn't include the "fitness at iteration" evaluations,
-- because those aren't actually used to step towards the optimum.
print(("optimized in %i function evaluations"):format(es.evals))
local s = ''
for i, v in ipairs(es:params()) do
s = s..("%.8f"):format(v)
if i ~= es.dim then s = s..', ' end
end
print(s)

62
extra.lua Normal file
View File

@ -0,0 +1,62 @@
local function strpad(num, count, pad)
num = tostring(num)
return (pad:rep(count)..num):sub(#num)
end
local function add_zeros(num, count)
return strpad(num, count - 1, '0')
end
local function mixed_sorter(a, b)
a = type(a) == 'number' and add_zeros(a, 16) or tostring(a)
b = type(b) == 'number' and add_zeros(b, 16) or tostring(b)
return a < b
end
-- loosely based on http://lua-users.org/wiki/SortedIteration
-- the original didn't make use of closures for who knows why
local function order_keys(t)
local oi = {}
for key in pairs(t) do
table.insert(oi, key)
end
table.sort(oi, mixed_sorter)
return oi
end
local function opairs(t, cache)
local oi = cache and cache[t] or order_keys(t)
if cache then
cache[t] = oi
end
local i = 0
return function()
i = i + 1
local key = oi[i]
if key then return key, t[key] end
end
end
local function traverse(path)
if not path then return end
local parent = _G
local key
for w in path:gfind("[%w_]+") do
if key then
parent = rawget(parent, key)
if type(parent) ~= 'table' then return end
end
key = w
end
if not key then return end
return {parent=parent, key=key}
end
return {
strpad = strpad,
add_zeros = add_zeros,
mixed_sorter = mixed_sorter,
order_keys = order_keys,
opairs = opairs,
traverse = traverse,
}

54
monitor_tiles.lua Normal file
View File

@ -0,0 +1,54 @@
-- keep track of which blocks are actually seen in the game.
-- play back an all-levels TAS with this script running.
local floor = math.floor
local open = io.open
local pairs = pairs
local print = print
local util = require("util")
local R = memory.readbyteunsigned
local W = memory.writebyte
local function S(addr) return util.signbyte(R(addr)) end
local game = require("smb") -- just for advance()
local serial = require "serialize"
local serialize = serial.serialize
local deserialize = serial.deserialize
local fn = 'seen_tiles.lua'
local seen = deserialize(fn) or {}
local function mark_tile(sx, sy, kind)
if not seen[kind] then
seen[kind] = true
print(("%02X"):format(kind))
serialize(fn, seen)
end
end
local function handle_tiles()
--local tile_col = R(0x6A0)
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
local tile_scroll_remainder = R(0x73F) % 16
for y = 0, 12 do
for x = 0, 16 do
local col = (x + tile_scroll) % 32
local t
if col < 16 then
t = R(0x500 + y * 16 + (col % 16))
else
t = R(0x5D0 + y * 16 + (col % 16))
end
local sx = x * 16 + 8 - tile_scroll_remainder
local sy = y * 16 + 40
mark_tile(sx, sy, t)
end
end
end
while true do
handle_tiles()
game.advance()
end

View File

@ -72,6 +72,28 @@ make_preset{
sigma_decay = 0.008,
}
make_preset{
name = 'snes2',
parent = 'big-scroll-reduced',
es = 'snes',
deterministic = true,
deviation = 0.01,
negate_trials = false,
epoch_trials = 60,
min_refresh = 2/3,
param_rate = 0.368,
param_decay = 0.0138,
sigma_rate = 0.100,
sigma_decay = 0.0051,
}
make_preset{
name = 'snes3',
parent = 'snes2',
min_refresh = 1/3,
}
make_preset{
name = 'xnes',
parent = 'big-scroll-reduced',

15
rescale.lua Normal file
View File

@ -0,0 +1,15 @@
local f = assert(io.open("params-ars4.txt", "r"))
local data = f:read("*a")
f:close()
local values = {}
for v in data:gmatch("[^\r\n]+") do
table.insert(values, tonumber(v))
end
for i, v in ipairs(values) do
values[i] = v * 100
end
for i, v in ipairs(values) do
print(v)
end

183
running.lua Normal file
View File

@ -0,0 +1,183 @@
local huge = math.huge
local ipairs = ipairs
local open = io.open
local sqrt = math.sqrt
local nn = require("nn")
local Base = require("Base")
-- https://github.com/modestyachts/ARS/blob/master/code/filter.py
-- http://www.johndcook.com/blog/standard_deviation/
local Stats = Base:extend()
local Normalizer = Base:extend()
function Stats:init(shape)
self._n = 0
self._M = nn.zeros(shape)
self._S = nn.zeros(shape)
end
function Stats:push(x)
assert(nn.prod(x.shape) == nn.prod(self._M.shape), "sizes mismatch")
local n1 = self._n
self._n = self._n + 1
if self._n == 1 then
nn.copy(x, self._M)
else
local delta = {}
for i, v in ipairs(self._M) do delta[i] = x[i] - v end
for i, v in ipairs(self._M) do self._M[i] = v + delta[i] / self._n end
for i, v in ipairs(self._S) do self._S[i] = v + delta[i] * delta[i] * n1 / self._n end
end
end
function Stats:var()
local out = {}
if self._n == 1 then
for i, v in ipairs(self._M) do out[i] = v * v end
else
for i, v in ipairs(self._S) do out[i] = v / (self._n - 1) end
end
return out
end
function Stats:dev()
local out = self:var()
for i, v in ipairs(out) do out[i] = sqrt(v) end
return out
end
function Normalizer:init(shape, demean, destd)
if demean == nil then demean = true end
if destd == nil then destd = true end
self.shape = shape
self.demean = demean
self.destd = destd
self.rs = Stats(shape)
self.mean = nn.zeros(shape)
self.std = nn.zeros(shape)
for i = 1, #self.std do self.std[i] = 1 end
end
function Normalizer:process(x)
local out = nn.copy(x)
if self.demean then
for i, v in ipairs(out) do out[i] = out[i] - self.mean[i] end
end
if self.destd then
for i, v in ipairs(out) do out[i] = out[i] / (self.std[i] + 1e-8) end
end
return out
end
function Normalizer:update()
nn.copy(self.rs._M, self.mean) -- FIXME: HACK
nn.copy(self.rs:dev(), self.std)
-- Set values for std less than 1e-7 to +inf
-- to avoid dividing by zero. State elements
-- with zero variance are set to zero as a result.
for i, v in ipairs(self.std) do
if v < 1e-7 then self.std[i] = huge end
end
end
function Normalizer:push(x, update)
self.rs:push(x)
if update == nil or update then self:update() end
return self:process(x)
end
function Normalizer:default_filename()
return ('stats%07i.txt'):format(nn.prod(self.shape))
end
function Normalizer:save(fn)
local fn = fn or self:default_filename()
local f = open(fn, 'w')
if f == nil then error("Failed to save stats to file "..fn) end
f:write(self.rs._n)
f:write('\n')
for i, v in ipairs(self.rs._M) do
f:write(v)
f:write('\n')
end
for i, v in ipairs(self.rs._S) do
f:write(v)
f:write('\n')
end
f:close()
end
function Normalizer:load(fn)
local fn = fn or self:default_filename()
local f = open(fn, 'r')
if f == nil then error("Failed to load stats from file "..fn) end
local i = 0
local split_M = 1
local split_S = split_M + nn.prod(self.shape)
for line in f:lines() do
i = i + 1
local n = tonumber(line)
if n == nil then
error("Failed reading line "..tostring(i).." of file "..fn)
end
if i <= split_M then
self.rs._n = n
elseif i <= split_S then
self.rs._M[i - split_M] = n
else
self.rs._S[i - split_S] = n
end
end
f:close()
self:update()
end
--[[
-- basic tests
local dims = 20
local rs = Stats(dims)
local x = nn.zeros(dims)
for i = 1, #x do x[i] = nn.normal() end
rs:push(x)
print(nn.pp(rs:dev()))
for j = 1, 10000 do
for i = 1, #x do x[i] = nn.normal() end
rs:push(x)
end
print(nn.pp(rs:dev()))
--
local ms = Normalizer(dims)
local exp = math.exp
local y
for i = 1, #x do x[i] = exp(nn.normal()) end
y = ms:push(x)
print(nn.pp(y))
for j = 1, 10000 do
for i = 1, #x do x[i] = exp(nn.normal()) end
y = ms:push(x)
end
print(nn.pp(y))
print("mean:")
print(nn.pp(ms.mean))
print("stdev:")
print(nn.pp(ms.std))
--]]
return {
Stats = Stats,
Normalizer = Normalizer,
}

59
seen_tiles.lua Normal file
View File

@ -0,0 +1,59 @@
return {
[0] = true,
[16] = true,
[17] = true,
[18] = true,
[19] = true,
[20] = true,
[21] = true,
[22] = true,
[23] = true,
[24] = true,
[25] = true,
[26] = true,
[27] = true,
[28] = true,
[29] = true,
[30] = true,
[31] = true,
[32] = true,
[33] = true,
[34] = true,
[35] = true,
[36] = true,
[37] = true,
[38] = true,
[81] = true,
[82] = true,
[84] = true,
[85] = true,
[86] = true,
[87] = true,
[88] = true,
[89] = true,
[90] = true,
[91] = true,
[92] = true,
[93] = true,
[94] = true,
[95] = true,
[96] = true,
[97] = true,
[98] = true,
[99] = true,
[100] = true,
[101] = true,
[102] = true,
[103] = true,
[104] = true,
[105] = true,
[107] = true,
[108] = true,
[137] = true,
[192] = true,
[193] = true,
[194] = true,
[195] = true,
[196] = true,
[197] = true,
}

76
serialize.lua Normal file
View File

@ -0,0 +1,76 @@
-- it's simple, dumb, unsafe, incomplete, and it gets the damn job done
local type = type
local extra = require "extra"
local opairs = extra.opairs
local tostring = tostring
local open = io.open
local strfmt = string.format
local strrep = string.rep
local function kill_bom(s)
if #s >= 3 and s:byte(1)==0xEF and s:byte(2)==0xBB and s:byte(3)==0xBF then
return s:sub(4)
end
return s
end
local function sanitize(v)
local force = type(v) == 'string' and v:sub(1, 1):match('%d')
force = force and true or false
return type(v) == 'string' and strfmt('%q', v) or tostring(v), force
end
local function _serialize(value, writer, level)
level = level or 1
if type(value) == 'table' then
local indent = strrep('\t', level)
writer('{\n')
for key,value in opairs(value) do
local sane, force = sanitize(key)
local keyval = (sane == '"'..key..'"' and not force) and key or '['..sane..']'
writer(indent..keyval..' = ')
_serialize(value, writer, level + 1)
writer(',\n')
end
writer(strrep('\t', level - 1)..'}')
else
local sane, force = sanitize(value)
writer(sane)
end
end
local function _deserialize(script)
local f = loadstring(kill_bom(script))
if f ~= nil then
return f()
else
print('WARNING: no function to deserialize with')
return nil
end
end
local function serialize(path, value)
local file = open(path, 'w')
if not file then return end
file:write("return ")
_serialize(value, function(...)
file:write(...)
end)
file:write("\n")
file:close()
end
local function deserialize(path)
local file = open(path, 'r')
if not file then return end
local script = file:read('*a')
local value = _deserialize(script)
file:close()
return value
end
return {
serialize = serialize,
deserialize = deserialize,
}

View File

@ -3,7 +3,6 @@
-- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
-- not to be confused with the Super Nintendo Entertainment System.
local abs = math.abs
local assert = assert
local exp = math.exp
local floor = math.floor

View File

@ -115,7 +115,7 @@ local function unperturbed_rank(rewards, unperturbed_reward)
local nth_place = 1
for i, v in ipairs(rewards) do
if v > unperturbed_reward then
nth_place = nth_place + 1
nth_place = nth_place + 1
end
end
return nth_place