diff --git a/optim_nn.py b/optim_nn.py index 87acc4f..2d3c478 100644 --- a/optim_nn.py +++ b/optim_nn.py @@ -343,11 +343,12 @@ class GeluApprox(Layer): return dY * self.sig * (1 + self.a * (1 - self.sig)) class Dense(Layer): - def __init__(self, dim): + def __init__(self, dim, init=init_he_uniform): super().__init__() self.dim = ni(dim) self.output_shape = (dim,) self.size = None + self.weight_init = init def init(self, W, dW): ins, outs = self.input_shape[0], self.output_shape[0] @@ -359,7 +360,7 @@ class Dense(Layer): self.dcoeffs = self.dW[:self.nW].reshape(ins, outs) self.dbiases = self.dW[self.nW:].reshape(1, outs) - self.coeffs.flat = init_he_uniform(self.nW, ins, outs) + self.coeffs.flat = self.weight_init(self.nW, ins, outs) self.biases.flat = 0 def make_shape(self, shape): @@ -444,7 +445,9 @@ class Model: return self.dW def load_weights(self, fn): - # seemingly compatible with keras models at the moment + # seemingly compatible with keras' Dense layers. + # ignores any non-Dense layer types. + # TODO: assert file actually exists import h5py f = h5py.File(fn) weights = {} @@ -459,11 +462,25 @@ class Model: a, b = i, i + 1 b_name = "dense_{}".format(b) # TODO: write a Dense method instead of assigning directly - denses[a].coeffs = weights[b_name+'_W'] - denses[a].biases = np.expand_dims(weights[b_name+'_b'], 0) + denses[a].coeffs[:] = weights[b_name+'_W'] + denses[a].biases[:] = np.expand_dims(weights[b_name+'_b'], 0) def save_weights(self, fn, overwrite=False): - raise NotImplementedError("unimplemented", self) + import h5py + f = h5py.File(fn, 'w') + + denses = [node for node in self.ordered_nodes if isinstance(node, Dense)] + for i in range(len(denses)): + a, b = i, i + 1 + b_name = "dense_{}".format(b) + # TODO: write a Dense method instead of assigning directly + grp = f.create_group(b_name) + data = grp.create_dataset(b_name+'_W', denses[a].coeffs.shape, dtype=nf) + data[:] = denses[a].coeffs + data = grp.create_dataset(b_name+'_b', denses[a].biases.shape, dtype=nf) + data[:] = denses[a].biases + + f.close() class Ritual: def __init__(self, @@ -519,7 +536,9 @@ class Ritual: else: return avg_loss -def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batchless'): +def multiresnet(x, width, depth, block=2, multi=1, + activation=Relu, style='batchless', + init=init_he_normal): y = x last_size = x.output_shape[0] @@ -527,7 +546,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch size = width if last_size != size: - y = y.feed(Dense(size)) + y = y.feed(Dense(size, init)) if style == 'batchless': skip = y @@ -539,7 +558,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch for i in range(block): if i > 0: z = z.feed(activation()) - z = z.feed(Dense(size)) + z = z.feed(Dense(size, init)) z.feed(merger) y = merger elif style == 'onelesssum': @@ -556,7 +575,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch for i in range(block): if i > 0: z = z.feed(activation()) - z = z.feed(Dense(size)) + z = z.feed(Dense(size, init)) if needs_sum: z.feed(merger) if needs_sum: @@ -570,6 +589,9 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch return y +inits = dict(he_normal=init_he_normal, he_uniform=init_he_uniform) +activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox) + def run(program, args=[]): import sys lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs) @@ -580,7 +602,8 @@ def run(program, args=[]): from dotmap import DotMap config = DotMap( - fn = 'ml/cie_mlp_min.h5', + fn_load = None, + fn_save = 'optim_nn.h5', log_fn = 'losses.npz', # multi-residual network parameters @@ -598,11 +621,11 @@ def run(program, args=[]): momentum = 0.33, # only used with SGD # learning parameters: SGD with restarts (kinda) - LR = 1e-2, + learn = 1e-2, epochs = 24, - LR_halve_every = 16, + learn_halve_every = 16, restarts = 2, - LR_restart_advance = 16, + learn_restart_advance = 16, # misc batch_size = 64, @@ -635,16 +658,17 @@ def run(program, args=[]): # Our Test Model + init = inits[config.init] + activation = activations[config.activation] + x = Input(shape=(input_samples,)) y = x - activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox) - activation = activations[config.activation] y = multiresnet(y, config.res_width, config.res_depth, config.res_block, config.res_multi, - activation=activation) + activation=activation, init=init) if y.output_shape[0] != output_samples: - y = y.feed(Dense(output_samples)) + y = y.feed(Dense(output_samples, init)) model = Model(x, y, unsafe=config.unsafe) @@ -654,14 +678,9 @@ def run(program, args=[]): training = config.epochs > 0 and config.restarts >= 0 - if not training: - assert config.res_width == 12 - assert config.res_depth == 3 - assert config.res_block == 2 - assert config.res_multi == 4 - assert config.activation == 'relu' - assert config.parallel_style == 'batchless' - model.load_weights(config.fn) + if config.fn_load is not None: + log('loading weights', config.fn_load) + model.load_weights(config.fn_load) if config.optim == 'adam': assert not config.nesterov, "unimplemented" @@ -686,13 +705,13 @@ def run(program, args=[]): loss = lookup_loss(config.loss) mloss = lookup_loss(config.mloss) if config.mloss else loss - anneal = 0.5**(1/config.LR_halve_every) + anneal = 0.5**(1/config.learn_halve_every) ritual = Ritual(optim=optim, - learn_rate=config.LR, learn_anneal=anneal, - learn_advance=config.LR_restart_advance, + learn_rate=config.learn, learn_anneal=anneal, + learn_advance=config.learn_restart_advance, loss=loss, mloss=mloss) - learn_end = config.LR * (anneal**config.LR_restart_advance)**config.restarts * anneal**(config.epochs - 1) + learn_end = config.learn * (anneal**config.learn_restart_advance)**config.restarts * anneal**(config.epochs - 1) log("final learning rate", "{:10.8f}".format(learn_end)) # Training @@ -747,8 +766,9 @@ def run(program, args=[]): measure_error() - #if training: - # model.save_weights(config.fn, overwrite=True) + if config.fn_save is not None: + log('saving weights', config.fn_save) + model.save_weights(config.fn_save, overwrite=True) # Evaluation