111 lines
2.6 KiB
Lua
111 lines
2.6 KiB
Lua
local min = math.min
|
|
local sqrt = math.sqrt
|
|
|
|
local nn = require "nn"
|
|
local dot = nn.dot
|
|
local reshape = nn.reshape
|
|
local transpose = nn.transpose
|
|
local zeros = nn.zeros
|
|
|
|
local function minor(x, d)
|
|
assert(#x.shape == 2)
|
|
assert(d <= x.shape[1] and d <= x.shape[2])
|
|
|
|
local m = zeros(x.shape)
|
|
|
|
-- fill diagonals.
|
|
--for i = 1, d do m[(i - 1) * m.shape[2] + i] = 1 end
|
|
for i = 1, d * m.shape[2], m.shape[2] + 1 do m[i] = 1 end
|
|
|
|
-- copy values.
|
|
for i = d + 1, m.shape[1] do
|
|
for j = d + 1, m.shape[2] do
|
|
local ind = (i - 1) * m.shape[2] + j
|
|
m[ind] = x[ind]
|
|
end
|
|
end
|
|
|
|
return m
|
|
end
|
|
|
|
local function norm(a) -- vector norm
|
|
local sum = 0
|
|
for _, v in ipairs(a) do sum = sum + v * v end
|
|
return sqrt(sum)
|
|
end
|
|
|
|
local function householder(x)
|
|
local rows = x.shape[1]
|
|
local cols = x.shape[2]
|
|
local iters = min(rows - 1, cols)
|
|
|
|
local q = nil
|
|
local vec = zeros(rows)
|
|
local z = x
|
|
|
|
for k = 1, iters do
|
|
z = minor(z, k - 1)
|
|
|
|
-- extract a column.
|
|
for i = 1, rows do vec[i] = z[k + (i - 1) * cols] end
|
|
|
|
local a = norm(vec)
|
|
-- negate the norm if the original diagonal is non-negative.
|
|
local ind = (k - 1) * cols + k
|
|
if x[ind] > 0 then a = -a end
|
|
|
|
vec[k] = vec[k] + a
|
|
|
|
local a = norm(vec)
|
|
if a == 0 then a = 1 end -- FIXME: should probably just raise an error.
|
|
for i, v in ipairs(vec) do vec[i] = v / a end
|
|
|
|
-- construct the householder reflection: mat = I - 2 * vec * vec.T
|
|
local mat = zeros{rows, rows}
|
|
for i = 1, rows do
|
|
for j = 1, rows do
|
|
local ind = (i - 1) * rows + j
|
|
local diag = i == j and 1 or 0
|
|
mat[ind] = diag - 2 * vec[i] * vec[j]
|
|
end
|
|
end
|
|
|
|
--print(nn.pp(mat, "%9.3f"))
|
|
if q == nil then q = mat else q = dot(mat, q) end
|
|
|
|
z = dot(mat, z)
|
|
end
|
|
|
|
return transpose(q), dot(q, x) -- Q, R
|
|
end
|
|
|
|
local function qr(x)
|
|
-- a wrapper for the householder method that will return reduced matrices.
|
|
assert(#x.shape == 2)
|
|
|
|
local q, r = householder(x)
|
|
|
|
local rows = x.shape[1]
|
|
local cols = x.shape[2]
|
|
if cols >= rows then return q, r end
|
|
|
|
-- trim q in-place.
|
|
q.shape[2] = cols
|
|
local ind = 1
|
|
for i = 1, rows do
|
|
for j = 1, cols do
|
|
--ind = (i - 1) * cols + j
|
|
q[ind] = q[(i - 1) * rows + j]
|
|
ind = ind + 1
|
|
end
|
|
end
|
|
for i = rows * cols + 1, #q do q[i] = nil end
|
|
|
|
-- trim r in-place.
|
|
r.shape[1] = r.shape[2]
|
|
for i = r.shape[1] * r.shape[2] + 1, #r do r[i] = nil end
|
|
|
|
return q, r
|
|
end
|
|
|
|
return qr
|