.
This commit is contained in:
parent
8337fdb39c
commit
d442241e76
1 changed files with 14 additions and 0 deletions
|
@ -393,6 +393,20 @@ class Reshape(Layer):
|
|||
assert dY.shape[0] == self.batch_size
|
||||
return dY.reshape(self.batch_size, *self.input_shape)
|
||||
|
||||
class Flatten(Layer):
|
||||
def make_shape(self, shape):
|
||||
super().make_shape(shape)
|
||||
self.output_shape = (np.prod(shape),)
|
||||
return shape
|
||||
|
||||
def F(self, X):
|
||||
self.batch_size = X.shape[0]
|
||||
return X.reshape(self.batch_size, *self.output_shape)
|
||||
|
||||
def dF(self, dY):
|
||||
assert dY.shape[0] == self.batch_size
|
||||
return dY.reshape(self.batch_size, *self.input_shape)
|
||||
|
||||
class Affine(Layer):
|
||||
def __init__(self, a=1, b=0):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in a new issue