diff --git a/nn.lua b/nn.lua index 82dea0b..3faef88 100644 --- a/nn.lua +++ b/nn.lua @@ -187,6 +187,31 @@ local function cache(bs, shape) return zeros(fullshape) end +local function dot_mv(mat, vec, out) + -- treats matrix as a matrix. + -- treats vec as a column vector, flattened. + assert(#mat.shape == 2) + local d0, d1 = unpack(mat.shape) + assert(d1 == #vec) + + local out_shape = {d0} + if out == nil then + out = zeros(out_shape) + else + assert(d0 == #out, "given output is the wrong size") + end + + for i=1, d0 do + local sum = 0 + for j=1, d1 do + sum = sum + mat[(i - 1) * d1 + j] * vec[j] + end + out[i] = sum + end + + return out +end + local function dot(a, b, ax_a, ax_b, out) ax_a = ax_a or #a.shape - 0 ax_b = ax_b or #b.shape - 1 @@ -806,6 +831,7 @@ return { reshape = reshape, pp = pp, ppi = ppi, + dot_mv = dot_mv, dot = dot, traverse = traverse, traverse_all = traverse_all, diff --git a/xnes.lua b/xnes.lua index cc2c942..cc33599 100644 --- a/xnes.lua +++ b/xnes.lua @@ -16,6 +16,7 @@ local unpack = table.unpack or unpack local Base = require "Base" local nn = require "nn" +local dot_mv = nn.dot_mv local normal = nn.normal local zeros = nn.zeros @@ -24,31 +25,6 @@ local argsort = util.argsort local Xnes = Base:extend() -local function dot_mv(mat, vec, out) - -- treats matrix as a matrix. - -- treats vec as a column vector, flattened. - assert(#mat.shape == 2) - local d0, d1 = unpack(mat.shape) - assert(d1 == #vec) - - local out_shape = {d0} - if out == nil then - out = zeros(out_shape) - else - assert(d0 == #out, "given output is the wrong size") - end - - for i=1, d0 do - local sum = 0 - for j=1, d1 do - sum = sum + mat[(i - 1) * d1 + j] * vec[j] - end - out[i] = sum - end - - return out -end - local function make_utility(popsize, out) local utility = out or {} local temp = log(popsize / 2 + 1) @@ -230,7 +206,6 @@ function Xnes:tell(scored, noise) end return { - dot_mv = dot_mv, make_utility = make_utility, make_covars = make_covars,