library: add util.py
This commit is contained in:
parent
593c746b12
commit
ded762da8d
1 changed files with 335 additions and 0 deletions
335
library/util.py
Normal file
335
library/util.py
Normal file
|
@ -0,0 +1,335 @@
|
|||
# from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
try:
|
||||
from plots import Plot, CleanPlot, CleanPlotUnity, SquareCenterPlot
|
||||
except ModuleNotFoundError: # matplotlib might not be installed
|
||||
pass
|
||||
|
||||
# NOTE: these comments were written before autograd died, RIP.
|
||||
# i'll probably wind up using aesara, or JAX if i'm desperate,
|
||||
# so this kind of thing will have to be done differently.
|
||||
# TODO: define gradients by hand,
|
||||
# import both numpy and autograd.numpy,
|
||||
# use numpy for forward passes.
|
||||
# TODO: add gradient-checking stuff.
|
||||
# import autograd.numpy as np
|
||||
import numpy as np
|
||||
|
||||
|
||||
invphi = (np.sqrt(5) - 1) / 2 # 1/phi
|
||||
invphi2 = (3 - np.sqrt(5)) / 2 # 1/phi^2
|
||||
|
||||
|
||||
class rpartial(partial):
|
||||
"""`rpartial` functions similarly to `functools.partial`, but
|
||||
prepends additional arguments instead of appending them."""
|
||||
|
||||
# https://stackoverflow.com/a/11831662
|
||||
def __call__(self, *args, **kwargs):
|
||||
kw = self.keywords.copy()
|
||||
kw.update(kwargs)
|
||||
return self.func(*(args + self.args), **kwargs)
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self, title="Untitled timer", *, verbose=True, cumulative=False):
|
||||
from time import time
|
||||
|
||||
self.title = str(title)
|
||||
self.verbose = bool(verbose)
|
||||
self.cumulative = bool(cumulative)
|
||||
self.time = time
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.then = None
|
||||
self.now = None
|
||||
self.cum = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
self.then = self.time()
|
||||
|
||||
def __exit__(self, exc_type, exc, exc_tb):
|
||||
self.now = self.time()
|
||||
if self.cumulative:
|
||||
self.cum += self.now - self.then
|
||||
if self.verbose:
|
||||
self.print()
|
||||
|
||||
def print(self, f=None):
|
||||
from sys import stderr
|
||||
|
||||
elapsed = self.cum if self.cumulative else self.now - self.then
|
||||
f = stderr if f is None else f
|
||||
print(f"{self.title:>30}: {elapsed:8.3f} seconds", file=stderr)
|
||||
|
||||
|
||||
class CumTimer(Timer):
|
||||
def __init__(self, title="Untitled cumulative timer"):
|
||||
super().__init__(title, verbose=False, cumulative=True)
|
||||
|
||||
|
||||
def div0(a, b):
|
||||
"""Return the element-wise division of array `a` by array `b`,
|
||||
whereby division by zero equals zero."""
|
||||
a = np.asanyarray(a)
|
||||
b = np.asanyarray(b)
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
c = np.true_divide(a, b)
|
||||
if np.ndim(c) == 0: # if np.isscalar(c):
|
||||
return c if np.isfinite(c) else c.dtype.type(0)
|
||||
c[~np.isfinite(c)] = 0 # -inf, inf, and NaN get set
|
||||
return c
|
||||
|
||||
|
||||
def invsqrt(a):
|
||||
"""Return the element-wise inverse square-root of the array `a`."""
|
||||
return np.reciprocal(np.sqrt(a))
|
||||
|
||||
|
||||
def sab(a, axis=None):
|
||||
"""Return the sum of absolute values of
|
||||
the array `a` along its axis `axis`."""
|
||||
return np.sum(np.abs(a), axis=axis)
|
||||
|
||||
|
||||
def ssq(a, axis=None):
|
||||
"""Return the halved sum-of-squares of
|
||||
the array `a` along its axis `axis`."""
|
||||
return 0.5 * np.sum(np.square(a), axis=axis)
|
||||
|
||||
|
||||
def rms(a, axis=None):
|
||||
"""Return the root-mean-square of
|
||||
the array `a` along its axis `axis`."""
|
||||
return np.sqrt(np.mean(np.square(a), axis=axis))
|
||||
|
||||
|
||||
def cossim(x, y):
|
||||
"""Return the cosine-similarity of
|
||||
the arrays `x` and `y`.
|
||||
(The exact behavior of non-vector arguments is undefined for now.)"""
|
||||
return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
|
||||
|
||||
|
||||
class Regulator: # NOTE: should've been called Regularizer, whoops?
|
||||
# TODO: support linear constraints.
|
||||
|
||||
def __init__(self, *args, r1=0.0, r2=0.0, r2s=0.0, rinf=0.0, ridges=None):
|
||||
# this way i don't have to worry about arbitrary ordering:
|
||||
assert len(args) == 0, "all arguments of Regulator must be given by keywords"
|
||||
|
||||
self.r1 = float(r1) # L1 aka Lasso
|
||||
self.r2 = float(r2) # L2 aka Euclidian
|
||||
self.r2s = float(r2s) # L2-squared aka Ridge
|
||||
self.rinf = float(rinf) # L-infinity (L1-ball)
|
||||
|
||||
# "ridges" are just Ridge penalties with a non-zero bias,
|
||||
# pulling the solution towards `point` by `coeff`.
|
||||
# this is used in optimizers such as in arxiv:1801.02982.
|
||||
# FIXME: but it isn't part of the proximal equation!
|
||||
self.ridges = []
|
||||
if ridges is not None:
|
||||
for coeff, point in ridges:
|
||||
self.push_ridge(coeff, point)
|
||||
|
||||
def push_ridge(self, coeff, point):
|
||||
pair = (float(coeff), np.array(point, float, copy=False))
|
||||
self.ridges.append(pair)
|
||||
|
||||
def copy(self):
|
||||
"""creates a shallow copy of the Regulator."""
|
||||
return Regulator(
|
||||
r1=self.r1, r2=self.r2, r2s=self.r2s, rinf=self.rinf, ridges=self.ridges
|
||||
)
|
||||
|
||||
def regulate(self, x):
|
||||
# avoid using += here for autograd.
|
||||
y = 0
|
||||
if self.r1 > 0:
|
||||
y = y + self.r1 * np.sum(np.abs(x))
|
||||
if self.r2 > 0 or self.r2s > 0:
|
||||
sum_of_squares = np.sum(np.square(x))
|
||||
if self.r2 > 0:
|
||||
y = y + self.r2 * np.sqrt(sum_of_squares)
|
||||
if self.rinf > 0:
|
||||
y = y + self.rinf * np.max(np.abs(x))
|
||||
|
||||
if self.r2s > 0:
|
||||
y = y + self.r2s * 0.5 * sum_of_squares
|
||||
for coeff, point in self.ridges:
|
||||
y = y + coeff * ssq(x - point)
|
||||
return y
|
||||
|
||||
def solve(self, x, g, rate=1.0): # solves the proximal
|
||||
if rate > 0:
|
||||
y = x - rate * g
|
||||
else:
|
||||
assert rate == 0, rate
|
||||
return x
|
||||
for coeff, point in self.ridges:
|
||||
y += rate * coeff * point
|
||||
|
||||
if self.r1 > 0:
|
||||
y = np.maximum(np.abs(y) - self.r1 * rate, 0) * np.sign(y)
|
||||
if self.r2 > 0:
|
||||
y = np.maximum(1 - rate * self.r2 / np.linalg.norm(y), 0) * y
|
||||
if self.rinf > 0:
|
||||
y = np.maximum(1 - rate * self.rinf / np.max(np.abs(y)), 0) * y
|
||||
|
||||
denominator = 1
|
||||
if self.r2s > 0:
|
||||
denominator += rate * self.r2s
|
||||
for coeff, point in self.ridges:
|
||||
denominator += rate * coeff
|
||||
return y / denominator
|
||||
|
||||
|
||||
class RandomIndices:
|
||||
def __init__(self, n: int):
|
||||
self.reset(n)
|
||||
|
||||
def reset(self, n=None):
|
||||
if n is not None:
|
||||
assert n >= 1, n
|
||||
self.n = n
|
||||
self.indices = np.arange(n)
|
||||
self.i = 0
|
||||
np.random.shuffle(self.indices)
|
||||
|
||||
def draw(self, n):
|
||||
return zip(range(n), self)
|
||||
|
||||
def __next__(self):
|
||||
ret = self.indices[self.i]
|
||||
self.i += 1
|
||||
if self.i == self.n:
|
||||
self.i = 0
|
||||
self.reset()
|
||||
return ret
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
def estimate_condition(f, g, x, scale=1.0, iters=100):
|
||||
"""Estimate the convexity and L-smoothness of the function `f`.
|
||||
`g` should return the gradient of `f`.
|
||||
Noise is centered around `x` with scale `scale`."""
|
||||
|
||||
# i.e. the condition of the hessian.
|
||||
|
||||
norm2 = lambda a: np.linalg.norm(a.ravel())
|
||||
convexity, smoothness = np.inf, 0
|
||||
|
||||
fx, gx = f(x), g(x)
|
||||
for i in range(iters):
|
||||
y = x + np.random.normal(0, scale, size=x.shape)
|
||||
fy, gy = f(y), g(y)
|
||||
|
||||
if 0:
|
||||
# f(y) >= f(x) + g(x) @ (y - x) + σ * ssq(x - y)
|
||||
σ = (fy - fx - np.dot(gx, (y - x))) / ssq(y - x)
|
||||
if σ < convexity:
|
||||
convexity = σ
|
||||
|
||||
# norm2(g(x) - g(y)) <= L * norm2(x - y)
|
||||
L = norm2(gx - gy) / norm2(y - x)
|
||||
# NOTE: i think the sqrt in norm2 can be pulled out like:
|
||||
# np.sqrt(np.sum(np.square(gx - gy)) / np.sum(np.square(x - y)))
|
||||
if L > smoothness:
|
||||
smoothness = L
|
||||
else:
|
||||
# (gx - gy) @ (x - y) >= σ * norm2(x - y)**2
|
||||
blah = (gx - gy) @ (x - y) / np.sum(np.square(x - y))
|
||||
if blah < convexity:
|
||||
convexity = blah
|
||||
# f(y) >= f(x) + dot(g(x), y - x) + σ/2 * sum(square(x - y))
|
||||
# TODO: is this still correct? it's a different definition...
|
||||
if blah > smoothness:
|
||||
smoothness = blah
|
||||
|
||||
return convexity, smoothness
|
||||
|
||||
|
||||
def minimum_enclosing_ball(xa, xb, ra, rb):
|
||||
"""Given the *squared* radii `ra` and `rb` of balls
|
||||
centered at `xa` and `xb` respectively,
|
||||
return the center and squared radius of a ball
|
||||
encompassing their (possibly negative) intersection."""
|
||||
|
||||
xa, xb = np.asfarray(xa), np.asfarray(xb)
|
||||
delnorm2 = np.sum(np.square(xa - xb))
|
||||
if delnorm2 >= abs(ra - rb):
|
||||
xc = ((xa + xb) - (ra - rb) / delnorm2) / 2
|
||||
rc = np.square(delnorm2 + rb - ra) / delnorm2 / 4
|
||||
return xc, rc
|
||||
elif delnorm2 < ra - rb:
|
||||
return xb, rb
|
||||
else:
|
||||
return xa, ra
|
||||
|
||||
|
||||
def gss(f, a, b, tol=1e-5): # golden section search
|
||||
"""Given a function `f` with a single local minimum in
|
||||
the interval [`a`, `b`], return a subset interval
|
||||
[`c`, `d`] that contains the minimum with `d - c <= tol`."""
|
||||
# via https://en.wikipedia.org/wiki/Golden-section_search
|
||||
|
||||
a, b = min(a, b), max(a, b)
|
||||
h = b - a
|
||||
if h <= tol:
|
||||
return (a, b)
|
||||
|
||||
# required steps to achieve tolerance:
|
||||
n = int(np.ceil(np.log(tol / h) / np.log(invphi)))
|
||||
|
||||
c = a + invphi2 * h
|
||||
d = a + invphi * h
|
||||
y_c = f(c)
|
||||
y_d = f(d)
|
||||
|
||||
for k in range(n - 1):
|
||||
if y_c < y_d:
|
||||
b, d, y_d = d, c, y_c
|
||||
h = invphi * h
|
||||
c = a + invphi2 * h
|
||||
y_c = f(c)
|
||||
else:
|
||||
a, c, y_c = c, d, y_d
|
||||
h = invphi * h
|
||||
d = a + invphi * h
|
||||
y_d = f(d)
|
||||
|
||||
if y_c < y_d:
|
||||
return (a, d)
|
||||
else:
|
||||
return (c, b)
|
||||
|
||||
|
||||
def gssm(f, a, b, tol=1e-5): # golden section search (scalar result)
|
||||
return np.mean(gss(f, a, b, tol=tol))
|
||||
|
||||
|
||||
__all__ = """
|
||||
Plot
|
||||
CleanPlot
|
||||
CleanPlotUnity
|
||||
SquareCenterPlot
|
||||
|
||||
RandomIndices
|
||||
Regulator
|
||||
|
||||
partial
|
||||
rpartial
|
||||
invsqrt
|
||||
sab
|
||||
ssq
|
||||
rms
|
||||
cossim
|
||||
|
||||
estimate_condition
|
||||
minimum_enclosing_ball
|
||||
gss
|
||||
""".split()
|
Loading…
Reference in a new issue