Skip to content

Commit 3c4aa9b

Browse files
authored
[breaking] Remove label encoder deprecated in 1.3. (#7357)
1 parent d05754f commit 3c4aa9b

File tree

7 files changed

+74
-83
lines changed

7 files changed

+74
-83
lines changed

python-package/xgboost/sklearn.py

Lines changed: 17 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# coding: utf-8
2-
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, R0912, C0302
1+
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, too-many-lines
32
"""Scikit-Learn Wrapper interface for XGBoost."""
43
import copy
54
import warnings
@@ -278,14 +277,13 @@ def _wrap_evaluation_matrices(
278277
eval_qid: Optional[List[Any]],
279278
create_dmatrix: Callable,
280279
enable_categorical: bool,
281-
label_transform: Callable = lambda x: x,
282280
) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]:
283281
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
284282
285283
"""
286284
train_dmatrix = create_dmatrix(
287285
data=X,
288-
label=label_transform(y),
286+
label=y,
289287
group=group,
290288
qid=qid,
291289
weight=sample_weight,
@@ -333,7 +331,7 @@ def validate_or_none(meta: Optional[List], name: str) -> List:
333331
else:
334332
m = create_dmatrix(
335333
data=valid_X,
336-
label=label_transform(valid_y),
334+
label=valid_y,
337335
weight=sample_weight_eval_set[i],
338336
group=eval_group[i],
339337
qid=eval_qid[i],
@@ -1112,9 +1110,6 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
11121110
['model', 'objective'], extra_parameters='''
11131111
n_estimators : int
11141112
Number of boosting rounds.
1115-
use_label_encoder : bool
1116-
(Deprecated) Use the label encoder from scikit-learn to encode the labels. For new
1117-
code, we recommend that you set this parameter to False.
11181113
''')
11191114
class XGBClassifier(XGBModel, XGBClassifierBase):
11201115
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
@@ -1123,10 +1118,13 @@ def __init__(
11231118
self,
11241119
*,
11251120
objective: _SklObjective = "binary:logistic",
1126-
use_label_encoder: bool = True,
1121+
use_label_encoder: bool = False,
11271122
**kwargs: Any
11281123
) -> None:
1124+
# must match the parameters for `get_params`
11291125
self.use_label_encoder = use_label_encoder
1126+
if use_label_encoder is True:
1127+
raise ValueError("Label encoder was removed in 1.6.")
11301128
super().__init__(objective=objective, **kwargs)
11311129

11321130
@_deprecate_positional_args
@@ -1148,51 +1146,32 @@ def fit(
11481146
callbacks: Optional[List[TrainingCallback]] = None
11491147
) -> "XGBClassifier":
11501148
# pylint: disable = attribute-defined-outside-init,too-many-statements
1151-
can_use_label_encoder = True
1152-
label_encoding_check_error = (
1153-
"The label must consist of integer "
1154-
"labels of form 0, 1, 2, ..., [num_class - 1]."
1155-
)
1156-
label_encoder_deprecation_msg = (
1157-
"The use of label encoder in XGBClassifier is deprecated and will be "
1158-
"removed in a future release. To remove this warning, do the "
1159-
"following: 1) Pass option use_label_encoder=False when constructing "
1160-
"XGBClassifier object; and 2) Encode your labels (y) as integers "
1161-
"starting with 0, i.e. 0, 1, 2, ..., [num_class - 1]."
1162-
)
1163-
11641149
evals_result: TrainingCallback.EvalsLog = {}
1150+
11651151
if _is_cudf_df(y) or _is_cudf_ser(y):
11661152
import cupy as cp # pylint: disable=E0401
11671153

11681154
self.classes_ = cp.unique(y.values)
11691155
self.n_classes_ = len(self.classes_)
1170-
can_use_label_encoder = False
11711156
expected_classes = cp.arange(self.n_classes_)
1172-
if (
1173-
self.classes_.shape != expected_classes.shape
1174-
or not (self.classes_ == expected_classes).all()
1175-
):
1176-
raise ValueError(label_encoding_check_error)
11771157
elif _is_cupy_array(y):
11781158
import cupy as cp # pylint: disable=E0401
11791159

11801160
self.classes_ = cp.unique(y)
11811161
self.n_classes_ = len(self.classes_)
1182-
can_use_label_encoder = False
11831162
expected_classes = cp.arange(self.n_classes_)
1184-
if (
1185-
self.classes_.shape != expected_classes.shape
1186-
or not (self.classes_ == expected_classes).all()
1187-
):
1188-
raise ValueError(label_encoding_check_error)
11891163
else:
11901164
self.classes_ = np.unique(np.asarray(y))
11911165
self.n_classes_ = len(self.classes_)
1192-
if not self.use_label_encoder and (
1193-
not np.array_equal(self.classes_, np.arange(self.n_classes_))
1194-
):
1195-
raise ValueError(label_encoding_check_error)
1166+
expected_classes = np.arange(self.n_classes_)
1167+
if (
1168+
self.classes_.shape != expected_classes.shape
1169+
or not (self.classes_ == expected_classes).all()
1170+
):
1171+
raise ValueError(
1172+
f"Invalid classes inferred from unique values of `y`. "
1173+
f"Expected: {expected_classes}, got {self.classes_}"
1174+
)
11961175

11971176
params = self.get_xgb_params()
11981177

@@ -1211,18 +1190,6 @@ def fit(
12111190
params["objective"] = "multi:softprob"
12121191
params["num_class"] = self.n_classes_
12131192

1214-
if self.use_label_encoder:
1215-
if not can_use_label_encoder:
1216-
raise ValueError('The option use_label_encoder=True is incompatible with inputs ' +
1217-
'of type cuDF or cuPy. Please set use_label_encoder=False when ' +
1218-
'constructing XGBClassifier object. NOTE: ' +
1219-
label_encoder_deprecation_msg)
1220-
warnings.warn(label_encoder_deprecation_msg, UserWarning)
1221-
self._le = XGBoostLabelEncoder().fit(y)
1222-
label_transform = self._le.transform
1223-
else:
1224-
label_transform = lambda x: x
1225-
12261193
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
12271194
train_dmatrix, evals = _wrap_evaluation_matrices(
12281195
missing=self.missing,
@@ -1240,7 +1207,6 @@ def fit(
12401207
eval_qid=None,
12411208
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
12421209
enable_categorical=self.enable_categorical,
1243-
label_transform=label_transform,
12441210
)
12451211

12461212
self._Booster = train(
@@ -1403,9 +1369,6 @@ def evals_result(self) -> TrainingCallback.EvalsLog:
14031369
extra_parameters='''
14041370
n_estimators : int
14051371
Number of trees in random forest to fit.
1406-
use_label_encoder : bool
1407-
(Deprecated) Use the label encoder from scikit-learn to encode the labels. For new
1408-
code, we recommend that you set this parameter to False.
14091372
''')
14101373
class XGBRFClassifier(XGBClassifier):
14111374
# pylint: disable=missing-docstring
@@ -1416,14 +1379,12 @@ def __init__(
14161379
subsample: float = 0.8,
14171380
colsample_bynode: float = 0.8,
14181381
reg_lambda: float = 1e-5,
1419-
use_label_encoder: bool = True,
14201382
**kwargs: Any
14211383
):
14221384
super().__init__(learning_rate=learning_rate,
14231385
subsample=subsample,
14241386
colsample_bynode=colsample_bynode,
14251387
reg_lambda=reg_lambda,
1426-
use_label_encoder=use_label_encoder,
14271388
**kwargs)
14281389

14291390
def get_xgb_params(self) -> Dict[str, Any]:

tests/python-gpu/test_from_cudf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def test_cudf_training_with_sklearn():
239239
y_cudf_series = ss(data=y.iloc[:, 0])
240240

241241
for y_obj in [y_cudf, y_cudf_series]:
242-
clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist', use_label_encoder=False)
242+
clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist')
243243
clf.fit(X_cudf, y_obj, sample_weight=cudf_weights, base_margin=cudf_base_margin,
244244
eval_set=[(X_cudf, y_obj)])
245245
pred = clf.predict(X_cudf)

tests/python-gpu/test_from_cupy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_cupy_training_with_sklearn():
122122
base_margin = np.random.random(50)
123123
cupy_base_margin = cp.array(base_margin)
124124

125-
clf = xgb.XGBClassifier(gpu_id=0, tree_method="gpu_hist", use_label_encoder=False)
125+
clf = xgb.XGBClassifier(gpu_id=0, tree_method="gpu_hist")
126126
clf.fit(
127127
X,
128128
y,

tests/python-gpu/test_gpu_basic_models.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,20 @@
77
# Don't import the test class, otherwise they will run twice.
88
import test_callback as test_cb # noqa
99
import test_basic_models as test_bm
10+
import testing as tm
1011
rng = np.random.RandomState(1994)
1112

1213

1314
class TestGPUBasicModels:
1415
cpu_test_cb = test_cb.TestCallbacks()
1516
cpu_test_bm = test_bm.TestModels()
1617

17-
def run_cls(self, X, y, deterministic):
18-
cls = xgb.XGBClassifier(tree_method='gpu_hist',
19-
deterministic_histogram=deterministic,
20-
single_precision_histogram=True)
18+
def run_cls(self, X, y):
19+
cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True)
2120
cls.fit(X, y)
2221
cls.get_booster().save_model('test_deterministic_gpu_hist-0.json')
2322

24-
cls = xgb.XGBClassifier(tree_method='gpu_hist',
25-
deterministic_histogram=deterministic,
26-
single_precision_histogram=True)
23+
cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True)
2724
cls.fit(X, y)
2825
cls.get_booster().save_model('test_deterministic_gpu_hist-1.json')
2926

@@ -49,19 +46,22 @@ def test_deterministic_gpu_hist(self):
4946
kClasses = 4
5047
# Create large values to force rounding.
5148
X = np.random.randn(kRows, kCols) * 1e4
52-
y = np.random.randint(0, kClasses, size=kRows) * 1e4
49+
y = np.random.randint(0, kClasses, size=kRows)
5350

54-
model_0, model_1 = self.run_cls(X, y, True)
51+
model_0, model_1 = self.run_cls(X, y)
5552
assert model_0 == model_1
5653

54+
@pytest.mark.skipif(**tm.no_sklearn())
5755
def test_invalid_gpu_id(self):
58-
X = np.random.randn(10, 5) * 1e4
59-
y = np.random.randint(0, 2, size=10) * 1e4
56+
from sklearn.datasets import load_digits
57+
X, y = load_digits(return_X_y=True)
6058
# should pass with invalid gpu id
6159
cls1 = xgb.XGBClassifier(tree_method='gpu_hist', gpu_id=9999)
6260
cls1.fit(X, y)
6361
# should throw error with fail_on_invalid_gpu_id enabled
64-
cls2 = xgb.XGBClassifier(tree_method='gpu_hist', gpu_id=9999, fail_on_invalid_gpu_id=True)
62+
cls2 = xgb.XGBClassifier(
63+
tree_method='gpu_hist', gpu_id=9999, fail_on_invalid_gpu_id=True
64+
)
6565
try:
6666
cls2.fit(X, y)
6767
assert False, "Should have failed with with fail_on_invalid_gpu_id enabled"

tests/python-gpu/test_gpu_pickling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,10 @@ def test_pickled_predictor(self):
146146

147147
os.remove(model_path)
148148

149+
@pytest.mark.skipif(**tm.no_sklearn())
149150
def test_predict_sklearn_pickle(self):
150-
x, y = build_dataset()
151+
from sklearn.datasets import load_digits
152+
x, y = load_digits(return_X_y=True)
151153

152154
kwargs = {'tree_method': 'gpu_hist',
153155
'predictor': 'gpu_predictor',

tests/python-gpu/test_gpu_with_sklearn.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def test_categorical():
5656
X, y = load_svmlight_file(os.path.join(data_dir, "agaricus.txt.train"))
5757
clf = xgb.XGBClassifier(
5858
tree_method="gpu_hist",
59-
use_label_encoder=False,
6059
enable_categorical=True,
6160
n_estimators=10,
6261
)
@@ -98,3 +97,36 @@ def check_predt(X, y):
9897

9998
X = cudf.DataFrame(X)
10099
check_predt(X, y)
100+
101+
102+
@pytest.mark.skipif(**tm.no_cupy())
103+
@pytest.mark.skipif(**tm.no_cudf())
104+
def test_classififer():
105+
from sklearn.datasets import load_digits
106+
import cupy as cp
107+
import cudf
108+
109+
X, y = load_digits(return_X_y=True)
110+
y *= 10
111+
112+
clf = xgb.XGBClassifier(tree_method="gpu_hist", n_estimators=1)
113+
114+
# numpy
115+
with pytest.raises(ValueError, match=r"Invalid classes.*"):
116+
clf.fit(X, y)
117+
118+
# cupy
119+
X, y = cp.array(X), cp.array(y)
120+
with pytest.raises(ValueError, match=r"Invalid classes.*"):
121+
clf.fit(X, y)
122+
123+
# cudf
124+
X, y = cudf.DataFrame(X), cudf.DataFrame(y)
125+
with pytest.raises(ValueError, match=r"Invalid classes.*"):
126+
clf.fit(X, y)
127+
128+
# pandas
129+
X, y = load_digits(return_X_y=True, as_frame=True)
130+
y *= 10
131+
with pytest.raises(ValueError, match=r"Invalid classes.*"):
132+
clf.fit(X, y)

tests/python/test_with_sklearn.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,6 @@ def test_feature_importances_gain():
283283
random_state=0, tree_method="exact",
284284
learning_rate=0.1,
285285
importance_type="gain",
286-
use_label_encoder=False,
287286
).fit(X, y)
288287

289288
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
@@ -306,7 +305,6 @@ def test_feature_importances_gain():
306305
tree_method="exact",
307306
learning_rate=0.1,
308307
importance_type="gain",
309-
use_label_encoder=False,
310308
).fit(X, y)
311309
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
312310

@@ -315,14 +313,11 @@ def test_feature_importances_gain():
315313
tree_method="exact",
316314
learning_rate=0.1,
317315
importance_type="gain",
318-
use_label_encoder=False,
319316
).fit(X, y)
320317
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
321318

322319
# no split can be found
323-
cls = xgb.XGBClassifier(
324-
min_child_weight=1000, tree_method="hist", n_estimators=1, use_label_encoder=False
325-
)
320+
cls = xgb.XGBClassifier(min_child_weight=1000, tree_method="hist", n_estimators=1)
326321
cls.fit(X, y)
327322
assert np.all(cls.feature_importances_ == 0)
328323

@@ -497,7 +492,7 @@ def dummy_objective(y_true, y_preds):
497492
X, y
498493
)
499494

500-
cls = xgb.XGBClassifier(use_label_encoder=False, n_estimators=1)
495+
cls = xgb.XGBClassifier(n_estimators=1)
501496
cls.fit(X, y)
502497

503498
is_called = [False]
@@ -923,7 +918,7 @@ def test_RFECV():
923918
bst = xgb.XGBClassifier(booster='gblinear', learning_rate=0.1,
924919
n_estimators=10,
925920
objective='binary:logistic',
926-
random_state=0, verbosity=0, use_label_encoder=False)
921+
random_state=0, verbosity=0)
927922
rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='roc_auc')
928923
rfecv.fit(X, y)
929924

@@ -934,7 +929,7 @@ def test_RFECV():
934929
n_estimators=10,
935930
objective='multi:softprob',
936931
random_state=0, reg_alpha=0.001, reg_lambda=0.01,
937-
scale_pos_weight=0.5, verbosity=0, use_label_encoder=False)
932+
scale_pos_weight=0.5, verbosity=0)
938933
rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='neg_log_loss')
939934
rfecv.fit(X, y)
940935

@@ -943,7 +938,7 @@ def test_RFECV():
943938
rfecv = RFECV(estimator=reg)
944939
rfecv.fit(X, y)
945940

946-
cls = xgb.XGBClassifier(use_label_encoder=False)
941+
cls = xgb.XGBClassifier()
947942
rfecv = RFECV(estimator=cls, step=1, cv=3,
948943
scoring='neg_mean_squared_error')
949944
rfecv.fit(X, y)
@@ -1052,8 +1047,9 @@ def test_deprecate_position_arg():
10521047
with pytest.warns(FutureWarning):
10531048
model.fit(X, y, w)
10541049

1055-
with pytest.warns(FutureWarning):
1050+
with pytest.raises(ValueError):
10561051
xgb.XGBRFClassifier(1, use_label_encoder=True)
1052+
10571053
model = xgb.XGBRFClassifier(n_estimators=1)
10581054
with pytest.warns(FutureWarning):
10591055
model.fit(X, y, w)

0 commit comments

Comments
 (0)