21
21
from .metadata import OBPMetadata
22
22
23
23
24
- def plot_accumulated_error (metadata : OBPMetadata , axes : Axes ) -> None :
24
+ def plot_accumulated_error (metadata : OBPMetadata , axes : Axes , * , show_legend : bool = True ) -> None :
25
25
"""Plot the accumulated error.
26
26
27
27
This method populates the provided figure axes with a line-plot of the
@@ -36,7 +36,7 @@ def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None:
36
36
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
37
37
38
38
.. plot::
39
- :context:
39
+ :context: close-figs
40
40
:include-source:
41
41
42
42
>>> from matplotlib import pyplot as plt
@@ -57,7 +57,7 @@ def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None:
57
57
Args:
58
58
metadata: the metadata to be visualized.
59
59
axes: the matplotlib axes in which to plot.
60
-
60
+ show_legend: enable/disable showing the legend in the plot.
61
61
"""
62
62
if not np .isinf (metadata .truncation_error_budget .max_error_total ):
63
63
axes .axhline (
@@ -79,10 +79,12 @@ def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None:
79
79
)
80
80
axes .set_xlabel ("backpropagated slice number" )
81
81
axes .set_ylabel ("accumulated error" )
82
- axes . legend ( )
82
+ _set_legend ( axes , show_legend )
83
83
84
84
85
- def plot_left_over_error_budget (metadata : OBPMetadata , axes : Axes ) -> None :
85
+ def plot_left_over_error_budget (
86
+ metadata : OBPMetadata , axes : Axes , * , show_legend : bool = True
87
+ ) -> None :
86
88
"""Plot the left-over error budget.
87
89
88
90
This method populates the provided figure axes with a line-plot of the
@@ -97,7 +99,7 @@ def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None:
97
99
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
98
100
99
101
.. plot::
100
- :context:
102
+ :context: close-figs
101
103
:include-source:
102
104
103
105
>>> from matplotlib import pyplot as plt
@@ -113,7 +115,7 @@ def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None:
113
115
Args:
114
116
metadata: the metadata to be visualized.
115
117
axes: the matplotlib axes in which to plot.
116
-
118
+ show_legend: enable/disable showing the legend in the plot.
117
119
"""
118
120
for obs_idx in range (len (metadata .backpropagation_history [0 ].slice_errors )):
119
121
axes .plot (
@@ -126,10 +128,10 @@ def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None:
126
128
)
127
129
axes .set_xlabel ("backpropagated slice number" )
128
130
axes .set_ylabel ("left-over error budget" )
129
- axes . legend ( )
131
+ _set_legend ( axes , show_legend )
130
132
131
133
132
- def plot_slice_errors (metadata : OBPMetadata , axes : Axes ) -> None :
134
+ def plot_slice_errors (metadata : OBPMetadata , axes : Axes , * , show_legend : bool = True ) -> None :
133
135
"""Plot the slice errors.
134
136
135
137
This method populates the provided figure axes with a bar-plot of the truncation error incurred
@@ -144,7 +146,7 @@ def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None:
144
146
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
145
147
146
148
.. plot::
147
- :context:
149
+ :context: close-figs
148
150
:include-source:
149
151
150
152
>>> from matplotlib import pyplot as plt
@@ -163,7 +165,7 @@ def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None:
163
165
Args:
164
166
metadata: the metadata to be visualized.
165
167
axes: the matplotlib axes in which to plot.
166
-
168
+ show_legend: enable/disable showing the legend in the plot.
167
169
"""
168
170
num_observables = len (metadata .backpropagation_history [0 ].slice_errors )
169
171
width = 0.8 / num_observables
@@ -181,9 +183,10 @@ def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None:
181
183
axes .set_xlabel ("backpropagated slice number" )
182
184
axes .set_ylabel ("incurred slice error" )
183
185
axes .legend ()
186
+ _set_legend (axes , show_legend )
184
187
185
188
186
- def plot_num_paulis (metadata : OBPMetadata , axes : Axes ) -> None :
189
+ def plot_num_paulis (metadata : OBPMetadata , axes : Axes , * , show_legend : bool = True ) -> None :
187
190
"""Plot the number of Pauli terms.
188
191
189
192
This method populates the provided figure axes with a line-plot of the number of Pauli terms at
@@ -198,7 +201,7 @@ def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None:
198
201
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
199
202
200
203
.. plot::
201
- :context:
204
+ :context: close-figs
202
205
:include-source:
203
206
204
207
>>> from matplotlib import pyplot as plt
@@ -217,7 +220,7 @@ def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None:
217
220
Args:
218
221
metadata: the metadata to be visualized.
219
222
axes: the matplotlib axes in which to plot.
220
-
223
+ show_legend: enable/disable showing the legend in the plot.
221
224
"""
222
225
for obs_idx in range (len (metadata .backpropagation_history [0 ].slice_errors )):
223
226
axes .plot (
@@ -227,10 +230,12 @@ def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None:
227
230
)
228
231
axes .set_xlabel ("backpropagated slice number" )
229
232
axes .set_ylabel ("# Pauli terms" )
230
- axes . legend ( )
233
+ _set_legend ( axes , show_legend )
231
234
232
235
233
- def plot_num_truncated_paulis (metadata : OBPMetadata , axes : Axes ) -> None :
236
+ def plot_num_truncated_paulis (
237
+ metadata : OBPMetadata , axes : Axes , * , show_legend : bool = True
238
+ ) -> None :
234
239
"""Plot the number of truncated Pauli terms.
235
240
236
241
This method populates the provided figure axes with a bar-plot of the number of the truncated
@@ -245,7 +250,7 @@ def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None:
245
250
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
246
251
247
252
.. plot::
248
- :context:
253
+ :context: close-figs
249
254
:include-source:
250
255
251
256
>>> from matplotlib import pyplot as plt
@@ -264,7 +269,7 @@ def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None:
264
269
Args:
265
270
metadata: the metadata to be visualized.
266
271
axes: the matplotlib axes in which to plot.
267
-
272
+ show_legend: enable/disable showing the legend in the plot.
268
273
"""
269
274
num_observables = len (metadata .backpropagation_history [0 ].slice_errors )
270
275
width = 0.8 / num_observables
@@ -281,10 +286,10 @@ def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None:
281
286
offset += width
282
287
axes .set_xlabel ("backpropagated slice number" )
283
288
axes .set_ylabel ("# truncated Pauli terms" )
284
- axes . legend ( )
289
+ _set_legend ( axes , show_legend )
285
290
286
291
287
- def plot_sum_paulis (metadata : OBPMetadata , axes : Axes ) -> None :
292
+ def plot_sum_paulis (metadata : OBPMetadata , axes : Axes , * , show_legend : bool = True ) -> None :
288
293
"""Plot the total number of all Pauli terms.
289
294
290
295
This method populates the provided figure axes with a line-plot of the total number of all Pauli
@@ -299,7 +304,7 @@ def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None:
299
304
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
300
305
301
306
.. plot::
302
- :context:
307
+ :context: close-figs
303
308
:include-source:
304
309
305
310
>>> from matplotlib import pyplot as plt
@@ -319,7 +324,7 @@ def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None:
319
324
Args:
320
325
metadata: the metadata to be visualized.
321
326
axes: the matplotlib axes in which to plot.
322
-
327
+ show_legend: enable/disable showing the legend in the plot.
323
328
"""
324
329
if metadata .operator_budget .max_paulis is not None :
325
330
axes .axhline (
@@ -337,10 +342,10 @@ def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None:
337
342
)
338
343
axes .set_xlabel ("backpropagated slice number" )
339
344
axes .set_ylabel ("total # of Pauli terms" )
340
- axes . legend ( )
345
+ _set_legend ( axes , show_legend )
341
346
342
347
343
- def plot_num_qwc_groups (metadata : OBPMetadata , axes : Axes ) -> None :
348
+ def plot_num_qwc_groups (metadata : OBPMetadata , axes : Axes , * , show_legend : bool = True ) -> None :
344
349
"""Plot the number of qubit-wise commuting Pauli groups.
345
350
346
351
This method populates the provided figure axes with a line-plot of the number of qubit-wise
@@ -355,7 +360,7 @@ def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None:
355
360
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
356
361
357
362
.. plot::
358
- :context:
363
+ :context: close-figs
359
364
:include-source:
360
365
361
366
>>> from matplotlib import pyplot as plt
@@ -371,7 +376,7 @@ def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None:
371
376
Args:
372
377
metadata: the metadata to be visualized.
373
378
axes: the matplotlib axes in which to plot.
374
-
379
+ show_legend: enable/disable showing the legend in the plot.
375
380
"""
376
381
if metadata .operator_budget .max_qwc_groups is not None :
377
382
axes .axhline (
@@ -389,4 +394,9 @@ def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None:
389
394
)
390
395
axes .set_xlabel ("backpropagated slice number" )
391
396
axes .set_ylabel ("# of qubit-wise commuting Pauli groups" )
392
- axes .legend ()
397
+ _set_legend (axes , show_legend )
398
+
399
+
400
+ def _set_legend (axes : Axes , show_legend : bool ) -> None :
401
+ if show_legend : # pragma: no cover
402
+ axes .legend ()
0 commit comments