add minimal example to readme
This commit is contained in:
parent
264c3abd83
commit
9b85b49ee5
1 changed files with 45 additions and 0 deletions
45
README.md
45
README.md
|
@ -37,6 +37,51 @@ python 3.5+
|
|||
|
||||
numpy scipy h5py sklearn dotmap
|
||||
|
||||
## minimal example
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
from optim_nn_core import *
|
||||
bs = 500
|
||||
lr = 0.0005 * np.sqrt(bs)
|
||||
reg = L1L2(3.2e-5, 3.2e-4)
|
||||
final_reg = L1L2(3.2e-5, 1e-3)
|
||||
|
||||
def get_mnist(fn='mnist.npz'):
|
||||
with np.load(fn) as f:
|
||||
return f['X_train'], f['Y_train'], f['X_test'], f['Y_test']
|
||||
inputs, outputs, valid_inputs, valid_outputs = get_mnist()
|
||||
|
||||
x = Input(shape=inputs.shape[1:])
|
||||
y = x
|
||||
y = y.feed(Flatten())
|
||||
y = y.feed(Dense(y.output_shape[0], init=init_he_normal, reg_w=reg, reg_b=reg))
|
||||
y = y.feed(Relu())
|
||||
y = y.feed(Dense(y.output_shape[0], init=init_he_normal, reg_w=reg, reg_b=reg))
|
||||
y = y.feed(Dropout(0.05))
|
||||
y = y.feed(Relu())
|
||||
y = y.feed(Dense(10, init=init_glorot_uniform, reg_w=final_reg, reg_b=final_reg))
|
||||
y = y.feed(Softmax())
|
||||
model = Model(x, y, unsafe=True)
|
||||
|
||||
optim = Adam()
|
||||
learner = SGDR(optim, epochs=20, rate=lr, restarts=2)
|
||||
ritual = Ritual(learner=learner, loss=CategoricalCrossentropy(), mloss=Accuracy())
|
||||
ritual.prepare(model)
|
||||
while learner.next():
|
||||
print("epoch", learner.epoch)
|
||||
mloss, _ = ritual.train_batched(inputs, outputs, batch_size=bs, return_losses=True)
|
||||
print("train accuracy", "{:6.2f}%".format(mloss * 100))
|
||||
|
||||
def print_error(name, inputs, outputs):
|
||||
loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both')
|
||||
predicted = ritual.model.forward(inputs, deterministic=True)
|
||||
print(name + " loss", "{:12.6e}".format(loss))
|
||||
print(name + " accuracy", "{:6.2f}%".format(mloss * 100))
|
||||
print_error("train", inputs, outputs)
|
||||
print_error("valid", valid_inputs, valid_outputs)
|
||||
```
|
||||
|
||||
## contributing
|
||||
|
||||
i'm just throwing this code out there,
|
||||
|
|
Loading…
Reference in a new issue