optim/onn/initialization.py

39 lines
809 B
Python
Raw Normal View History

2018-01-21 14:04:25 -08:00
import numpy as np
# note: these are currently only implemented for 2D shapes.
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
def init_zeros(size, ins=None, outs=None):
return np.zeros(size)
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
def init_ones(size, ins=None, outs=None):
return np.ones(size)
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
def init_he_normal(size, ins, outs):
s = np.sqrt(2 / ins)
return np.random.normal(0, s, size=size)
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
def init_he_uniform(size, ins, outs):
s = np.sqrt(6 / ins)
return np.random.uniform(-s, s, size=size)
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
def init_glorot_normal(size, ins, outs):
s = np.sqrt(2 / (ins + outs))
return np.random.normal(0, s, size=size)
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
def init_glorot_uniform(size, ins, outs):
s = np.sqrt(6 / (ins + outs))
return np.random.uniform(-s, s, size=size)
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
# more
def init_gaussian_unit(size, ins, outs):
s = np.sqrt(1 / ins)
return np.random.normal(0, s, size=size)