add ad-hoc weight-sharing method
This commit is contained in:
parent
f507dc10f8
commit
e7c12c1f44
1 changed files with 17 additions and 3 deletions
20
onn_core.py
20
onn_core.py
|
@ -132,6 +132,7 @@ class Weights:
|
|||
self.init = None
|
||||
self.allocator = None
|
||||
self.regularizer = None
|
||||
self._allocated = False
|
||||
|
||||
self.configure(**kwargs)
|
||||
|
||||
|
@ -146,6 +147,8 @@ class Weights:
|
|||
return np.prod(self.shape)
|
||||
|
||||
def allocate(self, *args, **kwargs):
|
||||
if self._allocated:
|
||||
raise Exception("attempted to allocate existing weights")
|
||||
self.configure(**kwargs)
|
||||
|
||||
# intentionally not using isinstance
|
||||
|
@ -159,6 +162,8 @@ class Weights:
|
|||
self.f = f.reshape(self.shape)
|
||||
self.g = g.reshape(self.shape)
|
||||
|
||||
self._allocated = True
|
||||
|
||||
def forward(self):
|
||||
if self.regularizer is None:
|
||||
return 0.0
|
||||
|
@ -482,7 +487,7 @@ class Layer:
|
|||
_layer_counters[kind] += 1
|
||||
self.name = "{}_{}".format(kind, _layer_counters[kind])
|
||||
self.unsafe = False # disables assertions for better performance
|
||||
# TODO: allow weights to be shared across layers.
|
||||
self.shared = False # as in weight sharing
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
@ -551,6 +556,13 @@ class Layer:
|
|||
self.weights[name] = w
|
||||
return w
|
||||
|
||||
def share(self, node):
|
||||
self.weights = node.weights # TODO: this should be all it takes.
|
||||
for k, v in self.weights.items():
|
||||
vs = getattr(node, k) # hack: key isn't necessarily attribute name!
|
||||
setattr(self, k, vs)
|
||||
self.shared = True
|
||||
|
||||
def clear_grad(self):
|
||||
for name, w in self.weights.items():
|
||||
w.g[:] = 0
|
||||
|
@ -864,13 +876,13 @@ class Model:
|
|||
return self.nodes
|
||||
|
||||
def make_weights(self):
|
||||
self.param_count = sum((node.size for node in self.nodes))
|
||||
self.param_count = sum((node.size for node in self.nodes if not node.shared))
|
||||
self.W = np.zeros(self.param_count, dtype=_f)
|
||||
self.dW = np.zeros(self.param_count, dtype=_f)
|
||||
|
||||
offset = 0
|
||||
for node in self.nodes:
|
||||
if node.size > 0:
|
||||
if node.size > 0 and not node.shared:
|
||||
inner_offset = 0
|
||||
|
||||
def allocate(size):
|
||||
|
@ -941,6 +953,7 @@ class Model:
|
|||
used[k] = False
|
||||
|
||||
nodes = [node for node in self.nodes if node.size > 0]
|
||||
# TODO: support shared weights.
|
||||
for node in nodes:
|
||||
full_name = str(node).lower()
|
||||
for s_name, o_name in node.serialized.items():
|
||||
|
@ -961,6 +974,7 @@ class Model:
|
|||
counts = defaultdict(lambda: 0)
|
||||
|
||||
nodes = [node for node in self.nodes if node.size > 0]
|
||||
# TODO: support shared weights.
|
||||
for node in nodes:
|
||||
full_name = str(node).lower()
|
||||
grp = f.create_group(full_name)
|
||||
|
|
Loading…
Reference in a new issue