update example

This commit is contained in:
Connor Olding 2017-09-25 06:28:59 +00:00
parent a760c4841b
commit 5b6fd6259f

View file

@ -37,7 +37,7 @@ numpy scipy h5py sklearn dotmap
#!/usr/bin/env python3 #!/usr/bin/env python3
from onn_core import * from onn_core import *
bs = 500 bs = 500
lr = 0.0005 * np.sqrt(bs) lr = 0.01
reg = L1L2(3.2e-5, 3.2e-4) reg = L1L2(3.2e-5, 3.2e-4)
final_reg = L1L2(3.2e-5, 1e-3) final_reg = L1L2(3.2e-5, 1e-3)
@ -56,11 +56,11 @@ y = y.feed(Dropout(0.05))
y = y.feed(Relu()) y = y.feed(Relu())
y = y.feed(Dense(10, init=init_glorot_uniform, reg_w=final_reg, reg_b=final_reg)) y = y.feed(Dense(10, init=init_glorot_uniform, reg_w=final_reg, reg_b=final_reg))
y = y.feed(Softmax()) y = y.feed(Softmax())
model = Model(x, y, unsafe=True) model = Model(x, y, loss=CategoricalCrossentropy(), mloss=Accuracy(), unsafe=True)
optim = Adam() optim = Adam()
learner = SGDR(optim, epochs=20, rate=lr, restarts=2) learner = SGDR(optim, epochs=20, rate=lr, restarts=1)
ritual = Ritual(learner=learner, loss=CategoricalCrossentropy(), mloss=Accuracy()) ritual = Ritual(learner=learner)
ritual.prepare(model) ritual.prepare(model)
while learner.next(): while learner.next():
print("epoch", learner.epoch) print("epoch", learner.epoch)
@ -69,11 +69,11 @@ while learner.next():
def print_error(name, inputs, outputs): def print_error(name, inputs, outputs):
loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both') loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both')
predicted = ritual.model.forward(inputs, deterministic=True)
print(name + " loss", "{:12.6e}".format(loss)) print(name + " loss", "{:12.6e}".format(loss))
print(name + " accuracy", "{:6.2f}%".format(mloss * 100)) print(name + " accuracy", "{:6.2f}%".format(mloss * 100))
print_error("train", inputs, outputs) print_error("train", inputs, outputs)
print_error("valid", valid_inputs, valid_outputs) print_error("valid", valid_inputs, valid_outputs)
predicted = model.evaluate(inputs) # use this as you will!
``` ```
## contributing ## contributing