From c3e2dd56bfcbf5c454502248f4547e745b71d777 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 10 Apr 2017 08:26:38 +0000 Subject: [PATCH] refactor shape-handling a little --- optim_nn.py | 10 ++++++---- optim_nn_core.py | 35 ++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/optim_nn.py b/optim_nn.py index f50dbe7..eddd8e5 100755 --- a/optim_nn.py +++ b/optim_nn.py @@ -101,8 +101,9 @@ class LayerNorm(Layer): 'beta': 'beta', } - def make_shape(self, shape): - super().make_shape(shape) + def make_shape(self, parent): + shape = parent.output_shape + self.input_shape = shape if len(shape) != 1: return False self.features = shape[0] @@ -164,8 +165,9 @@ class Denses(Layer): # TODO: rename? self.axis = int(axis) self.size = None - def make_shape(self, shape): - super().make_shape(shape) + def make_shape(self, parent): + shape = parent.output_shape + self.input_shape = shape if len(shape) != 2: return False diff --git a/optim_nn_core.py b/optim_nn_core.py index 47316db..f39ee60 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -290,19 +290,18 @@ class Layer: def backward(self, dY): raise NotImplementedError("unimplemented", self) + def make_shape(self, parent): + if self.input_shape == None: + self.input_shape = parent.output_shape + if self.output_shape == None: + self.output_shape = self.input_shape + def do_feed(self, child): self.children.append(child) def be_fed(self, parent): self.parents.append(parent) - def make_shape(self, shape): - if not self.unsafe: - assert shape is not None - if self.output_shape is None: - self.output_shape = shape - return shape - # TODO: better names for these (still) def _propagate(self, edges): @@ -318,15 +317,11 @@ class Layer: # general utility methods: def is_compatible(self, parent): - if self.input_shape is None: - # inherit shape from output - shape = self.make_shape(parent.output_shape) - if shape is None: - return False - self.input_shape = shape return np.all(self.input_shape == parent.output_shape) def feed(self, child): + assert self.output_shape is not None, self + child.make_shape(self) if not child.is_compatible(self): fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}" raise LayerIncompatibility(fmt.format(self, child, self.output_shape, child.input_shape)) @@ -408,8 +403,9 @@ class Reshape(Layer): return dY.reshape(self.batch_size, *self.input_shape) class Flatten(Layer): - def make_shape(self, shape): - super().make_shape(shape) + def make_shape(self, parent): + shape = parent.output_shape + self.input_shape = shape self.output_shape = (np.prod(shape),) return shape @@ -534,8 +530,9 @@ class Dense(Layer): self.weight_init = init self.size = None - def make_shape(self, shape): - super().make_shape(shape) + def make_shape(self, parent): + shape = parent.output_shape + self.input_shape = shape if len(shape) != 1: return False self.nW = self.dim * shape[0] @@ -734,6 +731,8 @@ class Ritual: # i'm just making up names at this point batch_size, prev_batch_size) # TODO: lift this restriction prev_batch_size = batch_size + # same from hereon + if not test_only and self.learner.per_batch: self.learner.batch(b / batch_count) @@ -786,6 +785,8 @@ class Ritual: # i'm just making up names at this point batch_inputs = inputs[ bi:bi+batch_size] batch_outputs = outputs[bi:bi+batch_size] + # same from hereon + if not test_only and self.learner.per_batch: self.learner.batch(b / batch_count)