add preliminary Sparse layer
This commit is contained in:
parent
4a5084df48
commit
5cd4e8d1c1
1 changed files with 59 additions and 0 deletions
|
@ -265,3 +265,62 @@ class CosineDense(Dense):
|
|||
dX = ddot @ self.coeffs.f.T + dX_norm / self.X_norm * self.X
|
||||
|
||||
return dX
|
||||
|
||||
|
||||
class Sparse(Layer):
|
||||
# (WIP)
|
||||
# roughly implements a structured, sparsely-connected layer.
|
||||
# paper: https://arxiv.org/abs/1812.01164
|
||||
|
||||
# TODO: (re)implement serialization.
|
||||
|
||||
def __init__(self, dim, con, init=init_he_uniform, reg=None):
|
||||
super().__init__()
|
||||
self.dim = int(dim)
|
||||
self.con = int(con)
|
||||
self.output_shape = (dim,)
|
||||
self.coeffs = self._new_weights('coeffs', init=init, regularizer=reg)
|
||||
self.indices = None
|
||||
|
||||
def make_shape(self, parent):
|
||||
shape = parent.output_shape
|
||||
self.input_shape = shape
|
||||
assert len(shape) == 1, shape
|
||||
self.coeffs.shape = (self.con, self.dim)
|
||||
self.size_in = shape[0]
|
||||
self.make_indices(self.size_in, self.con, self.dim)
|
||||
|
||||
def make_indices(self, size_in, connectivity, size_out):
|
||||
basic = np.arange(size_in)
|
||||
indices = []
|
||||
inv_ind = []
|
||||
count = 0
|
||||
desired = size_out * connectivity
|
||||
# TODO: replace with a for loop.
|
||||
while count < desired:
|
||||
np.random.shuffle(basic)
|
||||
indices.append(basic.copy())
|
||||
inverse = np.zeros_like(basic)
|
||||
inverse[basic] = np.arange(len(basic)) + count
|
||||
inv_ind.append(inverse)
|
||||
count += len(basic)
|
||||
self.indices = np.concatenate(indices)[:desired].copy()
|
||||
self.inv_ind = np.concatenate(inv_ind)
|
||||
|
||||
def forward(self, X):
|
||||
self.X = X
|
||||
self.O = X[:,self.indices].reshape(len(X), self.con, self.dim)
|
||||
return np.sum(self.O * self.coeffs.f, 1)
|
||||
|
||||
def backward(self, dY):
|
||||
dY = np.expand_dims(dY, 1)
|
||||
self.coeffs.g += np.sum(dY * self.O, 0)
|
||||
dO = dY * self.coeffs.f
|
||||
|
||||
x = dO
|
||||
batch_size = len(x)
|
||||
x = x.reshape(batch_size, -1)
|
||||
if x.shape[1] % self.size_in != 0:
|
||||
x = np.pad(x, ((0, 0), (0, self.size_in - x.shape[1] % self.size_in)))
|
||||
x = x[:, self.inv_ind].reshape(batch_size, -1, self.size_in)
|
||||
return x.sum(1)
|
||||
|
|
Loading…
Reference in a new issue