.
This commit is contained in:
parent
8baa7a267a
commit
d299520fd9
1 changed files with 34 additions and 10 deletions
44
optim_nn.py
44
optim_nn.py
|
@ -230,6 +230,23 @@ class Affine(Layer):
|
||||||
def dF(self, dY):
|
def dF(self, dY):
|
||||||
return dY * self.a
|
return dY * self.a
|
||||||
|
|
||||||
|
class Sigmoid(Layer): # aka Logistic
|
||||||
|
def F(self, X):
|
||||||
|
from scipy.special import expit as sigmoid
|
||||||
|
self.sig = sigmoid(X)
|
||||||
|
return X * self.sig
|
||||||
|
|
||||||
|
def dF(self, dY):
|
||||||
|
return dY * self.sig * (1 - self.sig)
|
||||||
|
|
||||||
|
class Tanh(Layer):
|
||||||
|
def F(self, X):
|
||||||
|
self.sig = np.tanh(X)
|
||||||
|
return X * self.sig
|
||||||
|
|
||||||
|
def dF(self, dY):
|
||||||
|
return dY * (1 - self.sig * self.sig)
|
||||||
|
|
||||||
class Relu(Layer):
|
class Relu(Layer):
|
||||||
def F(self, X):
|
def F(self, X):
|
||||||
self.cond = X >= 0
|
self.cond = X >= 0
|
||||||
|
@ -239,7 +256,8 @@ class Relu(Layer):
|
||||||
return np.where(self.cond, dY, 0)
|
return np.where(self.cond, dY, 0)
|
||||||
|
|
||||||
class GeluApprox(Layer):
|
class GeluApprox(Layer):
|
||||||
# refer to https://www.desmos.com/calculator/ydzgtccsld
|
# paper: https://arxiv.org/abs/1606.08415
|
||||||
|
# plot: https://www.desmos.com/calculator/ydzgtccsld
|
||||||
def F(self, X):
|
def F(self, X):
|
||||||
from scipy.special import expit as sigmoid
|
from scipy.special import expit as sigmoid
|
||||||
self.a = 1.704 * X
|
self.a = 1.704 * X
|
||||||
|
@ -385,26 +403,32 @@ if __name__ == '__main__':
|
||||||
config = DotMap(
|
config = DotMap(
|
||||||
fn = 'ml/cie_mlp_min.h5',
|
fn = 'ml/cie_mlp_min.h5',
|
||||||
|
|
||||||
batch_size = 64,
|
# multi-residual network parameters
|
||||||
|
|
||||||
res_width = 12,
|
res_width = 12,
|
||||||
res_depth = 3,
|
res_depth = 3,
|
||||||
res_block = 2, # normally 2
|
res_block = 2, # normally 2 for plain resnet
|
||||||
res_multi = 4, # normally 1
|
res_multi = 4, # normally 1 for plain resnet
|
||||||
|
|
||||||
|
# style of resnet
|
||||||
|
# only one is implemented so far
|
||||||
|
parallel_style = 'batchless',
|
||||||
activation = 'gelu',
|
activation = 'gelu',
|
||||||
|
|
||||||
optim = 'adam',
|
optim = 'adam',
|
||||||
nesterov = False, # only used with SGD or Adam
|
nesterov = False, # only used with SGD or Adam
|
||||||
momentum = 0.33, # only used with SGD
|
momentum = 0.33, # only used with SGD
|
||||||
epochs = 6, # 6
|
|
||||||
|
# learning parameters: SGD with restarts
|
||||||
LR = 1e-2,
|
LR = 1e-2,
|
||||||
restarts = 3, # 3
|
epochs = 6,
|
||||||
LR_halve_every = 2,
|
LR_halve_every = 2,
|
||||||
|
restarts = 3,
|
||||||
LR_restart_advance = 3,
|
LR_restart_advance = 3,
|
||||||
|
|
||||||
|
# misc
|
||||||
|
batch_size = 64,
|
||||||
init = 'he_normal',
|
init = 'he_normal',
|
||||||
loss = 'mse',
|
loss = 'mse',
|
||||||
|
|
||||||
parallel_style = 'batchless',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# toy CIE-2000 data
|
# toy CIE-2000 data
|
||||||
|
@ -429,7 +453,7 @@ if __name__ == '__main__':
|
||||||
y = x
|
y = x
|
||||||
last_size = input_samples
|
last_size = input_samples
|
||||||
|
|
||||||
activations = dict(relu=Relu, gelu=GeluApprox)
|
activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, gelu=GeluApprox)
|
||||||
activation = activations[config.activation]
|
activation = activations[config.activation]
|
||||||
|
|
||||||
for blah in range(config.res_depth):
|
for blah in range(config.res_depth):
|
||||||
|
|
Loading…
Reference in a new issue