80 lines
1.6 KiB
Lua
80 lines
1.6 KiB
Lua
-- gram-schmidt process
|
|
|
|
local nn = require "nn"
|
|
|
|
local dot = nn.dot
|
|
local reshape = nn.reshape
|
|
local sqrt = math.sqrt
|
|
local transpose = nn.transpose
|
|
local zeros = nn.zeros
|
|
|
|
local function qr(mat)
|
|
assert(#mat.shape == 2)
|
|
local v = transpose(mat)
|
|
local rows = v.shape[1]
|
|
local cols = v.shape[2]
|
|
|
|
local u = zeros(v.shape)
|
|
|
|
local w = {}
|
|
local y = 1
|
|
|
|
local function load_row()
|
|
local start = (y - 1) * cols
|
|
for x = 1, cols do w[x] = v[start + x] end
|
|
end
|
|
|
|
local function push_row()
|
|
local sum = 0
|
|
for _, value in ipairs(w) do sum = sum + value * value end
|
|
local norm = sqrt(sum)
|
|
|
|
local start = (y - 1) * cols
|
|
for x, value in ipairs(w) do u[start + x] = value / norm end
|
|
|
|
y = y + 1
|
|
end
|
|
|
|
load_row()
|
|
push_row()
|
|
|
|
local sums = {}
|
|
|
|
for i = 2, rows do
|
|
load_row()
|
|
|
|
for x = 1, cols do sums[x] = 0 end
|
|
|
|
for j = 1, i - 1 do
|
|
local start = (j - 1) * cols
|
|
|
|
local dotted = 0
|
|
for x, value in ipairs(w) do
|
|
dotted = dotted + value * u[start + x]
|
|
end
|
|
|
|
--[[
|
|
local scale = 0
|
|
for x = 1, cols do
|
|
local value = u[start + x]
|
|
scale = scale + value * value
|
|
end
|
|
print(scale)
|
|
dotted = dotted / scale
|
|
--]]
|
|
|
|
for x, value in ipairs(sums) do
|
|
sums[x] = value + dotted * u[start + x]
|
|
end
|
|
end
|
|
|
|
for x, value in ipairs(w) do w[x] = value - sums[x] end
|
|
|
|
push_row()
|
|
end
|
|
|
|
return transpose(u), dot(u, mat)
|
|
end
|
|
|
|
return qr
|