add Decimate and Undecimate layers

This commit is contained in:
Connor Olding 2017-08-02 06:47:15 +00:00
parent f28e8d3a54
commit 5074dcb2aa

63
onn.py
View file

@ -358,6 +358,69 @@ class AlphaDropout(Layer):
def backward(self, dY):
return dY * self.a * self.mask
class Decimate(Layer):
# simple decimaton layer that drops every other sample from the last axis.
def __init__(self, phase='even'):
super().__init__()
# phase is the set of samples we keep in the forward pass.
assert phase in ('even', 'odd'), phase
self.phase = phase
def make_shape(self, parent):
shape = parent.output_shape
self.input_shape = shape
divy = (shape[-1] + 1) // 2 if self.phase == 'even' else shape[-1] // 2
self.output_shape = tuple(list(shape[:-1]) + [divy])
self.dX = np.zeros(self.input_shape, dtype=_f)
def forward(self, X):
self.batch_size = X.shape[0]
if self.phase == 'even':
return X.ravel()[0::2].reshape(self.batch_size, *self.output_shape)
elif self.phase == 'odd':
return X.ravel()[1::2].reshape(self.batch_size, *self.output_shape)
def backward(self, dY):
assert dY.shape[0] == self.batch_size
dX = np.zeros((self.batch_size, *self.input_shape), dtype=_f)
if self.phase == 'even':
dX.ravel()[0::2] = dY.ravel()
elif self.phase == 'odd':
dX.ravel()[1::2] = dY.ravel()
return dX
class Undecimate(Layer):
# reverse operation of Decimate. not quite interpolation.
def __init__(self, phase='even'):
super().__init__()
# phase is the set of samples we keep in the backward pass.
assert phase in ('even', 'odd'), phase
self.phase = phase
def make_shape(self, parent):
shape = parent.output_shape
self.input_shape = shape
mult = shape[-1] * 2
self.output_shape = tuple(list(shape[:-1]) + [mult])
def forward(self, X):
self.batch_size = X.shape[0]
Y = np.zeros((self.batch_size, *self.output_shape), dtype=_f)
if self.phase == 'even':
Y.ravel()[0::2] = X.ravel()
elif self.phase == 'odd':
Y.ravel()[1::2] = X.ravel()
return Y
def backward(self, dY):
assert dY.shape[0] == self.batch_size
if self.phase == 'even':
return dY.ravel()[0::2].reshape(self.batch_size, *self.input_shape)
elif self.phase == 'odd':
return dY.ravel()[1::2].reshape(self.batch_size, *self.input_shape)
# Activations {{{2
class Selu(Layer):