smbot/eig.lua
2019-03-11 07:15:41 +01:00

291 lines
7.2 KiB
Lua

-- blah
--local globalize = require "strict"
local nn = require "nn"
local util = require "util"
local sign = util.sign
local abs = math.abs
local reshape = nn.reshape
local sqrt = math.sqrt
local zeros = nn.zeros
local function tred2(a)
assert(#a.shape == 2)
assert(a.shape[1] == a.shape[2])
local n = a.shape[1]
local d = zeros(n) -- diagonal
local e = zeros(n) -- off-diagonal (e[1] is a dummy value?)
for i = n, 2, -1 do
local l = i - 1
local ind_li = (l - 1) * n + i
local h = 0
local scale = 0
if l > 1 then
for k = 1, l do
local ind_ki = (k - 1) * n + i
scale = scale + abs(a[ind_ki])
end
if scale == 0 then
e[i] = a[ind_li]
else
for k = 1, l do
local ind_ki = (k - 1) * n + i
a[ind_ki] = a[ind_ki] / scale
h = h + a[ind_ki] * a[ind_ki]
end
local f = a[ind_li]
local g = sqrt(h)
if f >= 0 then g = -g end
e[i] = scale * g
h = h - f * g
a[ind_li] = f - g
f = 0
for j = 1, l do
local ind_ij = (i - 1) * n + j
local ind_ji = (j - 1) * n + i
a[ind_ij] = a[ind_ji] / h
g = 0
for k = 1, j do
local ind_kj = (k - 1) * n + j
local ind_ki = (k - 1) * n + i
g = g + a[ind_kj] * a[ind_ki]
end
for k = j + 1, l do
local ind_jk = (j - 1) * n + k
local ind_ki = (k - 1) * n + i
g = g + a[ind_jk] * a[ind_ki]
end
e[j] = g / h
f = f + e[j] * a[ind_ji]
end
local hh = f / (h + h)
for j = 1, l do
local ind_ji = (j - 1) * n + i
f = a[ind_ji]
g = e[j] - hh * f
e[j] = g
for k = 1, j do
local ind_kj = (k - 1) * n + j
local ind_ki = (k - 1) * n + i
a[ind_kj] = a[ind_kj] - (f * e[k] + g * a[ind_ki])
end
end
end
else
e[i] = a[ind_li]
end
d[i] = h
end
d[1] = 0
e[1] = 0
for i = 1, n do
local l = i - 1
if d[i] ~= 0 then
for j = 1, l do
local g = 0
for k = 1, l do
local ind_ki = (k - 1) * n + i
local ind_jk = (j - 1) * n + k
g = g + a[ind_ki] * a[ind_jk]
end
for k = 1, l do
local ind_ik = (i - 1) * n + k
local ind_jk = (j - 1) * n + k
a[ind_jk] = a[ind_jk] - g * a[ind_ik]
end
end
end
local ind_ii = (i - 1) * n + i
d[i] = a[ind_ii]
a[ind_ii] = 1
for j = 1, l do
local ind_ij = (i - 1) * n + j
local ind_ji = (j - 1) * n + i
a[ind_ij] = 0
a[ind_ji] = 0
end
end
return d, e
end
local function pythag(a, b)
--return sqrt(a * a + b * b)
local abs_a = abs(a)
local abs_b = abs(b)
if abs_a > abs_b then
local temp = abs_b / abs_a
temp = temp * temp
return abs_a * sqrt(1 + temp)
elseif abs_b ~= 0 then
local temp = abs_a / abs_b
temp = temp * temp
return abs_b * sqrt(1 + temp)
end
return 0
end
local function tqli(d, e, z)
assert(#z.shape == 2)
assert(z.shape[1] == z.shape[2])
local n = z.shape[1]
assert(#d == n)
assert(#e == n)
local eps = 1.2e-7
for i = 2, n do e[i - 1] = e[i] end
e[n] = 0
for l = 1, n do
local iter = 0
local fucky = 0
local m
while true do
m = l
while m <= n - 1 do
local dd = abs(d[m]) + abs(d[m + 1])
if abs(e[m]) + dd == dd then break end
--if abs(e[m]) <= eps * dd then break end
m = m + 1
end
fucky = fucky + 1
if fucky == 100 then print("fucky!"); break end
if fucky == 1000 then error("super fucky!"); break end
--print(("l: %i, m: %i"):format(l - 1, m - 1))
if m == l then break end
iter = iter + 1
if iter >= 32 then error("Too many iterations in tqli") end
local g = (d[l + 1] - d[l]) / (2 * e[l])
local r = pythag(g, 1)
g = d[m] - d[l] + e[l] / (g + r * sign(g))
local s = 1
local c = 1
local p = 0
for i = m - 1, l, -1 do
local f = s * e[i]
local b = c * e[i]
r = pythag(f, g)
e[i + 1] = r
if r == 0 then
d[i + 1] = d[i + 1] - p
e[m] = 0
break
end
s = f / r
c = g / r
g = d[i + 1] - p
r = (d[i] - g) * s + 2 * c * b
p = s * r
d[i + 1] = g + p
g = c * r - b
for k = 1, n do
if true then
local ind = (i - 1) * n + k
f = z[ind + n]
z[ind + n] = s * z[ind] + c * f
z[ind] = c * z[ind] - s * f
else
local ind = (k - 1) * n + i
f = z[ind + 1]
z[ind + 1] = s * z[ind] + c * f
z[ind] = c * z[ind] - s * f
end
end
end
if r == 0 and i >= l then
-- continue
else
d[l] = d[l] - p
e[l] = g
e[m] = 0
end
end
end
end
--[=[
local A = {
4, 1, -2, 2,
1, 2, 0, 1,
-2, 0, 3, -2,
2, 1, -2, -1,
}
reshape(A, 4, 4)
local d, e = tred2(A)
--[[
print(nn.pp(A))
print(nn.pp(d))
print(nn.pp(e))
--]]
--[[
{
0.248069, 0.744208, 0.620174, 0.000000,
0.702863, -0.578829, 0.413449, 0.000000,
0.666667, 0.333333, -0.666667, 0.000000,
0.000000, 0.000000, 0.000000, 1.000000,
}
{ 2.261538, 1.182906, 5.555556, -1.000000,}
{ 0.000000, -0.092308, 0.895806, 3.000000,}
--]]
--A = nn.transpose(A)
tqli(d, e, A)
print(nn.pp(A))
print(nn.pp(d))
print(nn.pp(e))
local D = zeros{4, 4}
for i = 1, 4 do
D[(i - 1) * 4 + i] = math.exp(d[i])
end
local out = nn.dot(nn.transpose(A), D)
out = nn.dot(out, A)
print(nn.pp(out))
--[[
{
703.414032, 1410.125991, 1478.990752, -43.126976,
-1205.204565, 1573.963121, -940.902581, 319.676625,
14.433478, 3.884400, -10.434671, 6.841011,
3.529384, 3.087068, -5.116236, -17.850223,
}
{ 2.273819, 1.072834, 6.818268, -2.164921,}
{ -0.000000, 0.000000, 0.000000, 0.000000,}
--]]
--]=]
return {
tred2=tred2,
tqli=tqli,
}