diff --git a/config.lua b/config.lua index dce1815..accbf81 100644 --- a/config.lua +++ b/config.lua @@ -83,5 +83,7 @@ cfg.enable_network = not cfg.playable_mode assert(not cfg.ars_lips or cfg.unperturbed_trial, "cfg.unperturbed_trial must be true to use cfg.ars_lips") +assert(not cfg.ars_lips or cfg.negate_trials, + "cfg.negate_trials must be true to use cfg.ars_lips") return cfg diff --git a/main.lua b/main.lua index add137f..37668f9 100644 --- a/main.lua +++ b/main.lua @@ -14,9 +14,9 @@ local trial_neg = true local trial_noise = {} local trial_rewards = {} local trials_remaining = 0 -local mom1 -- first moments in AMSgrad. -local mom2 -- second moments in AMSgrad. -local mom2max -- running element-wise maximum of mom2. +local mom1 -- first moments in AMSgrad. +local mom2 -- second moments in AMSgrad. +local mom2max -- running element-wise maximum of mom2. local trial_frames = 0 local total_frames = 0 diff --git a/nn.lua b/nn.lua index eac6dbf..bb3e828 100644 --- a/nn.lua +++ b/nn.lua @@ -67,9 +67,8 @@ local function arange(n, out) return out end -local function allocate(t, out, init) +local function allocate(size, out, init) out = out or {} - local size = t if init ~= nil then return init(zeros(size, out)) else diff --git a/util.lua b/util.lua index a5d5c7f..735c047 100644 --- a/util.lua +++ b/util.lua @@ -8,7 +8,7 @@ local min = math.min local pairs = pairs local random = math.random local select = select -local sqrt= math.sqrt +local sqrt = math.sqrt local function signbyte(x) if x >= 128 then x = 256 - x end