.
This commit is contained in:
parent
02e03ad85e
commit
f12e408c7e
123
optim_nn.py
123
optim_nn.py
|
@ -98,7 +98,7 @@ class Momentum(Optimizer):
|
||||||
self.dWprev = np.copy(dW)
|
self.dWprev = np.copy(dW)
|
||||||
|
|
||||||
V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb)
|
V = self.mu * self.dWprev - self.alpha * (dW + W * self.lamb)
|
||||||
self.dWprev = V
|
self.dWprev[:] = V
|
||||||
if self.nesterov:
|
if self.nesterov:
|
||||||
return self.mu * V - self.alpha * (dW + W * self.lamb)
|
return self.mu * V - self.alpha * (dW + W * self.lamb)
|
||||||
else:
|
else:
|
||||||
|
@ -412,7 +412,7 @@ class Model:
|
||||||
offset += node.size
|
offset += node.size
|
||||||
|
|
||||||
def traverse(self, nodes, node):
|
def traverse(self, nodes, node):
|
||||||
if node == x:
|
if node == self.x:
|
||||||
return [node]
|
return [node]
|
||||||
for parent in node.parents:
|
for parent in node.parents:
|
||||||
if parent not in nodes:
|
if parent not in nodes:
|
||||||
|
@ -465,6 +465,60 @@ class Model:
|
||||||
def save_weights(self, fn, overwrite=False):
|
def save_weights(self, fn, overwrite=False):
|
||||||
raise NotImplementedError("unimplemented", self)
|
raise NotImplementedError("unimplemented", self)
|
||||||
|
|
||||||
|
class Ritual:
|
||||||
|
def __init__(self,
|
||||||
|
optim=None,
|
||||||
|
learn_rate=1e-3, learn_anneal=1, learn_advance=0,
|
||||||
|
loss=None, mloss=None):
|
||||||
|
self.optim = optim if optim is not None else SGD()
|
||||||
|
self.loss = loss if loss is not None else Squared()
|
||||||
|
self.mloss = mloss if mloss is not None else loss
|
||||||
|
self.learn_rate = nf(learn_rate)
|
||||||
|
self.learn_anneal = nf(learn_anneal)
|
||||||
|
self.learn_advance = nf(learn_advance)
|
||||||
|
|
||||||
|
def measure(self, residual):
|
||||||
|
return self.mloss.mean(residual)
|
||||||
|
|
||||||
|
def derive(self, residual):
|
||||||
|
return self.loss.dmean(residual)
|
||||||
|
|
||||||
|
def update(self, dW, W):
|
||||||
|
self.optim.update(dW, W)
|
||||||
|
|
||||||
|
def prepare(self, epoch):
|
||||||
|
self.optim.alpha = self.learn_rate * self.learn_anneal**epoch
|
||||||
|
|
||||||
|
def restart(self, optim=False):
|
||||||
|
self.learn_rate *= self.learn_anneal**self.learn_advance
|
||||||
|
if optim:
|
||||||
|
self.optim.reset()
|
||||||
|
|
||||||
|
def train_batched(self, model, inputs, outputs, batch_size, return_losses=False):
|
||||||
|
cumsum_loss = 0
|
||||||
|
batch_count = inputs.shape[0] // batch_size
|
||||||
|
losses = []
|
||||||
|
for b in range(batch_count):
|
||||||
|
bi = b * batch_size
|
||||||
|
batch_inputs = inputs[ bi:bi+batch_size]
|
||||||
|
batch_outputs = outputs[bi:bi+batch_size]
|
||||||
|
|
||||||
|
predicted = model.forward(batch_inputs)
|
||||||
|
residual = predicted - batch_outputs
|
||||||
|
|
||||||
|
model.backward(self.derive(residual))
|
||||||
|
self.update(model.dW, model.W)
|
||||||
|
|
||||||
|
batch_loss = self.measure(residual)
|
||||||
|
cumsum_loss += batch_loss
|
||||||
|
if return_losses:
|
||||||
|
losses.append(batch_loss)
|
||||||
|
avg_loss = cumsum_loss / batch_count
|
||||||
|
if return_losses:
|
||||||
|
return avg_loss, losses
|
||||||
|
else:
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batchless'):
|
def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batchless'):
|
||||||
y = x
|
y = x
|
||||||
last_size = x.output_shape[0]
|
last_size = x.output_shape[0]
|
||||||
|
@ -516,7 +570,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def run(program, args=[]):
|
||||||
import sys
|
import sys
|
||||||
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
||||||
def log(left, right):
|
def log(left, right):
|
||||||
|
@ -527,6 +581,7 @@ if __name__ == '__main__':
|
||||||
from dotmap import DotMap
|
from dotmap import DotMap
|
||||||
config = DotMap(
|
config = DotMap(
|
||||||
fn = 'ml/cie_mlp_min.h5',
|
fn = 'ml/cie_mlp_min.h5',
|
||||||
|
log_fn = 'losses.npz',
|
||||||
|
|
||||||
# multi-residual network parameters
|
# multi-residual network parameters
|
||||||
res_width = 49,
|
res_width = 49,
|
||||||
|
@ -553,7 +608,6 @@ if __name__ == '__main__':
|
||||||
batch_size = 64,
|
batch_size = 64,
|
||||||
init = 'he_normal',
|
init = 'he_normal',
|
||||||
loss = SomethingElse(4/3),
|
loss = SomethingElse(4/3),
|
||||||
log_fn = 'losses.npz',
|
|
||||||
mloss = 'mse',
|
mloss = 'mse',
|
||||||
restart_optim = True, # restarts also reset internal state of optimizer
|
restart_optim = True, # restarts also reset internal state of optimizer
|
||||||
unsafe = True, # aka gotta go fast mode
|
unsafe = True, # aka gotta go fast mode
|
||||||
|
@ -632,11 +686,14 @@ if __name__ == '__main__':
|
||||||
loss = lookup_loss(config.loss)
|
loss = lookup_loss(config.loss)
|
||||||
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||||
|
|
||||||
LR = config.LR
|
anneal = 0.5**(1/config.LR_halve_every)
|
||||||
LRprod = 0.5**(1/config.LR_halve_every)
|
ritual = Ritual(optim=optim,
|
||||||
|
learn_rate=config.LR, learn_anneal=anneal,
|
||||||
|
learn_advance=config.LR_restart_advance,
|
||||||
|
loss=loss, mloss=mloss)
|
||||||
|
|
||||||
LRE = LR * (LRprod**config.LR_restart_advance)**config.restarts * LRprod**(config.epochs - 1)
|
learn_end = config.LR * (anneal**config.LR_restart_advance)**config.restarts * anneal**(config.epochs - 1)
|
||||||
log("final learning rate", "{:10.8f}".format(LRE))
|
log("final learning rate", "{:10.8f}".format(learn_end))
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|
||||||
|
@ -645,24 +702,21 @@ if __name__ == '__main__':
|
||||||
valid_losses = []
|
valid_losses = []
|
||||||
|
|
||||||
def measure_error():
|
def measure_error():
|
||||||
# log("weight mean", "{:11.7f}".format(np.mean(model.W)))
|
|
||||||
# log("weight var", "{:11.7f}".format(np.var(model.W)))
|
|
||||||
|
|
||||||
def print_error(name, inputs, outputs, comparison=None):
|
def print_error(name, inputs, outputs, comparison=None):
|
||||||
predicted = model.forward(inputs)
|
predicted = model.forward(inputs)
|
||||||
residual = predicted - outputs
|
residual = predicted - outputs
|
||||||
err = mloss.mean(residual)
|
err = ritual.measure(residual)
|
||||||
log(name + " loss", "{:11.7f}".format(err))
|
log(name + " loss", "{:11.7f}".format(err))
|
||||||
if comparison:
|
if comparison:
|
||||||
log("improvement", "{:+7.2f}%".format((comparison / err - 1) * 100))
|
log("improvement", "{:+7.2f}%".format((comparison / err - 1) * 100))
|
||||||
return err
|
return err
|
||||||
|
|
||||||
train_err = print_error("train",
|
train_err = print_error("train",
|
||||||
inputs / x_scale, outputs / y_scale,
|
inputs / x_scale, outputs / y_scale,
|
||||||
config.train_compare)
|
config.train_compare)
|
||||||
valid_err = print_error("valid",
|
valid_err = print_error("valid",
|
||||||
valid_inputs / x_scale, valid_outputs / y_scale,
|
valid_inputs / x_scale, valid_outputs / y_scale,
|
||||||
config.valid_compare)
|
config.valid_compare)
|
||||||
train_losses.append(train_err)
|
train_losses.append(train_err)
|
||||||
valid_losses.append(valid_err)
|
valid_losses.append(valid_err)
|
||||||
|
|
||||||
|
@ -671,38 +725,25 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
if i > 0:
|
if i > 0:
|
||||||
log("restarting", i)
|
log("restarting", i)
|
||||||
LR *= LRprod**config.LR_restart_advance
|
ritual.restart(optim=config.restart_optim)
|
||||||
if config.restart_optim:
|
|
||||||
optim.reset()
|
|
||||||
|
|
||||||
assert inputs.shape[0] % config.batch_size == 0, \
|
assert inputs.shape[0] % config.batch_size == 0, \
|
||||||
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
||||||
batch_count = inputs.shape[0] // config.batch_size
|
|
||||||
for e in range(config.epochs):
|
for e in range(config.epochs):
|
||||||
indices = np.arange(inputs.shape[0])
|
indices = np.arange(inputs.shape[0])
|
||||||
np.random.shuffle(indices)
|
np.random.shuffle(indices)
|
||||||
shuffled_inputs = inputs[indices] / x_scale
|
shuffled_inputs = inputs[indices] / x_scale
|
||||||
shuffled_outputs = outputs[indices] / y_scale
|
shuffled_outputs = outputs[indices] / y_scale
|
||||||
|
|
||||||
optim.alpha = LR * LRprod**e
|
ritual.prepare(e)
|
||||||
#log("learning rate", "{:10.8f}".format(optim.alpha))
|
#log("learning rate", "{:10.8f}".format(ritual.optim.alpha))
|
||||||
|
|
||||||
cumsum_loss = 0
|
avg_loss, losses = ritual.train_batched(model,
|
||||||
for b in range(batch_count):
|
shuffled_inputs, shuffled_outputs,
|
||||||
bi = b * config.batch_size
|
config.batch_size,
|
||||||
batch_inputs = shuffled_inputs[ bi:bi+config.batch_size]
|
return_losses=True)
|
||||||
batch_outputs = shuffled_outputs[bi:bi+config.batch_size]
|
log("average loss", "{:11.7f}".format(avg_loss))
|
||||||
|
batch_losses += losses
|
||||||
predicted = model.forward(batch_inputs)
|
|
||||||
residual = predicted - batch_outputs
|
|
||||||
dW = model.backward(loss.dmean(residual))
|
|
||||||
optim.update(dW, model.W)
|
|
||||||
|
|
||||||
# note: we don't actually need this for training, only monitoring.
|
|
||||||
batch_loss = mloss.mean(residual)
|
|
||||||
cumsum_loss += batch_loss
|
|
||||||
batch_losses.append(batch_loss)
|
|
||||||
#log("average loss", "{:11.7f}".format(cumsum_loss / batch_count))
|
|
||||||
|
|
||||||
measure_error()
|
measure_error()
|
||||||
|
|
||||||
|
@ -723,3 +764,9 @@ if __name__ == '__main__':
|
||||||
batch_losses=nfa(batch_losses),
|
batch_losses=nfa(batch_losses),
|
||||||
train_losses=nfa(train_losses),
|
train_losses=nfa(train_losses),
|
||||||
valid_losses=nfa(valid_losses))
|
valid_losses=nfa(valid_losses))
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
sys.exit(run(sys.argv[0], sys.argv[1:]))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user