This commit is contained in:
Connor Olding 2017-01-10 19:06:42 -08:00
parent cf9010d25f
commit 5e06190acc

View File

@ -123,6 +123,7 @@ class Layer:
_layer_counters[kind] += 1
self.name = "{}_{}".format(kind, _layer_counters[kind])
self.size = None # total weight count (if any)
self.unsafe = False # aka gotta go fast mode
def __str__(self):
return self.name
@ -142,7 +143,8 @@ class Layer:
self.parents.append(parent)
def make_shape(self, shape):
assert shape is not None
if not self.unsafe:
assert shape is not None
if self.output_shape is None:
self.output_shape = shape
return shape
@ -150,7 +152,8 @@ class Layer:
# TODO: rename this multi and B crap to something actually relevant.
def multi(self, B):
assert len(B) == 1, self
if not self.unsafe:
assert len(B) == 1, self
return self.F(B[0])
def dmulti(self, dB):
@ -194,27 +197,33 @@ class Layer:
assert Y.shape[1:] == self.output_shape, (str(self), Y.shape[1:], self.output_shape)
def forward(self, lut):
assert len(self.parents) > 0, self
if not self.unsafe:
assert len(self.parents) > 0, self
B = []
for parent in self.parents:
# TODO: skip over irrelevant nodes (if any)
X = lut[parent]
self.validate_input(X)
if not self.unsafe:
self.validate_input(X)
B.append(X)
Y = self.multi(B)
self.validate_output(Y)
if not self.unsafe:
self.validate_output(Y)
return Y
def backward(self, lut):
assert len(self.children) > 0, self
if not self.unsafe:
assert len(self.children) > 0, self
dB = []
for child in self.children:
# TODO: skip over irrelevant nodes (if any)
dY = lut[child]
self.validate_output(dY)
if not self.unsafe:
self.validate_output(dY)
dB.append(dY)
dX = self.dmulti(dB)
self.validate_input(dX)
if not self.unsafe:
self.validate_input(dX)
return dX
# Final Layers
@ -337,13 +346,15 @@ class Dense(Layer):
# Model
class Model:
def __init__(self, x, y):
def __init__(self, x, y, unsafe=False):
assert isinstance(x, Layer), x
assert isinstance(y, Layer), y
self.x = x
self.y = y
self.ordered_nodes = self.traverse([], self.y)
self.make_weights()
for node in self.ordered_nodes:
node.unsafe = unsafe
def make_weights(self):
self.param_count = 0
@ -502,7 +513,7 @@ if __name__ == '__main__':
if last_size != output_samples:
y = y.feed(Dense(output_samples))
model = Model(x, y)
model = Model(x, y, unsafe=False)
node_names = ' '.join([str(node) for node in model.ordered_nodes])
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
@ -558,13 +569,13 @@ if __name__ == '__main__':
if i > 0:
log("restarting", i)
LR *= LRprod**config.LR_restart_advance
#optim.reset()
optim.reset()
assert inputs.shape[0] % config.batch_size == 0, \
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
batch_count = inputs.shape[0] // config.batch_size
for e in range(config.epochs):
indices = np.arange(len(inputs))
indices = np.arange(inputs.shape[0])
np.random.shuffle(indices)
shuffled_inputs = inputs[indices] / x_scale
shuffled_outputs = outputs[indices] / y_scale