diff --git a/optim_nn_core.py b/optim_nn_core.py index f2b8b2f..e2be96d 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -393,6 +393,20 @@ class Reshape(Layer): assert dY.shape[0] == self.batch_size return dY.reshape(self.batch_size, *self.input_shape) +class Flatten(Layer): + def make_shape(self, shape): + super().make_shape(shape) + self.output_shape = (np.prod(shape),) + return shape + + def F(self, X): + self.batch_size = X.shape[0] + return X.reshape(self.batch_size, *self.output_shape) + + def dF(self, dY): + assert dY.shape[0] == self.batch_size + return dY.reshape(self.batch_size, *self.input_shape) + class Affine(Layer): def __init__(self, a=1, b=0): super().__init__()