add Huber loss
This commit is contained in:
parent
eb16377ba6
commit
9bb26b1ec5
2 changed files with 17 additions and 1 deletions
2
onn.py
2
onn.py
|
@ -1102,6 +1102,8 @@ def lookup_loss(maybe_name):
|
|||
return Absolute()
|
||||
elif maybe_name == 'msee':
|
||||
return SomethingElse()
|
||||
elif maybe_name == 'huber':
|
||||
return Huber(delta=0.1)
|
||||
raise Exception('unknown objective', maybe_name)
|
||||
|
||||
def ritual_from_config(config, learner):
|
||||
|
|
14
onn_core.py
14
onn_core.py
|
@ -249,6 +249,20 @@ class Absolute(ResidualLoss):
|
|||
def df(self, r):
|
||||
return np.sign(r)
|
||||
|
||||
class Huber(ResidualLoss):
|
||||
def __init__(self, delta=1.0):
|
||||
self.delta = _f(delta)
|
||||
|
||||
def f(self, r):
|
||||
return np.where(r <= self.delta,
|
||||
np.square(r) / 2,
|
||||
self.delta * (np.abs(r) - self.delta / 2))
|
||||
|
||||
def df(self, r):
|
||||
return np.where(r <= self.delta,
|
||||
r,
|
||||
self.delta * np.sign(r))
|
||||
|
||||
# Regularizers {{{1
|
||||
|
||||
class Regularizer:
|
||||
|
|
Loading…
Add table
Reference in a new issue