fix Swish inits

This commit is contained in:
Connor Olding 2018-03-12 02:15:36 +01:00
parent bd1e80b8de
commit b74e0941dc

View file

@ -94,6 +94,7 @@ class Swish(Activation):
# note that Swish generalizes both SiLU and an approximation of GELU.
def __init__(self, scale=1.0):
super().__init__()
self.scale = _f(scale)
def forward(self, X):
@ -108,15 +109,15 @@ class Swish(Activation):
class Silu(Swish):
# paper: https://arxiv.org/abs/1702.03118
def __init__(self):
self.scale = _1
super().__init__(_1)
class GeluApprox(Activation):
class GeluApprox(Swish):
# paper: https://arxiv.org/abs/1606.08415
# plot: https://www.desmos.com/calculator/ydzgtccsld
def __init__(self):
self.scale = _f(1.704)
super().__init__(_f(1.704))
class Softmax(Activation):