From 5a07cdac3218f05840b3b25981f4ec3eae52f21e Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Mon, 11 Feb 2019 20:30:31 +0100 Subject: [PATCH] add mixup variant of batchize --- onn/utility.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/onn/utility.py b/onn/utility.py index 06ab2f9..7e6111c 100644 --- a/onn/utility.py +++ b/onn/utility.py @@ -63,6 +63,36 @@ def batchize(inputs, outputs, batch_size, shuffle=True): return gen(), batch_count +def mixup(inputs, outputs, batch_size, p=0.1): + # paper: https://arxiv.org/abs/1710.09412 + + if p == 0: + return batchize(inputs, outputs, batch_size) + + batch_count = np.ceil(len(inputs) / batch_size).astype(int) + + def lerp(a, b, t): + t = t.reshape([len(t)] + [1] * (a.ndim - t.ndim)) + return (1 - t) * a + t * b + + def gen(): + indices0 = np.arange(len(inputs)) + indices1 = np.arange(len(inputs)) + np.random.shuffle(indices0) + np.random.shuffle(indices1) + + for b in range(batch_count): + bi = b * batch_size + ind0 = indices0[bi:bi + batch_size] + ind1 = indices1[bi:bi + batch_size] + ps = np.random.beta(p, p, size=batch_size) + batch_inputs = lerp(inputs[ind0], inputs[ind1], ps) + batch_outputs = lerp(outputs[ind0], outputs[ind1], ps) + yield batch_inputs, batch_outputs + + return gen(), batch_count + + # more _log_was_update = False