refactor shape-handling a little
This commit is contained in:
parent
1d729b98aa
commit
c3e2dd56bf
2 changed files with 24 additions and 21 deletions
10
optim_nn.py
10
optim_nn.py
|
@ -101,8 +101,9 @@ class LayerNorm(Layer):
|
|||
'beta': 'beta',
|
||||
}
|
||||
|
||||
def make_shape(self, shape):
|
||||
super().make_shape(shape)
|
||||
def make_shape(self, parent):
|
||||
shape = parent.output_shape
|
||||
self.input_shape = shape
|
||||
if len(shape) != 1:
|
||||
return False
|
||||
self.features = shape[0]
|
||||
|
@ -164,8 +165,9 @@ class Denses(Layer): # TODO: rename?
|
|||
self.axis = int(axis)
|
||||
self.size = None
|
||||
|
||||
def make_shape(self, shape):
|
||||
super().make_shape(shape)
|
||||
def make_shape(self, parent):
|
||||
shape = parent.output_shape
|
||||
self.input_shape = shape
|
||||
if len(shape) != 2:
|
||||
return False
|
||||
|
||||
|
|
|
@ -290,19 +290,18 @@ class Layer:
|
|||
def backward(self, dY):
|
||||
raise NotImplementedError("unimplemented", self)
|
||||
|
||||
def make_shape(self, parent):
|
||||
if self.input_shape == None:
|
||||
self.input_shape = parent.output_shape
|
||||
if self.output_shape == None:
|
||||
self.output_shape = self.input_shape
|
||||
|
||||
def do_feed(self, child):
|
||||
self.children.append(child)
|
||||
|
||||
def be_fed(self, parent):
|
||||
self.parents.append(parent)
|
||||
|
||||
def make_shape(self, shape):
|
||||
if not self.unsafe:
|
||||
assert shape is not None
|
||||
if self.output_shape is None:
|
||||
self.output_shape = shape
|
||||
return shape
|
||||
|
||||
# TODO: better names for these (still)
|
||||
|
||||
def _propagate(self, edges):
|
||||
|
@ -318,15 +317,11 @@ class Layer:
|
|||
# general utility methods:
|
||||
|
||||
def is_compatible(self, parent):
|
||||
if self.input_shape is None:
|
||||
# inherit shape from output
|
||||
shape = self.make_shape(parent.output_shape)
|
||||
if shape is None:
|
||||
return False
|
||||
self.input_shape = shape
|
||||
return np.all(self.input_shape == parent.output_shape)
|
||||
|
||||
def feed(self, child):
|
||||
assert self.output_shape is not None, self
|
||||
child.make_shape(self)
|
||||
if not child.is_compatible(self):
|
||||
fmt = "{} is incompatible with {}: shape mismatch: {} vs. {}"
|
||||
raise LayerIncompatibility(fmt.format(self, child, self.output_shape, child.input_shape))
|
||||
|
@ -408,8 +403,9 @@ class Reshape(Layer):
|
|||
return dY.reshape(self.batch_size, *self.input_shape)
|
||||
|
||||
class Flatten(Layer):
|
||||
def make_shape(self, shape):
|
||||
super().make_shape(shape)
|
||||
def make_shape(self, parent):
|
||||
shape = parent.output_shape
|
||||
self.input_shape = shape
|
||||
self.output_shape = (np.prod(shape),)
|
||||
return shape
|
||||
|
||||
|
@ -534,8 +530,9 @@ class Dense(Layer):
|
|||
self.weight_init = init
|
||||
self.size = None
|
||||
|
||||
def make_shape(self, shape):
|
||||
super().make_shape(shape)
|
||||
def make_shape(self, parent):
|
||||
shape = parent.output_shape
|
||||
self.input_shape = shape
|
||||
if len(shape) != 1:
|
||||
return False
|
||||
self.nW = self.dim * shape[0]
|
||||
|
@ -734,6 +731,8 @@ class Ritual: # i'm just making up names at this point
|
|||
batch_size, prev_batch_size) # TODO: lift this restriction
|
||||
prev_batch_size = batch_size
|
||||
|
||||
# same from hereon
|
||||
|
||||
if not test_only and self.learner.per_batch:
|
||||
self.learner.batch(b / batch_count)
|
||||
|
||||
|
@ -786,6 +785,8 @@ class Ritual: # i'm just making up names at this point
|
|||
batch_inputs = inputs[ bi:bi+batch_size]
|
||||
batch_outputs = outputs[bi:bi+batch_size]
|
||||
|
||||
# same from hereon
|
||||
|
||||
if not test_only and self.learner.per_batch:
|
||||
self.learner.batch(b / batch_count)
|
||||
|
||||
|
|
Loading…
Reference in a new issue