#!/usr/bin/env python3 # prepend underscores so these don't leak to __all__. from hashlib import sha256 as _sha256 from itertools import chain as _chain from itertools import repeat as _repeat from os import environ as _environ from pathlib import Path as _Path from traceback import print_exception as _print_exception import ast import inspect import numpy as np # NOTE: the environment variable "BITTEN_CACHE" takes precedence here. default_cache = "bitten.cache" optimizer_defaults = dict(tol="weak", depth=1, attempts=1) def pack(*args): """Return positional arguments as a tuple.""" return args def lament(*args, **kwargs): from sys import stderr print(*args, **kwargs, file=stderr) def _is_in(needle, *haystack): # like the "in" keyword, but uses "is" instead of "==". for thing in haystack: if thing is needle: return True return False def _hashable(a): # currently only used for caching adjustments if isinstance(a, np.ndarray): if a.ndim > 0: # key = tuple(a) # probably won't work for ndim >= 2 key = a.shape + tuple(a.ravel()) else: key = a.item() else: key = a # just hope for the best return key def _promote_all(*dtypes): # new_dtypes = [np.promote_types(t, _promote_all(*dtypes)) for t in dtypes] a, dtypes = dtypes[0], dtypes[1:] # TODO: use next(iter(dtypes)) pattern instead? new_dtypes = [np.promote_types(a, b) for b in dtypes] while len(new_dtypes) > 1: a, dtypes = new_dtypes[0], new_dtypes[1:] new_dtypes = [np.promote_types(a, b) for b in dtypes] return new_dtypes[0] def _promote_equal(a, b, *etc): a = np.asanyarray(a) b = np.asanyarray(b) if etc: etc_arrays = (np.asanyarray(e) for e in etc) etc_dtypes = (e.dtype for e in etc_arrays) all_arrays = (a, b, *etc_arrays) all_dtypes = (a.dtype, b.dtype, *etc_dtypes) promo = _promote_all(*all_dtypes) return tuple(a.astype(promo) for a in all_arrays) if a.dtype != b.dtype: promo = np.promote_types(a.dtype, b.dtype) a = a.astype(promo) b = b.astype(promo) assert a.dtype == b.dtype return a, b def _remove_docstring(node): # via: https://stackoverflow.com/a/49998190 # remove all the doc strings in a FunctionDef or ClassDef as node. if not (isinstance(node, ast.FunctionDef) or isinstance(node, ast.ClassDef)): return if len(node.body) != 0: docstr = node.body[0] if isinstance(docstr, ast.Expr) and isinstance(docstr.value, ast.Str): node.body.pop(0) def _prepare_function(func): # TODO: describe what this does and why. # adapted from: https://stackoverflow.com/a/49998190 # TODO: (optionally) strip the initial decorator from the AST as well. func_str = inspect.getsource(func) # Remove leading indentation, if any. indent = len(func_str[: len(func_str) - len(func_str.lstrip())]) if indent: func_str = "\n".join(line[indent:] for line in func_str.splitlines()) module = ast.parse(func_str) assert len(module.body) == 1 and isinstance(module.body[0], ast.FunctionDef) # Clear function name so it doesn't affect the hash. func_node = module.body[0] func_node.name = "" # Clear all the doc strings. for node in ast.walk(module): _remove_docstring(node) # TODO: clear pure constant expressions as well? like a simple string or number on its own line. # Convert the ast to a string for hashing. ast_str = ast.dump(module, annotate_fields=False) return ast_str.encode() def _flatten(it): # TODO: make this more robust (and more picky? and faster?). res = [] for value in it: if hasattr(value, "__len__"): res += _flatten(value) else: res.append(value) return res def _penalize_v2021_23(penalties, tol=1e-5, scale=1e10): n_con = len(penalties) unmet = sum(p > tol for p in penalties) growth = 4.0 ** (1.0 / n_con) # "ps" penalty = 0.0 for p in penalties: p = p if p > tol else 0.0 squared = p * p penalty = growth * penalty + p + squared + p * squared return scale * (unmet + penalty) def _penalize_v2022_25(penalties, tol=1e-5, scale=1e10): n_con = len(penalties) growth = 3.0 ** (1.0 / n_con) # "ps" increment = n_con**-0.5 # "pnsi" penalty, nominal = 0.0, 0.0 # "pns", "pnsm" for p in penalties: p = p - tol if p > tol else 0.0 squared = p * p penalty = growth * penalty + increment + p + squared + p * squared nominal = growth * nominal + increment return scale * (penalty - nominal + 1.0) penalizers = { +1: _penalize_v2021_23, +5: _penalize_v2022_25, } def penalize(x, constraints, tol=1e-5, *, scale=1e10, growth=4.0): assert False, "deprecated; use _penalize_v2021_23 instead" def count_unsat(x, constraints, tol=1e-5): assert False, "deprecated; do it yourself" class Impure: # TODO: rename? volatile? aaa the word is on the tip of my tongue def needs_hashing(self): return () class Objective(Impure): pass class Hyper(Impure): pass class Minimize(Objective): def __init__(self, *args): self.args = args # TODO: rename to "realize" to amalgamate with the non-Bounds Hypers? def compute_with(self, approx, **kw_hypers): return approx(*self.args) def needs_hashing(self): return tuple(self.args) class Maximize(Objective): def __init__(self, *args): self.args = args def compute_with(self, approx, **kw_hypers): return -approx(*self.args) def needs_hashing(self): return tuple(self.args) class AbstractError(Objective): def __init__(self, inputs, outputs, iwrap=pack, owrap=np.asanyarray): assert callable(iwrap), type(iwrap) assert callable(owrap), type(owrap) self.inputs = np.asanyarray(inputs) self.outputs = np.asanyarray(outputs) self.iwrap = iwrap self.owrap = owrap def compute_with(self, fun, **kw_hypers): args = self.iwrap(self.inputs) predictions = fun(*args) predictions = self.owrap(predictions) assert predictions.shape == self.outputs.shape return self._compute(predictions) def needs_hashing(self): return (self.inputs, self.outputs) class MeanSquaredError(AbstractError): def _compute(self, predictions): error = self.outputs - predictions return np.mean(np.square(error)) class MeanAbsoluteError(AbstractError): def _compute(self, predictions): error = self.outputs - predictions return np.mean(np.abs(error)) class PeakError(AbstractError): def _compute(self, predictions): error = self.outputs - predictions return np.max(np.abs(error)) class HuberError(AbstractError): def _compute(self, predictions, delta=1.0): error = self.outputs - predictions return np.where( error <= delta, 0.5 * np.square(error), delta * (np.abs(error) - 0.5 * delta), ) class Accuracy(AbstractError): def _compute(self, predictions): p, t = predictions, self.outputs correct = np.argmax(p, axis=-1) == np.argmax(t, axis=-1) return np.mean(correct) class Crossentropy(AbstractError): def _compute(self, predictions, eps=1e-6): p, t = np.clip(predictions, eps, 1 - eps), self.outputs f = np.sum(-t * np.log(p) - (1 - t) * np.log(1 - p), axis=-1) return np.mean(f) class Constrain(Objective): def __init__(self, *constraints, tol=1e-5, version=5): for cons in constraints: assert callable(cons) self.constraints = constraints self.tol = float(tol) self.version = version @property def version(self): return self._version @version.setter def version(self, version): assert version in penalizers, f"unknown version of penalty function: {version}" self._version = version self._penalize = penalizers[self._version] def penalize(self, *x): penalties = _flatten(cons(*x) for cons in self.constraints) if not any(p > self.tol for p in penalties): return 0.0 return self._penalize(penalties, tol=self.tol) def compute_with(self, fun, **kw_hypers): return self.penalize(*kw_hypers.values()) def needs_hashing(self): return (self.tol,) class L1(Objective): def __init__(self, scale): self.scale = float(scale) def compute_with(self, fun, **kw_hypers): a = np.array(list(kw_hypers.values())) return self.scale * np.sum(np.abs(a)) def needs_hashing(self): return (self.scale,) class L2(Objective): def __init__(self, scale): self.scale = float(scale) def compute_with(self, fun, **kw_hypers): a = np.array(list(kw_hypers.values())) return self.scale * np.sum(np.square(a)) def needs_hashing(self): return (self.scale,) class Const(Hyper): def __init__(self, value): self.value = value # def compute(self, a): # return a def realize(self, kw_hypers, *, name, approx=None): kw_hypers[name] = self.value return True # realized def needs_hashing(self): return (self.value,) class Bound(Hyper): def __init__(self, lower, upper): self.lower, self.upper = _promote_equal(lower, upper) @property def bounds(self): # allowed to be overridden for log-scale, etc. return self.lower, self.upper def compute(self, a): # allowed to be overridden for log-scale, etc. return a def needs_hashing(self): return (self.lower, self.upper) class BoundArray(Hyper): # TODO: write and use a Mixin for Array types? def __init__(self, lower, upper, dims, dtype=np.float64): self.lower, self.upper = _promote_equal(lower, upper) self.dims = int(dims) self.dtype = dtype @property def bounds(self): # allowed to be overridden for log-scale, etc. return self.lower, self.upper def vectorize(self, a): # the Array equivalent of self.compute(a) return np.array(a, self.dtype, copy=True) def needs_hashing(self): return (self.lower, self.upper, self.dims) class Round(Bound): def __init__(self, lower, upper, decimals=0): self.decimals = int(decimals) super().__init__(lower, upper) def compute(self, a): return np.round(a, self.decimals) def needs_hashing(self): return (self.lower, self.upper, self.decimals) class Adjustment(Hyper): def __init__(self, x, y, default=0): self._x, self._y, self._default = _promote_equal(x, y, default) @property def x(self): # allowed to be overridden for log-scale, etc. return self._x @property def y(self): # allowed to be overridden for log-scale, etc. return self._y @property def default(self): # allowed to be overridden for log-scale, etc. return self._default def perform(self, result): return self.y - result def realize(self, kw_hypers, *, name, approx=None): if approx is None: kw_hypers[name] = self.default return False # not realized else: result = approx(self.x) # TODO: abstract better (especially multiple args) kw_hypers[name] = self.perform(result) return True # realized def needs_hashing(self): return (self._x, self._y, self._default) class Storage: # TODO: use dict-like interface? def __init__(self, filepath): open(filepath, "ab").close() # ensure it exists and is readable/writable self.filepath = filepath def get(self, key=None, default=None): tokens = None with open(self.filepath, "rb") as f: for line in f: stored, _, rest = line.partition(b",") if key == stored: tokens = rest.rstrip(b"\r\n").split(b",") cached = default if tokens is not None: cached = np.array([float(x) for x in tokens], float) return cached def set(self, key, values): formatted = ",".join(repr(x) for x in values).encode() # FIXME: doesn't overwrite old value, if any. with open(self.filepath, "ab") as f: f.write(key + b"," + formatted + b"\n") class Cache: # TODO: rename to something a little more specific? def __init__(self, storage): assert isinstance(storage, Storage), type(storage) self.storage = storage self.hashing = _sha256() self._hashed = None self._hashed_n = 0 self._n = 0 def hash(self, *hashme): for h in hashme: if self._n == 0 and callable(h): data = _prepare_function(h) elif callable(h): raise Exception( "functions other than the one being decorated cannot be hashed (for now?)" ) else: data = np.asanyarray(h).tobytes() self.hashing.update(data) self._n += 1 del data def finish(self): self._hashed = self.hashing.hexdigest().encode() self._hashed_n = self._n return self._hashed @property def hashed(self): if self._n != self._hashed_n or self._hashed is None: return self.finish() return self._hashed def get(self): return self.storage.get(self.hashed) def set(self, values): return self.storage.set(self.hashed, values) class HashMe(Impure): def __init__(self, *objects): self.objects = objects def needs_hashing(self): return self.objects def _categorize_objectives(objectives): yes, no = [], [] for subjective in objectives: # forgive me # if isinstance(objective, Constrain): do something? if isinstance(subjective, Objective): yes.append(subjective) else: no.append(subjective) return yes, no def _categorize_parameters(params): # or "organize"? positional_kinds = ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ) inputs, hypers, positional, ind = {}, {}, {}, 0 for param in params: nodefault = param.default != inspect.Parameter.empty ishyper = isinstance(param.default, Hyper) if nodefault and ishyper: if param.kind == inspect.Parameter.POSITIONAL_ONLY: raise Exception("hyperparameters must be accessible through keywords") hypers[param.name] = param.default if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: positional[param.name] = ind ind += 1 elif nodefault: ind += 1 elif param.kind in positional_kinds: inputs[param.name] = param.default return inputs, hypers, positional def _evaluator(objective, budget=1): # most optimizers don't fare well when there's # literally nothing provided for them to optimize. # this evaluator is a substitute for an optimizer in the case # that there are no hyperparameters, but possibly still parameters. from scipybiteopt import OptimizeResult x, fun, nfev = (), np.inf, 0 for i in range(budget): new_fun = objective(x) if new_fun < fun: fun = new_fun nfev += 1 return OptimizeResult(x=x, fun=fun, nfev=nfev) def _set_new_defaults(fun, positional, kw_hypers): # instead of creating a new function masquerading as the old, # or attempting to change the function's signature, # simply change the function's default values # from dummy values to their optimized values. unfun = fun while hasattr(unfun, "__wrapped__"): unfun = unfun.__wrapped__ defaults = unfun.__defaults__ kwdefaults = unfun.__kwdefaults__ if defaults: new_defaults = list(defaults) for k, v in kw_hypers.items(): if k in positional: ind = positional[k] new_defaults[ind] = v else: kwdefaults[k] = v if defaults: unfun.__defaults__ = tuple(new_defaults) if kwdefaults: unfun.__kwdefaults__ = kwdefaults # return unfun def _bite( fun, *objectives, budget=1_000_000, # maximum allowed evaluations of "fun" cache=False, # True for default filepath, or a path multimethod="sum", # how to handle multiple objectives deterministic="unknown", optimizer_kwargs=None, _debug=False, _optim="biteopt", # ignored for now ): # NOTE: cache only properly supports floats for now. from scipybiteopt import biteopt as _biteopt if cache is True: # intentionally not using == env_cache = _environ.get("BITTEN_CACHE", None) storage = Storage(env_cache if env_cache else default_cache) cache = Cache(storage) elif cache is False or cache is None: # intentionally not using == cache = None else: assert isinstance(cache, Cache), type(cache) # apply defaults. if optimizer_kwargs is None: optimizer_kwargs = {} for k, v in optimizer_defaults.items(): if k not in optimizer_kwargs: optimizer_kwargs[k] = v assert _is_in(multimethod, "sum"), f"unknown multiobjective method: {multimethod}" assert _is_in(deterministic, False, True, "unknown"), deterministic sig = inspect.signature(fun) params = sig.parameters.values() objectives, nonjectives = _categorize_objectives(objectives) inputs, hypers, positional = _categorize_parameters(params) linear_bounds = [] for key, hyper in hypers.items(): if isinstance(hyper, BoundArray): # TODO: ArrayMixin or whatever lower, upper = hyper.bounds lower = lower if lower.ndim else _repeat(lower) upper = upper if upper.ndim else _repeat(upper) new = [(l, u) for _, l, u in zip(range(hyper.dims), lower, upper)] assert ( len(new) == hyper.dims ), f"length of bounds must equal number of dimensions" linear_bounds += new elif isinstance(hyper, Bound): linear_bounds.append(hyper.bounds) memo = {} memo_secret = object() # dummy object for uniqueness def make_approx(kw_hypers): # print(kw_hypers) def approx(*args): return fun(*args, **kw_hypers) # try: # return fun(*args, **kw_hypers) # except Exception as e: # _print_exception(e) # return -np.inf return approx def make_keywords(params): kw_hypers = {} other_hypers = [] remaining = iter(params) all_realized = True for name, hyper in hypers.items(): if isinstance(hyper, BoundArray): params = [next(remaining) for _ in range(hyper.dims)] kw_hypers[name] = hyper.vectorize(params) elif isinstance(hyper, Bound): param = next(remaining) kw_hypers[name] = hyper.compute(param) else: other_hypers.append((name, hyper)) realized = hyper.realize(kw_hypers, name=name) if not realized: all_realized = False while not all_realized: # TODO: is there a design pattern to accomplish this that makes the user # conscious of the possibility of accidentally infinite-looping? approx = make_approx(kw_hypers) all_realized = True for (name, hyper) in other_hypers: realized = hyper.realize(kw_hypers, name=name, approx=approx) if not realized: all_realized = False return kw_hypers def objective(params): kw_hypers = make_keywords(params) approx = make_approx(kw_hypers) objective_values = [] for subjective in objectives: # forgive me if deterministic is True: key = (subjective,) + tuple(float(x) for x in kw_hypers.values()) if key in memo: value = memo[key] # print("MEMO:", key, value) else: value = subjective.compute_with(approx, **kw_hypers) memo[key] = value else: value = subjective.compute_with(approx, **kw_hypers) objective_values.append(value) if multimethod == "sum": return np.sum(objective_values) elif callable(multimethod): return multimethod(objective_values) else: # huh? just do this i guess. return objective_values[0] cached = None if cache: cache.hash(fun) for obj in _chain(objectives, nonjectives, hypers.values()): for h in obj.needs_hashing(): cache.hash(h) cached = cache.get() if _debug: lament("HASH:", cache.hashed, sep="\n") if cached is None: if _debug: # dirty test for exceptions center = np.mean(linear_bounds, axis=-1) if hypers else () lament("FIRST VALUE:", objective(center), sep="\n") if hypers: res = _biteopt(objective, linear_bounds, iters=budget, **optimizer_kwargs) else: if deterministic is True: res = _evaluator(objective, 1) else: res = _evaluator(objective, budget) if _debug: lament("RES:", res, sep="\n") optimized = res.x else: assert len(cached) == len(linear_bounds) optimized = cached if cache and cached is None: cache.set(optimized) kw_hypers = make_keywords(optimized) _set_new_defaults(fun, positional, kw_hypers) fun.__optimized__ = kw_hypers # secret! if cached is None: fun.__result__ = res return fun def bite(*objectives, **config): def decorator(fun): # this prevents an unnecessary level of indentation. return _bite(fun, *objectives, **config) return decorator # this is similar to default behaviour of having no __all__ variable at all, # but ours ignores modules as well. this allows for `import sys` and such # without clobbering `from our_module import *`. __all__ = [ k for k, v in locals().items() if not inspect.ismodule(v) and not k.startswith("_") ] if __name__ == "__main__": # non-associative example from: # https://numpy.org/doc/stable/reference/generated/numpy.promote_types.html print("S4 or S6?", _promote_all("S", "i1", "u1")) print(np.promote_types("S4", "S6"))