add LayerNorm layer
This commit is contained in:
parent
feaf86dc6b
commit
deb1ea7de0
3 changed files with 87 additions and 35 deletions
63
config.lua
63
config.lua
|
@ -12,48 +12,62 @@ local function intmap(x)
|
|||
return math.pow(10, x / 2)
|
||||
end
|
||||
|
||||
local cfg = {
|
||||
log_fn = 'log.csv', -- can be nil to disable logging.
|
||||
|
||||
local common_cfg = {
|
||||
defer_prints = true,
|
||||
|
||||
playable_mode = false,
|
||||
start_big = false, --true
|
||||
starting_lives = 0, --1
|
||||
playback_mode = false,
|
||||
start_big = false,
|
||||
starting_lives = 0,
|
||||
|
||||
init_zeros = true, -- instead of he_normal noise or whatever.
|
||||
frameskip = 4,
|
||||
-- true greedy epsilon has both deterministic and det_epsilon set.
|
||||
deterministic = false, -- use argmax on outputs instead of random sampling.
|
||||
det_epsilon = false, -- take random actions with probability eps.
|
||||
layernorm = false,
|
||||
|
||||
init_zeros = true, -- instead of he_normal noise or whatever.
|
||||
graycode = false,
|
||||
epoch_trials = 64 * (7/8),
|
||||
epoch_top_trials = 40 * (7/8), -- new with ARS.
|
||||
unperturbed_trial = true, -- do a trial without any noise.
|
||||
negate_trials = true, -- try pairs of normal and negated noise directions.
|
||||
time_inputs = true, -- binary inputs of global frame count
|
||||
-- ^ note that this now doubles the effective trials.
|
||||
deviation = intmap(-3),
|
||||
learning_rate = intmap(-4),
|
||||
weight_decay = intmap(-6),
|
||||
time_inputs = true, -- binary inputs of global frame count
|
||||
|
||||
adamant = true, -- run steps through AMSgrad.
|
||||
adam_b1 = math.pow(10, -1 / 15),
|
||||
adam_b2 = math.pow(10, -1 / 100),
|
||||
adam_eps = intmap(-8),
|
||||
adam_debias = false,
|
||||
adamant = false, -- run steps through AMSgrad.
|
||||
|
||||
cap_time = 222, --400
|
||||
cap_time = 300,
|
||||
timer_loser = 1/2,
|
||||
decrement_reward = false, -- bad idea, encourages mario to kill himself
|
||||
decrement_reward = false, -- bad idea, encourages mario to run into goombas.
|
||||
}
|
||||
|
||||
playback_mode = false,
|
||||
local cfg = {
|
||||
log_fn = 'log.csv', -- can be nil to disable logging.
|
||||
params_fn = nil, -- can be nil to generate based on param count.
|
||||
|
||||
deterministic = true,
|
||||
|
||||
epoch_trials = 20,
|
||||
epoch_top_trials = 10,
|
||||
learning_rate = 1.0,
|
||||
|
||||
deviation = 0.1,
|
||||
weight_decay = 0.004,
|
||||
|
||||
cap_time = 300,
|
||||
starting_lives = 1,
|
||||
}
|
||||
|
||||
-- TODO: so, uhh..
|
||||
-- what happens when playback_mode is true but unperturbed_trial is false?
|
||||
|
||||
setmetatable(cfg, {
|
||||
__index = function(t, n)
|
||||
if common_cfg[n] ~= nil then return common_cfg[n] end
|
||||
if n == 'params_fn' then return nil end
|
||||
error("cannot use undeclared config '" .. tostring(n) .. "'", 2)
|
||||
end
|
||||
})
|
||||
|
||||
cfg.epoch_top_trials = math.min(cfg.epoch_trials, cfg.epoch_top_trials)
|
||||
|
||||
cfg.eps_start = 1.0 * cfg.frameskip / 64
|
||||
|
@ -62,8 +76,7 @@ cfg.eps_frames = 1000000
|
|||
cfg.enable_overlay = cfg.playable_mode
|
||||
cfg.enable_network = not cfg.playable_mode
|
||||
|
||||
return setmetatable(cfg, {
|
||||
__index = function(t, n)
|
||||
error("cannot use undeclared config '" .. tostring(n) .. "'", 2)
|
||||
end
|
||||
})
|
||||
assert(not cfg.ars_lips or cfg.unperturbed_trial,
|
||||
"cfg.unperturbed_trial must be true to use cfg.ars_lips")
|
||||
|
||||
return cfg
|
||||
|
|
1
main.lua
1
main.lua
|
@ -227,6 +227,7 @@ local function make_network(input_size)
|
|||
nn_ty:feed(nn_y)
|
||||
|
||||
nn_y = nn_y:feed(nn.Dense(128))
|
||||
if cfg.layernorm then nn_y = nn_y:feed(nn.LayerNorm()) end
|
||||
if cfg.deterministic then
|
||||
nn_y = nn_y:feed(nn.Relu())
|
||||
else
|
||||
|
|
58
nn.lua
58
nn.lua
|
@ -311,6 +311,7 @@ local Gelu = Layer:extend()
|
|||
local Dense = Layer:extend()
|
||||
local Softmax = Layer:extend()
|
||||
local Embed = Layer:extend()
|
||||
local LayerNorm = Layer:extend()
|
||||
|
||||
function Weights:init(weight_init)
|
||||
self.weight_init = weight_init
|
||||
|
@ -618,6 +619,42 @@ function Embed:forward(X)
|
|||
return Y
|
||||
end
|
||||
|
||||
function LayerNorm:init(eps)
|
||||
Layer.init(self, "LayerNorm")
|
||||
if eps == nil then eps = 1e-5 end
|
||||
assert(type(eps) == "number")
|
||||
self.eps = eps
|
||||
end
|
||||
|
||||
function LayerNorm:reset_cache(bs)
|
||||
self.bs = bs
|
||||
self.cache = cache(bs, self.shape_out)
|
||||
end
|
||||
|
||||
function LayerNorm:forward(X)
|
||||
local bs = checkshape(X, self.shape_in)
|
||||
if self.bs ~= bs then self:reset_cache(bs) end
|
||||
|
||||
local mean = 0
|
||||
for i, v in ipairs(X) do
|
||||
mean = mean + v / #X
|
||||
end
|
||||
|
||||
local var = 0
|
||||
for i, v in ipairs(X) do
|
||||
local delta = v - mean
|
||||
self.cache[i] = delta
|
||||
var = var + delta * delta / #X
|
||||
end
|
||||
|
||||
local std = sqrt(var + self.eps)
|
||||
for i, v in ipairs(self.cache) do
|
||||
self.cache[i] = v / std
|
||||
end
|
||||
|
||||
return self.cache
|
||||
end
|
||||
|
||||
function Model:init(nodes_in, nodes_out)
|
||||
assert(#nodes_in > 0, #nodes_in)
|
||||
assert(#nodes_out > 0, #nodes_out)
|
||||
|
@ -764,14 +801,15 @@ return {
|
|||
traverse = traverse,
|
||||
traverse_all = traverse_all,
|
||||
|
||||
Weights = Weights,
|
||||
Layer = Layer,
|
||||
Model = Model,
|
||||
Input = Input,
|
||||
Merge = Merge,
|
||||
Relu = Relu,
|
||||
Gelu = Gelu,
|
||||
Dense = Dense,
|
||||
Softmax = Softmax,
|
||||
Embed = Embed,
|
||||
Weights = Weights,
|
||||
Layer = Layer,
|
||||
Model = Model,
|
||||
Input = Input,
|
||||
Merge = Merge,
|
||||
Relu = Relu,
|
||||
Gelu = Gelu,
|
||||
Dense = Dense,
|
||||
Softmax = Softmax,
|
||||
Embed = Embed,
|
||||
LayerNorm = LayerNorm,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue