Skip to content

Commit 1ed566b

Browse files
authored
Add async support in EngineClient, EngineSampler, etc. (#5219)
Review@ @mpharrigan
1 parent 4e36288 commit 1ed566b

File tree

7 files changed

+485
-351
lines changed

7 files changed

+485
-351
lines changed

cirq-google/cirq_google/engine/engine.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import string
3030
from typing import Dict, Iterable, List, Optional, Sequence, Set, TypeVar, Union, TYPE_CHECKING
3131

32+
import duet
3233
import google.auth
3334
from google.protobuf import any_pb2
3435

@@ -493,7 +494,7 @@ def run_calibration(
493494
)
494495

495496
@util.deprecated_gate_set_parameter
496-
def create_program(
497+
async def create_program_async(
497498
self,
498499
program: cirq.AbstractCircuit,
499500
program_id: Optional[str] = None,
@@ -524,7 +525,7 @@ def create_program(
524525
if not program_id:
525526
program_id = _make_random_id('prog-')
526527

527-
new_program_id, new_program = self.context.client.create_program(
528+
new_program_id, new_program = await self.context.client.create_program_async(
528529
self.project_id,
529530
program_id,
530531
code=self.context._serialize_program(program, gate_set),
@@ -536,8 +537,10 @@ def create_program(
536537
self.project_id, new_program_id, self.context, new_program
537538
)
538539

540+
create_program = duet.sync(create_program_async)
541+
539542
@util.deprecated_gate_set_parameter
540-
def create_batch_program(
543+
async def create_batch_program_async(
541544
self,
542545
programs: Sequence[cirq.AbstractCircuit],
543546
program_id: Optional[str] = None,
@@ -574,7 +577,7 @@ def create_batch_program(
574577
for program in programs:
575578
gate_set.serialize(program, msg=batch.programs.add())
576579

577-
new_program_id, new_program = self.context.client.create_program(
580+
new_program_id, new_program = await self.context.client.create_program_async(
578581
self.project_id,
579582
program_id,
580583
code=util.pack_any(batch),
@@ -586,8 +589,10 @@ def create_batch_program(
586589
self.project_id, new_program_id, self.context, new_program, result_type=ResultType.Batch
587590
)
588591

592+
create_batch_program = duet.sync(create_batch_program_async)
593+
589594
@util.deprecated_gate_set_parameter
590-
def create_calibration_program(
595+
async def create_calibration_program_async(
591596
self,
592597
layers: List['cirq_google.CalibrationLayer'],
593598
program_id: Optional[str] = None,
@@ -632,7 +637,7 @@ def create_calibration_program(
632637
arg_to_proto(layer.args[arg], out=new_layer.args[arg])
633638
gate_set.serialize(layer.program, msg=new_layer.layer)
634639

635-
new_program_id, new_program = self.context.client.create_program(
640+
new_program_id, new_program = await self.context.client.create_program_async(
636641
self.project_id,
637642
program_id,
638643
code=util.pack_any(calibration),
@@ -648,6 +653,8 @@ def create_calibration_program(
648653
result_type=ResultType.Calibration,
649654
)
650655

656+
create_calibration_program = duet.sync(create_calibration_program_async)
657+
651658
def get_program(self, program_id: str) -> engine_program.EngineProgram:
652659
"""Returns an EngineProgram for an existing Quantum Engine program.
653660
@@ -659,7 +666,7 @@ def get_program(self, program_id: str) -> engine_program.EngineProgram:
659666
"""
660667
return engine_program.EngineProgram(self.project_id, program_id, self.context)
661668

662-
def list_programs(
669+
async def list_programs_async(
663670
self,
664671
created_before: Optional[Union[datetime.datetime, datetime.date]] = None,
665672
created_after: Optional[Union[datetime.datetime, datetime.date]] = None,
@@ -681,7 +688,7 @@ def list_programs(
681688
"""
682689

683690
client = self.context.client
684-
response = client.list_programs(
691+
response = await client.list_programs_async(
685692
self.project_id,
686693
created_before=created_before,
687694
created_after=created_after,
@@ -697,7 +704,9 @@ def list_programs(
697704
for p in response
698705
]
699706

700-
def list_jobs(
707+
list_programs = duet.sync(list_programs_async)
708+
709+
async def list_jobs_async(
701710
self,
702711
created_before: Optional[Union[datetime.datetime, datetime.date]] = None,
703712
created_after: Optional[Union[datetime.datetime, datetime.date]] = None,
@@ -730,7 +739,7 @@ def list_jobs(
730739
`quantum.ExecutionStatus.State` enum for accepted values.
731740
"""
732741
client = self.context.client
733-
response = client.list_jobs(
742+
response = await client.list_jobs_async(
734743
self.project_id,
735744
None,
736745
created_before=created_before,
@@ -749,7 +758,9 @@ def list_jobs(
749758
for j in response
750759
]
751760

752-
def list_processors(self) -> List[engine_processor.EngineProcessor]:
761+
list_jobs = duet.sync(list_jobs_async)
762+
763+
async def list_processors_async(self) -> List[engine_processor.EngineProcessor]:
753764
"""Returns a list of Processors that the user has visibility to in the
754765
current Engine project. The names of these processors are used to
755766
identify devices when scheduling jobs and gathering calibration metrics.
@@ -758,14 +769,16 @@ def list_processors(self) -> List[engine_processor.EngineProcessor]:
758769
A list of EngineProcessors to access status, device and calibration
759770
information.
760771
"""
761-
response = self.context.client.list_processors(self.project_id)
772+
response = await self.context.client.list_processors_async(self.project_id)
762773
return [
763774
engine_processor.EngineProcessor(
764775
self.project_id, engine_client._ids_from_processor_name(p.name)[1], self.context, p
765776
)
766777
for p in response
767778
]
768779

780+
list_processors = duet.sync(list_processors_async)
781+
769782
def get_processor(self, processor_id: str) -> engine_processor.EngineProcessor:
770783
"""Returns an EngineProcessor for a Quantum Engine processor.
771784

cirq-google/cirq_google/engine/engine_job.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414
"""A helper for jobs that have been created on the Quantum Engine."""
1515
import datetime
16-
import time
1716

1817
from typing import Dict, Iterator, List, Optional, overload, Sequence, Tuple, TYPE_CHECKING
1918

19+
import duet
2020
from google.protobuf import any_pb2
2121

2222
import cirq
@@ -107,20 +107,25 @@ def program(self) -> 'engine_program.EngineProgram':
107107

108108
return engine_program.EngineProgram(self.project_id, self.program_id, self.context)
109109

110+
async def _get_job_async(self, return_run_context: bool = False) -> quantum.QuantumJob:
111+
return await self.context.client.get_job_async(
112+
self.project_id, self.program_id, self.job_id, return_run_context
113+
)
114+
115+
_get_job = duet.sync(_get_job_async)
116+
110117
def _inner_job(self) -> quantum.QuantumJob:
111118
if self._job is None:
112-
self._job = self.context.client.get_job(
113-
self.project_id, self.program_id, self.job_id, False
114-
)
119+
self._job = self._get_job()
115120
return self._job
116121

117-
def _refresh_job(self) -> quantum.QuantumJob:
122+
async def _refresh_job_async(self) -> quantum.QuantumJob:
118123
if self._job is None or self._job.execution_status.state not in TERMINAL_STATES:
119-
self._job = self.context.client.get_job(
120-
self.project_id, self.program_id, self.job_id, False
121-
)
124+
self._job = await self._get_job_async()
122125
return self._job
123126

127+
_refresh_job = duet.sync(_refresh_job_async)
128+
124129
def create_time(self) -> 'datetime.datetime':
125130
"""Returns when the job was created."""
126131
return self._inner_job().create_time
@@ -224,10 +229,7 @@ def get_repetitions_and_sweeps(self) -> Tuple[int, List[cirq.Sweep]]:
224229
A tuple of the repetition count and list of sweeps.
225230
"""
226231
if self._job is None or self._job.run_context is None:
227-
self._job = self.context.client.get_job(
228-
self.project_id, self.program_id, self.job_id, True
229-
)
230-
232+
self._job = self._get_job(return_run_context=True)
231233
return _deserialize_run_context(self._job.run_context)
232234

233235
def get_processor(self) -> 'Optional[engine_processor.EngineProcessor]':
@@ -260,42 +262,26 @@ def delete(self) -> None:
260262
"""Deletes the job and result, if any."""
261263
self.context.client.delete_job(self.project_id, self.program_id, self.job_id)
262264

263-
def batched_results(self) -> Sequence[Sequence[EngineResult]]:
265+
async def batched_results_async(self) -> Sequence[Sequence[EngineResult]]:
264266
"""Returns the job results, blocking until the job is complete.
265267
266268
This method is intended for batched jobs. Instead of flattening
267269
results into a single list, this will return a Sequence[Result]
268270
for each circuit in the batch.
269271
"""
270-
self.results()
272+
await self.results_async()
271273
if self._batched_results is None:
272274
raise ValueError('batched_results called for a non-batch result.')
273275
return self._batched_results
274276

275-
def _wait_for_result(self):
276-
job = self._refresh_job()
277-
total_seconds_waited = 0.0
278-
timeout = self.context.timeout
279-
while True:
280-
if timeout and total_seconds_waited >= timeout:
281-
break
282-
if job.execution_status.state in TERMINAL_STATES:
283-
break
284-
time.sleep(0.5)
285-
total_seconds_waited += 0.5
286-
job = self._refresh_job()
287-
_raise_on_failure(job)
288-
response = self.context.client.get_job_results(
289-
self.project_id, self.program_id, self.job_id
290-
)
291-
return response.result
277+
batched_results = duet.sync(batched_results_async)
292278

293-
def results(self) -> Sequence[EngineResult]:
279+
async def results_async(self) -> Sequence[EngineResult]:
294280
"""Returns the job results, blocking until the job is complete."""
295281
import cirq_google.engine.engine as engine_base
296282

297283
if self._results is None:
298-
result = self._wait_for_result()
284+
result = await self._await_result_async()
299285
result_type = result.type_url[len(engine_base.TYPE_PREFIX) :]
300286
if (
301287
result_type == 'cirq.google.api.v1.Result'
@@ -317,7 +303,22 @@ def results(self) -> Sequence[EngineResult]:
317303
raise ValueError(f'invalid result proto version: {result_type}')
318304
return self._results
319305

320-
def calibration_results(self) -> Sequence[CalibrationResult]:
306+
results = duet.sync(results_async)
307+
308+
async def _await_result_async(self) -> quantum.QuantumResult:
309+
async with duet.timeout_scope(self.context.timeout):
310+
while True:
311+
job = await self._refresh_job_async()
312+
if job.execution_status.state in TERMINAL_STATES:
313+
break
314+
await duet.sleep(0.5)
315+
_raise_on_failure(job)
316+
response = await self.context.client.get_job_results_async(
317+
self.project_id, self.program_id, self.job_id
318+
)
319+
return response.result
320+
321+
async def calibration_results_async(self) -> Sequence[CalibrationResult]:
321322
"""Returns the results of a run_calibration() call.
322323
323324
This function will fail if any other type of results were returned
@@ -326,7 +327,7 @@ def calibration_results(self) -> Sequence[CalibrationResult]:
326327
import cirq_google.engine.engine as engine_base
327328

328329
if self._calibration_results is None:
329-
result = self._wait_for_result()
330+
result = await self._await_result_async()
330331
result_type = result.type_url[len(engine_base.TYPE_PREFIX) :]
331332
if result_type != 'cirq.google.api.v2.FocusedCalibrationResult':
332333
raise ValueError(f'Did not find calibration results, instead found: {result_type}')
@@ -343,6 +344,8 @@ def calibration_results(self) -> Sequence[CalibrationResult]:
343344
self._calibration_results = cal_results
344345
return self._calibration_results
345346

347+
calibration_results = duet.sync(calibration_results_async)
348+
346349
def _get_job_results_v1(self, result: v1.program_pb2.Result) -> Sequence[EngineResult]:
347350
# coverage: ignore
348351
job_id = self.id()

0 commit comments

Comments
 (0)