shuffle by default

This commit is contained in:
Connor Olding 2017-06-17 17:12:59 +00:00
parent cf1b7c1c13
commit a4081606f7
3 changed files with 7 additions and 11 deletions

View file

@ -910,11 +910,6 @@ def run(program, args=None):
measure_error() measure_error()
while training and learner.next(): while training and learner.next():
indices = np.arange(inputs.shape[0])
np.random.shuffle(indices)
shuffled_inputs = inputs[indices]
shuffled_outputs = outputs[indices]
avg_loss, losses = ritual.train_batched( avg_loss, losses = ritual.train_batched(
shuffled_inputs, shuffled_outputs, shuffled_inputs, shuffled_outputs,
config.batch_size, config.batch_size,

View file

@ -946,12 +946,18 @@ class Ritual: # i'm just making up names at this point
return avg_mloss return avg_mloss
def train_batched(self, inputs, outputs, batch_size, def train_batched(self, inputs, outputs, batch_size,
return_losses=False, test_only=False): return_losses=False, test_only=False, shuffle=True):
assert isinstance(return_losses, bool) or return_losses == 'both' assert isinstance(return_losses, bool) or return_losses == 'both'
if not test_only: if not test_only:
self.en += 1 self.en += 1
if shuffle:
indices = np.arange(inputs.shape[0])
np.random.shuffle(indices)
inputs = inputs[indices]
outputs = outputs[indices]
cumsum_loss, cumsum_mloss = _0, _0 cumsum_loss, cumsum_mloss = _0, _0
batch_count = inputs.shape[0] // batch_size batch_count = inputs.shape[0] // batch_size
losses, mlosses = [], [] losses, mlosses = [], []

View file

@ -196,11 +196,6 @@ while learner.next():
if isinstance(node, ActivityRegularizer): if isinstance(node, ActivityRegularizer):
node.reg.lamb = act_t * node.reg.lamb_orig # HACK node.reg.lamb = act_t * node.reg.lamb_orig # HACK
indices = np.arange(inputs.shape[0])
np.random.shuffle(indices)
shuffled_inputs = inputs[indices]
shuffled_outputs = outputs[indices]
avg_loss, avg_mloss, losses, mlosses = ritual.train_batched( avg_loss, avg_mloss, losses, mlosses = ritual.train_batched(
shuffled_inputs, shuffled_outputs, shuffled_inputs, shuffled_outputs,
batch_size=bs, batch_size=bs,