38
38
from collections .abc import Mapping
39
39
from collections .abc import Sequence
40
40
41
+ from linkd import compose
41
42
from linkd import conditions
42
43
from linkd import container
43
44
from linkd import context as context_
58
59
R = t .TypeVar ("R" )
59
60
T = t .TypeVar ("T" )
60
61
AsyncFnT = t .TypeVar ("AsyncFnT" , bound = t .Callable [..., Coroutine [t .Any , t .Any , t .Any ]])
62
+ DependencyExprOrComposed : t .TypeAlias = t .Union [conditions .DependencyExpression [t .Any ], type [compose .Compose ]]
63
+
64
+ LOGGER = logging .getLogger (__name__ )
61
65
62
66
DI_ENABLED : t .Final [bool ] = os .environ .get ("LINKD_DI_DISABLED" , "false" ).lower () != "true"
63
67
DI_CONTAINER : contextvars .ContextVar [container .Container | None ] = contextvars .ContextVar (
64
68
"linkd_container" , default = None
65
69
)
66
- LOGGER = logging .getLogger (__name__ )
67
70
68
71
INJECTED : t .Final [t .Any ] = utils .Marker ("INJECTED" )
69
72
"""
@@ -288,17 +291,39 @@ async def close(self) -> None:
288
291
CANNOT_INJECT : t .Final [t .Any ] = utils .Marker ("CANNOT_INJECT" )
289
292
290
293
291
- def _parse_injectable_params (func : Callable [..., t .Any ]) -> tuple [list [tuple [str , t .Any ]], dict [str , t .Any ]]:
292
- positional_or_keyword_params : list [tuple [str , t .Any ]] = []
293
- keyword_only_params : dict [str , t .Any ] = {}
294
+ def _parse_composed_dependencies (cls : type [compose .Compose ]) -> dict [str , conditions .DependencyExpression [t .Any ]]:
295
+ if (existing := getattr (cls , compose ._DEPS_ATTR , None )) is not None :
296
+ return existing
297
+
298
+ actual_class = getattr (cls , compose ._ACTUAL_ATTR , None )
299
+ if actual_class is None :
300
+ raise TypeError (f"class { cls } is not a composed dependency" )
301
+
302
+ actual_class = t .cast ("type[t.Any]" , actual_class )
303
+ hints = t .get_type_hints (
304
+ actual_class , localns = {m : sys .modules [m ] for m in utils .ANNOTATION_PARSE_LOCAL_INCLUDE_MODULES }
305
+ )
306
+ return {
307
+ name : conditions .DependencyExpression .create (annotation )
308
+ for name , annotation in hints .items ()
309
+ if name in getattr (cls , "__slots__" )
310
+ }
311
+
312
+
313
+ def _parse_injectable_params (
314
+ func : Callable [..., t .Any ],
315
+ ) -> tuple [list [tuple [str , DependencyExprOrComposed ]], dict [str , DependencyExprOrComposed ]]:
316
+ positional_or_keyword_params : list [tuple [str , DependencyExprOrComposed ]] = []
317
+ keyword_only_params : dict [str , DependencyExprOrComposed ] = {}
294
318
295
319
parameters = inspect .signature (
296
320
func , locals = {m : sys .modules [m ] for m in utils .ANNOTATION_PARSE_LOCAL_INCLUDE_MODULES }, eval_str = True
297
321
).parameters
298
322
for parameter in parameters .values ():
323
+ annotation = parameter .annotation
299
324
if (
300
325
# If the parameter has no annotation
301
- parameter . annotation is inspect .Parameter .empty
326
+ annotation is inspect .Parameter .empty
302
327
# If the parameter is not positional-or-keyword or keyword-only
303
328
or parameter .kind
304
329
in (inspect .Parameter .POSITIONAL_ONLY , inspect .Parameter .VAR_POSITIONAL , inspect .Parameter .VAR_KEYWORD )
@@ -309,12 +334,17 @@ def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[tuple[str
309
334
positional_or_keyword_params .append ((parameter .name , CANNOT_INJECT ))
310
335
continue
311
336
312
- expr = conditions .DependencyExpression .create (parameter .annotation )
337
+ if compose ._is_compose_class (annotation ):
338
+ setattr (annotation , compose ._DEPS_ATTR , _parse_composed_dependencies (annotation ))
339
+
340
+ item = (
341
+ annotation if compose ._is_compose_class (annotation ) else conditions .DependencyExpression .create (annotation )
342
+ )
313
343
if parameter .kind is inspect .Parameter .POSITIONAL_OR_KEYWORD :
314
- positional_or_keyword_params .append ((parameter .name , expr ))
344
+ positional_or_keyword_params .append ((parameter .name , item ))
315
345
else :
316
346
# It has to be a keyword-only parameter
317
- keyword_only_params [parameter .name ] = expr
347
+ keyword_only_params [parameter .name ] = item
318
348
319
349
return positional_or_keyword_params , keyword_only_params
320
350
@@ -378,7 +408,7 @@ def _codegen_dependency_func(
378
408
) -> DependencyResolverFunctionT :
379
409
pos_or_kw , kw_only = _parse_injectable_params (self ._func )
380
410
381
- exec_globals : dict [str , conditions . DependencyExpression [ t . Any ] ] = {}
411
+ exec_globals : dict [str , DependencyExprOrComposed ] = {}
382
412
383
413
def gen_random_name () -> str :
384
414
while True :
@@ -389,30 +419,46 @@ def gen_random_name() -> str:
389
419
# this can never happen but pycharm is being stupid
390
420
return ""
391
421
422
+ def resolver (dependency : DependencyExprOrComposed , refname : str ) -> t .Any :
423
+ if not compose ._is_compose_class (dependency ):
424
+ return f"await { refname } .resolve(container)"
425
+
426
+ init_params : list [str ] = []
427
+ subdeps = t .cast (
428
+ "dict[str, conditions.DependencyExpression[t.Any]]" , getattr (dependency , compose ._DEPS_ATTR )
429
+ )
430
+ for subdep_name , subdep in subdeps .items ():
431
+ exec_globals [ident := gen_random_name ()] = subdep
432
+ init_params .append (f"{ subdep_name } =await { ident } .resolve(container)" )
433
+
434
+ return f"{ refname } ({ ',' .join (init_params )} )"
435
+
392
436
fn_lines = ["arglen = len(args)" , "new_kwargs = {}; new_kwargs.update(kwargs)" ]
393
437
394
438
for i , tup in enumerate (pos_or_kw ):
395
- name , type_expr = tup
396
- if type_expr is CANNOT_INJECT :
439
+ name , dep = tup
440
+ if dep is CANNOT_INJECT :
397
441
continue
398
442
399
- exec_globals [n := gen_random_name ()] = type_expr
443
+ exec_globals [n := gen_random_name ()] = dep
444
+
400
445
fn_lines .append (
401
- f"if '{ name } ' not in new_kwargs and arglen < ({ i + 1 } - offset): new_kwargs['{ name } '] = await { n } .resolve(container) " # noqa: E501
446
+ f"if '{ name } ' not in new_kwargs and arglen < ({ i + 1 } - offset): new_kwargs['{ name } '] = { resolver ( dep , n ) } " # noqa: E501
402
447
)
403
448
404
- for name , type_expr in kw_only .items ():
405
- if type_expr is CANNOT_INJECT :
449
+ for name , dep in kw_only .items ():
450
+ if dep is CANNOT_INJECT :
406
451
continue
407
452
408
- exec_globals [n := gen_random_name ()] = type_expr
409
- fn_lines .append (f"if '{ name } ' not in new_kwargs: new_kwargs['{ name } '] = await { n } .resolve(container) " )
453
+ exec_globals [n := gen_random_name ()] = dep
454
+ fn_lines .append (f"if '{ name } ' not in new_kwargs: new_kwargs['{ name } '] = { resolver ( dep , n ) } " )
410
455
411
456
fn_lines .append ("return new_kwargs" )
412
457
413
- fn = "async def resolve_dependencies(container, offset, args, kwargs):\n " + "\n " .join (
458
+ fn = "async def resolve_dependencies(container,offset,args,kwargs):\n " + "\n " .join (
414
459
textwrap .indent (line , " " ) for line in fn_lines
415
460
)
461
+
416
462
exec (fn , exec_globals , (generated_locals := {}))
417
463
return generated_locals ["resolve_dependencies" ] # type: ignore[reportReturnType]
418
464
0 commit comments