add ad-hoc weight-sharing method
This commit is contained in:
parent
f507dc10f8
commit
e7c12c1f44
1 changed files with 17 additions and 3 deletions
20
onn_core.py
20
onn_core.py
|
@ -132,6 +132,7 @@ class Weights:
|
||||||
self.init = None
|
self.init = None
|
||||||
self.allocator = None
|
self.allocator = None
|
||||||
self.regularizer = None
|
self.regularizer = None
|
||||||
|
self._allocated = False
|
||||||
|
|
||||||
self.configure(**kwargs)
|
self.configure(**kwargs)
|
||||||
|
|
||||||
|
@ -146,6 +147,8 @@ class Weights:
|
||||||
return np.prod(self.shape)
|
return np.prod(self.shape)
|
||||||
|
|
||||||
def allocate(self, *args, **kwargs):
|
def allocate(self, *args, **kwargs):
|
||||||
|
if self._allocated:
|
||||||
|
raise Exception("attempted to allocate existing weights")
|
||||||
self.configure(**kwargs)
|
self.configure(**kwargs)
|
||||||
|
|
||||||
# intentionally not using isinstance
|
# intentionally not using isinstance
|
||||||
|
@ -159,6 +162,8 @@ class Weights:
|
||||||
self.f = f.reshape(self.shape)
|
self.f = f.reshape(self.shape)
|
||||||
self.g = g.reshape(self.shape)
|
self.g = g.reshape(self.shape)
|
||||||
|
|
||||||
|
self._allocated = True
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
if self.regularizer is None:
|
if self.regularizer is None:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
@ -482,7 +487,7 @@ class Layer:
|
||||||
_layer_counters[kind] += 1
|
_layer_counters[kind] += 1
|
||||||
self.name = "{}_{}".format(kind, _layer_counters[kind])
|
self.name = "{}_{}".format(kind, _layer_counters[kind])
|
||||||
self.unsafe = False # disables assertions for better performance
|
self.unsafe = False # disables assertions for better performance
|
||||||
# TODO: allow weights to be shared across layers.
|
self.shared = False # as in weight sharing
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
@ -551,6 +556,13 @@ class Layer:
|
||||||
self.weights[name] = w
|
self.weights[name] = w
|
||||||
return w
|
return w
|
||||||
|
|
||||||
|
def share(self, node):
|
||||||
|
self.weights = node.weights # TODO: this should be all it takes.
|
||||||
|
for k, v in self.weights.items():
|
||||||
|
vs = getattr(node, k) # hack: key isn't necessarily attribute name!
|
||||||
|
setattr(self, k, vs)
|
||||||
|
self.shared = True
|
||||||
|
|
||||||
def clear_grad(self):
|
def clear_grad(self):
|
||||||
for name, w in self.weights.items():
|
for name, w in self.weights.items():
|
||||||
w.g[:] = 0
|
w.g[:] = 0
|
||||||
|
@ -864,13 +876,13 @@ class Model:
|
||||||
return self.nodes
|
return self.nodes
|
||||||
|
|
||||||
def make_weights(self):
|
def make_weights(self):
|
||||||
self.param_count = sum((node.size for node in self.nodes))
|
self.param_count = sum((node.size for node in self.nodes if not node.shared))
|
||||||
self.W = np.zeros(self.param_count, dtype=_f)
|
self.W = np.zeros(self.param_count, dtype=_f)
|
||||||
self.dW = np.zeros(self.param_count, dtype=_f)
|
self.dW = np.zeros(self.param_count, dtype=_f)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
if node.size > 0:
|
if node.size > 0 and not node.shared:
|
||||||
inner_offset = 0
|
inner_offset = 0
|
||||||
|
|
||||||
def allocate(size):
|
def allocate(size):
|
||||||
|
@ -941,6 +953,7 @@ class Model:
|
||||||
used[k] = False
|
used[k] = False
|
||||||
|
|
||||||
nodes = [node for node in self.nodes if node.size > 0]
|
nodes = [node for node in self.nodes if node.size > 0]
|
||||||
|
# TODO: support shared weights.
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
full_name = str(node).lower()
|
full_name = str(node).lower()
|
||||||
for s_name, o_name in node.serialized.items():
|
for s_name, o_name in node.serialized.items():
|
||||||
|
@ -961,6 +974,7 @@ class Model:
|
||||||
counts = defaultdict(lambda: 0)
|
counts = defaultdict(lambda: 0)
|
||||||
|
|
||||||
nodes = [node for node in self.nodes if node.size > 0]
|
nodes = [node for node in self.nodes if node.size > 0]
|
||||||
|
# TODO: support shared weights.
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
full_name = str(node).lower()
|
full_name = str(node).lower()
|
||||||
grp = f.create_group(full_name)
|
grp = f.create_group(full_name)
|
||||||
|
|
Loading…
Reference in a new issue