This commit is contained in:
Connor Olding 2017-02-01 22:03:12 -08:00
parent 028209b699
commit 361aefcb29

View file

@ -347,8 +347,17 @@ class Dense(Layer):
super().__init__()
self.dim = ni(dim)
self.output_shape = (dim,)
self.size = None
self.weight_init = init
self.size = None
def make_shape(self, shape):
super().make_shape(shape)
if len(shape) != 1:
return False
self.nW = self.dim * shape[0]
self.nb = self.dim
self.size = self.nW + self.nb
return shape
def init(self, W, dW):
ins, outs = self.input_shape[0], self.output_shape[0]
@ -363,15 +372,6 @@ class Dense(Layer):
self.coeffs.flat = self.weight_init(self.nW, ins, outs)
self.biases.flat = 0
def make_shape(self, shape):
super().make_shape(shape)
if len(shape) != 1:
return False
self.nW = self.dim * shape[0]
self.nb = self.dim
self.size = self.nW + self.nb
return shape
def F(self, X):
self.X = X
Y = X.dot(self.coeffs) \
@ -381,7 +381,53 @@ class Dense(Layer):
def dF(self, dY):
dX = dY.dot(self.coeffs.T)
self.dcoeffs[:] = self.X.T.dot(dY)
self.dbiases[:] = np.sum(dY, axis=0, keepdims=True)
self.dbiases[:] = dY.sum(0, keepdims=True)
return dX
class DenseOneLess(Dense):
def init(self, W, dW):
super().init(W, dW)
ins, outs = self.input_shape[0], self.output_shape[0]
assert ins == outs, (ins, outs)
def F(self, X):
np.fill_diagonal(self.coeffs, 0)
self.X = X
Y = X.dot(self.coeffs) \
+ self.biases
return Y
def dF(self, dY):
dX = dY.dot(self.coeffs.T)
self.dcoeffs[:] = self.X.T.dot(dY)
self.dbiases[:] = dY.sum(0, keepdims=True)
np.fill_diagonal(self.dcoeffs, 0)
return dX
class LayerNorm(Layer): # TODO: inherit Affine instead?
def __init__(self, eps=1e-3, axis=-1):
super().__init__()
self.eps = nf(eps)
self.axis = int(axis)
def F(self, X):
self.center = X - np.mean(X, axis=self.axis, keepdims=True)
#self.var = np.var(X, axis=self.axis, keepdims=True) + self.eps
self.var = np.mean(np.square(self.center), axis=self.axis, keepdims=True) + self.eps
self.std = np.sqrt(self.var) + self.eps
Y = self.center / self.std
return Y
def dF(self, dY):
length = self.input_shape[self.axis]
dstd = dY * (-self.center / self.var)
dvar = dstd * (0.5 / self.std)
dcenter2 = dvar * (1 / length)
dcenter = dY * (1 / self.std)
dcenter += dcenter2 * (2 * self.center)
dX = dcenter - dcenter / length
return dX
# Model
@ -527,6 +573,8 @@ class Ritual:
self.update(model.dW, model.W)
batch_loss = self.measure(residual)
if np.isnan(batch_loss):
raise Exception("nan")
cumsum_loss += batch_loss
if return_losses:
losses.append(batch_loss)
@ -542,6 +590,9 @@ def multiresnet(x, width, depth, block=2, multi=1,
y = x
last_size = x.output_shape[0]
FC = lambda size: Dense(size, init)
#FC = lambda size: DenseOneLess(size, init)
for d in range(depth):
size = width
@ -558,7 +609,7 @@ def multiresnet(x, width, depth, block=2, multi=1,
for i in range(block):
if i > 0:
z = z.feed(activation())
z = z.feed(Dense(size, init))
z = z.feed(FC(size))
z.feed(merger)
y = merger
elif style == 'onelesssum':
@ -575,7 +626,7 @@ def multiresnet(x, width, depth, block=2, multi=1,
for i in range(block):
if i > 0:
z = z.feed(activation())
z = z.feed(Dense(size, init))
z = z.feed(FC(size))
if needs_sum:
z.feed(merger)
if needs_sum:
@ -630,7 +681,7 @@ def run(program, args=[]):
# misc
batch_size = 64,
init = 'he_normal',
loss = SomethingElse(4/3),
loss = SomethingElse(),
mloss = 'mse',
restart_optim = True, # restarts also reset internal state of optimizer
unsafe = True, # aka gotta go fast mode
@ -642,19 +693,9 @@ def run(program, args=[]):
config.pprint()
# toy CIE-2000 data
from ml.cie_mlp_data import rgbcompare, input_samples, output_samples, x_scale, y_scale
def read_data(fn):
data = np.load(fn)
try:
inputs, outputs = data['inputs'], data['outputs']
except KeyError:
# because i'm bad at video games.
inputs, outputs = data['arr_0'], data['arr_1']
return inputs, outputs
inputs, outputs = read_data("ml/cie_mlp_data.npz")
valid_inputs, valid_outputs = read_data("ml/cie_mlp_vdata.npz")
from ml.cie_mlp_data import rgbcompare, input_samples, output_samples, \
inputs, outputs, valid_inputs, valid_outputs, \
x_scale, y_scale
# Our Test Model
@ -666,14 +707,22 @@ def run(program, args=[]):
y = multiresnet(y,
config.res_width, config.res_depth,
config.res_block, config.res_multi,
activation=activation, init=init)
activation=activation, init=init,
style=config.parallel_style)
if y.output_shape[0] != output_samples:
y = y.feed(Dense(output_samples, init))
model = Model(x, y, unsafe=config.unsafe)
node_names = ' '.join([str(node) for node in model.ordered_nodes])
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
if 0:
node_names = ' '.join([str(node) for node in model.ordered_nodes])
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
else:
for node in model.ordered_nodes:
children = [str(n) for n in node.children]
if len(children) > 0:
sep = '->'
print(str(node)+sep+('\n'+str(node)+sep).join(children))
log('parameters', model.param_count)
training = config.epochs > 0 and config.restarts >= 0