diff --git a/main.lua b/main.lua index 2db16be..db32810 100644 --- a/main.lua +++ b/main.lua @@ -145,13 +145,6 @@ local ror = bit.ror -- utilities. -local function ifind(haystack, needle) - for i, v in ipairs(haystack) do - if v == needle then return i end - end - return nil -end - local function boolean_xor(a, b) if a and b then return false end if not a and not b then return false end @@ -569,13 +562,13 @@ local function fitness_shaping(rewards) local denom = 0 for i, v in ipairs(rewards) do local l = log2(lamb / 2 + 1) - local r = log2(ifind(decreasing, v)) + local r = log2(nn.indexof(decreasing, v)) denom = denom + max(0, l - r) end for i, v in ipairs(rewards) do local l = log2(lamb / 2 + 1) - local r = log2(ifind(decreasing, v)) + local r = log2(nn.indexof(decreasing, v)) local numer = max(0, l - r) insert(shaped_returns, numer / denom + 1 / lamb) end diff --git a/nn.lua b/nn.lua index 996ef98..76d18b1 100644 --- a/nn.lua +++ b/nn.lua @@ -1,28 +1,36 @@ -local print = print -local tostring = tostring -local ipairs = ipairs -local pairs = pairs -local uniform = math.random -local sqrt = math.sqrt -local log = math.log -local pi = math.pi -local exp = math.exp -local min = math.min -local max = math.max local cos = math.cos -local sin = math.sin +local exp = math.exp local insert = table.insert -local remove = table.remove +local ipairs = ipairs +local log = math.log +local max = math.max +local min = math.min local open = io.open - -local bor = bit.bor +local pairs = pairs +local pi = math.pi +local print = print +local remove = table.remove +local sin = math.sin +local sqrt = math.sqrt +local tostring = tostring +local uniform = math.random local Base = require("Base") -local function contains(t, a) +local function copy(t) -- shallow copy + local new_t = {} + for k, v in pairs(t) do new_t[k] = v end + return new_t +end + +local function indexof(t, a) assert(type(t) == "table") - for k, v in pairs(t) do if v == a then return true end end - return false + for k, v in pairs(t) do if v == a then return k end end + return nil +end + +local function contains(t, a) + return indexof(t, a) ~= nil end local function prod(x, ...) @@ -63,12 +71,6 @@ local function init_he_normal(t, fan_in, fan_out) return t end -local function copy(t) -- shallow copy - local new_t = {} - for k, v in pairs(t) do new_t[k] = v end - return new_t -end - local function allocate(t, out, init) out = out or {} local size = t @@ -79,44 +81,50 @@ local function allocate(t, out, init) end end -local function levelorder(field, node_in, nodes) - -- horribly inefficient. +local function traverse(node_in, node_out, nodes, dummy_mode) + -- i have no idea if this is any algorithm in particular. nodes = nodes or {} - local q = {node_in} - while #q > 0 do - local node = q[1] - remove(q, 1) - insert(nodes, node) - for _, child in ipairs(node[field]) do - q[#q+1] = child - end - end - return nodes -end -local function traverse(node_in, node_out, nodes) - nodes = nodes or {} - local down = levelorder('children', node_in, {}) - local up = levelorder('parents', node_out, {}) - local seen = {} - for _, node in ipairs(up) do - seen[node] = bor(seen[node] or 0, 1) + local seen_up = {} + local q = {node_out} + while #q > 0 do + local node = remove(q, 1) + seen_up[node] = true + for _, parent in ipairs(node.parents) do insert(q, parent) end end - for _, node in ipairs(down) do - seen[node] = bor(seen[node] or 0, 2) - if seen[node] == 3 then - insert(nodes, node) + + if dummy_mode then seen_up[node_in] = true end + + nodes = {} + q = {node_in} + while #q > 0 do + local node = remove(q, 1) + if seen_up[node] then + local all_parents_added = true + for _, parent in ipairs(node.parents) do + if not contains(nodes, parent) then + all_parents_added = false + break + end + end + if not contains(nodes, node) and all_parents_added then + insert(nodes, node) + end + for _, child in ipairs(node.children) do insert(q, child) end end end + + if dummy_mode then remove(nodes, indexof(nodes, node_in)) end + return nodes end local function traverse_all(nodes_in, nodes_out, nodes) - local all_in = {children={}} - local all_out = {parents={}} + local all_in = {children={}, parents={}} + local all_out = {children={}, parents={}} for _, node in ipairs(nodes_in) do insert(all_in.children, node) end for _, node in ipairs(nodes_out) do insert(all_out.parents, node) end - return traverse(all_in, all_out, nodes or {}) + return traverse(all_in, all_out, nodes or {}, true) end local Weights = Base:extend() @@ -443,14 +451,17 @@ function Model:load(fn) end return { - uniform = uniform, - normal = normal, - copy = copy, + indexof = indexof, + contains = contains, + prod = prod, + normal = normal, zeros = zeros, init_zeros = init_zeros, init_he_uniform = init_he_uniform, init_he_normal = init_he_normal, + allocate = allocate, + traverse = traverse, Weights = Weights, Layer = Layer,