Asif Rahman

Python Dependency Injection

Dependency Injection (DI) is a software design pattern that lets you pass instances of a service rather than creating them directly within a class or function. The FastAPI framework provides a neat way to pass dependencies using Python’s type hints and the Depends function. The fast-depends package extracts the FastAPI code and strips out all the web framework-specific code into a small library that can be used in any Python project.

Here is an example of how to use dependency injection using the custom single-file implementation of dependency injection at the bottom of this page.

import typing as t

def get_settings() -> Settings:
    return Settings()

@inject
def get_db(
    settings: t.Annotated[Settings, Depends(get_settings)],
) -> DatabaseConnection:
    db = DatabaseConnection(settings)
    return db

@inject
def compute_something(
    db: t.Annotated[DatabaseConnection, Depends(get_db)]
):
    # Use db connection here
    pass

# Call the function with dependencies injected
result = compute_something()  # Automatically resolves dependencies

db = get_db()  # You can also call the dependency directly
result = compute_something(db=db)  # Pass the db directly if needed

The get_* functions return instances of the required services, and the Depends function is used to declare dependencies. The inject decorator enables dependency injection for the function by automatically resolving the dependencies when the function is called. The default behavior is to cache the results of dependencies, so if a dependency is called multiple times, it will return the cached result instead of executing the function again (e.g. Settings and DB connection are created once and reused).

This pattern allows for better separation of concerns (instantiatiating dependencies outside of the function) and makes unit testing easier. For example, we can create different implementations of the database and settings classes and pass them to the compute_something function without changing its signature.

Below is the full implementation of dependency injection as a standalone Python module. You can copy it into a depends.py file and use it in your projects.

Key Components:

Features:

Table of Contents:

Usage

Simple Dependency Injection

Dependencies can be injected into functions using the Depends class and the inject decorator.

def get_database():
    return "database_connection"

def get_user(db: str = Depends(get_database)):
    return f"user_from_{db}"

@inject
def handler(user: str = Depends(get_user)):
    return f"Hello, {user}!"

result = handler()  # "Hello, user_from_database_connection!"
print(result)

Async Dependencies

async def get_async_db():
    await asyncio.sleep(0.1)
    return "async_database"

@inject
async def async_handler(db: str = Depends(get_async_db)):
    return f"DB: {db}"

result = await async_handler()  # "DB: async_database"
print(result)

Dependency Caching

Caching is enabled by default, meaning that if a dependency is called multiple times within the same request, it will return the cached result instead of executing the function again.

call_count = 0

def expensive_operation():
    global call_count
    call_count += 1
    return f"result_{call_count}"

@inject
def handler(
    a: str = Depends(expensive_operation),
    b: str = Depends(expensive_operation),  # Same result due to caching
):
    return f"{a}, {b}"

result = handler()  # "result_1, result_1"
print(result)  # Output: "result_1, result_1"

Disable Caching

The use_cache parameter can be set to False to disable caching for specific dependencies.

@inject
def handler(
    a: str = Depends(expensive_operation, use_cache=False),
    b: str = Depends(expensive_operation, use_cache=False),
):
    return f"{a}, {b}"

result = handler()  # "result_1, result_2"
print(result)  # Output: "result_1, result_2"

Custom Fields

class HeaderExtractor(CustomField):
    def __init__(self, header_name: str):
        super().__init__()
        self.header_name = header_name

    def use(self, **kwargs: t.Any) -> t.Dict[str, t.Any]:
        # Extract from some global context
        kwargs[self.param_name] = f"header_value_for_{self.header_name}"
        return kwargs

@inject
def api_handler(
    auth: str = HeaderExtractor("Authorization"),
    content_type: str = HeaderExtractor("Content-Type"),
):
    return {"auth": auth, "content_type": content_type}

Dependency Overrides

def original_dep():
    return "original"

def override_dep():
    return "overridden"

@inject
def handler(value: str = Depends(original_dep)):
    return value

# Override dependency
dependency_provider.override(original_dep, override_dep)
result = handler()  # "overridden"

# Clear overrides
dependency_provider.clear()
result = handler()  # "original"

Generator Dependencies (Context Managers)

A dependency can be a generator function, which allows for resource management (like opening and closing database connections). In this example, the database_session function is a context manager that opens a database connection and closes it after use. The Depends decorator will handle the context management automatically.

def database_session() -> t.Generator[str, None, None]:
    print("Opening connection")
    yield "db_session"
    print("Closing connection")

@inject
def handler(db: str = Depends(database_session)):
    return f"Using {db}"

result = handler()
print(result)  # "Using db_session"
# Output: Opening connection
# Output: Closing connection
# Returns: "Using db_session"

Type Validation with Pydantic

Arguments can be annotated with Pydantic models for automatic validation and casting. In this example, the get_user_id function returns a string, but it will be cast to an integer when injected. The Depends decorator will handle the type casting automatically if cast=True is set (default behavior). The Annotated type from typing_extensions is used to specify the type and the dependency.

