fix shape assertions
This commit is contained in:
parent
a78fc98215
commit
0619163447
2 changed files with 3 additions and 6 deletions
|
@ -106,8 +106,7 @@ class LayerNorm(Layer):
|
|||
shape = parent.output_shape
|
||||
self.input_shape = shape
|
||||
self.output_shape = shape
|
||||
if len(shape) != 1:
|
||||
return False
|
||||
assert len(shape) == 1, shape
|
||||
if self.affine:
|
||||
self.gamma.shape = (shape[0],)
|
||||
self.beta.shape = (shape[0],)
|
||||
|
@ -159,8 +158,7 @@ class Denses(Layer): # TODO: rename?
|
|||
def make_shape(self, parent):
|
||||
shape = parent.output_shape
|
||||
self.input_shape = shape
|
||||
if len(shape) != 2:
|
||||
return False
|
||||
assert len(shape) == 2, shape
|
||||
|
||||
assert -len(shape) <= self.axis < len(shape)
|
||||
self.axis = self.axis % len(shape)
|
||||
|
|
|
@ -585,8 +585,7 @@ class Dense(Layer):
|
|||
def make_shape(self, parent):
|
||||
shape = parent.output_shape
|
||||
self.input_shape = shape
|
||||
if len(shape) != 1:
|
||||
return False
|
||||
assert len(shape) == 1, shape
|
||||
self.coeffs.shape = (shape[0], self.dim)
|
||||
self.biases.shape = (1, self.dim)
|
||||
|
||||
|
|
Loading…
Reference in a new issue