.
This commit is contained in:
parent
77ba4fa11b
commit
6f6a34a6bc
1 changed files with 10 additions and 3 deletions
13
optim_nn.py
13
optim_nn.py
|
@ -11,6 +11,9 @@ lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
|||
def log(left, right):
|
||||
lament("{:>20}: {}".format(left, right))
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
# Loss functions {{{1
|
||||
|
||||
class SquaredHalved(Loss):
|
||||
|
@ -304,7 +307,7 @@ def toy_data(train_samples, valid_samples, problem=2):
|
|||
|
||||
return (inputs, outputs), (valid_inputs, valid_outputs)
|
||||
|
||||
def model_from_config(config, input_features, output_features):
|
||||
def model_from_config(config, input_features, output_features, callbacks):
|
||||
# Our Test Model
|
||||
|
||||
init = inits[config.init]
|
||||
|
@ -344,7 +347,7 @@ def model_from_config(config, input_features, output_features):
|
|||
raise Exception('unknown optimizer', config.optim)
|
||||
|
||||
def rscb(restart):
|
||||
measure_error() # declared later...
|
||||
callbacks.restart()
|
||||
log("restarting", restart)
|
||||
if config.restart_optim:
|
||||
optim.reset()
|
||||
|
@ -472,8 +475,10 @@ def run(program, args=[]):
|
|||
input_features = inputs.shape[-1]
|
||||
output_features = outputs.shape[-1]
|
||||
|
||||
callbacks = Dummy()
|
||||
|
||||
model, learner, ritual, (loss, mloss) = \
|
||||
model_from_config(config, input_features, output_features)
|
||||
model_from_config(config, input_features, output_features, callbacks)
|
||||
|
||||
# Model Information
|
||||
|
||||
|
@ -515,6 +520,8 @@ def run(program, args=[]):
|
|||
train_losses.append(train_err)
|
||||
valid_losses.append(valid_err)
|
||||
|
||||
callbacks.restart = measure_error
|
||||
|
||||
measure_error()
|
||||
|
||||
assert inputs.shape[0] % config.batch_size == 0, \
|
||||
|
|
Loading…
Reference in a new issue