This commit is contained in:
Connor Olding 2017-02-27 02:05:33 +00:00
parent 8337fdb39c
commit d442241e76

View file

@ -393,6 +393,20 @@ class Reshape(Layer):
assert dY.shape[0] == self.batch_size assert dY.shape[0] == self.batch_size
return dY.reshape(self.batch_size, *self.input_shape) return dY.reshape(self.batch_size, *self.input_shape)
class Flatten(Layer):
def make_shape(self, shape):
super().make_shape(shape)
self.output_shape = (np.prod(shape),)
return shape
def F(self, X):
self.batch_size = X.shape[0]
return X.reshape(self.batch_size, *self.output_shape)
def dF(self, dY):
assert dY.shape[0] == self.batch_size
return dY.reshape(self.batch_size, *self.input_shape)
class Affine(Layer): class Affine(Layer):
def __init__(self, a=1, b=0): def __init__(self, a=1, b=0):
super().__init__() super().__init__()