diff --git a/onn/utility.py b/onn/utility.py index b6d0712..efd65f8 100644 --- a/onn/utility.py +++ b/onn/utility.py @@ -54,3 +54,53 @@ def log(left, right, update=False): class Dummy: pass + + +class Folding: + # NOTE: this class assumes classes are *exactly* evenly distributed. + + def __init__(self, inputs, outputs, folds): + # outputs should be one-hot. + + self.folds = int(folds) + + # this temporarily converts one-hot encoding back to integer indices. + classes = np.argmax(outputs, axis=-1) + + # we need to do stratified k-folds, + # so let's put them in an order that's easy to split + # without breaking class distribution. + # don't worry, they'll get shuffled again in train_batched. + classes = np.argmax(outputs, axis=-1) + class_n = np.max(classes) + 1 + sorted_inputs = np.array([inputs[classes == n] + for n in range(class_n)], inputs.dtype) + sorted_outputs = np.arange(class_n + ).repeat(sorted_inputs.shape[1]).reshape(sorted_inputs.shape[:2]) + + # now to interleave the classes instead of having them grouped: + inputs = np.swapaxes(sorted_inputs, 0, 1 + ).reshape(-1, *sorted_inputs.shape[2:]) + outputs = np.swapaxes(sorted_outputs, 0, 1 + ).reshape(-1, *sorted_outputs.shape[2:]) + + # one final thing: we need to make our outputs one-hot again. + self.inputs = inputs + self.outputs = onehot(outputs) + + # now we can do stratified folds simply by contiguous slices! + self.foldstep = len(self.inputs) // self.folds + assert len(self.inputs) % self.foldstep == 0, \ + "bad number of folds; cannot be stratified" + + def fold(self, i): + roll = i * self.foldstep + split = (self.folds - 1) * self.foldstep + + train_inputs = np.roll(self.inputs, roll, axis=0)[:split] + valid_inputs = np.roll(self.inputs, roll, axis=0)[split:] + + train_outputs = np.roll(self.outputs, roll, axis=0)[:split] + valid_outputs = np.roll(self.outputs, roll, axis=0)[split:] + + return train_inputs, train_outputs, valid_inputs, valid_outputs