from typing import Annotated

def get_user_id() -> int:
    return "123"  # Wrong type!

@inject
def handler(user_id: Annotated[int, Depends(get_user_id)]):
    return f"User ID: {user_id}"

result = handler()  # user_id will be cast to int(123)

Disable Type Casting

If you want to disable type casting for a specific dependency, you can set cast=False in the inject decorator.

@inject(cast=False)
def handler(user_id: Annotated[int, Depends(get_user_id)]):
    return f"User ID type: {type(user_id)}"

result = handler()  # user_id remains as string "123"

Implementation Code

Requirements:

import anyio
import asyncio
import inspect
import functools
import typing as t
from abc import ABC
from copy import deepcopy
from itertools import chain
from pydantic import ConfigDict
from collections import namedtuple
from functools import wraps, partial
from pydantic import BaseModel, create_model
from typing_extensions import Annotated, ParamSpec, get_args, get_origin
from pydantic._internal._typing_extra import try_eval_type as evaluate_forwardref
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager

P = ParamSpec("P")
T = t.TypeVar("T")
Cls = t.TypeVar("Cls", bound="CustomField")

default_pydantic_config = {"arbitrary_types_allowed": True}


def get_config_base(config_data: t.Optional[ConfigDict] = None) -> ConfigDict:
    return config_data or ConfigDict(**default_pydantic_config)


def get_aliases(model: t.Type[BaseModel]) -> t.Tuple[str, ...]:
    return tuple(f.alias or name for name, f in model.model_fields.items())


class Depends:
    """Mark a parameter as a dependency to be injected."""

    use_cache: bool
    cast: bool

    def __init__(
        self,
        dependency: t.Callable[..., t.Any],
        *,
        use_cache: bool = True,
        cast: bool = True,
    ) -> None:
        self.dependency = dependency
        self.use_cache = use_cache
        self.cast = cast

    def __repr__(self) -> str:
        attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
        cache = "" if self.use_cache else ", use_cache=False"
        return f"{self.__class__.__name__}({attr}{cache})"


class CustomField(ABC):
    """Base class for custom field extractors."""

    param_name: t.Optional[str]
    cast: bool
    required: bool

    __slots__ = (
        "cast",
        "param_name",
        "required",
        "field",
    )

    def __init__(
        self,
        *,
        cast: bool = True,
        required: bool = True,
    ) -> None:
        self.cast = cast
        self.param_name = None
        self.required = required
        self.field = False

    def set_param_name(self: Cls, name: str) -> Cls:
        self.param_name = name
        return self

    def use(self, /, **kwargs: t.Any) -> t.Dict[str, t.Any]:
        assert self.param_name, "You should specify `param_name` before using"
        return kwargs

    def use_field(self, kwargs: t.Dict[str, t.Any]) -> None:
        raise NotImplementedError("You should implement `use_field` method.")


# Provider for dependency overrides


class Provider:
    """Provider for dependency overrides."""

    dependency_overrides: t.Dict[t.Callable[..., t.Any], t.Callable[..., t.Any]]

    def __init__(self) -> None:
        self.dependency_overrides = {}

    def clear(self) -> None:
        self.dependency_overrides = {}

    def override(
        self,
        original: t.Callable[..., t.Any],
        override: t.Callable[..., t.Any],
    ) -> None:
        self.dependency_overrides[original] = override

    @contextmanager
    def scope(
        self,
        original: t.Callable[..., t.Any],
        override: t.Callable[..., t.Any],
    ) -> t.Iterator[None]:
        self.dependency_overrides[original] = override
        yield
        self.dependency_overrides.pop(original, None)


dependency_provider = Provider()


def is_coroutine_callable(call: t.Callable[..., t.Any]) -> bool:
    if inspect.isclass(call):
        return False
    if asyncio.iscoroutinefunction(call):
        return True
    dunder_call = getattr(call, "__call__", None)
    return asyncio.iscoroutinefunction(dunder_call)


def is_gen_callable(call: t.Callable[..., t.Any]) -> bool:
    if inspect.isgeneratorfunction(call):
        return True
    dunder_call = getattr(call, "__call__", None)
    return inspect.isgeneratorfunction(dunder_call)


def is_async_gen_callable(call: t.Callable[..., t.Any]) -> bool:
    if inspect.isasyncgenfunction(call):
        return True
    dunder_call = getattr(call, "__call__", None)
    return inspect.isasyncgenfunction(dunder_call)


async def run_async(
    func: t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
    *args: P.args,
    **kwargs: P.kwargs,
) -> T:
    if is_coroutine_callable(func):
        return await t.cast(t.Callable[P, t.Awaitable[T]], func)(*args, **kwargs)
    else:
        return await run_in_threadpool(t.cast(t.Callable[P, T], func), *args, **kwargs)


