add ad-hoc weight-sharing method

This commit is contained in:
Connor Olding 2017-08-02 11:28:18 +00:00
parent f507dc10f8
commit e7c12c1f44

View File

@ -132,6 +132,7 @@ class Weights:
self.init = None
self.allocator = None
self.regularizer = None
self._allocated = False
self.configure(**kwargs)
@ -146,6 +147,8 @@ class Weights:
return np.prod(self.shape)
def allocate(self, *args, **kwargs):
if self._allocated:
raise Exception("attempted to allocate existing weights")
self.configure(**kwargs)
# intentionally not using isinstance
@ -159,6 +162,8 @@ class Weights:
self.f = f.reshape(self.shape)
self.g = g.reshape(self.shape)
self._allocated = True
def forward(self):
if self.regularizer is None:
return 0.0
@ -482,7 +487,7 @@ class Layer:
_layer_counters[kind] += 1
self.name = "{}_{}".format(kind, _layer_counters[kind])
self.unsafe = False # disables assertions for better performance
# TODO: allow weights to be shared across layers.
self.shared = False # as in weight sharing
def __str__(self):
return self.name
@ -551,6 +556,13 @@ class Layer:
self.weights[name] = w
return w
def share(self, node):
self.weights = node.weights # TODO: this should be all it takes.
for k, v in self.weights.items():
vs = getattr(node, k) # hack: key isn't necessarily attribute name!
setattr(self, k, vs)
self.shared = True
def clear_grad(self):
for name, w in self.weights.items():
w.g[:] = 0
@ -864,13 +876,13 @@ class Model:
return self.nodes
def make_weights(self):
self.param_count = sum((node.size for node in self.nodes))
self.param_count = sum((node.size for node in self.nodes if not node.shared))
self.W = np.zeros(self.param_count, dtype=_f)
self.dW = np.zeros(self.param_count, dtype=_f)
offset = 0
for node in self.nodes:
if node.size > 0:
if node.size > 0 and not node.shared:
inner_offset = 0
def allocate(size):
@ -941,6 +953,7 @@ class Model:
used[k] = False
nodes = [node for node in self.nodes if node.size > 0]
# TODO: support shared weights.
for node in nodes:
full_name = str(node).lower()
for s_name, o_name in node.serialized.items():
@ -961,6 +974,7 @@ class Model:
counts = defaultdict(lambda: 0)
nodes = [node for node in self.nodes if node.size > 0]
# TODO: support shared weights.
for node in nodes:
full_name = str(node).lower()
grp = f.create_group(full_name)