add mixup variant of batchize
This commit is contained in:
parent
2cfcc9062e
commit
5a07cdac32
1 changed files with 30 additions and 0 deletions
|
@ -63,6 +63,36 @@ def batchize(inputs, outputs, batch_size, shuffle=True):
|
||||||
return gen(), batch_count
|
return gen(), batch_count
|
||||||
|
|
||||||
|
|
||||||
|
def mixup(inputs, outputs, batch_size, p=0.1):
|
||||||
|
# paper: https://arxiv.org/abs/1710.09412
|
||||||
|
|
||||||
|
if p == 0:
|
||||||
|
return batchize(inputs, outputs, batch_size)
|
||||||
|
|
||||||
|
batch_count = np.ceil(len(inputs) / batch_size).astype(int)
|
||||||
|
|
||||||
|
def lerp(a, b, t):
|
||||||
|
t = t.reshape([len(t)] + [1] * (a.ndim - t.ndim))
|
||||||
|
return (1 - t) * a + t * b
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
indices0 = np.arange(len(inputs))
|
||||||
|
indices1 = np.arange(len(inputs))
|
||||||
|
np.random.shuffle(indices0)
|
||||||
|
np.random.shuffle(indices1)
|
||||||
|
|
||||||
|
for b in range(batch_count):
|
||||||
|
bi = b * batch_size
|
||||||
|
ind0 = indices0[bi:bi + batch_size]
|
||||||
|
ind1 = indices1[bi:bi + batch_size]
|
||||||
|
ps = np.random.beta(p, p, size=batch_size)
|
||||||
|
batch_inputs = lerp(inputs[ind0], inputs[ind1], ps)
|
||||||
|
batch_outputs = lerp(outputs[ind0], outputs[ind1], ps)
|
||||||
|
yield batch_inputs, batch_outputs
|
||||||
|
|
||||||
|
return gen(), batch_count
|
||||||
|
|
||||||
|
|
||||||
# more
|
# more
|
||||||
|
|
||||||
_log_was_update = False
|
_log_was_update = False
|
||||||
|
|
Loading…
Add table
Reference in a new issue