tweak config and network

This commit is contained in:
Connor Olding 2018-03-31 18:40:35 +02:00
parent d696bd8c21
commit 5636c7b2ed

View File

@ -30,7 +30,7 @@ local defer_prints = true
local playable_mode = false
local start_big = true
local starting_lives = 0
local starting_lives = 1
--
local init_zeros = false -- instead of he_normal noise or whatever.
local frameskip = 4
@ -41,16 +41,19 @@ local eps_start = 1.0 * frameskip / 64
local eps_stop = 0.1 * eps_start
local eps_frames = 4000000
--
local epoch_trials = 18
local epoch_top_trials = 9 -- new with ARS.
local epoch_trials = 15 --18
local epoch_top_trials = 10 --6 -- new with ARS.
local unperturbed_trial = true -- do a trial without any noise.
local negate_trials = true -- try pairs of normal and negated noise directions.
-- ^ note that this now doubles the effective trials.
local deviation = 0.025
local deviation = 0.025 --0.03
local function approx_cossim(dim)
return math.pow(1.521 * dim - 0.521, -0.5026)
end
local learning_rate = 0.01 / approx_cossim(7051)
--local learning_rate = 0.01 / approx_cossim(7051)
--local learning_rate = 0.0032 / approx_cossim(66573)
local learning_rate = 0.0056 / approx_cossim(66573)
local weight_decay = 1 - 0.9977
--
local cap_time = 200 --400
local timer_loser = 0 --1/3
@ -303,6 +306,9 @@ local function make_network(input_size)
nn_x:feed(nn_y)
nn_ty:feed(nn_y)
nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Gelu())
nn_z = nn_y
nn_z = nn_z:feed(nn.Dense(#jp_lut))
nn_z = nn_z:feed(nn.Softmax())
@ -740,7 +746,7 @@ local function learn_from_epoch()
print("full step stddev:", learning_rate * step_dev)
for i, v in ipairs(base_params) do
base_params[i] = v + learning_rate * step[i]
base_params[i] = v + learning_rate * step[i] - weight_decay * v
end
if enable_network then
@ -876,7 +882,7 @@ local function doit(dummy)
if dummy == true then
-- don't invoke AI this frame. (keep holding the old inputs)
gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F')
return
end
@ -922,7 +928,7 @@ local function doit(dummy)
if not ingame_paused then reward = reward + reward_delta end
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F')
if get_state() == 'dead' and state_old ~= 'dead' then
--print("dead. lives remaining:", R(0x75A, 0))