fix Swish inits
This commit is contained in:
parent
bd1e80b8de
commit
b74e0941dc
1 changed files with 4 additions and 3 deletions
|
@ -94,6 +94,7 @@ class Swish(Activation):
|
|||
# note that Swish generalizes both SiLU and an approximation of GELU.
|
||||
|
||||
def __init__(self, scale=1.0):
|
||||
super().__init__()
|
||||
self.scale = _f(scale)
|
||||
|
||||
def forward(self, X):
|
||||
|
@ -108,15 +109,15 @@ class Swish(Activation):
|
|||
class Silu(Swish):
|
||||
# paper: https://arxiv.org/abs/1702.03118
|
||||
def __init__(self):
|
||||
self.scale = _1
|
||||
super().__init__(_1)
|
||||
|
||||
|
||||
class GeluApprox(Activation):
|
||||
class GeluApprox(Swish):
|
||||
# paper: https://arxiv.org/abs/1606.08415
|
||||
# plot: https://www.desmos.com/calculator/ydzgtccsld
|
||||
|
||||
def __init__(self):
|
||||
self.scale = _f(1.704)
|
||||
super().__init__(_f(1.704))
|
||||
|
||||
|
||||
class Softmax(Activation):
|
||||
|
|
Loading…
Reference in a new issue