improve node traversal and export more stuff

This commit is contained in:
Connor Olding 2017-09-07 19:04:36 +00:00
parent ddc8cae4c6
commit 16b38a7151
2 changed files with 67 additions and 63 deletions

View file

@ -145,13 +145,6 @@ local ror = bit.ror
-- utilities.
local function ifind(haystack, needle)
for i, v in ipairs(haystack) do
if v == needle then return i end
end
return nil
end
local function boolean_xor(a, b)
if a and b then return false end
if not a and not b then return false end
@ -569,13 +562,13 @@ local function fitness_shaping(rewards)
local denom = 0
for i, v in ipairs(rewards) do
local l = log2(lamb / 2 + 1)
local r = log2(ifind(decreasing, v))
local r = log2(nn.indexof(decreasing, v))
denom = denom + max(0, l - r)
end
for i, v in ipairs(rewards) do
local l = log2(lamb / 2 + 1)
local r = log2(ifind(decreasing, v))
local r = log2(nn.indexof(decreasing, v))
local numer = max(0, l - r)
insert(shaped_returns, numer / denom + 1 / lamb)
end

119
nn.lua
View file

@ -1,28 +1,36 @@
local print = print
local tostring = tostring
local ipairs = ipairs
local pairs = pairs
local uniform = math.random
local sqrt = math.sqrt
local log = math.log
local pi = math.pi
local exp = math.exp
local min = math.min
local max = math.max
local cos = math.cos
local sin = math.sin
local exp = math.exp
local insert = table.insert
local remove = table.remove
local ipairs = ipairs
local log = math.log
local max = math.max
local min = math.min
local open = io.open
local bor = bit.bor
local pairs = pairs
local pi = math.pi
local print = print
local remove = table.remove
local sin = math.sin
local sqrt = math.sqrt
local tostring = tostring
local uniform = math.random
local Base = require("Base")
local function contains(t, a)
local function copy(t) -- shallow copy
local new_t = {}
for k, v in pairs(t) do new_t[k] = v end
return new_t
end
local function indexof(t, a)
assert(type(t) == "table")
for k, v in pairs(t) do if v == a then return true end end
return false
for k, v in pairs(t) do if v == a then return k end end
return nil
end
local function contains(t, a)
return indexof(t, a) ~= nil
end
local function prod(x, ...)
@ -63,12 +71,6 @@ local function init_he_normal(t, fan_in, fan_out)
return t
end
local function copy(t) -- shallow copy
local new_t = {}
for k, v in pairs(t) do new_t[k] = v end
return new_t
end
local function allocate(t, out, init)
out = out or {}
local size = t
@ -79,44 +81,50 @@ local function allocate(t, out, init)
end
end
local function levelorder(field, node_in, nodes)
-- horribly inefficient.
local function traverse(node_in, node_out, nodes, dummy_mode)
-- i have no idea if this is any algorithm in particular.
nodes = nodes or {}
local q = {node_in}
while #q > 0 do
local node = q[1]
remove(q, 1)
insert(nodes, node)
for _, child in ipairs(node[field]) do
q[#q+1] = child
end
end
return nodes
end
local function traverse(node_in, node_out, nodes)
nodes = nodes or {}
local down = levelorder('children', node_in, {})
local up = levelorder('parents', node_out, {})
local seen = {}
for _, node in ipairs(up) do
seen[node] = bor(seen[node] or 0, 1)
local seen_up = {}
local q = {node_out}
while #q > 0 do
local node = remove(q, 1)
seen_up[node] = true
for _, parent in ipairs(node.parents) do insert(q, parent) end
end
for _, node in ipairs(down) do
seen[node] = bor(seen[node] or 0, 2)
if seen[node] == 3 then
insert(nodes, node)
if dummy_mode then seen_up[node_in] = true end
nodes = {}
q = {node_in}
while #q > 0 do
local node = remove(q, 1)
if seen_up[node] then
local all_parents_added = true
for _, parent in ipairs(node.parents) do
if not contains(nodes, parent) then
all_parents_added = false
break
end
end
if not contains(nodes, node) and all_parents_added then
insert(nodes, node)
end
for _, child in ipairs(node.children) do insert(q, child) end
end
end
if dummy_mode then remove(nodes, indexof(nodes, node_in)) end
return nodes
end
local function traverse_all(nodes_in, nodes_out, nodes)
local all_in = {children={}}
local all_out = {parents={}}
local all_in = {children={}, parents={}}
local all_out = {children={}, parents={}}
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
return traverse(all_in, all_out, nodes or {})
return traverse(all_in, all_out, nodes or {}, true)
end
local Weights = Base:extend()
@ -443,14 +451,17 @@ function Model:load(fn)
end
return {
uniform = uniform,
normal = normal,
copy = copy,
indexof = indexof,
contains = contains,
prod = prod,
normal = normal,
zeros = zeros,
init_zeros = init_zeros,
init_he_uniform = init_he_uniform,
init_he_normal = init_he_normal,
allocate = allocate,
traverse = traverse,
Weights = Weights,
Layer = Layer,