Skip to content

Commit 0f16f77

Browse files
committed
Revert "Revert "Fix deepcopy/pickle of DAGCircuit variable IO nodes (backport Qiskit#14041) (Qiskit#14043)" (Qiskit#14094)"
This reverts commit 3f094ab.
1 parent 3f094ab commit 0f16f77

File tree

3 files changed

+82
-1
lines changed

3 files changed

+82
-1
lines changed

crates/circuit/src/dag_circuit.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -5591,7 +5591,8 @@ impl DAGCircuit {
55915591
} else if wire.is_instance(imports::CLBIT.get_bound(py))? {
55925592
NodeType::ClbitOut(self.clbits.find(wire).unwrap())
55935593
} else {
5594-
NodeType::VarIn(self.vars.find(wire).unwrap())
5594+
let var = PyObjectAsKey::new(wire);
5595+
NodeType::VarOut(self.vars.find(&var).unwrap())
55955596
}
55965597
} else if let Ok(op_node) = b.downcast::<DAGOpNode>() {
55975598
let op_node = op_node.borrow();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed a bug in :class:`~.dagcircuit.DAGCircuit` that would cause
5+
output :class:`~.expr.Var` nodes to become input nodes during
6+
``deepcopy`` and pickling.

test/python/dagcircuit/test_dagcircuit.py

+74
Original file line numberDiff line numberDiff line change
@@ -2082,6 +2082,80 @@ def test_present_vars(self):
20822082
right.add_captured_var(a_u8_other)
20832083
self.assertNotEqual(left, right)
20842084

2085+
def test_pickle_vars(self):
2086+
"""Test vars preserved through pickle."""
2087+
a = expr.Var.new("a", types.Bool())
2088+
b = expr.Var.new("b", types.Uint(8))
2089+
2090+
# Check inputs.
2091+
dag = DAGCircuit()
2092+
dag.add_input_var(a)
2093+
2094+
self.assertEqual(dag.num_vars, 1)
2095+
self.assertEqual(dag.num_input_vars, 1)
2096+
2097+
with io.BytesIO() as buf:
2098+
pickle.dump(dag, buf)
2099+
buf.seek(0)
2100+
output = pickle.load(buf)
2101+
2102+
self.assertEqual(output.num_vars, 1)
2103+
self.assertEqual(output.num_input_vars, 1)
2104+
self.assertEqual(output, dag)
2105+
2106+
# Check captures and declarations.
2107+
dag = DAGCircuit()
2108+
dag.add_declared_var(a)
2109+
dag.add_captured_var(b)
2110+
2111+
self.assertEqual(dag.num_vars, 2)
2112+
self.assertEqual(dag.num_captured_vars, 1)
2113+
self.assertEqual(dag.num_declared_vars, 1)
2114+
2115+
with io.BytesIO() as buf:
2116+
pickle.dump(dag, buf)
2117+
buf.seek(0)
2118+
output = pickle.load(buf)
2119+
2120+
self.assertEqual(output.num_vars, 2)
2121+
self.assertEqual(output.num_captured_vars, 1)
2122+
self.assertEqual(output.num_declared_vars, 1)
2123+
self.assertEqual(output, dag)
2124+
2125+
def test_deepcopy_vars(self):
2126+
"""Test vars preserved through deepcopy."""
2127+
a = expr.Var.new("a", types.Bool())
2128+
b = expr.Var.new("b", types.Uint(8))
2129+
2130+
# Check inputs.
2131+
dag = DAGCircuit()
2132+
dag.add_input_var(a)
2133+
2134+
self.assertEqual(dag.num_vars, 1)
2135+
self.assertEqual(dag.num_input_vars, 1)
2136+
2137+
output = copy.deepcopy(dag)
2138+
2139+
self.assertEqual(output.num_vars, 1)
2140+
self.assertEqual(output.num_input_vars, 1)
2141+
self.assertEqual(output, dag)
2142+
2143+
# Check captures and declarations.
2144+
dag = DAGCircuit()
2145+
dag.add_declared_var(a)
2146+
dag.add_captured_var(b)
2147+
2148+
self.assertEqual(dag.num_vars, 2)
2149+
self.assertEqual(dag.num_captured_vars, 1)
2150+
self.assertEqual(dag.num_declared_vars, 1)
2151+
2152+
output = copy.deepcopy(dag)
2153+
2154+
self.assertEqual(output.num_vars, 2)
2155+
self.assertEqual(output.num_captured_vars, 1)
2156+
self.assertEqual(output.num_declared_vars, 1)
2157+
self.assertEqual(output, dag)
2158+
20852159
def test_wires_added_for_simple_classical_vars(self):
20862160
"""Var uses should be represented in the wire structure."""
20872161
a = expr.Var.new("a", types.Bool())

0 commit comments

Comments
 (0)