.
This commit is contained in:
parent
fe577eb7f4
commit
3106495704
2 changed files with 64 additions and 57 deletions
93
optim_nn.py
93
optim_nn.py
|
@ -408,32 +408,7 @@ def toy_data(train_samples, valid_samples, problem=2):
|
||||||
|
|
||||||
# Model Creation {{{1
|
# Model Creation {{{1
|
||||||
|
|
||||||
def model_from_config(config, input_features, output_features, callbacks):
|
def optim_from_config(config):
|
||||||
# Our Test Model
|
|
||||||
|
|
||||||
init = inits[config.init]
|
|
||||||
activation = activations[config.activation]
|
|
||||||
|
|
||||||
x = Input(shape=(input_features,))
|
|
||||||
y = x
|
|
||||||
y = multiresnet(y,
|
|
||||||
config.res_width, config.res_depth,
|
|
||||||
config.res_block, config.res_multi,
|
|
||||||
activation=activation, init=init,
|
|
||||||
style=config.parallel_style)
|
|
||||||
if y.output_shape[0] != output_features:
|
|
||||||
y = y.feed(Dense(output_features, init))
|
|
||||||
|
|
||||||
model = Model(x, y, unsafe=config.unsafe)
|
|
||||||
|
|
||||||
#
|
|
||||||
|
|
||||||
if config.fn_load is not None:
|
|
||||||
log('loading weights', config.fn_load)
|
|
||||||
model.load_weights(config.fn_load)
|
|
||||||
|
|
||||||
#
|
|
||||||
|
|
||||||
if config.optim == 'adam':
|
if config.optim == 'adam':
|
||||||
assert not config.nesterov, "unimplemented"
|
assert not config.nesterov, "unimplemented"
|
||||||
d1 = config.optim_decay1 if 'optim_decay1' in config else 9.5
|
d1 = config.optim_decay1 if 'optim_decay1' in config else 9.5
|
||||||
|
@ -453,14 +428,9 @@ def model_from_config(config, input_features, output_features, callbacks):
|
||||||
else:
|
else:
|
||||||
raise Exception('unknown optimizer', config.optim)
|
raise Exception('unknown optimizer', config.optim)
|
||||||
|
|
||||||
def rscb(restart):
|
return optim
|
||||||
callbacks.restart()
|
|
||||||
log("restarting", restart)
|
|
||||||
if config.restart_optim:
|
|
||||||
optim.reset()
|
|
||||||
|
|
||||||
#
|
|
||||||
|
|
||||||
|
def learner_from_config(config, optim, rscb):
|
||||||
if config.learner == 'sgdr':
|
if config.learner == 'sgdr':
|
||||||
expando = config.expando if 'expando' in config else None
|
expando = config.expando if 'expando' in config else None
|
||||||
learner = SGDR(optim, epochs=config.epochs, rate=config.learn,
|
learner = SGDR(optim, epochs=config.epochs, rate=config.learn,
|
||||||
|
@ -484,7 +454,7 @@ def model_from_config(config, input_features, output_features, callbacks):
|
||||||
else:
|
else:
|
||||||
raise Exception('unknown learner', config.learner)
|
raise Exception('unknown learner', config.learner)
|
||||||
|
|
||||||
#
|
return learner
|
||||||
|
|
||||||
def lookup_loss(maybe_name):
|
def lookup_loss(maybe_name):
|
||||||
if isinstance(maybe_name, Loss):
|
if isinstance(maybe_name, Loss):
|
||||||
|
@ -499,9 +469,7 @@ def model_from_config(config, input_features, output_features, callbacks):
|
||||||
return SomethingElse()
|
return SomethingElse()
|
||||||
raise Exception('unknown objective', maybe_name)
|
raise Exception('unknown objective', maybe_name)
|
||||||
|
|
||||||
loss = lookup_loss(config.loss)
|
def ritual_from_config(config, learner, loss, mloss):
|
||||||
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
|
||||||
|
|
||||||
if config.ritual == 'default':
|
if config.ritual == 'default':
|
||||||
ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
|
ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
|
||||||
elif config.ritual == 'stochm':
|
elif config.ritual == 'stochm':
|
||||||
|
@ -513,7 +481,44 @@ def model_from_config(config, input_features, output_features, callbacks):
|
||||||
else:
|
else:
|
||||||
raise Exception('unknown ritual', config.ritual)
|
raise Exception('unknown ritual', config.ritual)
|
||||||
|
|
||||||
#
|
return ritual
|
||||||
|
|
||||||
|
def model_from_config(config, input_features, output_features, callbacks):
|
||||||
|
# Our Test Model
|
||||||
|
|
||||||
|
init = inits[config.init]
|
||||||
|
activation = activations[config.activation]
|
||||||
|
|
||||||
|
x = Input(shape=(input_features,))
|
||||||
|
y = x
|
||||||
|
y = multiresnet(y,
|
||||||
|
config.res_width, config.res_depth,
|
||||||
|
config.res_block, config.res_multi,
|
||||||
|
activation=activation, init=init,
|
||||||
|
style=config.parallel_style)
|
||||||
|
if y.output_shape[0] != output_features:
|
||||||
|
y = y.feed(Dense(output_features, init))
|
||||||
|
|
||||||
|
model = Model(x, y, unsafe=config.unsafe)
|
||||||
|
|
||||||
|
if config.fn_load is not None:
|
||||||
|
log('loading weights', config.fn_load)
|
||||||
|
model.load_weights(config.fn_load)
|
||||||
|
|
||||||
|
optim = optim_from_config(config)
|
||||||
|
|
||||||
|
def rscb(restart):
|
||||||
|
callbacks.restart()
|
||||||
|
log("restarting", restart)
|
||||||
|
if config.restart_optim:
|
||||||
|
optim.reset()
|
||||||
|
|
||||||
|
learner = learner_from_config(config, optim, rscb)
|
||||||
|
|
||||||
|
loss = lookup_loss(config.loss)
|
||||||
|
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||||
|
|
||||||
|
ritual = ritual_from_config(config, learner, loss, mloss)
|
||||||
|
|
||||||
return model, learner, ritual
|
return model, learner, ritual
|
||||||
|
|
||||||
|
@ -599,7 +604,7 @@ def run(program, args=None):
|
||||||
model, learner, ritual = \
|
model, learner, ritual = \
|
||||||
model_from_config(config, input_features, output_features, callbacks)
|
model_from_config(config, input_features, output_features, callbacks)
|
||||||
|
|
||||||
# Model Information
|
# Model Information {{{2
|
||||||
|
|
||||||
for node in model.ordered_nodes:
|
for node in model.ordered_nodes:
|
||||||
children = [str(n) for n in node.children]
|
children = [str(n) for n in node.children]
|
||||||
|
@ -636,17 +641,17 @@ def run(program, args=None):
|
||||||
|
|
||||||
training = config.epochs > 0 and config.restarts >= 0
|
training = config.epochs > 0 and config.restarts >= 0
|
||||||
|
|
||||||
if training:
|
|
||||||
measure_error()
|
|
||||||
|
|
||||||
ritual.prepare(model)
|
ritual.prepare(model)
|
||||||
|
|
||||||
if training and config.warmup:
|
if training and config.warmup:
|
||||||
log("warming", "up")
|
log("warming", "up")
|
||||||
ritual.train_batched(
|
ritual.train_batched(
|
||||||
np.random.normal(0, 1, size=inputs.shape),
|
np.random.normal(size=inputs.shape),
|
||||||
np.random.normal(0, 1, size=outputs.shape),
|
np.random.normal(size=outputs.shape),
|
||||||
config.batch_size)
|
config.batch_size)
|
||||||
|
ritual.reset()
|
||||||
|
|
||||||
|
if training:
|
||||||
measure_error()
|
measure_error()
|
||||||
|
|
||||||
while training and learner.next():
|
while training and learner.next():
|
||||||
|
|
|
@ -563,6 +563,8 @@ class Ritual: # i'm just making up names at this point
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.learner.reset(optim=True)
|
self.learner.reset(optim=True)
|
||||||
|
self.en = 0
|
||||||
|
self.bn = 0
|
||||||
|
|
||||||
def measure(self, p, y):
|
def measure(self, p, y):
|
||||||
return self.mloss.F(p, y)
|
return self.mloss.F(p, y)
|
||||||
|
|
Loading…
Reference in a new issue