backyard/library/util.py

335 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# from collections import namedtuple
from functools import partial
try:
from plots import Plot, CleanPlot, CleanPlotUnity, SquareCenterPlot
except ModuleNotFoundError: # matplotlib might not be installed
pass
# NOTE: these comments were written before autograd died, RIP.
# i'll probably wind up using aesara, or JAX if i'm desperate,
# so this kind of thing will have to be done differently.
# TODO: define gradients by hand,
# import both numpy and autograd.numpy,
# use numpy for forward passes.
# TODO: add gradient-checking stuff.
# import autograd.numpy as np
import numpy as np
invphi = (np.sqrt(5) - 1) / 2 # 1/phi
invphi2 = (3 - np.sqrt(5)) / 2 # 1/phi^2
class rpartial(partial):
"""`rpartial` functions similarly to `functools.partial`, but
prepends additional arguments instead of appending them."""
# https://stackoverflow.com/a/11831662
def __call__(self, *args, **kwargs):
kw = self.keywords.copy()
kw.update(kwargs)
return self.func(*(args + self.args), **kwargs)
class Timer:
def __init__(self, title="Untitled timer", *, verbose=True, cumulative=False):
from time import time
self.title = str(title)
self.verbose = bool(verbose)
self.cumulative = bool(cumulative)
self.time = time
self.reset()
def reset(self):
self.then = None
self.now = None
self.cum = 0.0
def __enter__(self):
self.then = self.time()
def __exit__(self, exc_type, exc, exc_tb):
self.now = self.time()
if self.cumulative:
self.cum += self.now - self.then
if self.verbose:
self.print()
def print(self, f=None):
from sys import stderr
elapsed = self.cum if self.cumulative else self.now - self.then
f = stderr if f is None else f
print(f"{self.title:>30}: {elapsed:8.3f} seconds", file=stderr)
class CumTimer(Timer):
def __init__(self, title="Untitled cumulative timer"):
super().__init__(title, verbose=False, cumulative=True)
def div0(a, b):
"""Return the element-wise division of array `a` by array `b`,
whereby division by zero equals zero."""
a = np.asanyarray(a)
b = np.asanyarray(b)
with np.errstate(divide="ignore", invalid="ignore"):
c = np.true_divide(a, b)
if np.ndim(c) == 0: # if np.isscalar(c):
return c if np.isfinite(c) else c.dtype.type(0)
c[~np.isfinite(c)] = 0 # -inf, inf, and NaN get set
return c
def invsqrt(a):
"""Return the element-wise inverse square-root of the array `a`."""
return np.reciprocal(np.sqrt(a))
def sab(a, axis=None):
"""Return the sum of absolute values of
the array `a` along its axis `axis`."""
return np.sum(np.abs(a), axis=axis)
def ssq(a, axis=None):
"""Return the halved sum-of-squares of
the array `a` along its axis `axis`."""
return 0.5 * np.sum(np.square(a), axis=axis)
def rms(a, axis=None):
"""Return the root-mean-square of
the array `a` along its axis `axis`."""
return np.sqrt(np.mean(np.square(a), axis=axis))
def cossim(x, y):
"""Return the cosine-similarity of
the arrays `x` and `y`.
(The exact behavior of non-vector arguments is undefined for now.)"""
return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
class Regulator: # NOTE: should've been called Regularizer, whoops?
# TODO: support linear constraints.
def __init__(self, *args, r1=0.0, r2=0.0, r2s=0.0, rinf=0.0, ridges=None):
# this way i don't have to worry about arbitrary ordering:
assert len(args) == 0, "all arguments of Regulator must be given by keywords"
self.r1 = float(r1) # L1 aka Lasso
self.r2 = float(r2) # L2 aka Euclidian
self.r2s = float(r2s) # L2-squared aka Ridge
self.rinf = float(rinf) # L-infinity (L1-ball)
# "ridges" are just Ridge penalties with a non-zero bias,
# pulling the solution towards `point` by `coeff`.
# this is used in optimizers such as in arxiv:1801.02982.
# FIXME: but it isn't part of the proximal equation!
self.ridges = []
if ridges is not None:
for coeff, point in ridges:
self.push_ridge(coeff, point)
def push_ridge(self, coeff, point):
pair = (float(coeff), np.array(point, float, copy=False))
self.ridges.append(pair)
def copy(self):
"""creates a shallow copy of the Regulator."""
return Regulator(
r1=self.r1, r2=self.r2, r2s=self.r2s, rinf=self.rinf, ridges=self.ridges
)
def regulate(self, x):
# avoid using += here for autograd.
y = 0
if self.r1 > 0:
y = y + self.r1 * np.sum(np.abs(x))
if self.r2 > 0 or self.r2s > 0:
sum_of_squares = np.sum(np.square(x))
if self.r2 > 0:
y = y + self.r2 * np.sqrt(sum_of_squares)
if self.rinf > 0:
y = y + self.rinf * np.max(np.abs(x))
if self.r2s > 0:
y = y + self.r2s * 0.5 * sum_of_squares
for coeff, point in self.ridges:
y = y + coeff * ssq(x - point)
return y
def solve(self, x, g, rate=1.0): # solves the proximal
if rate > 0:
y = x - rate * g
else:
assert rate == 0, rate
return x
for coeff, point in self.ridges:
y += rate * coeff * point
if self.r1 > 0:
y = np.maximum(np.abs(y) - self.r1 * rate, 0) * np.sign(y)
if self.r2 > 0:
y = np.maximum(1 - rate * self.r2 / np.linalg.norm(y), 0) * y
if self.rinf > 0:
y = np.maximum(1 - rate * self.rinf / np.max(np.abs(y)), 0) * y
denominator = 1
if self.r2s > 0:
denominator += rate * self.r2s
for coeff, point in self.ridges:
denominator += rate * coeff
return y / denominator
class RandomIndices:
def __init__(self, n: int):
self.reset(n)
def reset(self, n=None):
if n is not None:
assert n >= 1, n
self.n = n
self.indices = np.arange(n)
self.i = 0
np.random.shuffle(self.indices)
def draw(self, n):
return zip(range(n), self)
def __next__(self):
ret = self.indices[self.i]
self.i += 1
if self.i == self.n:
self.i = 0
self.reset()
return ret
def __iter__(self):
return self
def estimate_condition(f, g, x, scale=1.0, iters=100):
"""Estimate the convexity and L-smoothness of the function `f`.
`g` should return the gradient of `f`.
Noise is centered around `x` with scale `scale`."""
# i.e. the condition of the hessian.
norm2 = lambda a: np.linalg.norm(a.ravel())
convexity, smoothness = np.inf, 0
fx, gx = f(x), g(x)
for i in range(iters):
y = x + np.random.normal(0, scale, size=x.shape)
fy, gy = f(y), g(y)
if 0:
# f(y) >= f(x) + g(x) @ (y - x) + σ * ssq(x - y)
σ = (fy - fx - np.dot(gx, (y - x))) / ssq(y - x)
if σ < convexity:
convexity = σ
# norm2(g(x) - g(y)) <= L * norm2(x - y)
L = norm2(gx - gy) / norm2(y - x)
# NOTE: i think the sqrt in norm2 can be pulled out like:
# np.sqrt(np.sum(np.square(gx - gy)) / np.sum(np.square(x - y)))
if L > smoothness:
smoothness = L
else:
# (gx - gy) @ (x - y) >= σ * norm2(x - y)**2
blah = (gx - gy) @ (x - y) / np.sum(np.square(x - y))
if blah < convexity:
convexity = blah
# f(y) >= f(x) + dot(g(x), y - x) + σ/2 * sum(square(x - y))
# TODO: is this still correct? it's a different definition...
if blah > smoothness:
smoothness = blah
return convexity, smoothness
def minimum_enclosing_ball(xa, xb, ra, rb):
"""Given the *squared* radii `ra` and `rb` of balls
centered at `xa` and `xb` respectively,
return the center and squared radius of a ball
encompassing their (possibly negative) intersection."""
xa, xb = np.asfarray(xa), np.asfarray(xb)
delnorm2 = np.sum(np.square(xa - xb))
if delnorm2 >= abs(ra - rb):
xc = ((xa + xb) - (ra - rb) / delnorm2) / 2
rc = np.square(delnorm2 + rb - ra) / delnorm2 / 4
return xc, rc
elif delnorm2 < ra - rb:
return xb, rb
else:
return xa, ra
def gss(f, a, b, tol=1e-5): # golden section search
"""Given a function `f` with a single local minimum in
the interval [`a`, `b`], return a subset interval
[`c`, `d`] that contains the minimum with `d - c <= tol`."""
# via https://en.wikipedia.org/wiki/Golden-section_search
a, b = min(a, b), max(a, b)
h = b - a
if h <= tol:
return (a, b)
# required steps to achieve tolerance:
n = int(np.ceil(np.log(tol / h) / np.log(invphi)))
c = a + invphi2 * h
d = a + invphi * h
y_c = f(c)
y_d = f(d)
for k in range(n - 1):
if y_c < y_d:
b, d, y_d = d, c, y_c
h = invphi * h
c = a + invphi2 * h
y_c = f(c)
else:
a, c, y_c = c, d, y_d
h = invphi * h
d = a + invphi * h
y_d = f(d)
if y_c < y_d:
return (a, d)
else:
return (c, b)
def gssm(f, a, b, tol=1e-5): # golden section search (scalar result)
return np.mean(gss(f, a, b, tol=tol))
__all__ = """
Plot
CleanPlot
CleanPlotUnity
SquareCenterPlot
RandomIndices
Regulator
partial
rpartial
invsqrt
sab
ssq
rms
cossim
estimate_condition
minimum_enclosing_ball
gss
""".split()