diff --git a/optim_nn.py b/optim_nn.py index a57fe95..9e601aa 100644 --- a/optim_nn.py +++ b/optim_nn.py @@ -11,6 +11,9 @@ lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs) def log(left, right): lament("{:>20}: {}".format(left, right)) +class Dummy: + pass + # Loss functions {{{1 class SquaredHalved(Loss): @@ -304,7 +307,7 @@ def toy_data(train_samples, valid_samples, problem=2): return (inputs, outputs), (valid_inputs, valid_outputs) -def model_from_config(config, input_features, output_features): +def model_from_config(config, input_features, output_features, callbacks): # Our Test Model init = inits[config.init] @@ -344,7 +347,7 @@ def model_from_config(config, input_features, output_features): raise Exception('unknown optimizer', config.optim) def rscb(restart): - measure_error() # declared later... + callbacks.restart() log("restarting", restart) if config.restart_optim: optim.reset() @@ -472,8 +475,10 @@ def run(program, args=[]): input_features = inputs.shape[-1] output_features = outputs.shape[-1] + callbacks = Dummy() + model, learner, ritual, (loss, mloss) = \ - model_from_config(config, input_features, output_features) + model_from_config(config, input_features, output_features, callbacks) # Model Information @@ -515,6 +520,8 @@ def run(program, args=[]): train_losses.append(train_err) valid_losses.append(valid_err) + callbacks.restart = measure_error + measure_error() assert inputs.shape[0] % config.batch_size == 0, \