diff --git a/main.lua b/main.lua index a93d48c..d457932 100644 --- a/main.lua +++ b/main.lua @@ -102,6 +102,7 @@ local log_map = { adam_std = 7, weight_mean = 8, weight_std = 9, + test_trial = 10, } local function log_csv(t) @@ -597,8 +598,10 @@ local function learn_from_epoch() --for _, v in ipairs(trial_rewards) do insert(all_rewards, v) end + local current_cost = trial_rewards[0] -- may be nil! + if cfg.unperturbed_trial then - local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0]) + local nth_place = unperturbed_rank(trial_rewards, current_cost) -- a rank of 1 means our gradient is uninformative. print(("test trial: %d out of %d"):format(nth_place, #trial_rewards)) @@ -723,6 +726,7 @@ local function learn_from_epoch() adam_std = momstep_dev, weight_mean = weight_mean, weight_std = weight_std, + test_trial = current_cost or 0, } -- trying a heuristic...