add playback_mode
This commit is contained in:
parent
b453438055
commit
2bdd67b721
2 changed files with 23 additions and 11 deletions
|
@ -15,7 +15,7 @@ local cfg = {
|
||||||
deterministic = true, -- use argmax on outputs instead of random sampling.
|
deterministic = true, -- use argmax on outputs instead of random sampling.
|
||||||
det_epsilon = false, -- take random actions with probability eps.
|
det_epsilon = false, -- take random actions with probability eps.
|
||||||
|
|
||||||
graycode = true,
|
graycode = false,
|
||||||
epoch_trials = 50,
|
epoch_trials = 50,
|
||||||
epoch_top_trials = 25, -- new with ARS.
|
epoch_top_trials = 25, -- new with ARS.
|
||||||
unperturbed_trial = true, -- do a trial without any noise.
|
unperturbed_trial = true, -- do a trial without any noise.
|
||||||
|
@ -32,6 +32,8 @@ local cfg = {
|
||||||
cap_time = 200, --400
|
cap_time = 200, --400
|
||||||
timer_loser = 1/2,
|
timer_loser = 1/2,
|
||||||
decrement_reward = false, -- bad idea, encourages mario to kill himself
|
decrement_reward = false, -- bad idea, encourages mario to kill himself
|
||||||
|
|
||||||
|
playback_mode = false,
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials)
|
cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials)
|
||||||
|
|
30
main.lua
30
main.lua
|
@ -422,8 +422,10 @@ end
|
||||||
-- learning and evaluation.
|
-- learning and evaluation.
|
||||||
|
|
||||||
local function prepare_epoch()
|
local function prepare_epoch()
|
||||||
print('preparing epoch '..tostring(epoch_i)..'.')
|
|
||||||
base_params = network:collect()
|
base_params = network:collect()
|
||||||
|
if cfg.playback_mode then return end
|
||||||
|
|
||||||
|
print('preparing epoch '..tostring(epoch_i)..'.')
|
||||||
empty(trial_noise)
|
empty(trial_noise)
|
||||||
empty(trial_rewards)
|
empty(trial_rewards)
|
||||||
|
|
||||||
|
@ -440,7 +442,8 @@ local function prepare_epoch()
|
||||||
local noise = nn.zeros(#base_params)
|
local noise = nn.zeros(#base_params)
|
||||||
-- NOTE: change in implementation: deviation is multiplied here
|
-- NOTE: change in implementation: deviation is multiplied here
|
||||||
-- and ONLY here now.
|
-- and ONLY here now.
|
||||||
if cfg.graycode then
|
if i % 2 == 0 then -- FIXME: just messing around.
|
||||||
|
--if cfg.graycode then
|
||||||
--local precision = 1 / cfg.deviation
|
--local precision = 1 / cfg.deviation
|
||||||
--print(cfg.deviation, precision)
|
--print(cfg.deviation, precision)
|
||||||
for j = 1, #base_params do
|
for j = 1, #base_params do
|
||||||
|
@ -686,7 +689,7 @@ local function do_reset()
|
||||||
|
|
||||||
if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then
|
if epoch_i == 0 or (trial_i == cfg.epoch_trials and trial_neg) then
|
||||||
if epoch_i > 0 then learn_from_epoch() end
|
if epoch_i > 0 then learn_from_epoch() end
|
||||||
epoch_i = epoch_i + 1
|
if not cfg.playback_mode then epoch_i = epoch_i + 1 end
|
||||||
prepare_epoch()
|
prepare_epoch()
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -754,6 +757,11 @@ local function init()
|
||||||
if res == false then print(err) end
|
if res == false then print(err) end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function prepare_reset()
|
||||||
|
if cfg.playback_mode then return end
|
||||||
|
reset = true
|
||||||
|
end
|
||||||
|
|
||||||
local function doit(dummy)
|
local function doit(dummy)
|
||||||
local ingame_paused = get_state() == "paused"
|
local ingame_paused = get_state() == "paused"
|
||||||
|
|
||||||
|
@ -765,9 +773,11 @@ local function doit(dummy)
|
||||||
if ingame_paused or random() > 1 - cfg.timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
if ingame_paused or random() > 1 - cfg.timer_loser and R(0x1D) == 0 and R(0x57) == 0 then
|
||||||
timer = timer - 1
|
timer = timer - 1
|
||||||
end
|
end
|
||||||
timer = clamp(timer, 0, max_time)
|
if not cfg.playback_mode then
|
||||||
if cfg.enable_network then
|
timer = clamp(timer, 0, max_time)
|
||||||
set_timer(timer)
|
if cfg.enable_network then
|
||||||
|
set_timer(timer)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
local tf0 = total_frames % 1000
|
local tf0 = total_frames % 1000
|
||||||
|
@ -835,17 +845,17 @@ local function doit(dummy)
|
||||||
|
|
||||||
if get_state() == 'dead' and state_old ~= 'dead' then
|
if get_state() == 'dead' and state_old ~= 'dead' then
|
||||||
--print("dead. lives remaining:", R(0x75A, 0))
|
--print("dead. lives remaining:", R(0x75A, 0))
|
||||||
if R(0x75A, 0) == 0 then reset = true end
|
if R(0x75A, 0) == 0 then prepare_reset() end
|
||||||
end
|
end
|
||||||
if get_state() == 'lose' then
|
if get_state() == 'lose' then
|
||||||
-- this shouldn't happen if we catch the deaths as above.
|
-- this shouldn't happen if we catch the deaths as above.
|
||||||
print("ran out of lives.")
|
print("ran out of lives.")
|
||||||
reset = true
|
if not cfg.playback_mode then prepare_reset() end
|
||||||
end
|
end
|
||||||
|
|
||||||
-- lose a point for every frame paused.
|
-- lose a point for every frame paused.
|
||||||
--if ingame_paused then reward = reward - 1 end
|
--if ingame_paused then reward = reward - 1 end
|
||||||
if ingame_paused then reward = reward - 402; reset = true end
|
if ingame_paused then reward = reward - 402; prepare_reset() end
|
||||||
|
|
||||||
-- if we've run out of time while the game is paused...
|
-- if we've run out of time while the game is paused...
|
||||||
-- that's cheating! unpause.
|
-- that's cheating! unpause.
|
||||||
|
@ -902,7 +912,7 @@ while true do
|
||||||
while gcfg.bad_states[get_state()] do
|
while gcfg.bad_states[get_state()] do
|
||||||
-- mash the start button until we have control.
|
-- mash the start button until we have control.
|
||||||
joypad_mash('start')
|
joypad_mash('start')
|
||||||
reset = true
|
prepare_reset()
|
||||||
|
|
||||||
advance()
|
advance()
|
||||||
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
gui.text(4, 12, get_state(), '#FFFFFF', '#0000003F')
|
||||||
|
|
Loading…
Reference in a new issue