.
This commit is contained in:
parent
5b2a293b8c
commit
cf9010d25f
1 changed files with 24 additions and 22 deletions
46
optim_nn.py
46
optim_nn.py
|
@ -195,12 +195,10 @@ class Layer:
|
||||||
|
|
||||||
def forward(self, lut):
|
def forward(self, lut):
|
||||||
assert len(self.parents) > 0, self
|
assert len(self.parents) > 0, self
|
||||||
#print(" forwarding", self)
|
|
||||||
B = []
|
B = []
|
||||||
for parent in self.parents:
|
for parent in self.parents:
|
||||||
# TODO: skip over irrelevant nodes (if any)
|
# TODO: skip over irrelevant nodes (if any)
|
||||||
X = lut[parent]
|
X = lut[parent]
|
||||||
#print("collected parent", parent)
|
|
||||||
self.validate_input(X)
|
self.validate_input(X)
|
||||||
B.append(X)
|
B.append(X)
|
||||||
Y = self.multi(B)
|
Y = self.multi(B)
|
||||||
|
@ -209,12 +207,10 @@ class Layer:
|
||||||
|
|
||||||
def backward(self, lut):
|
def backward(self, lut):
|
||||||
assert len(self.children) > 0, self
|
assert len(self.children) > 0, self
|
||||||
#print(" backwarding", self)
|
|
||||||
dB = []
|
dB = []
|
||||||
for child in self.children:
|
for child in self.children:
|
||||||
# TODO: skip over irrelevant nodes (if any)
|
# TODO: skip over irrelevant nodes (if any)
|
||||||
dY = lut[child]
|
dY = lut[child]
|
||||||
#print(" collected child", child)
|
|
||||||
self.validate_output(dY)
|
self.validate_output(dY)
|
||||||
dB.append(dY)
|
dB.append(dY)
|
||||||
dX = self.dmulti(dB)
|
dX = self.dmulti(dB)
|
||||||
|
@ -346,11 +342,7 @@ class Model:
|
||||||
assert isinstance(y, Layer), y
|
assert isinstance(y, Layer), y
|
||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
self.ordered_nodes = self.traverse([], self.y)
|
self.ordered_nodes = self.traverse([], self.y)
|
||||||
node_names = ' '.join([str(node) for node in self.ordered_nodes])
|
|
||||||
print('{} nodes: {}'.format(len(self.ordered_nodes), node_names))
|
|
||||||
|
|
||||||
self.make_weights()
|
self.make_weights()
|
||||||
|
|
||||||
def make_weights(self):
|
def make_weights(self):
|
||||||
|
@ -358,7 +350,6 @@ class Model:
|
||||||
for node in self.ordered_nodes:
|
for node in self.ordered_nodes:
|
||||||
if node.size is not None:
|
if node.size is not None:
|
||||||
self.param_count += node.size
|
self.param_count += node.size
|
||||||
print(self.param_count)
|
|
||||||
self.W = np.zeros(self.param_count, dtype=nf)
|
self.W = np.zeros(self.param_count, dtype=nf)
|
||||||
self.dW = np.zeros(self.param_count, dtype=nf)
|
self.dW = np.zeros(self.param_count, dtype=nf)
|
||||||
|
|
||||||
|
@ -369,8 +360,6 @@ class Model:
|
||||||
node.init(self.W[offset:end], self.dW[offset:end])
|
node.init(self.W[offset:end], self.dW[offset:end])
|
||||||
offset += node.size
|
offset += node.size
|
||||||
|
|
||||||
#print(self.W, self.dW)
|
|
||||||
|
|
||||||
def traverse(self, nodes, node):
|
def traverse(self, nodes, node):
|
||||||
if node == x:
|
if node == x:
|
||||||
return [node]
|
return [node]
|
||||||
|
@ -425,6 +414,11 @@ class Model:
|
||||||
raise NotImplementedError("unimplemented", self)
|
raise NotImplementedError("unimplemented", self)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
||||||
|
def log(left, right):
|
||||||
|
lament("{:>20}: {}".format(left, right))
|
||||||
|
|
||||||
# Config
|
# Config
|
||||||
|
|
||||||
from dotmap import DotMap
|
from dotmap import DotMap
|
||||||
|
@ -446,7 +440,7 @@ if __name__ == '__main__':
|
||||||
nesterov = False, # only used with SGD or Adam
|
nesterov = False, # only used with SGD or Adam
|
||||||
momentum = 0.33, # only used with SGD
|
momentum = 0.33, # only used with SGD
|
||||||
|
|
||||||
# learning parameters: SGD with restarts
|
# learning parameters: SGD with restarts (kinda)
|
||||||
LR = 1e-2,
|
LR = 1e-2,
|
||||||
epochs = 6,
|
epochs = 6,
|
||||||
LR_halve_every = 2,
|
LR_halve_every = 2,
|
||||||
|
@ -510,6 +504,10 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
model = Model(x, y)
|
model = Model(x, y)
|
||||||
|
|
||||||
|
node_names = ' '.join([str(node) for node in model.ordered_nodes])
|
||||||
|
log('{} nodes'.format(len(model.ordered_nodes)), node_names)
|
||||||
|
log('parameters', model.param_count)
|
||||||
|
|
||||||
training = config.epochs > 0 and config.restarts >= 0
|
training = config.epochs > 0 and config.restarts >= 0
|
||||||
|
|
||||||
if not training:
|
if not training:
|
||||||
|
@ -536,26 +534,31 @@ if __name__ == '__main__':
|
||||||
LR = config.LR
|
LR = config.LR
|
||||||
LRprod = 0.5**(1/config.LR_halve_every)
|
LRprod = 0.5**(1/config.LR_halve_every)
|
||||||
|
|
||||||
|
LRE = LR * (LRprod**config.LR_restart_advance)**config.restarts * LRprod**(config.epochs - 1)
|
||||||
|
log("final learning rate", "{:10.8f}".format(LRE))
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|
||||||
def measure_loss():
|
def measure_loss():
|
||||||
predicted = model.forward(inputs / x_scale)
|
predicted = model.forward(inputs / x_scale)
|
||||||
residual = predicted - outputs / y_scale
|
residual = predicted - outputs / y_scale
|
||||||
err = loss.mean(residual)
|
err = loss.mean(residual)
|
||||||
print("train loss: {:11.7f}".format(err))
|
log("train loss", "{:11.7f}".format(err))
|
||||||
print("improvement: {:+7.2f}%".format((0.0007031 / err - 1) * 100))
|
log("improvement", "{:+7.2f}%".format((0.0007031 / err - 1) * 100))
|
||||||
|
|
||||||
predicted = model.forward(valid_inputs / x_scale)
|
predicted = model.forward(valid_inputs / x_scale)
|
||||||
residual = predicted - valid_outputs / y_scale
|
residual = predicted - valid_outputs / y_scale
|
||||||
err = loss.mean(residual)
|
err = loss.mean(residual)
|
||||||
print("valid loss: {:11.7f}".format(err))
|
log("valid loss", "{:11.7f}".format(err))
|
||||||
print("improvement: {:+7.2f}%".format((0.0007159 / err - 1) * 100))
|
log("improvement", "{:+7.2f}%".format((0.0007159 / err - 1) * 100))
|
||||||
|
|
||||||
for i in range(config.restarts + 1):
|
for i in range(config.restarts + 1):
|
||||||
measure_loss()
|
measure_loss()
|
||||||
|
|
||||||
if i > 0:
|
if i > 0:
|
||||||
print("restarting")
|
log("restarting", i)
|
||||||
|
LR *= LRprod**config.LR_restart_advance
|
||||||
|
#optim.reset()
|
||||||
|
|
||||||
assert inputs.shape[0] % config.batch_size == 0, \
|
assert inputs.shape[0] % config.batch_size == 0, \
|
||||||
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
"inputs is not evenly divisible by batch_size" # TODO: lift this restriction
|
||||||
|
@ -567,6 +570,7 @@ if __name__ == '__main__':
|
||||||
shuffled_outputs = outputs[indices] / y_scale
|
shuffled_outputs = outputs[indices] / y_scale
|
||||||
|
|
||||||
optim.alpha = LR * LRprod**e
|
optim.alpha = LR * LRprod**e
|
||||||
|
#log("learning rate", "{:10.8f}".format(optim.alpha))
|
||||||
|
|
||||||
cumsum_loss = 0
|
cumsum_loss = 0
|
||||||
for b in range(batch_count):
|
for b in range(batch_count):
|
||||||
|
@ -581,9 +585,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# note: we don't actually need this for training, only monitoring.
|
# note: we don't actually need this for training, only monitoring.
|
||||||
cumsum_loss += loss.mean(residual)
|
cumsum_loss += loss.mean(residual)
|
||||||
print("avg loss: {:10.6f}".format(cumsum_loss / batch_count))
|
log("average loss", "{:11.7f}".format(cumsum_loss / batch_count))
|
||||||
|
|
||||||
LR *= LRprod**config.LR_restart_advance
|
|
||||||
|
|
||||||
measure_loss()
|
measure_loss()
|
||||||
|
|
||||||
|
@ -596,5 +598,5 @@ if __name__ == '__main__':
|
||||||
b = (64, 128, 192)
|
b = (64, 128, 192)
|
||||||
X = np.expand_dims(np.hstack((a, b)), 0) / x_scale
|
X = np.expand_dims(np.hstack((a, b)), 0) / x_scale
|
||||||
P = model.forward(X) * y_scale
|
P = model.forward(X) * y_scale
|
||||||
print("truth:", rgbcompare(a, b))
|
log("truth", rgbcompare(a, b))
|
||||||
print("network:", np.squeeze(P))
|
log("network", np.squeeze(P))
|
||||||
|
|
Loading…
Reference in a new issue