fix shape assertions

This commit is contained in:
Connor Olding 2017-04-10 10:34:58 +00:00
parent a78fc98215
commit 0619163447
2 changed files with 3 additions and 6 deletions

View file

@ -106,8 +106,7 @@ class LayerNorm(Layer):
shape = parent.output_shape shape = parent.output_shape
self.input_shape = shape self.input_shape = shape
self.output_shape = shape self.output_shape = shape
if len(shape) != 1: assert len(shape) == 1, shape
return False
if self.affine: if self.affine:
self.gamma.shape = (shape[0],) self.gamma.shape = (shape[0],)
self.beta.shape = (shape[0],) self.beta.shape = (shape[0],)
@ -159,8 +158,7 @@ class Denses(Layer): # TODO: rename?
def make_shape(self, parent): def make_shape(self, parent):
shape = parent.output_shape shape = parent.output_shape
self.input_shape = shape self.input_shape = shape
if len(shape) != 2: assert len(shape) == 2, shape
return False
assert -len(shape) <= self.axis < len(shape) assert -len(shape) <= self.axis < len(shape)
self.axis = self.axis % len(shape) self.axis = self.axis % len(shape)

View file

@ -585,8 +585,7 @@ class Dense(Layer):
def make_shape(self, parent): def make_shape(self, parent):
shape = parent.output_shape shape = parent.output_shape
self.input_shape = shape self.input_shape = shape
if len(shape) != 1: assert len(shape) == 1, shape
return False
self.coeffs.shape = (shape[0], self.dim) self.coeffs.shape = (shape[0], self.dim)
self.biases.shape = (1, self.dim) self.biases.shape = (1, self.dim)