From 9b85b49ee51af0b698f0a3e212e2aacc65a5467f Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sun, 18 Jun 2017 01:58:40 +0000 Subject: [PATCH] add minimal example to readme --- README.md | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/README.md b/README.md index effe3fb..57474f6 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,51 @@ python 3.5+ numpy scipy h5py sklearn dotmap +## minimal example + +```python +#!/usr/bin/env python3 +from optim_nn_core import * +bs = 500 +lr = 0.0005 * np.sqrt(bs) +reg = L1L2(3.2e-5, 3.2e-4) +final_reg = L1L2(3.2e-5, 1e-3) + +def get_mnist(fn='mnist.npz'): + with np.load(fn) as f: + return f['X_train'], f['Y_train'], f['X_test'], f['Y_test'] +inputs, outputs, valid_inputs, valid_outputs = get_mnist() + +x = Input(shape=inputs.shape[1:]) +y = x +y = y.feed(Flatten()) +y = y.feed(Dense(y.output_shape[0], init=init_he_normal, reg_w=reg, reg_b=reg)) +y = y.feed(Relu()) +y = y.feed(Dense(y.output_shape[0], init=init_he_normal, reg_w=reg, reg_b=reg)) +y = y.feed(Dropout(0.05)) +y = y.feed(Relu()) +y = y.feed(Dense(10, init=init_glorot_uniform, reg_w=final_reg, reg_b=final_reg)) +y = y.feed(Softmax()) +model = Model(x, y, unsafe=True) + +optim = Adam() +learner = SGDR(optim, epochs=20, rate=lr, restarts=2) +ritual = Ritual(learner=learner, loss=CategoricalCrossentropy(), mloss=Accuracy()) +ritual.prepare(model) +while learner.next(): + print("epoch", learner.epoch) + mloss, _ = ritual.train_batched(inputs, outputs, batch_size=bs, return_losses=True) + print("train accuracy", "{:6.2f}%".format(mloss * 100)) + +def print_error(name, inputs, outputs): + loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both') + predicted = ritual.model.forward(inputs, deterministic=True) + print(name + " loss", "{:12.6e}".format(loss)) + print(name + " accuracy", "{:6.2f}%".format(mloss * 100)) +print_error("train", inputs, outputs) +print_error("valid", valid_inputs, valid_outputs) +``` + ## contributing i'm just throwing this code out there,