backyard/logic/cnf.py
2022-06-08 04:27:47 +02:00

486 lines
15 KiB
Python

# Copyright (C) 2019 Connor Olding
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import warnings
digits = set('0123456789')
digits_or_negative = digits | {'-'}
def readcnf(f=None, fp=None):
clauses = []
initialized = False
n_vars, n_clauses = None, None
our_file = False
if f is None:
assert fp is not None, 'readcnf requires a file handle or file path'
f = open(fp, 'r')
our_file = True
elif fp is None:
fp = '<unknown>'
i = -1
def failure():
return 'malformed cnf at {}:{}'.format(fp, i + 1)
for i, line in enumerate(f):
tokens = line.strip().split(' ')
if len(tokens) == 0:
continue
cmd = tokens[0]
if cmd == 'c':
continue
elif cmd == 'p':
assert not initialized, failure()
assert len(tokens) == 4, failure()
cnfstr, n_vars, n_clauses = tokens[1:]
assert cnfstr == 'cnf', failure()
for c in n_vars:
assert c in digits, failure()
for c in n_clauses:
assert c in digits, failure()
n_vars = int(n_vars)
n_clauses = int(n_clauses)
initialized = True
elif cmd[0] in digits_or_negative:
assert initialized, failure()
for t in tokens:
if t[0] == '-':
assert len(t) > 1, failure()
assert t[1] != '0', failure()
elif t[0] == '0':
assert len(t) == 1, failure()
for c in t:
assert c in digits_or_negative, failure()
assert int(tokens[-1]) == 0, failure()
terms = tuple(int(t) for t in tokens[:-1])
for t in terms:
assert t != 0, failure()
assert abs(t) <= n_vars, failure()
if len(terms) > 0:
assert len(clauses) < n_clauses, failure()
clauses.append(terms)
else:
assert False, failure()
if our_file:
f.close()
assert initialized, 'empty cnf'
return n_vars, clauses
def readyicescom(names, key, ids):
# helper function for readcnfcom.
int_ids = []
for i, id in enumerate(ids):
if len(id) == 0:
return
if i == 0:
if id[0] != '[':
return
id = id[1:]
elif i + 1 == len(ids):
if id[-1] != ']':
return
id = id[:-1]
if not all(c in digits for c in id):
return
int_ids.append(int(id))
for i, int_id in enumerate(int_ids):
names[key + '.' + str(i)] = int_id
def readcnfcom(f=None, fp=None):
# read the comments describing the names of each term.
# might return an empty dictionary or missing keys; anything goes.
names = {}
our_file = False
if f is None:
assert fp is not None, 'readcnfcom requires a file handle or file path'
f = open(fp, 'r')
our_file = True
elif fp is None:
fp = '<unknown>'
for i, line in enumerate(f):
tokens = line.strip().split(' ')
if len(tokens) < 4 or tokens[0] != 'c':
pass
elif tokens[2] == '=' and len(tokens) == 4:
if all(c in digits for c in tokens[1]):
names[tokens[3]] = int(tokens[1])
elif all(c in digits for c in tokens[3]):
names[tokens[1]] = int(tokens[3])
elif tokens[2] == '-->':
readyicescom(names, tokens[1], tokens[3:])
if our_file:
f.close()
return names
def readsolutions(f=None, fp=None):
# reads any number of solutions from f.
# this is pretty lenient: your solutions are important,
# even if your solver outputs a bunch of unrelated trash.
solutions, solution = [], []
our_file = False
if f is None:
assert fp is not None, ('readsolutions'
' requires a file handle or file path')
f = open(fp, 'r')
our_file = True
elif fp is None:
fp = '<unknown>'
for line in f:
tokens = line.strip().split(' ')
if len(tokens) == 0:
continue
cmd = tokens[0]
if cmd == 's':
# NOTE TO SELF: picosat --all prints "s SOLUTIONS x" at last
if len(solution) > 0:
solutions.append(solution)
solution = []
elif cmd == 'v':
for t in tokens[1:]:
try:
t = int(t)
except ValueError as e:
msg = 'invalid token while parsing solution at {}:{} {}'
warnings.warn(msg.format(fp, i + 1, t))
else:
if t != 0:
solution.append(t)
else:
continue
if len(solution) > 0:
solutions.append(solution)
if our_file:
f.close()
return solutions
def bv_exname(name):
"""splits a name like myvar_abc.123 into ('myvar_abc', 123)"""
if len(name) == 0 or '.' not in name or not name[-1].isdigit():
return None, None
i = len(name) - 1
while name[i] in digits:
i -= 1
if name[i] != '.':
return None, None
exname, bit = name[:i], int(name[i + 1:])
return exname, bit
def unmapsolution(solution, names):
# solution should be a sequence of nonzero integers.
# names should be the result of readcnfcom.
# note that the return value is only as complete as
# the conjunction of solution and names.
results = {}
for ind in solution:
if ind > 0:
results[ind] = 1
elif ind < 0:
results[-ind] = 0
assignments = {}
bitvecs = {}
for name, ind in names.items():
if ind in results:
bv_name, bit = bv_exname(name)
if bv_name is None:
assignments[name] = results[ind]
else:
if bv_name not in bitvecs:
bitvecs[bv_name] = {}
bitvecs[bv_name][bit] = results[ind]
for bv_name, bv in bitvecs.items():
# TODO: check for contiguity?
val = sum(1 << k if v else 0 for k, v in bv.items())
assignments[bv_name] = val
return assignments
class Problem:
def __init__(self):
self.clauses = []
self.i = 0
self.names = {}
self._used = set()
self._counts = {}
# old names for backwards compatibility:
self.assign_or = self.ors
self.assign_and = self.ands
self.assign_xor = self.xors
self.assign_not = self.nots
self.assign_ite = self.ites
self.assign_implies = self.implies
self.ite = self.ites
def var(self, name=None): # TODO: rename?
self.i += 1
if name is not None:
self.names[self.i] = name
return self.i
def _new_temp(self, category, bits=None): # TODO: rename?
count = self._counts.get(category, 0)
varname = f'_{category}{count}'
if bits is None:
bits = 1
var_or_vars = self.var(varname)
else:
var_or_vars = [self.var(varname + '.' + str(i))
for i in range(bits)]
self._counts[category] = count + 1
return var_or_vars
def new_var(self, name, bits=None): # TODO: rename?
assert name not in self._used, name
if bits is None:
bits = 1
var_or_vars = self.var(name)
else:
var_or_vars = [self.var(name + '.' + str(i))
for i in range(bits)]
return var_or_vars
def vars(self, temps=False): # TODO: rename?
# returns a mapping like readcnfcom.
names = {}
for ind, name in self.names.items():
if name.startswith('_') and not temps:
continue
names[name] = ind
return names
def add_clause(self, *args): # TODO: rename?
for a in args:
assert type(a) is int, type(a)
self.clauses.append(args)
def _map(self, clauses, *mapping):
for clause in clauses:
new_clause = [mapping[t - 1] if t > 0 else
self.nots(mapping[-t - 1]) for t in clause]
self.assert_or(*new_clause)
def assert_or(self, *args):
if len(args) == 0:
return
for a in args: # don't do `if True in args:` because that matches `1`
if a is True:
return
args = [a for a in args if a is not False]
assert not all(a is False for a in args), "trivially unsatisfiable"
if len(args) > 0:
self.add_clause(*args)
def assert_and(self, *args):
for a in args:
self.assert_or(a)
def assert_implies(self, a, b):
self.assert_or(self.nots(a), b)
def assert_eq(self, a, b):
self.assert_or(self.nots(self.xors(a, b)))
def nots(self, a):
if type(a) is bool:
return not a
else:
return -a
def ors(self, *args):
for a in args:
if a is True:
return True
args = [a for a in args if a is not False]
if len(args) == 0:
return False
elif len(args) == 1:
return args[0]
v = self._new_temp('or')
for a in args:
self.add_clause(v, -a)
self.add_clause(-v, *args)
return v
def ands(self, *args):
for a in args:
if a is False:
return False
args = [a for a in args if a is not True]
if len(args) == 0:
return True
elif len(args) == 1:
return args[0]
v = self._new_temp('and')
for a in args:
self.add_clause(-v, a)
self.add_clause(v, *(-a for a in args))
return v
def xors(self, a, *args):
if len(args) == 0:
return a
b = args[0]
if type(a) is bool:
if type(b) is bool:
v = (a or b) and (not a or not b)
else:
v = -b if a else b
else:
if type(b) is bool:
v = -a if b else a
else:
v = self._new_temp('xor')
self.add_clause(-v, -a, -b)
self.add_clause(-v, a, b)
self.add_clause(v, -a, b)
self.add_clause(v, a, -b)
for a in args[1:]:
v = self.xors(v, a)
return v
def ites(self, cond, a, b): # ternary
if cond is True:
return a
elif cond is False:
return b
if a == b: # NOTE: this optimization might be redundant.
return a
if a is True:
return self.ors(cond, b)
elif a is False:
return self.ands(-cond, b)
if b is True:
return self.ors(-cond, a)
elif b is False:
return self.ands(cond, a)
v = self._new_temp('ite')
self.add_clause(-v, -cond, a)
self.add_clause(-v, cond, b)
self.add_clause(v, -cond, -a)
self.add_clause(v, cond, -b)
return v
def implies(self, a, b):
return self.ors(self.nots(a), b)
def adds(self, a, b, bits, carry=False, return_carry=False):
c = []
for i in range(bits):
if return_carry != "only":
t = self.xors(carry, self.xors(a[i], b[i]))
c.append(t)
if i + 1 != bits or return_carry:
carry = self.xors(self.ands(a[i], b[i]),
self.ands(carry, a[i]),
self.ands(carry, b[i]))
if return_carry == "only":
return carry
return (c, carry) if return_carry else c
def subs(self, a, b, bits, carry=False, return_carry=False):
if return_carry == "only":
# if i ever implement a post-processing stage with reduction,
# then this could just be: self.subs(..., return_carry=True)[1]
carry = self.adds(a, [self.nots(x) for x in b], bits,
self.nots(carry), "only")
return self.nots(carry)
elif return_carry:
(c, carry) = self.adds(a, [self.nots(x) for x in b], bits,
self.nots(carry), True)
return c, self.nots(carry)
else:
return self.adds(a, [self.nots(x) for x in b], bits,
self.nots(carry), False)
def adder(self, a, b, bits, return_carry=False, carry_in=False):
return self.adds(a, b, bits, carry_in, return_carry)
def subber(self, a, b, bits, return_carry=False, carry_in=False):
return self.subs(a, b, bits, carry_in, return_carry)
def shifter(self, a, b, padding=False): # shifts a left by b
c = list(a)
for j in reversed(range(len(b))):
amount = 1 << j
for i in reversed(range(len(a))):
c[i] = self.ites(b[j],
c[i - amount] if i >= amount else padding,
c[i])
return c
def rotater(self, a, b):
abits = len(a)
if abits == 0:
return a
bbits = abits.bit_length() - 1
assert abits == 1 << bbits, abits # must be power of two
if abits == 1:
return a
c = [False] * abits # dummy values
for j in range(bbits):
amount = 1 << j
for i in range(abits):
c[i] = self.ites(b[j], a[(i - amount) % abits], a[i])
a = list(c)
return c
def dimacs(self, file=None, verbosity=1):
print(f'p cnf {self.i} {len(self.clauses)}', file=file)
if verbosity > 0:
for ind, name in self.names.items():
if name.startswith('_') and verbosity < 2:
continue
print(f'c {ind} = {name}', file=file)
for cls in self.clauses:
print(*cls, 0, file=file)