From d442241e76508b8cc069dbfa50d79e0220e1d2db Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 27 Feb 2017 02:05:33 +0000 Subject: [PATCH] . --- optim_nn_core.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/optim_nn_core.py b/optim_nn_core.py index f2b8b2f..e2be96d 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -393,6 +393,20 @@ class Reshape(Layer): assert dY.shape[0] == self.batch_size return dY.reshape(self.batch_size, *self.input_shape) +class Flatten(Layer): + def make_shape(self, shape): + super().make_shape(shape) + self.output_shape = (np.prod(shape),) + return shape + + def F(self, X): + self.batch_size = X.shape[0] + return X.reshape(self.batch_size, *self.output_shape) + + def dF(self, dY): + assert dY.shape[0] == self.batch_size + return dY.reshape(self.batch_size, *self.input_shape) + class Affine(Layer): def __init__(self, a=1, b=0): super().__init__()