Skip to content

Commit da53e59

Browse files
committed
Fix string and standard gate mismatch in commutation checker
This commit reworks the internals of the CommutationChecker to not rely on operation name except for where we do a lookup by name in the commutation library provided (which is the only key available to support custom gates). This fixes the case where a custom gate that overloads the standard gate name, previously the code would assume it to be a standard gate and internally panic when it wasn't. When working with standard gates (or standard instructions) we don't need to rely on string matching because we can rely on the rust data model to do the heavy lifting for us. This commit moves all the explicit handling of standard gates to use the StandardGate type directly and makes this logic more robust. This also removes are usage of the once_cell library in qiskit-accelerate because it was used to create a lazy static hashsets of strings which are no longer needed because static lookup tables replace this when we stopped using string comparisons. Fixes Qiskit#13988
1 parent 81269f7 commit da53e59

File tree

4 files changed

+145
-90
lines changed

4 files changed

+145
-90
lines changed

Cargo.lock

-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/accelerate/Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ itertools.workspace = true
2828
qiskit-circuit.workspace = true
2929
thiserror.workspace = true
3030
ndarray-einsum = "0.8.0"
31-
once_cell = "1.20.3"
3231
rustiq-core = "0.0.10"
3332
bytemuck.workspace = true
3433
nalgebra.workspace = true

crates/accelerate/src/commutation_checker.rs

+114-88
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ use ndarray::linalg::kron;
1515
use ndarray::Array2;
1616
use num_complex::Complex64;
1717
use num_complex::ComplexFloat;
18-
use once_cell::sync::Lazy;
1918
use qiskit_circuit::bit_data::VarAsKey;
2019
use smallvec::SmallVec;
2120
use std::fmt::Debug;
@@ -33,57 +32,69 @@ use qiskit_circuit::dag_node::DAGOpNode;
3332
use qiskit_circuit::imports::QI_OPERATOR;
3433
use qiskit_circuit::operations::OperationRef::{Gate as PyGateType, Operation as PyOperationType};
3534
use qiskit_circuit::operations::{
36-
get_standard_gate_names, Operation, OperationRef, Param, StandardGate,
35+
Operation, OperationRef, Param, StandardGate, STANDARD_GATE_SIZE,
3736
};
3837
use qiskit_circuit::{BitType, Clbit, Qubit};
3938

4039
use crate::gate_metrics;
4140
use crate::unitary_compose;
4241
use crate::QiskitError;
4342

44-
// These gates do not commute with other gates, we do not check them.
45-
static SKIPPED_NAMES: [&str; 4] = ["measure", "reset", "delay", "initialize"];
43+
const fn build_supported_ops() -> [bool; STANDARD_GATE_SIZE] {
44+
let mut lut = [false; STANDARD_GATE_SIZE];
45+
lut[StandardGate::RXXGate as usize] = true;
46+
lut[StandardGate::RYYGate as usize] = true;
47+
lut[StandardGate::RZZGate as usize] = true;
48+
lut[StandardGate::RZXGate as usize] = true;
49+
lut[StandardGate::HGate as usize] = true;
50+
lut[StandardGate::XGate as usize] = true;
51+
lut[StandardGate::YGate as usize] = true;
52+
lut[StandardGate::ZGate as usize] = true;
53+
lut[StandardGate::SXGate as usize] = true;
54+
lut[StandardGate::SXdgGate as usize] = true;
55+
lut[StandardGate::TGate as usize] = true;
56+
lut[StandardGate::TdgGate as usize] = true;
57+
lut[StandardGate::SGate as usize] = true;
58+
lut[StandardGate::SdgGate as usize] = true;
59+
lut[StandardGate::CXGate as usize] = true;
60+
lut[StandardGate::CYGate as usize] = true;
61+
lut[StandardGate::CZGate as usize] = true;
62+
lut[StandardGate::SwapGate as usize] = true;
63+
lut[StandardGate::ISwapGate as usize] = true;
64+
lut[StandardGate::ECRGate as usize] = true;
65+
lut[StandardGate::CCXGate as usize] = true;
66+
lut[StandardGate::CSwapGate as usize] = true;
67+
lut
68+
}
4669

47-
// We keep a hash-set of operations eligible for commutation checking. This is because checking
48-
// eligibility is not for free.
49-
static SUPPORTED_OP: Lazy<HashSet<&str>> = Lazy::new(|| {
50-
HashSet::from([
51-
"rxx", "ryy", "rzz", "rzx", "h", "x", "y", "z", "sx", "sxdg", "t", "tdg", "s", "sdg", "cx",
52-
"cy", "cz", "swap", "iswap", "ecr", "ccx", "cswap",
53-
])
54-
});
70+
static SUPPORTED_OP: [bool; STANDARD_GATE_SIZE] = build_supported_ops();
5571

5672
// Map rotation gates to their generators (or to ``None`` if we cannot currently efficiently
5773
// represent the generator in Rust and store the commutation relation in the commutation dictionary)
5874
// and their pi-periodicity. Here we mean a gate is n-pi periodic, if for angles that are
5975
// multiples of n*pi, the gate is equal to the identity up to a global phase.
6076
// E.g. RX is generated by X and 2-pi periodic, while CRX is generated by CX and 4-pi periodic.
61-
static SUPPORTED_ROTATIONS: Lazy<HashMap<&str, Option<OperationRef>>> = Lazy::new(|| {
62-
HashMap::from([
63-
("rx", Some(OperationRef::StandardGate(StandardGate::XGate))),
64-
("ry", Some(OperationRef::StandardGate(StandardGate::YGate))),
65-
("rz", Some(OperationRef::StandardGate(StandardGate::ZGate))),
66-
("p", Some(OperationRef::StandardGate(StandardGate::ZGate))),
67-
("u1", Some(OperationRef::StandardGate(StandardGate::ZGate))),
68-
("rxx", None), // None means the gate is in the commutation dictionary
69-
("ryy", None),
70-
("rzx", None),
71-
("rzz", None),
72-
(
73-
"crx",
74-
Some(OperationRef::StandardGate(StandardGate::CXGate)),
75-
),
76-
(
77-
"cry",
78-
Some(OperationRef::StandardGate(StandardGate::CYGate)),
79-
),
80-
(
81-
"crz",
82-
Some(OperationRef::StandardGate(StandardGate::CZGate)),
83-
),
84-
("cp", Some(OperationRef::StandardGate(StandardGate::CZGate))),
85-
])
86-
});
77+
const fn build_supported_rotations() -> [Option<Option<StandardGate>>; STANDARD_GATE_SIZE] {
78+
let mut lut = [None; STANDARD_GATE_SIZE];
79+
lut[StandardGate::RXGate as usize] = Some(Some(StandardGate::XGate));
80+
lut[StandardGate::RYGate as usize] = Some(Some(StandardGate::YGate));
81+
lut[StandardGate::RZGate as usize] = Some(Some(StandardGate::ZGate));
82+
lut[StandardGate::PhaseGate as usize] = Some(Some(StandardGate::ZGate));
83+
lut[StandardGate::U1Gate as usize] = Some(Some(StandardGate::ZGate));
84+
lut[StandardGate::CRXGate as usize] = Some(Some(StandardGate::CXGate));
85+
lut[StandardGate::CRYGate as usize] = Some(Some(StandardGate::CYGate));
86+
lut[StandardGate::CRZGate as usize] = Some(Some(StandardGate::CZGate));
87+
lut[StandardGate::CPhaseGate as usize] = Some(Some(StandardGate::CZGate));
88+
// RXXGate, RYYGate, RZXGate, and RZZGate are supported by the commutation dictionary
89+
lut[StandardGate::RXXGate as usize] = Some(None);
90+
lut[StandardGate::RYYGate as usize] = Some(None);
91+
lut[StandardGate::RZXGate as usize] = Some(None);
92+
lut[StandardGate::RZZGate as usize] = Some(None);
93+
lut
94+
}
95+
96+
static SUPPORTED_ROTATIONS: [Option<Option<StandardGate>>; STANDARD_GATE_SIZE] =
97+
build_supported_rotations();
8798

