backyard/bitten/bitten.py
2022-06-07 06:54:47 +02:00

738 lines
22 KiB
Python

#!/usr/bin/env python3
# prepend underscores so these don't leak to __all__.
from hashlib import sha256 as _sha256
from itertools import chain as _chain
from os import environ as _environ
from pathlib import Path as _Path
from traceback import print_exception as _print_exception
import ast
import inspect
import numpy as np
# NOTE: the environment variable "BITTEN_CACHE" takes precedence here.
default_cache = "bitten.cache"
optimizer_defaults = dict(tol="weak", depth=1, attempts=1)
def pack(*args):
"""Return positional arguments as a tuple."""
return args
def _is_in(needle, *haystack):
# like the "in" keyword, but uses "is" instead of "==".
for thing in haystack:
if thing is needle:
return True
return False
def _hashable(a): # currently only used for caching adjustments
if isinstance(a, np.ndarray):
if a.ndim > 0:
# key = tuple(a) # probably won't work for ndim >= 2
key = a.shape + tuple(a.ravel())
else:
key = a.item()
else:
key = a # just hope for the best
return key
def _promote_all(*dtypes):
# new_dtypes = [np.promote_types(t, _promote_all(*dtypes)) for t in dtypes]
a, dtypes = dtypes[0], dtypes[1:] # TODO: use next(iter(dtypes)) pattern instead?
new_dtypes = [np.promote_types(a, b) for b in dtypes]
while len(new_dtypes) > 1:
a, dtypes = new_dtypes[0], new_dtypes[1:]
new_dtypes = [np.promote_types(a, b) for b in dtypes]
return new_dtypes[0]
def _promote_equal(a, b, *etc):
a = np.asanyarray(a)
b = np.asanyarray(b)
if etc:
etc_arrays = (np.asanyarray(e) for e in etc)
etc_dtypes = (e.dtype for e in etc_arrays)
all_arrays = (a, b, *etc_arrays)
all_dtypes = (a.dtype, b.dtype, *etc_dtypes)
promo = _promote_all(*all_dtypes)
return tuple(a.astype(promo) for a in all_arrays)
if a.dtype != b.dtype:
promo = np.promote_types(a.dtype, b.dtype)
a = a.astype(promo)
b = b.astype(promo)
assert a.dtype == b.dtype
return a, b
def _remove_docstring(node):
# via: https://stackoverflow.com/a/49998190
# remove all the doc strings in a FunctionDef or ClassDef as node.
if not (isinstance(node, ast.FunctionDef) or isinstance(node, ast.ClassDef)):
return
if len(node.body) != 0:
docstr = node.body[0]
if isinstance(docstr, ast.Expr) and isinstance(docstr.value, ast.Str):
node.body.pop(0)
def _prepare_function(func):
# TODO: describe what this does and why.
# adapted from: https://stackoverflow.com/a/49998190
# TODO: (optionally) strip the initial decorator from the AST as well.
func_str = inspect.getsource(func)
# Remove leading indentation, if any.
indent = len(func_str[: len(func_str) - len(func_str.lstrip())])
if indent:
func_str = "\n".join(line[indent:] for line in func_str.splitlines())
module = ast.parse(func_str)
assert len(module.body) == 1 and isinstance(module.body[0], ast.FunctionDef)
# Clear function name so it doesn't affect the hash.
func_node = module.body[0]
func_node.name = ""
# Clear all the doc strings.
for node in ast.walk(module):
_remove_docstring(node)
# TODO: clear pure constant expressions as well? like a simple string or number on its own line.
# Convert the ast to a string for hashing.
ast_str = ast.dump(module, annotate_fields=False)
return ast_str.encode()
def _flatten(it):
# TODO: make this more robust (and more picky? and faster?).
res = []
for value in it:
if hasattr(value, "__len__"):
res += _flatten(value)
else:
res.append(value)
return res
def _penalize(constraints, *x, tol=1e-5, scale=1e10, growth=4.0):
# NOTE: maybe growth shouldn't be configurable unless
# it also affects the order of the penalty equation (currently 3).
# NOTE: different ordering of constraints can result in different solutions.
# NOTE: this function doesn't use numpy, so you can copypaste it elsewhere.
# growth = growth ** (1.0 / len(constraints))
# penalties = [cons(*x) for cons in constraints]
penalties = _flatten(cons(*x) for cons in constraints)
growth = growth ** (1.0 / len(penalties))
unsat = sum(p > tol for p in penalties)
# cubic = [p + p * p + p * p * p for p in penalties]
penalsum = 0.0
for p in penalties:
penalsum *= growth # always happens so each penalty gets its own stratum
if p > tol:
penalsum += p + p * p + p * p * p
return scale * (unsat + penalsum)
def _penalize2(constraints, *x, tol=1e-5, scale=1e10, growth=3.0):
# updated from upstream (v2022.11)
penalties = _flatten(cons(*x) for cons in constraints)
growth = growth ** (1.0 / len(penalties))
unsat = sum(p > tol for p in penalties)
penalsum = 0.0
for p in penalties:
penalsum *= growth # always happens so each penalty gets its own stratum
if p > tol:
penalsum += p + p * p * p
return scale * (unsat + penalsum)
def penalize(x, constraints, tol=1e-5, *, scale=1e10, growth=4.0):
# DEPRECATED
return _penalize(constraints, x, tol=tol, scale=scale, growth=growth)
def count_unsat(x, constraints, tol=1e-5):
# DEPRECATED
penalties = [cons(*x) for cons in constraints]
return sum(p > tol for p in penalties)
class Impure: # TODO: rename? volatile? aaa the word is on the tip of my tongue
def needs_hashing(self):
return ()
class Objective(Impure):
pass
class Hyper(Impure):
pass
class Minimize(Objective):
def __init__(self, *args):
self.args = args
# TODO: rename to "realize" to amalgamate with the non-Bounds Hypers?
def compute_with(self, approx, **kw_hypers):
return approx(*self.args)
def needs_hashing(self):
return tuple(self.args)
class Maximize(Objective):
def __init__(self, *args):
self.args = args
def compute_with(self, approx, **kw_hypers):
return -approx(*self.args)
def needs_hashing(self):
return tuple(self.args)
class AbstractError(Objective):
def __init__(self, inputs, outputs, iwrap=pack, owrap=np.asanyarray):
assert callable(iwrap), type(iwrap)
assert callable(owrap), type(owrap)
self.inputs = np.asanyarray(inputs)
self.outputs = np.asanyarray(outputs)
self.iwrap = iwrap
self.owrap = owrap
def compute_with(self, fun, **kw_hypers):
args = self.iwrap(self.inputs)
predictions = fun(*args)
predictions = self.owrap(predictions)
assert predictions.shape == self.outputs.shape
return self._compute(predictions)
def needs_hashing(self):
return (self.inputs, self.outputs)
class MeanSquaredError(AbstractError):
def _compute(self, predictions):
error = self.outputs - predictions
return np.mean(np.square(error))
class MeanAbsoluteError(AbstractError):
def _compute(self, predictions):
error = self.outputs - predictions
return np.mean(np.abs(error))
class PeakError(AbstractError):
def _compute(self, predictions):
error = self.outputs - predictions
return np.max(np.abs(error))
class HuberError(AbstractError):
def _compute(self, predictions, delta=1.0):
error = self.outputs - predictions
return np.where(
error <= delta,
0.5 * np.square(error),
delta * (np.abs(error) - 0.5 * delta),
)
class Accuracy(AbstractError):
def _compute(self, predictions):
p, t = predictions, self.outputs
correct = np.argmax(p, axis=-1) == np.argmax(t, axis=-1)
return np.mean(correct)
class Crossentropy(AbstractError):
def crossentropy(self, predictions, eps=1e-6):
p, t = np.clip(predictions, eps, 1 - eps), self.outputs
f = np.sum(-t * np.log(p) - (1 - t) * np.log(1 - p), axis=-1)
return np.mean(f)
class Constrain(Objective):
def __init__(self, *constraints, tol=1e-5):
for cons in constraints:
assert callable(cons)
self.constraints = constraints
self.tol = float(tol)
def penalize(self, *args):
return _penalize(self.constraints, *args, tol=self.tol)
def compute_with(self, fun, **kw_hypers):
return self.penalize(*kw_hypers.values())
def needs_hashing(self):
return (self.tol,)
class L1(Objective):
def __init__(self, scale):
self.scale = float(scale)
def compute_with(self, fun, **kw_hypers):
a = np.array(list(kw_hypers.values()))
return self.scale * np.sum(np.abs(a))
def needs_hashing(self):
return (self.scale,)
class L2(Objective):
def __init__(self, scale):
self.scale = float(scale)
def compute_with(self, fun, **kw_hypers):
a = np.array(list(kw_hypers.values()))
return self.scale * np.sum(np.square(a))
def needs_hashing(self):
return (self.scale,)
class Const(Hyper):
def __init__(self, value):
self.value = value
# def compute(self, a):
# return a
def realize(self, kw_hypers, *, name, approx=None):
kw_hypers[name] = self.value
return True # realized
def needs_hashing(self):
return (self.value,)
class Bound(Hyper):
def __init__(self, lower, upper):
self.lower, self.upper = _promote_equal(lower, upper)
@property
def bounds(self): # allowed to be overridden for log-scale, etc.
return self.lower, self.upper
def compute(self, a): # allowed to be overridden for log-scale, etc.
return a
def needs_hashing(self):
return (self.lower, self.upper)
class BoundArray(Hyper): # TODO: write and use a Mixin for Array types?
def __init__(self, lower, upper, dims, dtype=np.float64):
self.lower, self.upper = _promote_equal(lower, upper)
self.dims = int(dims)
self.dtype = dtype
@property
def bounds(self): # allowed to be overridden for log-scale, etc.
return self.lower, self.upper
def vectorize(self, a): # the Array equivalent of self.compute(a)
return np.array(a, self.dtype, copy=True)
def needs_hashing(self):
return (self.lower, self.upper, self.dims)
class Round(Bound):
def __init__(self, lower, upper, decimals=0):
self.decimals = int(decimals)
super().__init__(lower, upper)
def compute(self, a):
return np.round(a, self.decimals)
def needs_hashing(self):
return (self.lower, self.upper, self.decimals)
class Adjustment(Hyper):
def __init__(self, x, y, default=0):
self._x, self._y, self._default = _promote_equal(x, y, default)
@property
def x(self): # allowed to be overridden for log-scale, etc.
return self._x
@property
def y(self): # allowed to be overridden for log-scale, etc.
return self._y
@property
def default(self): # allowed to be overridden for log-scale, etc.
return self._default
def perform(self, result):
return self.y - result
def realize(self, kw_hypers, *, name, approx=None):
if approx is None:
kw_hypers[name] = self.default
return False # not realized
else:
result = approx(self.x) # TODO: abstract better (especially multiple args)
kw_hypers[name] = self.perform(result)
return True # realized
def needs_hashing(self):
return (self._x, self._y, self._default)
class Storage: # TODO: use dict-like interface?
def __init__(self, filepath):
open(filepath, "ab").close() # ensure it exists and is readable/writable
self.filepath = filepath
def get(self, key=None, default=None):
tokens = None
with open(self.filepath, "rb") as f:
for line in f:
stored, _, rest = line.partition(b",")
if key == stored:
tokens = rest.rstrip(b"\r\n").split(b",")
cached = default
if tokens is not None:
cached = np.array([float(x) for x in tokens], float)
return cached
def set(self, key, values):
formatted = ",".join(repr(x) for x in values).encode()
# FIXME: doesn't overwrite old value, if any.
with open(self.filepath, "ab") as f:
f.write(key + b"," + formatted + b"\n")
class Cache: # TODO: rename to something a little more specific?
def __init__(self, storage):
assert isinstance(storage, Storage), type(storage)
self.storage = storage
self.hashing = _sha256()
self._hashed = None
self._hashed_n = 0
self._n = 0
def hash(self, *hashme):
for h in hashme:
if self._n == 0 and callable(h):
data = _prepare_function(h)
elif callable(h):
raise Exception(
"functions other than the one being decorated cannot be hashed (for now?)"
)
else:
data = np.asanyarray(h).tobytes()
self.hashing.update(data)
self._n += 1
del data
def finish(self):
self._hashed = self.hashing.hexdigest().encode()
self._hashed_n = self._n
return self._hashed
@property
def hashed(self):
if self._n != self._hashed_n or self._hashed is None:
return self.finish()
return self._hashed
def get(self):
return self.storage.get(self.hashed)
def set(self, values):
return self.storage.set(self.hashed, values)
class HashMe(Impure):
def __init__(self, *objects):
self.objects = objects
def needs_hashing(self):
return self.objects
def _categorize_objectives(objectives):
yes, no = [], []
for subjective in objectives: # forgive me
# if isinstance(objective, Constrain): do something?
if isinstance(subjective, Objective):
yes.append(subjective)
else:
no.append(subjective)
return yes, no
def _categorize_parameters(params): # or "organize"?
positional_kinds = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
inputs, hypers, positional, ind = {}, {}, {}, 0
for param in params:
nodefault = param.default != inspect.Parameter.empty
ishyper = isinstance(param.default, Hyper)
if nodefault and ishyper:
if param.kind == inspect.Parameter.POSITIONAL_ONLY:
raise Exception("hyperparameters must be accessible through keywords")
hypers[param.name] = param.default
if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
positional[param.name] = ind
ind += 1
elif nodefault:
ind += 1
elif param.kind in positional_kinds:
inputs[param.name] = param.default
return inputs, hypers, positional
def _evaluator(objective, budget=1):
# most optimizers don't fare well when there's
# literally nothing provided for them to optimize.
# this evaluator is a substitute for an optimizer in the case
# that there are no hyperparameters, but possibly still parameters.
from scipybiteopt import OptimizeResult
x, fun, nfev = (), np.inf, 0
for i in range(budget):
new_fun = objective(x)
if new_fun < fun:
fun = new_fun
nfev += 1
return OptimizeResult(x=x, fun=fun, nfev=nfev)
def _set_new_defaults(fun, positional, kw_hypers):
# instead of creating a new function masquerading as the old,
# or attempting to change the function's signature,
# simply change the function's default values
# from dummy values to their optimized values.
unfun = fun
while hasattr(unfun, "__wrapped__"):
unfun = unfun.__wrapped__
defaults = unfun.__defaults__
kwdefaults = unfun.__kwdefaults__
if defaults:
new_defaults = list(defaults)
for k, v in kw_hypers.items():
if k in positional:
ind = positional[k]
new_defaults[ind] = v
else:
kwdefaults[k] = v
if defaults:
unfun.__defaults__ = tuple(new_defaults)
if kwdefaults:
unfun.__kwdefaults__ = kwdefaults
# return unfun
def _bite(
fun,
*objectives,
budget=1_000_000, # maximum allowed evaluations of "fun"
cache=False, # True for default filepath, or a path
multimethod="sum", # how to handle multiple objectives
deterministic="unknown",
optimizer_kwargs=None,
_debug=False,
_optim="biteopt", # ignored for now
):
# NOTE: cache only properly supports floats for now.
from scipybiteopt import biteopt as _biteopt
if cache is True: # intentionally not using ==
env_cache = _environ.get("BITTEN_CACHE", None)
storage = Storage(env_cache if env_cache else default_cache)
cache = Cache(storage)
elif cache is False or cache is None: # intentionally not using ==
cache = None
else:
assert isinstance(cache, Cache), type(cache)
# apply defaults.
if optimizer_kwargs is None:
optimizer_kwargs = {}
for k, v in optimizer_defaults.items():
if k not in optimizer_kwargs:
optimizer_kwargs[k] = v
assert _is_in(multimethod, "sum"), f"unknown multiobjective method: {multimethod}"
assert _is_in(deterministic, False, True, "unknown"), deterministic
sig = inspect.signature(fun)
params = sig.parameters.values()
objectives, nonjectives = _categorize_objectives(objectives)
inputs, hypers, positional = _categorize_parameters(params)
linear_bounds = []
for key, hyper in hypers.items():
if isinstance(hyper, BoundArray): # TODO: ArrayMixin or whatever
linear_bounds += [hyper.bounds] * hyper.dims
elif isinstance(hyper, Bound):
linear_bounds.append(hyper.bounds)
memo = {}
memo_secret = object() # dummy object for uniqueness
def make_approx(kw_hypers):
# print(kw_hypers)
def approx(*args):
return fun(*args, **kw_hypers)
# try:
# return fun(*args, **kw_hypers)
# except Exception as e:
# _print_exception(e)
# return -np.inf
return approx
def make_keywords(params):
kw_hypers = {}
other_hypers = []
remaining = iter(params)
all_realized = True
for name, hyper in hypers.items():
if isinstance(hyper, BoundArray):
params = [next(remaining) for _ in range(hyper.dims)]
kw_hypers[name] = hyper.vectorize(params)
elif isinstance(hyper, Bound):
param = next(remaining)
kw_hypers[name] = hyper.compute(param)
else:
other_hypers.append((name, hyper))
realized = hyper.realize(kw_hypers, name=name)
if not realized:
all_realized = False
while not all_realized:
# TODO: is there a design pattern to accomplish this that makes the user
# conscious of the possibility of accidentally infinite-looping?
approx = make_approx(kw_hypers)
all_realized = True
for (name, hyper) in other_hypers:
realized = hyper.realize(kw_hypers, name=name, approx=approx)
if not realized:
all_realized = False
return kw_hypers
def objective(params):
kw_hypers = make_keywords(params)
approx = make_approx(kw_hypers)
objective_values = []
for subjective in objectives: # forgive me
if deterministic is True:
key = (subjective,) + tuple(float(x) for x in kw_hypers.values())
if key in memo:
value = memo[key]
# print("MEMO:", key, value)
else:
value = subjective.compute_with(approx, **kw_hypers)
memo[key] = value
else:
value = subjective.compute_with(approx, **kw_hypers)
objective_values.append(value)
if multimethod == "sum":
return np.sum(objective_values)
elif callable(multimethod):
return multimethod(objective_values)
else:
# huh? just do this i guess.
return objective_values[0]
cached = None
if cache:
cache.hash(fun)
for obj in _chain(objectives, nonjectives, hypers.values()):
for h in obj.needs_hashing():
cache.hash(h)
cached = cache.get()
if _debug:
print("HASH:", cache.hashed, sep="\n")
if cached is None:
if _debug: # dirty test for exceptions
center = np.mean(linear_bounds, axis=-1) if hypers else ()
print("FIRST VALUE:", objective(center), sep="\n")
if hypers:
res = _biteopt(objective, linear_bounds, iters=budget, **optimizer_kwargs)
else:
if deterministic is True:
res = _evaluator(objective, 1)
else:
res = _evaluator(objective, budget)
if _debug:
print("RES:", res, sep="\n")
optimized = res.x
else:
assert len(cached) == len(linear_bounds)
optimized = cached
if cache and cached is None:
cache.set(optimized)
kw_hypers = make_keywords(optimized)
_set_new_defaults(fun, positional, kw_hypers)
fun.__optimized__ = kw_hypers # secret!
if cached is None:
fun.__result__ = res
return fun
def bite(*objectives, **config):
def decorator(fun):
# this prevents an unnecessary level of indentation.
return _bite(fun, *objectives, **config)
return decorator
# this is similar to default behaviour of having no __all__ variable at all,
# but ours ignores modules as well. this allows for `import sys` and such
# without clobbering `from our_module import *`.
__all__ = [
k for k, v in locals().items() if not inspect.ismodule(v) and not k.startswith("_")
]
if __name__ == "__main__":
# non-associative example from:
# https://numpy.org/doc/stable/reference/generated/numpy.promote_types.html
print("S4 or S6?", _promote_all("S", "i1", "u1"))
print(np.promote_types("S4", "S6"))