move losses into Model and refactor methods
This commit is contained in:
parent
910facf98d
commit
e22316a4c9
3 changed files with 60 additions and 40 deletions
34
onn.py
34
onn.py
|
@ -696,8 +696,8 @@ class StochMRitual(Ritual):
|
|||
# this probably doesn't make sense for regression problems,
|
||||
# let alone small models, but here it is anyway!
|
||||
|
||||
def __init__(self, learner=None, loss=None, mloss=None, gamma=0.5):
|
||||
super().__init__(learner, loss, mloss)
|
||||
def __init__(self, learner=None, gamma=0.5):
|
||||
super().__init__(learner)
|
||||
self.gamma = _f(gamma)
|
||||
|
||||
def prepare(self, model):
|
||||
|
@ -726,12 +726,12 @@ class StochMRitual(Ritual):
|
|||
# np.clip(layer.W, -1, 1, out=layer.W)
|
||||
|
||||
class NoisyRitual(Ritual):
|
||||
def __init__(self, learner=None, loss=None, mloss=None,
|
||||
def __init__(self, learner=None,
|
||||
input_noise=0, output_noise=0, gradient_noise=0):
|
||||
self.input_noise = _f(input_noise)
|
||||
self.output_noise = _f(output_noise)
|
||||
self.gradient_noise = _f(gradient_noise)
|
||||
super().__init__(learner, loss, mloss)
|
||||
super().__init__(learner)
|
||||
|
||||
def learn(self, inputs, outputs):
|
||||
# this is pretty crude
|
||||
|
@ -1077,13 +1077,13 @@ def lookup_loss(maybe_name):
|
|||
return SomethingElse()
|
||||
raise Exception('unknown objective', maybe_name)
|
||||
|
||||
def ritual_from_config(config, learner, loss, mloss):
|
||||
def ritual_from_config(config, learner):
|
||||
if config.ritual == 'default':
|
||||
ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
|
||||
ritual = Ritual(learner=learner)
|
||||
elif config.ritual == 'stochm':
|
||||
ritual = StochMRitual(learner=learner, loss=loss, mloss=mloss)
|
||||
ritual = StochMRitual(learner=learner)
|
||||
elif config.ritual == 'noisy':
|
||||
ritual = NoisyRitual(learner=learner, loss=loss, mloss=mloss,
|
||||
ritual = NoisyRitual(learner=learner,
|
||||
input_noise=1e-1, output_noise=1e-2,
|
||||
gradient_noise=2e-7)
|
||||
else:
|
||||
|
@ -1105,7 +1105,10 @@ def model_from_config(config, input_features, output_features, callbacks=None):
|
|||
if y.output_shape[0] != output_features:
|
||||
y = y.feed(Dense(output_features, init))
|
||||
|
||||
model = Model(x, y, unsafe=config.unsafe)
|
||||
loss = lookup_loss(config.loss)
|
||||
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||
|
||||
model = Model(x, y, loss=loss, mloss=mloss, unsafe=config.unsafe)
|
||||
|
||||
if config.fn_load is not None:
|
||||
log('loading weights', config.fn_load)
|
||||
|
@ -1122,10 +1125,7 @@ def model_from_config(config, input_features, output_features, callbacks=None):
|
|||
|
||||
learner = learner_from_config(config, optim, rscb)
|
||||
|
||||
loss = lookup_loss(config.loss)
|
||||
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||
|
||||
ritual = ritual_from_config(config, learner, loss, mloss)
|
||||
ritual = ritual_from_config(config, learner)
|
||||
|
||||
return model, learner, ritual
|
||||
|
||||
|
@ -1234,8 +1234,8 @@ def run(program, args=None):
|
|||
|
||||
def measure_error():
|
||||
def print_error(name, inputs, outputs, comparison=None):
|
||||
predicted = model.forward(inputs)
|
||||
err = ritual.mloss.forward(predicted, outputs)
|
||||
predicted = model.evaluate(inputs)
|
||||
err = model.mloss.forward(predicted, outputs)
|
||||
if config.log10_loss:
|
||||
print(name, "{:12.6e}".format(err))
|
||||
if comparison:
|
||||
|
@ -1272,7 +1272,7 @@ def run(program, args=None):
|
|||
|
||||
# use plain SGD in warmup to prevent (or possibly cause?) numeric issues
|
||||
temp_optim = learner.optim
|
||||
temp_loss = ritual.loss
|
||||
temp_loss = model.loss
|
||||
learner.optim = Optimizer(lr=0.001)
|
||||
ritual.loss = Absolute() # less likely to blow up; more general
|
||||
|
||||
|
@ -1292,7 +1292,7 @@ def run(program, args=None):
|
|||
ritual.reset()
|
||||
|
||||
learner.optim = temp_optim
|
||||
ritual.loss = temp_loss
|
||||
model.loss = temp_loss
|
||||
|
||||
if training:
|
||||
measure_error()
|
||||
|
|
51
onn_core.py
51
onn_core.py
|
@ -884,13 +884,17 @@ class Dense(Layer):
|
|||
# Models {{{1
|
||||
|
||||
class Model:
|
||||
def __init__(self, nodes_in, nodes_out, unsafe=False):
|
||||
def __init__(self, nodes_in, nodes_out, loss=None, mloss=None, unsafe=False):
|
||||
self.loss = loss if loss is not None else SquaredHalved()
|
||||
self.mloss = mloss if mloss is not None else loss
|
||||
|
||||
nodes_in = [nodes_in] if isinstance(nodes_in, Layer) else nodes_in
|
||||
nodes_out = [nodes_out] if isinstance(nodes_out, Layer) else nodes_out
|
||||
assert type(nodes_in) == list, type(nodes_in)
|
||||
assert type(nodes_out) == list, type(nodes_out)
|
||||
self.nodes_in = nodes_in
|
||||
self.nodes_out = nodes_out
|
||||
|
||||
self.nodes = traverse_all(self.nodes_in, self.nodes_out)
|
||||
self.make_weights()
|
||||
for node in self.nodes:
|
||||
|
@ -928,22 +932,40 @@ class Model:
|
|||
assert inner_offset >= node.size, "Layer {} allocated less weights than it said it would".format(node)
|
||||
offset += node.size
|
||||
|
||||
def forward(self, X, deterministic=False):
|
||||
def evaluate(self, inputs, deterministic=True):
|
||||
values = dict()
|
||||
input_node = self.nodes[0]
|
||||
output_node = self.nodes[-1]
|
||||
values[input_node] = input_node._propagate(np.expand_dims(X, 0), deterministic)
|
||||
values[input_node] = input_node._propagate(np.expand_dims(inputs, 0), deterministic)
|
||||
for node in self.nodes[1:]:
|
||||
values[node] = node.propagate(values, deterministic)
|
||||
return values[output_node]
|
||||
|
||||
def backward(self, error):
|
||||
def apply(self, error): # TODO: better name?
|
||||
values = dict()
|
||||
input_node = self.nodes[0]
|
||||
output_node = self.nodes[-1]
|
||||
values[output_node] = output_node._backpropagate(np.expand_dims(error, 0))
|
||||
for node in reversed(self.nodes[:-1]):
|
||||
values[node] = node.backpropagate(values)
|
||||
return self.dW
|
||||
return values[input_node]
|
||||
|
||||
def forward(self, inputs, outputs, measure=False, deterministic=False):
|
||||
predicted = self.evaluate(inputs, deterministic=deterministic)
|
||||
if measure:
|
||||
error = self.mloss.forward(predicted, outputs)
|
||||
else:
|
||||
error = self.loss.forward(predicted, outputs)
|
||||
return error, predicted
|
||||
|
||||
def backward(self, predicted, outputs, measure=False):
|
||||
if measure:
|
||||
error = self.mloss.backward(predicted, outputs)
|
||||
else:
|
||||
error = self.loss.backward(predicted, outputs)
|
||||
# input_delta is rarely useful; it's just to match the forward pass.
|
||||
input_delta = self.apply(error)
|
||||
return self.dW, input_delta
|
||||
|
||||
def clear_grad(self):
|
||||
for node in self.nodes:
|
||||
|
@ -1028,11 +1050,8 @@ class Model:
|
|||
# Rituals {{{1
|
||||
|
||||
class Ritual: # i'm just making up names at this point.
|
||||
def __init__(self, learner=None, loss=None, mloss=None):
|
||||
# TODO: store loss and mloss in Model instead of here.
|
||||
def __init__(self, learner=None):
|
||||
self.learner = learner if learner is not None else Learner(Optimizer())
|
||||
self.loss = loss if loss is not None else Squared()
|
||||
self.mloss = mloss if mloss is not None else loss
|
||||
self.model = None
|
||||
|
||||
def reset(self):
|
||||
|
@ -1041,10 +1060,10 @@ class Ritual: # i'm just making up names at this point.
|
|||
self.bn = 0
|
||||
|
||||
def learn(self, inputs, outputs):
|
||||
predicted = self.model.forward(inputs)
|
||||
self.model.backward(self.loss.backward(predicted, outputs))
|
||||
error, predicted = self.model.forward(inputs, outputs)
|
||||
self.model.backward(predicted, outputs)
|
||||
self.model.regulate()
|
||||
return predicted
|
||||
return error, predicted
|
||||
|
||||
def update(self):
|
||||
optim = self.learner.optim
|
||||
|
@ -1062,14 +1081,14 @@ class Ritual: # i'm just making up names at this point.
|
|||
self.learner.batch(b / batch_count)
|
||||
|
||||
if test_only:
|
||||
predicted = self.model.forward(batch_inputs, deterministic=True)
|
||||
predicted = self.model.evaluate(batch_inputs, deterministic=True)
|
||||
else:
|
||||
predicted = self.learn(batch_inputs, batch_outputs)
|
||||
error, predicted = self.learn(batch_inputs, batch_outputs)
|
||||
self.model.regulate_forward()
|
||||
self.update()
|
||||
|
||||
if loss_logging:
|
||||
batch_loss = self.loss.forward(predicted, batch_outputs)
|
||||
batch_loss = self.model.loss.forward(predicted, batch_outputs)
|
||||
if np.isnan(batch_loss):
|
||||
raise Exception("nan")
|
||||
self.losses.append(batch_loss)
|
||||
|
@ -1077,7 +1096,7 @@ class Ritual: # i'm just making up names at this point.
|
|||
|
||||
if mloss_logging:
|
||||
# NOTE: this can use the non-deterministic predictions. fixme?
|
||||
batch_mloss = self.mloss.forward(predicted, batch_outputs)
|
||||
batch_mloss = self.model.mloss.forward(predicted, batch_outputs)
|
||||
if np.isnan(batch_mloss):
|
||||
raise Exception("nan")
|
||||
self.mlosses.append(batch_mloss)
|
||||
|
|
15
onn_mnist.py
15
onn_mnist.py
|
@ -145,7 +145,13 @@ y = y.feed(Dense(mnist_classes, init=init_glorot_uniform,
|
|||
reg_w=final_reg, reg_b=final_reg))
|
||||
y = y.feed(output_activation())
|
||||
|
||||
model = Model(x, y, unsafe=True)
|
||||
if output_activation in (Softmax, Sigmoid):
|
||||
loss = CategoricalCrossentropy()
|
||||
else:
|
||||
loss = SquaredHalved()
|
||||
mloss = Accuracy()
|
||||
|
||||
model = Model(x, y, loss=loss, mloss=mloss, unsafe=True)
|
||||
|
||||
def rscb(restart):
|
||||
log("restarting", restart)
|
||||
|
@ -176,12 +182,7 @@ else:
|
|||
lament('WARNING: no learning rate schedule selected.')
|
||||
learner = Learner(optim, epochs=epochs)
|
||||
|
||||
loss = CategoricalCrossentropy() if output_activation == Softmax else SquaredHalved()
|
||||
mloss = Accuracy()
|
||||
|
||||
ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
|
||||
#ritual = NoisyRitual(learner=learner, loss=loss, mloss=mloss,
|
||||
# input_noise=1e-1, output_noise=3.2e-2, gradient_noise=1e-1)
|
||||
ritual = Ritual(learner=learner)
|
||||
|
||||
model.print_graph()
|
||||
log('parameters', model.param_count)
|
||||
|
|
Loading…
Add table
Reference in a new issue