move dot_mv to nn
This commit is contained in:
parent
0100934ac4
commit
771650613c
2 changed files with 27 additions and 26 deletions
26
nn.lua
26
nn.lua
|
@ -187,6 +187,31 @@ local function cache(bs, shape)
|
||||||
return zeros(fullshape)
|
return zeros(fullshape)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function dot_mv(mat, vec, out)
|
||||||
|
-- treats matrix as a matrix.
|
||||||
|
-- treats vec as a column vector, flattened.
|
||||||
|
assert(#mat.shape == 2)
|
||||||
|
local d0, d1 = unpack(mat.shape)
|
||||||
|
assert(d1 == #vec)
|
||||||
|
|
||||||
|
local out_shape = {d0}
|
||||||
|
if out == nil then
|
||||||
|
out = zeros(out_shape)
|
||||||
|
else
|
||||||
|
assert(d0 == #out, "given output is the wrong size")
|
||||||
|
end
|
||||||
|
|
||||||
|
for i=1, d0 do
|
||||||
|
local sum = 0
|
||||||
|
for j=1, d1 do
|
||||||
|
sum = sum + mat[(i - 1) * d1 + j] * vec[j]
|
||||||
|
end
|
||||||
|
out[i] = sum
|
||||||
|
end
|
||||||
|
|
||||||
|
return out
|
||||||
|
end
|
||||||
|
|
||||||
local function dot(a, b, ax_a, ax_b, out)
|
local function dot(a, b, ax_a, ax_b, out)
|
||||||
ax_a = ax_a or #a.shape - 0
|
ax_a = ax_a or #a.shape - 0
|
||||||
ax_b = ax_b or #b.shape - 1
|
ax_b = ax_b or #b.shape - 1
|
||||||
|
@ -806,6 +831,7 @@ return {
|
||||||
reshape = reshape,
|
reshape = reshape,
|
||||||
pp = pp,
|
pp = pp,
|
||||||
ppi = ppi,
|
ppi = ppi,
|
||||||
|
dot_mv = dot_mv,
|
||||||
dot = dot,
|
dot = dot,
|
||||||
traverse = traverse,
|
traverse = traverse,
|
||||||
traverse_all = traverse_all,
|
traverse_all = traverse_all,
|
||||||
|
|
27
xnes.lua
27
xnes.lua
|
@ -16,6 +16,7 @@ local unpack = table.unpack or unpack
|
||||||
local Base = require "Base"
|
local Base = require "Base"
|
||||||
|
|
||||||
local nn = require "nn"
|
local nn = require "nn"
|
||||||
|
local dot_mv = nn.dot_mv
|
||||||
local normal = nn.normal
|
local normal = nn.normal
|
||||||
local zeros = nn.zeros
|
local zeros = nn.zeros
|
||||||
|
|
||||||
|
@ -24,31 +25,6 @@ local argsort = util.argsort
|
||||||
|
|
||||||
local Xnes = Base:extend()
|
local Xnes = Base:extend()
|
||||||
|
|
||||||
local function dot_mv(mat, vec, out)
|
|
||||||
-- treats matrix as a matrix.
|
|
||||||
-- treats vec as a column vector, flattened.
|
|
||||||
assert(#mat.shape == 2)
|
|
||||||
local d0, d1 = unpack(mat.shape)
|
|
||||||
assert(d1 == #vec)
|
|
||||||
|
|
||||||
local out_shape = {d0}
|
|
||||||
if out == nil then
|
|
||||||
out = zeros(out_shape)
|
|
||||||
else
|
|
||||||
assert(d0 == #out, "given output is the wrong size")
|
|
||||||
end
|
|
||||||
|
|
||||||
for i=1, d0 do
|
|
||||||
local sum = 0
|
|
||||||
for j=1, d1 do
|
|
||||||
sum = sum + mat[(i - 1) * d1 + j] * vec[j]
|
|
||||||
end
|
|
||||||
out[i] = sum
|
|
||||||
end
|
|
||||||
|
|
||||||
return out
|
|
||||||
end
|
|
||||||
|
|
||||||
local function make_utility(popsize, out)
|
local function make_utility(popsize, out)
|
||||||
local utility = out or {}
|
local utility = out or {}
|
||||||
local temp = log(popsize / 2 + 1)
|
local temp = log(popsize / 2 + 1)
|
||||||
|
@ -230,7 +206,6 @@ function Xnes:tell(scored, noise)
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
dot_mv = dot_mv,
|
|
||||||
make_utility = make_utility,
|
make_utility = make_utility,
|
||||||
make_covars = make_covars,
|
make_covars = make_covars,
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue