backyard/logic/z0.py
2022-06-08 04:33:13 +02:00

641 lines
19 KiB
Python

# z0: a bit-blaster that implements a tiny subset of z3's interface.
import cnf
gcnf = cnf.Problem()
def Todo(*args, **kwargs):
raise NotImplementedError("TODO")
def binreduce(lst, fun):
if len(lst) == 1:
return lst
assert len(lst) > 1
while len(lst) > 1:
new_lst = []
for a, b in zip(lst[::2], lst[1::2]):
new_lst.append(fun(a, b))
if len(lst) % 2 == 1:
new_lst.append(lst[-1])
lst = new_lst
return lst[0]
class Goal:
def __init__(self):
self.used = set()
def add(self, cond):
if type(cond) is bool:
self.used.add(cond)
gcnf.assert_or(cond)
else:
self.used.add(cond.value)
gcnf.assert_or(cond.value)
class SolverFor:
def __init__(self, name):
assert name == "QF_BV"
self.goal = None
self.fp = None
def add(self, goal):
if self.goal is not None:
Todo()
self.goal = Goal
def check(self, cond=None, args=None, verbose=False):
from tempfile import mkstemp
from os import environ, remove, fdopen
from subprocess import Popen, PIPE
if cond is None:
newcnf = gcnf
else:
from copy import deepcopy
newcnf = deepcopy(gcnf)
if type(cond) is bool:
newcnf.assert_or(cond)
else:
newcnf.assert_or(cond.value)
solver = environ.get('SOLVER', 'kissat')
args = [solver] if args is None else [solver] + list(args)
if solver.endswith('cadical'):
args.append('-q')
if self.fp is not None:
remove(self.fp)
fd, self.fp = mkstemp()
with fdopen(fd, 'w') as f:
newcnf.dimacs(f, verbosity=2 if verbose else 1)
args.append(self.fp)
p = Popen(args, stdout=PIPE, stderr=PIPE)
self.out, self.err = p.communicate()
assert p.returncode in (10, 20), (p.returncode, self.out, self.err)
if p.returncode == 10:
return sat
elif p.returncode == 20:
return unsat
def model(self):
from io import StringIO
mapping = cnf.readcnfcom(fp=self.fp)
sol = StringIO(self.out.decode("utf-8", errors="ignore"))
solution = cnf.readsolutions(f=sol)[0]
assignments = cnf.unmapsolution(solution, mapping)
return assignments
def __del__(self):
# FIXME: this doesn't work when python is shutting down.
#if self.fp is not None:
# from os import remove
# remove(self.fp)
pass
def Solver():
return SolverFor("QF_BV")
class Bool:
def __init__(self, name, value=None, public=True):
assert type(name) == str, type(name)
self.name = name
new = gcnf.new_var if public else gcnf._new_temp
self.value = new(name) if value is None else value
@classmethod
def result(cls, name, value=None):
return cls(name, value=value, public=False)
@classmethod
def constant(cls, value):
return cls("const", value, public=False)
def __eq__(self, other):
if other is True or other is False:
return self.result("==", gcnf.nots(gcnf.xors(self.value, other)))
assert isinstance(other, Bool), type(other)
return self.result("==", gcnf.nots(gcnf.xors(self.value, other.value)))
def __ne__(self, other):
if other is True or other is False:
return self.result("!=", gcnf.xors(self.value, other))
assert isinstance(other, Bool), type(other)
return self.result("!=", gcnf.xors(self.value, other.value))
class BitVec:
def __init__(self, name, length, value=None, public=True):
assert type(name) == str, type(name)
assert type(length) == int, type(length)
new = gcnf.new_var if public else gcnf._new_temp
self.name = name
if type(value) is int:
assert value >= 0 and value.bit_length() <= length, value
value = [(value >> i) & 1 == 1 for i in range(length)]
assert value is None or type(value) is list, type(value)
self.value = new(name, length) if value is None else value
self.length = length
assert len(self.value) == self.length
@classmethod
def result(cls, name, length, value=None):
return cls(name, length, value=value, public=False)
@classmethod
def constant(cls, value, length):
return cls("const", length, value, public=False)
@classmethod
def compatible(cls, a, b, opstr):
if type(a) is int:
assert isinstance(b, BitVec), type(b)
#assert a >= 0 and a.bit_length() <= b.length, a
a = cls.constant(a, b.length)
return a.value, b.value, b.length
elif type(b) is int:
assert isinstance(a, BitVec), type(a)
#assert b >= 0 and b.bit_length() <= a.length, b
b = cls.constant(b, a.length)
return a.value, b.value, a.length
else:
assert isinstance(a, BitVec), type(a)
assert isinstance(b, BitVec), type(b)
assert a.length == b.length, (a.length, b.length)
return a.value, b.value, a.length
def __add__(self, other):
a, b, n = self.compatible(self, other, "+")
return self.result("+", n, gcnf.adds(a, b, n))
def __radd__(self, other):
a, b, n = self.compatible(other, self, "+")
return self.result("+", n, gcnf.adds(a, b, n))
def __sub__(self, other):
a, b, n = self.compatible(self, other, "-")
return self.result("-", n, gcnf.subs(a, b, n))
def __rsub__(self, other):
a, b, n = self.compatible(other, self, "-")
return self.result("-", n, gcnf.subs(a, b, n))
def __eq__(self, other):
a, b, n = self.compatible(self, other, "==")
temp = [gcnf.nots(gcnf.xors(ai, bi)) for ai, bi in zip(a, b)]
reduced = binreduce(temp, lambda a, b: gcnf.ands(a, b))
return Bool.result("==", reduced)
def __ne__(self, other):
a, b, n = self.compatible(self, other, "!=")
temp = [gcnf.xors(ai, bi) for ai, bi in zip(a, b)]
reduced = binreduce(temp, lambda a, b: gcnf.ors(a, b))
return Bool.result("!=", reduced)
def __lt__(self, other):
a, b, n = self.compatible(self, other, "<")
carry = gcnf.subs(a, b, n, return_carry="only")
return Bool.result("<", carry)
def __ge__(self, other):
a, b, n = self.compatible(self, other, ">=")
carry = gcnf.subs(a, b, n, return_carry="only")
return Bool.result(">=", gcnf.nots(carry))
def __le__(self, other):
a, b, n = self.compatible(self, other, "<=")
carry = gcnf.subs(a, b, n, return_carry="only", carry=True)
return Bool.result("<=", carry)
def __gt__(self, other):
a, b, n = self.compatible(self, other, ">")
carry = gcnf.subs(a, b, n, return_carry="only", carry=True)
return Bool.result(">", gcnf.nots(carry))
def __iadd__(self, other):
a, b, n = self.compatible(self, other, "+=")
self.value = gcnf.adds(a, b, n)
return self
def __isub__(self, other):
a, b, n = self.compatible(self, other, "-=")
self.value = gcnf.subs(a, b, n)
return self
def __xor__(self, other):
a, b, n = self.compatible(self, other, "^")
return self.result("^", n, [gcnf.xors(ai, bi) for ai, bi in zip(a, b)])
def __pos__(self):
return self
def __neg__(self):
zero = [False for _ in range(self.length)]
return self.result("u-", self.length,
gcnf.subs(zero, self.value, self.length))
def __abs__(self):
return self
def __invert__(self):
return self.result("u~", self.length,
[gcnf.nots(x) for x in self.value])
def booleate(f):
def g(*args):
newargs = [Bool.constant(a) if type(a) is bool else a for a in args]
for arg in newargs:
assert isinstance(arg, Bool), type(arg)
return f(*newargs)
return g
@booleate
def Not(a):
return Bool.result("Not", gcnf.nots(a.value))
@booleate
def And(a, b):
return Bool.result("And", gcnf.ands(a.value, b.value))
@booleate
def Xor(a, b):
return Bool.result("Xor", gcnf.xors(a.value, b.value))
@booleate
def Or(a, b):
return Bool.result("Or", gcnf.ors(a.value, b.value))
@booleate
def Implies(a, b):
return Bool.result("Implies", gcnf.implies(a.value, b.value))
def _maybe_add_one(should_add, var):
result = [False] * (len(var) + 1)
for i in range(len(var) + 1):
shifted = var[i - 1] if i > 0 else True
unshifted = var[i] if i < len(var) else False
result[i] = gcnf.ites(should_add, shifted, unshifted)
return result
class Unary:
def __init__(self, name, length, value=None, public=True):
assert type(name) == str, type(name)
assert type(length) == int, type(length)
new = gcnf.new_var if public else gcnf._new_temp
self.name = name
if type(value) is int:
assert value >= 0 and value <= length, value
value = [i < value for i in range(length)]
assert value is None or type(value) is list, type(value)
self.value = new(name, length) if value is None else value
self.length = length
self._validate()
def _validate(self):
assert len(self.value) == self.length
# unary numbers must be in unary form.
if all(type(v) is bool for v in self.value):
for a, b in zip(self.value, self.value[1:]):
assert a == b or a, self.value
else:
for a, b in zip(self.value, self.value[1:]):
gcnf.assert_or(gcnf.nots(b), a)
@classmethod
def result(cls, name, length, value=None):
return cls(name, length, value=value, public=False)
@classmethod
def constant(cls, value, length):
return cls("const", length, value, public=False)
@classmethod
def compatible(cls, a, b, opstr):
if type(a) is int:
assert isinstance(b, Unary), type(b)
a = cls.constant(a, a)
elif type(b) is int:
assert isinstance(a, Unary), type(a)
b = cls.constant(b, b)
else:
assert isinstance(a, Unary), type(a)
assert isinstance(b, Unary), type(b)
length = max(a.length, b.length)
a.extend(length)
b.extend(length)
return a.value, b.value, length
def extend(self, new_length):
count = new_length - self.length
if count <= 0:
return
self.value = self.value + [False] * count
return self
def __add__(self, other):
if type(other) is int:
assert other >= 0, other
if other == 0:
result = self.value
else:
result = [True] * other + self.value
while result[-1] is False:
result = result[:-1]
return self.result("+", max(self.length, len(result)), result)
else:
a, b, n = self.compatible(self, other, "+")
if 0:
c = b
for i in range(n):
c = _maybe_add_one(a[i], c)
else:
c = gcnf.unary_sum(a, b)
return self.result("+", len(c), c)
def __radd__(self, other):
Todo()
def __sub__(self, other):
if type(other) is int:
assert other >= 0, other
if other == 0:
result = self.value
elif other >= self.length:
result = [False] * self.length
else:
result = self.value[other:] + [False] * other
return self.result("-", self.length, result)
else:
Todo()
def __rsub__(self, other):
Todo()
def __eq__(self, other):
if type(other) is int:
if other < 0 or other > self.length:
result = False
elif other == self.length:
result = self.value[other - 1]
elif other == 0:
result = gcnf.nots(self.value[0])
else:
result = gcnf.ands(self.value[other - 1],
gcnf.nots(self.value[other]))
else:
a, b, n = self.compatible(self, other, "==")
temp = [gcnf.nots(gcnf.xors(ai, bi)) for ai, bi in zip(a, b)]
result = binreduce(temp, lambda a, b: gcnf.ands(a, b))
return Bool.result("==", result)
def __ne__(self, other):
if type(other) is int:
if other < 0 or other > self.length:
result = True
elif other == self.length:
result = gcnf.nots(self.value[other - 1])
elif other == 0:
result = self.value[0]
else:
result = gcnf.ors(gcnf.nots(self.value[other - 1]),
self.value[other])
else:
a, b, n = self.compatible(self, other, "!=")
temp = [gcnf.xors(ai, bi) for ai, bi in zip(a, b)]
result = binreduce(temp, lambda a, b: gcnf.ors(a, b))
return Bool.result("!=", result)
def __lt__(self, other):
if type(other) is int:
if other > self.length:
result = True
elif other <= 0:
result = False
else:
result = gcnf.nots(self.value[other - 1])
else:
a, b, n = self.compatible(a, b, "<")
temp = [gcnf.ands(gcnf.nots(a), b) for a, b in zip(a, b)]
result = binreduce(temp, lambda a, b: gcnf.ors(a, b))
return Bool.result("<", result)
def __ge__(self, other):
if type(other) is int:
if other > self.length:
result = False
elif other <= 0:
result = True
else:
result = self.value[other - 1]
else:
a, b, n = self.compatible(a, b, ">=")
temp = [gcnf.ands(gcnf.nots(ai), bi) for ai, bi in zip(a, b)]
result = gcnf.nots(binreduce(temp, lambda a, b: gcnf.ors(a, b)))
return Bool.result(">=", result)
def __le__(self, other):
if type(other) is int:
if other >= self.length:
result = True
elif other < 0:
result = False
else:
result = gcnf.nots(self.value[other])
else:
a, b, n = self.compatible(a, b, "<=")
temp = [gcnf.ands(ai, gcnf.nots(bi)) for ai, bi in zip(a, b)]
result = gcnf.nots(binreduce(temp, lambda a, b: gcnf.ors(a, b)))
return Bool.result("<=", result)
def __gt__(self, other):
if type(other) is int:
if other >= self.length:
result = False
elif other < 0:
result = True
else:
result = self.value[other]
else:
a, b, n = self.compatible(a, b, ">")
temp = [gcnf.ands(ai, gcnf.nots(bi)) for ai, bi in zip(a, b)]
result = binreduce(temp, lambda a, b: gcnf.ors(a, b))
return Bool.result(">", result)
def __iadd__(self, other):
Todo()
def __isub__(self, other):
Todo()
def __pos__(self):
return self
def __neg__(self):
Todo()
def __abs__(self):
return self
def __invert__(self):
Todo()
@classmethod
def min(cls, a, b):
a, b, n = cls.compatible(a, b, "min")
return cls.result("min", n, [gcnf.ands(ai, bi) for ai, bi in zip(a, b)])
@classmethod
def max(cls, a, b):
a, b, n = cls.compatible(a, b, "max")
return cls.result("max", n, [gcnf.ors(ai, bi) for ai, bi in zip(a, b)])
def same_type(a, b): # TODO: Unary support?
if isinstance(a, Bool) and isinstance(b, Bool):
return True
if isinstance(a, BitVec) and isinstance(b, BitVec):
return a.length == b.length
return False
def If(cond, a, b, length=None):
if cond is True:
return a
elif cond is False:
return b
if type(a) is int and type(b) is int:
if length is None:
Todo()
else:
assert a >= 0 and a.bit_length() <= length, a
assert b >= 0 and b.bit_length() <= length, b
a = BitVec.constant(a, length)
b = BitVec.constant(b, length)
elif type(a) is int:
assert isinstance(b, BitVec), type(b)
assert a >= 0 and a.bit_length() <= b.length, a
a = BitVec.constant(a, b.length)
elif type(b) is int:
assert isinstance(a, BitVec), type(a)
assert b >= 0 and b.bit_length() <= a.length, b
b = BitVec.constant(b, a.length)
if type(a) is bool:
a = Bool.constant(a)
if type(b) is bool:
b = Bool.constant(b)
assert isinstance(cond, Bool), type(cond)
assert same_type(a, b), (type(a), type(b))
if isinstance(a, Bool):
return Bool.result("If", gcnf.ites(cond.value, a.value, b.value))
else:
return BitVec.result("If", a.length,
[gcnf.ites(cond.value, ai, bi)
for ai, bi in zip(a.value, b.value)])
def IfUnary(cond, a, b, length=None): # length parameter is unused (for now?)
if cond is True:
return a
elif cond is False:
return b
if type(a) is int and type(b) is int:
assert a >= 0, a
assert b >= 0, b
length = max(a, b)
a = Unary.constant(a, length)
b = Unary.constant(b, length)
elif type(a) is int:
assert isinstance(b, Unary), type(b)
assert a >= 0, a
length = max(a, b.length)
a = Unary.constant(a, length)
b.extend(length)
elif type(b) is int:
assert isinstance(a, Unary), type(a)
assert b >= 0, b
length = max(a.length, b)
b = Unary.constant(b, length)
a.extend(length)
if type(a) is bool:
a = Bool.constant(a)
if type(b) is bool:
b = Bool.constant(b)
assert isinstance(cond, Bool), type(cond)
if isinstance(a, Bool) and isinstance(b, Bool):
return Bool.result("If", gcnf.ites(cond.value, a.value, b.value))
elif isinstance(a, Bool) or isinstance(b, Bool):
# one of these two assertions should fail:
assert isinstance(a, Bool), type(a)
assert isinstance(b, Bool), type(b)
elif isinstance(a, Unary) or isinstance(b, Unary):
assert isinstance(a, Unary), type(a)
assert isinstance(b, Unary), type(b)
if isinstance(a, Bool):
return Bool.result("If", gcnf.ites(cond.value, a.value, b.value))
else:
return Unary.result("If", a.length,
[gcnf.ites(cond.value, ai, bi)
for ai, bi in zip(a.value, b.value)])
def Totalizer(elements):
# TODO: assert each element of x is bool or Bool.
elements = [el.value if isinstance(el, Bool) else el for el in elements]
return Unary.result("Totalizer", len(elements), gcnf.totalizer(elements))
class Result:
def __init__(self, result):
self.result = str(result)
def __str__(self):
return self.result
sat = Result("sat")
unsat = Result("unsat")
Int = Todo
Optimize = Todo
__all__ = """
binreduce
Goal
Solver
SolverFor
Optimize
Int
Bool
BitVec
Unary
If
Implies
And
Xor
Or
Not
Totalizer
sat
unsat
""".strip().split()