This commit is contained in:
Connor Olding 2017-02-27 22:14:58 +00:00
parent 06eb28b23e
commit 547017a6fc
2 changed files with 49 additions and 20 deletions

View file

@ -65,6 +65,12 @@ class LayerNorm(Layer):
self.affine = bool(affine)
self.size = None
if self.affine:
self.serialized = {
'gamma': 'gamma',
'beta': 'beta',
}
def make_shape(self, shape):
super().make_shape(shape)
if len(shape) != 1:
@ -116,6 +122,11 @@ class LayerNorm(Layer):
class Denses(Layer): # TODO: rename?
# acts as a separate Dense for each row or column. only for 2D arrays.
serialized = {
'W': 'coeffs',
'b': 'biases',
}
def __init__(self, dim, init=init_he_uniform, axis=-1):
super().__init__()
self.dim = int(dim)

View file

@ -515,6 +515,11 @@ class Softmax(Layer):
# Parametric Layers {{{1
class Dense(Layer):
serialized = {
'W': 'coeffs',
'b': 'biases',
}
def __init__(self, dim, init=init_he_uniform):
super().__init__()
self.dim = int(dim)
@ -615,10 +620,9 @@ class Model:
def load_weights(self, fn):
# seemingly compatible with keras' Dense layers.
# ignores any non-Dense layer types.
# TODO: assert file actually exists
import h5py
f = h5py.File(fn)
open(fn) # just ensure the file exists (python's error is better)
f = h5py.File(fn, 'r')
weights = {}
def visitor(name, obj):
if isinstance(obj, h5py.Dataset):
@ -626,28 +630,42 @@ class Model:
f.visititems(visitor)
f.close()
denses = [node for node in self.ordered_nodes if isinstance(node, Dense)]
for i in range(len(denses)):
a, b = i, i + 1
b_name = "dense_{}".format(b)
# TODO: write a Dense method instead of assigning directly
denses[a].coeffs[:] = weights[b_name+'_W']
denses[a].biases[:] = np.expand_dims(weights[b_name+'_b'], 0)
used = {}
for k in weights.keys():
used[k] = False
nodes = [node for node in self.ordered_nodes if node.size is not None]
for node in nodes:
full_name = str(node).lower()
for s_name, o_name in node.serialized.items():
key = full_name + '_' + s_name
data = weights[key]
target = getattr(node, o_name)
target[:] = data
used[key] = True
for k, v in used.items():
if not v:
lament("WARNING: unused weight", k)
def save_weights(self, fn, overwrite=False):
import h5py
f = h5py.File(fn, 'w')
denses = [node for node in self.ordered_nodes if isinstance(node, Dense)]
for i in range(len(denses)):
a, b = i, i + 1
b_name = "dense_{}".format(b)
# TODO: write a Dense method instead of assigning directly
grp = f.create_group(b_name)
data = grp.create_dataset(b_name+'_W', denses[a].coeffs.shape, dtype=_f)
data[:] = denses[a].coeffs
data = grp.create_dataset(b_name+'_b', denses[a].biases.shape, dtype=_f)
data[:] = denses[a].biases
counts = defaultdict(lambda: 0)
nodes = [node for node in self.ordered_nodes if node.size is not None]
for node in nodes:
full_name = str(node).lower()
grp = f.create_group(full_name)
for s_name, o_name in node.serialized.items():
key = full_name + '_' + s_name
target = getattr(node, o_name)
data = grp.create_dataset(key, target.shape, dtype=_f)
data[:] = target
counts[key] += 1
if counts[key] > 1:
lament("WARNING: rewrote weight", key)
f.close()