Skip to content

Commit bfaf638

Browse files
committed
Move skl eval_metric early_stopping rounds to model params.
These 2 parameters are now model parameters that can be set at constructor and `set_params` method. Doc and test. Fix return. non breaking.
1 parent 7593fa9 commit bfaf638

File tree

5 files changed

+181
-98
lines changed

5 files changed

+181
-98
lines changed

python-package/xgboost/dask.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,8 +1676,8 @@ async def _fit_async(
16761676
obj: Optional[Callable] = _objective_decorator(self.objective)
16771677
else:
16781678
obj = None
1679-
model, metric, params = self._configure_fit(
1680-
booster=xgb_model, eval_metric=eval_metric, params=params
1679+
model, metric, params, early_stopping_rounds = self._configure_fit(
1680+
xgb_model, eval_metric, params, early_stopping_rounds
16811681
)
16821682
results = await self.client.sync(
16831683
_train_async,
@@ -1778,8 +1778,8 @@ async def _fit_async(
17781778
obj: Optional[Callable] = _objective_decorator(self.objective)
17791779
else:
17801780
obj = None
1781-
model, metric, params = self._configure_fit(
1782-
booster=xgb_model, eval_metric=eval_metric, params=params
1781+
model, metric, params, early_stopping_rounds = self._configure_fit(
1782+
xgb_model, eval_metric, params, early_stopping_rounds
17831783
)
17841784
results = await self.client.sync(
17851785
_train_async,
@@ -1903,9 +1903,9 @@ def _argmax(x: Any) -> Any:
19031903
""",
19041904
["estimators", "model"],
19051905
end_note="""
1906-
Note
1907-
----
1908-
For dask implementation, group is not supported, use qid instead.
1906+
.. note::
1907+
1908+
For dask implementation, group is not supported, use qid instead.
19091909
""",
19101910
)
19111911
class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
@@ -1963,8 +1963,8 @@ async def _fit_async(
19631963
raise ValueError(
19641964
"Custom evaluation metric is not yet supported for XGBRanker."
19651965
)
1966-
model, metric, params = self._configure_fit(
1967-
booster=xgb_model, eval_metric=eval_metric, params=params
1966+
model, metric, params, early_stopping_rounds = self._configure_fit(
1967+
xgb_model, eval_metric, params, early_stopping_rounds
19681968
)
19691969
results = await self.client.sync(
19701970
_train_async,

python-package/xgboost/sklearn.py

Lines changed: 139 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@ def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]:
9090
return inner
9191

9292

93+
def _metric_decorator(func: Callable) -> Metric:
94+
"""Decorate a metric function from sklearn.
95+
96+
Converts an objective function using the typical sklearn metrics signature so that it
97+
is compatible with ``xgboost.training.train``
98+
99+
"""
100+
def inner(y_score: np.ndarray, dmatrix: DMatrix) -> float:
101+
y_true = dmatrix.get_label()
102+
return func.__name__, func(y_true, y_score)
103+
return inner
104+
105+
93106
__estimator_doc = '''
94107
n_estimators : int
95108
Number of gradient boosted trees. Equivalent to number of boosting
@@ -184,6 +197,46 @@ def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]:
184197
Experimental support for categorical data. Do not set to true unless you are
185198
interested in development. Only valid when `gpu_hist` and dataframe are used.
186199
200+
eval_metric : Optional[Union[str, List[str], Callable]]
201+
Metric used for monitoring the training result and early stopping. It can be a
202+
string or list of strings as names of predefined metric in XGBoost (See
203+
doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any other
204+
user defined metric that looks like `sklearn.metrics`.
205+
206+
Unlike scikit-learn `scoring` parameter, when a callable object is provided, it's
207+
assumed to be a cost function and by default XGBoost will minimize the result
208+
during early stopping.
209+
210+
For advanced usage on Early stopping like directly choosing to maximize instead of
211+
minimize, see :py:obj:`xgboost.callback.EarlyStopping`.
212+
213+
.. versionadded:: 1.5.1
214+
215+
.. note::
216+
217+
This parameter replaces `eval_metric` in
218+
:py:meth:`fit` method.
219+
220+
early_stopping_rounds : Optional[int]
221+
Activates early stopping. Validation metric needs to improve at least once in
222+
every **early_stopping_rounds** round(s) to continue training. Requires at least
223+
one item in **eval_set** in :py:meth:`xgboost.sklearn.XGBModel.fit`.
224+
225+
The method returns the model from the last iteration (not the best one). If
226+
there's more than one item in **eval_set**, the last entry will be used for early
227+
stopping. If there's more than one metric in **eval_metric**, the last metric
228+
will be used for early stopping.
229+
230+
If early stopping occurs, the model will have three additional fields:
231+
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
232+
233+
.. versionadded:: 1.5.1
234+
235+
.. note::
236+
237+
This parameter replaces `early_stopping_rounds` in
238+
:py:meth:`fit` method.
239+
187240
kwargs : dict, optional
188241
Keyword arguments for XGBoost Booster object. Full documentation of
189242
parameters can be found here:
@@ -399,6 +452,8 @@ def __init__(
399452
validate_parameters: Optional[bool] = None,
400453
predictor: Optional[str] = None,
401454
enable_categorical: bool = False,
455+
eval_metric=None,
456+
early_stopping_rounds=None,
402457
**kwargs: Any
403458
) -> None:
404459
if not SKLEARN_INSTALLED:
@@ -435,6 +490,8 @@ def __init__(
435490
self.validate_parameters = validate_parameters
436491
self.predictor = predictor
437492
self.enable_categorical = enable_categorical
493+
self.eval_metric = eval_metric
494+
self.early_stopping_rounds = early_stopping_rounds
438495
if kwargs:
439496
self.kwargs = kwargs
440497

@@ -545,10 +602,15 @@ def get_xgb_params(self) -> Dict[str, Any]:
545602
params = self.get_params()
546603
# Parameters that should not go into native learner.
547604
wrapper_specific = {
548-
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder',
605+
'importance_type',
606+
'kwargs',
607+
'missing',
608+
'n_estimators',
609+
'use_label_encoder',
549610
"enable_categorical"
611+
"early_stopping_rounds"
550612
}
551-
filtered = {}
613+
filtered = dict()
552614
for k, v in params.items():
553615
if k not in wrapper_specific and not callable(v):
554616
filtered[k] = v
@@ -636,15 +698,32 @@ def _configure_fit(
636698
booster: Optional[Union[Booster, "XGBModel", str]],
637699
eval_metric: Optional[Union[Callable, str, List[str]]],
638700
params: Dict[str, Any],
639-
) -> Tuple[Optional[Union[Booster, str]], Optional[Metric], Dict[str, Any]]:
640-
# pylint: disable=protected-access, no-self-use
641-
if isinstance(booster, XGBModel):
701+
early_stopping_rounds: Optional[int],
702+
) -> Tuple[Optional[Union[Booster, str, "XGBModel"]], Optional[Metric], Dict[str, Any], Optional[int]]:
703+
# pylint: disable=protected-access
704+
model = booster
705+
if hasattr(model, "_Booster"):
642706
# Handle the case when xgb_model is a sklearn model object
643-
model: Optional[Union[Booster, str]] = booster._Booster
644-
else:
645-
model = booster
707+
model = model._Booster
708+
709+
if eval_metric is not None:
710+
warnings.warn(
711+
"eval_metric for `fit` method is deprecated, use `eval_metric` in "
712+
"constructor or `set_params` instead.",
713+
UserWarning,
714+
)
646715

716+
# configure callable evaluation metric
647717
feval = eval_metric if callable(eval_metric) else None
718+
if self.eval_metric is not None and feval is not None:
719+
warnings.warn(
720+
"Overriding `eval_metric` from `fit` with `eval_metric` from parameter",
721+
UserWarning
722+
)
723+
if callable(self.eval_metric):
724+
feval = _metric_decorator(self.eval_metric)
725+
726+
# configure string/list evaluation metric
648727
if eval_metric is not None:
649728
if callable(eval_metric):
650729
eval_metric = None
@@ -656,7 +735,26 @@ def _configure_fit(
656735
" current tree method yet."
657736
)
658737

659-
return model, feval, params
738+
# configure early_stopping_rounds
739+
if early_stopping_rounds is not None:
740+
warnings.warn(
741+
"`early_stopping_rounds` is deprecated, use `early_stopping_rounds` "
742+
"in constructor or `set_params` instead.",
743+
UserWarning,
744+
)
745+
if (
746+
self.early_stopping_rounds is not None
747+
and self.early_stopping_rounds != early_stopping_rounds
748+
):
749+
raise ValueError("2 different `early_stopping_rounds` are provided.")
750+
751+
early_stopping_rounds = (
752+
self.early_stopping_rounds
753+
if self.early_stopping_rounds is not None
754+
else early_stopping_rounds
755+
)
756+
757+
return model, feval, params, early_stopping_rounds
660758

661759
def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None:
662760
if evals_result:
@@ -704,31 +802,10 @@ def fit(
704802
A list of (X, y) tuple pairs to use as validation sets, for which
705803
metrics will be computed.
706804
Validation metrics will help us track the performance of the model.
707-
eval_metric :
708-
If a str, should be a built-in evaluation metric to use. See doc/parameter.rst.
709-
710-
If a list of str, should be the list of multiple built-in evaluation metrics
711-
to use.
712-
713-
If callable, a custom evaluation metric. The call signature is
714-
``func(y_predicted, y_true)`` where ``y_true`` will be a DMatrix object such
715-
that you may need to call the ``get_label`` method. It must return a str,
716-
value pair where the str is a name for the evaluation and value is the value
717-
of the evaluation function. The callable custom objective is always minimized.
718-
early_stopping_rounds :
719-
Activates early stopping. Validation metric needs to improve at least once in
720-
every **early_stopping_rounds** round(s) to continue training.
721-
Requires at least one item in **eval_set**.
722-
723-
The method returns the model from the last iteration (not the best one).
724-
If there's more than one item in **eval_set**, the last entry will be used
725-
for early stopping.
726-
727-
If there's more than one metric in **eval_metric**, the last metric will be
728-
used for early stopping.
729-
730-
If early stopping occurs, the model will have three additional fields:
731-
``clf.best_score``, ``clf.best_iteration``.
805+
eval_metric : str, list of str, or callable, optional
806+
Deprecated, use `eval_metric` in constructor or `set_params` instead.
807+
early_stopping_rounds : int
808+
Deprecated, use `early_stopping_rounds` in constructor instead.
732809
verbose :
733810
If `verbose` and an evaluation set is used, writes the evaluation metric
734811
measured on the validation set to stderr.
@@ -785,7 +862,9 @@ def fit(
785862
else:
786863
obj = None
787864

788-
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
865+
model, feval, params, early_stopping_rounds = self._configure_fit(
866+
xgb_model, eval_metric, params, early_stopping_rounds
867+
)
789868
self._Booster = train(
790869
params,
791870
train_dmatrix,
@@ -1223,7 +1302,9 @@ def fit(
12231302
else:
12241303
label_transform = lambda x: x
12251304

1226-
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
1305+
model, feval, params, early_stopping_rounds = self._configure_fit(
1306+
xgb_model, eval_metric, params, early_stopping_rounds
1307+
)
12271308
train_dmatrix, evals = _wrap_evaluation_matrices(
12281309
missing=self.missing,
12291310
X=X,
@@ -1359,8 +1440,9 @@ def evals_result(self) -> TrainingCallback.EvalsLog:
13591440
13601441
If **eval_set** is passed to the `fit` function, you can call
13611442
``evals_result()`` to get evaluation results for all passed **eval_sets**.
1362-
When **eval_metric** is also passed to the `fit` function, the
1363-
**evals_result** will contain the **eval_metrics** passed to the `fit` function.
1443+
1444+
When **eval_metric** is also passed as a parameter, the **evals_result** will
1445+
contain the **eval_metric** passed to the `fit` function.
13641446
13651447
Returns
13661448
-------
@@ -1371,13 +1453,14 @@ def evals_result(self) -> TrainingCallback.EvalsLog:
13711453
13721454
.. code-block:: python
13731455
1374-
param_dist = {'objective':'binary:logistic', 'n_estimators':2}
1456+
param_dist = {
1457+
'objective':'binary:logistic', 'n_estimators':2, eval_metric="logloss"
1458+
}
13751459
13761460
clf = xgb.XGBClassifier(**param_dist)
13771461
13781462
clf.fit(X_train, y_train,
13791463
eval_set=[(X_train, y_train), (X_test, y_test)],
1380-
eval_metric='logloss',
13811464
verbose=True)
13821465
13831466
evals_result = clf.evals_result()
@@ -1388,6 +1471,7 @@ def evals_result(self) -> TrainingCallback.EvalsLog:
13881471
13891472
{'validation_0': {'logloss': ['0.604835', '0.531479']},
13901473
'validation_1': {'logloss': ['0.41965', '0.17686']}}
1474+
13911475
"""
13921476
if self.evals_result_:
13931477
evals_result = self.evals_result_
@@ -1534,15 +1618,15 @@ def fit(
15341618
'Implementation of the Scikit-Learn API for XGBoost Ranking.',
15351619
['estimators', 'model'],
15361620
end_note='''
1537-
Note
1538-
----
1539-
A custom objective function is currently not supported by XGBRanker.
1540-
Likewise, a custom metric function is not supported either.
1621+
.. note::
1622+
1623+
A custom objective function is currently not supported by XGBRanker.
1624+
Likewise, a custom metric function is not supported either.
15411625
1542-
Note
1543-
----
1544-
Query group information is required for ranking tasks by either using the `group`
1545-
parameter or `qid` parameter in `fit` method.
1626+
.. note::
1627+
1628+
Query group information is required for ranking tasks by either using the
1629+
`group` parameter or `qid` parameter in `fit` method.
15461630
15471631
Before fitting the model, your data need to be sorted by query group. When fitting
15481632
the model, you need to provide an additional array that contains the size of each
@@ -1644,22 +1728,10 @@ def fit(
16441728
eval_qid :
16451729
A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th
16461730
pair in **eval_set**.
1647-
eval_metric :
1648-
If a str, should be a built-in evaluation metric to use. See
1649-
doc/parameter.rst.
1650-
If a list of str, should be the list of multiple built-in evaluation metrics
1651-
to use. The custom evaluation metric is not yet supported for the ranker.
1652-
early_stopping_rounds :
1653-
Activates early stopping. Validation metric needs to improve at least once in
1654-
every **early_stopping_rounds** round(s) to continue training. Requires at
1655-
least one item in **eval_set**.
1656-
The method returns the model from the last iteration (not the best one). If
1657-
there's more than one item in **eval_set**, the last entry will be used for
1658-
early stopping.
1659-
If there's more than one metric in **eval_metric**, the last metric will be
1660-
used for early stopping.
1661-
If early stopping occurs, the model will have three additional fields:
1662-
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
1731+
eval_metric : str, list of str, optional
1732+
The custom evaluation metric is not yet supported for the ranker.
1733+
early_stopping_rounds : int
1734+
Deprecated, use `early_stopping_rounds` in constructor instead.
16631735
verbose :
16641736
If `verbose` and an evaluation set is used, writes the evaluation metric
16651737
measured on the validation set to stderr.
@@ -1724,7 +1796,9 @@ def fit(
17241796
evals_result: TrainingCallback.EvalsLog = {}
17251797
params = self.get_xgb_params()
17261798

1727-
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
1799+
model, feval, params, early_stopping_rounds = self._configure_fit(
1800+
xgb_model, eval_metric, params, early_stopping_rounds
1801+
)
17281802
if callable(feval):
17291803
raise ValueError(
17301804
'Custom evaluation metric is not yet supported for XGBRanker.'

0 commit comments

Comments
 (0)