select outputs from array instead of binary combinations
This commit is contained in:
parent
88dcd203a1
commit
d384635000
2 changed files with 61 additions and 43 deletions
94
main.lua
94
main.lua
|
@ -71,6 +71,49 @@ local bad_states = {
|
||||||
lose = true,
|
lose = true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
local jp_lut = {
|
||||||
|
{ -- none
|
||||||
|
up = false, down = false, left = false, right = false,
|
||||||
|
select = false, start = false, B = false, A = false,
|
||||||
|
}, { -- A
|
||||||
|
up = false, down = false, left = false, right = false,
|
||||||
|
select = false, start = false, B = false, A = true,
|
||||||
|
}, { -- L
|
||||||
|
up = false, down = false, left = true, right = false,
|
||||||
|
select = false, start = false, B = false, A = false,
|
||||||
|
}, { -- R
|
||||||
|
up = false, down = false, left = false, right = true,
|
||||||
|
select = false, start = false, B = false, A = false,
|
||||||
|
}, { -- L + B
|
||||||
|
up = false, down = false, left = true, right = false,
|
||||||
|
select = false, start = false, B = true, A = false,
|
||||||
|
}, { -- R + B
|
||||||
|
up = false, down = false, left = false, right = true,
|
||||||
|
select = false, start = false, B = true, A = false,
|
||||||
|
}, { -- L + A
|
||||||
|
up = false, down = false, left = true, right = false,
|
||||||
|
select = false, start = false, B = false, A = true,
|
||||||
|
}, { -- R + A
|
||||||
|
up = false, down = false, left = false, right = true,
|
||||||
|
select = false, start = false, B = false, A = true,
|
||||||
|
}, { -- L + A + B
|
||||||
|
up = false, down = false, left = true, right = false,
|
||||||
|
select = false, start = false, B = true, A = true,
|
||||||
|
}, { -- R + A + B
|
||||||
|
up = false, down = false, left = false, right = true,
|
||||||
|
select = false, start = false, B = true, A = true,
|
||||||
|
}, { -- D
|
||||||
|
up = false, down = true, left = false, right = false,
|
||||||
|
select = false, start = false, B = false, A = false,
|
||||||
|
}, { -- D + A
|
||||||
|
up = false, down = true, left = false, right = false,
|
||||||
|
select = false, start = false, B = false, A = true,
|
||||||
|
}, { -- U
|
||||||
|
up = true, down = false, left = false, right = false,
|
||||||
|
select = false, start = false, B = false, A = false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
-- state.
|
-- state.
|
||||||
|
|
||||||
local epoch_i = 0
|
local epoch_i = 0
|
||||||
|
@ -237,24 +280,20 @@ local function make_network(input_size, buttons)
|
||||||
nn_y = nn.Merge()
|
nn_y = nn.Merge()
|
||||||
nn_x:feed(nn_y)
|
nn_x:feed(nn_y)
|
||||||
nn_ty:feed(nn_y)
|
nn_ty:feed(nn_y)
|
||||||
nn_z = {}
|
|
||||||
|
|
||||||
if true then
|
if true then
|
||||||
nn_y = nn_y:feed(nn.Dense(128))
|
nn_y = nn_y:feed(nn.Dense(128))
|
||||||
nn_y = nn_y:feed(nn.Gelu())
|
nn_y = nn_y:feed(nn.Relu())
|
||||||
nn_y = nn_y:feed(nn.Dense(64))
|
nn_y = nn_y:feed(nn.Dense(64))
|
||||||
nn_y = nn_y:feed(nn.Gelu())
|
nn_y = nn_y:feed(nn.Relu())
|
||||||
nn_y = nn_y:feed(nn.Dense(48))
|
nn_y = nn_y:feed(nn.Dense(48))
|
||||||
nn_y = nn_y:feed(nn.Gelu())
|
nn_y = nn_y:feed(nn.Relu())
|
||||||
end
|
end
|
||||||
|
|
||||||
for i = 1, buttons do
|
nn_z = nn_y
|
||||||
nn_z[i] = nn_y
|
nn_z = nn_z:feed(nn.Dense(#jp_lut))
|
||||||
nn_z[i] = nn_z[i]:feed(nn.Dense(2))
|
nn_z = nn_z:feed(nn.Softmax())
|
||||||
nn_z[i] = nn_z[i]:feed(nn.Softmax())
|
return nn.Model({nn_x, nn_tx}, {nn_z})
|
||||||
end
|
|
||||||
|
|
||||||
return nn.Model({nn_x, nn_tx}, nn_z)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
-- and here we go with the game stuff.
|
-- and here we go with the game stuff.
|
||||||
|
@ -792,38 +831,17 @@ local function doit(dummy)
|
||||||
if enable_network and get_state() == 'playing' or ingame_paused then
|
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||||
total_frames = total_frames + frameskip
|
total_frames = total_frames + frameskip
|
||||||
|
|
||||||
|
-- TODO: reimplement this.
|
||||||
local choose = deterministic and argmax2 or rchoice2
|
local choose = deterministic and argmax2 or rchoice2
|
||||||
|
|
||||||
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input})
|
||||||
|
|
||||||
local softmaxed = {
|
|
||||||
outputs[nn_z[1]],
|
|
||||||
outputs[nn_z[2]],
|
|
||||||
outputs[nn_z[3]],
|
|
||||||
outputs[nn_z[4]],
|
|
||||||
outputs[nn_z[5]],
|
|
||||||
outputs[nn_z[6]],
|
|
||||||
learn_start_select and outputs[nn_z[7]] or dummy_softmax_values,
|
|
||||||
learn_start_select and outputs[nn_z[8]] or dummy_softmax_values,
|
|
||||||
}
|
|
||||||
|
|
||||||
jp = {
|
|
||||||
up = choose(softmaxed[1]),
|
|
||||||
down = choose(softmaxed[2]),
|
|
||||||
left = choose(softmaxed[3]),
|
|
||||||
right = choose(softmaxed[4]),
|
|
||||||
A = choose(softmaxed[5]),
|
|
||||||
B = choose(softmaxed[6]),
|
|
||||||
start = choose(softmaxed[7]),
|
|
||||||
select = choose(softmaxed[8]),
|
|
||||||
}
|
|
||||||
|
|
||||||
if det_epsilon then --and not trial_i == 0 then
|
|
||||||
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
|
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
|
||||||
for k, v in pairs(jp) do
|
if det_epsilon and random() < eps then
|
||||||
local ss_ok = k ~= 'start' and k ~= 'select' or learn_start_select
|
local i = floor(random() * #jp_lut) + 1
|
||||||
if random() < eps and ss_ok then jp[k] = rbool() end
|
jp = nn.copy(jp_lut[i], jp)
|
||||||
end
|
else
|
||||||
|
jp = nn.copy(jp_lut[argmax(unpack(outputs[nn_z]))], jp)
|
||||||
end
|
end
|
||||||
|
|
||||||
if force_start then
|
if force_start then
|
||||||
|
|
8
nn.lua
8
nn.lua
|
@ -25,10 +25,10 @@ local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) e
|
||||||
|
|
||||||
-- general utilities
|
-- general utilities
|
||||||
|
|
||||||
local function copy(t) -- shallow copy
|
local function copy(t, out) -- shallow copy
|
||||||
local new_t = {}
|
local out = out or {}
|
||||||
for k, v in pairs(t) do new_t[k] = v end
|
for k, v in pairs(t) do out[k] = v end
|
||||||
return new_t
|
return out
|
||||||
end
|
end
|
||||||
|
|
||||||
local function indexof(t, a)
|
local function indexof(t, a)
|
||||||
|
|
Loading…
Reference in a new issue