This commit is contained in:
Connor Olding 2017-02-17 18:37:04 -08:00
parent fe577eb7f4
commit 3106495704
2 changed files with 64 additions and 57 deletions

View file

@ -408,32 +408,7 @@ def toy_data(train_samples, valid_samples, problem=2):
# Model Creation {{{1
def model_from_config(config, input_features, output_features, callbacks):
# Our Test Model
init = inits[config.init]
activation = activations[config.activation]
x = Input(shape=(input_features,))
y = x
y = multiresnet(y,
config.res_width, config.res_depth,
config.res_block, config.res_multi,
activation=activation, init=init,
style=config.parallel_style)
if y.output_shape[0] != output_features:
y = y.feed(Dense(output_features, init))
model = Model(x, y, unsafe=config.unsafe)
#
if config.fn_load is not None:
log('loading weights', config.fn_load)
model.load_weights(config.fn_load)
#
def optim_from_config(config):
if config.optim == 'adam':
assert not config.nesterov, "unimplemented"
d1 = config.optim_decay1 if 'optim_decay1' in config else 9.5
@ -453,14 +428,9 @@ def model_from_config(config, input_features, output_features, callbacks):
else:
raise Exception('unknown optimizer', config.optim)
def rscb(restart):
callbacks.restart()
log("restarting", restart)
if config.restart_optim:
optim.reset()
#
return optim
def learner_from_config(config, optim, rscb):
if config.learner == 'sgdr':
expando = config.expando if 'expando' in config else None
learner = SGDR(optim, epochs=config.epochs, rate=config.learn,
@ -484,24 +454,22 @@ def model_from_config(config, input_features, output_features, callbacks):
else:
raise Exception('unknown learner', config.learner)
#
return learner
def lookup_loss(maybe_name):
if isinstance(maybe_name, Loss):
return maybe_name
elif maybe_name == 'mse':
return Squared()
elif maybe_name == 'mshe': # mushy
return SquaredHalved()
elif maybe_name == 'mae':
return Absolute()
elif maybe_name == 'msee':
return SomethingElse()
raise Exception('unknown objective', maybe_name)
loss = lookup_loss(config.loss)
mloss = lookup_loss(config.mloss) if config.mloss else loss
def lookup_loss(maybe_name):
if isinstance(maybe_name, Loss):
return maybe_name
elif maybe_name == 'mse':
return Squared()
elif maybe_name == 'mshe': # mushy
return SquaredHalved()
elif maybe_name == 'mae':
return Absolute()
elif maybe_name == 'msee':
return SomethingElse()
raise Exception('unknown objective', maybe_name)
def ritual_from_config(config, learner, loss, mloss):
if config.ritual == 'default':
ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
elif config.ritual == 'stochm':
@ -513,7 +481,44 @@ def model_from_config(config, input_features, output_features, callbacks):
else:
raise Exception('unknown ritual', config.ritual)
#
return ritual
def model_from_config(config, input_features, output_features, callbacks):
# Our Test Model
init = inits[config.init]
activation = activations[config.activation]
x = Input(shape=(input_features,))
y = x
y = multiresnet(y,
config.res_width, config.res_depth,
config.res_block, config.res_multi,
activation=activation, init=init,
style=config.parallel_style)
if y.output_shape[0] != output_features:
y = y.feed(Dense(output_features, init))
model = Model(x, y, unsafe=config.unsafe)
if config.fn_load is not None:
log('loading weights', config.fn_load)
model.load_weights(config.fn_load)
optim = optim_from_config(config)
def rscb(restart):
callbacks.restart()
log("restarting", restart)
if config.restart_optim:
optim.reset()
learner = learner_from_config(config, optim, rscb)
loss = lookup_loss(config.loss)
mloss = lookup_loss(config.mloss) if config.mloss else loss
ritual = ritual_from_config(config, learner, loss, mloss)
return model, learner, ritual
@ -599,13 +604,13 @@ def run(program, args=None):
model, learner, ritual = \
model_from_config(config, input_features, output_features, callbacks)
# Model Information
# Model Information {{{2
for node in model.ordered_nodes:
children = [str(n) for n in node.children]
if children:
sep = '->'
print(str(node)+sep+('\n'+str(node)+sep).join(children))
print(str(node) + sep + ('\n' + str(node) + sep).join(children))
log('parameters', model.param_count)
# Training {{{2
@ -636,17 +641,17 @@ def run(program, args=None):
training = config.epochs > 0 and config.restarts >= 0
if training:
measure_error()
ritual.prepare(model)
if training and config.warmup:
log("warming", "up")
ritual.train_batched(
np.random.normal(0, 1, size=inputs.shape),
np.random.normal(0, 1, size=outputs.shape),
np.random.normal(size=inputs.shape),
np.random.normal(size=outputs.shape),
config.batch_size)
ritual.reset()
if training:
measure_error()
while training and learner.next():

View file

@ -563,6 +563,8 @@ class Ritual: # i'm just making up names at this point
def reset(self):
self.learner.reset(optim=True)
self.en = 0
self.bn = 0
def measure(self, p, y):
return self.mloss.F(p, y)