From e799c9b01c78cf02b538c81674d50870b8ffa3cf Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Wed, 8 Jun 2022 04:31:52 +0200 Subject: [PATCH] logic: add z0.py --- logic/z0.py | 641 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 641 insertions(+) create mode 100644 logic/z0.py diff --git a/logic/z0.py b/logic/z0.py new file mode 100644 index 0000000..7d8a86a --- /dev/null +++ b/logic/z0.py @@ -0,0 +1,641 @@ +# z0: a bit-blaster that implements a tiny subset of z3's interface. + +import cnf + +gcnf = cnf.Problem() + +def Todo(*args, **kwargs): + raise NotImplementedError("TODO") + +def binreduce(lst, fun): + if len(lst) == 1: + return lst + assert len(lst) > 1 + + while len(lst) > 1: + new_lst = [] + for a, b in zip(lst[::2], lst[1::2]): + new_lst.append(fun(a, b)) + if len(lst) % 2 == 1: + new_lst.append(lst[-1]) + lst = new_lst + + return lst[0] + +class Goal: + def __init__(self): + self.used = set() + + def add(self, cond): + if type(cond) is bool: + self.used.add(cond) + gcnf.assert_or(cond) + else: + self.used.add(cond.value) + gcnf.assert_or(cond.value) + +class SolverFor: + def __init__(self, name): + assert name == "QF_BV" + self.goal = None + self.fp = None + + def add(self, goal): + if self.goal is not None: + Todo() + self.goal = Goal + + def check(self, cond=None, args=None, verbose=False): + from tempfile import mkstemp + from os import environ, remove, fdopen + from subprocess import Popen, PIPE + + if cond is None: + newcnf = gcnf + else: + from copy import deepcopy + newcnf = deepcopy(gcnf) + if type(cond) is bool: + newcnf.assert_or(cond) + else: + newcnf.assert_or(cond.value) + + solver = environ.get('SOLVER', 'kissat') + args = [solver] if args is None else [solver] + list(args) + if solver.endswith('cadical'): + args.append('-q') + + if self.fp is not None: + remove(self.fp) + fd, self.fp = mkstemp() + with fdopen(fd, 'w') as f: + newcnf.dimacs(f, verbosity=2 if verbose else 1) + + args.append(self.fp) + p = Popen(args, stdout=PIPE, stderr=PIPE) + self.out, self.err = p.communicate() + + assert p.returncode in (10, 20), (p.returncode, self.out, self.err) + if p.returncode == 10: + return sat + elif p.returncode == 20: + return unsat + + def model(self): + from io import StringIO + mapping = cnf.readcnfcom(fp=self.fp) + sol = StringIO(self.out.decode("utf-8", errors="ignore")) + solution = cnf.readsolutions(f=sol)[0] + assignments = cnf.unmapsolution(solution, mapping) + return assignments + + def __del__(self): + # FIXME: this doesn't work when python is shutting down. + #if self.fp is not None: + # from os import remove + # remove(self.fp) + pass + +def Solver(): + return SolverFor("QF_BV") + +class Bool: + def __init__(self, name, value=None, public=True): + assert type(name) == str, type(name) + self.name = name + new = gcnf.new_var if public else gcnf._new_temp + self.value = new(name) if value is None else value + + @classmethod + def result(cls, name, value=None): + return cls(name, value=value, public=False) + + @classmethod + def constant(cls, value): + return cls("const", value, public=False) + + def __eq__(self, other): + if other is True or other is False: + return self.result("==", gcnf.nots(gcnf.xors(self.value, other))) + assert isinstance(other, Bool), type(other) + return self.result("==", gcnf.nots(gcnf.xors(self.value, other.value))) + + def __ne__(self, other): + if other is True or other is False: + return self.result("!=", gcnf.xors(self.value, other)) + assert isinstance(other, Bool), type(other) + return self.result("!=", gcnf.xors(self.value, other.value)) + +class BitVec: + def __init__(self, name, length, value=None, public=True): + assert type(name) == str, type(name) + assert type(length) == int, type(length) + new = gcnf.new_var if public else gcnf._new_temp + self.name = name + if type(value) is int: + assert value >= 0 and value.bit_length() <= length, value + value = [(value >> i) & 1 == 1 for i in range(length)] + assert value is None or type(value) is list, type(value) + self.value = new(name, length) if value is None else value + self.length = length + assert len(self.value) == self.length + + @classmethod + def result(cls, name, length, value=None): + return cls(name, length, value=value, public=False) + + @classmethod + def constant(cls, value, length): + return cls("const", length, value, public=False) + + @classmethod + def compatible(cls, a, b, opstr): + if type(a) is int: + assert isinstance(b, BitVec), type(b) + #assert a >= 0 and a.bit_length() <= b.length, a + a = cls.constant(a, b.length) + return a.value, b.value, b.length + + elif type(b) is int: + assert isinstance(a, BitVec), type(a) + #assert b >= 0 and b.bit_length() <= a.length, b + b = cls.constant(b, a.length) + return a.value, b.value, a.length + + else: + assert isinstance(a, BitVec), type(a) + assert isinstance(b, BitVec), type(b) + assert a.length == b.length, (a.length, b.length) + return a.value, b.value, a.length + + def __add__(self, other): + a, b, n = self.compatible(self, other, "+") + return self.result("+", n, gcnf.adds(a, b, n)) + + def __radd__(self, other): + a, b, n = self.compatible(other, self, "+") + return self.result("+", n, gcnf.adds(a, b, n)) + + def __sub__(self, other): + a, b, n = self.compatible(self, other, "-") + return self.result("-", n, gcnf.subs(a, b, n)) + + def __rsub__(self, other): + a, b, n = self.compatible(other, self, "-") + return self.result("-", n, gcnf.subs(a, b, n)) + + def __eq__(self, other): + a, b, n = self.compatible(self, other, "==") + temp = [gcnf.nots(gcnf.xors(ai, bi)) for ai, bi in zip(a, b)] + reduced = binreduce(temp, lambda a, b: gcnf.ands(a, b)) + return Bool.result("==", reduced) + + def __ne__(self, other): + a, b, n = self.compatible(self, other, "!=") + temp = [gcnf.xors(ai, bi) for ai, bi in zip(a, b)] + reduced = binreduce(temp, lambda a, b: gcnf.ors(a, b)) + return Bool.result("!=", reduced) + + def __lt__(self, other): + a, b, n = self.compatible(self, other, "<") + carry = gcnf.subs(a, b, n, return_carry="only") + return Bool.result("<", carry) + + def __ge__(self, other): + a, b, n = self.compatible(self, other, ">=") + carry = gcnf.subs(a, b, n, return_carry="only") + return Bool.result(">=", gcnf.nots(carry)) + + def __le__(self, other): + a, b, n = self.compatible(self, other, "<=") + carry = gcnf.subs(a, b, n, return_carry="only", carry=True) + return Bool.result("<=", carry) + + def __gt__(self, other): + a, b, n = self.compatible(self, other, ">") + carry = gcnf.subs(a, b, n, return_carry="only", carry=True) + return Bool.result(">", gcnf.nots(carry)) + + def __iadd__(self, other): + a, b, n = self.compatible(self, other, "+=") + self.value = gcnf.adds(a, b, n) + return self + + def __isub__(self, other): + a, b, n = self.compatible(self, other, "-=") + self.value = gcnf.subs(a, b, n) + return self + + def __xor__(self, other): + a, b, n = self.compatible(self, other, "^") + return self.result("^", n, [gcnf.xors(ai, bi) for ai, bi in zip(a, b)]) + + def __pos__(self): + return self + + def __neg__(self): + zero = [False for _ in range(self.length)] + return self.result("u-", self.length, + gcnf.subs(zero, self.value, self.length)) + + def __abs__(self): + return self + + def __invert__(self): + return self.result("u~", self.length, + [gcnf.nots(x) for x in self.value]) + +def booleate(f): + def g(*args): + newargs = [Bool.constant(a) if type(a) is bool else a for a in args] + for arg in newargs: + assert isinstance(arg, Bool), type(arg) + return f(*newargs) + return g + +@booleate +def Not(a): + return Bool.result("Not", gcnf.nots(a.value)) + +@booleate +def And(a, b): + return Bool.result("And", gcnf.ands(a.value, b.value)) + +@booleate +def Xor(a, b): + return Bool.result("Xor", gcnf.xors(a.value, b.value)) + +@booleate +def Or(a, b): + return Bool.result("Or", gcnf.ors(a.value, b.value)) + +@booleate +def Implies(a, b): + return Bool.result("Implies", gcnf.implies(a.value, b.value)) + +def _maybe_add_one(should_add, var): + result = [False] * (len(var) + 1) + for i in range(len(var) + 1): + shifted = var[i - 1] if i > 0 else True + unshifted = var[i] if i < len(var) else False + result[i] = gcnf.ites(should_add, shifted, unshifted) + return result + +class Unary: + def __init__(self, name, length, value=None, public=True): + assert type(name) == str, type(name) + assert type(length) == int, type(length) + new = gcnf.new_var if public else gcnf._new_temp + self.name = name + if type(value) is int: + assert value >= 0 and value <= length, value + value = [i < value for i in range(length)] + assert value is None or type(value) is list, type(value) + self.value = new(name, length) if value is None else value + self.length = length + self._validate() + + def _validate(self): + assert len(self.value) == self.length + # unary numbers must be in unary form. + if all(type(v) is bool for v in self.value): + for a, b in zip(self.value, self.value[1:]): + assert a == b or a, self.value + else: + for a, b in zip(self.value, self.value[1:]): + gcnf.assert_or(gcnf.nots(b), a) + + @classmethod + def result(cls, name, length, value=None): + return cls(name, length, value=value, public=False) + + @classmethod + def constant(cls, value, length): + return cls("const", length, value, public=False) + + @classmethod + def compatible(cls, a, b, opstr): + if type(a) is int: + assert isinstance(b, Unary), type(b) + a = cls.constant(a, a) + + elif type(b) is int: + assert isinstance(a, Unary), type(a) + b = cls.constant(b, b) + + else: + assert isinstance(a, Unary), type(a) + assert isinstance(b, Unary), type(b) + + length = max(a.length, b.length) + a.extend(length) + b.extend(length) + return a.value, b.value, length + + def extend(self, new_length): + count = new_length - self.length + if count <= 0: + return + self.value = self.value + [False] * count + return self + + def __add__(self, other): + if type(other) is int: + assert other >= 0, other + if other == 0: + result = self.value + else: + result = [True] * other + self.value + while result[-1] is False: + result = result[:-1] + return self.result("+", max(self.length, len(result)), result) + + else: + a, b, n = self.compatible(self, other, "+") + if 0: + c = b + for i in range(n): + c = _maybe_add_one(a[i], c) + else: + c = gcnf.unary_sum(a, b) + return self.result("+", len(c), c) + + def __radd__(self, other): + Todo() + + def __sub__(self, other): + if type(other) is int: + assert other >= 0, other + if other == 0: + result = self.value + elif other >= self.length: + result = [False] * self.length + else: + result = self.value[other:] + [False] * other + return self.result("-", self.length, result) + + else: + Todo() + + def __rsub__(self, other): + Todo() + + def __eq__(self, other): + if type(other) is int: + if other < 0 or other > self.length: + result = False + elif other == self.length: + result = self.value[other - 1] + elif other == 0: + result = gcnf.nots(self.value[0]) + else: + result = gcnf.ands(self.value[other - 1], + gcnf.nots(self.value[other])) + else: + a, b, n = self.compatible(self, other, "==") + temp = [gcnf.nots(gcnf.xors(ai, bi)) for ai, bi in zip(a, b)] + result = binreduce(temp, lambda a, b: gcnf.ands(a, b)) + + return Bool.result("==", result) + + def __ne__(self, other): + if type(other) is int: + if other < 0 or other > self.length: + result = True + elif other == self.length: + result = gcnf.nots(self.value[other - 1]) + elif other == 0: + result = self.value[0] + else: + result = gcnf.ors(gcnf.nots(self.value[other - 1]), + self.value[other]) + else: + a, b, n = self.compatible(self, other, "!=") + temp = [gcnf.xors(ai, bi) for ai, bi in zip(a, b)] + result = binreduce(temp, lambda a, b: gcnf.ors(a, b)) + + return Bool.result("!=", result) + + def __lt__(self, other): + if type(other) is int: + if other > self.length: + result = True + elif other <= 0: + result = False + else: + result = gcnf.nots(self.value[other - 1]) + else: + a, b, n = self.compatible(a, b, "<") + temp = [gcnf.ands(gcnf.nots(a), b) for a, b in zip(a, b)] + result = binreduce(temp, lambda a, b: gcnf.ors(a, b)) + + return Bool.result("<", result) + + def __ge__(self, other): + if type(other) is int: + if other > self.length: + result = False + elif other <= 0: + result = True + else: + result = self.value[other - 1] + else: + a, b, n = self.compatible(a, b, ">=") + temp = [gcnf.ands(gcnf.nots(ai), bi) for ai, bi in zip(a, b)] + result = gcnf.nots(binreduce(temp, lambda a, b: gcnf.ors(a, b))) + + return Bool.result(">=", result) + + def __le__(self, other): + if type(other) is int: + if other >= self.length: + result = True + elif other < 0: + result = False + else: + result = gcnf.nots(self.value[other]) + else: + a, b, n = self.compatible(a, b, "<=") + temp = [gcnf.ands(ai, gcnf.nots(bi)) for ai, bi in zip(a, b)] + result = gcnf.nots(binreduce(temp, lambda a, b: gcnf.ors(a, b))) + + return Bool.result("<=", result) + + def __gt__(self, other): + if type(other) is int: + if other >= self.length: + result = False + elif other < 0: + result = True + else: + result = self.value[other] + else: + a, b, n = self.compatible(a, b, ">") + temp = [gcnf.ands(ai, gcnf.nots(bi)) for ai, bi in zip(a, b)] + result = binreduce(temp, lambda a, b: gcnf.ors(a, b)) + + return Bool.result(">", result) + + def __iadd__(self, other): + Todo() + + def __isub__(self, other): + Todo() + + def __pos__(self): + return self + + def __neg__(self): + Todo() + + def __abs__(self): + return self + + def __invert__(self): + Todo() + + @classmethod + def min(cls, a, b): + a, b, n = cls.compatible(a, b, "min") + return cls.result("min", n, [gcnf.ands(ai, bi) for ai, bi in zip(a, b)]) + + @classmethod + def max(cls, a, b): + a, b, n = cls.compatible(a, b, "max") + return cls.result("max", n, [gcnf.ors(ai, bi) for ai, bi in zip(a, b)]) + +def same_type(a, b): # TODO: Unary support? + if isinstance(a, Bool) and isinstance(b, Bool): + return True + if isinstance(a, BitVec) and isinstance(b, BitVec): + return a.length == b.length + return False + +def If(cond, a, b, length=None): + if cond is True: + return a + elif cond is False: + return b + + if type(a) is int and type(b) is int: + if length is None: + Todo() + else: + assert a >= 0 and a.bit_length() <= length, a + assert b >= 0 and b.bit_length() <= length, b + a = BitVec.constant(a, length) + b = BitVec.constant(b, length) + elif type(a) is int: + assert isinstance(b, BitVec), type(b) + assert a >= 0 and a.bit_length() <= b.length, a + a = BitVec.constant(a, b.length) + elif type(b) is int: + assert isinstance(a, BitVec), type(a) + assert b >= 0 and b.bit_length() <= a.length, b + b = BitVec.constant(b, a.length) + + if type(a) is bool: + a = Bool.constant(a) + if type(b) is bool: + b = Bool.constant(b) + + assert isinstance(cond, Bool), type(cond) + assert same_type(a, b), (type(a), type(b)) + + if isinstance(a, Bool): + return Bool.result("If", gcnf.ites(cond.value, a.value, b.value)) + else: + return BitVec.result("If", a.length, + [gcnf.ites(cond.value, ai, bi) + for ai, bi in zip(a.value, b.value)]) + +def IfUnary(cond, a, b, length=None): # length parameter is unused (for now?) + if cond is True: + return a + elif cond is False: + return b + + if type(a) is int and type(b) is int: + assert a >= 0, a + assert b >= 0, b + length = max(a, b) + a = Unary.constant(a, length) + b = Unary.constant(b, length) + elif type(a) is int: + assert isinstance(b, Unary), type(b) + assert a >= 0, a + length = max(a, b.length) + a = Unary.constant(a, length) + b.extend(length) + elif type(b) is int: + assert isinstance(a, Unary), type(a) + assert b >= 0, b + length = max(a.length, b) + b = Unary.constant(b, length) + a.extend(length) + + if type(a) is bool: + a = Bool.constant(a) + if type(b) is bool: + b = Bool.constant(b) + + assert isinstance(cond, Bool), type(cond) + + if isinstance(a, Bool) and isinstance(b, Bool): + return Bool.result("If", gcnf.ites(cond.value, a.value, b.value)) + elif isinstance(a, Bool) or isinstance(b, Bool): + # one of these two assertions should fail: + assert isinstance(a, Bool), type(a) + assert isinstance(b, Bool), type(b) + elif isinstance(a, Unary) or isinstance(b, Unary): + assert isinstance(a, Unary), type(a) + assert isinstance(b, Unary), type(b) + + if isinstance(a, Bool): + return Bool.result("If", gcnf.ites(cond.value, a.value, b.value)) + else: + return Unary.result("If", a.length, + [gcnf.ites(cond.value, ai, bi) + for ai, bi in zip(a.value, b.value)]) + +def Totalizer(elements): + # TODO: assert each element of x is bool or Bool. + elements = [el.value if isinstance(el, Bool) else el for el in elements] + return Unary.result("Totalizer", len(elements), gcnf.totalizer(elements)) + +class Result: + def __init__(self, result): + self.result = str(result) + + def __str__(self): + return self.result + +sat = Result("sat") +unsat = Result("unsat") + +Int = Todo +Optimize = Todo + +__all__ = """ +binreduce + +Goal +Solver +SolverFor +Optimize +Int +Bool +BitVec +Unary + +If +Implies +And +Xor +Or +Not +Totalizer + +sat +unsat +""".strip().split()