This commit is contained in:
Connor Olding 2017-02-12 17:41:47 -08:00
parent 77ba4fa11b
commit 6f6a34a6bc

View file

@ -11,6 +11,9 @@ lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
def log(left, right):
lament("{:>20}: {}".format(left, right))
class Dummy:
pass
# Loss functions {{{1
class SquaredHalved(Loss):
@ -304,7 +307,7 @@ def toy_data(train_samples, valid_samples, problem=2):
return (inputs, outputs), (valid_inputs, valid_outputs)
def model_from_config(config, input_features, output_features):
def model_from_config(config, input_features, output_features, callbacks):
# Our Test Model
init = inits[config.init]
@ -344,7 +347,7 @@ def model_from_config(config, input_features, output_features):
raise Exception('unknown optimizer', config.optim)
def rscb(restart):
measure_error() # declared later...
callbacks.restart()
log("restarting", restart)
if config.restart_optim:
optim.reset()
@ -472,8 +475,10 @@ def run(program, args=[]):
input_features = inputs.shape[-1]
output_features = outputs.shape[-1]
callbacks = Dummy()
model, learner, ritual, (loss, mloss) = \
model_from_config(config, input_features, output_features)
model_from_config(config, input_features, output_features, callbacks)
# Model Information
@ -515,6 +520,8 @@ def run(program, args=[]):
train_losses.append(train_err)
valid_losses.append(valid_err)
callbacks.restart = measure_error
measure_error()
assert inputs.shape[0] % config.batch_size == 0, \