Skip to content

Commit 7249904

Browse files
authored
[AutoScheduler][AutoTVM] Enable xgboost >= 1.7.x new changes (#14036)
Enable xgboost >= 1.7.x new changes
1 parent 87bb8b1 commit 7249904

File tree

4 files changed

+238
-171
lines changed

4 files changed

+238
-171
lines changed

docs/install/from_source.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ like ``virtualenv``.
347347

348348
.. code:: bash
349349
350-
pip3 install --user tornado psutil 'xgboost<1.6.0' cloudpickle
350+
pip3 install --user tornado psutil 'xgboost>=1.1.0' cloudpickle
351351
352352
Note on M1 macs, you may have trouble installing xgboost / scipy. scipy and xgboost requires some additional dependencies to be installed,
353353
including openblas and its dependencies. Use the following commands to install scipy and xgboost with the required dependencies and
@@ -363,7 +363,7 @@ configuration. A workaround for this is to do the following commands:
363363
364364
pip install scipy --no-use-pep517
365365
366-
pip install 'xgboost<1.6.0'
366+
pip install 'xgboost>=1.1.0'
367367
368368
Install Contrib Libraries
369369
-------------------------

python/gen_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@
276276
("torch", None),
277277
("torchvision", None),
278278
("tornado", None),
279-
("xgboost", ">=1.1.0,<1.6.0"), # From PR #4953 & Issue #12009
279+
("xgboost", ">=1.1.0"), # From PR #4953 & Issue #12009
280280
]
281281

282282
################################################################################

python/tvm/auto_scheduler/cost_model/xgb_model.py

Lines changed: 115 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"""Cost model based on xgboost"""
2020
import multiprocessing
2121
import logging
22+
from typing import Dict
2223
from collections import defaultdict
2324

2425
import numpy as np
@@ -28,6 +29,14 @@
2829
from ..feature import get_per_store_features_from_measure_pairs, get_per_store_features_from_states
2930
from ..measure_record import RecordReader
3031

32+
try:
33+
from xgboost.callback import TrainingCallback # type: ignore
34+
except ImportError:
35+
36+
class TrainingCallback: # type: ignore
37+
pass
38+
39+
3140
xgb = None
3241

3342
logger = logging.getLogger("auto_scheduler")
@@ -198,7 +207,7 @@ def update(self, inputs, results):
198207
num_boost_round=10000,
199208
obj=pack_sum_square_error,
200209
callbacks=[
201-
custom_callback(
210+
CustomCallback(
202211
stopping_rounds=50,
203212
metric="tr-p-rmse",
204213
fevals=[
@@ -539,125 +548,144 @@ def feval(preds, labels):
539548
return feval
540549

541550

542-
def custom_callback(
543-
stopping_rounds,
544-
metric,
545-
fevals,
546-
evals=(),
547-
log_file=None,
548-
maximize=False,
549-
verbose_eval=True,
550-
skip_every=2,
551-
):
552-
"""Callback function for xgboost to support multiple custom evaluation functions"""
553-
# pylint: disable=import-outside-toplevel
554-
from xgboost.core import EarlyStopException
555-
from xgboost.callback import _fmt_metric
556-
557-
try:
558-
from xgboost.training import aggcv
559-
except ImportError:
560-
from xgboost.callback import _aggcv as aggcv
561-
562-
state = {}
563-
metric_shortname = metric.split("-")[1]
564-
565-
def init(env):
566-
"""internal function"""
567-
bst = env.model
568-
569-
state["maximize_score"] = maximize
570-
state["best_iteration"] = 0
571-
if maximize:
572-
state["best_score"] = float("-inf")
573-
else:
574-
state["best_score"] = float("inf")
551+
class XGBoostCallback(TrainingCallback):
552+
"""Base class for XGBoost callbacks."""
575553

576-
if bst is not None:
577-
if bst.attr("best_score") is not None:
578-
state["best_score"] = float(bst.attr("best_score"))
579-
state["best_iteration"] = int(bst.attr("best_iteration"))
580-
state["best_msg"] = bst.attr("best_msg")
581-
else:
582-
bst.set_attr(best_iteration=str(state["best_iteration"]))
583-
bst.set_attr(best_score=str(state["best_score"]))
584-
else:
585-
assert env.cvfolds is not None
554+
def __call__(self, env: "xgb.core.CallbackEnv"):
555+
# Compatibility with xgboost < 1.3
556+
return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)
586557

587-
def callback(env):
588-
"""internal function"""
589-
if not state:
590-
init(env)
558+
def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict):
559+
raise NotImplementedError
560+
561+
562+
class CustomCallback(XGBoostCallback):
563+
"""
564+
Callback function for xgboost.
565+
Support custom evaluation function and early-stopping.
566+
"""
567+
568+
def __init__(
569+
self,
570+
stopping_rounds,
571+
metric,
572+
fevals,
573+
evals=(),
574+
log_file=None,
575+
maximize=False,
576+
verbose_eval=True,
577+
skip_every=2,
578+
):
579+
"""Init function"""
580+
self.stopping_rounds = stopping_rounds
581+
self.metric = metric
582+
self.metric_shortname = metric.split("-")[1]
583+
self.fevals = fevals
584+
self.evals = evals
585+
self.log_file = log_file
586+
self.maximize = maximize
587+
self.verbose_eval = verbose_eval
588+
self.skip_every = skip_every
589+
self.state = {}
591590

592-
bst = env.model
593-
i = env.iteration
594-
cvfolds = env.cvfolds
591+
def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict):
592+
"""Run after each iteration. Return True when training should stop."""
593+
# pylint:disable = import-outside-toplevel
594+
try:
595+
from xgboost.callback import _fmt_metric # type: ignore
596+
except ImportError:
597+
# Compatibility with xgboost >= 1.6
598+
def _fmt_metric(value, show_stdv=True):
599+
"""format metric string"""
600+
if len(value) == 2:
601+
return f"{value[0]}:{value[1]:.5f}"
602+
if len(value) == 3:
603+
if show_stdv:
604+
return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
605+
return f"{value[0]}:{value[1]:.5f}"
606+
raise ValueError("wrong metric value", value)
607+
608+
##### init state #####
609+
if not self.state:
610+
self.state["maximize_score"] = self.maximize
611+
self.state["best_iteration"] = 0
612+
if self.maximize:
613+
self.state["best_score"] = float("-inf")
614+
else:
615+
self.state["best_score"] = float("inf")
595616

617+
assert model is not None
618+
if model.attr("best_score") is not None:
619+
self.state["best_score"] = float(model.attr("best_score"))
620+
self.state["best_iteration"] = int(model.attr("best_iteration"))
621+
self.state["best_msg"] = model.attr("best_msg")
622+
else:
623+
model.set_attr(best_iteration=str(self.state["best_iteration"]))
624+
model.set_attr(best_score=str(self.state["best_score"]))
596625
res_dict = {}
597626

598-
if i % skip_every == 1:
599-
return
627+
if epoch % self.skip_every == 1:
628+
return False
600629

601630
##### evaluation #####
602-
if cvfolds is not None:
603-
for feval in fevals:
604-
tmp = aggcv([f.eval(i, feval) for f in cvfolds])
605-
for k, mean, std in tmp:
606-
res_dict[k] = [mean, std]
607-
else:
608-
for feval in fevals:
609-
bst_eval = bst.eval_set(evals, i, feval)
610-
res = [x.split(":") for x in bst_eval.split()]
611-
for kv in res[1:]:
612-
res_dict[kv[0]] = [float(kv[1])]
631+
for feval in self.fevals:
632+
bst_eval = model.eval_set(self.evals, epoch, feval)
633+
res = [x.split(":") for x in bst_eval.split()]
634+
for kv in res[1:]:
635+
res_dict[kv[0]] = [float(kv[1])]
613636

614637
eval_res = []
615638
keys = list(res_dict.keys())
616-
keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x)
639+
keys.sort(key=lambda x: x if self.metric_shortname not in x else "a" + x)
617640
for key in keys:
618641
v = res_dict[key]
619642
eval_res.append([key] + v)
620643

621644
##### print eval result #####
622-
if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
623-
infos = ["XGB iter: %3d" % i]
645+
if (
646+
not isinstance(self.verbose_eval, bool)
647+
and self.verbose_eval
648+
and epoch % self.verbose_eval == 0
649+
):
650+
infos = ["XGB iter: %3d" % epoch]
624651
for item in eval_res:
625652
if "null" in item[0]:
626653
continue
627654
infos.append("%s: %.6f" % (item[0], item[1]))
628655

629656
logger.debug("\t".join(infos))
630-
if log_file:
631-
with open(log_file, "a") as fout:
657+
if self.log_file:
658+
with open(self.log_file, "a") as fout:
632659
fout.write("\t".join(infos) + "\n")
633660

634661
##### choose score and do early stopping #####
635662
score = None
636663
for item in eval_res:
637-
if item[0] == metric:
664+
if item[0] == self.metric:
638665
score = item[1]
639666
break
640667
assert score is not None
641668

642-
best_score = state["best_score"]
643-
best_iteration = state["best_iteration"]
644-
maximize_score = state["maximize_score"]
669+
best_score = self.state["best_score"]
670+
best_iteration = self.state["best_iteration"]
671+
maximize_score = self.state["maximize_score"]
672+
645673
if (maximize_score and score > best_score) or (not maximize_score and score < best_score):
646-
msg = "[%d] %s" % (env.iteration, "\t".join([_fmt_metric(x) for x in eval_res]))
647-
state["best_msg"] = msg
648-
state["best_score"] = score
649-
state["best_iteration"] = env.iteration
674+
msg = "[%d] %s" % (epoch, "\t".join([_fmt_metric(x) for x in eval_res]))
675+
self.state["best_msg"] = msg
676+
self.state["best_score"] = score
677+
self.state["best_iteration"] = epoch
650678
# save the property to attributes, so they will occur in checkpoint.
651-
if env.model is not None:
652-
env.model.set_attr(
653-
best_score=str(state["best_score"]),
654-
best_iteration=str(state["best_iteration"]),
655-
best_msg=state["best_msg"],
679+
if model is not None:
680+
model.set_attr(
681+
best_score=str(self.state["best_score"]),
682+
best_iteration=str(self.state["best_iteration"]),
683+
best_msg=self.state["best_msg"],
656684
)
657-
elif env.iteration - best_iteration >= stopping_rounds:
658-
best_msg = state["best_msg"]
659-
if verbose_eval and env.rank == 0:
685+
elif epoch - best_iteration >= self.stopping_rounds:
686+
best_msg = self.state["best_msg"]
687+
if self.verbose_eval:
660688
logger.debug("XGB stopped. Best iteration: %s ", best_msg)
661-
raise EarlyStopException(best_iteration)
689+
return True
662690

663-
return callback
691+
return False

0 commit comments

Comments
 (0)