move dot_mv to nn

This commit is contained in:
Connor Olding 2018-06-10 16:34:20 +02:00
parent 0100934ac4
commit 771650613c
2 changed files with 27 additions and 26 deletions

26
nn.lua
View file

@ -187,6 +187,31 @@ local function cache(bs, shape)
return zeros(fullshape)
end
local function dot_mv(mat, vec, out)
-- treats matrix as a matrix.
-- treats vec as a column vector, flattened.
assert(#mat.shape == 2)
local d0, d1 = unpack(mat.shape)
assert(d1 == #vec)
local out_shape = {d0}
if out == nil then
out = zeros(out_shape)
else
assert(d0 == #out, "given output is the wrong size")
end
for i=1, d0 do
local sum = 0
for j=1, d1 do
sum = sum + mat[(i - 1) * d1 + j] * vec[j]
end
out[i] = sum
end
return out
end
local function dot(a, b, ax_a, ax_b, out)
ax_a = ax_a or #a.shape - 0
ax_b = ax_b or #b.shape - 1
@ -806,6 +831,7 @@ return {
reshape = reshape,
pp = pp,
ppi = ppi,
dot_mv = dot_mv,
dot = dot,
traverse = traverse,
traverse_all = traverse_all,

View file

@ -16,6 +16,7 @@ local unpack = table.unpack or unpack
local Base = require "Base"
local nn = require "nn"
local dot_mv = nn.dot_mv
local normal = nn.normal
local zeros = nn.zeros
@ -24,31 +25,6 @@ local argsort = util.argsort
local Xnes = Base:extend()
local function dot_mv(mat, vec, out)
-- treats matrix as a matrix.
-- treats vec as a column vector, flattened.
assert(#mat.shape == 2)
local d0, d1 = unpack(mat.shape)
assert(d1 == #vec)
local out_shape = {d0}
if out == nil then
out = zeros(out_shape)
else
assert(d0 == #out, "given output is the wrong size")
end
for i=1, d0 do
local sum = 0
for j=1, d1 do
sum = sum + mat[(i - 1) * d1 + j] * vec[j]
end
out[i] = sum
end
return out
end
local function make_utility(popsize, out)
local utility = out or {}
local temp = log(popsize / 2 + 1)
@ -230,7 +206,6 @@ function Xnes:tell(scored, noise)
end
return {
dot_mv = dot_mv,
make_utility = make_utility,
make_covars = make_covars,