move dot_mv to nn
This commit is contained in:
parent
0100934ac4
commit
771650613c
2 changed files with 27 additions and 26 deletions
26
nn.lua
26
nn.lua
|
@ -187,6 +187,31 @@ local function cache(bs, shape)
|
|||
return zeros(fullshape)
|
||||
end
|
||||
|
||||
local function dot_mv(mat, vec, out)
|
||||
-- treats matrix as a matrix.
|
||||
-- treats vec as a column vector, flattened.
|
||||
assert(#mat.shape == 2)
|
||||
local d0, d1 = unpack(mat.shape)
|
||||
assert(d1 == #vec)
|
||||
|
||||
local out_shape = {d0}
|
||||
if out == nil then
|
||||
out = zeros(out_shape)
|
||||
else
|
||||
assert(d0 == #out, "given output is the wrong size")
|
||||
end
|
||||
|
||||
for i=1, d0 do
|
||||
local sum = 0
|
||||
for j=1, d1 do
|
||||
sum = sum + mat[(i - 1) * d1 + j] * vec[j]
|
||||
end
|
||||
out[i] = sum
|
||||
end
|
||||
|
||||
return out
|
||||
end
|
||||
|
||||
local function dot(a, b, ax_a, ax_b, out)
|
||||
ax_a = ax_a or #a.shape - 0
|
||||
ax_b = ax_b or #b.shape - 1
|
||||
|
@ -806,6 +831,7 @@ return {
|
|||
reshape = reshape,
|
||||
pp = pp,
|
||||
ppi = ppi,
|
||||
dot_mv = dot_mv,
|
||||
dot = dot,
|
||||
traverse = traverse,
|
||||
traverse_all = traverse_all,
|
||||
|
|
27
xnes.lua
27
xnes.lua
|
@ -16,6 +16,7 @@ local unpack = table.unpack or unpack
|
|||
local Base = require "Base"
|
||||
|
||||
local nn = require "nn"
|
||||
local dot_mv = nn.dot_mv
|
||||
local normal = nn.normal
|
||||
local zeros = nn.zeros
|
||||
|
||||
|
@ -24,31 +25,6 @@ local argsort = util.argsort
|
|||
|
||||
local Xnes = Base:extend()
|
||||
|
||||
local function dot_mv(mat, vec, out)
|
||||
-- treats matrix as a matrix.
|
||||
-- treats vec as a column vector, flattened.
|
||||
assert(#mat.shape == 2)
|
||||
local d0, d1 = unpack(mat.shape)
|
||||
assert(d1 == #vec)
|
||||
|
||||
local out_shape = {d0}
|
||||
if out == nil then
|
||||
out = zeros(out_shape)
|
||||
else
|
||||
assert(d0 == #out, "given output is the wrong size")
|
||||
end
|
||||
|
||||
for i=1, d0 do
|
||||
local sum = 0
|
||||
for j=1, d1 do
|
||||
sum = sum + mat[(i - 1) * d1 + j] * vec[j]
|
||||
end
|
||||
out[i] = sum
|
||||
end
|
||||
|
||||
return out
|
||||
end
|
||||
|
||||
local function make_utility(popsize, out)
|
||||
local utility = out or {}
|
||||
local temp = log(popsize / 2 + 1)
|
||||
|
@ -230,7 +206,6 @@ function Xnes:tell(scored, noise)
|
|||
end
|
||||
|
||||
return {
|
||||
dot_mv = dot_mv,
|
||||
make_utility = make_utility,
|
||||
make_covars = make_covars,
|
||||
|
||||
|
|
Loading…
Reference in a new issue