diff --git a/direct/soo.py b/direct/soo.py new file mode 100644 index 0000000..9a2f1c7 --- /dev/null +++ b/direct/soo.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +from collections import namedtuple +from queue import PriorityQueue +import numpy as np + +Result = namedtuple("Result", ["h", "p", "value"]) + + +def checked_fma(x, m, a): + # assumes x, m, and a are integers. + # assumes m >= 1 and a >= 0. + y = x * m + assert y >= x, "integer overflow" + z = y + a + assert z >= y, "integer overflow" + return z + + +# Simultaneous Optimistic Optimization +def soo(obj, origin, sigma, evals, K=3, h_max=None, iters=None): + assert K > 1, "K must be at least 2" + + origin = np.array(origin, float) + dims = len(origin) + h_max = int(31 * np.log(2) / np.log(K) - 1e-8) if h_max is None else h_max + + # Tree (coordinates) + Ts = [PriorityQueue() for _ in range(h_max + 1)] + + def coords(h, p): + return (np.array(p, int) + 0.5) * (sigma * 2) / K ** np.array(h, int) - sigma + + samples = 0 + + def query(h, p): + nonlocal samples + depth = np.max(h) + value = obj(coords(h, p) + origin) + Ts[depth].put((float(value), samples, list(h), list(p))) + samples += 1 + return Result(h, p, value) + + # kickstart with middle point at the root node. + best_ever = query([0] * dims, [0] * dims) + + history, stopping = [], False + for t in range(1_000_000 if iters is None else iters): + best = np.inf + for depth in range(h_max): + if Ts[depth].empty(): + continue + + # select a node at this depth. + tup = Ts[depth].get(block=False) + value, _, h, p = tup + if value > best: + Ts[depth].put(tup, block=False) + continue + best = value + + # split this node, potentially increasing its overall depth. + inc = 0 + for j in reversed(range(1, dims)): + if h[j] < h[j - 1]: + inc = j + + if K & 1 == 1: # when odd, + # maintain center point but split it differently next time. + h[inc] += 1 + p[inc] = checked_fma(p[inc], K, K // 2) + + new_depth = np.max(h) + Ts[new_depth].put(tup, block=False) + + for i in range(K): + if i == K // 2: + continue + p_new = p.copy() + p_new[inc] += i - K // 2 + candidate = query(h.copy(), p_new) + if candidate.value < best_ever.value: + best_ever = candidate + if samples >= evals: + stopping = True + break + + else: # when even, + # discard center point and only use its subdivisions. + for i in range(K): + h_new, p_new = h.copy(), p.copy() + h_new[inc] += 1 + p_new[inc] = checked_fma(p_new[inc], K, i) + candidate = query(h_new, p_new) + if candidate.value < best_ever.value: + best_ever = candidate + if samples >= evals: + stopping = True + break + + if stopping: + break + + history.append(best_ever.value) + if stopping: + break + + h, p, _ = best_ever + + return coords(h, p) + origin, history + + +if __name__ == "__main__": + # get this here: https://github.com/imh/hipsterplot/blob/master/hipsterplot.py + from hipsterplot import plot + import numpy as np + + def objective2210(x): + # another linear transformation of the rosenbrock banana function. + assert len(x) > 1, len(x) + a, b = x[:-1], x[1:] + + a = a / 4.0 * 2 - 12 / 15 + b = b / 1.5 * 2 - 43 / 15 + # solution: 3.60 1.40 + + return ( + np.sum(100 * np.square(np.square(a) + b) + np.square(a - 1)) + / 499.0444444444444 + ) + + optimized, history = soo( + objective2210, + origin=np.array([2.5, 2.5]), + sigma=np.array([2.5, 2.5]), + evals=2_000, + # K=3, + # h_max=19, # should be equivalent-ish to min_diag=1e-8 in birect.py + K=2, + h_max=27, # should be equivalent-ish to min_diag=1e-8 in birect.py + ) + print(" " * 11 + "plot of log10-losses over time") + plot(np.log10(history), num_y_chars=23) + print("", "soo result:", list(optimized), history[-1], sep="\n") + print("", "double checked:", objective2210(optimized), sep="\n")