Skip to content

Commit aeb38c4

Browse files
authored
Merge pull request #910 from MilesCranmer/fix-inv-pickling
fix: pickling of inv
2 parents 13cc76c + 1a8759c commit aeb38c4

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "pysr"
7-
version = "1.5.5"
7+
version = "1.5.6"
88
authors = [
99
{name = "Miles Cranmer", email = "[email protected]"},
1010
]

pysr/export_sympy.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
sympy_mappings = {
1414
"div": lambda x, y: x / y,
15+
"inv": lambda x: 1 / x,
1516
"mult": lambda x, y: x * y,
1617
"sqrt": lambda x: sympy.sqrt(x),
1718
"sqrt_abs": lambda x: sympy.sqrt(abs(x)),

pysr/test/test_main.py

+23
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
load_all_packages,
3232
)
3333
from pysr.export_latex import sympy2latex
34+
from pysr.export_sympy import pysr2sympy
3435
from pysr.feature_selection import _handle_feature_selection, run_feature_selection
3536
from pysr.julia_helpers import init_julia
3637
from pysr.sr import (
@@ -903,6 +904,28 @@ def test_feature_selection_handler(self):
903904
class TestMiscellaneous(unittest.TestCase):
904905
"""Test miscellaneous functions."""
905906

907+
def test_pickle_inv_sympy_expression(self):
908+
"""Test that sympy expressions with the inv operator can be pickled and unpickled correctly."""
909+
expr_str = "inv(x0) + x1"
910+
sympy_expr = pysr2sympy(expr_str, feature_names_in=["x0", "x1"])
911+
912+
# Evaluate the original expression at a test point
913+
test_vals = {sympy.Symbol("x0"): 2.0, sympy.Symbol("x1"): 3.0}
914+
original_result = float(sympy_expr.subs(test_vals))
915+
916+
# Pickle and unpickle the sympy expression
917+
serialized = pkl.dumps(sympy_expr)
918+
deserialized_expr = pkl.loads(serialized)
919+
920+
# Evaluate the unpickled expression at the same test point
921+
unpickled_result = float(deserialized_expr.subs(test_vals))
922+
923+
# Verify the results match
924+
self.assertEqual(original_result, unpickled_result)
925+
926+
# Check that the same operator mapping was used (1/x for inv)
927+
self.assertEqual(original_result, 0.5 + 3.0) # 1/2 + 3 = 3.5
928+
906929
def test_pickle_with_temp_equation_file(self):
907930
"""If we have a temporary equation file, unpickle the estimator."""
908931
model = PySRRegressor(

0 commit comments

Comments
 (0)