tweaks and reordering

This commit is contained in:
Connor Olding 2017-09-07 18:41:44 +00:00
parent f7bee50d12
commit 6b193cac9b

107
main.lua
View File

@ -24,17 +24,19 @@ function mt.__newindex(t, n, v) error("cannot assign undeclared global '" .. tos
function mt.__index(t, n) error("cannot use undeclared global '" .. tostring(n) .. "'", 2) end
local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
-- configuration and globals.
-- configuration.
--randomseed(11)
local playable_mode = false
--
local deterministic = false -- use argmax on outputs instead of random sampling.
-- true greedy epsilon has both deterministic and det_epsilon set.
local deterministic = true -- use argmax on outputs instead of random sampling.
local det_epsilon = true -- take random actions with probability eps.
local eps_start = 0.50
local eps_stop = 0.05
local eps_frames = 60*60*60
-- using parameters from DQN... sorta.
local eps_start = 1.0 * 1/60 -- i think this should be * 1/16 for atari ref.
local eps_stop = 0.1 * 1/60 -- "
local eps_frames = 1000000
local consider_past_rewards = false
local learn_start_select = false
--
@ -50,6 +52,27 @@ local enable_network = not playable_mode
local input_size = 281 -- TODO: let the script figure this out for us.
local ok_routines = {
[0x4] = true, -- sliding down flagpole
[0x5] = true, -- end of level auto-walk
[0x7] = true, -- start of level auto-walk
[0x8] = true, -- normal (in control)
[0x9] = true, -- acquiring mushroom
[0xA] = true, -- losing big mario
[0xB] = true, -- uhh
[0xC] = true, -- acquiring fireflower
}
local bad_states = {
power = true,
waiting_demo = true,
playing_demo = true,
unknown = true,
lose = true,
}
-- state.
local epoch_i = 0
local base_params
local trial_i = 0
@ -83,25 +106,6 @@ local score_old
local once = false
local reset = true
local ok_routines = {
[0x4] = true, -- sliding down flagpole
[0x5] = true, -- end of level auto-walk
[0x7] = true, -- start of level auto-walk
[0x8] = true, -- normal (in control)
[0x9] = true, -- acquiring mushroom
[0xA] = true, -- losing big mario
[0xB] = true, -- uhh
[0xC] = true, -- acquiring fireflower
}
local bad_states = {
power = true,
waiting_demo = true,
playing_demo = true,
unknown = true,
lose = true,
}
local state_old = ''
-- localize some stuff.
@ -508,6 +512,31 @@ local function handle_tiles()
end
end
local function prepare_epoch()
print('preparing epoch '..tostring(epoch_i)..'. this might take a while.')
base_params = network:collect()
empty(trial_noise)
empty(trial_rewards)
for i = 1, epoch_trials do
local noise = nn.zeros(#base_params)
for j = 1, #base_params do noise[j] = nn.normal() end
trial_noise[i] = noise
end
trial_i = 0
end
local function load_next_trial()
trial_i = trial_i + 1
print('loading trial', trial_i)
local W = nn.copy(base_params)
local noise = trial_noise[trial_i]
local devsqrt = sqrt(deviation)
for i, v in ipairs(base_params) do
W[i] = v + devsqrt * noise[i]
end
network:distribute(W)
end
local function learn_from_epoch()
print()
print('rewards:', trial_rewards)
@ -571,33 +600,11 @@ local function learn_from_epoch()
print()
end
local function prepare_epoch()
print('preparing epoch '..tostring(epoch_i)..'. this might take a while.')
base_params = network:collect()
empty(trial_noise)
empty(trial_rewards)
for i = 1, epoch_trials do
local noise = nn.zeros(#base_params)
for j = 1, #base_params do noise[j] = nn.normal() end
trial_noise[i] = noise
end
trial_i = 0
end
local function load_next_trial()
trial_i = trial_i + 1
print('loading trial', trial_i)
local W = nn.copy(base_params)
local noise = trial_noise[trial_i]
local devsqrt = sqrt(deviation)
for i, v in ipairs(base_params) do
W[i] = v + devsqrt * noise[i]
end
network:distribute(W)
end
local function do_reset()
print("resetting in state: "..get_state()..". reward:", reward)
local state = get_state()
-- be a little more descriptive.
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
print("resetting in state: "..state..". reward:", reward)
if trial_i > 0 then trial_rewards[trial_i] = reward end