add LayerNorm layer

This commit is contained in:
Connor Olding 2018-05-07 05:55:58 +02:00
parent feaf86dc6b
commit deb1ea7de0
3 changed files with 87 additions and 35 deletions

View file

@ -12,48 +12,62 @@ local function intmap(x)
return math.pow(10, x / 2)
end
local cfg = {
log_fn = 'log.csv', -- can be nil to disable logging.
local common_cfg = {
defer_prints = true,
playable_mode = false,
start_big = false, --true
starting_lives = 0, --1
playback_mode = false,
start_big = false,
starting_lives = 0,
init_zeros = true, -- instead of he_normal noise or whatever.
frameskip = 4,
-- true greedy epsilon has both deterministic and det_epsilon set.
deterministic = false, -- use argmax on outputs instead of random sampling.
det_epsilon = false, -- take random actions with probability eps.
layernorm = false,
init_zeros = true, -- instead of he_normal noise or whatever.
graycode = false,
epoch_trials = 64 * (7/8),
epoch_top_trials = 40 * (7/8), -- new with ARS.
unperturbed_trial = true, -- do a trial without any noise.
negate_trials = true, -- try pairs of normal and negated noise directions.
time_inputs = true, -- binary inputs of global frame count
-- ^ note that this now doubles the effective trials.
deviation = intmap(-3),
learning_rate = intmap(-4),
weight_decay = intmap(-6),
time_inputs = true, -- binary inputs of global frame count
adamant = true, -- run steps through AMSgrad.
adam_b1 = math.pow(10, -1 / 15),
adam_b2 = math.pow(10, -1 / 100),
adam_eps = intmap(-8),
adam_debias = false,
adamant = false, -- run steps through AMSgrad.
cap_time = 222, --400
cap_time = 300,
timer_loser = 1/2,
decrement_reward = false, -- bad idea, encourages mario to kill himself
decrement_reward = false, -- bad idea, encourages mario to run into goombas.
}
playback_mode = false,
local cfg = {
log_fn = 'log.csv', -- can be nil to disable logging.
params_fn = nil, -- can be nil to generate based on param count.
deterministic = true,
epoch_trials = 20,
epoch_top_trials = 10,
learning_rate = 1.0,
deviation = 0.1,
weight_decay = 0.004,
cap_time = 300,
starting_lives = 1,
}
-- TODO: so, uhh..
-- what happens when playback_mode is true but unperturbed_trial is false?
setmetatable(cfg, {
__index = function(t, n)
if common_cfg[n] ~= nil then return common_cfg[n] end
if n == 'params_fn' then return nil end
error("cannot use undeclared config '" .. tostring(n) .. "'", 2)
end
})
cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials)
cfg.eps_start = 1.0 * cfg.frameskip / 64
@ -62,8 +76,7 @@ cfg.eps_frames = 1000000
cfg.enable_overlay = cfg.playable_mode
cfg.enable_network = not cfg.playable_mode
return setmetatable(cfg, {
__index = function(t, n)
error("cannot use undeclared config '" .. tostring(n) .. "'", 2)
end
})
assert(not cfg.ars_lips or cfg.unperturbed_trial,
"cfg.unperturbed_trial must be true to use cfg.ars_lips")
return cfg

View file

@ -227,6 +227,7 @@ local function make_network(input_size)
nn_ty:feed(nn_y)
nn_y = nn_y:feed(nn.Dense(128))
if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end
if cfg.deterministic then
nn_y = nn_y:feed(nn.Relu())
else

58
nn.lua
View file

@ -311,6 +311,7 @@ local Gelu = Layer:extend()
local Dense = Layer:extend()
local Softmax = Layer:extend()
local Embed = Layer:extend()
local LayerNorm = Layer:extend()
function Weights:init(weight_init)
self.weight_init = weight_init
@ -618,6 +619,42 @@ function Embed:forward(X)
return Y
end
function LayerNorm:init(eps)
Layer.init(self, "LayerNorm")
if eps == nil then eps = 1e-5 end
assert(type(eps) == "number")
self.eps = eps
end
function LayerNorm:reset_cache(bs)
self.bs = bs
self.cache = cache(bs, self.shape_out)
end
function LayerNorm:forward(X)
local bs = checkshape(X, self.shape_in)
if self.bs ~= bs then self:reset_cache(bs) end
local mean = 0
for i, v in ipairs(X) do
mean = mean + v / #X
end
local var = 0
for i, v in ipairs(X) do
local delta = v - mean
self.cache[i] = delta
var = var + delta * delta / #X
end
local std = sqrt(var + self.eps)
for i, v in ipairs(self.cache) do
self.cache[i] = v / std
end
return self.cache
end
function Model:init(nodes_in, nodes_out)
assert(#nodes_in > 0, #nodes_in)
assert(#nodes_out > 0, #nodes_out)
@ -764,14 +801,15 @@ return {
traverse = traverse,
traverse_all = traverse_all,
Weights = Weights,
Layer = Layer,
Model = Model,
Input = Input,
Merge = Merge,
Relu = Relu,
Gelu = Gelu,
Dense = Dense,
Softmax = Softmax,
Embed = Embed,
Weights = Weights,
Layer = Layer,
Model = Model,
Input = Input,
Merge = Merge,
Relu = Relu,
Gelu = Gelu,
Dense = Dense,
Softmax = Softmax,
Embed = Embed,
LayerNorm = LayerNorm,
}