Skip to content

Deprecate LabelEncoder in XGBClassifier; Enable cuDF/cuPy inputs in XGBClassifier #6269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from .core import Booster, DMatrix, XGBoostError
from .training import train
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array

# Do not use class names on scikit-learn directly. Re-define the classes on
# .compat to guarantee the behavior without scikit-learn
Expand Down Expand Up @@ -345,7 +346,7 @@ def get_xgb_params(self):
params = self.get_params()
# Parameters that should not go into native learner.
wrapper_specific = {
'importance_type', 'kwargs', 'missing', 'n_estimators'}
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder'}
filtered = dict()
for k, v in params.items():
if k not in wrapper_specific:
Expand Down Expand Up @@ -430,6 +431,9 @@ def load_model(self, fname):
if k == 'classes_':
self.classes_ = np.array(v)
continue
if k == 'use_label_encoder':
self.use_label_encoder = bool(v)
continue
if k == 'type' and type(self).__name__ != v:
msg = 'Current model type: {}, '.format(type(self).__name__) + \
'type of model in file: {}'.format(v)
Expand Down Expand Up @@ -763,21 +767,53 @@ def intercept_(self):
['model', 'objective'], extra_parameters='''
n_estimators : int
Number of boosting rounds.
use_label_encoder : bool
(Deprecated) Use the label encoder from scikit-learn to encode the labels. For new code,
we recommend that you set this parameter to False.
''')
class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
def __init__(self, objective="binary:logistic", **kwargs):
def __init__(self, objective="binary:logistic", use_label_encoder=True, **kwargs):
self.use_label_encoder = use_label_encoder
super().__init__(objective=objective, **kwargs)

def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None, feature_weights=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
# pylint: disable = attribute-defined-outside-init,arguments-differ,too-many-statements

can_use_label_encoder = True
label_encoding_check_error = (
'The label must consist of integer labels of form 0, 1, 2, ..., [num_class - 1].')
label_encoder_deprecation_msg = (
'The use of label encoder in XGBClassifier is deprecated and will be ' +
'removed in a future release. To remove this warning, do the ' +
'following: 1) Pass option use_label_encoder=False when constructing ' +
'XGBClassifier object; and 2) Encode your labels (y) as integers ' +
'starting with 0, i.e. 0, 1, 2, ..., [num_class - 1].')

evals_result = {}
self.classes_ = np.unique(y)
self.n_classes_ = len(self.classes_)
if _is_cudf_df(y) or _is_cudf_ser(y):
import cupy as cp # pylint: disable=E0401
self.classes_ = cp.unique(y.values)
self.n_classes_ = len(self.classes_)
can_use_label_encoder = False
if not cp.array_equal(self.classes_, cp.arange(self.n_classes_)):
raise ValueError(label_encoding_check_error)
elif _is_cupy_array(y):
import cupy as cp # pylint: disable=E0401
self.classes_ = cp.unique(y)
self.n_classes_ = len(self.classes_)
can_use_label_encoder = False
if not cp.array_equal(self.classes_, cp.arange(self.n_classes_)):
raise ValueError(label_encoding_check_error)
else:
self.classes_ = np.unique(y)
self.n_classes_ = len(self.classes_)
if not self.use_label_encoder and (
not np.array_equal(self.classes_, np.arange(self.n_classes_))):
raise ValueError(label_encoding_check_error)

xgb_options = self.get_xgb_params()

