refactor shape-handling a little
This commit is contained in:
parent
1d729b98aa
commit
c3e2dd56bf
2 changed files with 24 additions and 21 deletions
10
optim_nn.py
10
optim_nn.py
|
@ -101,8 +101,9 @@ class LayerNorm(Layer):
|
||||||
'beta': 'beta',
|
'beta': 'beta',
|
||||||
}
|
}
|
||||||
|
|
||||||
def make_shape(self, shape):
|
def make_shape(self, parent):
|
||||||
super().make_shape(shape)
|
shape = parent.output_shape
|
||||||
|
self.input_shape = shape
|
||||||
if len(shape) != 1:
|
if len(shape) != 1:
|
||||||
return False
|
return False
|
||||||
self.features = shape[0]
|
self.features = shape[0]
|
||||||
|
@ -164,8 +165,9 @@ class Denses(Layer): # TODO: rename?
|
||||||
self.axis = int(axis)
|
self.axis = int(axis)
|
||||||
self.size = None
|
self.size = None
|
||||||
|
|
||||||
def make_shape(self, shape):
|
def make_shape(self, parent):
|
||||||
super().make_shape(shape)
|
shape = parent.output_shape
|
||||||
|
self.input_shape = shape
|
||||||
if len(shape) != 2:
|
if len(shape) != 2:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -290,19 +290,18 @@ class Layer:
|
||||||
def backward(self, dY):
|
def backward(self, dY):
|
||||||
raise NotImplementedError("unimplemented", self)
|
raise NotImplementedError("unimplemented", self)
|
||||||
|
|
||||||
|
def make_shape(self, parent):
|
||||||
|
if self.input_shape == None:
|
||||||
|
self.input_shape = parent.output_shape
|
||||||
|
if self.output_shape == None:
|
||||||
|
self.output_shape = self.input_shape
|
||||||
|
|
||||||
def do_feed(self, child):
|
def do_feed(self, child):
|
||||||
self.children.append(child)
|
self.children.append(child)
|
||||||
|
|
||||||
def be_fed(self, parent):
|
def be_fed(self, parent):
|
||||||
self.parents.append(parent)
|
self.parents.append(parent)
|
||||||
|
|
||||||
def make_shape(self, shape):
|
|
||||||
if not self.unsafe:
|
|
||||||
assert shape is not None
|
|
||||||
if self.output_shape is None:
|
|
||||||
self.output_shape = shape
|
|
||||||
return shape
|
|
||||||
|
|
||||||
# TODO: better names for these (still)
|
# TODO: better names for these (still)
|
||||||
|
|
||||||
def _propagate(self, edges):
|
def _propagate(self, edges):
|
||||||
|
@ -318,15 +317,11 @@ class Layer:
|
||||||
# general utility methods:
|
# general utility methods:
|
||||||
|
|
||||||
def is_compatible(self, parent):
|
def is_compatible(self, parent):
|
||||||
if self.input_shape is None:
|
|
||||||
# inherit shape from output
|
|
||||||
shape = self.make_shape(parent.output_shape)
|
|
||||||
if shape is None:
|
|
||||||
return False
|
|
||||||
self.input_shape = shape
|
|
||||||
return np.all(self.input_shape == parent.output_shape)
|
return np.all(self.input_shape == parent.output_shape)
|
||||||
|
|
||||||
def feed(self, child):
|
def feed(self, child):
|
||||||
|
assert self.output_shape is not None, self
|
||||||
|
child.make_shape(self)
|
||||||
if not child.is_compatible(self):
|
if not child.is_compatible(self):
|
||||||
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
|
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
|
||||||
raise LayerIncompatibility(fmt.format(self, child, self.output_shape, child.input_shape))
|
raise LayerIncompatibility(fmt.format(self, child, self.output_shape, child.input_shape))
|
||||||
|
@ -408,8 +403,9 @@ class Reshape(Layer):
|
||||||
return dY.reshape(self.batch_size, *self.input_shape)
|
return dY.reshape(self.batch_size, *self.input_shape)
|
||||||
|
|
||||||
class Flatten(Layer):
|
class Flatten(Layer):
|
||||||
def make_shape(self, shape):
|
def make_shape(self, parent):
|
||||||
super().make_shape(shape)
|
shape = parent.output_shape
|
||||||
|
self.input_shape = shape
|
||||||
self.output_shape = (np.prod(shape),)
|
self.output_shape = (np.prod(shape),)
|
||||||
return shape
|
return shape
|
||||||
|
|
||||||
|
@ -534,8 +530,9 @@ class Dense(Layer):
|
||||||
self.weight_init = init
|
self.weight_init = init
|
||||||
self.size = None
|
self.size = None
|
||||||
|
|
||||||
def make_shape(self, shape):
|
def make_shape(self, parent):
|
||||||
super().make_shape(shape)
|
shape = parent.output_shape
|
||||||
|
self.input_shape = shape
|
||||||
if len(shape) != 1:
|
if len(shape) != 1:
|
||||||
return False
|
return False
|
||||||
self.nW = self.dim * shape[0]
|
self.nW = self.dim * shape[0]
|
||||||
|
@ -734,6 +731,8 @@ class Ritual: # i'm just making up names at this point
|
||||||
batch_size, prev_batch_size) # TODO: lift this restriction
|
batch_size, prev_batch_size) # TODO: lift this restriction
|
||||||
prev_batch_size = batch_size
|
prev_batch_size = batch_size
|
||||||
|
|
||||||
|
# same from hereon
|
||||||
|
|
||||||
if not test_only and self.learner.per_batch:
|
if not test_only and self.learner.per_batch:
|
||||||
self.learner.batch(b / batch_count)
|
self.learner.batch(b / batch_count)
|
||||||
|
|
||||||
|
@ -786,6 +785,8 @@ class Ritual: # i'm just making up names at this point
|
||||||
batch_inputs = inputs[ bi:bi+batch_size]
|
batch_inputs = inputs[ bi:bi+batch_size]
|
||||||
batch_outputs = outputs[bi:bi+batch_size]
|
batch_outputs = outputs[bi:bi+batch_size]
|
||||||
|
|
||||||
|
# same from hereon
|
||||||
|
|
||||||
if not test_only and self.learner.per_batch:
|
if not test_only and self.learner.per_batch:
|
||||||
self.learner.batch(b / batch_count)
|
self.learner.batch(b / batch_count)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue