FastAPI + SQLAlchemy example — Dependency Injector 4.46.0 documentation (original) (raw)

This example shows how to use Dependency Injector with FastAPI andSQLAlchemy.

The source code is available on the Github.

Thanks to @ShvetsovYura for providing initial example:FastAPI_DI_SqlAlchemy.

Application structure

Application has next structure:

./ ├── webapp/ │ ├── init.py │ ├── application.py │ ├── containers.py │ ├── database.py │ ├── endpoints.py │ ├── models.py │ ├── repositories.py │ ├── services.py │ └── tests.py ├── config.yml ├── docker-compose.yml ├── Dockerfile └── requirements.txt

Application factory

Application factory creates container, wires it with the endpoints module, createsFastAPI app, and setup routes.

Application factory also creates database if it does not exist.

Listing of webapp/application.py:

"""Application module."""

from fastapi import FastAPI

from .containers import Container from . import endpoints

def create_app() -> FastAPI: container = Container()

db = container.db()
db.create_database()

app = FastAPI()
app.container = container
app.include_router(endpoints.router)
return app

app = create_app()

Endpoints

Module endpoints contains example endpoints. Endpoints have a dependency on user service. User service is injected using Wiring feature. See webapp/endpoints.py:

"""Endpoints module."""

from typing import Annotated

from fastapi import APIRouter, Depends, Response, status

from dependency_injector.wiring import Provide, inject

from .containers import Container from .repositories import NotFoundError from .services import UserService

router = APIRouter()

@router.get("/users") @inject def get_list( user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): return user_service.get_users()

@router.get("/users/{user_id}") @inject def get_by_id( user_id: int, user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): try: return user_service.get_user_by_id(user_id) except NotFoundError: return Response(status_code=status.HTTP_404_NOT_FOUND)

@router.post("/users", status_code=status.HTTP_201_CREATED) @inject def add( user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): return user_service.create_user()

@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @inject def remove( user_id: int, user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ) -> Response: try: user_service.delete_user_by_id(user_id) except NotFoundError: return Response(status_code=status.HTTP_404_NOT_FOUND) else: return Response(status_code=status.HTTP_204_NO_CONTENT)

@router.get("/status") def get_status(): return {"status": "OK"}

Container

Declarative container wires example user service, user repository, and utility database class. See webapp/containers.py:

"""Containers module."""

from dependency_injector import containers, providers

from .database import Database from .repositories import UserRepository from .services import UserService

class Container(containers.DeclarativeContainer):

wiring_config = containers.WiringConfiguration(modules=[".endpoints"])

config = providers.Configuration(yaml_files=["config.yml"])

db = providers.Singleton(Database, db_url=config.db.url)

user_repository = providers.Factory(
    UserRepository,
    session_factory=db.provided.session,
)

user_service = providers.Factory(
    UserService,
    user_repository=user_repository,
)

Services

Module services contains example user service. See webapp/services.py:

"""Services module."""

from uuid import uuid4 from typing import Iterator

from .repositories import UserRepository from .models import User

class UserService:

def __init__(self, user_repository: UserRepository) -> None:
    self._repository: UserRepository = user_repository

def get_users(self) -> Iterator[User]:
    return self._repository.get_all()

def get_user_by_id(self, user_id: int) -> User:
    return self._repository.get_by_id(user_id)

def create_user(self) -> User:
    uid = uuid4()
    return self._repository.add(email=f"{uid}@email.com", password="pwd")

def delete_user_by_id(self, user_id: int) -> None:
    return self._repository.delete_by_id(user_id)

Repositories

Module repositories contains example user repository. See webapp/repositories.py:

"""Repositories module."""

from contextlib import AbstractContextManager from typing import Callable, Iterator

from sqlalchemy.orm import Session

from .models import User

class UserRepository:

def __init__(self, session_factory: Callable[..., AbstractContextManager[Session]]) -> None:
    self.session_factory = session_factory

def get_all(self) -> Iterator[User]:
    with self.session_factory() as session:
        return session.query(User).all()

def get_by_id(self, user_id: int) -> User:
    with self.session_factory() as session:
        user = session.query(User).filter(User.id == user_id).first()
        if not user:
            raise UserNotFoundError(user_id)
        return user

def add(self, email: str, password: str, is_active: bool = True) -> User:
    with self.session_factory() as session:
        user = User(email=email, hashed_password=password, is_active=is_active)
        session.add(user)
        session.commit()
        session.refresh(user)
        return user

