use argsort
This commit is contained in:
parent
d33bdfea62
commit
e24c3d31a4
1 changed files with 2 additions and 3 deletions
5
main.lua
5
main.lua
|
@ -85,6 +85,7 @@ local gui = gui
|
|||
|
||||
local util = require("util")
|
||||
local argmax = util.argmax
|
||||
local argsort = util.argsort
|
||||
local calc_mean_dev = util.calc_mean_dev
|
||||
local clamp = util.clamp
|
||||
local copy = util.copy
|
||||
|
@ -258,9 +259,7 @@ local function collect_best_indices()
|
|||
best_rewards = copy(trial_rewards)
|
||||
end
|
||||
|
||||
local indices = {}
|
||||
for i = 1, #best_rewards do indices[i] = i end
|
||||
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
|
||||
local indices = argsort(best_rewards, function(a, b) return a > b end)
|
||||
|
||||
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
||||
return indices
|
||||
|
|
Loading…
Reference in a new issue