use argsort

This commit is contained in:
Connor Olding 2018-06-08 13:46:38 +02:00
parent d33bdfea62
commit e24c3d31a4

View file

@ -85,6 +85,7 @@ local gui = gui
local util = require("util")
local argmax = util.argmax
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev
local clamp = util.clamp
local copy = util.copy
@ -258,9 +259,7 @@ local function collect_best_indices()
best_rewards = copy(trial_rewards)
end
local indices = {}
for i = 1, #best_rewards do indices[i] = i end
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
local indices = argsort(best_rewards, function(a, b) return a > b end)
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
return indices