fix code i forgot to test, plus some tweaks
This commit is contained in:
parent
7bd5518650
commit
112e263056
2 changed files with 3 additions and 4 deletions
5
onn.py
5
onn.py
|
@ -1014,11 +1014,10 @@ def run(program, args=None):
|
||||||
#optim = 'ftml',
|
#optim = 'ftml',
|
||||||
#optim_decay1 = 2,
|
#optim_decay1 = 2,
|
||||||
#optim_decay2 = 100,
|
#optim_decay2 = 100,
|
||||||
#nesterov = False,
|
|
||||||
optim = 'adam', # note: most features only implemented for Adam
|
optim = 'adam', # note: most features only implemented for Adam
|
||||||
optim_decay1 = 24, # first momentum given in epochs (optional)
|
optim_decay1 = 24, # first momentum given in epochs (optional)
|
||||||
optim_decay2 = 100, # second momentum given in epochs (optional)
|
optim_decay2 = 100, # second momentum given in epochs (optional)
|
||||||
nesterov = True,
|
nesterov = True, # not available for all optimizers.
|
||||||
batch_size = 64,
|
batch_size = 64,
|
||||||
|
|
||||||
# learning parameters
|
# learning parameters
|
||||||
|
@ -1092,7 +1091,7 @@ def run(program, args=None):
|
||||||
def measure_error():
|
def measure_error():
|
||||||
def print_error(name, inputs, outputs, comparison=None):
|
def print_error(name, inputs, outputs, comparison=None):
|
||||||
predicted = model.forward(inputs)
|
predicted = model.forward(inputs)
|
||||||
err = ritual.measure(predicted, outputs)
|
err = ritual.mloss.forward(predicted, outputs)
|
||||||
if config.log10_loss:
|
if config.log10_loss:
|
||||||
print(name, "{:12.6e}".format(err))
|
print(name, "{:12.6e}".format(err))
|
||||||
if comparison:
|
if comparison:
|
||||||
|
|
|
@ -1025,7 +1025,7 @@ class Ritual: # i'm just making up names at this point.
|
||||||
avg_loss = self.cumsum_loss / _f(batch_count)
|
avg_loss = self.cumsum_loss / _f(batch_count)
|
||||||
return avg_loss, avg_mloss, self.losses, self.mlosses
|
return avg_loss, avg_mloss, self.losses, self.mlosses
|
||||||
elif return_losses:
|
elif return_losses:
|
||||||
return avg_mloss, mlosses
|
return avg_mloss, self.mlosses
|
||||||
return avg_mloss
|
return avg_mloss
|
||||||
|
|
||||||
def test_batched(self, inputs, outputs, *args, **kwargs):
|
def test_batched(self, inputs, outputs, *args, **kwargs):
|
||||||
|
|
Loading…
Add table
Reference in a new issue