library: add finite_difference.py
This commit is contained in:
parent
7dec19baf4
commit
0285213be8
1 changed files with 69 additions and 0 deletions
69
library/finite_difference.py
Normal file
69
library/finite_difference.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env python3
|
||||
# implementing: https://web.media.mit.edu/~crtaylor/calculator.html
|
||||
|
||||
import numpy as np
|
||||
|
||||
# from math import factorial
|
||||
_factorials = ( # all the factorials that fit into 53 bits
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
6,
|
||||
24,
|
||||
120,
|
||||
720,
|
||||
5040,
|
||||
40320,
|
||||
362880,
|
||||
3628800,
|
||||
39916800,
|
||||
479001600,
|
||||
6227020800,
|
||||
87178291200,
|
||||
1307674368000,
|
||||
20922789888000,
|
||||
355687428096000,
|
||||
6402373705728000,
|
||||
)
|
||||
|
||||
__noisy = False
|
||||
|
||||
|
||||
def fd(offsets, degree):
|
||||
if degree < 0:
|
||||
raise RuntimeError("the degree must be greater than or equal to zero")
|
||||
elif degree > 18:
|
||||
raise RuntimeError("degrees greater than 18 are unsupported")
|
||||
if len(offsets) <= degree:
|
||||
raise RuntimeError("the number of offsets must be greater than the degree")
|
||||
mat = np.array([[o**n for o in offsets] for n in range(len(offsets))])
|
||||
vec = np.array([_factorials[i] if i == degree else 0 for i in range(len(offsets))])
|
||||
if __noisy:
|
||||
print("mat:", mat, "vec:", vec, sep="\n")
|
||||
res = np.linalg.lstsq(mat, vec, rcond=None)[0]
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
__noisy = True
|
||||
assert np.allclose(fd([-2, 0, 2], 1), [-1 / 4, 0, 1 / 4])
|
||||
|
||||
# + A * f(x - 2 * h)
|
||||
# + B * f(x - 1 * h)
|
||||
# + C * f(x)
|
||||
# + D * f(x + 1 * h)
|
||||
# + E * f(x + 2 * h)
|
||||
assert np.allclose(fd([-2, -1, 0, 1, 2], 4), [1, -4, 6, -4, 1])
|
||||
|
||||
print("", "#" * 40, "", sep="\n")
|
||||
|
||||
ans = fd([-2, 7], 1)
|
||||
print("ans:", ans)
|
||||
|
||||
ans = fd([-2, 0, 7], 1)
|
||||
print("ans:", ans)
|
||||
|
||||
ans = fd([-0.2, 0.0, 0.7], 1)
|
||||
print("ans:", ans)
|
||||
|
||||
__all__ = ("fd",)
|
Loading…
Reference in a new issue