fix shape assertions
This commit is contained in:
parent
a78fc98215
commit
0619163447
2 changed files with 3 additions and 6 deletions
|
@ -106,8 +106,7 @@ class LayerNorm(Layer):
|
||||||
shape = parent.output_shape
|
shape = parent.output_shape
|
||||||
self.input_shape = shape
|
self.input_shape = shape
|
||||||
self.output_shape = shape
|
self.output_shape = shape
|
||||||
if len(shape) != 1:
|
assert len(shape) == 1, shape
|
||||||
return False
|
|
||||||
if self.affine:
|
if self.affine:
|
||||||
self.gamma.shape = (shape[0],)
|
self.gamma.shape = (shape[0],)
|
||||||
self.beta.shape = (shape[0],)
|
self.beta.shape = (shape[0],)
|
||||||
|
@ -159,8 +158,7 @@ class Denses(Layer): # TODO: rename?
|
||||||
def make_shape(self, parent):
|
def make_shape(self, parent):
|
||||||
shape = parent.output_shape
|
shape = parent.output_shape
|
||||||
self.input_shape = shape
|
self.input_shape = shape
|
||||||
if len(shape) != 2:
|
assert len(shape) == 2, shape
|
||||||
return False
|
|
||||||
|
|
||||||
assert -len(shape) <= self.axis < len(shape)
|
assert -len(shape) <= self.axis < len(shape)
|
||||||
self.axis = self.axis % len(shape)
|
self.axis = self.axis % len(shape)
|
||||||
|
|
|
@ -585,8 +585,7 @@ class Dense(Layer):
|
||||||
def make_shape(self, parent):
|
def make_shape(self, parent):
|
||||||
shape = parent.output_shape
|
shape = parent.output_shape
|
||||||
self.input_shape = shape
|
self.input_shape = shape
|
||||||
if len(shape) != 1:
|
assert len(shape) == 1, shape
|
||||||
return False
|
|
||||||
self.coeffs.shape = (shape[0], self.dim)
|
self.coeffs.shape = (shape[0], self.dim)
|
||||||
self.biases.shape = (1, self.dim)
|
self.biases.shape = (1, self.dim)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue