82 lines
1.9 KiB
Lua
82 lines
1.9 KiB
Lua
-- here's a really awful way of computing the matrix exponential.
|
|
-- we employ the QR algorithm to find eigenpairs of a given symmetric matrix,
|
|
-- then run the ordinary exponent function over the eigenvalues.
|
|
-- this only works for symmetric matrices!
|
|
|
|
local nn = require "nn"
|
|
local qr = require "qr"
|
|
local util = require "util"
|
|
|
|
local copy = util.copy
|
|
local dot = nn.dot
|
|
local exp = math.exp
|
|
local reshape = nn.reshape
|
|
local transpose = nn.transpose
|
|
local zeros = nn.zeros
|
|
|
|
local function expm(mat)
|
|
assert(#mat.shape == 2)
|
|
assert(mat.shape[1] == mat.shape[2], "expm input must be square")
|
|
--assert(stuff(mat), "expm input must be symmetrical")
|
|
|
|
local dims = mat.shape[1]
|
|
|
|
local vec = zeros(mat.shape)
|
|
for i = 1, dims do
|
|
local ind = (i - 1) * dims + i -- diagonal
|
|
vec[ind] = 1
|
|
end
|
|
|
|
local diag = mat
|
|
for i = 1, 10 do
|
|
local q, r = qr(diag)
|
|
vec = dot(vec, q)
|
|
diag = dot(r, q)
|
|
end
|
|
|
|
for y = 1, dims do
|
|
for x = 1, dims do
|
|
local ind = (y - 1) * dims + x
|
|
if x == y then
|
|
diag[ind] = exp(diag[ind])
|
|
else
|
|
diag[ind] = 0
|
|
end
|
|
end
|
|
end
|
|
|
|
return dot(dot(vec, diag), transpose(vec))
|
|
end
|
|
|
|
local eig = require "eig"
|
|
local tred2 = eig.tred2
|
|
local tqli = eig.tqli
|
|
|
|
local function expm2(mat)
|
|
assert(#mat.shape == 2)
|
|
assert(mat.shape[1] == mat.shape[2])
|
|
local dims = mat.shape[1]
|
|
|
|
-- new version that computes much better (faster?) eigenpairs
|
|
local vec = copy(mat)
|
|
local d, e = tred2(vec)
|
|
tqli(d, e, vec)
|
|
|
|
local diag = {}
|
|
for y = 1, dims do
|
|
for x = 1, dims do
|
|
local ind = (y - 1) * dims + x
|
|
if x == y then
|
|
diag[ind] = exp(d[x])
|
|
else
|
|
diag[ind] = 0
|
|
end
|
|
end
|
|
end
|
|
reshape(diag, dims, dims)
|
|
|
|
return dot(dot(transpose(vec), diag), vec)
|
|
end
|
|
|
|
return expm2
|