smbot/expm.lua
2019-03-11 07:15:41 +01:00

82 lines
1.9 KiB
Lua

-- here's a really awful way of computing the matrix exponential.
-- we employ the QR algorithm to find eigenpairs of a given symmetric matrix,
-- then run the ordinary exponent function over the eigenvalues.
-- this only works for symmetric matrices!
local nn = require "nn"
local qr = require "qr"
local util = require "util"
local copy = util.copy
local dot = nn.dot
local exp = math.exp
local reshape = nn.reshape
local transpose = nn.transpose
local zeros = nn.zeros
local function expm(mat)
assert(#mat.shape == 2)
assert(mat.shape[1] == mat.shape[2], "expm input must be square")
--assert(stuff(mat), "expm input must be symmetrical")
local dims = mat.shape[1]
local vec = zeros(mat.shape)
for i = 1, dims do
local ind = (i - 1) * dims + i -- diagonal
vec[ind] = 1
end
local diag = mat
for i = 1, 10 do
local q, r = qr(diag)
vec = dot(vec, q)
diag = dot(r, q)
end
for y = 1, dims do
for x = 1, dims do
local ind = (y - 1) * dims + x
if x == y then
diag[ind] = exp(diag[ind])
else
diag[ind] = 0
end
end
end
return dot(dot(vec, diag), transpose(vec))
end
local eig = require "eig"
local tred2 = eig.tred2
local tqli = eig.tqli
local function expm2(mat)
assert(#mat.shape == 2)
assert(mat.shape[1] == mat.shape[2])
local dims = mat.shape[1]
-- new version that computes much better (faster?) eigenpairs
local vec = copy(mat)
local d, e = tred2(vec)
tqli(d, e, vec)
local diag = {}
for y = 1, dims do
for x = 1, dims do
local ind = (y - 1) * dims + x
if x == y then
diag[ind] = exp(d[x])
else
diag[ind] = 0
end
end
end
reshape(diag, dims, dims)
return dot(dot(transpose(vec), diag), vec)
end
return expm2