Skip to content

Commit 7d5d7ea

Browse files
committed
Add Queue.wait_for_completion()
1 parent 16935a1 commit 7d5d7ea

File tree

5 files changed

+143
-62
lines changed

5 files changed

+143
-62
lines changed

Diff for: pgjobq/_queue.py

+102-46
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
Awaitable,
1414
Callable,
1515
Dict,
16-
Hashable,
1716
List,
1817
Mapping,
1918
Optional,
@@ -23,16 +22,18 @@
2322

2423
import anyio
2524
import asyncpg # type: ignore
25+
from anyio.abc import TaskGroup
2626

27+
from pgjobq.api import CompletionHandle as AbstractCompletionHandle
2728
from pgjobq.api import JobHandle
2829
from pgjobq.api import JobHandleStream as AbstractJobHandleStream
2930
from pgjobq.api import Message
3031
from pgjobq.api import Queue as AbstractQueue
3132
from pgjobq.api import QueueStatistics
32-
from pgjobq.api import SendCompletionHandle as AbstractCompletionHandle
3333
from pgjobq.sql._functions import (
3434
ack_message,
3535
extend_ack_deadlines,
36+
get_completed_jobs,
3637
get_statistics,
3738
nack_message,
3839
poll_for_messages,
@@ -136,8 +137,8 @@ def __anext__(self) -> Awaitable[JobHandle]:
136137
class Queue(AbstractQueue):
137138
pool: asyncpg.Pool
138139
queue_name: str
139-
completion_callbacks: Dict[str, Dict[Hashable, anyio.Event]]
140-
new_job_callbacks: Dict[str, Set[Callable[[], None]]]
140+
completion_callbacks: Dict[UUID, Set[anyio.Event]]
141+
new_job_callbacks: Set[Callable[[], None]]
141142
in_flight_jobs: Dict[UUID, Set[JobManager]]
142143

143144
@asynccontextmanager
@@ -186,7 +187,7 @@ async def get_next_job() -> JobHandle:
186187

187188
# wait for a new job to be published or the poll interval to expire
188189
new_job = anyio.Event()
189-
self.new_job_callbacks[self.queue_name].add(new_job.set) # type: ignore
190+
self.new_job_callbacks.add(new_job.set) # type: ignore
190191

191192
async def skip_forward_if_timeout() -> None:
192193
await anyio.sleep(poll_interval)
@@ -198,9 +199,7 @@ async def skip_forward_if_timeout() -> None:
198199
await new_job.wait()
199200
gather_tg.cancel_scope.cancel()
200201
finally:
201-
self.new_job_callbacks[self.queue_name].discard(
202-
new_job.set # type: ignore
203-
)
202+
self.new_job_callbacks.discard(new_job.set) # type: ignore
204203

205204
return unyielded_jobs.pop()
206205

@@ -214,31 +213,86 @@ async def skip_forward_if_timeout() -> None:
214213
def send(
215214
self, body: bytes, *bodies: bytes, delay: Optional[timedelta] = None
216215
) -> AsyncContextManager[AbstractCompletionHandle]:
217-
# create the job id application side
218-
# so that we can start listening before we send
219-
all_bodies = [body, *bodies]
220-
job_ids = [uuid4() for _ in range(len(all_bodies))]
221-
done_events = {id: anyio.Event() for id in job_ids}
216+
@asynccontextmanager
217+
async def cm() -> AsyncIterator[AbstractCompletionHandle]:
218+
# create the job id application side
219+
# so that we can start listening before we send
220+
all_bodies = [body, *bodies]
221+
job_ids = [uuid4() for _ in range(len(all_bodies))]
222+
async with self.wait_for_completion(*job_ids, poll_interval=None) as handle:
223+
conn: asyncpg.Connection
224+
async with self.pool.acquire() as conn: # type: ignore
225+
await publish_messages(
226+
conn,
227+
queue_name=self.queue_name,
228+
message_ids=job_ids,
229+
message_bodies=all_bodies,
230+
delay=delay,
231+
)
232+
yield handle
233+
234+
return cm()
222235

236+
def wait_for_completion(
237+
self,
238+
job: UUID,
239+
*jobs: UUID,
240+
poll_interval: Optional[timedelta] = timedelta(seconds=10),
241+
) -> AsyncContextManager[AbstractCompletionHandle]:
223242
@asynccontextmanager
224243
async def cm() -> AsyncIterator[AbstractCompletionHandle]:
225-
conn: asyncpg.Connection
226-
async with self.pool.acquire() as conn: # type: ignore
227-
for job_id, done_event in done_events.items():
228-
self.completion_callbacks[self.queue_name][job_id] = done_event
229-
await publish_messages(
230-
conn,
231-
queue_name=self.queue_name,
232-
message_ids=job_ids,
233-
message_bodies=all_bodies,
234-
delay=delay,
235-
)
236-
handle = JobCompletionHandle(jobs=done_events)
237-
try:
238-
yield handle
239-
finally:
240-
for job_id, done_event in done_events.items():
241-
self.completion_callbacks[self.queue_name].pop(job_id)
244+
done_events = {id: anyio.Event() for id in (job, *jobs)}
245+
for job_id, event in done_events.items():
246+
self.completion_callbacks[job_id].add(event)
247+
248+
def cleanup_done_events() -> None:
249+
for job_id in done_events.copy().keys():
250+
if done_events[job_id].is_set():
251+
event = done_events.pop(job_id)
252+
self.completion_callbacks[job_id].discard(event)
253+
if not self.completion_callbacks[job_id]:
254+
self.completion_callbacks.pop(job_id)
255+
256+
async def poll_for_completion(interval: float) -> None:
257+
nonlocal done_events
258+
while True:
259+
new_completion = anyio.Event()
260+
# wait for a completion notification or poll interval to expire
261+
262+
async def set_new_completion(
263+
job_id: UUID, event: anyio.Event, tg: TaskGroup
264+
) -> None:
265+
await event.wait()
266+
new_completion.set()
267+
tg.cancel_scope.cancel()
268+
269+
async with anyio.create_task_group() as tg:
270+
for event in done_events.values():
271+
tg.start_soon(set_new_completion, event, tg)
272+
await anyio.sleep(interval)
273+
274+
if not new_completion.is_set():
275+
# poll
276+
completed_jobs = await get_completed_jobs(
277+
self.pool, self.queue_name, job_ids=list(done_events.keys())
278+
)
279+
if completed_jobs:
280+
new_completion.set()
281+
for job in completed_jobs:
282+
done_events[job].set()
283+
if new_completion.is_set():
284+
cleanup_done_events()
285+
286+
try:
287+
async with anyio.create_task_group() as tg:
288+
if poll_interval:
289+
tg.start_soon(
290+
poll_for_completion, poll_interval.total_seconds()
291+
)
292+
yield JobCompletionHandle(jobs=done_events.copy())
293+
tg.cancel_scope.cancel()
294+
finally:
295+
cleanup_done_events()
242296

243297
return cm()
244298

@@ -267,9 +321,9 @@ async def connect_to_queue(
267321
Returns:
268322
AsyncContextManager[AbstractQueue]: A context manager yielding an AbstractQueue
269323
"""
270-
completion_callbacks: Dict[str, Dict[Hashable, anyio.Event]] = defaultdict(dict)
271-
new_job_callbacks: Dict[str, Set[Callable[[], None]]] = defaultdict(set)
272-
in_flight_jobs: Dict[UUID, Set[JobManager]] = {}
324+
completion_callbacks: Dict[UUID, Set[anyio.Event]] = defaultdict(set)
325+
new_job_callbacks: Set[Callable[[], None]] = set()
326+
checked_out_jobs: Dict[UUID, Set[JobManager]] = {}
273327

274328
async def run_cleanup(conn: asyncpg.Connection) -> None:
275329
while True:
@@ -281,7 +335,7 @@ async def run_cleanup(conn: asyncpg.Connection) -> None:
281335
async def extend_acks(conn: asyncpg.Connection) -> None:
282336
while True:
283337
job_ids = [
284-
job.message.id for jobs in in_flight_jobs.values() for job in jobs
338+
job.message.id for jobs in checked_out_jobs.values() for job in jobs
285339
]
286340
await extend_ack_deadlines(
287341
conn,
@@ -296,40 +350,42 @@ async def process_completion_notification(
296350
channel: str,
297351
payload: str,
298352
) -> None:
299-
queue_name, job_id = payload.split(",")
300-
cb = completion_callbacks[queue_name].get(UUID(job_id), None)
301-
if cb is not None:
302-
cb.set()
353+
job_id = payload
354+
job_id_key = UUID(job_id)
355+
events = completion_callbacks.get(job_id_key, None) or ()
356+
for event in events:
357+
event.set()
303358

304359
async def process_new_job_notification(
305360
conn: asyncpg.Connection,
306361
pid: int,
307362
channel: str,
308363
payload: str,
309364
) -> None:
310-
queue_name, *_ = payload.split(",")
311-
for event in new_job_callbacks[queue_name]:
312-
event()
365+
for cb in new_job_callbacks:
366+
cb()
313367

314368
async with AsyncExitStack() as stack:
315369
cleanup_conn: asyncpg.Connection = await stack.enter_async_context(pool.acquire()) # type: ignore
316370
ack_conn: asyncpg.Connection = await stack.enter_async_context(pool.acquire()) # type: ignore
371+
completion_channel = f"pgjobq.job_completed_{queue_name}"
372+
new_job_channel = f"pgjobq.new_job_{queue_name}"
317373
await cleanup_conn.add_listener( # type: ignore
318-
channel="pgjobq.job_completed",
374+
channel=completion_channel,
319375
callback=process_completion_notification,
320376
)
321377
stack.push_async_callback(
322378
cleanup_conn.remove_listener, # type: ignore
323-
channel="pgjobq.job_completed",
379+
channel=completion_channel,
324380
callback=process_completion_notification,
325381
)
326382
await cleanup_conn.add_listener( # type: ignore
327-
channel="pgjobq.new_job",
383+
channel=new_job_channel,
328384
callback=process_new_job_notification,
329385
)
330386
stack.push_async_callback(
331387
cleanup_conn.remove_listener, # type: ignore
332-
channel="pgjobq.new_job",
388+
channel=new_job_channel,
333389
callback=process_new_job_notification,
334390
)
335391
async with anyio.create_task_group() as tg:
@@ -341,7 +397,7 @@ async def process_new_job_notification(
341397
queue_name=queue_name,
342398
completion_callbacks=completion_callbacks,
343399
new_job_callbacks=new_job_callbacks,
344-
in_flight_jobs=in_flight_jobs,
400+
in_flight_jobs=checked_out_jobs,
345401
)
346402
try:
347403
yield queue

Diff for: pgjobq/api.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def acquire(self) -> AsyncContextManager[Message]:
4848
...
4949

5050

51-
class SendCompletionHandle(Protocol):
51+
class CompletionHandle(Protocol):
5252
@property
5353
def jobs(self) -> Mapping[UUID, anyio.Event]:
5454
"""Completion events for each published job"""
@@ -102,7 +102,7 @@ def receive(
102102
@abstractmethod
103103
def send(
104104
self, body: bytes, *bodies: bytes, delay: Optional[timedelta] = None
105-
) -> AsyncContextManager[SendCompletionHandle]:
105+
) -> AsyncContextManager[CompletionHandle]:
106106
"""Put jobs on the queue.
107107
108108
You _must_ enter the context manager but awaiting the completion
@@ -116,6 +116,24 @@ def send(
116116
"""
117117
pass # pragma: no cover
118118

119+
@abstractmethod
120+
def wait_for_completion(
121+
self,
122+
job: UUID,
123+
*jobs: UUID,
124+
poll_interval: timedelta = timedelta(seconds=10),
125+
) -> AsyncContextManager[CompletionHandle]:
126+
"""Wait for a job or group of jobs to complete
127+
128+
Args:
129+
job (UUID): job ID as returned by Queue.send()
130+
poll_interval (timedelta, optional): interval to poll for completion. Defaults to 10 seconds.
131+
132+
Returns:
133+
AsyncContextManager[CompletionHandle]: A context manager that returns a completion handle.
134+
"""
135+
pass
136+
119137
@abstractmethod
120138
async def get_statistics(self) -> QueueStatistics:
121139
"""Gather statistics from the queue.

Diff for: pgjobq/sql/_functions.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, *, queue_name: str) -> None:
5454
unnest($3::bytea[])
5555
FROM queue_info
5656
RETURNING (
57-
SELECT 'true'::bool FROM pg_notify('pgjobq.new_job', $1)
57+
SELECT 'true'::bool FROM pg_notify('pgjobq.new_job_' || $1, '')
5858
); -- NULL if the queue doesn't exist
5959
"""
6060

@@ -105,7 +105,7 @@ async def publish_messages(
105105
), updated_queue_info AS (
106106
UPDATE pgjobq.queues
107107
SET
108-
undelivered_message_count = undelivered_message_count - (SELECT sum(first_delivery) FROM selected_messages)
108+
undelivered_message_count = undelivered_message_count - (SELECT COALESCE(sum(first_delivery), 0) FROM selected_messages)
109109
WHERE id = (SELECT id FROM queue_info)
110110
)
111111
UPDATE pgjobq.messages
@@ -156,7 +156,7 @@ async def poll_for_messages(
156156
WHERE pgjobq.messages.id = (SELECT id FROM msg)
157157
RETURNING (
158158
SELECT
159-
pg_notify('pgjobq.job_completed', $1 || ',' || CAST($2::uuid AS text))
159+
pg_notify('pgjobq.job_completed_' || $1, CAST($2::uuid AS text))
160160
) AS notified;
161161
"""
162162

@@ -175,7 +175,7 @@ async def ack_message(
175175
-- which check to make sure the message is still available before extending
176176
SET available_at = now() - '1 second'::interval
177177
WHERE queue_id = (SELECT id FROM pgjobq.queues WHERE name = $1) AND id = $2
178-
RETURNING (SELECT pg_notify('pgjobq.new_job', $1));
178+
RETURNING (SELECT pg_notify('pgjobq.new_job_' || $1, ''));
179179
"""
180180

181181

@@ -248,3 +248,11 @@ async def get_statistics(
248248
if record is None:
249249
raise QueueDoesNotExist(queue_name=queue_name)
250250
return record
251+
252+
253+
async def get_completed_jobs(
254+
conn: PoolOrConnection,
255+
queue_name: str,
256+
job_ids: List[UUID],
257+
) -> Sequence[UUID]:
258+
...

Diff for: pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pgjobq"
3-
version = "0.3.0"
3+
version = "0.4.0"
44
description = "PostgreSQL backed job queues"
55
authors = ["Adrian Garcia Badaracco <[email protected]>"]
66
license = "MIT"

Diff for: tests/test_queue.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -282,19 +282,18 @@ async def test_new_message_notification_triggers_poll(
282282

283283
async with anyio.create_task_group() as tg:
284284

285-
async def worker() -> None:
285+
async def worker(*, task_status: TaskStatus) -> None:
286286
async with queue.receive(poll_interval=60) as job_handle_stream:
287-
await job_handle_stream.receive()
288-
rcv_times.append(time())
289-
return
287+
task_status.started()
288+
async with (await job_handle_stream.receive()).acquire():
289+
rcv_times.append(time())
290290

291-
tg.start_soon(worker)
292-
# wait for the worker to start polling
293-
await anyio.sleep(0.05)
291+
await tg.start(worker)
294292

295-
async with queue.send(b'{"foo":"bar"}'):
293+
async with queue.send(b'{"foo":"bar"}') as handle:
296294
send_times.append(time())
297-
pass
295+
await handle()
296+
print(1)
298297

299298
assert len(send_times) == len(rcv_times)
300299
# not deterministic

0 commit comments

Comments
 (0)