Skip to content

Commit 2278679

Browse files
Upgrade unittest equality method (#1132)
### Summary Current implementation of equality check, i.e. `QiskitExperimentsTestCase.json_equiv`, is not readable and scalable because it implements equality check logic for different types in a single method. This PR adds new test module `test/extended_equality.py` which implements new equality check dispatcher `is_equivalent`. Developers no longer need to specify `check_func` in the `assertRoundTripSerializable` and `assertRoundTripPickle` methods unless they define custom class for a specific unittest. This simplifies unittests and improves readability of equality check logic (and test becomes more trustable). This PR adds new software dependency in develop; [multimethod](https://pypi.org/project/multimethod/) Among several similar packages, this is chosen in favor of - its license type (Apache License, Version 2.0) - syntax compatibility with `functools.singledispatch` - support for subscripted generics in `typings`, e.g. `Union` --------- Co-authored-by: Helena Zhang <[email protected]>
1 parent 4038556 commit 2278679

32 files changed

+564
-257
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
---
2+
developer:
3+
- |
4+
Added the :meth:`QiskitExperimentsTestCase.assertEqualExtended` method for generic equality checks
5+
of Qiskit Experiments class instances in unittests. This is a drop-in replacement of
6+
calling the assertTrue with :meth:`QiskitExperimentsTestCase.json_equiv`.
7+
Note that some Qiskit Experiments classes may not officially implement equality check logic,
8+
although objects may be compared during unittests. Extended equality check is used
9+
for such situations.
10+
- |
11+
The following unittest test case methods will be deprecated:
12+
13+
* :meth:`QiskitExperimentsTestCase.json_equiv`
14+
* :meth:`QiskitExperimentsTestCase.ufloat_equiv`
15+
* :meth:`QiskitExperimentsTestCase.analysis_result_equiv`
16+
* :meth:`QiskitExperimentsTestCase.curve_fit_data_equiv`
17+
* :meth:`QiskitExperimentsTestCase.experiment_data_equiv`
18+
19+
One can now use the :func:`~test.extended_equality.is_equivalent` function instead.
20+
This function internally dispatches the logic for equality check.
21+
- |
22+
The default behavior of :meth:`QiskitExperimentsTestCase.assertRoundTripSerializable` and
23+
:meth:`QiskitExperimentsTestCase.assertRoundTripPickle` when `check_func` is not
24+
provided was upgraded. These methods now compare the decoded instance with
25+
:func:`~test.extended_equality.is_equivalent`, rather than
26+
delegating to the native `assertEqual` unittest method.
27+
One writing a unittest for serialization no longer need to explicitly set checker function.

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ qiskit-aer>=0.11.0
1616
pandas>=1.1.5
1717
cvxpy>=1.1.15
1818
pylatexenc
19+
multimethod
1920
scikit-learn
2021
sphinx-copybutton
2122
# Pin versions below because of build errors

test/base.py

+98-146
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,22 @@
1313
Qiskit Experiments test case class
1414
"""
1515

16-
import dataclasses
1716
import json
1817
import pickle
1918
import warnings
2019
from typing import Any, Callable, Optional
2120

22-
import numpy as np
2321
import uncertainties
24-
from lmfit import Model
2522
from qiskit.test import QiskitTestCase
26-
from qiskit_experiments.data_processing import DataAction, DataProcessor
27-
from qiskit_experiments.framework.experiment_data import ExperimentStatus
23+
from qiskit.utils.deprecation import deprecate_func
24+
2825
from qiskit_experiments.framework import (
2926
ExperimentDecoder,
3027
ExperimentEncoder,
3128
ExperimentData,
32-
BaseExperiment,
33-
BaseAnalysis,
3429
)
35-
from qiskit_experiments.visualization import BaseDrawer
36-
from qiskit_experiments.curve_analysis.curve_data import CurveFitResult
30+
from qiskit_experiments.framework.experiment_data import ExperimentStatus
31+
from .extended_equality import is_equivalent
3732

3833

3934
class QiskitExperimentsTestCase(QiskitTestCase):
@@ -76,15 +71,52 @@ def assertExperimentDone(
7671
msg="All threads are executed but status is not DONE. " + experiment_data.errors(),
7772
)
7873

79-
def assertRoundTripSerializable(self, obj: Any, check_func: Optional[Callable] = None):
74+
def assertEqualExtended(
75+
self,
76+
first: Any,
77+
second: Any,
78+
*,
79+
msg: Optional[str] = None,
80+
strict_type: bool = False,
81+
):
82+
"""Extended equality assertion which covers Qiskit Experiments classes.
83+
84+
.. note::
85+
Some Qiskit Experiments class may intentionally avoid implementing
86+
the equality dunder method, or may be used in some unusual situations.
87+
These are mainly caused by to JSON round trip situation, and some custom classes
88+
doesn't guarantee object equality after round trip.
89+
This assertion function forcibly compares input two objects with
90+
the custom equality checker, which is implemented for unittest purpose.
91+
92+
Args:
93+
first: First object to compare.
94+
second: Second object to compare.
95+
msg: Optional. Custom error message issued when first and second object are not equal.
96+
strict_type: Set True to enforce type check before comparison.
97+
"""
98+
default_msg = f"{first} != {second}"
99+
100+
self.assertTrue(
101+
is_equivalent(first, second, strict_type=strict_type),
102+
msg=msg or default_msg,
103+
)
104+
105+
def assertRoundTripSerializable(
106+
self,
107+
obj: Any,
108+
*,
109+
check_func: Optional[Callable] = None,
110+
strict_type: bool = False,
111+
):
80112
"""Assert that an object is round trip serializable.
81113
82114
Args:
83115
obj: the object to be serialized.
84116
check_func: Optional, a custom function ``check_func(a, b) -> bool``
85-
to check equality of the original object with the decoded
86-
object. If None the ``__eq__`` method of the original
87-
object will be used.
117+
to check equality of the original object with the decoded
118+
object. If None :meth:`.assertEqualExtended` is called.
119+
strict_type: Set True to enforce type check before comparison.
88120
"""
89121
try:
90122
encoded = json.dumps(obj, cls=ExperimentEncoder)
@@ -94,20 +126,27 @@ def assertRoundTripSerializable(self, obj: Any, check_func: Optional[Callable] =
94126
decoded = json.loads(encoded, cls=ExperimentDecoder)
95127
except TypeError:
96128
self.fail("JSON deserialization raised unexpectedly.")
97-
if check_func is None:
98-
self.assertEqual(obj, decoded)
99-
else:
129+
130+
if check_func is not None:
100131
self.assertTrue(check_func(obj, decoded), msg=f"{obj} != {decoded}")
132+
else:
133+
self.assertEqualExtended(obj, decoded, strict_type=strict_type)
101134

102-
def assertRoundTripPickle(self, obj: Any, check_func: Optional[Callable] = None):
135+
def assertRoundTripPickle(
136+
self,
137+
obj: Any,
138+
*,
139+
check_func: Optional[Callable] = None,
140+
strict_type: bool = False,
141+
):
103142
"""Assert that an object is round trip serializable using pickle module.
104143
105144
Args:
106145
obj: the object to be serialized.
107146
check_func: Optional, a custom function ``check_func(a, b) -> bool``
108-
to check equality of the original object with the decoded
109-
object. If None the ``__eq__`` method of the original
110-
object will be used.
147+
to check equality of the original object with the decoded
148+
object. If None :meth:`.assertEqualExtended` is called.
149+
strict_type: Set True to enforce type check before comparison.
111150
"""
112151
try:
113152
encoded = pickle.dumps(obj)
@@ -117,150 +156,63 @@ def assertRoundTripPickle(self, obj: Any, check_func: Optional[Callable] = None)
117156
decoded = pickle.loads(encoded)
118157
except TypeError:
119158
self.fail("pickle deserialization raised unexpectedly.")
120-
if check_func is None:
121-
self.assertEqual(obj, decoded)
122-
else:
159+
160+
if check_func is not None:
123161
self.assertTrue(check_func(obj, decoded), msg=f"{obj} != {decoded}")
162+
else:
163+
self.assertEqualExtended(obj, decoded, strict_type=strict_type)
124164

125165
@classmethod
166+
@deprecate_func(
167+
since="0.6",
168+
additional_msg="Use test.extended_equality.is_equivalent instead.",
169+
pending=True,
170+
package_name="qiskit-experiments",
171+
)
126172
def json_equiv(cls, data1, data2) -> bool:
127173
"""Check if two experiments are equivalent by comparing their configs"""
128-
# pylint: disable = too-many-return-statements
129-
configurable_type = (BaseExperiment, BaseAnalysis, BaseDrawer)
130-
compare_repr = (DataAction, DataProcessor)
131-
list_type = (list, tuple, set)
132-
skipped = tuple()
133-
134-
if isinstance(data1, skipped) and isinstance(data2, skipped):
135-
warnings.warn(f"Equivalence check for data {data1.__class__.__name__} is skipped.")
136-
return True
137-
elif isinstance(data1, configurable_type) and isinstance(data2, configurable_type):
138-
return cls.json_equiv(data1.config(), data2.config())
139-
elif dataclasses.is_dataclass(data1) and dataclasses.is_dataclass(data2):
140-
# not using asdict. this copies all objects.
141-
return cls.json_equiv(data1.__dict__, data2.__dict__)
142-
elif isinstance(data1, dict) and isinstance(data2, dict):
143-
if set(data1) != set(data2):
144-
return False
145-
return all(cls.json_equiv(data1[k], data2[k]) for k in data1.keys())
146-
elif isinstance(data1, np.ndarray) or isinstance(data2, np.ndarray):
147-
return np.allclose(data1, data2)
148-
elif isinstance(data1, list_type) and isinstance(data2, list_type):
149-
return all(cls.json_equiv(e1, e2) for e1, e2 in zip(data1, data2))
150-
elif isinstance(data1, uncertainties.UFloat) and isinstance(data2, uncertainties.UFloat):
151-
return cls.ufloat_equiv(data1, data2)
152-
elif isinstance(data1, Model) and isinstance(data2, Model):
153-
return cls.json_equiv(data1.dumps(), data2.dumps())
154-
elif isinstance(data1, CurveFitResult) and isinstance(data2, CurveFitResult):
155-
return cls.curve_fit_data_equiv(data1, data2)
156-
elif isinstance(data1, compare_repr) and isinstance(data2, compare_repr):
157-
# otherwise compare instance representation
158-
return repr(data1) == repr(data2)
159-
160-
return data1 == data2
174+
return is_equivalent(data1, data2)
161175

162176
@staticmethod
177+
@deprecate_func(
178+
since="0.6",
179+
additional_msg="Use test.extended_equality.is_equivalent instead.",
180+
pending=True,
181+
package_name="qiskit-experiments",
182+
)
163183
def ufloat_equiv(data1: uncertainties.UFloat, data2: uncertainties.UFloat) -> bool:
164184
"""Check if two values with uncertainties are equal. No correlation is considered."""
165-
return data1.n == data2.n and data1.s == data2.s
185+
return is_equivalent(data1, data2)
166186

167187
@classmethod
188+
@deprecate_func(
189+
since="0.6",
190+
additional_msg="Use test.extended_equality.is_equivalent instead.",
191+
pending=True,
192+
package_name="qiskit-experiments",
193+
)
168194
def analysis_result_equiv(cls, result1, result2):
169195
"""Test two analysis results are equivalent"""
170-
# Check basic attributes skipping service which is not serializable
171-
for att in [
172-
"name",
173-
"value",
174-
"extra",
175-
"device_components",
176-
"result_id",
177-
"experiment_id",
178-
"chisq",
179-
"quality",
180-
"verified",
181-
"tags",
182-
"auto_save",
183-
"source",
184-
]:
185-
if not cls.json_equiv(getattr(result1, att), getattr(result2, att)):
186-
return False
187-
return True
196+
return is_equivalent(result1, result2)
188197

189198
@classmethod
199+
@deprecate_func(
200+
since="0.6",
201+
additional_msg="Use test.extended_equality.is_equivalent instead.",
202+
pending=True,
203+
package_name="qiskit-experiments",
204+
)
190205
def curve_fit_data_equiv(cls, data1, data2):
191206
"""Test two curve fit result are equivalent."""
192-
for att in [
193-
"method",
194-
"model_repr",
195-
"success",
196-
"nfev",
197-
"message",
198-
"dof",
199-
"init_params",
200-
"chisq",
201-
"reduced_chisq",
202-
"aic",
203-
"bic",
204-
"params",
205-
"var_names",
206-
"x_data",
207-
"y_data",
208-
"covar",
209-
]:
210-
if not cls.json_equiv(getattr(data1, att), getattr(data2, att)):
211-
return False
212-
return True
207+
return is_equivalent(data1, data2)
213208

214209
@classmethod
210+
@deprecate_func(
211+
since="0.6",
212+
additional_msg="Use test.extended_equality.is_equivalent instead.",
213+
pending=True,
214+
package_name="qiskit-experiments",
215+
)
215216
def experiment_data_equiv(cls, data1, data2):
216217
"""Check two experiment data containers are equivalent"""
217-
218-
# Check basic attributes
219-
# Skip non-compatible backend
220-
for att in [
221-
"experiment_id",
222-
"experiment_type",
223-
"parent_id",
224-
"tags",
225-
"job_ids",
226-
"figure_names",
227-
"share_level",
228-
"metadata",
229-
]:
230-
if not cls.json_equiv(getattr(data1, att), getattr(data2, att)):
231-
return False
232-
233-
# Check length of data, results, child_data
234-
# check for child data attribute so this method still works for
235-
# DbExperimentData
236-
if hasattr(data1, "child_data"):
237-
child_data1 = data1.child_data()
238-
else:
239-
child_data1 = []
240-
if hasattr(data2, "child_data"):
241-
child_data2 = data2.child_data()
242-
else:
243-
child_data2 = []
244-
245-
if (
246-
len(data1.data()) != len(data2.data())
247-
or len(data1.analysis_results()) != len(data2.analysis_results())
248-
or len(child_data1) != len(child_data2)
249-
):
250-
return False
251-
252-
# Check data
253-
if not cls.json_equiv(data1.data(), data2.data()):
254-
return False
255-
256-
# Check analysis results
257-
for result1, result2 in zip(data1.analysis_results(), data2.analysis_results()):
258-
if not cls.analysis_result_equiv(result1, result2):
259-
return False
260-
261-
# Check child data
262-
for child1, child2 in zip(child_data1, child_data2):
263-
if not cls.experiment_data_equiv(child1, child2):
264-
return False
265-
266-
return True
218+
return is_equivalent(data1, data2)

test/calibration/test_calibrations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1754,7 +1754,7 @@ def test_serialization(self):
17541754
cals = Calibrations.from_backend(backend, libraries=[library])
17551755
cals.add_parameter_value(0.12345, "amp", 3, "x")
17561756

1757-
self.assertRoundTripSerializable(cals, self.json_equiv)
1757+
self.assertRoundTripSerializable(cals)
17581758

17591759
def test_equality(self):
17601760
"""Test the equal method on calibrations."""

test/curve_analysis/test_baseclass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class TestCurveAnalysis(CurveAnalysisTestCase):
8585
def test_roundtrip_serialize(self):
8686
"""A testcase for serializing analysis instance."""
8787
analysis = CurveAnalysis(models=[ExpressionModel(expr="par0 * x + par1", name="test")])
88-
self.assertRoundTripSerializable(analysis, check_func=self.json_equiv)
88+
self.assertRoundTripSerializable(analysis)
8989

9090
def test_parameters(self):
9191
"""A testcase for getting fit parameters with attribute."""

test/data_processing/test_data_processing.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -387,14 +387,14 @@ def test_json_single_node(self):
387387
"""Check if the data processor is serializable."""
388388
node = MinMaxNormalize()
389389
processor = DataProcessor("counts", [node])
390-
self.assertRoundTripSerializable(processor, check_func=self.json_equiv)
390+
self.assertRoundTripSerializable(processor)
391391

392392
def test_json_multi_node(self):
393393
"""Check if the data processor with multiple nodes is serializable."""
394394
node1 = MinMaxNormalize()
395395
node2 = AverageData(axis=2)
396396
processor = DataProcessor("counts", [node1, node2])
397-
self.assertRoundTripSerializable(processor, check_func=self.json_equiv)
397+
self.assertRoundTripSerializable(processor)
398398

399399
def test_json_trained(self):
400400
"""Check if trained data processor is serializable and still work."""
@@ -405,7 +405,7 @@ def test_json_trained(self):
405405
main_axes=np.array([[1, 0]]), scales=[1.0], i_means=[0.0], q_means=[0.0]
406406
)
407407
processor = DataProcessor("memory", data_actions=[node])
408-
self.assertRoundTripSerializable(processor, check_func=self.json_equiv)
408+
self.assertRoundTripSerializable(processor)
409409

410410
serialized = json.dumps(processor, cls=ExperimentEncoder)
411411
loaded_processor = json.loads(serialized, cls=ExperimentDecoder)

0 commit comments

Comments
 (0)