use argsort

This commit is contained in:
Connor Olding 2018-06-08 13:46:38 +02:00
parent d33bdfea62
commit e24c3d31a4

View file

@ -85,6 +85,7 @@ local gui = gui
local util = require("util") local util = require("util")
local argmax = util.argmax local argmax = util.argmax
local argsort = util.argsort
local calc_mean_dev = util.calc_mean_dev local calc_mean_dev = util.calc_mean_dev
local clamp = util.clamp local clamp = util.clamp
local copy = util.copy local copy = util.copy
@ -258,9 +259,7 @@ local function collect_best_indices()
best_rewards = copy(trial_rewards) best_rewards = copy(trial_rewards)
end end
local indices = {} local indices = argsort(best_rewards, function(a, b) return a > b end)
for i = 1, #best_rewards do indices[i] = i end
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
return indices return indices