refactor shape-handling a little

This commit is contained in:
Connor Olding 2017-04-10 08:26:38 +00:00
parent 1d729b98aa
commit c3e2dd56bf
2 changed files with 24 additions and 21 deletions

View file

@ -101,8 +101,9 @@ class LayerNorm(Layer):
'beta': 'beta', 'beta': 'beta',
} }
def make_shape(self, shape): def make_shape(self, parent):
super().make_shape(shape) shape = parent.output_shape
self.input_shape = shape
if len(shape) != 1: if len(shape) != 1:
return False return False
self.features = shape[0] self.features = shape[0]
@ -164,8 +165,9 @@ class Denses(Layer): # TODO: rename?
self.axis = int(axis) self.axis = int(axis)
self.size = None self.size = None
def make_shape(self, shape): def make_shape(self, parent):
super().make_shape(shape) shape = parent.output_shape
self.input_shape = shape
if len(shape) != 2: if len(shape) != 2:
return False return False

View file

@ -290,19 +290,18 @@ class Layer:
def backward(self, dY): def backward(self, dY):
raise NotImplementedError("unimplemented", self) raise NotImplementedError("unimplemented", self)
def make_shape(self, parent):
if self.input_shape == None:
self.input_shape = parent.output_shape
if self.output_shape == None:
self.output_shape = self.input_shape
def do_feed(self, child): def do_feed(self, child):
self.children.append(child) self.children.append(child)
def be_fed(self, parent): def be_fed(self, parent):
self.parents.append(parent) self.parents.append(parent)
def make_shape(self, shape):
if not self.unsafe:
assert shape is not None
if self.output_shape is None:
self.output_shape = shape
return shape
# TODO: better names for these (still) # TODO: better names for these (still)
def _propagate(self, edges): def _propagate(self, edges):
@ -318,15 +317,11 @@ class Layer:
# general utility methods: # general utility methods:
def is_compatible(self, parent): def is_compatible(self, parent):
if self.input_shape is None:
# inherit shape from output
shape = self.make_shape(parent.output_shape)
if shape is None:
return False
self.input_shape = shape
return np.all(self.input_shape == parent.output_shape) return np.all(self.input_shape == parent.output_shape)
def feed(self, child): def feed(self, child):
assert self.output_shape is not None, self
child.make_shape(self)
if not child.is_compatible(self): if not child.is_compatible(self):
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}" fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
raise LayerIncompatibility(fmt.format(self, child, self.output_shape, child.input_shape)) raise LayerIncompatibility(fmt.format(self, child, self.output_shape, child.input_shape))
@ -408,8 +403,9 @@ class Reshape(Layer):
return dY.reshape(self.batch_size, *self.input_shape) return dY.reshape(self.batch_size, *self.input_shape)
class Flatten(Layer): class Flatten(Layer):
def make_shape(self, shape): def make_shape(self, parent):
super().make_shape(shape) shape = parent.output_shape
self.input_shape = shape
self.output_shape = (np.prod(shape),) self.output_shape = (np.prod(shape),)
return shape return shape
@ -534,8 +530,9 @@ class Dense(Layer):
self.weight_init = init self.weight_init = init
self.size = None self.size = None
def make_shape(self, shape): def make_shape(self, parent):
super().make_shape(shape) shape = parent.output_shape
self.input_shape = shape
if len(shape) != 1: if len(shape) != 1:
return False return False
self.nW = self.dim * shape[0] self.nW = self.dim * shape[0]
@ -734,6 +731,8 @@ class Ritual: # i'm just making up names at this point
batch_size, prev_batch_size) # TODO: lift this restriction batch_size, prev_batch_size) # TODO: lift this restriction
prev_batch_size = batch_size prev_batch_size = batch_size
# same from hereon
if not test_only and self.learner.per_batch: if not test_only and self.learner.per_batch:
self.learner.batch(b / batch_count) self.learner.batch(b / batch_count)
@ -786,6 +785,8 @@ class Ritual: # i'm just making up names at this point
batch_inputs = inputs[ bi:bi+batch_size] batch_inputs = inputs[ bi:bi+batch_size]
batch_outputs = outputs[bi:bi+batch_size] batch_outputs = outputs[bi:bi+batch_size]
# same from hereon
if not test_only and self.learner.per_batch: if not test_only and self.learner.per_batch:
self.learner.batch(b / batch_count) self.learner.batch(b / batch_count)