diff --git a/config.lua b/config.lua index e8cb96d..b2aca37 100644 --- a/config.lua +++ b/config.lua @@ -15,7 +15,7 @@ local cfg = { deterministic = true, -- use argmax on outputs instead of random sampling. det_epsilon = false, -- take random actions with probability eps. - graycode = true, + graycode = false, epoch_trials = 50, epoch_top_trials = 25, -- new with ARS. unperturbed_trial = true, -- do a trial without any noise. @@ -32,6 +32,8 @@ local cfg = { cap_time = 200, --400 timer_loser = 1/2, decrement_reward = false, -- bad idea, encourages mario to kill himself + + playback_mode = false, } cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials) diff --git a/main.lua b/main.lua index 4ab813f..f52b5c7 100644 --- a/main.lua +++ b/main.lua @@ -422,8 +422,10 @@ end -- learning and evaluation. local function prepare_epoch() - print('preparing epoch '..tostring(epoch_i)..'.') base_params = network:collect() + if cfg.playback_mode then return end + + print('preparing epoch '..tostring(epoch_i)..'.') empty(trial_noise) empty(trial_rewards) @@ -440,7 +442,8 @@ local function prepare_epoch() local noise = nn.zeros(#base_params) -- NOTE: change in implementation: deviation is multiplied here -- and ONLY here now. - if cfg.graycode then + if i % 2 == 0 then -- FIXME: just messing around. + --if cfg.graycode then --local precision = 1 / cfg.deviation --print(cfg.deviation, precision) for j = 1, #base_params do @@ -686,7 +689,7 @@ local function do_reset() if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then if epoch_i > 0 then learn_from_epoch() end - epoch_i = epoch_i + 1 + if not cfg.playback_mode then epoch_i = epoch_i + 1 end prepare_epoch() end @@ -754,6 +757,11 @@ local function init() if res == false then print(err) end end +local function prepare_reset() + if cfg.playback_mode then return end + reset = true +end + local function doit(dummy) local ingame_paused = get_state() == "paused" @@ -765,9 +773,11 @@ local function doit(dummy) if ingame_paused or random() > 1 - cfg.timer_loser and R(0x1D) == 0 and R(0x57) == 0 then timer = timer - 1 end - timer = clamp(timer, 0, max_time) - if cfg.enable_network then - set_timer(timer) + if not cfg.playback_mode then + timer = clamp(timer, 0, max_time) + if cfg.enable_network then + set_timer(timer) + end end local tf0 = total_frames % 1000 @@ -835,17 +845,17 @@ local function doit(dummy) if get_state() == 'dead' and state_old ~= 'dead' then --print("dead. lives remaining:", R(0x75A, 0)) - if R(0x75A, 0) == 0 then reset = true end + if R(0x75A, 0) == 0 then prepare_reset() end end if get_state() == 'lose' then -- this shouldn't happen if we catch the deaths as above. print("ran out of lives.") - reset = true + if not cfg.playback_mode then prepare_reset() end end -- lose a point for every frame paused. --if ingame_paused then reward = reward - 1 end - if ingame_paused then reward = reward - 402; reset = true end + if ingame_paused then reward = reward - 402; prepare_reset() end -- if we've run out of time while the game is paused... -- that's cheating! unpause. @@ -902,7 +912,7 @@ while true do while gcfg.bad_states[get_state()] do -- mash the start button until we have control. joypad_mash('start') - reset = true + prepare_reset() advance() gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')