add rough stratified k-folding utility class

This commit is contained in:
Connor Olding 2018-03-08 02:41:45 +01:00
parent 65bc9b8a6f
commit 9a45b26b7f

View File

@ -54,3 +54,53 @@ def log(left, right, update=False):
class Dummy:
pass
class Folding:
# NOTE: this class assumes classes are *exactly* evenly distributed.
def __init__(self, inputs, outputs, folds):
# outputs should be one-hot.
self.folds = int(folds)
# this temporarily converts one-hot encoding back to integer indices.
classes = np.argmax(outputs, axis=-1)
# we need to do stratified k-folds,
# so let's put them in an order that's easy to split
# without breaking class distribution.
# don't worry, they'll get shuffled again in train_batched.
classes = np.argmax(outputs, axis=-1)
class_n = np.max(classes) + 1
sorted_inputs = np.array([inputs[classes == n]
for n in range(class_n)], inputs.dtype)
sorted_outputs = np.arange(class_n
).repeat(sorted_inputs.shape[1]).reshape(sorted_inputs.shape[:2])
# now to interleave the classes instead of having them grouped:
inputs = np.swapaxes(sorted_inputs, 0, 1
).reshape(-1, *sorted_inputs.shape[2:])
outputs = np.swapaxes(sorted_outputs, 0, 1
).reshape(-1, *sorted_outputs.shape[2:])
# one final thing: we need to make our outputs one-hot again.
self.inputs = inputs
self.outputs = onehot(outputs)
# now we can do stratified folds simply by contiguous slices!
self.foldstep = len(self.inputs) // self.folds
assert len(self.inputs) % self.foldstep == 0, \
"bad number of folds; cannot be stratified"
def fold(self, i):
roll = i * self.foldstep
split = (self.folds - 1) * self.foldstep
train_inputs = np.roll(self.inputs, roll, axis=0)[:split]
valid_inputs = np.roll(self.inputs, roll, axis=0)[split:]
train_outputs = np.roll(self.outputs, roll, axis=0)[:split]
valid_outputs = np.roll(self.outputs, roll, axis=0)[split:]
return train_inputs, train_outputs, valid_inputs, valid_outputs