select outputs from array instead of binary combinations
This commit is contained in:
parent
88dcd203a1
commit
d384635000
2 changed files with 61 additions and 43 deletions
94
main.lua
94
main.lua
|
@ -71,6 +71,49 @@ local bad_states = {
|
|||
lose = true,
|
||||
}
|
||||
|
||||
local jp_lut = {
|
||||
{ -- none
|
||||
up = false, down = false, left = false, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- A
|
||||
up = false, down = false, left = false, right = false,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- L
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- R
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- L + B
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = true, A = false,
|
||||
}, { -- R + B
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = true, A = false,
|
||||
}, { -- L + A
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- R + A
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- L + A + B
|
||||
up = false, down = false, left = true, right = false,
|
||||
select = false, start = false, B = true, A = true,
|
||||
}, { -- R + A + B
|
||||
up = false, down = false, left = false, right = true,
|
||||
select = false, start = false, B = true, A = true,
|
||||
}, { -- D
|
||||
up = false, down = true, left = false, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
}, { -- D + A
|
||||
up = false, down = true, left = false, right = false,
|
||||
select = false, start = false, B = false, A = true,
|
||||
}, { -- U
|
||||
up = true, down = false, left = false, right = false,
|
||||
select = false, start = false, B = false, A = false,
|
||||
},
|
||||
}
|
||||
|
||||
-- state.
|
||||
|
||||
local epoch_i = 0
|
||||
|
@ -237,24 +280,20 @@ local function make_network(input_size, buttons)
|
|||
nn_y = nn.Merge()
|
||||
nn_x:feed(nn_y)
|
||||
nn_ty:feed(nn_y)
|
||||
nn_z = {}
|
||||
|
||||
if true then
|
||||
nn_y = nn_y:feed(nn.Dense(128))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
nn_y = nn_y:feed(nn.Dense(64))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
nn_y = nn_y:feed(nn.Dense(48))
|
||||
nn_y = nn_y:feed(nn.Gelu())
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
end
|
||||
|
||||
for i = 1, buttons do
|
||||
nn_z[i] = nn_y
|
||||
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
|
||||
nn_z[i] = nn_z[i]:feed(nn.Softmax())
|
||||
end
|
||||
|
||||
return nn.Model({nn_x, nn_tx}, nn_z)
|
||||
nn_z = nn_y
|
||||
nn_z = nn_z:feed(nn.Dense(#jp_lut))
|
||||
nn_z = nn_z:feed(nn.Softmax())
|
||||
return nn.Model({nn_x, nn_tx}, {nn_z})
|
||||
end
|
||||
|
||||
-- and here we go with the game stuff.
|
||||
|
@ -792,38 +831,17 @@ local function doit(dummy)
|
|||
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||
total_frames = total_frames + frameskip
|
||||
|
||||
-- TODO: reimplement this.
|
||||
local choose = deterministic and argmax2 or rchoice2
|
||||
|
||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
||||
|
||||
local softmaxed = {
|
||||
outputs[nn_z[1]],
|
||||
outputs[nn_z[2]],
|
||||
outputs[nn_z[3]],
|
||||
outputs[nn_z[4]],
|
||||
outputs[nn_z[5]],
|
||||
outputs[nn_z[6]],
|
||||
learn_start_select and outputs[nn_z[7]] or dummy_softmax_values,
|
||||
learn_start_select and outputs[nn_z[8]] or dummy_softmax_values,
|
||||
}
|
||||
|
||||
jp = {
|
||||
up = choose(softmaxed[1]),
|
||||
down = choose(softmaxed[2]),
|
||||
left = choose(softmaxed[3]),
|
||||
right = choose(softmaxed[4]),
|
||||
A = choose(softmaxed[5]),
|
||||
B = choose(softmaxed[6]),
|
||||
start = choose(softmaxed[7]),
|
||||
select = choose(softmaxed[8]),
|
||||
}
|
||||
|
||||
if det_epsilon then --and not trial_i == 0 then
|
||||
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
|
||||
for k, v in pairs(jp) do
|
||||
local ss_ok = k ~= 'start' and k ~= 'select' or learn_start_select
|
||||
if random() < eps and ss_ok then jp[k] = rbool() end
|
||||
end
|
||||
if det_epsilon and random() < eps then
|
||||
local i = floor(random() * #jp_lut) + 1
|
||||
jp = nn.copy(jp_lut[i], jp)
|
||||
else
|
||||
jp = nn.copy(jp_lut[argmax(unpack(outputs[nn_z]))], jp)
|
||||
end
|
||||
|
||||
if force_start then
|
||||
|
|
8
nn.lua
8
nn.lua
|
@ -25,10 +25,10 @@ local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) e
|
|||
|
||||
-- general utilities
|
||||
|
||||
local function copy(t) -- shallow copy
|
||||
local new_t = {}
|
||||
for k, v in pairs(t) do new_t[k] = v end
|
||||
return new_t
|
||||
local function copy(t, out) -- shallow copy
|
||||
local out = out or {}
|
||||
for k, v in pairs(t) do out[k] = v end
|
||||
return out
|
||||
end
|
||||
|
||||
local function indexof(t, a)
|
||||
|
|
Loading…
Reference in a new issue