You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The squared log error objective function produces NaN values during training, if the predicted value x is less than or equal to -1.0, as you correctly write in your documentation. However, this is not a mathematical necessity but a consequence of your implementation. You may exploit the identity
log(1 + x) == 0.5 * log((1 + x) * (1 + x))
or as you may prefer
log1p(x) == 0.5 * log1p(x * (2.0 + x))
which yields NaN if and only if x == -1.0 but a meaningful value otherwise. The gradient and the Hessian (and hence the training) become stable for values x < -1.0, too. I included my custom implementation of the squared log error objective (and its associated metric) in the code below. It works well in my case and I would like to share it with you. My implementation also uses the identity
log(a) - log(b) == log(a / b)
But this is not essential. Take it or leave it. Best wishes, Ralf.
"""This module defines custom objectives."""fromabcimportABCfromabcimportabstractmethodimportnumpyasnpimportxgboostasxgbclassObjective(ABC):
""" The interface for a custom objective and its associated metric. """@abstractmethoddefgradient(self, pred: np.ndarray, data: xgb.DMatrix) ->np.ndarray:
""" Returns the gradient of the objective. :param pred: The predicted values. :param data: The predictor values. :return: The gradient. """@abstractmethoddefhessian(self, pred: np.ndarray, data: xgb.DMatrix) ->np.ndarray:
""" Returns the Hessian of the objective. :param pred: The predicted values. :param data: The predictor values. :return: The Hessian. """@abstractmethoddefmetric(
self, pred: np.ndarray, data: xgb.DMatrix
) ->tuple[str, float]:
""" Returns the metric associated with the objective. :param pred: The predicted values. :param data: The predictor values. :return: The name and the value of the metric. """defobj(
self, pred: np.ndarray, data: xgb.DMatrix
) ->tuple[np.ndarray, np.ndarray]:
""" The objective function. :param pred: The predicted values. :param data: The predictor values. :return: The gradient and the Hessian of the objective. """returnself.gradient(pred, data), self.hessian(pred, data)
defle(x: np.ndarray, y: np.ndarray) ->np.ndarray:
"""Returns the logarithmic error terms."""return0.5*np.log(np.square((1.0+x) / (1.0+y)))
defrms(e: np.ndarray, w: np.ndarray) ->np.ndarray:
"""Returns the root (weighted) mean squared error."""returnnp.sqrt(
np.average(np.square(e), weights=wifw.shape==e.shapeelseNone)
)
classSLE(Objective):
""" The squared logarithmic error objective. This objective shall replace the internal XGB squared logarithmic error objective. """defgradient(self, pred: np.ndarray, data: xgb.DMatrix) ->np.ndarray:
returnle(pred, data.get_label()) / (1.0+pred)
defhessian(self, pred: np.ndarray, data: xgb.DMatrix) ->np.ndarray:
return (1.0-le(pred, data.get_label())) /np.square(1.0+pred)
defmetric(
self, pred: np.ndarray, data: xgb.DMatrix
) ->tuple[str, float]:
return (
"rmsle",
rms(le(pred, data.get_label()), data.get_weight()).item(),
)
The text was updated successfully, but these errors were encountered:
The squared log error objective function produces NaN values during training, if the predicted value
x
is less than or equal to-1.0
, as you correctly write in your documentation. However, this is not a mathematical necessity but a consequence of your implementation. You may exploit the identitylog(1 + x) == 0.5 * log((1 + x) * (1 + x))
or as you may prefer
log1p(x) == 0.5 * log1p(x * (2.0 + x))
which yields NaN if and only if
x == -1.0
but a meaningful value otherwise. The gradient and the Hessian (and hence the training) become stable for valuesx < -1.0
, too. I included my custom implementation of the squared log error objective (and its associated metric) in the code below. It works well in my case and I would like to share it with you. My implementation also uses the identitylog(a) - log(b) == log(a / b)
But this is not essential. Take it or leave it. Best wishes, Ralf.
The text was updated successfully, but these errors were encountered: