diff --git a/main.lua b/main.lua index a94cbf4..a9c8d25 100644 --- a/main.lua +++ b/main.lua @@ -71,6 +71,49 @@ local bad_states = { lose = true, } +local jp_lut = { + { -- none + up = false, down = false, left = false, right = false, + select = false, start = false, B = false, A = false, + }, { -- A + up = false, down = false, left = false, right = false, + select = false, start = false, B = false, A = true, + }, { -- L + up = false, down = false, left = true, right = false, + select = false, start = false, B = false, A = false, + }, { -- R + up = false, down = false, left = false, right = true, + select = false, start = false, B = false, A = false, + }, { -- L + B + up = false, down = false, left = true, right = false, + select = false, start = false, B = true, A = false, + }, { -- R + B + up = false, down = false, left = false, right = true, + select = false, start = false, B = true, A = false, + }, { -- L + A + up = false, down = false, left = true, right = false, + select = false, start = false, B = false, A = true, + }, { -- R + A + up = false, down = false, left = false, right = true, + select = false, start = false, B = false, A = true, + }, { -- L + A + B + up = false, down = false, left = true, right = false, + select = false, start = false, B = true, A = true, + }, { -- R + A + B + up = false, down = false, left = false, right = true, + select = false, start = false, B = true, A = true, + }, { -- D + up = false, down = true, left = false, right = false, + select = false, start = false, B = false, A = false, + }, { -- D + A + up = false, down = true, left = false, right = false, + select = false, start = false, B = false, A = true, + }, { -- U + up = true, down = false, left = false, right = false, + select = false, start = false, B = false, A = false, + }, +} + -- state. local epoch_i = 0 @@ -237,24 +280,20 @@ local function make_network(input_size, buttons) nn_y = nn.Merge() nn_x:feed(nn_y) nn_ty:feed(nn_y) - nn_z = {} if true then nn_y = nn_y:feed(nn.Dense(128)) - nn_y = nn_y:feed(nn.Gelu()) + nn_y = nn_y:feed(nn.Relu()) nn_y = nn_y:feed(nn.Dense(64)) - nn_y = nn_y:feed(nn.Gelu()) + nn_y = nn_y:feed(nn.Relu()) nn_y = nn_y:feed(nn.Dense(48)) - nn_y = nn_y:feed(nn.Gelu()) + nn_y = nn_y:feed(nn.Relu()) end - for i = 1, buttons do - nn_z[i] = nn_y - nn_z[i] = nn_z[i]:feed(nn.Dense(2)) - nn_z[i] = nn_z[i]:feed(nn.Softmax()) - end - - return nn.Model({nn_x, nn_tx}, nn_z) + nn_z = nn_y + nn_z = nn_z:feed(nn.Dense(#jp_lut)) + nn_z = nn_z:feed(nn.Softmax()) + return nn.Model({nn_x, nn_tx}, {nn_z}) end -- and here we go with the game stuff. @@ -792,38 +831,17 @@ local function doit(dummy) if enable_network and get_state() == 'playing' or ingame_paused then total_frames = total_frames + frameskip + -- TODO: reimplement this. local choose = deterministic and argmax2 or rchoice2 local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input}) - local softmaxed = { - outputs[nn_z[1]], - outputs[nn_z[2]], - outputs[nn_z[3]], - outputs[nn_z[4]], - outputs[nn_z[5]], - outputs[nn_z[6]], - learn_start_select and outputs[nn_z[7]] or dummy_softmax_values, - learn_start_select and outputs[nn_z[8]] or dummy_softmax_values, - } - - jp = { - up = choose(softmaxed[1]), - down = choose(softmaxed[2]), - left = choose(softmaxed[3]), - right = choose(softmaxed[4]), - A = choose(softmaxed[5]), - B = choose(softmaxed[6]), - start = choose(softmaxed[7]), - select = choose(softmaxed[8]), - } - - if det_epsilon then --and not trial_i == 0 then - local eps = lerp(eps_start, eps_stop, total_frames / eps_frames) - for k, v in pairs(jp) do - local ss_ok = k ~= 'start' and k ~= 'select' or learn_start_select - if random() < eps and ss_ok then jp[k] = rbool() end - end + local eps = lerp(eps_start, eps_stop, total_frames / eps_frames) + if det_epsilon and random() < eps then + local i = floor(random() * #jp_lut) + 1 + jp = nn.copy(jp_lut[i], jp) + else + jp = nn.copy(jp_lut[argmax(unpack(outputs[nn_z]))], jp) end if force_start then diff --git a/nn.lua b/nn.lua index 32ff0bd..3887684 100644 --- a/nn.lua +++ b/nn.lua @@ -25,10 +25,10 @@ local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) e -- general utilities -local function copy(t) -- shallow copy - local new_t = {} - for k, v in pairs(t) do new_t[k] = v end - return new_t +local function copy(t, out) -- shallow copy + local out = out or {} + for k, v in pairs(t) do out[k] = v end + return out end local function indexof(t, a)