add unperturbed trial with ranking
This commit is contained in:
parent
9ce1f87ade
commit
ddc8cae4c6
1 changed files with 31 additions and 11 deletions
40
main.lua
40
main.lua
|
@ -41,6 +41,7 @@ local consider_past_rewards = false
|
||||||
local learn_start_select = false
|
local learn_start_select = false
|
||||||
--
|
--
|
||||||
local epoch_trials = 40
|
local epoch_trials = 40
|
||||||
|
local unperturbed_trial = true -- do a trial without any noise.
|
||||||
local learning_rate = 0.3 -- bigger now that i'm shaping trials etc.
|
local learning_rate = 0.3 -- bigger now that i'm shaping trials etc.
|
||||||
local deviation = 0.05
|
local deviation = 0.05
|
||||||
--
|
--
|
||||||
|
@ -75,7 +76,7 @@ local bad_states = {
|
||||||
|
|
||||||
local epoch_i = 0
|
local epoch_i = 0
|
||||||
local base_params
|
local base_params
|
||||||
local trial_i = 0
|
local trial_i = -1 -- NOTE: trial 0 is an unperturbed trial, if enabled.
|
||||||
local trial_noise = {}
|
local trial_noise = {}
|
||||||
local trial_rewards = {}
|
local trial_rewards = {}
|
||||||
local trials_remaining = 0
|
local trials_remaining = 0
|
||||||
|
@ -535,17 +536,25 @@ local function prepare_epoch()
|
||||||
for j = 1, #base_params do noise[j] = nn.normal() end
|
for j = 1, #base_params do noise[j] = nn.normal() end
|
||||||
trial_noise[i] = noise
|
trial_noise[i] = noise
|
||||||
end
|
end
|
||||||
trial_i = 0
|
trial_i = -1
|
||||||
end
|
end
|
||||||
|
|
||||||
local function load_next_trial()
|
local function load_next_trial()
|
||||||
trial_i = trial_i + 1
|
trial_i = trial_i + 1
|
||||||
print('loading trial', trial_i)
|
|
||||||
local W = nn.copy(base_params)
|
local W = nn.copy(base_params)
|
||||||
local noise = trial_noise[trial_i]
|
local noise = trial_noise[trial_i]
|
||||||
local devsqrt = sqrt(deviation)
|
local devsqrt = sqrt(deviation)
|
||||||
|
if trial_i == 0 and not unperturbed_trial then
|
||||||
|
trial_i = 1
|
||||||
|
end
|
||||||
|
if trial_i > 0 then
|
||||||
|
print('loading trial', trial_i)
|
||||||
|
local noise = trial_noise[trial_i]
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
W[i] = v + devsqrt * noise[i]
|
W[i] = v + deviation * noise[i]
|
||||||
|
end
|
||||||
|
else
|
||||||
|
print("test trial")
|
||||||
end
|
end
|
||||||
network:distribute(W)
|
network:distribute(W)
|
||||||
end
|
end
|
||||||
|
@ -574,6 +583,17 @@ local function fitness_shaping(rewards)
|
||||||
return shaped_returns
|
return shaped_returns
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function unperturbed_rank(rewards, unperturbed_reward)
|
||||||
|
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
||||||
|
local nth_place = 1
|
||||||
|
for i, v in ipairs(rewards) do
|
||||||
|
if v > unperturbed_reward then
|
||||||
|
nth_place = nth_place + 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return nth_place
|
||||||
|
end
|
||||||
|
|
||||||
local function learn_from_epoch()
|
local function learn_from_epoch()
|
||||||
print()
|
print()
|
||||||
print('rewards:', trial_rewards)
|
print('rewards:', trial_rewards)
|
||||||
|
@ -582,12 +602,12 @@ local function learn_from_epoch()
|
||||||
insert(all_rewards, v)
|
insert(all_rewards, v)
|
||||||
end
|
end
|
||||||
|
|
||||||
if consider_past_rewards then
|
if unperturbed_trial then
|
||||||
normalize_wrt(trial_rewards, all_rewards)
|
local nth_place = unperturbed_rank(trial_rewards, trial_rewards[0])
|
||||||
else
|
|
||||||
normalize(trial_rewards)
|
-- a rank of 1 means our gradient is uninformative.
|
||||||
|
print(("test trial: %d out of %d"):format(nth_place, #trial_rewards))
|
||||||
end
|
end
|
||||||
--print('normalized:', trial_rewards)
|
|
||||||
|
|
||||||
local step = nn.zeros(#base_params)
|
local step = nn.zeros(#base_params)
|
||||||
local shaped_rewards = fitness_shaping(trial_rewards)
|
local shaped_rewards = fitness_shaping(trial_rewards)
|
||||||
|
@ -625,7 +645,7 @@ local function do_reset()
|
||||||
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
||||||
print("resetting in state: "..state..". reward:", reward)
|
print("resetting in state: "..state..". reward:", reward)
|
||||||
|
|
||||||
if trial_i > 0 then trial_rewards[trial_i] = reward end
|
if trial_i >= 0 then trial_rewards[trial_i] = reward end
|
||||||
|
|
||||||
if epoch_i == 0 or trial_i == epoch_trials then
|
if epoch_i == 0 or trial_i == epoch_trials then
|
||||||
if epoch_i > 0 then learn_from_epoch() end
|
if epoch_i > 0 then learn_from_epoch() end
|
||||||
|
|
Loading…
Reference in a new issue