diff --git a/optim_nn.py b/optim_nn.py index 8f66952..f50dbe7 100755 --- a/optim_nn.py +++ b/optim_nn.py @@ -859,11 +859,11 @@ def run(program, args=None): if config.log10_loss: fmt = "epoch {:4.0f}, rate {:10.8f}, log10-loss {:+6.3f}" - log("info", fmt.format(learner.epoch + 1, learner.rate, np.log10(avg_loss)), + log("info", fmt.format(learner.epoch, learner.rate, np.log10(avg_loss)), update=True) else: fmt = "epoch {:4.0f}, rate {:10.8f}, loss {:12.6e}" - log("info", fmt.format(learner.epoch + 1, learner.rate, avg_loss), + log("info", fmt.format(learner.epoch, learner.rate, avg_loss), update=True) measure_error() diff --git a/optim_nn_core.py b/optim_nn_core.py index 9361f20..47316db 100644 --- a/optim_nn_core.py +++ b/optim_nn_core.py @@ -845,7 +845,8 @@ class Learner: @epoch.setter def epoch(self, new_epoch): self._epoch = int(new_epoch) - self.rate = self.rate_at(self._epoch) + if 0 <= self.epoch <= self.epochs: + self.rate = self.rate_at(self._epoch) @property def rate(self): @@ -860,13 +861,11 @@ class Learner: def next(self): # prepares the next epoch. returns whether or not to continue training. - if self.epoch + 1 >= self.epochs: - return False - if self.started: - self.epoch += 1 - else: + if not self.started: self.started = True - self.epoch = self.epoch # poke property setter just in case + self.epoch += 1 + if self.epoch > self.epochs: + return False return True def batch(self, progress): # TODO: rename @@ -874,7 +873,7 @@ class Learner: # unlike epochs, we do not store batch number as a state. # i.e. calling next() will not respect progress. assert 0 <= progress <= 1 - self.rate = self.rate_at(self._epoch + progress) + self.rate = self.rate_at(self._epoch - 1 + progress) @property def final_rate(self): @@ -921,7 +920,7 @@ class SGDR(Learner): def split_num(self, epoch): shit = [0] + self.splits # hack for i in range(0, len(self.splits)): - if epoch < self.splits[i]: + if epoch < self.splits[i] + 1: sub_epoch = epoch - shit[i] next_restart = self.splits[i] - shit[i] return i, sub_epoch, next_restart diff --git a/optim_nn_mnist.py b/optim_nn_mnist.py old mode 100644 new mode 100755 index ce3e257..fc092ef --- a/optim_nn_mnist.py +++ b/optim_nn_mnist.py @@ -161,14 +161,14 @@ while learner.next(): batch_size=bs, return_losses='both') fmt = "rate {:10.8f}, loss {:12.6e}, accuracy {:6.2f}%" - log("epoch {}".format(learner.epoch + 1), + log("epoch {}".format(learner.epoch), fmt.format(learner.rate, avg_loss, avg_mloss * 100)) batch_losses += losses batch_mlosses += mlosses if measure_every_epoch: - quiet = learner.epoch + 1 != learner.epochs + quiet = learner.epoch != learner.epochs measure_error(quiet=quiet) if not measure_every_epoch: