bitten: update bitten.py
This commit is contained in:
parent
cb2a800044
commit
6c256510d6
1 changed files with 27 additions and 2 deletions
|
@ -156,6 +156,20 @@ def _penalize2(constraints, *x, tol=1e-5, scale=1e10, growth=3.0):
|
||||||
return scale * (unsat + penalsum)
|
return scale * (unsat + penalsum)
|
||||||
|
|
||||||
|
|
||||||
|
def _penalize3(constraints, *x, tol=1e-5, scale=1e10, growth=3.0):
|
||||||
|
# updated from upstream (v2022.19)
|
||||||
|
penalties = _flatten(cons(*x) for cons in constraints)
|
||||||
|
n_con = len(penalties)
|
||||||
|
growth = growth ** (1.0 / n_con) # "ps"
|
||||||
|
increment = 1.0 / np.sqrt(n_con) # "pnsi"
|
||||||
|
penalty, nominal = 0.0, 0.0 # "pns", "pnsm"
|
||||||
|
for p in penalties:
|
||||||
|
p = max(p - tol, 0.0)
|
||||||
|
penalty = penalty * growth + increment + p + p * p * p
|
||||||
|
nominal = nominal * growth + increment
|
||||||
|
return scale * (penalty - nominal + 1.0)
|
||||||
|
|
||||||
|
|
||||||
def penalize(x, constraints, tol=1e-5, *, scale=1e10, growth=4.0):
|
def penalize(x, constraints, tol=1e-5, *, scale=1e10, growth=4.0):
|
||||||
# DEPRECATED
|
# DEPRECATED
|
||||||
return _penalize(constraints, x, tol=tol, scale=scale, growth=growth)
|
return _penalize(constraints, x, tol=tol, scale=scale, growth=growth)
|
||||||
|
@ -266,14 +280,25 @@ class Crossentropy(AbstractError):
|
||||||
|
|
||||||
|
|
||||||
class Constrain(Objective):
|
class Constrain(Objective):
|
||||||
def __init__(self, *constraints, tol=1e-5):
|
def __init__(self, *constraints, tol=1e-5, version=1):
|
||||||
for cons in constraints:
|
for cons in constraints:
|
||||||
assert callable(cons)
|
assert callable(cons)
|
||||||
self.constraints = constraints
|
self.constraints = constraints
|
||||||
self.tol = float(tol)
|
self.tol = float(tol)
|
||||||
|
self.version = version
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self):
|
||||||
|
return self._version + 1
|
||||||
|
|
||||||
|
@version.setter
|
||||||
|
def version(self, version):
|
||||||
|
assert 1 <= version <= 3, version
|
||||||
|
self._version = version - 1
|
||||||
|
self._penalize = [_penalize, _penalize2, _penalize3][self._version]
|
||||||
|
|
||||||
def penalize(self, *args):
|
def penalize(self, *args):
|
||||||
return _penalize(self.constraints, *args, tol=self.tol)
|
return self._penalize(self.constraints, *args, tol=self.tol)
|
||||||
|
|
||||||
def compute_with(self, fun, **kw_hypers):
|
def compute_with(self, fun, **kw_hypers):
|
||||||
return self.penalize(*kw_hypers.values())
|
return self.penalize(*kw_hypers.values())
|
||||||
|
|
Loading…
Add table
Reference in a new issue