direct: add soo.py
This commit is contained in:
parent
6c256510d6
commit
f00ecdfde0
1 changed files with 144 additions and 0 deletions
144
direct/soo.py
Normal file
144
direct/soo.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
#!/usr/bin/env python3
|
||||
from collections import namedtuple
|
||||
from queue import PriorityQueue
|
||||
import numpy as np
|
||||
|
||||
Result = namedtuple("Result", ["h", "p", "value"])
|
||||
|
||||
|
||||
def checked_fma(x, m, a):
|
||||
# assumes x, m, and a are integers.
|
||||
# assumes m >= 1 and a >= 0.
|
||||
y = x * m
|
||||
assert y >= x, "integer overflow"
|
||||
z = y + a
|
||||
assert z >= y, "integer overflow"
|
||||
return z
|
||||
|
||||
|
||||
# Simultaneous Optimistic Optimization
|
||||
def soo(obj, origin, sigma, evals, K=3, h_max=None, iters=None):
|
||||
assert K > 1, "K must be at least 2"
|
||||
|
||||
origin = np.array(origin, float)
|
||||
dims = len(origin)
|
||||
h_max = int(31 * np.log(2) / np.log(K) - 1e-8) if h_max is None else h_max
|
||||
|
||||
# Tree (coordinates)
|
||||
Ts = [PriorityQueue() for _ in range(h_max + 1)]
|
||||
|
||||
def coords(h, p):
|
||||
return (np.array(p, int) + 0.5) * (sigma * 2) / K ** np.array(h, int) - sigma
|
||||
|
||||
samples = 0
|
||||
|
||||
def query(h, p):
|
||||
nonlocal samples
|
||||
depth = np.max(h)
|
||||
value = obj(coords(h, p) + origin)
|
||||
Ts[depth].put((float(value), samples, list(h), list(p)))
|
||||
samples += 1
|
||||
return Result(h, p, value)
|
||||
|
||||
# kickstart with middle point at the root node.
|
||||
best_ever = query([0] * dims, [0] * dims)
|
||||
|
||||
history, stopping = [], False
|
||||
for t in range(1_000_000 if iters is None else iters):
|
||||
best = np.inf
|
||||
for depth in range(h_max):
|
||||
if Ts[depth].empty():
|
||||
continue
|
||||
|
||||
# select a node at this depth.
|
||||
tup = Ts[depth].get(block=False)
|
||||
value, _, h, p = tup
|
||||
if value > best:
|
||||
Ts[depth].put(tup, block=False)
|
||||
continue
|
||||
best = value
|
||||
|
||||
# split this node, potentially increasing its overall depth.
|
||||
inc = 0
|
||||
for j in reversed(range(1, dims)):
|
||||
if h[j] < h[j - 1]:
|
||||
inc = j
|
||||
|
||||
if K & 1 == 1: # when odd,
|
||||
# maintain center point but split it differently next time.
|
||||
h[inc] += 1
|
||||
p[inc] = checked_fma(p[inc], K, K // 2)
|
||||
|
||||
new_depth = np.max(h)
|
||||
Ts[new_depth].put(tup, block=False)
|
||||
|
||||
for i in range(K):
|
||||
if i == K // 2:
|
||||
continue
|
||||
p_new = p.copy()
|
||||
p_new[inc] += i - K // 2
|
||||
candidate = query(h.copy(), p_new)
|
||||
if candidate.value < best_ever.value:
|
||||
best_ever = candidate
|
||||
if samples >= evals:
|
||||
stopping = True
|
||||
break
|
||||
|
||||
else: # when even,
|
||||
# discard center point and only use its subdivisions.
|
||||
for i in range(K):
|
||||
h_new, p_new = h.copy(), p.copy()
|
||||
h_new[inc] += 1
|
||||
p_new[inc] = checked_fma(p_new[inc], K, i)
|
||||
candidate = query(h_new, p_new)
|
||||
if candidate.value < best_ever.value:
|
||||
best_ever = candidate
|
||||
if samples >= evals:
|
||||
stopping = True
|
||||
break
|
||||
|
||||
if stopping:
|
||||
break
|
||||
|
||||
history.append(best_ever.value)
|
||||
if stopping:
|
||||
break
|
||||
|
||||
h, p, _ = best_ever
|
||||
|
||||
return coords(h, p) + origin, history
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# get this here: https://github.com/imh/hipsterplot/blob/master/hipsterplot.py
|
||||
from hipsterplot import plot
|
||||
import numpy as np
|
||||
|
||||
def objective2210(x):
|
||||
# another linear transformation of the rosenbrock banana function.
|
||||
assert len(x) > 1, len(x)
|
||||
a, b = x[:-1], x[1:]
|
||||
|
||||
a = a / 4.0 * 2 - 12 / 15
|
||||
b = b / 1.5 * 2 - 43 / 15
|
||||
# solution: 3.60 1.40
|
||||
|
||||
return (
|
||||
np.sum(100 * np.square(np.square(a) + b) + np.square(a - 1))
|
||||
/ 499.0444444444444
|
||||
)
|
||||
|
||||
optimized, history = soo(
|
||||
objective2210,
|
||||
origin=np.array([2.5, 2.5]),
|
||||
sigma=np.array([2.5, 2.5]),
|
||||
evals=2_000,
|
||||
# K=3,
|
||||
# h_max=19, # should be equivalent-ish to min_diag=1e-8 in birect.py
|
||||
K=2,
|
||||
h_max=27, # should be equivalent-ish to min_diag=1e-8 in birect.py
|
||||
)
|
||||
print(" " * 11 + "plot of log10-losses over time")
|
||||
plot(np.log10(history), num_y_chars=23)
|
||||
print("", "soo result:", list(optimized), history[-1], sep="\n")
|
||||
print("", "double checked:", objective2210(optimized), sep="\n")
|
Loading…
Reference in a new issue