.
This commit is contained in:
parent
6f6a34a6bc
commit
8c79667904
2 changed files with 77 additions and 11 deletions
33
optim_nn.py
33
optim_nn.py
|
@ -1,5 +1,8 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# BIG TODO: ensure numpy isn't upcasting to float64 *anywhere*.
|
||||
# this is gonna take some work.
|
||||
|
||||
# external packages required for full functionality:
|
||||
# numpy scipy h5py sklearn dotmap
|
||||
|
||||
|
@ -24,7 +27,7 @@ class SquaredHalved(Loss):
|
|||
return r
|
||||
|
||||
class SomethingElse(Loss):
|
||||
# generalizes Absolute and SquaredHalved (|dx| = 1)
|
||||
# generalizes Absolute and SquaredHalved
|
||||
# plot: https://www.desmos.com/calculator/fagjg9vuz7
|
||||
def __init__(self, a=4/3):
|
||||
assert 1 <= a <= 2, "parameter out of range"
|
||||
|
@ -337,7 +340,15 @@ def model_from_config(config, input_features, output_features, callbacks):
|
|||
|
||||
if config.optim == 'adam':
|
||||
assert not config.nesterov, "unimplemented"
|
||||
optim = Adam()
|
||||
d1 = config.optim_decay1 if 'optim_decay1' in config else 9.5
|
||||
d2 = config.optim_decay2 if 'optim_decay2' in config else 999.5
|
||||
b1 = np.exp(-1/d1)
|
||||
b2 = np.exp(-1/d2)
|
||||
optim = Adam(b1=b1, b1_t=b1, b2=b2, b2_t=b2)
|
||||
elif config.optim in ('rms', 'rmsprop'):
|
||||
d2 = config.optim_decay2 if 'optim_decay2' in config else 99.5
|
||||
mu = np.exp(-1/d2)
|
||||
optim = RMSprop(mu=mu)
|
||||
elif config.optim == 'sgd':
|
||||
if config.momentum != 0:
|
||||
optim = Momentum(mu=config.momentum, nesterov=config.nesterov)
|
||||
|
@ -413,7 +424,9 @@ def model_from_config(config, input_features, output_features, callbacks):
|
|||
|
||||
def run(program, args=[]):
|
||||
|
||||
# Config
|
||||
np.random.seed(42069)
|
||||
|
||||
# Config {{{2
|
||||
|
||||
from dotmap import DotMap
|
||||
config = DotMap(
|
||||
|
@ -432,6 +445,8 @@ def run(program, args=[]):
|
|||
activation = 'gelu',
|
||||
|
||||
optim = 'adam',
|
||||
optim_decay1 = 2, # given in epochs (optional)
|
||||
optim_decay2 = 100, # given in epochs (optional)
|
||||
nesterov = False, # only used with SGD or Adam
|
||||
momentum = 0.50, # only used with SGD
|
||||
batch_size = 64,
|
||||
|
@ -467,7 +482,7 @@ def run(program, args=[]):
|
|||
|
||||
config.pprint()
|
||||
|
||||
# toy data
|
||||
# Toy Data {{{2
|
||||
# (our model is probably complete overkill for this, so TODO: better data)
|
||||
|
||||
(inputs, outputs), (valid_inputs, valid_outputs) = \
|
||||
|
@ -493,7 +508,7 @@ def run(program, args=[]):
|
|||
print(str(node)+sep+('\n'+str(node)+sep).join(children))
|
||||
log('parameters', model.param_count)
|
||||
|
||||
# Training
|
||||
# Training {{{2
|
||||
|
||||
batch_losses = []
|
||||
train_losses = []
|
||||
|
@ -550,9 +565,6 @@ def run(program, args=[]):
|
|||
log('saving weights', config.fn_save)
|
||||
model.save_weights(config.fn_save, overwrite=True)
|
||||
|
||||
# Evaluation
|
||||
# TODO: write this portion again
|
||||
|
||||
if config.log_fn is not None:
|
||||
log('saving losses', config.log_fn)
|
||||
np.savez_compressed(config.log_fn,
|
||||
|
@ -560,8 +572,13 @@ def run(program, args=[]):
|
|||
train_losses=nfa(train_losses),
|
||||
valid_losses=nfa(valid_losses))
|
||||
|
||||
# Evaluation {{{2
|
||||
# TODO: write this portion again
|
||||
|
||||
return 0
|
||||
|
||||
# do main {{{1
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
sys.exit(run(sys.argv[0], sys.argv[1:]))
|
||||
|
|
|
@ -27,6 +27,8 @@ def init_he_uniform(size, ins, outs):
|
|||
# Loss functions {{{1
|
||||
|
||||
class Loss:
|
||||
per_batch = False
|
||||
|
||||
def mean(self, r):
|
||||
return np.average(self.f(r))
|
||||
|
||||
|
@ -91,7 +93,53 @@ class Momentum(Optimizer):
|
|||
else:
|
||||
return V
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
# RMSprop generalizes* Adagrad, etc.
|
||||
|
||||
# * RMSprop == Adagrad when
|
||||
# RMSprop.mu == 1
|
||||
|
||||
def __init__(self, alpha=0.0001, mu=0.99, eps=1e-8):
|
||||
self.alpha = nf(alpha) # learning rate
|
||||
self.mu = nf(mu) # decay term
|
||||
self.eps = nf(eps)
|
||||
|
||||
# one might consider the following equation when specifying mu:
|
||||
# mu = e**(-1/t)
|
||||
# default: t = -1/ln(0.99) = ~99.5
|
||||
# therefore the default of mu=0.99 means
|
||||
# an input decays to 1/e its original amplitude over 99.5 epochs.
|
||||
# (this is from DSP, so how relevant it is in SGD is debatable)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.g = None
|
||||
|
||||
def compute(self, dW, W):
|
||||
if self.g is None:
|
||||
self.g = np.zeros_like(dW)
|
||||
|
||||
# basically apply a first-order low-pass filter to delta squared
|
||||
self.g[:] = self.mu * self.g + (1 - self.mu) * dW * dW
|
||||
# equivalent (though numerically different?):
|
||||
#self.g += (dW * dW - self.g) * (1 - self.mu)
|
||||
|
||||
# finally sqrt it to complete the running root-mean-square approximation
|
||||
return -self.alpha * dW / np.sqrt(self.g + self.eps)
|
||||
|
||||
class Adam(Optimizer):
|
||||
# Adam generalizes* RMSprop, and
|
||||
# adds a decay term to the regular (non-squared) delta, and
|
||||
# does some decay-gain voodoo. (i guess it's compensating
|
||||
# for the filtered deltas starting from zero)
|
||||
|
||||
# * Adam == RMSprop when
|
||||
# Adam.b1 == 0
|
||||
# Adam.b2 == RMSprop.mu
|
||||
# Adam.b1_t == 0
|
||||
# Adam.b2_t == 0
|
||||
|
||||
def __init__(self, alpha=0.001, b1=0.9, b2=0.999, b1_t=0.9, b2_t=0.999, eps=1e-8):
|
||||
self.alpha = nf(alpha) # learning rate
|
||||
self.b1 = nf(b1) # decay term
|
||||
|
@ -110,14 +158,15 @@ class Adam(Optimizer):
|
|||
|
||||
def compute(self, dW, W):
|
||||
if self.mt is None:
|
||||
self.mt = np.zeros_like(W)
|
||||
self.mt = np.zeros_like(dW)
|
||||
if self.vt is None:
|
||||
self.vt = np.zeros_like(W)
|
||||
self.vt = np.zeros_like(dW)
|
||||
|
||||
# decay
|
||||
# decay gain
|
||||
self.b1_t *= self.b1
|
||||
self.b2_t *= self.b2
|
||||
|
||||
# filter
|
||||
self.mt[:] = self.b1 * self.mt + (1 - self.b1) * dW
|
||||
self.vt[:] = self.b2 * self.vt + (1 - self.b2) * dW * dW
|
||||
|
||||
|
|
Loading…
Reference in a new issue