diff --git a/main.lua b/main.lua index 7ed2dc8..69b3231 100644 --- a/main.lua +++ b/main.lua @@ -24,17 +24,19 @@ function mt.__newindex(t, n, v) error("cannot assign undeclared global '" .. tos function mt.__index(t, n) error("cannot use undeclared global '" .. tostring(n) .. "'", 2) end local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end --- configuration and globals. +-- configuration. --randomseed(11) local playable_mode = false -- -local deterministic = false -- use argmax on outputs instead of random sampling. +-- true greedy epsilon has both deterministic and det_epsilon set. +local deterministic = true -- use argmax on outputs instead of random sampling. local det_epsilon = true -- take random actions with probability eps. -local eps_start = 0.50 -local eps_stop = 0.05 -local eps_frames = 60*60*60 +-- using parameters from DQN... sorta. +local eps_start = 1.0 * 1/60 -- i think this should be * 1/16 for atari ref. +local eps_stop = 0.1 * 1/60 -- " +local eps_frames = 1000000 local consider_past_rewards = false local learn_start_select = false -- @@ -50,6 +52,27 @@ local enable_network = not playable_mode local input_size = 281 -- TODO: let the script figure this out for us. +local ok_routines = { + [0x4] = true, -- sliding down flagpole + [0x5] = true, -- end of level auto-walk + [0x7] = true, -- start of level auto-walk + [0x8] = true, -- normal (in control) + [0x9] = true, -- acquiring mushroom + [0xA] = true, -- losing big mario + [0xB] = true, -- uhh + [0xC] = true, -- acquiring fireflower +} + +local bad_states = { + power = true, + waiting_demo = true, + playing_demo = true, + unknown = true, + lose = true, +} + +-- state. + local epoch_i = 0 local base_params local trial_i = 0 @@ -83,25 +106,6 @@ local score_old local once = false local reset = true -local ok_routines = { - [0x4] = true, -- sliding down flagpole - [0x5] = true, -- end of level auto-walk - [0x7] = true, -- start of level auto-walk - [0x8] = true, -- normal (in control) - [0x9] = true, -- acquiring mushroom - [0xA] = true, -- losing big mario - [0xB] = true, -- uhh - [0xC] = true, -- acquiring fireflower -} - -local bad_states = { - power = true, - waiting_demo = true, - playing_demo = true, - unknown = true, - lose = true, -} - local state_old = '' -- localize some stuff. @@ -508,6 +512,31 @@ local function handle_tiles() end end +local function prepare_epoch() + print('preparing epoch '..tostring(epoch_i)..'. this might take a while.') + base_params = network:collect() + empty(trial_noise) + empty(trial_rewards) + for i = 1, epoch_trials do + local noise = nn.zeros(#base_params) + for j = 1, #base_params do noise[j] = nn.normal() end + trial_noise[i] = noise + end + trial_i = 0 +end + +local function load_next_trial() + trial_i = trial_i + 1 + print('loading trial', trial_i) + local W = nn.copy(base_params) + local noise = trial_noise[trial_i] + local devsqrt = sqrt(deviation) + for i, v in ipairs(base_params) do + W[i] = v + devsqrt * noise[i] + end + network:distribute(W) +end + local function learn_from_epoch() print() print('rewards:', trial_rewards) @@ -571,33 +600,11 @@ local function learn_from_epoch() print() end -local function prepare_epoch() - print('preparing epoch '..tostring(epoch_i)..'. this might take a while.') - base_params = network:collect() - empty(trial_noise) - empty(trial_rewards) - for i = 1, epoch_trials do - local noise = nn.zeros(#base_params) - for j = 1, #base_params do noise[j] = nn.normal() end - trial_noise[i] = noise - end - trial_i = 0 -end - -local function load_next_trial() - trial_i = trial_i + 1 - print('loading trial', trial_i) - local W = nn.copy(base_params) - local noise = trial_noise[trial_i] - local devsqrt = sqrt(deviation) - for i, v in ipairs(base_params) do - W[i] = v + devsqrt * noise[i] - end - network:distribute(W) -end - local function do_reset() - print("resetting in state: "..get_state()..". reward:", reward) + local state = get_state() + -- be a little more descriptive. + if state == 'dead' and get_timer() == 0 then state = 'timeup' end + print("resetting in state: "..state..". reward:", reward) if trial_i > 0 then trial_rewards[trial_i] = reward end