fix epoch incrementing
This commit is contained in:
parent
19a9583da9
commit
1d729b98aa
|
@ -859,11 +859,11 @@ def run(program, args=None):
|
||||||
|
|
||||||
if config.log10_loss:
|
if config.log10_loss:
|
||||||
fmt = "epoch {:4.0f}, rate {:10.8f}, log10-loss {:+6.3f}"
|
fmt = "epoch {:4.0f}, rate {:10.8f}, log10-loss {:+6.3f}"
|
||||||
log("info", fmt.format(learner.epoch + 1, learner.rate, np.log10(avg_loss)),
|
log("info", fmt.format(learner.epoch, learner.rate, np.log10(avg_loss)),
|
||||||
update=True)
|
update=True)
|
||||||
else:
|
else:
|
||||||
fmt = "epoch {:4.0f}, rate {:10.8f}, loss {:12.6e}"
|
fmt = "epoch {:4.0f}, rate {:10.8f}, loss {:12.6e}"
|
||||||
log("info", fmt.format(learner.epoch + 1, learner.rate, avg_loss),
|
log("info", fmt.format(learner.epoch, learner.rate, avg_loss),
|
||||||
update=True)
|
update=True)
|
||||||
|
|
||||||
measure_error()
|
measure_error()
|
||||||
|
|
|
@ -845,6 +845,7 @@ class Learner:
|
||||||
@epoch.setter
|
@epoch.setter
|
||||||
def epoch(self, new_epoch):
|
def epoch(self, new_epoch):
|
||||||
self._epoch = int(new_epoch)
|
self._epoch = int(new_epoch)
|
||||||
|
if 0 <= self.epoch <= self.epochs:
|
||||||
self.rate = self.rate_at(self._epoch)
|
self.rate = self.rate_at(self._epoch)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -860,13 +861,11 @@ class Learner:
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
# prepares the next epoch. returns whether or not to continue training.
|
# prepares the next epoch. returns whether or not to continue training.
|
||||||
if self.epoch + 1 >= self.epochs:
|
if not self.started:
|
||||||
return False
|
|
||||||
if self.started:
|
|
||||||
self.epoch += 1
|
|
||||||
else:
|
|
||||||
self.started = True
|
self.started = True
|
||||||
self.epoch = self.epoch # poke property setter just in case
|
self.epoch += 1
|
||||||
|
if self.epoch > self.epochs:
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def batch(self, progress): # TODO: rename
|
def batch(self, progress): # TODO: rename
|
||||||
|
@ -874,7 +873,7 @@ class Learner:
|
||||||
# unlike epochs, we do not store batch number as a state.
|
# unlike epochs, we do not store batch number as a state.
|
||||||
# i.e. calling next() will not respect progress.
|
# i.e. calling next() will not respect progress.
|
||||||
assert 0 <= progress <= 1
|
assert 0 <= progress <= 1
|
||||||
self.rate = self.rate_at(self._epoch + progress)
|
self.rate = self.rate_at(self._epoch - 1 + progress)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def final_rate(self):
|
def final_rate(self):
|
||||||
|
@ -921,7 +920,7 @@ class SGDR(Learner):
|
||||||
def split_num(self, epoch):
|
def split_num(self, epoch):
|
||||||
shit = [0] + self.splits # hack
|
shit = [0] + self.splits # hack
|
||||||
for i in range(0, len(self.splits)):
|
for i in range(0, len(self.splits)):
|
||||||
if epoch < self.splits[i]:
|
if epoch < self.splits[i] + 1:
|
||||||
sub_epoch = epoch - shit[i]
|
sub_epoch = epoch - shit[i]
|
||||||
next_restart = self.splits[i] - shit[i]
|
next_restart = self.splits[i] - shit[i]
|
||||||
return i, sub_epoch, next_restart
|
return i, sub_epoch, next_restart
|
||||||
|
|
4
optim_nn_mnist.py
Normal file → Executable file
4
optim_nn_mnist.py
Normal file → Executable file
|
@ -161,14 +161,14 @@ while learner.next():
|
||||||
batch_size=bs,
|
batch_size=bs,
|
||||||
return_losses='both')
|
return_losses='both')
|
||||||
fmt = "rate {:10.8f}, loss {:12.6e}, accuracy {:6.2f}%"
|
fmt = "rate {:10.8f}, loss {:12.6e}, accuracy {:6.2f}%"
|
||||||
log("epoch {}".format(learner.epoch + 1),
|
log("epoch {}".format(learner.epoch),
|
||||||
fmt.format(learner.rate, avg_loss, avg_mloss * 100))
|
fmt.format(learner.rate, avg_loss, avg_mloss * 100))
|
||||||
|
|
||||||
batch_losses += losses
|
batch_losses += losses
|
||||||
batch_mlosses += mlosses
|
batch_mlosses += mlosses
|
||||||
|
|
||||||
if measure_every_epoch:
|
if measure_every_epoch:
|
||||||
quiet = learner.epoch + 1 != learner.epochs
|
quiet = learner.epoch != learner.epochs
|
||||||
measure_error(quiet=quiet)
|
measure_error(quiet=quiet)
|
||||||
|
|
||||||
if not measure_every_epoch:
|
if not measure_every_epoch:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user