From e7c12c1f448590867c20130cca317aa675e1945f Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Wed, 2 Aug 2017 11:28:18 +0000 Subject: [PATCH] add ad-hoc weight-sharing method --- onn_core.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/onn_core.py b/onn_core.py index 2228eb9..f1781d4 100644 --- a/onn_core.py +++ b/onn_core.py @@ -132,6 +132,7 @@ class Weights: self.init = None self.allocator = None self.regularizer = None + self._allocated = False self.configure(**kwargs) @@ -146,6 +147,8 @@ class Weights: return np.prod(self.shape) def allocate(self, *args, **kwargs): + if self._allocated: + raise Exception("attempted to allocate existing weights") self.configure(**kwargs) # intentionally not using isinstance @@ -159,6 +162,8 @@ class Weights: self.f = f.reshape(self.shape) self.g = g.reshape(self.shape) + self._allocated = True + def forward(self): if self.regularizer is None: return 0.0 @@ -482,7 +487,7 @@ class Layer: _layer_counters[kind] += 1 self.name = "{}_{}".format(kind, _layer_counters[kind]) self.unsafe = False # disables assertions for better performance - # TODO: allow weights to be shared across layers. + self.shared = False # as in weight sharing def __str__(self): return self.name @@ -551,6 +556,13 @@ class Layer: self.weights[name] = w return w + def share(self, node): + self.weights = node.weights # TODO: this should be all it takes. + for k, v in self.weights.items(): + vs = getattr(node, k) # hack: key isn't necessarily attribute name! + setattr(self, k, vs) + self.shared = True + def clear_grad(self): for name, w in self.weights.items(): w.g[:] = 0 @@ -864,13 +876,13 @@ class Model: return self.nodes def make_weights(self): - self.param_count = sum((node.size for node in self.nodes)) + self.param_count = sum((node.size for node in self.nodes if not node.shared)) self.W = np.zeros(self.param_count, dtype=_f) self.dW = np.zeros(self.param_count, dtype=_f) offset = 0 for node in self.nodes: - if node.size > 0: + if node.size > 0 and not node.shared: inner_offset = 0 def allocate(size): @@ -941,6 +953,7 @@ class Model: used[k] = False nodes = [node for node in self.nodes if node.size > 0] + # TODO: support shared weights. for node in nodes: full_name = str(node).lower() for s_name, o_name in node.serialized.items(): @@ -961,6 +974,7 @@ class Model: counts = defaultdict(lambda: 0) nodes = [node for node in self.nodes if node.size > 0] + # TODO: support shared weights. for node in nodes: full_name = str(node).lower() grp = f.create_group(full_name)