diff --git a/optim_nn.py b/optim_nn.py index 0031d79..7d50c05 100644 --- a/optim_nn.py +++ b/optim_nn.py @@ -230,6 +230,23 @@ class Affine(Layer): def dF(self, dY): return dY * self.a +class Sigmoid(Layer): # aka Logistic + def F(self, X): + from scipy.special import expit as sigmoid + self.sig = sigmoid(X) + return X * self.sig + + def dF(self, dY): + return dY * self.sig * (1 - self.sig) + +class Tanh(Layer): + def F(self, X): + self.sig = np.tanh(X) + return X * self.sig + + def dF(self, dY): + return dY * (1 - self.sig * self.sig) + class Relu(Layer): def F(self, X): self.cond = X >= 0 @@ -239,7 +256,8 @@ class Relu(Layer): return np.where(self.cond, dY, 0) class GeluApprox(Layer): - # refer to https://www.desmos.com/calculator/ydzgtccsld + # paper: https://arxiv.org/abs/1606.08415 + # plot: https://www.desmos.com/calculator/ydzgtccsld def F(self, X): from scipy.special import expit as sigmoid self.a = 1.704 * X @@ -385,26 +403,32 @@ if __name__ == '__main__': config = DotMap( fn = 'ml/cie_mlp_min.h5', - batch_size = 64, - + # multi-residual network parameters res_width = 12, res_depth = 3, - res_block = 2, # normally 2 - res_multi = 4, # normally 1 + res_block = 2, # normally 2 for plain resnet + res_multi = 4, # normally 1 for plain resnet + + # style of resnet + # only one is implemented so far + parallel_style = 'batchless', activation = 'gelu', optim = 'adam', nesterov = False, # only used with SGD or Adam momentum = 0.33, # only used with SGD - epochs = 6, # 6 + + # learning parameters: SGD with restarts LR = 1e-2, - restarts = 3, # 3 + epochs = 6, LR_halve_every = 2, + restarts = 3, LR_restart_advance = 3, + + # misc + batch_size = 64, init = 'he_normal', loss = 'mse', - - parallel_style = 'batchless', ) # toy CIE-2000 data @@ -429,7 +453,7 @@ if __name__ == '__main__': y = x last_size = input_samples - activations = dict(relu=Relu, gelu=GeluApprox) + activations = dict(sigmoid=Sigmoid, tanh=Tanh, relu=Relu, gelu=GeluApprox) activation = activations[config.activation] for blah in range(config.res_depth):