291 lines
7.2 KiB
Lua
291 lines
7.2 KiB
Lua
-- blah
|
|
|
|
--local globalize = require "strict"
|
|
local nn = require "nn"
|
|
local util = require "util"
|
|
|
|
local sign = util.sign
|
|
local abs = math.abs
|
|
local reshape = nn.reshape
|
|
local sqrt = math.sqrt
|
|
local zeros = nn.zeros
|
|
|
|
local function tred2(a)
|
|
assert(#a.shape == 2)
|
|
assert(a.shape[1] == a.shape[2])
|
|
local n = a.shape[1]
|
|
local d = zeros(n) -- diagonal
|
|
local e = zeros(n) -- off-diagonal (e[1] is a dummy value?)
|
|
|
|
for i = n, 2, -1 do
|
|
local l = i - 1
|
|
local ind_li = (l - 1) * n + i
|
|
local h = 0
|
|
local scale = 0
|
|
|
|
if l > 1 then
|
|
for k = 1, l do
|
|
local ind_ki = (k - 1) * n + i
|
|
scale = scale + abs(a[ind_ki])
|
|
end
|
|
|
|
if scale == 0 then
|
|
e[i] = a[ind_li]
|
|
else
|
|
for k = 1, l do
|
|
local ind_ki = (k - 1) * n + i
|
|
a[ind_ki] = a[ind_ki] / scale
|
|
h = h + a[ind_ki] * a[ind_ki]
|
|
end
|
|
|
|
local f = a[ind_li]
|
|
local g = sqrt(h)
|
|
if f >= 0 then g = -g end
|
|
e[i] = scale * g
|
|
h = h - f * g
|
|
a[ind_li] = f - g
|
|
f = 0
|
|
|
|
for j = 1, l do
|
|
local ind_ij = (i - 1) * n + j
|
|
local ind_ji = (j - 1) * n + i
|
|
a[ind_ij] = a[ind_ji] / h
|
|
g = 0
|
|
for k = 1, j do
|
|
local ind_kj = (k - 1) * n + j
|
|
local ind_ki = (k - 1) * n + i
|
|
g = g + a[ind_kj] * a[ind_ki]
|
|
end
|
|
for k = j + 1, l do
|
|
local ind_jk = (j - 1) * n + k
|
|
local ind_ki = (k - 1) * n + i
|
|
g = g + a[ind_jk] * a[ind_ki]
|
|
end
|
|
e[j] = g / h
|
|
f = f + e[j] * a[ind_ji]
|
|
end
|
|
|
|
local hh = f / (h + h)
|
|
for j = 1, l do
|
|
local ind_ji = (j - 1) * n + i
|
|
f = a[ind_ji]
|
|
g = e[j] - hh * f
|
|
e[j] = g
|
|
for k = 1, j do
|
|
local ind_kj = (k - 1) * n + j
|
|
local ind_ki = (k - 1) * n + i
|
|
a[ind_kj] = a[ind_kj] - (f * e[k] + g * a[ind_ki])
|
|
end
|
|
end
|
|
end
|
|
else
|
|
e[i] = a[ind_li]
|
|
end
|
|
|
|
d[i] = h
|
|
end
|
|
|
|
d[1] = 0
|
|
e[1] = 0
|
|
for i = 1, n do
|
|
local l = i - 1
|
|
|
|
if d[i] ~= 0 then
|
|
for j = 1, l do
|
|
local g = 0
|
|
for k = 1, l do
|
|
local ind_ki = (k - 1) * n + i
|
|
local ind_jk = (j - 1) * n + k
|
|
g = g + a[ind_ki] * a[ind_jk]
|
|
end
|
|
for k = 1, l do
|
|
local ind_ik = (i - 1) * n + k
|
|
local ind_jk = (j - 1) * n + k
|
|
a[ind_jk] = a[ind_jk] - g * a[ind_ik]
|
|
end
|
|
end
|
|
end
|
|
|
|
local ind_ii = (i - 1) * n + i
|
|
d[i] = a[ind_ii]
|
|
a[ind_ii] = 1
|
|
for j = 1, l do
|
|
local ind_ij = (i - 1) * n + j
|
|
local ind_ji = (j - 1) * n + i
|
|
a[ind_ij] = 0
|
|
a[ind_ji] = 0
|
|
end
|
|
end
|
|
|
|
return d, e
|
|
end
|
|
|
|
local function pythag(a, b)
|
|
--return sqrt(a * a + b * b)
|
|
local abs_a = abs(a)
|
|
local abs_b = abs(b)
|
|
|
|
if abs_a > abs_b then
|
|
local temp = abs_b / abs_a
|
|
temp = temp * temp
|
|
return abs_a * sqrt(1 + temp)
|
|
elseif abs_b ~= 0 then
|
|
local temp = abs_a / abs_b
|
|
temp = temp * temp
|
|
return abs_b * sqrt(1 + temp)
|
|
end
|
|
|
|
return 0
|
|
end
|
|
|
|
local function tqli(d, e, z)
|
|
assert(#z.shape == 2)
|
|
assert(z.shape[1] == z.shape[2])
|
|
local n = z.shape[1]
|
|
assert(#d == n)
|
|
assert(#e == n)
|
|
|
|
local eps = 1.2e-7
|
|
|
|
for i = 2, n do e[i - 1] = e[i] end
|
|
e[n] = 0
|
|
|
|
for l = 1, n do
|
|
local iter = 0
|
|
local fucky = 0
|
|
|
|
local m
|
|
while true do
|
|
m = l
|
|
while m <= n - 1 do
|
|
local dd = abs(d[m]) + abs(d[m + 1])
|
|
if abs(e[m]) + dd == dd then break end
|
|
--if abs(e[m]) <= eps * dd then break end
|
|
m = m + 1
|
|
end
|
|
|
|
fucky = fucky + 1
|
|
if fucky == 100 then print("fucky!"); break end
|
|
if fucky == 1000 then error("super fucky!"); break end
|
|
|
|
--print(("l: %i, m: %i"):format(l - 1, m - 1))
|
|
|
|
if m == l then break end
|
|
|
|
iter = iter + 1
|
|
if iter >= 32 then error("Too many iterations in tqli") end
|
|
|
|
local g = (d[l + 1] - d[l]) / (2 * e[l])
|
|
local r = pythag(g, 1)
|
|
g = d[m] - d[l] + e[l] / (g + r * sign(g))
|
|
local s = 1
|
|
local c = 1
|
|
local p = 0
|
|
|
|
for i = m - 1, l, -1 do
|
|
local f = s * e[i]
|
|
local b = c * e[i]
|
|
r = pythag(f, g)
|
|
e[i + 1] = r
|
|
if r == 0 then
|
|
d[i + 1] = d[i + 1] - p
|
|
e[m] = 0
|
|
break
|
|
end
|
|
|
|
s = f / r
|
|
c = g / r
|
|
g = d[i + 1] - p
|
|
r = (d[i] - g) * s + 2 * c * b
|
|
p = s * r
|
|
d[i + 1] = g + p
|
|
g = c * r - b
|
|
|
|
for k = 1, n do
|
|
if true then
|
|
local ind = (i - 1) * n + k
|
|
f = z[ind + n]
|
|
z[ind + n] = s * z[ind] + c * f
|
|
z[ind] = c * z[ind] - s * f
|
|
else
|
|
local ind = (k - 1) * n + i
|
|
f = z[ind + 1]
|
|
z[ind + 1] = s * z[ind] + c * f
|
|
z[ind] = c * z[ind] - s * f
|
|
end
|
|
end
|
|
end
|
|
|
|
if r == 0 and i >= l then
|
|
-- continue
|
|
else
|
|
d[l] = d[l] - p
|
|
e[l] = g
|
|
e[m] = 0
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
--[=[
|
|
|
|
local A = {
|
|
4, 1, -2, 2,
|
|
1, 2, 0, 1,
|
|
-2, 0, 3, -2,
|
|
2, 1, -2, -1,
|
|
}
|
|
reshape(A, 4, 4)
|
|
|
|
local d, e = tred2(A)
|
|
|
|
--[[
|
|
print(nn.pp(A))
|
|
print(nn.pp(d))
|
|
print(nn.pp(e))
|
|
--]]
|
|
|
|
--[[
|
|
{
|
|
0.248069, 0.744208, 0.620174, 0.000000,
|
|
0.702863, -0.578829, 0.413449, 0.000000,
|
|
0.666667, 0.333333, -0.666667, 0.000000,
|
|
0.000000, 0.000000, 0.000000, 1.000000,
|
|
}
|
|
{ 2.261538, 1.182906, 5.555556, -1.000000,}
|
|
{ 0.000000, -0.092308, 0.895806, 3.000000,}
|
|
--]]
|
|
|
|
--A = nn.transpose(A)
|
|
tqli(d, e, A)
|
|
|
|
print(nn.pp(A))
|
|
print(nn.pp(d))
|
|
print(nn.pp(e))
|
|
|
|
local D = zeros{4, 4}
|
|
for i = 1, 4 do
|
|
D[(i - 1) * 4 + i] = math.exp(d[i])
|
|
end
|
|
local out = nn.dot(nn.transpose(A), D)
|
|
out = nn.dot(out, A)
|
|
print(nn.pp(out))
|
|
|
|
--[[
|
|
{
|
|
703.414032, 1410.125991, 1478.990752, -43.126976,
|
|
-1205.204565, 1573.963121, -940.902581, 319.676625,
|
|
14.433478, 3.884400, -10.434671, 6.841011,
|
|
3.529384, 3.087068, -5.116236, -17.850223,
|
|
}
|
|
{ 2.273819, 1.072834, 6.818268, -2.164921,}
|
|
{ -0.000000, 0.000000, 0.000000, 0.000000,}
|
|
--]]
|
|
|
|
--]=]
|
|
|
|
return {
|
|
tred2=tred2,
|
|
tqli=tqli,
|
|
}
|