Skip to content

Commit 280c4dc

Browse files
authored
Implement SparseObservable.apply_layout (#13372)
This is one more usability method to bring `SparseObservable` closer inline with `SparsePauliOp`. The same functionality is relatively easily implementable by the user by iterating through the terms, mapping the indices, and putting the output back into `SparseObservable.from_sparse_list`, but given how heavily we promote the method for `SparsePauliOp`, it probably forms part of the core of user expectations.
1 parent b06c3cf commit 280c4dc

File tree

3 files changed

+302
-4
lines changed

3 files changed

+302
-4
lines changed

crates/accelerate/src/lib.rs

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
// copyright notice, and modified files need to carry a notice indicating
1111
// that they have been altered from the originals.
1212

13+
// This stylistic lint suppression should be in `Cargo.toml`, but we can't do that until we're at an
14+
// MSRV of 1.74 or greater.
15+
#![allow(clippy::comparison_chain)]
16+
1317
use std::env;
1418

1519
use pyo3::import_exception;

crates/accelerate/src/sparse_observable.rs

+120-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
use std::collections::btree_map;
1414

15+
use hashbrown::HashSet;
1516
use num_complex::Complex64;
1617
use num_traits::Zero;
1718
use thiserror::Error;
@@ -263,8 +264,11 @@ impl ::std::convert::TryFrom<u8> for BitTerm {
263264
}
264265
}
265266

