refactor shape-handling a little

This commit is contained in:
Connor Olding 2017-04-10 08:26:38 +00:00
parent 1d729b98aa
commit c3e2dd56bf
2 changed files with 24 additions and 21 deletions

View file

@ -101,8 +101,9 @@ class LayerNorm(Layer):
'beta': 'beta',
}
def make_shape(self, shape):
super().make_shape(shape)
def make_shape(self, parent):
shape = parent.output_shape
self.input_shape = shape
if len(shape) != 1:
return False
self.features = shape[0]
@ -164,8 +165,9 @@ class Denses(Layer): # TODO: rename?
self.axis = int(axis)
self.size = None
def make_shape(self, shape):
super().make_shape(shape)
def make_shape(self, parent):
shape = parent.output_shape
self.input_shape = shape
if len(shape) != 2:
return False

View file

@ -290,19 +290,18 @@ class Layer:
def backward(self, dY):
raise NotImplementedError("unimplemented", self)
def make_shape(self, parent):
if self.input_shape == None:
self.input_shape = parent.output_shape
if self.output_shape == None:
self.output_shape = self.input_shape
def do_feed(self, child):
self.children.append(child)
def be_fed(self, parent):
self.parents.append(parent)
def make_shape(self, shape):
if not self.unsafe:
assert shape is not None
if self.output_shape is None:
self.output_shape = shape
return shape
# TODO: better names for these (still)
def _propagate(self, edges):
@ -318,15 +317,11 @@ class Layer:
# general utility methods:
def is_compatible(self, parent):
if self.input_shape is None:
# inherit shape from output
shape = self.make_shape(parent.output_shape)
if shape is None:
return False
self.input_shape = shape
return np.all(self.input_shape == parent.output_shape)
def feed(self, child):
assert self.output_shape is not None, self
child.make_shape(self)
if not child.is_compatible(self):
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
raise LayerIncompatibility(fmt.format(self, child, self.output_shape, child.input_shape))
@ -408,8 +403,9 @@ class Reshape(Layer):
return dY.reshape(self.batch_size, *self.input_shape)
class Flatten(Layer):
def make_shape(self, shape):
super().make_shape(shape)
def make_shape(self, parent):
shape = parent.output_shape
self.input_shape = shape
self.output_shape = (np.prod(shape),)
return shape
@ -534,8 +530,9 @@ class Dense(Layer):
self.weight_init = init
self.size = None
def make_shape(self, shape):
super().make_shape(shape)
def make_shape(self, parent):
shape = parent.output_shape
self.input_shape = shape
if len(shape) != 1:
return False
self.nW = self.dim * shape[0]
@ -734,6 +731,8 @@ class Ritual: # i'm just making up names at this point
batch_size, prev_batch_size) # TODO: lift this restriction
prev_batch_size = batch_size
# same from hereon
if not test_only and self.learner.per_batch:
self.learner.batch(b / batch_count)
@ -786,6 +785,8 @@ class Ritual: # i'm just making up names at this point
batch_inputs = inputs[ bi:bi+batch_size]
batch_outputs = outputs[bi:bi+batch_size]
# same from hereon
if not test_only and self.learner.per_batch:
self.learner.batch(b / batch_count)