From 6c256510d62e90dd3b24a489151bf773f5f1d593 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 13 Jun 2022 22:02:45 +0200 Subject: [PATCH] bitten: update bitten.py --- bitten/bitten.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/bitten/bitten.py b/bitten/bitten.py index 3d28ebe..eaf22ff 100644 --- a/bitten/bitten.py +++ b/bitten/bitten.py @@ -156,6 +156,20 @@ def _penalize2(constraints, *x, tol=1e-5, scale=1e10, growth=3.0): return scale * (unsat + penalsum) +def _penalize3(constraints, *x, tol=1e-5, scale=1e10, growth=3.0): + # updated from upstream (v2022.19) + penalties = _flatten(cons(*x) for cons in constraints) + n_con = len(penalties) + growth = growth ** (1.0 / n_con) # "ps" + increment = 1.0 / np.sqrt(n_con) # "pnsi" + penalty, nominal = 0.0, 0.0 # "pns", "pnsm" + for p in penalties: + p = max(p - tol, 0.0) + penalty = penalty * growth + increment + p + p * p * p + nominal = nominal * growth + increment + return scale * (penalty - nominal + 1.0) + + def penalize(x, constraints, tol=1e-5, *, scale=1e10, growth=4.0): # DEPRECATED return _penalize(constraints, x, tol=tol, scale=scale, growth=growth) @@ -266,14 +280,25 @@ class Crossentropy(AbstractError): class Constrain(Objective): - def __init__(self, *constraints, tol=1e-5): + def __init__(self, *constraints, tol=1e-5, version=1): for cons in constraints: assert callable(cons) self.constraints = constraints self.tol = float(tol) + self.version = version + + @property + def version(self): + return self._version + 1 + + @version.setter + def version(self, version): + assert 1 <= version <= 3, version + self._version = version - 1 + self._penalize = [_penalize, _penalize2, _penalize3][self._version] def penalize(self, *args): - return _penalize(self.constraints, *args, tol=self.tol) + return self._penalize(self.constraints, *args, tol=self.tol) def compute_with(self, fun, **kw_hypers): return self.penalize(*kw_hypers.values())