tweaks and reordering
This commit is contained in:
parent
f7bee50d12
commit
6b193cac9b
1 changed files with 57 additions and 50 deletions
107
main.lua
107
main.lua
|
@ -24,17 +24,19 @@ function mt.__newindex(t, n, v) error("cannot assign undeclared global '" .. tos
|
|||
function mt.__index(t, n) error("cannot use undeclared global '" .. tostring(n) .. "'", 2) end
|
||||
local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
|
||||
|
||||
-- configuration and globals.
|
||||
-- configuration.
|
||||
|
||||
--randomseed(11)
|
||||
|
||||
local playable_mode = false
|
||||
--
|
||||
local deterministic = false -- use argmax on outputs instead of random sampling.
|
||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||
local deterministic = true -- use argmax on outputs instead of random sampling.
|
||||
local det_epsilon = true -- take random actions with probability eps.
|
||||
local eps_start = 0.50
|
||||
local eps_stop = 0.05
|
||||
local eps_frames = 60*60*60
|
||||
-- using parameters from DQN... sorta.
|
||||
local eps_start = 1.0 * 1/60 -- i think this should be * 1/16 for atari ref.
|
||||
local eps_stop = 0.1 * 1/60 -- "
|
||||
local eps_frames = 1000000
|
||||
local consider_past_rewards = false
|
||||
local learn_start_select = false
|
||||
--
|
||||
|
@ -50,6 +52,27 @@ local enable_network = not playable_mode
|
|||
|
||||
local input_size = 281 -- TODO: let the script figure this out for us.
|
||||
|
||||
local ok_routines = {
|
||||
[0x4] = true, -- sliding down flagpole
|
||||
[0x5] = true, -- end of level auto-walk
|
||||
[0x7] = true, -- start of level auto-walk
|
||||
[0x8] = true, -- normal (in control)
|
||||
[0x9] = true, -- acquiring mushroom
|
||||
[0xA] = true, -- losing big mario
|
||||
[0xB] = true, -- uhh
|
||||
[0xC] = true, -- acquiring fireflower
|
||||
}
|
||||
|
||||
local bad_states = {
|
||||
power = true,
|
||||
waiting_demo = true,
|
||||
playing_demo = true,
|
||||
unknown = true,
|
||||
lose = true,
|
||||
}
|
||||
|
||||
-- state.
|
||||
|
||||
local epoch_i = 0
|
||||
local base_params
|
||||
local trial_i = 0
|
||||
|
@ -83,25 +106,6 @@ local score_old
|
|||
local once = false
|
||||
local reset = true
|
||||
|
||||
local ok_routines = {
|
||||
[0x4] = true, -- sliding down flagpole
|
||||
[0x5] = true, -- end of level auto-walk
|
||||
[0x7] = true, -- start of level auto-walk
|
||||
[0x8] = true, -- normal (in control)
|
||||
[0x9] = true, -- acquiring mushroom
|
||||
[0xA] = true, -- losing big mario
|
||||
[0xB] = true, -- uhh
|
||||
[0xC] = true, -- acquiring fireflower
|
||||
}
|
||||
|
||||
local bad_states = {
|
||||
power = true,
|
||||
waiting_demo = true,
|
||||
playing_demo = true,
|
||||
unknown = true,
|
||||
lose = true,
|
||||
}
|
||||
|
||||
local state_old = ''
|
||||
|
||||
-- localize some stuff.
|
||||
|
@ -508,6 +512,31 @@ local function handle_tiles()
|
|||
end
|
||||
end
|
||||
|
||||
local function prepare_epoch()
|
||||
print('preparing epoch '..tostring(epoch_i)..'. this might take a while.')
|
||||
base_params = network:collect()
|
||||
empty(trial_noise)
|
||||
empty(trial_rewards)
|
||||
for i = 1, epoch_trials do
|
||||
local noise = nn.zeros(#base_params)
|
||||
for j = 1, #base_params do noise[j] = nn.normal() end
|
||||
trial_noise[i] = noise
|
||||
end
|
||||
trial_i = 0
|
||||
end
|
||||
|
||||
local function load_next_trial()
|
||||
trial_i = trial_i + 1
|
||||
print('loading trial', trial_i)
|
||||
local W = nn.copy(base_params)
|
||||
local noise = trial_noise[trial_i]
|
||||
local devsqrt = sqrt(deviation)
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v + devsqrt * noise[i]
|
||||
end
|
||||
network:distribute(W)
|
||||
end
|
||||
|
||||
local function learn_from_epoch()
|
||||
print()
|
||||
print('rewards:', trial_rewards)
|
||||
|
@ -571,33 +600,11 @@ local function learn_from_epoch()
|
|||
print()
|
||||
end
|
||||
|
||||
local function prepare_epoch()
|
||||
print('preparing epoch '..tostring(epoch_i)..'. this might take a while.')
|
||||
base_params = network:collect()
|
||||
empty(trial_noise)
|
||||
empty(trial_rewards)
|
||||
for i = 1, epoch_trials do
|
||||
local noise = nn.zeros(#base_params)
|
||||
for j = 1, #base_params do noise[j] = nn.normal() end
|
||||
trial_noise[i] = noise
|
||||
end
|
||||
trial_i = 0
|
||||
end
|
||||
|
||||
local function load_next_trial()
|
||||
trial_i = trial_i + 1
|
||||
print('loading trial', trial_i)
|
||||
local W = nn.copy(base_params)
|
||||
local noise = trial_noise[trial_i]
|
||||
local devsqrt = sqrt(deviation)
|
||||
for i, v in ipairs(base_params) do
|
||||
W[i] = v + devsqrt * noise[i]
|
||||
end
|
||||
network:distribute(W)
|
||||
end
|
||||
|
||||
local function do_reset()
|
||||
print("resetting in state: "..get_state()..". reward:", reward)
|
||||
local state = get_state()
|
||||
-- be a little more descriptive.
|
||||
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
||||
print("resetting in state: "..state..". reward:", reward)
|
||||
|
||||
if trial_i > 0 then trial_rewards[trial_i] = reward end
|
||||
|
||||
|
|
Loading…
Reference in a new issue