select outputs from array instead of binary combinations

This commit is contained in:
Connor Olding 2017-09-08 10:43:32 +00:00
parent 88dcd203a1
commit d384635000
2 changed files with 61 additions and 43 deletions

View file

@ -71,6 +71,49 @@ local bad_states = {
lose = true, lose = true,
} }
local jp_lut = {
{ -- none
up = false, down = false, left = false, right = false,
select = false, start = false, B = false, A = false,
}, { -- A
up = false, down = false, left = false, right = false,
select = false, start = false, B = false, A = true,
}, { -- L
up = false, down = false, left = true, right = false,
select = false, start = false, B = false, A = false,
}, { -- R
up = false, down = false, left = false, right = true,
select = false, start = false, B = false, A = false,
}, { -- L + B
up = false, down = false, left = true, right = false,
select = false, start = false, B = true, A = false,
}, { -- R + B
up = false, down = false, left = false, right = true,
select = false, start = false, B = true, A = false,
}, { -- L + A
up = false, down = false, left = true, right = false,
select = false, start = false, B = false, A = true,
}, { -- R + A
up = false, down = false, left = false, right = true,
select = false, start = false, B = false, A = true,
}, { -- L + A + B
up = false, down = false, left = true, right = false,
select = false, start = false, B = true, A = true,
}, { -- R + A + B
up = false, down = false, left = false, right = true,
select = false, start = false, B = true, A = true,
}, { -- D
up = false, down = true, left = false, right = false,
select = false, start = false, B = false, A = false,
}, { -- D + A
up = false, down = true, left = false, right = false,
select = false, start = false, B = false, A = true,
}, { -- U
up = true, down = false, left = false, right = false,
select = false, start = false, B = false, A = false,
},
}
-- state. -- state.
local epoch_i = 0 local epoch_i = 0
@ -237,24 +280,20 @@ local function make_network(input_size, buttons)
nn_y = nn.Merge() nn_y = nn.Merge()
nn_x:feed(nn_y) nn_x:feed(nn_y)
nn_ty:feed(nn_y) nn_ty:feed(nn_y)
nn_z = {}
if true then if true then
nn_y = nn_y:feed(nn.Dense(128)) nn_y = nn_y:feed(nn.Dense(128))
nn_y = nn_y:feed(nn.Gelu()) nn_y = nn_y:feed(nn.Relu())
nn_y = nn_y:feed(nn.Dense(64)) nn_y = nn_y:feed(nn.Dense(64))
nn_y = nn_y:feed(nn.Gelu()) nn_y = nn_y:feed(nn.Relu())
nn_y = nn_y:feed(nn.Dense(48)) nn_y = nn_y:feed(nn.Dense(48))
nn_y = nn_y:feed(nn.Gelu()) nn_y = nn_y:feed(nn.Relu())
end end
for i = 1, buttons do nn_z = nn_y
nn_z[i] = nn_y nn_z = nn_z:feed(nn.Dense(#jp_lut))
nn_z[i] = nn_z[i]:feed(nn.Dense(2)) nn_z = nn_z:feed(nn.Softmax())
nn_z[i] = nn_z[i]:feed(nn.Softmax()) return nn.Model({nn_x, nn_tx}, {nn_z})
end
return nn.Model({nn_x, nn_tx}, nn_z)
end end
-- and here we go with the game stuff. -- and here we go with the game stuff.
@ -792,38 +831,17 @@ local function doit(dummy)
if enable_network and get_state() == 'playing' or ingame_paused then if enable_network and get_state() == 'playing' or ingame_paused then
total_frames = total_frames + frameskip total_frames = total_frames + frameskip
-- TODO: reimplement this.
local choose = deterministic and argmax2 or rchoice2 local choose = deterministic and argmax2 or rchoice2
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input}) local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
local softmaxed = {
outputs[nn_z[1]],
outputs[nn_z[2]],
outputs[nn_z[3]],
outputs[nn_z[4]],
outputs[nn_z[5]],
outputs[nn_z[6]],
learn_start_select and outputs[nn_z[7]] or dummy_softmax_values,
learn_start_select and outputs[nn_z[8]] or dummy_softmax_values,
}
jp = {
up = choose(softmaxed[1]),
down = choose(softmaxed[2]),
left = choose(softmaxed[3]),
right = choose(softmaxed[4]),
A = choose(softmaxed[5]),
B = choose(softmaxed[6]),
start = choose(softmaxed[7]),
select = choose(softmaxed[8]),
}
if det_epsilon then --and not trial_i == 0 then
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames) local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
for k, v in pairs(jp) do if det_epsilon and random() < eps then
local ss_ok = k ~= 'start' and k ~= 'select' or learn_start_select local i = floor(random() * #jp_lut) + 1
if random() < eps and ss_ok then jp[k] = rbool() end jp = nn.copy(jp_lut[i], jp)
end else
jp = nn.copy(jp_lut[argmax(unpack(outputs[nn_z]))], jp)
end end
if force_start then if force_start then

8
nn.lua
View file

@ -25,10 +25,10 @@ local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) e
-- general utilities -- general utilities
local function copy(t) -- shallow copy local function copy(t, out) -- shallow copy
local new_t = {} local out = out or {}
for k, v in pairs(t) do new_t[k] = v end for k, v in pairs(t) do out[k] = v end
return new_t return out
end end
local function indexof(t, a) local function indexof(t, a)