.
This commit is contained in:
parent
5b2a293b8c
commit
cf9010d25f
1 changed files with 24 additions and 22 deletions
46
optim_nn.py
46
optim_nn.py
|
@ -195,12 +195,10 @@ class Layer:
|
|||
|
||||
def forward(self, lut):
|
||||
assert len(self.parents) > 0, self
|
||||
#print(" forwarding", self)
|
||||
B = []
|
||||
for parent in self.parents:
|
||||
# TODO: skip over irrelevant nodes (if any)
|
||||
X = lut[parent]
|
||||
#print("collected parent", parent)
|
||||
self.validate_input(X)
|
||||
B.append(X)
|
||||
Y = self.multi(B)
|
||||
|
@ -209,12 +207,10 @@ class Layer:
|
|||
|
||||
def backward(self, lut):
|
||||
assert len(self.children) > 0, self
|
||||
#print(" backwarding", self)
|
||||
dB = []
|
||||
for child in self.children:
|
||||
# TODO: skip over irrelevant nodes (if any)
|
||||
dY = lut[child]
|
||||
#print(" collected child", child)
|
||||
self.validate_output(dY)
|
||||
dB.append(dY)
|
||||
dX = self.dmulti(dB)
|
||||
|
@ -346,11 +342,7 @@ class Model:
|
|||
assert isinstance(y, Layer), y
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
self.ordered_nodes = self.traverse([], self.y)
|
||||
node_names = ' '.join([str(node) for node in self.ordered_nodes])
|
||||
print('{} nodes: {}'.format(len(self.ordered_nodes), node_names))
|
||||
|
||||
self.make_weights()
|
||||
|
||||
def make_weights(self):
|
||||
|
@ -358,7 +350,6 @@ class Model:
|
|||
for node in self.ordered_nodes:
|
||||
if node.size is not None:
|
||||
self.param_count += node.size
|
||||
print(self.param_count)
|
||||
self.W = np.zeros(self.param_count, dtype=nf)
|
||||
self.dW = np.zeros(self.param_count, dtype=nf)
|
||||
|
||||
|
@ -369,8 +360,6 @@ class Model:
|
|||
node.init(self.W[offset:end], self.dW[offset:end])
|
||||
offset += node.size
|
||||
|
||||
#print(self.W, self.dW)
|
||||
|
||||
def traverse(self, nodes, node):
|
||||
if node == x:
|
||||
return [node]
|
||||
|
@ -425,6 +414,11 @@ class Model:
|
|||
raise NotImplementedError("unimplemented", self)
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
||||
def log(left, right):
|
||||
lament("{:>20}: {}".format(left, right))
|
||||
|
||||
# Config
|
||||
|
||||
from dotmap import DotMap
|
||||
|
@ -446,7 +440,7 @@ if __name__ == '__main__':
|
|||
nesterov = False, # only used with SGD or Adam
|
||||
momentum = 0.33, # only used with SGD
|
||||
|
||||
# learning parameters: SGD with restarts
|
||||
# learning parameters: SGD with restarts (kinda)
|
||||
LR = 1e-2,
|
||||
epochs = 6,
|
||||
LR_halve_every = 2,
|
||||
|
@ -510,6 +504,10 @@ if __name__ == '__main__':
|
|||
|
||||
model = Model(x, y)
|
||||
|
||||
node_names = ' '.join([str(node) for node in model.ordered_nodes])
|
||||
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
|
||||
log('parameters', model.param_count)
|
||||
|
||||
training = config.epochs > 0 and config.restarts >= 0
|
||||
|
||||
if not training:
|
||||
|
@ -536,26 +534,31 @@ if __name__ == '__main__':
|
|||
LR = config.LR
|
||||
LRprod = 0.5**(1/config.LR_halve_every)
|
||||
|
||||
LRE = LR * (LRprod**config.LR_restart_advance)**config.restarts * LRprod**(config.epochs - 1)
|
||||
log("final learning rate", "{:10.8f}".format(LRE))
|
||||
|
||||
# Training
|
||||
|
||||
def measure_loss():
|
||||
predicted = model.forward(inputs / x_scale)
|
||||
residual = predicted - outputs / y_scale
|
||||
err = loss.mean(residual)
|
||||
print("train loss: {:11.7f}".format(err))
|
||||
print("improvement: {:+7.2f}%".format((0.0007031 / err - 1) * 100))
|
||||
log("train loss", "{:11.7f}".format(err))
|
||||
log("improvement", "{:+7.2f}%".format((0.0007031 / err - 1) * 100))
|
||||
|
||||
predicted = model.forward(valid_inputs / x_scale)
|
||||
residual = predicted - valid_outputs / y_scale
|
||||
err = loss.mean(residual)
|
||||
print("valid loss: {:11.7f}".format(err))
|
||||
print("improvement: {:+7.2f}%".format((0.0007159 / err - 1) * 100))
|
||||
log("valid loss", "{:11.7f}".format(err))
|
||||
log("improvement", "{:+7.2f}%".format((0.0007159 / err - 1) * 100))
|
||||
|
||||
for i in range(config.restarts + 1):
|
||||
measure_loss()
|
||||
|
||||
if i > 0:
|
||||
print("restarting")
|
||||
log("restarting", i)
|
||||
LR *= LRprod**config.LR_restart_advance
|
||||
#optim.reset()
|
||||
|
||||
assert inputs.shape[0] % config.batch_size == 0, \
|
||||
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
||||
|
@ -567,6 +570,7 @@ if __name__ == '__main__':
|
|||
shuffled_outputs = outputs[indices] / y_scale
|
||||
|
||||
optim.alpha = LR * LRprod**e
|
||||
#log("learning rate", "{:10.8f}".format(optim.alpha))
|
||||
|
||||
cumsum_loss = 0
|
||||
for b in range(batch_count):
|
||||
|
@ -581,9 +585,7 @@ if __name__ == '__main__':
|
|||
|
||||
# note: we don't actually need this for training, only monitoring.
|
||||
cumsum_loss += loss.mean(residual)
|
||||
print("avg loss: {:10.6f}".format(cumsum_loss / batch_count))
|
||||
|
||||
LR *= LRprod**config.LR_restart_advance
|
||||
log("average loss", "{:11.7f}".format(cumsum_loss / batch_count))
|
||||
|
||||
measure_loss()
|
||||
|
||||
|
@ -596,5 +598,5 @@ if __name__ == '__main__':
|
|||
b = (64, 128, 192)
|
||||
X = np.expand_dims(np.hstack((a, b)), 0) / x_scale
|
||||
P = model.forward(X) * y_scale
|
||||
print("truth:", rgbcompare(a, b))
|
||||
print("network:", np.squeeze(P))
|
||||
log("truth", rgbcompare(a, b))
|
||||
log("network", np.squeeze(P))
|
||||
|
|
Loading…
Reference in a new issue