.
This commit is contained in:
parent
8337fdb39c
commit
d442241e76
1 changed files with 14 additions and 0 deletions
|
@ -393,6 +393,20 @@ class Reshape(Layer):
|
||||||
assert dY.shape[0] == self.batch_size
|
assert dY.shape[0] == self.batch_size
|
||||||
return dY.reshape(self.batch_size, *self.input_shape)
|
return dY.reshape(self.batch_size, *self.input_shape)
|
||||||
|
|
||||||
|
class Flatten(Layer):
|
||||||
|
def make_shape(self, shape):
|
||||||
|
super().make_shape(shape)
|
||||||
|
self.output_shape = (np.prod(shape),)
|
||||||
|
return shape
|
||||||
|
|
||||||
|
def F(self, X):
|
||||||
|
self.batch_size = X.shape[0]
|
||||||
|
return X.reshape(self.batch_size, *self.output_shape)
|
||||||
|
|
||||||
|
def dF(self, dY):
|
||||||
|
assert dY.shape[0] == self.batch_size
|
||||||
|
return dY.reshape(self.batch_size, *self.input_shape)
|
||||||
|
|
||||||
class Affine(Layer):
|
class Affine(Layer):
|
||||||
def __init__(self, a=1, b=0):
|
def __init__(self, a=1, b=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in a new issue