2018-01-21 14:04:25 -08:00
|
|
|
import numpy as np
|
|
|
|
|
2018-01-22 11:40:36 -08:00
|
|
|
|
2018-01-21 14:04:25 -08:00
|
|
|
class Weights:
|
|
|
|
# we may or may not contain weights -- or any information, for that matter.
|
|
|
|
|
|
|
|
def __init__(self, **kwargs):
|
2018-01-22 11:40:36 -08:00
|
|
|
self.f = None # forward weights
|
|
|
|
self.g = None # backward weights (gradients)
|
2018-01-21 14:04:25 -08:00
|
|
|
self.shape = None
|
|
|
|
self.init = None
|
|
|
|
self.allocator = None
|
|
|
|
self.regularizer = None
|
|
|
|
self._allocated = False
|
|
|
|
|
|
|
|
self.configure(**kwargs)
|
|
|
|
|
|
|
|
def configure(self, **kwargs):
|
|
|
|
for k, v in kwargs.items():
|
2018-01-22 11:40:36 -08:00
|
|
|
getattr(self, k) # ensures the key already exists
|
2018-01-21 14:04:25 -08:00
|
|
|
setattr(self, k, v)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def size(self):
|
|
|
|
assert self.shape is not None
|
|
|
|
return np.prod(self.shape)
|
|
|
|
|
|
|
|
def allocate(self, *args, **kwargs):
|
|
|
|
if self._allocated:
|
|
|
|
raise Exception("attempted to allocate existing weights")
|
|
|
|
self.configure(**kwargs)
|
|
|
|
|
|
|
|
# intentionally not using isinstance
|
|
|
|
assert type(self.shape) == tuple, self.shape
|
|
|
|
|
|
|
|
f, g = self.allocator(self.size)
|
|
|
|
assert len(f) == self.size, "{} != {}".format(f.shape, self.size)
|
|
|
|
assert len(g) == self.size, "{} != {}".format(g.shape, self.size)
|
|
|
|
f[:] = self.init(self.size, *args)
|
|
|
|
g[:] = self.init(self.size, *args)
|
|
|
|
self.f = f.reshape(self.shape)
|
|
|
|
self.g = g.reshape(self.shape)
|
|
|
|
|
|
|
|
self._allocated = True
|
|
|
|
|
|
|
|
def forward(self):
|
|
|
|
if self.regularizer is None:
|
|
|
|
return 0.0
|
|
|
|
return self.regularizer.forward(self.f)
|
|
|
|
|
|
|
|
def backward(self):
|
|
|
|
if self.regularizer is None:
|
|
|
|
return 0.0
|
|
|
|
return self.regularizer.backward(self.f)
|
|
|
|
|
|
|
|
def update(self):
|
|
|
|
if self.regularizer is None:
|
|
|
|
return
|
|
|
|
self.g += self.regularizer.backward(self.f)
|