From 5dc9939cf02d8f11771e73de50c599c56f085ff4 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Tue, 7 Jun 2022 07:02:35 +0200 Subject: [PATCH] add direct --- direct/README.md | 0 direct/birect.py | 325 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 325 insertions(+) create mode 100644 direct/README.md create mode 100644 direct/birect.py diff --git a/direct/README.md b/direct/README.md new file mode 100644 index 0000000..e69de29 diff --git a/direct/birect.py b/direct/birect.py new file mode 100644 index 0000000..2bdf498 --- /dev/null +++ b/direct/birect.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python3 +# based on hamming_exact3.py, but with all debug stuff removed + +def birect( + obj, + lo, + hi, + *, + min_diag, + min_error=None, + max_evals=None, + max_iters=None, + by_longest=False, + pruning=0, + F=float, +): + assert len(lo) == len(hi), "dimensional mismatch" + + assert not ( + min_error is None and max_evals is None and max_iters is None + ), "at least one stopping condition must be specified" + + # variables prefixed with v_ are to be one-dimensional vectors. [a, b, c] + # variables prefixed with vw_ are to be two-dimensional vectors. [[a, b], [c, d]] + # variables prefixed with w_ are to be one-dimensional vectors of pairs. + # aside: xmin should actually be called v_xmin, but it's not! + + def fun(w_t): + # xs = [l + (h - l) * t for t in w_t] + v_x = [F(num) / F(den) for num, den in w_t] + v_x = [l * (1 - t) + h * t for l, h, t in zip(lo, hi, v_x)] + res = obj(v_x) + return res + + def ab_to_lu(w_a, w_b): + w_l = [(a[0] + a[0] + b[0], (1 << a[1]) * 3) for a, b in zip(w_a, w_b)] + w_u = [(a[0] + b[0] + b[0], (1 << b[1]) * 3) for a, b in zip(w_a, w_b)] + return w_l, w_u + + dims = len(lo) + + w_a, w_b = [(0, 0)] * dims, [(1, 0)] * dims + + w_l, w_u = ab_to_lu(w_a, w_b) + fl = fun(w_l) + fu = fun(w_u) + if fl <= fu: + xmin, fmin = w_l, fl + else: + xmin, fmin = w_u, fu + imin = 0 # index of the minimum -- only one point so far, so it's that one. + + # sample coordinates and their values: + vw_a, vw_b = [w_a], [w_b] + v_fl = [fl] # remember that "l" and "u" are arbitrary shorthand used by the paper, + v_fu = [fu] # and one isn't necessarily above or below the other. + v_active = {0: True} # indices of hyper-rectangles that have yet to be subdivided + v_depth = [0] # increments every split + n = 1 # how many indices are in use + + del w_a, w_b, w_l, w_u, fl, fu # prevent accidental re-use + + def precision_met(): + return min_error is not None and fmin <= min_error + + def no_more_evals(): + return max_evals is not None and n + 1 >= max_evals + + avg = lambda a, b: (a + b) / 2 # TODO: delete me! + diff = lambda a, b: -abs(a - b) / 2 # TODO: delete me! + cycle_funs = [min, max, avg, diff] # TODO: delete me! + cycle_funs = [min, avg, max, diff, max, avg] # TODO: delete me! + cycle_funs = [min] # TODO: delete me! + # interesting. using cycle_funs = [max] seems optimal for objective2210. + + def gather_potential(v_i): + # crappy algorithm for finding the convex hull of the plot where + # x = diameter of hyper-rectangle + # y = minimum loss of the two points (v_fl, v_fu) within it + + cycle_fun = cycle_funs[outer % len(cycle_funs)] # TODO: delete me! + + # TODO: make this faster. use a sorted queue and peek at the best for each depth. + # start by finding the arg-minimum for each set of equal-diameter rects. + bests = {} # mapping of depth to arg-minimum value (i.e. its index) + for i in v_i: + fl, fu = v_fl[i], v_fu[i] + f = cycle_fun(fl, fu) # TODO: use min(fl, fu)! + depth = v_depth[i] + if cycle_fun == diff: # TODO: delete me! + # f = f * (1 << depth) # TODO: delete me! + f = f / diagonal_cache[depth] # TODO: delete me! + if by_longest: + depth = depth // dims * dims + # assert depth < depth_limit + best = bests.get(depth) + # TODO: handle f == best case. + if best is None or f < best[1]: + bests[depth] = (i, f) + + if len(bests) == 1: # nothing to compare + return [i for i, f in bests.values()] + + asc = sorted(bests.items(), key=lambda t: -t[0]) # sort by length, ascending + + # first, remove points that slope downwards. + # this yields a pareto front, which isn't necessarily convex. + old = asc + new = [old[-1]] + smallest = old[-1][1][1] + for i in reversed(range(len(old) - 1)): + f = old[i][1][1] + if f <= smallest: + smallest = f + new.append(old[i]) + new = new[::-1] + + # second, convert depths to lengths. + if by_longest: # TODO: does this branch make any difference? + new = [(longest_cache[depth],) + t for depth, t in new] + else: + new = [(diagonal_cache[depth],) + t for depth, t in new] + + # third, remove points that fall under a previous slope. + old = new + skip = [False] * len(old) + for i in range(len(old)): + if skip[i]: + continue + len0, i0, f0 = old[i] + smallest_slope = None + for j in range(i + 1, len(old)): + if skip[j]: + continue + len1, i1, f1 = old[j] + num = f1 - f0 + den = len1 - len0 + # this factor of 3/2 comes from the denominator; + # each length should be multiplied by 2/3: + # the furthest relative distance from a corner to a center point. + slope = num / den # * F(3 / 2) + if smallest_slope is None: + smallest_slope = slope + elif slope < smallest_slope: + for k in range(i + 1, j): + skip[k] = True + smallest_slope = slope + + new = [entry for entry, skipping in zip(old, skip) if not skipping] + + if pruning: + v_f = sorted(min(fl, fu) for fl, fu in zip(v_fl, v_fu)) + fmedian = v_f[len(v_f) // 2] + + offset = fmin - pruning * (fmedian - fmin) + start = 0 + K_slope = None + for i in range(len(new)): + len0, i0, f0 = new[i] + new_slope = (f0 - offset) / len0 + if K_slope is None or new_slope < K_slope: + # if new_slope >= 0: + K_slope = new_slope + start = i + # if start: print(end=f"[starting at {i} with slope {K_slope:.3f}]") + new = new[start:] + + return [i for len, i, f in new] + + def determine_longest(w_a, w_b): + # the index of the dimension is used as a tie-breaker (considered longer). + # e.g. a box has lengths (2, 3, 3). the index returned is then 1. + # TODO: alternate way of stating that comment: biased towards smaller indices. + longest = 0 + + invlen = None + for i, (a, b) in enumerate(zip(w_a, w_b)): + den_a = 1 << a[1] + den_b = 1 << b[1] + den = max(den_a, den_b) # TODO: always the same, right? + if invlen is None or den < invlen: + invlen = den + longest = i + + return longest + + def split_it(i, which, *, w_a, w_b, d): + nonlocal n, xmin, fmin + new, n = n, n + 1 # designate an index for the new hyper-rectangle + w_new_a = w_a.copy() + w_new_b = w_b.copy() + + den = w_a[d][1] # should be equal to w_b[d][1] as well + num_a = w_a[d][0] + num_b = w_b[d][0] + if which: + w_new_a[d] = (num_b + num_b, den + 1) # swap + w_new_b[d] = (num_a + num_b, den + 1) # slide + else: + w_new_a[d] = (num_a + num_b, den + 1) # slide + w_new_b[d] = (num_a + num_a, den + 1) # swap + + w_l, w_u = ab_to_lu(w_new_a, w_new_b) + fl = fun(w_l) if which else v_fl[i] + fu = v_fu[i] if which else fun(w_u) + vw_a.append(w_new_a) + vw_b.append(w_new_b) + v_fl.append(fl) + v_fu.append(fu) + v_depth.append(v_depth[i] + 1) + if which: + if fl < fmin: + xmin, fmin = w_l, fl + else: + if fu < fmin: + xmin, fmin = w_u, fu + return new + + def split_rectangles(v_i): # returns new indices + v_new = [] + for i in v_i: + w_a = vw_a[i] + w_b = vw_b[i] + d = determine_longest(w_a, w_b) + + v_new.append(split_it(i, 0, w_a=w_a, w_b=w_b, d=d)) + if precision_met() or no_more_evals(): + break + v_new.append(split_it(i, 1, w_a=w_a, w_b=w_b, d=d)) + if precision_met() or no_more_evals(): + break + + assert len(vw_a) == n, "internal error: vw_a has invalid length" + assert len(vw_b) == n, "internal error: vw_b has invalid length" + assert len(v_fl) == n, "internal error: v_fl has invalid length" + assert len(v_fu) == n, "internal error: v_fu has invalid length" + assert len(v_depth) == n, "internal error: v_depth has invalid length" + return v_new + + def precompute_diagonals_by_limit(limit): + diags = [] + w_a = vw_a[0].copy() + w_b = vw_b[0].copy() + for depth in range(limit): + sq_dist = 0 + for a, b in zip(w_a, w_b): + delta = b[0] - a[0] + sq_dist += (delta * delta) << (2 * (limit - a[1])) + diags.append(sq_dist) + + d = depth % dims + a_d = w_a[d] + b_d = w_b[d] + large = max(a_d[0], b_d[0]) + small = min(a_d[0], b_d[0]) + w_a[d] = (large * 2 - 0, a_d[1] + 1) + w_b[d] = (small * 2 + 1, b_d[1] + 1) + return [F(diag) ** F(0.5) / F(1 << limit) for diag in diags] + + def precompute_diagonals_by_length(limit): + diags, longests = [], [] + w_a = vw_a[0].copy() + w_b = vw_b[0].copy() + for depth in range(1_000_000): + bits = max(max(a[1], b[1]) for a, b in zip(w_a, w_b)) + sq_dist, longest = 0, 0 + for a, b in zip(w_a, w_b): + delta = b[0] - a[0] + sq_dist += (delta * delta) << (2 * (bits - a[1])) + longest = max(longest, abs(delta) << (bits - a[1])) + diag = F(sq_dist) ** F(0.5) / F(1 << bits) + if diag < limit: + break + longest = F(longest) / F(1 << bits) + diags.append(diag) + longests.append(longest) + + d = depth % dims + a_d = w_a[d] + b_d = w_b[d] + large = max(a_d[0], b_d[0]) + small = min(a_d[0], b_d[0]) + w_a[d] = (large * 2 - 0, a_d[1] + 1) + w_b[d] = (small * 2 + 1, b_d[1] + 1) + return diags, longests + + diagonal_cache, longest_cache = precompute_diagonals_by_length(min_diag) + depth_limit = len(diagonal_cache) + diagonal_cache = precompute_diagonals_by_limit(depth_limit) + + for outer in range(1_000_000 if max_iters is None else max_iters): + if precision_met() or no_more_evals(): + break + + v_potential = gather_potential(v_active) + v_new = split_rectangles(v_potential) + + for j in v_potential: + del v_active[j] + for j in v_new: + if v_depth[j] < depth_limit: + v_active[j] = True + + tmin = [F(x[0]) / F(x[1]) for x in xmin] + argmin = [l * (1 - t) + h * t for l, h, t in zip(lo, hi, tmin)] + return argmin, fmin + + +if __name__ == "__main__": + from arsd_objectives import objective2210 + import numpy as np + + F = np.float64 + res = birect( + lambda a: objective2210(np.array(a, F)), + [0, 0], + [5, 5], + min_diag=F(1e-8 / 5), + # max_evals=50_000, + max_evals=2_000, + F=F, + by_longest=True, + ) + print("", "birect result:", *res, sep="\n") + print("", "double checked:", objective2210(np.array(res[0], F)), sep="\n")