fix softmax
This commit is contained in:
parent
7831f534c9
commit
c7c657513e
1 changed files with 3 additions and 1 deletions
4
nn.lua
4
nn.lua
|
@ -2,6 +2,7 @@ local ceil = math.ceil
|
||||||
local cos = math.cos
|
local cos = math.cos
|
||||||
local exp = math.exp
|
local exp = math.exp
|
||||||
local floor = math.floor
|
local floor = math.floor
|
||||||
|
local huge = math.huge
|
||||||
local insert = table.insert
|
local insert = table.insert
|
||||||
local ipairs = ipairs
|
local ipairs = ipairs
|
||||||
local log = math.log
|
local log = math.log
|
||||||
|
@ -26,6 +27,7 @@ local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) e
|
||||||
-- general utilities
|
-- general utilities
|
||||||
|
|
||||||
local function copy(t, out) -- shallow copy
|
local function copy(t, out) -- shallow copy
|
||||||
|
assert(type(t) == "table")
|
||||||
local out = out or {}
|
local out = out or {}
|
||||||
for k, v in pairs(t) do out[k] = v end
|
for k, v in pairs(t) do out[k] = v end
|
||||||
return out
|
return out
|
||||||
|
@ -559,7 +561,7 @@ function Softmax:forward(X)
|
||||||
if self.bs ~= bs then self:reset_cache(bs) end
|
if self.bs ~= bs then self:reset_cache(bs) end
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
|
|
||||||
local alpha = 0
|
local alpha = -huge
|
||||||
local num = {} -- TODO: cache
|
local num = {} -- TODO: cache
|
||||||
local den = 0
|
local den = 0
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue