Skip to content

Commit bd538ca

Browse files
mtreinishmergify[bot]
authored andcommitted
Fix string and standard gate mismatch in commutation checker (#13991)
* Fix string and standard gate mismatch in commutation checker This commit reworks the internals of the CommutationChecker to not rely on operation name except for where we do a lookup by name in the commutation library provided (which is the only key available to support custom gates). This fixes the case where a custom gate that overloads the standard gate name, previously the code would assume it to be a standard gate and internally panic when it wasn't. When working with standard gates (or standard instructions) we don't need to rely on string matching because we can rely on the rust data model to do the heavy lifting for us. This commit moves all the explicit handling of standard gates to use the StandardGate type directly and makes this logic more robust. This also removes are usage of the once_cell library in qiskit-accelerate because it was used to create a lazy static hashsets of strings which are no longer needed because static lookup tables replace this when we stopped using string comparisons. Fixes #13988 * Rename is_commutation_skipped() is_parameterized() (cherry picked from commit abb0cf9) # Conflicts: # crates/accelerate/src/commutation_checker.rs
1 parent 5a10169 commit bd538ca

File tree

4 files changed

+149
-89
lines changed

4 files changed

+149
-89
lines changed

Cargo.lock

-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/accelerate/Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ itertools.workspace = true
2828
qiskit-circuit.workspace = true
2929
thiserror.workspace = true
3030
ndarray-einsum = "0.8.0"
31-
once_cell = "1.20.3"
3231
rustiq-core = "0.0.10"
3332
bytemuck.workspace = true
3433
nalgebra.workspace = true

crates/accelerate/src/commutation_checker.rs

+118-87
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@ use ndarray::linalg::kron;
1515
use ndarray::Array2;
1616
use num_complex::Complex64;
1717
use num_complex::ComplexFloat;
18+
<<<<<<< HEAD
1819
use once_cell::sync::Lazy;
1920
use qiskit_circuit::bit_data::VarAsKey;
21+
=======
22+
use qiskit_circuit::object_registry::PyObjectAsKey;
23+
>>>>>>> abb0cf9db (Fix string and standard gate mismatch in commutation checker (#13991))
2024
use smallvec::SmallVec;
2125
use std::fmt::Debug;
2226

@@ -33,57 +37,69 @@ use qiskit_circuit::dag_node::DAGOpNode;
3337
use qiskit_circuit::imports::QI_OPERATOR;
3438
use qiskit_circuit::operations::OperationRef::{Gate as PyGateType, Operation as PyOperationType};
3539
use qiskit_circuit::operations::{
36-
get_standard_gate_names, Operation, OperationRef, Param, StandardGate,
40+
Operation, OperationRef, Param, StandardGate, STANDARD_GATE_SIZE,
3741
};
3842
use qiskit_circuit::{BitType, Clbit, Qubit};
3943

4044
use crate::gate_metrics;
4145
use crate::unitary_compose;
4246
use crate::QiskitError;
4347

44-
// These gates do not commute with other gates, we do not check them.
45-
static SKIPPED_NAMES: [&str; 4] = ["measure", "reset", "delay", "initialize"];
48+
const fn build_supported_ops() -> [bool; STANDARD_GATE_SIZE] {
49+
let mut lut = [false; STANDARD_GATE_SIZE];
50+
lut[StandardGate::RXXGate as usize] = true;
51+
lut[StandardGate::RYYGate as usize] = true;
52+
lut[StandardGate::RZZGate as usize] = true;
53+
lut[StandardGate::RZXGate as usize] = true;
54+
lut[StandardGate::HGate as usize] = true;
55+
lut[StandardGate::XGate as usize] = true;
56+
lut[StandardGate::YGate as usize] = true;
57+
lut[StandardGate::ZGate as usize] = true;
58+
lut[StandardGate::SXGate as usize] = true;
59+
lut[StandardGate::SXdgGate as usize] = true;
60+
lut[StandardGate::TGate as usize] = true;
61+
lut[StandardGate::TdgGate as usize] = true;
62+
lut[StandardGate::SGate as usize] = true;
63+
lut[StandardGate::SdgGate as usize] = true;
64+
lut[StandardGate::CXGate as usize] = true;
65+
lut[StandardGate::CYGate as usize] = true;
66+
lut[StandardGate::CZGate as usize] = true;
67+
lut[StandardGate::SwapGate as usize] = true;
68+
lut[StandardGate::ISwapGate as usize] = true;
69+
lut[StandardGate::ECRGate as usize] = true;
70+
lut[StandardGate::CCXGate as usize] = true;
71+
lut[StandardGate::CSwapGate as usize] = true;
72+
lut
73+
}
4674

47-
// We keep a hash-set of operations eligible for commutation checking. This is because checking
48-
// eligibility is not for free.
49-
static SUPPORTED_OP: Lazy<HashSet<&str>> = Lazy::new(|| {
50-
HashSet::from([
51-
"rxx", "ryy", "rzz", "rzx", "h", "x", "y", "z", "sx", "sxdg", "t", "tdg", "s", "sdg", "cx",
52-
"cy", "cz", "swap", "iswap", "ecr", "ccx", "cswap",
53-
])
54-
});
75+
static SUPPORTED_OP: [bool; STANDARD_GATE_SIZE] = build_supported_ops();
5576

5677
// Map rotation gates to their generators (or to ``None`` if we cannot currently efficiently
5778
// represent the generator in Rust and store the commutation relation in the commutation dictionary)
5879
// and their pi-periodicity. Here we mean a gate is n-pi periodic, if for angles that are
5980
// multiples of n*pi, the gate is equal to the identity up to a global phase.
6081
// E.g. RX is generated by X and 2-pi periodic, while CRX is generated by CX and 4-pi periodic.
61-
static SUPPORTED_ROTATIONS: Lazy<HashMap<&str, Option<OperationRef>>> = Lazy::new(|| {
62-
HashMap::from([
63-
("rx", Some(OperationRef::StandardGate(StandardGate::XGate))),
64-
("ry", Some(OperationRef::StandardGate(StandardGate::YGate))),
65-
("rz", Some(OperationRef::StandardGate(StandardGate::ZGate))),
66-
("p", Some(OperationRef::StandardGate(StandardGate::ZGate))),
67-
("u1", Some(OperationRef::StandardGate(StandardGate::ZGate))),
68-
("rxx", None), // None means the gate is in the commutation dictionary
69-
("ryy", None),
70-
("rzx", None),
71-
("rzz", None),
72-
(
73-
"crx",
74-
Some(OperationRef::StandardGate(StandardGate::CXGate)),
75-
),
76-
(
77-
"cry",
78-
Some(OperationRef::StandardGate(StandardGate::CYGate)),
79-
),
80-
(
81-
"crz",
82-
Some(OperationRef::StandardGate(StandardGate::CZGate)),
83-
),
84-
("cp", Some(OperationRef::StandardGate(StandardGate::CZGate))),
85-
])
86-
});
82+
const fn build_supported_rotations() -> [Option<Option<StandardGate>>; STANDARD_GATE_SIZE] {
83+
let mut lut = [None; STANDARD_GATE_SIZE];
84+
lut[StandardGate::RXGate as usize] = Some(Some(StandardGate::XGate));
85+
lut[StandardGate::RYGate as usize] = Some(Some(StandardGate::YGate));
86+
lut[StandardGate::RZGate as usize] = Some(Some(StandardGate::ZGate));
87+
lut[StandardGate::PhaseGate as usize] = Some(Some(StandardGate::ZGate));
88+
lut[StandardGate::U1Gate as usize] = Some(Some(StandardGate::ZGate));
89+
lut[StandardGate::CRXGate as usize] = Some(Some(StandardGate::CXGate));
90+
lut[StandardGate::CRYGate as usize] = Some(Some(StandardGate::CYGate));
91+
lut[StandardGate::CRZGate as usize] = Some(Some(StandardGate::CZGate));
92+
lut[StandardGate::CPhaseGate as usize] = Some(Some(StandardGate::CZGate));
93+
// RXXGate, RYYGate, RZXGate, and RZZGate are supported by the commutation dictionary
94+
lut[StandardGate::RXXGate as usize] = Some(None);
95+
lut[StandardGate::RYYGate as usize] = Some(None);
96+
lut[StandardGate::RZXGate as usize] = Some(None);
97+
lut[StandardGate::RZZGate as usize] = Some(None);
98+
lut
99+
}
100+
101+
static SUPPORTED_ROTATIONS: [Option<Option<StandardGate>>; STANDARD_GATE_SIZE] =
102+
build_supported_rotations();
87103

88104
fn get_bits<T>(bits1: &Bound<PyTuple>, bits2: &Bound<PyTuple>) -> PyResult<(Vec<T>, Vec<T>)>
89105
where
@@ -132,8 +148,6 @@ impl CommutationChecker {
132148
gates: Option<HashSet<String>>,
133149
) -> Self {
134150
// Initialize sets before they are used in the commutation checker
135-
Lazy::force(&SUPPORTED_OP);
136-
Lazy::force(&SUPPORTED_ROTATIONS);
137151
CommutationChecker {
138152
library: CommutationLibrary::new(standard_gate_commutations),
139153
cache: HashMap::new(),
@@ -287,14 +301,24 @@ impl CommutationChecker {
287301

288302
// if we have rotation gates, we attempt to map them to their generators, for example
289303
// RX -> X or CPhase -> CZ
290-
let (op1, params1, trivial1) = map_rotation(op1, params1, tol);
304+
let (op1_gate, params1, trivial1) = map_rotation(op1, params1, tol);
291305
if trivial1 {
292306
return Ok(true);
293307
}
294-
let (op2, params2, trivial2) = map_rotation(op2, params2, tol);
308+
let op1 = if let Some(gate) = op1_gate {
309+
&OperationRef::StandardGate(gate)
310+
} else {
311+
op1
312+
};
313+
let (op2_gate, params2, trivial2) = map_rotation(op2, params2, tol);
295314
if trivial2 {
296315
return Ok(true);
297316
}
317+
let op2 = if let Some(gate) = op2_gate {
318+
&OperationRef::StandardGate(gate)
319+
} else {
320+
op2
321+
};
298322

299323
if let Some(gates) = &self.gates {
300324
if !gates.is_empty() && (!gates.contains(op1.name()) || !gates.contains(op2.name())) {
@@ -339,14 +363,15 @@ impl CommutationChecker {
339363
// the cache for
340364
// * gates we know are in the cache (SUPPORTED_OPS), or
341365
// * standard gates with float params (otherwise we cannot cache them)
342-
let standard_gates = get_standard_gate_names();
343-
let is_cachable = |name: &str, params: &[Param]| {
344-
SUPPORTED_OP.contains(name)
345-
|| (standard_gates.contains(&name)
346-
&& params.iter().all(|p| matches!(p, Param::Float(_))))
366+
let is_cachable = |op: &OperationRef, params: &[Param]| {
367+
if let Some(gate) = op.standard_gate() {
368+
SUPPORTED_OP[gate as usize] || params.iter().all(|p| matches!(p, Param::Float(_)))
369+
} else {
370+
false
371+
}
347372
};
348-
let check_cache = is_cachable(first_op.name(), first_params)
349-
&& is_cachable(second_op.name(), second_params);
373+
let check_cache =
374+
is_cachable(first_op, first_params) && is_cachable(second_op, second_params);
350375

351376
if !check_cache {
352377
return self.commute_matmul(
@@ -544,11 +569,25 @@ fn commutation_precheck(
544569
return Some(false);
545570
}
546571

547-
if SUPPORTED_OP.contains(op1.name()) && SUPPORTED_OP.contains(op2.name()) {
548-
return None;
572+
if let Some(gate_1) = op1.standard_gate() {
573+
if let Some(gate_2) = op2.standard_gate() {
574+
if SUPPORTED_OP[gate_1 as usize] && SUPPORTED_OP[gate_2 as usize] {
575+
return None;
576+
}
577+
}
578+
}
579+
580+
if matches!(
581+
op1,
582+
OperationRef::StandardInstruction(_) | OperationRef::Instruction(_)
583+
) || matches!(
584+
op2,
585+
OperationRef::StandardInstruction(_) | OperationRef::Instruction(_)
586+
) {
587+
return Some(false);
549588
}
550589

551-
if is_commutation_skipped(op1, params1) || is_commutation_skipped(op2, params2) {
590+
if is_parameterized(params1) || is_parameterized(params2) {
552591
return Some(false);
553592
}
554593

@@ -580,15 +619,10 @@ fn matrix_via_operator(py: Python, py_obj: &PyObject) -> PyResult<Array2<Complex
580619
.to_owned())
581620
}
582621

583-
fn is_commutation_skipped<T>(op: &T, params: &[Param]) -> bool
584-
where
585-
T: Operation,
586-
{
587-
op.directive()
588-
|| SKIPPED_NAMES.contains(&op.name())
589-
|| params
590-
.iter()
591-
.any(|x| matches!(x, Param::ParameterExpression(_)))
622+
fn is_parameterized(params: &[Param]) -> bool {
623+
params
624+
.iter()
625+
.any(|x| matches!(x, Param::ParameterExpression(_)))
592626
}
593627

594628
/// Check if a given operation can be mapped onto a generator.
@@ -604,36 +638,33 @@ fn map_rotation<'a>(
604638
op: &'a OperationRef<'a>,
605639
params: &'a [Param],
606640
tol: f64,
607-
) -> (&'a OperationRef<'a>, &'a [Param], bool) {
608-
let name = op.name();
609-
610-
if let Some(generator) = SUPPORTED_ROTATIONS.get(name) {
611-
// If the rotation angle is below the tolerance, the gate is assumed to
612-
// commute with everything, and we simply return the operation with the flag that
613-
// it commutes trivially.
614-
if let Param::Float(angle) = params[0] {
615-
let gate = op
616-
.standard_gate()
617-
.expect("Supported gates are standard gates");
618-
let (tr_over_dim, dim) = gate_metrics::rotation_trace_and_dim(gate, angle)
619-
.expect("All rotation should be covered at this point");
620-
let gate_fidelity = tr_over_dim.abs().powi(2);
621-
let process_fidelity = (dim * gate_fidelity + 1.) / (dim + 1.);
622-
if (1. - process_fidelity).abs() <= tol {
623-
return (op, params, true);
641+
) -> (Option<StandardGate>, &'a [Param], bool) {
642+
if let Some(gate) = op.standard_gate() {
643+
if let Some(generator) = SUPPORTED_ROTATIONS[gate as usize] {
644+
// If the rotation angle is below the tolerance, the gate is assumed to
645+
// commute with everything, and we simply return the operation with the flag that
646+
// it commutes trivially.
647+
if let Param::Float(angle) = params[0] {
648+
let (tr_over_dim, dim) = gate_metrics::rotation_trace_and_dim(gate, angle)
649+
.expect("All rotation should be covered at this point");
650+
let gate_fidelity = tr_over_dim.abs().powi(2);
651+
let process_fidelity = (dim * gate_fidelity + 1.) / (dim + 1.);
652+
if (1. - process_fidelity).abs() <= tol {
653+
return (Some(gate), params, true);
654+
};
624655
};
625-
};
626656

627-
// Otherwise we need to cover two cases -- either a generator is given, in which case
628-
// we return it, or we don't have a generator yet, but we know we have the operation
629-
// stored in the commutation library. For example, RXX does not have a generator in Rust
630-
// yet (PauliGate is not in Rust currently), but it is stored in the library, so we
631-
// can strip the parameters and just return the gate.
632-
if let Some(gate) = generator {
633-
return (gate, &[], false);
634-
};
657+
// Otherwise we need to cover two cases -- either a generator is given, in which case
658+
// we return it, or we don't have a generator yet, but we know we have the operation
659+
// stored in the commutation library. For example, RXX does not have a generator in Rust
660+
// yet (PauliGate is not in Rust currently), but it is stored in the library, so we
661+
// can strip the parameters and just return the gate.
662+
if let Some(gate) = generator {
663+
return (Some(gate), &[], false);
664+
};
665+
}
635666
}
636-
(op, params, false)
667+
(None, params, false)
637668
}
638669

639670
fn get_relative_placement(

test/python/transpiler/test_commutative_cancellation.py

+31
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,37 @@ def test_no_intransitive_cancellation(self):
772772
new_circuit = passmanager.run(circ)
773773
self.assertEqual(new_circuit, circ)
774774

775+
def test_overloaded_standard_gate_name(self):
776+
"""Validate the pass works with custom gates using overloaded names
777+
778+
See: https://github.com/Qiskit/qiskit/issues/13988 for more details.
779+
"""
780+
qasm_str = """OPENQASM 2.0;
781+
include "qelib1.inc";
782+
gate ryy(param0) q0,q1
783+
{
784+
rx(pi/2) q0;
785+
rx(pi/2) q1;
786+
cx q0,q1;
787+
rz(0.37801308) q1;
788+
cx q0,q1;
789+
rx(-pi/2) q0;
790+
rx(-pi/2) q1;
791+
}
792+
qreg q0[2];
793+
creg c0[2];
794+
z q0[0];
795+
ryy(1.2182379) q0[0],q0[1];
796+
z q0[0];
797+
measure q0[0] -> c0[0];
798+
measure q0[1] -> c0[1];
799+
"""
800+
qc = QuantumCircuit.from_qasm_str(qasm_str)
801+
cancellation_pass = CommutativeCancellation()
802+
res = cancellation_pass(qc)
803+
# We don't cancel any gates with a custom rzz gate
804+
self.assertEqual(res.count_ops()["z"], 2)
805+
775806

776807
if __name__ == "__main__":
777808
unittest.main()

0 commit comments

Comments
 (0)