smbot/qr.lua
2019-03-11 07:15:41 +01:00

114 lines
2.7 KiB
Lua

local nn = require "nn"
local assert = assert
local dot = nn.dot
local ipairs = ipairs
local min = math.min
local reshape = nn.reshape
local sqrt = math.sqrt
local transpose = nn.transpose
local zeros = nn.zeros
local function minor(x, d)
assert(#x.shape == 2)
assert(d <= x.shape[1] and d <= x.shape[2])
local m = zeros(x.shape)
-- fill diagonals.
--for i = 1, d do m[(i - 1) * m.shape[2] + i] = 1 end
for i = 1, d * m.shape[2], m.shape[2] + 1 do m[i] = 1 end
-- copy values.
for i = d + 1, m.shape[1] do
for j = d + 1, m.shape[2] do
local ind = (i - 1) * m.shape[2] + j
m[ind] = x[ind]
end
end
return m
end
local function norm(a) -- vector norm
local sum = 0
for _, v in ipairs(a) do sum = sum + v * v end
return sqrt(sum)
end
local function householder(x)
local rows = x.shape[1]
local cols = x.shape[2]
local iters = min(rows - 1, cols)
local q = nil
local vec = zeros(rows)
local z = x
for k = 1, iters do
z = minor(z, k - 1)
-- extract a column.
for i = 1, rows do vec[i] = z[k + (i - 1) * cols] end
local a = norm(vec)
-- negate the norm if the original diagonal is non-negative.
local ind = (k - 1) * cols + k
if x[ind] > 0 then a = -a end
vec[k] = vec[k] + a
local a = norm(vec)
if a == 0 then a = 1 end -- FIXME: should probably just raise an error.
for i, v in ipairs(vec) do vec[i] = v / a end
-- construct the householder reflection: mat = I - 2 * vec * vec.T
local mat = zeros{rows, rows}
for i = 1, rows do
for j = 1, rows do
local ind = (i - 1) * rows + j
local diag = i == j and 1 or 0
mat[ind] = diag - 2 * vec[i] * vec[j]
end
end
--print(nn.pp(mat, "%9.3f"))
if q == nil then q = mat else q = dot(mat, q) end
z = dot(mat, z)
end
return transpose(q), dot(q, x) -- Q, R
end
local function qr(x)
-- a wrapper for the householder method that will return reduced matrices.
assert(#x.shape == 2)
local q, r = householder(x)
local rows = x.shape[1]
local cols = x.shape[2]
if cols >= rows then return q, r end
-- trim q in-place.
q.shape[2] = cols
local ind = 1
for i = 1, rows do
for j = 1, cols do
--ind = (i - 1) * cols + j
q[ind] = q[(i - 1) * rows + j]
ind = ind + 1
end
end
for i = rows * cols + 1, #q do q[i] = nil end
-- trim r in-place.
r.shape[1] = r.shape[2]
for i = r.shape[1] * r.shape[2] + 1, #r do r[i] = nil end
return q, r
end
return qr