fix epoch incrementing

This commit is contained in:
Connor Olding 2017-03-22 21:41:24 +00:00
parent 19a9583da9
commit 1d729b98aa
3 changed files with 12 additions and 13 deletions

View file

@ -859,11 +859,11 @@ def run(program, args=None):
if config.log10_loss:
fmt = "epoch {:4.0f}, rate {:10.8f}, log10-loss {:+6.3f}"
log("info", fmt.format(learner.epoch + 1, learner.rate, np.log10(avg_loss)),
log("info", fmt.format(learner.epoch, learner.rate, np.log10(avg_loss)),
update=True)
else:
fmt = "epoch {:4.0f}, rate {:10.8f}, loss {:12.6e}"
log("info", fmt.format(learner.epoch + 1, learner.rate, avg_loss),
log("info", fmt.format(learner.epoch, learner.rate, avg_loss),
update=True)
measure_error()

View file

@ -845,7 +845,8 @@ class Learner:
@epoch.setter
def epoch(self, new_epoch):
self._epoch = int(new_epoch)
self.rate = self.rate_at(self._epoch)
if 0 <= self.epoch <= self.epochs:
self.rate = self.rate_at(self._epoch)
@property
def rate(self):
@ -860,13 +861,11 @@ class Learner:
def next(self):
# prepares the next epoch. returns whether or not to continue training.
if self.epoch + 1 >= self.epochs:
return False
if self.started:
self.epoch += 1
else:
if not self.started:
self.started = True
self.epoch = self.epoch # poke property setter just in case
self.epoch += 1
if self.epoch > self.epochs:
return False
return True
def batch(self, progress): # TODO: rename
@ -874,7 +873,7 @@ class Learner:
# unlike epochs, we do not store batch number as a state.
# i.e. calling next() will not respect progress.
assert 0 <= progress <= 1
self.rate = self.rate_at(self._epoch + progress)
self.rate = self.rate_at(self._epoch - 1 + progress)
@property
def final_rate(self):
@ -921,7 +920,7 @@ class SGDR(Learner):
def split_num(self, epoch):
shit = [0] + self.splits # hack
for i in range(0, len(self.splits)):
if epoch < self.splits[i]:
if epoch < self.splits[i] + 1:
sub_epoch = epoch - shit[i]
next_restart = self.splits[i] - shit[i]
return i, sub_epoch, next_restart

4
optim_nn_mnist.py Normal file → Executable file
View file

@ -161,14 +161,14 @@ while learner.next():
batch_size=bs,
return_losses='both')
fmt = "rate {:10.8f}, loss {:12.6e}, accuracy {:6.2f}%"
log("epoch {}".format(learner.epoch + 1),
log("epoch {}".format(learner.epoch),
fmt.format(learner.rate, avg_loss, avg_mloss * 100))
batch_losses += losses
batch_mlosses += mlosses
if measure_every_epoch:
quiet = learner.epoch + 1 != learner.epochs
quiet = learner.epoch != learner.epochs
measure_error(quiet=quiet)
if not measure_every_epoch: