.
This commit is contained in:
parent
06eb28b23e
commit
547017a6fc
2 changed files with 49 additions and 20 deletions
11
optim_nn.py
11
optim_nn.py
|
@ -65,6 +65,12 @@ class LayerNorm(Layer):
|
||||||
self.affine = bool(affine)
|
self.affine = bool(affine)
|
||||||
self.size = None
|
self.size = None
|
||||||
|
|
||||||
|
if self.affine:
|
||||||
|
self.serialized = {
|
||||||
|
'gamma': 'gamma',
|
||||||
|
'beta': 'beta',
|
||||||
|
}
|
||||||
|
|
||||||
def make_shape(self, shape):
|
def make_shape(self, shape):
|
||||||
super().make_shape(shape)
|
super().make_shape(shape)
|
||||||
if len(shape) != 1:
|
if len(shape) != 1:
|
||||||
|
@ -116,6 +122,11 @@ class LayerNorm(Layer):
|
||||||
class Denses(Layer): # TODO: rename?
|
class Denses(Layer): # TODO: rename?
|
||||||
# acts as a separate Dense for each row or column. only for 2D arrays.
|
# acts as a separate Dense for each row or column. only for 2D arrays.
|
||||||
|
|
||||||
|
serialized = {
|
||||||
|
'W': 'coeffs',
|
||||||
|
'b': 'biases',
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, dim, init=init_he_uniform, axis=-1):
|
def __init__(self, dim, init=init_he_uniform, axis=-1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = int(dim)
|
self.dim = int(dim)
|
||||||
|
|
|
@ -515,6 +515,11 @@ class Softmax(Layer):
|
||||||
# Parametric Layers {{{1
|
# Parametric Layers {{{1
|
||||||
|
|
||||||
class Dense(Layer):
|
class Dense(Layer):
|
||||||
|
serialized = {
|
||||||
|
'W': 'coeffs',
|
||||||
|
'b': 'biases',
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, dim, init=init_he_uniform):
|
def __init__(self, dim, init=init_he_uniform):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = int(dim)
|
self.dim = int(dim)
|
||||||
|
@ -615,10 +620,9 @@ class Model:
|
||||||
|
|
||||||
def load_weights(self, fn):
|
def load_weights(self, fn):
|
||||||
# seemingly compatible with keras' Dense layers.
|
# seemingly compatible with keras' Dense layers.
|
||||||
# ignores any non-Dense layer types.
|
|
||||||
# TODO: assert file actually exists
|
|
||||||
import h5py
|
import h5py
|
||||||
f = h5py.File(fn)
|
open(fn) # just ensure the file exists (python's error is better)
|
||||||
|
f = h5py.File(fn, 'r')
|
||||||
weights = {}
|
weights = {}
|
||||||
def visitor(name, obj):
|
def visitor(name, obj):
|
||||||
if isinstance(obj, h5py.Dataset):
|
if isinstance(obj, h5py.Dataset):
|
||||||
|
@ -626,28 +630,42 @@ class Model:
|
||||||
f.visititems(visitor)
|
f.visititems(visitor)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
denses = [node for node in self.ordered_nodes if isinstance(node, Dense)]
|
used = {}
|
||||||
for i in range(len(denses)):
|
for k in weights.keys():
|
||||||
a, b = i, i + 1
|
used[k] = False
|
||||||
b_name = "dense_{}".format(b)
|
|
||||||
# TODO: write a Dense method instead of assigning directly
|
nodes = [node for node in self.ordered_nodes if node.size is not None]
|
||||||
denses[a].coeffs[:] = weights[b_name+'_W']
|
for node in nodes:
|
||||||
denses[a].biases[:] = np.expand_dims(weights[b_name+'_b'], 0)
|
full_name = str(node).lower()
|
||||||
|
for s_name, o_name in node.serialized.items():
|
||||||
|
key = full_name + '_' + s_name
|
||||||
|
data = weights[key]
|
||||||
|
target = getattr(node, o_name)
|
||||||
|
target[:] = data
|
||||||
|
used[key] = True
|
||||||
|
|
||||||
|
for k, v in used.items():
|
||||||
|
if not v:
|
||||||
|
lament("WARNING: unused weight", k)
|
||||||
|
|
||||||
def save_weights(self, fn, overwrite=False):
|
def save_weights(self, fn, overwrite=False):
|
||||||
import h5py
|
import h5py
|
||||||
f = h5py.File(fn, 'w')
|
f = h5py.File(fn, 'w')
|
||||||
|
|
||||||
denses = [node for node in self.ordered_nodes if isinstance(node, Dense)]
|
counts = defaultdict(lambda: 0)
|
||||||
for i in range(len(denses)):
|
|
||||||
a, b = i, i + 1
|
nodes = [node for node in self.ordered_nodes if node.size is not None]
|
||||||
b_name = "dense_{}".format(b)
|
for node in nodes:
|
||||||
# TODO: write a Dense method instead of assigning directly
|
full_name = str(node).lower()
|
||||||
grp = f.create_group(b_name)
|
grp = f.create_group(full_name)
|
||||||
data = grp.create_dataset(b_name+'_W', denses[a].coeffs.shape, dtype=_f)
|
for s_name, o_name in node.serialized.items():
|
||||||
data[:] = denses[a].coeffs
|
key = full_name + '_' + s_name
|
||||||
data = grp.create_dataset(b_name+'_b', denses[a].biases.shape, dtype=_f)
|
target = getattr(node, o_name)
|
||||||
data[:] = denses[a].biases
|
data = grp.create_dataset(key, target.shape, dtype=_f)
|
||||||
|
data[:] = target
|
||||||
|
counts[key] += 1
|
||||||
|
if counts[key] > 1:
|
||||||
|
lament("WARNING: rewrote weight", key)
|
||||||
|
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue