.
This commit is contained in:
parent
f12e408c7e
commit
028209b699
84
optim_nn.py
84
optim_nn.py
|
@ -343,11 +343,12 @@ class GeluApprox(Layer):
|
||||||
return dY * self.sig * (1 + self.a * (1 - self.sig))
|
return dY * self.sig * (1 + self.a * (1 - self.sig))
|
||||||
|
|
||||||
class Dense(Layer):
|
class Dense(Layer):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim, init=init_he_uniform):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = ni(dim)
|
self.dim = ni(dim)
|
||||||
self.output_shape = (dim,)
|
self.output_shape = (dim,)
|
||||||
self.size = None
|
self.size = None
|
||||||
|
self.weight_init = init
|
||||||
|
|
||||||
def init(self, W, dW):
|
def init(self, W, dW):
|
||||||
ins, outs = self.input_shape[0], self.output_shape[0]
|
ins, outs = self.input_shape[0], self.output_shape[0]
|
||||||
|
@ -359,7 +360,7 @@ class Dense(Layer):
|
||||||
self.dcoeffs = self.dW[:self.nW].reshape(ins, outs)
|
self.dcoeffs = self.dW[:self.nW].reshape(ins, outs)
|
||||||
self.dbiases = self.dW[self.nW:].reshape(1, outs)
|
self.dbiases = self.dW[self.nW:].reshape(1, outs)
|
||||||
|
|
||||||
self.coeffs.flat = init_he_uniform(self.nW, ins, outs)
|
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
|
||||||
self.biases.flat = 0
|
self.biases.flat = 0
|
||||||
|
|
||||||
def make_shape(self, shape):
|
def make_shape(self, shape):
|
||||||
|
@ -444,7 +445,9 @@ class Model:
|
||||||
return self.dW
|
return self.dW
|
||||||
|
|
||||||
def load_weights(self, fn):
|
def load_weights(self, fn):
|
||||||
# seemingly compatible with keras models at the moment
|
# seemingly compatible with keras' Dense layers.
|
||||||
|
# ignores any non-Dense layer types.
|
||||||
|
# TODO: assert file actually exists
|
||||||
import h5py
|
import h5py
|
||||||
f = h5py.File(fn)
|
f = h5py.File(fn)
|
||||||
weights = {}
|
weights = {}
|
||||||
|
@ -459,11 +462,25 @@ class Model:
|
||||||
a, b = i, i + 1
|
a, b = i, i + 1
|
||||||
b_name = "dense_{}".format(b)
|
b_name = "dense_{}".format(b)
|
||||||
# TODO: write a Dense method instead of assigning directly
|
# TODO: write a Dense method instead of assigning directly
|
||||||
denses[a].coeffs = weights[b_name+'_W']
|
denses[a].coeffs[:] = weights[b_name+'_W']
|
||||||
denses[a].biases = np.expand_dims(weights[b_name+'_b'], 0)
|
denses[a].biases[:] = np.expand_dims(weights[b_name+'_b'], 0)
|
||||||
|
|
||||||
def save_weights(self, fn, overwrite=False):
|
def save_weights(self, fn, overwrite=False):
|
||||||
raise NotImplementedError("unimplemented", self)
|
import h5py
|
||||||
|
f = h5py.File(fn, 'w')
|
||||||
|
|
||||||
|
denses = [node for node in self.ordered_nodes if isinstance(node, Dense)]
|
||||||
|
for i in range(len(denses)):
|
||||||
|
a, b = i, i + 1
|
||||||
|
b_name = "dense_{}".format(b)
|
||||||
|
# TODO: write a Dense method instead of assigning directly
|
||||||
|
grp = f.create_group(b_name)
|
||||||
|
data = grp.create_dataset(b_name+'_W', denses[a].coeffs.shape, dtype=nf)
|
||||||
|
data[:] = denses[a].coeffs
|
||||||
|
data = grp.create_dataset(b_name+'_b', denses[a].biases.shape, dtype=nf)
|
||||||
|
data[:] = denses[a].biases
|
||||||
|
|
||||||
|
f.close()
|
||||||
|
|
||||||
class Ritual:
|
class Ritual:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -519,7 +536,9 @@ class Ritual:
|
||||||
else:
|
else:
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batchless'):
|
def multiresnet(x, width, depth, block=2, multi=1,
|
||||||
|
activation=Relu, style='batchless',
|
||||||
|
init=init_he_normal):
|
||||||
y = x
|
y = x
|
||||||
last_size = x.output_shape[0]
|
last_size = x.output_shape[0]
|
||||||
|
|
||||||
|
@ -527,7 +546,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch
|
||||||
size = width
|
size = width
|
||||||
|
|
||||||
if last_size != size:
|
if last_size != size:
|
||||||
y = y.feed(Dense(size))
|
y = y.feed(Dense(size, init))
|
||||||
|
|
||||||
if style == 'batchless':
|
if style == 'batchless':
|
||||||
skip = y
|
skip = y
|
||||||
|
@ -539,7 +558,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch
|
||||||
for i in range(block):
|
for i in range(block):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
z = z.feed(activation())
|
z = z.feed(activation())
|
||||||
z = z.feed(Dense(size))
|
z = z.feed(Dense(size, init))
|
||||||
z.feed(merger)
|
z.feed(merger)
|
||||||
y = merger
|
y = merger
|
||||||
elif style == 'onelesssum':
|
elif style == 'onelesssum':
|
||||||
|
@ -556,7 +575,7 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch
|
||||||
for i in range(block):
|
for i in range(block):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
z = z.feed(activation())
|
z = z.feed(activation())
|
||||||
z = z.feed(Dense(size))
|
z = z.feed(Dense(size, init))
|
||||||
if needs_sum:
|
if needs_sum:
|
||||||
z.feed(merger)
|
z.feed(merger)
|
||||||
if needs_sum:
|
if needs_sum:
|
||||||
|
@ -570,6 +589,9 @@ def multiresnet(x, width, depth, block=2, multi=1, activation=Relu, style='batch
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
inits = dict(he_normal=init_he_normal, he_uniform=init_he_uniform)
|
||||||
|
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox)
|
||||||
|
|
||||||
def run(program, args=[]):
|
def run(program, args=[]):
|
||||||
import sys
|
import sys
|
||||||
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
||||||
|
@ -580,7 +602,8 @@ def run(program, args=[]):
|
||||||
|
|
||||||
from dotmap import DotMap
|
from dotmap import DotMap
|
||||||
config = DotMap(
|
config = DotMap(
|
||||||
fn = 'ml/cie_mlp_min.h5',
|
fn_load = None,
|
||||||
|
fn_save = 'optim_nn.h5',
|
||||||
log_fn = 'losses.npz',
|
log_fn = 'losses.npz',
|
||||||
|
|
||||||
# multi-residual network parameters
|
# multi-residual network parameters
|
||||||
|
@ -598,11 +621,11 @@ def run(program, args=[]):
|
||||||
momentum = 0.33, # only used with SGD
|
momentum = 0.33, # only used with SGD
|
||||||
|
|
||||||
# learning parameters: SGD with restarts (kinda)
|
# learning parameters: SGD with restarts (kinda)
|
||||||
LR = 1e-2,
|
learn = 1e-2,
|
||||||
epochs = 24,
|
epochs = 24,
|
||||||
LR_halve_every = 16,
|
learn_halve_every = 16,
|
||||||
restarts = 2,
|
restarts = 2,
|
||||||
LR_restart_advance = 16,
|
learn_restart_advance = 16,
|
||||||
|
|
||||||
# misc
|
# misc
|
||||||
batch_size = 64,
|
batch_size = 64,
|
||||||
|
@ -635,16 +658,17 @@ def run(program, args=[]):
|
||||||
|
|
||||||
# Our Test Model
|
# Our Test Model
|
||||||
|
|
||||||
|
init = inits[config.init]
|
||||||
|
activation = activations[config.activation]
|
||||||
|
|
||||||
x = Input(shape=(input_samples,))
|
x = Input(shape=(input_samples,))
|
||||||
y = x
|
y = x
|
||||||
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, elu=Elu, gelu=GeluApprox)
|
|
||||||
activation = activations[config.activation]
|
|
||||||
y = multiresnet(y,
|
y = multiresnet(y,
|
||||||
config.res_width, config.res_depth,
|
config.res_width, config.res_depth,
|
||||||
config.res_block, config.res_multi,
|
config.res_block, config.res_multi,
|
||||||
activation=activation)
|
activation=activation, init=init)
|
||||||
if y.output_shape[0] != output_samples:
|
if y.output_shape[0] != output_samples:
|
||||||
y = y.feed(Dense(output_samples))
|
y = y.feed(Dense(output_samples, init))
|
||||||
|
|
||||||
model = Model(x, y, unsafe=config.unsafe)
|
model = Model(x, y, unsafe=config.unsafe)
|
||||||
|
|
||||||
|
@ -654,14 +678,9 @@ def run(program, args=[]):
|
||||||
|
|
||||||
training = config.epochs > 0 and config.restarts >= 0
|
training = config.epochs > 0 and config.restarts >= 0
|
||||||
|
|
||||||
if not training:
|
if config.fn_load is not None:
|
||||||
assert config.res_width == 12
|
log('loading weights', config.fn_load)
|
||||||
assert config.res_depth == 3
|
model.load_weights(config.fn_load)
|
||||||
assert config.res_block == 2
|
|
||||||
assert config.res_multi == 4
|
|
||||||
assert config.activation == 'relu'
|
|
||||||
assert config.parallel_style == 'batchless'
|
|
||||||
model.load_weights(config.fn)
|
|
||||||
|
|
||||||
if config.optim == 'adam':
|
if config.optim == 'adam':
|
||||||
assert not config.nesterov, "unimplemented"
|
assert not config.nesterov, "unimplemented"
|
||||||
|
@ -686,13 +705,13 @@ def run(program, args=[]):
|
||||||
loss = lookup_loss(config.loss)
|
loss = lookup_loss(config.loss)
|
||||||
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
mloss = lookup_loss(config.mloss) if config.mloss else loss
|
||||||
|
|
||||||
anneal = 0.5**(1/config.LR_halve_every)
|
anneal = 0.5**(1/config.learn_halve_every)
|
||||||
ritual = Ritual(optim=optim,
|
ritual = Ritual(optim=optim,
|
||||||
learn_rate=config.LR, learn_anneal=anneal,
|
learn_rate=config.learn, learn_anneal=anneal,
|
||||||
learn_advance=config.LR_restart_advance,
|
learn_advance=config.learn_restart_advance,
|
||||||
loss=loss, mloss=mloss)
|
loss=loss, mloss=mloss)
|
||||||
|
|
||||||
learn_end = config.LR * (anneal**config.LR_restart_advance)**config.restarts * anneal**(config.epochs - 1)
|
learn_end = config.learn * (anneal**config.learn_restart_advance)**config.restarts * anneal**(config.epochs - 1)
|
||||||
log("final learning rate", "{:10.8f}".format(learn_end))
|
log("final learning rate", "{:10.8f}".format(learn_end))
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
@ -747,8 +766,9 @@ def run(program, args=[]):
|
||||||
|
|
||||||
measure_error()
|
measure_error()
|
||||||
|
|
||||||
#if training:
|
if config.fn_save is not None:
|
||||||
# model.save_weights(config.fn, overwrite=True)
|
log('saving weights', config.fn_save)
|
||||||
|
model.save_weights(config.fn_save, overwrite=True)
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user