library: add util.py

This commit is contained in:
Connor Olding 2022-06-07 06:49:54 +02:00
parent 593c746b12
commit ded762da8d

335
library/util.py Normal file
View file

@ -0,0 +1,335 @@
# from collections import namedtuple
from functools import partial
try:
from plots import Plot, CleanPlot, CleanPlotUnity, SquareCenterPlot
except ModuleNotFoundError: # matplotlib might not be installed
pass
# NOTE: these comments were written before autograd died, RIP.
# i'll probably wind up using aesara, or JAX if i'm desperate,
# so this kind of thing will have to be done differently.
# TODO: define gradients by hand,
# import both numpy and autograd.numpy,
# use numpy for forward passes.
# TODO: add gradient-checking stuff.
# import autograd.numpy as np
import numpy as np
invphi = (np.sqrt(5) - 1) / 2 # 1/phi
invphi2 = (3 - np.sqrt(5)) / 2 # 1/phi^2
class rpartial(partial):
"""`rpartial` functions similarly to `functools.partial`, but
prepends additional arguments instead of appending them."""
# https://stackoverflow.com/a/11831662
def __call__(self, *args, **kwargs):
kw = self.keywords.copy()
kw.update(kwargs)
return self.func(*(args + self.args), **kwargs)
class Timer:
def __init__(self, title="Untitled timer", *, verbose=True, cumulative=False):
from time import time
self.title = str(title)
self.verbose = bool(verbose)
self.cumulative = bool(cumulative)
self.time = time
self.reset()
def reset(self):
self.then = None
self.now = None
self.cum = 0.0
def __enter__(self):
self.then = self.time()
def __exit__(self, exc_type, exc, exc_tb):
self.now = self.time()
if self.cumulative:
self.cum += self.now - self.then
if self.verbose:
self.print()
def print(self, f=None):
from sys import stderr
elapsed = self.cum if self.cumulative else self.now - self.then
f = stderr if f is None else f
print(f"{self.title:>30}: {elapsed:8.3f} seconds", file=stderr)
class CumTimer(Timer):
def __init__(self, title="Untitled cumulative timer"):
super().__init__(title, verbose=False, cumulative=True)
def div0(a, b):
"""Return the element-wise division of array `a` by array `b`,
whereby division by zero equals zero."""
a = np.asanyarray(a)
b = np.asanyarray(b)
with np.errstate(divide="ignore", invalid="ignore"):
c = np.true_divide(a, b)
if np.ndim(c) == 0: # if np.isscalar(c):
return c if np.isfinite(c) else c.dtype.type(0)
c[~np.isfinite(c)] = 0 # -inf, inf, and NaN get set
return c
def invsqrt(a):
"""Return the element-wise inverse square-root of the array `a`."""
return np.reciprocal(np.sqrt(a))
def sab(a, axis=None):
"""Return the sum of absolute values of
the array `a` along its axis `axis`."""
return np.sum(np.abs(a), axis=axis)
def ssq(a, axis=None):
"""Return the halved sum-of-squares of
the array `a` along its axis `axis`."""
return 0.5 * np.sum(np.square(a), axis=axis)
def rms(a, axis=None):
"""Return the root-mean-square of
the array `a` along its axis `axis`."""
return np.sqrt(np.mean(np.square(a), axis=axis))
def cossim(x, y):
"""Return the cosine-similarity of
the arrays `x` and `y`.
(The exact behavior of non-vector arguments is undefined for now.)"""
return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
class Regulator: # NOTE: should've been called Regularizer, whoops?
# TODO: support linear constraints.
def __init__(self, *args, r1=0.0, r2=0.0, r2s=0.0, rinf=0.0, ridges=None):
# this way i don't have to worry about arbitrary ordering:
assert len(args) == 0, "all arguments of Regulator must be given by keywords"
self.r1 = float(r1) # L1 aka Lasso
self.r2 = float(r2) # L2 aka Euclidian
self.r2s = float(r2s) # L2-squared aka Ridge
self.rinf = float(rinf) # L-infinity (L1-ball)
# "ridges" are just Ridge penalties with a non-zero bias,
# pulling the solution towards `point` by `coeff`.
# this is used in optimizers such as in arxiv:1801.02982.
# FIXME: but it isn't part of the proximal equation!
self.ridges = []
if ridges is not None:
for coeff, point in ridges:
self.push_ridge(coeff, point)
def push_ridge(self, coeff, point):
pair = (float(coeff), np.array(point, float, copy=False))
self.ridges.append(pair)
def copy(self):
"""creates a shallow copy of the Regulator."""
return Regulator(
r1=self.r1, r2=self.r2, r2s=self.r2s, rinf=self.rinf, ridges=self.ridges
)
def regulate(self, x):
# avoid using += here for autograd.
y = 0
if self.r1 > 0:
y = y + self.r1 * np.sum(np.abs(x))
if self.r2 > 0 or self.r2s > 0:
sum_of_squares = np.sum(np.square(x))
if self.r2 > 0:
y = y + self.r2 * np.sqrt(sum_of_squares)
if self.rinf > 0:
y = y + self.rinf * np.max(np.abs(x))
if self.r2s > 0:
y = y + self.r2s * 0.5 * sum_of_squares
for coeff, point in self.ridges:
y = y + coeff * ssq(x - point)
return y
def solve(self, x, g, rate=1.0): # solves the proximal
if rate > 0:
y = x - rate * g
else:
assert rate == 0, rate
return x
for coeff, point in self.ridges:
y += rate * coeff * point
if self.r1 > 0:
y = np.maximum(np.abs(y) - self.r1 * rate, 0) * np.sign(y)
if self.r2 > 0:
y = np.maximum(1 - rate * self.r2 / np.linalg.norm(y), 0) * y
if self.rinf > 0:
y = np.maximum(1 - rate * self.rinf / np.max(np.abs(y)), 0) * y
denominator = 1
if self.r2s > 0:
denominator += rate * self.r2s
for coeff, point in self.ridges:
denominator += rate * coeff
return y / denominator
class RandomIndices:
def __init__(self, n: int):
self.reset(n)
def reset(self, n=None):
if n is not None:
assert n >= 1, n
self.n = n
self.indices = np.arange(n)
self.i = 0
np.random.shuffle(self.indices)
def draw(self, n):
return zip(range(n), self)
def __next__(self):
ret = self.indices[self.i]
self.i += 1
if self.i == self.n:
self.i = 0
self.reset()
return ret
def __iter__(self):
return self
def estimate_condition(f, g, x, scale=1.0, iters=100):
"""Estimate the convexity and L-smoothness of the function `f`.
`g` should return the gradient of `f`.
Noise is centered around `x` with scale `scale`."""
# i.e. the condition of the hessian.
norm2 = lambda a: np.linalg.norm(a.ravel())
convexity, smoothness = np.inf, 0
fx, gx = f(x), g(x)
for i in range(iters):
y = x + np.random.normal(0, scale, size=x.shape)
fy, gy = f(y), g(y)
if 0:
# f(y) >= f(x) + g(x) @ (y - x) + σ * ssq(x - y)
σ = (fy - fx - np.dot(gx, (y - x))) / ssq(y - x)
if σ < convexity:
convexity = σ
# norm2(g(x) - g(y)) <= L * norm2(x - y)
L = norm2(gx - gy) / norm2(y - x)
# NOTE: i think the sqrt in norm2 can be pulled out like:
# np.sqrt(np.sum(np.square(gx - gy)) / np.sum(np.square(x - y)))
if L > smoothness:
smoothness = L
else:
# (gx - gy) @ (x - y) >= σ * norm2(x - y)**2
blah = (gx - gy) @ (x - y) / np.sum(np.square(x - y))
if blah < convexity:
convexity = blah
# f(y) >= f(x) + dot(g(x), y - x) + σ/2 * sum(square(x - y))
# TODO: is this still correct? it's a different definition...
if blah > smoothness:
smoothness = blah
return convexity, smoothness
def minimum_enclosing_ball(xa, xb, ra, rb):
"""Given the *squared* radii `ra` and `rb` of balls
centered at `xa` and `xb` respectively,
return the center and squared radius of a ball
encompassing their (possibly negative) intersection."""
xa, xb = np.asfarray(xa), np.asfarray(xb)
delnorm2 = np.sum(np.square(xa - xb))
if delnorm2 >= abs(ra - rb):
xc = ((xa + xb) - (ra - rb) / delnorm2) / 2
rc = np.square(delnorm2 + rb - ra) / delnorm2 / 4
return xc, rc
elif delnorm2 < ra - rb:
return xb, rb
else:
return xa, ra
def gss(f, a, b, tol=1e-5): # golden section search
"""Given a function `f` with a single local minimum in
the interval [`a`, `b`], return a subset interval
[`c`, `d`] that contains the minimum with `d - c <= tol`."""
# via https://en.wikipedia.org/wiki/Golden-section_search
a, b = min(a, b), max(a, b)
h = b - a
if h <= tol:
return (a, b)
# required steps to achieve tolerance:
n = int(np.ceil(np.log(tol / h) / np.log(invphi)))
c = a + invphi2 * h
d = a + invphi * h
y_c = f(c)
y_d = f(d)
for k in range(n - 1):
if y_c < y_d:
b, d, y_d = d, c, y_c
h = invphi * h
c = a + invphi2 * h
y_c = f(c)
else:
a, c, y_c = c, d, y_d
h = invphi * h
d = a + invphi * h
y_d = f(d)
if y_c < y_d:
return (a, d)
else:
return (c, b)
def gssm(f, a, b, tol=1e-5): # golden section search (scalar result)
return np.mean(gss(f, a, b, tol=tol))
__all__ = """
Plot
CleanPlot
CleanPlotUnity
SquareCenterPlot
RandomIndices
Regulator
partial
rpartial
invsqrt
sab
ssq
rms
cossim
estimate_condition
minimum_enclosing_ball
gss
""".split()