use fitness shaping
This commit is contained in:
parent
6b193cac9b
commit
9ce1f87ade
1 changed files with 45 additions and 26 deletions
71
main.lua
71
main.lua
|
@ -40,9 +40,9 @@ local eps_frames = 1000000
|
||||||
local consider_past_rewards = false
|
local consider_past_rewards = false
|
||||||
local learn_start_select = false
|
local learn_start_select = false
|
||||||
--
|
--
|
||||||
local epoch_trials = 40 -- 24
|
local epoch_trials = 40
|
||||||
local learning_rate = 1e-3
|
local learning_rate = 0.3 -- bigger now that i'm shaping trials etc.
|
||||||
local deviation = 1e-2 -- 4e-3
|
local deviation = 0.05
|
||||||
--
|
--
|
||||||
local cap_time = 400
|
local cap_time = 400
|
||||||
local timer_loser = 1/3
|
local timer_loser = 1/3
|
||||||
|
@ -127,6 +127,7 @@ local randomseed = math.randomseed
|
||||||
local insert = table.insert
|
local insert = table.insert
|
||||||
local remove = table.remove
|
local remove = table.remove
|
||||||
local unpack = table.unpack or unpack
|
local unpack = table.unpack or unpack
|
||||||
|
local sort = table.sort
|
||||||
local R = memory.readbyteunsigned
|
local R = memory.readbyteunsigned
|
||||||
local S = memory.readbyte --signed
|
local S = memory.readbyte --signed
|
||||||
local W = memory.writebyte
|
local W = memory.writebyte
|
||||||
|
@ -143,12 +144,22 @@ local ror = bit.ror
|
||||||
|
|
||||||
-- utilities.
|
-- utilities.
|
||||||
|
|
||||||
|
local function ifind(haystack, needle)
|
||||||
|
for i, v in ipairs(haystack) do
|
||||||
|
if v == needle then return i end
|
||||||
|
end
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
|
||||||
local function boolean_xor(a, b)
|
local function boolean_xor(a, b)
|
||||||
if a and b then return false end
|
if a and b then return false end
|
||||||
if not a and not b then return false end
|
if not a and not b then return false end
|
||||||
return true
|
return true
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local _invlog2 = 1 / log(2)
|
||||||
|
local function log2(x) return log(x) * _invlog2 end
|
||||||
|
|
||||||
local function clamp(x, l, u) return min(max(x, l), u) end
|
local function clamp(x, l, u) return min(max(x, l), u) end
|
||||||
|
|
||||||
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
|
local function lerp(a, b, t) return a + (b - a) * clamp(t, 0, 1) end
|
||||||
|
@ -517,6 +528,8 @@ local function prepare_epoch()
|
||||||
base_params = network:collect()
|
base_params = network:collect()
|
||||||
empty(trial_noise)
|
empty(trial_noise)
|
||||||
empty(trial_rewards)
|
empty(trial_rewards)
|
||||||
|
-- TODO: save memory. generate noise as needed by saving the seed
|
||||||
|
-- (the os.time() as of here) and calling nn.normal() each trial.
|
||||||
for i = 1, epoch_trials do
|
for i = 1, epoch_trials do
|
||||||
local noise = nn.zeros(#base_params)
|
local noise = nn.zeros(#base_params)
|
||||||
for j = 1, #base_params do noise[j] = nn.normal() end
|
for j = 1, #base_params do noise[j] = nn.normal() end
|
||||||
|
@ -537,6 +550,30 @@ local function load_next_trial()
|
||||||
network:distribute(W)
|
network:distribute(W)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function fitness_shaping(rewards)
|
||||||
|
-- lifted from: https://github.com/atgambardella/pytorch-es/blob/master/train.py
|
||||||
|
local decreasing = nn.copy(rewards)
|
||||||
|
sort(decreasing, function(a, b) return a > b end)
|
||||||
|
local shaped_returns = {}
|
||||||
|
local lamb = #rewards
|
||||||
|
|
||||||
|
local denom = 0
|
||||||
|
for i, v in ipairs(rewards) do
|
||||||
|
local l = log2(lamb / 2 + 1)
|
||||||
|
local r = log2(ifind(decreasing, v))
|
||||||
|
denom = denom + max(0, l - r)
|
||||||
|
end
|
||||||
|
|
||||||
|
for i, v in ipairs(rewards) do
|
||||||
|
local l = log2(lamb / 2 + 1)
|
||||||
|
local r = log2(ifind(decreasing, v))
|
||||||
|
local numer = max(0, l - r)
|
||||||
|
insert(shaped_returns, numer / denom + 1 / lamb)
|
||||||
|
end
|
||||||
|
|
||||||
|
return shaped_returns
|
||||||
|
end
|
||||||
|
|
||||||
local function learn_from_epoch()
|
local function learn_from_epoch()
|
||||||
print()
|
print()
|
||||||
print('rewards:', trial_rewards)
|
print('rewards:', trial_rewards)
|
||||||
|
@ -552,37 +589,19 @@ local function learn_from_epoch()
|
||||||
end
|
end
|
||||||
--print('normalized:', trial_rewards)
|
--print('normalized:', trial_rewards)
|
||||||
|
|
||||||
local reward_mean, reward_dev = calc_mean_dev(trial_rewards)
|
|
||||||
|
|
||||||
local step = nn.zeros(#base_params)
|
local step = nn.zeros(#base_params)
|
||||||
|
local shaped_rewards = fitness_shaping(trial_rewards)
|
||||||
|
|
||||||
|
local altogether = learning_rate / (epoch_trials * deviation)
|
||||||
for i = 1, epoch_trials do
|
for i = 1, epoch_trials do
|
||||||
local reward = trial_rewards[i]
|
local reward = shaped_rewards[i]
|
||||||
local noise = trial_noise[i]
|
local noise = trial_noise[i]
|
||||||
for j, v in ipairs(noise) do
|
for j, v in ipairs(noise) do
|
||||||
step[j] = step[j] + reward * v
|
step[j] = step[j] + altogether * (reward * v)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
local magnitude = learning_rate / deviation
|
|
||||||
--print('stepping with magnitude', magnitude)
|
|
||||||
-- throw the division from the averaging in there too.
|
|
||||||
local altogether = magnitude / epoch_trials
|
|
||||||
for i, v in ipairs(step) do
|
|
||||||
step[i] = altogether * v
|
|
||||||
end
|
|
||||||
|
|
||||||
local step_mean, step_dev = calc_mean_dev(step)
|
local step_mean, step_dev = calc_mean_dev(step)
|
||||||
if step_dev < 1e-8 then
|
|
||||||
-- we didn't get anywhere. step in a random direction.
|
|
||||||
print("stepping randomly.")
|
|
||||||
local noise = trial_noise[1]
|
|
||||||
local devsqrt = sqrt(deviation)
|
|
||||||
for i, v in ipairs(step) do
|
|
||||||
step[i] = devsqrt * noise[i]
|
|
||||||
end
|
|
||||||
|
|
||||||
step_mean, step_dev = calc_mean_dev(step)
|
|
||||||
end
|
|
||||||
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
|
if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end
|
||||||
print("step stddev:", step_dev)
|
print("step stddev:", step_dev)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue