|
19 | 19 | """Cost model based on xgboost"""
|
20 | 20 | import multiprocessing
|
21 | 21 | import logging
|
| 22 | +from typing import Dict |
22 | 23 | from collections import defaultdict
|
23 | 24 |
|
24 | 25 | import numpy as np
|
|
28 | 29 | from ..feature import get_per_store_features_from_measure_pairs, get_per_store_features_from_states
|
29 | 30 | from ..measure_record import RecordReader
|
30 | 31 |
|
| 32 | +try: |
| 33 | + from xgboost.callback import TrainingCallback # type: ignore |
| 34 | +except ImportError: |
| 35 | + |
| 36 | + class TrainingCallback: # type: ignore |
| 37 | + pass |
| 38 | + |
| 39 | + |
31 | 40 | xgb = None
|
32 | 41 |
|
33 | 42 | logger = logging.getLogger("auto_scheduler")
|
@@ -198,7 +207,7 @@ def update(self, inputs, results):
|
198 | 207 | num_boost_round=10000,
|
199 | 208 | obj=pack_sum_square_error,
|
200 | 209 | callbacks=[
|
201 |
| - custom_callback( |
| 210 | + CustomCallback( |
202 | 211 | stopping_rounds=50,
|
203 | 212 | metric="tr-p-rmse",
|
204 | 213 | fevals=[
|
@@ -539,125 +548,144 @@ def feval(preds, labels):
|
539 | 548 | return feval
|
540 | 549 |
|
541 | 550 |
|
542 |
| -def custom_callback( |
543 |
| - stopping_rounds, |
544 |
| - metric, |
545 |
| - fevals, |
546 |
| - evals=(), |
547 |
| - log_file=None, |
548 |
| - maximize=False, |
549 |
| - verbose_eval=True, |
550 |
| - skip_every=2, |
551 |
| -): |
552 |
| - """Callback function for xgboost to support multiple custom evaluation functions""" |
553 |
| - # pylint: disable=import-outside-toplevel |
554 |
| - from xgboost.core import EarlyStopException |
555 |
| - from xgboost.callback import _fmt_metric |
556 |
| - |
557 |
| - try: |
558 |
| - from xgboost.training import aggcv |
559 |
| - except ImportError: |
560 |
| - from xgboost.callback import _aggcv as aggcv |
561 |
| - |
562 |
| - state = {} |
563 |
| - metric_shortname = metric.split("-")[1] |
564 |
| - |
565 |
| - def init(env): |
566 |
| - """internal function""" |
567 |
| - bst = env.model |
568 |
| - |
569 |
| - state["maximize_score"] = maximize |
570 |
| - state["best_iteration"] = 0 |
571 |
| - if maximize: |
572 |
| - state["best_score"] = float("-inf") |
573 |
| - else: |
574 |
| - state["best_score"] = float("inf") |
| 551 | +class XGBoostCallback(TrainingCallback): |
| 552 | + """Base class for XGBoost callbacks.""" |
575 | 553 |
|
576 |
| - if bst is not None: |
577 |
| - if bst.attr("best_score") is not None: |
578 |
| - state["best_score"] = float(bst.attr("best_score")) |
579 |
| - state["best_iteration"] = int(bst.attr("best_iteration")) |
580 |
| - state["best_msg"] = bst.attr("best_msg") |
581 |
| - else: |
582 |
| - bst.set_attr(best_iteration=str(state["best_iteration"])) |
583 |
| - bst.set_attr(best_score=str(state["best_score"])) |
584 |
| - else: |
585 |
| - assert env.cvfolds is not None |
| 554 | + def __call__(self, env: "xgb.core.CallbackEnv"): |
| 555 | + # Compatibility with xgboost < 1.3 |
| 556 | + return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) |
586 | 557 |
|
587 |
| - def callback(env): |
588 |
| - """internal function""" |
589 |
| - if not state: |
590 |
| - init(env) |
| 558 | + def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): |
| 559 | + raise NotImplementedError |
| 560 | + |
| 561 | + |
| 562 | +class CustomCallback(XGBoostCallback): |
| 563 | + """ |
| 564 | + Callback function for xgboost. |
| 565 | + Support custom evaluation function and early-stopping. |
| 566 | + """ |
| 567 | + |
| 568 | + def __init__( |
| 569 | + self, |
| 570 | + stopping_rounds, |
| 571 | + metric, |
| 572 | + fevals, |
| 573 | + evals=(), |
| 574 | + log_file=None, |
| 575 | + maximize=False, |
| 576 | + verbose_eval=True, |
| 577 | + skip_every=2, |
| 578 | + ): |
| 579 | + """Init function""" |
| 580 | + self.stopping_rounds = stopping_rounds |
| 581 | + self.metric = metric |
| 582 | + self.metric_shortname = metric.split("-")[1] |
| 583 | + self.fevals = fevals |
| 584 | + self.evals = evals |
| 585 | + self.log_file = log_file |
| 586 | + self.maximize = maximize |
| 587 | + self.verbose_eval = verbose_eval |
| 588 | + self.skip_every = skip_every |
| 589 | + self.state = {} |
591 | 590 |
|
592 |
| - bst = env.model |
593 |
| - i = env.iteration |
594 |
| - cvfolds = env.cvfolds |
| 591 | + def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): |
| 592 | + """Run after each iteration. Return True when training should stop.""" |
| 593 | + # pylint:disable = import-outside-toplevel |
| 594 | + try: |
| 595 | + from xgboost.callback import _fmt_metric # type: ignore |
| 596 | + except ImportError: |
| 597 | + # Compatibility with xgboost >= 1.6 |
| 598 | + def _fmt_metric(value, show_stdv=True): |
| 599 | + """format metric string""" |
| 600 | + if len(value) == 2: |
| 601 | + return f"{value[0]}:{value[1]:.5f}" |
| 602 | + if len(value) == 3: |
| 603 | + if show_stdv: |
| 604 | + return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}" |
| 605 | + return f"{value[0]}:{value[1]:.5f}" |
| 606 | + raise ValueError("wrong metric value", value) |
| 607 | + |
| 608 | + ##### init state ##### |
| 609 | + if not self.state: |
| 610 | + self.state["maximize_score"] = self.maximize |
| 611 | + self.state["best_iteration"] = 0 |
| 612 | + if self.maximize: |
| 613 | + self.state["best_score"] = float("-inf") |
| 614 | + else: |
| 615 | + self.state["best_score"] = float("inf") |
595 | 616 |
|
| 617 | + assert model is not None |
| 618 | + if model.attr("best_score") is not None: |
| 619 | + self.state["best_score"] = float(model.attr("best_score")) |
| 620 | + self.state["best_iteration"] = int(model.attr("best_iteration")) |
| 621 | + self.state["best_msg"] = model.attr("best_msg") |
| 622 | + else: |
| 623 | + model.set_attr(best_iteration=str(self.state["best_iteration"])) |
| 624 | + model.set_attr(best_score=str(self.state["best_score"])) |
596 | 625 | res_dict = {}
|
597 | 626 |
|
598 |
| - if i % skip_every == 1: |
599 |
| - return |
| 627 | + if epoch % self.skip_every == 1: |
| 628 | + return False |
600 | 629 |
|
601 | 630 | ##### evaluation #####
|
602 |
| - if cvfolds is not None: |
603 |
| - for feval in fevals: |
604 |
| - tmp = aggcv([f.eval(i, feval) for f in cvfolds]) |
605 |
| - for k, mean, std in tmp: |
606 |
| - res_dict[k] = [mean, std] |
607 |
| - else: |
608 |
| - for feval in fevals: |
609 |
| - bst_eval = bst.eval_set(evals, i, feval) |
610 |
| - res = [x.split(":") for x in bst_eval.split()] |
611 |
| - for kv in res[1:]: |
612 |
| - res_dict[kv[0]] = [float(kv[1])] |
| 631 | + for feval in self.fevals: |
| 632 | + bst_eval = model.eval_set(self.evals, epoch, feval) |
| 633 | + res = [x.split(":") for x in bst_eval.split()] |
| 634 | + for kv in res[1:]: |
| 635 | + res_dict[kv[0]] = [float(kv[1])] |
613 | 636 |
|
614 | 637 | eval_res = []
|
615 | 638 | keys = list(res_dict.keys())
|
616 |
| - keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x) |
| 639 | + keys.sort(key=lambda x: x if self.metric_shortname not in x else "a" + x) |
617 | 640 | for key in keys:
|
618 | 641 | v = res_dict[key]
|
619 | 642 | eval_res.append([key] + v)
|
620 | 643 |
|
621 | 644 | ##### print eval result #####
|
622 |
| - if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0: |
623 |
| - infos = ["XGB iter: %3d" % i] |
| 645 | + if ( |
| 646 | + not isinstance(self.verbose_eval, bool) |
| 647 | + and self.verbose_eval |
| 648 | + and epoch % self.verbose_eval == 0 |
| 649 | + ): |
| 650 | + infos = ["XGB iter: %3d" % epoch] |
624 | 651 | for item in eval_res:
|
625 | 652 | if "null" in item[0]:
|
626 | 653 | continue
|
627 | 654 | infos.append("%s: %.6f" % (item[0], item[1]))
|
628 | 655 |
|
629 | 656 | logger.debug("\t".join(infos))
|
630 |
| - if log_file: |
631 |
| - with open(log_file, "a") as fout: |
| 657 | + if self.log_file: |
| 658 | + with open(self.log_file, "a") as fout: |
632 | 659 | fout.write("\t".join(infos) + "\n")
|
633 | 660 |
|
634 | 661 | ##### choose score and do early stopping #####
|
635 | 662 | score = None
|
636 | 663 | for item in eval_res:
|
637 |
| - if item[0] == metric: |
| 664 | + if item[0] == self.metric: |
638 | 665 | score = item[1]
|
639 | 666 | break
|
640 | 667 | assert score is not None
|
641 | 668 |
|
642 |
| - best_score = state["best_score"] |
643 |
| - best_iteration = state["best_iteration"] |
644 |
| - maximize_score = state["maximize_score"] |
| 669 | + best_score = self.state["best_score"] |
| 670 | + best_iteration = self.state["best_iteration"] |
| 671 | + maximize_score = self.state["maximize_score"] |
| 672 | + |
645 | 673 | if (maximize_score and score > best_score) or (not maximize_score and score < best_score):
|
646 |
| - msg = "[%d] %s" % (env.iteration, "\t".join([_fmt_metric(x) for x in eval_res])) |
647 |
| - state["best_msg"] = msg |
648 |
| - state["best_score"] = score |
649 |
| - state["best_iteration"] = env.iteration |
| 674 | + msg = "[%d] %s" % (epoch, "\t".join([_fmt_metric(x) for x in eval_res])) |
| 675 | + self.state["best_msg"] = msg |
| 676 | + self.state["best_score"] = score |
| 677 | + self.state["best_iteration"] = epoch |
650 | 678 | # save the property to attributes, so they will occur in checkpoint.
|
651 |
| - if env.model is not None: |
652 |
| - env.model.set_attr( |
653 |
| - best_score=str(state["best_score"]), |
654 |
| - best_iteration=str(state["best_iteration"]), |
655 |
| - best_msg=state["best_msg"], |
| 679 | + if model is not None: |
| 680 | + model.set_attr( |
| 681 | + best_score=str(self.state["best_score"]), |
| 682 | + best_iteration=str(self.state["best_iteration"]), |
| 683 | + best_msg=self.state["best_msg"], |
656 | 684 | )
|
657 |
| - elif env.iteration - best_iteration >= stopping_rounds: |
658 |
| - best_msg = state["best_msg"] |
659 |
| - if verbose_eval and env.rank == 0: |
| 685 | + elif epoch - best_iteration >= self.stopping_rounds: |
| 686 | + best_msg = self.state["best_msg"] |
| 687 | + if self.verbose_eval: |
660 | 688 | logger.debug("XGB stopped. Best iteration: %s ", best_msg)
|
661 |
| - raise EarlyStopException(best_iteration) |
| 689 | + return True |
662 | 690 |
|
663 |
| - return callback |
| 691 | + return False |
0 commit comments