backyard/logic/hex_problem.py

447 lines
14 KiB
Python

#!/usr/bin/env python3
# run like: SOLVER="$(which kissat)" python3 hex_problem.py
from z0 import Goal, SolverFor, BitVec, Bool, \
If, Implies, And, Or, Not, sat, unsat
from random import randrange
if False: # original settings
symmetry = 3 # one of: 1 (none), 2 (180 degrees), 3 (120 deg), 6 (60 deg)
mirrors = None # mirrored symmetries
hexrad = 9 # radius (1 is the smallest possible)
mincost = 1 # can be 0 but it's buggy
maxcost = 5 # colors are only defined up to 6
stumps = 1 # 0 to disable, 1 to enable
corneroffset = 2 # offset from outer edge
corner0 = 1 # cost at the upper left corner; has no effect when symmetry=6
corner1 = 5 # cost at the uppermost corner; can be 0 to allow any cost
breakpoint0 = 10 # minimum total cost for corner0
breakpoint1 = 20 # minimum total cost for corner1
onedollartrail = 10 # asserts there are $1s in range(1, onedollartrail + 1)
desiredtotal = 150 # per symmetrical fragment
special = "falling" # adds extra assertions
specialvalue = 2 # extra assertion parameter
rotation = 0 # an integer (3 = 180 degrees)
else: # alternative settings demonstrating new features
symmetry = 2
mirrors = (True,)
hexrad = 7
mincost = 1
maxcost = 5
stumps = 1
corneroffset = 1
corner0 = 1
corner1 = 5
breakpoint0 = 0
breakpoint1 = 16
onedollartrail = 3
desiredtotal = 128
special = "mincount"
specialvalue = 6
rotation = -1
bits = 9 # BitVec length, might need to be adjusted with radius and desiredtotal
boxwidth = 6
boxheight = 3
###
up, right, down, left = 1, 2, 4, 8
boxchars = " ??└?│┌├?┘─┴┐┤┬┼"
def dummy(*args, **kwargs):
pass
def esc(x):
return "\x1B[" + str(x) + "m"
def zeros(width, height):
return [[0 for x in range(width)] for y in range(height)]
def nones(width, height):
return [[None for x in range(width)] for y in range(height)]
stumpcost = 99
hexcount = 3 * hexrad * (hexrad - 1) + 1
assert (hexcount - 1) % symmetry == 0, symmetry
fragcount = (hexcount - 1) // symmetry
canvaswidth = (boxwidth - 1) * (2 * hexrad - 1) + 1
canvasheight = (boxheight - 1) * (2 * hexrad - 1) + 1
boxcanvas = zeros(canvaswidth, canvasheight)
colorcanvas = nones(canvaswidth, canvasheight)
def clearcanvas():
global boxcanvas
boxcanvas = zeros(canvaswidth, canvasheight)
def drawbox(x, y, text=None, colors=None):
# (x, y) are the coordinates of the top-left corner of the box.
boxcanvas[y][x] |= down | right
boxcanvas[y][x + boxwidth - 1] |= down | left
boxcanvas[y + boxheight - 1][x] |= up | right
boxcanvas[y + boxheight - 1][x + boxwidth - 1] |= up | left
if text != None and boxheight > 2:
assert type(text) == str, type(text)
cy = boxheight // 2
for bx, c in zip(range(1, min(boxwidth - 1, len(text) + 1)), text):
boxcanvas[y + cy][x + bx] = ord(c)
if colors is not None:
colorcanvas[y + cy][x + 1] = list(colors)
colorcanvas[y + cy][x + boxwidth - 1] = [0]
for bx in range(1, boxwidth - 1):
boxcanvas[y][x + bx] |= left | right
for bx in range(1, boxwidth - 1):
boxcanvas[y + boxheight - 1][x + bx] |= left | right
for by in range(1, boxheight - 1):
boxcanvas[y + by][x] |= down | up
for by in range(1, boxheight - 1):
boxcanvas[y + by][x + boxwidth - 1] |= down | up
def rendercanvas():
for y in range(canvasheight):
line = ""
wasbox = False
for x in range(canvaswidth):
v = boxcanvas[y][x]
isbox = 0 < v < 32
colors = colorcanvas[y][x]
if colors is not None:
line += "".join(esc(color) for color in colors)
c = boxchars[v] if v < 32 else chr(v)
wasbox = isbox
line += c
print(line + esc(0))
class Canvas:
def __enter__(self):
clearcanvas()
return None
def __exit__(self, exc_type, exc_val, exc_tb):
rendercanvas()
def hexdist(ax, ay, bx, by):
# ax + ay + az = 0
# except our y is negated (increases downwards instead of upwards)
az = ay - ax
bz = by - bx
return max(abs(ax - bx), abs(ay - by), abs(az - bz))
def symmetry2(contents, radius, mirrors=None):
oneless = radius - 1
double = oneless * 2
mirrors = (False,) if mirrors is None else mirrors
if mirrors[0]:
for y in range(oneless):
for x in range(radius + y):
contents[double - y][x - y + oneless] = contents[y][x]
else:
for y in range(oneless):
for x in range(radius + y):
contents[double - y][double - x] = contents[y][x]
y = oneless
for x in range(oneless):
contents[double - y][double - x] = contents[y][x]
def symmetry3(contents, radius, mirrors=None):
oneless = radius - 1
double = oneless * 2
mirrors = (False, False) if mirrors is None else mirrors
for y in range(oneless):
for x in range(radius):
v = contents[y][x]
w = contents[x][y] if x < oneless else v
# clockwise:
contents[x - y + oneless][double - y] = w if mirrors[0] else v
contents[double - x][y - x + oneless] = w if mirrors[1] else v
def symmetry6(contents, radius, mirrors=None):
oneless = radius - 1
double = oneless * 2
mirrors = [False] * 5 if mirrors is None else mirrors
for y in range(radius):
for x in range(y, oneless):
v = contents[y][x]
w = contents[y][y - x + oneless - 1]
# clockwise:
contents[x][x - y + oneless] = w if mirrors[0] else v
contents[x - y + oneless][double - y] = w if mirrors[1] else v
contents[double - y][double - x] = w if mirrors[2] else v
contents[double - x][y - x + oneless] = w if mirrors[3] else v
contents[y - x + oneless][y] = w if mirrors[4] else v
def symmetrize(contents, radius, symmetry, mirrors=None):
funs = {1: dummy, 2: symmetry2, 3: symmetry3, 6: symmetry6}
f = funs.get(symmetry, None)
assert f is not None, symmetry
f(contents, radius, mirrors)
def rotate(contents, radius, sixths):
oneless = radius - 1
double = oneless * 2
for _ in range(0 if sixths is None else sixths % 6):
for y in range(radius):
for x in range(y, oneless):
def swap(by, bx): # arguments are reversed for consistency
contents[by][bx], contents[y][x] = contents[y][x], contents[by][bx]
swap(x, x - y + oneless) # 1 <--> 2
swap(x - y + oneless, double - y) # 2 <--> 3
swap(double - y, double - x) # 3 <--> 4
swap(double - x, y - x + oneless) # 4 <--> 5
swap(y - x + oneless, y) # 5 <--> 6
size = hexrad * 2 - 1
half = hexrad // 2
odd = hexrad & 1
even = odd ^ 1
def hexiter():
for hx in range(size):
# these equations maximize the number of boxes in the rectangular canvas.
hy0 = (hx + odd) // 2 - half
hy1 = (hx + even) // 2 - half + size
for hy in range(hy0, hy1):
if hexdist(hx, hy, hexrad - 1, hexrad - 1) < hexrad:
yield hx, hy
def hxy():
if symmetry == 6:
for hy in range(hexrad):
for hx in range(hy, hexrad - 1):
yield hx, hy
elif symmetry == 3:
for hy in range(hexrad - 1):
for hx in range(hexrad):
yield hx, hy
elif symmetry == 2:
for hy in range(hexrad):
for hx in range(hexrad + hy):
if hy == hexrad - 1 and hx == hexrad - 1:
break
yield hx, hy
elif symmetry == 1:
for (hx, hy) in hexiter():
if hy != hexrad - 1 or hx != hexrad - 1:
yield hx, hy
def findbox(hx, hy):
x = (boxwidth - 1) * hx
y = (boxheight - 1) * (hy + half) - (boxheight // 2) * hx - even
return x, y
if corner0 != 0:
assert mincost <= corner0 <= maxcost, (mincost, corner0, maxcost)
if corner1 != 0:
assert mincost <= corner1 <= maxcost, (mincost, corner1, maxcost)
assert 0 <= corneroffset < hexrad, (0, corneroffset, hexrad)
egnar = maxcost - mincost + 1
if special == "equalcounts":
denom = sum(range(mincost, maxcost + 1))
# this should catch most issues early.
if stumps == 0:
assert fragcount % egnar == 0, f"{fragcount} indivisible by {egnar}"
if specialvalue == 0:
specialvalue = fragcount // egnar
if desiredtotal == 0 and specialvalue != 0:
desiredtotal = specialvalue * denom
if desiredtotal != 0:
assert desiredtotal % denom == 0, f"{desiredtotal} indivisible by {denom}"
if specialvalue == 0:
specialvalue = desiredtotal // denom
else:
assert specialvalue * denom == desiredtotal, "(equalcounts) that doesn't add up!"
if stumps == 0:
assert specialvalue * egnar == fragcount, "(equalcounts) that doesn't add up!"
if specialvalue != 0:
assert onedollartrail <= specialvalue, "onedollartrail is impossibly large"
if special == "mincount":
mincount = sum(range(mincost, maxcost + 1)) * specialvalue
if desiredtotal != 0:
assert mincount <= desiredtotal, (mincount, desiredtotal)
###
def Min(a, b):
return If(a <= b, a, b)
costs = nones(size, size)
mintotalcosts = nones(size, size)
costs[hexrad - 1][hexrad - 1] = 0
mintotalcosts[hexrad - 1][hexrad - 1] = 0
for (x, y) in hxy():
costs[y][x] = BitVec(f"c{x:X}{y:X}", bits)
for (x, y) in hxy():
mintotalcosts[y][x] = BitVec(f"m{x:X}{y:X}", bits)
symmetrize(costs, hexrad, symmetry, mirrors)
symmetrize(mintotalcosts, hexrad, symmetry, mirrors)
g = Goal()
for (hx, hy) in hxy():
cost = costs[hy][hx]
mtc = mintotalcosts[hy][hx]
if cost is 0:
continue # center point is special
def maybe(hx, hy):
if hx >= 0 and hy >= 0 and hx < size and hy < size:
if hexdist(hx, hy, hexrad - 1, hexrad - 1) < hexrad:
return mintotalcosts[hy][hx]
neighbors = [
maybe(hx + 1, hy), maybe(hx + 1, hy + 1), maybe(hx, hy + 1),
maybe(hx - 1, hy), maybe(hx - 1, hy - 1), maybe(hx, hy - 1),
]
minimum = (1 << bits - 1) - 1
for n in (n for n in neighbors if n is not None):
minimum = Min(n, minimum)
if stumps != 0:
g.add(minimum < mtc)
g.add(mtc == minimum + cost)
if hx == corneroffset and hy == corneroffset:
if corner1 != 0:
g.add(cost == corner1)
if breakpoint1 != 0:
g.add(mtc == breakpoint1)
elif hx == hexrad - 1 and hy == corneroffset:
if corner0 != 0:
g.add(cost == corner0)
if breakpoint0 != 0:
g.add(mtc == breakpoint0)
if mincost != 0:
g.add(cost >= mincost)
orstump = lambda cond: Or(cond, cost == stumpcost) if stumps != 0 else cond
g.add(orstump(cost <= maxcost))
# NOTE: this only really matters when offset > 0.
if breakpoint0 != 0 or breakpoint1 != 0:
g.add(orstump(mtc <= max(breakpoint0, breakpoint1)))
if desiredtotal > 0:
if stumps == 0:
totalcost = sum(costs[y][x] for (x, y) in hxy())
else:
totalcost = sum(If(costs[y][x] == stumpcost, 0, costs[y][x])
for (x, y) in hxy())
g.add(totalcost == desiredtotal)
mincostmul = max(mincost, 1)
for calc in range(1, onedollartrail + 1):
ok = False
for (hx, hy) in hxy():
match = And(costs[hy][hx] == mincostmul,
mintotalcosts[hy][hx] == calc * mincostmul)
ok = Or(ok, match)
g.add(ok)
alpha = "abcdefghijklmnopqrstuvwxyz"
special = "".join(c for c in special.lower() if c in alpha)
if special == "mtctrail":
for calc in range(1, specialvalue + 1):
ok = False
for (hx, hy) in hxy():
ok = Or(ok, mintotalcosts[hy][hx] == calc * mincostmul)
g.add(ok)
if special in ("equalcounts", "mincount", "falling"):
counts = [0 for _ in range(mincost, maxcost + 1)]
for (hx, hy) in hxy():
for i in range(mincost, maxcost + 1):
ind = i - mincost
counts[ind] = counts[ind] + If(costs[hy][hx] == i, 1, 0, bits)
if special == "equalcounts":
if specialvalue == 0:
equalcount = BitVec("equalcount", bits)
else:
equalcount = specialvalue
for ind in range(maxcost - mincost + 1):
g.add(counts[ind] == equalcount)
elif special == "mincount":
for ind in range(maxcost - mincost + 1):
g.add(counts[ind] >= specialvalue)
elif special == "falling":
for ind in range(maxcost - mincost):
if specialvalue >= 0:
g.add(counts[ind] >= counts[ind + 1] + specialvalue)
else:
g.add(counts[ind] >= counts[ind + 1] - abs(specialvalue))
s = SolverFor("QF_BV")
s.add(g)
# NOTE: assuming $SOLVER is picosat:
#seed = randrange(2**32)
#satisfiable = s.check(args=f"-i 3 -s {seed}".split())
satisfiable = s.check()
print("check:", satisfiable)
costcolors = {
stumpcost: (40, 30),
6: (105, 30),
5: (41, 97),
4: (43, 30),
3: (42, 30),
2: (46, 30),
1: (100, 97),
0: (40, 97),
}
if satisfiable == sat:
model = s.model()
realsum, stumpcount = 0, 0
counts = [0 for _ in range(mincost, maxcost + 1)]
rotate(costs, hexrad, rotation)
rotate(mintotalcosts, hexrad, rotation)
with Canvas():
for (hx, hy) in hexiter():
v = model[costs[hy][hx].name] if costs[hy][hx] is not 0 else 0
colors = costcolors[v] if maxcost <= 6 else None
s = " " if v == stumpcost else f" ${v} "
drawbox(*findbox(hx, hy), s, colors=colors)
if stumps != 0 and v == stumpcost:
stumpcount += 1
else:
counts[v - mincost] += 1
realsum += v
print("counts per fragment: ",
", ".join(f"${i + mincost}: {c // symmetry}" for i, c in enumerate(counts)))
if stumps != 0:
print(f"stumps: {stumpcount:3}, {stumpcount // symmetry:3} per fragment")
print(f"total sum: {realsum:3}, {realsum // symmetry:3} per fragment")
with Canvas():
for (hx, hy) in hexiter():
v = model[mintotalcosts[hy][hx].name] if mintotalcosts[hy][hx] is not 0 else 0
drawbox(*findbox(hx, hy), f" {v:2} ")
else:
print("NO SOLUTION")