.
This commit is contained in:
parent
cf9010d25f
commit
5e06190acc
1 changed files with 23 additions and 12 deletions
35
optim_nn.py
35
optim_nn.py
|
@ -123,6 +123,7 @@ class Layer:
|
|||
_layer_counters[kind] += 1
|
||||
self.name = "{}_{}".format(kind, _layer_counters[kind])
|
||||
self.size = None # total weight count (if any)
|
||||
self.unsafe = False # aka gotta go fast mode
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
@ -142,7 +143,8 @@ class Layer:
|
|||
self.parents.append(parent)
|
||||
|
||||
def make_shape(self, shape):
|
||||
assert shape is not None
|
||||
if not self.unsafe:
|
||||
assert shape is not None
|
||||
if self.output_shape is None:
|
||||
self.output_shape = shape
|
||||
return shape
|
||||
|
@ -150,7 +152,8 @@ class Layer:
|
|||
# TODO: rename this multi and B crap to something actually relevant.
|
||||
|
||||
def multi(self, B):
|
||||
assert len(B) == 1, self
|
||||
if not self.unsafe:
|
||||
assert len(B) == 1, self
|
||||
return self.F(B[0])
|
||||
|
||||
def dmulti(self, dB):
|
||||
|
@ -194,27 +197,33 @@ class Layer:
|
|||
assert Y.shape[1:] == self.output_shape, (str(self), Y.shape[1:], self.output_shape)
|
||||
|
||||
def forward(self, lut):
|
||||
assert len(self.parents) > 0, self
|
||||
if not self.unsafe:
|
||||
assert len(self.parents) > 0, self
|
||||
B = []
|
||||
for parent in self.parents:
|
||||
# TODO: skip over irrelevant nodes (if any)
|
||||
X = lut[parent]
|
||||
self.validate_input(X)
|
||||
if not self.unsafe:
|
||||
self.validate_input(X)
|
||||
B.append(X)
|
||||
Y = self.multi(B)
|
||||
self.validate_output(Y)
|
||||
if not self.unsafe:
|
||||
self.validate_output(Y)
|
||||
return Y
|
||||
|
||||
def backward(self, lut):
|
||||
assert len(self.children) > 0, self
|
||||
if not self.unsafe:
|
||||
assert len(self.children) > 0, self
|
||||
dB = []
|
||||
for child in self.children:
|
||||
# TODO: skip over irrelevant nodes (if any)
|
||||
dY = lut[child]
|
||||
self.validate_output(dY)
|
||||
if not self.unsafe:
|
||||
self.validate_output(dY)
|
||||
dB.append(dY)
|
||||
dX = self.dmulti(dB)
|
||||
self.validate_input(dX)
|
||||
if not self.unsafe:
|
||||
self.validate_input(dX)
|
||||
return dX
|
||||
|
||||
# Final Layers
|
||||
|
@ -337,13 +346,15 @@ class Dense(Layer):
|
|||
# Model
|
||||
|
||||
class Model:
|
||||
def __init__(self, x, y):
|
||||
def __init__(self, x, y, unsafe=False):
|
||||
assert isinstance(x, Layer), x
|
||||
assert isinstance(y, Layer), y
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.ordered_nodes = self.traverse([], self.y)
|
||||
self.make_weights()
|
||||
for node in self.ordered_nodes:
|
||||
node.unsafe = unsafe
|
||||
|
||||
def make_weights(self):
|
||||
self.param_count = 0
|
||||
|
@ -502,7 +513,7 @@ if __name__ == '__main__':
|
|||
if last_size != output_samples:
|
||||
y = y.feed(Dense(output_samples))
|
||||
|
||||
model = Model(x, y)
|
||||
model = Model(x, y, unsafe=False)
|
||||
|
||||
node_names = ' '.join([str(node) for node in model.ordered_nodes])
|
||||
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
|
||||
|
@ -558,13 +569,13 @@ if __name__ == '__main__':
|
|||
if i > 0:
|
||||
log("restarting", i)
|
||||
LR *= LRprod**config.LR_restart_advance
|
||||
#optim.reset()
|
||||
optim.reset()
|
||||
|
||||
assert inputs.shape[0] % config.batch_size == 0, \
|
||||
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
||||
batch_count = inputs.shape[0] // config.batch_size
|
||||
for e in range(config.epochs):
|
||||
indices = np.arange(len(inputs))
|
||||
indices = np.arange(inputs.shape[0])
|
||||
np.random.shuffle(indices)
|
||||
shuffled_inputs = inputs[indices] / x_scale
|
||||
shuffled_outputs = outputs[indices] / y_scale
|
||||
|
|
Loading…
Reference in a new issue