def delete_by_id(self, user_id: int) -> None:
    with self.session_factory() as session:
        entity: User = session.query(User).filter(User.id == user_id).first()
        if not entity:
            raise UserNotFoundError(user_id)
        session.delete(entity)
        session.commit()

class NotFoundError(Exception):

entity_name: str

def __init__(self, entity_id):
    super().__init__(f"{self.entity_name} not found, id: {entity_id}")

class UserNotFoundError(NotFoundError):

entity_name: str = "User"

Models

Module models contains example SQLAlchemy user model. See webapp/models.py:

"""Models module."""

from sqlalchemy import Column, String, Boolean, Integer

from .database import Base

class User(Base):

__tablename__ = "users"

id = Column(Integer, primary_key=True)
email = Column(String, unique=True)
hashed_password = Column(String)
is_active = Column(Boolean, default=True)

def __repr__(self):
    return f"<User(id={self.id}, " \
           f"email=\"{self.email}\", " \
           f"hashed_password=\"{self.hashed_password}\", " \
           f"is_active={self.is_active})>"

Database

Module database defines declarative base and utility class with engine and session factory. See webapp/database.py:

"""Database module."""

from contextlib import contextmanager, AbstractContextManager from typing import Callable import logging

from sqlalchemy import create_engine, orm from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session

logger = logging.getLogger(name)

Base = declarative_base()

class Database:

def __init__(self, db_url: str) -> None:
    self._engine = create_engine(db_url, echo=True)
    self._session_factory = orm.scoped_session(
        orm.sessionmaker(
            autocommit=False,
            autoflush=False,
            bind=self._engine,
        ),
    )

def create_database(self) -> None:
    Base.metadata.create_all(self._engine)

@contextmanager
def session(self) -> Callable[..., AbstractContextManager[Session]]:
    session: Session = self._session_factory()
    try:
        yield session
    except Exception:
        logger.exception("Session rollback because of exception")
        session.rollback()
        raise
    finally:
        session.close()

Tests

Tests use Provider overriding feature to replace repository with a mock. See webapp/tests.py:

"""Tests module."""

from unittest import mock

import pytest from fastapi.testclient import TestClient

from .repositories import UserRepository, UserNotFoundError from .models import User from .application import app

@pytest.fixture def client(): yield TestClient(app)

def test_get_list(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.get_all.return_value = [ User(id=1, email="test1@email.com", hashed_password="pwd", is_active=True), User(id=2, email="test2@email.com", hashed_password="pwd", is_active=False), ]

with app.container.user_repository.override(repository_mock):
    response = client.get("/users")

assert response.status_code == 200
data = response.json()
assert data == [
    {"id": 1, "email": "test1@email.com", "hashed_password": "pwd", "is_active": True},
    {"id": 2, "email": "test2@email.com", "hashed_password": "pwd", "is_active": False},
]

def test_get_by_id(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.get_by_id.return_value = User( id=1, email="xyz@email.com", hashed_password="pwd", is_active=True, )

with app.container.user_repository.override(repository_mock):
    response = client.get("/users/1")

assert response.status_code == 200
data = response.json()
assert data == {"id": 1, "email": "xyz@email.com", "hashed_password": "pwd", "is_active": True}
repository_mock.get_by_id.assert_called_once_with(1)

def test_get_by_id_404(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.get_by_id.side_effect = UserNotFoundError(1)

with app.container.user_repository.override(repository_mock):
    response = client.get("/users/1")

assert response.status_code == 404

@mock.patch("webapp.services.uuid4", return_value="xyz") def test_add(_, client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.add.return_value = User( id=1, email="xyz@email.com", hashed_password="pwd", is_active=True, )

with app.container.user_repository.override(repository_mock):
    response = client.post("/users")

assert response.status_code == 201
data = response.json()
assert data == {"id": 1, "email": "xyz@email.com", "hashed_password": "pwd", "is_active": True}
repository_mock.add.assert_called_once_with(email="xyz@email.com", password="pwd")

def test_remove(client): repository_mock = mock.Mock(spec=UserRepository)

with app.container.user_repository.override(repository_mock):
    response = client.delete("/users/1")

assert response.status_code == 204
repository_mock.delete_by_id.assert_called_once_with(1)

def test_remove_404(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.delete_by_id.side_effect = UserNotFoundError(1)

with app.container.user_repository.override(repository_mock):
    response = client.delete("/users/1")

assert response.status_code == 404

def test_status(client): response = client.get("/status") assert response.status_code == 200 data = response.json() assert data == {"status": "OK"}

Sources

The source code is available on the Github.

Sponsor the project on GitHub: