From 547017a6fce6d5ea9e87f685c7e881f854f89e1b Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 27 Feb 2017 22:14:58 +0000 Subject: [PATCH] . --- optim_nn.py | 11 +++++++++ optim_nn_core.py | 58 +++++++++++++++++++++++++++++++----------------- 2 files changed, 49 insertions(+), 20 deletions(-) diff --git a/optim_nn.py b/optim_nn.py index 99e8071..23d3230 100644 --- a/optim_nn.py +++ b/optim_nn.py @@ -65,6 +65,12 @@ class LayerNorm(Layer): self.affine = bool(affine) self.size = None + if self.affine: + self.serialized = { + 'gamma': 'gamma', + 'beta': 'beta', + } + def make_shape(self, shape): super().make_shape(shape) if len(shape) != 1: @@ -116,6 +122,11 @@ class LayerNorm(Layer): class Denses(Layer): # TODO: rename? # acts as a separate Dense for each row or column. only for 2D arrays. + serialized = { + 'W': 'coeffs', + 'b': 'biases', + } + def __init__(self, dim, init=init_he_uniform, axis=-1): super().__init__() self.dim = int(dim) diff --git a/optim_nn_core.py b/optim_nn_core.py index 9c441e3..f65b7d4 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -515,6 +515,11 @@ class Softmax(Layer): # Parametric Layers {{{1 class Dense(Layer): + serialized = { + 'W': 'coeffs', + 'b': 'biases', + } + def __init__(self, dim, init=init_he_uniform): super().__init__() self.dim = int(dim) @@ -615,10 +620,9 @@ class Model: def load_weights(self, fn): # seemingly compatible with keras' Dense layers. - # ignores any non-Dense layer types. - # TODO: assert file actually exists import h5py - f = h5py.File(fn) + open(fn) # just ensure the file exists (python's error is better) + f = h5py.File(fn, 'r') weights = {} def visitor(name, obj): if isinstance(obj, h5py.Dataset): @@ -626,28 +630,42 @@ class Model: f.visititems(visitor) f.close() - denses = [node for node in self.ordered_nodes if isinstance(node, Dense)] - for i in range(len(denses)): - a, b = i, i + 1 - b_name = "dense_{}".format(b) - # TODO: write a Dense method instead of assigning directly - denses[a].coeffs[:] = weights[b_name+'_W'] - denses[a].biases[:] = np.expand_dims(weights[b_name+'_b'], 0) + used = {} + for k in weights.keys(): + used[k] = False + + nodes = [node for node in self.ordered_nodes if node.size is not None] + for node in nodes: + full_name = str(node).lower() + for s_name, o_name in node.serialized.items(): + key = full_name + '_' + s_name + data = weights[key] + target = getattr(node, o_name) + target[:] = data + used[key] = True + + for k, v in used.items(): + if not v: + lament("WARNING: unused weight", k) def save_weights(self, fn, overwrite=False): import h5py f = h5py.File(fn, 'w') - denses = [node for node in self.ordered_nodes if isinstance(node, Dense)] - for i in range(len(denses)): - a, b = i, i + 1 - b_name = "dense_{}".format(b) - # TODO: write a Dense method instead of assigning directly - grp = f.create_group(b_name) - data = grp.create_dataset(b_name+'_W', denses[a].coeffs.shape, dtype=_f) - data[:] = denses[a].coeffs - data = grp.create_dataset(b_name+'_b', denses[a].biases.shape, dtype=_f) - data[:] = denses[a].biases + counts = defaultdict(lambda: 0) + + nodes = [node for node in self.ordered_nodes if node.size is not None] + for node in nodes: + full_name = str(node).lower() + grp = f.create_group(full_name) + for s_name, o_name in node.serialized.items(): + key = full_name + '_' + s_name + target = getattr(node, o_name) + data = grp.create_dataset(key, target.shape, dtype=_f) + data[:] = target + counts[key] += 1 + if counts[key] > 1: + lament("WARNING: rewrote weight", key) f.close()