From e24c3d31a45d173c5194d5b28e149f14879a5e98 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Fri, 8 Jun 2018 13:46:38 +0200 Subject: [PATCH] use argsort --- main.lua | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/main.lua b/main.lua index 37668f9..a64a837 100644 --- a/main.lua +++ b/main.lua @@ -85,6 +85,7 @@ local gui = gui local util = require("util") local argmax = util.argmax +local argsort = util.argsort local calc_mean_dev = util.calc_mean_dev local clamp = util.clamp local copy = util.copy @@ -258,9 +259,7 @@ local function collect_best_indices() best_rewards = copy(trial_rewards) end - local indices = {} - for i = 1, #best_rewards do indices[i] = i end - sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end) + local indices = argsort(best_rewards, function(a, b) return a > b end) for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end return indices