485 lines
15 KiB
Python
485 lines
15 KiB
Python
# Copyright (C) 2019 Connor Olding
|
|
|
|
# Permission to use, copy, modify, and/or distribute this software for any
|
|
# purpose with or without fee is hereby granted, provided that the above
|
|
# copyright notice and this permission notice appear in all copies.
|
|
|
|
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
|
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
|
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
|
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
|
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
|
|
|
import warnings
|
|
|
|
digits = set('0123456789')
|
|
digits_or_negative = digits | {'-'}
|
|
|
|
|
|
def readcnf(f=None, fp=None):
|
|
clauses = []
|
|
initialized = False
|
|
n_vars, n_clauses = None, None
|
|
|
|
our_file = False
|
|
if f is None:
|
|
assert fp is not None, 'readcnf requires a file handle or file path'
|
|
f = open(fp, 'r')
|
|
our_file = True
|
|
elif fp is None:
|
|
fp = '<unknown>'
|
|
|
|
i = -1
|
|
|
|
def failure():
|
|
return 'malformed cnf at {}:{}'.format(fp, i + 1)
|
|
|
|
for i, line in enumerate(f):
|
|
tokens = line.strip().split(' ')
|
|
if len(tokens) == 0:
|
|
continue
|
|
|
|
cmd = tokens[0]
|
|
|
|
if cmd == 'c':
|
|
continue
|
|
|
|
elif cmd == 'p':
|
|
assert not initialized, failure()
|
|
assert len(tokens) == 4, failure()
|
|
|
|
cnfstr, n_vars, n_clauses = tokens[1:]
|
|
assert cnfstr == 'cnf', failure()
|
|
for c in n_vars:
|
|
assert c in digits, failure()
|
|
for c in n_clauses:
|
|
assert c in digits, failure()
|
|
|
|
n_vars = int(n_vars)
|
|
n_clauses = int(n_clauses)
|
|
initialized = True
|
|
|
|
elif cmd[0] in digits_or_negative:
|
|
assert initialized, failure()
|
|
|
|
for t in tokens:
|
|
if t[0] == '-':
|
|
assert len(t) > 1, failure()
|
|
assert t[1] != '0', failure()
|
|
elif t[0] == '0':
|
|
assert len(t) == 1, failure()
|
|
for c in t:
|
|
assert c in digits_or_negative, failure()
|
|
|
|
assert int(tokens[-1]) == 0, failure()
|
|
terms = tuple(int(t) for t in tokens[:-1])
|
|
|
|
for t in terms:
|
|
assert t != 0, failure()
|
|
assert abs(t) <= n_vars, failure()
|
|
|
|
if len(terms) > 0:
|
|
assert len(clauses) < n_clauses, failure()
|
|
clauses.append(terms)
|
|
|
|
else:
|
|
assert False, failure()
|
|
|
|
if our_file:
|
|
f.close()
|
|
|
|
assert initialized, 'empty cnf'
|
|
return n_vars, clauses
|
|
|
|
|
|
def readyicescom(names, key, ids):
|
|
# helper function for readcnfcom.
|
|
int_ids = []
|
|
for i, id in enumerate(ids):
|
|
if len(id) == 0:
|
|
return
|
|
if i == 0:
|
|
if id[0] != '[':
|
|
return
|
|
id = id[1:]
|
|
elif i + 1 == len(ids):
|
|
if id[-1] != ']':
|
|
return
|
|
id = id[:-1]
|
|
if not all(c in digits for c in id):
|
|
return
|
|
int_ids.append(int(id))
|
|
|
|
for i, int_id in enumerate(int_ids):
|
|
names[key + '.' + str(i)] = int_id
|
|
|
|
|
|
def readcnfcom(f=None, fp=None):
|
|
# read the comments describing the names of each term.
|
|
# might return an empty dictionary or missing keys; anything goes.
|
|
names = {}
|
|
|
|
our_file = False
|
|
if f is None:
|
|
assert fp is not None, 'readcnfcom requires a file handle or file path'
|
|
f = open(fp, 'r')
|
|
our_file = True
|
|
elif fp is None:
|
|
fp = '<unknown>'
|
|
|
|
for i, line in enumerate(f):
|
|
tokens = line.strip().split(' ')
|
|
if len(tokens) < 4 or tokens[0] != 'c':
|
|
pass
|
|
elif tokens[2] == '=' and len(tokens) == 4:
|
|
if all(c in digits for c in tokens[1]):
|
|
names[tokens[3]] = int(tokens[1])
|
|
elif all(c in digits for c in tokens[3]):
|
|
names[tokens[1]] = int(tokens[3])
|
|
elif tokens[2] == '-->':
|
|
readyicescom(names, tokens[1], tokens[3:])
|
|
|
|
if our_file:
|
|
f.close()
|
|
|
|
return names
|
|
|
|
|
|
def readsolutions(f=None, fp=None):
|
|
# reads any number of solutions from f.
|
|
# this is pretty lenient: your solutions are important,
|
|
# even if your solver outputs a bunch of unrelated trash.
|
|
solutions, solution = [], []
|
|
|
|
our_file = False
|
|
if f is None:
|
|
assert fp is not None, ('readsolutions'
|
|
' requires a file handle or file path')
|
|
f = open(fp, 'r')
|
|
our_file = True
|
|
elif fp is None:
|
|
fp = '<unknown>'
|
|
|
|
for line in f:
|
|
tokens = line.strip().split(' ')
|
|
if len(tokens) == 0:
|
|
continue
|
|
cmd = tokens[0]
|
|
|
|
if cmd == 's':
|
|
# NOTE TO SELF: picosat --all prints "s SOLUTIONS x" at last
|
|
if len(solution) > 0:
|
|
solutions.append(solution)
|
|
solution = []
|
|
|
|
elif cmd == 'v':
|
|
for t in tokens[1:]:
|
|
try:
|
|
t = int(t)
|
|
except ValueError as e:
|
|
msg = 'invalid token while parsing solution at {}:{} {}'
|
|
warnings.warn(msg.format(fp, i + 1, t))
|
|
else:
|
|
if t != 0:
|
|
solution.append(t)
|
|
|
|
else:
|
|
continue
|
|
|
|
if len(solution) > 0:
|
|
solutions.append(solution)
|
|
|
|
if our_file:
|
|
f.close()
|
|
|
|
return solutions
|
|
|
|
|
|
def bv_exname(name):
|
|
"""splits a name like myvar_abc.123 into ('myvar_abc', 123)"""
|
|
if len(name) == 0 or '.' not in name or not name[-1].isdigit():
|
|
return None, None
|
|
i = len(name) - 1
|
|
while name[i] in digits:
|
|
i -= 1
|
|
if name[i] != '.':
|
|
return None, None
|
|
exname, bit = name[:i], int(name[i + 1:])
|
|
return exname, bit
|
|
|
|
|
|
def unmapsolution(solution, names):
|
|
# solution should be a sequence of nonzero integers.
|
|
# names should be the result of readcnfcom.
|
|
# note that the return value is only as complete as
|
|
# the conjunction of solution and names.
|
|
|
|
results = {}
|
|
for ind in solution:
|
|
if ind > 0:
|
|
results[ind] = 1
|
|
elif ind < 0:
|
|
results[-ind] = 0
|
|
|
|
assignments = {}
|
|
bitvecs = {}
|
|
for name, ind in names.items():
|
|
if ind in results:
|
|
bv_name, bit = bv_exname(name)
|
|
if bv_name is None:
|
|
assignments[name] = results[ind]
|
|
else:
|
|
if bv_name not in bitvecs:
|
|
bitvecs[bv_name] = {}
|
|
bitvecs[bv_name][bit] = results[ind]
|
|
|
|
for bv_name, bv in bitvecs.items():
|
|
# TODO: check for contiguity?
|
|
val = sum(1 << k if v else 0 for k, v in bv.items())
|
|
assignments[bv_name] = val
|
|
|
|
return assignments
|
|
|
|
|
|
class Problem:
|
|
def __init__(self):
|
|
self.clauses = []
|
|
self.i = 0
|
|
self.names = {}
|
|
|
|
self._used = set()
|
|
self._counts = {}
|
|
|
|
# old names for backwards compatibility:
|
|
self.assign_or = self.ors
|
|
self.assign_and = self.ands
|
|
self.assign_xor = self.xors
|
|
self.assign_not = self.nots
|
|
self.assign_ite = self.ites
|
|
self.assign_implies = self.implies
|
|
self.ite = self.ites
|
|
|
|
def var(self, name=None): # TODO: rename?
|
|
self.i += 1
|
|
if name is not None:
|
|
self.names[self.i] = name
|
|
return self.i
|
|
|
|
def _new_temp(self, category, bits=None): # TODO: rename?
|
|
count = self._counts.get(category, 0)
|
|
varname = f'_{category}{count}'
|
|
if bits is None:
|
|
bits = 1
|
|
var_or_vars = self.var(varname)
|
|
else:
|
|
var_or_vars = [self.var(varname + '.' + str(i))
|
|
for i in range(bits)]
|
|
self._counts[category] = count + 1
|
|
return var_or_vars
|
|
|
|
def new_var(self, name, bits=None): # TODO: rename?
|
|
assert name not in self._used, name
|
|
if bits is None:
|
|
bits = 1
|
|
var_or_vars = self.var(name)
|
|
else:
|
|
var_or_vars = [self.var(name + '.' + str(i))
|
|
for i in range(bits)]
|
|
return var_or_vars
|
|
|
|
def vars(self, temps=False): # TODO: rename?
|
|
# returns a mapping like readcnfcom.
|
|
names = {}
|
|
for ind, name in self.names.items():
|
|
if name.startswith('_') and not temps:
|
|
continue
|
|
names[name] = ind
|
|
return names
|
|
|
|
def add_clause(self, *args): # TODO: rename?
|
|
for a in args:
|
|
assert type(a) is int, type(a)
|
|
self.clauses.append(args)
|
|
|
|
def _map(self, clauses, *mapping):
|
|
for clause in clauses:
|
|
new_clause = [mapping[t - 1] if t > 0 else
|
|
self.nots(mapping[-t - 1]) for t in clause]
|
|
self.assert_or(*new_clause)
|
|
|
|
def assert_or(self, *args):
|
|
if len(args) == 0:
|
|
return
|
|
for a in args: # don't do `if True in args:` because that matches `1`
|
|
if a is True:
|
|
return
|
|
args = [a for a in args if a is not False]
|
|
assert not all(a is False for a in args), "trivially unsatisfiable"
|
|
if len(args) > 0:
|
|
self.add_clause(*args)
|
|
|
|
def assert_and(self, *args):
|
|
for a in args:
|
|
self.assert_or(a)
|
|
|
|
def assert_implies(self, a, b):
|
|
self.assert_or(self.nots(a), b)
|
|
|
|
def assert_eq(self, a, b):
|
|
self.assert_or(self.nots(self.xors(a, b)))
|
|
|
|
def nots(self, a):
|
|
if type(a) is bool:
|
|
return not a
|
|
else:
|
|
return -a
|
|
|
|
def ors(self, *args):
|
|
for a in args:
|
|
if a is True:
|
|
return True
|
|
args = [a for a in args if a is not False]
|
|
if len(args) == 0:
|
|
return False
|
|
elif len(args) == 1:
|
|
return args[0]
|
|
v = self._new_temp('or')
|
|
for a in args:
|
|
self.add_clause(v, -a)
|
|
self.add_clause(-v, *args)
|
|
return v
|
|
|
|
def ands(self, *args):
|
|
for a in args:
|
|
if a is False:
|
|
return False
|
|
args = [a for a in args if a is not True]
|
|
if len(args) == 0:
|
|
return True
|
|
elif len(args) == 1:
|
|
return args[0]
|
|
v = self._new_temp('and')
|
|
for a in args:
|
|
self.add_clause(-v, a)
|
|
self.add_clause(v, *(-a for a in args))
|
|
return v
|
|
|
|
def xors(self, a, *args):
|
|
if len(args) == 0:
|
|
return a
|
|
b = args[0]
|
|
if type(a) is bool:
|
|
if type(b) is bool:
|
|
v = (a or b) and (not a or not b)
|
|
else:
|
|
v = -b if a else b
|
|
else:
|
|
if type(b) is bool:
|
|
v = -a if b else a
|
|
else:
|
|
v = self._new_temp('xor')
|
|
self.add_clause(-v, -a, -b)
|
|
self.add_clause(-v, a, b)
|
|
self.add_clause(v, -a, b)
|
|
self.add_clause(v, a, -b)
|
|
for a in args[1:]:
|
|
v = self.xors(v, a)
|
|
return v
|
|
|
|
def ites(self, cond, a, b): # ternary
|
|
if cond is True:
|
|
return a
|
|
elif cond is False:
|
|
return b
|
|
if a == b: # NOTE: this optimization might be redundant.
|
|
return a
|
|
if a is True:
|
|
return self.ors(cond, b)
|
|
elif a is False:
|
|
return self.ands(-cond, b)
|
|
if b is True:
|
|
return self.ors(-cond, a)
|
|
elif b is False:
|
|
return self.ands(cond, a)
|
|
v = self._new_temp('ite')
|
|
self.add_clause(-v, -cond, a)
|
|
self.add_clause(-v, cond, b)
|
|
self.add_clause(v, -cond, -a)
|
|
self.add_clause(v, cond, -b)
|
|
return v
|
|
|
|
def implies(self, a, b):
|
|
return self.ors(self.nots(a), b)
|
|
|
|
def adds(self, a, b, bits, carry=False, return_carry=False):
|
|
c = []
|
|
for i in range(bits):
|
|
if return_carry != "only":
|
|
t = self.xors(carry, self.xors(a[i], b[i]))
|
|
c.append(t)
|
|
if i + 1 != bits or return_carry:
|
|
carry = self.xors(self.ands(a[i], b[i]),
|
|
self.ands(carry, a[i]),
|
|
self.ands(carry, b[i]))
|
|
if return_carry == "only":
|
|
return carry
|
|
return (c, carry) if return_carry else c
|
|
|
|
def subs(self, a, b, bits, carry=False, return_carry=False):
|
|
if return_carry == "only":
|
|
# if i ever implement a post-processing stage with reduction,
|
|
# then this could just be: self.subs(..., return_carry=True)[1]
|
|
carry = self.adds(a, [self.nots(x) for x in b], bits,
|
|
self.nots(carry), "only")
|
|
return self.nots(carry)
|
|
elif return_carry:
|
|
(c, carry) = self.adds(a, [self.nots(x) for x in b], bits,
|
|
self.nots(carry), True)
|
|
return c, self.nots(carry)
|
|
else:
|
|
return self.adds(a, [self.nots(x) for x in b], bits,
|
|
self.nots(carry), False)
|
|
|
|
def adder(self, a, b, bits, return_carry=False, carry_in=False):
|
|
return self.adds(a, b, bits, carry_in, return_carry)
|
|
|
|
def subber(self, a, b, bits, return_carry=False, carry_in=False):
|
|
return self.subs(a, b, bits, carry_in, return_carry)
|
|
|
|
def shifter(self, a, b, padding=False): # shifts a left by b
|
|
c = list(a)
|
|
for j in reversed(range(len(b))):
|
|
amount = 1 << j
|
|
for i in reversed(range(len(a))):
|
|
c[i] = self.ites(b[j],
|
|
c[i - amount] if i >= amount else padding,
|
|
c[i])
|
|
return c
|
|
|
|
def rotater(self, a, b):
|
|
abits = len(a)
|
|
if abits == 0:
|
|
return a
|
|
bbits = abits.bit_length() - 1
|
|
assert abits == 1 << bbits, abits # must be power of two
|
|
if abits == 1:
|
|
return a
|
|
c = [False] * abits # dummy values
|
|
for j in range(bbits):
|
|
amount = 1 << j
|
|
for i in range(abits):
|
|
c[i] = self.ites(b[j], a[(i - amount) % abits], a[i])
|
|
a = list(c)
|
|
return c
|
|
|
|
def dimacs(self, file=None, verbosity=1):
|
|
print(f'p cnf {self.i} {len(self.clauses)}', file=file)
|
|
if verbosity > 0:
|
|
for ind, name in self.names.items():
|
|
if name.startswith('_') and verbosity < 2:
|
|
continue
|
|
print(f'c {ind} = {name}', file=file)
|
|
for cls in self.clauses:
|
|
print(*cls, 0, file=file)
|