From ddc8cae4c6dd693e8d31f09dd7b38c0548abc02d Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 7 Sep 2017 19:00:09 +0000 Subject: [PATCH] add unperturbed trial with ranking --- main.lua | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/main.lua b/main.lua index 994ac2f..2db16be 100644 --- a/main.lua +++ b/main.lua @@ -41,6 +41,7 @@ local consider_past_rewards = false local learn_start_select = false -- local epoch_trials = 40 +local unperturbed_trial = true -- do a trial without any noise. local learning_rate = 0.3 -- bigger now that i'm shaping trials etc. local deviation = 0.05 -- @@ -75,7 +76,7 @@ local bad_states = { local epoch_i = 0 local base_params -local trial_i = 0 +local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled. local trial_noise = {} local trial_rewards = {} local trials_remaining = 0 @@ -535,17 +536,25 @@ local function prepare_epoch() for j = 1, #base_params do noise[j] = nn.normal() end trial_noise[i] = noise end - trial_i = 0 + trial_i = -1 end local function load_next_trial() trial_i = trial_i + 1 - print('loading trial', trial_i) local W = nn.copy(base_params) local noise = trial_noise[trial_i] local devsqrt = sqrt(deviation) - for i, v in ipairs(base_params) do - W[i] = v + devsqrt * noise[i] + if trial_i == 0 and not unperturbed_trial then + trial_i = 1 + end + if trial_i > 0 then + print('loading trial', trial_i) + local noise = trial_noise[trial_i] + for i, v in ipairs(base_params) do + W[i] = v + deviation * noise[i] + end + else + print("test trial") end network:distribute(W) end @@ -574,6 +583,17 @@ local function fitness_shaping(rewards) return shaped_returns end +local function unperturbed_rank(rewards, unperturbed_reward) + -- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py + local nth_place = 1 + for i, v in ipairs(rewards) do + if v > unperturbed_reward then + nth_place = nth_place + 1 + end + end + return nth_place +end + local function learn_from_epoch() print() print('rewards:', trial_rewards) @@ -582,12 +602,12 @@ local function learn_from_epoch() insert(all_rewards, v) end - if consider_past_rewards then - normalize_wrt(trial_rewards, all_rewards) - else - normalize(trial_rewards) + if unperturbed_trial then + local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0]) + + -- a rank of 1 means our gradient is uninformative. + print(("test trial: %d out of %d"):format(nth_place, #trial_rewards)) end - --print('normalized:', trial_rewards) local step = nn.zeros(#base_params) local shaped_rewards = fitness_shaping(trial_rewards) @@ -625,7 +645,7 @@ local function do_reset() if state == 'dead' and get_timer() == 0 then state = 'timeup' end print("resetting in state: "..state..". reward:", reward) - if trial_i > 0 then trial_rewards[trial_i] = reward end + if trial_i >= 0 then trial_rewards[trial_i] = reward end if epoch_i == 0 or trial_i == epoch_trials then if epoch_i > 0 then learn_from_epoch() end