diff --git a/nn.lua b/nn.lua index 85b9a10..49e7a63 100644 --- a/nn.lua +++ b/nn.lua @@ -2,6 +2,7 @@ local ceil = math.ceil local cos = math.cos local exp = math.exp local floor = math.floor +local huge = math.huge local insert = table.insert local ipairs = ipairs local log = math.log @@ -26,6 +27,7 @@ local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) e -- general utilities local function copy(t, out) -- shallow copy + assert(type(t) == "table") local out = out or {} for k, v in pairs(t) do out[k] = v end return out @@ -559,7 +561,7 @@ function Softmax:forward(X) if self.bs ~= bs then self:reset_cache(bs) end local Y = self.cache - local alpha = 0 + local alpha = -huge local num = {} -- TODO: cache local den = 0