From 3106495704910bfa32d17aeaded953f7dcc1cf93 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Fri, 17 Feb 2017 18:37:04 -0800 Subject: [PATCH] . --- optim_nn.py | 119 ++++++++++++++++++++++++----------------------- optim_nn_core.py | 2 + 2 files changed, 64 insertions(+), 57 deletions(-) diff --git a/optim_nn.py b/optim_nn.py index 91c8c54..655c478 100644 --- a/optim_nn.py +++ b/optim_nn.py @@ -408,32 +408,7 @@ def toy_data(train_samples, valid_samples, problem=2): # Model Creation {{{1 -def model_from_config(config, input_features, output_features, callbacks): - # Our Test Model - - init = inits[config.init] - activation = activations[config.activation] - - x = Input(shape=(input_features,)) - y = x - y = multiresnet(y, - config.res_width, config.res_depth, - config.res_block, config.res_multi, - activation=activation, init=init, - style=config.parallel_style) - if y.output_shape[0] != output_features: - y = y.feed(Dense(output_features, init)) - - model = Model(x, y, unsafe=config.unsafe) - - # - - if config.fn_load is not None: - log('loading weights', config.fn_load) - model.load_weights(config.fn_load) - - # - +def optim_from_config(config): if config.optim == 'adam': assert not config.nesterov, "unimplemented" d1 = config.optim_decay1 if 'optim_decay1' in config else 9.5 @@ -453,14 +428,9 @@ def model_from_config(config, input_features, output_features, callbacks): else: raise Exception('unknown optimizer', config.optim) - def rscb(restart): - callbacks.restart() - log("restarting", restart) - if config.restart_optim: - optim.reset() - - # + return optim +def learner_from_config(config, optim, rscb): if config.learner == 'sgdr': expando = config.expando if 'expando' in config else None learner = SGDR(optim, epochs=config.epochs, rate=config.learn, @@ -484,24 +454,22 @@ def model_from_config(config, input_features, output_features, callbacks): else: raise Exception('unknown learner', config.learner) - # + return learner - def lookup_loss(maybe_name): - if isinstance(maybe_name, Loss): - return maybe_name - elif maybe_name == 'mse': - return Squared() - elif maybe_name == 'mshe': # mushy - return SquaredHalved() - elif maybe_name == 'mae': - return Absolute() - elif maybe_name == 'msee': - return SomethingElse() - raise Exception('unknown objective', maybe_name) - - loss = lookup_loss(config.loss) - mloss = lookup_loss(config.mloss) if config.mloss else loss +def lookup_loss(maybe_name): + if isinstance(maybe_name, Loss): + return maybe_name + elif maybe_name == 'mse': + return Squared() + elif maybe_name == 'mshe': # mushy + return SquaredHalved() + elif maybe_name == 'mae': + return Absolute() + elif maybe_name == 'msee': + return SomethingElse() + raise Exception('unknown objective', maybe_name) +def ritual_from_config(config, learner, loss, mloss): if config.ritual == 'default': ritual = Ritual(learner=learner, loss=loss, mloss=mloss) elif config.ritual == 'stochm': @@ -513,7 +481,44 @@ def model_from_config(config, input_features, output_features, callbacks): else: raise Exception('unknown ritual', config.ritual) - # + return ritual + +def model_from_config(config, input_features, output_features, callbacks): + # Our Test Model + + init = inits[config.init] + activation = activations[config.activation] + + x = Input(shape=(input_features,)) + y = x + y = multiresnet(y, + config.res_width, config.res_depth, + config.res_block, config.res_multi, + activation=activation, init=init, + style=config.parallel_style) + if y.output_shape[0] != output_features: + y = y.feed(Dense(output_features, init)) + + model = Model(x, y, unsafe=config.unsafe) + + if config.fn_load is not None: + log('loading weights', config.fn_load) + model.load_weights(config.fn_load) + + optim = optim_from_config(config) + + def rscb(restart): + callbacks.restart() + log("restarting", restart) + if config.restart_optim: + optim.reset() + + learner = learner_from_config(config, optim, rscb) + + loss = lookup_loss(config.loss) + mloss = lookup_loss(config.mloss) if config.mloss else loss + + ritual = ritual_from_config(config, learner, loss, mloss) return model, learner, ritual @@ -599,13 +604,13 @@ def run(program, args=None): model, learner, ritual = \ model_from_config(config, input_features, output_features, callbacks) - # Model Information + # Model Information {{{2 for node in model.ordered_nodes: children = [str(n) for n in node.children] if children: sep = '->' - print(str(node)+sep+('\n'+str(node)+sep).join(children)) + print(str(node) + sep + ('\n' + str(node) + sep).join(children)) log('parameters', model.param_count) # Training {{{2 @@ -636,17 +641,17 @@ def run(program, args=None): training = config.epochs > 0 and config.restarts >= 0 - if training: - measure_error() - ritual.prepare(model) if training and config.warmup: log("warming", "up") ritual.train_batched( - np.random.normal(0, 1, size=inputs.shape), - np.random.normal(0, 1, size=outputs.shape), + np.random.normal(size=inputs.shape), + np.random.normal(size=outputs.shape), config.batch_size) + ritual.reset() + + if training: measure_error() while training and learner.next(): diff --git a/optim_nn_core.py b/optim_nn_core.py index 3a42042..f314db5 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -563,6 +563,8 @@ class Ritual: # i'm just making up names at this point def reset(self): self.learner.reset(optim=True) + self.en = 0 + self.bn = 0 def measure(self, p, y): return self.mloss.F(p, y)