update example
This commit is contained in:
parent
a760c4841b
commit
5b6fd6259f
1 changed files with 5 additions and 5 deletions
10
README.md
10
README.md
|
@ -37,7 +37,7 @@ numpy scipy h5py sklearn dotmap
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from onn_core import *
|
from onn_core import *
|
||||||
bs = 500
|
bs = 500
|
||||||
lr = 0.0005 * np.sqrt(bs)
|
lr = 0.01
|
||||||
reg = L1L2(3.2e-5, 3.2e-4)
|
reg = L1L2(3.2e-5, 3.2e-4)
|
||||||
final_reg = L1L2(3.2e-5, 1e-3)
|
final_reg = L1L2(3.2e-5, 1e-3)
|
||||||
|
|
||||||
|
@ -56,11 +56,11 @@ y = y.feed(Dropout(0.05))
|
||||||
y = y.feed(Relu())
|
y = y.feed(Relu())
|
||||||
y = y.feed(Dense(10, init=init_glorot_uniform, reg_w=final_reg, reg_b=final_reg))
|
y = y.feed(Dense(10, init=init_glorot_uniform, reg_w=final_reg, reg_b=final_reg))
|
||||||
y = y.feed(Softmax())
|
y = y.feed(Softmax())
|
||||||
model = Model(x, y, unsafe=True)
|
model = Model(x, y, loss=CategoricalCrossentropy(), mloss=Accuracy(), unsafe=True)
|
||||||
|
|
||||||
optim = Adam()
|
optim = Adam()
|
||||||
learner = SGDR(optim, epochs=20, rate=lr, restarts=2)
|
learner = SGDR(optim, epochs=20, rate=lr, restarts=1)
|
||||||
ritual = Ritual(learner=learner, loss=CategoricalCrossentropy(), mloss=Accuracy())
|
ritual = Ritual(learner=learner)
|
||||||
ritual.prepare(model)
|
ritual.prepare(model)
|
||||||
while learner.next():
|
while learner.next():
|
||||||
print("epoch", learner.epoch)
|
print("epoch", learner.epoch)
|
||||||
|
@ -69,11 +69,11 @@ while learner.next():
|
||||||
|
|
||||||
def print_error(name, inputs, outputs):
|
def print_error(name, inputs, outputs):
|
||||||
loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both')
|
loss, mloss, _, _ = ritual.test_batched(inputs, outputs, bs, return_losses='both')
|
||||||
predicted = ritual.model.forward(inputs, deterministic=True)
|
|
||||||
print(name + " loss", "{:12.6e}".format(loss))
|
print(name + " loss", "{:12.6e}".format(loss))
|
||||||
print(name + " accuracy", "{:6.2f}%".format(mloss * 100))
|
print(name + " accuracy", "{:6.2f}%".format(mloss * 100))
|
||||||
print_error("train", inputs, outputs)
|
print_error("train", inputs, outputs)
|
||||||
print_error("valid", valid_inputs, valid_outputs)
|
print_error("valid", valid_inputs, valid_outputs)
|
||||||
|
predicted = model.evaluate(inputs) # use this as you will!
|
||||||
```
|
```
|
||||||
|
|
||||||
## contributing
|
## contributing
|
||||||
|
|
Loading…
Add table
Reference in a new issue