diff --git a/onn_core.py b/onn_core.py index 15a6f0a..0c2f2bf 100644 --- a/onn_core.py +++ b/onn_core.py @@ -840,7 +840,7 @@ class Bias(Layer): shape = parent.output_shape self.input_shape = shape self.output_shape = shape - self.biases.shape = (self.dim,) + self.biases.shape = (shape[-1],) def forward(self, X): return X + self.biases.f