.
This commit is contained in:
parent
f12e408c7e
commit
028209b699
1 changed files with 52 additions and 32 deletions
84
optim_nn.py
84
optim_nn.py
|
@ -343,11 +343,12 @@ class GeluApprox(Layer):
|
|||
return dY * self.sig * (1 + self.a * (1 - self.sig))
|
||||
|
||||
class Dense(Layer):
|
||||
def __init__(self, dim):
|
||||
def __init__(self, dim, init=init_he_uniform):
|
||||
super().__init__()
|
||||
self.dim = ni(dim)
|
||||
self.output_shape = (dim,)
|
||||
self.size = None
|
||||
self.weight_init = init
|
||||
|
||||
def init(self, W, dW):
|
||||
ins, outs = self.input_shape[0], self.output_shape[0]
|
||||
|
@ -359,7 +360,7 @@ class Dense(Layer):
|
|||
self.dcoeffs = self.dW[:self.nW].reshape(ins, outs)
|
||||
self.dbiases = self.dW[self.nW:].reshape(1, outs)
|
||||
|
||||
self.coeffs.flat = init_he_uniform(self.nW, ins, outs)
|
||||
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
|
||||
self.biases.flat = 0
|
||||
|
||||
def make_shape(self, shape):
|
||||
|
@ -444,7 +445,9 @@ class Model:
|
|||
return self.dW
|
||||
|
||||
def load_weights(self, fn):
|
||||
# seemingly compatible with keras models at the moment
|
||||
# seemingly compatible with keras' Dense layers.
|
||||
# ignores any non-Dense layer types.
|
||||
# TODO: assert file actually exists
|
||||
import h5py
|
||||
f = h5py.File(fn)
|
||||
weights = {}
|
||||
|
@ -459,11 +462,25 @@ class Model:
|
|||
a, b = i, i + 1
|
||||
b_name = "dense_{}".format(b)
|
||||
# TODO: write a Dense method instead of assigning directly
|
||||
denses[a].coeffs = weights[b_name+'_W']
|
||||
denses[a].biases = np.expand_dims(weights[b_name+'_b'], 0)
|
||||
denses[a].coeffs[:] = weights[b_name+'_W']
|
||||
denses[a].biases[:] = np.expand_dims(weights[b_name+'_b'], 0)
|
||||
|
||||
def save_weights(self, fn, overwrite=False):
|
||||
raise NotImplementedError("unimplemented", self)
|
||||
import h5py
|
||||
f = h5py.File(fn, 'w')
|
||||
|
||||
denses = [node for node in self.ordered_nodes if isinstance(node, Dense)]
|
||||
for i in range(len(denses)):
|
||||
a, b = i, i + 1
|
||||
b_name = "dense_{}".format(b)
|
||||
# TODO: write a Dense method instead of assigning directly
|
||||
grp = f.create_group(b_name)
|
||||
data = grp.create_dataset(b_name+'_W', denses[a].coeffs.shape, dtype=nf)
|
||||
data[:] = denses[a].coeffs
|
||||
data = grp.create_dataset(b_name+'_b', denses[a].biases.shape, dtype=nf)
|
||||
data[:] = denses[a].biases
|
||||
|
||||
f.close()
|
||||
|
||||
class Ritual:
|
||||
def __init__(self,
|
||||
|
@ -519,7 +536,9 @@ class Ritual:
|
|||
else:
|
||||
return avg_loss
|
||||
|
||||
def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batchless'):
|
||||
def multiresnet(x, width, depth, block=2, multi=1,
|
||||
activation=Relu, style='batchless',
|
||||
init=init_he_normal):
|
||||
y = x
|
||||
last_size = x.output_shape[0]
|
||||
|
||||
|
@ -527,7 +546,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch
|
|||
size = width
|
||||
|
||||
if last_size != size:
|
||||
y = y.feed(Dense(size))
|
||||
y = y.feed(Dense(size, init))
|
||||
|
||||
if style == 'batchless':
|
||||
skip = y
|
||||
|
@ -539,7 +558,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch
|
|||
for i in range(block):
|
||||
if i > 0:
|
||||
z = z.feed(activation())
|
||||
z = z.feed(Dense(size))
|
||||
z = z.feed(Dense(size, init))
|
||||
z.feed(merger)
|
||||
y = merger
|
||||
elif style == 'onelesssum':
|
||||
|
@ -556,7 +575,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch
|
|||
for i in range(block):
|
||||
if i > 0:
|
||||
z = z.feed(activation())
|
||||
z = z.feed(Dense(size))
|
||||
z = z.feed(Dense(size, init))
|
||||
if needs_sum:
|
||||
z.feed(merger)
|
||||
if needs_sum:
|
||||
|
@ -570,6 +589,9 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch
|
|||
|
||||
return y
|
||||
|
||||
inits = dict(he_normal=init_he_normal, he_uniform=init_he_uniform)
|
||||
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox)
|
||||
|
||||
def run(program, args=[]):
|
||||
import sys
|
||||
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
||||
|
@ -580,7 +602,8 @@ def run(program, args=[]):
|
|||
|
||||
from dotmap import DotMap
|
||||
config = DotMap(
|
||||
fn = 'ml/cie_mlp_min.h5',
|
||||
fn_load = None,
|
||||
fn_save = 'optim_nn.h5',
|
||||
log_fn = 'losses.npz',
|
||||
|
||||
# multi-residual network parameters
|
||||
|
@ -598,11 +621,11 @@ def run(program, args=[]):
|
|||
momentum = 0.33, # only used with SGD
|
||||
|
||||
# learning parameters: SGD with restarts (kinda)
|
||||
LR = 1e-2,
|
||||
learn = 1e-2,
|
||||
epochs = 24,
|
||||
LR_halve_every = 16,
|
||||
learn_halve_every = 16,
|
||||
restarts = 2,
|
||||
LR_restart_advance = 16,
|
||||
learn_restart_advance = 16,
|
||||
|
||||
# misc
|
||||
batch_size = 64,
|
||||
|
@ -635,16 +658,17 @@ def run(program, args=[]):
|
|||
|
||||
# Our Test Model
|
||||
|
||||
init = inits[config.init]
|
||||
activation = activations[config.activation]
|
||||
|
||||
x = Input(shape=(input_samples,))
|
||||
y = x
|
||||
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox)
|
||||
activation = activations[config.activation]
|
||||
y = multiresnet(y,
|
||||
config.res_width, config.res_depth,
|
||||
config.res_block, config.res_multi,
|
||||
activation=activation)
|
||||
activation=activation, init=init)
|
||||
if y.output_shape[0] != output_samples:
|
||||
y = y.feed(Dense(output_samples))
|
||||
y = y.feed(Dense(output_samples, init))
|
||||
|
||||
model = Model(x, y, unsafe=config.unsafe)
|
||||
|
||||
|
@ -654,14 +678,9 @@ def run(program, args=[]):
|
|||
|
||||
training = config.epochs > 0 and config.restarts >= 0
|
||||
|
||||
if not training:
|
||||
assert config.res_width == 12
|
||||
assert config.res_depth == 3
|
||||
assert config.res_block == 2
|
||||
assert config.res_multi == 4
|
||||
assert config.activation == 'relu'
|
||||
assert config.parallel_style == 'batchless'
|
||||
model.load_weights(config.fn)
|
||||
if config.fn_load is not None:
|
||||
log('loading weights', config.fn_load)
|
||||
model.load_weights(config.fn_load)
|
||||
|
||||
if config.optim == 'adam':
|
||||
assert not config.nesterov, "unimplemented"
|
||||
|
@ -686,13 +705,13 @@ def run(program, args=[]):
|
|||
loss = lookup_loss(config.loss)
|
||||
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||
|
||||
anneal = 0.5**(1/config.LR_halve_every)
|
||||
anneal = 0.5**(1/config.learn_halve_every)
|
||||
ritual = Ritual(optim=optim,
|
||||
learn_rate=config.LR, learn_anneal=anneal,
|
||||
learn_advance=config.LR_restart_advance,
|
||||
learn_rate=config.learn, learn_anneal=anneal,
|
||||
learn_advance=config.learn_restart_advance,
|
||||
loss=loss, mloss=mloss)
|
||||
|
||||
learn_end = config.LR * (anneal**config.LR_restart_advance)**config.restarts * anneal**(config.epochs - 1)
|
||||
learn_end = config.learn * (anneal**config.learn_restart_advance)**config.restarts * anneal**(config.epochs - 1)
|
||||
log("final learning rate", "{:10.8f}".format(learn_end))
|
||||
|
||||
# Training
|
||||
|
@ -747,8 +766,9 @@ def run(program, args=[]):
|
|||
|
||||
measure_error()
|
||||
|
||||
#if training:
|
||||
# model.save_weights(config.fn, overwrite=True)
|
||||
if config.fn_save is not None:
|
||||
log('saving weights', config.fn_save)
|
||||
model.save_weights(config.fn_save, overwrite=True)
|
||||
|
||||
# Evaluation
|
||||
|
||||
|
|
Loading…
Reference in a new issue