add playback_mode

This commit is contained in:
Connor Olding 2018-04-03 18:13:11 +02:00
parent b453438055
commit 2bdd67b721
2 changed files with 23 additions and 11 deletions

View File

@ -15,7 +15,7 @@ local cfg = {
deterministic = true, -- use argmax on outputs instead of random sampling.
det_epsilon = false, -- take random actions with probability eps.
graycode = true,
graycode = false,
epoch_trials = 50,
epoch_top_trials = 25, -- new with ARS.
unperturbed_trial = true, -- do a trial without any noise.
@ -32,6 +32,8 @@ local cfg = {
cap_time = 200, --400
timer_loser = 1/2,
decrement_reward = false, -- bad idea, encourages mario to kill himself
playback_mode = false,
}
cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials)

View File

@ -422,8 +422,10 @@ end
-- learning and evaluation.
local function prepare_epoch()
print('preparing epoch '..tostring(epoch_i)..'.')
base_params = network:collect()
if cfg.playback_mode then return end
print('preparing epoch '..tostring(epoch_i)..'.')
empty(trial_noise)
empty(trial_rewards)
@ -440,7 +442,8 @@ local function prepare_epoch()
local noise = nn.zeros(#base_params)
-- NOTE: change in implementation: deviation is multiplied here
-- and ONLY here now.
if cfg.graycode then
if i % 2 == 0 then -- FIXME: just messing around.
--if cfg.graycode then
--local precision = 1 / cfg.deviation
--print(cfg.deviation, precision)
for j = 1, #base_params do
@ -686,7 +689,7 @@ local function do_reset()
if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then
if epoch_i > 0 then learn_from_epoch() end
epoch_i = epoch_i + 1
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
prepare_epoch()
end
@ -754,6 +757,11 @@ local function init()
if res == false then print(err) end
end
local function prepare_reset()
if cfg.playback_mode then return end
reset = true
end
local function doit(dummy)
local ingame_paused = get_state() == "paused"
@ -765,9 +773,11 @@ local function doit(dummy)
if ingame_paused or random() > 1 - cfg.timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
timer = timer - 1
end
timer = clamp(timer, 0, max_time)
if cfg.enable_network then
set_timer(timer)
if not cfg.playback_mode then
timer = clamp(timer, 0, max_time)
if cfg.enable_network then
set_timer(timer)
end
end
local tf0 = total_frames % 1000
@ -835,17 +845,17 @@ local function doit(dummy)
if get_state() == 'dead' and state_old ~= 'dead' then
--print("dead. lives remaining:", R(0x75A, 0))
if R(0x75A, 0) == 0 then reset = true end
if R(0x75A, 0) == 0 then prepare_reset() end
end
if get_state() == 'lose' then
-- this shouldn't happen if we catch the deaths as above.
print("ran out of lives.")
reset = true
if not cfg.playback_mode then prepare_reset() end
end
-- lose a point for every frame paused.
--if ingame_paused then reward = reward - 1 end
if ingame_paused then reward = reward - 402; reset = true end
if ingame_paused then reward = reward - 402; prepare_reset() end
-- if we've run out of time while the game is paused...
-- that's cheating! unpause.
@ -902,7 +912,7 @@ while true do
while gcfg.bad_states[get_state()] do
-- mash the start button until we have control.
joypad_mash('start')
reset = true
prepare_reset()
advance()
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')