.
This commit is contained in:
parent
42a66d4d6c
commit
7e5bb731da
1 changed files with 17 additions and 8 deletions
25
optim_nn.py
25
optim_nn.py
|
@ -574,7 +574,7 @@ def run(program, args=None):
|
|||
mloss = 'mse',
|
||||
ritual = 'default',
|
||||
restart_optim = False, # restarts also reset internal state of optimizer
|
||||
warmup = True,
|
||||
warmup = True, # train a couple epochs on gaussian noise and reset
|
||||
log10_loss = True, # personally, i'm sick of looking linear loss values!
|
||||
|
||||
problem = 3,
|
||||
|
@ -657,16 +657,27 @@ def run(program, args=None):
|
|||
|
||||
# use plain SGD in warmup to prevent (or possibly cause?) numeric issues
|
||||
temp_optim = learner.optim
|
||||
learner.optim = Optimizer(alpha=0.01)
|
||||
temp_loss = ritual.loss
|
||||
learner.optim = Optimizer(alpha=0.001)
|
||||
ritual.loss = Absolute() # less likely to blow up; more general
|
||||
|
||||
for _ in range(2):
|
||||
# NOTE: experiment: trying const batches and batch_size
|
||||
bs = 256
|
||||
target = 1 * 1024 * 1024
|
||||
# 4 being sizeof(float)
|
||||
batches = (target / 4 / np.prod(inputs.shape[1:])) // bs * bs
|
||||
ins = [int(batches)] + list( inputs.shape[1:])
|
||||
outs = [int(batches)] + list(outputs.shape[1:])
|
||||
|
||||
for _ in range(4):
|
||||
ritual.train_batched(
|
||||
np.random.normal(size=inputs.shape),
|
||||
np.random.normal(size=outputs.shape),
|
||||
config.batch_size)
|
||||
np.random.normal(size=ins),
|
||||
np.random.normal(size=outs),
|
||||
batch_size=bs)
|
||||
ritual.reset()
|
||||
|
||||
learner.optim = temp_optim
|
||||
ritual.loss = temp_loss
|
||||
|
||||
if training:
|
||||
measure_error()
|
||||
|
@ -683,8 +694,6 @@ def run(program, args=None):
|
|||
return_losses=True)
|
||||
batch_losses += losses
|
||||
|
||||
#log("learning rate", "{:10.8f}".format(learner.rate))
|
||||
#log("average loss", "{:11.7f}".format(avg_loss))
|
||||
if config.log10_loss:
|
||||
fmt = "epoch {:4.0f}, rate {:10.8f}, log10-loss {:+6.3f}"
|
||||
log("info", fmt.format(learner.epoch + 1, learner.rate, np.log10(avg_loss)))
|
||||
|
|
Loading…
Reference in a new issue