diff --git a/onn/parametric.py b/onn/parametric.py index 65008a2..24c4dca 100644 --- a/onn/parametric.py +++ b/onn/parametric.py @@ -62,6 +62,33 @@ class Dense(Layer): return dY @ self.coeffs.f.T +class DenseUnbiased(Layer): + serialized = { + 'W': 'coeffs', + } + + def __init__(self, dim, init=init_he_uniform, reg_w=None): + super().__init__() + self.dim = int(dim) + self.output_shape = (dim,) + self.coeffs = self._new_weights('coeffs', init=init, + regularizer=reg_w) + + def make_shape(self, parent): + shape = parent.output_shape + self.input_shape = shape + assert len(shape) == 1, shape + self.coeffs.shape = (shape[0], self.dim) + + def forward(self, X): + self.X = X + return X @ self.coeffs.f + + def backward(self, dY): + self.coeffs.g += self.X.T @ dY + return dY @ self.coeffs.f.T + + # more class Conv1Dper(Layer):