remove some unused arguments
This commit is contained in:
parent
2e74c9160c
commit
058a779f6c
2 changed files with 2 additions and 7 deletions
|
@ -787,10 +787,6 @@ class GeluApprox(Layer):
|
|||
return dY * self.sig * (1 + self.a * (1 - self.sig))
|
||||
|
||||
class Softmax(Layer):
|
||||
def __init__(self, axis=-1):
|
||||
super().__init__()
|
||||
self.axis = int(axis)
|
||||
|
||||
def forward(self, X):
|
||||
alpha = np.max(X, axis=-1, keepdims=True)
|
||||
num = np.exp(X - alpha)
|
||||
|
@ -802,9 +798,8 @@ class Softmax(Layer):
|
|||
return (dY - np.sum(dY * self.sm, axis=-1, keepdims=True)) * self.sm
|
||||
|
||||
class LogSoftmax(Softmax):
|
||||
def __init__(self, axis=-1, eps=1e-6):
|
||||
def __init__(self, eps=1e-6):
|
||||
super().__init__()
|
||||
self.axis = int(axis)
|
||||
self.eps = _f(eps)
|
||||
|
||||
def forward(self, X):
|
||||
|
|
|
@ -200,7 +200,7 @@ logs = DotMap(
|
|||
)
|
||||
|
||||
def measure_error(quiet=False):
|
||||
def print_error(name, inputs, outputs, comparison=None):
|
||||
def print_error(name, inputs, outputs):
|
||||
loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both')
|
||||
|
||||
if not quiet:
|
||||
|
|
Loading…
Add table
Reference in a new issue