Skip to content

Commit 45aef75

Browse files
authored
Move skl eval_metric and early_stopping rounds to model params. (#6751)
A new parameter `custom_metric` is added to `train` and `cv` to distinguish the behaviour from the old `feval`. And `feval` is deprecated. The new `custom_metric` receives transformed prediction when the built-in objective is used. This enables XGBoost to use cost functions from other libraries like scikit-learn directly without going through the definition of the link function. `eval_metric` and `early_stopping_rounds` in sklearn interface are moved from `fit` to `__init__` and is now saved as part of the scikit-learn model. The old ones in `fit` function are now deprecated. The new `eval_metric` in `__init__` has the same new behaviour as `custom_metric`. Added more detailed documents for the behaviour of custom objective and metric.
1 parent 6b074ad commit 45aef75

File tree

13 files changed

+687
-192
lines changed

13 files changed

+687
-192
lines changed

demo/guide-python/custom_rmsle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
144144
dtrain=dtrain,
145145
num_boost_round=kBoostRound,
146146
obj=squared_log,
147-
feval=rmsle,
147+
custom_metric=rmsle,
148148
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
149149
evals_result=results)
150150

demo/guide-python/custom_softmax.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
returns transformed prediction for multi-class objective function. More
44
details in comments.
55
6+
See https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html for detailed
7+
tutorial and notes.
8+
69
'''
710

811
import numpy as np
@@ -95,7 +98,12 @@ def predict(booster: xgb.Booster, X):
9598

9699
def merror(predt: np.ndarray, dtrain: xgb.DMatrix):
97100
y = dtrain.get_label()
98-
# Like custom objective, the predt is untransformed leaf weight
101+
# Like custom objective, the predt is untransformed leaf weight when custom objective
102+
# is provided.
103+
104+
# With the use of `custom_metric` parameter in train function, custom metric receives
105+
# raw input only when custom objective is also being used. Otherwise custom metric
106+
# will receive transformed prediction.
99107
assert predt.shape == (kRows, kClasses)
100108
out = np.zeros(kRows)
101109
for r in range(predt.shape[0]):
@@ -134,7 +142,7 @@ def main(args):
134142
m,
135143
num_boost_round=kRounds,
136144
obj=softprob_obj,
137-
feval=merror,
145+
custom_metric=merror,
138146
evals_result=custom_results,
139147
evals=[(m, 'train')])
140148

@@ -143,6 +151,7 @@ def main(args):
143151
native_results = {}
144152
# Use the same objective function defined in XGBoost.
145153
booster_native = xgb.train({'num_class': kClasses,
154+
"objective": "multi:softmax",
146155
'eval_metric': 'merror'},
147156
m,
148157
num_boost_round=kRounds,

doc/tutorials/custom_metric_obj.rst

Lines changed: 180 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
Custom Objective and Evaluation Metric
33
######################################
44

5+
**Contents**
6+
7+
.. contents::
8+
:backlinks: none
9+
:local:
10+
11+
********
12+
Overview
13+
********
14+
515
XGBoost is designed to be an extensible library. One way to extend it is by providing our
616
own objective function for training and corresponding metric for performance monitoring.
717
This document introduces implementing a customized elementwise evaluation metric and
@@ -11,12 +21,8 @@ concepts should be readily applicable to other language bindings.
1121
.. note::
1222

1323
* The ranking task does not support customized functions.
14-
* The customized functions defined here are only applicable to single node training.
15-
Distributed environment requires syncing with ``xgboost.rabit``, the interface is
16-
subject to change hence beyond the scope of this tutorial.
17-
* We also plan to improve the interface for multi-classes objective in the future.
1824

19-
In the following sections, we will provide a step by step walk through of implementing
25+
In the following two sections, we will provide a step by step walk through of implementing
2026
``Squared Log Error(SLE)`` objective function:
2127

2228
.. math::
@@ -30,7 +36,10 @@ and its default metric ``Root Mean Squared Log Error(RMSLE)``:
3036
Although XGBoost has native support for said functions, using it for demonstration
3137
provides us the opportunity of comparing the result from our own implementation and the
3238
one from XGBoost internal for learning purposes. After finishing this tutorial, we should
33-
be able to provide our own functions for rapid experiments.
39+
be able to provide our own functions for rapid experiments. And at the end, we will
40+
provide some notes on non-identy link function along with examples of using custom metric
41+
and objective with `scikit-learn` interface.
42+
with scikit-learn interface.
3443

3544
*****************************
3645
Customized Objective Function
@@ -125,24 +134,177 @@ We will be able to see XGBoost printing something like:
125134

126135
.. code-block:: none
127136
128-
[0] dtrain-PyRMSLE:1.37153 dtest-PyRMSLE:1.31487
129-
[1] dtrain-PyRMSLE:1.26619 dtest-PyRMSLE:1.20899
130-
[2] dtrain-PyRMSLE:1.17508 dtest-PyRMSLE:1.11629
131-
[3] dtrain-PyRMSLE:1.09836 dtest-PyRMSLE:1.03871
132-
[4] dtrain-PyRMSLE:1.03557 dtest-PyRMSLE:0.977186
133-
[5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057
137+
[0] dtrain-PyRMSLE:1.37153 dtest-PyRMSLE:1.31487
138+
[1] dtrain-PyRMSLE:1.26619 dtest-PyRMSLE:1.20899
139+
[2] dtrain-PyRMSLE:1.17508 dtest-PyRMSLE:1.11629
140+
[3] dtrain-PyRMSLE:1.09836 dtest-PyRMSLE:1.03871
141+
[4] dtrain-PyRMSLE:1.03557 dtest-PyRMSLE:0.977186
142+
[5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057
134143
...
135144
136145
Notice that the parameter ``disable_default_eval_metric`` is used to suppress the default metric
137146
in XGBoost.
138147

139148
For fully reproducible source code and comparison plots, see `custom_rmsle.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_rmsle.py>`_.
140149

150+
*********************
151+
Reverse Link Function
152+
*********************
153+
154+
When using builtin objective, the raw prediction is transformed according to the objective
155+
function. When custom objective is provided XGBoost doesn't know its link function so the
156+
user is responsible for making the transformation for both objective and custom evaluation
157+
metric. For objective with identiy link like ``squared error`` this is trivial, but for
158+
other link functions like log link or inverse link the difference is significant.
159+
160+
For the Python package, the behaviour of prediction can be controlled by the
161+
``output_margin`` parameter in ``predict`` function. When using the ``custom_metric``
162+
parameter without a custom objective, the metric function will receive transformed
163+
prediction since the objective is defined by XGBoost. However, when custom objective is
164+
also provided along with that metric, then both the objective and custom metric will
165+
recieve raw prediction. Following example provides a comparison between two different
166+
behavior with a multi-class classification model. Firstly we define 2 different Python
167+
metric functions implementing the same underlying metric for comparison,
168+
`merror_with_transform` is used when custom objective is also used, otherwise the simpler
169+
`merror` is preferred since XGBoost can perform the transformation itself.
170+
171+
.. code-block:: python
172+
173+
import xgboost as xgb
174+
import numpy as np
175+
176+
def merror_with_transform(predt: np.ndarray, dtrain: xgb.DMatrix):
177+
"""Used when custom objective is supplied."""
178+
y = dtrain.get_label()
179+
n_classes = predt.size // y.shape[0]
180+
# Like custom objective, the predt is untransformed leaf weight when custom objective
181+
# is provided.
182+
183+
# With the use of `custom_metric` parameter in train function, custom metric receives
184+
# raw input only when custom objective is also being used. Otherwise custom metric
185+
# will receive transformed prediction.
186+
assert predt.shape == (d_train.num_row(), n_classes)
187+
out = np.zeros(dtrain.num_row())
188+
for r in range(predt.shape[0]):
189+
i = np.argmax(predt[r])
190+
out[r] = i
191+
192+
assert y.shape == out.shape
193+
194+
errors = np.zeros(dtrain.num_row())
195+
errors[y != out] = 1.0
196+
return 'PyMError', np.sum(errors) / dtrain.num_row()
197+
198+
The above function is only needed when we want to use custom objective and XGBoost doesn't
199+
know how to transform the prediction. The normal implementation for multi-class error
200+
function is:
201+
202+
.. code-block:: python
203+
204+
def merror(predt: np.ndarray, dtrain: xgb.DMatrix):
205+
"""Used when there's no custom objective."""
206+
# No need to do transform, XGBoost handles it internally.
207+
errors = np.zeros(dtrain.num_row())
208+
errors[y != out] = 1.0
209+
return 'PyMError', np.sum(errors) / dtrain.num_row()
210+
211+
212+
Next we need the custom softprob objective:
213+
214+
.. code-block:: python
215+
216+
def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
217+
"""Loss function. Computing the gradient and approximated hessian (diagonal).
218+
Reimplements the `multi:softprob` inside XGBoost.
219+
"""
220+
221+
# Full implementation is available in the Python demo script linked below
222+
...
141223
142-
******************************
143-
Multi-class objective function
144-
******************************
224+
return grad, hess
225+
226+
Lastly we can train the model using ``obj`` and ``custom_metric`` parameters:
227+
228+
.. code-block:: python
229+
230+
Xy = xgb.DMatrix(X, y)
231+
booster = xgb.train(
232+
{"num_class": kClasses, "disable_default_eval_metric": True},
233+
m,
234+
num_boost_round=kRounds,
235+
obj=softprob_obj,
236+
custom_metric=merror_with_transform,
237+
evals_result=custom_results,
238+
evals=[(m, "train")],
239+
)
240+
241+
Or if you don't need the custom objective and just want to supply a metric that's not
242+
available in XGBoost:
243+
244+
.. code-block:: python
245+
246+
booster = xgb.train(
247+
{
248+
"num_class": kClasses,
249+
"disable_default_eval_metric": True,
250+
"objective": "multi:softmax",
251+
},
252+
m,
253+
num_boost_round=kRounds,
254+
# Use a simpler metric implementation.
255+
custom_metric=merror,
256+
evals_result=custom_results,
257+
evals=[(m, "train")],
258+
)
259+
260+
We use ``multi:softmax`` to illustrate the differences of transformed prediction. With
261+
``softprob`` the output prediction array has shape ``(n_samples, n_classes)`` while for
262+
``softmax`` it's ``(n_samples, )``. A demo for multi-class objective function is also
263+
available at `demo/guide-python/custom_softmax.py
264+
<https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_softmax.py>`_
265+
266+
267+
**********************
268+
Scikit-Learn Interface
269+
**********************
270+
271+
272+
The scikit-learn interface of XGBoost has some utilities to improve the integration with
273+
standard scikit-learn functions. For instance, after XGBoost 1.5.1 users can use the cost
274+
function (not scoring functions) from scikit-learn out of the box:
275+
276+
.. code-block:: python
277+
278+
from sklearn.datasets import load_diabetes
279+
from sklearn.metrics import mean_absolute_error
280+
X, y = load_diabetes(return_X_y=True)
281+
reg = xgb.XGBRegressor(
282+
tree_method="hist",
283+
eval_metric=mean_absolute_error,
284+
)
285+
reg.fit(X, y, eval_set=[(X, y)])
286+
287+
Also, for custom objective function, users can define the objective without having to
288+
access ``DMatrix``:
289+
290+
.. code-block:: python
291+
292+
def softprob_obj(labels: np.ndarray, predt: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
293+
rows = labels.shape[0]
294+
grad = np.zeros((rows, classes), dtype=float)
295+
hess = np.zeros((rows, classes), dtype=float)
296+
eps = 1e-6
297+
for r in range(predt.shape[0]):
298+
target = labels[r]
299+
p = softmax(predt[r, :])
300+
for c in range(predt.shape[1]):
301+
g = p[c] - 1.0 if c == target else p[c]
302+
h = max((2.0 * p[c] * (1.0 - p[c])).item(), eps)
303+
grad[r, c] = g
304+
hess[r, c] = h
305+
306+
grad = grad.reshape((rows * classes, 1))
307+
hess = hess.reshape((rows * classes, 1))
308+
return grad, hess
145309
146-
A similar demo for multi-class objective function is also available, see
147-
`demo/guide-python/custom_softmax.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/custom_softmax.py>`_
148-
for details.
310+
clf = xgb.XGBClassifier(tree_method="hist", objective=softprob_obj)

python-package/xgboost/callback.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,13 @@ class CallbackContainer:
103103

104104
EvalsLog = TrainingCallback.EvalsLog
105105

106-
def __init__(self,
107-
callbacks: List[TrainingCallback],
108-
metric: Callable = None,
109-
is_cv: bool = False):
106+
def __init__(
107+
self,
108+
callbacks: List[TrainingCallback],
109+
metric: Callable = None,
110+
output_margin: bool = True,
111+
is_cv: bool = False
112+
) -> None:
110113
self.callbacks = set(callbacks)
111114
if metric is not None:
112115
msg = 'metric must be callable object for monitoring. For ' + \
@@ -115,6 +118,7 @@ def __init__(self,
115118
assert callable(metric), msg
116119
self.metric = metric
117120
self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
121+
self._output_margin = output_margin
118122
self.is_cv = is_cv
119123

120124
if self.is_cv:
@@ -171,15 +175,15 @@ def _update_history(self, score, epoch):
171175
def after_iteration(self, model, epoch, dtrain, evals) -> bool:
172176
'''Function called after training iteration.'''
173177
if self.is_cv:
174-
scores = model.eval(epoch, self.metric)
178+
scores = model.eval(epoch, self.metric, self._output_margin)
175179
scores = _aggcv(scores)
176180
self.aggregated_cv = scores
177181
self._update_history(scores, epoch)
178182
else:
179183
evals = [] if evals is None else evals
180184
for _, name in evals:
181185
assert name.find('-') == -1, 'Dataset name should not contain `-`'
182-
score = model.eval_set(evals, epoch, self.metric)
186+
score = model.eval_set(evals, epoch, self.metric, self._output_margin)
183187
score = score.split()[1:] # into datasets
184188
# split up `test-error:0.1234`
185189
score = [tuple(s.split(':')) for s in score]

python-package/xgboost/core.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,7 +1700,7 @@ def boost(self, dtrain, grad, hess):
17001700
c_array(ctypes.c_float, hess),
17011701
c_bst_ulong(len(grad))))
17021702

1703-
def eval_set(self, evals, iteration=0, feval=None):
1703+
def eval_set(self, evals, iteration=0, feval=None, output_margin=True):
17041704
# pylint: disable=invalid-name
17051705
"""Evaluate a set of data.
17061706
@@ -1728,24 +1728,30 @@ def eval_set(self, evals, iteration=0, feval=None):
17281728
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
17291729
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
17301730
msg = ctypes.c_char_p()
1731-
_check_call(_LIB.XGBoosterEvalOneIter(self.handle,
1732-
ctypes.c_int(iteration),
1733-
dmats, evnames,
1734-
c_bst_ulong(len(evals)),
1735-
ctypes.byref(msg)))
1731+
_check_call(
1732+
_LIB.XGBoosterEvalOneIter(
1733+
self.handle,
1734+
ctypes.c_int(iteration),
1735+
dmats,
1736+
evnames,
1737+
c_bst_ulong(len(evals)),
1738+
ctypes.byref(msg),
1739+
)
1740+
)
17361741
res = msg.value.decode() # pylint: disable=no-member
17371742
if feval is not None:
17381743
for dmat, evname in evals:
1739-
feval_ret = feval(self.predict(dmat, training=False,
1740-
output_margin=True), dmat)
1744+
feval_ret = feval(
1745+
self.predict(dmat, training=False, output_margin=output_margin), dmat
1746+
)
17411747
if isinstance(feval_ret, list):
17421748
for name, val in feval_ret:
17431749
# pylint: disable=consider-using-f-string
1744-
res += '\t%s-%s:%f' % (evname, name, val)
1750+
res += "\t%s-%s:%f" % (evname, name, val)
17451751
else:
17461752
name, val = feval_ret
17471753
# pylint: disable=consider-using-f-string
1748-
res += '\t%s-%s:%f' % (evname, name, val)
1754+
res += "\t%s-%s:%f" % (evname, name, val)
17491755
return res
17501756

17511757
def eval(self, data, name='eval', iteration=0):

0 commit comments

Comments
 (0)