use argsort
This commit is contained in:
parent
d33bdfea62
commit
e24c3d31a4
1 changed files with 2 additions and 3 deletions
5
main.lua
5
main.lua
|
@ -85,6 +85,7 @@ local gui = gui
|
||||||
|
|
||||||
local util = require("util")
|
local util = require("util")
|
||||||
local argmax = util.argmax
|
local argmax = util.argmax
|
||||||
|
local argsort = util.argsort
|
||||||
local calc_mean_dev = util.calc_mean_dev
|
local calc_mean_dev = util.calc_mean_dev
|
||||||
local clamp = util.clamp
|
local clamp = util.clamp
|
||||||
local copy = util.copy
|
local copy = util.copy
|
||||||
|
@ -258,9 +259,7 @@ local function collect_best_indices()
|
||||||
best_rewards = copy(trial_rewards)
|
best_rewards = copy(trial_rewards)
|
||||||
end
|
end
|
||||||
|
|
||||||
local indices = {}
|
local indices = argsort(best_rewards, function(a, b) return a > b end)
|
||||||
for i = 1, #best_rewards do indices[i] = i end
|
|
||||||
sort(indices, function(a, b) return best_rewards[a] > best_rewards[b] end)
|
|
||||||
|
|
||||||
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
for i = cfg.epoch_top_trials + 1, #best_rewards do indices[i] = nil end
|
||||||
return indices
|
return indices
|
||||||
|
|
Loading…
Reference in a new issue