async def run_in_threadpool(func: t.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
    if kwargs:
        func = functools.partial(func, **kwargs)
    return await anyio.to_thread.run_sync(func, *args)


def get_typed_annotation(
    annotation: t.Any,
    globalns: t.Dict[str, t.Any],
    locals: t.Dict[str, t.Any],
) -> t.Any:
    if isinstance(annotation, str):
        annotation = t.ForwardRef(annotation)

    if isinstance(annotation, t.ForwardRef):
        annotation = evaluate_forwardref(annotation, globalns, locals)

    if get_origin(annotation) is Annotated and (args := get_args(annotation)):
        solved_args = [get_typed_annotation(x, globalns, locals) for x in args]
        annotation.__origin__, annotation.__metadata__ = solved_args[0], tuple(solved_args[1:])

    return annotation


def collect_outer_stack_locals() -> t.Dict[str, t.Any]:
    """
    Collect local variables from outer stack frames to resolve type annotations.

    This function walks up the call stack and collects all local variables
    from frames outside of this module. This is necessary for resolving
    forward references and string annotations that might reference variables
    defined in the calling code.
    """
    frame = inspect.currentframe()
    frames: t.List[t.Any] = []
    current_filename = __file__ if "__file__" in globals() else None

    while frame is not None:
        frame_filename = frame.f_code.co_filename
        # Skip frames from this module to avoid internal variables
        if current_filename is None or frame_filename != current_filename:
            frames.append(frame)
        frame = frame.f_back

    locals = {}
    for f in frames[::-1]:
        locals.update(f.f_locals)

    return locals


def get_typed_signature(call: t.Callable[..., t.Any]) -> t.Tuple[inspect.Signature, t.Any]:
    signature = inspect.signature(call)
    locals = collect_outer_stack_locals()
    call = inspect.unwrap(call)
    globalns = getattr(call, "__globals__", {})

    typed_params = [
        inspect.Parameter(
            name=param.name,
            kind=param.kind,
            default=param.default,
            annotation=get_typed_annotation(
                param.annotation,
                globalns,
                locals,
            ),
        )
        for param in signature.parameters.values()
    ]

    return inspect.Signature(typed_params), get_typed_annotation(
        signature.return_annotation,
        globalns,
        locals,
    )


async def solve_generator_async(
    *sub_args: t.Any, call: t.Callable[..., t.Any], stack: AsyncExitStack, **sub_values: t.Any
) -> t.Any:
    if is_gen_callable(call):
        cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
    elif is_async_gen_callable(call):
        cm = asynccontextmanager(call)(*sub_args, **sub_values)
    return await stack.enter_async_context(cm)


def solve_generator_sync(
    *sub_args: t.Any, call: t.Callable[..., t.Any], stack: ExitStack, **sub_values: t.Any
) -> t.Any:
    cm = contextmanager(call)(*sub_args, **sub_values)
    return stack.enter_context(cm)


@asynccontextmanager
async def contextmanager_in_threadpool(
    cm: t.ContextManager[T],
) -> t.AsyncGenerator[T, None]:
    exit_limiter = anyio.CapacityLimiter(1)
    try:
        yield await run_in_threadpool(cm.__enter__)
    except Exception as e:
        ok = bool(await anyio.to_thread.run_sync(cm.__exit__, type(e), e, None, limiter=exit_limiter))
        if not ok:
            raise e
    else:
        await anyio.to_thread.run_sync(cm.__exit__, None, None, None, limiter=exit_limiter)


async def async_map(func: t.Callable[..., T], async_iterable: t.AsyncIterable[t.Any]) -> t.AsyncIterable[T]:
    async for i in async_iterable:
        yield func(i)


class solve_wrapper(partial[T]):
    call: t.Callable[..., T]

    def __new__(
        cls,
        func: t.Callable[..., T],
        *args: t.Any,
        **kwargs: t.Any,
    ) -> "solve_wrapper[T]":
        assert len(args) > 0, "Model should be passed as first argument"
        model = args[0]
        self = super().__new__(cls, func, *args, **kwargs)
        self.call = model.call
        return self


# Core Models

PriorityPair = namedtuple("PriorityPair", ("call", "dependencies_number", "dependencies_names"))


class ResponseModel(BaseModel, t.Generic[T]):
    response: T


class CallModel(t.Generic[P, T]):
    """Model representing a callable with dependency injection."""

    call: t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]]
    is_async: bool
    is_generator: bool
    model: t.Optional[t.Type[BaseModel]]
    response_model: t.Optional[t.Type[ResponseModel[T]]]

    params: t.Dict[str, t.Tuple[t.Any, t.Any]]
    alias_arguments: t.Tuple[str, ...]

    dependencies: t.Dict[str, "CallModel[..., t.Any]"]
    extra_dependencies: t.Iterable["CallModel[..., t.Any]"]
    sorted_dependencies: t.Tuple[t.Tuple["CallModel[..., t.Any]", int], ...]
    custom_fields: t.Dict[str, CustomField]
    keyword_args: t.Tuple[str, ...]
    positional_args: t.Tuple[str, ...]
    var_positional_arg: t.Optional[str]
    var_keyword_arg: t.Optional[str]

    use_cache: bool
    cast: bool

    __slots__ = (
        "call",
        "is_async",
        "is_generator",
        "model",
        "response_model",
        "params",
        "alias_arguments",
        "keyword_args",
        "positional_args",
        "var_positional_arg",
        "var_keyword_arg",
        "dependencies",
        "extra_dependencies",
        "sorted_dependencies",
        "custom_fields",
        "use_cache",
        "cast",
    )

    @property
    def call_name(self) -> str:
        call = inspect.unwrap(self.call)
        return getattr(call, "__name__", type(call).__name__)

    @property
    def flat_params(self) -> t.Dict[str, t.Tuple[t.Any, t.Any]]:
        params = self.params
        for d in (*self.dependencies.values(), *self.extra_dependencies):
            params.update(d.flat_params)
        return params

    @property
    def flat_dependencies(
        self,
    ) -> t.Dict[
        t.Callable[..., t.Any],
        t.Tuple["CallModel[..., t.Any]", t.Tuple[t.Callable[..., t.Any], ...]],
    ]:
        flat: t.Dict[
            t.Callable[..., t.Any],
            t.Tuple[CallModel[..., t.Any], t.Tuple[t.Callable[..., t.Any], ...]],
        ] = {}

        for i in (*self.dependencies.values(), *self.extra_dependencies):
            flat.update(
                {
                    i.call: (
                        i,
                        tuple(j.call for j in i.dependencies.values()),
                    )
                }
            )
            flat.update(i.flat_dependencies)

        return flat

    def __init__(
        self,
        /,
        call: t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
        model: t.Optional[t.Type[BaseModel]],
        params: t.Dict[str, t.Tuple[t.Any, t.Any]],
        response_model: t.Optional[t.Type[ResponseModel[T]]] = None,
        use_cache: bool = True,
        cast: bool = True,
        is_async: bool = False,
        is_generator: bool = False,
        dependencies: t.Optional[t.Dict[str, "CallModel[..., t.Any]"]] = None,
        extra_dependencies: t.Optional[t.Iterable["CallModel[..., t.Any]"]] = None,
        keyword_args: t.Optional[t.List[str]] = None,
        positional_args: t.Optional[t.List[str]] = None,
        var_positional_arg: t.Optional[str] = None,
        var_keyword_arg: t.Optional[str] = None,
        custom_fields: t.Optional[t.Dict[str, CustomField]] = None,
    ):
        self.call = call
        self.model = model

        if model:
            self.alias_arguments = get_aliases(model)
        else:
            self.alias_arguments = ()

        self.keyword_args = tuple(keyword_args or ())
        self.positional_args = tuple(positional_args or ())
        self.var_positional_arg = var_positional_arg
        self.var_keyword_arg = var_keyword_arg
        self.response_model = response_model
        self.use_cache = use_cache
        self.cast = cast
        self.is_async = is_async or is_coroutine_callable(call) or is_async_gen_callable(call)
        self.is_generator = is_generator or is_gen_callable(call) or is_async_gen_callable(call)

        self.dependencies = dependencies or {}
        self.extra_dependencies = extra_dependencies or ()
        self.custom_fields = custom_fields or {}

        sorted_dep: t.List[CallModel[..., t.Any]] = []
        flat = self.flat_dependencies
        for calls in flat.values():
            _sort_dep(sorted_dep, calls, flat)

        self.sorted_dependencies = tuple((i, len(i.sorted_dependencies)) for i in sorted_dep if i.use_cache)
        for name in chain(self.dependencies.keys(), self.custom_fields.keys()):
            params.pop(name, None)
        self.params = params

    def _solve(
        self,
        /,
        *args: t.Tuple[t.Any, ...],
        cache_dependencies: t.Dict[
            t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
            T,
        ],
        dependency_overrides: t.Optional[
            t.Dict[
                t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
                t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
            ]
        ] = None,
        **kwargs: t.Dict[str, t.Any],
    ) -> t.Generator[
        t.Tuple[t.Sequence[t.Any], t.Dict[str, t.Any], t.Callable[..., t.Any]],
        t.Any,
        T,
    ]:
        if dependency_overrides:
            call = dependency_overrides.get(self.call, self.call)
            assert self.is_async or not is_coroutine_callable(call), (
                f"You cannot use async dependency `{self.call_name}` at sync main"
            )
        else:
            call = self.call

        if self.use_cache and call in cache_dependencies:
            return cache_dependencies[call]

        kw: t.Dict[str, t.Any] = {}

        for arg in self.keyword_args:
            if (v := kwargs.pop(arg, inspect.Parameter.empty)) is not inspect.Parameter.empty:
                kw[arg] = v

        if self.var_keyword_arg is not None:
            kw[self.var_keyword_arg] = kwargs
        else:
            kw.update(kwargs)

        for arg in self.positional_args:
            if args:
                kw[arg], args = args[0], args[1:]
            else:
                break

        keyword_args: t.Iterable[str]
        if self.var_positional_arg is not None:
            kw[self.var_positional_arg] = args
            keyword_args = self.keyword_args
        else:
            keyword_args = self.keyword_args + self.positional_args
            for arg in keyword_args:
                if not self.cast and arg in self.params:
                    kw[arg] = self.params[arg][1]

                if not args:
                    break

                if arg not in self.dependencies:
                    kw[arg], args = args[0], args[1:]

        solved_kw: t.Dict[str, t.Any]
        solved_kw = yield args, kw, call

        args_: t.Sequence[t.Any]
        if self.cast:
            assert self.model, "Cast should be used only with model"
            casted_model = self.model(**solved_kw)

            kwargs_ = {arg: getattr(casted_model, arg, solved_kw.get(arg)) for arg in keyword_args}
            if self.var_keyword_arg:
                kwargs_.update(getattr(casted_model, self.var_keyword_arg, {}))

            if self.var_positional_arg is not None:
                args_ = [getattr(casted_model, arg, solved_kw.get(arg)) for arg in self.positional_args]
                args_.extend(getattr(casted_model, self.var_positional_arg, ()))
            else:
                args_ = ()
        else:
            kwargs_ = {arg: solved_kw.get(arg) for arg in keyword_args}

            if self.var_positional_arg is not None:
                args_ = tuple(map(solved_kw.get, self.positional_args))
            else:
                args_ = ()

        response: T
        response = yield args_, kwargs_, call

        if self.cast and not self.is_generator:
            response = self._cast_response(response)

        if self.use_cache:
            cache_dependencies[call] = response

        return response

    def _cast_response(self, /, value: t.Any) -> t.Any:
        if self.response_model is not None:
            return self.response_model(response=value).response
        else:
            return value

    def solve(
        self,
        /,
        *args: t.Tuple[t.Any, ...],
        stack: ExitStack,
        cache_dependencies: t.Dict[
            t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
            T,
        ],
        dependency_overrides: t.Optional[
            t.Dict[
                t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
                t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
            ]
        ] = None,
        nested: bool = False,
        **kwargs: t.Dict[str, t.Any],
    ) -> T:
        cast_gen = self._solve(
            *args,
            cache_dependencies=cache_dependencies,
            dependency_overrides=dependency_overrides,
            **kwargs,
        )
        try:
            args, kwargs, _ = next(cast_gen)
        except StopIteration as e:
            cached_value: T = e.value
            return cached_value

        # Heat cache and solve extra dependencies
        for dep, _ in self.sorted_dependencies:
            dep.solve(
                *args,
                stack=stack,
                cache_dependencies=cache_dependencies,
                dependency_overrides=dependency_overrides,
                nested=True,
                **kwargs,
            )

        # Always get from cache
        for dep in self.extra_dependencies:
            dep.solve(
                *args,
                stack=stack,
                cache_dependencies=cache_dependencies,
                dependency_overrides=dependency_overrides,
                nested=True,
                **kwargs,
            )

        for dep_arg, dep in self.dependencies.items():
            kwargs[dep_arg] = dep.solve(
                stack=stack,
                cache_dependencies=cache_dependencies,
                dependency_overrides=dependency_overrides,
                nested=True,
                **kwargs,
            )

        for custom in self.custom_fields.values():
            if custom.field:
                custom.use_field(kwargs)
            else:
                kwargs = custom.use(**kwargs)

        final_args, final_kwargs, call = cast_gen.send(kwargs)

        if self.is_generator and nested:
            response = solve_generator_sync(
                *final_args,
                call=call,
                stack=stack,
                **final_kwargs,
            )
        else:
            response = call(*final_args, **final_kwargs)

        try:
            cast_gen.send(response)
        except StopIteration as e:
            value: T = e.value

            if not self.cast or nested or not self.is_generator:
                return value
            else:
                return map(self._cast_response, value)

        raise AssertionError("unreachable")

    async def asolve(
        self,
        /,
        *args: t.Tuple[t.Any, ...],
        stack: AsyncExitStack,
        cache_dependencies: t.Dict[
            t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
            T,
        ],
        dependency_overrides: t.Optional[
            t.Dict[
                t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
                t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
            ]
        ] = None,
        nested: bool = False,
        **kwargs: t.Dict[str, t.Any],
    ) -> T:
        cast_gen = self._solve(
            *args,
            cache_dependencies=cache_dependencies,
            dependency_overrides=dependency_overrides,
            **kwargs,
        )
        try:
            args, kwargs, _ = next(cast_gen)
        except StopIteration as e:
            cached_value: T = e.value
            return cached_value

        # Heat cache and solve extra dependencies
        dep_to_solve: t.List[t.Callable[..., t.Awaitable[t.Any]]] = []
        try:
            async with anyio.create_task_group() as tg:
                for dep, subdep in self.sorted_dependencies:
                    solve = partial(
                        dep.asolve,
                        *args,
                        stack=stack,
                        cache_dependencies=cache_dependencies,
                        dependency_overrides=dependency_overrides,
                        nested=True,
                        **kwargs,
                    )
                    if not subdep:
                        tg.start_soon(solve)
                    else:
                        dep_to_solve.append(solve)
        except Exception as e:
            raise e

        for i in dep_to_solve:
            await i()

        # Always get from cache
        for dep in self.extra_dependencies:
            await dep.asolve(
                *args,
                stack=stack,
                cache_dependencies=cache_dependencies,
                dependency_overrides=dependency_overrides,
                nested=True,
                **kwargs,
            )

        for dep_arg, dep in self.dependencies.items():
            kwargs[dep_arg] = await dep.asolve(
                stack=stack,
                cache_dependencies=cache_dependencies,
                dependency_overrides=dependency_overrides,
                nested=True,
                **kwargs,
            )

        custom_to_solve: t.List[CustomField] = []

        try:
            async with anyio.create_task_group() as tg:
                for custom in self.custom_fields.values():
                    if custom.field:
                        tg.start_soon(run_async, custom.use_field, kwargs)
                    else:
                        custom_to_solve.append(custom)
        except Exception as e:
            raise e

        for j in custom_to_solve:
            kwargs = await run_async(j.use, **kwargs)

        final_args, final_kwargs, call = cast_gen.send(kwargs)

        if self.is_generator and nested:
            response = await solve_generator_async(
                *final_args,
                call=call,
                stack=stack,
                **final_kwargs,
            )
        else:
            response = await run_async(call, *final_args, **final_kwargs)

        try:
            cast_gen.send(response)
        except StopIteration as e:
            value: T = e.value

            if not self.cast or nested or not self.is_generator:
                return value
            else:
                return async_map(self._cast_response, value)

        raise AssertionError("unreachable")


