diff --git a/config.lua b/config.lua index 756af60..01c3d6c 100644 --- a/config.lua +++ b/config.lua @@ -33,6 +33,7 @@ local common_cfg = { -- ^ note that this now doubles the effective trials. time_inputs = true, -- binary inputs of global frame count + ars_lips = false, adamant = false, -- run steps through AMSgrad. cap_time = 300, @@ -50,6 +51,7 @@ local cfg = { epoch_top_trials = 10, learning_rate = 1.0, + ars_lips = true, deviation = 0.1, weight_decay = 0.004, diff --git a/main.lua b/main.lua index d457932..4d619fd 100644 --- a/main.lua +++ b/main.lua @@ -656,11 +656,12 @@ local function learn_from_epoch() end print("best deltas:", top_delta_rewards) - local _, reward_dev = calc_mean_dev(top_rewards) - --print("mean, dev:", _, reward_dev) - if reward_dev == 0 then reward_dev = 1 end - - for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end + if not cfg.ars_lips then + local _, reward_dev = calc_mean_dev(top_rewards) + --print("mean, dev:", _, reward_dev) + if reward_dev == 0 then reward_dev = 1 end + for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end + end -- NOTE: step no longer directly incorporates learning_rate. for i = 1, cfg.epoch_trials do @@ -668,9 +669,26 @@ local function learn_from_epoch() local pos = top_rewards[ind + 0] local neg = top_rewards[ind + 1] local reward = pos - neg - local noise = trial_noise[i] - for j, v in ipairs(noise) do - step[j] = step[j] + reward * v / cfg.epoch_top_trials + if reward ~= 0 then + local noise = trial_noise[i] + + if cfg.ars_lips then + local _, dev = calc_mean_dev(noise) + local c0 = neg - current_cost + local c1 = pos - current_cost + local l0 = abs(3 * c1 + c0) + local l1 = abs(c1 + 3 * c0) + local lips = max(l0, l1) / (2 * dev) + --reward = pos / lips - neg / lips + local old_reward = reward + reward = reward / lips + reward = reward / cfg.deviation -- FIXME: hack? + --print(("trial %i reward: %.0f -> %.5f"):format(i, old_reward, reward)) + end + + for j, v in ipairs(noise) do + step[j] = step[j] + reward * v / cfg.epoch_top_trials + end end end