This commit is contained in:
Connor Olding 2017-02-17 18:37:04 -08:00
parent fe577eb7f4
commit 3106495704
2 changed files with 64 additions and 57 deletions

View file

@ -408,32 +408,7 @@ def toy_data(train_samples, valid_samples, problem=2):
# Model Creation {{{1 # Model Creation {{{1
def model_from_config(config, input_features, output_features, callbacks): def optim_from_config(config):
# Our Test Model
init = inits[config.init]
activation = activations[config.activation]
x = Input(shape=(input_features,))
y = x
y = multiresnet(y,
config.res_width, config.res_depth,
config.res_block, config.res_multi,
activation=activation, init=init,
style=config.parallel_style)
if y.output_shape[0] != output_features:
y = y.feed(Dense(output_features, init))
model = Model(x, y, unsafe=config.unsafe)
#
if config.fn_load is not None:
log('loading weights', config.fn_load)
model.load_weights(config.fn_load)
#
if config.optim == 'adam': if config.optim == 'adam':
assert not config.nesterov, "unimplemented" assert not config.nesterov, "unimplemented"
d1 = config.optim_decay1 if 'optim_decay1' in config else 9.5 d1 = config.optim_decay1 if 'optim_decay1' in config else 9.5
@ -453,14 +428,9 @@ def model_from_config(config, input_features, output_features, callbacks):
else: else:
raise Exception('unknown optimizer', config.optim) raise Exception('unknown optimizer', config.optim)
def rscb(restart): return optim
callbacks.restart()
log("restarting", restart)
if config.restart_optim:
optim.reset()
#
def learner_from_config(config, optim, rscb):
if config.learner == 'sgdr': if config.learner == 'sgdr':
expando = config.expando if 'expando' in config else None expando = config.expando if 'expando' in config else None
learner = SGDR(optim, epochs=config.epochs, rate=config.learn, learner = SGDR(optim, epochs=config.epochs, rate=config.learn,
@ -484,9 +454,9 @@ def model_from_config(config, input_features, output_features, callbacks):
else: else:
raise Exception('unknown learner', config.learner) raise Exception('unknown learner', config.learner)
# return learner
def lookup_loss(maybe_name): def lookup_loss(maybe_name):
if isinstance(maybe_name, Loss): if isinstance(maybe_name, Loss):
return maybe_name return maybe_name
elif maybe_name == 'mse': elif maybe_name == 'mse':
@ -499,9 +469,7 @@ def model_from_config(config, input_features, output_features, callbacks):
return SomethingElse() return SomethingElse()
raise Exception('unknown objective', maybe_name) raise Exception('unknown objective', maybe_name)
loss = lookup_loss(config.loss) def ritual_from_config(config, learner, loss, mloss):
mloss = lookup_loss(config.mloss) if config.mloss else loss
if config.ritual == 'default': if config.ritual == 'default':
ritual = Ritual(learner=learner, loss=loss, mloss=mloss) ritual = Ritual(learner=learner, loss=loss, mloss=mloss)
elif config.ritual == 'stochm': elif config.ritual == 'stochm':
@ -513,7 +481,44 @@ def model_from_config(config, input_features, output_features, callbacks):
else: else:
raise Exception('unknown ritual', config.ritual) raise Exception('unknown ritual', config.ritual)
# return ritual
def model_from_config(config, input_features, output_features, callbacks):
# Our Test Model
init = inits[config.init]
activation = activations[config.activation]
x = Input(shape=(input_features,))
y = x
y = multiresnet(y,
config.res_width, config.res_depth,
config.res_block, config.res_multi,
activation=activation, init=init,
style=config.parallel_style)
if y.output_shape[0] != output_features:
y = y.feed(Dense(output_features, init))
model = Model(x, y, unsafe=config.unsafe)
if config.fn_load is not None:
log('loading weights', config.fn_load)
model.load_weights(config.fn_load)
optim = optim_from_config(config)
def rscb(restart):
callbacks.restart()
log("restarting", restart)
if config.restart_optim:
optim.reset()
learner = learner_from_config(config, optim, rscb)
loss = lookup_loss(config.loss)
mloss = lookup_loss(config.mloss) if config.mloss else loss
ritual = ritual_from_config(config, learner, loss, mloss)
return model, learner, ritual return model, learner, ritual
@ -599,13 +604,13 @@ def run(program, args=None):
model, learner, ritual = \ model, learner, ritual = \
model_from_config(config, input_features, output_features, callbacks) model_from_config(config, input_features, output_features, callbacks)
# Model Information # Model Information {{{2
for node in model.ordered_nodes: for node in model.ordered_nodes:
children = [str(n) for n in node.children] children = [str(n) for n in node.children]
if children: if children:
sep = '->' sep = '->'
print(str(node)+sep+('\n'+str(node)+sep).join(children)) print(str(node) + sep + ('\n' + str(node) + sep).join(children))
log('parameters', model.param_count) log('parameters', model.param_count)
# Training {{{2 # Training {{{2
@ -636,17 +641,17 @@ def run(program, args=None):
training = config.epochs > 0 and config.restarts >= 0 training = config.epochs > 0 and config.restarts >= 0
if training:
measure_error()
ritual.prepare(model) ritual.prepare(model)
if training and config.warmup: if training and config.warmup:
log("warming", "up") log("warming", "up")
ritual.train_batched( ritual.train_batched(
np.random.normal(0, 1, size=inputs.shape), np.random.normal(size=inputs.shape),
np.random.normal(0, 1, size=outputs.shape), np.random.normal(size=outputs.shape),
config.batch_size) config.batch_size)
ritual.reset()
if training:
measure_error() measure_error()
while training and learner.next(): while training and learner.next():

View file

@ -563,6 +563,8 @@ class Ritual: # i'm just making up names at this point
def reset(self): def reset(self):
self.learner.reset(optim=True) self.learner.reset(optim=True)
self.en = 0
self.bn = 0
def measure(self, p, y): def measure(self, p, y):
return self.mloss.F(p, y) return self.mloss.F(p, y)