direct: add soo.py

This commit is contained in:
Connor Olding 2022-06-13 22:58:34 +02:00
parent 6c256510d6
commit f00ecdfde0

144
direct/soo.py Normal file
View File

@ -0,0 +1,144 @@
#!/usr/bin/env python3
from collections import namedtuple
from queue import PriorityQueue
import numpy as np
Result = namedtuple("Result", ["h", "p", "value"])
def checked_fma(x, m, a):
# assumes x, m, and a are integers.
# assumes m >= 1 and a >= 0.
y = x * m
assert y >= x, "integer overflow"
z = y + a
assert z >= y, "integer overflow"
return z
# Simultaneous Optimistic Optimization
def soo(obj, origin, sigma, evals, K=3, h_max=None, iters=None):
assert K > 1, "K must be at least 2"
origin = np.array(origin, float)
dims = len(origin)
h_max = int(31 * np.log(2) / np.log(K) - 1e-8) if h_max is None else h_max
# Tree (coordinates)
Ts = [PriorityQueue() for _ in range(h_max + 1)]
def coords(h, p):
return (np.array(p, int) + 0.5) * (sigma * 2) / K ** np.array(h, int) - sigma
samples = 0
def query(h, p):
nonlocal samples
depth = np.max(h)
value = obj(coords(h, p) + origin)
Ts[depth].put((float(value), samples, list(h), list(p)))
samples += 1
return Result(h, p, value)
# kickstart with middle point at the root node.
best_ever = query([0] * dims, [0] * dims)
history, stopping = [], False
for t in range(1_000_000 if iters is None else iters):
best = np.inf
for depth in range(h_max):
if Ts[depth].empty():
continue
# select a node at this depth.
tup = Ts[depth].get(block=False)
value, _, h, p = tup
if value > best:
Ts[depth].put(tup, block=False)
continue
best = value
# split this node, potentially increasing its overall depth.
inc = 0
for j in reversed(range(1, dims)):
if h[j] < h[j - 1]:
inc = j
if K & 1 == 1: # when odd,
# maintain center point but split it differently next time.
h[inc] += 1
p[inc] = checked_fma(p[inc], K, K // 2)
new_depth = np.max(h)
Ts[new_depth].put(tup, block=False)
for i in range(K):
if i == K // 2:
continue
p_new = p.copy()
p_new[inc] += i - K // 2
candidate = query(h.copy(), p_new)
if candidate.value < best_ever.value:
best_ever = candidate
if samples >= evals:
stopping = True
break
else: # when even,
# discard center point and only use its subdivisions.
for i in range(K):
h_new, p_new = h.copy(), p.copy()
h_new[inc] += 1
p_new[inc] = checked_fma(p_new[inc], K, i)
candidate = query(h_new, p_new)
if candidate.value < best_ever.value:
best_ever = candidate
if samples >= evals:
stopping = True
break
if stopping:
break
history.append(best_ever.value)
if stopping:
break
h, p, _ = best_ever
return coords(h, p) + origin, history
if __name__ == "__main__":
# get this here: https://github.com/imh/hipsterplot/blob/master/hipsterplot.py
from hipsterplot import plot
import numpy as np
def objective2210(x):
# another linear transformation of the rosenbrock banana function.
assert len(x) > 1, len(x)
a, b = x[:-1], x[1:]
a = a / 4.0 * 2 - 12 / 15
b = b / 1.5 * 2 - 43 / 15
# solution: 3.60 1.40
return (
np.sum(100 * np.square(np.square(a) + b) + np.square(a - 1))
/ 499.0444444444444
)
optimized, history = soo(
objective2210,
origin=np.array([2.5, 2.5]),
sigma=np.array([2.5, 2.5]),
evals=2_000,
# K=3,
# h_max=19, # should be equivalent-ish to min_diag=1e-8 in birect.py
K=2,
h_max=27, # should be equivalent-ish to min_diag=1e-8 in birect.py
)
print(" " * 11 + "plot of log10-losses over time")
plot(np.log10(history), num_y_chars=23)
print("", "soo result:", list(optimized), history[-1], sep="\n")
print("", "double checked:", objective2210(optimized), sep="\n")