From 4c97502e84617c96cdbb0770894066d16bd441d2 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 13 Jun 2022 06:15:08 +0200 Subject: [PATCH] direct: update birect.py --- direct/birect.py | 122 ++++++++++++++++++++++++++++++----------------- 1 file changed, 79 insertions(+), 43 deletions(-) diff --git a/direct/birect.py b/direct/birect.py index ac882bb..438694d 100644 --- a/direct/birect.py +++ b/direct/birect.py @@ -1,18 +1,32 @@ #!/usr/bin/env python3 +# Copyright (C) 2022 Connor Olding + +# Permission to use, copy, modify, and/or distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notice and this permission notice appear in all copies. + +# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + def birect( - obj, - lo, - hi, + obj, # objective function to find the minimum value of + lo, # list of lower bounds for each problem dimension + hi, # list of upper bounds for each problem dimension *, - min_diag, - min_error=None, - max_evals=None, - max_iters=None, - by_longest=False, + min_diag, # never subdivide hyper-rectangles below this length + min_error=None, # exit when the objective function achieves this error + max_evals=None, # exit when the objective function has been run this many times + max_iters=None, # exit when the optimization procedure iterates this many times + by_longest=False, # measure by rects by longest edge instead of their diagonal pruning=0, - F=float, + F=float, # can be float, np.float32, np.float64, decimal, or anything float-like ): assert len(lo) == len(hi), "dimensional mismatch" @@ -26,31 +40,43 @@ def birect( # aside: xmin should actually be called v_xmin, but it's not! def fun(w_t): - # xs = [l + (h - l) * t for t in w_t] + # final conversion from exact fraction to possibly-inexact F-type: v_x = [F(num) / F(den) for num, den in w_t] + # linearly interpolate within the bounds of the function: v_x = [l * (1 - t) + h * t for l, h, t in zip(lo, hi, v_x)] res = obj(v_x) return res - def ab_to_lu(w_a, w_b): + def ab_to_lu(w_a, w_b): # converts corner points to midpoints, also denominators + # (2 * a + b) / (den * 3) = point halfway between corner "a" and the center w_l = [(a[0] + a[0] + b[0], (1 << a[1]) * 3) for a, b in zip(w_a, w_b)] + # (a + 2 * b) / (den * 3) = point halfway between corner "b" and the center w_u = [(a[0] + b[0] + b[0], (1 << b[1]) * 3) for a, b in zip(w_a, w_b)] return w_l, w_u - dims = len(lo) + dims = len(lo) # already asserted that len(lo) == len(hi) + # initial corner points: + # note that the denominators are encoded as the exponents of a power of two. + # therefore, coordinate = pair[0] / (1 << pair[1]). w_a, w_b = [(0, 0)] * dims, [(1, 0)] * dims + # initial points halfway between the each of the two corners and the center: + # note that the denominators are identity here. + # therefore, coordinate = pair[0] / pair[1]. w_l, w_u = ab_to_lu(w_a, w_b) + + # initial function evaluations: fl = fun(w_l) fu = fun(w_u) + + # initial minimum of all evaluated points so far: if fl <= fu: xmin, fmin = w_l, fl else: xmin, fmin = w_u, fu - imin = 0 # index of the minimum -- only one point so far, so it's that one. - # sample coordinates and their values: + # construct lists to hold all sample coordinates and their values: vw_a, vw_b = [w_a], [w_b] v_fl = [fl] # remember that "l" and "u" are arbitrary shorthand used by the paper, v_fu = [fu] # and one isn't necessarily above or below the other. @@ -60,19 +86,19 @@ def birect( del w_a, w_b, w_l, w_u, fl, fu # prevent accidental re-use - def precision_met(): + def precision_met(): # returns True when the optimization procedure should exit return min_error is not None and fmin <= min_error - def no_more_evals(): + def no_more_evals(): # returns True when the optimization procedure should exit return max_evals is not None and n + 1 >= max_evals def gather_potential(v_i): - # crappy algorithm for finding the convex hull of the plot, where - # x = diameter of hyper-rectangle, and - # y = minimum loss of the two points (v_fl, v_fu) within it. + # crappy algorithm for finding the convex hull of a line plot where + # the x axis is the diameter of hyper-rectangle, and + # the y axis is the minimum loss of the two points (v_fl, v_fu) within it. - # TODO: make this faster. use a sorted queue and peek at the best for each depth. # start by finding the arg-minimum for each set of equal-diameter rects. + # TODO: make this faster. use a sorted queue and peek at the best for each depth. bests = {} # mapping of depth to arg-minimum value (i.e. its index) for i in v_i: fl, fu = v_fl[i], v_fu[i] @@ -86,7 +112,7 @@ def birect( if best is None or f < best[1]: bests[depth] = (i, f) - if len(bests) == 1: # nothing to compare + if len(bests) == 1: # nothing to compare it to return [i for i, f in bests.values()] asc = sorted(bests.items(), key=lambda t: -t[0]) # sort by length, ascending @@ -170,6 +196,19 @@ def birect( return longest + # Hyper-rectangle subdivision demonstration: (2D example) + # + # Initial Split Once Split Twice Split Both + # ↓b ↓b a↓b a↓ + # ┌───────────┐←b ┌─────╥─────┐ ┌─────╥─────┐ ┌─────╥─────┐←a + # │ │ │ ║ │ │ l ║ │ │ l ║ l │ + # │ ① u │ ⇒ │ u ② ║ u │ ⇒ │ u ③ ║ u │ ⇒ │ u ║ u │ + # │ │ │ ║ │ b→╞═════╣←a │ b→╞═════╬═════╡←a + # │ l │ ⇒ │ l ║ l │ ⇒ │ l ║ l │ ⇒ │ l ║ l │ + # │ │ │ ║ │ │ u ║ │ │ u ║ u │ + # a→└───────────┘ └─────╨─────┘ b→└─────╨─────┘ b→└─────╨─────┘ + # ↑a ↑a ↑a ↑b + def split_it(i, which, *, w_a, w_b, d): nonlocal n, xmin, fmin new, n = n, n + 1 # designate an index for the new hyper-rectangle @@ -223,6 +262,16 @@ def birect( assert len(v_depth) == n, "internal error: v_depth has invalid length" return v_new + def _arbitrary_subdivision(w_a, w_b, d): + # shrink the coordinates as if they were subdivided and a single + # subdivision was selected. which subdivision is chosen doesn't matter. + a_d = w_a[d] + b_d = w_b[d] + large = max(a_d[0], b_d[0]) + small = min(a_d[0], b_d[0]) + w_a[d] = (large * 2 - 0, a_d[1] + 1) + w_b[d] = (small * 2 + 1, b_d[1] + 1) + def precompute_diagonals_by_limit(limit): diags = [] w_a = vw_a[0].copy() @@ -233,17 +282,10 @@ def birect( delta = b[0] - a[0] sq_dist += (delta * delta) << (2 * (limit - a[1])) diags.append(sq_dist) - - d = depth % dims - a_d = w_a[d] - b_d = w_b[d] - large = max(a_d[0], b_d[0]) - small = min(a_d[0], b_d[0]) - w_a[d] = (large * 2 - 0, a_d[1] + 1) - w_b[d] = (small * 2 + 1, b_d[1] + 1) + _arbitrary_subdivision(w_a, w_b, depth % dims) return [F(diag) ** F(0.5) / F(1 << limit) for diag in diags] - def precompute_diagonals_by_length(limit): + def precompute_diagonals_by_length(minlen): diags, longests = [], [] w_a = vw_a[0].copy() w_b = vw_b[0].copy() @@ -255,19 +297,12 @@ def birect( sq_dist += (delta * delta) << (2 * (bits - a[1])) longest = max(longest, abs(delta) << (bits - a[1])) diag = F(sq_dist) ** F(0.5) / F(1 << bits) - if diag < limit: + if diag < minlen: break longest = F(longest) / F(1 << bits) diags.append(diag) longests.append(longest) - - d = depth % dims - a_d = w_a[d] - b_d = w_b[d] - large = max(a_d[0], b_d[0]) - small = min(a_d[0], b_d[0]) - w_a[d] = (large * 2 - 0, a_d[1] + 1) - w_b[d] = (small * 2 + 1, b_d[1] + 1) + _arbitrary_subdivision(w_a, w_b, depth % dims) return diags, longests diagonal_cache, longest_cache = precompute_diagonals_by_length(min_diag) @@ -275,17 +310,18 @@ def birect( diagonal_cache = precompute_diagonals_by_limit(depth_limit) for outer in range(1_000_000 if max_iters is None else max_iters): - if precision_met() or no_more_evals(): + if precision_met() or no_more_evals(): # check stopping conditions break + # perform the actual *di*viding *rect*angles algorithm: v_potential = gather_potential(v_active) v_new = split_rectangles(v_potential) for j in v_potential: - del v_active[j] + del v_active[j] # these were just split a moment ago, so remove them for j in v_new: - if v_depth[j] < depth_limit: - v_active[j] = True + if v_depth[j] < depth_limit: # TODO: is checking this late wasting evals? + v_active[j] = True # these were just created, so add them tmin = [F(x[0]) / F(x[1]) for x in xmin] argmin = [l * (1 - t) + h * t for l, h, t in zip(lo, hi, tmin)]