8899
fn get_bits<T>(bits1: &Bound<PyTuple>, bits2: &Bound<PyTuple>) -> PyResult<(Vec<T>, Vec<T>)>
89100
where
@@ -132,8 +143,6 @@ impl CommutationChecker {
132143
gates: Option<HashSet<String>>,
133144
) -> Self {
134145
// Initialize sets before they are used in the commutation checker
135-
Lazy::force(&SUPPORTED_OP);
136-
Lazy::force(&SUPPORTED_ROTATIONS);
137146
CommutationChecker {
138147
library: CommutationLibrary::new(standard_gate_commutations),
139148
cache: HashMap::new(),
@@ -287,14 +296,24 @@ impl CommutationChecker {
287296

288297
// if we have rotation gates, we attempt to map them to their generators, for example
289298
// RX -> X or CPhase -> CZ
290-
let (op1, params1, trivial1) = map_rotation(op1, params1, tol);
299+
let (op1_gate, params1, trivial1) = map_rotation(op1, params1, tol);
291300
if trivial1 {
292301
return Ok(true);
293302
}
294-
let (op2, params2, trivial2) = map_rotation(op2, params2, tol);
303+
let op1 = if let Some(gate) = op1_gate {
304+
&OperationRef::StandardGate(gate)
305+
} else {
306+
op1
307+
};
308+
let (op2_gate, params2, trivial2) = map_rotation(op2, params2, tol);
295309
if trivial2 {
296310
return Ok(true);
297311
}
312+
let op2 = if let Some(gate) = op2_gate {
313+
&OperationRef::StandardGate(gate)
314+
} else {
315+
op2
316+
};
298317

299318
if let Some(gates) = &self.gates {
300319
if !gates.is_empty() && (!gates.contains(op1.name()) || !gates.contains(op2.name())) {
@@ -339,14 +358,15 @@ impl CommutationChecker {
339358
// the cache for
340359
// * gates we know are in the cache (SUPPORTED_OPS), or
341360
// * standard gates with float params (otherwise we cannot cache them)
342-
let standard_gates = get_standard_gate_names();
343-
let is_cachable = |name: &str, params: &[Param]| {
344-
SUPPORTED_OP.contains(name)
345-
|| (standard_gates.contains(&name)
346-
&& params.iter().all(|p| matches!(p, Param::Float(_))))
361+
let is_cachable = |op: &OperationRef, params: &[Param]| {
362+
if let Some(gate) = op.standard_gate() {
363+
SUPPORTED_OP[gate as usize] || params.iter().all(|p| matches!(p, Param::Float(_)))
364+
} else {
365+
false
366+
}
347367
};
348-
let check_cache = is_cachable(first_op.name(), first_params)
349-
&& is_cachable(second_op.name(), second_params);
368+
let check_cache =
369+
is_cachable(first_op, first_params) && is_cachable(second_op, second_params);
350370

351371
if !check_cache {
352372
return self.commute_matmul(
@@ -544,11 +564,25 @@ fn commutation_precheck(
544564
return Some(false);
545565
}
546566

547-
if SUPPORTED_OP.contains(op1.name()) && SUPPORTED_OP.contains(op2.name()) {
548-
return None;
567+
if let Some(gate_1) = op1.standard_gate() {
568+
if let Some(gate_2) = op2.standard_gate() {
569+
if SUPPORTED_OP[gate_1 as usize] && SUPPORTED_OP[gate_2 as usize] {
570+
return None;
571+
}
572+
}
573+
}
574+
575+
if matches!(
576+
op1,
577+
OperationRef::StandardInstruction(_) | OperationRef::Instruction(_)
578+
) || matches!(
579+
op2,
580+
OperationRef::StandardInstruction(_) | OperationRef::Instruction(_)
581+
) {
582+
return Some(false);
549583
}
550584

551-
if is_commutation_skipped(op1, params1) || is_commutation_skipped(op2, params2) {
585+
if is_commutation_skipped(params1) || is_commutation_skipped(params2) {
552586
return Some(false);
553587
}
554588

@@ -580,15 +614,10 @@ fn matrix_via_operator(py: Python, py_obj: &PyObject) -> PyResult<Array2<Complex
580614
.to_owned())
581615
}
582616

583-
fn is_commutation_skipped<T>(op: &T, params: &[Param]) -> bool
584-
where
585-
T: Operation,
586-
{
587-
op.directive()
588-
|| SKIPPED_NAMES.contains(&op.name())
589-
|| params
590-
.iter()
591-
.any(|x| matches!(x, Param::ParameterExpression(_)))
617+
fn is_commutation_skipped(params: &[Param]) -> bool {
618+
params
619+
.iter()
620+
.any(|x| matches!(x, Param::ParameterExpression(_)))
592621
}
593622

594623
/// Check if a given operation can be mapped onto a generator.
@@ -604,36 +633,33 @@ fn map_rotation<'a>(
604633
op: &'a OperationRef<'a>,
605634
params: &'a [Param],
606635
tol: f64,
607-
) -> (&'a OperationRef<'a>, &'a [Param], bool) {
608-
let name = op.name();
609-
610-
if let Some(generator) = SUPPORTED_ROTATIONS.get(name) {
611-
// If the rotation angle is below the tolerance, the gate is assumed to
612-
// commute with everything, and we simply return the operation with the flag that
613-
// it commutes trivially.
614-
if let Param::Float(angle) = params[0] {
615-
let gate = op
616-
.standard_gate()
617-
.expect("Supported gates are standard gates");
618-
let (tr_over_dim, dim) = gate_metrics::rotation_trace_and_dim(gate, angle)
619-
.expect("All rotation should be covered at this point");
620-
let gate_fidelity = tr_over_dim.abs().powi(2);
621-
let process_fidelity = (dim * gate_fidelity + 1.) / (dim + 1.);
622-
if (1. - process_fidelity).abs() <= tol {
623-
return (op, params, true);
636+
) -> (Option<StandardGate>, &'a [Param], bool) {
637+
if let Some(gate) = op.standard_gate() {
638+
if let Some(generator) = SUPPORTED_ROTATIONS[gate as usize] {
639+
// If the rotation angle is below the tolerance, the gate is assumed to
640+
// commute with everything, and we simply return the operation with the flag that
641+
// it commutes trivially.
642+
if let Param::Float(angle) = params[0] {
643+
let (tr_over_dim, dim) = gate_metrics::rotation_trace_and_dim(gate, angle)
644+
.expect("All rotation should be covered at this point");
645+
let gate_fidelity = tr_over_dim.abs().powi(2);
646+
let process_fidelity = (dim * gate_fidelity + 1.) / (dim + 1.);
647+
if (1. - process_fidelity).abs() <= tol {
648+
return (Some(gate), params, true);
649+
};
624650
};
625-
};
626651

627-
// Otherwise we need to cover two cases -- either a generator is given, in which case
628-
// we return it, or we don't have a generator yet, but we know we have the operation
629-
// stored in the commutation library. For example, RXX does not have a generator in Rust
630-
// yet (PauliGate is not in Rust currently), but it is stored in the library, so we
631-
// can strip the parameters and just return the gate.
632-
if let Some(gate) = generator {
633-
return (gate, &[], false);
634-
};
652+
// Otherwise we need to cover two cases -- either a generator is given, in which case
653+
// we return it, or we don't have a generator yet, but we know we have the operation
654+
// stored in the commutation library. For example, RXX does not have a generator in Rust
655+
// yet (PauliGate is not in Rust currently), but it is stored in the library, so we
656+
// can strip the parameters and just return the gate.
657+
if let Some(gate) = generator {
658+
return (Some(gate), &[], false);
659+
};
660+
}
635661
}
636-
(op, params, false)
662+
(None, params, false)
637663
}
638664

639665
fn get_relative_placement(

test/python/transpiler/test_commutative_cancellation.py

+31
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,37 @@ def test_no_intransitive_cancellation(self):
772772
new_circuit = passmanager.run(circ)
773773
self.assertEqual(new_circuit, circ)
774774

775+
def test_overloaded_standard_gate_name(self):
776+
"""Validate the pass works with custom gates using overloaded names
777+
778+
See: https://github.com/Qiskit/qiskit/issues/13988 for more details.
779+
"""
780+
qasm_str = """OPENQASM 2.0;
781+
include "qelib1.inc";
782+
gate ryy(param0) q0,q1
783+
{
784+
rx(pi/2) q0;
785+
rx(pi/2) q1;
786+
cx q0,q1;
787+
rz(0.37801308) q1;
788+
cx q0,q1;
789+
rx(-pi/2) q0;
790+
rx(-pi/2) q1;
791+
}
792+
qreg q0[2];
793+
creg c0[2];
794+
z q0[0];
795+
ryy(1.2182379) q0[0],q0[1];
796+
z q0[0];
797+
measure q0[0] -> c0[0];
798+
measure q0[1] -> c0[1];
799+
"""
800+
qc = QuantumCircuit.from_qasm_str(qasm_str)
801+
cancellation_pass = CommutativeCancellation()
802+
res = cancellation_pass(qc)
803+
# We don't cancel any gates with a custom rzz gate
804+
self.assertEqual(res.count_ops()["z"], 2)
805+
775806

776807
if __name__ == "__main__":
777808
unittest.main()

0 commit comments

Comments
 (0)