temp
This commit is contained in:
parent
dc8969469d
commit
b7938a1785
10 changed files with 582 additions and 2 deletions
110
es_test.lua
Normal file
110
es_test.lua
Normal file
|
@ -0,0 +1,110 @@
|
|||
local floor = math.floor
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local print = print
|
||||
|
||||
local ars = require("ars")
|
||||
local snes = require("snes")
|
||||
local xnes = require("xnes")
|
||||
|
||||
-- try it all out on a dummy problem.
|
||||
|
||||
local function square(x) return x * x end
|
||||
|
||||
-- this function's global minimum is arange(dims) + 1.
|
||||
-- xNES should be able to find it almost exactly.
|
||||
local function spherical(x)
|
||||
local sum = 0
|
||||
for i, v in ipairs(x) do sum = sum + square(v - i) end
|
||||
-- we need to negate this to turn it into a maximization problem.
|
||||
return -sum
|
||||
end
|
||||
|
||||
-- i'm just copying settings from hardmaru's simple_es_example.ipynb.
|
||||
local iterations = 3000 --4000
|
||||
local dims = 100
|
||||
local popsize = dims + 1
|
||||
local sigma_init = 0.5
|
||||
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
|
||||
--local es = snes.Snes(dims, popsize, 0.1, sigma_init)
|
||||
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
|
||||
local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
|
||||
es.min_refresh = 0.7 -- FIXME: needs a better interface.
|
||||
|
||||
if false then -- TODO: delete me
|
||||
local nn = require("nn")
|
||||
local util = require("util")
|
||||
local insert = table.insert
|
||||
local scored = nn.arange(10)
|
||||
local indices = ars.collect_best_indices(scored, 3, true)
|
||||
for i, ind in ipairs(indices) do
|
||||
print(ind, ":", scored[ind * 2 - 1], scored[ind * 2 - 0])
|
||||
end
|
||||
local top_rewards = {}
|
||||
for _, ind in ipairs(indices) do
|
||||
insert(top_rewards, scored[ind * 2 - 1])
|
||||
insert(top_rewards, scored[ind * 2 - 0])
|
||||
end
|
||||
-- this shouldn't make a difference to the final print:
|
||||
top_rewards = util.normalize_sums(top_rewards)
|
||||
print(nn.pp(top_rewards))
|
||||
local _, reward_dev = util.calc_mean_dev(top_rewards)
|
||||
print(reward_dev)
|
||||
for i, ind in ipairs(indices) do
|
||||
local pos = top_rewards[i * 2 - 1]
|
||||
local neg = top_rewards[i * 2 - 0]
|
||||
local reward = pos - neg
|
||||
reward = reward / reward_dev
|
||||
print(reward)
|
||||
end
|
||||
do return end
|
||||
end
|
||||
|
||||
local asked = nil -- for caching purposes.
|
||||
local noise = nil -- for caching purposes.
|
||||
local current_cost = spherical(es:params())
|
||||
|
||||
for i=1, iterations do
|
||||
if getmetatable(es).__index == snes.Snes then
|
||||
asked, noise = es:ask_mix()
|
||||
elseif getmetatable(es).__index == ars.Ars then
|
||||
asked, noise = es:ask()
|
||||
else
|
||||
asked, noise = es:ask(asked, noise)
|
||||
end
|
||||
local scores = {}
|
||||
for i, v in ipairs(asked) do
|
||||
scores[i] = spherical(v)
|
||||
end
|
||||
|
||||
if getmetatable(es).__index == ars.Ars then
|
||||
es:tell(scores)--, current_cost) -- use lips
|
||||
else
|
||||
es:tell(scores)
|
||||
end
|
||||
|
||||
current_cost = spherical(es:params())
|
||||
--if i % 100 == 0 then
|
||||
if i % 100 == 0 then
|
||||
local sigma = es.sigma
|
||||
if getmetatable(es).__index == snes.Snes then
|
||||
sigma = 0
|
||||
for i, v in ipairs(es.std) do sigma = sigma + v end
|
||||
sigma = sigma / #es.std
|
||||
end
|
||||
local inconvergence = sigma / sigma_init
|
||||
local fmt = "fitness at iteration %i: %.4f (%.4f)"
|
||||
print(fmt:format(i, current_cost, log(inconvergence) / log(10)))
|
||||
end
|
||||
end
|
||||
|
||||
-- note: this metric doesn't include the "fitness at iteration" evaluations,
|
||||
-- because those aren't actually used to step towards the optimum.
|
||||
print(("optimized in %i function evaluations"):format(es.evals))
|
||||
|
||||
local s = ''
|
||||
for i, v in ipairs(es:params()) do
|
||||
s = s..("%.8f"):format(v)
|
||||
if i ~= es.dim then s = s..', ' end
|
||||
end
|
||||
print(s)
|
62
extra.lua
Normal file
62
extra.lua
Normal file
|
@ -0,0 +1,62 @@
|
|||
local function strpad(num, count, pad)
|
||||
num = tostring(num)
|
||||
return (pad:rep(count)..num):sub(#num)
|
||||
end
|
||||
|
||||
local function add_zeros(num, count)
|
||||
return strpad(num, count - 1, '0')
|
||||
end
|
||||
|
||||
local function mixed_sorter(a, b)
|
||||
a = type(a) == 'number' and add_zeros(a, 16) or tostring(a)
|
||||
b = type(b) == 'number' and add_zeros(b, 16) or tostring(b)
|
||||
return a < b
|
||||
end
|
||||
|
||||
-- loosely based on http://lua-users.org/wiki/SortedIteration
|
||||
-- the original didn't make use of closures for who knows why
|
||||
local function order_keys(t)
|
||||
local oi = {}
|
||||
for key in pairs(t) do
|
||||
table.insert(oi, key)
|
||||
end
|
||||
table.sort(oi, mixed_sorter)
|
||||
return oi
|
||||
end
|
||||
|
||||
local function opairs(t, cache)
|
||||
local oi = cache and cache[t] or order_keys(t)
|
||||
if cache then
|
||||
cache[t] = oi
|
||||
end
|
||||
local i = 0
|
||||
return function()
|
||||
i = i + 1
|
||||
local key = oi[i]
|
||||
if key then return key, t[key] end
|
||||
end
|
||||
end
|
||||
|
||||
local function traverse(path)
|
||||
if not path then return end
|
||||
local parent = _G
|
||||
local key
|
||||
for w in path:gfind("[%w_]+") do
|
||||
if key then
|
||||
parent = rawget(parent, key)
|
||||
if type(parent) ~= 'table' then return end
|
||||
end
|
||||
key = w
|
||||
end
|
||||
if not key then return end
|
||||
return {parent=parent, key=key}
|
||||
end
|
||||
|
||||
return {
|
||||
strpad = strpad,
|
||||
add_zeros = add_zeros,
|
||||
mixed_sorter = mixed_sorter,
|
||||
order_keys = order_keys,
|
||||
opairs = opairs,
|
||||
traverse = traverse,
|
||||
}
|
54
monitor_tiles.lua
Normal file
54
monitor_tiles.lua
Normal file
|
@ -0,0 +1,54 @@
|
|||
-- keep track of which blocks are actually seen in the game.
|
||||
-- play back an all-levels TAS with this script running.
|
||||
|
||||
local floor = math.floor
|
||||
local open = io.open
|
||||
local pairs = pairs
|
||||
local print = print
|
||||
|
||||
local util = require("util")
|
||||
local R = memory.readbyteunsigned
|
||||
local W = memory.writebyte
|
||||
local function S(addr) return util.signbyte(R(addr)) end
|
||||
|
||||
local game = require("smb") -- just for advance()
|
||||
|
||||
local serial = require "serialize"
|
||||
local serialize = serial.serialize
|
||||
local deserialize = serial.deserialize
|
||||
|
||||
local fn = 'seen_tiles.lua'
|
||||
local seen = deserialize(fn) or {}
|
||||
|
||||
local function mark_tile(sx, sy, kind)
|
||||
if not seen[kind] then
|
||||
seen[kind] = true
|
||||
print(("%02X"):format(kind))
|
||||
serialize(fn, seen)
|
||||
end
|
||||
end
|
||||
|
||||
local function handle_tiles()
|
||||
--local tile_col = R(0x6A0)
|
||||
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||
local tile_scroll_remainder = R(0x73F) % 16
|
||||
for y = 0, 12 do
|
||||
for x = 0, 16 do
|
||||
local col = (x + tile_scroll) % 32
|
||||
local t
|
||||
if col < 16 then
|
||||
t = R(0x500 + y * 16 + (col % 16))
|
||||
else
|
||||
t = R(0x5D0 + y * 16 + (col % 16))
|
||||
end
|
||||
local sx = x * 16 + 8 - tile_scroll_remainder
|
||||
local sy = y * 16 + 40
|
||||
mark_tile(sx, sy, t)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
while true do
|
||||
handle_tiles()
|
||||
game.advance()
|
||||
end
|
22
presets.lua
22
presets.lua
|
@ -72,6 +72,28 @@ make_preset{
|
|||
sigma_decay = 0.008,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'snes2',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
||||
es = 'snes',
|
||||
deterministic = true,
|
||||
deviation = 0.01,
|
||||
negate_trials = false,
|
||||
epoch_trials = 60,
|
||||
min_refresh = 2/3,
|
||||
param_rate = 0.368,
|
||||
param_decay = 0.0138,
|
||||
sigma_rate = 0.100,
|
||||
sigma_decay = 0.0051,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'snes3',
|
||||
parent = 'snes2',
|
||||
min_refresh = 1/3,
|
||||
}
|
||||
|
||||
make_preset{
|
||||
name = 'xnes',
|
||||
parent = 'big-scroll-reduced',
|
||||
|
|
15
rescale.lua
Normal file
15
rescale.lua
Normal file
|
@ -0,0 +1,15 @@
|
|||
local f = assert(io.open("params-ars4.txt", "r"))
|
||||
local data = f:read("*a")
|
||||
f:close()
|
||||
local values = {}
|
||||
for v in data:gmatch("[^\r\n]+") do
|
||||
table.insert(values, tonumber(v))
|
||||
end
|
||||
|
||||
for i, v in ipairs(values) do
|
||||
values[i] = v * 100
|
||||
end
|
||||
|
||||
for i, v in ipairs(values) do
|
||||
print(v)
|
||||
end
|
183
running.lua
Normal file
183
running.lua
Normal file
|
@ -0,0 +1,183 @@
|
|||
local huge = math.huge
|
||||
local ipairs = ipairs
|
||||
local open = io.open
|
||||
local sqrt = math.sqrt
|
||||
|
||||
local nn = require("nn")
|
||||
local Base = require("Base")
|
||||
|
||||
-- https://github.com/modestyachts/ARS/blob/master/code/filter.py
|
||||
-- http://www.johndcook.com/blog/standard_deviation/
|
||||
local Stats = Base:extend()
|
||||
local Normalizer = Base:extend()
|
||||
|
||||
function Stats:init(shape)
|
||||
self._n = 0
|
||||
self._M = nn.zeros(shape)
|
||||
self._S = nn.zeros(shape)
|
||||
end
|
||||
|
||||
function Stats:push(x)
|
||||
assert(nn.prod(x.shape) == nn.prod(self._M.shape), "sizes mismatch")
|
||||
local n1 = self._n
|
||||
self._n = self._n + 1
|
||||
if self._n == 1 then
|
||||
nn.copy(x, self._M)
|
||||
else
|
||||
local delta = {}
|
||||
for i, v in ipairs(self._M) do delta[i] = x[i] - v end
|
||||
for i, v in ipairs(self._M) do self._M[i] = v + delta[i] / self._n end
|
||||
for i, v in ipairs(self._S) do self._S[i] = v + delta[i] * delta[i] * n1 / self._n end
|
||||
end
|
||||
end
|
||||
|
||||
function Stats:var()
|
||||
local out = {}
|
||||
if self._n == 1 then
|
||||
for i, v in ipairs(self._M) do out[i] = v * v end
|
||||
else
|
||||
for i, v in ipairs(self._S) do out[i] = v / (self._n - 1) end
|
||||
end
|
||||
return out
|
||||
end
|
||||
|
||||
function Stats:dev()
|
||||
local out = self:var()
|
||||
for i, v in ipairs(out) do out[i] = sqrt(v) end
|
||||
return out
|
||||
end
|
||||
|
||||
function Normalizer:init(shape, demean, destd)
|
||||
if demean == nil then demean = true end
|
||||
if destd == nil then destd = true end
|
||||
self.shape = shape
|
||||
self.demean = demean
|
||||
self.destd = destd
|
||||
self.rs = Stats(shape)
|
||||
self.mean = nn.zeros(shape)
|
||||
self.std = nn.zeros(shape)
|
||||
for i = 1, #self.std do self.std[i] = 1 end
|
||||
end
|
||||
|
||||
function Normalizer:process(x)
|
||||
local out = nn.copy(x)
|
||||
if self.demean then
|
||||
for i, v in ipairs(out) do out[i] = out[i] - self.mean[i] end
|
||||
end
|
||||
if self.destd then
|
||||
for i, v in ipairs(out) do out[i] = out[i] / (self.std[i] + 1e-8) end
|
||||
end
|
||||
return out
|
||||
end
|
||||
|
||||
function Normalizer:update()
|
||||
nn.copy(self.rs._M, self.mean) -- FIXME: HACK
|
||||
nn.copy(self.rs:dev(), self.std)
|
||||
-- Set values for std less than 1e-7 to +inf
|
||||
-- to avoid dividing by zero. State elements
|
||||
-- with zero variance are set to zero as a result.
|
||||
for i, v in ipairs(self.std) do
|
||||
if v < 1e-7 then self.std[i] = huge end
|
||||
end
|
||||
end
|
||||
|
||||
function Normalizer:push(x, update)
|
||||
self.rs:push(x)
|
||||
if update == nil or update then self:update() end
|
||||
return self:process(x)
|
||||
end
|
||||
|
||||
function Normalizer:default_filename()
|
||||
return ('stats%07i.txt'):format(nn.prod(self.shape))
|
||||
end
|
||||
|
||||
function Normalizer:save(fn)
|
||||
local fn = fn or self:default_filename()
|
||||
local f = open(fn, 'w')
|
||||
if f == nil then error("Failed to save stats to file "..fn) end
|
||||
f:write(self.rs._n)
|
||||
f:write('\n')
|
||||
for i, v in ipairs(self.rs._M) do
|
||||
f:write(v)
|
||||
f:write('\n')
|
||||
end
|
||||
for i, v in ipairs(self.rs._S) do
|
||||
f:write(v)
|
||||
f:write('\n')
|
||||
end
|
||||
f:close()
|
||||
end
|
||||
|
||||
function Normalizer:load(fn)
|
||||
local fn = fn or self:default_filename()
|
||||
local f = open(fn, 'r')
|
||||
if f == nil then error("Failed to load stats from file "..fn) end
|
||||
|
||||
local i = 0
|
||||
local split_M = 1
|
||||
local split_S = split_M + nn.prod(self.shape)
|
||||
for line in f:lines() do
|
||||
i = i + 1
|
||||
local n = tonumber(line)
|
||||
if n == nil then
|
||||
error("Failed reading line "..tostring(i).." of file "..fn)
|
||||
end
|
||||
|
||||
if i <= split_M then
|
||||
self.rs._n = n
|
||||
elseif i <= split_S then
|
||||
self.rs._M[i - split_M] = n
|
||||
else
|
||||
self.rs._S[i - split_S] = n
|
||||
end
|
||||
end
|
||||
f:close()
|
||||
|
||||
self:update()
|
||||
end
|
||||
|
||||
--[[
|
||||
|
||||
-- basic tests
|
||||
|
||||
local dims = 20
|
||||
local rs = Stats(dims)
|
||||
local x = nn.zeros(dims)
|
||||
|
||||
for i = 1, #x do x[i] = nn.normal() end
|
||||
rs:push(x)
|
||||
print(nn.pp(rs:dev()))
|
||||
|
||||
for j = 1, 10000 do
|
||||
for i = 1, #x do x[i] = nn.normal() end
|
||||
rs:push(x)
|
||||
end
|
||||
print(nn.pp(rs:dev()))
|
||||
|
||||
--
|
||||
|
||||
local ms = Normalizer(dims)
|
||||
local exp = math.exp
|
||||
local y
|
||||
|
||||
for i = 1, #x do x[i] = exp(nn.normal()) end
|
||||
y = ms:push(x)
|
||||
print(nn.pp(y))
|
||||
|
||||
for j = 1, 10000 do
|
||||
for i = 1, #x do x[i] = exp(nn.normal()) end
|
||||
y = ms:push(x)
|
||||
end
|
||||
print(nn.pp(y))
|
||||
|
||||
print("mean:")
|
||||
print(nn.pp(ms.mean))
|
||||
print("stdev:")
|
||||
print(nn.pp(ms.std))
|
||||
|
||||
--]]
|
||||
|
||||
return {
|
||||
Stats = Stats,
|
||||
Normalizer = Normalizer,
|
||||
}
|
59
seen_tiles.lua
Normal file
59
seen_tiles.lua
Normal file
|
@ -0,0 +1,59 @@
|
|||
return {
|
||||
[0] = true,
|
||||
[16] = true,
|
||||
[17] = true,
|
||||
[18] = true,
|
||||
[19] = true,
|
||||
[20] = true,
|
||||
[21] = true,
|
||||
[22] = true,
|
||||
[23] = true,
|
||||
[24] = true,
|
||||
[25] = true,
|
||||
[26] = true,
|
||||
[27] = true,
|
||||
[28] = true,
|
||||
[29] = true,
|
||||
[30] = true,
|
||||
[31] = true,
|
||||
[32] = true,
|
||||
[33] = true,
|
||||
[34] = true,
|
||||
[35] = true,
|
||||
[36] = true,
|
||||
[37] = true,
|
||||
[38] = true,
|
||||
[81] = true,
|
||||
[82] = true,
|
||||
[84] = true,
|
||||
[85] = true,
|
||||
[86] = true,
|
||||
[87] = true,
|
||||
[88] = true,
|
||||
[89] = true,
|
||||
[90] = true,
|
||||
[91] = true,
|
||||
[92] = true,
|
||||
[93] = true,
|
||||
[94] = true,
|
||||
[95] = true,
|
||||
[96] = true,
|
||||
[97] = true,
|
||||
[98] = true,
|
||||
[99] = true,
|
||||
[100] = true,
|
||||
[101] = true,
|
||||
[102] = true,
|
||||
[103] = true,
|
||||
[104] = true,
|
||||
[105] = true,
|
||||
[107] = true,
|
||||
[108] = true,
|
||||
[137] = true,
|
||||
[192] = true,
|
||||
[193] = true,
|
||||
[194] = true,
|
||||
[195] = true,
|
||||
[196] = true,
|
||||
[197] = true,
|
||||
}
|
76
serialize.lua
Normal file
76
serialize.lua
Normal file
|
@ -0,0 +1,76 @@
|
|||
-- it's simple, dumb, unsafe, incomplete, and it gets the damn job done
|
||||
|
||||
local type = type
|
||||
local extra = require "extra"
|
||||
local opairs = extra.opairs
|
||||
local tostring = tostring
|
||||
local open = io.open
|
||||
local strfmt = string.format
|
||||
local strrep = string.rep
|
||||
|
||||
local function kill_bom(s)
|
||||
if #s >= 3 and s:byte(1)==0xEF and s:byte(2)==0xBB and s:byte(3)==0xBF then
|
||||
return s:sub(4)
|
||||
end
|
||||
return s
|
||||
end
|
||||
|
||||
local function sanitize(v)
|
||||
local force = type(v) == 'string' and v:sub(1, 1):match('%d')
|
||||
force = force and true or false
|
||||
return type(v) == 'string' and strfmt('%q', v) or tostring(v), force
|
||||
end
|
||||
|
||||
local function _serialize(value, writer, level)
|
||||
level = level or 1
|
||||
if type(value) == 'table' then
|
||||
local indent = strrep('\t', level)
|
||||
writer('{\n')
|
||||
for key,value in opairs(value) do
|
||||
local sane, force = sanitize(key)
|
||||
local keyval = (sane == '"'..key..'"' and not force) and key or '['..sane..']'
|
||||
writer(indent..keyval..' = ')
|
||||
_serialize(value, writer, level + 1)
|
||||
writer(',\n')
|
||||
end
|
||||
writer(strrep('\t', level - 1)..'}')
|
||||
else
|
||||
local sane, force = sanitize(value)
|
||||
writer(sane)
|
||||
end
|
||||
end
|
||||
|
||||
local function _deserialize(script)
|
||||
local f = loadstring(kill_bom(script))
|
||||
if f ~= nil then
|
||||
return f()
|
||||
else
|
||||
print('WARNING: no function to deserialize with')
|
||||
return nil
|
||||
end
|
||||
end
|
||||
|
||||
local function serialize(path, value)
|
||||
local file = open(path, 'w')
|
||||
if not file then return end
|
||||
file:write("return ")
|
||||
_serialize(value, function(...)
|
||||
file:write(...)
|
||||
end)
|
||||
file:write("\n")
|
||||
file:close()
|
||||
end
|
||||
|
||||
local function deserialize(path)
|
||||
local file = open(path, 'r')
|
||||
if not file then return end
|
||||
local script = file:read('*a')
|
||||
local value = _deserialize(script)
|
||||
file:close()
|
||||
return value
|
||||
end
|
||||
|
||||
return {
|
||||
serialize = serialize,
|
||||
deserialize = deserialize,
|
||||
}
|
1
snes.lua
1
snes.lua
|
@ -3,7 +3,6 @@
|
|||
-- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
|
||||
-- not to be confused with the Super Nintendo Entertainment System.
|
||||
|
||||
local abs = math.abs
|
||||
local assert = assert
|
||||
local exp = math.exp
|
||||
local floor = math.floor
|
||||
|
|
Loading…
Add table
Reference in a new issue