fix epoch incrementing
This commit is contained in:
parent
19a9583da9
commit
1d729b98aa
3 changed files with 12 additions and 13 deletions
|
@ -859,11 +859,11 @@ def run(program, args=None):
|
|||
|
||||
if config.log10_loss:
|
||||
fmt = "epoch {:4.0f}, rate {:10.8f}, log10-loss {:+6.3f}"
|
||||
log("info", fmt.format(learner.epoch + 1, learner.rate, np.log10(avg_loss)),
|
||||
log("info", fmt.format(learner.epoch, learner.rate, np.log10(avg_loss)),
|
||||
update=True)
|
||||
else:
|
||||
fmt = "epoch {:4.0f}, rate {:10.8f}, loss {:12.6e}"
|
||||
log("info", fmt.format(learner.epoch + 1, learner.rate, avg_loss),
|
||||
log("info", fmt.format(learner.epoch, learner.rate, avg_loss),
|
||||
update=True)
|
||||
|
||||
measure_error()
|
||||
|
|
|
@ -845,7 +845,8 @@ class Learner:
|
|||
@epoch.setter
|
||||
def epoch(self, new_epoch):
|
||||
self._epoch = int(new_epoch)
|
||||
self.rate = self.rate_at(self._epoch)
|
||||
if 0 <= self.epoch <= self.epochs:
|
||||
self.rate = self.rate_at(self._epoch)
|
||||
|
||||
@property
|
||||
def rate(self):
|
||||
|
@ -860,13 +861,11 @@ class Learner:
|
|||
|
||||
def next(self):
|
||||
# prepares the next epoch. returns whether or not to continue training.
|
||||
if self.epoch + 1 >= self.epochs:
|
||||
return False
|
||||
if self.started:
|
||||
self.epoch += 1
|
||||
else:
|
||||
if not self.started:
|
||||
self.started = True
|
||||
self.epoch = self.epoch # poke property setter just in case
|
||||
self.epoch += 1
|
||||
if self.epoch > self.epochs:
|
||||
return False
|
||||
return True
|
||||
|
||||
def batch(self, progress): # TODO: rename
|
||||
|
@ -874,7 +873,7 @@ class Learner:
|
|||
# unlike epochs, we do not store batch number as a state.
|
||||
# i.e. calling next() will not respect progress.
|
||||
assert 0 <= progress <= 1
|
||||
self.rate = self.rate_at(self._epoch + progress)
|
||||
self.rate = self.rate_at(self._epoch - 1 + progress)
|
||||
|
||||
@property
|
||||
def final_rate(self):
|
||||
|
@ -921,7 +920,7 @@ class SGDR(Learner):
|
|||
def split_num(self, epoch):
|
||||
shit = [0] + self.splits # hack
|
||||
for i in range(0, len(self.splits)):
|
||||
if epoch < self.splits[i]:
|
||||
if epoch < self.splits[i] + 1:
|
||||
sub_epoch = epoch - shit[i]
|
||||
next_restart = self.splits[i] - shit[i]
|
||||
return i, sub_epoch, next_restart
|
||||
|
|
4
optim_nn_mnist.py
Normal file → Executable file
4
optim_nn_mnist.py
Normal file → Executable file
|
@ -161,14 +161,14 @@ while learner.next():
|
|||
batch_size=bs,
|
||||
return_losses='both')
|
||||
fmt = "rate {:10.8f}, loss {:12.6e}, accuracy {:6.2f}%"
|
||||
log("epoch {}".format(learner.epoch + 1),
|
||||
log("epoch {}".format(learner.epoch),
|
||||
fmt.format(learner.rate, avg_loss, avg_mloss * 100))
|
||||
|
||||
batch_losses += losses
|
||||
batch_mlosses += mlosses
|
||||
|
||||
if measure_every_epoch:
|
||||
quiet = learner.epoch + 1 != learner.epochs
|
||||
quiet = learner.epoch != learner.epochs
|
||||
measure_error(quiet=quiet)
|
||||
|
||||
if not measure_every_epoch:
|
||||
|
|
Loading…
Reference in a new issue