add test trial logging
This commit is contained in:
parent
deb1ea7de0
commit
ee066154b2
1 changed files with 5 additions and 1 deletions
6
main.lua
6
main.lua
|
@ -102,6 +102,7 @@ local log_map = {
|
||||||
adam_std = 7,
|
adam_std = 7,
|
||||||
weight_mean = 8,
|
weight_mean = 8,
|
||||||
weight_std = 9,
|
weight_std = 9,
|
||||||
|
test_trial = 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
local function log_csv(t)
|
local function log_csv(t)
|
||||||
|
@ -597,8 +598,10 @@ local function learn_from_epoch()
|
||||||
|
|
||||||
--for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end
|
--for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end
|
||||||
|
|
||||||
|
local current_cost = trial_rewards[0] -- may be nil!
|
||||||
|
|
||||||
if cfg.unperturbed_trial then
|
if cfg.unperturbed_trial then
|
||||||
local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0])
|
local nth_place = unperturbed_rank(trial_rewards, current_cost)
|
||||||
|
|
||||||
-- a rank of 1 means our gradient is uninformative.
|
-- a rank of 1 means our gradient is uninformative.
|
||||||
print(("test trial: %d out of %d"):format(nth_place, #trial_rewards))
|
print(("test trial: %d out of %d"):format(nth_place, #trial_rewards))
|
||||||
|
@ -723,6 +726,7 @@ local function learn_from_epoch()
|
||||||
adam_std = momstep_dev,
|
adam_std = momstep_dev,
|
||||||
weight_mean = weight_mean,
|
weight_mean = weight_mean,
|
||||||
weight_std = weight_std,
|
weight_std = weight_std,
|
||||||
|
test_trial = current_cost or 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
-- trying a heuristic...
|
-- trying a heuristic...
|
||||||
|
|
Loading…
Reference in a new issue