Expand All @@ -801,8 +837,18 @@ def fit(self, X, y, sample_weight=None, base_margin=None,
else:
xgb_options.update({"eval_metric": eval_metric})

self._le = XGBoostLabelEncoder().fit(y)
training_labels = self._le.transform(y)
if self.use_label_encoder:
if not can_use_label_encoder:
raise ValueError('The option use_label_encoder=True is incompatible with inputs ' +
'of type cuDF or cuPy. Please set use_label_encoder=False when ' +
'constructing XGBClassifier object. NOTE: ' +
label_encoder_deprecation_msg)
warnings.warn(label_encoder_deprecation_msg, UserWarning)
self._le = XGBoostLabelEncoder().fit(y)
label_transform = self._le.transform
else:
label_transform = (lambda x: x)
training_labels = label_transform(y)

if eval_set is not None:
if sample_weight_eval_set is None:
Expand All @@ -811,7 +857,7 @@ def fit(self, X, y, sample_weight=None, base_margin=None,
assert len(sample_weight_eval_set) == len(eval_set)
evals = list(
DMatrix(eval_set[i][0],
label=self._le.transform(eval_set[i][1]),
label=label_transform(eval_set[i][1]),
missing=self.missing, weight=sample_weight_eval_set[i],
nthread=self.n_jobs)
for i in range(len(eval_set))
Expand Down Expand Up @@ -919,9 +965,7 @@ def predict(self, data, output_margin=False, ntree_limit=None,

if hasattr(self, '_le'):
return self._le.inverse_transform(column_indexes)
warnings.warn(
'Label encoder is not defined. Returning class probability.')
return class_probs
return column_indexes

def predict_proba(self, data, ntree_limit=None, validate_features=False,
base_margin=None):
Expand Down Expand Up @@ -1012,6 +1056,9 @@ def evals_result(self):
extra_parameters='''
n_estimators : int
Number of trees in random forest to fit.
use_label_encoder : bool
(Deprecated) Use the label encoder from scikit-learn to encode the labels. For new code,
we recommend that you set this parameter to False.
''')
class XGBRFClassifier(XGBClassifier):
# pylint: disable=missing-docstring
Expand All @@ -1020,11 +1067,13 @@ def __init__(self,
subsample=0.8,
colsample_bynode=0.8,
reg_lambda=1e-5,
use_label_encoder=True,
**kwargs):
super().__init__(learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
use_label_encoder=use_label_encoder,
**kwargs)

def get_xgb_params(self):
Expand Down
28 changes: 28 additions & 0 deletions tests/python-gpu/test_from_cudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,34 @@ def test_cudf_metainfo_device_dmatrix(self):
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)


@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.skipif(**tm.no_pandas())
def test_cudf_training_with_sklearn():
from cudf import DataFrame as df
from cudf import Series as ss
import pandas as pd
np.random.seed(1)
X = pd.DataFrame(np.random.randn(50, 10))
y = pd.DataFrame((np.random.randn(50) > 0).astype(np.int8))
weights = np.random.random(50) + 1.0
cudf_weights = df.from_pandas(pd.DataFrame(weights))
base_margin = np.random.random(50)
cudf_base_margin = df.from_pandas(pd.DataFrame(base_margin))

X_cudf = df.from_pandas(X)
y_cudf = df.from_pandas(y)
y_cudf_series = ss(data=y.iloc[:, 0])

for y_obj in [y_cudf, y_cudf_series]:
clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist', use_label_encoder=False)
clf.fit(X_cudf, y_obj, sample_weight=cudf_weights, base_margin=cudf_base_margin,
eval_set=[(X_cudf, y_obj)])
pred = clf.predict(X_cudf)
assert np.array_equal(np.unique(pred), np.array([0, 1]))


class IterForDMatrixTest(xgb.core.DataIter):
'''A data iterator for XGBoost DMatrix.

Expand Down
19 changes: 19 additions & 0 deletions tests/python-gpu/test_from_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,25 @@ def _test_cupy_metainfo(DMatrixT):
dmat_cupy.get_uint_info('group_ptr'))


@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_sklearn())
def test_cupy_training_with_sklearn():
import cupy as cp
np.random.seed(1)
cp.random.seed(1)
X = cp.random.randn(50, 10, dtype='float32')
y = (cp.random.randn(50, dtype='float32') > 0).astype('int8')
weights = np.random.random(50) + 1
cupy_weights = cp.array(weights)
base_margin = np.random.random(50)
cupy_base_margin = cp.array(base_margin)

clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist', use_label_encoder=False)
clf.fit(X, y, sample_weight=cupy_weights, base_margin=cupy_base_margin, eval_set=[(X, y)])
pred = clf.predict(X)
assert np.array_equal(np.unique(pred), np.array([0, 1]))


class TestFromCupy:
'''Tests for constructing DMatrix from data structure conforming Apache
Arrow specification.'''
Expand Down
26 changes: 12 additions & 14 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,19 +706,17 @@ def save_load_model(model_path):
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold

digits = load_digits(2)
digits = load_digits(n_class=2)
y = digits['target']
X = digits['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
xgb_model = xgb.XGBClassifier(use_label_encoder=False).fit(X[train_index], y[train_index])
xgb_model.save_model(model_path)
xgb_model = xgb.XGBClassifier()
xgb_model = xgb.XGBClassifier(use_label_encoder=False)
xgb_model.load_model(model_path)
assert isinstance(xgb_model.classes_, np.ndarray)
assert isinstance(xgb_model._Booster, xgb.Booster)
assert isinstance(xgb_model._le, XGBoostLabelEncoder)
assert isinstance(xgb_model._le.classes_, np.ndarray)
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(1 for i in range(len(preds))
Expand Down Expand Up @@ -750,7 +748,7 @@ def test_save_load_model():
from sklearn.datasets import load_digits
with TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model.json')
digits = load_digits(2)
digits = load_digits(n_class=2)
y = digits['target']
X = digits['data']
booster = xgb.train({'tree_method': 'hist',
Expand All @@ -761,7 +759,7 @@ def test_save_load_model():
booster.save_model(model_path)
cls = xgb.XGBClassifier()
cls.load_model(model_path)
predt_1 = cls.predict(X)
predt_1 = cls.predict_proba(X)[:, 1]
assert np.allclose(predt_0, predt_1)

cls = xgb.XGBModel()
Expand All @@ -778,10 +776,10 @@ def test_RFECV():

# Regression
X, y = load_boston(return_X_y=True)
bst = xgb.XGBClassifier(booster='gblinear', learning_rate=0.1,
n_estimators=10,
objective='reg:squarederror',
random_state=0, verbosity=0)
bst = xgb.XGBRegressor(booster='gblinear', learning_rate=0.1,
n_estimators=10,
objective='reg:squarederror',
random_state=0, verbosity=0)
rfecv = RFECV(
estimator=bst, step=1, cv=3, scoring='neg_mean_squared_error')
rfecv.fit(X, y)
Expand All @@ -791,7 +789,7 @@ def test_RFECV():
bst = xgb.XGBClassifier(booster='gblinear', learning_rate=0.1,
n_estimators=10,
objective='binary:logistic',
random_state=0, verbosity=0)
random_state=0, verbosity=0, use_label_encoder=False)
rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='roc_auc')
rfecv.fit(X, y)

Expand All @@ -802,7 +800,7 @@ def test_RFECV():
n_estimators=10,
objective='multi:softprob',
random_state=0, reg_alpha=0.001, reg_lambda=0.01,
scale_pos_weight=0.5, verbosity=0)
scale_pos_weight=0.5, verbosity=0, use_label_encoder=False)
rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='neg_log_loss')
rfecv.fit(X, y)

Expand All @@ -811,7 +809,7 @@ def test_RFECV():
rfecv = RFECV(estimator=reg)
rfecv.fit(X, y)

cls = xgb.XGBClassifier()
cls = xgb.XGBClassifier(use_label_encoder=False)
rfecv = RFECV(estimator=cls, step=1, cv=3,
scoring='neg_mean_squared_error')
rfecv.fit(X, y)
Expand Down