13
13
Awaitable ,
14
14
Callable ,
15
15
Dict ,
16
- Hashable ,
17
16
List ,
18
17
Mapping ,
19
18
Optional ,
23
22
24
23
import anyio
25
24
import asyncpg # type: ignore
25
+ from anyio .abc import TaskGroup
26
26
27
+ from pgjobq .api import CompletionHandle as AbstractCompletionHandle
27
28
from pgjobq .api import JobHandle
28
29
from pgjobq .api import JobHandleStream as AbstractJobHandleStream
29
30
from pgjobq .api import Message
30
31
from pgjobq .api import Queue as AbstractQueue
31
32
from pgjobq .api import QueueStatistics
32
- from pgjobq .api import SendCompletionHandle as AbstractCompletionHandle
33
33
from pgjobq .sql ._functions import (
34
34
ack_message ,
35
35
extend_ack_deadlines ,
36
+ get_completed_jobs ,
36
37
get_statistics ,
37
38
nack_message ,
38
39
poll_for_messages ,
@@ -136,8 +137,8 @@ def __anext__(self) -> Awaitable[JobHandle]:
136
137
class Queue (AbstractQueue ):
137
138
pool : asyncpg .Pool
138
139
queue_name : str
139
- completion_callbacks : Dict [str , Dict [ Hashable , anyio .Event ]]
140
- new_job_callbacks : Dict [ str , Set [Callable [[], None ] ]]
140
+ completion_callbacks : Dict [UUID , Set [ anyio .Event ]]
141
+ new_job_callbacks : Set [Callable [[], None ]]
141
142
in_flight_jobs : Dict [UUID , Set [JobManager ]]
142
143
143
144
@asynccontextmanager
@@ -186,7 +187,7 @@ async def get_next_job() -> JobHandle:
186
187
187
188
# wait for a new job to be published or the poll interval to expire
188
189
new_job = anyio .Event ()
189
- self .new_job_callbacks [ self . queue_name ] .add (new_job .set ) # type: ignore
190
+ self .new_job_callbacks .add (new_job .set ) # type: ignore
190
191
191
192
async def skip_forward_if_timeout () -> None :
192
193
await anyio .sleep (poll_interval )
@@ -198,9 +199,7 @@ async def skip_forward_if_timeout() -> None:
198
199
await new_job .wait ()
199
200
gather_tg .cancel_scope .cancel ()
200
201
finally :
201
- self .new_job_callbacks [self .queue_name ].discard (
202
- new_job .set # type: ignore
203
- )
202
+ self .new_job_callbacks .discard (new_job .set ) # type: ignore
204
203
205
204
return unyielded_jobs .pop ()
206
205
@@ -214,31 +213,86 @@ async def skip_forward_if_timeout() -> None:
214
213
def send (
215
214
self , body : bytes , * bodies : bytes , delay : Optional [timedelta ] = None
216
215
) -> AsyncContextManager [AbstractCompletionHandle ]:
217
- # create the job id application side
218
- # so that we can start listening before we send
219
- all_bodies = [body , * bodies ]
220
- job_ids = [uuid4 () for _ in range (len (all_bodies ))]
221
- done_events = {id : anyio .Event () for id in job_ids }
216
+ @asynccontextmanager
217
+ async def cm () -> AsyncIterator [AbstractCompletionHandle ]:
218
+ # create the job id application side
219
+ # so that we can start listening before we send
220
+ all_bodies = [body , * bodies ]
221
+ job_ids = [uuid4 () for _ in range (len (all_bodies ))]
222
+ async with self .wait_for_completion (* job_ids , poll_interval = None ) as handle :
223
+ conn : asyncpg .Connection
224
+ async with self .pool .acquire () as conn : # type: ignore
225
+ await publish_messages (
226
+ conn ,
227
+ queue_name = self .queue_name ,
228
+ message_ids = job_ids ,
229
+ message_bodies = all_bodies ,
230
+ delay = delay ,
231
+ )
232
+ yield handle
233
+
234
+ return cm ()
222
235
236
+ def wait_for_completion (
237
+ self ,
238
+ job : UUID ,
239
+ * jobs : UUID ,
240
+ poll_interval : Optional [timedelta ] = timedelta (seconds = 10 ),
241
+ ) -> AsyncContextManager [AbstractCompletionHandle ]:
223
242
@asynccontextmanager
224
243
async def cm () -> AsyncIterator [AbstractCompletionHandle ]:
225
- conn : asyncpg .Connection
226
- async with self .pool .acquire () as conn : # type: ignore
227
- for job_id , done_event in done_events .items ():
228
- self .completion_callbacks [self .queue_name ][job_id ] = done_event
229
- await publish_messages (
230
- conn ,
231
- queue_name = self .queue_name ,
232
- message_ids = job_ids ,
233
- message_bodies = all_bodies ,
234
- delay = delay ,
235
- )
236
- handle = JobCompletionHandle (jobs = done_events )
237
- try :
238
- yield handle
239
- finally :
240
- for job_id , done_event in done_events .items ():
241
- self .completion_callbacks [self .queue_name ].pop (job_id )
244
+ done_events = {id : anyio .Event () for id in (job , * jobs )}
245
+ for job_id , event in done_events .items ():
246
+ self .completion_callbacks [job_id ].add (event )
247
+
248
+ def cleanup_done_events () -> None :
249
+ for job_id in done_events .copy ().keys ():
250
+ if done_events [job_id ].is_set ():
251
+ event = done_events .pop (job_id )
252
+ self .completion_callbacks [job_id ].discard (event )
253
+ if not self .completion_callbacks [job_id ]:
254
+ self .completion_callbacks .pop (job_id )
255
+
256
+ async def poll_for_completion (interval : float ) -> None :
257
+ nonlocal done_events
258
+ while True :
259
+ new_completion = anyio .Event ()
260
+ # wait for a completion notification or poll interval to expire
261
+
262
+ async def set_new_completion (
263
+ job_id : UUID , event : anyio .Event , tg : TaskGroup
264
+ ) -> None :
265
+ await event .wait ()
266
+ new_completion .set ()
267
+ tg .cancel_scope .cancel ()
268
+
269
+ async with anyio .create_task_group () as tg :
270
+ for event in done_events .values ():
271
+ tg .start_soon (set_new_completion , event , tg )
272
+ await anyio .sleep (interval )
273
+
274
+ if not new_completion .is_set ():
275
+ # poll
276
+ completed_jobs = await get_completed_jobs (
277
+ self .pool , self .queue_name , job_ids = list (done_events .keys ())
278
+ )
279
+ if completed_jobs :
280
+ new_completion .set ()
281
+ for job in completed_jobs :
282
+ done_events [job ].set ()
283
+ if new_completion .is_set ():
284
+ cleanup_done_events ()
285
+
286
+ try :
287
+ async with anyio .create_task_group () as tg :
288
+ if poll_interval :
289
+ tg .start_soon (
290
+ poll_for_completion , poll_interval .total_seconds ()
291
+ )
292
+ yield JobCompletionHandle (jobs = done_events .copy ())
293
+ tg .cancel_scope .cancel ()
294
+ finally :
295
+ cleanup_done_events ()
242
296
243
297
return cm ()
244
298
@@ -267,9 +321,9 @@ async def connect_to_queue(
267
321
Returns:
268
322
AsyncContextManager[AbstractQueue]: A context manager yielding an AbstractQueue
269
323
"""
270
- completion_callbacks : Dict [str , Dict [ Hashable , anyio .Event ]] = defaultdict (dict )
271
- new_job_callbacks : Dict [ str , Set [Callable [[], None ]]] = defaultdict ( set )
272
- in_flight_jobs : Dict [UUID , Set [JobManager ]] = {}
324
+ completion_callbacks : Dict [UUID , Set [ anyio .Event ]] = defaultdict (set )
325
+ new_job_callbacks : Set [Callable [[], None ]] = set ( )
326
+ checked_out_jobs : Dict [UUID , Set [JobManager ]] = {}
273
327
274
328
async def run_cleanup (conn : asyncpg .Connection ) -> None :
275
329
while True :
@@ -281,7 +335,7 @@ async def run_cleanup(conn: asyncpg.Connection) -> None:
281
335
async def extend_acks (conn : asyncpg .Connection ) -> None :
282
336
while True :
283
337
job_ids = [
284
- job .message .id for jobs in in_flight_jobs .values () for job in jobs
338
+ job .message .id for jobs in checked_out_jobs .values () for job in jobs
285
339
]
286
340
await extend_ack_deadlines (
287
341
conn ,
@@ -296,40 +350,42 @@ async def process_completion_notification(
296
350
channel : str ,
297
351
payload : str ,
298
352
) -> None :
299
- queue_name , job_id = payload .split ("," )
300
- cb = completion_callbacks [queue_name ].get (UUID (job_id ), None )
301
- if cb is not None :
302
- cb .set ()
353
+ job_id = payload
354
+ job_id_key = UUID (job_id )
355
+ events = completion_callbacks .get (job_id_key , None ) or ()
356
+ for event in events :
357
+ event .set ()
303
358
304
359
async def process_new_job_notification (
305
360
conn : asyncpg .Connection ,
306
361
pid : int ,
307
362
channel : str ,
308
363
payload : str ,
309
364
) -> None :
310
- queue_name , * _ = payload .split ("," )
311
- for event in new_job_callbacks [queue_name ]:
312
- event ()
365
+ for cb in new_job_callbacks :
366
+ cb ()
313
367
314
368
async with AsyncExitStack () as stack :
315
369
cleanup_conn : asyncpg .Connection = await stack .enter_async_context (pool .acquire ()) # type: ignore
316
370
ack_conn : asyncpg .Connection = await stack .enter_async_context (pool .acquire ()) # type: ignore
371
+ completion_channel = f"pgjobq.job_completed_{ queue_name } "
372
+ new_job_channel = f"pgjobq.new_job_{ queue_name } "
317
373
await cleanup_conn .add_listener ( # type: ignore
318
- channel = "pgjobq.job_completed" ,
374
+ channel = completion_channel ,
319
375
callback = process_completion_notification ,
320
376
)
321
377
stack .push_async_callback (
322
378
cleanup_conn .remove_listener , # type: ignore
323
- channel = "pgjobq.job_completed" ,
379
+ channel = completion_channel ,
324
380
callback = process_completion_notification ,
325
381
)
326
382
await cleanup_conn .add_listener ( # type: ignore
327
- channel = "pgjobq.new_job" ,
383
+ channel = new_job_channel ,
328
384
callback = process_new_job_notification ,
329
385
)
330
386
stack .push_async_callback (
331
387
cleanup_conn .remove_listener , # type: ignore
332
- channel = "pgjobq.new_job" ,
388
+ channel = new_job_channel ,
333
389
callback = process_new_job_notification ,
334
390
)
335
391
async with anyio .create_task_group () as tg :
@@ -341,7 +397,7 @@ async def process_new_job_notification(
341
397
queue_name = queue_name ,
342
398
completion_callbacks = completion_callbacks ,
343
399
new_job_callbacks = new_job_callbacks ,
344
- in_flight_jobs = in_flight_jobs ,
400
+ in_flight_jobs = checked_out_jobs ,
345
401
)
346
402
try :
347
403
yield queue
0 commit comments