add rough stratified k-folding utility class
This commit is contained in:
parent
65bc9b8a6f
commit
9a45b26b7f
1 changed files with 50 additions and 0 deletions
|
@ -54,3 +54,53 @@ def log(left, right, update=False):
|
|||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
|
||||
class Folding:
|
||||
# NOTE: this class assumes classes are *exactly* evenly distributed.
|
||||
|
||||
def __init__(self, inputs, outputs, folds):
|
||||
# outputs should be one-hot.
|
||||
|
||||
self.folds = int(folds)
|
||||
|
||||
# this temporarily converts one-hot encoding back to integer indices.
|
||||
classes = np.argmax(outputs, axis=-1)
|
||||
|
||||
# we need to do stratified k-folds,
|
||||
# so let's put them in an order that's easy to split
|
||||
# without breaking class distribution.
|
||||
# don't worry, they'll get shuffled again in train_batched.
|
||||
classes = np.argmax(outputs, axis=-1)
|
||||
class_n = np.max(classes) + 1
|
||||
sorted_inputs = np.array([inputs[classes == n]
|
||||
for n in range(class_n)], inputs.dtype)
|
||||
sorted_outputs = np.arange(class_n
|
||||
).repeat(sorted_inputs.shape[1]).reshape(sorted_inputs.shape[:2])
|
||||
|
||||
# now to interleave the classes instead of having them grouped:
|
||||
inputs = np.swapaxes(sorted_inputs, 0, 1
|
||||
).reshape(-1, *sorted_inputs.shape[2:])
|
||||
outputs = np.swapaxes(sorted_outputs, 0, 1
|
||||
).reshape(-1, *sorted_outputs.shape[2:])
|
||||
|
||||
# one final thing: we need to make our outputs one-hot again.
|
||||
self.inputs = inputs
|
||||
self.outputs = onehot(outputs)
|
||||
|
||||
# now we can do stratified folds simply by contiguous slices!
|
||||
self.foldstep = len(self.inputs) // self.folds
|
||||
assert len(self.inputs) % self.foldstep == 0, \
|
||||
"bad number of folds; cannot be stratified"
|
||||
|
||||
def fold(self, i):
|
||||
roll = i * self.foldstep
|
||||
split = (self.folds - 1) * self.foldstep
|
||||
|
||||
train_inputs = np.roll(self.inputs, roll, axis=0)[:split]
|
||||
valid_inputs = np.roll(self.inputs, roll, axis=0)[split:]
|
||||
|
||||
train_outputs = np.roll(self.outputs, roll, axis=0)[:split]
|
||||
valid_outputs = np.roll(self.outputs, roll, axis=0)[split:]
|
||||
|
||||
return train_inputs, train_outputs, valid_inputs, valid_outputs
|
||||
|
|
Loading…
Reference in a new issue