.
This commit is contained in:
parent
77ba4fa11b
commit
6f6a34a6bc
1 changed files with 10 additions and 3 deletions
13
optim_nn.py
13
optim_nn.py
|
@ -11,6 +11,9 @@ lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
||||||
def log(left, right):
|
def log(left, right):
|
||||||
lament("{:>20}: {}".format(left, right))
|
lament("{:>20}: {}".format(left, right))
|
||||||
|
|
||||||
|
class Dummy:
|
||||||
|
pass
|
||||||
|
|
||||||
# Loss functions {{{1
|
# Loss functions {{{1
|
||||||
|
|
||||||
class SquaredHalved(Loss):
|
class SquaredHalved(Loss):
|
||||||
|
@ -304,7 +307,7 @@ def toy_data(train_samples, valid_samples, problem=2):
|
||||||
|
|
||||||
return (inputs, outputs), (valid_inputs, valid_outputs)
|
return (inputs, outputs), (valid_inputs, valid_outputs)
|
||||||
|
|
||||||
def model_from_config(config, input_features, output_features):
|
def model_from_config(config, input_features, output_features, callbacks):
|
||||||
# Our Test Model
|
# Our Test Model
|
||||||
|
|
||||||
init = inits[config.init]
|
init = inits[config.init]
|
||||||
|
@ -344,7 +347,7 @@ def model_from_config(config, input_features, output_features):
|
||||||
raise Exception('unknown optimizer', config.optim)
|
raise Exception('unknown optimizer', config.optim)
|
||||||
|
|
||||||
def rscb(restart):
|
def rscb(restart):
|
||||||
measure_error() # declared later...
|
callbacks.restart()
|
||||||
log("restarting", restart)
|
log("restarting", restart)
|
||||||
if config.restart_optim:
|
if config.restart_optim:
|
||||||
optim.reset()
|
optim.reset()
|
||||||
|
@ -472,8 +475,10 @@ def run(program, args=[]):
|
||||||
input_features = inputs.shape[-1]
|
input_features = inputs.shape[-1]
|
||||||
output_features = outputs.shape[-1]
|
output_features = outputs.shape[-1]
|
||||||
|
|
||||||
|
callbacks = Dummy()
|
||||||
|
|
||||||
model, learner, ritual, (loss, mloss) = \
|
model, learner, ritual, (loss, mloss) = \
|
||||||
model_from_config(config, input_features, output_features)
|
model_from_config(config, input_features, output_features, callbacks)
|
||||||
|
|
||||||
# Model Information
|
# Model Information
|
||||||
|
|
||||||
|
@ -515,6 +520,8 @@ def run(program, args=[]):
|
||||||
train_losses.append(train_err)
|
train_losses.append(train_err)
|
||||||
valid_losses.append(valid_err)
|
valid_losses.append(valid_err)
|
||||||
|
|
||||||
|
callbacks.restart = measure_error
|
||||||
|
|
||||||
measure_error()
|
measure_error()
|
||||||
|
|
||||||
assert inputs.shape[0] % config.batch_size == 0, \
|
assert inputs.shape[0] % config.batch_size == 0, \
|
||||||
|
|
Loading…
Reference in a new issue