def _sort_dep(
    collector: t.List["CallModel[..., t.Any]"],
    items: t.Tuple[
        "CallModel[..., t.Any]",
        t.Tuple[t.Callable[..., t.Any], ...],
    ],
    flat: t.Dict[
        t.Callable[..., t.Any],
        t.Tuple[
            "CallModel[..., t.Any]",
            t.Tuple[t.Callable[..., t.Any], ...],
        ],
    ],
) -> None:
    model, calls = items

    if model in collector:
        return

    if not calls:
        position = -1
    else:
        for i in calls:
            sub_model, _ = flat[i]
            if sub_model not in collector:
                _sort_dep(collector, flat[i], flat)

        position = max(collector.index(flat[i][0]) for i in calls)

    collector.insert(position + 1, model)


CUSTOM_ANNOTATIONS = (Depends, CustomField)


def build_call_model(
    call: t.Union[t.Callable[P, T], t.Callable[P, t.Awaitable[T]]],
    *,
    cast: bool = True,
    use_cache: bool = True,
    is_sync: t.Optional[bool] = None,
    extra_dependencies: t.Sequence[Depends] = (),
    pydantic_config: t.Optional[ConfigDict] = None,
) -> CallModel[P, T]:
    """Build a CallModel from a callable."""
    name = getattr(call, "__name__", type(call).__name__)

    is_call_async = is_coroutine_callable(call) or is_async_gen_callable(call)
    if is_sync is None:
        is_sync = not is_call_async
    else:
        assert not (is_sync and is_call_async), f"You cannot use async dependency `{name}` at sync main"

    typed_params, return_annotation = get_typed_signature(call)
    if (is_call_generator := is_gen_callable(call) or is_async_gen_callable(call)) and (
        return_args := get_args(return_annotation)
    ):
        return_annotation = return_args[0]

    class_fields: t.Dict[str, t.Tuple[t.Any, t.Any]] = {}
    dependencies: t.Dict[str, CallModel[..., t.Any]] = {}
    custom_fields: t.Dict[str, CustomField] = {}
    positional_args: t.List[str] = []
    keyword_args: t.List[str] = []
    var_positional_arg: t.Optional[str] = None
    var_keyword_arg: t.Optional[str] = None

    for param_name, param in typed_params.parameters.items():
        dep: t.Optional[Depends] = None
        custom: t.Optional[CustomField] = None

        if param.annotation is inspect.Parameter.empty:
            annotation = t.Any
        elif get_origin(param.annotation) is Annotated:
            annotated_args = get_args(param.annotation)
            type_annotation = annotated_args[0]

            custom_annotations = []
            regular_annotations = []
            for arg in annotated_args[1:]:
                if isinstance(arg, CUSTOM_ANNOTATIONS):
                    custom_annotations.append(arg)
                else:
                    regular_annotations.append(arg)

            assert len(custom_annotations) <= 1, (
                f"Cannot specify multiple `Annotated` Custom arguments for `{param_name}`!"
            )

            next_custom = next(iter(custom_annotations), None)
            if next_custom is not None:
                if isinstance(next_custom, Depends):
                    dep = next_custom
                elif isinstance(next_custom, CustomField):
                    custom = deepcopy(next_custom)
                else:
                    raise AssertionError("unreachable")

                if regular_annotations:
                    annotation = param.annotation
                else:
                    annotation = type_annotation
            else:
                annotation = param.annotation
        else:
            annotation = param.annotation

        default: t.Any
        if param.kind == inspect.Parameter.VAR_POSITIONAL:
            default = ()
            var_positional_arg = param_name
        elif param.kind == inspect.Parameter.VAR_KEYWORD:
            default = {}
            var_keyword_arg = param_name
        elif param.default is inspect.Parameter.empty:
            default = Ellipsis
        else:
            default = param.default

        if isinstance(default, Depends):
            if dep:
                raise AssertionError(
                    "You can not use both `Depends` with `Annotated` and a default",
                )
            dep, default = default, Ellipsis

        elif isinstance(default, CustomField):
            if custom:
                raise AssertionError(
                    "You can not use both `CustomField` with `Annotated` and a default",
                )
            custom, default = default, Ellipsis

        else:
            class_fields[param_name] = (annotation, default)

        if dep:
            if not cast:
                dep.cast = False

            if isinstance(dep.dependency, solve_wrapper):
                dep.dependency = dep.dependency.call

            dependencies[param_name] = build_call_model(
                dep.dependency,
                cast=dep.cast,
                use_cache=dep.use_cache,
                is_sync=is_sync,
                pydantic_config=pydantic_config,
            )

            if dep.cast is True:
                class_fields[param_name] = (annotation, Ellipsis)

            keyword_args.append(param_name)

        elif custom:
            assert not (is_sync and is_coroutine_callable(custom.use)), (
                f"You cannot use async custom field `{type(custom).__name__}` at sync `{name}`"
            )

            custom.set_param_name(param_name)
            custom_fields[param_name] = custom

            if custom.cast is False:
                annotation = t.Any

            if custom.required:
                class_fields[param_name] = (annotation, default)
            else:
                class_fields[param_name] = class_fields.get(param_name, (t.Optional[annotation], None))

            keyword_args.append(param_name)

        else:
            if param.kind is param.KEYWORD_ONLY:
                keyword_args.append(param_name)
            elif param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
                positional_args.append(param_name)

    func_model = create_model(
        name,
        __config__=get_config_base(pydantic_config),
        **class_fields,
    )

    response_model: t.Optional[t.Type[ResponseModel[T]]] = None
    if cast and return_annotation and return_annotation is not inspect.Parameter.empty:
        response_model = create_model(
            "ResponseModel",
            __config__=get_config_base(pydantic_config),
            response=(return_annotation, Ellipsis),
        )

    return CallModel(
        call=call,
        model=func_model,
        response_model=response_model,
        params=class_fields,
        cast=cast,
        use_cache=use_cache,
        is_async=is_call_async,
        is_generator=is_call_generator,
        dependencies=dependencies,
        custom_fields=custom_fields,
        positional_args=positional_args,
        keyword_args=keyword_args,
        var_positional_arg=var_positional_arg,
        var_keyword_arg=var_keyword_arg,
        extra_dependencies=[
            build_call_model(
                d.dependency,
                cast=d.cast,
                use_cache=d.use_cache,
                is_sync=is_sync,
                pydantic_config=pydantic_config,
            )
            for d in extra_dependencies
        ],
    )


