improve node traversal and export more stuff
This commit is contained in:
parent
ddc8cae4c6
commit
16b38a7151
2 changed files with 67 additions and 63 deletions
11
main.lua
11
main.lua
|
@ -145,13 +145,6 @@ local ror = bit.ror
|
|||
|
||||
-- utilities.
|
||||
|
||||
local function ifind(haystack, needle)
|
||||
for i, v in ipairs(haystack) do
|
||||
if v == needle then return i end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
local function boolean_xor(a, b)
|
||||
if a and b then return false end
|
||||
if not a and not b then return false end
|
||||
|
@ -569,13 +562,13 @@ local function fitness_shaping(rewards)
|
|||
local denom = 0
|
||||
for i, v in ipairs(rewards) do
|
||||
local l = log2(lamb / 2 + 1)
|
||||
local r = log2(ifind(decreasing, v))
|
||||
local r = log2(nn.indexof(decreasing, v))
|
||||
denom = denom + max(0, l - r)
|
||||
end
|
||||
|
||||
for i, v in ipairs(rewards) do
|
||||
local l = log2(lamb / 2 + 1)
|
||||
local r = log2(ifind(decreasing, v))
|
||||
local r = log2(nn.indexof(decreasing, v))
|
||||
local numer = max(0, l - r)
|
||||
insert(shaped_returns, numer / denom + 1 / lamb)
|
||||
end
|
||||
|
|
119
nn.lua
119
nn.lua
|
@ -1,28 +1,36 @@
|
|||
local print = print
|
||||
local tostring = tostring
|
||||
local ipairs = ipairs
|
||||
local pairs = pairs
|
||||
local uniform = math.random
|
||||
local sqrt = math.sqrt
|
||||
local log = math.log
|
||||
local pi = math.pi
|
||||
local exp = math.exp
|
||||
local min = math.min
|
||||
local max = math.max
|
||||
local cos = math.cos
|
||||
local sin = math.sin
|
||||
local exp = math.exp
|
||||
local insert = table.insert
|
||||
local remove = table.remove
|
||||
local ipairs = ipairs
|
||||
local log = math.log
|
||||
local max = math.max
|
||||
local min = math.min
|
||||
local open = io.open
|
||||
|
||||
local bor = bit.bor
|
||||
local pairs = pairs
|
||||
local pi = math.pi
|
||||
local print = print
|
||||
local remove = table.remove
|
||||
local sin = math.sin
|
||||
local sqrt = math.sqrt
|
||||
local tostring = tostring
|
||||
local uniform = math.random
|
||||
|
||||
local Base = require("Base")
|
||||
|
||||
local function contains(t, a)
|
||||
local function copy(t) -- shallow copy
|
||||
local new_t = {}
|
||||
for k, v in pairs(t) do new_t[k] = v end
|
||||
return new_t
|
||||
end
|
||||
|
||||
local function indexof(t, a)
|
||||
assert(type(t) == "table")
|
||||
for k, v in pairs(t) do if v == a then return true end end
|
||||
return false
|
||||
for k, v in pairs(t) do if v == a then return k end end
|
||||
return nil
|
||||
end
|
||||
|
||||
local function contains(t, a)
|
||||
return indexof(t, a) ~= nil
|
||||
end
|
||||
|
||||
local function prod(x, ...)
|
||||
|
@ -63,12 +71,6 @@ local function init_he_normal(t, fan_in, fan_out)
|
|||
return t
|
||||
end
|
||||
|
||||
local function copy(t) -- shallow copy
|
||||
local new_t = {}
|
||||
for k, v in pairs(t) do new_t[k] = v end
|
||||
return new_t
|
||||
end
|
||||
|
||||
local function allocate(t, out, init)
|
||||
out = out or {}
|
||||
local size = t
|
||||
|
@ -79,44 +81,50 @@ local function allocate(t, out, init)
|
|||
end
|
||||
end
|
||||
|
||||
local function levelorder(field, node_in, nodes)
|
||||
-- horribly inefficient.
|
||||
local function traverse(node_in, node_out, nodes, dummy_mode)
|
||||
-- i have no idea if this is any algorithm in particular.
|
||||
nodes = nodes or {}
|
||||
local q = {node_in}
|
||||
while #q > 0 do
|
||||
local node = q[1]
|
||||
remove(q, 1)
|
||||
insert(nodes, node)
|
||||
for _, child in ipairs(node[field]) do
|
||||
q[#q+1] = child
|
||||
end
|
||||
end
|
||||
return nodes
|
||||
end
|
||||
|
||||
local function traverse(node_in, node_out, nodes)
|
||||
nodes = nodes or {}
|
||||
local down = levelorder('children', node_in, {})
|
||||
local up = levelorder('parents', node_out, {})
|
||||
local seen = {}
|
||||
for _, node in ipairs(up) do
|
||||
seen[node] = bor(seen[node] or 0, 1)
|
||||
local seen_up = {}
|
||||
local q = {node_out}
|
||||
while #q > 0 do
|
||||
local node = remove(q, 1)
|
||||
seen_up[node] = true
|
||||
for _, parent in ipairs(node.parents) do insert(q, parent) end
|
||||
end
|
||||
for _, node in ipairs(down) do
|
||||
seen[node] = bor(seen[node] or 0, 2)
|
||||
if seen[node] == 3 then
|
||||
insert(nodes, node)
|
||||
|
||||
if dummy_mode then seen_up[node_in] = true end
|
||||
|
||||
nodes = {}
|
||||
q = {node_in}
|
||||
while #q > 0 do
|
||||
local node = remove(q, 1)
|
||||
if seen_up[node] then
|
||||
local all_parents_added = true
|
||||
for _, parent in ipairs(node.parents) do
|
||||
if not contains(nodes, parent) then
|
||||
all_parents_added = false
|
||||
break
|
||||
end
|
||||
end
|
||||
if not contains(nodes, node) and all_parents_added then
|
||||
insert(nodes, node)
|
||||
end
|
||||
for _, child in ipairs(node.children) do insert(q, child) end
|
||||
end
|
||||
end
|
||||
|
||||
if dummy_mode then remove(nodes, indexof(nodes, node_in)) end
|
||||
|
||||
return nodes
|
||||
end
|
||||
|
||||
local function traverse_all(nodes_in, nodes_out, nodes)
|
||||
local all_in = {children={}}
|
||||
local all_out = {parents={}}
|
||||
local all_in = {children={}, parents={}}
|
||||
local all_out = {children={}, parents={}}
|
||||
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
|
||||
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
|
||||
return traverse(all_in, all_out, nodes or {})
|
||||
return traverse(all_in, all_out, nodes or {}, true)
|
||||
end
|
||||
|
||||
local Weights = Base:extend()
|
||||
|
@ -443,14 +451,17 @@ function Model:load(fn)
|
|||
end
|
||||
|
||||
return {
|
||||
uniform = uniform,
|
||||
normal = normal,
|
||||
|
||||
copy = copy,
|
||||
indexof = indexof,
|
||||
contains = contains,
|
||||
prod = prod,
|
||||
normal = normal,
|
||||
zeros = zeros,
|
||||
init_zeros = init_zeros,
|
||||
init_he_uniform = init_he_uniform,
|
||||
init_he_normal = init_he_normal,
|
||||
allocate = allocate,
|
||||
traverse = traverse,
|
||||
|
||||
Weights = Weights,
|
||||
Layer = Layer,
|
||||
|
|
Loading…
Reference in a new issue