smbot/qr3.lua
2019-03-11 07:15:41 +01:00

80 lines
1.6 KiB
Lua

-- gram-schmidt process
local nn = require "nn"
local dot = nn.dot
local reshape = nn.reshape
local sqrt = math.sqrt
local transpose = nn.transpose
local zeros = nn.zeros
local function qr(mat)
assert(#mat.shape == 2)
local v = transpose(mat)
local rows = v.shape[1]
local cols = v.shape[2]
local u = zeros(v.shape)
local w = {}
local y = 1
local function load_row()
local start = (y - 1) * cols
for x = 1, cols do w[x] = v[start + x] end
end
local function push_row()
local sum = 0
for _, value in ipairs(w) do sum = sum + value * value end
local norm = sqrt(sum)
local start = (y - 1) * cols
for x, value in ipairs(w) do u[start + x] = value / norm end
y = y + 1
end
load_row()
push_row()
local sums = {}
for i = 2, rows do
load_row()
for x = 1, cols do sums[x] = 0 end
for j = 1, i - 1 do
local start = (j - 1) * cols
local dotted = 0
for x, value in ipairs(w) do
dotted = dotted + value * u[start + x]
end
--[[
local scale = 0
for x = 1, cols do
local value = u[start + x]
scale = scale + value * value
end
print(scale)
dotted = dotted / scale
--]]
for x, value in ipairs(sums) do
sums[x] = value + dotted * u[start + x]
end
end
for x, value in ipairs(w) do w[x] = value - sums[x] end
push_row()
end
return transpose(u), dot(u, mat)
end
return qr