76 lines
2.1 KiB
Lua
76 lines
2.1 KiB
Lua
local min = math.min
|
|
local sqrt = math.sqrt
|
|
|
|
local nn = require "nn"
|
|
local transpose = nn.transpose
|
|
local zeros = nn.zeros
|
|
|
|
local function qr(a)
|
|
-- FIXME: if first column is exactly zero,
|
|
-- and cols > rows, Q @ R will not reconstruct the input.
|
|
-- this isn't too bad since an input like that is invalid anyway,
|
|
-- but i feel like it should be salvageable.
|
|
|
|
-- actually the scope of the problem is much larger than that.
|
|
-- an input like
|
|
--[=[
|
|
[[0, 0, 0, 0]
|
|
[1, 0, 1, 1]
|
|
[0, 0, 2, 2]
|
|
[0, 0, 3, 3]]
|
|
--]=]
|
|
-- will cause a lot of problems. for example, Q @ Q.T won't equal eye(4).
|
|
-- hmm. maybe we can detect this and reverse the matmul to identity if necessary?
|
|
|
|
assert(#a.shape == 2)
|
|
local rows = a.shape[1]
|
|
local cols = a.shape[2]
|
|
local small = min(rows, cols)
|
|
|
|
local q = transpose(a)
|
|
local r = zeros{small, cols}
|
|
|
|
for i = 1, cols do
|
|
local i0 = (i - 1) * rows + 1
|
|
local i1 = i * rows
|
|
|
|
for j = 1, min(i - 1, small) do
|
|
local j0 = (j - 1) * rows + 1
|
|
local j1 = j * rows
|
|
local i_to_j = j0 - i0
|
|
|
|
local num = 0
|
|
local den = 0
|
|
for k = i0, i1 do num = num + q[k] * q[k + i_to_j] end
|
|
for k = j0, j1 do den = den + q[k] * q[k] end
|
|
print(num, den)
|
|
if den == 0 then den = 1 end -- TODO: should probably just error.
|
|
|
|
local x = num / den
|
|
r[(j - 1) * cols + i] = x
|
|
for k = i0, i1 do q[k] = q[k] - q[k + i_to_j] * x end
|
|
end
|
|
|
|
if i <= small then
|
|
local sum = 0
|
|
for k = i0, i1 do sum = sum + q[k] * q[k] end
|
|
local norm = sqrt(sum)
|
|
|
|
if norm == 0 then
|
|
--norm = 1
|
|
--q[i0 + i - 1] = 1 -- FIXME: not robust.
|
|
r[(i - 1) * cols + i] = 0
|
|
else
|
|
for k = i0, i1 do q[k] = q[k] / norm end
|
|
r[(i - 1) * cols + i] = norm
|
|
end
|
|
end
|
|
end
|
|
|
|
for k = small * rows + 1, #q do q[k] = nil end
|
|
q.shape[1] = small
|
|
|
|
return transpose(q), r
|
|
end
|
|
|
|
return qr
|