tweak config and network
This commit is contained in:
parent
d696bd8c21
commit
5636c7b2ed
1 changed files with 14 additions and 8 deletions
22
main.lua
22
main.lua
|
@ -30,7 +30,7 @@ local defer_prints = true
|
||||||
|
|
||||||
local playable_mode = false
|
local playable_mode = false
|
||||||
local start_big = true
|
local start_big = true
|
||||||
local starting_lives = 0
|
local starting_lives = 1
|
||||||
--
|
--
|
||||||
local init_zeros = false -- instead of he_normal noise or whatever.
|
local init_zeros = false -- instead of he_normal noise or whatever.
|
||||||
local frameskip = 4
|
local frameskip = 4
|
||||||
|
@ -41,16 +41,19 @@ local eps_start = 1.0 * frameskip / 64
|
||||||
local eps_stop = 0.1 * eps_start
|
local eps_stop = 0.1 * eps_start
|
||||||
local eps_frames = 4000000
|
local eps_frames = 4000000
|
||||||
--
|
--
|
||||||
local epoch_trials = 18
|
local epoch_trials = 15 --18
|
||||||
local epoch_top_trials = 9 -- new with ARS.
|
local epoch_top_trials = 10 --6 -- new with ARS.
|
||||||
local unperturbed_trial = true -- do a trial without any noise.
|
local unperturbed_trial = true -- do a trial without any noise.
|
||||||
local negate_trials = true -- try pairs of normal and negated noise directions.
|
local negate_trials = true -- try pairs of normal and negated noise directions.
|
||||||
-- ^ note that this now doubles the effective trials.
|
-- ^ note that this now doubles the effective trials.
|
||||||
local deviation = 0.025
|
local deviation = 0.025 --0.03
|
||||||
local function approx_cossim(dim)
|
local function approx_cossim(dim)
|
||||||
return math.pow(1.521 * dim - 0.521, -0.5026)
|
return math.pow(1.521 * dim - 0.521, -0.5026)
|
||||||
end
|
end
|
||||||
local learning_rate = 0.01 / approx_cossim(7051)
|
--local learning_rate = 0.01 / approx_cossim(7051)
|
||||||
|
--local learning_rate = 0.0032 / approx_cossim(66573)
|
||||||
|
local learning_rate = 0.0056 / approx_cossim(66573)
|
||||||
|
local weight_decay = 1 - 0.9977
|
||||||
--
|
--
|
||||||
local cap_time = 200 --400
|
local cap_time = 200 --400
|
||||||
local timer_loser = 0 --1/3
|
local timer_loser = 0 --1/3
|
||||||
|
@ -303,6 +306,9 @@ local function make_network(input_size)
|
||||||
nn_x:feed(nn_y)
|
nn_x:feed(nn_y)
|
||||||
nn_ty:feed(nn_y)
|
nn_ty:feed(nn_y)
|
||||||
|
|
||||||
|
nn_y = nn_y:feed(nn.Dense(128))
|
||||||
|
nn_y = nn_y:feed(nn.Gelu())
|
||||||
|
|
||||||
nn_z = nn_y
|
nn_z = nn_y
|
||||||
nn_z = nn_z:feed(nn.Dense(#jp_lut))
|
nn_z = nn_z:feed(nn.Dense(#jp_lut))
|
||||||
nn_z = nn_z:feed(nn.Softmax())
|
nn_z = nn_z:feed(nn.Softmax())
|
||||||
|
@ -740,7 +746,7 @@ local function learn_from_epoch()
|
||||||
print("full step stddev:", learning_rate * step_dev)
|
print("full step stddev:", learning_rate * step_dev)
|
||||||
|
|
||||||
for i, v in ipairs(base_params) do
|
for i, v in ipairs(base_params) do
|
||||||
base_params[i] = v + learning_rate * step[i]
|
base_params[i] = v + learning_rate * step[i] - weight_decay * v
|
||||||
end
|
end
|
||||||
|
|
||||||
if enable_network then
|
if enable_network then
|
||||||
|
@ -876,7 +882,7 @@ local function doit(dummy)
|
||||||
|
|
||||||
if dummy == true then
|
if dummy == true then
|
||||||
-- don't invoke AI this frame. (keep holding the old inputs)
|
-- don't invoke AI this frame. (keep holding the old inputs)
|
||||||
gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F')
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -922,7 +928,7 @@ local function doit(dummy)
|
||||||
if not ingame_paused then reward = reward + reward_delta end
|
if not ingame_paused then reward = reward + reward_delta end
|
||||||
|
|
||||||
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
--gui.text(72, 12, ("%+4i"):format(reward_delta), '#FFFFFF', '#0000003F')
|
||||||
gui.text(96, 16, ("%+4i"):format(reward), '#FFFFFF', '#0000003F')
|
gui.text(89, 16, ("%+5i"):format(reward), '#FFFFFF', '#0000003F')
|
||||||
|
|
||||||
if get_state() == 'dead' and state_old ~= 'dead' then
|
if get_state() == 'dead' and state_old ~= 'dead' then
|
||||||
--print("dead. lives remaining:", R(0x75A, 0))
|
--print("dead. lives remaining:", R(0x75A, 0))
|
||||||
|
|
Loading…
Reference in a new issue