preliminary batches and backwards passes
also adds negated noise trials because i forgot to commit that earlier
This commit is contained in:
parent
3e3b4d9207
commit
3d64df0574
2 changed files with 298 additions and 55 deletions
18
main.lua
18
main.lua
|
@ -37,12 +37,12 @@ local det_epsilon = true -- take random actions with probability eps.
|
||||||
local eps_start = 1.0 * 1/60 -- i think this should be * 1/16 for atari ref.
|
local eps_start = 1.0 * 1/60 -- i think this should be * 1/16 for atari ref.
|
||||||
local eps_stop = 0.1 * 1/60 -- "
|
local eps_stop = 0.1 * 1/60 -- "
|
||||||
local eps_frames = 1000000
|
local eps_frames = 1000000
|
||||||
local consider_past_rewards = false
|
|
||||||
local learn_start_select = false
|
local learn_start_select = false
|
||||||
--
|
--
|
||||||
local epoch_trials = 40
|
local epoch_trials = 40
|
||||||
|
local negate_trials = true -- try pairs of normal and negated noise directions.
|
||||||
local unperturbed_trial = true -- do a trial without any noise.
|
local unperturbed_trial = true -- do a trial without any noise.
|
||||||
local learning_rate = 0.3 -- bigger now that i'm shaping trials etc.
|
local learning_rate = 0.3 -- bigger now that i'm stealing code etc.
|
||||||
local deviation = 0.05
|
local deviation = 0.05
|
||||||
--
|
--
|
||||||
local cap_time = 400
|
local cap_time = 400
|
||||||
|
@ -231,7 +231,7 @@ local nn_x
|
||||||
local nn_y
|
local nn_y
|
||||||
local nn_z
|
local nn_z
|
||||||
local function make_network(input_size, buttons)
|
local function make_network(input_size, buttons)
|
||||||
nn_x = nn.Input(input_size)
|
nn_x = nn.Input({input_size})
|
||||||
nn_y = nn_x
|
nn_y = nn_x
|
||||||
nn_z = {}
|
nn_z = {}
|
||||||
if false then
|
if false then
|
||||||
|
@ -526,7 +526,11 @@ local function prepare_epoch()
|
||||||
-- (the os.time() as of here) and calling nn.normal() each trial.
|
-- (the os.time() as of here) and calling nn.normal() each trial.
|
||||||
for i = 1, epoch_trials do
|
for i = 1, epoch_trials do
|
||||||
local noise = nn.zeros(#base_params)
|
local noise = nn.zeros(#base_params)
|
||||||
for j = 1, #base_params do noise[j] = nn.normal() end
|
if negate_trials and i % 2 == 0 then -- every other iteration...
|
||||||
|
for j = 1, #base_params do noise[j] = -trial_noise[i-1][j] end
|
||||||
|
else
|
||||||
|
for j = 1, #base_params do noise[j] = nn.normal() end
|
||||||
|
end
|
||||||
trial_noise[i] = noise
|
trial_noise[i] = noise
|
||||||
end
|
end
|
||||||
trial_i = -1
|
trial_i = -1
|
||||||
|
@ -535,8 +539,6 @@ end
|
||||||
local function load_next_trial()
|
local function load_next_trial()
|
||||||
trial_i = trial_i + 1
|
trial_i = trial_i + 1
|
||||||
local W = nn.copy(base_params)
|
local W = nn.copy(base_params)
|
||||||
local noise = trial_noise[trial_i]
|
|
||||||
local devsqrt = sqrt(deviation)
|
|
||||||
if trial_i == 0 and not unperturbed_trial then
|
if trial_i == 0 and not unperturbed_trial then
|
||||||
trial_i = 1
|
trial_i = 1
|
||||||
end
|
end
|
||||||
|
@ -824,7 +826,7 @@ while true do
|
||||||
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
for i, v in ipairs(sprite_input) do insert(X, v / 256) end
|
||||||
for i, v in ipairs(tile_input) do insert(X, v / 256) end
|
for i, v in ipairs(tile_input) do insert(X, v / 256) end
|
||||||
for i, v in ipairs(extra_input) do insert(X, v / 256) end
|
for i, v in ipairs(extra_input) do insert(X, v / 256) end
|
||||||
if #X ~= input_size then error("input size should be: "..tostring(#X)) end
|
nn.reshape(X, 1, input_size)
|
||||||
|
|
||||||
if enable_network and get_state() == 'playing' or ingame_paused then
|
if enable_network and get_state() == 'playing' or ingame_paused then
|
||||||
local choose = deterministic and argmax2 or rchoice2
|
local choose = deterministic and argmax2 or rchoice2
|
||||||
|
@ -858,7 +860,7 @@ while true do
|
||||||
select = choose(softmaxed[8]),
|
select = choose(softmaxed[8]),
|
||||||
}
|
}
|
||||||
|
|
||||||
if det_epsilon then
|
if det_epsilon then --and not trial_i == 0 then
|
||||||
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
|
local eps = lerp(eps_start, eps_stop, total_frames / eps_frames)
|
||||||
for k, v in pairs(jp) do
|
for k, v in pairs(jp) do
|
||||||
local ss_ok = k ~= 'start' and k ~= 'select' or learn_start_select
|
local ss_ok = k ~= 'start' and k ~= 'select' or learn_start_select
|
||||||
|
|
335
nn.lua
335
nn.lua
|
@ -1,5 +1,7 @@
|
||||||
|
local ceil = math.ceil
|
||||||
local cos = math.cos
|
local cos = math.cos
|
||||||
local exp = math.exp
|
local exp = math.exp
|
||||||
|
local floor = math.floor
|
||||||
local insert = table.insert
|
local insert = table.insert
|
||||||
local ipairs = ipairs
|
local ipairs = ipairs
|
||||||
local log = math.log
|
local log = math.log
|
||||||
|
@ -18,6 +20,9 @@ local unpack = table.unpack or unpack
|
||||||
|
|
||||||
local Base = require("Base")
|
local Base = require("Base")
|
||||||
|
|
||||||
|
-- hacks
|
||||||
|
local function helpme() print(debug.traceback('helpme', 2):gsub("\n", "\r\n")) end
|
||||||
|
|
||||||
-- general utilities
|
-- general utilities
|
||||||
|
|
||||||
local function copy(t) -- shallow copy
|
local function copy(t) -- shallow copy
|
||||||
|
@ -54,13 +59,23 @@ local function normal() -- box muller
|
||||||
end
|
end
|
||||||
|
|
||||||
local function zeros(n, out)
|
local function zeros(n, out)
|
||||||
local out = out or {}
|
out = out or {}
|
||||||
|
if type(n) == 'table' then
|
||||||
|
local shape = n
|
||||||
|
n = prod(shape)
|
||||||
|
out.shape = shape
|
||||||
|
end
|
||||||
for i = 1, n do out[i] = 0 end
|
for i = 1, n do out[i] = 0 end
|
||||||
return out
|
return out
|
||||||
end
|
end
|
||||||
|
|
||||||
local function arange(n, out)
|
local function arange(n, out)
|
||||||
out = out or {}
|
out = out or {}
|
||||||
|
if type(n) == 'table' then
|
||||||
|
local shape = n
|
||||||
|
n = prod(shape)
|
||||||
|
out.shape = shape
|
||||||
|
end
|
||||||
for i = 1, n do out[i] = i - 1 end
|
for i = 1, n do out[i] = i - 1 end
|
||||||
return out
|
return out
|
||||||
end
|
end
|
||||||
|
@ -92,6 +107,143 @@ local function init_he_normal(t, fan_in, fan_out)
|
||||||
return t
|
return t
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- ndarray-ish stuff and more involved math
|
||||||
|
|
||||||
|
local function pp(t, fmt, sep, ti, di, depth, isfirst, islast)
|
||||||
|
-- pretty-prints an nd-array.
|
||||||
|
fmt = fmt or '%10.7f,'
|
||||||
|
sep = sep or ','
|
||||||
|
ti = ti or 0
|
||||||
|
di = di or 1
|
||||||
|
depth = depth or 0
|
||||||
|
|
||||||
|
if t.shape == nil then
|
||||||
|
local s = '['
|
||||||
|
for i = 1, #t do s = s..fmt:format(t[i]) end
|
||||||
|
return s..']'..sep..'\n'
|
||||||
|
end
|
||||||
|
|
||||||
|
local dim = t.shape[di]
|
||||||
|
|
||||||
|
local ti_step = 1
|
||||||
|
for dj = di + 1, #t.shape do ti_step = ti_step * t.shape[dj] end
|
||||||
|
|
||||||
|
local indent = ''
|
||||||
|
for i = 1, depth do indent = indent..' ' end
|
||||||
|
|
||||||
|
local s = ''
|
||||||
|
if di ~= #t.shape then
|
||||||
|
if isfirst then s = s..indent..'[\n' else s = s..'[\n' end
|
||||||
|
for i = 1, dim do
|
||||||
|
s = s..pp(t, fmt, sep, ti, di + 1, depth + 1, i == 1, i == dim)
|
||||||
|
ti = ti + ti_step
|
||||||
|
end
|
||||||
|
if islast then s = s..indent..']'..sep..'\n' else s = s..indent..']'..sep end
|
||||||
|
else
|
||||||
|
s = s..indent..'['
|
||||||
|
for i = ti + 1, ti + dim do s = s..fmt:format(t[i])..sep end
|
||||||
|
s = s..']'..sep..'\n'
|
||||||
|
end
|
||||||
|
return s
|
||||||
|
end
|
||||||
|
|
||||||
|
local function ppi(t, n, ...)
|
||||||
|
-- TODO: determine maximum number of digits if n is omitted.
|
||||||
|
n = n or 1
|
||||||
|
return pp(t, '%'..tostring(n)..'i', ' ', ...)
|
||||||
|
end
|
||||||
|
|
||||||
|
local function checkshape_helper(shape, isbatch)
|
||||||
|
local s = '{ '
|
||||||
|
if not isbatch then
|
||||||
|
s = s..'n, '
|
||||||
|
end
|
||||||
|
for i, v in ipairs(shape) do
|
||||||
|
if not isbatch or i > 1 then
|
||||||
|
s = s..tostring(v)..(i ~= #shape and ', ' or ' ')
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return s..'}'
|
||||||
|
end
|
||||||
|
|
||||||
|
local function checkshape(batch, shape)
|
||||||
|
assert(type(batch) == 'table', "batch is not an array")
|
||||||
|
assert(batch.shape ~= nil, "batch is missing a shape")
|
||||||
|
if #batch.shape == 1 then
|
||||||
|
error("batch shape is incomplete", 2)
|
||||||
|
end
|
||||||
|
for n=1, #shape do
|
||||||
|
if batch.shape[n+1] ~= shape[n] then
|
||||||
|
local s1 = checkshape_helper(batch.shape, true)
|
||||||
|
local s2 = checkshape_helper(shape, false)
|
||||||
|
error("shapes do not match: "..s1.." ~= "..s2, 2)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return batch.shape[1]
|
||||||
|
end
|
||||||
|
|
||||||
|
local function reshape(a, ...)
|
||||||
|
local new_shape = {...}
|
||||||
|
assert(#a == prod(new_shape), "new shape does not fit size")
|
||||||
|
a.shape = new_shape
|
||||||
|
return a
|
||||||
|
end
|
||||||
|
|
||||||
|
local function cache(bs, shape)
|
||||||
|
if bs == nil then return nil end
|
||||||
|
local fullshape = copy(shape)
|
||||||
|
insert(fullshape, bs, 1)
|
||||||
|
return zeros(fullshape)
|
||||||
|
end
|
||||||
|
|
||||||
|
local function dot(a, b, ax_a, ax_b, out)
|
||||||
|
ax_a = ax_a or #a.shape - 0
|
||||||
|
ax_b = ax_b or #b.shape - 1
|
||||||
|
|
||||||
|
assert(a.shape[ax_a] == b.shape[ax_b], "dotted axes do not match")
|
||||||
|
local dim = a.shape[ax_a]
|
||||||
|
|
||||||
|
local out_shape = {}
|
||||||
|
for di = 1, #a.shape do if di ~= ax_a then insert(out_shape, a.shape[di]) end end
|
||||||
|
for di = 1, #b.shape do if di ~= ax_b then insert(out_shape, b.shape[di]) end end
|
||||||
|
|
||||||
|
if out == nil then
|
||||||
|
out = zeros(prod(out_shape))
|
||||||
|
else
|
||||||
|
assert(prod(out_shape) == #out, "given output is the wrong size")
|
||||||
|
end
|
||||||
|
out.shape = out_shape
|
||||||
|
|
||||||
|
local a0 = 1
|
||||||
|
local a1 = 1
|
||||||
|
local b0 = 1
|
||||||
|
local b1 = 1
|
||||||
|
for di = 1, ax_a - 1 do a0 = a0 * a.shape[di] end
|
||||||
|
for di = 1, ax_b - 1 do b0 = b0 * b.shape[di] end
|
||||||
|
for di = ax_a + 1, #a.shape do a1 = a1 * a.shape[di] end
|
||||||
|
for di = ax_b + 1, #b.shape do b1 = b1 * b.shape[di] end
|
||||||
|
|
||||||
|
local o = 1
|
||||||
|
local i_end = a0 * dim - 1
|
||||||
|
local k_end = b0 * dim - 1
|
||||||
|
for i = 0, i_end, dim do for j = 1, a1 do
|
||||||
|
for k = 0, k_end, dim do for m = 1, b1 do
|
||||||
|
local res = 0
|
||||||
|
local x = i + j
|
||||||
|
local y = k + m
|
||||||
|
for d = 1, dim do
|
||||||
|
res = res + a[x] * b[y]
|
||||||
|
x = x + a1
|
||||||
|
y = y + b1
|
||||||
|
end
|
||||||
|
out[o] = res
|
||||||
|
o = o + 1
|
||||||
|
end end
|
||||||
|
end end
|
||||||
|
|
||||||
|
return out
|
||||||
|
end
|
||||||
|
|
||||||
-- nodal
|
-- nodal
|
||||||
|
|
||||||
local function traverse(node_in, node_out, nodes, dummy_mode)
|
local function traverse(node_in, node_out, nodes, dummy_mode)
|
||||||
|
@ -171,17 +323,17 @@ function Layer:init(name)
|
||||||
self.parents = {}
|
self.parents = {}
|
||||||
self.children = {}
|
self.children = {}
|
||||||
self.weights = {}
|
self.weights = {}
|
||||||
--self.size_in = nil
|
--self.shape_in = nil
|
||||||
--self.size_out = nil
|
--self.shape_out = nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function Layer:make_shape(parent)
|
function Layer:make_shape(parent)
|
||||||
if self.size_in == nil then self.size_in = parent.size_out end
|
if self.shape_in == nil then self.shape_in = parent.shape_out end
|
||||||
if self.size_out == nil then self.size_out = self.size_in end
|
if self.shape_out == nil then self.shape_out = self.shape_in end
|
||||||
end
|
end
|
||||||
|
|
||||||
function Layer:feed(child)
|
function Layer:feed(child)
|
||||||
assert(self.size_out ~= nil)
|
assert(self.shape_out ~= nil)
|
||||||
child:make_shape(self)
|
child:make_shape(self)
|
||||||
insert(self.children, child)
|
insert(self.children, child)
|
||||||
insert(child.parents, self)
|
insert(child.parents, self)
|
||||||
|
@ -216,12 +368,18 @@ function Layer:init_weights()
|
||||||
for i, w in ipairs(self.weights) do
|
for i, w in ipairs(self.weights) do
|
||||||
--print("allocating weights", i, "of", self.name)
|
--print("allocating weights", i, "of", self.name)
|
||||||
for j, v in ipairs(w) do w[j] = nil end -- FIXME: HACK
|
for j, v in ipairs(w) do w[j] = nil end -- FIXME: HACK
|
||||||
w:allocate(self.size_in, self.size_out)
|
w:allocate(prod(self.shape_in), prod(self.shape_out))
|
||||||
end
|
end
|
||||||
|
self:reset_cache()
|
||||||
|
end
|
||||||
|
|
||||||
|
function Layer:reset_cache(bs)
|
||||||
|
self.bs = bs
|
||||||
end
|
end
|
||||||
|
|
||||||
function Layer:_propagate(edges, deterministic)
|
function Layer:_propagate(edges, deterministic)
|
||||||
assert(#edges == 1, #edges) -- override this if you need multiple parents.
|
-- override this if you need multiple parents.
|
||||||
|
assert(#edges == 1, ("%s edges for node %s (expected 1)"):format(#edges, self.name))
|
||||||
if deterministic then
|
if deterministic then
|
||||||
return self:forward_deterministic(edges[1])
|
return self:forward_deterministic(edges[1])
|
||||||
else
|
else
|
||||||
|
@ -237,25 +395,25 @@ function Layer:propagate(values, deterministic)
|
||||||
insert(edges, X)
|
insert(edges, X)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
assert(#edges > 0, #edges)
|
assert(#edges > 0, ("%s edges for node %s (expected >0)"):format(#edges, self.name))
|
||||||
local Y = self:_propagate(edges, deterministic)
|
local Y = self:_propagate(edges, deterministic)
|
||||||
return Y
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
function Input:init(size)
|
function Input:init(shape)
|
||||||
Layer.init(self, "Input")
|
Layer.init(self, "Input")
|
||||||
assert(type(size) == 'number')
|
assert(type(shape) == 'table')
|
||||||
self.size_in = size
|
self.shape_in = shape
|
||||||
self.size_out = size
|
self.shape_out = shape
|
||||||
end
|
end
|
||||||
|
|
||||||
function Input:forward(X)
|
function Input:forward(X)
|
||||||
assert(#X == self.size_in)
|
checkshape(X, self.shape_in)
|
||||||
return X
|
return X
|
||||||
end
|
end
|
||||||
|
|
||||||
function Input:backward(dY)
|
function Input:backward(dY)
|
||||||
assert(#dY == self.size_out)
|
checkshape(dY, self.shape_out)
|
||||||
return zeros(#dY)
|
return zeros(#dY)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -263,99 +421,177 @@ function Relu:init()
|
||||||
Layer.init(self, "Relu")
|
Layer.init(self, "Relu")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Relu:reset_cache(bs)
|
||||||
|
print("clearing cache:", self.name, bs)
|
||||||
|
self.bs = bs
|
||||||
|
|
||||||
|
self.cache = cache(bs, self.shape_out)
|
||||||
|
self.dcache = cache(bs, self.shape_in)
|
||||||
|
end
|
||||||
|
|
||||||
function Relu:forward(X)
|
function Relu:forward(X)
|
||||||
assert(#X == self.size_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
self.cache = self.cache or zeros(self.size_out)
|
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
|
|
||||||
for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end
|
for i = 1, #X do Y[i] = X[i] >= 0 and X[i] or 0 end
|
||||||
|
|
||||||
assert(#Y == self.size_out)
|
checkshape(Y, self.shape_out)
|
||||||
return Y
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
function Relu:backward(dY)
|
function Relu:backward(dY)
|
||||||
assert(#dY == self.size_out)
|
local bs = checkshape(dY, self.shape_out)
|
||||||
self.dcache = self.dcache or zeros(self.size_in)
|
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
local dX = self.dcache
|
local dX = self.dcache
|
||||||
|
|
||||||
for i = 1, #dY do dX[i] = Y[i] >= 0 and dY[i] or 0 end
|
for i = 1, #dY do dX[i] = Y[i] >= 0 and dY[i] or 0 end
|
||||||
|
|
||||||
assert(#Y == self.size_in)
|
checkshape(dX, self.shape_in)
|
||||||
return Y
|
return dX
|
||||||
end
|
end
|
||||||
|
|
||||||
function Gelu:init()
|
function Gelu:init()
|
||||||
Layer.init(self, "Gelu")
|
Layer.init(self, "Gelu")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Gelu:reset_cache(bs)
|
||||||
|
print("clearing cache:", self.name, bs)
|
||||||
|
self.bs = bs
|
||||||
|
|
||||||
|
self.cache = cache(bs, self.shape_out)
|
||||||
|
self.cache_a = cache(bs, self.shape_out)
|
||||||
|
self.cache_sig = cache(bs, self.shape_out)
|
||||||
|
self.dcache = cache(bs, self.shape_in)
|
||||||
|
end
|
||||||
|
|
||||||
function Gelu:forward(X)
|
function Gelu:forward(X)
|
||||||
assert(#X == self.size_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
self.cache = self.cache or zeros(self.size_out)
|
if bs ~= self.bs then self:reset_cache(bs) end
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
|
local a = self.cache_a
|
||||||
|
local sig = self.cache_sig
|
||||||
|
|
||||||
-- NOTE: approximate form of GELU exploiting similarities to sigmoid curve.
|
-- NOTE: approximate form of GELU exploiting similarities to sigmoid curve.
|
||||||
for i = 1, #X do
|
for i = 1, #X do
|
||||||
Y[i] = X[i] / (1 + exp(-1.704 * X[i]))
|
a[i] = 1.704 * X[i]
|
||||||
|
sig[i] = 1 / (1 + exp(-a[i]))
|
||||||
|
Y[i] = X[i] * sig[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
assert(#Y == self.size_out)
|
checkshape(Y, self.shape_out)
|
||||||
return Y
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Gelu:backward(dY)
|
||||||
|
checkshape(dY, self.shape_out)
|
||||||
|
local Y = self.cache
|
||||||
|
local a = self.cache_a
|
||||||
|
local sig = self.cache_sig
|
||||||
|
local dX = self.dcache
|
||||||
|
|
||||||
|
for i = 1, #dY do
|
||||||
|
dX[i] = dY[i] * sig[i] * (1 + a[i] * (1 - sig[i]))
|
||||||
|
end
|
||||||
|
|
||||||
|
checkshape(dX, self.shape_in)
|
||||||
|
return dX
|
||||||
|
end
|
||||||
|
|
||||||
function Dense:init(dim)
|
function Dense:init(dim)
|
||||||
Layer.init(self, "Dense")
|
Layer.init(self, "Dense")
|
||||||
assert(type(dim) == "number")
|
assert(type(dim) == "number")
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.size_out = dim
|
self.shape_out = {dim}
|
||||||
self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but...
|
self.coeffs = self:_new_weights(init_he_normal) -- should be normal, but...
|
||||||
self.biases = self:_new_weights(init_zeros)
|
self.biases = self:_new_weights(init_zeros)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Dense:make_shape(parent)
|
function Dense:make_shape(parent)
|
||||||
self.size_in = parent.size_out
|
self.shape_in = parent.shape_out
|
||||||
self.coeffs.shape = {self.size_in, self.dim}
|
self.coeffs.shape = {self.shape_in[#self.shape_in], self.dim}
|
||||||
self.biases.shape = self.dim
|
self.biases.shape = {1, self.dim}
|
||||||
|
end
|
||||||
|
|
||||||
|
function Dense:reset_cache(bs)
|
||||||
|
print("clearing cache:", self.name, bs)
|
||||||
|
self.bs = bs
|
||||||
|
|
||||||
|
self.cache = cache(bs, self.shape_out)
|
||||||
|
self.cache_x = cache(bs, self.shape_in)
|
||||||
|
self.dcache = cache(bs, self.shape_in)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Dense:forward(X)
|
function Dense:forward(X)
|
||||||
assert(#X == self.size_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
self.cache = self.cache or zeros(self.size_out)
|
if self.bs ~= bs then self:reset_cache(bs) end
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
|
|
||||||
for i = 1, self.dim do
|
for i = 1, #X do
|
||||||
local res = 0
|
-- only needed for backwards pass.
|
||||||
local c = (i - 1) * #X
|
self.cache_x[i] = X[i]
|
||||||
for j = 1, #X do
|
|
||||||
res = res + X[j] * self.coeffs[c + j]
|
|
||||||
end
|
|
||||||
Y[i] = res + self.biases[i]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
assert(#Y == self.size_out)
|
--dot_1aab(X, self.coeffs, Y)
|
||||||
|
dot(X, self.coeffs, 2, 1, Y)
|
||||||
|
|
||||||
|
for i = 1, self.dim do
|
||||||
|
Y[i] = Y[i] + self.biases[i]
|
||||||
|
end
|
||||||
|
|
||||||
|
checkshape(Y, self.shape_out)
|
||||||
return Y
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Dense:backward(dY)
|
||||||
|
local X = self.cache_x
|
||||||
|
local dX = self.dcache
|
||||||
|
|
||||||
|
--dot_ta(X, dY, self.coeffs.g)
|
||||||
|
dot(X, dY, 1, 1, self.coeffs.g)
|
||||||
|
|
||||||
|
for b = 1, X.shape[1] do
|
||||||
|
local l = X.shape[2]
|
||||||
|
local j = (b - 1) * l
|
||||||
|
for i = 1, l do self.biases.g[i] = self.biases.g[i] + dY[j-1+i] end
|
||||||
|
end
|
||||||
|
|
||||||
|
--dot_tb(dY, self.coeffs, dX)
|
||||||
|
dot(dY, self.coeffs, 2, 2, dX)
|
||||||
|
|
||||||
|
checkshape(dX, self.shape_in)
|
||||||
|
return dX
|
||||||
|
end
|
||||||
|
|
||||||
function Softmax:init()
|
function Softmax:init()
|
||||||
Layer.init(self, "Softmax")
|
Layer.init(self, "Softmax")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function Softmax:reset_cache(bs)
|
||||||
|
print("clearing cache:", self.name, bs)
|
||||||
|
self.bs = bs
|
||||||
|
|
||||||
|
self.cache = cache(bs, self.shape_out)
|
||||||
|
end
|
||||||
|
|
||||||
function Softmax:forward(X)
|
function Softmax:forward(X)
|
||||||
assert(#X == self.size_in)
|
local bs = checkshape(X, self.shape_in)
|
||||||
self.cache = self.cache or zeros(self.size_out)
|
if self.bs ~= bs then self:reset_cache(bs) end
|
||||||
local Y = self.cache
|
local Y = self.cache
|
||||||
|
|
||||||
local alpha = 0
|
local alpha = 0
|
||||||
local num = {} -- TODO: cache
|
local num = {} -- TODO: cache
|
||||||
local den = 0
|
local den = 0
|
||||||
|
|
||||||
for i = 1, #X do alpha = max(alpha, X[i]) end
|
for b = 1, X.shape[1] do
|
||||||
for i = 1, #X do num[i] = exp(X[i] - alpha) end
|
local l = X.shape[2]
|
||||||
for i = 1, #X do den = den + num[i] end
|
local j = (b - 1) * l
|
||||||
for i = 1, #X do Y[i] = num[i] / den end
|
for i = j+1, j+l do alpha = max(alpha, X[i]) end
|
||||||
|
for i = j+1, j+l do num[i] = exp(X[i] - alpha) end
|
||||||
|
for i = j+1, j+l do den = den + num[i] end
|
||||||
|
for i = j+1, j+l do Y[i] = num[i] / den end
|
||||||
|
end
|
||||||
|
|
||||||
assert(#Y == self.size_out)
|
checkshape(Y, self.shape_out)
|
||||||
return Y
|
return Y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -392,6 +628,7 @@ function Model:forward(inputs)
|
||||||
if contains(self.nodes_in, node) then
|
if contains(self.nodes_in, node) then
|
||||||
local X = inputs[node]
|
local X = inputs[node]
|
||||||
assert(X ~= nil, ("missing input for node %s"):format(node.name))
|
assert(X ~= nil, ("missing input for node %s"):format(node.name))
|
||||||
|
assert(X.shape, ("missing shape for node %s"):format(node.name))
|
||||||
values[node] = node:_propagate({X})
|
values[node] = node:_propagate({X})
|
||||||
else
|
else
|
||||||
values[node] = node:propagate(values)
|
values[node] = node:propagate(values)
|
||||||
|
@ -491,6 +728,10 @@ return {
|
||||||
init_zeros = init_zeros,
|
init_zeros = init_zeros,
|
||||||
init_he_uniform = init_he_uniform,
|
init_he_uniform = init_he_uniform,
|
||||||
init_he_normal = init_he_normal,
|
init_he_normal = init_he_normal,
|
||||||
|
reshape = reshape,
|
||||||
|
pp = pp,
|
||||||
|
ppi = ppi,
|
||||||
|
dot = dot,
|
||||||
traverse = traverse,
|
traverse = traverse,
|
||||||
traverse_all = traverse_all,
|
traverse_all = traverse_all,
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue