From 4c113eed155373761235628eb23e10855bb80645 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Tue, 7 Jun 2022 06:54:47 +0200 Subject: [PATCH] add bitten --- bitten/README.md | 0 bitten/bitten.py | 738 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 738 insertions(+) create mode 100644 bitten/README.md create mode 100644 bitten/bitten.py diff --git a/bitten/README.md b/bitten/README.md new file mode 100644 index 0000000..e69de29 diff --git a/bitten/bitten.py b/bitten/bitten.py new file mode 100644 index 0000000..3d28ebe --- /dev/null +++ b/bitten/bitten.py @@ -0,0 +1,738 @@ +#!/usr/bin/env python3 +# prepend underscores so these don't leak to __all__. +from hashlib import sha256 as _sha256 +from itertools import chain as _chain +from os import environ as _environ +from pathlib import Path as _Path +from traceback import print_exception as _print_exception +import ast +import inspect +import numpy as np + +# NOTE: the environment variable "BITTEN_CACHE" takes precedence here. +default_cache = "bitten.cache" +optimizer_defaults = dict(tol="weak", depth=1, attempts=1) + + +def pack(*args): + """Return positional arguments as a tuple.""" + return args + + +def _is_in(needle, *haystack): + # like the "in" keyword, but uses "is" instead of "==". + for thing in haystack: + if thing is needle: + return True + return False + + +def _hashable(a): # currently only used for caching adjustments + if isinstance(a, np.ndarray): + if a.ndim > 0: + # key = tuple(a) # probably won't work for ndim >= 2 + key = a.shape + tuple(a.ravel()) + else: + key = a.item() + else: + key = a # just hope for the best + return key + + +def _promote_all(*dtypes): + # new_dtypes = [np.promote_types(t, _promote_all(*dtypes)) for t in dtypes] + a, dtypes = dtypes[0], dtypes[1:] # TODO: use next(iter(dtypes)) pattern instead? + new_dtypes = [np.promote_types(a, b) for b in dtypes] + while len(new_dtypes) > 1: + a, dtypes = new_dtypes[0], new_dtypes[1:] + new_dtypes = [np.promote_types(a, b) for b in dtypes] + return new_dtypes[0] + + +def _promote_equal(a, b, *etc): + a = np.asanyarray(a) + b = np.asanyarray(b) + + if etc: + etc_arrays = (np.asanyarray(e) for e in etc) + etc_dtypes = (e.dtype for e in etc_arrays) + all_arrays = (a, b, *etc_arrays) + all_dtypes = (a.dtype, b.dtype, *etc_dtypes) + promo = _promote_all(*all_dtypes) + return tuple(a.astype(promo) for a in all_arrays) + + if a.dtype != b.dtype: + promo = np.promote_types(a.dtype, b.dtype) + a = a.astype(promo) + b = b.astype(promo) + assert a.dtype == b.dtype + return a, b + + +def _remove_docstring(node): + # via: https://stackoverflow.com/a/49998190 + # remove all the doc strings in a FunctionDef or ClassDef as node. + + if not (isinstance(node, ast.FunctionDef) or isinstance(node, ast.ClassDef)): + return + + if len(node.body) != 0: + docstr = node.body[0] + if isinstance(docstr, ast.Expr) and isinstance(docstr.value, ast.Str): + node.body.pop(0) + + +def _prepare_function(func): + # TODO: describe what this does and why. + # adapted from: https://stackoverflow.com/a/49998190 + # TODO: (optionally) strip the initial decorator from the AST as well. + + func_str = inspect.getsource(func) + + # Remove leading indentation, if any. + indent = len(func_str[: len(func_str) - len(func_str.lstrip())]) + if indent: + func_str = "\n".join(line[indent:] for line in func_str.splitlines()) + + module = ast.parse(func_str) + + assert len(module.body) == 1 and isinstance(module.body[0], ast.FunctionDef) + + # Clear function name so it doesn't affect the hash. + func_node = module.body[0] + func_node.name = "" + + # Clear all the doc strings. + for node in ast.walk(module): + _remove_docstring(node) + + # TODO: clear pure constant expressions as well? like a simple string or number on its own line. + + # Convert the ast to a string for hashing. + ast_str = ast.dump(module, annotate_fields=False) + return ast_str.encode() + + +def _flatten(it): + # TODO: make this more robust (and more picky? and faster?). + res = [] + for value in it: + if hasattr(value, "__len__"): + res += _flatten(value) + else: + res.append(value) + return res + + +def _penalize(constraints, *x, tol=1e-5, scale=1e10, growth=4.0): + # NOTE: maybe growth shouldn't be configurable unless + # it also affects the order of the penalty equation (currently 3). + # NOTE: different ordering of constraints can result in different solutions. + # NOTE: this function doesn't use numpy, so you can copypaste it elsewhere. + # growth = growth ** (1.0 / len(constraints)) + # penalties = [cons(*x) for cons in constraints] + penalties = _flatten(cons(*x) for cons in constraints) + growth = growth ** (1.0 / len(penalties)) + unsat = sum(p > tol for p in penalties) + # cubic = [p + p * p + p * p * p for p in penalties] + penalsum = 0.0 + for p in penalties: + penalsum *= growth # always happens so each penalty gets its own stratum + if p > tol: + penalsum += p + p * p + p * p * p + return scale * (unsat + penalsum) + + +def _penalize2(constraints, *x, tol=1e-5, scale=1e10, growth=3.0): + # updated from upstream (v2022.11) + penalties = _flatten(cons(*x) for cons in constraints) + growth = growth ** (1.0 / len(penalties)) + unsat = sum(p > tol for p in penalties) + penalsum = 0.0 + for p in penalties: + penalsum *= growth # always happens so each penalty gets its own stratum + if p > tol: + penalsum += p + p * p * p + return scale * (unsat + penalsum) + + +def penalize(x, constraints, tol=1e-5, *, scale=1e10, growth=4.0): + # DEPRECATED + return _penalize(constraints, x, tol=tol, scale=scale, growth=growth) + + +def count_unsat(x, constraints, tol=1e-5): + # DEPRECATED + penalties = [cons(*x) for cons in constraints] + return sum(p > tol for p in penalties) + + +class Impure: # TODO: rename? volatile? aaa the word is on the tip of my tongue + def needs_hashing(self): + return () + + +class Objective(Impure): + pass + + +class Hyper(Impure): + pass + + +class Minimize(Objective): + def __init__(self, *args): + self.args = args + + # TODO: rename to "realize" to amalgamate with the non-Bounds Hypers? + def compute_with(self, approx, **kw_hypers): + return approx(*self.args) + + def needs_hashing(self): + return tuple(self.args) + + +class Maximize(Objective): + def __init__(self, *args): + self.args = args + + def compute_with(self, approx, **kw_hypers): + return -approx(*self.args) + + def needs_hashing(self): + return tuple(self.args) + + +class AbstractError(Objective): + def __init__(self, inputs, outputs, iwrap=pack, owrap=np.asanyarray): + assert callable(iwrap), type(iwrap) + assert callable(owrap), type(owrap) + self.inputs = np.asanyarray(inputs) + self.outputs = np.asanyarray(outputs) + self.iwrap = iwrap + self.owrap = owrap + + def compute_with(self, fun, **kw_hypers): + args = self.iwrap(self.inputs) + predictions = fun(*args) + predictions = self.owrap(predictions) + assert predictions.shape == self.outputs.shape + return self._compute(predictions) + + def needs_hashing(self): + return (self.inputs, self.outputs) + + +class MeanSquaredError(AbstractError): + def _compute(self, predictions): + error = self.outputs - predictions + return np.mean(np.square(error)) + + +class MeanAbsoluteError(AbstractError): + def _compute(self, predictions): + error = self.outputs - predictions + return np.mean(np.abs(error)) + + +class PeakError(AbstractError): + def _compute(self, predictions): + error = self.outputs - predictions + return np.max(np.abs(error)) + + +class HuberError(AbstractError): + def _compute(self, predictions, delta=1.0): + error = self.outputs - predictions + return np.where( + error <= delta, + 0.5 * np.square(error), + delta * (np.abs(error) - 0.5 * delta), + ) + + +class Accuracy(AbstractError): + def _compute(self, predictions): + p, t = predictions, self.outputs + correct = np.argmax(p, axis=-1) == np.argmax(t, axis=-1) + return np.mean(correct) + + +class Crossentropy(AbstractError): + def crossentropy(self, predictions, eps=1e-6): + p, t = np.clip(predictions, eps, 1 - eps), self.outputs + f = np.sum(-t * np.log(p) - (1 - t) * np.log(1 - p), axis=-1) + return np.mean(f) + + +class Constrain(Objective): + def __init__(self, *constraints, tol=1e-5): + for cons in constraints: + assert callable(cons) + self.constraints = constraints + self.tol = float(tol) + + def penalize(self, *args): + return _penalize(self.constraints, *args, tol=self.tol) + + def compute_with(self, fun, **kw_hypers): + return self.penalize(*kw_hypers.values()) + + def needs_hashing(self): + return (self.tol,) + + +class L1(Objective): + def __init__(self, scale): + self.scale = float(scale) + + def compute_with(self, fun, **kw_hypers): + a = np.array(list(kw_hypers.values())) + return self.scale * np.sum(np.abs(a)) + + def needs_hashing(self): + return (self.scale,) + + +class L2(Objective): + def __init__(self, scale): + self.scale = float(scale) + + def compute_with(self, fun, **kw_hypers): + a = np.array(list(kw_hypers.values())) + return self.scale * np.sum(np.square(a)) + + def needs_hashing(self): + return (self.scale,) + + +class Const(Hyper): + def __init__(self, value): + self.value = value + + # def compute(self, a): + # return a + + def realize(self, kw_hypers, *, name, approx=None): + kw_hypers[name] = self.value + return True # realized + + def needs_hashing(self): + return (self.value,) + + +class Bound(Hyper): + def __init__(self, lower, upper): + self.lower, self.upper = _promote_equal(lower, upper) + + @property + def bounds(self): # allowed to be overridden for log-scale, etc. + return self.lower, self.upper + + def compute(self, a): # allowed to be overridden for log-scale, etc. + return a + + def needs_hashing(self): + return (self.lower, self.upper) + + +class BoundArray(Hyper): # TODO: write and use a Mixin for Array types? + def __init__(self, lower, upper, dims, dtype=np.float64): + self.lower, self.upper = _promote_equal(lower, upper) + self.dims = int(dims) + self.dtype = dtype + + @property + def bounds(self): # allowed to be overridden for log-scale, etc. + return self.lower, self.upper + + def vectorize(self, a): # the Array equivalent of self.compute(a) + return np.array(a, self.dtype, copy=True) + + def needs_hashing(self): + return (self.lower, self.upper, self.dims) + + +class Round(Bound): + def __init__(self, lower, upper, decimals=0): + self.decimals = int(decimals) + super().__init__(lower, upper) + + def compute(self, a): + return np.round(a, self.decimals) + + def needs_hashing(self): + return (self.lower, self.upper, self.decimals) + + +class Adjustment(Hyper): + def __init__(self, x, y, default=0): + self._x, self._y, self._default = _promote_equal(x, y, default) + + @property + def x(self): # allowed to be overridden for log-scale, etc. + return self._x + + @property + def y(self): # allowed to be overridden for log-scale, etc. + return self._y + + @property + def default(self): # allowed to be overridden for log-scale, etc. + return self._default + + def perform(self, result): + return self.y - result + + def realize(self, kw_hypers, *, name, approx=None): + if approx is None: + kw_hypers[name] = self.default + return False # not realized + else: + result = approx(self.x) # TODO: abstract better (especially multiple args) + kw_hypers[name] = self.perform(result) + return True # realized + + def needs_hashing(self): + return (self._x, self._y, self._default) + + +class Storage: # TODO: use dict-like interface? + def __init__(self, filepath): + open(filepath, "ab").close() # ensure it exists and is readable/writable + self.filepath = filepath + + def get(self, key=None, default=None): + tokens = None + with open(self.filepath, "rb") as f: + for line in f: + stored, _, rest = line.partition(b",") + if key == stored: + tokens = rest.rstrip(b"\r\n").split(b",") + + cached = default + if tokens is not None: + cached = np.array([float(x) for x in tokens], float) + + return cached + + def set(self, key, values): + formatted = ",".join(repr(x) for x in values).encode() + # FIXME: doesn't overwrite old value, if any. + with open(self.filepath, "ab") as f: + f.write(key + b"," + formatted + b"\n") + + +class Cache: # TODO: rename to something a little more specific? + def __init__(self, storage): + assert isinstance(storage, Storage), type(storage) + self.storage = storage + self.hashing = _sha256() + self._hashed = None + self._hashed_n = 0 + self._n = 0 + + def hash(self, *hashme): + for h in hashme: + if self._n == 0 and callable(h): + data = _prepare_function(h) + elif callable(h): + raise Exception( + "functions other than the one being decorated cannot be hashed (for now?)" + ) + else: + data = np.asanyarray(h).tobytes() + + self.hashing.update(data) + self._n += 1 + del data + + def finish(self): + self._hashed = self.hashing.hexdigest().encode() + self._hashed_n = self._n + return self._hashed + + @property + def hashed(self): + if self._n != self._hashed_n or self._hashed is None: + return self.finish() + return self._hashed + + def get(self): + return self.storage.get(self.hashed) + + def set(self, values): + return self.storage.set(self.hashed, values) + + +class HashMe(Impure): + def __init__(self, *objects): + self.objects = objects + + def needs_hashing(self): + return self.objects + + +def _categorize_objectives(objectives): + yes, no = [], [] + for subjective in objectives: # forgive me + # if isinstance(objective, Constrain): do something? + if isinstance(subjective, Objective): + yes.append(subjective) + else: + no.append(subjective) + return yes, no + + +def _categorize_parameters(params): # or "organize"? + positional_kinds = ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + + inputs, hypers, positional, ind = {}, {}, {}, 0 + for param in params: + nodefault = param.default != inspect.Parameter.empty + ishyper = isinstance(param.default, Hyper) + if nodefault and ishyper: + if param.kind == inspect.Parameter.POSITIONAL_ONLY: + raise Exception("hyperparameters must be accessible through keywords") + hypers[param.name] = param.default + if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + positional[param.name] = ind + ind += 1 + elif nodefault: + ind += 1 + elif param.kind in positional_kinds: + inputs[param.name] = param.default + + return inputs, hypers, positional + + +def _evaluator(objective, budget=1): + # most optimizers don't fare well when there's + # literally nothing provided for them to optimize. + # this evaluator is a substitute for an optimizer in the case + # that there are no hyperparameters, but possibly still parameters. + from scipybiteopt import OptimizeResult + + x, fun, nfev = (), np.inf, 0 + for i in range(budget): + new_fun = objective(x) + if new_fun < fun: + fun = new_fun + nfev += 1 + return OptimizeResult(x=x, fun=fun, nfev=nfev) + + +def _set_new_defaults(fun, positional, kw_hypers): + # instead of creating a new function masquerading as the old, + # or attempting to change the function's signature, + # simply change the function's default values + # from dummy values to their optimized values. + unfun = fun + while hasattr(unfun, "__wrapped__"): + unfun = unfun.__wrapped__ + + defaults = unfun.__defaults__ + kwdefaults = unfun.__kwdefaults__ + if defaults: + new_defaults = list(defaults) + + for k, v in kw_hypers.items(): + if k in positional: + ind = positional[k] + new_defaults[ind] = v + else: + kwdefaults[k] = v + + if defaults: + unfun.__defaults__ = tuple(new_defaults) + if kwdefaults: + unfun.__kwdefaults__ = kwdefaults + + # return unfun + + +def _bite( + fun, + *objectives, + budget=1_000_000, # maximum allowed evaluations of "fun" + cache=False, # True for default filepath, or a path + multimethod="sum", # how to handle multiple objectives + deterministic="unknown", + optimizer_kwargs=None, + _debug=False, + _optim="biteopt", # ignored for now +): + # NOTE: cache only properly supports floats for now. + from scipybiteopt import biteopt as _biteopt + + if cache is True: # intentionally not using == + env_cache = _environ.get("BITTEN_CACHE", None) + storage = Storage(env_cache if env_cache else default_cache) + cache = Cache(storage) + elif cache is False or cache is None: # intentionally not using == + cache = None + else: + assert isinstance(cache, Cache), type(cache) + + # apply defaults. + if optimizer_kwargs is None: + optimizer_kwargs = {} + for k, v in optimizer_defaults.items(): + if k not in optimizer_kwargs: + optimizer_kwargs[k] = v + + assert _is_in(multimethod, "sum"), f"unknown multiobjective method: {multimethod}" + + assert _is_in(deterministic, False, True, "unknown"), deterministic + + sig = inspect.signature(fun) + params = sig.parameters.values() + + objectives, nonjectives = _categorize_objectives(objectives) + inputs, hypers, positional = _categorize_parameters(params) + + linear_bounds = [] + for key, hyper in hypers.items(): + if isinstance(hyper, BoundArray): # TODO: ArrayMixin or whatever + linear_bounds += [hyper.bounds] * hyper.dims + elif isinstance(hyper, Bound): + linear_bounds.append(hyper.bounds) + + memo = {} + memo_secret = object() # dummy object for uniqueness + + def make_approx(kw_hypers): + # print(kw_hypers) + def approx(*args): + return fun(*args, **kw_hypers) + # try: + # return fun(*args, **kw_hypers) + # except Exception as e: + # _print_exception(e) + # return -np.inf + + return approx + + def make_keywords(params): + kw_hypers = {} + other_hypers = [] + remaining = iter(params) + all_realized = True + for name, hyper in hypers.items(): + if isinstance(hyper, BoundArray): + params = [next(remaining) for _ in range(hyper.dims)] + kw_hypers[name] = hyper.vectorize(params) + elif isinstance(hyper, Bound): + param = next(remaining) + kw_hypers[name] = hyper.compute(param) + else: + other_hypers.append((name, hyper)) + realized = hyper.realize(kw_hypers, name=name) + if not realized: + all_realized = False + + while not all_realized: + # TODO: is there a design pattern to accomplish this that makes the user + # conscious of the possibility of accidentally infinite-looping? + approx = make_approx(kw_hypers) + all_realized = True + + for (name, hyper) in other_hypers: + realized = hyper.realize(kw_hypers, name=name, approx=approx) + if not realized: + all_realized = False + + return kw_hypers + + def objective(params): + kw_hypers = make_keywords(params) + approx = make_approx(kw_hypers) + + objective_values = [] + for subjective in objectives: # forgive me + if deterministic is True: + key = (subjective,) + tuple(float(x) for x in kw_hypers.values()) + if key in memo: + value = memo[key] + # print("MEMO:", key, value) + else: + value = subjective.compute_with(approx, **kw_hypers) + memo[key] = value + else: + value = subjective.compute_with(approx, **kw_hypers) + objective_values.append(value) + + if multimethod == "sum": + return np.sum(objective_values) + elif callable(multimethod): + return multimethod(objective_values) + else: + # huh? just do this i guess. + return objective_values[0] + + cached = None + if cache: + cache.hash(fun) + for obj in _chain(objectives, nonjectives, hypers.values()): + for h in obj.needs_hashing(): + cache.hash(h) + cached = cache.get() + if _debug: + print("HASH:", cache.hashed, sep="\n") + + if cached is None: + if _debug: # dirty test for exceptions + center = np.mean(linear_bounds, axis=-1) if hypers else () + print("FIRST VALUE:", objective(center), sep="\n") + + if hypers: + res = _biteopt(objective, linear_bounds, iters=budget, **optimizer_kwargs) + else: + if deterministic is True: + res = _evaluator(objective, 1) + else: + res = _evaluator(objective, budget) + + if _debug: + print("RES:", res, sep="\n") + + optimized = res.x + else: + assert len(cached) == len(linear_bounds) + optimized = cached + + if cache and cached is None: + cache.set(optimized) + + kw_hypers = make_keywords(optimized) + _set_new_defaults(fun, positional, kw_hypers) + fun.__optimized__ = kw_hypers # secret! + if cached is None: + fun.__result__ = res + + return fun + + +def bite(*objectives, **config): + def decorator(fun): + # this prevents an unnecessary level of indentation. + return _bite(fun, *objectives, **config) + + return decorator + + +# this is similar to default behaviour of having no __all__ variable at all, +# but ours ignores modules as well. this allows for `import sys` and such +# without clobbering `from our_module import *`. +__all__ = [ + k for k, v in locals().items() if not inspect.ismodule(v) and not k.startswith("_") +] + +if __name__ == "__main__": + # non-associative example from: + # https://numpy.org/doc/stable/reference/generated/numpy.promote_types.html + print("S4 or S6?", _promote_all("S", "i1", "u1")) + print(np.promote_types("S4", "S6"))