.
This commit is contained in:
parent
028209b699
commit
361aefcb29
109
optim_nn.py
109
optim_nn.py
|
@ -347,8 +347,17 @@ class Dense(Layer):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = ni(dim)
|
self.dim = ni(dim)
|
||||||
self.output_shape = (dim,)
|
self.output_shape = (dim,)
|
||||||
self.size = None
|
|
||||||
self.weight_init = init
|
self.weight_init = init
|
||||||
|
self.size = None
|
||||||
|
|
||||||
|
def make_shape(self, shape):
|
||||||
|
super().make_shape(shape)
|
||||||
|
if len(shape) != 1:
|
||||||
|
return False
|
||||||
|
self.nW = self.dim * shape[0]
|
||||||
|
self.nb = self.dim
|
||||||
|
self.size = self.nW + self.nb
|
||||||
|
return shape
|
||||||
|
|
||||||
def init(self, W, dW):
|
def init(self, W, dW):
|
||||||
ins, outs = self.input_shape[0], self.output_shape[0]
|
ins, outs = self.input_shape[0], self.output_shape[0]
|
||||||
|
@ -363,15 +372,6 @@ class Dense(Layer):
|
||||||
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
|
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
|
||||||
self.biases.flat = 0
|
self.biases.flat = 0
|
||||||
|
|
||||||
def make_shape(self, shape):
|
|
||||||
super().make_shape(shape)
|
|
||||||
if len(shape) != 1:
|
|
||||||
return False
|
|
||||||
self.nW = self.dim * shape[0]
|
|
||||||
self.nb = self.dim
|
|
||||||
self.size = self.nW + self.nb
|
|
||||||
return shape
|
|
||||||
|
|
||||||
def F(self, X):
|
def F(self, X):
|
||||||
self.X = X
|
self.X = X
|
||||||
Y = X.dot(self.coeffs) \
|
Y = X.dot(self.coeffs) \
|
||||||
|
@ -381,7 +381,53 @@ class Dense(Layer):
|
||||||
def dF(self, dY):
|
def dF(self, dY):
|
||||||
dX = dY.dot(self.coeffs.T)
|
dX = dY.dot(self.coeffs.T)
|
||||||
self.dcoeffs[:] = self.X.T.dot(dY)
|
self.dcoeffs[:] = self.X.T.dot(dY)
|
||||||
self.dbiases[:] = np.sum(dY, axis=0, keepdims=True)
|
self.dbiases[:] = dY.sum(0, keepdims=True)
|
||||||
|
return dX
|
||||||
|
|
||||||
|
class DenseOneLess(Dense):
|
||||||
|
def init(self, W, dW):
|
||||||
|
super().init(W, dW)
|
||||||
|
ins, outs = self.input_shape[0], self.output_shape[0]
|
||||||
|
assert ins == outs, (ins, outs)
|
||||||
|
|
||||||
|
def F(self, X):
|
||||||
|
np.fill_diagonal(self.coeffs, 0)
|
||||||
|
self.X = X
|
||||||
|
Y = X.dot(self.coeffs) \
|
||||||
|
+ self.biases
|
||||||
|
return Y
|
||||||
|
|
||||||
|
def dF(self, dY):
|
||||||
|
dX = dY.dot(self.coeffs.T)
|
||||||
|
self.dcoeffs[:] = self.X.T.dot(dY)
|
||||||
|
self.dbiases[:] = dY.sum(0, keepdims=True)
|
||||||
|
np.fill_diagonal(self.dcoeffs, 0)
|
||||||
|
return dX
|
||||||
|
|
||||||
|
class LayerNorm(Layer): # TODO: inherit Affine instead?
|
||||||
|
def __init__(self, eps=1e-3, axis=-1):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = nf(eps)
|
||||||
|
self.axis = int(axis)
|
||||||
|
|
||||||
|
def F(self, X):
|
||||||
|
self.center = X - np.mean(X, axis=self.axis, keepdims=True)
|
||||||
|
#self.var = np.var(X, axis=self.axis, keepdims=True) + self.eps
|
||||||
|
self.var = np.mean(np.square(self.center), axis=self.axis, keepdims=True) + self.eps
|
||||||
|
self.std = np.sqrt(self.var) + self.eps
|
||||||
|
Y = self.center / self.std
|
||||||
|
return Y
|
||||||
|
|
||||||
|
def dF(self, dY):
|
||||||
|
length = self.input_shape[self.axis]
|
||||||
|
|
||||||
|
dstd = dY * (-self.center / self.var)
|
||||||
|
dvar = dstd * (0.5 / self.std)
|
||||||
|
dcenter2 = dvar * (1 / length)
|
||||||
|
dcenter = dY * (1 / self.std)
|
||||||
|
dcenter += dcenter2 * (2 * self.center)
|
||||||
|
dX = dcenter - dcenter / length
|
||||||
|
|
||||||
return dX
|
return dX
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
|
@ -527,6 +573,8 @@ class Ritual:
|
||||||
self.update(model.dW, model.W)
|
self.update(model.dW, model.W)
|
||||||
|
|
||||||
batch_loss = self.measure(residual)
|
batch_loss = self.measure(residual)
|
||||||
|
if np.isnan(batch_loss):
|
||||||
|
raise Exception("nan")
|
||||||
cumsum_loss += batch_loss
|
cumsum_loss += batch_loss
|
||||||
if return_losses:
|
if return_losses:
|
||||||
losses.append(batch_loss)
|
losses.append(batch_loss)
|
||||||
|
@ -542,6 +590,9 @@ def multiresnet(x, width, depth, block=2, multi=1,
|
||||||
y = x
|
y = x
|
||||||
last_size = x.output_shape[0]
|
last_size = x.output_shape[0]
|
||||||
|
|
||||||
|
FC = lambda size: Dense(size, init)
|
||||||
|
#FC = lambda size: DenseOneLess(size, init)
|
||||||
|
|
||||||
for d in range(depth):
|
for d in range(depth):
|
||||||
size = width
|
size = width
|
||||||
|
|
||||||
|
@ -558,7 +609,7 @@ def multiresnet(x, width, depth, block=2, multi=1,
|
||||||
for i in range(block):
|
for i in range(block):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
z = z.feed(activation())
|
z = z.feed(activation())
|
||||||
z = z.feed(Dense(size, init))
|
z = z.feed(FC(size))
|
||||||
z.feed(merger)
|
z.feed(merger)
|
||||||
y = merger
|
y = merger
|
||||||
elif style == 'onelesssum':
|
elif style == 'onelesssum':
|
||||||
|
@ -575,7 +626,7 @@ def multiresnet(x, width, depth, block=2, multi=1,
|
||||||
for i in range(block):
|
for i in range(block):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
z = z.feed(activation())
|
z = z.feed(activation())
|
||||||
z = z.feed(Dense(size, init))
|
z = z.feed(FC(size))
|
||||||
if needs_sum:
|
if needs_sum:
|
||||||
z.feed(merger)
|
z.feed(merger)
|
||||||
if needs_sum:
|
if needs_sum:
|
||||||
|
@ -630,7 +681,7 @@ def run(program, args=[]):
|
||||||
# misc
|
# misc
|
||||||
batch_size = 64,
|
batch_size = 64,
|
||||||
init = 'he_normal',
|
init = 'he_normal',
|
||||||
loss = SomethingElse(4/3),
|
loss = SomethingElse(),
|
||||||
mloss = 'mse',
|
mloss = 'mse',
|
||||||
restart_optim = True, # restarts also reset internal state of optimizer
|
restart_optim = True, # restarts also reset internal state of optimizer
|
||||||
unsafe = True, # aka gotta go fast mode
|
unsafe = True, # aka gotta go fast mode
|
||||||
|
@ -642,19 +693,9 @@ def run(program, args=[]):
|
||||||
config.pprint()
|
config.pprint()
|
||||||
|
|
||||||
# toy CIE-2000 data
|
# toy CIE-2000 data
|
||||||
from ml.cie_mlp_data import rgbcompare, input_samples, output_samples, x_scale, y_scale
|
from ml.cie_mlp_data import rgbcompare, input_samples, output_samples, \
|
||||||
|
inputs, outputs, valid_inputs, valid_outputs, \
|
||||||
def read_data(fn):
|
x_scale, y_scale
|
||||||
data = np.load(fn)
|
|
||||||
try:
|
|
||||||
inputs, outputs = data['inputs'], data['outputs']
|
|
||||||
except KeyError:
|
|
||||||
# because i'm bad at video games.
|
|
||||||
inputs, outputs = data['arr_0'], data['arr_1']
|
|
||||||
return inputs, outputs
|
|
||||||
|
|
||||||
inputs, outputs = read_data("ml/cie_mlp_data.npz")
|
|
||||||
valid_inputs, valid_outputs = read_data("ml/cie_mlp_vdata.npz")
|
|
||||||
|
|
||||||
# Our Test Model
|
# Our Test Model
|
||||||
|
|
||||||
|
@ -666,14 +707,22 @@ def run(program, args=[]):
|
||||||
y = multiresnet(y,
|
y = multiresnet(y,
|
||||||
config.res_width, config.res_depth,
|
config.res_width, config.res_depth,
|
||||||
config.res_block, config.res_multi,
|
config.res_block, config.res_multi,
|
||||||
activation=activation, init=init)
|
activation=activation, init=init,
|
||||||
|
style=config.parallel_style)
|
||||||
if y.output_shape[0] != output_samples:
|
if y.output_shape[0] != output_samples:
|
||||||
y = y.feed(Dense(output_samples, init))
|
y = y.feed(Dense(output_samples, init))
|
||||||
|
|
||||||
model = Model(x, y, unsafe=config.unsafe)
|
model = Model(x, y, unsafe=config.unsafe)
|
||||||
|
|
||||||
node_names = ' '.join([str(node) for node in model.ordered_nodes])
|
if 0:
|
||||||
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
|
node_names = ' '.join([str(node) for node in model.ordered_nodes])
|
||||||
|
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
|
||||||
|
else:
|
||||||
|
for node in model.ordered_nodes:
|
||||||
|
children = [str(n) for n in node.children]
|
||||||
|
if len(children) > 0:
|
||||||
|
sep = '->'
|
||||||
|
print(str(node)+sep+('\n'+str(node)+sep).join(children))
|
||||||
log('parameters', model.param_count)
|
log('parameters', model.param_count)
|
||||||
|
|
||||||
training = config.epochs > 0 and config.restarts >= 0
|
training = config.epochs > 0 and config.restarts >= 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user