add Huber loss
This commit is contained in:
parent
eb16377ba6
commit
9bb26b1ec5
2 changed files with 17 additions and 1 deletions
2
onn.py
2
onn.py
|
@ -1102,6 +1102,8 @@ def lookup_loss(maybe_name):
|
||||||
return Absolute()
|
return Absolute()
|
||||||
elif maybe_name == 'msee':
|
elif maybe_name == 'msee':
|
||||||
return SomethingElse()
|
return SomethingElse()
|
||||||
|
elif maybe_name == 'huber':
|
||||||
|
return Huber(delta=0.1)
|
||||||
raise Exception('unknown objective', maybe_name)
|
raise Exception('unknown objective', maybe_name)
|
||||||
|
|
||||||
def ritual_from_config(config, learner):
|
def ritual_from_config(config, learner):
|
||||||
|
|
14
onn_core.py
14
onn_core.py
|
@ -249,6 +249,20 @@ class Absolute(ResidualLoss):
|
||||||
def df(self, r):
|
def df(self, r):
|
||||||
return np.sign(r)
|
return np.sign(r)
|
||||||
|
|
||||||
|
class Huber(ResidualLoss):
|
||||||
|
def __init__(self, delta=1.0):
|
||||||
|
self.delta = _f(delta)
|
||||||
|
|
||||||
|
def f(self, r):
|
||||||
|
return np.where(r <= self.delta,
|
||||||
|
np.square(r) / 2,
|
||||||
|
self.delta * (np.abs(r) - self.delta / 2))
|
||||||
|
|
||||||
|
def df(self, r):
|
||||||
|
return np.where(r <= self.delta,
|
||||||
|
r,
|
||||||
|
self.delta * np.sign(r))
|
||||||
|
|
||||||
# Regularizers {{{1
|
# Regularizers {{{1
|
||||||
|
|
||||||
class Regularizer:
|
class Regularizer:
|
||||||
|
|
Loading…
Add table
Reference in a new issue