|
10 | 10 | import numpy
|
11 | 11 |
|
12 | 12 | from . import rabit
|
13 |
| -from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError |
| 13 | +from .core import Booster, XGBoostError |
14 | 14 | from .compat import STRING_TYPES
|
15 | 15 |
|
16 | 16 |
|
17 |
| -def _get_callback_context(env): |
18 |
| - """return whether the current callback context is cv or train""" |
19 |
| - if env.model is not None and env.cvfolds is None: |
20 |
| - context = 'train' |
21 |
| - elif env.model is None and env.cvfolds is not None: |
22 |
| - context = 'cv' |
23 |
| - else: |
24 |
| - raise ValueError("Unexpected input with both model and cvfolds.") |
25 |
| - return context |
26 |
| - |
27 |
| - |
28 |
| -def _fmt_metric(value, show_stdv=True): |
29 |
| - """format metric string""" |
30 |
| - if len(value) == 2: |
31 |
| - return f"{value[0]}:{value[1]:.5f}" |
32 |
| - if len(value) == 3: |
33 |
| - if show_stdv: |
34 |
| - return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}" |
35 |
| - return f"{value[0]}:{value[1]:.5f}" |
36 |
| - raise ValueError("wrong metric value", value) |
37 |
| - |
38 |
| - |
39 |
| -def print_evaluation(period=1, show_stdv=True): |
40 |
| - """Create a callback that print evaluation result. |
41 |
| -
|
42 |
| - We print the evaluation results every **period** iterations |
43 |
| - and on the first and the last iterations. |
44 |
| -
|
45 |
| - Parameters |
46 |
| - ---------- |
47 |
| - period : int |
48 |
| - The period to log the evaluation results |
49 |
| -
|
50 |
| - show_stdv : bool, optional |
51 |
| - Whether show stdv if provided |
52 |
| -
|
53 |
| - Returns |
54 |
| - ------- |
55 |
| - callback : function |
56 |
| - A callback that print evaluation every period iterations. |
57 |
| - """ |
58 |
| - def callback(env): |
59 |
| - """internal function""" |
60 |
| - if env.rank != 0 or (not env.evaluation_result_list) or period is False or period == 0: |
61 |
| - return |
62 |
| - i = env.iteration |
63 |
| - if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration: |
64 |
| - msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list]) |
65 |
| - rabit.tracker_print(f"{i}\t{msg}\n") |
66 |
| - return callback |
67 |
| - |
68 |
| - |
69 |
| -def record_evaluation(eval_result): |
70 |
| - """Create a call back that records the evaluation history into **eval_result**. |
71 |
| -
|
72 |
| - Parameters |
73 |
| - ---------- |
74 |
| - eval_result : dict |
75 |
| - A dictionary to store the evaluation results. |
76 |
| -
|
77 |
| - Returns |
78 |
| - ------- |
79 |
| - callback : function |
80 |
| - The requested callback function. |
81 |
| - """ |
82 |
| - if not isinstance(eval_result, dict): |
83 |
| - raise TypeError('eval_result has to be a dictionary') |
84 |
| - eval_result.clear() |
85 |
| - |
86 |
| - def init(env): |
87 |
| - """internal function""" |
88 |
| - for k, _ in env.evaluation_result_list: |
89 |
| - pos = k.index('-') |
90 |
| - key = k[:pos] |
91 |
| - metric = k[pos + 1:] |
92 |
| - if key not in eval_result: |
93 |
| - eval_result[key] = {} |
94 |
| - if metric not in eval_result[key]: |
95 |
| - eval_result[key][metric] = [] |
96 |
| - |
97 |
| - def callback(env): |
98 |
| - """internal function""" |
99 |
| - if not eval_result: |
100 |
| - init(env) |
101 |
| - for k, v in env.evaluation_result_list: |
102 |
| - pos = k.index('-') |
103 |
| - key = k[:pos] |
104 |
| - metric = k[pos + 1:] |
105 |
| - eval_result[key][metric].append(v) |
106 |
| - return callback |
107 |
| - |
108 |
| - |
109 |
| -def reset_learning_rate(learning_rates): |
110 |
| - """Reset learning rate after iteration 1 |
111 |
| -
|
112 |
| - NOTE: the initial learning rate will still take in-effect on first iteration. |
113 |
| -
|
114 |
| - Parameters |
115 |
| - ---------- |
116 |
| - learning_rates: list or function |
117 |
| - List of learning rate for each boosting round |
118 |
| - or a customized function that calculates eta in terms of |
119 |
| - current number of round and the total number of boosting round (e.g. |
120 |
| - yields learning rate decay) |
121 |
| -
|
122 |
| - * list ``l``: ``eta = l[boosting_round]`` |
123 |
| - * function ``f``: ``eta = f(boosting_round, num_boost_round)`` |
124 |
| -
|
125 |
| - Returns |
126 |
| - ------- |
127 |
| - callback : function |
128 |
| - The requested callback function. |
129 |
| - """ |
130 |
| - def get_learning_rate(i, n, learning_rates): |
131 |
| - """helper providing the learning rate""" |
132 |
| - if isinstance(learning_rates, list): |
133 |
| - if len(learning_rates) != n: |
134 |
| - raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") |
135 |
| - new_learning_rate = learning_rates[i] |
136 |
| - else: |
137 |
| - new_learning_rate = learning_rates(i, n) |
138 |
| - return new_learning_rate |
139 |
| - |
140 |
| - def callback(env): |
141 |
| - """internal function""" |
142 |
| - context = _get_callback_context(env) |
143 |
| - |
144 |
| - if context == 'train': |
145 |
| - bst, i, n = env.model, env.iteration, env.end_iteration |
146 |
| - bst.set_param( |
147 |
| - 'learning_rate', get_learning_rate(i, n, learning_rates)) |
148 |
| - elif context == 'cv': |
149 |
| - i, n = env.iteration, env.end_iteration |
150 |
| - for cvpack in env.cvfolds: |
151 |
| - bst = cvpack.bst |
152 |
| - bst.set_param( |
153 |
| - 'learning_rate', get_learning_rate(i, n, learning_rates)) |
154 |
| - |
155 |
| - callback.before_iteration = False |
156 |
| - return callback |
157 |
| - |
158 |
| - |
159 |
| -def early_stop(stopping_rounds, maximize=False, verbose=True): |
160 |
| - """Create a callback that activates early stoppping. |
161 |
| -
|
162 |
| - Validation error needs to decrease at least |
163 |
| - every **stopping_rounds** round(s) to continue training. |
164 |
| - Requires at least one item in **evals**. |
165 |
| - If there's more than one, will use the last. |
166 |
| - Returns the model from the last iteration (not the best one). |
167 |
| - If early stopping occurs, the model will have three additional fields: |
168 |
| - ``bst.best_score``, ``bst.best_iteration``. |
169 |
| -
|
170 |
| - Parameters |
171 |
| - ---------- |
172 |
| - stopping_rounds : int |
173 |
| - The stopping rounds before the trend occur. |
174 |
| -
|
175 |
| - maximize : bool |
176 |
| - Whether to maximize evaluation metric. |
177 |
| -
|
178 |
| - verbose : optional, bool |
179 |
| - Whether to print message about early stopping information. |
180 |
| -
|
181 |
| - Returns |
182 |
| - ------- |
183 |
| - callback : function |
184 |
| - The requested callback function. |
185 |
| - """ |
186 |
| - state = {} |
187 |
| - |
188 |
| - def init(env): |
189 |
| - """internal function""" |
190 |
| - bst = env.model |
191 |
| - |
192 |
| - if not env.evaluation_result_list: |
193 |
| - raise ValueError('For early stopping you need at least one set in evals.') |
194 |
| - if len(env.evaluation_result_list) > 1 and verbose: |
195 |
| - msg = ("Multiple eval metrics have been passed: " |
196 |
| - "'{0}' will be used for early stopping.\n\n") |
197 |
| - rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0])) |
198 |
| - maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg') |
199 |
| - maximize_at_n_metrics = ('auc@', 'aucpr@', 'map@', 'ndcg@') |
200 |
| - maximize_score = maximize |
201 |
| - metric_label = env.evaluation_result_list[-1][0] |
202 |
| - metric = metric_label.split('-', 1)[-1] |
203 |
| - |
204 |
| - if any(metric.startswith(x) for x in maximize_at_n_metrics): |
205 |
| - maximize_score = True |
206 |
| - |
207 |
| - if any(metric.split(":")[0] == x for x in maximize_metrics): |
208 |
| - maximize_score = True |
209 |
| - |
210 |
| - if verbose and env.rank == 0: |
211 |
| - msg = "Will train until {} hasn't improved in {} rounds.\n" |
212 |
| - rabit.tracker_print(msg.format(metric_label, stopping_rounds)) |
213 |
| - |
214 |
| - state['maximize_score'] = maximize_score |
215 |
| - state['best_iteration'] = 0 |
216 |
| - if maximize_score: |
217 |
| - state['best_score'] = float('-inf') |
218 |
| - else: |
219 |
| - state['best_score'] = float('inf') |
220 |
| - # pylint: disable=consider-using-f-string |
221 |
| - msg = '[%d]\t%s' % ( |
222 |
| - env.iteration, |
223 |
| - '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]) |
224 |
| - ) |
225 |
| - state['best_msg'] = msg |
226 |
| - |
227 |
| - if bst is not None: |
228 |
| - if bst.attr('best_score') is not None: |
229 |
| - state['best_score'] = float(bst.attr('best_score')) |
230 |
| - state['best_iteration'] = int(bst.attr('best_iteration')) |
231 |
| - state['best_msg'] = bst.attr('best_msg') |
232 |
| - else: |
233 |
| - bst.set_attr(best_iteration=str(state['best_iteration'])) |
234 |
| - bst.set_attr(best_score=str(state['best_score'])) |
235 |
| - else: |
236 |
| - assert env.cvfolds is not None |
237 |
| - |
238 |
| - def callback(env): |
239 |
| - """internal function""" |
240 |
| - if not state: |
241 |
| - init(env) |
242 |
| - score = env.evaluation_result_list[-1][1] |
243 |
| - best_score = state['best_score'] |
244 |
| - best_iteration = state['best_iteration'] |
245 |
| - maximize_score = state['maximize_score'] |
246 |
| - if (maximize_score and score > best_score) or \ |
247 |
| - (not maximize_score and score < best_score): |
248 |
| - # pylint: disable=consider-using-f-string |
249 |
| - msg = '[%d]\t%s' % ( |
250 |
| - env.iteration, |
251 |
| - '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])) |
252 |
| - state['best_msg'] = msg |
253 |
| - state['best_score'] = score |
254 |
| - state['best_iteration'] = env.iteration |
255 |
| - # save the property to attributes, so they will occur in checkpoint. |
256 |
| - if env.model is not None: |
257 |
| - env.model.set_attr(best_score=str(state['best_score']), |
258 |
| - best_iteration=str(state['best_iteration']), |
259 |
| - best_msg=state['best_msg']) |
260 |
| - elif env.iteration - best_iteration >= stopping_rounds: |
261 |
| - best_msg = state['best_msg'] |
262 |
| - if verbose and env.rank == 0: |
263 |
| - msg = "Stopping. Best iteration:\n{}\n\n" |
264 |
| - rabit.tracker_print(msg.format(best_msg)) |
265 |
| - raise EarlyStopException(best_iteration) |
266 |
| - return callback |
267 |
| - |
268 |
| - |
269 | 17 | # The new implementation of callback functions.
|
270 | 18 | # Breaking:
|
271 | 19 | # - reset learning rate no longer accepts total boosting rounds
|
@@ -741,100 +489,3 @@ def after_iteration(self, model, epoch: int,
|
741 | 489 | model.save_model(path)
|
742 | 490 | self._epoch += 1
|
743 | 491 | return False
|
744 |
| - |
745 |
| - |
746 |
| -class LegacyCallbacks: |
747 |
| - '''Adapter for legacy callback functions. |
748 |
| -
|
749 |
| - .. versionadded:: 1.3.0 |
750 |
| -
|
751 |
| - Parameters |
752 |
| - ---------- |
753 |
| -
|
754 |
| - callbacks : Sequence |
755 |
| - A sequence of legacy callbacks (callbacks that are not instance of |
756 |
| - TrainingCallback) |
757 |
| - start_iteration : int |
758 |
| - Begining iteration. |
759 |
| - end_iteration : int |
760 |
| - End iteration, normally is the number of boosting rounds. |
761 |
| - evals : Sequence |
762 |
| - Sequence of evaluation dataset tuples. |
763 |
| - feval : Custom evaluation metric. |
764 |
| - ''' |
765 |
| - def __init__(self, callbacks, start_iteration, end_iteration, |
766 |
| - feval, cvfolds=None): |
767 |
| - self.callbacks_before_iter = [ |
768 |
| - cb for cb in callbacks |
769 |
| - if cb.__dict__.get('before_iteration', False)] |
770 |
| - self.callbacks_after_iter = [ |
771 |
| - cb for cb in callbacks |
772 |
| - if not cb.__dict__.get('before_iteration', False)] |
773 |
| - |
774 |
| - self.start_iteration = start_iteration |
775 |
| - self.end_iteration = end_iteration |
776 |
| - self.cvfolds = cvfolds |
777 |
| - |
778 |
| - self.feval = feval |
779 |
| - assert self.feval is None or callable(self.feval) |
780 |
| - |
781 |
| - if cvfolds is not None: |
782 |
| - self.aggregated_cv = None |
783 |
| - |
784 |
| - super().__init__() |
785 |
| - |
786 |
| - def before_training(self, model): |
787 |
| - '''Nothing to do for legacy callbacks''' |
788 |
| - return model |
789 |
| - |
790 |
| - def after_training(self, model): |
791 |
| - '''Nothing to do for legacy callbacks''' |
792 |
| - return model |
793 |
| - |
794 |
| - def before_iteration(self, model, epoch, dtrain, evals): |
795 |
| - '''Called before each iteration.''' |
796 |
| - for cb in self.callbacks_before_iter: |
797 |
| - rank = rabit.get_rank() |
798 |
| - cb(CallbackEnv(model=None if self.cvfolds is not None else model, |
799 |
| - cvfolds=self.cvfolds, |
800 |
| - iteration=epoch, |
801 |
| - begin_iteration=self.start_iteration, |
802 |
| - end_iteration=self.end_iteration, |
803 |
| - rank=rank, |
804 |
| - evaluation_result_list=None)) |
805 |
| - return False |
806 |
| - |
807 |
| - def after_iteration(self, model, epoch, dtrain, evals): |
808 |
| - '''Called after each iteration.''' |
809 |
| - evaluation_result_list = [] |
810 |
| - if self.cvfolds is not None: |
811 |
| - # dtrain is not used here. |
812 |
| - scores = model.eval(epoch, self.feval) |
813 |
| - self.aggregated_cv = _aggcv(scores) |
814 |
| - evaluation_result_list = self.aggregated_cv |
815 |
| - |
816 |
| - if evals: |
817 |
| - # When cv is used, evals are embedded into folds. |
818 |
| - assert self.cvfolds is None |
819 |
| - bst_eval_set = model.eval_set(evals, epoch, self.feval) |
820 |
| - if isinstance(bst_eval_set, STRING_TYPES): |
821 |
| - msg = bst_eval_set |
822 |
| - else: |
823 |
| - msg = bst_eval_set.decode() |
824 |
| - res = [x.split(':') for x in msg.split()] |
825 |
| - evaluation_result_list = [(k, float(v)) for k, v in res[1:]] |
826 |
| - |
827 |
| - try: |
828 |
| - for cb in self.callbacks_after_iter: |
829 |
| - rank = rabit.get_rank() |
830 |
| - cb(CallbackEnv(model=None if self.cvfolds is not None else model, |
831 |
| - cvfolds=self.cvfolds, |
832 |
| - iteration=epoch, |
833 |
| - begin_iteration=self.start_iteration, |
834 |
| - end_iteration=self.end_iteration, |
835 |
| - rank=rank, |
836 |
| - evaluation_result_list=evaluation_result_list)) |
837 |
| - except EarlyStopException: |
838 |
| - return True |
839 |
| - |
840 |
| - return False |
0 commit comments