.
This commit is contained in:
parent
8baa7a267a
commit
d299520fd9
1 changed files with 34 additions and 10 deletions
44
optim_nn.py
44
optim_nn.py
|
@ -230,6 +230,23 @@ class Affine(Layer):
|
|||
def dF(self, dY):
|
||||
return dY * self.a
|
||||
|
||||
class Sigmoid(Layer): # aka Logistic
|
||||
def F(self, X):
|
||||
from scipy.special import expit as sigmoid
|
||||
self.sig = sigmoid(X)
|
||||
return X * self.sig
|
||||
|
||||
def dF(self, dY):
|
||||
return dY * self.sig * (1 - self.sig)
|
||||
|
||||
class Tanh(Layer):
|
||||
def F(self, X):
|
||||
self.sig = np.tanh(X)
|
||||
return X * self.sig
|
||||
|
||||
def dF(self, dY):
|
||||
return dY * (1 - self.sig * self.sig)
|
||||
|
||||
class Relu(Layer):
|
||||
def F(self, X):
|
||||
self.cond = X >= 0
|
||||
|
@ -239,7 +256,8 @@ class Relu(Layer):
|
|||
return np.where(self.cond, dY, 0)
|
||||
|
||||
class GeluApprox(Layer):
|
||||
# refer to https://www.desmos.com/calculator/ydzgtccsld
|
||||
# paper: https://arxiv.org/abs/1606.08415
|
||||
# plot: https://www.desmos.com/calculator/ydzgtccsld
|
||||
def F(self, X):
|
||||
from scipy.special import expit as sigmoid
|
||||
self.a = 1.704 * X
|
||||
|
@ -385,26 +403,32 @@ if __name__ == '__main__':
|
|||
config = DotMap(
|
||||
fn = 'ml/cie_mlp_min.h5',
|
||||
|
||||
batch_size = 64,
|
||||
|
||||
# multi-residual network parameters
|
||||
res_width = 12,
|
||||
res_depth = 3,
|
||||
res_block = 2, # normally 2
|
||||
res_multi = 4, # normally 1
|
||||
res_block = 2, # normally 2 for plain resnet
|
||||
res_multi = 4, # normally 1 for plain resnet
|
||||
|
||||
# style of resnet
|
||||
# only one is implemented so far
|
||||
parallel_style = 'batchless',
|
||||
activation = 'gelu',
|
||||
|
||||
optim = 'adam',
|
||||
nesterov = False, # only used with SGD or Adam
|
||||
momentum = 0.33, # only used with SGD
|
||||
epochs = 6, # 6
|
||||
|
||||
# learning parameters: SGD with restarts
|
||||
LR = 1e-2,
|
||||
restarts = 3, # 3
|
||||
epochs = 6,
|
||||
LR_halve_every = 2,
|
||||
restarts = 3,
|
||||
LR_restart_advance = 3,
|
||||
|
||||
# misc
|
||||
batch_size = 64,
|
||||
init = 'he_normal',
|
||||
loss = 'mse',
|
||||
|
||||
parallel_style = 'batchless',
|
||||
)
|
||||
|
||||
# toy CIE-2000 data
|
||||
|
@ -429,7 +453,7 @@ if __name__ == '__main__':
|
|||
y = x
|
||||
last_size = input_samples
|
||||
|
||||
activations = dict(relu=Relu, gelu=GeluApprox)
|
||||
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, gelu=GeluApprox)
|
||||
activation = activations[config.activation]
|
||||
|
||||
for blah in range(config.res_depth):
|
||||
|
|
Loading…
Reference in a new issue