.
This commit is contained in:
parent
fe577eb7f4
commit
3106495704
2 changed files with 64 additions and 57 deletions
119
optim_nn.py
119
optim_nn.py
|
@ -408,32 +408,7 @@ def toy_data(train_samples, valid_samples, problem=2):
|
|||
|
||||
# Model Creation {{{1
|
||||
|
||||
def model_from_config(config, input_features, output_features, callbacks):
|
||||
# Our Test Model
|
||||
|
||||
init = inits[config.init]
|
||||
activation = activations[config.activation]
|
||||
|
||||
x = Input(shape=(input_features,))
|
||||
y = x
|
||||
y = multiresnet(y,
|
||||
config.res_width, config.res_depth,
|
||||
config.res_block, config.res_multi,
|
||||
activation=activation, init=init,
|
||||
style=config.parallel_style)
|
||||
if y.output_shape[0] != output_features:
|
||||
y = y.feed(Dense(output_features, init))
|
||||
|
||||
model = Model(x, y, unsafe=config.unsafe)
|
||||
|
||||
#
|
||||
|
||||
if config.fn_load is not None:
|
||||
log('loading weights', config.fn_load)
|
||||
model.load_weights(config.fn_load)
|
||||
|
||||
#
|
||||
|
||||
def optim_from_config(config):
|
||||
if config.optim == 'adam':
|
||||
assert not config.nesterov, "unimplemented"
|
||||
d1 = config.optim_decay1 if 'optim_decay1' in config else 9.5
|
||||
|
@ -453,14 +428,9 @@ def model_from_config(config, input_features, output_features, callbacks):
|
|||
else:
|
||||
raise Exception('unknown optimizer', config.optim)
|
||||
|
||||
def rscb(restart):
|
||||
callbacks.restart()
|
||||
log("restarting", restart)
|
||||
if config.restart_optim:
|
||||
optim.reset()
|
||||
|
||||
#
|
||||
return optim
|
||||
|
||||
def learner_from_config(config, optim, rscb):
|
||||
if config.learner == 'sgdr':
|
||||
expando = config.expando if 'expando' in config else None
|
||||
learner = SGDR(optim, epochs=config.epochs, rate=config.learn,
|
||||
|
@ -484,24 +454,22 @@ def model_from_config(config, input_features, output_features, callbacks):
|
|||
else:
|
||||
raise Exception('unknown learner', config.learner)
|
||||
|
||||
#
|
||||
return learner
|
||||
|
||||
def lookup_loss(maybe_name):
|
||||
if isinstance(maybe_name, Loss):
|
||||
return maybe_name
|
||||
elif maybe_name == 'mse':
|
||||
return Squared()
|
||||
elif maybe_name == 'mshe': # mushy
|
||||
return SquaredHalved()
|
||||
elif maybe_name == 'mae':
|
||||
return Absolute()
|
||||
elif maybe_name == 'msee':
|
||||
return SomethingElse()
|
||||
raise Exception('unknown objective', maybe_name)
|
||||
|
||||
loss = lookup_loss(config.loss)
|
||||
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||
def lookup_loss(maybe_name):
|
||||
if isinstance(maybe_name, Loss):
|
||||
return maybe_name
|
||||
elif maybe_name == 'mse':
|
||||
return Squared()
|
||||
elif maybe_name == 'mshe': # mushy
|
||||
return SquaredHalved()
|
||||
elif maybe_name == 'mae':
|
||||
return Absolute()
|
||||
elif maybe_name == 'msee':
|
||||
return SomethingElse()
|
||||
raise Exception('unknown objective', maybe_name)
|
||||
|
||||
def ritual_from_config(config, learner, loss, mloss):
|
||||
if config.ritual == 'default':
|
||||
ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
|
||||
elif config.ritual == 'stochm':
|
||||
|
@ -513,7 +481,44 @@ def model_from_config(config, input_features, output_features, callbacks):
|
|||
else:
|
||||
raise Exception('unknown ritual', config.ritual)
|
||||
|
||||
#
|
||||
return ritual
|
||||
|
||||
def model_from_config(config, input_features, output_features, callbacks):
|
||||
# Our Test Model
|
||||
|
||||
init = inits[config.init]
|
||||
activation = activations[config.activation]
|
||||
|
||||
x = Input(shape=(input_features,))
|
||||
y = x
|
||||
y = multiresnet(y,
|
||||
config.res_width, config.res_depth,
|
||||
config.res_block, config.res_multi,
|
||||
activation=activation, init=init,
|
||||
style=config.parallel_style)
|
||||
if y.output_shape[0] != output_features:
|
||||
y = y.feed(Dense(output_features, init))
|
||||
|
||||
model = Model(x, y, unsafe=config.unsafe)
|
||||
|
||||
if config.fn_load is not None:
|
||||
log('loading weights', config.fn_load)
|
||||
model.load_weights(config.fn_load)
|
||||
|
||||
optim = optim_from_config(config)
|
||||
|
||||
def rscb(restart):
|
||||
callbacks.restart()
|
||||
log("restarting", restart)
|
||||
if config.restart_optim:
|
||||
optim.reset()
|
||||
|
||||
learner = learner_from_config(config, optim, rscb)
|
||||
|
||||
loss = lookup_loss(config.loss)
|
||||
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||
|
||||
ritual = ritual_from_config(config, learner, loss, mloss)
|
||||
|
||||
return model, learner, ritual
|
||||
|
||||
|
@ -599,13 +604,13 @@ def run(program, args=None):
|
|||
model, learner, ritual = \
|
||||
model_from_config(config, input_features, output_features, callbacks)
|
||||
|
||||
# Model Information
|
||||
# Model Information {{{2
|
||||
|
||||
for node in model.ordered_nodes:
|
||||
children = [str(n) for n in node.children]
|
||||
if children:
|
||||
sep = '->'
|
||||
print(str(node)+sep+('\n'+str(node)+sep).join(children))
|
||||
print(str(node) + sep + ('\n' + str(node) + sep).join(children))
|
||||
log('parameters', model.param_count)
|
||||
|
||||
# Training {{{2
|
||||
|
@ -636,17 +641,17 @@ def run(program, args=None):
|
|||
|
||||
training = config.epochs > 0 and config.restarts >= 0
|
||||
|
||||
if training:
|
||||
measure_error()
|
||||
|
||||
ritual.prepare(model)
|
||||
|
||||
if training and config.warmup:
|
||||
log("warming", "up")
|
||||
ritual.train_batched(
|
||||
np.random.normal(0, 1, size=inputs.shape),
|
||||
np.random.normal(0, 1, size=outputs.shape),
|
||||
np.random.normal(size=inputs.shape),
|
||||
np.random.normal(size=outputs.shape),
|
||||
config.batch_size)
|
||||
ritual.reset()
|
||||
|
||||
if training:
|
||||
measure_error()
|
||||
|
||||
while training and learner.next():
|
||||
|
|
|
@ -563,6 +563,8 @@ class Ritual: # i'm just making up names at this point
|
|||
|
||||
def reset(self):
|
||||
self.learner.reset(optim=True)
|
||||
self.en = 0
|
||||
self.bn = 0
|
||||
|
||||
def measure(self, p, y):
|
||||
return self.mloss.F(p, y)
|
||||
|
|
Loading…
Reference in a new issue