diff --git a/main.lua b/main.lua index 7d9c018..e517f73 100644 --- a/main.lua +++ b/main.lua @@ -26,30 +26,35 @@ local function globalize(t) for k, v in pairs(t) do rawset(_G, k, v) end end --randomseed(11) +local defer_prints = true + local playable_mode = false local start_big = true local starting_lives = 0 -- -local init_zeros = true -- instead of he_normal noise or whatever. +local init_zeros = false -- instead of he_normal noise or whatever. local frameskip = 4 -- true greedy epsilon has both deterministic and det_epsilon set. -local deterministic = true -- use argmax on outputs instead of random sampling. +local deterministic = false -- use argmax on outputs instead of random sampling. local det_epsilon = false -- take random actions with probability eps. local eps_start = 1.0 * frameskip / 64 local eps_stop = 0.1 * eps_start -local eps_frames = 2000000 +local eps_frames = 4000000 -- -local epoch_trials = 20 -local epoch_top_trials = 10 -- new with ARS. +local epoch_trials = 18 +local epoch_top_trials = 9 -- new with ARS. local unperturbed_trial = true -- do a trial without any noise. local negate_trials = true -- try pairs of normal and negated noise directions. -- ^ note that this now doubles the effective trials. -local learning_rate = 0.01 -local deviation = 0.06 +local deviation = 0.025 +local function approx_cossim(dim) + return math.pow(1.521 * dim - 0.521, -0.5026) +end +local learning_rate = 0.01 / approx_cossim(7051) -- -local cap_time = 100 --400 +local cap_time = 200 --400 local timer_loser = 0 --1/3 -local decrement_reward = true +local decrement_reward = false -- bad idea, encourages mario to kill himself -- local enable_overlay = playable_mode local enable_network = not playable_mode @@ -159,6 +164,7 @@ local once = false local reset = true local state_old = '' +local last_trial_state -- localize some stuff. @@ -222,6 +228,18 @@ local function argmax(...) return max_i end +local function softchoice(...) + local t = random() + local psum = 0 + for i=1, select("#", ...) do + local p = select(i, ...) + psum = psum + p + if t < psum then + return i + end + end +end + local function argmax2(t) return t[1] > t[2] end @@ -251,15 +269,14 @@ local function calc_mean_dev(x) dev = dev + delta * delta / #x end - return mean, dev + return mean, sqrt(dev) end local function normalize(x, out) out = out or x local mean, dev = calc_mean_dev(x) if dev <= 0 then dev = 1 end - local devs = sqrt(dev) - for i, v in ipairs(x) do out[i] = (v - mean) / devs end + for i, v in ipairs(x) do out[i] = (v - mean) / dev end return out end @@ -267,8 +284,7 @@ local function normalize_wrt(x, s, out) out = out or x local mean, dev = calc_mean_dev(s) if dev <= 0 then dev = 1 end - local devs = sqrt(dev) - for i, v in ipairs(x) do out[i] = (v - mean) / devs end + for i, v in ipairs(x) do out[i] = (v - mean) / dev end return out end @@ -287,10 +303,6 @@ local function make_network(input_size) nn_x:feed(nn_y) nn_ty:feed(nn_y) - nn_y = nn_y:feed(nn.Dense(128)) - --nn_y = nn_y:feed(nn.Gelu()) - nn_y = nn_y:feed(nn.Relu()) - nn_z = nn_y nn_z = nn_z:feed(nn.Dense(#jp_lut)) nn_z = nn_z:feed(nn.Softmax()) @@ -560,7 +572,9 @@ local function prepare_epoch() -- (the os.time() as of here) and calling nn.normal() each trial. for i = 1, epoch_trials do local noise = nn.zeros(#base_params) - for j = 1, #base_params do noise[j] = nn.normal() end + -- NOTE: change in implementation: deviation is multiplied here + -- and ONLY here now. + for j = 1, #base_params do noise[j] = deviation * nn.normal() end trial_noise[i] = noise end trial_i = -1 @@ -577,24 +591,24 @@ local function load_next_pair() if trial_i > 0 then if trial_neg then - print('trial', trial_i, 'positive') + if not defer_prints then print('trial', trial_i, 'positive') end local noise = trial_noise[trial_i] for i, v in ipairs(base_params) do - W[i] = v + deviation * noise[i] + W[i] = v + noise[i] end else trial_i = trial_i - 1 - print('trial', trial_i, 'negative') + if not defer_prints then print('trial', trial_i, 'positive') end local noise = trial_noise[trial_i] for i, v in ipairs(base_params) do - W[i] = v - deviation * noise[i] + W[i] = v - noise[i] end end trial_neg = not trial_neg else - print("test trial") + if not defer_prints then print("test trial") end end network:distribute(W) @@ -611,7 +625,7 @@ local function load_next_trial() print('loading trial', trial_i) local noise = trial_noise[trial_i] for i, v in ipairs(base_params) do - W[i] = v + deviation * noise[i] + W[i] = v + noise[i] end else print("test trial") @@ -702,12 +716,11 @@ local function learn_from_epoch() end print("top:", top_rewards) - local reward_mean, reward_dev = calc_mean_dev(top_rewards) - --print("mean, dev:", reward_mean, reward_dev) + local _, reward_dev = calc_mean_dev(top_rewards) + --print("mean, dev:", _, reward_dev) if reward_dev == 0 then reward_dev = 1 end for i, v in ipairs(top_rewards) do top_rewards[i] = v / reward_dev end - --print("scaled:", top_rewards) -- NOTE: step no longer directly incorporates learning_rate. for i = 1, epoch_trials do @@ -717,13 +730,14 @@ local function learn_from_epoch() local reward = pos - neg local noise = trial_noise[i] for j, v in ipairs(noise) do - step[j] = step[j] + (reward * v) / epoch_trials + step[j] = step[j] + reward * v / epoch_top_trials end end local step_mean, step_dev = calc_mean_dev(step) if abs(step_mean) > 1e-3 then print("step mean:", step_mean) end - print("step stddev:", step_dev) + --print("step stddev:", step_dev) + print("full step stddev:", learning_rate * step_dev) for i, v in ipairs(base_params) do base_params[i] = v + learning_rate * step[i] @@ -743,7 +757,27 @@ local function do_reset() local state = get_state() -- be a little more descriptive. if state == 'dead' and get_timer() == 0 then state = 'timeup' end - print("reward:", reward, "("..state..")") + + if trial_i >= 0 and defer_prints then + if trial_i == 0 then + print('test trial reward:', reward, "("..state..")") + elseif negate_trials then + --local dir = trial_neg and "negative" or "positive" + --print('trial', trial_i, dir, 'reward:', reward, "("..state..")") + + if trial_neg then + local pos = trial_rewards[#trial_rewards] + local neg = reward + local fmt = "trial %i rewards: %+i, %+i (%s, %s)" + print(fmt:format(trial_i, pos, neg, last_trial_state, state)) + end + last_trial_state = state + else + print('trial', trial_i, 'reward:', reward, "("..state..")") + end + else + print("reward:", reward, "("..state..")") + end if trial_i >= 0 then if trial_i == 0 or not negate_trials then @@ -777,8 +811,9 @@ local function do_reset() --max_time = min(log(epoch_i) * 10 + 100, cap_time) --max_time = min(8 * sqrt(360 / epoch_trials * (epoch_i - 1)) + 100, cap_time) - --max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time) + max_time = min(6 * sqrt(480 / epoch_trials * (epoch_i - 1)) + 60, cap_time) max_time = ceil(max_time) + --max_time = cap_time if once then savestate.load(startsave) @@ -916,9 +951,6 @@ local function doit(dummy) if enable_network and get_state() == 'playing' or ingame_paused then total_frames = total_frames + frameskip - -- TODO: reimplement this. - local choose = deterministic and argmax2 or rchoice2 - local outputs = network:forward({[nn_x]=X, [nn_tx]=tile_input}) local eps = lerp(eps_start, eps_stop, total_frames / eps_frames) @@ -926,7 +958,9 @@ local function doit(dummy) local i = floor(random() * #jp_lut) + 1 jp = nn.copy(jp_lut[i], jp) else - jp = nn.copy(jp_lut[argmax(unpack(outputs[nn_z]))], jp) + local choose = deterministic and argmax or softchoice + local ind = choose(unpack(outputs[nn_z])) + jp = nn.copy(jp_lut[ind], jp) end if force_start then