diff --git a/main.lua b/main.lua index e517f73..3d377cb 100644 --- a/main.lua +++ b/main.lua @@ -30,7 +30,7 @@ local defer_prints = true local playable_mode = false local start_big = true -local starting_lives = 0 +local starting_lives = 1 -- local init_zeros = false -- instead of he_normal noise or whatever. local frameskip = 4 @@ -41,16 +41,19 @@ local eps_start = 1.0 * frameskip / 64 local eps_stop = 0.1 * eps_start local eps_frames = 4000000 -- -local epoch_trials = 18 -local epoch_top_trials = 9 -- new with ARS. +local epoch_trials = 15 --18 +local epoch_top_trials = 10 --6 -- new with ARS. local unperturbed_trial = true -- do a trial without any noise. local negate_trials = true -- try pairs of normal and negated noise directions. -- ^ note that this now doubles the effective trials. -local deviation = 0.025 +local deviation = 0.025 --0.03 local function approx_cossim(dim) return math.pow(1.521 * dim - 0.521, -0.5026) end -local learning_rate = 0.01 / approx_cossim(7051) +--local learning_rate = 0.01 / approx_cossim(7051) +--local learning_rate = 0.0032 / approx_cossim(66573) +local learning_rate = 0.0056 / approx_cossim(66573) +local weight_decay = 1 - 0.9977 -- local cap_time = 200 --400 local timer_loser = 0 --1/3 @@ -303,6 +306,9 @@ local function make_network(input_size) nn_x:feed(nn_y) nn_ty:feed(nn_y) + nn_y = nn_y:feed(nn.Dense(128)) + nn_y = nn_y:feed(nn.Gelu()) + nn_z = nn_y nn_z = nn_z:feed(nn.Dense(#jp_lut)) nn_z = nn_z:feed(nn.Softmax()) @@ -740,7 +746,7 @@ local function learn_from_epoch() print("full step stddev:", learning_rate * step_dev) for i, v in ipairs(base_params) do - base_params[i] = v + learning_rate * step[i] + base_params[i] = v + learning_rate * step[i] - weight_decay * v end if enable_network then @@ -876,7 +882,7 @@ local function doit(dummy) if dummy == true then -- don't invoke AI this frame. (keep holding the old inputs) - gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') + gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F') return end @@ -922,7 +928,7 @@ local function doit(dummy) if not ingame_paused then reward = reward + reward_delta end --gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F') - gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F') + gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F') if get_state() == 'dead' and state_old ~= 'dead' then --print("dead. lives remaining:", R(0x75A, 0))