add mixup variant of batchize
This commit is contained in:
parent
2cfcc9062e
commit
5a07cdac32
1 changed files with 30 additions and 0 deletions
|
@ -63,6 +63,36 @@ def batchize(inputs, outputs, batch_size, shuffle=True):
|
|||
return gen(), batch_count
|
||||
|
||||
|
||||
def mixup(inputs, outputs, batch_size, p=0.1):
|
||||
# paper: https://arxiv.org/abs/1710.09412
|
||||
|
||||
if p == 0:
|
||||
return batchize(inputs, outputs, batch_size)
|
||||
|
||||
batch_count = np.ceil(len(inputs) / batch_size).astype(int)
|
||||
|
||||
def lerp(a, b, t):
|
||||
t = t.reshape([len(t)] + [1] * (a.ndim - t.ndim))
|
||||
return (1 - t) * a + t * b
|
||||
|
||||
def gen():
|
||||
indices0 = np.arange(len(inputs))
|
||||
indices1 = np.arange(len(inputs))
|
||||
np.random.shuffle(indices0)
|
||||
np.random.shuffle(indices1)
|
||||
|
||||
for b in range(batch_count):
|
||||
bi = b * batch_size
|
||||
ind0 = indices0[bi:bi + batch_size]
|
||||
ind1 = indices1[bi:bi + batch_size]
|
||||
ps = np.random.beta(p, p, size=batch_size)
|
||||
batch_inputs = lerp(inputs[ind0], inputs[ind1], ps)
|
||||
batch_outputs = lerp(outputs[ind0], outputs[ind1], ps)
|
||||
yield batch_inputs, batch_outputs
|
||||
|
||||
return gen(), batch_count
|
||||
|
||||
|
||||
# more
|
||||
|
||||
_log_was_update = False
|
||||
|
|
Loading…
Reference in a new issue