Python Dependency Injection
Dependency Injection (DI) is a software design pattern that lets you pass instances of a service rather than creating them directly within a class or function. The FastAPI framework provides a neat way to pass dependencies using Python’s type hints and the Depends
function. The fast-depends package extracts the FastAPI code and strips out all the web framework-specific code into a small library that can be used in any Python project.
Here is an example of how to use dependency injection using the custom single-file implementation of dependency injection at the bottom of this page.
import typing as t
def get_settings() -> Settings:
return Settings()
@inject
def get_db(
settings: t.Annotated[Settings, Depends(get_settings)],
) -> DatabaseConnection:
db = DatabaseConnection(settings)
return db
@inject
def compute_something(
db: t.Annotated[DatabaseConnection, Depends(get_db)]
):
# Use db connection here
pass
# Call the function with dependencies injected
result = compute_something() # Automatically resolves dependencies
db = get_db() # You can also call the dependency directly
result = compute_something(db=db) # Pass the db directly if needed
The get_*
functions return instances of the required services, and the Depends
function is used to declare dependencies. The inject
decorator enables dependency injection for the function by automatically resolving the dependencies when the function is called. The default behavior is to cache the results of dependencies, so if a dependency is called multiple times, it will return the cached result instead of executing the function again (e.g. Settings and DB connection are created once and reused).
This pattern allows for better separation of concerns (instantiatiating dependencies outside of the function) and makes unit testing easier. For example, we can create different implementations of the database and settings classes and pass them to the compute_something
function without changing its signature.
Below is the full implementation of dependency injection as a standalone Python module. You can copy it into a depends.py
file and use it in your projects.
Key Components:
Depends
: Marks a parameter as a dependency to be injectedinject
: Decorator that enables dependency injection for a sync or async functionCustomField
: Base class for creating custom parameter extractorsdependency_provider
: Global provider for managing dependency overrides
Features:
- Automatic dependency resolution and injection
- Support for both sync and async functions
- Dependency caching (can be disabled per dependency)
- Type validation and casting using Pydantic
- Context manager support for resource management
- Custom field extractors for complex parameter handling
- Dependency override system for testing and configuration
Table of Contents:
Usage
Simple Dependency Injection
Dependencies can be injected into functions using the Depends
class and the inject
decorator.
def get_database():
return "database_connection"
def get_user(db: str = Depends(get_database)):
return f"user_from_{db}"
@inject
def handler(user: str = Depends(get_user)):
return f"Hello, {user}!"
result = handler() # "Hello, user_from_database_connection!"
print(result)
Async Dependencies
async def get_async_db():
await asyncio.sleep(0.1)
return "async_database"
@inject
async def async_handler(db: str = Depends(get_async_db)):
return f"DB: {db}"
result = await async_handler() # "DB: async_database"
print(result)
Dependency Caching
Caching is enabled by default, meaning that if a dependency is called multiple times within the same request, it will return the cached result instead of executing the function again.
call_count = 0
def expensive_operation():
global call_count
call_count += 1
return f"result_{call_count}"
@inject
def handler(
a: str = Depends(expensive_operation),
b: str = Depends(expensive_operation), # Same result due to caching
):
return f"{a}, {b}"
result = handler() # "result_1, result_1"
print(result) # Output: "result_1, result_1"
Disable Caching
The use_cache parameter can be set to False to disable caching for specific dependencies.
@inject
def handler(
a: str = Depends(expensive_operation, use_cache=False),
b: str = Depends(expensive_operation, use_cache=False),
):
return f"{a}, {b}"
result = handler() # "result_1, result_2"
print(result) # Output: "result_1, result_2"
Custom Fields
class HeaderExtractor(CustomField):
def __init__(self, header_name: str):
super().__init__()
self.header_name = header_name
def use(self, **kwargs: t.Any) -> t.Dict[str, t.Any]:
# Extract from some global context
kwargs[self.param_name] = f"header_value_for_{self.header_name}"
return kwargs
@inject
def api_handler(
auth: str = HeaderExtractor("Authorization"),
content_type: str = HeaderExtractor("Content-Type"),
):
return {"auth": auth, "content_type": content_type}
Dependency Overrides
def original_dep():
return "original"
def override_dep():
return "overridden"
@inject
def handler(value: str = Depends(original_dep)):
return value
# Override dependency
dependency_provider.override(original_dep, override_dep)
result = handler() # "overridden"
# Clear overrides
dependency_provider.clear()
result = handler() # "original"
Generator Dependencies (Context Managers)
A dependency can be a generator function, which allows for resource management (like opening and closing database connections).
In this example, the database_session
function is a context manager that opens a database connection and closes it after use.
The Depends
decorator will handle the context management automatically.
def database_session() -> t.Generator[str, None, None]:
print("Opening connection")
yield "db_session"
print("Closing connection")
@inject
def handler(db: str = Depends(database_session)):
return f"Using {db}"
result = handler()
print(result) # "Using db_session"
# Output: Opening connection
# Output: Closing connection
# Returns: "Using db_session"
Type Validation with Pydantic
Arguments can be annotated with Pydantic models for automatic validation and casting. In this example, the get_user_id
function returns a string, but it will be cast to an integer when injected. The Depends
decorator will handle the type casting automatically if cast=True
is set (default behavior). The Annotated
type from typing_extensions
is used to specify the type and the dependency.
from typing import Annotated
def get_user_id() -> int:
return "123" # Wrong type!
@inject
def handler(user_id: Annotated[int, Depends(get_user_id)]):
return f"User ID: {user_id}"
result = handler() # user_id will be cast to int(123)
Disable Type Casting
If you want to disable type casting for a specific dependency, you can set cast=False
in the inject
decorator.
@inject(cast=False)
def handler(user_id: Annotated[int, Depends(get_user_id)]):
return f"User ID type: {type(user_id)}"
result = handler() # user_id remains as string "123"
Implementation Code
Requirements:
- Python 3.11+
anyio
for async I/O operationspydantic
for data validation and settings managementtyping_extensions
for type annotations
import anyio
import asyncio
import inspect
import functools
import typing as t
from abc import ABC
from copy import deepcopy
from itertools import chain
from pydantic import ConfigDict
from collections import namedtuple
from functools import wraps, partial
from pydantic import BaseModel, create_model
from typing_extensions import Annotated, ParamSpec, get_args, get_origin
from pydantic._internal._typing_extra import try_eval_type as evaluate_forwardref
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
P = ParamSpec("P")
T = t.TypeVar("T")
Cls = t.TypeVar("Cls", bound="CustomField")
default_pydantic_config = {"arbitrary_types_allowed": True}
def get_config_base(config_data: t.Optional[ConfigDict] = None) -> ConfigDict:
return config_data or ConfigDict(**default_pydantic_config)
def get_aliases(model: t.Type[BaseModel]) -> t.Tuple[str, ...]:
return tuple(f.alias or name for name, f in model.model_fields.items())
class Depends:
"""Mark a parameter as a dependency to be injected."""
use_cache: bool
cast: bool
def __init__(
self,
dependency: t.Callable[..., t.Any],
*,
use_cache: bool = True,
cast: bool = True,
) -> None:
self.dependency = dependency
self.use_cache = use_cache
self.cast = cast
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})"
class CustomField(ABC):
"""Base class for custom field extractors."""
param_name: t.Optional[str]
cast: bool
required: bool
__slots__ = (
"cast",
"param_name",
"required",
"field",
)
def __init__(
self,
*,
cast: bool = True,
required: bool = True,
) -> None:
self.cast = cast
self.param_name = None
self.required = required
self.field = False
def set_param_name(self: Cls, name: str) -> Cls:
self.param_name = name
return self
def use(self, /, **kwargs: t.Any) -> t.Dict[str, t.Any]:
assert self.param_name, "You should specify `param_name` before using"
return kwargs
def use_field(self, kwargs: t.Dict[str, t.Any]) -> None:
raise NotImplementedError("You should implement `use_field` method.")
# Provider for dependency overrides
class Provider:
"""Provider for dependency overrides."""
dependency_overrides: t.Dict[t.Callable[..., t.Any], t.Callable[..., t.Any]]
def __init__(self) -> None:
self.dependency_overrides = {}
def clear(self) -> None:
self.dependency_overrides = {}
def override(
self,
original: t.Callable[..., t.Any],
override: t.Callable[..., t.Any],
) -> None:
self.dependency_overrides[original] = override
@contextmanager
def scope(
self,
original: t.Callable[..., t.Any],
override: t.Callable[..., t.Any],
) -> t.Iterator[None]:
self.dependency_overrides[original] = override
yield
self.dependency_overrides.pop(original, None)
dependency_provider = Provider()
def is_coroutine_callable(call: t.Callable[..., t.Any]) -> bool:
if inspect.isclass(call):
return False
if asyncio.iscoroutinefunction(call):
return True
dunder_call = getattr(call, "__call__", None)
return asyncio.iscoroutinefunction(dunder_call)
def is_gen_callable(call: t.Callable[..., t.Any]) -> bool:
if inspect.isgeneratorfunction(call):
return True
dunder_call = getattr(call, "__call__", None)
return inspect.isgeneratorfunction(dunder_call)
def is_async_gen_callable(call: t.Callable[..., t.Any]) -> bool:
if inspect.isasyncgenfunction(call):
return True
dunder_call = getattr(call, "__call__", None)
return inspect.isasyncgenfunction(dunder_call)
async def run_async(
func: t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
if is_coroutine_callable(func):
return await t.cast(t.Callable[P, t.Awaitable[T]], func)(*args, **kwargs)
else:
return await run_in_threadpool(t.cast(t.Callable[P, T], func), *args, **kwargs)
async def run_in_threadpool(func: t.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
if kwargs:
func = functools.partial(func, **kwargs)
return await anyio.to_thread.run_sync(func, *args)
def get_typed_annotation(
annotation: t.Any,
globalns: t.Dict[str, t.Any],
locals: t.Dict[str, t.Any],
) -> t.Any:
if isinstance(annotation, str):
annotation = t.ForwardRef(annotation)
if isinstance(annotation, t.ForwardRef):
annotation = evaluate_forwardref(annotation, globalns, locals)
if get_origin(annotation) is Annotated and (args := get_args(annotation)):
solved_args = [get_typed_annotation(x, globalns, locals) for x in args]
annotation.__origin__, annotation.__metadata__ = solved_args[0], tuple(solved_args[1:])
return annotation
def collect_outer_stack_locals() -> t.Dict[str, t.Any]:
"""
Collect local variables from outer stack frames to resolve type annotations.
This function walks up the call stack and collects all local variables
from frames outside of this module. This is necessary for resolving
forward references and string annotations that might reference variables
defined in the calling code.
"""
frame = inspect.currentframe()
frames: t.List[t.Any] = []
current_filename = __file__ if "__file__" in globals() else None
while frame is not None:
frame_filename = frame.f_code.co_filename
# Skip frames from this module to avoid internal variables
if current_filename is None or frame_filename != current_filename:
frames.append(frame)
frame = frame.f_back
locals = {}
for f in frames[::-1]:
locals.update(f.f_locals)
return locals
def get_typed_signature(call: t.Callable[..., t.Any]) -> t.Tuple[inspect.Signature, t.Any]:
signature = inspect.signature(call)
locals = collect_outer_stack_locals()
call = inspect.unwrap(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(
param.annotation,
globalns,
locals,
),
)
for param in signature.parameters.values()
]
return inspect.Signature(typed_params), get_typed_annotation(
signature.return_annotation,
globalns,
locals,
)
async def solve_generator_async(
*sub_args: t.Any, call: t.Callable[..., t.Any], stack: AsyncExitStack, **sub_values: t.Any
) -> t.Any:
if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call):
cm = asynccontextmanager(call)(*sub_args, **sub_values)
return await stack.enter_async_context(cm)
def solve_generator_sync(
*sub_args: t.Any, call: t.Callable[..., t.Any], stack: ExitStack, **sub_values: t.Any
) -> t.Any:
cm = contextmanager(call)(*sub_args, **sub_values)
return stack.enter_context(cm)
@asynccontextmanager
async def contextmanager_in_threadpool(
cm: t.ContextManager[T],
) -> t.AsyncGenerator[T, None]:
exit_limiter = anyio.CapacityLimiter(1)
try:
yield await run_in_threadpool(cm.__enter__)
except Exception as e:
ok = bool(await anyio.to_thread.run_sync(cm.__exit__, type(e), e, None, limiter=exit_limiter))
if not ok:
raise e
else:
await anyio.to_thread.run_sync(cm.__exit__, None, None, None, limiter=exit_limiter)
async def async_map(func: t.Callable[..., T], async_iterable: t.AsyncIterable[t.Any]) -> t.AsyncIterable[T]:
async for i in async_iterable:
yield func(i)
class solve_wrapper(partial[T]):
call: t.Callable[..., T]
def __new__(
cls,
func: t.Callable[..., T],
*args: t.Any,
**kwargs: t.Any,
) -> "solve_wrapper[T]":
assert len(args) > 0, "Model should be passed as first argument"
model = args[0]
self = super().__new__(cls, func, *args, **kwargs)
self.call = model.call
return self
# Core Models
PriorityPair = namedtuple("PriorityPair", ("call", "dependencies_number", "dependencies_names"))
class ResponseModel(BaseModel, t.Generic[T]):
response: T
class CallModel(t.Generic[P, T]):
"""Model representing a callable with dependency injection."""
call: t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]]
is_async: bool
is_generator: bool
model: t.Optional[t.Type[BaseModel]]
response_model: t.Optional[t.Type[ResponseModel[T]]]
params: t.Dict[str, t.Tuple[t.Any, t.Any]]
alias_arguments: t.Tuple[str, ...]
dependencies: t.Dict[str, "CallModel[..., t.Any]"]
extra_dependencies: t.Iterable["CallModel[..., t.Any]"]
sorted_dependencies: t.Tuple[t.Tuple["CallModel[..., t.Any]", int], ...]
custom_fields: t.Dict[str, CustomField]
keyword_args: t.Tuple[str, ...]
positional_args: t.Tuple[str, ...]
var_positional_arg: t.Optional[str]
var_keyword_arg: t.Optional[str]
use_cache: bool
cast: bool
__slots__ = (
"call",
"is_async",
"is_generator",
"model",
"response_model",
"params",
"alias_arguments",
"keyword_args",
"positional_args",
"var_positional_arg",
"var_keyword_arg",
"dependencies",
"extra_dependencies",
"sorted_dependencies",
"custom_fields",
"use_cache",
"cast",
)
@property
def call_name(self) -> str:
call = inspect.unwrap(self.call)
return getattr(call, "__name__", type(call).__name__)
@property
def flat_params(self) -> t.Dict[str, t.Tuple[t.Any, t.Any]]:
params = self.params
for d in (*self.dependencies.values(), *self.extra_dependencies):
params.update(d.flat_params)
return params
@property
def flat_dependencies(
self,
) -> t.Dict[
t.Callable[..., t.Any],
t.Tuple["CallModel[..., t.Any]", t.Tuple[t.Callable[..., t.Any], ...]],
]:
flat: t.Dict[
t.Callable[..., t.Any],
t.Tuple[CallModel[..., t.Any], t.Tuple[t.Callable[..., t.Any], ...]],
] = {}
for i in (*self.dependencies.values(), *self.extra_dependencies):
flat.update(
{
i.call: (
i,
tuple(j.call for j in i.dependencies.values()),
)
}
)
flat.update(i.flat_dependencies)
return flat
def __init__(
self,
/,
call: t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
model: t.Optional[t.Type[BaseModel]],
params: t.Dict[str, t.Tuple[t.Any, t.Any]],
response_model: t.Optional[t.Type[ResponseModel[T]]] = None,
use_cache: bool = True,
cast: bool = True,
is_async: bool = False,
is_generator: bool = False,
dependencies: t.Optional[t.Dict[str, "CallModel[..., t.Any]"]] = None,
extra_dependencies: t.Optional[t.Iterable["CallModel[..., t.Any]"]] = None,
keyword_args: t.Optional[t.List[str]] = None,
positional_args: t.Optional[t.List[str]] = None,
var_positional_arg: t.Optional[str] = None,
var_keyword_arg: t.Optional[str] = None,
custom_fields: t.Optional[t.Dict[str, CustomField]] = None,
):
self.call = call
self.model = model
if model:
self.alias_arguments = get_aliases(model)
else:
self.alias_arguments = ()
self.keyword_args = tuple(keyword_args or ())
self.positional_args = tuple(positional_args or ())
self.var_positional_arg = var_positional_arg
self.var_keyword_arg = var_keyword_arg
self.response_model = response_model
self.use_cache = use_cache
self.cast = cast
self.is_async = is_async or is_coroutine_callable(call) or is_async_gen_callable(call)
self.is_generator = is_generator or is_gen_callable(call) or is_async_gen_callable(call)
self.dependencies = dependencies or {}
self.extra_dependencies = extra_dependencies or ()
self.custom_fields = custom_fields or {}
sorted_dep: t.List[CallModel[..., t.Any]] = []
flat = self.flat_dependencies
for calls in flat.values():
_sort_dep(sorted_dep, calls, flat)
self.sorted_dependencies = tuple((i, len(i.sorted_dependencies)) for i in sorted_dep if i.use_cache)
for name in chain(self.dependencies.keys(), self.custom_fields.keys()):
params.pop(name, None)
self.params = params
def _solve(
self,
/,
*args: t.Tuple[t.Any, ...],
cache_dependencies: t.Dict[
t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
T,
],
dependency_overrides: t.Optional[
t.Dict[
t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
]
] = None,
**kwargs: t.Dict[str, t.Any],
) -> t.Generator[
t.Tuple[t.Sequence[t.Any], t.Dict[str, t.Any], t.Callable[..., t.Any]],
t.Any,
T,
]:
if dependency_overrides:
call = dependency_overrides.get(self.call, self.call)
assert self.is_async or not is_coroutine_callable(call), (
f"You cannot use async dependency `{self.call_name}` at sync main"
)
else:
call = self.call
if self.use_cache and call in cache_dependencies:
return cache_dependencies[call]
kw: t.Dict[str, t.Any] = {}
for arg in self.keyword_args:
if (v := kwargs.pop(arg, inspect.Parameter.empty)) is not inspect.Parameter.empty:
kw[arg] = v
if self.var_keyword_arg is not None:
kw[self.var_keyword_arg] = kwargs
else:
kw.update(kwargs)
for arg in self.positional_args:
if args:
kw[arg], args = args[0], args[1:]
else:
break
keyword_args: t.Iterable[str]
if self.var_positional_arg is not None:
kw[self.var_positional_arg] = args
keyword_args = self.keyword_args
else:
keyword_args = self.keyword_args + self.positional_args
for arg in keyword_args:
if not self.cast and arg in self.params:
kw[arg] = self.params[arg][1]
if not args:
break
if arg not in self.dependencies:
kw[arg], args = args[0], args[1:]
solved_kw: t.Dict[str, t.Any]
solved_kw = yield args, kw, call
args_: t.Sequence[t.Any]
if self.cast:
assert self.model, "Cast should be used only with model"
casted_model = self.model(**solved_kw)
kwargs_ = {arg: getattr(casted_model, arg, solved_kw.get(arg)) for arg in keyword_args}
if self.var_keyword_arg:
kwargs_.update(getattr(casted_model, self.var_keyword_arg, {}))
if self.var_positional_arg is not None:
args_ = [getattr(casted_model, arg, solved_kw.get(arg)) for arg in self.positional_args]
args_.extend(getattr(casted_model, self.var_positional_arg, ()))
else:
args_ = ()
else:
kwargs_ = {arg: solved_kw.get(arg) for arg in keyword_args}
if self.var_positional_arg is not None:
args_ = tuple(map(solved_kw.get, self.positional_args))
else:
args_ = ()
response: T
response = yield args_, kwargs_, call
if self.cast and not self.is_generator:
response = self._cast_response(response)
if self.use_cache:
cache_dependencies[call] = response
return response
def _cast_response(self, /, value: t.Any) -> t.Any:
if self.response_model is not None:
return self.response_model(response=value).response
else:
return value
def solve(
self,
/,
*args: t.Tuple[t.Any, ...],
stack: ExitStack,
cache_dependencies: t.Dict[
t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
T,
],
dependency_overrides: t.Optional[
t.Dict[
t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
]
] = None,
nested: bool = False,
**kwargs: t.Dict[str, t.Any],
) -> T:
cast_gen = self._solve(
*args,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
**kwargs,
)
try:
args, kwargs, _ = next(cast_gen)
except StopIteration as e:
cached_value: T = e.value
return cached_value
# Heat cache and solve extra dependencies
for dep, _ in self.sorted_dependencies:
dep.solve(
*args,
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
# Always get from cache
for dep in self.extra_dependencies:
dep.solve(
*args,
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
for dep_arg, dep in self.dependencies.items():
kwargs[dep_arg] = dep.solve(
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
for custom in self.custom_fields.values():
if custom.field:
custom.use_field(kwargs)
else:
kwargs = custom.use(**kwargs)
final_args, final_kwargs, call = cast_gen.send(kwargs)
if self.is_generator and nested:
response = solve_generator_sync(
*final_args,
call=call,
stack=stack,
**final_kwargs,
)
else:
response = call(*final_args, **final_kwargs)
try:
cast_gen.send(response)
except StopIteration as e:
value: T = e.value
if not self.cast or nested or not self.is_generator:
return value
else:
return map(self._cast_response, value)
raise AssertionError("unreachable")
async def asolve(
self,
/,
*args: t.Tuple[t.Any, ...],
stack: AsyncExitStack,
cache_dependencies: t.Dict[
t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
T,
],
dependency_overrides: t.Optional[
t.Dict[
t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
]
] = None,
nested: bool = False,
**kwargs: t.Dict[str, t.Any],
) -> T:
cast_gen = self._solve(
*args,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
**kwargs,
)
try:
args, kwargs, _ = next(cast_gen)
except StopIteration as e:
cached_value: T = e.value
return cached_value
# Heat cache and solve extra dependencies
dep_to_solve: t.List[t.Callable[..., t.Awaitable[t.Any]]] = []
try:
async with anyio.create_task_group() as tg:
for dep, subdep in self.sorted_dependencies:
solve = partial(
dep.asolve,
*args,
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
if not subdep:
tg.start_soon(solve)
else:
dep_to_solve.append(solve)
except Exception as e:
raise e
for i in dep_to_solve:
await i()
# Always get from cache
for dep in self.extra_dependencies:
await dep.asolve(
*args,
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
for dep_arg, dep in self.dependencies.items():
kwargs[dep_arg] = await dep.asolve(
stack=stack,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
nested=True,
**kwargs,
)
custom_to_solve: t.List[CustomField] = []
try:
async with anyio.create_task_group() as tg:
for custom in self.custom_fields.values():
if custom.field:
tg.start_soon(run_async, custom.use_field, kwargs)
else:
custom_to_solve.append(custom)
except Exception as e:
raise e
for j in custom_to_solve:
kwargs = await run_async(j.use, **kwargs)
final_args, final_kwargs, call = cast_gen.send(kwargs)
if self.is_generator and nested:
response = await solve_generator_async(
*final_args,
call=call,
stack=stack,
**final_kwargs,
)
else:
response = await run_async(call, *final_args, **final_kwargs)
try:
cast_gen.send(response)
except StopIteration as e:
value: T = e.value
if not self.cast or nested or not self.is_generator:
return value
else:
return async_map(self._cast_response, value)
raise AssertionError("unreachable")
def _sort_dep(
collector: t.List["CallModel[..., t.Any]"],
items: t.Tuple[
"CallModel[..., t.Any]",
t.Tuple[t.Callable[..., t.Any], ...],
],
flat: t.Dict[
t.Callable[..., t.Any],
t.Tuple[
"CallModel[..., t.Any]",
t.Tuple[t.Callable[..., t.Any], ...],
],
],
) -> None:
model, calls = items
if model in collector:
return
if not calls:
position = -1
else:
for i in calls:
sub_model, _ = flat[i]
if sub_model not in collector:
_sort_dep(collector, flat[i], flat)
position = max(collector.index(flat[i][0]) for i in calls)
collector.insert(position + 1, model)
CUSTOM_ANNOTATIONS = (Depends, CustomField)
def build_call_model(
call: t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
*,
cast: bool = True,
use_cache: bool = True,
is_sync: t.Optional[bool] = None,
extra_dependencies: t.Sequence[Depends] = (),
pydantic_config: t.Optional[ConfigDict] = None,
) -> CallModel[P, T]:
"""Build a CallModel from a callable."""
name = getattr(call, "__name__", type(call).__name__)
is_call_async = is_coroutine_callable(call) or is_async_gen_callable(call)
if is_sync is None:
is_sync = not is_call_async
else:
assert not (is_sync and is_call_async), f"You cannot use async dependency `{name}` at sync main"
typed_params, return_annotation = get_typed_signature(call)
if (is_call_generator := is_gen_callable(call) or is_async_gen_callable(call)) and (
return_args := get_args(return_annotation)
):
return_annotation = return_args[0]
class_fields: t.Dict[str, t.Tuple[t.Any, t.Any]] = {}
dependencies: t.Dict[str, CallModel[..., t.Any]] = {}
custom_fields: t.Dict[str, CustomField] = {}
positional_args: t.List[str] = []
keyword_args: t.List[str] = []
var_positional_arg: t.Optional[str] = None
var_keyword_arg: t.Optional[str] = None
for param_name, param in typed_params.parameters.items():
dep: t.Optional[Depends] = None
custom: t.Optional[CustomField] = None
if param.annotation is inspect.Parameter.empty:
annotation = t.Any
elif get_origin(param.annotation) is Annotated:
annotated_args = get_args(param.annotation)
type_annotation = annotated_args[0]
custom_annotations = []
regular_annotations = []
for arg in annotated_args[1:]:
if isinstance(arg, CUSTOM_ANNOTATIONS):
custom_annotations.append(arg)
else:
regular_annotations.append(arg)
assert len(custom_annotations) <= 1, (
f"Cannot specify multiple `Annotated` Custom arguments for `{param_name}`!"
)
next_custom = next(iter(custom_annotations), None)
if next_custom is not None:
if isinstance(next_custom, Depends):
dep = next_custom
elif isinstance(next_custom, CustomField):
custom = deepcopy(next_custom)
else:
raise AssertionError("unreachable")
if regular_annotations:
annotation = param.annotation
else:
annotation = type_annotation
else:
annotation = param.annotation
else:
annotation = param.annotation
default: t.Any
if param.kind == inspect.Parameter.VAR_POSITIONAL:
default = ()
var_positional_arg = param_name
elif param.kind == inspect.Parameter.VAR_KEYWORD:
default = {}
var_keyword_arg = param_name
elif param.default is inspect.Parameter.empty:
default = Ellipsis
else:
default = param.default
if isinstance(default, Depends):
if dep:
raise AssertionError(
"You can not use both `Depends` with `Annotated` and a default",
)
dep, default = default, Ellipsis
elif isinstance(default, CustomField):
if custom:
raise AssertionError(
"You can not use both `CustomField` with `Annotated` and a default",
)
custom, default = default, Ellipsis
else:
class_fields[param_name] = (annotation, default)
if dep:
if not cast:
dep.cast = False
if isinstance(dep.dependency, solve_wrapper):
dep.dependency = dep.dependency.call
dependencies[param_name] = build_call_model(
dep.dependency,
cast=dep.cast,
use_cache=dep.use_cache,
is_sync=is_sync,
pydantic_config=pydantic_config,
)
if dep.cast is True:
class_fields[param_name] = (annotation, Ellipsis)
keyword_args.append(param_name)
elif custom:
assert not (is_sync and is_coroutine_callable(custom.use)), (
f"You cannot use async custom field `{type(custom).__name__}` at sync `{name}`"
)
custom.set_param_name(param_name)
custom_fields[param_name] = custom
if custom.cast is False:
annotation = t.Any
if custom.required:
class_fields[param_name] = (annotation, default)
else:
class_fields[param_name] = class_fields.get(param_name, (t.Optional[annotation], None))
keyword_args.append(param_name)
else:
if param.kind is param.KEYWORD_ONLY:
keyword_args.append(param_name)
elif param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
positional_args.append(param_name)
func_model = create_model(
name,
__config__=get_config_base(pydantic_config),
**class_fields,
)
response_model: t.Optional[t.Type[ResponseModel[T]]] = None
if cast and return_annotation and return_annotation is not inspect.Parameter.empty:
response_model = create_model(
"ResponseModel",
__config__=get_config_base(pydantic_config),
response=(return_annotation, Ellipsis),
)
return CallModel(
call=call,
model=func_model,
response_model=response_model,
params=class_fields,
cast=cast,
use_cache=use_cache,
is_async=is_call_async,
is_generator=is_call_generator,
dependencies=dependencies,
custom_fields=custom_fields,
positional_args=positional_args,
keyword_args=keyword_args,
var_positional_arg=var_positional_arg,
var_keyword_arg=var_keyword_arg,
extra_dependencies=[
build_call_model(
d.dependency,
cast=d.cast,
use_cache=d.use_cache,
is_sync=is_sync,
pydantic_config=pydantic_config,
)
for d in extra_dependencies
],
)
class _InjectWrapper(t.Protocol[P, T]):
def __call__(
self,
func: t.Callable[P, T],
model: t.Optional[CallModel[P, T]] = None,
) -> t.Callable[P, T]: ...
@t.overload
def inject(
func: None,
*,
cast: bool = True,
extra_dependencies: t.Sequence[Depends] = (),
pydantic_config: t.Optional[ConfigDict] = None,
dependency_overrides_provider: t.Optional[t.Any] = dependency_provider,
wrap_model: t.Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> _InjectWrapper[P, T]: ...
@t.overload
def inject(
func: t.Callable[P, T],
*,
cast: bool = True,
extra_dependencies: t.Sequence[Depends] = (),
pydantic_config: t.Optional[ConfigDict] = None,
dependency_overrides_provider: t.Optional[t.Any] = dependency_provider,
wrap_model: t.Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> t.Callable[P, T]: ...
def inject(
func: t.Optional[t.Callable[P, T]] = None,
*,
cast: bool = True,
extra_dependencies: t.Sequence[Depends] = (),
pydantic_config: t.Optional[ConfigDict] = None,
dependency_overrides_provider: t.Optional[t.Any] = dependency_provider,
wrap_model: t.Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> t.Union[t.Callable[P, T], _InjectWrapper[P, T]]:
"""Decorator to inject dependencies into a function."""
decorator = _wrap_inject(
dependency_overrides_provider=dependency_overrides_provider,
wrap_model=wrap_model,
extra_dependencies=extra_dependencies,
cast=cast,
pydantic_config=pydantic_config,
)
if func is None:
return decorator
else:
return decorator(func)
def _wrap_inject(
dependency_overrides_provider: t.Optional[t.Any],
wrap_model: t.Callable[[CallModel[P, T]], CallModel[P, T]],
extra_dependencies: t.Sequence[Depends],
cast: bool,
pydantic_config: t.Optional[ConfigDict],
) -> _InjectWrapper[P, T]:
if (
dependency_overrides_provider
and getattr(dependency_overrides_provider, "dependency_overrides", None) is not None
):
overrides = dependency_overrides_provider.dependency_overrides
else:
overrides = None
def func_wrapper(
func: t.Callable[P, T],
model: t.Optional[CallModel[P, T]] = None,
) -> t.Callable[P, T]:
if model is None:
real_model = wrap_model(
build_call_model(
call=func,
extra_dependencies=extra_dependencies,
cast=cast,
pydantic_config=pydantic_config,
)
)
else:
real_model = model
if real_model.is_async:
injected_wrapper: t.Callable[P, T]
if real_model.is_generator:
injected_wrapper = solve_wrapper(solve_async_gen, real_model, overrides)
else:
@wraps(func)
async def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
async with AsyncExitStack() as stack:
r = await real_model.asolve(
*args,
stack=stack,
dependency_overrides=overrides,
cache_dependencies={},
nested=False,
**kwargs,
)
return r
raise AssertionError("unreachable")
else:
if real_model.is_generator:
injected_wrapper = solve_wrapper(solve_gen, real_model, overrides)
else:
@wraps(func)
def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
with ExitStack() as stack:
r = real_model.solve(
*args,
stack=stack,
dependency_overrides=overrides,
cache_dependencies={},
nested=False,
**kwargs,
)
return r
raise AssertionError("unreachable")
return injected_wrapper
return func_wrapper
class solve_async_gen:
_iter: t.Optional[t.AsyncIterator[t.Any]] = None
def __init__(
self,
model: "CallModel[..., t.Any]",
overrides: t.Optional[t.Any],
*args: t.Any,
**kwargs: t.Any,
):
self.call = model
self.args = args
self.kwargs = kwargs
self.overrides = overrides
def __aiter__(self) -> "solve_async_gen":
self.stack = AsyncExitStack()
return self
async def __anext__(self) -> t.Any:
if self._iter is None:
stack = self.stack = AsyncExitStack()
await self.stack.__aenter__()
self._iter = t.cast(
t.AsyncIterator[t.Any],
(
await self.call.asolve(
*self.args,
stack=stack,
dependency_overrides=self.overrides,
cache_dependencies={},
nested=False,
**self.kwargs,
)
).__aiter__(),
)
try:
r = await self._iter.__anext__()
except StopAsyncIteration as e:
await self.stack.__aexit__(None, None, None)
raise e
else:
return r
class solve_gen:
_iter: t.Optional[t.Iterator[t.Any]] = None
def __init__(
self,
model: "CallModel[..., t.Any]",
overrides: t.Optional[t.Any],
*args: t.Any,
**kwargs: t.Any,
):
self.call = model
self.args = args
self.kwargs = kwargs
self.overrides = overrides
def __iter__(self) -> "solve_gen":
self.stack = ExitStack()
return self
def __next__(self) -> t.Any:
if self._iter is None:
stack = self.stack = ExitStack()
self.stack.__enter__()
self._iter = t.cast(
t.Iterator[t.Any],
iter(
self.call.solve(
*self.args,
stack=stack,
dependency_overrides=self.overrides,
cache_dependencies={},
nested=False,
**self.kwargs,
)
),
)
try:
r = next(self._iter)
except StopIteration as e:
self.stack.__exit__(None, None, None)
raise e
else:
return r