improve node traversal and export more stuff
This commit is contained in:
parent
ddc8cae4c6
commit
16b38a7151
2 changed files with 67 additions and 63 deletions
11
main.lua
11
main.lua
|
@ -145,13 +145,6 @@ local ror = bit.ror
|
||||||
|
|
||||||
-- utilities.
|
-- utilities.
|
||||||
|
|
||||||
local function ifind(haystack, needle)
|
|
||||||
for i, v in ipairs(haystack) do
|
|
||||||
if v == needle then return i end
|
|
||||||
end
|
|
||||||
return nil
|
|
||||||
end
|
|
||||||
|
|
||||||
local function boolean_xor(a, b)
|
local function boolean_xor(a, b)
|
||||||
if a and b then return false end
|
if a and b then return false end
|
||||||
if not a and not b then return false end
|
if not a and not b then return false end
|
||||||
|
@ -569,13 +562,13 @@ local function fitness_shaping(rewards)
|
||||||
local denom = 0
|
local denom = 0
|
||||||
for i, v in ipairs(rewards) do
|
for i, v in ipairs(rewards) do
|
||||||
local l = log2(lamb / 2 + 1)
|
local l = log2(lamb / 2 + 1)
|
||||||
local r = log2(ifind(decreasing, v))
|
local r = log2(nn.indexof(decreasing, v))
|
||||||
denom = denom + max(0, l - r)
|
denom = denom + max(0, l - r)
|
||||||
end
|
end
|
||||||
|
|
||||||
for i, v in ipairs(rewards) do
|
for i, v in ipairs(rewards) do
|
||||||
local l = log2(lamb / 2 + 1)
|
local l = log2(lamb / 2 + 1)
|
||||||
local r = log2(ifind(decreasing, v))
|
local r = log2(nn.indexof(decreasing, v))
|
||||||
local numer = max(0, l - r)
|
local numer = max(0, l - r)
|
||||||
insert(shaped_returns, numer / denom + 1 / lamb)
|
insert(shaped_returns, numer / denom + 1 / lamb)
|
||||||
end
|
end
|
||||||
|
|
119
nn.lua
119
nn.lua
|
@ -1,28 +1,36 @@
|
||||||
local print = print
|
|
||||||
local tostring = tostring
|
|
||||||
local ipairs = ipairs
|
|
||||||
local pairs = pairs
|
|
||||||
local uniform = math.random
|
|
||||||
local sqrt = math.sqrt
|
|
||||||
local log = math.log
|
|
||||||
local pi = math.pi
|
|
||||||
local exp = math.exp
|
|
||||||
local min = math.min
|
|
||||||
local max = math.max
|
|
||||||
local cos = math.cos
|
local cos = math.cos
|
||||||
local sin = math.sin
|
local exp = math.exp
|
||||||
local insert = table.insert
|
local insert = table.insert
|
||||||
local remove = table.remove
|
local ipairs = ipairs
|
||||||
|
local log = math.log
|
||||||
|
local max = math.max
|
||||||
|
local min = math.min
|
||||||
local open = io.open
|
local open = io.open
|
||||||
|
local pairs = pairs
|
||||||
local bor = bit.bor
|
local pi = math.pi
|
||||||
|
local print = print
|
||||||
|
local remove = table.remove
|
||||||
|
local sin = math.sin
|
||||||
|
local sqrt = math.sqrt
|
||||||
|
local tostring = tostring
|
||||||
|
local uniform = math.random
|
||||||
|
|
||||||
local Base = require("Base")
|
local Base = require("Base")
|
||||||
|
|
||||||
local function contains(t, a)
|
local function copy(t) -- shallow copy
|
||||||
|
local new_t = {}
|
||||||
|
for k, v in pairs(t) do new_t[k] = v end
|
||||||
|
return new_t
|
||||||
|
end
|
||||||
|
|
||||||
|
local function indexof(t, a)
|
||||||
assert(type(t) == "table")
|
assert(type(t) == "table")
|
||||||
for k, v in pairs(t) do if v == a then return true end end
|
for k, v in pairs(t) do if v == a then return k end end
|
||||||
return false
|
return nil
|
||||||
|
end
|
||||||
|
|
||||||
|
local function contains(t, a)
|
||||||
|
return indexof(t, a) ~= nil
|
||||||
end
|
end
|
||||||
|
|
||||||
local function prod(x, ...)
|
local function prod(x, ...)
|
||||||
|
@ -63,12 +71,6 @@ local function init_he_normal(t, fan_in, fan_out)
|
||||||
return t
|
return t
|
||||||
end
|
end
|
||||||
|
|
||||||
local function copy(t) -- shallow copy
|
|
||||||
local new_t = {}
|
|
||||||
for k, v in pairs(t) do new_t[k] = v end
|
|
||||||
return new_t
|
|
||||||
end
|
|
||||||
|
|
||||||
local function allocate(t, out, init)
|
local function allocate(t, out, init)
|
||||||
out = out or {}
|
out = out or {}
|
||||||
local size = t
|
local size = t
|
||||||
|
@ -79,44 +81,50 @@ local function allocate(t, out, init)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
local function levelorder(field, node_in, nodes)
|
local function traverse(node_in, node_out, nodes, dummy_mode)
|
||||||
-- horribly inefficient.
|
-- i have no idea if this is any algorithm in particular.
|
||||||
nodes = nodes or {}
|
nodes = nodes or {}
|
||||||
local q = {node_in}
|
|
||||||
while #q > 0 do
|
|
||||||
local node = q[1]
|
|
||||||
remove(q, 1)
|
|
||||||
insert(nodes, node)
|
|
||||||
for _, child in ipairs(node[field]) do
|
|
||||||
q[#q+1] = child
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return nodes
|
|
||||||
end
|
|
||||||
|
|
||||||
local function traverse(node_in, node_out, nodes)
|
local seen_up = {}
|
||||||
nodes = nodes or {}
|
local q = {node_out}
|
||||||
local down = levelorder('children', node_in, {})
|
while #q > 0 do
|
||||||
local up = levelorder('parents', node_out, {})
|
local node = remove(q, 1)
|
||||||
local seen = {}
|
seen_up[node] = true
|
||||||
for _, node in ipairs(up) do
|
for _, parent in ipairs(node.parents) do insert(q, parent) end
|
||||||
seen[node] = bor(seen[node] or 0, 1)
|
|
||||||
end
|
end
|
||||||
for _, node in ipairs(down) do
|
|
||||||
seen[node] = bor(seen[node] or 0, 2)
|
if dummy_mode then seen_up[node_in] = true end
|
||||||
if seen[node] == 3 then
|
|
||||||
insert(nodes, node)
|
nodes = {}
|
||||||
|
q = {node_in}
|
||||||
|
while #q > 0 do
|
||||||
|
local node = remove(q, 1)
|
||||||
|
if seen_up[node] then
|
||||||
|
local all_parents_added = true
|
||||||
|
for _, parent in ipairs(node.parents) do
|
||||||
|
if not contains(nodes, parent) then
|
||||||
|
all_parents_added = false
|
||||||
|
break
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if not contains(nodes, node) and all_parents_added then
|
||||||
|
insert(nodes, node)
|
||||||
|
end
|
||||||
|
for _, child in ipairs(node.children) do insert(q, child) end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if dummy_mode then remove(nodes, indexof(nodes, node_in)) end
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
end
|
end
|
||||||
|
|
||||||
local function traverse_all(nodes_in, nodes_out, nodes)
|
local function traverse_all(nodes_in, nodes_out, nodes)
|
||||||
local all_in = {children={}}
|
local all_in = {children={}, parents={}}
|
||||||
local all_out = {parents={}}
|
local all_out = {children={}, parents={}}
|
||||||
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
|
for _, node in ipairs(nodes_in) do insert(all_in.children, node) end
|
||||||
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
|
for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end
|
||||||
return traverse(all_in, all_out, nodes or {})
|
return traverse(all_in, all_out, nodes or {}, true)
|
||||||
end
|
end
|
||||||
|
|
||||||
local Weights = Base:extend()
|
local Weights = Base:extend()
|
||||||
|
@ -443,14 +451,17 @@ function Model:load(fn)
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
uniform = uniform,
|
|
||||||
normal = normal,
|
|
||||||
|
|
||||||
copy = copy,
|
copy = copy,
|
||||||
|
indexof = indexof,
|
||||||
|
contains = contains,
|
||||||
|
prod = prod,
|
||||||
|
normal = normal,
|
||||||
zeros = zeros,
|
zeros = zeros,
|
||||||
init_zeros = init_zeros,
|
init_zeros = init_zeros,
|
||||||
init_he_uniform = init_he_uniform,
|
init_he_uniform = init_he_uniform,
|
||||||
init_he_normal = init_he_normal,
|
init_he_normal = init_he_normal,
|
||||||
|
allocate = allocate,
|
||||||
|
traverse = traverse,
|
||||||
|
|
||||||
Weights = Weights,
|
Weights = Weights,
|
||||||
Layer = Layer,
|
Layer = Layer,
|
||||||
|
|
Loading…
Add table
Reference in a new issue