.
This commit is contained in:
parent
6f6a34a6bc
commit
8c79667904
2 changed files with 77 additions and 11 deletions
33
optim_nn.py
33
optim_nn.py
|
@ -1,5 +1,8 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# BIG TODO: ensure numpy isn't upcasting to float64 *anywhere*.
|
||||||
|
# this is gonna take some work.
|
||||||
|
|
||||||
# external packages required for full functionality:
|
# external packages required for full functionality:
|
||||||
# numpy scipy h5py sklearn dotmap
|
# numpy scipy h5py sklearn dotmap
|
||||||
|
|
||||||
|
@ -24,7 +27,7 @@ class SquaredHalved(Loss):
|
||||||
return r
|
return r
|
||||||
|
|
||||||
class SomethingElse(Loss):
|
class SomethingElse(Loss):
|
||||||
# generalizes Absolute and SquaredHalved (|dx| = 1)
|
# generalizes Absolute and SquaredHalved
|
||||||
# plot: https://www.desmos.com/calculator/fagjg9vuz7
|
# plot: https://www.desmos.com/calculator/fagjg9vuz7
|
||||||
def __init__(self, a=4/3):
|
def __init__(self, a=4/3):
|
||||||
assert 1 <= a <= 2, "parameter out of range"
|
assert 1 <= a <= 2, "parameter out of range"
|
||||||
|
@ -337,7 +340,15 @@ def model_from_config(config, input_features, output_features, callbacks):
|
||||||
|
|
||||||
if config.optim == 'adam':
|
if config.optim == 'adam':
|
||||||
assert not config.nesterov, "unimplemented"
|
assert not config.nesterov, "unimplemented"
|
||||||
optim = Adam()
|
d1 = config.optim_decay1 if 'optim_decay1' in config else 9.5
|
||||||
|
d2 = config.optim_decay2 if 'optim_decay2' in config else 999.5
|
||||||
|
b1 = np.exp(-1/d1)
|
||||||
|
b2 = np.exp(-1/d2)
|
||||||
|
optim = Adam(b1=b1, b1_t=b1, b2=b2, b2_t=b2)
|
||||||
|
elif config.optim in ('rms', 'rmsprop'):
|
||||||
|
d2 = config.optim_decay2 if 'optim_decay2' in config else 99.5
|
||||||
|
mu = np.exp(-1/d2)
|
||||||
|
optim = RMSprop(mu=mu)
|
||||||
elif config.optim == 'sgd':
|
elif config.optim == 'sgd':
|
||||||
if config.momentum != 0:
|
if config.momentum != 0:
|
||||||
optim = Momentum(mu=config.momentum, nesterov=config.nesterov)
|
optim = Momentum(mu=config.momentum, nesterov=config.nesterov)
|
||||||
|
@ -413,7 +424,9 @@ def model_from_config(config, input_features, output_features, callbacks):
|
||||||
|
|
||||||
def run(program, args=[]):
|
def run(program, args=[]):
|
||||||
|
|
||||||
# Config
|
np.random.seed(42069)
|
||||||
|
|
||||||
|
# Config {{{2
|
||||||
|
|
||||||
from dotmap import DotMap
|
from dotmap import DotMap
|
||||||
config = DotMap(
|
config = DotMap(
|
||||||
|
@ -432,6 +445,8 @@ def run(program, args=[]):
|
||||||
activation = 'gelu',
|
activation = 'gelu',
|
||||||
|
|
||||||
optim = 'adam',
|
optim = 'adam',
|
||||||
|
optim_decay1 = 2, # given in epochs (optional)
|
||||||
|
optim_decay2 = 100, # given in epochs (optional)
|
||||||
nesterov = False, # only used with SGD or Adam
|
nesterov = False, # only used with SGD or Adam
|
||||||
momentum = 0.50, # only used with SGD
|
momentum = 0.50, # only used with SGD
|
||||||
batch_size = 64,
|
batch_size = 64,
|
||||||
|
@ -467,7 +482,7 @@ def run(program, args=[]):
|
||||||
|
|
||||||
config.pprint()
|
config.pprint()
|
||||||
|
|
||||||
# toy data
|
# Toy Data {{{2
|
||||||
# (our model is probably complete overkill for this, so TODO: better data)
|
# (our model is probably complete overkill for this, so TODO: better data)
|
||||||
|
|
||||||
(inputs, outputs), (valid_inputs, valid_outputs) = \
|
(inputs, outputs), (valid_inputs, valid_outputs) = \
|
||||||
|
@ -493,7 +508,7 @@ def run(program, args=[]):
|
||||||
print(str(node)+sep+('\n'+str(node)+sep).join(children))
|
print(str(node)+sep+('\n'+str(node)+sep).join(children))
|
||||||
log('parameters', model.param_count)
|
log('parameters', model.param_count)
|
||||||
|
|
||||||
# Training
|
# Training {{{2
|
||||||
|
|
||||||
batch_losses = []
|
batch_losses = []
|
||||||
train_losses = []
|
train_losses = []
|
||||||
|
@ -550,9 +565,6 @@ def run(program, args=[]):
|
||||||
log('saving weights', config.fn_save)
|
log('saving weights', config.fn_save)
|
||||||
model.save_weights(config.fn_save, overwrite=True)
|
model.save_weights(config.fn_save, overwrite=True)
|
||||||
|
|
||||||
# Evaluation
|
|
||||||
# TODO: write this portion again
|
|
||||||
|
|
||||||
if config.log_fn is not None:
|
if config.log_fn is not None:
|
||||||
log('saving losses', config.log_fn)
|
log('saving losses', config.log_fn)
|
||||||
np.savez_compressed(config.log_fn,
|
np.savez_compressed(config.log_fn,
|
||||||
|
@ -560,8 +572,13 @@ def run(program, args=[]):
|
||||||
train_losses=nfa(train_losses),
|
train_losses=nfa(train_losses),
|
||||||
valid_losses=nfa(valid_losses))
|
valid_losses=nfa(valid_losses))
|
||||||
|
|
||||||
|
# Evaluation {{{2
|
||||||
|
# TODO: write this portion again
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
# do main {{{1
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import sys
|
import sys
|
||||||
sys.exit(run(sys.argv[0], sys.argv[1:]))
|
sys.exit(run(sys.argv[0], sys.argv[1:]))
|
||||||
|
|
|
@ -27,6 +27,8 @@ def init_he_uniform(size, ins, outs):
|
||||||
# Loss functions {{{1
|
# Loss functions {{{1
|
||||||
|
|
||||||
class Loss:
|
class Loss:
|
||||||
|
per_batch = False
|
||||||
|
|
||||||
def mean(self, r):
|
def mean(self, r):
|
||||||
return np.average(self.f(r))
|
return np.average(self.f(r))
|
||||||
|
|
||||||
|
@ -91,7 +93,53 @@ class Momentum(Optimizer):
|
||||||
else:
|
else:
|
||||||
return V
|
return V
|
||||||
|
|
||||||
|
class RMSprop(Optimizer):
|
||||||
|
# RMSprop generalizes* Adagrad, etc.
|
||||||
|
|
||||||
|
# * RMSprop == Adagrad when
|
||||||
|
# RMSprop.mu == 1
|
||||||
|
|
||||||
|
def __init__(self, alpha=0.0001, mu=0.99, eps=1e-8):
|
||||||
|
self.alpha = nf(alpha) # learning rate
|
||||||
|
self.mu = nf(mu) # decay term
|
||||||
|
self.eps = nf(eps)
|
||||||
|
|
||||||
|
# one might consider the following equation when specifying mu:
|
||||||
|
# mu = e**(-1/t)
|
||||||
|
# default: t = -1/ln(0.99) = ~99.5
|
||||||
|
# therefore the default of mu=0.99 means
|
||||||
|
# an input decays to 1/e its original amplitude over 99.5 epochs.
|
||||||
|
# (this is from DSP, so how relevant it is in SGD is debatable)
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.g = None
|
||||||
|
|
||||||
|
def compute(self, dW, W):
|
||||||
|
if self.g is None:
|
||||||
|
self.g = np.zeros_like(dW)
|
||||||
|
|
||||||
|
# basically apply a first-order low-pass filter to delta squared
|
||||||
|
self.g[:] = self.mu * self.g + (1 - self.mu) * dW * dW
|
||||||
|
# equivalent (though numerically different?):
|
||||||
|
#self.g += (dW * dW - self.g) * (1 - self.mu)
|
||||||
|
|
||||||
|
# finally sqrt it to complete the running root-mean-square approximation
|
||||||
|
return -self.alpha * dW / np.sqrt(self.g + self.eps)
|
||||||
|
|
||||||
class Adam(Optimizer):
|
class Adam(Optimizer):
|
||||||
|
# Adam generalizes* RMSprop, and
|
||||||
|
# adds a decay term to the regular (non-squared) delta, and
|
||||||
|
# does some decay-gain voodoo. (i guess it's compensating
|
||||||
|
# for the filtered deltas starting from zero)
|
||||||
|
|
||||||
|
# * Adam == RMSprop when
|
||||||
|
# Adam.b1 == 0
|
||||||
|
# Adam.b2 == RMSprop.mu
|
||||||
|
# Adam.b1_t == 0
|
||||||
|
# Adam.b2_t == 0
|
||||||
|
|
||||||
def __init__(self, alpha=0.001, b1=0.9, b2=0.999, b1_t=0.9, b2_t=0.999, eps=1e-8):
|
def __init__(self, alpha=0.001, b1=0.9, b2=0.999, b1_t=0.9, b2_t=0.999, eps=1e-8):
|
||||||
self.alpha = nf(alpha) # learning rate
|
self.alpha = nf(alpha) # learning rate
|
||||||
self.b1 = nf(b1) # decay term
|
self.b1 = nf(b1) # decay term
|
||||||
|
@ -110,14 +158,15 @@ class Adam(Optimizer):
|
||||||
|
|
||||||
def compute(self, dW, W):
|
def compute(self, dW, W):
|
||||||
if self.mt is None:
|
if self.mt is None:
|
||||||
self.mt = np.zeros_like(W)
|
self.mt = np.zeros_like(dW)
|
||||||
if self.vt is None:
|
if self.vt is None:
|
||||||
self.vt = np.zeros_like(W)
|
self.vt = np.zeros_like(dW)
|
||||||
|
|
||||||
# decay
|
# decay gain
|
||||||
self.b1_t *= self.b1
|
self.b1_t *= self.b1
|
||||||
self.b2_t *= self.b2
|
self.b2_t *= self.b2
|
||||||
|
|
||||||
|
# filter
|
||||||
self.mt[:] = self.b1 * self.mt + (1 - self.b1) * dW
|
self.mt[:] = self.b1 * self.mt + (1 - self.b1) * dW
|
||||||
self.vt[:] = self.b2 * self.vt + (1 - self.b2) * dW * dW
|
self.vt[:] = self.b2 * self.vt + (1 - self.b2) * dW * dW
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue