fix Swish inits

This commit is contained in:
Connor Olding 2018-03-12 02:15:36 +01:00
parent bd1e80b8de
commit b74e0941dc

View File

@ -94,6 +94,7 @@ class Swish(Activation):
# note that Swish generalizes both SiLU and an approximation of GELU. # note that Swish generalizes both SiLU and an approximation of GELU.
def __init__(self, scale=1.0): def __init__(self, scale=1.0):
super().__init__()
self.scale = _f(scale) self.scale = _f(scale)
def forward(self, X): def forward(self, X):
@ -108,15 +109,15 @@ class Swish(Activation):
class Silu(Swish): class Silu(Swish):
# paper: https://arxiv.org/abs/1702.03118 # paper: https://arxiv.org/abs/1702.03118
def __init__(self): def __init__(self):
self.scale = _1 super().__init__(_1)
class GeluApprox(Activation): class GeluApprox(Swish):
# paper: https://arxiv.org/abs/1606.08415 # paper: https://arxiv.org/abs/1606.08415
# plot: https://www.desmos.com/calculator/ydzgtccsld # plot: https://www.desmos.com/calculator/ydzgtccsld
def __init__(self): def __init__(self):
self.scale = _f(1.704) super().__init__(_f(1.704))
class Softmax(Activation): class Softmax(Activation):