class _InjectWrapper(t.Protocol[P, T]):
    def __call__(
        self,
        func: t.Callable[P, T],
        model: t.Optional[CallModel[P, T]] = None,
    ) -> t.Callable[P, T]: ...


@t.overload
def inject(
    func: None,
    *,
    cast: bool = True,
    extra_dependencies: t.Sequence[Depends] = (),
    pydantic_config: t.Optional[ConfigDict] = None,
    dependency_overrides_provider: t.Optional[t.Any] = dependency_provider,
    wrap_model: t.Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> _InjectWrapper[P, T]: ...


@t.overload
def inject(
    func: t.Callable[P, T],
    *,
    cast: bool = True,
    extra_dependencies: t.Sequence[Depends] = (),
    pydantic_config: t.Optional[ConfigDict] = None,
    dependency_overrides_provider: t.Optional[t.Any] = dependency_provider,
    wrap_model: t.Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> t.Callable[P, T]: ...


def inject(
    func: t.Optional[t.Callable[P, T]] = None,
    *,
    cast: bool = True,
    extra_dependencies: t.Sequence[Depends] = (),
    pydantic_config: t.Optional[ConfigDict] = None,
    dependency_overrides_provider: t.Optional[t.Any] = dependency_provider,
    wrap_model: t.Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x,
) -> t.Union[t.Callable[P, T], _InjectWrapper[P, T]]:
    """Decorator to inject dependencies into a function."""
    decorator = _wrap_inject(
        dependency_overrides_provider=dependency_overrides_provider,
        wrap_model=wrap_model,
        extra_dependencies=extra_dependencies,
        cast=cast,
        pydantic_config=pydantic_config,
    )

    if func is None:
        return decorator
    else:
        return decorator(func)


def _wrap_inject(
    dependency_overrides_provider: t.Optional[t.Any],
    wrap_model: t.Callable[[CallModel[P, T]], CallModel[P, T]],
    extra_dependencies: t.Sequence[Depends],
    cast: bool,
    pydantic_config: t.Optional[ConfigDict],
) -> _InjectWrapper[P, T]:
    if (
        dependency_overrides_provider
        and getattr(dependency_overrides_provider, "dependency_overrides", None) is not None
    ):
        overrides = dependency_overrides_provider.dependency_overrides
    else:
        overrides = None

    def func_wrapper(
        func: t.Callable[P, T],
        model: t.Optional[CallModel[P, T]] = None,
    ) -> t.Callable[P, T]:
        if model is None:
            real_model = wrap_model(
                build_call_model(
                    call=func,
                    extra_dependencies=extra_dependencies,
                    cast=cast,
                    pydantic_config=pydantic_config,
                )
            )
        else:
            real_model = model

        if real_model.is_async:
            injected_wrapper: t.Callable[P, T]

            if real_model.is_generator:
                injected_wrapper = solve_wrapper(solve_async_gen, real_model, overrides)
            else:

                @wraps(func)
                async def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
                    async with AsyncExitStack() as stack:
                        r = await real_model.asolve(
                            *args,
                            stack=stack,
                            dependency_overrides=overrides,
                            cache_dependencies={},
                            nested=False,
                            **kwargs,
                        )
                        return r

                    raise AssertionError("unreachable")

        else:
            if real_model.is_generator:
                injected_wrapper = solve_wrapper(solve_gen, real_model, overrides)
            else:

                @wraps(func)
                def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
                    with ExitStack() as stack:
                        r = real_model.solve(
                            *args,
                            stack=stack,
                            dependency_overrides=overrides,
                            cache_dependencies={},
                            nested=False,
                            **kwargs,
                        )
                        return r

                    raise AssertionError("unreachable")

        return injected_wrapper

    return func_wrapper


class solve_async_gen:
    _iter: t.Optional[t.AsyncIterator[t.Any]] = None

    def __init__(
        self,
        model: "CallModel[..., t.Any]",
        overrides: t.Optional[t.Any],
        *args: t.Any,
        **kwargs: t.Any,
    ):
        self.call = model
        self.args = args
        self.kwargs = kwargs
        self.overrides = overrides

    def __aiter__(self) -> "solve_async_gen":
        self.stack = AsyncExitStack()
        return self

    async def __anext__(self) -> t.Any:
        if self._iter is None:
            stack = self.stack = AsyncExitStack()
            await self.stack.__aenter__()
            self._iter = t.cast(
                t.AsyncIterator[t.Any],
                (
                    await self.call.asolve(
                        *self.args,
                        stack=stack,
                        dependency_overrides=self.overrides,
                        cache_dependencies={},
                        nested=False,
                        **self.kwargs,
                    )
                ).__aiter__(),
            )

        try:
            r = await self._iter.__anext__()
        except StopAsyncIteration as e:
            await self.stack.__aexit__(None, None, None)
            raise e
        else:
            return r


class solve_gen:
    _iter: t.Optional[t.Iterator[t.Any]] = None

    def __init__(
        self,
        model: "CallModel[..., t.Any]",
        overrides: t.Optional[t.Any],
        *args: t.Any,
        **kwargs: t.Any,
    ):
        self.call = model
        self.args = args
        self.kwargs = kwargs
        self.overrides = overrides

    def __iter__(self) -> "solve_gen":
        self.stack = ExitStack()
        return self

    def __next__(self) -> t.Any:
        if self._iter is None:
            stack = self.stack = ExitStack()
            self.stack.__enter__()
            self._iter = t.cast(
                t.Iterator[t.Any],
                iter(
                    self.call.solve(
                        *self.args,
                        stack=stack,
                        dependency_overrides=self.overrides,
                        cache_dependencies={},
                        nested=False,
                        **self.kwargs,
                    )
                ),
            )

        try:
            r = next(self._iter)
        except StopIteration as e:
            self.stack.__exit__(None, None, None)
            raise e
        else:
            return r

#Python