direct: update birect.py
This commit is contained in:
parent
993d3fde72
commit
4c97502e84
1 changed files with 79 additions and 43 deletions
122
direct/birect.py
122
direct/birect.py
|
@ -1,18 +1,32 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (C) 2022 Connor Olding
|
||||||
|
|
||||||
|
# Permission to use, copy, modify, and/or distribute this software for any
|
||||||
|
# purpose with or without fee is hereby granted, provided that the above
|
||||||
|
# copyright notice and this permission notice appear in all copies.
|
||||||
|
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
|
||||||
def birect(
|
def birect(
|
||||||
obj,
|
obj, # objective function to find the minimum value of
|
||||||
lo,
|
lo, # list of lower bounds for each problem dimension
|
||||||
hi,
|
hi, # list of upper bounds for each problem dimension
|
||||||
*,
|
*,
|
||||||
min_diag,
|
min_diag, # never subdivide hyper-rectangles below this length
|
||||||
min_error=None,
|
min_error=None, # exit when the objective function achieves this error
|
||||||
max_evals=None,
|
max_evals=None, # exit when the objective function has been run this many times
|
||||||
max_iters=None,
|
max_iters=None, # exit when the optimization procedure iterates this many times
|
||||||
by_longest=False,
|
by_longest=False, # measure by rects by longest edge instead of their diagonal
|
||||||
pruning=0,
|
pruning=0,
|
||||||
F=float,
|
F=float, # can be float, np.float32, np.float64, decimal, or anything float-like
|
||||||
):
|
):
|
||||||
assert len(lo) == len(hi), "dimensional mismatch"
|
assert len(lo) == len(hi), "dimensional mismatch"
|
||||||
|
|
||||||
|
@ -26,31 +40,43 @@ def birect(
|
||||||
# aside: xmin should actually be called v_xmin, but it's not!
|
# aside: xmin should actually be called v_xmin, but it's not!
|
||||||
|
|
||||||
def fun(w_t):
|
def fun(w_t):
|
||||||
# xs = [l + (h - l) * t for t in w_t]
|
# final conversion from exact fraction to possibly-inexact F-type:
|
||||||
v_x = [F(num) / F(den) for num, den in w_t]
|
v_x = [F(num) / F(den) for num, den in w_t]
|
||||||
|
# linearly interpolate within the bounds of the function:
|
||||||
v_x = [l * (1 - t) + h * t for l, h, t in zip(lo, hi, v_x)]
|
v_x = [l * (1 - t) + h * t for l, h, t in zip(lo, hi, v_x)]
|
||||||
res = obj(v_x)
|
res = obj(v_x)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def ab_to_lu(w_a, w_b):
|
def ab_to_lu(w_a, w_b): # converts corner points to midpoints, also denominators
|
||||||
|
# (2 * a + b) / (den * 3) = point halfway between corner "a" and the center
|
||||||
w_l = [(a[0] + a[0] + b[0], (1 << a[1]) * 3) for a, b in zip(w_a, w_b)]
|
w_l = [(a[0] + a[0] + b[0], (1 << a[1]) * 3) for a, b in zip(w_a, w_b)]
|
||||||
|
# (a + 2 * b) / (den * 3) = point halfway between corner "b" and the center
|
||||||
w_u = [(a[0] + b[0] + b[0], (1 << b[1]) * 3) for a, b in zip(w_a, w_b)]
|
w_u = [(a[0] + b[0] + b[0], (1 << b[1]) * 3) for a, b in zip(w_a, w_b)]
|
||||||
return w_l, w_u
|
return w_l, w_u
|
||||||
|
|
||||||
dims = len(lo)
|
dims = len(lo) # already asserted that len(lo) == len(hi)
|
||||||
|
|
||||||
|
# initial corner points:
|
||||||
|
# note that the denominators are encoded as the exponents of a power of two.
|
||||||
|
# therefore, coordinate = pair[0] / (1 << pair[1]).
|
||||||
w_a, w_b = [(0, 0)] * dims, [(1, 0)] * dims
|
w_a, w_b = [(0, 0)] * dims, [(1, 0)] * dims
|
||||||
|
|
||||||
|
# initial points halfway between the each of the two corners and the center:
|
||||||
|
# note that the denominators are identity here.
|
||||||
|
# therefore, coordinate = pair[0] / pair[1].
|
||||||
w_l, w_u = ab_to_lu(w_a, w_b)
|
w_l, w_u = ab_to_lu(w_a, w_b)
|
||||||
|
|
||||||
|
# initial function evaluations:
|
||||||
fl = fun(w_l)
|
fl = fun(w_l)
|
||||||
fu = fun(w_u)
|
fu = fun(w_u)
|
||||||
|
|
||||||
|
# initial minimum of all evaluated points so far:
|
||||||
if fl <= fu:
|
if fl <= fu:
|
||||||
xmin, fmin = w_l, fl
|
xmin, fmin = w_l, fl
|
||||||
else:
|
else:
|
||||||
xmin, fmin = w_u, fu
|
xmin, fmin = w_u, fu
|
||||||
imin = 0 # index of the minimum -- only one point so far, so it's that one.
|
|
||||||
|
|
||||||
# sample coordinates and their values:
|
# construct lists to hold all sample coordinates and their values:
|
||||||
vw_a, vw_b = [w_a], [w_b]
|
vw_a, vw_b = [w_a], [w_b]
|
||||||
v_fl = [fl] # remember that "l" and "u" are arbitrary shorthand used by the paper,
|
v_fl = [fl] # remember that "l" and "u" are arbitrary shorthand used by the paper,
|
||||||
v_fu = [fu] # and one isn't necessarily above or below the other.
|
v_fu = [fu] # and one isn't necessarily above or below the other.
|
||||||
|
@ -60,19 +86,19 @@ def birect(
|
||||||
|
|
||||||
del w_a, w_b, w_l, w_u, fl, fu # prevent accidental re-use
|
del w_a, w_b, w_l, w_u, fl, fu # prevent accidental re-use
|
||||||
|
|
||||||
def precision_met():
|
def precision_met(): # returns True when the optimization procedure should exit
|
||||||
return min_error is not None and fmin <= min_error
|
return min_error is not None and fmin <= min_error
|
||||||
|
|
||||||
def no_more_evals():
|
def no_more_evals(): # returns True when the optimization procedure should exit
|
||||||
return max_evals is not None and n + 1 >= max_evals
|
return max_evals is not None and n + 1 >= max_evals
|
||||||
|
|
||||||
def gather_potential(v_i):
|
def gather_potential(v_i):
|
||||||
# crappy algorithm for finding the convex hull of the plot, where
|
# crappy algorithm for finding the convex hull of a line plot where
|
||||||
# x = diameter of hyper-rectangle, and
|
# the x axis is the diameter of hyper-rectangle, and
|
||||||
# y = minimum loss of the two points (v_fl, v_fu) within it.
|
# the y axis is the minimum loss of the two points (v_fl, v_fu) within it.
|
||||||
|
|
||||||
# TODO: make this faster. use a sorted queue and peek at the best for each depth.
|
|
||||||
# start by finding the arg-minimum for each set of equal-diameter rects.
|
# start by finding the arg-minimum for each set of equal-diameter rects.
|
||||||
|
# TODO: make this faster. use a sorted queue and peek at the best for each depth.
|
||||||
bests = {} # mapping of depth to arg-minimum value (i.e. its index)
|
bests = {} # mapping of depth to arg-minimum value (i.e. its index)
|
||||||
for i in v_i:
|
for i in v_i:
|
||||||
fl, fu = v_fl[i], v_fu[i]
|
fl, fu = v_fl[i], v_fu[i]
|
||||||
|
@ -86,7 +112,7 @@ def birect(
|
||||||
if best is None or f < best[1]:
|
if best is None or f < best[1]:
|
||||||
bests[depth] = (i, f)
|
bests[depth] = (i, f)
|
||||||
|
|
||||||
if len(bests) == 1: # nothing to compare
|
if len(bests) == 1: # nothing to compare it to
|
||||||
return [i for i, f in bests.values()]
|
return [i for i, f in bests.values()]
|
||||||
|
|
||||||
asc = sorted(bests.items(), key=lambda t: -t[0]) # sort by length, ascending
|
asc = sorted(bests.items(), key=lambda t: -t[0]) # sort by length, ascending
|
||||||
|
@ -170,6 +196,19 @@ def birect(
|
||||||
|
|
||||||
return longest
|
return longest
|
||||||
|
|
||||||
|
# Hyper-rectangle subdivision demonstration: (2D example)
|
||||||
|
#
|
||||||
|
# Initial Split Once Split Twice Split Both
|
||||||
|
# ↓b ↓b a↓b a↓
|
||||||
|
# ┌───────────┐←b ┌─────╥─────┐ ┌─────╥─────┐ ┌─────╥─────┐←a
|
||||||
|
# │ │ │ ║ │ │ l ║ │ │ l ║ l │
|
||||||
|
# │ ① u │ ⇒ │ u ② ║ u │ ⇒ │ u ③ ║ u │ ⇒ │ u ║ u │
|
||||||
|
# │ │ │ ║ │ b→╞═════╣←a │ b→╞═════╬═════╡←a
|
||||||
|
# │ l │ ⇒ │ l ║ l │ ⇒ │ l ║ l │ ⇒ │ l ║ l │
|
||||||
|
# │ │ │ ║ │ │ u ║ │ │ u ║ u │
|
||||||
|
# a→└───────────┘ └─────╨─────┘ b→└─────╨─────┘ b→└─────╨─────┘
|
||||||
|
# ↑a ↑a ↑a ↑b
|
||||||
|
|
||||||
def split_it(i, which, *, w_a, w_b, d):
|
def split_it(i, which, *, w_a, w_b, d):
|
||||||
nonlocal n, xmin, fmin
|
nonlocal n, xmin, fmin
|
||||||
new, n = n, n + 1 # designate an index for the new hyper-rectangle
|
new, n = n, n + 1 # designate an index for the new hyper-rectangle
|
||||||
|
@ -223,6 +262,16 @@ def birect(
|
||||||
assert len(v_depth) == n, "internal error: v_depth has invalid length"
|
assert len(v_depth) == n, "internal error: v_depth has invalid length"
|
||||||
return v_new
|
return v_new
|
||||||
|
|
||||||
|
def _arbitrary_subdivision(w_a, w_b, d):
|
||||||
|
# shrink the coordinates as if they were subdivided and a single
|
||||||
|
# subdivision was selected. which subdivision is chosen doesn't matter.
|
||||||
|
a_d = w_a[d]
|
||||||
|
b_d = w_b[d]
|
||||||
|
large = max(a_d[0], b_d[0])
|
||||||
|
small = min(a_d[0], b_d[0])
|
||||||
|
w_a[d] = (large * 2 - 0, a_d[1] + 1)
|
||||||
|
w_b[d] = (small * 2 + 1, b_d[1] + 1)
|
||||||
|
|
||||||
def precompute_diagonals_by_limit(limit):
|
def precompute_diagonals_by_limit(limit):
|
||||||
diags = []
|
diags = []
|
||||||
w_a = vw_a[0].copy()
|
w_a = vw_a[0].copy()
|
||||||
|
@ -233,17 +282,10 @@ def birect(
|
||||||
delta = b[0] - a[0]
|
delta = b[0] - a[0]
|
||||||
sq_dist += (delta * delta) << (2 * (limit - a[1]))
|
sq_dist += (delta * delta) << (2 * (limit - a[1]))
|
||||||
diags.append(sq_dist)
|
diags.append(sq_dist)
|
||||||
|
_arbitrary_subdivision(w_a, w_b, depth % dims)
|
||||||
d = depth % dims
|
|
||||||
a_d = w_a[d]
|
|
||||||
b_d = w_b[d]
|
|
||||||
large = max(a_d[0], b_d[0])
|
|
||||||
small = min(a_d[0], b_d[0])
|
|
||||||
w_a[d] = (large * 2 - 0, a_d[1] + 1)
|
|
||||||
w_b[d] = (small * 2 + 1, b_d[1] + 1)
|
|
||||||
return [F(diag) ** F(0.5) / F(1 << limit) for diag in diags]
|
return [F(diag) ** F(0.5) / F(1 << limit) for diag in diags]
|
||||||
|
|
||||||
def precompute_diagonals_by_length(limit):
|
def precompute_diagonals_by_length(minlen):
|
||||||
diags, longests = [], []
|
diags, longests = [], []
|
||||||
w_a = vw_a[0].copy()
|
w_a = vw_a[0].copy()
|
||||||
w_b = vw_b[0].copy()
|
w_b = vw_b[0].copy()
|
||||||
|
@ -255,19 +297,12 @@ def birect(
|
||||||
sq_dist += (delta * delta) << (2 * (bits - a[1]))
|
sq_dist += (delta * delta) << (2 * (bits - a[1]))
|
||||||
longest = max(longest, abs(delta) << (bits - a[1]))
|
longest = max(longest, abs(delta) << (bits - a[1]))
|
||||||
diag = F(sq_dist) ** F(0.5) / F(1 << bits)
|
diag = F(sq_dist) ** F(0.5) / F(1 << bits)
|
||||||
if diag < limit:
|
if diag < minlen:
|
||||||
break
|
break
|
||||||
longest = F(longest) / F(1 << bits)
|
longest = F(longest) / F(1 << bits)
|
||||||
diags.append(diag)
|
diags.append(diag)
|
||||||
longests.append(longest)
|
longests.append(longest)
|
||||||
|
_arbitrary_subdivision(w_a, w_b, depth % dims)
|
||||||
d = depth % dims
|
|
||||||
a_d = w_a[d]
|
|
||||||
b_d = w_b[d]
|
|
||||||
large = max(a_d[0], b_d[0])
|
|
||||||
small = min(a_d[0], b_d[0])
|
|
||||||
w_a[d] = (large * 2 - 0, a_d[1] + 1)
|
|
||||||
w_b[d] = (small * 2 + 1, b_d[1] + 1)
|
|
||||||
return diags, longests
|
return diags, longests
|
||||||
|
|
||||||
diagonal_cache, longest_cache = precompute_diagonals_by_length(min_diag)
|
diagonal_cache, longest_cache = precompute_diagonals_by_length(min_diag)
|
||||||
|
@ -275,17 +310,18 @@ def birect(
|
||||||
diagonal_cache = precompute_diagonals_by_limit(depth_limit)
|
diagonal_cache = precompute_diagonals_by_limit(depth_limit)
|
||||||
|
|
||||||
for outer in range(1_000_000 if max_iters is None else max_iters):
|
for outer in range(1_000_000 if max_iters is None else max_iters):
|
||||||
if precision_met() or no_more_evals():
|
if precision_met() or no_more_evals(): # check stopping conditions
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# perform the actual *di*viding *rect*angles algorithm:
|
||||||
v_potential = gather_potential(v_active)
|
v_potential = gather_potential(v_active)
|
||||||
v_new = split_rectangles(v_potential)
|
v_new = split_rectangles(v_potential)
|
||||||
|
|
||||||
for j in v_potential:
|
for j in v_potential:
|
||||||
del v_active[j]
|
del v_active[j] # these were just split a moment ago, so remove them
|
||||||
for j in v_new:
|
for j in v_new:
|
||||||
if v_depth[j] < depth_limit:
|
if v_depth[j] < depth_limit: # TODO: is checking this late wasting evals?
|
||||||
v_active[j] = True
|
v_active[j] = True # these were just created, so add them
|
||||||
|
|
||||||
tmin = [F(x[0]) / F(x[1]) for x in xmin]
|
tmin = [F(x[0]) / F(x[1]) for x in xmin]
|
||||||
argmin = [l * (1 - t) + h * t for l, h, t in zip(lo, hi, tmin)]
|
argmin = [l * (1 - t) + h * t for l, h, t in zip(lo, hi, tmin)]
|
||||||
|
|
Loading…
Add table
Reference in a new issue