smbot/qr2.lua
2019-03-11 07:15:41 +01:00

77 lines
2.1 KiB
Lua

local min = math.min
local sqrt = math.sqrt
local nn = require "nn"
local transpose = nn.transpose
local zeros = nn.zeros
local function qr(a)
-- FIXME: if first column is exactly zero,
-- and cols > rows, Q @ R will not reconstruct the input.
-- this isn't too bad since an input like that is invalid anyway,
-- but i feel like it should be salvageable.
-- actually the scope of the problem is much larger than that.
-- an input like
--[=[
[[0, 0, 0, 0]
[1, 0, 1, 1]
[0, 0, 2, 2]
[0, 0, 3, 3]]
--]=]
-- will cause a lot of problems. for example, Q @ Q.T won't equal eye(4).
-- hmm. maybe we can detect this and reverse the matmul to identity if necessary?
assert(#a.shape == 2)
local rows = a.shape[1]
local cols = a.shape[2]
local small = min(rows, cols)
local q = transpose(a)
local r = zeros{small, cols}
for i = 1, cols do
local i0 = (i - 1) * rows + 1
local i1 = i * rows
for j = 1, min(i - 1, small) do
local j0 = (j - 1) * rows + 1
local j1 = j * rows
local i_to_j = j0 - i0
local num = 0
local den = 0
for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end
for k = j0, j1 do den = den + q[k] * q[k] end
--print(num, den)
if den == 0 then den = 1 end -- TODO: should probably just error.
local x = num / den
r[(j - 1) * cols + i] = x
for k = i0, i1 do q[k] = q[k] - q[k + i_to_j] * x end
end
if i <= small then
local sum = 0
for k = i0, i1 do sum = sum + q[k] * q[k] end
local norm = sqrt(sum)
if norm == 0 then
--norm = 1
--q[i0 + i - 1] = 1 -- FIXME: not robust.
r[(i - 1) * cols + i] = 0
else
for k = i0, i1 do q[k] = q[k] / norm end
r[(i - 1) * cols + i] = norm
end
end
end
for k = small * rows + 1, #q do q[k] = nil end
q.shape[1] = small
return transpose(q), r
end
return qr