Testing is not optional — it is a fundamental part of professional software development. FastAPI, built on top of Starlette and Pydantic, provides first-class testing support that makes writing comprehensive tests straightforward and enjoyable. In this tutorial, we will explore every aspect of testing FastAPI applications, from simple route tests to complex integration tests with databases, authentication, mocking, and CI/CD pipelines.
By the end of this guide, you will have a complete understanding of how to build a robust test suite that gives you confidence in your FastAPI application’s correctness, performance, and reliability.
Before writing a single test, it is important to understand why we test and what strategies guide our testing decisions.
Testing serves multiple critical purposes in software development:
| Benefit | Description |
|---|---|
| Bug Prevention | Catch errors before they reach production |
| Confidence in Refactoring | Change code without fear of breaking existing functionality |
| Documentation | Tests describe how the system should behave |
| Design Improvement | Writing testable code leads to better architecture |
| Regression Prevention | Ensure fixed bugs stay fixed |
| Deployment Confidence | Automated tests enable continuous deployment |
The testing pyramid is a strategy that guides how many tests of each type you should write:
"""
The Testing Pyramid for FastAPI Applications
/\
/ \ E2E Tests (Few)
/ \ - Full browser/API integration tests
/------\ - Slow, expensive, brittle
/ \
/ Integ. \ Integration Tests (Some)
/ Tests \ - Database, external services
/--------------\ - Moderate speed
/ \
/ Unit Tests \ Unit Tests (Many)
/------------------\ - Fast, isolated, focused
- Test individual functions/classes
"""
# Unit Test Example
def test_calculate_discount():
"""Tests a pure function in isolation."""
assert calculate_discount(100, 10) == 90.0
assert calculate_discount(100, 0) == 100.0
assert calculate_discount(50, 50) == 25.0
# Integration Test Example
def test_create_user_in_database(db_session):
"""Tests interaction with a real database."""
user = User(name="Alice", email="alice@example.com")
db_session.add(user)
db_session.commit()
saved = db_session.query(User).filter_by(email="alice@example.com").first()
assert saved is not None
assert saved.name == "Alice"
# E2E Test Example
def test_full_user_registration_flow(client):
"""Tests the complete registration workflow."""
# Register
response = client.post("/auth/register", json={
"name": "Alice",
"email": "alice@example.com",
"password": "SecurePass123!"
})
assert response.status_code == 201
# Login
response = client.post("/auth/login", json={
"email": "alice@example.com",
"password": "SecurePass123!"
})
assert response.status_code == 200
token = response.json()["access_token"]
# Access protected resource
response = client.get("/users/me", headers={
"Authorization": f"Bearer {token}"
})
assert response.status_code == 200
assert response.json()["email"] == "alice@example.com"
Test-Driven Development (TDD) follows a simple cycle: Red (write a failing test), Green (make it pass), Refactor (improve the code). Here is how TDD works with FastAPI:
"""
TDD Cycle with FastAPI:
1. RED - Write a test for a feature that doesn't exist yet
2. GREEN - Write the minimum code to make the test pass
3. REFACTOR - Clean up while keeping tests green
"""
# Step 1: RED - Write the failing test first
# tests/test_items.py
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_get_items_returns_empty_list():
"""We want GET /items to return an empty list initially."""
response = client.get("/items")
assert response.status_code == 200
assert response.json() == []
# Running this test will FAIL because we haven't created the endpoint yet.
# Step 2: GREEN - Write minimal code to pass
# app/main.py
from fastapi import FastAPI
app = FastAPI()
@app.get("/items")
def get_items():
return []
# Now the test passes!
# Step 3: REFACTOR - Improve the implementation
# app/main.py
from fastapi import FastAPI
from typing import List
from app.models import Item
app = FastAPI()
items_db: List[Item] = []
@app.get("/items", response_model=List[Item])
def get_items():
"""Retrieve all items."""
return items_db
A well-organized test setup is the foundation for a maintainable test suite. Let us configure everything properly from the start.
# Core testing dependencies pip install pytest httpx # Additional testing utilities pip install pytest-asyncio pytest-cov pytest-mock # For database testing pip install sqlalchemy aiosqlite # For performance testing pip install locust # Install all at once pip install pytest httpx pytest-asyncio pytest-cov pytest-mock sqlalchemy aiosqlite locust
| Package | Purpose |
|---|---|
pytest |
Primary testing framework with powerful fixtures and assertions |
httpx |
Async HTTP client, used for async test client |
pytest-asyncio |
Enables async test functions with pytest |
pytest-cov |
Test coverage reporting |
pytest-mock |
Simplified mocking with pytest fixtures |
locust |
Load testing and performance benchmarking |
my_fastapi_project/ ├── app/ │ ├── __init__.py │ ├── main.py # FastAPI app instance │ ├── config.py # Configuration settings │ ├── database.py # Database connection │ ├── models/ │ │ ├── __init__.py │ │ ├── user.py # SQLAlchemy User model │ │ └── item.py # SQLAlchemy Item model │ ├── schemas/ │ │ ├── __init__.py │ │ ├── user.py # Pydantic User schemas │ │ └── item.py # Pydantic Item schemas │ ├── routers/ │ │ ├── __init__.py │ │ ├── users.py # User endpoints │ │ ├── items.py # Item endpoints │ │ └── auth.py # Authentication endpoints │ ├── services/ │ │ ├── __init__.py │ │ ├── user_service.py # User business logic │ │ └── email_service.py # Email sending service │ └── dependencies.py # Shared dependencies ├── tests/ │ ├── __init__.py │ ├── conftest.py # Shared fixtures (CRITICAL FILE) │ ├── test_main.py # App-level tests │ ├── unit/ │ │ ├── __init__.py │ │ ├── test_schemas.py # Pydantic schema tests │ │ └── test_services.py # Business logic tests │ ├── integration/ │ │ ├── __init__.py │ │ ├── test_users.py # User endpoint tests │ │ ├── test_items.py # Item endpoint tests │ │ └── test_auth.py # Auth endpoint tests │ ├── e2e/ │ │ ├── __init__.py │ │ └── test_workflows.py # End-to-end workflow tests │ └── performance/ │ └── locustfile.py # Load testing scripts ├── pytest.ini # Pytest configuration ├── .coveragerc # Coverage configuration └── requirements-test.txt # Test dependencies
# pytest.ini
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
asyncio_mode = auto
addopts = -v --tb=short --strict-markers
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
integration: marks integration tests
e2e: marks end-to-end tests
unit: marks unit tests
# .coveragerc
[run]
source = app
omit =
app/__init__.py
app/config.py
tests/*
[report]
show_missing = true
fail_under = 80
exclude_lines =
pragma: no cover
def __repr__
if __name__ == .__main__.
raise NotImplementedError
The conftest.py file is the heart of your test configuration. Pytest automatically discovers and uses fixtures defined here across all test files.
# tests/conftest.py
"""
Central test configuration and shared fixtures.
Fixtures defined here are automatically available to all test files
without requiring explicit imports.
"""
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.main import app
from app.database import Base, get_db
# ---------------------------------------------------------------------------
# Test Database Configuration
# ---------------------------------------------------------------------------
# Use an in-memory SQLite database for tests — fast and isolated
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
# ---------------------------------------------------------------------------
# Database Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="function")
def db_session():
"""
Create a fresh database session for each test.
Tables are created before the test and dropped after,
ensuring complete isolation between tests.
"""
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def client(db_session):
"""
Create a TestClient with the test database session injected.
This overrides the real database dependency so that all
endpoint tests use the test database.
"""
def override_get_db():
try:
yield db_session
finally:
db_session.close()
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
# Clean up overrides after the test
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# Authentication Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def auth_headers(client):
"""
Register a test user and return auth headers with a valid JWT token.
"""
# Register a user
client.post("/auth/register", json={
"name": "Test User",
"email": "test@example.com",
"password": "TestPass123!"
})
# Log in and get token
response = client.post("/auth/login", json={
"email": "test@example.com",
"password": "TestPass123!"
})
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def admin_headers(client):
"""
Create an admin user and return auth headers.
"""
client.post("/auth/register", json={
"name": "Admin User",
"email": "admin@example.com",
"password": "AdminPass123!",
"role": "admin"
})
response = client.post("/auth/login", json={
"email": "admin@example.com",
"password": "AdminPass123!"
})
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
# ---------------------------------------------------------------------------
# Sample Data Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def sample_user_data():
"""Return valid user registration data."""
return {
"name": "Alice Johnson",
"email": "alice@example.com",
"password": "SecurePass123!"
}
@pytest.fixture
def sample_item_data():
"""Return valid item creation data."""
return {
"title": "Test Item",
"description": "A test item for unit testing",
"price": 29.99,
"quantity": 10
}
@pytest.fixture
def multiple_items_data():
"""Return a list of items for batch testing."""
return [
{"title": f"Item {i}", "description": f"Description {i}",
"price": 10.0 * i, "quantity": i * 5}
for i in range(1, 6)
]
conftest.py file must be placed in the tests/ directory (or in subdirectories for scoped fixtures). Pytest discovers it automatically — you should never import from conftest.py directly.
FastAPI’s TestClient is built on top of Starlette’s test client, which uses the requests library interface. It allows you to test your endpoints synchronously without running a server.
# app/main.py
from fastapi import FastAPI
app = FastAPI(title="My API", version="1.0.0")
@app.get("/")
def read_root():
return {"message": "Hello, World!"}
@app.get("/health")
def health_check():
return {"status": "healthy", "version": "1.0.0"}
@app.get("/items/{item_id}")
def get_item(item_id: int, q: str = None):
result = {"item_id": item_id}
if q:
result["query"] = q
return result
# tests/test_main.py
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
class TestRootEndpoint:
"""Tests for the root endpoint."""
def test_root_returns_200(self):
"""GET / should return 200 OK."""
response = client.get("/")
assert response.status_code == 200
def test_root_returns_hello_message(self):
"""GET / should return a hello message."""
response = client.get("/")
data = response.json()
assert data["message"] == "Hello, World!"
def test_root_content_type_is_json(self):
"""GET / should return JSON content type."""
response = client.get("/")
assert response.headers["content-type"] == "application/json"
class TestHealthEndpoint:
"""Tests for the health check endpoint."""
def test_health_returns_200(self):
response = client.get("/health")
assert response.status_code == 200
def test_health_returns_healthy_status(self):
response = client.get("/health")
data = response.json()
assert data["status"] == "healthy"
def test_health_includes_version(self):
response = client.get("/health")
data = response.json()
assert "version" in data
assert data["version"] == "1.0.0"
class TestGetItem:
"""Tests for the GET /items/{item_id} endpoint."""
def test_get_item_returns_correct_id(self):
response = client.get("/items/42")
assert response.status_code == 200
assert response.json()["item_id"] == 42
def test_get_item_with_query_parameter(self):
response = client.get("/items/42?q=search_term")
data = response.json()
assert data["item_id"] == 42
assert data["query"] == "search_term"
def test_get_item_without_query_parameter(self):
response = client.get("/items/42")
data = response.json()
assert "query" not in data
def test_get_item_invalid_id_type(self):
"""Passing a string where int is expected should return 422."""
response = client.get("/items/not_a_number")
assert response.status_code == 422
# tests/test_http_methods.py
"""
Demonstrates testing all HTTP methods with TestClient.
"""
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
class TestHTTPMethods:
"""Test all standard HTTP methods."""
def test_get_request(self):
"""Test a GET request."""
response = client.get("/items")
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_post_request_with_json(self):
"""Test a POST request with JSON body."""
payload = {
"title": "New Item",
"description": "A brand new item",
"price": 19.99,
"quantity": 5
}
response = client.post("/items", json=payload)
assert response.status_code == 201
data = response.json()
assert data["title"] == "New Item"
assert data["price"] == 19.99
assert "id" in data # Should have an auto-generated ID
def test_put_request_updates_resource(self):
"""Test a PUT request to update a resource."""
# First, create an item
create_response = client.post("/items", json={
"title": "Original",
"description": "Original description",
"price": 10.00,
"quantity": 1
})
item_id = create_response.json()["id"]
# Then update it
update_payload = {
"title": "Updated",
"description": "Updated description",
"price": 15.00,
"quantity": 2
}
response = client.put(f"/items/{item_id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["title"] == "Updated"
assert data["price"] == 15.00
def test_patch_request_partial_update(self):
"""Test a PATCH request for partial updates."""
# Create an item
create_response = client.post("/items", json={
"title": "Original",
"description": "Original description",
"price": 10.00,
"quantity": 1
})
item_id = create_response.json()["id"]
# Partially update (only the price)
response = client.patch(
f"/items/{item_id}",
json={"price": 25.00}
)
assert response.status_code == 200
assert response.json()["price"] == 25.00
assert response.json()["title"] == "Original" # Unchanged
def test_delete_request(self):
"""Test a DELETE request."""
# Create an item
create_response = client.post("/items", json={
"title": "To Delete",
"description": "This will be deleted",
"price": 5.00,
"quantity": 1
})
item_id = create_response.json()["id"]
# Delete it
response = client.delete(f"/items/{item_id}")
assert response.status_code == 204
# Verify it is gone
get_response = client.get(f"/items/{item_id}")
assert get_response.status_code == 404
class TestRequestHeaders:
"""Test sending custom headers."""
def test_custom_headers(self):
"""Send custom headers with a request."""
headers = {
"X-Custom-Header": "test-value",
"Accept-Language": "en-US"
}
response = client.get("/items", headers=headers)
assert response.status_code == 200
def test_content_type_header(self):
"""Verify correct content-type in response."""
response = client.get("/items")
assert "application/json" in response.headers["content-type"]
class TestQueryParameters:
"""Test various query parameter patterns."""
def test_single_query_param(self):
response = client.get("/items?skip=0")
assert response.status_code == 200
def test_multiple_query_params(self):
response = client.get("/items?skip=0&limit=10")
assert response.status_code == 200
def test_query_params_with_params_dict(self):
"""Use the params argument for cleaner code."""
response = client.get("/items", params={
"skip": 0,
"limit": 10,
"search": "test"
})
assert response.status_code == 200
def test_query_param_validation(self):
"""Test that invalid query params are rejected."""
response = client.get("/items?limit=-1")
assert response.status_code == 422
# tests/test_responses.py
"""
Comprehensive examples of inspecting response objects.
"""
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_response_status_codes():
"""Check various status code scenarios."""
# Success
assert client.get("/items").status_code == 200
# Created
response = client.post("/items", json={
"title": "New", "description": "Test",
"price": 1.0, "quantity": 1
})
assert response.status_code == 201
# Not Found
assert client.get("/items/99999").status_code == 404
# Validation Error
assert client.post("/items", json={}).status_code == 422
def test_response_json_body():
"""Parse and inspect JSON response body."""
response = client.get("/items")
# Parse JSON
data = response.json()
# Check structure
assert isinstance(data, list)
# Check individual items if any exist
if len(data) > 0:
item = data[0]
assert "id" in item
assert "title" in item
assert "price" in item
def test_response_headers():
"""Inspect response headers."""
response = client.get("/items")
# Check content type
assert "application/json" in response.headers["content-type"]
# Check custom headers if your app sets them
# assert response.headers.get("X-Request-ID") is not None
def test_response_text():
"""Access raw response text."""
response = client.get("/items")
# Raw text representation
text = response.text
assert isinstance(text, str)
# Useful for debugging
print(f"Response body: {text}")
def test_response_cookies():
"""Check cookies set by the response."""
response = client.post("/auth/login", json={
"email": "test@example.com",
"password": "TestPass123!"
})
# Access cookies
cookies = response.cookies
# If your app sets cookies:
# assert "session_id" in cookies
def test_response_timing():
"""Measure response time (useful for performance assertions)."""
import time
start = time.time()
response = client.get("/health")
elapsed = time.time() - start
assert response.status_code == 200
assert elapsed < 1.0 # Should respond within 1 second
While TestClient works synchronously, FastAPI endpoints are often asynchronous. For testing async code directly (such as async dependencies, database queries, or background tasks), you need async test support.
# Install the required packages pip install pytest-asyncio httpx
# pytest.ini — enable auto mode for async tests [pytest] asyncio_mode = auto
# tests/test_async.py
"""
Async testing with httpx AsyncClient.
The AsyncClient communicates with your FastAPI app using ASGI,
which means it can test truly async behavior.
"""
import pytest
from httpx import AsyncClient, ASGITransport
from app.main import app
@pytest.fixture
async def async_client():
"""Create an async test client."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
class TestAsyncEndpoints:
"""Test endpoints using async client."""
async def test_read_root(self, async_client):
"""Test GET / asynchronously."""
response = await async_client.get("/")
assert response.status_code == 200
assert response.json()["message"] == "Hello, World!"
async def test_create_item_async(self, async_client):
"""Test POST /items asynchronously."""
response = await async_client.post("/items", json={
"title": "Async Item",
"description": "Created in async test",
"price": 42.00,
"quantity": 3
})
assert response.status_code == 201
assert response.json()["title"] == "Async Item"
async def test_concurrent_requests(self, async_client):
"""Test multiple concurrent requests."""
import asyncio
# Send 10 requests concurrently
tasks = [
async_client.get(f"/items/{i}")
for i in range(1, 11)
]
responses = await asyncio.gather(*tasks, return_exceptions=True)
# All should complete without errors
for response in responses:
assert not isinstance(response, Exception)
# app/dependencies.py
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///./app.db"
async_engine = create_async_engine(ASYNC_DATABASE_URL)
AsyncSessionLocal = sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""Yield an async database session."""
async with AsyncSessionLocal() as session:
yield session
# tests/test_async_db.py
import pytest
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from app.main import app
from app.database import Base
from app.dependencies import get_async_db
# Test async engine using in-memory SQLite
TEST_ASYNC_DATABASE_URL = "sqlite+aiosqlite://"
test_async_engine = create_async_engine(TEST_ASYNC_DATABASE_URL)
TestAsyncSession = sessionmaker(
test_async_engine, class_=AsyncSession, expire_on_commit=False
)
@pytest.fixture(autouse=True)
async def setup_test_db():
"""Create tables before each test, drop after."""
async with test_async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with test_async_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def async_db_session():
"""Provide a test async database session."""
async with TestAsyncSession() as session:
yield session
@pytest.fixture
async def async_client(async_db_session):
"""Async client with overridden database dependency."""
async def override_get_async_db():
yield async_db_session
app.dependency_overrides[get_async_db] = override_get_async_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class TestAsyncDatabaseOperations:
"""Test async database operations."""
async def test_create_and_retrieve_user(self, async_client):
"""Test creating a user with async database."""
# Create
response = await async_client.post("/users", json={
"name": "Async User",
"email": "async@example.com",
"password": "AsyncPass123!"
})
assert response.status_code == 201
user_id = response.json()["id"]
# Retrieve
response = await async_client.get(f"/users/{user_id}")
assert response.status_code == 200
assert response.json()["name"] == "Async User"
| Scenario | Use Sync (TestClient) | Use Async (AsyncClient) |
|---|---|---|
| Simple endpoint testing | Yes | Optional |
| Testing async dependencies | No | Yes |
| Testing async database operations | No | Yes |
| Testing WebSockets | Yes (built-in support) | No |
| Testing concurrent behavior | No | Yes |
| Testing background tasks | Partial | Yes |
| Testing SSE/streaming | Partial | Yes |
TestClient for most tests. Only use AsyncClient when you specifically need to test async behavior, concurrent operations, or async dependencies.
Route testing is where you verify that your API endpoints accept correct input, return proper responses, and handle edge cases gracefully. Let us build a comprehensive set of route tests.
# app/routers/items.py
"""
Item router with full CRUD operations.
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.item import Item as ItemModel
from app.schemas.item import ItemCreate, ItemUpdate, ItemResponse
router = APIRouter(prefix="/items", tags=["items"])
@router.get("/", response_model=List[ItemResponse])
def list_items(
skip: int = Query(0, ge=0, description="Number of items to skip"),
limit: int = Query(10, ge=1, le=100, description="Max items to return"),
search: Optional[str] = Query(None, min_length=1, max_length=100),
db: Session = Depends(get_db),
):
"""List items with pagination and optional search."""
query = db.query(ItemModel)
if search:
query = query.filter(ItemModel.title.contains(search))
items = query.offset(skip).limit(limit).all()
return items
@router.get("/{item_id}", response_model=ItemResponse)
def get_item(item_id: int, db: Session = Depends(get_db)):
"""Get a single item by ID."""
item = db.query(ItemModel).filter(ItemModel.id == item_id).first()
if not item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Item with id {item_id} not found"
)
return item
@router.post("/", response_model=ItemResponse, status_code=status.HTTP_201_CREATED)
def create_item(item_data: ItemCreate, db: Session = Depends(get_db)):
"""Create a new item."""
item = ItemModel(**item_data.model_dump())
db.add(item)
db.commit()
db.refresh(item)
return item
@router.put("/{item_id}", response_model=ItemResponse)
def update_item(item_id: int, item_data: ItemUpdate, db: Session = Depends(get_db)):
"""Update an existing item (full replacement)."""
item = db.query(ItemModel).filter(ItemModel.id == item_id).first()
if not item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Item with id {item_id} not found"
)
for field, value in item_data.model_dump().items():
setattr(item, field, value)
db.commit()
db.refresh(item)
return item
@router.delete("/{item_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_item(item_id: int, db: Session = Depends(get_db)):
"""Delete an item."""
item = db.query(ItemModel).filter(ItemModel.id == item_id).first()
if not item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Item with id {item_id} not found"
)
db.delete(item)
db.commit()
# tests/integration/test_items.py
"""
Comprehensive tests for the items router.
Covers all CRUD operations, edge cases, and error scenarios.
"""
import pytest
class TestListItems:
"""Tests for GET /items."""
def test_list_items_empty(self, client):
"""Returns empty list when no items exist."""
response = client.get("/items")
assert response.status_code == 200
assert response.json() == []
def test_list_items_returns_all(self, client, sample_item_data):
"""Returns all created items."""
# Create 3 items
for i in range(3):
data = {**sample_item_data, "title": f"Item {i}"}
client.post("/items", json=data)
response = client.get("/items")
assert response.status_code == 200
assert len(response.json()) == 3
def test_list_items_pagination_skip(self, client, sample_item_data):
"""Skip parameter works correctly."""
for i in range(5):
data = {**sample_item_data, "title": f"Item {i}"}
client.post("/items", json=data)
response = client.get("/items?skip=3")
assert response.status_code == 200
assert len(response.json()) == 2 # 5 total - 3 skipped
def test_list_items_pagination_limit(self, client, sample_item_data):
"""Limit parameter caps the results."""
for i in range(5):
data = {**sample_item_data, "title": f"Item {i}"}
client.post("/items", json=data)
response = client.get("/items?limit=2")
assert response.status_code == 200
assert len(response.json()) == 2
def test_list_items_pagination_combined(self, client, sample_item_data):
"""Skip and limit work together."""
for i in range(10):
data = {**sample_item_data, "title": f"Item {i}"}
client.post("/items", json=data)
response = client.get("/items?skip=2&limit=3")
assert response.status_code == 200
data = response.json()
assert len(data) == 3
assert data[0]["title"] == "Item 2"
def test_list_items_search(self, client):
"""Search filter returns matching items only."""
client.post("/items", json={
"title": "Python Tutorial",
"description": "Learn Python",
"price": 29.99, "quantity": 1
})
client.post("/items", json={
"title": "Java Tutorial",
"description": "Learn Java",
"price": 39.99, "quantity": 1
})
response = client.get("/items?search=Python")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["title"] == "Python Tutorial"
def test_list_items_search_no_results(self, client, sample_item_data):
"""Search with no matches returns empty list."""
client.post("/items", json=sample_item_data)
response = client.get("/items?search=nonexistent")
assert response.status_code == 200
assert response.json() == []
def test_list_items_invalid_skip(self, client):
"""Negative skip value returns 422."""
response = client.get("/items?skip=-1")
assert response.status_code == 422
def test_list_items_invalid_limit(self, client):
"""Limit exceeding max returns 422."""
response = client.get("/items?limit=101")
assert response.status_code == 422
def test_list_items_limit_zero(self, client):
"""Zero limit returns 422."""
response = client.get("/items?limit=0")
assert response.status_code == 422
class TestGetItem:
"""Tests for GET /items/{item_id}."""
def test_get_existing_item(self, client, sample_item_data):
"""Successfully retrieve an existing item."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
response = client.get(f"/items/{item_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == item_id
assert data["title"] == sample_item_data["title"]
assert data["price"] == sample_item_data["price"]
def test_get_nonexistent_item(self, client):
"""Returns 404 for non-existent item."""
response = client.get("/items/99999")
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_get_item_invalid_id_type(self, client):
"""Returns 422 for invalid ID type."""
response = client.get("/items/abc")
assert response.status_code == 422
def test_get_item_response_structure(self, client, sample_item_data):
"""Verify the response has all expected fields."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
response = client.get(f"/items/{item_id}")
data = response.json()
expected_fields = {"id", "title", "description", "price", "quantity"}
assert expected_fields.issubset(set(data.keys()))
class TestCreateItem:
"""Tests for POST /items."""
def test_create_item_success(self, client, sample_item_data):
"""Successfully create a new item."""
response = client.post("/items", json=sample_item_data)
assert response.status_code == 201
data = response.json()
assert data["title"] == sample_item_data["title"]
assert data["description"] == sample_item_data["description"]
assert data["price"] == sample_item_data["price"]
assert "id" in data
def test_create_item_auto_generates_id(self, client, sample_item_data):
"""Created item gets an auto-generated ID."""
response = client.post("/items", json=sample_item_data)
assert "id" in response.json()
assert isinstance(response.json()["id"], int)
def test_create_item_persists(self, client, sample_item_data):
"""Created item can be retrieved afterward."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
get_response = client.get(f"/items/{item_id}")
assert get_response.status_code == 200
assert get_response.json()["title"] == sample_item_data["title"]
def test_create_multiple_items(self, client, multiple_items_data):
"""Create multiple items and verify count."""
for item_data in multiple_items_data:
response = client.post("/items", json=item_data)
assert response.status_code == 201
list_response = client.get("/items")
assert len(list_response.json()) == len(multiple_items_data)
def test_create_item_missing_required_field(self, client):
"""Missing required fields return 422."""
response = client.post("/items", json={"title": "Incomplete"})
assert response.status_code == 422
def test_create_item_empty_body(self, client):
"""Empty body returns 422."""
response = client.post("/items", json={})
assert response.status_code == 422
class TestUpdateItem:
"""Tests for PUT /items/{item_id}."""
def test_update_item_success(self, client, sample_item_data):
"""Successfully update an existing item."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
update_data = {
"title": "Updated Title",
"description": "Updated description",
"price": 99.99,
"quantity": 20
}
response = client.put(f"/items/{item_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["title"] == "Updated Title"
assert data["price"] == 99.99
def test_update_nonexistent_item(self, client):
"""Updating non-existent item returns 404."""
update_data = {
"title": "Ghost", "description": "Does not exist",
"price": 0.0, "quantity": 0
}
response = client.put("/items/99999", json=update_data)
assert response.status_code == 404
def test_update_persists_changes(self, client, sample_item_data):
"""Changes are persisted after update."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
client.put(f"/items/{item_id}", json={
"title": "Persisted",
"description": "This should persist",
"price": 55.55,
"quantity": 5
})
get_response = client.get(f"/items/{item_id}")
assert get_response.json()["title"] == "Persisted"
class TestDeleteItem:
"""Tests for DELETE /items/{item_id}."""
def test_delete_item_success(self, client, sample_item_data):
"""Successfully delete an existing item."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
response = client.delete(f"/items/{item_id}")
assert response.status_code == 204
def test_delete_item_removes_from_database(self, client, sample_item_data):
"""Deleted item can no longer be retrieved."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
client.delete(f"/items/{item_id}")
get_response = client.get(f"/items/{item_id}")
assert get_response.status_code == 404
def test_delete_nonexistent_item(self, client):
"""Deleting non-existent item returns 404."""
response = client.delete("/items/99999")
assert response.status_code == 404
def test_delete_reduces_count(self, client, sample_item_data):
"""Deleting an item reduces the total count."""
# Create 3 items
ids = []
for i in range(3):
data = {**sample_item_data, "title": f"Item {i}"}
resp = client.post("/items", json=data)
ids.append(resp.json()["id"])
assert len(client.get("/items").json()) == 3
# Delete one
client.delete(f"/items/{ids[0]}")
assert len(client.get("/items").json()) == 2
FastAPI uses Pydantic for request validation. Testing validation ensures that your API properly rejects bad input and returns informative error messages.
# app/schemas/item.py
from pydantic import BaseModel, Field, field_validator
from typing import Optional
class ItemCreate(BaseModel):
"""Schema for creating a new item."""
title: str = Field(..., min_length=1, max_length=200)
description: Optional[str] = Field(None, max_length=1000)
price: float = Field(..., gt=0, description="Price must be positive")
quantity: int = Field(..., ge=0, description="Quantity cannot be negative")
@field_validator("title")
@classmethod
def title_must_not_be_blank(cls, v):
if v.strip() == "":
raise ValueError("Title cannot be blank or whitespace only")
return v.strip()
@field_validator("price")
@classmethod
def price_must_have_two_decimals(cls, v):
if round(v, 2) != v:
raise ValueError("Price must have at most 2 decimal places")
return v
class ItemUpdate(BaseModel):
"""Schema for updating an item."""
title: str = Field(..., min_length=1, max_length=200)
description: Optional[str] = Field(None, max_length=1000)
price: float = Field(..., gt=0)
quantity: int = Field(..., ge=0)
class ItemResponse(BaseModel):
"""Schema for item responses."""
id: int
title: str
description: Optional[str]
price: float
quantity: int
model_config = {"from_attributes": True}
# tests/unit/test_schemas.py
"""
Test Pydantic schema validation independently of the API.
These are pure unit tests — no HTTP calls needed.
"""
import pytest
from pydantic import ValidationError
from app.schemas.item import ItemCreate, ItemUpdate, ItemResponse
class TestItemCreateValidation:
"""Test ItemCreate schema validation rules."""
def test_valid_item(self):
"""Valid data creates an item successfully."""
item = ItemCreate(
title="Test Item",
description="A valid item",
price=29.99,
quantity=10
)
assert item.title == "Test Item"
assert item.price == 29.99
def test_title_required(self):
"""Missing title raises ValidationError."""
with pytest.raises(ValidationError) as exc_info:
ItemCreate(
description="No title",
price=10.0,
quantity=1
)
errors = exc_info.value.errors()
assert any(e["loc"] == ("title",) for e in errors)
def test_title_min_length(self):
"""Empty string title is rejected."""
with pytest.raises(ValidationError):
ItemCreate(title="", price=10.0, quantity=1)
def test_title_max_length(self):
"""Title exceeding 200 chars is rejected."""
with pytest.raises(ValidationError):
ItemCreate(
title="x" * 201,
price=10.0,
quantity=1
)
def test_title_whitespace_only(self):
"""Whitespace-only title is rejected by custom validator."""
with pytest.raises(ValidationError) as exc_info:
ItemCreate(title=" ", price=10.0, quantity=1)
assert "blank" in str(exc_info.value).lower()
def test_title_stripped(self):
"""Title is stripped of leading/trailing whitespace."""
item = ItemCreate(
title=" Trimmed Title ",
price=10.0,
quantity=1
)
assert item.title == "Trimmed Title"
def test_price_required(self):
"""Missing price raises ValidationError."""
with pytest.raises(ValidationError):
ItemCreate(title="No Price", quantity=1)
def test_price_must_be_positive(self):
"""Zero or negative price is rejected."""
with pytest.raises(ValidationError):
ItemCreate(title="Free Item", price=0, quantity=1)
with pytest.raises(ValidationError):
ItemCreate(title="Negative", price=-10.0, quantity=1)
def test_price_too_many_decimals(self):
"""Price with more than 2 decimal places is rejected."""
with pytest.raises(ValidationError):
ItemCreate(title="Precise", price=10.999, quantity=1)
def test_quantity_cannot_be_negative(self):
"""Negative quantity is rejected."""
with pytest.raises(ValidationError):
ItemCreate(title="Negative Qty", price=10.0, quantity=-1)
def test_quantity_zero_allowed(self):
"""Zero quantity is allowed (out of stock)."""
item = ItemCreate(title="Out of Stock", price=10.0, quantity=0)
assert item.quantity == 0
def test_description_optional(self):
"""Description is optional and defaults to None."""
item = ItemCreate(title="No Description", price=10.0, quantity=1)
assert item.description is None
def test_description_max_length(self):
"""Description exceeding 1000 chars is rejected."""
with pytest.raises(ValidationError):
ItemCreate(
title="Long Desc",
description="x" * 1001,
price=10.0,
quantity=1
)
class TestValidationErrorResponses:
"""Test that the API returns proper validation error responses."""
def test_missing_field_error_format(self, client):
"""Validation errors include field location and message."""
response = client.post("/items", json={
"description": "Missing title and price"
})
assert response.status_code == 422
errors = response.json()["detail"]
assert isinstance(errors, list)
assert len(errors) > 0
# Each error should have location, message, and type
for error in errors:
assert "loc" in error
assert "msg" in error
assert "type" in error
def test_invalid_type_error(self, client):
"""Sending wrong type returns descriptive error."""
response = client.post("/items", json={
"title": "Test",
"price": "not_a_number",
"quantity": 1
})
assert response.status_code == 422
errors = response.json()["detail"]
price_errors = [e for e in errors if "price" in str(e.get("loc", []))]
assert len(price_errors) > 0
def test_multiple_validation_errors(self, client):
"""Multiple invalid fields return multiple errors."""
response = client.post("/items", json={
"title": "",
"price": -5,
"quantity": -1
})
assert response.status_code == 422
errors = response.json()["detail"]
assert len(errors) >= 2 # At least title and price errors
def test_extra_fields_ignored(self, client):
"""Extra fields in request body are silently ignored."""
response = client.post("/items", json={
"title": "Valid Item",
"description": "Valid",
"price": 10.0,
"quantity": 1,
"extra_field": "should be ignored",
"another_extra": 42
})
assert response.status_code == 201
assert "extra_field" not in response.json()
Database testing verifies that your application correctly interacts with the database. The key challenge is test isolation — each test should start with a clean database and not affect other tests.
# app/database.py
"""
Production database configuration.
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, DeclarativeBase
SQLALCHEMY_DATABASE_URL = "postgresql://user:pass@localhost/mydb"
engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
class Base(DeclarativeBase):
pass
def get_db():
"""Dependency that provides a database session."""
db = SessionLocal()
try:
yield db
finally:
db.close()
# app/models/item.py
"""SQLAlchemy Item model."""
from sqlalchemy import Column, Integer, String, Float
from app.database import Base
class Item(Base):
__tablename__ = "items"
id = Column(Integer, primary_key=True, index=True)
title = Column(String(200), nullable=False)
description = Column(String(1000), nullable=True)
price = Column(Float, nullable=False)
quantity = Column(Integer, nullable=False, default=0)
# app/models/user.py
"""SQLAlchemy User model."""
from sqlalchemy import Column, Integer, String, Boolean, DateTime
from sqlalchemy.sql import func
from app.database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
email = Column(String(255), unique=True, nullable=False, index=True)
hashed_password = Column(String(255), nullable=False)
is_active = Column(Boolean, default=True)
role = Column(String(20), default="user")
created_at = Column(DateTime(timezone=True), server_default=func.now())
# tests/conftest.py — Database fixtures with transaction rollback
"""
Two approaches to database test isolation:
1. Drop/Create tables (simple but slower)
2. Transaction rollback (fast but more complex)
Below we show the rollback approach for maximum speed.
"""
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.main import app
from app.database import Base, get_db
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
# Create tables once for the entire test session
@pytest.fixture(scope="session", autouse=True)
def create_tables():
"""Create all tables once before running tests."""
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def db_session():
"""
Transaction rollback pattern:
1. Begin a transaction
2. Create a session bound to this transaction
3. Run the test
4. Roll back the transaction (undoing all changes)
This is faster than drop/create because it avoids DDL operations.
"""
connection = engine.connect()
transaction = connection.begin()
# Bind a session to the ongoing transaction
session = TestingSessionLocal(bind=connection)
# Begin a nested transaction (savepoint)
nested = connection.begin_nested()
# Restart the nested transaction on each commit
@event.listens_for(session, "after_transaction_end")
def end_savepoint(session, transaction):
nonlocal nested
if not nested.is_active:
nested = connection.begin_nested()
yield session
# Roll back everything
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(scope="function")
def client(db_session):
"""TestClient with rolled-back database session."""
def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
app.dependency_overrides.clear()
# tests/integration/test_database.py
"""
Tests that verify database operations work correctly.
"""
import pytest
from app.models.item import Item
from app.models.user import User
class TestItemDatabaseOperations:
"""Test Item model database interactions."""
def test_create_item_in_database(self, db_session):
"""Directly create and query an item from the database."""
item = Item(
title="DB Test Item",
description="Created directly in DB",
price=19.99,
quantity=5
)
db_session.add(item)
db_session.commit()
# Query it back
saved = db_session.query(Item).filter_by(title="DB Test Item").first()
assert saved is not None
assert saved.price == 19.99
assert saved.quantity == 5
def test_update_item_in_database(self, db_session):
"""Update an item and verify changes persist."""
item = Item(title="Original", price=10.0, quantity=1)
db_session.add(item)
db_session.commit()
# Update
item.title = "Modified"
item.price = 25.0
db_session.commit()
# Verify
updated = db_session.query(Item).filter_by(id=item.id).first()
assert updated.title == "Modified"
assert updated.price == 25.0
def test_delete_item_from_database(self, db_session):
"""Delete an item and verify it is gone."""
item = Item(title="To Delete", price=5.0, quantity=1)
db_session.add(item)
db_session.commit()
item_id = item.id
db_session.delete(item)
db_session.commit()
deleted = db_session.query(Item).filter_by(id=item_id).first()
assert deleted is None
def test_query_items_with_filter(self, db_session):
"""Test complex queries with filters."""
items = [
Item(title="Cheap Item", price=5.0, quantity=100),
Item(title="Mid Item", price=25.0, quantity=50),
Item(title="Expensive Item", price=100.0, quantity=10),
]
db_session.add_all(items)
db_session.commit()
# Find items under $30
cheap_items = db_session.query(Item).filter(Item.price < 30.0).all()
assert len(cheap_items) == 2
# Find items with quantity over 20
stocked = db_session.query(Item).filter(Item.quantity > 20).all()
assert len(stocked) == 2
def test_database_isolation_between_tests(self, db_session):
"""
Verify that each test starts with a clean database.
If isolation works correctly, this test should find
zero items (regardless of what previous tests created).
"""
count = db_session.query(Item).count()
assert count == 0
class TestUserDatabaseOperations:
"""Test User model database interactions."""
def test_create_user(self, db_session):
"""Create a user and verify all fields."""
user = User(
name="Alice",
email="alice@example.com",
hashed_password="hashed_abc123",
role="user"
)
db_session.add(user)
db_session.commit()
saved = db_session.query(User).filter_by(email="alice@example.com").first()
assert saved is not None
assert saved.name == "Alice"
assert saved.is_active is True # Default value
assert saved.role == "user"
def test_email_uniqueness(self, db_session):
"""Duplicate email should raise an error."""
user1 = User(
name="Alice",
email="duplicate@example.com",
hashed_password="hash1"
)
db_session.add(user1)
db_session.commit()
user2 = User(
name="Bob",
email="duplicate@example.com",
hashed_password="hash2"
)
db_session.add(user2)
with pytest.raises(Exception): # IntegrityError
db_session.commit()
db_session.rollback()
def test_user_default_values(self, db_session):
"""Verify default values are set correctly."""
user = User(
name="Default User",
email="default@example.com",
hashed_password="hash"
)
db_session.add(user)
db_session.commit()
assert user.is_active is True
assert user.role == "user"
assert user.created_at is not None
# tests/factories.py
"""
Factory fixtures for creating test data.
Factories encapsulate object creation logic, making tests
cleaner and reducing duplication.
"""
import pytest
from app.models.item import Item
from app.models.user import User
@pytest.fixture
def item_factory(db_session):
"""
Factory fixture for creating Item instances.
Usage:
def test_something(item_factory):
item = item_factory(title="Custom Title", price=25.0)
"""
created_items = []
def _create_item(
title="Test Item",
description="Test Description",
price=10.0,
quantity=5,
):
item = Item(
title=title,
description=description,
price=price,
quantity=quantity,
)
db_session.add(item)
db_session.commit()
db_session.refresh(item)
created_items.append(item)
return item
yield _create_item
@pytest.fixture
def user_factory(db_session):
"""Factory fixture for creating User instances."""
counter = 0
def _create_user(
name=None,
email=None,
password="hashed_default",
role="user",
is_active=True,
):
nonlocal counter
counter += 1
user = User(
name=name or f"User {counter}",
email=email or f"user{counter}@example.com",
hashed_password=password,
role=role,
is_active=is_active,
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
yield _create_user
# tests/integration/test_with_factories.py
"""Using factory fixtures for cleaner tests."""
class TestWithFactories:
"""Demonstrate factory fixture usage."""
def test_list_items_with_factory(self, client, item_factory):
"""Use the factory to create test data."""
item_factory(title="Item A", price=10.0)
item_factory(title="Item B", price=20.0)
item_factory(title="Item C", price=30.0)
response = client.get("/items")
assert len(response.json()) == 3
def test_search_with_factory(self, client, item_factory):
"""Factory makes it easy to set up specific scenarios."""
item_factory(title="Python Guide", price=29.99)
item_factory(title="Java Guide", price=39.99)
item_factory(title="Python Advanced", price=49.99)
response = client.get("/items?search=Python")
assert len(response.json()) == 2
def test_user_roles_with_factory(self, client, user_factory):
"""Create users with different roles."""
admin = user_factory(name="Admin", role="admin")
user = user_factory(name="Regular", role="user")
assert admin.role == "admin"
assert user.role == "user"
FastAPI’s dependency injection system is one of its most powerful features, and app.dependency_overrides makes it trivially easy to swap real dependencies with test doubles.
# app/dependencies.py
"""
Application dependencies that can be overridden in tests.
"""
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.user import User
security = HTTPBearer()
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db),
) -> User:
"""Decode JWT token and return the current user."""
token = credentials.credentials
# In production, this decodes and validates the JWT
payload = decode_jwt_token(token) # Raises on invalid token
user = db.query(User).filter(User.id == payload["sub"]).first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found"
)
return user
def get_current_admin(
user: User = Depends(get_current_user),
) -> User:
"""Ensure the current user has admin role."""
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin access required"
)
return user
class EmailService:
"""Service for sending emails."""
def send_welcome_email(self, email: str, name: str) -> bool:
"""Send a welcome email to a new user."""
# In production, this calls an email API
return True
def send_password_reset(self, email: str, token: str) -> bool:
"""Send a password reset email."""
return True
def get_email_service() -> EmailService:
"""Dependency that provides the email service."""
return EmailService()
# tests/test_dependency_overrides.py
"""
Demonstrates various ways to override dependencies in tests.
"""
import pytest
from fastapi.testclient import TestClient
from app.main import app
from app.dependencies import (
get_current_user,
get_current_admin,
get_email_service,
EmailService,
)
from app.models.user import User
# ---------------------------------------------------------------------------
# Approach 1: Simple function override
# ---------------------------------------------------------------------------
class TestWithSimpleOverride:
"""Override a dependency with a simple function."""
@pytest.fixture(autouse=True)
def setup(self):
"""Set up and tear down dependency overrides."""
# Create a fake user to return
fake_user = User(
id=1,
name="Test User",
email="test@example.com",
hashed_password="fake",
role="user",
is_active=True,
)
def override_get_current_user():
return fake_user
app.dependency_overrides[get_current_user] = override_get_current_user
yield
app.dependency_overrides.clear()
def test_protected_endpoint(self):
"""Access a protected endpoint without real authentication."""
client = TestClient(app)
response = client.get("/users/me")
assert response.status_code == 200
assert response.json()["email"] == "test@example.com"
# ---------------------------------------------------------------------------
# Approach 2: Parameterized override fixture
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_current_user():
"""Fixture that creates a mock user and overrides the dependency."""
def _create_override(
user_id=1,
name="Test User",
email="test@example.com",
role="user",
):
user = User(
id=user_id,
name=name,
email=email,
hashed_password="fake",
role=role,
is_active=True,
)
def override():
return user
app.dependency_overrides[get_current_user] = override
return user
yield _create_override
app.dependency_overrides.clear()
class TestWithParameterizedOverride:
"""Use parameterized overrides for different scenarios."""
def test_regular_user_access(self, mock_current_user):
"""Test as a regular user."""
mock_current_user(role="user")
client = TestClient(app)
response = client.get("/users/me")
assert response.status_code == 200
def test_admin_endpoint_as_user(self, mock_current_user):
"""Regular user cannot access admin endpoints."""
mock_current_user(role="user")
client = TestClient(app)
response = client.get("/admin/dashboard")
assert response.status_code == 403
def test_admin_endpoint_as_admin(self, mock_current_user):
"""Admin user can access admin endpoints."""
mock_current_user(role="admin")
# Also override the admin check
def override_admin():
return User(
id=1, name="Admin", email="admin@example.com",
hashed_password="fake", role="admin", is_active=True
)
app.dependency_overrides[get_current_admin] = override_admin
client = TestClient(app)
response = client.get("/admin/dashboard")
assert response.status_code == 200
# ---------------------------------------------------------------------------
# Approach 3: Override with a mock class
# ---------------------------------------------------------------------------
class MockEmailService(EmailService):
"""Mock email service that records calls instead of sending."""
def __init__(self):
self.sent_emails = []
def send_welcome_email(self, email: str, name: str) -> bool:
self.sent_emails.append({
"type": "welcome",
"email": email,
"name": name,
})
return True
def send_password_reset(self, email: str, token: str) -> bool:
self.sent_emails.append({
"type": "password_reset",
"email": email,
"token": token,
})
return True
@pytest.fixture
def mock_email_service():
"""Override the email service with a mock."""
mock_service = MockEmailService()
def override():
return mock_service
app.dependency_overrides[get_email_service] = override
yield mock_service
app.dependency_overrides.clear()
class TestEmailServiceOverride:
"""Test email-related functionality with mock service."""
def test_registration_sends_welcome_email(
self, client, mock_email_service
):
"""Registering a user should send a welcome email."""
client.post("/auth/register", json={
"name": "New User",
"email": "new@example.com",
"password": "Pass123!"
})
# Verify the welcome email was "sent"
assert len(mock_email_service.sent_emails) == 1
assert mock_email_service.sent_emails[0]["type"] == "welcome"
assert mock_email_service.sent_emails[0]["email"] == "new@example.com"
def test_password_reset_sends_email(
self, client, mock_email_service
):
"""Password reset should send a reset email."""
# Assume user already exists
client.post("/auth/password-reset", json={
"email": "existing@example.com"
})
reset_emails = [
e for e in mock_email_service.sent_emails
if e["type"] == "password_reset"
]
assert len(reset_emails) == 1
app.dependency_overrides.clear() after your tests to prevent overrides from leaking between tests. Using a fixture with proper teardown (after yield) ensures this happens automatically.
Real applications interact with external services: APIs, email providers, payment gateways, and cloud storage. In tests, you should mock these services to avoid real network calls, costs, and flakiness.
# app/services/weather_service.py
"""External weather API service."""
import httpx
class WeatherService:
"""Fetches weather data from an external API."""
BASE_URL = "https://api.weatherapi.com/v1"
def __init__(self, api_key: str):
self.api_key = api_key
def get_current_weather(self, city: str) -> dict:
"""Fetch current weather for a city."""
response = httpx.get(
f"{self.BASE_URL}/current.json",
params={"key": self.api_key, "q": city}
)
response.raise_for_status()
return response.json()
def get_forecast(self, city: str, days: int = 3) -> dict:
"""Fetch weather forecast."""
response = httpx.get(
f"{self.BASE_URL}/forecast.json",
params={"key": self.api_key, "q": city, "days": days}
)
response.raise_for_status()
return response.json()
# tests/test_mocking.py
"""
Comprehensive mocking examples using unittest.mock.
"""
from unittest.mock import patch, MagicMock, AsyncMock
import pytest
import httpx
from app.services.weather_service import WeatherService
class TestWeatherServiceMocking:
"""Mock external HTTP calls in the weather service."""
@patch("app.services.weather_service.httpx.get")
def test_get_current_weather(self, mock_get):
"""Mock the HTTP call to return fake weather data."""
# Configure the mock to return a fake response
mock_response = MagicMock()
mock_response.json.return_value = {
"location": {"name": "London"},
"current": {
"temp_c": 15.0,
"condition": {"text": "Partly cloudy"}
}
}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
# Call the service
service = WeatherService(api_key="fake-key")
result = service.get_current_weather("London")
# Verify the result
assert result["location"]["name"] == "London"
assert result["current"]["temp_c"] == 15.0
# Verify the HTTP call was made correctly
mock_get.assert_called_once()
call_args = mock_get.call_args
assert "London" in str(call_args)
@patch("app.services.weather_service.httpx.get")
def test_weather_api_error(self, mock_get):
"""Test handling of API errors."""
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Server Error",
request=MagicMock(),
response=MagicMock(status_code=500)
)
mock_get.return_value = mock_response
service = WeatherService(api_key="fake-key")
with pytest.raises(httpx.HTTPStatusError):
service.get_current_weather("London")
@patch("app.services.weather_service.httpx.get")
def test_weather_network_error(self, mock_get):
"""Test handling of network errors."""
mock_get.side_effect = httpx.ConnectError("Connection refused")
service = WeatherService(api_key="fake-key")
with pytest.raises(httpx.ConnectError):
service.get_current_weather("London")
@patch("app.services.weather_service.httpx.get")
def test_forecast_calls_correct_endpoint(self, mock_get):
"""Verify the forecast method calls the correct URL."""
mock_response = MagicMock()
mock_response.json.return_value = {"forecast": {"forecastday": []}}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
service = WeatherService(api_key="test-key")
service.get_forecast("Paris", days=5)
# Check that the correct URL was called
call_args = mock_get.call_args
url = call_args[0][0] if call_args[0] else call_args[1].get("url", "")
assert "forecast.json" in str(call_args)
class TestMockingPatterns:
"""Common mocking patterns for FastAPI tests."""
def test_mock_with_context_manager(self):
"""Use patch as a context manager."""
with patch("app.services.weather_service.httpx.get") as mock_get:
mock_get.return_value = MagicMock(
json=MagicMock(return_value={"data": "test"}),
raise_for_status=MagicMock()
)
service = WeatherService(api_key="key")
result = service.get_current_weather("Tokyo")
assert result == {"data": "test"}
def test_mock_with_side_effect_list(self):
"""Return different values on consecutive calls."""
with patch("app.services.weather_service.httpx.get") as mock_get:
responses = [
MagicMock(
json=MagicMock(return_value={"temp": 20}),
raise_for_status=MagicMock()
),
MagicMock(
json=MagicMock(return_value={"temp": 25}),
raise_for_status=MagicMock()
),
]
mock_get.side_effect = responses
service = WeatherService(api_key="key")
first = service.get_current_weather("London")
second = service.get_current_weather("Paris")
assert first["temp"] == 20
assert second["temp"] == 25
def test_assert_call_count(self):
"""Verify how many times a mock was called."""
with patch("app.services.weather_service.httpx.get") as mock_get:
mock_get.return_value = MagicMock(
json=MagicMock(return_value={}),
raise_for_status=MagicMock()
)
service = WeatherService(api_key="key")
service.get_current_weather("A")
service.get_current_weather("B")
service.get_current_weather("C")
assert mock_get.call_count == 3
# tests/test_pytest_mock.py
"""
pytest-mock provides a cleaner interface for mocking.
The 'mocker' fixture is automatically available.
"""
import pytest
import httpx
class TestWithPytestMock:
"""Use the mocker fixture for cleaner mocking."""
def test_mock_external_call(self, mocker):
"""Mock using the mocker fixture."""
mock_get = mocker.patch("app.services.weather_service.httpx.get")
mock_get.return_value.json.return_value = {
"location": {"name": "Berlin"},
"current": {"temp_c": 22.0}
}
mock_get.return_value.raise_for_status = mocker.MagicMock()
from app.services.weather_service import WeatherService
service = WeatherService(api_key="key")
result = service.get_current_weather("Berlin")
assert result["current"]["temp_c"] == 22.0
def test_spy_on_method(self, mocker):
"""Spy on a method to verify it was called without changing behavior."""
from app.services.weather_service import WeatherService
spy = mocker.spy(WeatherService, "get_current_weather")
# This would still make the real call, so we also mock httpx
mocker.patch("app.services.weather_service.httpx.get").return_value = (
mocker.MagicMock(
json=mocker.MagicMock(return_value={"data": "test"}),
raise_for_status=mocker.MagicMock()
)
)
service = WeatherService(api_key="key")
service.get_current_weather("London")
# Verify the method was called
spy.assert_called_once_with(service, "London")
def test_mock_email_sending(self, mocker, client):
"""Mock email sending in an endpoint test."""
mock_send = mocker.patch(
"app.services.email_service.EmailService.send_welcome_email",
return_value=True
)
response = client.post("/auth/register", json={
"name": "New User",
"email": "new@example.com",
"password": "Pass123!"
})
# Verify the email was "sent"
mock_send.assert_called_once_with("new@example.com", "New User")
def test_mock_datetime(self, mocker):
"""Mock the current time for time-dependent tests."""
from datetime import datetime
mock_now = mocker.patch("app.services.user_service.datetime")
mock_now.utcnow.return_value = datetime(2025, 6, 15, 12, 0, 0)
# Now any code that calls datetime.utcnow() will get the mocked time
# app/services/async_api_service.py
"""Async external API service."""
import httpx
class AsyncAPIService:
"""Async service for external API calls."""
async def fetch_user_data(self, user_id: int) -> dict:
"""Fetch user data from external API."""
async with httpx.AsyncClient() as client:
response = await client.get(
f"https://api.example.com/users/{user_id}"
)
response.raise_for_status()
return response.json()
# tests/test_async_mocking.py
"""Mocking async external calls."""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
class TestAsyncMocking:
"""Test async mocking patterns."""
async def test_mock_async_http_call(self):
"""Mock an async HTTP call."""
with patch("app.services.async_api_service.httpx.AsyncClient") as MockClient:
# Set up the mock
mock_response = MagicMock()
mock_response.json.return_value = {
"id": 1,
"name": "Mocked User"
}
mock_response.raise_for_status = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_client_instance.__aenter__.return_value = mock_client_instance
mock_client_instance.__aexit__.return_value = None
MockClient.return_value = mock_client_instance
from app.services.async_api_service import AsyncAPIService
service = AsyncAPIService()
result = await service.fetch_user_data(1)
assert result["name"] == "Mocked User"
mock_client_instance.get.assert_called_once()
Testing authentication ensures that your security layer works correctly: valid credentials grant access, invalid credentials are rejected, and roles are enforced.
# app/routers/auth.py
"""Authentication router."""
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from jose import jwt, JWTError
from passlib.context import CryptContext
from app.database import get_db
from app.models.user import User
from app.schemas.user import UserCreate, UserLogin, TokenResponse
router = APIRouter(prefix="/auth", tags=["auth"])
SECRET_KEY = "your-secret-key-here"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
security = HTTPBearer()
def create_access_token(data: dict, expires_delta: timedelta = None):
"""Create a JWT access token."""
to_encode = data.copy()
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def hash_password(password: str) -> str:
"""Hash a password."""
return pwd_context.hash(password)
@router.post("/register", status_code=status.HTTP_201_CREATED)
def register(user_data: UserCreate, db: Session = Depends(get_db)):
"""Register a new user."""
# Check if email already exists
existing = db.query(User).filter(User.email == user_data.email).first()
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Email already registered"
)
user = User(
name=user_data.name,
email=user_data.email,
hashed_password=hash_password(user_data.password),
role=getattr(user_data, "role", "user"),
)
db.add(user)
db.commit()
db.refresh(user)
return {"id": user.id, "name": user.name, "email": user.email}
@router.post("/login", response_model=TokenResponse)
def login(credentials: UserLogin, db: Session = Depends(get_db)):
"""Authenticate and return a JWT token."""
user = db.query(User).filter(User.email == credentials.email).first()
if not user or not verify_password(credentials.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is deactivated"
)
token = create_access_token(
data={"sub": str(user.id), "role": user.role},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
)
return {"access_token": token, "token_type": "bearer"}
# tests/integration/test_auth.py
"""
Comprehensive authentication tests.
"""
import pytest
from jose import jwt
from datetime import datetime, timedelta
# Helper to create test tokens
def create_test_token(
user_id: int = 1,
role: str = "user",
expired: bool = False,
secret: str = "your-secret-key-here",
):
"""Create a JWT token for testing."""
expire = datetime.utcnow() + (
timedelta(minutes=-5) if expired else timedelta(minutes=30)
)
payload = {
"sub": str(user_id),
"role": role,
"exp": expire,
}
return jwt.encode(payload, secret, algorithm="HS256")
class TestRegistration:
"""Tests for POST /auth/register."""
def test_register_success(self, client, sample_user_data):
"""Successfully register a new user."""
response = client.post("/auth/register", json=sample_user_data)
assert response.status_code == 201
data = response.json()
assert data["name"] == sample_user_data["name"]
assert data["email"] == sample_user_data["email"]
assert "id" in data
# Password should NOT be in the response
assert "password" not in data
assert "hashed_password" not in data
def test_register_duplicate_email(self, client, sample_user_data):
"""Registering with an existing email returns 409."""
client.post("/auth/register", json=sample_user_data)
response = client.post("/auth/register", json=sample_user_data)
assert response.status_code == 409
assert "already registered" in response.json()["detail"].lower()
def test_register_invalid_email(self, client):
"""Invalid email format is rejected."""
response = client.post("/auth/register", json={
"name": "Test",
"email": "not-an-email",
"password": "Pass123!"
})
assert response.status_code == 422
def test_register_weak_password(self, client):
"""Weak password is rejected (if validation is implemented)."""
response = client.post("/auth/register", json={
"name": "Test",
"email": "test@example.com",
"password": "123"
})
assert response.status_code == 422
def test_register_missing_fields(self, client):
"""Missing required fields return 422."""
response = client.post("/auth/register", json={
"email": "test@example.com"
})
assert response.status_code == 422
class TestLogin:
"""Tests for POST /auth/login."""
def test_login_success(self, client, sample_user_data):
"""Successfully log in with correct credentials."""
client.post("/auth/register", json=sample_user_data)
response = client.post("/auth/login", json={
"email": sample_user_data["email"],
"password": sample_user_data["password"]
})
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
def test_login_returns_valid_jwt(self, client, sample_user_data):
"""The returned token is a valid JWT."""
client.post("/auth/register", json=sample_user_data)
response = client.post("/auth/login", json={
"email": sample_user_data["email"],
"password": sample_user_data["password"]
})
token = response.json()["access_token"]
# Decode and verify the token
payload = jwt.decode(token, "your-secret-key-here", algorithms=["HS256"])
assert "sub" in payload
assert "exp" in payload
assert "role" in payload
def test_login_wrong_password(self, client, sample_user_data):
"""Wrong password returns 401."""
client.post("/auth/register", json=sample_user_data)
response = client.post("/auth/login", json={
"email": sample_user_data["email"],
"password": "WrongPassword123!"
})
assert response.status_code == 401
def test_login_nonexistent_user(self, client):
"""Login with non-existent email returns 401."""
response = client.post("/auth/login", json={
"email": "ghost@example.com",
"password": "Pass123!"
})
assert response.status_code == 401
def test_login_deactivated_account(self, client, db_session, sample_user_data):
"""Deactivated account cannot log in."""
# Register and then deactivate
client.post("/auth/register", json=sample_user_data)
from app.models.user import User
user = db_session.query(User).filter_by(
email=sample_user_data["email"]
).first()
user.is_active = False
db_session.commit()
response = client.post("/auth/login", json={
"email": sample_user_data["email"],
"password": sample_user_data["password"]
})
assert response.status_code == 403
class TestProtectedEndpoints:
"""Tests for accessing protected endpoints."""
def test_access_with_valid_token(self, client, auth_headers):
"""Valid token grants access."""
response = client.get("/users/me", headers=auth_headers)
assert response.status_code == 200
def test_access_without_token(self, client):
"""No token returns 401 or 403."""
response = client.get("/users/me")
assert response.status_code in (401, 403)
def test_access_with_invalid_token(self, client):
"""Invalid token returns 401."""
headers = {"Authorization": "Bearer invalid.token.here"}
response = client.get("/users/me", headers=headers)
assert response.status_code in (401, 403)
def test_access_with_expired_token(self, client):
"""Expired token returns 401."""
expired_token = create_test_token(expired=True)
headers = {"Authorization": f"Bearer {expired_token}"}
response = client.get("/users/me", headers=headers)
assert response.status_code in (401, 403)
def test_access_with_wrong_secret(self, client):
"""Token signed with wrong secret is rejected."""
bad_token = create_test_token(secret="wrong-secret")
headers = {"Authorization": f"Bearer {bad_token}"}
response = client.get("/users/me", headers=headers)
assert response.status_code in (401, 403)
class TestRoleBasedAccess:
"""Tests for role-based authorization."""
def test_admin_access_admin_endpoint(self, client, admin_headers):
"""Admin can access admin-only endpoints."""
response = client.get("/admin/users", headers=admin_headers)
assert response.status_code == 200
def test_user_cannot_access_admin_endpoint(self, client, auth_headers):
"""Regular user cannot access admin endpoints."""
response = client.get("/admin/users", headers=auth_headers)
assert response.status_code == 403
def test_role_in_token_payload(self, client, sample_user_data):
"""Token payload contains the correct role."""
client.post("/auth/register", json=sample_user_data)
response = client.post("/auth/login", json={
"email": sample_user_data["email"],
"password": sample_user_data["password"]
})
token = response.json()["access_token"]
payload = jwt.decode(token, "your-secret-key-here", algorithms=["HS256"])
assert payload["role"] == "user"
Testing file upload endpoints requires sending multipart form data. FastAPI’s TestClient makes this straightforward.
# app/routers/files.py
"""File upload endpoints."""
import os
import shutil
from typing import List
from fastapi import APIRouter, File, UploadFile, HTTPException, status
router = APIRouter(prefix="/files", tags=["files"])
UPLOAD_DIR = "uploads"
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".pdf", ".txt"}
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
@router.post("/upload")
async def upload_file(file: UploadFile = File(...)):
"""Upload a single file."""
# Validate file extension
ext = os.path.splitext(file.filename)[1].lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type '{ext}' not allowed. Allowed: {ALLOWED_EXTENSIONS}"
)
# Read and check file size
contents = await file.read()
if len(contents) > MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File too large. Max size: {MAX_FILE_SIZE} bytes"
)
# Save the file
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_path = os.path.join(UPLOAD_DIR, file.filename)
with open(file_path, "wb") as f:
f.write(contents)
return {
"filename": file.filename,
"size": len(contents),
"content_type": file.content_type,
}
@router.post("/upload-multiple")
async def upload_multiple_files(files: List[UploadFile] = File(...)):
"""Upload multiple files at once."""
results = []
for file in files:
contents = await file.read()
results.append({
"filename": file.filename,
"size": len(contents),
"content_type": file.content_type,
})
return {"uploaded": len(results), "files": results}
# tests/integration/test_files.py
"""
Tests for file upload endpoints.
"""
import io
import os
import pytest
import shutil
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
UPLOAD_DIR = "uploads"
@pytest.fixture(autouse=True)
def cleanup_uploads():
"""Remove uploaded files after each test."""
yield
if os.path.exists(UPLOAD_DIR):
shutil.rmtree(UPLOAD_DIR)
class TestSingleFileUpload:
"""Tests for POST /files/upload."""
def test_upload_text_file(self):
"""Upload a text file successfully."""
file_content = b"Hello, this is a test file."
response = client.post(
"/files/upload",
files={"file": ("test.txt", file_content, "text/plain")}
)
assert response.status_code == 200
data = response.json()
assert data["filename"] == "test.txt"
assert data["size"] == len(file_content)
assert data["content_type"] == "text/plain"
def test_upload_image_file(self):
"""Upload a PNG image file."""
# Create a minimal PNG file (1x1 pixel)
png_header = (
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01'
b'\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde'
)
response = client.post(
"/files/upload",
files={"file": ("image.png", png_header, "image/png")}
)
assert response.status_code == 200
assert response.json()["filename"] == "image.png"
def test_upload_pdf_file(self):
"""Upload a PDF file."""
pdf_content = b"%PDF-1.4 fake pdf content"
response = client.post(
"/files/upload",
files={"file": ("document.pdf", pdf_content, "application/pdf")}
)
assert response.status_code == 200
def test_upload_disallowed_extension(self):
"""Uploading a .exe file should be rejected."""
response = client.post(
"/files/upload",
files={"file": ("malware.exe", b"bad content", "application/octet-stream")}
)
assert response.status_code == 400
assert "not allowed" in response.json()["detail"]
def test_upload_file_saves_to_disk(self):
"""Verify the uploaded file is actually saved."""
file_content = b"Persistent content"
client.post(
"/files/upload",
files={"file": ("saved.txt", file_content, "text/plain")}
)
file_path = os.path.join(UPLOAD_DIR, "saved.txt")
assert os.path.exists(file_path)
with open(file_path, "rb") as f:
assert f.read() == file_content
def test_upload_no_file(self):
"""Request without a file should return 422."""
response = client.post("/files/upload")
assert response.status_code == 422
def test_upload_using_io_bytes(self):
"""Upload using io.BytesIO for in-memory files."""
file_obj = io.BytesIO(b"BytesIO content")
response = client.post(
"/files/upload",
files={"file": ("bytesio.txt", file_obj, "text/plain")}
)
assert response.status_code == 200
class TestMultipleFileUpload:
"""Tests for POST /files/upload-multiple."""
def test_upload_multiple_files(self):
"""Upload multiple files at once."""
files = [
("files", ("file1.txt", b"Content 1", "text/plain")),
("files", ("file2.txt", b"Content 2", "text/plain")),
("files", ("file3.txt", b"Content 3", "text/plain")),
]
response = client.post("/files/upload-multiple", files=files)
assert response.status_code == 200
data = response.json()
assert data["uploaded"] == 3
assert len(data["files"]) == 3
def test_upload_mixed_file_types(self):
"""Upload files of different types."""
files = [
("files", ("doc.txt", b"text content", "text/plain")),
("files", ("image.jpg", b"fake jpg", "image/jpeg")),
]
response = client.post("/files/upload-multiple", files=files)
assert response.status_code == 200
assert response.json()["uploaded"] == 2
FastAPI supports WebSocket endpoints, and the TestClient provides built-in WebSocket testing capabilities.
# app/routers/websocket.py
"""WebSocket endpoints."""
from typing import List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
router = APIRouter()
class ConnectionManager:
"""Manages active WebSocket connections."""
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@router.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
"""WebSocket endpoint for real-time communication."""
await manager.connect(websocket)
try:
# Send welcome message
await websocket.send_json({
"type": "connected",
"client_id": client_id,
"message": f"Welcome, {client_id}!"
})
while True:
# Receive message from client
data = await websocket.receive_text()
# Echo back with processing
response = {
"type": "message",
"client_id": client_id,
"content": data,
"echo": f"Server received: {data}"
}
await websocket.send_json(response)
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client {client_id} disconnected")
# tests/integration/test_websocket.py
"""
Tests for WebSocket endpoints.
"""
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
class TestWebSocketEndpoint:
"""Tests for the WebSocket /ws/{client_id} endpoint."""
def test_websocket_connection(self):
"""Client can connect to the WebSocket."""
with client.websocket_connect("/ws/test-user") as websocket:
# Should receive a welcome message
data = websocket.receive_json()
assert data["type"] == "connected"
assert data["client_id"] == "test-user"
assert "Welcome" in data["message"]
def test_websocket_echo(self):
"""Server echoes back received messages."""
with client.websocket_connect("/ws/echo-user") as websocket:
# Skip welcome message
websocket.receive_json()
# Send a message
websocket.send_text("Hello, Server!")
# Receive the echo
response = websocket.receive_json()
assert response["type"] == "message"
assert response["content"] == "Hello, Server!"
assert "Server received" in response["echo"]
def test_websocket_multiple_messages(self):
"""Can send and receive multiple messages."""
with client.websocket_connect("/ws/multi-user") as websocket:
websocket.receive_json() # Skip welcome
messages = ["Hello", "How are you?", "Goodbye"]
for msg in messages:
websocket.send_text(msg)
response = websocket.receive_json()
assert response["content"] == msg
def test_websocket_json_communication(self):
"""Can send and receive JSON data."""
with client.websocket_connect("/ws/json-user") as websocket:
websocket.receive_json() # Skip welcome
# Send text, receive JSON
websocket.send_text("test message")
response = websocket.receive_json()
assert isinstance(response, dict)
assert "type" in response
assert "content" in response
def test_websocket_client_id_in_response(self):
"""The client_id appears in all responses."""
client_id = "user-123"
with client.websocket_connect(f"/ws/{client_id}") as websocket:
welcome = websocket.receive_json()
assert welcome["client_id"] == client_id
websocket.send_text("test")
response = websocket.receive_json()
assert response["client_id"] == client_id
def test_websocket_disconnect(self):
"""Disconnection is handled gracefully."""
# The 'with' block handles disconnection automatically
with client.websocket_connect("/ws/disconnect-user") as websocket:
data = websocket.receive_json()
assert data["type"] == "connected"
# After the with block, the connection is closed
# No exceptions should be raised
Performance testing ensures your API can handle expected load and identifies bottlenecks before they reach production.
# Install locust pip install locust
# tests/performance/locustfile.py
"""
Load testing configuration for FastAPI application.
Run with: locust -f tests/performance/locustfile.py --host=http://localhost:8000
"""
from locust import HttpUser, task, between, tag
class APIUser(HttpUser):
"""Simulates a typical API user."""
# Wait 1-3 seconds between tasks
wait_time = between(1, 3)
def on_start(self):
"""Run once when a simulated user starts."""
# Register and login
self.client.post("/auth/register", json={
"name": f"Load Test User",
"email": f"loadtest_{self.environment.runner.user_count}@test.com",
"password": "LoadTest123!"
})
response = self.client.post("/auth/login", json={
"email": f"loadtest_{self.environment.runner.user_count}@test.com",
"password": "LoadTest123!"
})
if response.status_code == 200:
self.token = response.json().get("access_token", "")
self.headers = {"Authorization": f"Bearer {self.token}"}
else:
self.headers = {}
@task(5)
@tag("read")
def list_items(self):
"""GET /items — most common operation (weight: 5)."""
self.client.get("/items", headers=self.headers)
@task(3)
@tag("read")
def get_single_item(self):
"""GET /items/{id} — common operation (weight: 3)."""
self.client.get("/items/1", headers=self.headers)
@task(1)
@tag("write")
def create_item(self):
"""POST /items — less common (weight: 1)."""
self.client.post("/items", json={
"title": "Load Test Item",
"description": "Created during load testing",
"price": 9.99,
"quantity": 1
}, headers=self.headers)
@task(1)
@tag("read")
def health_check(self):
"""GET /health — monitoring endpoint."""
self.client.get("/health")
class AdminUser(HttpUser):
"""Simulates an admin user — less frequent, heavier operations."""
wait_time = between(5, 10)
weight = 1 # 1 admin per 10 regular users
def on_start(self):
"""Login as admin."""
response = self.client.post("/auth/login", json={
"email": "admin@test.com",
"password": "AdminPass123!"
})
if response.status_code == 200:
self.token = response.json().get("access_token", "")
self.headers = {"Authorization": f"Bearer {self.token}"}
else:
self.headers = {}
@task
def list_all_users(self):
"""GET /admin/users — admin-only endpoint."""
self.client.get("/admin/users", headers=self.headers)
# Run Locust with web UI
locust -f tests/performance/locustfile.py --host=http://localhost:8000
# Run headless (for CI/CD)
locust -f tests/performance/locustfile.py \
--host=http://localhost:8000 \
--headless \
--users 100 \
--spawn-rate 10 \
--run-time 60s \
--csv=results/load_test
# Run only read-tagged tests
locust -f tests/performance/locustfile.py \
--host=http://localhost:8000 \
--tags read \
--headless \
--users 50 \
--spawn-rate 5 \
--run-time 30s
# tests/performance/test_benchmarks.py
"""
Benchmark tests to measure endpoint performance.
Install: pip install pytest-benchmark
Run: pytest tests/performance/test_benchmarks.py --benchmark-only
"""
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_health_endpoint_performance(benchmark):
"""Benchmark the health check endpoint."""
result = benchmark(client.get, "/health")
assert result.status_code == 200
def test_list_items_performance(benchmark):
"""Benchmark listing items."""
result = benchmark(client.get, "/items")
assert result.status_code == 200
def test_create_item_performance(benchmark):
"""Benchmark creating an item."""
payload = {
"title": "Benchmark Item",
"description": "Performance test",
"price": 10.0,
"quantity": 1
}
result = benchmark(client.post, "/items", json=payload)
assert result.status_code == 201
@pytest.mark.slow
def test_response_time_under_threshold(client):
"""Verify critical endpoints respond within acceptable time."""
import time
endpoints = [
("GET", "/health", None),
("GET", "/items", None),
("POST", "/items", {
"title": "Speed Test",
"description": "Testing speed",
"price": 10.0,
"quantity": 1
}),
]
max_response_time = 0.5 # 500ms threshold
for method, path, payload in endpoints:
start = time.time()
if method == "GET":
response = client.get(path)
else:
response = client.post(path, json=payload)
elapsed = time.time() - start
assert elapsed < max_response_time, (
f"{method} {path} took {elapsed:.3f}s "
f"(threshold: {max_response_time}s)"
)
Test coverage measures what percentage of your code is executed by tests. While 100% coverage does not guarantee bug-free code, it highlights untested code paths.
# Install pytest-cov
pip install pytest-cov
# Run tests with coverage
pytest --cov=app tests/
# Generate HTML coverage report
pytest --cov=app --cov-report=html tests/
# Generate XML report (for CI/CD)
pytest --cov=app --cov-report=xml tests/
# Show missing lines in terminal
pytest --cov=app --cov-report=term-missing tests/
# Fail if coverage is below threshold
pytest --cov=app --cov-fail-under=80 tests/
# Combine multiple report formats
pytest --cov=app \
--cov-report=term-missing \
--cov-report=html:coverage_html \
--cov-report=xml:coverage.xml \
tests/
# pyproject.toml — Coverage configuration
[tool.coverage.run]
source = ["app"]
omit = [
"app/__init__.py",
"app/config.py",
"tests/*",
"*/migrations/*",
]
[tool.coverage.report]
show_missing = true
fail_under = 80
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if __name__ == .__main__.",
"raise NotImplementedError",
"pass",
"except ImportError",
]
[tool.coverage.html]
directory = "htmlcov"
# Sample terminal output from pytest --cov=app --cov-report=term-missing ---------- coverage: platform linux, python 3.11.0 ---------- Name Stmts Miss Cover Missing ------------------------------------------------------------ app/__init__.py 0 0 100% app/main.py 15 0 100% app/database.py 12 2 83% 18-19 app/models/item.py 8 0 100% app/models/user.py 12 0 100% app/routers/auth.py 45 3 93% 67-69 app/routers/items.py 38 0 100% app/routers/files.py 30 5 83% 42-46 app/dependencies.py 25 4 84% 31-34 app/services/email_service.py 18 8 56% 12-19 app/services/weather_service.py 20 12 40% 8-19 ------------------------------------------------------------ TOTAL 223 34 85%
| Coverage Level | Meaning | Recommendation |
|---|---|---|
| < 50% | Critical gaps | Many untested code paths — high risk of bugs |
| 50-70% | Basic coverage | Core paths tested but edge cases missing |
| 70-80% | Good coverage | Solid for most projects — good balance |
| 80-90% | Strong coverage | Recommended target for production APIs |
| 90-100% | Comprehensive | Great for critical systems (payments, auth) |
Automating your tests in a CI/CD pipeline ensures that every code change is validated before reaching production.
# .github/workflows/test.yml
name: FastAPI Tests
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
services:
postgres:
image: postgres:15
env:
POSTGRES_USER: testuser
POSTGRES_PASSWORD: testpass
POSTGRES_DB: testdb
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip packages
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements*.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-test.txt
- name: Run linting
run: |
pip install ruff
ruff check app/ tests/
- name: Run type checking
run: |
pip install mypy
mypy app/ --ignore-missing-imports
- name: Run tests with coverage
env:
DATABASE_URL: postgresql://testuser:testpass@localhost:5432/testdb
SECRET_KEY: test-secret-key-for-ci
TESTING: "true"
run: |
pytest tests/ \
--cov=app \
--cov-report=xml:coverage.xml \
--cov-report=term-missing \
--cov-fail-under=80 \
-v \
--tb=short
- name: Upload coverage to Codecov
if: matrix.python-version == '3.12'
uses: codecov/codecov-action@v4
with:
file: coverage.xml
fail_ci_if_error: false
- name: Upload test results
if: always()
uses: actions/upload-artifact@v4
with:
name: test-results-${{ matrix.python-version }}
path: |
coverage.xml
htmlcov/
# Dockerfile.test FROM python:3.12-slim WORKDIR /app # Install dependencies COPY requirements.txt requirements-test.txt ./ RUN pip install --no-cache-dir -r requirements.txt -r requirements-test.txt # Copy application code COPY . . # Run tests CMD ["pytest", "tests/", "-v", "--cov=app", "--cov-report=term-missing"]
# docker-compose.test.yml
version: "3.9"
services:
test:
build:
context: .
dockerfile: Dockerfile.test
environment:
- DATABASE_URL=postgresql://testuser:testpass@db:5432/testdb
- SECRET_KEY=test-secret-key
- TESTING=true
depends_on:
db:
condition: service_healthy
db:
image: postgres:15-alpine
environment:
POSTGRES_USER: testuser
POSTGRES_PASSWORD: testpass
POSTGRES_DB: testdb
healthcheck:
test: ["CMD-SHELL", "pg_isready -U testuser -d testdb"]
interval: 5s
timeout: 5s
retries: 5
# Run tests in Docker
docker compose -f docker-compose.test.yml up --build --abort-on-container-exit
# Run specific test files
docker compose -f docker-compose.test.yml run test \
pytest tests/integration/ -v
# Run with coverage report
docker compose -f docker-compose.test.yml run test \
pytest tests/ --cov=app --cov-report=html
# Clean up
docker compose -f docker-compose.test.yml down -v
# .github/workflows/ci-cd.yml
name: CI/CD Pipeline
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
# Stage 1: Lint and Type Check
quality:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- run: pip install ruff mypy
- run: ruff check app/ tests/
- run: mypy app/ --ignore-missing-imports
# Stage 2: Unit Tests (fast, no external deps)
unit-tests:
runs-on: ubuntu-latest
needs: quality
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- run: |
pip install -r requirements.txt
pip install -r requirements-test.txt
- run: pytest tests/unit/ -v --cov=app -m unit
# Stage 3: Integration Tests (needs database)
integration-tests:
runs-on: ubuntu-latest
needs: unit-tests
services:
postgres:
image: postgres:15
env:
POSTGRES_USER: test
POSTGRES_PASSWORD: test
POSTGRES_DB: testdb
ports: ["5432:5432"]
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- run: |
pip install -r requirements.txt
pip install -r requirements-test.txt
- run: pytest tests/integration/ -v -m integration
env:
DATABASE_URL: postgresql://test:test@localhost:5432/testdb
# Stage 4: Deploy (only on main branch)
deploy:
runs-on: ubuntu-latest
needs: [unit-tests, integration-tests]
if: github.ref == 'refs/heads/main' && github.event_name == 'push'
steps:
- uses: actions/checkout@v4
- name: Deploy to production
run: echo "Deploying to production..."
Let us bring everything together into a complete, production-ready test suite for a CRUD API with authentication, database integration, and mocking.
# Complete application structure
bookstore_api/
├── app/
│ ├── __init__.py
│ ├── main.py
│ ├── config.py
│ ├── database.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── book.py
│ │ └── user.py
│ ├── schemas/
│ │ ├── __init__.py
│ │ ├── book.py
│ │ └── user.py
│ ├── routers/
│ │ ├── __init__.py
│ │ ├── books.py
│ │ └── auth.py
│ ├── services/
│ │ └── notification_service.py
│ └── dependencies.py
└── tests/
├── __init__.py
├── conftest.py
├── test_books_crud.py
├── test_auth.py
└── test_notifications.py
# app/main.py
"""FastAPI Bookstore Application."""
from fastapi import FastAPI
from app.routers import books, auth
from app.database import Base, engine
# Create tables
Base.metadata.create_all(bind=engine)
app = FastAPI(title="Bookstore API", version="1.0.0")
app.include_router(auth.router)
app.include_router(books.router)
@app.get("/health")
def health_check():
return {"status": "healthy"}
# app/schemas/book.py
"""Book Pydantic schemas."""
from pydantic import BaseModel, Field
from typing import Optional
class BookCreate(BaseModel):
title: str = Field(..., min_length=1, max_length=300)
author: str = Field(..., min_length=1, max_length=200)
isbn: str = Field(..., pattern=r"^\d{13}$")
price: float = Field(..., gt=0)
description: Optional[str] = None
class BookUpdate(BaseModel):
title: Optional[str] = Field(None, min_length=1, max_length=300)
author: Optional[str] = Field(None, min_length=1, max_length=200)
price: Optional[float] = Field(None, gt=0)
description: Optional[str] = None
class BookResponse(BaseModel):
id: int
title: str
author: str
isbn: str
price: float
description: Optional[str]
owner_id: int
model_config = {"from_attributes": True}
# app/models/book.py
"""Book SQLAlchemy model."""
from sqlalchemy import Column, Integer, String, Float, ForeignKey
from sqlalchemy.orm import relationship
from app.database import Base
class Book(Base):
__tablename__ = "books"
id = Column(Integer, primary_key=True, index=True)
title = Column(String(300), nullable=False)
author = Column(String(200), nullable=False)
isbn = Column(String(13), unique=True, nullable=False)
price = Column(Float, nullable=False)
description = Column(String(2000), nullable=True)
owner_id = Column(Integer, ForeignKey("users.id"), nullable=False)
owner = relationship("User", back_populates="books")
# app/routers/books.py
"""Book CRUD router with authentication."""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies import get_current_user
from app.models.book import Book
from app.models.user import User
from app.schemas.book import BookCreate, BookUpdate, BookResponse
router = APIRouter(prefix="/books", tags=["books"])
@router.get("/", response_model=List[BookResponse])
def list_books(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
author: Optional[str] = None,
db: Session = Depends(get_db),
):
"""List all books (public endpoint)."""
query = db.query(Book)
if author:
query = query.filter(Book.author.ilike(f"%{author}%"))
return query.offset(skip).limit(limit).all()
@router.get("/{book_id}", response_model=BookResponse)
def get_book(book_id: int, db: Session = Depends(get_db)):
"""Get a book by ID (public endpoint)."""
book = db.query(Book).filter(Book.id == book_id).first()
if not book:
raise HTTPException(status_code=404, detail="Book not found")
return book
@router.post("/", response_model=BookResponse, status_code=201)
def create_book(
book_data: BookCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Create a new book (authenticated)."""
# Check for duplicate ISBN
existing = db.query(Book).filter(Book.isbn == book_data.isbn).first()
if existing:
raise HTTPException(
status_code=409,
detail=f"Book with ISBN {book_data.isbn} already exists"
)
book = Book(**book_data.model_dump(), owner_id=current_user.id)
db.add(book)
db.commit()
db.refresh(book)
return book
@router.put("/{book_id}", response_model=BookResponse)
def update_book(
book_id: int,
book_data: BookUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Update a book (owner only)."""
book = db.query(Book).filter(Book.id == book_id).first()
if not book:
raise HTTPException(status_code=404, detail="Book not found")
if book.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
update_data = book_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(book, field, value)
db.commit()
db.refresh(book)
return book
@router.delete("/{book_id}", status_code=204)
def delete_book(
book_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Delete a book (owner only)."""
book = db.query(Book).filter(Book.id == book_id).first()
if not book:
raise HTTPException(status_code=404, detail="Book not found")
if book.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
db.delete(book)
db.commit()
# tests/conftest.py
"""Complete test configuration for the Bookstore API."""
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.main import app
from app.database import Base, get_db
from app.dependencies import get_current_user
from app.models.user import User
SQLALCHEMY_DATABASE_URL = "sqlite://"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
@pytest.fixture(scope="function")
def db_session():
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def client(db_session):
def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as c:
yield c
app.dependency_overrides.clear()
@pytest.fixture
def test_user(db_session):
"""Create a test user in the database."""
user = User(
id=1, name="Test User", email="test@example.com",
hashed_password="fake_hash", role="user", is_active=True
)
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def other_user(db_session):
"""Create a second test user."""
user = User(
id=2, name="Other User", email="other@example.com",
hashed_password="fake_hash", role="user", is_active=True
)
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def authenticated_client(client, test_user):
"""Client with authentication dependency overridden."""
def override_current_user():
return test_user
app.dependency_overrides[get_current_user] = override_current_user
yield client
# Note: client fixture already clears overrides
@pytest.fixture
def other_authenticated_client(client, other_user):
"""Client authenticated as a different user."""
def override_current_user():
return other_user
app.dependency_overrides[get_current_user] = override_current_user
yield client
@pytest.fixture
def sample_book():
return {
"title": "Clean Code",
"author": "Robert C. Martin",
"isbn": "9780132350884",
"price": 39.99,
"description": "A handbook of agile software craftsmanship"
}
@pytest.fixture
def another_book():
return {
"title": "The Pragmatic Programmer",
"author": "David Thomas",
"isbn": "9780135957059",
"price": 49.99,
"description": "Your journey to mastery"
}
# tests/test_books_crud.py
"""Complete CRUD tests for the Bookstore API."""
import pytest
class TestListBooks:
"""GET /books — Public endpoint."""
def test_empty_list(self, client):
response = client.get("/books")
assert response.status_code == 200
assert response.json() == []
def test_list_with_books(self, authenticated_client, sample_book, another_book):
authenticated_client.post("/books", json=sample_book)
authenticated_client.post("/books", json=another_book)
response = authenticated_client.get("/books")
assert response.status_code == 200
assert len(response.json()) == 2
def test_filter_by_author(self, authenticated_client, sample_book, another_book):
authenticated_client.post("/books", json=sample_book)
authenticated_client.post("/books", json=another_book)
response = authenticated_client.get("/books?author=Martin")
assert response.status_code == 200
books = response.json()
assert len(books) == 1
assert books[0]["author"] == "Robert C. Martin"
def test_pagination(self, authenticated_client, sample_book):
# Create several books with unique ISBNs
for i in range(5):
book = {**sample_book, "isbn": f"978013235{i:04d}", "title": f"Book {i}"}
authenticated_client.post("/books", json=book)
response = authenticated_client.get("/books?skip=2&limit=2")
assert len(response.json()) == 2
class TestGetBook:
"""GET /books/{id} — Public endpoint."""
def test_get_existing_book(self, authenticated_client, sample_book):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
response = authenticated_client.get(f"/books/{book_id}")
assert response.status_code == 200
assert response.json()["title"] == sample_book["title"]
assert response.json()["isbn"] == sample_book["isbn"]
def test_get_nonexistent_book(self, client):
response = client.get("/books/99999")
assert response.status_code == 404
def test_response_has_all_fields(self, authenticated_client, sample_book):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
response = authenticated_client.get(f"/books/{book_id}")
data = response.json()
assert "id" in data
assert "title" in data
assert "author" in data
assert "isbn" in data
assert "price" in data
assert "owner_id" in data
class TestCreateBook:
"""POST /books — Authenticated endpoint."""
def test_create_success(self, authenticated_client, sample_book):
response = authenticated_client.post("/books", json=sample_book)
assert response.status_code == 201
data = response.json()
assert data["title"] == sample_book["title"]
assert data["author"] == sample_book["author"]
assert data["isbn"] == sample_book["isbn"]
assert data["owner_id"] == 1 # test_user id
def test_create_requires_authentication(self, client, sample_book):
"""Unauthenticated request should fail."""
response = client.post("/books", json=sample_book)
assert response.status_code in (401, 403)
def test_duplicate_isbn_rejected(self, authenticated_client, sample_book):
authenticated_client.post("/books", json=sample_book)
duplicate = {**sample_book, "title": "Different Title"}
response = authenticated_client.post("/books", json=duplicate)
assert response.status_code == 409
def test_invalid_isbn_format(self, authenticated_client, sample_book):
invalid_book = {**sample_book, "isbn": "123"}
response = authenticated_client.post("/books", json=invalid_book)
assert response.status_code == 422
def test_negative_price_rejected(self, authenticated_client, sample_book):
invalid_book = {**sample_book, "isbn": "9781234567890", "price": -10.0}
response = authenticated_client.post("/books", json=invalid_book)
assert response.status_code == 422
def test_missing_required_fields(self, authenticated_client):
response = authenticated_client.post("/books", json={"title": "Incomplete"})
assert response.status_code == 422
class TestUpdateBook:
"""PUT /books/{id} — Owner only."""
def test_owner_can_update(self, authenticated_client, sample_book):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
response = authenticated_client.put(f"/books/{book_id}", json={
"title": "Clean Code (2nd Edition)",
"price": 44.99
})
assert response.status_code == 200
assert response.json()["title"] == "Clean Code (2nd Edition)"
assert response.json()["price"] == 44.99
def test_non_owner_cannot_update(
self, authenticated_client, other_authenticated_client, sample_book
):
# Create book as test_user
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
# Try to update as other_user
response = other_authenticated_client.put(f"/books/{book_id}", json={
"title": "Stolen Book"
})
assert response.status_code == 403
def test_update_nonexistent_book(self, authenticated_client):
response = authenticated_client.put("/books/99999", json={
"title": "Ghost Book"
})
assert response.status_code == 404
def test_partial_update(self, authenticated_client, sample_book):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
# Only update the price
response = authenticated_client.put(f"/books/{book_id}", json={
"price": 29.99
})
assert response.status_code == 200
assert response.json()["price"] == 29.99
assert response.json()["title"] == sample_book["title"] # Unchanged
class TestDeleteBook:
"""DELETE /books/{id} — Owner only."""
def test_owner_can_delete(self, authenticated_client, sample_book):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
response = authenticated_client.delete(f"/books/{book_id}")
assert response.status_code == 204
# Verify deletion
get_resp = authenticated_client.get(f"/books/{book_id}")
assert get_resp.status_code == 404
def test_non_owner_cannot_delete(
self, authenticated_client, other_authenticated_client, sample_book
):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
response = other_authenticated_client.delete(f"/books/{book_id}")
assert response.status_code == 403
def test_delete_nonexistent(self, authenticated_client):
response = authenticated_client.delete("/books/99999")
assert response.status_code == 404
def test_delete_reduces_count(self, authenticated_client, sample_book, another_book):
r1 = authenticated_client.post("/books", json=sample_book)
authenticated_client.post("/books", json=another_book)
assert len(authenticated_client.get("/books").json()) == 2
authenticated_client.delete(f"/books/{r1.json()['id']}")
assert len(authenticated_client.get("/books").json()) == 1
class TestNotificationMocking:
"""Test that notifications are sent on book creation."""
def test_notification_sent_on_create(
self, authenticated_client, sample_book, mocker
):
"""Mock the notification service to verify it is called."""
mock_notify = mocker.patch(
"app.routers.books.send_notification",
return_value=True
)
authenticated_client.post("/books", json=sample_book)
# If send_notification is called in the create endpoint:
# mock_notify.assert_called_once()
# Running the complete test suite # Run all tests pytest tests/ -v # Run with coverage pytest tests/ --cov=app --cov-report=term-missing --cov-fail-under=80 # Run only CRUD tests pytest tests/test_books_crud.py -v # Run a specific test class pytest tests/test_books_crud.py::TestCreateBook -v # Run a specific test pytest tests/test_books_crud.py::TestCreateBook::test_create_success -v # Run tests matching a keyword pytest tests/ -k "create" -v # Run tests in parallel (requires pytest-xdist) pip install pytest-xdist pytest tests/ -n auto -v # Show slowest tests pytest tests/ --durations=10
| # | Topic | Key Point |
|---|---|---|
| 1 | Testing Philosophy | Follow the testing pyramid — many unit tests, fewer integration tests, few E2E tests |
| 2 | Setup | Use pytest, httpx, and conftest.py for a clean test foundation |
| 3 | TestClient | Use Starlette’s TestClient for synchronous endpoint testing — no server required |
| 4 | Async Testing | Use httpx.AsyncClient with pytest-asyncio for async endpoint and dependency testing |
| 5 | Route Testing | Test all CRUD operations, edge cases, error responses, and query parameter validation |
| 6 | Validation Testing | Test Pydantic schemas independently and verify API returns proper 422 error responses |
| 7 | Database Testing | Use in-memory SQLite with transaction rollback for fast, isolated database tests |
| 8 | Dependency Override | Use app.dependency_overrides to swap real dependencies with test doubles |
| 9 | Mocking | Mock external services with unittest.mock or pytest-mock to avoid real network calls |
| 10 | Auth Testing | Test registration, login, JWT validation, token expiration, and role-based access |
| 11 | File Uploads | Test multipart form data uploads with the files parameter in TestClient |
| 12 | WebSockets | Use client.websocket_connect() to test WebSocket communication |
| 13 | Performance | Use Locust for load testing and pytest-benchmark for endpoint benchmarking |
| 14 | Coverage | Aim for 80%+ coverage; use pytest-cov with --cov-fail-under to enforce minimums |
| 15 | CI/CD | Automate tests with GitHub Actions; use Docker for reproducible test environments |
| 16 | Complete Suite | Combine all patterns into a cohesive test suite that covers auth, CRUD, mocking, and validation |