.
This commit is contained in:
parent
06eb28b23e
commit
547017a6fc
2 changed files with 49 additions and 20 deletions
11
optim_nn.py
11
optim_nn.py
|
@ -65,6 +65,12 @@ class LayerNorm(Layer):
|
|||
self.affine = bool(affine)
|
||||
self.size = None
|
||||
|
||||
if self.affine:
|
||||
self.serialized = {
|
||||
'gamma': 'gamma',
|
||||
'beta': 'beta',
|
||||
}
|
||||
|
||||
def make_shape(self, shape):
|
||||
super().make_shape(shape)
|
||||
if len(shape) != 1:
|
||||
|
@ -116,6 +122,11 @@ class LayerNorm(Layer):
|
|||
class Denses(Layer): # TODO: rename?
|
||||
# acts as a separate Dense for each row or column. only for 2D arrays.
|
||||
|
||||
serialized = {
|
||||
'W': 'coeffs',
|
||||
'b': 'biases',
|
||||
}
|
||||
|
||||
def __init__(self, dim, init=init_he_uniform, axis=-1):
|
||||
super().__init__()
|
||||
self.dim = int(dim)
|
||||
|
|
|
@ -515,6 +515,11 @@ class Softmax(Layer):
|
|||
# Parametric Layers {{{1
|
||||
|
||||
class Dense(Layer):
|
||||
serialized = {
|
||||
'W': 'coeffs',
|
||||
'b': 'biases',
|
||||
}
|
||||
|
||||
def __init__(self, dim, init=init_he_uniform):
|
||||
super().__init__()
|
||||
self.dim = int(dim)
|
||||
|
@ -615,10 +620,9 @@ class Model:
|
|||
|
||||
def load_weights(self, fn):
|
||||
# seemingly compatible with keras' Dense layers.
|
||||
# ignores any non-Dense layer types.
|
||||
# TODO: assert file actually exists
|
||||
import h5py
|
||||
f = h5py.File(fn)
|
||||
open(fn) # just ensure the file exists (python's error is better)
|
||||
f = h5py.File(fn, 'r')
|
||||
weights = {}
|
||||
def visitor(name, obj):
|
||||
if isinstance(obj, h5py.Dataset):
|
||||
|
@ -626,28 +630,42 @@ class Model:
|
|||
f.visititems(visitor)
|
||||
f.close()
|
||||
|
||||
denses = [node for node in self.ordered_nodes if isinstance(node, Dense)]
|
||||
for i in range(len(denses)):
|
||||
a, b = i, i + 1
|
||||
b_name = "dense_{}".format(b)
|
||||
# TODO: write a Dense method instead of assigning directly
|
||||
denses[a].coeffs[:] = weights[b_name+'_W']
|
||||
denses[a].biases[:] = np.expand_dims(weights[b_name+'_b'], 0)
|
||||
used = {}
|
||||
for k in weights.keys():
|
||||
used[k] = False
|
||||
|
||||
nodes = [node for node in self.ordered_nodes if node.size is not None]
|
||||
for node in nodes:
|
||||
full_name = str(node).lower()
|
||||
for s_name, o_name in node.serialized.items():
|
||||
key = full_name + '_' + s_name
|
||||
data = weights[key]
|
||||
target = getattr(node, o_name)
|
||||
target[:] = data
|
||||
used[key] = True
|
||||
|
||||
for k, v in used.items():
|
||||
if not v:
|
||||
lament("WARNING: unused weight", k)
|
||||
|
||||
def save_weights(self, fn, overwrite=False):
|
||||
import h5py
|
||||
f = h5py.File(fn, 'w')
|
||||
|
||||
denses = [node for node in self.ordered_nodes if isinstance(node, Dense)]
|
||||
for i in range(len(denses)):
|
||||
a, b = i, i + 1
|
||||
b_name = "dense_{}".format(b)
|
||||
# TODO: write a Dense method instead of assigning directly
|
||||
grp = f.create_group(b_name)
|
||||
data = grp.create_dataset(b_name+'_W', denses[a].coeffs.shape, dtype=_f)
|
||||
data[:] = denses[a].coeffs
|
||||
data = grp.create_dataset(b_name+'_b', denses[a].biases.shape, dtype=_f)
|
||||
data[:] = denses[a].biases
|
||||
counts = defaultdict(lambda: 0)
|
||||
|
||||
nodes = [node for node in self.ordered_nodes if node.size is not None]
|
||||
for node in nodes:
|
||||
full_name = str(node).lower()
|
||||
grp = f.create_group(full_name)
|
||||
for s_name, o_name in node.serialized.items():
|
||||
key = full_name + '_' + s_name
|
||||
target = getattr(node, o_name)
|
||||
data = grp.create_dataset(key, target.shape, dtype=_f)
|
||||
data[:] = target
|
||||
counts[key] += 1
|
||||
if counts[key] > 1:
|
||||
lament("WARNING: rewrote weight", key)
|
||||
|
||||
f.close()
|
||||
|
||||
|
|
Loading…
Reference in a new issue