tweaks and reordering
This commit is contained in:
parent
f7bee50d12
commit
6b193cac9b
1 changed files with 57 additions and 50 deletions
107
main.lua
107
main.lua
|
@ -24,17 +24,19 @@ function mt.__newindex(t, n, v) error("cannot assign undeclared global '" .. tos
|
||||||
function mt.__index(t, n) error("cannot use undeclared global '" .. tostring(n) .. "'", 2) end
|
function mt.__index(t, n) error("cannot use undeclared global '" .. tostring(n) .. "'", 2) end
|
||||||
local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
|
local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end
|
||||||
|
|
||||||
-- configuration and globals.
|
-- configuration.
|
||||||
|
|
||||||
--randomseed(11)
|
--randomseed(11)
|
||||||
|
|
||||||
local playable_mode = false
|
local playable_mode = false
|
||||||
--
|
--
|
||||||
local deterministic = false -- use argmax on outputs instead of random sampling.
|
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||||
|
local deterministic = true -- use argmax on outputs instead of random sampling.
|
||||||
local det_epsilon = true -- take random actions with probability eps.
|
local det_epsilon = true -- take random actions with probability eps.
|
||||||
local eps_start = 0.50
|
-- using parameters from DQN... sorta.
|
||||||
local eps_stop = 0.05
|
local eps_start = 1.0 * 1/60 -- i think this should be * 1/16 for atari ref.
|
||||||
local eps_frames = 60*60*60
|
local eps_stop = 0.1 * 1/60 -- "
|
||||||
|
local eps_frames = 1000000
|
||||||
local consider_past_rewards = false
|
local consider_past_rewards = false
|
||||||
local learn_start_select = false
|
local learn_start_select = false
|
||||||
--
|
--
|
||||||
|
@ -50,6 +52,27 @@ local enable_network = not playable_mode
|
||||||
|
|
||||||
local input_size = 281 -- TODO: let the script figure this out for us.
|
local input_size = 281 -- TODO: let the script figure this out for us.
|
||||||
|
|
||||||
|
local ok_routines = {
|
||||||
|
[0x4] = true, -- sliding down flagpole
|
||||||
|
[0x5] = true, -- end of level auto-walk
|
||||||
|
[0x7] = true, -- start of level auto-walk
|
||||||
|
[0x8] = true, -- normal (in control)
|
||||||
|
[0x9] = true, -- acquiring mushroom
|
||||||
|
[0xA] = true, -- losing big mario
|
||||||
|
[0xB] = true, -- uhh
|
||||||
|
[0xC] = true, -- acquiring fireflower
|
||||||
|
}
|
||||||
|
|
||||||
|
local bad_states = {
|
||||||
|
power = true,
|
||||||
|
waiting_demo = true,
|
||||||
|
playing_demo = true,
|
||||||
|
unknown = true,
|
||||||
|
lose = true,
|
||||||
|
}
|
||||||
|
|
||||||
|
-- state.
|
||||||
|
|
||||||
local epoch_i = 0
|
local epoch_i = 0
|
||||||
local base_params
|
local base_params
|
||||||
local trial_i = 0
|
local trial_i = 0
|
||||||
|
@ -83,25 +106,6 @@ local score_old
|
||||||
local once = false
|
local once = false
|
||||||
local reset = true
|
local reset = true
|
||||||
|
|
||||||
local ok_routines = {
|
|
||||||
[0x4] = true, -- sliding down flagpole
|
|
||||||
[0x5] = true, -- end of level auto-walk
|
|
||||||
[0x7] = true, -- start of level auto-walk
|
|
||||||
[0x8] = true, -- normal (in control)
|
|
||||||
[0x9] = true, -- acquiring mushroom
|
|
||||||
[0xA] = true, -- losing big mario
|
|
||||||
[0xB] = true, -- uhh
|
|
||||||
[0xC] = true, -- acquiring fireflower
|
|
||||||
}
|
|
||||||
|
|
||||||
local bad_states = {
|
|
||||||
power = true,
|
|
||||||
waiting_demo = true,
|
|
||||||
playing_demo = true,
|
|
||||||
unknown = true,
|
|
||||||
lose = true,
|
|
||||||
}
|
|
||||||
|
|
||||||
local state_old = ''
|
local state_old = ''
|
||||||
|
|
||||||
-- localize some stuff.
|
-- localize some stuff.
|
||||||
|
@ -508,6 +512,31 @@ local function handle_tiles()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function prepare_epoch()
|
||||||
|
print('preparing epoch '..tostring(epoch_i)..'. this might take a while.')
|
||||||
|
base_params = network:collect()
|
||||||
|
empty(trial_noise)
|
||||||
|
empty(trial_rewards)
|
||||||
|
for i = 1, epoch_trials do
|
||||||
|
local noise = nn.zeros(#base_params)
|
||||||
|
for j = 1, #base_params do noise[j] = nn.normal() end
|
||||||
|
trial_noise[i] = noise
|
||||||
|
end
|
||||||
|
trial_i = 0
|
||||||
|
end
|
||||||
|
|
||||||
|
local function load_next_trial()
|
||||||
|
trial_i = trial_i + 1
|
||||||
|
print('loading trial', trial_i)
|
||||||
|
local W = nn.copy(base_params)
|
||||||
|
local noise = trial_noise[trial_i]
|
||||||
|
local devsqrt = sqrt(deviation)
|
||||||
|
for i, v in ipairs(base_params) do
|
||||||
|
W[i] = v + devsqrt * noise[i]
|
||||||
|
end
|
||||||
|
network:distribute(W)
|
||||||
|
end
|
||||||
|
|
||||||
local function learn_from_epoch()
|
local function learn_from_epoch()
|
||||||
print()
|
print()
|
||||||
print('rewards:', trial_rewards)
|
print('rewards:', trial_rewards)
|
||||||
|
@ -571,33 +600,11 @@ local function learn_from_epoch()
|
||||||
print()
|
print()
|
||||||
end
|
end
|
||||||
|
|
||||||
local function prepare_epoch()
|
|
||||||
print('preparing epoch '..tostring(epoch_i)..'. this might take a while.')
|
|
||||||
base_params = network:collect()
|
|
||||||
empty(trial_noise)
|
|
||||||
empty(trial_rewards)
|
|
||||||
for i = 1, epoch_trials do
|
|
||||||
local noise = nn.zeros(#base_params)
|
|
||||||
for j = 1, #base_params do noise[j] = nn.normal() end
|
|
||||||
trial_noise[i] = noise
|
|
||||||
end
|
|
||||||
trial_i = 0
|
|
||||||
end
|
|
||||||
|
|
||||||
local function load_next_trial()
|
|
||||||
trial_i = trial_i + 1
|
|
||||||
print('loading trial', trial_i)
|
|
||||||
local W = nn.copy(base_params)
|
|
||||||
local noise = trial_noise[trial_i]
|
|
||||||
local devsqrt = sqrt(deviation)
|
|
||||||
for i, v in ipairs(base_params) do
|
|
||||||
W[i] = v + devsqrt * noise[i]
|
|
||||||
end
|
|
||||||
network:distribute(W)
|
|
||||||
end
|
|
||||||
|
|
||||||
local function do_reset()
|
local function do_reset()
|
||||||
print("resetting in state: "..get_state()..". reward:", reward)
|
local state = get_state()
|
||||||
|
-- be a little more descriptive.
|
||||||
|
if state == 'dead' and get_timer() == 0 then state = 'timeup' end
|
||||||
|
print("resetting in state: "..state..". reward:", reward)
|
||||||
|
|
||||||
if trial_i > 0 then trial_rewards[trial_i] = reward end
|
if trial_i > 0 then trial_rewards[trial_i] = reward end
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue