add mixup variant of batchize

This commit is contained in:
Connor Olding 2019-02-11 20:30:31 +01:00
parent 2cfcc9062e
commit 5a07cdac32

View file

@ -63,6 +63,36 @@ def batchize(inputs, outputs, batch_size, shuffle=True):
return gen(), batch_count
def mixup(inputs, outputs, batch_size, p=0.1):
# paper: https://arxiv.org/abs/1710.09412
if p == 0:
return batchize(inputs, outputs, batch_size)
batch_count = np.ceil(len(inputs) / batch_size).astype(int)
def lerp(a, b, t):
t = t.reshape([len(t)] + [1] * (a.ndim - t.ndim))
return (1 - t) * a + t * b
def gen():
indices0 = np.arange(len(inputs))
indices1 = np.arange(len(inputs))
np.random.shuffle(indices0)
np.random.shuffle(indices1)
for b in range(batch_count):
bi = b * batch_size
ind0 = indices0[bi:bi + batch_size]
ind1 = indices1[bi:bi + batch_size]
ps = np.random.beta(p, p, size=batch_size)
batch_inputs = lerp(inputs[ind0], inputs[ind1], ps)
batch_outputs = lerp(outputs[ind0], outputs[ind1], ps)
yield batch_inputs, batch_outputs
return gen(), batch_count
# more
_log_was_update = False