13
13
Qiskit Experiments test case class
14
14
"""
15
15
16
- import dataclasses
17
16
import json
18
17
import pickle
19
18
import warnings
20
19
from typing import Any , Callable , Optional
21
20
22
- import numpy as np
23
21
import uncertainties
24
- from lmfit import Model
25
22
from qiskit .test import QiskitTestCase
26
- from qiskit_experiments . data_processing import DataAction , DataProcessor
27
- from qiskit_experiments . framework . experiment_data import ExperimentStatus
23
+ from qiskit . utils . deprecation import deprecate_func
24
+
28
25
from qiskit_experiments .framework import (
29
26
ExperimentDecoder ,
30
27
ExperimentEncoder ,
31
28
ExperimentData ,
32
- BaseExperiment ,
33
- BaseAnalysis ,
34
29
)
35
- from qiskit_experiments .visualization import BaseDrawer
36
- from qiskit_experiments . curve_analysis . curve_data import CurveFitResult
30
+ from qiskit_experiments .framework . experiment_data import ExperimentStatus
31
+ from . extended_equality import is_equivalent
37
32
38
33
39
34
class QiskitExperimentsTestCase (QiskitTestCase ):
@@ -76,15 +71,52 @@ def assertExperimentDone(
76
71
msg = "All threads are executed but status is not DONE. " + experiment_data .errors (),
77
72
)
78
73
79
- def assertRoundTripSerializable (self , obj : Any , check_func : Optional [Callable ] = None ):
74
+ def assertEqualExtended (
75
+ self ,
76
+ first : Any ,
77
+ second : Any ,
78
+ * ,
79
+ msg : Optional [str ] = None ,
80
+ strict_type : bool = False ,
81
+ ):
82
+ """Extended equality assertion which covers Qiskit Experiments classes.
83
+
84
+ .. note::
85
+ Some Qiskit Experiments class may intentionally avoid implementing
86
+ the equality dunder method, or may be used in some unusual situations.
87
+ These are mainly caused by to JSON round trip situation, and some custom classes
88
+ doesn't guarantee object equality after round trip.
89
+ This assertion function forcibly compares input two objects with
90
+ the custom equality checker, which is implemented for unittest purpose.
91
+
92
+ Args:
93
+ first: First object to compare.
94
+ second: Second object to compare.
95
+ msg: Optional. Custom error message issued when first and second object are not equal.
96
+ strict_type: Set True to enforce type check before comparison.
97
+ """
98
+ default_msg = f"{ first } != { second } "
99
+
100
+ self .assertTrue (
101
+ is_equivalent (first , second , strict_type = strict_type ),
102
+ msg = msg or default_msg ,
103
+ )
104
+
105
+ def assertRoundTripSerializable (
106
+ self ,
107
+ obj : Any ,
108
+ * ,
109
+ check_func : Optional [Callable ] = None ,
110
+ strict_type : bool = False ,
111
+ ):
80
112
"""Assert that an object is round trip serializable.
81
113
82
114
Args:
83
115
obj: the object to be serialized.
84
116
check_func: Optional, a custom function ``check_func(a, b) -> bool``
85
- to check equality of the original object with the decoded
86
- object. If None the ``__eq__`` method of the original
87
- object will be used .
117
+ to check equality of the original object with the decoded
118
+ object. If None :meth:`.assertEqualExtended` is called.
119
+ strict_type: Set True to enforce type check before comparison .
88
120
"""
89
121
try :
90
122
encoded = json .dumps (obj , cls = ExperimentEncoder )
@@ -94,20 +126,27 @@ def assertRoundTripSerializable(self, obj: Any, check_func: Optional[Callable] =
94
126
decoded = json .loads (encoded , cls = ExperimentDecoder )
95
127
except TypeError :
96
128
self .fail ("JSON deserialization raised unexpectedly." )
97
- if check_func is None :
98
- self .assertEqual (obj , decoded )
99
- else :
129
+
130
+ if check_func is not None :
100
131
self .assertTrue (check_func (obj , decoded ), msg = f"{ obj } != { decoded } " )
132
+ else :
133
+ self .assertEqualExtended (obj , decoded , strict_type = strict_type )
101
134
102
- def assertRoundTripPickle (self , obj : Any , check_func : Optional [Callable ] = None ):
135
+ def assertRoundTripPickle (
136
+ self ,
137
+ obj : Any ,
138
+ * ,
139
+ check_func : Optional [Callable ] = None ,
140
+ strict_type : bool = False ,
141
+ ):
103
142
"""Assert that an object is round trip serializable using pickle module.
104
143
105
144
Args:
106
145
obj: the object to be serialized.
107
146
check_func: Optional, a custom function ``check_func(a, b) -> bool``
108
- to check equality of the original object with the decoded
109
- object. If None the ``__eq__`` method of the original
110
- object will be used .
147
+ to check equality of the original object with the decoded
148
+ object. If None :meth:`.assertEqualExtended` is called.
149
+ strict_type: Set True to enforce type check before comparison .
111
150
"""
112
151
try :
113
152
encoded = pickle .dumps (obj )
@@ -117,150 +156,63 @@ def assertRoundTripPickle(self, obj: Any, check_func: Optional[Callable] = None)
117
156
decoded = pickle .loads (encoded )
118
157
except TypeError :
119
158
self .fail ("pickle deserialization raised unexpectedly." )
120
- if check_func is None :
121
- self .assertEqual (obj , decoded )
122
- else :
159
+
160
+ if check_func is not None :
123
161
self .assertTrue (check_func (obj , decoded ), msg = f"{ obj } != { decoded } " )
162
+ else :
163
+ self .assertEqualExtended (obj , decoded , strict_type = strict_type )
124
164
125
165
@classmethod
166
+ @deprecate_func (
167
+ since = "0.6" ,
168
+ additional_msg = "Use test.extended_equality.is_equivalent instead." ,
169
+ pending = True ,
170
+ package_name = "qiskit-experiments" ,
171
+ )
126
172
def json_equiv (cls , data1 , data2 ) -> bool :
127
173
"""Check if two experiments are equivalent by comparing their configs"""
128
- # pylint: disable = too-many-return-statements
129
- configurable_type = (BaseExperiment , BaseAnalysis , BaseDrawer )
130
- compare_repr = (DataAction , DataProcessor )
131
- list_type = (list , tuple , set )
132
- skipped = tuple ()
133
-
134
- if isinstance (data1 , skipped ) and isinstance (data2 , skipped ):
135
- warnings .warn (f"Equivalence check for data { data1 .__class__ .__name__ } is skipped." )
136
- return True
137
- elif isinstance (data1 , configurable_type ) and isinstance (data2 , configurable_type ):
138
- return cls .json_equiv (data1 .config (), data2 .config ())
139
- elif dataclasses .is_dataclass (data1 ) and dataclasses .is_dataclass (data2 ):
140
- # not using asdict. this copies all objects.
141
- return cls .json_equiv (data1 .__dict__ , data2 .__dict__ )
142
- elif isinstance (data1 , dict ) and isinstance (data2 , dict ):
143
- if set (data1 ) != set (data2 ):
144
- return False
145
- return all (cls .json_equiv (data1 [k ], data2 [k ]) for k in data1 .keys ())
146
- elif isinstance (data1 , np .ndarray ) or isinstance (data2 , np .ndarray ):
147
- return np .allclose (data1 , data2 )
148
- elif isinstance (data1 , list_type ) and isinstance (data2 , list_type ):
149
- return all (cls .json_equiv (e1 , e2 ) for e1 , e2 in zip (data1 , data2 ))
150
- elif isinstance (data1 , uncertainties .UFloat ) and isinstance (data2 , uncertainties .UFloat ):
151
- return cls .ufloat_equiv (data1 , data2 )
152
- elif isinstance (data1 , Model ) and isinstance (data2 , Model ):
153
- return cls .json_equiv (data1 .dumps (), data2 .dumps ())
154
- elif isinstance (data1 , CurveFitResult ) and isinstance (data2 , CurveFitResult ):
155
- return cls .curve_fit_data_equiv (data1 , data2 )
156
- elif isinstance (data1 , compare_repr ) and isinstance (data2 , compare_repr ):
157
- # otherwise compare instance representation
158
- return repr (data1 ) == repr (data2 )
159
-
160
- return data1 == data2
174
+ return is_equivalent (data1 , data2 )
161
175
162
176
@staticmethod
177
+ @deprecate_func (
178
+ since = "0.6" ,
179
+ additional_msg = "Use test.extended_equality.is_equivalent instead." ,
180
+ pending = True ,
181
+ package_name = "qiskit-experiments" ,
182
+ )
163
183
def ufloat_equiv (data1 : uncertainties .UFloat , data2 : uncertainties .UFloat ) -> bool :
164
184
"""Check if two values with uncertainties are equal. No correlation is considered."""
165
- return data1 . n == data2 . n and data1 . s == data2 . s
185
+ return is_equivalent ( data1 , data2 )
166
186
167
187
@classmethod
188
+ @deprecate_func (
189
+ since = "0.6" ,
190
+ additional_msg = "Use test.extended_equality.is_equivalent instead." ,
191
+ pending = True ,
192
+ package_name = "qiskit-experiments" ,
193
+ )
168
194
def analysis_result_equiv (cls , result1 , result2 ):
169
195
"""Test two analysis results are equivalent"""
170
- # Check basic attributes skipping service which is not serializable
171
- for att in [
172
- "name" ,
173
- "value" ,
174
- "extra" ,
175
- "device_components" ,
176
- "result_id" ,
177
- "experiment_id" ,
178
- "chisq" ,
179
- "quality" ,
180
- "verified" ,
181
- "tags" ,
182
- "auto_save" ,
183
- "source" ,
184
- ]:
185
- if not cls .json_equiv (getattr (result1 , att ), getattr (result2 , att )):
186
- return False
187
- return True
196
+ return is_equivalent (result1 , result2 )
188
197
189
198
@classmethod
199
+ @deprecate_func (
200
+ since = "0.6" ,
201
+ additional_msg = "Use test.extended_equality.is_equivalent instead." ,
202
+ pending = True ,
203
+ package_name = "qiskit-experiments" ,
204
+ )
190
205
def curve_fit_data_equiv (cls , data1 , data2 ):
191
206
"""Test two curve fit result are equivalent."""
192
- for att in [
193
- "method" ,
194
- "model_repr" ,
195
- "success" ,
196
- "nfev" ,
197
- "message" ,
198
- "dof" ,
199
- "init_params" ,
200
- "chisq" ,
201
- "reduced_chisq" ,
202
- "aic" ,
203
- "bic" ,
204
- "params" ,
205
- "var_names" ,
206
- "x_data" ,
207
- "y_data" ,
208
- "covar" ,
209
- ]:
210
- if not cls .json_equiv (getattr (data1 , att ), getattr (data2 , att )):
211
- return False
212
- return True
207
+ return is_equivalent (data1 , data2 )
213
208
214
209
@classmethod
210
+ @deprecate_func (
211
+ since = "0.6" ,
212
+ additional_msg = "Use test.extended_equality.is_equivalent instead." ,
213
+ pending = True ,
214
+ package_name = "qiskit-experiments" ,
215
+ )
215
216
def experiment_data_equiv (cls , data1 , data2 ):
216
217
"""Check two experiment data containers are equivalent"""
217
-
218
- # Check basic attributes
219
- # Skip non-compatible backend
220
- for att in [
221
- "experiment_id" ,
222
- "experiment_type" ,
223
- "parent_id" ,
224
- "tags" ,
225
- "job_ids" ,
226
- "figure_names" ,
227
- "share_level" ,
228
- "metadata" ,
229
- ]:
230
- if not cls .json_equiv (getattr (data1 , att ), getattr (data2 , att )):
231
- return False
232
-
233
- # Check length of data, results, child_data
234
- # check for child data attribute so this method still works for
235
- # DbExperimentData
236
- if hasattr (data1 , "child_data" ):
237
- child_data1 = data1 .child_data ()
238
- else :
239
- child_data1 = []
240
- if hasattr (data2 , "child_data" ):
241
- child_data2 = data2 .child_data ()
242
- else :
243
- child_data2 = []
244
-
245
- if (
246
- len (data1 .data ()) != len (data2 .data ())
247
- or len (data1 .analysis_results ()) != len (data2 .analysis_results ())
248
- or len (child_data1 ) != len (child_data2 )
249
- ):
250
- return False
251
-
252
- # Check data
253
- if not cls .json_equiv (data1 .data (), data2 .data ()):
254
- return False
255
-
256
- # Check analysis results
257
- for result1 , result2 in zip (data1 .analysis_results (), data2 .analysis_results ()):
258
- if not cls .analysis_result_equiv (result1 , result2 ):
259
- return False
260
-
261
- # Check child data
262
- for child1 , child2 in zip (child_data1 , child_data2 ):
263
- if not cls .experiment_data_equiv (child1 , child2 ):
264
- return False
265
-
266
- return True
218
+ return is_equivalent (data1 , data2 )
0 commit comments