266-
/// Error cases stemming from data coherence at the point of entry into `SparseObservable` from raw
267-
/// arrays.
267+
/// Error cases stemming from data coherence at the point of entry into `SparseObservable` from
268+
/// user-provided arrays.
269+
///
270+
/// These most typically appear during [from_raw_parts], but can also be introduced by various
271+
/// remapping arithmetic functions.
268272
///
269273
/// These are generally associated with the Python-space `ValueError` because all of the
270274
/// `TypeError`-related ones are statically forbidden (within Rust) by the language, and conversion
@@ -285,6 +289,10 @@ pub enum CoherenceError {
285289
DecreasingBoundaries,
286290
#[error("the values in `indices` are not term-wise increasing")]
287291
UnsortedIndices,
292+
#[error("the input contains duplicate qubits")]
293+
DuplicateIndices,
294+
#[error("the provided qubit mapping does not account for all contained qubits")]
295+
IndexMapTooSmall,
288296
}
289297
impl From<CoherenceError> for PyErr {
290298
fn from(value: CoherenceError) -> PyErr {
@@ -753,7 +761,9 @@ impl SparseObservable {
753761
let indices = &indices[left..right];
754762
if !indices.is_empty() {
755763
for (index_left, index_right) in indices[..].iter().zip(&indices[1..]) {
756-
if index_left >= index_right {
764+
if index_left == index_right {
765+
return Err(CoherenceError::DuplicateIndices);
766+
} else if index_left > index_right {
757767
return Err(CoherenceError::UnsortedIndices);
758768
}
759769
}
@@ -931,6 +941,42 @@ impl SparseObservable {
931941
Ok(())
932942
}
933943

944+
/// Relabel the `indices` in the operator to new values.
945+
///
946+
/// This fails if any of the new indices are too large, or if any mapping would cause a term to
947+
/// contain duplicates of the same index. It may not detect if multiple qubits are mapped to
948+
/// the same index, if those qubits never appear together in the same term. Such a mapping
949+
/// would not cause data-coherence problems (the output observable will be valid), but is
950+
/// unlikely to be what you intended.
951+
///
952+
/// *Panics* if `new_qubits` is not long enough to map every index used in the operator.
953+
pub fn relabel_qubits_from_slice(&mut self, new_qubits: &[u32]) -> Result<(), CoherenceError> {
954+
for qubit in new_qubits {
955+
if *qubit >= self.num_qubits {
956+
return Err(CoherenceError::BitIndexTooHigh);
957+
}
958+
}
959+
let mut order = btree_map::BTreeMap::new();
960+
for i in 0..self.num_terms() {
961+
let start = self.boundaries[i];
962+
let end = self.boundaries[i + 1];
963+
for j in start..end {
964+
order.insert(new_qubits[self.indices[j] as usize], self.bit_terms[j]);
965+
}
966+
if order.len() != end - start {
967+
return Err(CoherenceError::DuplicateIndices);
968+
}
969+
for (index, dest) in order.keys().zip(&mut self.indices[start..end]) {
970+
*dest = *index;
971+
}
972+
for (bit_term, dest) in order.values().zip(&mut self.bit_terms[start..end]) {
973+
*dest = *bit_term;
974+
}
975+
order.clear();
976+
}
977+
Ok(())
978+
}
979+
934980
/// Return a suitable Python error if two observables do not have equal numbers of qubits.
935981
fn check_equal_qubits(&self, other: &SparseObservable) -> PyResult<()> {
936982
if self.num_qubits != other.num_qubits {
@@ -2020,6 +2066,77 @@ impl SparseObservable {
20202066
}
20212067
out
20222068
}
2069+
2070+
/// Apply a transpiler layout to this :class:`SparseObservable`.
2071+
///
2072+
/// Typically you will have defined your observable in terms of the virtual qubits of the
2073+
/// circuits you will use to prepare states. After transpilation, the virtual qubits are mapped
2074+
/// to particular physical qubits on a device, which may be wider than your circuit. That
2075+
/// mapping can also change over the course of the circuit. This method transforms the input
2076+
/// observable on virtual qubits to an observable that is suitable to apply immediately after
2077+
/// the fully transpiled *physical* circuit.
2078+
///
2079+
/// Args:
2080+
/// layout (TranspileLayout | list[int] | None): The layout to apply. Most uses of this
2081+
/// function should pass the :attr:`.QuantumCircuit.layout` field from a circuit that
2082+
/// was transpiled for hardware. In addition, you can pass a list of new qubit indices.
2083+
/// If given as explicitly ``None``, no remapping is applied (but you can still use
2084+
/// ``num_qubits`` to expand the observable).
2085+
/// num_qubits (int | None): The number of qubits to expand the observable to. If not
2086+
/// supplied, the output will be as wide as the given :class:`.TranspileLayout`, or the
2087+
/// same width as the input if the ``layout`` is given in another form.
2088+
///
2089+
/// Returns:
2090+
/// A new :class:`SparseObservable` with the provided layout applied.
2091+
#[pyo3(signature = (/, layout, num_qubits=None), name = "apply_layout")]
2092+
fn py_apply_layout(&self, layout: Bound<PyAny>, num_qubits: Option<u32>) -> PyResult<Self> {
2093+
let py = layout.py();
2094+
let check_inferred_qubits = |inferred: u32| -> PyResult<u32> {
2095+
if inferred < self.num_qubits {
2096+
return Err(PyValueError::new_err(format!(
2097+
"cannot shrink the qubit count in an observable from {} to {}",
2098+
self.num_qubits, inferred
2099+
)));
2100+
}
2101+
Ok(inferred)
2102+
};
2103+
if layout.is_none() {
2104+
let mut out = self.clone();
2105+
out.num_qubits = check_inferred_qubits(num_qubits.unwrap_or(self.num_qubits))?;
2106+
return Ok(out);
2107+
}
2108+
let (num_qubits, layout) = if layout.is_instance(
2109+
&py.import_bound(intern!(py, "qiskit.transpiler"))?
2110+
.getattr(intern!(py, "TranspileLayout"))?,
2111+
)? {
2112+
(
2113+
check_inferred_qubits(
2114+
layout.getattr(intern!(py, "_output_qubit_list"))?.len()? as u32
2115+
)?,
2116+
layout
2117+
.call_method0(intern!(py, "final_index_layout"))?
2118+
.extract::<Vec<u32>>()?,
2119+
)
2120+
} else {
2121+
(
2122+
check_inferred_qubits(num_qubits.unwrap_or(self.num_qubits))?,
2123+
layout.extract()?,
2124+
)
2125+
};
2126+
if layout.len() < self.num_qubits as usize {
2127+
return Err(CoherenceError::IndexMapTooSmall.into());
2128+
}
2129+
if layout.iter().any(|qubit| *qubit >= num_qubits) {
2130+
return Err(CoherenceError::BitIndexTooHigh.into());
2131+
}
2132+
if layout.iter().collect::<HashSet<_>>().len() != layout.len() {
2133+
return Err(CoherenceError::DuplicateIndices.into());
2134+
}
2135+
let mut out = self.clone();
2136+
out.num_qubits = num_qubits;
2137+
out.relabel_qubits_from_slice(&layout)?;
2138+
Ok(out)
2139+
}
20232140
}
20242141

20252142
impl ::std::ops::Add<&SparseObservable> for SparseObservable {

test/python/quantum_info/test_sparse_observable.py

+178-1
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@
1313
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
1414

1515
import copy
16+
import itertools
1617
import pickle
18+
import random
1719
import unittest
1820

1921
import ddt
2022
import numpy as np
2123

22-
from qiskit.circuit import Parameter
24+
from qiskit import transpile
25+
from qiskit.circuit import Measure, Parameter, library, QuantumCircuit
2326
from qiskit.exceptions import QiskitError
2427
from qiskit.quantum_info import SparseObservable, SparsePauliOp, Pauli
28+
from qiskit.transpiler import Target
2529

2630
from test import QiskitTestCase, combine # pylint: disable=wrong-import-order
2731

@@ -39,6 +43,24 @@ def single_cases():
3943
]
4044

4145

46+
def lnn_target(num_qubits):
47+
"""Create a simple `Target` object with an arbitrary basis-gate set, and open-path
48+
connectivity."""
49+
out = Target()
50+
out.add_instruction(library.RZGate(Parameter("a")), {(q,): None for q in range(num_qubits)})
51+
out.add_instruction(library.SXGate(), {(q,): None for q in range(num_qubits)})
52+
out.add_instruction(Measure(), {(q,): None for q in range(num_qubits)})
53+
out.add_instruction(
54+
library.CXGate(),
55+
{
56+
pair: None
57+
for lower in range(num_qubits - 1)
58+
for pair in [(lower, lower + 1), (lower + 1, lower)]
59+
},
60+
)
61+
return out
62+
63+
4264
class AllowRightArithmetic:
4365
"""Some type that implements only the right-hand-sided arithmatic operations, and allows
4466
`SparseObservable` to pass through them.
@@ -1533,3 +1555,158 @@ def test_clear(self, obs):
15331555
num_qubits = obs.num_qubits
15341556
obs.clear()
15351557
self.assertEqual(obs, SparseObservable.zero(num_qubits))
1558+
1559+
def test_apply_layout_list(self):
1560+
self.assertEqual(
1561+
SparseObservable.zero(5).apply_layout([4, 3, 2, 1, 0]), SparseObservable.zero(5)
1562+
)
1563+
self.assertEqual(
1564+
SparseObservable.zero(3).apply_layout([0, 2, 1], 8), SparseObservable.zero(8)
1565+
)
1566+
self.assertEqual(
1567+
SparseObservable.identity(2).apply_layout([1, 0]), SparseObservable.identity(2)
1568+
)
1569+
self.assertEqual(
1570+
SparseObservable.identity(3).apply_layout([100, 10_000, 3], 100_000_000),
1571+
SparseObservable.identity(100_000_000),
1572+
)
1573+
1574+
terms = [
1575+
("ZYX", (4, 2, 1), 1j),
1576+
("", (), -0.5),
1577+
("+-rl01", (10, 8, 6, 4, 2, 0), 2.0),
1578+
]
1579+
1580+
def map_indices(terms, layout):
1581+
return [
1582+
(terms, tuple(layout[bit] for bit in bits), coeff) for terms, bits, coeff in terms
1583+
]
1584+
1585+
identity = list(range(12))
1586+
self.assertEqual(
1587+
SparseObservable.from_sparse_list(terms, num_qubits=12).apply_layout(identity),
1588+
SparseObservable.from_sparse_list(terms, num_qubits=12),
1589+
)
1590+
# We've already tested elsewhere that `SparseObservable.from_sparse_list` produces termwise
1591+
# sorted indices, so these tests also ensure `apply_layout` is maintaining that invariant.
1592+
backwards = list(range(12))[::-1]
1593+
self.assertEqual(
1594+
SparseObservable.from_sparse_list(terms, num_qubits=12).apply_layout(backwards),
1595+
SparseObservable.from_sparse_list(map_indices(terms, backwards), num_qubits=12),
1596+
)
1597+
shuffled = [4, 7, 1, 10, 0, 11, 3, 2, 8, 5, 6, 9]
1598+
self.assertEqual(
1599+
SparseObservable.from_sparse_list(terms, num_qubits=12).apply_layout(shuffled),
1600+
SparseObservable.from_sparse_list(map_indices(terms, shuffled), num_qubits=12),
1601+
)
1602+
self.assertEqual(
1603+
SparseObservable.from_sparse_list(terms, num_qubits=12).apply_layout(shuffled, 100),
1604+
SparseObservable.from_sparse_list(map_indices(terms, shuffled), num_qubits=100),
1605+
)
1606+
expanded = [78, 69, 82, 68, 32, 97, 108, 101, 114, 116, 33]
1607+
self.assertEqual(
1608+
SparseObservable.from_sparse_list(terms, num_qubits=11).apply_layout(expanded, 120),
1609+
SparseObservable.from_sparse_list(map_indices(terms, expanded), num_qubits=120),
1610+
)
1611+
1612+
def test_apply_layout_transpiled(self):
1613+
base = SparseObservable.from_sparse_list(
1614+
[
1615+
("ZYX", (4, 2, 1), 1j),
1616+
("", (), -0.5),
1617+
("+-r", (3, 2, 0), 2.0),
1618+
],
1619+
num_qubits=5,
1620+
)
1621+
1622+
qc = QuantumCircuit(5)
1623+
initial_list = [3, 4, 0, 2, 1]
1624+
no_routing = transpile(
1625+
qc, target=lnn_target(5), initial_layout=initial_list, seed_transpiler=2024_10_25_0
1626+
).layout
1627+
# It's easiest here to test against the `list` form, which we verify separately and
1628+
# explicitly.
1629+
self.assertEqual(base.apply_layout(no_routing), base.apply_layout(initial_list))
1630+
1631+
expanded = transpile(
1632+
qc, target=lnn_target(100), initial_layout=initial_list, seed_transpiler=2024_10_25_1
1633+
).layout
1634+
self.assertEqual(
1635+
base.apply_layout(expanded), base.apply_layout(initial_list, num_qubits=100)
1636+
)
1637+
1638+
qc = QuantumCircuit(5)
1639+
qargs = list(itertools.permutations(range(5), 2))
1640+
random.Random(2024_10_25_2).shuffle(qargs)
1641+
for pair in qargs:
1642+
qc.cx(*pair)
1643+
1644+
routed = transpile(qc, target=lnn_target(5), seed_transpiler=2024_10_25_3).layout
1645+
self.assertEqual(
1646+
base.apply_layout(routed),
1647+
base.apply_layout(routed.final_index_layout(filter_ancillas=True)),
1648+
)
1649+
1650+
routed_expanded = transpile(qc, target=lnn_target(20), seed_transpiler=2024_10_25_3).layout
1651+
self.assertEqual(
1652+
base.apply_layout(routed_expanded),
1653+
base.apply_layout(
1654+
routed_expanded.final_index_layout(filter_ancillas=True), num_qubits=20
1655+
),
1656+
)
1657+
1658+
def test_apply_layout_none(self):
1659+
self.assertEqual(SparseObservable.zero(0).apply_layout(None), SparseObservable.zero(0))
1660+
self.assertEqual(SparseObservable.zero(0).apply_layout(None, 3), SparseObservable.zero(3))
1661+
self.assertEqual(SparseObservable.zero(5).apply_layout(None), SparseObservable.zero(5))
1662+
self.assertEqual(SparseObservable.zero(3).apply_layout(None, 8), SparseObservable.zero(8))
1663+
self.assertEqual(
1664+
SparseObservable.identity(0).apply_layout(None), SparseObservable.identity(0)
1665+
)
1666+
self.assertEqual(
1667+
SparseObservable.identity(0).apply_layout(None, 8), SparseObservable.identity(8)
1668+
)
1669+
self.assertEqual(
1670+
SparseObservable.identity(2).apply_layout(None), SparseObservable.identity(2)
1671+
)
1672+
self.assertEqual(
1673+
SparseObservable.identity(3).apply_layout(None, 100_000_000),
1674+
SparseObservable.identity(100_000_000),
1675+
)
1676+
1677+
terms = [
1678+
("ZYX", (2, 1, 0), 1j),
1679+
("", (), -0.5),
1680+
("+-rl01", (10, 8, 6, 4, 2, 0), 2.0),
1681+
]
1682+
self.assertEqual(
1683+
SparseObservable.from_sparse_list(terms, num_qubits=12).apply_layout(None),
1684+
SparseObservable.from_sparse_list(terms, num_qubits=12),
1685+
)
1686+
self.assertEqual(
1687+
SparseObservable.from_sparse_list(terms, num_qubits=12).apply_layout(
1688+
None, num_qubits=200
1689+
),
1690+
SparseObservable.from_sparse_list(terms, num_qubits=200),
1691+
)
1692+
1693+
def test_apply_layout_failures(self):
1694+
obs = SparseObservable.from_list([("IIYI", 2.0), ("IIIX", -1j)])
1695+
with self.assertRaisesRegex(ValueError, "duplicate"):
1696+
obs.apply_layout([0, 0, 1, 2])
1697+
with self.assertRaisesRegex(ValueError, "does not account for all contained qubits"):
1698+
obs.apply_layout([0, 1])
1699+
with self.assertRaisesRegex(ValueError, "less than the number of qubits"):
1700+
obs.apply_layout([0, 2, 4, 6])
1701+
with self.assertRaisesRegex(ValueError, "cannot shrink"):
1702+
obs.apply_layout([0, 1], num_qubits=2)
1703+
with self.assertRaisesRegex(ValueError, "cannot shrink"):
1704+
obs.apply_layout(None, num_qubits=2)
1705+
1706+
qc = QuantumCircuit(3)
1707+
qc.cx(0, 1)
1708+
qc.cx(1, 2)
1709+
qc.cx(2, 0)
1710+
layout = transpile(qc, target=lnn_target(3), seed_transpiler=2024_10_25).layout
1711+
with self.assertRaisesRegex(ValueError, "cannot shrink"):
1712+
obs.apply_layout(layout, num_qubits=2)

0 commit comments

Comments
 (0)