|
1 | 1 | """Define the PySRRegressor scikit-learn interface."""
|
2 | 2 |
|
| 3 | +from __future__ import annotations |
| 4 | + |
3 | 5 | import copy
|
4 | 6 | import logging
|
5 | 7 | import os
|
|
13 | 15 | from io import StringIO
|
14 | 16 | from multiprocessing import cpu_count
|
15 | 17 | from pathlib import Path
|
16 |
| -from typing import Any, Literal, cast |
| 18 | +from typing import Any, List, Literal, Tuple, Union, cast |
17 | 19 |
|
18 | 20 | import numpy as np
|
19 | 21 | import pandas as pd
|
@@ -94,7 +96,7 @@ def _process_constraints(
|
94 | 96 | )
|
95 | 97 | constraints[op] = (-1, -1)
|
96 | 98 |
|
97 |
| - constraint_tuple = cast(tuple[int, int], constraints[op]) |
| 99 | + constraint_tuple = cast(Tuple[int, int], constraints[op]) |
98 | 100 | if op in ["plus", "sub", "+", "-"]:
|
99 | 101 | if constraint_tuple[0] != constraint_tuple[1]:
|
100 | 102 | raise NotImplementedError(
|
@@ -1313,7 +1315,7 @@ def julia_options_(self):
|
1313 | 1315 | def julia_state_(self):
|
1314 | 1316 | """The deserialized state."""
|
1315 | 1317 | return cast(
|
1316 |
| - tuple[VectorValue, AnyValue] | None, |
| 1318 | + Union[Tuple[VectorValue, AnyValue], None], |
1317 | 1319 | jl_deserialize(self.julia_state_stream_),
|
1318 | 1320 | )
|
1319 | 1321 |
|
@@ -1640,7 +1642,7 @@ def _validate_data_X_y(self, X: Any, y: Any) -> tuple[ndarray, ndarray]:
|
1640 | 1642 | raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore
|
1641 | 1643 | else:
|
1642 | 1644 | raw_out = validate_data(self, X=X, y=y, reset=True, multi_output=True) # type: ignore
|
1643 |
| - return cast(tuple[ndarray, ndarray], raw_out) |
| 1645 | + return cast(Tuple[ndarray, ndarray], raw_out) |
1644 | 1646 |
|
1645 | 1647 | def _validate_data_X(self, X: Any) -> ndarray:
|
1646 | 1648 | if OLD_SKLEARN:
|
@@ -2629,7 +2631,7 @@ def get_hof(self, search_output=None) -> pd.DataFrame | list[pd.DataFrame]:
|
2629 | 2631 |
|
2630 | 2632 | _validate_export_mappings(self.extra_jax_mappings, self.extra_torch_mappings)
|
2631 | 2633 |
|
2632 |
| - equation_file_contents = cast(list[pd.DataFrame], self.equation_file_contents_) |
| 2634 | + equation_file_contents = cast(List[pd.DataFrame], self.equation_file_contents_) |
2633 | 2635 |
|
2634 | 2636 | ret_outputs = [
|
2635 | 2637 | pd.concat(
|
|
0 commit comments