Skip to content

Commit 62bbf57

Browse files
boonwaremrossinek
andauthored
feat: option to disable legend (#20)
* feat: option to disable legend * refactor: ensure param is keyword-only * docs: add release note * docs: fix RST formatting * docs: close previous figures * fix: suppress missing coverage --------- Co-authored-by: Max Rossmannek <[email protected]>
1 parent 9675e1b commit 62bbf57

File tree

2 files changed

+44
-27
lines changed

2 files changed

+44
-27
lines changed

qiskit_addon_obp/utils/visualization.py

+37-27
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .metadata import OBPMetadata
2222

2323

24-
def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None:
24+
def plot_accumulated_error(metadata: OBPMetadata, axes: Axes, *, show_legend: bool = True) -> None:
2525
"""Plot the accumulated error.
2626
2727
This method populates the provided figure axes with a line-plot of the
@@ -36,7 +36,7 @@ def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None:
3636
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
3737
3838
.. plot::
39-
:context:
39+
:context: close-figs
4040
:include-source:
4141
4242
>>> from matplotlib import pyplot as plt
@@ -57,7 +57,7 @@ def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None:
5757
Args:
5858
metadata: the metadata to be visualized.
5959
axes: the matplotlib axes in which to plot.
60-
60+
show_legend: enable/disable showing the legend in the plot.
6161
"""
6262
if not np.isinf(metadata.truncation_error_budget.max_error_total):
6363
axes.axhline(
@@ -79,10 +79,12 @@ def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None:
7979
)
8080
axes.set_xlabel("backpropagated slice number")
8181
axes.set_ylabel("accumulated error")
82-
axes.legend()
82+
_set_legend(axes, show_legend)
8383

8484

85-
def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None:
85+
def plot_left_over_error_budget(
86+
metadata: OBPMetadata, axes: Axes, *, show_legend: bool = True
87+
) -> None:
8688
"""Plot the left-over error budget.
8789
8890
This method populates the provided figure axes with a line-plot of the
@@ -97,7 +99,7 @@ def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None:
9799
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
98100
99101
.. plot::
100-
:context:
102+
:context: close-figs
101103
:include-source:
102104
103105
>>> from matplotlib import pyplot as plt
@@ -113,7 +115,7 @@ def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None:
113115
Args:
114116
metadata: the metadata to be visualized.
115117
axes: the matplotlib axes in which to plot.
116-
118+
show_legend: enable/disable showing the legend in the plot.
117119
"""
118120
for obs_idx in range(len(metadata.backpropagation_history[0].slice_errors)):
119121
axes.plot(
@@ -126,10 +128,10 @@ def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None:
126128
)
127129
axes.set_xlabel("backpropagated slice number")
128130
axes.set_ylabel("left-over error budget")
129-
axes.legend()
131+
_set_legend(axes, show_legend)
130132

131133

132-
def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None:
134+
def plot_slice_errors(metadata: OBPMetadata, axes: Axes, *, show_legend: bool = True) -> None:
133135
"""Plot the slice errors.
134136
135137
This method populates the provided figure axes with a bar-plot of the truncation error incurred
@@ -144,7 +146,7 @@ def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None:
144146
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
145147
146148
.. plot::
147-
:context:
149+
:context: close-figs
148150
:include-source:
149151
150152
>>> from matplotlib import pyplot as plt
@@ -163,7 +165,7 @@ def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None:
163165
Args:
164166
metadata: the metadata to be visualized.
165167
axes: the matplotlib axes in which to plot.
166-
168+
show_legend: enable/disable showing the legend in the plot.
167169
"""
168170
num_observables = len(metadata.backpropagation_history[0].slice_errors)
169171
width = 0.8 / num_observables
@@ -181,9 +183,10 @@ def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None:
181183
axes.set_xlabel("backpropagated slice number")
182184
axes.set_ylabel("incurred slice error")
183185
axes.legend()
186+
_set_legend(axes, show_legend)
184187

185188

186-
def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None:
189+
def plot_num_paulis(metadata: OBPMetadata, axes: Axes, *, show_legend: bool = True) -> None:
187190
"""Plot the number of Pauli terms.
188191
189192
This method populates the provided figure axes with a line-plot of the number of Pauli terms at
@@ -198,7 +201,7 @@ def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None:
198201
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
199202
200203
.. plot::
201-
:context:
204+
:context: close-figs
202205
:include-source:
203206
204207
>>> from matplotlib import pyplot as plt
@@ -217,7 +220,7 @@ def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None:
217220
Args:
218221
metadata: the metadata to be visualized.
219222
axes: the matplotlib axes in which to plot.
220-
223+
show_legend: enable/disable showing the legend in the plot.
221224
"""
222225
for obs_idx in range(len(metadata.backpropagation_history[0].slice_errors)):
223226
axes.plot(
@@ -227,10 +230,12 @@ def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None:
227230
)
228231
axes.set_xlabel("backpropagated slice number")
229232
axes.set_ylabel("# Pauli terms")
230-
axes.legend()
233+
_set_legend(axes, show_legend)
231234

232235

233-
def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None:
236+
def plot_num_truncated_paulis(
237+
metadata: OBPMetadata, axes: Axes, *, show_legend: bool = True
238+
) -> None:
234239
"""Plot the number of truncated Pauli terms.
235240
236241
This method populates the provided figure axes with a bar-plot of the number of the truncated
@@ -245,7 +250,7 @@ def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None:
245250
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
246251
247252
.. plot::
248-
:context:
253+
:context: close-figs
249254
:include-source:
250255
251256
>>> from matplotlib import pyplot as plt
@@ -264,7 +269,7 @@ def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None:
264269
Args:
265270
metadata: the metadata to be visualized.
266271
axes: the matplotlib axes in which to plot.
267-
272+
show_legend: enable/disable showing the legend in the plot.
268273
"""
269274
num_observables = len(metadata.backpropagation_history[0].slice_errors)
270275
width = 0.8 / num_observables
@@ -281,10 +286,10 @@ def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None:
281286
offset += width
282287
axes.set_xlabel("backpropagated slice number")
283288
axes.set_ylabel("# truncated Pauli terms")
284-
axes.legend()
289+
_set_legend(axes, show_legend)
285290

286291

287-
def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None:
292+
def plot_sum_paulis(metadata: OBPMetadata, axes: Axes, *, show_legend: bool = True) -> None:
288293
"""Plot the total number of all Pauli terms.
289294
290295
This method populates the provided figure axes with a line-plot of the total number of all Pauli
@@ -299,7 +304,7 @@ def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None:
299304
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
300305
301306
.. plot::
302-
:context:
307+
:context: close-figs
303308
:include-source:
304309
305310
>>> from matplotlib import pyplot as plt
@@ -319,7 +324,7 @@ def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None:
319324
Args:
320325
metadata: the metadata to be visualized.
321326
axes: the matplotlib axes in which to plot.
322-
327+
show_legend: enable/disable showing the legend in the plot.
323328
"""
324329
if metadata.operator_budget.max_paulis is not None:
325330
axes.axhline(
@@ -337,10 +342,10 @@ def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None:
337342
)
338343
axes.set_xlabel("backpropagated slice number")
339344
axes.set_ylabel("total # of Pauli terms")
340-
axes.legend()
345+
_set_legend(axes, show_legend)
341346

342347

343-
def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None:
348+
def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes, *, show_legend: bool = True) -> None:
344349
"""Plot the number of qubit-wise commuting Pauli groups.
345350
346351
This method populates the provided figure axes with a line-plot of the number of qubit-wise
@@ -355,7 +360,7 @@ def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None:
355360
>>> metadata = OBPMetadata.from_json("docs/_static/dummy_visualization_metadata.json")
356361
357362
.. plot::
358-
:context:
363+
:context: close-figs
359364
:include-source:
360365
361366
>>> from matplotlib import pyplot as plt
@@ -371,7 +376,7 @@ def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None:
371376
Args:
372377
metadata: the metadata to be visualized.
373378
axes: the matplotlib axes in which to plot.
374-
379+
show_legend: enable/disable showing the legend in the plot.
375380
"""
376381
if metadata.operator_budget.max_qwc_groups is not None:
377382
axes.axhline(
@@ -389,4 +394,9 @@ def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None:
389394
)
390395
axes.set_xlabel("backpropagated slice number")
391396
axes.set_ylabel("# of qubit-wise commuting Pauli groups")
392-
axes.legend()
397+
_set_legend(axes, show_legend)
398+
399+
400+
def _set_legend(axes: Axes, show_legend: bool) -> None:
401+
if show_legend: # pragma: no cover
402+
axes.legend()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
features:
3+
- |
4+
A new parameter ``show_legend`` has been added to each function in the
5+
:mod:`qiskit_addon_obp.utils.visualization` module that can show or hide the
6+
legend on a plot. The legend is shown by default. This can be useful when
7+
the legend becomes long and obstructs the plot.

0 commit comments

Comments
 (0)