diff --git a/onn/parametric.py b/onn/parametric.py index ea7b9b7..65008a2 100644 --- a/onn/parametric.py +++ b/onn/parametric.py @@ -265,3 +265,62 @@ class CosineDense(Dense): dX = ddot @ self.coeffs.f.T + dX_norm / self.X_norm * self.X return dX + + +class Sparse(Layer): + # (WIP) + # roughly implements a structured, sparsely-connected layer. + # paper: https://arxiv.org/abs/1812.01164 + + # TODO: (re)implement serialization. + + def __init__(self, dim, con, init=init_he_uniform, reg=None): + super().__init__() + self.dim = int(dim) + self.con = int(con) + self.output_shape = (dim,) + self.coeffs = self._new_weights('coeffs', init=init, regularizer=reg) + self.indices = None + + def make_shape(self, parent): + shape = parent.output_shape + self.input_shape = shape + assert len(shape) == 1, shape + self.coeffs.shape = (self.con, self.dim) + self.size_in = shape[0] + self.make_indices(self.size_in, self.con, self.dim) + + def make_indices(self, size_in, connectivity, size_out): + basic = np.arange(size_in) + indices = [] + inv_ind = [] + count = 0 + desired = size_out * connectivity + # TODO: replace with a for loop. + while count < desired: + np.random.shuffle(basic) + indices.append(basic.copy()) + inverse = np.zeros_like(basic) + inverse[basic] = np.arange(len(basic)) + count + inv_ind.append(inverse) + count += len(basic) + self.indices = np.concatenate(indices)[:desired].copy() + self.inv_ind = np.concatenate(inv_ind) + + def forward(self, X): + self.X = X + self.O = X[:,self.indices].reshape(len(X), self.con, self.dim) + return np.sum(self.O * self.coeffs.f, 1) + + def backward(self, dY): + dY = np.expand_dims(dY, 1) + self.coeffs.g += np.sum(dY * self.O, 0) + dO = dY * self.coeffs.f + + x = dO + batch_size = len(x) + x = x.reshape(batch_size, -1) + if x.shape[1] % self.size_in != 0: + x = np.pad(x, ((0, 0), (0, self.size_in - x.shape[1] % self.size_in))) + x = x[:, self.inv_ind].reshape(batch_size, -1, self.size_in) + return x.sum(1)