temp
This commit is contained in:
parent
dc8969469d
commit
b7938a1785
10 changed files with 582 additions and 2 deletions
110
es_test.lua
Normal file
110
es_test.lua
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
local floor = math.floor
|
||||||
|
local ipairs = ipairs
|
||||||
|
local log = math.log
|
||||||
|
local print = print
|
||||||
|
|
||||||
|
local ars = require("ars")
|
||||||
|
local snes = require("snes")
|
||||||
|
local xnes = require("xnes")
|
||||||
|
|
||||||
|
-- try it all out on a dummy problem.
|
||||||
|
|
||||||
|
local function square(x) return x * x end
|
||||||
|
|
||||||
|
-- this function's global minimum is arange(dims) + 1.
|
||||||
|
-- xNES should be able to find it almost exactly.
|
||||||
|
local function spherical(x)
|
||||||
|
local sum = 0
|
||||||
|
for i, v in ipairs(x) do sum = sum + square(v - i) end
|
||||||
|
-- we need to negate this to turn it into a maximization problem.
|
||||||
|
return -sum
|
||||||
|
end
|
||||||
|
|
||||||
|
-- i'm just copying settings from hardmaru's simple_es_example.ipynb.
|
||||||
|
local iterations = 3000 --4000
|
||||||
|
local dims = 100
|
||||||
|
local popsize = dims + 1
|
||||||
|
local sigma_init = 0.5
|
||||||
|
--local es = xnes.Xnes(dims, popsize, 0.1, sigma_init)
|
||||||
|
--local es = snes.Snes(dims, popsize, 0.1, sigma_init)
|
||||||
|
--local es = ars.Ars(dims, floor(popsize / 2), floor(popsize / 2), 1.0, sigma_init, true)
|
||||||
|
local es = ars.Ars(dims, popsize, popsize, 1.0, sigma_init, true)
|
||||||
|
es.min_refresh = 0.7 -- FIXME: needs a better interface.
|
||||||
|
|
||||||
|
if false then -- TODO: delete me
|
||||||
|
local nn = require("nn")
|
||||||
|
local util = require("util")
|
||||||
|
local insert = table.insert
|
||||||
|
local scored = nn.arange(10)
|
||||||
|
local indices = ars.collect_best_indices(scored, 3, true)
|
||||||
|
for i, ind in ipairs(indices) do
|
||||||
|
print(ind, ":", scored[ind * 2 - 1], scored[ind * 2 - 0])
|
||||||
|
end
|
||||||
|
local top_rewards = {}
|
||||||
|
for _, ind in ipairs(indices) do
|
||||||
|
insert(top_rewards, scored[ind * 2 - 1])
|
||||||
|
insert(top_rewards, scored[ind * 2 - 0])
|
||||||
|
end
|
||||||
|
-- this shouldn't make a difference to the final print:
|
||||||
|
top_rewards = util.normalize_sums(top_rewards)
|
||||||
|
print(nn.pp(top_rewards))
|
||||||
|
local _, reward_dev = util.calc_mean_dev(top_rewards)
|
||||||
|
print(reward_dev)
|
||||||
|
for i, ind in ipairs(indices) do
|
||||||
|
local pos = top_rewards[i * 2 - 1]
|
||||||
|
local neg = top_rewards[i * 2 - 0]
|
||||||
|
local reward = pos - neg
|
||||||
|
reward = reward / reward_dev
|
||||||
|
print(reward)
|
||||||
|
end
|
||||||
|
do return end
|
||||||
|
end
|
||||||
|
|
||||||
|
local asked = nil -- for caching purposes.
|
||||||
|
local noise = nil -- for caching purposes.
|
||||||
|
local current_cost = spherical(es:params())
|
||||||
|
|
||||||
|
for i=1, iterations do
|
||||||
|
if getmetatable(es).__index == snes.Snes then
|
||||||
|
asked, noise = es:ask_mix()
|
||||||
|
elseif getmetatable(es).__index == ars.Ars then
|
||||||
|
asked, noise = es:ask()
|
||||||
|
else
|
||||||
|
asked, noise = es:ask(asked, noise)
|
||||||
|
end
|
||||||
|
local scores = {}
|
||||||
|
for i, v in ipairs(asked) do
|
||||||
|
scores[i] = spherical(v)
|
||||||
|
end
|
||||||
|
|
||||||
|
if getmetatable(es).__index == ars.Ars then
|
||||||
|
es:tell(scores)--, current_cost) -- use lips
|
||||||
|
else
|
||||||
|
es:tell(scores)
|
||||||
|
end
|
||||||
|
|
||||||
|
current_cost = spherical(es:params())
|
||||||
|
--if i % 100 == 0 then
|
||||||
|
if i % 100 == 0 then
|
||||||
|
local sigma = es.sigma
|
||||||
|
if getmetatable(es).__index == snes.Snes then
|
||||||
|
sigma = 0
|
||||||
|
for i, v in ipairs(es.std) do sigma = sigma + v end
|
||||||
|
sigma = sigma / #es.std
|
||||||
|
end
|
||||||
|
local inconvergence = sigma / sigma_init
|
||||||
|
local fmt = "fitness at iteration %i: %.4f (%.4f)"
|
||||||
|
print(fmt:format(i, current_cost, log(inconvergence) / log(10)))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- note: this metric doesn't include the "fitness at iteration" evaluations,
|
||||||
|
-- because those aren't actually used to step towards the optimum.
|
||||||
|
print(("optimized in %i function evaluations"):format(es.evals))
|
||||||
|
|
||||||
|
local s = ''
|
||||||
|
for i, v in ipairs(es:params()) do
|
||||||
|
s = s..("%.8f"):format(v)
|
||||||
|
if i ~= es.dim then s = s..', ' end
|
||||||
|
end
|
||||||
|
print(s)
|
62
extra.lua
Normal file
62
extra.lua
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
local function strpad(num, count, pad)
|
||||||
|
num = tostring(num)
|
||||||
|
return (pad:rep(count)..num):sub(#num)
|
||||||
|
end
|
||||||
|
|
||||||
|
local function add_zeros(num, count)
|
||||||
|
return strpad(num, count - 1, '0')
|
||||||
|
end
|
||||||
|
|
||||||
|
local function mixed_sorter(a, b)
|
||||||
|
a = type(a) == 'number' and add_zeros(a, 16) or tostring(a)
|
||||||
|
b = type(b) == 'number' and add_zeros(b, 16) or tostring(b)
|
||||||
|
return a < b
|
||||||
|
end
|
||||||
|
|
||||||
|
-- loosely based on http://lua-users.org/wiki/SortedIteration
|
||||||
|
-- the original didn't make use of closures for who knows why
|
||||||
|
local function order_keys(t)
|
||||||
|
local oi = {}
|
||||||
|
for key in pairs(t) do
|
||||||
|
table.insert(oi, key)
|
||||||
|
end
|
||||||
|
table.sort(oi, mixed_sorter)
|
||||||
|
return oi
|
||||||
|
end
|
||||||
|
|
||||||
|
local function opairs(t, cache)
|
||||||
|
local oi = cache and cache[t] or order_keys(t)
|
||||||
|
if cache then
|
||||||
|
cache[t] = oi
|
||||||
|
end
|
||||||
|
local i = 0
|
||||||
|
return function()
|
||||||
|
i = i + 1
|
||||||
|
local key = oi[i]
|
||||||
|
if key then return key, t[key] end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local function traverse(path)
|
||||||
|
if not path then return end
|
||||||
|
local parent = _G
|
||||||
|
local key
|
||||||
|
for w in path:gfind("[%w_]+") do
|
||||||
|
if key then
|
||||||
|
parent = rawget(parent, key)
|
||||||
|
if type(parent) ~= 'table' then return end
|
||||||
|
end
|
||||||
|
key = w
|
||||||
|
end
|
||||||
|
if not key then return end
|
||||||
|
return {parent=parent, key=key}
|
||||||
|
end
|
||||||
|
|
||||||
|
return {
|
||||||
|
strpad = strpad,
|
||||||
|
add_zeros = add_zeros,
|
||||||
|
mixed_sorter = mixed_sorter,
|
||||||
|
order_keys = order_keys,
|
||||||
|
opairs = opairs,
|
||||||
|
traverse = traverse,
|
||||||
|
}
|
54
monitor_tiles.lua
Normal file
54
monitor_tiles.lua
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
-- keep track of which blocks are actually seen in the game.
|
||||||
|
-- play back an all-levels TAS with this script running.
|
||||||
|
|
||||||
|
local floor = math.floor
|
||||||
|
local open = io.open
|
||||||
|
local pairs = pairs
|
||||||
|
local print = print
|
||||||
|
|
||||||
|
local util = require("util")
|
||||||
|
local R = memory.readbyteunsigned
|
||||||
|
local W = memory.writebyte
|
||||||
|
local function S(addr) return util.signbyte(R(addr)) end
|
||||||
|
|
||||||
|
local game = require("smb") -- just for advance()
|
||||||
|
|
||||||
|
local serial = require "serialize"
|
||||||
|
local serialize = serial.serialize
|
||||||
|
local deserialize = serial.deserialize
|
||||||
|
|
||||||
|
local fn = 'seen_tiles.lua'
|
||||||
|
local seen = deserialize(fn) or {}
|
||||||
|
|
||||||
|
local function mark_tile(sx, sy, kind)
|
||||||
|
if not seen[kind] then
|
||||||
|
seen[kind] = true
|
||||||
|
print(("%02X"):format(kind))
|
||||||
|
serialize(fn, seen)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local function handle_tiles()
|
||||||
|
--local tile_col = R(0x6A0)
|
||||||
|
local tile_scroll = floor(R(0x73F) / 16) + R(0x71A) * 16
|
||||||
|
local tile_scroll_remainder = R(0x73F) % 16
|
||||||
|
for y = 0, 12 do
|
||||||
|
for x = 0, 16 do
|
||||||
|
local col = (x + tile_scroll) % 32
|
||||||
|
local t
|
||||||
|
if col < 16 then
|
||||||
|
t = R(0x500 + y * 16 + (col % 16))
|
||||||
|
else
|
||||||
|
t = R(0x5D0 + y * 16 + (col % 16))
|
||||||
|
end
|
||||||
|
local sx = x * 16 + 8 - tile_scroll_remainder
|
||||||
|
local sy = y * 16 + 40
|
||||||
|
mark_tile(sx, sy, t)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
while true do
|
||||||
|
handle_tiles()
|
||||||
|
game.advance()
|
||||||
|
end
|
22
presets.lua
22
presets.lua
|
@ -72,6 +72,28 @@ make_preset{
|
||||||
sigma_decay = 0.008,
|
sigma_decay = 0.008,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'snes2',
|
||||||
|
parent = 'big-scroll-reduced',
|
||||||
|
|
||||||
|
es = 'snes',
|
||||||
|
deterministic = true,
|
||||||
|
deviation = 0.01,
|
||||||
|
negate_trials = false,
|
||||||
|
epoch_trials = 60,
|
||||||
|
min_refresh = 2/3,
|
||||||
|
param_rate = 0.368,
|
||||||
|
param_decay = 0.0138,
|
||||||
|
sigma_rate = 0.100,
|
||||||
|
sigma_decay = 0.0051,
|
||||||
|
}
|
||||||
|
|
||||||
|
make_preset{
|
||||||
|
name = 'snes3',
|
||||||
|
parent = 'snes2',
|
||||||
|
min_refresh = 1/3,
|
||||||
|
}
|
||||||
|
|
||||||
make_preset{
|
make_preset{
|
||||||
name = 'xnes',
|
name = 'xnes',
|
||||||
parent = 'big-scroll-reduced',
|
parent = 'big-scroll-reduced',
|
||||||
|
|
15
rescale.lua
Normal file
15
rescale.lua
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
local f = assert(io.open("params-ars4.txt", "r"))
|
||||||
|
local data = f:read("*a")
|
||||||
|
f:close()
|
||||||
|
local values = {}
|
||||||
|
for v in data:gmatch("[^\r\n]+") do
|
||||||
|
table.insert(values, tonumber(v))
|
||||||
|
end
|
||||||
|
|
||||||
|
for i, v in ipairs(values) do
|
||||||
|
values[i] = v * 100
|
||||||
|
end
|
||||||
|
|
||||||
|
for i, v in ipairs(values) do
|
||||||
|
print(v)
|
||||||
|
end
|
183
running.lua
Normal file
183
running.lua
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
local huge = math.huge
|
||||||
|
local ipairs = ipairs
|
||||||
|
local open = io.open
|
||||||
|
local sqrt = math.sqrt
|
||||||
|
|
||||||
|
local nn = require("nn")
|
||||||
|
local Base = require("Base")
|
||||||
|
|
||||||
|
-- https://github.com/modestyachts/ARS/blob/master/code/filter.py
|
||||||
|
-- http://www.johndcook.com/blog/standard_deviation/
|
||||||
|
local Stats = Base:extend()
|
||||||
|
local Normalizer = Base:extend()
|
||||||
|
|
||||||
|
function Stats:init(shape)
|
||||||
|
self._n = 0
|
||||||
|
self._M = nn.zeros(shape)
|
||||||
|
self._S = nn.zeros(shape)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Stats:push(x)
|
||||||
|
assert(nn.prod(x.shape) == nn.prod(self._M.shape), "sizes mismatch")
|
||||||
|
local n1 = self._n
|
||||||
|
self._n = self._n + 1
|
||||||
|
if self._n == 1 then
|
||||||
|
nn.copy(x, self._M)
|
||||||
|
else
|
||||||
|
local delta = {}
|
||||||
|
for i, v in ipairs(self._M) do delta[i] = x[i] - v end
|
||||||
|
for i, v in ipairs(self._M) do self._M[i] = v + delta[i] / self._n end
|
||||||
|
for i, v in ipairs(self._S) do self._S[i] = v + delta[i] * delta[i] * n1 / self._n end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function Stats:var()
|
||||||
|
local out = {}
|
||||||
|
if self._n == 1 then
|
||||||
|
for i, v in ipairs(self._M) do out[i] = v * v end
|
||||||
|
else
|
||||||
|
for i, v in ipairs(self._S) do out[i] = v / (self._n - 1) end
|
||||||
|
end
|
||||||
|
return out
|
||||||
|
end
|
||||||
|
|
||||||
|
function Stats:dev()
|
||||||
|
local out = self:var()
|
||||||
|
for i, v in ipairs(out) do out[i] = sqrt(v) end
|
||||||
|
return out
|
||||||
|
end
|
||||||
|
|
||||||
|
function Normalizer:init(shape, demean, destd)
|
||||||
|
if demean == nil then demean = true end
|
||||||
|
if destd == nil then destd = true end
|
||||||
|
self.shape = shape
|
||||||
|
self.demean = demean
|
||||||
|
self.destd = destd
|
||||||
|
self.rs = Stats(shape)
|
||||||
|
self.mean = nn.zeros(shape)
|
||||||
|
self.std = nn.zeros(shape)
|
||||||
|
for i = 1, #self.std do self.std[i] = 1 end
|
||||||
|
end
|
||||||
|
|
||||||
|
function Normalizer:process(x)
|
||||||
|
local out = nn.copy(x)
|
||||||
|
if self.demean then
|
||||||
|
for i, v in ipairs(out) do out[i] = out[i] - self.mean[i] end
|
||||||
|
end
|
||||||
|
if self.destd then
|
||||||
|
for i, v in ipairs(out) do out[i] = out[i] / (self.std[i] + 1e-8) end
|
||||||
|
end
|
||||||
|
return out
|
||||||
|
end
|
||||||
|
|
||||||
|
function Normalizer:update()
|
||||||
|
nn.copy(self.rs._M, self.mean) -- FIXME: HACK
|
||||||
|
nn.copy(self.rs:dev(), self.std)
|
||||||
|
-- Set values for std less than 1e-7 to +inf
|
||||||
|
-- to avoid dividing by zero. State elements
|
||||||
|
-- with zero variance are set to zero as a result.
|
||||||
|
for i, v in ipairs(self.std) do
|
||||||
|
if v < 1e-7 then self.std[i] = huge end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function Normalizer:push(x, update)
|
||||||
|
self.rs:push(x)
|
||||||
|
if update == nil or update then self:update() end
|
||||||
|
return self:process(x)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Normalizer:default_filename()
|
||||||
|
return ('stats%07i.txt'):format(nn.prod(self.shape))
|
||||||
|
end
|
||||||
|
|
||||||
|
function Normalizer:save(fn)
|
||||||
|
local fn = fn or self:default_filename()
|
||||||
|
local f = open(fn, 'w')
|
||||||
|
if f == nil then error("Failed to save stats to file "..fn) end
|
||||||
|
f:write(self.rs._n)
|
||||||
|
f:write('\n')
|
||||||
|
for i, v in ipairs(self.rs._M) do
|
||||||
|
f:write(v)
|
||||||
|
f:write('\n')
|
||||||
|
end
|
||||||
|
for i, v in ipairs(self.rs._S) do
|
||||||
|
f:write(v)
|
||||||
|
f:write('\n')
|
||||||
|
end
|
||||||
|
f:close()
|
||||||
|
end
|
||||||
|
|
||||||
|
function Normalizer:load(fn)
|
||||||
|
local fn = fn or self:default_filename()
|
||||||
|
local f = open(fn, 'r')
|
||||||
|
if f == nil then error("Failed to load stats from file "..fn) end
|
||||||
|
|
||||||
|
local i = 0
|
||||||
|
local split_M = 1
|
||||||
|
local split_S = split_M + nn.prod(self.shape)
|
||||||
|
for line in f:lines() do
|
||||||
|
i = i + 1
|
||||||
|
local n = tonumber(line)
|
||||||
|
if n == nil then
|
||||||
|
error("Failed reading line "..tostring(i).." of file "..fn)
|
||||||
|
end
|
||||||
|
|
||||||
|
if i <= split_M then
|
||||||
|
self.rs._n = n
|
||||||
|
elseif i <= split_S then
|
||||||
|
self.rs._M[i - split_M] = n
|
||||||
|
else
|
||||||
|
self.rs._S[i - split_S] = n
|
||||||
|
end
|
||||||
|
end
|
||||||
|
f:close()
|
||||||
|
|
||||||
|
self:update()
|
||||||
|
end
|
||||||
|
|
||||||
|
--[[
|
||||||
|
|
||||||
|
-- basic tests
|
||||||
|
|
||||||
|
local dims = 20
|
||||||
|
local rs = Stats(dims)
|
||||||
|
local x = nn.zeros(dims)
|
||||||
|
|
||||||
|
for i = 1, #x do x[i] = nn.normal() end
|
||||||
|
rs:push(x)
|
||||||
|
print(nn.pp(rs:dev()))
|
||||||
|
|
||||||
|
for j = 1, 10000 do
|
||||||
|
for i = 1, #x do x[i] = nn.normal() end
|
||||||
|
rs:push(x)
|
||||||
|
end
|
||||||
|
print(nn.pp(rs:dev()))
|
||||||
|
|
||||||
|
--
|
||||||
|
|
||||||
|
local ms = Normalizer(dims)
|
||||||
|
local exp = math.exp
|
||||||
|
local y
|
||||||
|
|
||||||
|
for i = 1, #x do x[i] = exp(nn.normal()) end
|
||||||
|
y = ms:push(x)
|
||||||
|
print(nn.pp(y))
|
||||||
|
|
||||||
|
for j = 1, 10000 do
|
||||||
|
for i = 1, #x do x[i] = exp(nn.normal()) end
|
||||||
|
y = ms:push(x)
|
||||||
|
end
|
||||||
|
print(nn.pp(y))
|
||||||
|
|
||||||
|
print("mean:")
|
||||||
|
print(nn.pp(ms.mean))
|
||||||
|
print("stdev:")
|
||||||
|
print(nn.pp(ms.std))
|
||||||
|
|
||||||
|
--]]
|
||||||
|
|
||||||
|
return {
|
||||||
|
Stats = Stats,
|
||||||
|
Normalizer = Normalizer,
|
||||||
|
}
|
59
seen_tiles.lua
Normal file
59
seen_tiles.lua
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
return {
|
||||||
|
[0] = true,
|
||||||
|
[16] = true,
|
||||||
|
[17] = true,
|
||||||
|
[18] = true,
|
||||||
|
[19] = true,
|
||||||
|
[20] = true,
|
||||||
|
[21] = true,
|
||||||
|
[22] = true,
|
||||||
|
[23] = true,
|
||||||
|
[24] = true,
|
||||||
|
[25] = true,
|
||||||
|
[26] = true,
|
||||||
|
[27] = true,
|
||||||
|
[28] = true,
|
||||||
|
[29] = true,
|
||||||
|
[30] = true,
|
||||||
|
[31] = true,
|
||||||
|
[32] = true,
|
||||||
|
[33] = true,
|
||||||
|
[34] = true,
|
||||||
|
[35] = true,
|
||||||
|
[36] = true,
|
||||||
|
[37] = true,
|
||||||
|
[38] = true,
|
||||||
|
[81] = true,
|
||||||
|
[82] = true,
|
||||||
|
[84] = true,
|
||||||
|
[85] = true,
|
||||||
|
[86] = true,
|
||||||
|
[87] = true,
|
||||||
|
[88] = true,
|
||||||
|
[89] = true,
|
||||||
|
[90] = true,
|
||||||
|
[91] = true,
|
||||||
|
[92] = true,
|
||||||
|
[93] = true,
|
||||||
|
[94] = true,
|
||||||
|
[95] = true,
|
||||||
|
[96] = true,
|
||||||
|
[97] = true,
|
||||||
|
[98] = true,
|
||||||
|
[99] = true,
|
||||||
|
[100] = true,
|
||||||
|
[101] = true,
|
||||||
|
[102] = true,
|
||||||
|
[103] = true,
|
||||||
|
[104] = true,
|
||||||
|
[105] = true,
|
||||||
|
[107] = true,
|
||||||
|
[108] = true,
|
||||||
|
[137] = true,
|
||||||
|
[192] = true,
|
||||||
|
[193] = true,
|
||||||
|
[194] = true,
|
||||||
|
[195] = true,
|
||||||
|
[196] = true,
|
||||||
|
[197] = true,
|
||||||
|
}
|
76
serialize.lua
Normal file
76
serialize.lua
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
-- it's simple, dumb, unsafe, incomplete, and it gets the damn job done
|
||||||
|
|
||||||
|
local type = type
|
||||||
|
local extra = require "extra"
|
||||||
|
local opairs = extra.opairs
|
||||||
|
local tostring = tostring
|
||||||
|
local open = io.open
|
||||||
|
local strfmt = string.format
|
||||||
|
local strrep = string.rep
|
||||||
|
|
||||||
|
local function kill_bom(s)
|
||||||
|
if #s >= 3 and s:byte(1)==0xEF and s:byte(2)==0xBB and s:byte(3)==0xBF then
|
||||||
|
return s:sub(4)
|
||||||
|
end
|
||||||
|
return s
|
||||||
|
end
|
||||||
|
|
||||||
|
local function sanitize(v)
|
||||||
|
local force = type(v) == 'string' and v:sub(1, 1):match('%d')
|
||||||
|
force = force and true or false
|
||||||
|
return type(v) == 'string' and strfmt('%q', v) or tostring(v), force
|
||||||
|
end
|
||||||
|
|
||||||
|
local function _serialize(value, writer, level)
|
||||||
|
level = level or 1
|
||||||
|
if type(value) == 'table' then
|
||||||
|
local indent = strrep('\t', level)
|
||||||
|
writer('{\n')
|
||||||
|
for key,value in opairs(value) do
|
||||||
|
local sane, force = sanitize(key)
|
||||||
|
local keyval = (sane == '"'..key..'"' and not force) and key or '['..sane..']'
|
||||||
|
writer(indent..keyval..' = ')
|
||||||
|
_serialize(value, writer, level + 1)
|
||||||
|
writer(',\n')
|
||||||
|
end
|
||||||
|
writer(strrep('\t', level - 1)..'}')
|
||||||
|
else
|
||||||
|
local sane, force = sanitize(value)
|
||||||
|
writer(sane)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local function _deserialize(script)
|
||||||
|
local f = loadstring(kill_bom(script))
|
||||||
|
if f ~= nil then
|
||||||
|
return f()
|
||||||
|
else
|
||||||
|
print('WARNING: no function to deserialize with')
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local function serialize(path, value)
|
||||||
|
local file = open(path, 'w')
|
||||||
|
if not file then return end
|
||||||
|
file:write("return ")
|
||||||
|
_serialize(value, function(...)
|
||||||
|
file:write(...)
|
||||||
|
end)
|
||||||
|
file:write("\n")
|
||||||
|
file:close()
|
||||||
|
end
|
||||||
|
|
||||||
|
local function deserialize(path)
|
||||||
|
local file = open(path, 'r')
|
||||||
|
if not file then return end
|
||||||
|
local script = file:read('*a')
|
||||||
|
local value = _deserialize(script)
|
||||||
|
file:close()
|
||||||
|
return value
|
||||||
|
end
|
||||||
|
|
||||||
|
return {
|
||||||
|
serialize = serialize,
|
||||||
|
deserialize = deserialize,
|
||||||
|
}
|
1
snes.lua
1
snes.lua
|
@ -3,7 +3,6 @@
|
||||||
-- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
|
-- http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
|
||||||
-- not to be confused with the Super Nintendo Entertainment System.
|
-- not to be confused with the Super Nintendo Entertainment System.
|
||||||
|
|
||||||
local abs = math.abs
|
|
||||||
local assert = assert
|
local assert = assert
|
||||||
local exp = math.exp
|
local exp = math.exp
|
||||||
local floor = math.floor
|
local floor = math.floor
|
||||||
|
|
Loading…
Add table
Reference in a new issue