Skip to content

Commit 9c8eb06

Browse files
CryorisikkohamElePT
authored
Improve Parameter handling in SparsePauliOp (#9796)
* add reno * Add assign_parameters and parameter in init * add SPO.parameters and remove utils * fix ParameterValueType typehint * Update qiskit/quantum_info/operators/symplectic/sparse_pauli_op.py Co-authored-by: Ikko Hamamura <[email protected]> * remove trailing print * Elena's comments Co-authored-by: Elena Peña Tapia <[email protected]> * fix line length --------- Co-authored-by: Ikko Hamamura <[email protected]> Co-authored-by: Elena Peña Tapia <[email protected]>
1 parent 7bb4af9 commit 9c8eb06

File tree

7 files changed

+145
-84
lines changed

7 files changed

+145
-84
lines changed

qiskit/algorithms/time_evolvers/trotterization/trotter_qrte.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
from qiskit.quantum_info import Pauli, SparsePauliOp
2727
from qiskit.synthesis import ProductFormula, LieTrotter
2828

29-
from qiskit.algorithms.utils.assign_params import _assign_parameters, _get_parameters
30-
3129

3230
class TrotterQRTE(RealTimeEvolver):
3331
"""Quantum Real Time Evolution using Trotterization.
@@ -165,16 +163,25 @@ def evolve(self, evolution_problem: TimeEvolutionProblem) -> TimeEvolutionResult
165163
"The time evolution problem contained ``aux_operators`` but no estimator was "
166164
"provided. The algorithm continues without calculating these quantities. "
167165
)
166+
167+
# ensure the hamiltonian is a sparse pauli op
168168
hamiltonian = evolution_problem.hamiltonian
169169
if not isinstance(hamiltonian, (Pauli, PauliSumOp, SparsePauliOp)):
170170
raise ValueError(
171-
f"TrotterQRTE only accepts Pauli | PauliSumOp, {type(hamiltonian)} provided."
171+
f"TrotterQRTE only accepts Pauli | PauliSumOp | SparsePauliOp, {type(hamiltonian)} "
172+
"provided."
172173
)
174+
if isinstance(hamiltonian, PauliSumOp):
175+
hamiltonian = hamiltonian.primitive * hamiltonian.coeff
176+
elif isinstance(hamiltonian, Pauli):
177+
hamiltonian = SparsePauliOp(hamiltonian)
178+
173179
t_param = evolution_problem.t_param
174-
if t_param is not None and _get_parameters(hamiltonian.coeffs) != ParameterView([t_param]):
180+
free_parameters = hamiltonian.parameters
181+
if t_param is not None and free_parameters != ParameterView([t_param]):
175182
raise ValueError(
176-
"Hamiltonian time parameter does not match evolution_problem.t_param "
177-
"or contains multiple parameters"
183+
f"Hamiltonian time parameters ({free_parameters}) do not match "
184+
f"evolution_problem.t_param ({t_param})."
178185
)
179186

180187
# make sure PauliEvolutionGate does not implement more than one Trotter step
@@ -213,9 +220,9 @@ def evolve(self, evolution_problem: TimeEvolutionProblem) -> TimeEvolutionResult
213220
# evolution for next step
214221
if t_param is not None:
215222
time_value = (n + 1) * dt
216-
bound_coeffs = _assign_parameters(hamiltonian.coeffs, [time_value])
223+
bound_hamiltonian = hamiltonian.assign_parameters([time_value])
217224
single_step_evolution_gate = PauliEvolutionGate(
218-
SparsePauliOp(hamiltonian.paulis, bound_coeffs),
225+
bound_hamiltonian,
219226
dt,
220227
synthesis=self.product_formula,
221228
)

qiskit/algorithms/time_evolvers/variational/solvers/var_qte_linear_solver.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
from qiskit.quantum_info import SparsePauliOp
2424
from qiskit.quantum_info.operators.base_operator import BaseOperator
2525

26-
from qiskit.algorithms.utils.assign_params import _assign_parameters
27-
2826
from ..variational_principles import VariationalPrinciple
2927

3028

@@ -115,13 +113,12 @@ def solve_lse(
115113

116114
if self._time_param is not None:
117115
if time_value is not None:
118-
bound_params_array = _assign_parameters(self._hamiltonian.coeffs, [time_value])
119-
hamiltonian = SparsePauliOp(self._hamiltonian.paulis, bound_params_array)
116+
hamiltonian = hamiltonian.assign_parameters([time_value])
120117
else:
121118
raise ValueError(
122-
f"Providing a time_value is required for time-dependant hamiltonians, "
119+
"Providing a time_value is required for time-dependent hamiltonians, "
123120
f"but got time_value = {time_value}. "
124-
f"Please provide a time_value to the solve_lse method."
121+
"Please provide a time_value to the solve_lse method."
125122
)
126123

127124
evolution_grad_lse_rhs = self._var_principle.evolution_gradient(

qiskit/algorithms/utils/assign_params.py

-62
This file was deleted.

qiskit/quantum_info/operators/symplectic/sparse_pauli_op.py

+69-2
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,21 @@
1313
N-Qubit Sparse Pauli Operator class.
1414
"""
1515

16+
from __future__ import annotations
17+
1618
from collections import defaultdict
19+
from collections.abc import Mapping, Sequence
1720
from numbers import Number
1821
from typing import Dict, Optional
22+
from copy import deepcopy
1923

2024
import numpy as np
2125
import rustworkx as rx
2226

2327
from qiskit._accelerate.sparse_pauli_op import unordered_unique
28+
from qiskit.circuit.parameter import Parameter
29+
from qiskit.circuit.parameterexpression import ParameterExpression
30+
from qiskit.circuit.parametertable import ParameterView
2431
from qiskit.exceptions import QiskitError
2532
from qiskit.quantum_info.operators.custom_iterator import CustomIterator
2633
from qiskit.quantum_info.operators.linear_op import LinearOp
@@ -112,10 +119,18 @@ def __init__(self, data, coeffs=None, *, ignore_pauli_phase=False, copy=True):
112119

113120
pauli_list = PauliList(data.copy() if copy and hasattr(data, "copy") else data)
114121

115-
dtype = object if isinstance(coeffs, np.ndarray) and coeffs.dtype == object else complex
122+
if isinstance(coeffs, np.ndarray) and coeffs.dtype == object:
123+
dtype = object
124+
elif coeffs is not None:
125+
if not isinstance(coeffs, (np.ndarray, Sequence)):
126+
coeffs = [coeffs]
127+
if any(isinstance(coeff, ParameterExpression) for coeff in coeffs):
128+
dtype = object
129+
else:
130+
dtype = complex
116131

117132
if coeffs is None:
118-
coeffs = np.ones(pauli_list.size, dtype=dtype)
133+
coeffs = np.ones(pauli_list.size, dtype=complex)
119134
else:
120135
coeffs = np.array(coeffs, copy=copy, dtype=dtype)
121136

@@ -997,6 +1012,58 @@ def group_commuting(self, qubit_wise=False):
9971012
groups[color].append(idx)
9981013
return [self[group] for group in groups.values()]
9991014

1015+
@property
1016+
def parameters(self) -> ParameterView:
1017+
r"""Return the free ``Parameter``\s in the coefficients."""
1018+
ret = set()
1019+
for coeff in self.coeffs:
1020+
if isinstance(coeff, ParameterExpression):
1021+
ret |= coeff.parameters
1022+
return ParameterView(ret)
1023+
1024+
def assign_parameters(
1025+
self,
1026+
parameters: Mapping[Parameter, complex | ParameterExpression]
1027+
| Sequence[complex | ParameterExpression],
1028+
inplace: bool = False,
1029+
) -> SparsePauliOp | None:
1030+
r"""Bind the free ``Parameter``\s in the coefficients to provided values.
1031+
1032+
Args:
1033+
parameters: The values to bind the parameters to.
1034+
inplace: If ``False``, a copy of the operator with the bound parameters is returned.
1035+
If ``True`` the operator itself is modified.
1036+
1037+
Returns:
1038+
A copy of the operator with bound parameters, if ``inplace`` is ``False``, otherwise
1039+
``None``.
1040+
"""
1041+
if inplace:
1042+
bound = self
1043+
else:
1044+
bound = deepcopy(self)
1045+
1046+
# turn the parameters to a dictionary
1047+
if isinstance(parameters, Sequence):
1048+
free_parameters = bound.parameters
1049+
if len(parameters) != len(free_parameters):
1050+
raise ValueError(
1051+
f"Mismatching number of values ({len(parameters)}) and parameters "
1052+
f"({len(free_parameters)}). For partial binding please pass a dictionary of "
1053+
"{parameter: value} pairs."
1054+
)
1055+
parameters = dict(zip(free_parameters, parameters))
1056+
1057+
for i, coeff in enumerate(bound.coeffs):
1058+
if isinstance(coeff, ParameterExpression):
1059+
for key in coeff.parameters & parameters.keys():
1060+
coeff = coeff.assign(key, parameters[key])
1061+
if len(coeff.parameters) == 0:
1062+
coeff = complex(coeff)
1063+
bound.coeffs[i] = coeff
1064+
1065+
return None if inplace else bound
1066+
10001067

10011068
# Update docstrings for API docs
10021069
generate_apidocs(SparsePauliOp)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
features:
2+
- |
3+
Natively support the construction of :class:`.SparsePauliOp` objects with
4+
:class:`.ParameterExpression` coefficients, without requiring the explicit construction
5+
of an object-array. Now the following is supported::
6+
7+
from qiskit.circuit import Parameter
8+
from qiskit.quantum_info import SparsePauliOp
9+
10+
x = Parameter("x")
11+
op = SparsePauliOp(["Z", "X"], coeffs=[1, x])
12+
13+
- |
14+
Added the :meth:`.SparsePauliOp.assign_parameters` method and
15+
:attr:`.SparsePauliOp.parameters` attribute to assign and query unbound parameters
16+
inside a :class:`.SparsePauliOp`. This function can for example be used as::
17+
18+
from qiskit.circuit import Parameter
19+
from qiskit.quantum_info import SparsePauliOp
20+
21+
x = Parameter("x")
22+
op = SparsePauliOp(["Z", "X"], coeffs=[1, x])
23+
24+
# free_params will be: ParameterView([x])
25+
free_params = op.parameters
26+
27+
# assign the value 2 to the parameter x
28+
bound = op.assign_parameters([2])

test/python/algorithms/time_evolvers/test_trotter_qrte.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from numpy.testing import assert_raises
2121

2222
from qiskit.algorithms.time_evolvers import TimeEvolutionProblem, TrotterQRTE
23-
from qiskit.algorithms.utils.assign_params import _assign_parameters
2423
from qiskit.primitives import Estimator
2524
from qiskit import QuantumCircuit
2625
from qiskit.circuit.library import ZGate
@@ -245,11 +244,8 @@ def _get_expected_trotter_qrte(operator, time, num_timesteps, init_state, observ
245244
for n in range(num_timesteps):
246245
if t_param is not None:
247246
time_value = (n + 1) * dt
248-
bound_coeffs = _assign_parameters(operator.coeffs, [time_value])
249-
ops = [
250-
Pauli(op).to_matrix() * np.real(coeff)
251-
for op, coeff in SparsePauliOp(operator.paulis, bound_coeffs).to_list()
252-
]
247+
bound = operator.assign_parameters([time_value])
248+
ops = [Pauli(op).to_matrix() * np.real(coeff) for op, coeff in bound.to_list()]
253249
for op in ops:
254250
psi = expm(-1j * op * dt).dot(psi)
255251
observable_results.append(

test/python/quantum_info/operators/symplectic/test_sparse_pauli_op.py

+28
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from qiskit import QiskitError
2323
from qiskit.circuit import Parameter, ParameterVector
24+
from qiskit.circuit.parametertable import ParameterView
2425
from qiskit.quantum_info.operators import Operator, Pauli, PauliList, PauliTable, SparsePauliOp
2526
from qiskit.test import QiskitTestCase
2627

@@ -961,6 +962,33 @@ def test_dot_real(self):
961962
iz = SparsePauliOp("Z", 1j)
962963
self.assertEqual(x.dot(y), iz)
963964

965+
def test_get_parameters(self):
966+
"""Test getting the parameters."""
967+
x, y = Parameter("x"), Parameter("y")
968+
op = SparsePauliOp(["X", "Y", "Z"], coeffs=[1, x, x * y])
969+
970+
with self.subTest(msg="all parameters"):
971+
self.assertEqual(ParameterView([x, y]), op.parameters)
972+
973+
op.assign_parameters({y: 2}, inplace=True)
974+
with self.subTest(msg="after partial binding"):
975+
self.assertEqual(ParameterView([x]), op.parameters)
976+
977+
def test_assign_parameters(self):
978+
"""Test assign parameters."""
979+
x, y = Parameter("x"), Parameter("y")
980+
op = SparsePauliOp(["X", "Y", "Z"], coeffs=[1, x, x * y])
981+
982+
# partial binding inplace
983+
op.assign_parameters({y: 2}, inplace=True)
984+
with self.subTest(msg="partial binding"):
985+
self.assertListEqual(op.coeffs.tolist(), [1, x, 2 * x])
986+
987+
# bind via array
988+
bound = op.assign_parameters([3])
989+
with self.subTest(msg="fully bound"):
990+
self.assertTrue(np.allclose(bound.coeffs.astype(complex), [1, 3, 6]))
991+
964992

965993
if __name__ == "__main__":
966994
unittest.main()

0 commit comments

Comments
 (0)