FastAPI has rapidly become one of the most popular Python web frameworks, and demand for FastAPI developers continues to grow. Whether you are preparing for your first Python backend role or aiming for a senior architect position, this guide covers the questions you are most likely to encounter in a real interview.
The questions are organized into three tiers:
| Level | Focus Areas | What Interviewers Look For |
|---|---|---|
| Junior | Core concepts, basic routing, Pydantic basics, running the app | Solid fundamentals, ability to build simple endpoints |
| Mid-Level | Dependency injection, auth, async patterns, testing, CRUD | Production-quality code, understanding of the request lifecycle |
| Senior | Architecture, ASGI, WebSockets, deployment, CI/CD, security | System design thinking, performance tuning, operational maturity |
FastAPI is a modern, high-performance Python web framework for building APIs. It is built on top of Starlette (for the web layer) and Pydantic (for data validation). Key reasons to choose FastAPI include:
from fastapi import FastAPI
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "Hello, World!"}
Best practice: Choose FastAPI when you need an async-capable REST or GraphQL API with automatic request validation. If you only need to serve HTML templates with minimal API work, a lighter framework may suffice.
Path parameters are dynamic segments in the URL path. You declare them inside curly braces in the route decorator and as function arguments with type annotations. FastAPI automatically validates and converts the value to the declared type.
from fastapi import FastAPI
app = FastAPI()
@app.get("/users/{user_id}")
def get_user(user_id: int):
return {"user_id": user_id}
# GET /users/42 -> {"user_id": 42}
# GET /users/abc -> 422 Unprocessable Entity (validation error)
You can also use Path() for additional constraints:
from fastapi import FastAPI, Path
app = FastAPI()
@app.get("/items/{item_id}")
def get_item(item_id: int = Path(..., title="Item ID", ge=1, le=10000)):
return {"item_id": item_id}
Common pitfall: If you have both /users/me and /users/{user_id}, put the static route first. FastAPI matches routes in declaration order, so /users/{user_id} would capture "me" as a path parameter if declared first.
Query parameters are key-value pairs appended to the URL after a ?. Any function parameter that is not part of the path is automatically treated as a query parameter.
from fastapi import FastAPI
from typing import Optional
app = FastAPI()
@app.get("/items")
def list_items(skip: int = 0, limit: int = 10, q: Optional[str] = None):
result = {"skip": skip, "limit": limit}
if q:
result["query"] = q
return result
# GET /items?skip=5&limit=20&q=phone
# -> {"skip": 5, "limit": 20, "query": "phone"}
Use Query() for extra validation:
from fastapi import FastAPI, Query
app = FastAPI()
@app.get("/search")
def search(q: str = Query(..., min_length=2, max_length=100)):
return {"query": q}
Best practice: Always set sensible defaults for pagination parameters (skip, limit) and cap the maximum limit to prevent clients from requesting excessively large result sets.
Type hints are central to FastAPI. They serve multiple purposes simultaneously:
| Purpose | How Type Hints Help |
|---|---|
| Request validation | FastAPI validates incoming data against declared types automatically |
| Serialization | Response data is serialized based on the return type or response_model |
| Documentation | Swagger UI reflects parameter types, descriptions, and constraints |
| Editor support | IDEs provide auto-complete, type checking, and refactoring tools |
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List
app = FastAPI()
class Item(BaseModel):
name: str
price: float
tags: List[str] = []
description: Optional[str] = None
@app.post("/items", response_model=Item)
def create_item(item: Item):
# item is already validated and typed
return item
Key insight: Unlike Flask, where you manually call request.get_json() and validate fields yourself, FastAPI uses type hints to handle all of this declaratively. This eliminates an entire class of bugs.
Pydantic is a data validation and settings management library that uses Python type annotations. FastAPI relies on Pydantic for:
BaseSettingsfrom pydantic import BaseModel, Field, field_validator
from typing import Optional
from datetime import datetime
class UserCreate(BaseModel):
username: str = Field(..., min_length=3, max_length=50)
email: str
age: int = Field(..., ge=13, le=120)
bio: Optional[str] = None
@field_validator("email")
@classmethod
def validate_email(cls, v):
if "@" not in v:
raise ValueError("Invalid email address")
return v.lower()
class UserResponse(BaseModel):
id: int
username: str
email: str
created_at: datetime
model_config = {"from_attributes": True}
The model_config = {"from_attributes": True} setting (formerly class Config: orm_mode = True in Pydantic v1) allows Pydantic to read data from ORM objects like SQLAlchemy models.
Common pitfall: Forgetting to enable from_attributes when returning ORM objects will cause serialization errors.
FastAPI applications are ASGI apps, so you need an ASGI server. The most common choice is Uvicorn.
# Install pip install fastapi uvicorn # Run in development with auto-reload uvicorn main:app --reload --host 0.0.0.0 --port 8000 # Run in production with multiple workers uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4
You can also run it programmatically:
import uvicorn
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
For production deployments, Gunicorn with Uvicorn workers is the recommended pattern:
gunicorn main:app -w 4 -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
Best practice: Use --reload only in development. In production, use Gunicorn as the process manager with Uvicorn workers for robustness.
These decorators bind a function to an HTTP method. Each method has a specific semantic meaning:
| Decorator | HTTP Method | Purpose | Request Body | Idempotent |
|---|---|---|---|---|
@app.get |
GET | Retrieve data | No | Yes |
@app.post |
POST | Create a resource | Yes | No |
@app.put |
PUT | Replace a resource | Yes | Yes |
@app.patch |
PATCH | Partially update | Yes | No |
@app.delete |
DELETE | Delete a resource | Optional | Yes |
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class Item(BaseModel):
name: str
price: float
items_db = {}
@app.get("/items/{item_id}")
def get_item(item_id: int):
return items_db.get(item_id, {"error": "Not found"})
@app.post("/items", status_code=201)
def create_item(item: Item):
item_id = len(items_db) + 1
items_db[item_id] = item.model_dump()
return {"id": item_id, **item.model_dump()}
@app.put("/items/{item_id}")
def replace_item(item_id: int, item: Item):
items_db[item_id] = item.model_dump()
return {"id": item_id, **item.model_dump()}
@app.delete("/items/{item_id}", status_code=204)
def delete_item(item_id: int):
items_db.pop(item_id, None)
FastAPI automatically generates interactive API documentation from your route definitions, type hints, and Pydantic models. Two UIs are available out of the box:
| URL | UI | Description |
|---|---|---|
/docs |
Swagger UI | Interactive documentation with a “Try it out” feature |
/redoc |
ReDoc | Clean, read-only documentation |
/openapi.json |
Raw JSON | The OpenAPI schema as JSON |
from fastapi import FastAPI
# Customize docs metadata
app = FastAPI(
title="My API",
description="A comprehensive API for managing items",
version="1.0.0",
docs_url="/docs", # default
redoc_url="/redoc", # default
openapi_url="/openapi.json" # default
)
# Disable docs in production
app_prod = FastAPI(docs_url=None, redoc_url=None)
Best practice: Disable interactive docs in production for security. You can conditionally enable them based on an environment variable.
Uvicorn is a lightning-fast ASGI server implementation. It serves as the bridge between the network and your FastAPI application.
# Standard install pip install uvicorn[standard] # Run with SSL uvicorn main:app --ssl-keyfile=key.pem --ssl-certfile=cert.pem # Run with specific log level uvicorn main:app --log-level warning
Key insight: Uvicorn itself is single-process. For production, pair it with Gunicorn (gunicorn -k uvicorn.workers.UvicornWorker) to get multi-process concurrency.
FastAPI provides several mechanisms for error handling:
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
app = FastAPI()
items_db = {1: {"name": "Laptop"}}
# 1. HTTPException for known errors
@app.get("/items/{item_id}")
def get_item(item_id: int):
if item_id not in items_db:
raise HTTPException(
status_code=404,
detail="Item not found",
headers={"X-Error": "Item lookup failed"}
)
return items_db[item_id]
# 2. Custom exception class
class ItemNotFoundError(Exception):
def __init__(self, item_id: int):
self.item_id = item_id
@app.exception_handler(ItemNotFoundError)
async def item_not_found_handler(request: Request, exc: ItemNotFoundError):
return JSONResponse(
status_code=404,
content={"detail": f"Item {exc.item_id} does not exist"}
)
# 3. Override default validation error handler
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=422,
content={
"detail": "Validation failed",
"errors": exc.errors()
}
)
Best practice: Use HTTPException for simple cases. Create custom exception classes and handlers for domain-specific errors to keep your route functions clean.
| Feature | FastAPI | Flask |
|---|---|---|
| Type | ASGI (async-native) | WSGI (sync by default) |
| Validation | Built-in via Pydantic | Manual or via extensions (Marshmallow) |
| Documentation | Auto-generated Swagger & ReDoc | Manual or via Flask-RESTX |
| Performance | Very high (on par with Go/Node) | Moderate |
| Dependency injection | Built-in Depends() system |
Not built-in |
| Async support | Native async/await | Limited (added in Flask 2.0) |
| Ecosystem maturity | Growing rapidly | Very mature, huge plugin ecosystem |
| Learning curve | Moderate (need to understand type hints) | Low (simple and minimal) |
When to choose FastAPI: New API-first projects that need high performance, automatic validation, and auto-generated docs.
When to choose Flask: Projects that need extensive HTML template rendering, or when your team has deep Flask experience and a large existing Flask codebase.
from fastapi import FastAPI, Response, HTTPException
from fastapi.responses import JSONResponse
app = FastAPI()
# Method 1: Set default status code in decorator
@app.post("/items", status_code=201)
def create_item(item: dict):
return {"id": 1, **item}
# Method 2: Use Response parameter for dynamic codes
@app.get("/items/{item_id}")
def get_item(item_id: int, response: Response):
if item_id == 0:
response.status_code = 204
return None
return {"item_id": item_id}
# Method 3: Return a Response object directly
@app.get("/health")
def health_check():
healthy = True
if healthy:
return JSONResponse(content={"status": "ok"}, status_code=200)
return JSONResponse(content={"status": "degraded"}, status_code=503)
# Method 4: HTTPException for error codes
@app.get("/secure")
def secure_endpoint():
raise HTTPException(status_code=403, detail="Forbidden")
Best practice: Use the status_code parameter in the decorator for the “happy path” response. Use HTTPException for error paths. This keeps your OpenAPI docs accurate.
FastAPI has a powerful built-in dependency injection system using Depends(). Dependencies are functions (or classes) that are called before your route handler, and their return values are injected as parameters.
from fastapi import FastAPI, Depends, Query
from typing import Optional
app = FastAPI()
# Simple function dependency
def common_parameters(
skip: int = Query(0, ge=0),
limit: int = Query(10, ge=1, le=100),
q: Optional[str] = None
):
return {"skip": skip, "limit": limit, "q": q}
@app.get("/items")
def list_items(params: dict = Depends(common_parameters)):
return {"params": params}
@app.get("/users")
def list_users(params: dict = Depends(common_parameters)):
return {"params": params}
# Class-based dependency
class Pagination:
def __init__(self, skip: int = 0, limit: int = 10):
self.skip = skip
self.limit = limit
@app.get("/products")
def list_products(pagination: Pagination = Depends()):
return {"skip": pagination.skip, "limit": pagination.limit}
# Nested dependencies
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def get_current_user(db=Depends(get_db)):
# Uses db dependency
user = db.query(User).first()
return user
@app.get("/profile")
def get_profile(user=Depends(get_current_user)):
return {"username": user.username}
Key insight: Dependencies that use yield act like context managers. Code after yield runs after the response is sent, making them perfect for cleanup tasks like closing database connections.
When a request arrives at a FastAPI endpoint, it goes through a well-defined validation pipeline:
response_model if specified.yield dependencies.from fastapi import FastAPI, Depends, Header, Path, Query, HTTPException
from pydantic import BaseModel
app = FastAPI()
class ItemCreate(BaseModel):
name: str
price: float
class ItemResponse(BaseModel):
id: int
name: str
price: float
def verify_token(x_token: str = Header(...)):
if x_token != "secret-token":
raise HTTPException(status_code=403, detail="Invalid token")
return x_token
@app.post(
"/categories/{category_id}/items",
response_model=ItemResponse,
status_code=201
)
def create_item(
category_id: int = Path(..., ge=1), # Step 1: path
q: str = Query(None), # Step 2: query
token: str = Depends(verify_token), # Step 7: dependency
item: ItemCreate = ..., # Steps 5-6: body + validation
):
# Step 8: handler executes with all validated data
return ItemResponse(id=1, name=item.name, price=item.price)
# Step 9: response serialized via response_model
If validation fails at any step, FastAPI returns a 422 Unprocessable Entity response with detailed error information.
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from datetime import datetime, timedelta
from typing import Optional
# Configuration
SECRET_KEY = "your-secret-key-keep-it-secret"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
app = FastAPI()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Models
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: Optional[str] = None
class User(BaseModel):
username: str
email: str
disabled: bool = False
# Helper functions
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = get_user_from_db(username)
if user is None:
raise credentials_exception
return user
# Endpoints
@app.post("/token", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
access_token = create_access_token(
data={"sub": user.username},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
return {"access_token": access_token, "token_type": "bearer"}
@app.get("/users/me", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_user)):
return current_user
OAuth2PasswordBearer is a FastAPI security utility class that implements the OAuth2 Password flow. It does two things:
Authorization: Bearer <token> header from incoming requests.from fastapi import FastAPI, Depends
from fastapi.security import OAuth2PasswordBearer
app = FastAPI()
# tokenUrl is the endpoint where clients POST credentials to get a token
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Using it as a dependency simply extracts the token string
@app.get("/protected")
async def protected_route(token: str = Depends(oauth2_scheme)):
# token is the raw Bearer token string
# You still need to decode/validate it yourself
return {"token": token}
OAuth2PasswordBearer does not validate the token. It only extracts it. You must combine it with your own validation logic (e.g., JWT decoding) in a dependency.
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from fastapi import FastAPI, Depends, HTTPException
DATABASE_URL = "postgresql://user:password@localhost:5432/mydb"
engine = create_engine(DATABASE_URL, pool_size=10, max_overflow=20)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Dependency that provides a database session
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
app = FastAPI()
@app.get("/users/{user_id}")
def get_user(user_id: int, db: Session = Depends(get_db)):
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
@app.post("/users", status_code=201)
def create_user(user_data: UserCreate, db: Session = Depends(get_db)):
db_user = User(**user_data.model_dump())
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
Key insight: The yield pattern ensures the session is always closed, even if an exception occurs during request processing. This prevents connection leaks.
FastAPI natively supports Python’s async/await syntax because it runs on ASGI (Asynchronous Server Gateway Interface).
import httpx
from fastapi import FastAPI, Depends
app = FastAPI()
# Async endpoint - runs on the event loop
@app.get("/async-data")
async def get_async_data():
async with httpx.AsyncClient() as client:
response = await client.get("https://api.example.com/data")
return response.json()
# Sync endpoint - runs in a thread pool
@app.get("/sync-data")
def get_sync_data():
# FastAPI automatically runs this in a thread pool
# so it does not block the event loop
import time
time.sleep(1) # Simulates blocking I/O
return {"data": "result"}
# Async dependency
async def get_async_client():
async with httpx.AsyncClient() as client:
yield client
@app.get("/external")
async def call_external(client: httpx.AsyncClient = Depends(get_async_client)):
response = await client.get("https://api.example.com/resource")
return response.json()
Important rule: If your function uses await, declare it with async def. If it performs blocking I/O (database calls via synchronous drivers, file I/O), use regular def and let FastAPI handle the threading.
| Aspect | async def endpoint | def endpoint (sync) |
|---|---|---|
| Execution | Runs directly on the async event loop | Runs in a separate thread from a thread pool |
| Blocking I/O | Must use async libraries (httpx, aiofiles, asyncpg) | Can safely use blocking libraries (requests, open()) |
| Concurrency | Thousands of concurrent tasks via event loop | Limited by thread pool size (default: 40 threads) |
| CPU-bound work | Blocks the event loop – avoid | Blocks one thread – slightly better |
import httpx
# WRONG: blocking call in async function blocks the event loop
@app.get("/bad")
async def bad_endpoint():
import requests # blocking library!
response = requests.get("https://api.example.com") # blocks event loop
return response.json()
# CORRECT: use async library in async function
@app.get("/good-async")
async def good_async_endpoint():
async with httpx.AsyncClient() as client:
response = await client.get("https://api.example.com")
return response.json()
# CORRECT: use sync function for blocking calls
@app.get("/good-sync")
def good_sync_endpoint():
import requests
response = requests.get("https://api.example.com")
return response.json()
async def and then using blocking libraries like requests or synchronous database drivers. This blocks the entire event loop and kills performance.from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
from pydantic import BaseModel
from typing import List, Optional
app = FastAPI()
# --- Pydantic schemas ---
class ProductCreate(BaseModel):
name: str
description: Optional[str] = None
price: float
category: str
class ProductUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
price: Optional[float] = None
category: Optional[str] = None
class ProductResponse(BaseModel):
id: int
name: str
description: Optional[str]
price: float
category: str
model_config = {"from_attributes": True}
# --- CRUD functions (service layer) ---
def create_product(db: Session, product: ProductCreate):
db_product = Product(**product.model_dump())
db.add(db_product)
db.commit()
db.refresh(db_product)
return db_product
def get_products(db: Session, skip: int = 0, limit: int = 100):
return db.query(Product).offset(skip).limit(limit).all()
def get_product(db: Session, product_id: int):
return db.query(Product).filter(Product.id == product_id).first()
def update_product(db: Session, product_id: int, updates: ProductUpdate):
db_product = db.query(Product).filter(Product.id == product_id).first()
if not db_product:
return None
update_data = updates.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_product, field, value)
db.commit()
db.refresh(db_product)
return db_product
def delete_product(db: Session, product_id: int):
db_product = db.query(Product).filter(Product.id == product_id).first()
if not db_product:
return False
db.delete(db_product)
db.commit()
return True
# --- Route handlers ---
@app.post("/products", response_model=ProductResponse, status_code=201)
def create(product: ProductCreate, db: Session = Depends(get_db)):
return create_product(db, product)
@app.get("/products", response_model=List[ProductResponse])
def read_all(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
return get_products(db, skip, limit)
@app.get("/products/{product_id}", response_model=ProductResponse)
def read_one(product_id: int, db: Session = Depends(get_db)):
product = get_product(db, product_id)
if not product:
raise HTTPException(status_code=404, detail="Product not found")
return product
@app.patch("/products/{product_id}", response_model=ProductResponse)
def update(product_id: int, updates: ProductUpdate, db: Session = Depends(get_db)):
product = update_product(db, product_id, updates)
if not product:
raise HTTPException(status_code=404, detail="Product not found")
return product
@app.delete("/products/{product_id}", status_code=204)
def delete(product_id: int, db: Session = Depends(get_db)):
if not delete_product(db, product_id):
raise HTTPException(status_code=404, detail="Product not found")
Best practice: Separate CRUD logic into a service layer (separate module) rather than putting database queries directly in route handlers. This makes the code testable and reusable.
# app/routers/users.py
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List
router = APIRouter(
prefix="/users",
tags=["users"],
responses={404: {"description": "Not found"}},
)
@router.get("/", response_model=List[UserResponse])
def list_users(db: Session = Depends(get_db)):
return db.query(User).all()
@router.get("/{user_id}", response_model=UserResponse)
def get_user(user_id: int, db: Session = Depends(get_db)):
return db.query(User).filter(User.id == user_id).first()
@router.post("/", response_model=UserResponse, status_code=201)
def create_user(user: UserCreate, db: Session = Depends(get_db)):
db_user = User(**user.model_dump())
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
# app/routers/products.py
from fastapi import APIRouter
router = APIRouter(prefix="/products", tags=["products"])
@router.get("/")
def list_products():
return []
# app/main.py from fastapi import FastAPI from app.routers import users, products app = FastAPI(title="My API") app.include_router(users.router) app.include_router(products.router) # You can also add a prefix when including # app.include_router(users.router, prefix="/api/v1")
A well-organized project structure looks like this:
app/
__init__.py
main.py # FastAPI app instance and router includes
config.py # Settings and configuration
database.py # Database engine and session
models/ # SQLAlchemy models
__init__.py
user.py
product.py
schemas/ # Pydantic schemas
__init__.py
user.py
product.py
routers/ # Route handlers
__init__.py
users.py
products.py
services/ # Business logic
__init__.py
user_service.py
product_service.py
dependencies/ # Shared dependencies
__init__.py
auth.py
database.py
Background tasks let you run code after the response has been sent to the client. They are useful for operations that the client does not need to wait for.
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
app = FastAPI()
def send_email(email: str, subject: str, body: str):
# Simulate sending email (runs after response is sent)
import time
time.sleep(3)
print(f"Email sent to {email}: {subject}")
def write_audit_log(user_id: int, action: str):
# Write to audit log after response
with open("audit.log", "a") as f:
f.write(f"{user_id}: {action}\n")
class UserCreate(BaseModel):
username: str
email: str
@app.post("/users", status_code=201)
def create_user(user: UserCreate, background_tasks: BackgroundTasks):
# Create user in database (synchronous, client waits)
new_user = {"id": 1, "username": user.username, "email": user.email}
# These run AFTER the response is sent
background_tasks.add_task(send_email, user.email, "Welcome!", "Thanks for joining")
background_tasks.add_task(write_audit_log, 1, "user_created")
return new_user
When to use background tasks vs. a task queue (Celery/Redis):
| Criteria | BackgroundTasks | Celery / Task Queue |
|---|---|---|
| Duration | Short (seconds) | Long (minutes/hours) |
| Reliability | Lost if server crashes | Persisted in broker, retryable |
| Infrastructure | None extra | Needs Redis/RabbitMQ |
| Use case | Emails, logging, cache invalidation | Video processing, reports, ETL |
from fastapi.testclient import TestClient
from fastapi import FastAPI, Depends
import pytest
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "Hello"}
@app.get("/items/{item_id}")
def read_item(item_id: int):
return {"item_id": item_id}
# --- Basic tests ---
client = TestClient(app)
def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello"}
def test_read_item():
response = client.get("/items/42")
assert response.status_code == 200
assert response.json() == {"item_id": 42}
def test_invalid_item_id():
response = client.get("/items/not-a-number")
assert response.status_code == 422
# --- Testing with dependency overrides ---
def get_db():
return real_database_session()
@app.get("/users")
def get_users(db=Depends(get_db)):
return []
def override_get_db():
return test_database_session()
app.dependency_overrides[get_db] = override_get_db
# --- Async testing with httpx ---
import httpx
@pytest.mark.anyio
async def test_async_root():
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://test"
) as ac:
response = await ac.get("/")
assert response.status_code == 200
# --- Testing with pytest fixtures ---
@pytest.fixture
def test_client():
with TestClient(app) as c:
yield c
def test_with_fixture(test_client):
response = test_client.get("/")
assert response.status_code == 200
Best practice: Use dependency_overrides to replace real databases, external APIs, and authentication with test doubles. This makes your tests fast and deterministic.
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from typing import List
import os
app = FastAPI()
UPLOAD_DIR = "uploads"
ALLOWED_TYPES = {"image/jpeg", "image/png", "image/gif", "application/pdf"}
MAX_SIZE = 10 * 1024 * 1024 # 10 MB
# Single file upload
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
if file.content_type not in ALLOWED_TYPES:
raise HTTPException(status_code=400, detail="File type not allowed")
# Check file size
contents = await file.read()
if len(contents) > MAX_SIZE:
raise HTTPException(status_code=400, detail="File too large")
file_path = os.path.join(UPLOAD_DIR, file.filename)
with open(file_path, "wb") as f:
f.write(contents)
return {"filename": file.filename, "size": len(contents)}
# Multiple file upload
@app.post("/upload-multiple")
async def upload_multiple(files: List[UploadFile] = File(...)):
results = []
for file in files:
contents = await file.read()
file_path = os.path.join(UPLOAD_DIR, file.filename)
with open(file_path, "wb") as f:
f.write(contents)
results.append({"filename": file.filename, "size": len(contents)})
return results
# File upload with additional form data
@app.post("/upload-with-metadata")
async def upload_with_metadata(
file: UploadFile = File(...),
description: str = Form(...),
category: str = Form("general")
):
contents = await file.read()
return {
"filename": file.filename,
"description": description,
"category": category,
"size": len(contents)
}
Common pitfall: Calling await file.read() loads the entire file into memory. For large files, use chunked reading:
@app.post("/upload-large")
async def upload_large_file(file: UploadFile = File(...)):
file_path = os.path.join(UPLOAD_DIR, file.filename)
with open(file_path, "wb") as f:
while chunk := await file.read(1024 * 1024): # 1MB chunks
f.write(chunk)
return {"filename": file.filename}
A scalable FastAPI architecture addresses code organization, deployment topology, and operational concerns. Here is a proven pattern:
# Project structure for a scalable FastAPI application
project/
app/
__init__.py
main.py # App factory, middleware, router includes
config.py # Pydantic BaseSettings for env-based config
database.py # Engine, session factory, base model
middleware/
__init__.py
logging.py # Request/response logging
cors.py # CORS configuration
rate_limit.py # Rate limiting middleware
api/
__init__.py
v1/
__init__.py
router.py # Aggregates all v1 routers
endpoints/
users.py
products.py
orders.py
v2/
__init__.py
router.py
models/ # SQLAlchemy ORM models
schemas/ # Pydantic request/response schemas
services/ # Business logic layer
repositories/ # Data access layer
dependencies/ # Shared Depends() functions
events/ # Startup/shutdown event handlers
utils/ # Shared utilities
alembic/ # Database migrations
tests/
conftest.py
test_users.py
test_products.py
docker-compose.yml
Dockerfile
pyproject.toml
# app/config.py - Environment-based configuration
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
app_name: str = "My API"
debug: bool = False
database_url: str
redis_url: str = "redis://localhost:6379"
secret_key: str
allowed_origins: list[str] = ["http://localhost:3000"]
model_config = {"env_file": ".env"}
@lru_cache()
def get_settings():
return Settings()
# app/main.py - App factory pattern
from fastapi import FastAPI
from contextlib import asynccontextmanager
from app.config import get_settings
from app.api.v1.router import router as v1_router
from app.database import engine, Base
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
# Shutdown
await engine.dispose()
def create_app() -> FastAPI:
settings = get_settings()
app = FastAPI(
title=settings.app_name,
lifespan=lifespan,
docs_url="/docs" if settings.debug else None,
)
app.include_router(v1_router, prefix="/api/v1")
return app
app = create_app()
Key architectural principles:
/api/v1, /api/v2) to evolve your API without breaking clients.BaseSettings for type-safe, environment-driven configuration.ASGI (Asynchronous Server Gateway Interface) is the spiritual successor to WSGI. It defines a standard interface between async-capable Python web servers and applications.
| Feature | WSGI | ASGI |
|---|---|---|
| Concurrency model | Synchronous, one request per thread | Asynchronous, event-loop based |
| Protocol support | HTTP only | HTTP, WebSocket, HTTP/2 |
| Connection lifecycle | Request-response only | Long-lived connections supported |
| Frameworks | Flask, Django | FastAPI, Starlette, Django (3.0+) |
| Servers | Gunicorn, uWSGI | Uvicorn, Daphne, Hypercorn |
At its core, an ASGI application is a callable with this signature:
# Raw ASGI application example
async def app(scope, receive, send):
# scope - dict with connection info (type, path, headers, etc.)
# receive - async callable to receive messages from client
# send - async callable to send messages to client
if scope["type"] == "http":
# Read request body
body = b""
while True:
message = await receive()
body += message.get("body", b"")
if not message.get("more_body", False):
break
# Send response
await send({
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"application/json")],
})
await send({
"type": "http.response.body",
"body": b'{"message": "Hello from raw ASGI"}',
})
FastAPI wraps this low-level protocol behind its elegant decorator-based API. When you write @app.get("/items"), FastAPI (via Starlette) handles all the ASGI message passing for you.
The request flow: Client → Uvicorn (ASGI server) → Starlette (ASGI toolkit) → FastAPI (routing + validation) → Your handler
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import List
app = FastAPI()
# Connection manager for multiple clients
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
# Echo back to sender
await manager.send_personal(f"You said: {data}", websocket)
# Broadcast to all
await manager.broadcast(f"Client #{client_id}: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")
# WebSocket with JSON messages
@app.websocket("/ws/json")
async def json_websocket(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
action = data.get("action")
if action == "subscribe":
await websocket.send_json({"status": "subscribed", "channel": data["channel"]})
elif action == "message":
await websocket.send_json({"echo": data["content"]})
except WebSocketDisconnect:
pass
Best practice: For production WebSocket applications, use Redis Pub/Sub or a message broker to coordinate messages across multiple server instances, since in-memory connection managers only work within a single process.
High-concurrency optimization in FastAPI involves several layers:
# 1. Use async everywhere possible
import asyncpg
from fastapi import FastAPI
app = FastAPI()
# Use async database driver
pool = None
async def startup():
global pool
pool = await asyncpg.create_pool(
"postgresql://user:pass@localhost/db",
min_size=10,
max_size=50
)
@app.get("/users/{user_id}")
async def get_user(user_id: int):
async with pool.acquire() as conn:
row = await conn.fetchrow("SELECT * FROM users WHERE id = $1", user_id)
return dict(row)
# 2. Use connection pooling
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
engine = create_async_engine(
"postgresql+asyncpg://user:pass@localhost/db",
pool_size=20,
max_overflow=10,
pool_timeout=30,
pool_recycle=1800, # Recycle connections after 30 minutes
)
# 3. Add response caching
from fastapi_cache import FastAPICache
from fastapi_cache.backends.redis import RedisBackend
from fastapi_cache.decorator import cache
@app.get("/expensive-query")
@cache(expire=60)
async def expensive_query():
# Result is cached in Redis for 60 seconds
return await run_expensive_computation()
# 4. Use streaming responses for large payloads
from fastapi.responses import StreamingResponse
import json
@app.get("/large-dataset")
async def stream_data():
async def generate():
for chunk in fetch_large_dataset_in_chunks():
yield json.dumps(chunk) + "\n"
return StreamingResponse(generate(), media_type="application/x-ndjson")
# 5. Scale with multiple workers
gunicorn main:app -w 4 -k uvicorn.workers.UvicornWorker \
--bind 0.0.0.0:8000 \
--timeout 120 \
--keep-alive 5
Optimization checklist:
py-spy to find bottlenecks# SQLAlchemy async connection pooling
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.pool import QueuePool
engine = create_async_engine(
"postgresql+asyncpg://user:pass@localhost/db",
poolclass=QueuePool,
pool_size=20, # Steady-state connections
max_overflow=10, # Extra connections under load (total max: 30)
pool_timeout=30, # Seconds to wait for a connection
pool_recycle=1800, # Recycle connections every 30 minutes
pool_pre_ping=True, # Test connections before using them
echo=False, # Set True to log all SQL
)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
Sizing guidelines:
| Parameter | Guideline |
|---|---|
pool_size |
Number of Uvicorn workers x expected concurrent DB queries per worker. Start with 5-10 per worker. |
max_overflow |
50-100% of pool_size for burst handling |
pool_timeout |
Lower values (10-30s) fail fast; higher values queue more requests |
pool_recycle |
Set below your database’s wait_timeout to avoid stale connections |
max_connections setting. The total connections across all workers must not exceed the database limit.# Method 1: Custom middleware with Redis
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import JSONResponse
import aioredis
app = FastAPI()
redis = None
async def startup():
global redis
redis = aioredis.from_url("redis://localhost")
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
client_ip = request.client.host
key = f"rate_limit:{client_ip}"
window = 60 # seconds
max_requests = 100
current = await redis.get(key)
if current and int(current) >= max_requests:
return JSONResponse(
status_code=429,
content={"detail": "Too many requests"},
headers={"Retry-After": str(window)}
)
pipe = redis.pipeline()
pipe.incr(key)
pipe.expire(key, window)
await pipe.execute()
response = await call_next(request)
return response
# Method 2: Dependency-based rate limiting (per-route)
class RateLimiter:
def __init__(self, max_requests: int, window_seconds: int):
self.max_requests = max_requests
self.window = window_seconds
async def __call__(self, request: Request):
client_ip = request.client.host
key = f"rate:{client_ip}:{request.url.path}"
current = await redis.get(key)
if current and int(current) >= self.max_requests:
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Try again in {self.window} seconds."
)
pipe = redis.pipeline()
pipe.incr(key)
pipe.expire(key, self.window)
await pipe.execute()
# Apply different limits to different routes
@app.get("/search", dependencies=[Depends(RateLimiter(max_requests=30, window_seconds=60))])
async def search(q: str):
return {"results": []}
@app.post("/upload", dependencies=[Depends(RateLimiter(max_requests=5, window_seconds=60))])
async def upload():
return {"status": "ok"}
Best practice: Use dependency-based rate limiting so you can apply different limits to different endpoints. Expensive operations (search, uploads) should have stricter limits than simple reads.
# Dockerfile (multi-stage build) FROM python:3.11-slim AS builder WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir --prefix=/install -r requirements.txt FROM python:3.11-slim WORKDIR /app COPY --from=builder /install /usr/local COPY ./app ./app EXPOSE 8000 CMD ["gunicorn", "app.main:app", "-w", "4", "-k", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000"]
# docker-compose.yml
version: "3.8"
services:
api:
build: .
environment:
- DATABASE_URL=postgresql://user:pass@db:5432/mydb
- REDIS_URL=redis://redis:6379
depends_on:
- db
- redis
networks:
- backend
nginx:
image: nginx:alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
- ./certs:/etc/ssl/certs
depends_on:
- api
networks:
- backend
db:
image: postgres:15
environment:
POSTGRES_USER: user
POSTGRES_PASSWORD: pass
POSTGRES_DB: mydb
volumes:
- pgdata:/var/lib/postgresql/data
networks:
- backend
redis:
image: redis:7-alpine
networks:
- backend
volumes:
pgdata:
networks:
backend:
# nginx.conf
events {
worker_connections 1024;
}
http {
upstream fastapi {
server api:8000;
}
server {
listen 80;
server_name api.example.com;
location / {
proxy_pass http://fastapi;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
location /ws {
proxy_pass http://fastapi;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
}
}
}
Key considerations:
2 * CPU_CORES + 1.Upgrade header.# Event-driven architecture using an in-process event bus
from fastapi import FastAPI
from typing import Callable, Dict, List
import asyncio
# Simple event bus
class EventBus:
def __init__(self):
self._handlers: Dict[str, List[Callable]] = {}
def subscribe(self, event_type: str, handler: Callable):
if event_type not in self._handlers:
self._handlers[event_type] = []
self._handlers[event_type].append(handler)
async def publish(self, event_type: str, data: dict):
handlers = self._handlers.get(event_type, [])
await asyncio.gather(*[handler(data) for handler in handlers])
event_bus = EventBus()
# Register handlers
async def send_welcome_email(data: dict):
print(f"Sending welcome email to {data['email']}")
async def create_default_settings(data: dict):
print(f"Creating default settings for user {data['user_id']}")
async def notify_admin(data: dict):
print(f"New user registered: {data['username']}")
event_bus.subscribe("user.created", send_welcome_email)
event_bus.subscribe("user.created", create_default_settings)
event_bus.subscribe("user.created", notify_admin)
app = FastAPI()
@app.post("/users")
async def create_user(user: dict):
new_user = {"id": 1, **user}
# Publish event - all handlers run concurrently
await event_bus.publish("user.created", {
"user_id": new_user["id"],
"username": new_user.get("username"),
"email": new_user.get("email")
})
return new_user
# Production: Event-driven with Redis Streams
import aioredis
import json
class RedisEventPublisher:
def __init__(self, redis_url: str):
self.redis = None
self.redis_url = redis_url
async def connect(self):
self.redis = aioredis.from_url(self.redis_url)
async def publish(self, channel: str, event: dict):
await self.redis.xadd(
channel,
{"data": json.dumps(event)}
)
publisher = RedisEventPublisher("redis://localhost")
@app.post("/orders")
async def create_order(order: dict):
new_order = {"id": 1, "total": order.get("total", 0), **order}
await publisher.publish("orders", {
"event": "order.created",
"order_id": new_order["id"],
"total": new_order["total"]
})
return new_order
Best practice: Start with a simple in-process event bus for monoliths. Move to Redis Streams or Kafka when you need cross-service communication or guaranteed delivery.
# Install Alembic pip install alembic # Initialize Alembic in your project alembic init alembic
# alembic/env.py (key configuration)
from app.database import Base
from app.models import user, product, order # Import all models
from app.config import get_settings
settings = get_settings()
config.set_main_option("sqlalchemy.url", settings.database_url)
target_metadata = Base.metadata
# Create a migration alembic revision --autogenerate -m "add users table" # Apply migrations alembic upgrade head # Rollback one step alembic downgrade -1 # View migration history alembic history --verbose
# Example migration file
# add users table
# Revision ID: a1b2c3d4e5f6
from alembic import op
import sqlalchemy as sa
def upgrade():
op.create_table(
"users",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("username", sa.String(50), unique=True, nullable=False),
sa.Column("email", sa.String(120), unique=True, nullable=False),
sa.Column("hashed_password", sa.String(255), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.now()),
)
op.create_index("ix_users_email", "users", ["email"])
def downgrade():
op.drop_index("ix_users_email", table_name="users")
op.drop_table("users")
Production migration strategy:
import aioredis
import json
import hashlib
from fastapi import FastAPI, Request
from functools import wraps
app = FastAPI()
redis = None
async def startup():
global redis
redis = aioredis.from_url("redis://localhost", decode_responses=True)
# Strategy 1: Simple key-value caching
async def get_cached_or_fetch(key: str, fetch_func, ttl: int = 300):
cached = await redis.get(key)
if cached:
return json.loads(cached)
data = await fetch_func()
await redis.setex(key, ttl, json.dumps(data))
return data
@app.get("/products/{product_id}")
async def get_product(product_id: int):
async def fetch():
return await db_get_product(product_id)
return await get_cached_or_fetch(f"product:{product_id}", fetch, ttl=600)
# Strategy 2: Cache decorator
def cached(prefix: str, ttl: int = 300):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Build cache key from function args
key_data = f"{prefix}:{args}:{kwargs}"
cache_key = hashlib.md5(key_data.encode()).hexdigest()
cached_result = await redis.get(cache_key)
if cached_result:
return json.loads(cached_result)
result = await func(*args, **kwargs)
await redis.setex(cache_key, ttl, json.dumps(result))
return result
return wrapper
return decorator
# Strategy 3: Cache invalidation on write
@app.post("/products")
async def create_product(product: dict):
new_product = {"id": 1, **product}
# Invalidate list cache
await redis.delete("products:list")
# Cache the new product
await redis.setex(
f"product:{new_product['id']}",
600,
json.dumps(new_product)
)
return new_product
@app.put("/products/{product_id}")
async def update_product(product_id: int, updates: dict):
updated = {"id": product_id, **updates}
# Invalidate specific cache and list cache
await redis.delete(f"product:{product_id}")
await redis.delete("products:list")
return updated
# Strategy 4: HTTP cache headers
from fastapi.responses import JSONResponse
@app.get("/static-config")
async def get_config():
data = {"version": "1.0", "features": ["a", "b"]}
response = JSONResponse(content=data)
response.headers["Cache-Control"] = "public, max-age=3600"
response.headers["ETag"] = hashlib.md5(json.dumps(data).encode()).hexdigest()
return response
Caching strategies summary:
| Strategy | Use Case | TTL Guidance |
|---|---|---|
| Cache-aside (lazy load) | General-purpose; read-heavy data | 5-60 minutes |
| Write-through | Data that must be immediately consistent | Match read cache TTL |
| Cache invalidation | Data modified via your own API | Infinite (invalidate on change) |
| HTTP caching headers | Static or slowly changing responses | Based on data volatility |
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
import uuid
import re
app = FastAPI(docs_url=None, redoc_url=None) # Disable docs in production
# 1. CORS - restrict allowed origins
app.add_middleware(
CORSMiddleware,
allow_origins=["https://yourfrontend.com"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
)
# 2. Trusted hosts - prevent host header attacks
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["api.example.com"])
# 3. HTTPS redirect
app.add_middleware(HTTPSRedirectMiddleware)
# 4. Security headers middleware
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Content-Security-Policy"] = "default-src 'self'"
return response
# 5. Request ID tracking
@app.middleware("http")
async def add_request_id(request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
# 6. Input sanitization in Pydantic models
from pydantic import BaseModel, field_validator
class UserInput(BaseModel):
name: str
comment: str
@field_validator("comment")
@classmethod
def sanitize_comment(cls, v):
# Remove potential script tags
cleaned = re.sub(r"<script.*?>.*?</script>", "", v, flags=re.DOTALL | re.IGNORECASE)
return cleaned.strip()
Production security checklist:
/docs, /redoc)* in production)# .github/workflows/ci-cd.yml
name: CI/CD Pipeline
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:15
env:
POSTGRES_USER: test
POSTGRES_PASSWORD: test
POSTGRES_DB: testdb
ports:
- 5432:5432
redis:
image: redis:7
ports:
- 6379:6379
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio httpx
- name: Run linting
run: |
pip install ruff
ruff check app/
- name: Run type checking
run: |
pip install mypy
mypy app/ --ignore-missing-imports
- name: Run tests with coverage
env:
DATABASE_URL: postgresql://test:test@localhost:5432/testdb
REDIS_URL: redis://localhost:6379
run: |
pytest tests/ -v --cov=app --cov-report=xml
- name: Upload coverage
uses: codecov/codecov-action@v3
deploy:
needs: test
runs-on: ubuntu-latest
if: github.ref == 'refs/heads/main'
steps:
- uses: actions/checkout@v4
- name: Build Docker image
run: docker build -t myapi:latest .
- name: Push to registry
run: |
docker tag myapi:latest registry.example.com/myapi:latest
docker push registry.example.com/myapi:latest
- name: Deploy to production
run: |
ssh deploy@production "cd /app && docker-compose pull && docker-compose up -d"
CI/CD best practices for FastAPI:
ruff for fast Python linting and formatting.mypy to catch type errors before runtime.alembic upgrade head as part of the deploy step./health endpoint and verify it after deployment.This guide covered 36 interview questions spanning junior, mid-level, and senior FastAPI topics. Here are the most important themes to remember:
| Level | Key Themes |
|---|---|
| Junior | Understand path/query parameters, Pydantic models, type hints, HTTP methods, error handling, and how to run a FastAPI app with Uvicorn. |
| Mid-Level | Master dependency injection, JWT authentication, async vs sync patterns, CRUD operations, APIRouter organization, testing with TestClient, and file uploads. |
| Senior | Design scalable architectures, understand ASGI, implement WebSockets, optimize for high concurrency, manage database migrations with Alembic, implement caching and rate limiting, deploy with Docker/Nginx, secure the application, and set up CI/CD pipelines. |
Resources for further study:
Deploying a FastAPI application to production requires more than just running uvicorn main:app. A production deployment involves configuring ASGI servers for performance, containerizing your application with Docker, setting up reverse proxies, implementing CI/CD pipelines, managing database migrations, and ensuring security and monitoring are in place.
This comprehensive guide covers everything you need to deploy FastAPI applications reliably, from single-server setups to scalable cloud architectures. Whether you’re deploying to AWS, Heroku, DigitalOcean, or your own infrastructure, you’ll find practical, production-tested configurations here.
Before deploying, your FastAPI application needs proper configuration management, structured logging, and environment-specific settings. The pydantic-settings library provides type-safe configuration that reads from environment variables and .env files.
Install the required package:
pip install pydantic-settings python-dotenv
Create a centralized settings module that all parts of your application can import:
# app/config.py
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field
from functools import lru_cache
from typing import Optional
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
)
# Application
app_name: str = "FastAPI App"
app_version: str = "1.0.0"
debug: bool = False
environment: str = "production" # development, staging, production
# Server
host: str = "0.0.0.0"
port: int = 8000
workers: int = 4
reload: bool = False
# Database
database_url: str = "postgresql+asyncpg://user:pass@localhost:5432/mydb"
db_pool_size: int = 20
db_max_overflow: int = 10
db_pool_timeout: int = 30
# Redis
redis_url: str = "redis://localhost:6379/0"
# Security
secret_key: str = Field(default="change-me-in-production")
allowed_hosts: list[str] = ["*"]
cors_origins: list[str] = ["http://localhost:3000"]
# JWT
jwt_secret: str = Field(default="jwt-secret-change-me")
jwt_algorithm: str = "HS256"
jwt_expiration_minutes: int = 30
# Logging
log_level: str = "INFO"
log_format: str = "json" # json or text
# External Services
smtp_host: Optional[str] = None
smtp_port: int = 587
sentry_dsn: Optional[str] = None
@lru_cache()
def get_settings() -> Settings:
"""Cached settings instance."""
return Settings()
Create a .env file for local development:
# .env APP_NAME=MyFastAPIApp DEBUG=true ENVIRONMENT=development DATABASE_URL=postgresql+asyncpg://postgres:password@localhost:5432/mydb REDIS_URL=redis://localhost:6379/0 SECRET_KEY=dev-secret-key-not-for-production JWT_SECRET=dev-jwt-secret LOG_LEVEL=DEBUG LOG_FORMAT=text CORS_ORIGINS=["http://localhost:3000","http://localhost:8080"]
Use settings throughout your application:
# app/main.py
from fastapi import FastAPI, Depends
from app.config import Settings, get_settings
app = FastAPI()
@app.get("/info")
async def app_info(settings: Settings = Depends(get_settings)):
return {
"app_name": settings.app_name,
"version": settings.app_version,
"environment": settings.environment,
"debug": settings.debug,
}
Production applications need structured logging (JSON format) for log aggregation tools like ELK Stack, Datadog, or CloudWatch. Use structlog for structured, contextualized logging:
pip install structlog
# app/logging_config.py
import logging
import sys
import structlog
from app.config import get_settings
def setup_logging():
"""Configure structured logging for the application."""
settings = get_settings()
# Choose processors based on environment
if settings.log_format == "json":
renderer = structlog.processors.JSONRenderer()
else:
renderer = structlog.dev.ConsoleRenderer(colors=True)
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.StackInfoRenderer(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.format_exc_info,
renderer,
],
wrapper_class=structlog.make_filtering_bound_logger(
getattr(logging, settings.log_level.upper(), logging.INFO)
),
context_class=dict,
logger_factory=structlog.PrintLoggerFactory(file=sys.stdout),
cache_logger_on_first_use=True,
)
def get_logger(name: str = __name__):
"""Get a structured logger instance."""
return structlog.get_logger(name)
Add request logging middleware to track every request:
# app/middleware.py
import time
import uuid
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from app.logging_config import get_logger
logger = get_logger(__name__)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Log every request with timing and correlation ID."""
async def dispatch(self, request: Request, call_next):
request_id = str(uuid.uuid4())[:8]
start_time = time.perf_counter()
# Add request ID to structlog context
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(request_id=request_id)
logger.info(
"request_started",
method=request.method,
path=request.url.path,
client_ip=request.client.host if request.client else "unknown",
)
response = await call_next(request)
duration = time.perf_counter() - start_time
logger.info(
"request_completed",
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration_ms=round(duration * 1000, 2),
)
response.headers["X-Request-ID"] = request_id
return response
Use a factory function to create your FastAPI application with all middleware and configuration applied:
# app/main.py
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.config import get_settings
from app.logging_config import setup_logging
from app.middleware import RequestLoggingMiddleware
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application startup and shutdown."""
# Startup
setup_logging()
from app.logging_config import get_logger
logger = get_logger("lifespan")
logger.info("application_starting", environment=get_settings().environment)
# Initialize database, Redis, etc.
# await init_db()
# await init_redis()
yield # Application runs here
# Shutdown
logger.info("application_shutting_down")
# await close_db()
# await close_redis()
def create_app() -> FastAPI:
"""Application factory."""
settings = get_settings()
app = FastAPI(
title=settings.app_name,
version=settings.app_version,
debug=settings.debug,
lifespan=lifespan,
docs_url="/docs" if settings.debug else None,
redoc_url="/redoc" if settings.debug else None,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request logging
app.add_middleware(RequestLoggingMiddleware)
# Include routers
from app.routers import api_router
app.include_router(api_router, prefix="/api/v1")
return app
app = create_app()
FastAPI runs on ASGI (Asynchronous Server Gateway Interface) servers. While Uvicorn is great for development, production deployments need proper process management, graceful shutdowns, and multiple worker processes.
Uvicorn can run with multiple workers for production use:
# Basic production run
uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4
# With all production options
uvicorn app.main:app \
--host 0.0.0.0 \
--port 8000 \
--workers 4 \
--loop uvloop \
--http httptools \
--log-level warning \
--access-log \
--proxy-headers \
--forwarded-allow-ips="*"
The number of workers should typically be set to (2 * CPU_CORES) + 1. You can also configure Uvicorn programmatically:
# run.py
import uvicorn
from app.config import get_settings
if __name__ == "__main__":
settings = get_settings()
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
workers=settings.workers,
reload=settings.reload,
log_level=settings.log_level.lower(),
proxy_headers=True,
forwarded_allow_ips="*",
)
Gunicorn provides battle-tested process management. Combined with Uvicorn workers, it gives you the best of both worlds — Gunicorn’s process management with Uvicorn’s ASGI performance:
# Install both
pip install gunicorn uvicorn[standard]
# Run with Uvicorn workers
gunicorn app.main:app \
--worker-class uvicorn.workers.UvicornWorker \
--workers 4 \
--bind 0.0.0.0:8000 \
--timeout 120 \
--graceful-timeout 30 \
--keep-alive 5 \
--access-logfile - \
--error-logfile -
Create a Gunicorn configuration file for more control:
# gunicorn.conf.py
import multiprocessing
import os
# Server socket
bind = f"0.0.0.0:{os.getenv('PORT', '8000')}"
backlog = 2048
# Worker processes
workers = int(os.getenv("WEB_CONCURRENCY", multiprocessing.cpu_count() * 2 + 1))
worker_class = "uvicorn.workers.UvicornWorker"
worker_connections = 1000
timeout = 120
graceful_timeout = 30
keepalive = 5
# Restart workers after this many requests (prevents memory leaks)
max_requests = 1000
max_requests_jitter = 50
# Logging
accesslog = "-"
errorlog = "-"
loglevel = os.getenv("LOG_LEVEL", "info").lower()
# Process naming
proc_name = "fastapi-app"
# Server hooks
def on_starting(server):
"""Called just before the master process is initialized."""
pass
def post_worker_init(worker):
"""Called just after a worker has been initialized."""
worker.log.info(f"Worker {worker.pid} initialized")
def worker_exit(server, worker):
"""Called when a worker exits."""
worker.log.info(f"Worker {worker.pid} exiting")
# Run with config file gunicorn app.main:app -c gunicorn.conf.py
Hypercorn supports HTTP/2 and HTTP/3, which can be useful for applications that benefit from multiplexed connections:
pip install hypercorn
# Basic run
hypercorn app.main:app --bind 0.0.0.0:8000 --workers 4
# With HTTP/2
hypercorn app.main:app \
--bind 0.0.0.0:8000 \
--workers 4 \
--certfile cert.pem \
--keyfile key.pem
| Feature | Uvicorn | Gunicorn + Uvicorn | Hypercorn |
|---|---|---|---|
| Process Management | Basic | Advanced (preforking) | Basic |
| Graceful Restart | Limited | Full (SIGHUP) | Limited |
| HTTP/2 | No | No | Yes |
| Worker Recovery | Manual | Automatic | Manual |
| Memory Leak Protection | No | max_requests | No |
| Production Ready | With care | Yes (recommended) | With care |
Docker provides consistent, reproducible environments across development, staging, and production. A well-crafted Dockerfile ensures your FastAPI application runs the same way everywhere.
# Dockerfile
FROM python:3.12-slim
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1
# Create non-root user
RUN groupadd -r appuser && useradd -r -g appuser -d /app -s /sbin/nologin appuser
WORKDIR /app
# Install system dependencies
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
curl \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Change ownership to non-root user
RUN chown -R appuser:appuser /app
USER appuser
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run the application
CMD ["gunicorn", "app.main:app", \
"--worker-class", "uvicorn.workers.UvicornWorker", \
"--workers", "4", \
"--bind", "0.0.0.0:8000", \
"--timeout", "120", \
"--access-logfile", "-"]
Multi-stage builds produce smaller images by separating build dependencies from the runtime environment:
# Dockerfile.multistage
# ---- Build Stage ----
FROM python:3.12-slim AS builder
ENV PYTHONDONTWRITEBYTECODE=1 \
PIP_NO_CACHE_DIR=1
WORKDIR /build
# Install build dependencies
RUN apt-get update \
&& apt-get install -y --no-install-recommends build-essential \
&& rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install --prefix=/install --no-cache-dir -r requirements.txt
# ---- Runtime Stage ----
FROM python:3.12-slim AS runtime
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
# Create non-root user
RUN groupadd -r appuser && useradd -r -g appuser -d /app -s /sbin/nologin appuser
# Install runtime-only system dependencies
RUN apt-get update \
&& apt-get install -y --no-install-recommends curl \
&& rm -rf /var/lib/apt/lists/*
# Copy Python packages from builder
COPY --from=builder /install /usr/local
WORKDIR /app
COPY --chown=appuser:appuser . .
USER appuser
EXPOSE 8000
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
CMD ["gunicorn", "app.main:app", \
"--worker-class", "uvicorn.workers.UvicornWorker", \
"--workers", "4", \
"--bind", "0.0.0.0:8000"]
Exclude unnecessary files from the build context:
# .dockerignore __pycache__ *.pyc *.pyo .git .gitignore .env .env.* .venv venv *.md docs/ tests/ .pytest_cache .coverage htmlcov/ .mypy_cache .ruff_cache docker-compose*.yml Dockerfile* .dockerignore
# docker-compose.yml
version: "3.9"
services:
app:
build:
context: .
dockerfile: Dockerfile
ports:
- "8000:8000"
environment:
- DATABASE_URL=postgresql+asyncpg://postgres:password@db:5432/fastapi_db
- REDIS_URL=redis://redis:6379/0
- ENVIRONMENT=development
- DEBUG=true
- LOG_LEVEL=DEBUG
volumes:
- .:/app # Hot reload in development
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
db:
image: postgres:16-alpine
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: password
POSTGRES_DB: fastapi_db
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 5s
retries: 5
volumes:
postgres_data:
redis_data:
# Build and start all services docker compose up --build -d # View logs docker compose logs -f app # Run database migrations docker compose exec app alembic upgrade head # Stop all services docker compose down # Stop and remove volumes (clean slate) docker compose down -v
Nginx sits in front of your ASGI server to handle SSL termination, static file serving, load balancing, request buffering, and rate limiting. It is the standard production setup for Python web applications.
# nginx/nginx.conf
upstream fastapi_backend {
server app:8000;
}
server {
listen 80;
server_name yourdomain.com www.yourdomain.com;
# Redirect HTTP to HTTPS
return 301 https://$host$request_uri;
}
server {
listen 443 ssl http2;
server_name yourdomain.com www.yourdomain.com;
# SSL certificates (Let's Encrypt)
ssl_certificate /etc/letsencrypt/live/yourdomain.com/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/yourdomain.com/privkey.pem;
# SSL settings
ssl_protocols TLSv1.2 TLSv1.3;
ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256;
ssl_prefer_server_ciphers off;
ssl_session_timeout 1d;
ssl_session_cache shared:SSL:10m;
ssl_session_tickets off;
# Security headers
add_header X-Frame-Options "SAMEORIGIN" always;
add_header X-Content-Type-Options "nosniff" always;
add_header X-XSS-Protection "1; mode=block" always;
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
add_header Strict-Transport-Security "max-age=63072000; includeSubDomains" always;
# Request size limit
client_max_body_size 10M;
# Gzip compression
gzip on;
gzip_vary on;
gzip_min_length 1024;
gzip_types text/plain text/css application/json application/javascript text/xml;
# Static files
location /static/ {
alias /app/static/;
expires 30d;
add_header Cache-Control "public, immutable";
}
# API proxy
location / {
proxy_pass http://fastapi_backend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# Timeouts
proxy_connect_timeout 60s;
proxy_send_timeout 60s;
proxy_read_timeout 60s;
# Buffering
proxy_buffering on;
proxy_buffer_size 4k;
proxy_buffers 8 4k;
}
# Health check endpoint (no logging)
location /health {
proxy_pass http://fastapi_backend/health;
access_log off;
}
}
FastAPI supports WebSockets, which require special Nginx configuration:
# Add to the server block
location /ws/ {
proxy_pass http://fastapi_backend;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# WebSocket timeout (keep alive)
proxy_read_timeout 86400s;
proxy_send_timeout 86400s;
}
If you run multiple FastAPI instances, Nginx can load balance between them:
upstream fastapi_backend {
least_conn; # Send to the server with fewest connections
server app1:8000 weight=3; # Higher weight = more traffic
server app2:8000 weight=2;
server app3:8000 weight=1;
# Health checks (Nginx Plus only, use external for OSS)
# health_check interval=10s fails=3 passes=2;
}
# Add to http block (before server blocks)
limit_req_zone $binary_remote_addr zone=api:10m rate=10r/s;
limit_req_zone $binary_remote_addr zone=login:10m rate=1r/s;
server {
# ...
# General API rate limiting
location /api/ {
limit_req zone=api burst=20 nodelay;
proxy_pass http://fastapi_backend;
# ... proxy headers
}
# Strict rate limiting for auth endpoints
location /api/auth/ {
limit_req zone=login burst=5 nodelay;
proxy_pass http://fastapi_backend;
# ... proxy headers
}
}
AWS offers multiple ways to deploy FastAPI, from virtual servers (EC2) to managed containers (ECS/Fargate) to serverless (Lambda). Each approach has different trade-offs in cost, complexity, and scalability.
EC2 gives you full control over the server environment. This is a good starting point for teams familiar with server administration.
#!/bin/bash # ec2-setup.sh - Run on a fresh Ubuntu 22.04 EC2 instance # Update system sudo apt-get update && sudo apt-get upgrade -y # Install Python 3.12 sudo add-apt-repository ppa:deadsnakes/ppa -y sudo apt-get install -y python3.12 python3.12-venv python3.12-dev # Install Nginx sudo apt-get install -y nginx certbot python3-certbot-nginx # Install supervisor for process management sudo apt-get install -y supervisor # Create application directory sudo mkdir -p /opt/fastapi-app sudo chown $USER:$USER /opt/fastapi-app # Clone your application cd /opt/fastapi-app git clone https://github.com/youruser/yourapp.git . # Create virtual environment python3.12 -m venv venv source venv/bin/activate pip install -r requirements.txt # Copy environment file cp .env.production .env
# /etc/supervisor/conf.d/fastapi.conf
[program:fastapi]
command=/opt/fastapi-app/venv/bin/gunicorn app.main:app
--worker-class uvicorn.workers.UvicornWorker
--workers 4
--bind unix:/tmp/fastapi.sock
--timeout 120
--access-logfile /var/log/fastapi/access.log
--error-logfile /var/log/fastapi/error.log
directory=/opt/fastapi-app
user=www-data
autostart=true
autorestart=true
redirect_stderr=true
stdout_logfile=/var/log/fastapi/supervisor.log
environment=
ENVIRONMENT="production",
DATABASE_URL="postgresql+asyncpg://user:pass@rds-endpoint:5432/mydb"
# Start the application sudo supervisorctl reread sudo supervisorctl update sudo supervisorctl start fastapi # Check status sudo supervisorctl status fastapi
ECS Fargate runs your Docker containers without managing servers. You define a task (container specs) and a service (how many to run).
# ecs-task-definition.json
{
"family": "fastapi-app",
"networkMode": "awsvpc",
"requiresCompatibilities": ["FARGATE"],
"cpu": "512",
"memory": "1024",
"executionRoleArn": "arn:aws:iam::ACCOUNT:role/ecsTaskExecutionRole",
"containerDefinitions": [
{
"name": "fastapi",
"image": "ACCOUNT.dkr.ecr.us-east-1.amazonaws.com/fastapi-app:latest",
"portMappings": [
{
"containerPort": 8000,
"protocol": "tcp"
}
],
"environment": [
{"name": "ENVIRONMENT", "value": "production"},
{"name": "WORKERS", "value": "2"}
],
"secrets": [
{
"name": "DATABASE_URL",
"valueFrom": "arn:aws:ssm:us-east-1:ACCOUNT:parameter/fastapi/database_url"
},
{
"name": "SECRET_KEY",
"valueFrom": "arn:aws:ssm:us-east-1:ACCOUNT:parameter/fastapi/secret_key"
}
],
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": "/ecs/fastapi-app",
"awslogs-region": "us-east-1",
"awslogs-stream-prefix": "ecs"
}
},
"healthCheck": {
"command": ["CMD-SHELL", "curl -f http://localhost:8000/health || exit 1"],
"interval": 30,
"timeout": 5,
"retries": 3,
"startPeriod": 10
}
}
]
}
# Build and push Docker image to ECR
aws ecr get-login-password --region us-east-1 | \
docker login --username AWS --password-stdin ACCOUNT.dkr.ecr.us-east-1.amazonaws.com
docker build -t fastapi-app .
docker tag fastapi-app:latest ACCOUNT.dkr.ecr.us-east-1.amazonaws.com/fastapi-app:latest
docker push ACCOUNT.dkr.ecr.us-east-1.amazonaws.com/fastapi-app:latest
# Register task definition
aws ecs register-task-definition --cli-input-json file://ecs-task-definition.json
# Create or update service
aws ecs update-service \
--cluster fastapi-cluster \
--service fastapi-service \
--task-definition fastapi-app \
--desired-count 2 \
--force-new-deployment
Mangum is an adapter that lets you run FastAPI on AWS Lambda behind API Gateway. This is ideal for low-traffic APIs or APIs with bursty traffic patterns.
pip install mangum
# lambda_handler.py from mangum import Mangum from app.main import app # Create the Lambda handler handler = Mangum(app, lifespan="off")
# template.yaml (AWS SAM)
AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31
Globals:
Function:
Timeout: 30
MemorySize: 512
Runtime: python3.12
Resources:
FastAPIFunction:
Type: AWS::Serverless::Function
Properties:
Handler: lambda_handler.handler
CodeUri: .
Events:
ApiEvent:
Type: HttpApi
Properties:
Path: /{proxy+}
Method: ANY
RootEvent:
Type: HttpApi
Properties:
Path: /
Method: ANY
Environment:
Variables:
ENVIRONMENT: production
DATABASE_URL: !Ref DatabaseUrl
Policies:
- AmazonSSMReadOnlyAccess
Parameters:
DatabaseUrl:
Type: AWS::SSM::Parameter::Value<String>
Default: /fastapi/database_url
Outputs:
ApiUrl:
Description: API Gateway endpoint URL
Value: !Sub "https://${ServerlessHttpApi}.execute-api.${AWS::Region}.amazonaws.com"
# Deploy with SAM sam build sam deploy --guided
| Feature | EC2 | ECS Fargate | Lambda |
|---|---|---|---|
| Server Management | You manage | AWS manages | Fully serverless |
| Scaling | Manual / ASG | Auto-scaling | Automatic |
| Cost Model | Per hour | Per vCPU/memory/sec | Per request |
| Cold Start | None | Minimal | Yes (seconds) |
| WebSockets | Yes | Yes | Via API Gateway |
| Best For | Full control | Containers at scale | Low/bursty traffic |
Heroku is one of the simplest platforms for deploying FastAPI. It handles infrastructure, SSL, and scaling with minimal configuration.
Create the required files in your project root:
# Procfile web: gunicorn app.main:app --worker-class uvicorn.workers.UvicornWorker --workers 2 --bind 0.0.0.0:$PORT --timeout 120
# runtime.txt python-3.12.3
# requirements.txt fastapi==0.115.0 uvicorn[standard]==0.30.0 gunicorn==22.0.0 pydantic-settings==2.5.0 sqlalchemy[asyncio]==2.0.35 asyncpg==0.29.0 alembic==1.13.0 python-dotenv==1.0.1 httpx==0.27.0
# Login to Heroku
heroku login
# Create a new app
heroku create my-fastapi-app
# Add PostgreSQL addon
heroku addons:create heroku-postgresql:essential-0
# Add Redis addon
heroku addons:create heroku-redis:mini
# Set environment variables
heroku config:set \
ENVIRONMENT=production \
SECRET_KEY=$(python -c "import secrets; print(secrets.token_urlsafe(32))") \
JWT_SECRET=$(python -c "import secrets; print(secrets.token_urlsafe(32))") \
LOG_LEVEL=INFO \
LOG_FORMAT=json
# Deploy
git push heroku main
# Run migrations
heroku run alembic upgrade head
# View logs
heroku logs --tail
# Scale dynos
heroku ps:scale web=2
Add a release command to automatically run migrations on each deploy:
# Procfile (updated) web: gunicorn app.main:app --worker-class uvicorn.workers.UvicornWorker --workers 2 --bind 0.0.0.0:$PORT release: alembic upgrade head
DigitalOcean offers two main options: App Platform (managed PaaS, similar to Heroku) and Droplets (virtual servers, similar to EC2).
Create an app specification file:
# .do/app.yaml
name: fastapi-app
region: nyc
services:
- name: api
github:
repo: youruser/fastapi-app
branch: main
deploy_on_push: true
build_command: pip install -r requirements.txt
run_command: gunicorn app.main:app --worker-class uvicorn.workers.UvicornWorker --workers 2 --bind 0.0.0.0:$PORT
envs:
- key: ENVIRONMENT
value: production
- key: SECRET_KEY
type: SECRET
value: your-secret-key
- key: DATABASE_URL
scope: RUN_TIME
value: ${db.DATABASE_URL}
instance_count: 2
instance_size_slug: professional-xs
http_port: 8000
health_check:
http_path: /health
databases:
- engine: PG
name: db
num_nodes: 1
size: db-s-dev-database
version: "16"
# Deploy using doctl CLI doctl apps create --spec .do/app.yaml # List apps doctl apps list # View logs doctl apps logs APP_ID --type run
For a Droplet (virtual server), the setup is similar to EC2. Create a setup script:
#!/bin/bash
# droplet-setup.sh - For Ubuntu 22.04 Droplet
# Update system
apt-get update && apt-get upgrade -y
# Install dependencies
apt-get install -y python3.12 python3.12-venv python3-pip nginx certbot python3-certbot-nginx
# Setup application
mkdir -p /opt/fastapi-app
cd /opt/fastapi-app
git clone https://github.com/youruser/yourapp.git .
python3.12 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
# Create systemd service
cat > /etc/systemd/system/fastapi.service << 'UNIT'
[Unit]
Description=FastAPI Application
After=network.target
[Service]
User=www-data
Group=www-data
WorkingDirectory=/opt/fastapi-app
Environment="PATH=/opt/fastapi-app/venv/bin"
EnvironmentFile=/opt/fastapi-app/.env
ExecStart=/opt/fastapi-app/venv/bin/gunicorn app.main:app \
--worker-class uvicorn.workers.UvicornWorker \
--workers 4 \
--bind unix:/tmp/fastapi.sock \
--timeout 120
Restart=always
RestartSec=5
[Install]
WantedBy=multi-user.target
UNIT
# Enable and start
systemctl daemon-reload
systemctl enable fastapi
systemctl start fastapi
# Setup Nginx
cat > /etc/nginx/sites-available/fastapi << 'NGINX'
server {
listen 80;
server_name yourdomain.com;
location / {
proxy_pass http://unix:/tmp/fastapi.sock;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
}
NGINX
ln -s /etc/nginx/sites-available/fastapi /etc/nginx/sites-enabled/
nginx -t && systemctl restart nginx
# Setup SSL with Let's Encrypt
certbot --nginx -d yourdomain.com --non-interactive --agree-tos -m you@email.com
Automate testing, building, and deployment with GitHub Actions. A proper CI/CD pipeline ensures every change is tested before it reaches production.
# .github/workflows/ci-cd.yml
name: CI/CD Pipeline
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
env:
PYTHON_VERSION: "3.12"
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
jobs:
# ---- Lint & Type Check ----
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install dependencies
run: |
pip install ruff mypy
pip install -r requirements.txt
- name: Run Ruff linter
run: ruff check .
- name: Run Ruff formatter check
run: ruff format --check .
- name: Run MyPy type checker
run: mypy app/ --ignore-missing-imports
# ---- Unit & Integration Tests ----
test:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:16-alpine
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: password
POSTGRES_DB: test_db
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
redis:
image: redis:7-alpine
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
cache: pip
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Run tests with coverage
env:
DATABASE_URL: postgresql+asyncpg://postgres:password@localhost:5432/test_db
REDIS_URL: redis://localhost:6379/0
ENVIRONMENT: testing
SECRET_KEY: test-secret-key
run: |
pytest tests/ -v --cov=app --cov-report=xml --cov-report=term
- name: Upload coverage report
uses: codecov/codecov-action@v4
with:
file: coverage.xml
fail_ci_if_error: false
# ---- Build Docker Image ----
build:
needs: [lint, test]
runs-on: ubuntu-latest
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Container Registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v5
with:
context: .
push: true
tags: |
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.sha }}
cache-from: type=gha
cache-to: type=gha,mode=max
# ---- Deploy to Production ----
deploy:
needs: build
runs-on: ubuntu-latest
if: github.ref == 'refs/heads/main'
environment: production
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
cd /opt/fastapi-app
docker compose pull
docker compose up -d --remove-orphans
docker compose exec -T app alembic upgrade head
docker system prune -f
Add separate deployment jobs for staging and production:
# ---- Deploy to Staging ----
deploy-staging:
needs: build
runs-on: ubuntu-latest
if: github.ref == 'refs/heads/develop'
environment: staging
steps:
- name: Deploy to staging
uses: appleboy/ssh-action@v1
with:
host: ${{ secrets.STAGING_HOST }}
username: ${{ secrets.SERVER_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
cd /opt/fastapi-staging
docker compose -f docker-compose.staging.yml pull
docker compose -f docker-compose.staging.yml up -d
# ---- Deploy to Production (manual approval) ----
deploy-production:
needs: build
runs-on: ubuntu-latest
if: github.ref == 'refs/heads/main'
environment:
name: production
url: https://api.yourdomain.com
steps:
- name: Deploy to production
uses: appleboy/ssh-action@v1
with:
host: ${{ secrets.PROD_HOST }}
username: ${{ secrets.SERVER_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
cd /opt/fastapi-prod
docker compose pull
docker compose up -d --no-deps app
docker compose exec -T app alembic upgrade head
# Verify health
sleep 5
curl -f http://localhost:8000/health || exit 1
Alembic is the standard migration tool for SQLAlchemy. Managing migrations in production requires careful coordination with your deployment process to avoid downtime and data loss.
# Install Alembic pip install alembic # Initialize Alembic alembic init alembic
Configure Alembic to use your application’s database URL:
# alembic/env.py
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
import os
import sys
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from app.database import Base # Your SQLAlchemy Base
from app.models import * # Import all models
config = context.config
# Override sqlalchemy.url from environment
database_url = os.getenv("DATABASE_URL", "")
# Handle Heroku-style postgres:// URLs
if database_url.startswith("postgres://"):
database_url = database_url.replace("postgres://", "postgresql://", 1)
config.set_main_option("sqlalchemy.url", database_url)
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
def run_migrations_offline():
"""Run migrations in 'offline' mode (generates SQL script)."""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
"""Run migrations in 'online' mode (directly against database)."""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
# Generate a migration from model changes alembic revision --autogenerate -m "add_users_table" # Review the generated migration file before applying! # Then apply alembic upgrade head # Rollback one step alembic downgrade -1 # View migration history alembic history --verbose # Show current revision alembic current
Create an entrypoint script that runs migrations before starting the application:
#!/bin/bash # docker-entrypoint.sh set -e echo "Running database migrations..." alembic upgrade head echo "Starting application..." exec "$@"
# Dockerfile (updated) # ... (previous build steps) COPY docker-entrypoint.sh /docker-entrypoint.sh RUN chmod +x /docker-entrypoint.sh ENTRYPOINT ["/docker-entrypoint.sh"] CMD ["gunicorn", "app.main:app", "--worker-class", "uvicorn.workers.UvicornWorker", "--workers", "4", "--bind", "0.0.0.0:8000"]
For zero-downtime deployments, follow the expand-contract pattern:
# Example: Renaming a column (email -> email_address)
# Migration 1: Add new column (expand)
def upgrade():
op.add_column("users", sa.Column("email_address", sa.String(255), nullable=True))
# Backfill
op.execute("UPDATE users SET email_address = email WHERE email_address IS NULL")
def downgrade():
op.drop_column("users", "email_address")
# Migration 2: Make new column required and drop old (contract)
# Deploy AFTER all code uses email_address
def upgrade():
op.alter_column("users", "email_address", nullable=False)
op.drop_column("users", "email")
def downgrade():
op.add_column("users", sa.Column("email", sa.String(255), nullable=True))
op.execute("UPDATE users SET email = email_address")
Production applications need comprehensive monitoring to detect issues before users do. This includes health checks, metrics collection, structured logging, and alerting.
# app/routers/health.py
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
import redis.asyncio as redis
from datetime import datetime
from app.database import get_db
from app.config import get_settings
router = APIRouter(tags=["health"])
@router.get("/health")
async def health_check():
"""Basic health check for load balancers."""
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
@router.get("/health/ready")
async def readiness_check(db: AsyncSession = Depends(get_db)):
"""Readiness check - verifies all dependencies are available."""
checks = {}
# Database check
try:
result = await db.execute(text("SELECT 1"))
checks["database"] = {"status": "healthy"}
except Exception as e:
checks["database"] = {"status": "unhealthy", "error": str(e)}
# Redis check
try:
settings = get_settings()
r = redis.from_url(settings.redis_url)
await r.ping()
checks["redis"] = {"status": "healthy"}
await r.close()
except Exception as e:
checks["redis"] = {"status": "unhealthy", "error": str(e)}
overall = "healthy" if all(
c["status"] == "healthy" for c in checks.values()
) else "unhealthy"
return {
"status": overall,
"checks": checks,
"timestamp": datetime.utcnow().isoformat(),
}
Expose application metrics for Prometheus to scrape:
pip install prometheus-fastapi-instrumentator
# app/metrics.py
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_client import Counter, Histogram, Gauge
# Custom metrics
REQUEST_COUNT = Counter(
"app_requests_total",
"Total number of requests",
["method", "endpoint", "status"]
)
REQUEST_DURATION = Histogram(
"app_request_duration_seconds",
"Request duration in seconds",
["method", "endpoint"],
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0]
)
ACTIVE_CONNECTIONS = Gauge(
"app_active_connections",
"Number of active connections"
)
DB_POOL_SIZE = Gauge(
"app_db_pool_size",
"Database connection pool size"
)
def setup_metrics(app):
"""Initialize Prometheus instrumentation."""
Instrumentator(
should_group_status_codes=False,
should_ignore_untemplated=True,
should_respect_env_var=False,
excluded_handlers=["/health", "/metrics"],
env_var_name="ENABLE_METRICS",
).instrument(app).expose(app, endpoint="/metrics")
Add metrics to your application factory:
# In app/main.py create_app()
from app.metrics import setup_metrics
def create_app() -> FastAPI:
# ... previous setup ...
setup_metrics(app)
return app
pip install sentry-sdk[fastapi]
# app/sentry.py
import sentry_sdk
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
from app.config import get_settings
def setup_sentry():
"""Initialize Sentry error tracking."""
settings = get_settings()
if settings.sentry_dsn:
sentry_sdk.init(
dsn=settings.sentry_dsn,
environment=settings.environment,
release=settings.app_version,
integrations=[
FastApiIntegration(transaction_style="endpoint"),
SqlalchemyIntegration(),
],
traces_sample_rate=0.1 if settings.environment == "production" else 1.0,
profiles_sample_rate=0.1,
send_default_pii=False, # Don't send user PII
)
With Prometheus metrics exposed, you can create Grafana dashboards to visualize:
# docker-compose monitoring stack
prometheus:
image: prom/prometheus:latest
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
- prometheus_data:/prometheus
ports:
- "9090:9090"
command:
- '--config.file=/etc/prometheus/prometheus.yml'
- '--storage.tsdb.retention.time=15d'
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
volumes:
- grafana_data:/var/lib/grafana
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
# prometheus.yml
global:
scrape_interval: 15s
scrape_configs:
- job_name: "fastapi"
static_configs:
- targets: ["app:8000"]
metrics_path: /metrics
FastAPI is already one of the fastest Python frameworks, but production applications can benefit from caching, async optimization, connection pooling, and profiling.
# GOOD: Use async for I/O-bound operations
import httpx
async def fetch_external_data(url: str) -> dict:
async with httpx.AsyncClient() as client:
response = await client.get(url)
return response.json()
# GOOD: Run CPU-bound tasks in a thread pool
from fastapi.concurrency import run_in_threadpool
import hashlib
async def hash_password(password: str) -> str:
return await run_in_threadpool(
hashlib.pbkdf2_hmac, "sha256", password.encode(), b"salt", 100000
)
# GOOD: Parallel async operations
import asyncio
async def get_dashboard_data(user_id: int):
"""Fetch multiple pieces of data concurrently."""
orders, notifications, recommendations = await asyncio.gather(
get_user_orders(user_id),
get_notifications(user_id),
get_recommendations(user_id),
)
return {
"orders": orders,
"notifications": notifications,
"recommendations": recommendations,
}
# BAD: Sequential async calls (slower)
async def get_dashboard_data_slow(user_id: int):
orders = await get_user_orders(user_id) # Wait...
notifications = await get_notifications(user_id) # Wait...
recommendations = await get_recommendations(user_id) # Wait...
return {"orders": orders, "notifications": notifications}
pip install redis
# app/cache.py
import json
import hashlib
from functools import wraps
from typing import Optional, Callable
import redis.asyncio as redis
from app.config import get_settings
_redis_client: Optional[redis.Redis] = None
async def get_redis() -> redis.Redis:
"""Get or create Redis client."""
global _redis_client
if _redis_client is None:
settings = get_settings()
_redis_client = redis.from_url(
settings.redis_url,
encoding="utf-8",
decode_responses=True,
)
return _redis_client
async def cache_get(key: str) -> Optional[dict]:
"""Get a value from cache."""
r = await get_redis()
data = await r.get(key)
if data:
return json.loads(data)
return None
async def cache_set(key: str, value: dict, ttl: int = 300):
"""Set a value in cache with TTL (default 5 minutes)."""
r = await get_redis()
await r.setex(key, ttl, json.dumps(value))
async def cache_delete(key: str):
"""Delete a key from cache."""
r = await get_redis()
await r.delete(key)
async def cache_delete_pattern(pattern: str):
"""Delete all keys matching a pattern."""
r = await get_redis()
async for key in r.scan_iter(match=pattern):
await r.delete(key)
def cached(ttl: int = 300, prefix: str = ""):
"""Decorator for caching endpoint responses."""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# Build cache key from function name and arguments
key_data = f"{prefix}:{func.__name__}:{str(args)}:{str(sorted(kwargs.items()))}"
cache_key = hashlib.md5(key_data.encode()).hexdigest()
# Check cache
cached_result = await cache_get(cache_key)
if cached_result is not None:
return cached_result
# Execute function
result = await func(*args, **kwargs)
# Store in cache
if isinstance(result, dict):
await cache_set(cache_key, result, ttl)
elif hasattr(result, "model_dump"):
await cache_set(cache_key, result.model_dump(), ttl)
return result
return wrapper
return decorator
Use the caching decorator on your endpoints:
from app.cache import cached, cache_delete_pattern
@router.get("/products/{product_id}")
@cached(ttl=600, prefix="product")
async def get_product(product_id: int, db: AsyncSession = Depends(get_db)):
"""Get product with 10-minute cache."""
product = await db.get(Product, product_id)
if not product:
raise HTTPException(status_code=404, detail="Product not found")
return ProductResponse.model_validate(product).model_dump()
@router.put("/products/{product_id}")
async def update_product(product_id: int, data: ProductUpdate, db: AsyncSession = Depends(get_db)):
"""Update product and invalidate cache."""
product = await db.get(Product, product_id)
# ... update logic ...
await cache_delete_pattern("product:*")
return ProductResponse.model_validate(product)
# Add GZip middleware for large responses from fastapi.middleware.gzip import GZipMiddleware app.add_middleware(GZipMiddleware, minimum_size=1000) # Compress responses > 1KB
Use profiling to find bottlenecks in your application:
# app/profiling.py - Development only
import cProfile
import pstats
import io
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
class ProfilingMiddleware(BaseHTTPMiddleware):
"""Profile requests and log slow endpoints. DEV ONLY."""
async def dispatch(self, request: Request, call_next):
profiler = cProfile.Profile()
profiler.enable()
response = await call_next(request)
profiler.disable()
# Log if request took more than 100ms
stream = io.StringIO()
stats = pstats.Stats(profiler, stream=stream)
stats.sort_stats("cumulative")
total_time = sum(stat[3] for stat in stats.stats.values())
if total_time > 0.1: # 100ms threshold
stats.print_stats(20)
print(f"SLOW REQUEST: {request.method} {request.url.path}")
print(stream.getvalue())
return response
As your application grows, you need strategies to handle increased traffic. Scaling involves horizontal scaling (more instances), load balancing, caching layers, and rate limiting.
# Scale to multiple instances docker compose up -d --scale app=4 # Nginx automatically load balances across all instances
# docker-compose.prod.yml - Production scaling
version: "3.9"
services:
app:
build: .
deploy:
replicas: 4
resources:
limits:
cpus: "1.0"
memory: 512M
reservations:
cpus: "0.25"
memory: 128M
restart_policy:
condition: on-failure
delay: 5s
max_attempts: 3
environment:
- DATABASE_URL=postgresql+asyncpg://user:pass@db:5432/mydb
- REDIS_URL=redis://redis:6379/0
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
nginx:
image: nginx:alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
- ./nginx/certs:/etc/nginx/certs
depends_on:
- app
pip install slowapi
# app/rate_limit.py
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
limiter = Limiter(
key_func=get_remote_address,
default_limits=["100/minute"],
storage_uri="redis://localhost:6379/1",
strategy="fixed-window-elastic-expiry",
)
def setup_rate_limiting(app):
"""Configure rate limiting for the application."""
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
Apply rate limits to specific endpoints:
from app.rate_limit import limiter
@router.post("/auth/login")
@limiter.limit("5/minute")
async def login(request: Request, credentials: LoginRequest):
"""Login with strict rate limiting."""
# ... authentication logic
pass
@router.get("/api/search")
@limiter.limit("30/minute")
async def search(request: Request, q: str):
"""Search with moderate rate limiting."""
# ... search logic
pass
For long-running tasks, use a task queue to process work asynchronously:
pip install celery[redis]
# app/tasks.py
from celery import Celery
from app.config import get_settings
settings = get_settings()
celery_app = Celery(
"fastapi_tasks",
broker=settings.redis_url,
backend=settings.redis_url,
)
celery_app.conf.update(
task_serializer="json",
result_serializer="json",
accept_content=["json"],
timezone="UTC",
task_track_started=True,
task_time_limit=300, # 5 minute hard limit
task_soft_time_limit=240, # 4 minute soft limit
worker_max_tasks_per_child=100, # Restart workers after 100 tasks
)
@celery_app.task(bind=True, max_retries=3)
def send_email_task(self, to_email: str, subject: str, body: str):
"""Send email asynchronously."""
try:
# ... send email logic
pass
except Exception as exc:
self.retry(exc=exc, countdown=60) # Retry after 60 seconds
@celery_app.task
def generate_report_task(user_id: int, report_type: str):
"""Generate report in background."""
# ... heavy computation
pass
# Use in FastAPI endpoints
from app.tasks import send_email_task, generate_report_task
@router.post("/reports/generate")
async def generate_report(user_id: int, report_type: str):
task = generate_report_task.delay(user_id, report_type)
return {"task_id": task.id, "status": "processing"}
@router.get("/tasks/{task_id}")
async def get_task_status(task_id: str):
from celery.result import AsyncResult
result = AsyncResult(task_id)
return {
"task_id": task_id,
"status": result.status,
"result": result.result if result.ready() else None,
}
| Strategy | When to Use | Complexity |
|---|---|---|
| Vertical scaling (bigger server) | Quick fix, small apps | Low |
| Horizontal scaling (more instances) | High traffic, stateless apps | Medium |
| Caching (Redis) | Repeated reads, expensive queries | Medium |
| Background tasks (Celery) | Long operations, email, reports | Medium |
| Database read replicas | Read-heavy workloads | High |
| CDN for static assets | Global users, static content | Low |
| Microservices | Large teams, complex domains | Very High |
Security is not optional in production. FastAPI provides several built-in security features, but you need to configure additional layers for a properly hardened deployment.
Always enforce HTTPS in production. Use the HTTPS redirect middleware:
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
if settings.environment == "production":
app.add_middleware(HTTPSRedirectMiddleware)
# app/security.py
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all responses."""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Permissions-Policy"] = (
"camera=(), microphone=(), geolocation=(), payment=()"
)
if request.url.scheme == "https":
response.headers["Strict-Transport-Security"] = (
"max-age=63072000; includeSubDomains; preload"
)
return response
from fastapi.middleware.cors import CORSMiddleware
# NEVER use allow_origins=["*"] in production
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://yourdomain.com",
"https://www.yourdomain.com",
"https://admin.yourdomain.com",
],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
allow_headers=["Authorization", "Content-Type", "X-Request-ID"],
expose_headers=["X-Request-ID"],
max_age=3600, # Cache preflight for 1 hour
)
Never hardcode secrets. Use environment variables and secrets management services:
# app/secrets.py
import boto3
import json
from functools import lru_cache
@lru_cache()
def get_aws_secret(secret_name: str, region: str = "us-east-1") -> dict:
"""Retrieve secrets from AWS Secrets Manager."""
client = boto3.client("secretsmanager", region_name=region)
response = client.get_secret_value(SecretId=secret_name)
return json.loads(response["SecretString"])
# Usage in settings
class Settings(BaseSettings):
@classmethod
def _load_aws_secrets(cls):
"""Load secrets from AWS Secrets Manager at startup."""
try:
secrets = get_aws_secret("fastapi/production")
return secrets
except Exception:
return {}
def __init__(self, **kwargs):
aws_secrets = self._load_aws_secrets()
# AWS secrets override env vars
for key, value in aws_secrets.items():
if key.lower() not in kwargs:
kwargs[key.lower()] = value
super().__init__(**kwargs)
from pydantic import BaseModel, Field, field_validator
import bleach
import re
class UserInput(BaseModel):
"""User input with validation and sanitization."""
username: str = Field(min_length=3, max_length=50, pattern=r"^[a-zA-Z0-9_-]+$")
email: str = Field(max_length=255)
bio: str = Field(max_length=1000, default="")
@field_validator("bio")
@classmethod
def sanitize_bio(cls, v: str) -> str:
"""Remove HTML tags from bio."""
return bleach.clean(v, tags=[], strip=True)
@field_validator("email")
@classmethod
def validate_email(cls, v: str) -> str:
"""Validate email format."""
email_regex = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
if not re.match(email_regex, v):
raise ValueError("Invalid email format")
return v.lower()
| Category | Item | Status |
|---|---|---|
| Transport | HTTPS enforced everywhere | Required |
| Transport | HSTS header enabled | Required |
| Auth | Passwords hashed with bcrypt/argon2 | Required |
| Auth | JWT tokens with short expiry | Required |
| Auth | Rate limiting on login endpoints | Required |
| Headers | Security headers on all responses | Required |
| CORS | Specific origins (no wildcards) | Required |
| Input | Pydantic validation on all inputs | Required |
| Secrets | No secrets in code or git | Required |
| Secrets | Use secrets manager (AWS SM, Vault) | Recommended |
| Dependencies | Regular dependency updates | Required |
| Docs | Disable /docs and /redoc in production | Recommended |
Here is a complete production-ready docker-compose setup with FastAPI, PostgreSQL, Redis, Nginx, Celery, and monitoring — everything you need to deploy a real-world application.
fastapi-production/ ├── app/ │ ├── __init__.py │ ├── main.py # Application factory │ ├── config.py # Pydantic settings │ ├── database.py # Database setup │ ├── models/ # SQLAlchemy models │ ├── schemas/ # Pydantic schemas │ ├── routers/ # API routes │ ├── services/ # Business logic │ ├── middleware.py # Custom middleware │ ├── cache.py # Redis caching │ ├── tasks.py # Celery tasks │ └── logging_config.py # Structured logging ├── alembic/ # Database migrations │ ├── versions/ │ └── env.py ├── nginx/ │ ├── nginx.conf │ └── certs/ ├── tests/ │ ├── conftest.py │ ├── test_routes/ │ └── test_services/ ├── .github/ │ └── workflows/ │ └── ci-cd.yml ├── Dockerfile ├── docker-compose.yml # Development ├── docker-compose.prod.yml # Production ├── docker-entrypoint.sh ├── gunicorn.conf.py ├── requirements.txt ├── requirements-dev.txt ├── alembic.ini ├── .env.example ├── .dockerignore └── .gitignore
# docker-compose.prod.yml
version: "3.9"
services:
# ---- FastAPI Application ----
app:
build:
context: .
dockerfile: Dockerfile
environment:
- ENVIRONMENT=production
- DATABASE_URL=postgresql+asyncpg://fastapi:${DB_PASSWORD}@db:5432/fastapi_prod
- REDIS_URL=redis://redis:6379/0
- SECRET_KEY=${SECRET_KEY}
- JWT_SECRET=${JWT_SECRET}
- LOG_LEVEL=INFO
- LOG_FORMAT=json
- WORKERS=4
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
restart: always
deploy:
replicas: 2
resources:
limits:
cpus: "1.0"
memory: 512M
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- backend
- frontend
# ---- Celery Worker ----
celery-worker:
build: .
command: celery -A app.tasks worker --loglevel=info --concurrency=4
environment:
- DATABASE_URL=postgresql+asyncpg://fastapi:${DB_PASSWORD}@db:5432/fastapi_prod
- REDIS_URL=redis://redis:6379/0
depends_on:
- db
- redis
restart: always
deploy:
replicas: 2
resources:
limits:
cpus: "0.5"
memory: 256M
networks:
- backend
# ---- Celery Beat (Scheduler) ----
celery-beat:
build: .
command: celery -A app.tasks beat --loglevel=info
environment:
- REDIS_URL=redis://redis:6379/0
depends_on:
- redis
restart: always
networks:
- backend
# ---- PostgreSQL ----
db:
image: postgres:16-alpine
environment:
POSTGRES_USER: fastapi
POSTGRES_PASSWORD: ${DB_PASSWORD}
POSTGRES_DB: fastapi_prod
volumes:
- postgres_data:/var/lib/postgresql/data
- ./init.sql:/docker-entrypoint-initdb.d/init.sql
healthcheck:
test: ["CMD-SHELL", "pg_isready -U fastapi -d fastapi_prod"]
interval: 10s
timeout: 5s
retries: 5
restart: always
deploy:
resources:
limits:
cpus: "2.0"
memory: 1G
networks:
- backend
# ---- Redis ----
redis:
image: redis:7-alpine
command: redis-server --appendonly yes --maxmemory 256mb --maxmemory-policy allkeys-lru
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
restart: always
networks:
- backend
# ---- Nginx Reverse Proxy ----
nginx:
image: nginx:alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro
- ./nginx/certs:/etc/nginx/certs:ro
- static_files:/app/static:ro
depends_on:
- app
restart: always
networks:
- frontend
# ---- Prometheus (Monitoring) ----
prometheus:
image: prom/prometheus:latest
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml:ro
- prometheus_data:/prometheus
ports:
- "9090:9090"
restart: always
networks:
- backend
# ---- Grafana (Dashboards) ----
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
volumes:
- grafana_data:/var/lib/grafana
environment:
- GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_PASSWORD}
restart: always
networks:
- backend
volumes:
postgres_data:
redis_data:
static_files:
prometheus_data:
grafana_data:
networks:
frontend:
driver: bridge
backend:
driver: bridge
# nginx/nginx.conf (production)
worker_processes auto;
error_log /var/log/nginx/error.log warn;
pid /var/run/nginx.pid;
events {
worker_connections 1024;
use epoll;
multi_accept on;
}
http {
include /etc/nginx/mime.types;
default_type application/octet-stream;
# Logging format
log_format json_combined escape=json
'{"time":"$time_iso8601",'
'"remote_addr":"$remote_addr",'
'"request":"$request",'
'"status":$status,'
'"body_bytes_sent":$body_bytes_sent,'
'"request_time":$request_time,'
'"upstream_response_time":"$upstream_response_time"}';
access_log /var/log/nginx/access.log json_combined;
# Performance
sendfile on;
tcp_nopush on;
tcp_nodelay on;
keepalive_timeout 65;
types_hash_max_size 2048;
# Gzip
gzip on;
gzip_vary on;
gzip_proxied any;
gzip_min_length 1024;
gzip_types text/plain text/css application/json application/javascript text/xml;
# Rate limiting
limit_req_zone $binary_remote_addr zone=api:10m rate=10r/s;
limit_req_zone $binary_remote_addr zone=auth:10m rate=1r/s;
# Upstream (load balancing across app replicas)
upstream app {
least_conn;
server app:8000;
}
# HTTP -> HTTPS redirect
server {
listen 80;
server_name _;
return 301 https://$host$request_uri;
}
# HTTPS server
server {
listen 443 ssl http2;
server_name yourdomain.com;
ssl_certificate /etc/nginx/certs/fullchain.pem;
ssl_certificate_key /etc/nginx/certs/privkey.pem;
ssl_protocols TLSv1.2 TLSv1.3;
ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256;
ssl_prefer_server_ciphers off;
# Security headers
add_header X-Frame-Options "DENY" always;
add_header X-Content-Type-Options "nosniff" always;
add_header Strict-Transport-Security "max-age=63072000" always;
client_max_body_size 10M;
# API endpoints
location /api/ {
limit_req zone=api burst=20 nodelay;
proxy_pass http://app;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Auth endpoints (strict rate limiting)
location /api/auth/ {
limit_req zone=auth burst=5 nodelay;
proxy_pass http://app;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# WebSocket
location /ws/ {
proxy_pass http://app;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_read_timeout 86400s;
}
# Health check (no logging, no rate limit)
location /health {
access_log off;
proxy_pass http://app;
}
# Static files
location /static/ {
alias /app/static/;
expires 30d;
add_header Cache-Control "public, immutable";
}
}
}
# Create .env file for production secrets cat > .env << 'EOF' DB_PASSWORD=your-strong-password-here SECRET_KEY=your-secret-key-here JWT_SECRET=your-jwt-secret-here GRAFANA_PASSWORD=admin-password-here EOF # Start the full production stack docker compose -f docker-compose.prod.yml up -d --build # Check all services are healthy docker compose -f docker-compose.prod.yml ps # View application logs docker compose -f docker-compose.prod.yml logs -f app # Run database migrations docker compose -f docker-compose.prod.yml exec app alembic upgrade head # Scale application horizontally docker compose -f docker-compose.prod.yml up -d --scale app=4 # Rolling update (zero downtime) docker compose -f docker-compose.prod.yml build app docker compose -f docker-compose.prod.yml up -d --no-deps app # Backup database docker compose -f docker-compose.prod.yml exec db pg_dump -U fastapi fastapi_prod > backup.sql
| # | Topic | Key Points |
|---|---|---|
| 1 | Configuration | Use pydantic-settings for type-safe configuration from environment variables. Never hardcode secrets. |
| 2 | ASGI Servers | Use Gunicorn with Uvicorn workers for production. Set workers to (2 * CPU) + 1. Enable max_requests to prevent memory leaks. |
| 3 | Docker | Use multi-stage builds for smaller images. Run as non-root user. Include health checks. Use .dockerignore to reduce context size. |
| 4 | Nginx | Always use Nginx as a reverse proxy. Handle SSL termination, static files, rate limiting, and WebSocket proxying at the Nginx layer. |
| 5 | AWS | EC2 for full control, ECS/Fargate for managed containers, Lambda with Mangum for serverless. Use SSM Parameter Store or Secrets Manager for secrets. |
| 6 | Heroku | Simplest deployment path. Use Procfile with Gunicorn + Uvicorn workers. Add release phase for auto-migrations. |
| 7 | DigitalOcean | App Platform for managed PaaS or Droplets with systemd for full control. Both work well for FastAPI. |
| 8 | CI/CD | GitHub Actions pipeline: lint, test with services (Postgres, Redis), build Docker image, deploy. Use environments for staging/production separation. |
| 9 | Migrations | Use Alembic for database migrations. Run migrations in Docker entrypoint or release phase. Follow expand-contract pattern for zero-downtime changes. |
| 10 | Monitoring | Health check endpoints for load balancers. Prometheus metrics with Grafana dashboards. Sentry for error tracking. Structured JSON logging. |
| 11 | Performance | Use asyncio.gather for parallel I/O. Cache with Redis. Enable GZip compression. Profile slow endpoints to find bottlenecks. |
| 12 | Scaling | Start with vertical scaling, then horizontal. Use Celery for background tasks. Rate limit with slowapi. Consider read replicas for DB-heavy workloads. |
| 13 | Security | Enforce HTTPS, add security headers, configure CORS properly, validate all inputs with Pydantic, use secrets management, disable docs in production. |
| 14 | Full Stack | Production stack: FastAPI + PostgreSQL + Redis + Nginx + Celery + Prometheus + Grafana. Use docker-compose for orchestration with health checks, resource limits, and network isolation. |
With these configurations and practices in place, your FastAPI application is ready for production traffic. Start simple — you don’t need every component from day one. Begin with Docker + Nginx + Gunicorn, add monitoring as you grow, and scale horizontally when needed.
Security is one of the most critical aspects of any web application. A single vulnerability can expose user data, compromise accounts, and destroy user trust. FastAPI provides excellent built-in support for implementing authentication and authorization, leveraging Python’s type system and dependency injection to create secure, maintainable auth systems.
In this comprehensive tutorial, you will learn how to build a production-ready authentication and authorization system in FastAPI. We will cover everything from password hashing and JWT tokens to role-based access control, OAuth2 scopes, refresh token rotation, and security best practices. By the end, you will have a complete, reusable auth system that you can drop into any FastAPI project.
Before writing any code, it is essential to understand the distinction between authentication and authorization, and the common strategies used to implement them.
Authentication answers the question: “Who are you?” It is the process of verifying a user’s identity. When a user logs in with a username and password, the system authenticates them by checking those credentials against stored records.
Authorization answers the question: “What are you allowed to do?” It determines what resources and actions an authenticated user can access. A regular user might view their own profile, while an admin can manage all users.
| Aspect | Authentication | Authorization |
|---|---|---|
| Question | Who are you? | What can you do? |
| Purpose | Verify identity | Grant/deny access |
| When | Before authorization | After authentication |
| Example | Login with username/password | Admin-only endpoint access |
| Failure Response | 401 Unauthorized | 403 Forbidden |
| Data | Credentials (password, token, biometrics) | Roles, permissions, policies |
| Strategy | How It Works | Best For | Drawbacks |
|---|---|---|---|
| Session-Based | Server stores session data, client holds session ID cookie | Traditional web apps, server-rendered pages | Stateful, hard to scale horizontally |
| Token-Based (JWT) | Server issues signed token, client sends it with each request | SPAs, mobile apps, microservices | Token revocation is complex |
| API Key | Client sends a pre-shared key in header or query param | Server-to-server, third-party integrations | No user context, key rotation challenges |
| OAuth2 | Delegated auth via authorization server | Third-party login (Google, GitHub) | Complex implementation |
| Basic Auth | Username:password in Authorization header (Base64) | Simple internal tools, development | Credentials sent with every request |
The foundation of any password-based authentication system is secure password hashing. You must never store passwords in plaintext. Instead, you hash them using a one-way cryptographic function that makes it computationally infeasible to recover the original password.
pip install passlib[bcrypt] bcrypt
passlib is a comprehensive password hashing library that supports multiple algorithms. bcrypt is the recommended algorithm for password hashing because it is deliberately slow (making brute-force attacks expensive), includes a built-in salt, and has a configurable work factor that can be increased as hardware gets faster.
# security/password.py
from passlib.context import CryptContext
# Create a password context with bcrypt as the default scheme
pwd_context = CryptContext(
schemes=["bcrypt"],
deprecated="auto", # Automatically mark old schemes as deprecated
bcrypt__rounds=12, # Work factor (2^12 = 4096 iterations)
)
def hash_password(plain_password: str) -> str:
"""
Hash a plaintext password using bcrypt.
Args:
plain_password: The user's plaintext password.
Returns:
The bcrypt hash string (includes algorithm, rounds, salt, and hash).
"""
return pwd_context.hash(plain_password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verify a plaintext password against a stored hash.
Args:
plain_password: The password to verify.
hashed_password: The stored bcrypt hash.
Returns:
True if the password matches, False otherwise.
"""
return pwd_context.verify(plain_password, hashed_password)
A bcrypt hash looks like this:
$2b$12$LJ3m4ys3Lk0TSwMvNCH/8.VkEm8MRzIrvMnGJOgLrMwOOzcnX3iOa │ │ │ │ │ │ │ └── Hash + Salt (53 characters) │ │ │ └── Cost factor (12 rounds = 2^12 iterations) │ │ └── Sub-version (2b = current) │ └── Algorithm identifier ($2b = bcrypt) │
Key points about bcrypt:
verify() can extract it and re-hash the input for comparison.
# security/password_validation.py
import re
from pydantic import BaseModel, field_validator
class PasswordRequirements(BaseModel):
"""Validates password strength requirements."""
password: str
@field_validator("password")
@classmethod
def validate_password_strength(cls, v: str) -> str:
if len(v) < 8:
raise ValueError("Password must be at least 8 characters long")
if len(v) > 128:
raise ValueError("Password must not exceed 128 characters")
if not re.search(r"[A-Z]", v):
raise ValueError("Password must contain at least one uppercase letter")
if not re.search(r"[a-z]", v):
raise ValueError("Password must contain at least one lowercase letter")
if not re.search(r"\d", v):
raise ValueError("Password must contain at least one digit")
if not re.search(r"[!@#$%^&*(),.?\":{}|<>]", v):
raise ValueError(
"Password must contain at least one special character"
)
return v
# Usage example
def validate_and_hash_password(plain_password: str) -> str:
"""Validate password strength, then hash it."""
# This will raise ValidationError if password is weak
PasswordRequirements(password=plain_password)
return hash_password(plain_password)
| Practice | Why |
|---|---|
| Use bcrypt or argon2 | Purpose-built for passwords; deliberately slow |
| Never use MD5 or SHA-256 alone | Too fast; vulnerable to brute-force and rainbow tables |
| Let the library handle salts | bcrypt auto-generates cryptographically random salts |
| Set cost factor to at least 12 | Balances security and performance; increase over time |
| Enforce password complexity | Weak passwords undermine even the best hashing |
| Never log passwords | Even hashed passwords should be treated as sensitive |
| Limit password length to 72 bytes | bcrypt truncates input beyond 72 bytes |
FastAPI has built-in support for OAuth2 flows. The OAuth2 “password flow” (also called “Resource Owner Password Credentials”) is the simplest OAuth2 flow where the user provides their username and password directly to your application, which then returns a token.
The OAuth2 password flow works as follows:
/token endpointAuthorization header for subsequent requests
# security/oauth2.py
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
# This tells FastAPI where the token endpoint is located.
# It also enables the "Authorize" button in Swagger UI.
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/token")
# This dependency extracts the token from the Authorization header
async def get_current_token(token: str = Depends(oauth2_scheme)) -> str:
"""
Extract the bearer token from the Authorization header.
The OAuth2PasswordBearer dependency automatically:
1. Looks for the Authorization header
2. Checks it starts with "Bearer "
3. Extracts and returns the token string
4. Returns 401 if the header is missing or malformed
"""
return token
# routers/auth.py
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
router = APIRouter(prefix="/api/v1/auth", tags=["Authentication"])
@router.post("/token")
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends()
):
"""
OAuth2-compatible token endpoint.
OAuth2PasswordRequestForm provides:
- username: str (required)
- password: str (required)
- scope: str (optional, space-separated scopes)
- grant_type: str (optional, must be "password" if provided)
Note: OAuth2 spec requires form data (not JSON) for this endpoint.
The Content-Type must be application/x-www-form-urlencoded.
"""
# Authenticate the user (we will implement this fully later)
user = authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Create access token
access_token = create_access_token(data={"sub": user.username})
return {
"access_token": access_token,
"token_type": "bearer",
}
application/x-www-form-urlencoded content type. FastAPI’s OAuth2PasswordRequestForm handles this automatically. When testing with Swagger UI, the “Authorize” button sends credentials in this format.
One of FastAPI’s best features is automatic OpenAPI documentation with built-in OAuth2 support. When you set tokenUrl in OAuth2PasswordBearer, Swagger UI adds an “Authorize” button that lets you log in and automatically includes the token in subsequent requests.
# main.py
from fastapi import FastAPI
from routers import auth
app = FastAPI(
title="FastAPI Auth Tutorial",
description="Complete authentication and authorization system",
version="1.0.0",
)
app.include_router(auth.router)
# Visit http://localhost:8000/docs to see the Authorize button
# Click it, enter credentials, and all protected endpoints
# will automatically include the Bearer token
JSON Web Tokens (JWT) are the industry standard for stateless authentication in modern web applications. A JWT is a compact, URL-safe token that contains claims (statements about the user) and is cryptographically signed to prevent tampering.
A JWT consists of three Base64URL-encoded parts separated by dots:
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c │ │ │ └── Header └── Payload (Claims) └── Signature
| Part | Contains | Example |
|---|---|---|
| Header | Algorithm and token type | {"alg": "HS256", "typ": "JWT"} |
| Payload | Claims (user data, expiration, etc.) | {"sub": "user123", "exp": 1700000000} |
| Signature | HMAC or RSA signature for integrity | Cryptographic hash of header + payload + secret |
| Claim | Name | Purpose |
|---|---|---|
sub |
Subject | User identifier (username, user ID) |
exp |
Expiration | When the token expires (Unix timestamp) |
iat |
Issued At | When the token was created |
jti |
JWT ID | Unique token identifier (for revocation) |
iss |
Issuer | Who issued the token |
aud |
Audience | Intended recipient |
nbf |
Not Before | Token is not valid before this time |
pip install python-jose[cryptography] # or alternatively: pip install PyJWT
# security/jwt_handler.py
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
import uuid
from jose import JWTError, jwt
from pydantic import BaseModel
# Configuration — In production, load these from environment variables
SECRET_KEY = "your-secret-key-change-in-production-use-openssl-rand-hex-32"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
class TokenData(BaseModel):
"""Schema for decoded token data."""
username: Optional[str] = None
scopes: list[str] = []
token_type: str = "access"
jti: Optional[str] = None
class TokenResponse(BaseModel):
"""Schema for token endpoint response."""
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int # seconds until access token expires
def create_access_token(
data: dict[str, Any],
expires_delta: Optional[timedelta] = None,
) -> str:
"""
Create a JWT access token.
Args:
data: Claims to include in the token (must include 'sub').
expires_delta: Custom expiration time. Defaults to 30 minutes.
Returns:
Encoded JWT string.
"""
to_encode = data.copy()
now = datetime.now(timezone.utc)
expire = now + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
to_encode.update({
"exp": expire,
"iat": now,
"jti": str(uuid.uuid4()), # Unique token ID for revocation
"type": "access",
})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_refresh_token(
data: dict[str, Any],
expires_delta: Optional[timedelta] = None,
) -> str:
"""
Create a JWT refresh token with longer expiration.
Refresh tokens are used to obtain new access tokens without
requiring the user to log in again.
"""
to_encode = data.copy()
now = datetime.now(timezone.utc)
expire = now + (expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS))
to_encode.update({
"exp": expire,
"iat": now,
"jti": str(uuid.uuid4()),
"type": "refresh",
})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def decode_token(token: str) -> TokenData:
"""
Decode and validate a JWT token.
Args:
token: The JWT string to decode.
Returns:
TokenData with the decoded claims.
Raises:
JWTError: If the token is invalid, expired, or tampered with.
"""
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise JWTError("Token missing 'sub' claim")
scopes: list[str] = payload.get("scopes", [])
token_type: str = payload.get("type", "access")
jti: str = payload.get("jti")
return TokenData(
username=username,
scopes=scopes,
token_type=token_type,
jti=jti,
)
# Generate a cryptographically secure random key openssl rand -hex 32 # Output: a1b2c3d4e5f6... (64 hex characters = 256 bits) # Or using Python python -c "import secrets; print(secrets.token_hex(32))"
# config.py
from pydantic_settings import BaseSettings
class AuthSettings(BaseSettings):
"""
Authentication configuration loaded from environment variables.
Set these in your .env file or system environment:
SECRET_KEY=your-secret-key
ALGORITHM=HS256
ACCESS_TOKEN_EXPIRE_MINUTES=30
REFRESH_TOKEN_EXPIRE_DAYS=7
"""
SECRET_KEY: str
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
# Password hashing
BCRYPT_ROUNDS: int = 12
# Rate limiting
LOGIN_RATE_LIMIT: str = "5/minute"
class Config:
env_file = ".env"
case_sensitive = True
auth_settings = AuthSettings()
With password hashing and JWT tokens in place, we need a user model to store user data and a registration endpoint to create new accounts. We will use SQLAlchemy for the database model and Pydantic for request/response schemas.
pip install sqlalchemy asyncpg aiosqlite # asyncpg for PostgreSQL, aiosqlite for SQLite (development)
# database.py
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
DATABASE_URL = "sqlite+aiosqlite:///./auth_tutorial.db"
# For PostgreSQL: "postgresql+asyncpg://user:pass@localhost/dbname"
engine = create_async_engine(DATABASE_URL, echo=True)
async_session = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
class Base(DeclarativeBase):
"""Base class for all SQLAlchemy models."""
pass
async def get_db() -> AsyncSession:
"""Dependency that provides a database session."""
async with async_session() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def init_db():
"""Create all tables on application startup."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# models/user.py
import enum
from datetime import datetime, timezone
from sqlalchemy import (
Boolean, Column, DateTime, Enum, Integer, String, Text, Index
)
from sqlalchemy.orm import relationship
from database import Base
class UserRole(str, enum.Enum):
"""User role enumeration for RBAC."""
USER = "user"
MODERATOR = "moderator"
ADMIN = "admin"
SUPER_ADMIN = "super_admin"
class User(Base):
"""
User database model.
Stores user credentials, profile information, and role assignments.
"""
__tablename__ = "users"
# Primary key
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
# Authentication fields
username = Column(String(50), unique=True, nullable=False, index=True)
email = Column(String(255), unique=True, nullable=False, index=True)
hashed_password = Column(String(255), nullable=False)
# Profile fields
full_name = Column(String(100), nullable=True)
# Role and permissions
role = Column(
Enum(UserRole),
default=UserRole.USER,
nullable=False,
)
# Account status
is_active = Column(Boolean, default=True, nullable=False)
is_verified = Column(Boolean, default=False, nullable=False)
# Timestamps
created_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
nullable=False,
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
nullable=False,
)
last_login = Column(DateTime(timezone=True), nullable=True)
# Security fields
failed_login_attempts = Column(Integer, default=0)
locked_until = Column(DateTime(timezone=True), nullable=True)
# API key for programmatic access
api_key = Column(String(64), unique=True, nullable=True, index=True)
# Relationships
refresh_tokens = relationship(
"RefreshToken", back_populates="user", cascade="all, delete-orphan"
)
# Table indexes
__table_args__ = (
Index("ix_users_email_active", "email", "is_active"),
)
def __repr__(self) -> str:
return f"<User(id={self.id}, username='{self.username}', role='{self.role}')>"
class RefreshToken(Base):
"""
Stores refresh tokens for token rotation and revocation.
Each refresh token is stored in the database so it can be:
- Revoked individually
- Rotated (old token invalidated when new one is issued)
- Cleaned up when expired
"""
__tablename__ = "refresh_tokens"
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
token_jti = Column(String(36), unique=True, nullable=False, index=True)
user_id = Column(Integer, nullable=False, index=True)
is_revoked = Column(Boolean, default=False, nullable=False)
created_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
)
expires_at = Column(DateTime(timezone=True), nullable=False)
# Relationship back to user
user = relationship("User", back_populates="refresh_tokens")
# schemas/user.py
import re
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, EmailStr, field_validator
class UserCreate(BaseModel):
"""Schema for user registration requests."""
username: str
email: EmailStr
password: str
full_name: Optional[str] = None
@field_validator("username")
@classmethod
def validate_username(cls, v: str) -> str:
if len(v) < 3:
raise ValueError("Username must be at least 3 characters")
if len(v) > 50:
raise ValueError("Username must not exceed 50 characters")
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError(
"Username can only contain letters, numbers, hyphens, "
"and underscores"
)
return v.lower() # Normalize to lowercase
@field_validator("password")
@classmethod
def validate_password(cls, v: str) -> str:
if len(v) < 8:
raise ValueError("Password must be at least 8 characters")
if not re.search(r"[A-Z]", v):
raise ValueError("Password must contain an uppercase letter")
if not re.search(r"[a-z]", v):
raise ValueError("Password must contain a lowercase letter")
if not re.search(r"\d", v):
raise ValueError("Password must contain a digit")
if not re.search(r"[!@#$%^&*(),.?\":{}|<>]", v):
raise ValueError("Password must contain a special character")
return v
class UserResponse(BaseModel):
"""Schema for user data in API responses (never includes password)."""
id: int
username: str
email: str
full_name: Optional[str] = None
role: str
is_active: bool
is_verified: bool
created_at: datetime
model_config = {"from_attributes": True}
class UserUpdate(BaseModel):
"""Schema for updating user profile."""
full_name: Optional[str] = None
email: Optional[EmailStr] = None
class UserInDB(UserResponse):
"""Schema that includes the hashed password (for internal use only)."""
hashed_password: str
# routers/users.py
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models.user import User
from schemas.user import UserCreate, UserResponse
from security.password import hash_password
router = APIRouter(prefix="/api/v1/users", tags=["Users"])
@router.post(
"/register",
response_model=UserResponse,
status_code=status.HTTP_201_CREATED,
summary="Register a new user",
)
async def register_user(
user_data: UserCreate,
db: AsyncSession = Depends(get_db),
):
"""
Register a new user account.
This endpoint:
1. Validates the input (username format, email format, password strength)
2. Checks for duplicate username and email
3. Hashes the password
4. Creates the user record
5. Returns the user data (without password)
"""
# Check for duplicate username
result = await db.execute(
select(User).where(User.username == user_data.username)
)
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Username already registered",
)
# Check for duplicate email
result = await db.execute(
select(User).where(User.email == user_data.email)
)
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Email already registered",
)
# Create user with hashed password
new_user = User(
username=user_data.username,
email=user_data.email,
hashed_password=hash_password(user_data.password),
full_name=user_data.full_name,
)
db.add(new_user)
await db.flush() # Flush to get the auto-generated ID
await db.refresh(new_user) # Refresh to load all fields
return new_user
# services/user_service.py
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.user import User
from security.password import hash_password, verify_password
class UserService:
"""
Service layer for user-related business logic.
Separating business logic from route handlers makes the code
more testable and reusable.
"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_username(self, username: str) -> Optional[User]:
"""Find a user by username."""
result = await self.db.execute(
select(User).where(User.username == username)
)
return result.scalar_one_or_none()
async def get_by_email(self, email: str) -> Optional[User]:
"""Find a user by email address."""
result = await self.db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
async def get_by_id(self, user_id: int) -> Optional[User]:
"""Find a user by ID."""
result = await self.db.execute(
select(User).where(User.id == user_id)
)
return result.scalar_one_or_none()
async def authenticate(
self, username: str, password: str
) -> Optional[User]:
"""
Authenticate a user with username and password.
Returns the user if credentials are valid, None otherwise.
Also handles account lockout after too many failed attempts.
"""
user = await self.get_by_username(username)
if not user:
# Run password hash anyway to prevent timing attacks
# (so the response time is the same whether user exists or not)
hash_password("dummy-password")
return None
# Check if account is locked
if user.locked_until and user.locked_until > datetime.now(timezone.utc):
return None
# Check if account is active
if not user.is_active:
return None
# Verify password
if not verify_password(password, user.hashed_password):
# Increment failed attempts
user.failed_login_attempts += 1
# Lock account after 5 failed attempts (30 minute lockout)
if user.failed_login_attempts >= 5:
from datetime import timedelta
user.locked_until = datetime.now(timezone.utc) + timedelta(minutes=30)
await self.db.flush()
return None
# Successful login — reset failed attempts
user.failed_login_attempts = 0
user.locked_until = None
user.last_login = datetime.now(timezone.utc)
await self.db.flush()
return user
async def get_by_api_key(self, api_key: str) -> Optional[User]:
"""Find a user by API key."""
result = await self.db.execute(
select(User).where(
User.api_key == api_key,
User.is_active == True,
)
)
return result.scalar_one_or_none()
Now let us bring everything together into a complete login system that authenticates users, generates JWT tokens, and returns them to the client.
# routers/auth.py
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status, Response
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models.user import RefreshToken
from schemas.user import UserResponse
from security.jwt_handler import (
create_access_token,
create_refresh_token,
decode_token,
TokenResponse,
ACCESS_TOKEN_EXPIRE_MINUTES,
REFRESH_TOKEN_EXPIRE_DAYS,
)
from services.user_service import UserService
router = APIRouter(prefix="/api/v1/auth", tags=["Authentication"])
@router.post(
"/token",
response_model=TokenResponse,
summary="Login and get access + refresh tokens",
)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db),
):
"""
Authenticate user and return JWT tokens.
This endpoint:
1. Validates credentials against the database
2. Creates a short-lived access token (30 min)
3. Creates a long-lived refresh token (7 days)
4. Stores the refresh token in the database for revocation
5. Returns both tokens to the client
"""
user_service = UserService(db)
user = await user_service.authenticate(
form_data.username, form_data.password
)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Build token claims
token_data = {
"sub": user.username,
"role": user.role.value,
"scopes": form_data.scopes, # OAuth2 scopes if requested
}
# Create tokens
access_token = create_access_token(data=token_data)
refresh_token = create_refresh_token(data={"sub": user.username})
# Decode refresh token to get its JTI and expiration
refresh_data = decode_token(refresh_token)
# Store refresh token in database for revocation tracking
db_refresh_token = RefreshToken(
token_jti=refresh_data.jti,
user_id=user.id,
expires_at=datetime.now(timezone.utc) + timedelta(
days=REFRESH_TOKEN_EXPIRE_DAYS
),
)
db.add(db_refresh_token)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
@router.post("/logout", summary="Logout and revoke refresh token")
async def logout(
refresh_token: str,
db: AsyncSession = Depends(get_db),
):
"""
Revoke the user's refresh token.
Since JWTs are stateless, we cannot truly invalidate an access token.
Instead, we revoke the refresh token so no new access tokens can be
obtained. The current access token will expire naturally.
"""
try:
token_data = decode_token(refresh_token)
except Exception:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid refresh token",
)
# Find and revoke the refresh token in the database
from sqlalchemy import select, update
result = await db.execute(
update(RefreshToken)
.where(RefreshToken.token_jti == token_data.jti)
.values(is_revoked=True)
)
if result.rowcount == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Refresh token not found",
)
return {"detail": "Successfully logged out"}
How and where you store tokens on the client side significantly impacts security. Here are the common approaches:
| Storage Method | Pros | Cons | Best For |
|---|---|---|---|
| HTTP-Only Cookie | Not accessible via JavaScript (XSS-safe) | Vulnerable to CSRF (mitigate with SameSite) | Web applications |
| localStorage | Easy to implement, persists across tabs | Vulnerable to XSS attacks | Low-security apps |
| sessionStorage | Cleared when tab closes | Vulnerable to XSS, lost on tab close | Temporary sessions |
| In-memory variable | Safest from storage attacks | Lost on page refresh | High-security SPAs |
# Alternative login endpoint that sets cookies instead of returning tokens
@router.post("/login", summary="Login with cookie-based token storage")
async def login_with_cookies(
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db),
):
"""Login and set tokens as HTTP-only cookies."""
user_service = UserService(db)
user = await user_service.authenticate(
form_data.username, form_data.password
)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
access_token = create_access_token(data={"sub": user.username})
refresh_token = create_refresh_token(data={"sub": user.username})
# Set access token cookie
response.set_cookie(
key="access_token",
value=f"Bearer {access_token}",
httponly=True, # JavaScript cannot access this cookie
secure=True, # Only sent over HTTPS
samesite="strict", # Not sent with cross-origin requests
max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
path="/",
)
# Set refresh token cookie
response.set_cookie(
key="refresh_token",
value=refresh_token,
httponly=True,
secure=True,
samesite="strict",
max_age=REFRESH_TOKEN_EXPIRE_DAYS * 86400,
path="/api/v1/auth/refresh", # Only sent to refresh endpoint
)
return {"detail": "Login successful", "username": user.username}
# Example client-side usage with the requests library
import requests
BASE_URL = "http://localhost:8000/api/v1"
# Step 1: Login to get tokens
login_response = requests.post(
f"{BASE_URL}/auth/token",
data={ # Note: form data, not JSON
"username": "johndoe",
"password": "SecurePass123!",
},
)
tokens = login_response.json()
access_token = tokens["access_token"]
refresh_token = tokens["refresh_token"]
print(f"Access Token: {access_token[:20]}...")
print(f"Expires in: {tokens['expires_in']} seconds")
# Step 2: Access a protected endpoint
headers = {"Authorization": f"Bearer {access_token}"}
profile_response = requests.get(
f"{BASE_URL}/users/me",
headers=headers,
)
print(f"Profile: {profile_response.json()}")
# Step 3: Refresh the access token when it expires
refresh_response = requests.post(
f"{BASE_URL}/auth/refresh",
json={"refresh_token": refresh_token},
)
new_tokens = refresh_response.json()
print(f"New Access Token: {new_tokens['access_token'][:20]}...")
The real power of FastAPI’s dependency injection system shines when protecting endpoints. We create reusable dependencies that validate tokens and retrieve the current user, then inject them into any endpoint that requires authentication.
# security/dependencies.py
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models.user import User
from security.jwt_handler import decode_token
from services.user_service import UserService
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/token")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
) -> User:
"""
Dependency that validates the JWT token and returns the current user.
This dependency:
1. Extracts the Bearer token from the Authorization header
2. Decodes and validates the JWT (checks signature, expiration)
3. Extracts the username from the 'sub' claim
4. Looks up the user in the database
5. Verifies the user account is active
6. Returns the User object
If any step fails, it raises a 401 Unauthorized error.
Usage:
@router.get("/protected")
async def protected_route(user: User = Depends(get_current_user)):
return {"message": f"Hello, {user.username}"}
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# Decode the JWT token
token_data = decode_token(token)
if token_data.username is None:
raise credentials_exception
# Ensure this is an access token, not a refresh token
if token_data.token_type != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type. Use an access token.",
headers={"WWW-Authenticate": "Bearer"},
)
except JWTError:
raise credentials_exception
# Look up the user in the database
user_service = UserService(db)
user = await user_service.get_by_username(token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
"""
Dependency that ensures the user is active.
Builds on get_current_user by adding an active check.
Use this for most protected endpoints.
"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is deactivated",
)
return current_user
# routers/users.py
from fastapi import APIRouter, Depends
from models.user import User
from schemas.user import UserResponse, UserUpdate
from security.dependencies import get_current_active_user
router = APIRouter(prefix="/api/v1/users", tags=["Users"])
@router.get("/me", response_model=UserResponse, summary="Get current user profile")
async def get_my_profile(
current_user: User = Depends(get_current_active_user),
):
"""
Return the profile of the currently authenticated user.
This endpoint requires a valid Bearer token in the Authorization header.
The token is automatically validated by the get_current_active_user
dependency, which also retrieves the full user object from the database.
"""
return current_user
@router.put("/me", response_model=UserResponse, summary="Update current user profile")
async def update_my_profile(
updates: UserUpdate,
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
"""Update the authenticated user's profile."""
if updates.full_name is not None:
current_user.full_name = updates.full_name
if updates.email is not None:
# Check if new email is already taken
existing = await UserService(db).get_by_email(updates.email)
if existing and existing.id != current_user.id:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Email already in use",
)
current_user.email = updates.email
await db.flush()
await db.refresh(current_user)
return current_user
@router.delete("/me", summary="Delete current user account")
async def delete_my_account(
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
"""Soft-delete the authenticated user's account."""
current_user.is_active = False
await db.flush()
return {"detail": "Account deactivated successfully"}
FastAPI’s dependency injection creates a clean chain of responsibilities. Understanding this chain helps you design your auth system:
# The dependency chain for a protected endpoint:
#
# HTTP Request
# |
# v
# OAuth2PasswordBearer --> Extracts "Bearer <token>" from header
# |
# v
# get_current_user --> Decodes JWT, looks up user in DB
# |
# v
# get_current_active_user --> Checks user.is_active
# |
# v
# require_role("admin") --> Checks user.role (RBAC - next section)
# |
# v
# Your endpoint handler --> Receives the validated User object
# Each dependency in the chain can:
# 1. Raise HTTPException to short-circuit the request
# 2. Pass data to the next dependency via return values
# 3. Access other dependencies (like the database session)
# Example: Stacking dependencies
@router.get("/admin/dashboard")
async def admin_dashboard(
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
"""
Both dependencies are resolved:
- get_current_active_user validates the token AND checks is_active
- get_db provides a database session
FastAPI resolves the entire dependency tree automatically.
"""
pass
# Sometimes you want endpoints that work for both anonymous and
# authenticated users (e.g., a public profile that shows extra
# data to the profile owner).
from fastapi.security import OAuth2PasswordBearer
# auto_error=False makes the dependency return None instead of
# raising 401 when no token is provided
oauth2_scheme_optional = OAuth2PasswordBearer(
tokenUrl="/api/v1/auth/token",
auto_error=False,
)
async def get_current_user_optional(
token: str | None = Depends(oauth2_scheme_optional),
db: AsyncSession = Depends(get_db),
) -> User | None:
"""
Optional authentication dependency.
Returns the User if a valid token is provided, None otherwise.
Does NOT raise 401 for missing or invalid tokens.
"""
if token is None:
return None
try:
token_data = decode_token(token)
user_service = UserService(db)
return await user_service.get_by_username(token_data.username)
except Exception:
return None
# Usage
@router.get("/posts/{post_id}")
async def get_post(
post_id: int,
current_user: User | None = Depends(get_current_user_optional),
):
"""
Public endpoint that shows extra data for authenticated users.
"""
post = await get_post_by_id(post_id)
response = {"title": post.title, "content": post.content}
if current_user and current_user.id == post.author_id:
response["edit_url"] = f"/posts/{post_id}/edit"
response["analytics"] = await get_post_analytics(post_id)
return response
Role-Based Access Control restricts system access based on the roles assigned to users. Instead of checking individual permissions for each user, you assign users to roles, and roles have predefined sets of permissions. This simplifies permission management significantly.
# security/rbac.py
import enum
from typing import Optional
from fastapi import Depends, HTTPException, status
from models.user import User, UserRole
from security.dependencies import get_current_active_user
class Permission(str, enum.Enum):
"""Fine-grained permissions for the application."""
# User permissions
READ_OWN_PROFILE = "read:own_profile"
UPDATE_OWN_PROFILE = "update:own_profile"
DELETE_OWN_ACCOUNT = "delete:own_account"
# Post permissions
CREATE_POST = "create:post"
READ_POST = "read:post"
UPDATE_OWN_POST = "update:own_post"
DELETE_OWN_POST = "delete:own_post"
# Admin permissions
READ_ALL_USERS = "read:all_users"
UPDATE_ANY_USER = "update:any_user"
DELETE_ANY_USER = "delete:any_user"
UPDATE_ANY_POST = "update:any_post"
DELETE_ANY_POST = "delete:any_post"
MANAGE_ROLES = "manage:roles"
VIEW_AUDIT_LOG = "view:audit_log"
# Super admin
MANAGE_SYSTEM = "manage:system"
# Map roles to their permissions
ROLE_PERMISSIONS: dict[UserRole, set[Permission]] = {
UserRole.USER: {
Permission.READ_OWN_PROFILE,
Permission.UPDATE_OWN_PROFILE,
Permission.DELETE_OWN_ACCOUNT,
Permission.CREATE_POST,
Permission.READ_POST,
Permission.UPDATE_OWN_POST,
Permission.DELETE_OWN_POST,
},
UserRole.MODERATOR: {
# Inherits all USER permissions plus:
Permission.READ_OWN_PROFILE,
Permission.UPDATE_OWN_PROFILE,
Permission.DELETE_OWN_ACCOUNT,
Permission.CREATE_POST,
Permission.READ_POST,
Permission.UPDATE_OWN_POST,
Permission.DELETE_OWN_POST,
# Moderator-specific
Permission.UPDATE_ANY_POST,
Permission.DELETE_ANY_POST,
Permission.READ_ALL_USERS,
},
UserRole.ADMIN: {
# Inherits all MODERATOR permissions plus:
Permission.READ_OWN_PROFILE,
Permission.UPDATE_OWN_PROFILE,
Permission.DELETE_OWN_ACCOUNT,
Permission.CREATE_POST,
Permission.READ_POST,
Permission.UPDATE_OWN_POST,
Permission.DELETE_OWN_POST,
Permission.UPDATE_ANY_POST,
Permission.DELETE_ANY_POST,
Permission.READ_ALL_USERS,
# Admin-specific
Permission.UPDATE_ANY_USER,
Permission.DELETE_ANY_USER,
Permission.MANAGE_ROLES,
Permission.VIEW_AUDIT_LOG,
},
UserRole.SUPER_ADMIN: {
perm for perm in Permission # All permissions
},
}
# security/rbac.py (continued)
def require_role(*allowed_roles: UserRole):
"""
Dependency factory that restricts access to users with specific roles.
Usage:
@router.get("/admin/users")
async def list_users(
user: User = Depends(require_role(UserRole.ADMIN, UserRole.SUPER_ADMIN))
):
...
"""
async def role_checker(
current_user: User = Depends(get_current_active_user),
) -> User:
if current_user.role not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Insufficient permissions. Required role: "
f"{', '.join(r.value for r in allowed_roles)}",
)
return current_user
return role_checker
def require_permission(*required_permissions: Permission):
"""
Dependency factory that checks for specific permissions.
More granular than role checking — checks if the user's role
grants the required permissions.
Usage:
@router.delete("/posts/{post_id}")
async def delete_post(
post_id: int,
user: User = Depends(require_permission(Permission.DELETE_ANY_POST))
):
...
"""
async def permission_checker(
current_user: User = Depends(get_current_active_user),
) -> User:
user_permissions = ROLE_PERMISSIONS.get(current_user.role, set())
missing = set(required_permissions) - user_permissions
if missing:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing permissions: "
f"{', '.join(p.value for p in missing)}",
)
return current_user
return permission_checker
# routers/admin.py
from fastapi import APIRouter, Depends
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models.user import User, UserRole
from schemas.user import UserResponse
from security.rbac import require_role, require_permission, Permission
router = APIRouter(prefix="/api/v1/admin", tags=["Admin"])
# Role-based access: only admins and super admins
@router.get(
"/users",
response_model=list[UserResponse],
summary="List all users (admin only)",
)
async def list_all_users(
skip: int = 0,
limit: int = 100,
current_user: User = Depends(
require_role(UserRole.ADMIN, UserRole.SUPER_ADMIN)
),
db: AsyncSession = Depends(get_db),
):
"""List all users. Requires ADMIN or SUPER_ADMIN role."""
result = await db.execute(
select(User).offset(skip).limit(limit)
)
return result.scalars().all()
# Permission-based access: more granular
@router.put(
"/users/{user_id}/role",
response_model=UserResponse,
summary="Change a user's role",
)
async def change_user_role(
user_id: int,
new_role: UserRole,
current_user: User = Depends(
require_permission(Permission.MANAGE_ROLES)
),
db: AsyncSession = Depends(get_db),
):
"""
Change a user's role. Requires MANAGE_ROLES permission.
Security rules:
- Cannot change your own role
- Cannot assign SUPER_ADMIN role (only direct DB change)
- Cannot change the role of a SUPER_ADMIN
"""
if user_id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot change your own role",
)
if new_role == UserRole.SUPER_ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="SUPER_ADMIN role can only be assigned via database",
)
user = await db.get(User, user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
if user.role == UserRole.SUPER_ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Cannot modify a SUPER_ADMIN user",
)
user.role = new_role
await db.flush()
await db.refresh(user)
return user
# Dashboard with statistics
@router.get("/dashboard", summary="Admin dashboard statistics")
async def admin_dashboard(
current_user: User = Depends(
require_role(UserRole.ADMIN, UserRole.SUPER_ADMIN)
),
db: AsyncSession = Depends(get_db),
):
"""Get system statistics for the admin dashboard."""
total_users = await db.scalar(select(func.count(User.id)))
active_users = await db.scalar(
select(func.count(User.id)).where(User.is_active == True)
)
role_counts = {}
for role in UserRole:
count = await db.scalar(
select(func.count(User.id)).where(User.role == role)
)
role_counts[role.value] = count
return {
"total_users": total_users,
"active_users": active_users,
"inactive_users": total_users - active_users,
"users_by_role": role_counts,
}
# A cleaner approach using role hierarchy
# Higher roles automatically inherit all lower role permissions
ROLE_HIERARCHY: dict[UserRole, int] = {
UserRole.USER: 1,
UserRole.MODERATOR: 2,
UserRole.ADMIN: 3,
UserRole.SUPER_ADMIN: 4,
}
def require_minimum_role(minimum_role: UserRole):
"""
Require a minimum role level using hierarchy.
A user with ADMIN role automatically passes a check for
MODERATOR or USER level, because ADMIN is higher in the hierarchy.
Usage:
@router.get("/mod/reports")
async def view_reports(
user: User = Depends(require_minimum_role(UserRole.MODERATOR))
):
... # Accessible by MODERATOR, ADMIN, and SUPER_ADMIN
"""
required_level = ROLE_HIERARCHY[minimum_role]
async def check_role(
current_user: User = Depends(get_current_active_user),
) -> User:
user_level = ROLE_HIERARCHY.get(current_user.role, 0)
if user_level < required_level:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Minimum role required: {minimum_role.value}",
)
return current_user
return check_role
While JWT tokens are ideal for user-facing authentication, API keys are better suited for server-to-server communication, third-party integrations, and programmatic access. FastAPI provides built-in support for API key authentication via headers, query parameters, or cookies.
# security/api_key.py
import secrets
from datetime import datetime, timezone
from typing import Optional
from fastapi import Depends, HTTPException, Security, status
from fastapi.security import APIKeyHeader
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models.user import User
from services.user_service import UserService
# Define where to look for the API key
api_key_header = APIKeyHeader(
name="X-API-Key",
auto_error=False, # Return None instead of 403 if missing
)
def generate_api_key() -> str:
"""
Generate a cryptographically secure API key.
Format: prefix_randomhex
The prefix makes it easy to identify and rotate keys.
"""
return f"lmsc_{secrets.token_hex(32)}"
async def get_user_from_api_key(
api_key: Optional[str] = Security(api_key_header),
db: AsyncSession = Depends(get_db),
) -> Optional[User]:
"""
Validate an API key and return the associated user.
Returns None if no API key is provided (allowing fallback
to other auth methods).
"""
if api_key is None:
return None
user_service = UserService(db)
user = await user_service.get_by_api_key(api_key)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="API key owner account is deactivated",
)
return user
# security/combined_auth.py
from fastapi import Depends, HTTPException, status
from models.user import User
from security.dependencies import get_current_user_optional
from security.api_key import get_user_from_api_key
async def get_current_user_flexible(
jwt_user: User | None = Depends(get_current_user_optional),
api_key_user: User | None = Depends(get_user_from_api_key),
) -> User:
"""
Accept either JWT token OR API key authentication.
This allows endpoints to work for both:
- Browser users (JWT from login)
- API clients (API key in header)
JWT takes priority if both are provided.
"""
user = jwt_user or api_key_user
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required. Provide a Bearer token or X-API-Key.",
headers={"WWW-Authenticate": "Bearer"},
)
return user
# Usage
@router.get("/data")
async def get_data(
current_user: User = Depends(get_current_user_flexible),
):
"""This endpoint accepts both JWT and API key auth."""
return {"user": current_user.username, "data": "..."}
# routers/api_keys.py
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models.user import User
from security.api_key import generate_api_key
from security.dependencies import get_current_active_user
router = APIRouter(prefix="/api/v1/api-keys", tags=["API Keys"])
@router.post("/generate", summary="Generate a new API key")
async def generate_new_api_key(
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
"""
Generate a new API key for the authenticated user.
WARNING: The API key is only shown once. Store it securely.
Generating a new key invalidates the previous one.
"""
new_key = generate_api_key()
current_user.api_key = new_key
await db.flush()
return {
"api_key": new_key,
"message": "Store this key securely. It will not be shown again.",
}
@router.delete("/revoke", summary="Revoke current API key")
async def revoke_api_key(
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
"""Revoke the current user's API key."""
if current_user.api_key is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No API key to revoke",
)
current_user.api_key = None
await db.flush()
return {"detail": "API key revoked successfully"}
# security/rate_limit.py
import time
from collections import defaultdict
from typing import Optional
from fastapi import Depends, HTTPException, Request, status
class RateLimiter:
"""
In-memory rate limiter using the sliding window algorithm.
For production, use Redis-based rate limiting (e.g., with
the slowapi library or a custom Redis implementation).
"""
def __init__(self):
# {key: [(timestamp, count), ...]}
self.requests: dict[str, list[float]] = defaultdict(list)
def is_rate_limited(
self,
key: str,
max_requests: int,
window_seconds: int,
) -> tuple[bool, dict]:
"""
Check if a key has exceeded its rate limit.
Returns:
(is_limited, info_dict) where info_dict contains
remaining requests and reset time.
"""
now = time.time()
window_start = now - window_seconds
# Remove expired entries
self.requests[key] = [
ts for ts in self.requests[key]
if ts > window_start
]
current_count = len(self.requests[key])
if current_count >= max_requests:
reset_time = self.requests[key][0] + window_seconds
return True, {
"limit": max_requests,
"remaining": 0,
"reset": int(reset_time),
}
# Record this request
self.requests[key].append(now)
return False, {
"limit": max_requests,
"remaining": max_requests - current_count - 1,
"reset": int(now + window_seconds),
}
# Global rate limiter instance
rate_limiter = RateLimiter()
def rate_limit(max_requests: int = 60, window_seconds: int = 60):
"""
Rate limiting dependency factory.
Usage:
@router.get("/data", dependencies=[Depends(rate_limit(100, 60))])
async def get_data():
...
"""
async def check_rate_limit(request: Request):
# Use API key or IP address as the rate limit key
api_key = request.headers.get("X-API-Key")
client_key = api_key or request.client.host
is_limited, info = rate_limiter.is_rate_limited(
client_key, max_requests, window_seconds
)
if is_limited:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded",
headers={
"X-RateLimit-Limit": str(info["limit"]),
"X-RateLimit-Remaining": str(info["remaining"]),
"X-RateLimit-Reset": str(info["reset"]),
"Retry-After": str(info["reset"] - int(time.time())),
},
)
return check_rate_limit
# Usage example
@router.get(
"/api/v1/search",
dependencies=[Depends(rate_limit(max_requests=30, window_seconds=60))],
)
async def search(q: str):
"""Rate-limited search endpoint: 30 requests per minute."""
return {"query": q, "results": []}
OAuth2 scopes provide a fine-grained permission system at the token level. Unlike role-based access which checks the user’s role, scopes define what a specific token is authorized to do. This is particularly useful when users want to grant limited access to third-party applications.
# security/scopes.py
from fastapi import Depends, HTTPException, Security, status
from fastapi.security import OAuth2PasswordBearer, SecurityScopes
from jose import JWTError
from security.jwt_handler import decode_token
from models.user import User
from services.user_service import UserService
from database import get_db
from sqlalchemy.ext.asyncio import AsyncSession
# Define available scopes with descriptions
# These appear in Swagger UI's Authorize dialog
OAUTH2_SCOPES = {
"profile:read": "Read your profile information",
"profile:write": "Update your profile information",
"posts:read": "Read posts",
"posts:write": "Create and edit posts",
"posts:delete": "Delete posts",
"users:read": "Read user information (admin)",
"users:write": "Modify user accounts (admin)",
"admin": "Full administrative access",
}
# OAuth2 scheme with scopes
oauth2_scheme_scoped = OAuth2PasswordBearer(
tokenUrl="/api/v1/auth/token",
scopes=OAUTH2_SCOPES,
)
# security/scopes.py (continued)
async def get_current_user_with_scopes(
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme_scoped),
db: AsyncSession = Depends(get_db),
) -> User:
"""
Validate token AND check that it has the required scopes.
FastAPI's SecurityScopes automatically collects the scopes
required by the endpoint and all its dependencies.
Args:
security_scopes: Automatically populated by FastAPI with
the scopes required by the endpoint chain.
token: The JWT bearer token.
db: Database session.
Returns:
The authenticated User if token is valid and has required scopes.
"""
# Build the authenticate header value with required scopes
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = "Bearer"
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": authenticate_value},
)
try:
token_data = decode_token(token)
if token_data.username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
# Look up the user
user_service = UserService(db)
user = await user_service.get_by_username(token_data.username)
if user is None:
raise credentials_exception
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is deactivated",
)
# Check that the token has all required scopes
for scope in security_scopes.scopes:
if scope not in token_data.scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Token missing required scope: {scope}",
headers={"WWW-Authenticate": authenticate_value},
)
return user
# routers/posts.py
from fastapi import APIRouter, Depends, Security
from models.user import User
from security.scopes import get_current_user_with_scopes
router = APIRouter(prefix="/api/v1/posts", tags=["Posts"])
@router.get("/")
async def list_posts(
current_user: User = Security(
get_current_user_with_scopes,
scopes=["posts:read"],
),
):
"""
List all posts. Requires 'posts:read' scope.
Note: We use Security() instead of Depends() when working with
scopes. Security() is a subclass of Depends() that also passes
the required scopes to the dependency.
"""
return {"posts": [], "user": current_user.username}
@router.post("/")
async def create_post(
title: str,
content: str,
current_user: User = Security(
get_current_user_with_scopes,
scopes=["posts:read", "posts:write"],
),
):
"""Create a post. Requires both 'posts:read' and 'posts:write' scopes."""
return {
"title": title,
"content": content,
"author": current_user.username,
}
@router.delete("/{post_id}")
async def delete_post(
post_id: int,
current_user: User = Security(
get_current_user_with_scopes,
scopes=["posts:delete"],
),
):
"""Delete a post. Requires 'posts:delete' scope."""
return {"detail": f"Post {post_id} deleted"}
# Admin endpoint requiring admin scope
@router.get("/admin/all")
async def admin_list_all_posts(
current_user: User = Security(
get_current_user_with_scopes,
scopes=["admin"],
),
):
"""List all posts with admin details. Requires 'admin' scope."""
return {"posts": [], "total": 0, "admin_view": True}
# Updated login endpoint that respects requested scopes
@router.post("/token")
async def login_with_scopes(
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db),
):
"""
Login endpoint that issues tokens with requested scopes.
The client can request specific scopes during login.
The server validates that the user is allowed to have
those scopes based on their role.
"""
user_service = UserService(db)
user = await user_service.authenticate(
form_data.username, form_data.password
)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
# Determine allowed scopes based on user role
allowed_scopes = get_allowed_scopes_for_role(user.role)
# Filter requested scopes to only include allowed ones
requested_scopes = form_data.scopes # list of scope strings
granted_scopes = [s for s in requested_scopes if s in allowed_scopes]
# If no scopes requested, grant default scopes for the role
if not requested_scopes:
granted_scopes = list(allowed_scopes)
token_data = {
"sub": user.username,
"scopes": granted_scopes,
"role": user.role.value,
}
access_token = create_access_token(data=token_data)
refresh_token = create_refresh_token(data={"sub": user.username})
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"scope": " ".join(granted_scopes), # OAuth2 spec: space-separated
}
def get_allowed_scopes_for_role(role: UserRole) -> set[str]:
"""Map user roles to allowed OAuth2 scopes."""
base_scopes = {"profile:read", "profile:write", "posts:read"}
role_scope_map = {
UserRole.USER: base_scopes | {"posts:write"},
UserRole.MODERATOR: base_scopes | {
"posts:write", "posts:delete", "users:read"
},
UserRole.ADMIN: base_scopes | {
"posts:write", "posts:delete",
"users:read", "users:write", "admin"
},
UserRole.SUPER_ADMIN: set(OAUTH2_SCOPES.keys()), # All scopes
}
return role_scope_map.get(role, base_scopes)
Refresh token rotation is a security technique where a new refresh token is issued every time the old one is used. This limits the window of vulnerability if a refresh token is compromised, because the stolen token becomes invalid after its first use.
# routers/auth.py — refresh endpoint with rotation
from datetime import datetime, timedelta, timezone
from pydantic import BaseModel
from sqlalchemy import select
from models.user import RefreshToken
class RefreshRequest(BaseModel):
refresh_token: str
@router.post(
"/refresh",
response_model=TokenResponse,
summary="Refresh access token with rotation",
)
async def refresh_access_token(
request: RefreshRequest,
db: AsyncSession = Depends(get_db),
):
"""
Exchange a refresh token for a new access token + refresh token.
Implements refresh token rotation:
1. Validates the refresh token (signature, expiration, type)
2. Checks if the token has been revoked
3. Detects token reuse (potential theft)
4. Issues new access + refresh tokens
5. Revokes the old refresh token
If token reuse is detected, ALL refresh tokens for the user
are revoked as a security measure.
"""
# Step 1: Decode and validate the refresh token
try:
token_data = decode_token(request.refresh_token)
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid refresh token: {str(e)}",
)
# Step 2: Ensure this is actually a refresh token
if token_data.token_type != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type. Expected refresh token.",
)
# Step 3: Look up the refresh token in the database
result = await db.execute(
select(RefreshToken).where(
RefreshToken.token_jti == token_data.jti
)
)
stored_token = result.scalar_one_or_none()
if stored_token is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token not found",
)
# Step 4: CRITICAL — Check for token reuse (theft detection)
if stored_token.is_revoked:
# This token was already used! Someone may have stolen it.
# Revoke ALL refresh tokens for this user as a safety measure.
await _revoke_all_user_tokens(stored_token.user_id, db)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token reuse detected. All sessions have been revoked. "
"Please log in again.",
)
# Step 5: Revoke the current refresh token (it is now used)
stored_token.is_revoked = True
# Step 6: Look up the user
user_service = UserService(db)
user = await user_service.get_by_id(stored_token.user_id)
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User account not found or deactivated",
)
# Step 7: Issue new tokens
new_token_data = {
"sub": user.username,
"role": user.role.value,
}
new_access_token = create_access_token(data=new_token_data)
new_refresh_token = create_refresh_token(data={"sub": user.username})
# Step 8: Store the new refresh token
new_refresh_data = decode_token(new_refresh_token)
db_new_token = RefreshToken(
token_jti=new_refresh_data.jti,
user_id=user.id,
expires_at=datetime.now(timezone.utc) + timedelta(
days=REFRESH_TOKEN_EXPIRE_DAYS
),
)
db.add(db_new_token)
return TokenResponse(
access_token=new_access_token,
refresh_token=new_refresh_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
async def _revoke_all_user_tokens(user_id: int, db: AsyncSession):
"""
Revoke all refresh tokens for a user.
Called when token reuse is detected as a security measure.
This forces the user (and any attacker) to log in again.
"""
from sqlalchemy import update
await db.execute(
update(RefreshToken)
.where(
RefreshToken.user_id == user_id,
RefreshToken.is_revoked == False,
)
.values(is_revoked=True)
)
await db.flush()
# services/token_cleanup.py
from datetime import datetime, timezone
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from models.user import RefreshToken
async def cleanup_expired_tokens(db: AsyncSession) -> int:
"""
Remove expired and revoked refresh tokens from the database.
Run this periodically (e.g., daily via a scheduled task)
to keep the refresh_tokens table from growing indefinitely.
Returns:
Number of tokens removed.
"""
now = datetime.now(timezone.utc)
result = await db.execute(
delete(RefreshToken).where(
(RefreshToken.expires_at < now) |
(RefreshToken.is_revoked == True)
)
)
await db.commit()
return result.rowcount
# Schedule cleanup on application startup
from contextlib import asynccontextmanager
from fastapi import FastAPI
import asyncio
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager with periodic token cleanup."""
# Start background cleanup task
cleanup_task = asyncio.create_task(periodic_cleanup())
yield # Application runs here
# Shutdown: cancel the cleanup task
cleanup_task.cancel()
try:
await cleanup_task
except asyncio.CancelledError:
pass
async def periodic_cleanup():
"""Run token cleanup every 24 hours."""
while True:
try:
async with async_session() as db:
removed = await cleanup_expired_tokens(db)
print(f"Token cleanup: removed {removed} expired/revoked tokens")
except Exception as e:
print(f"Token cleanup error: {e}")
await asyncio.sleep(86400) # 24 hours
# Use the lifespan in your app
app = FastAPI(lifespan=lifespan)
# security/token_blacklist.py
"""
Token blacklisting for immediate access token revocation.
While refresh token rotation handles refresh tokens, sometimes
you need to immediately revoke an access token (e.g., when a user
changes their password or an admin disables an account).
For production, use Redis for O(1) lookups and automatic expiry:
pip install redis
"""
from datetime import datetime, timezone
from typing import Optional
import redis.asyncio as redis
# Redis connection for token blacklist
redis_client = redis.Redis(host="localhost", port=6379, db=0)
async def blacklist_token(jti: str, expires_at: datetime) -> None:
"""
Add a token's JTI to the blacklist.
The entry automatically expires when the token would have expired,
so the blacklist stays clean without manual cleanup.
"""
ttl = int((expires_at - datetime.now(timezone.utc)).total_seconds())
if ttl > 0:
await redis_client.setex(f"blacklist:{jti}", ttl, "revoked")
async def is_token_blacklisted(jti: str) -> bool:
"""Check if a token has been blacklisted."""
result = await redis_client.get(f"blacklist:{jti}")
return result is not None
# Updated get_current_user with blacklist check
async def get_current_user_with_blacklist(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
) -> User:
"""Validates token and checks blacklist."""
try:
token_data = decode_token(token)
except JWTError:
raise HTTPException(status_code=401, detail="Invalid token")
# Check blacklist
if token_data.jti and await is_token_blacklisted(token_data.jti):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has been revoked",
)
user_service = UserService(db)
user = await user_service.get_by_username(token_data.username)
if not user or not user.is_active:
raise HTTPException(status_code=401, detail="User not found")
return user
A secure authentication system requires more than just correct logic. You need to configure transport security, prevent cross-origin attacks, rate-limit sensitive endpoints, sanitize input, and set appropriate security headers. This section covers the essential security hardening steps for a production FastAPI application.
# middleware/https_redirect.py
from fastapi import FastAPI, Request
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
def configure_https(app: FastAPI, environment: str = "production"):
"""
Configure HTTPS enforcement.
In production, all HTTP requests are redirected to HTTPS.
In development, HTTPS is not enforced.
"""
if environment == "production":
# Redirect all HTTP requests to HTTPS
app.add_middleware(HTTPSRedirectMiddleware)
class HSTSMiddleware(BaseHTTPMiddleware):
"""
Add HTTP Strict Transport Security header.
Tells browsers to only access the site over HTTPS for the
specified duration, preventing protocol downgrade attacks.
"""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains; preload"
)
return response
# middleware/cors.py
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
def configure_cors(app: FastAPI):
"""
Configure Cross-Origin Resource Sharing.
CORS controls which origins (domains) can make requests to your API.
This is crucial for security when your frontend and backend are on
different domains.
"""
# Allowed origins — be specific, never use ["*"] in production
allowed_origins = [
"https://yourdomain.com",
"https://app.yourdomain.com",
"http://localhost:3000", # React development server
"http://localhost:5173", # Vite development server
]
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True, # Allow cookies (for HTTP-only token cookies)
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
allow_headers=[
"Authorization",
"Content-Type",
"X-API-Key",
"X-Request-ID",
],
expose_headers=[
"X-RateLimit-Limit",
"X-RateLimit-Remaining",
"X-RateLimit-Reset",
],
max_age=600, # Cache preflight requests for 10 minutes
)
allow_origins=["*"] with allow_credentials=True. This combination is explicitly forbidden by the CORS specification. If you need credentials support, you must list specific origins.
# middleware/security_headers.py
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""
Add security headers to all responses.
These headers protect against common web vulnerabilities:
- XSS (Cross-Site Scripting)
- Clickjacking
- MIME-type sniffing
- Information leakage
"""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Prevent XSS: Don't execute scripts from inline sources
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data:; "
"font-src 'self'; "
"connect-src 'self'"
)
# Prevent clickjacking: Don't allow embedding in iframes
response.headers["X-Frame-Options"] = "DENY"
# Prevent MIME-type sniffing
response.headers["X-Content-Type-Options"] = "nosniff"
# Control referrer information
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Enable browser XSS filter (legacy, but still useful)
response.headers["X-XSS-Protection"] = "1; mode=block"
# Control browser features
response.headers["Permissions-Policy"] = (
"camera=(), microphone=(), geolocation=(), "
"payment=(), usb=()"
)
# Don't leak server information
response.headers.pop("server", None)
return response
# security/sanitization.py
import re
import html
from typing import Any
from pydantic import field_validator
def sanitize_string(value: str) -> str:
"""
Sanitize a string input to prevent injection attacks.
This function:
1. Strips leading/trailing whitespace
2. Escapes HTML entities (prevents XSS)
3. Removes null bytes (prevents null byte injection)
4. Limits length to prevent DoS
"""
if not isinstance(value, str):
return value
# Strip whitespace
value = value.strip()
# Remove null bytes
value = value.replace("\x00", "")
# Escape HTML entities
value = html.escape(value, quote=True)
return value
def sanitize_search_query(query: str) -> str:
"""
Sanitize search queries to prevent SQL injection and XSS.
For SQL queries, always use parameterized queries (SQLAlchemy
does this automatically). This function handles the display layer.
"""
# Remove SQL special characters
query = re.sub(r"[;'\"\-\-\/\*]", "", query)
# Escape HTML
query = html.escape(query)
# Limit length
return query[:500]
# Pydantic model with built-in sanitization
from pydantic import BaseModel
class SanitizedInput(BaseModel):
"""Base model that automatically sanitizes string fields."""
@field_validator("*", mode="before")
@classmethod
def sanitize_strings(cls, v: Any) -> Any:
if isinstance(v, str):
return sanitize_string(v)
return v
# Usage: inherit from SanitizedInput instead of BaseModel
class CommentCreate(SanitizedInput):
content: str
post_id: int
# content will be automatically sanitized
pip install slowapi
# middleware/rate_limiting.py
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi import FastAPI, Request
def get_rate_limit_key(request: Request) -> str:
"""
Determine the rate limit key for a request.
Uses the API key if present, otherwise falls back to IP address.
This ensures API key users get their own rate limit bucket.
"""
api_key = request.headers.get("X-API-Key")
if api_key:
return f"apikey:{api_key}"
return get_remote_address(request)
# Create limiter instance
limiter = Limiter(key_func=get_rate_limit_key)
def configure_rate_limiting(app: FastAPI):
"""Configure rate limiting for the application."""
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Usage in endpoints
from slowapi import limiter as _limiter
@router.post("/auth/token")
@limiter.limit("5/minute") # Strict limit on login attempts
async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
"""Login with rate limiting: 5 attempts per minute."""
...
@router.post("/users/register")
@limiter.limit("3/hour") # Very strict for registration
async def register(request: Request, user_data: UserCreate):
"""Registration with rate limiting: 3 per hour per IP."""
...
@router.get("/api/data")
@limiter.limit("100/minute") # Higher limit for data endpoints
async def get_data(request: Request):
"""Data endpoint: 100 requests per minute."""
...
# security/audit.py
import logging
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from fastapi import Request
class AuditEvent(str, Enum):
"""Types of security events to audit."""
LOGIN_SUCCESS = "login_success"
LOGIN_FAILURE = "login_failure"
LOGOUT = "logout"
TOKEN_REFRESH = "token_refresh"
TOKEN_REVOKED = "token_revoked"
TOKEN_REUSE_DETECTED = "token_reuse_detected"
PASSWORD_CHANGED = "password_changed"
ROLE_CHANGED = "role_changed"
ACCOUNT_LOCKED = "account_locked"
ACCOUNT_DEACTIVATED = "account_deactivated"
API_KEY_GENERATED = "api_key_generated"
API_KEY_REVOKED = "api_key_revoked"
UNAUTHORIZED_ACCESS = "unauthorized_access"
RATE_LIMIT_EXCEEDED = "rate_limit_exceeded"
# Configure audit logger
audit_logger = logging.getLogger("security.audit")
audit_logger.setLevel(logging.INFO)
# File handler for persistent audit log
handler = logging.FileHandler("security_audit.log")
handler.setFormatter(
logging.Formatter(
"%(asctime)s | %(levelname)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
audit_logger.addHandler(handler)
def log_security_event(
event: AuditEvent,
request: Optional[Request] = None,
user_id: Optional[int] = None,
username: Optional[str] = None,
details: Optional[str] = None,
):
"""
Log a security event for auditing.
In production, consider sending these to a SIEM system
(Splunk, ELK, etc.) or a dedicated audit database.
"""
client_ip = request.client.host if request else "unknown"
user_agent = (
request.headers.get("user-agent", "unknown")
if request else "unknown"
)
log_entry = (
f"event={event.value} | "
f"ip={client_ip} | "
f"user_id={user_id} | "
f"username={username} | "
f"user_agent={user_agent} | "
f"details={details}"
)
# Use WARNING level for security-critical events
if event in (
AuditEvent.TOKEN_REUSE_DETECTED,
AuditEvent.ACCOUNT_LOCKED,
AuditEvent.UNAUTHORIZED_ACCESS,
AuditEvent.LOGIN_FAILURE,
):
audit_logger.warning(log_entry)
else:
audit_logger.info(log_entry)
# Usage in login endpoint
async def login(request: Request, ...):
user = await authenticate(username, password)
if not user:
log_security_event(
AuditEvent.LOGIN_FAILURE,
request=request,
username=username,
details="Invalid credentials",
)
raise HTTPException(...)
log_security_event(
AuditEvent.LOGIN_SUCCESS,
request=request,
user_id=user.id,
username=user.username,
)
...
| Category | Requirement | Status |
|---|---|---|
| Transport | HTTPS enforced with HSTS | Required |
| Transport | TLS 1.2+ only | Required |
| Passwords | bcrypt or argon2 hashing | Required |
| Passwords | Minimum 8 characters with complexity | Required |
| Tokens | Short-lived access tokens (15-30 min) | Required |
| Tokens | Refresh token rotation | Recommended |
| Tokens | Token blacklisting capability | Recommended |
| Headers | CORS properly configured | Required |
| Headers | Security headers set (CSP, X-Frame, etc.) | Required |
| Rate Limiting | Login endpoint rate limited | Required |
| Rate Limiting | Registration endpoint rate limited | Required |
| Input | All input validated and sanitized | Required |
| Input | Parameterized SQL queries | Required |
| Logging | Security events audited | Required |
| Logging | No sensitive data in logs | Required |
| Account | Account lockout after failed attempts | Recommended |
| Account | Email verification for new accounts | Recommended |
Now let us bring everything together into a complete, production-ready authentication system. This section combines all the concepts we have covered into a cohesive application structure that you can use as a template for your own projects.
fastapi-auth/
├── main.py # Application entry point
├── config.py # Configuration settings
├── database.py # Database connection and session
├── requirements.txt # Python dependencies
├── .env # Environment variables (not in git)
├── models/
│ ├── __init__.py
│ └── user.py # User and RefreshToken models
├── schemas/
│ ├── __init__.py
│ ├── user.py # User Pydantic schemas
│ └── auth.py # Auth request/response schemas
├── routers/
│ ├── __init__.py
│ ├── auth.py # Login, logout, refresh endpoints
│ ├── users.py # User registration and profile
│ └── admin.py # Admin-only endpoints
├── security/
│ ├── __init__.py
│ ├── password.py # Password hashing
│ ├── jwt_handler.py # JWT creation and validation
│ ├── dependencies.py # Auth dependencies (get_current_user)
│ ├── rbac.py # Role-based access control
│ ├── api_key.py # API key authentication
│ ├── scopes.py # OAuth2 scopes
│ └── audit.py # Security audit logging
├── services/
│ ├── __init__.py
│ └── user_service.py # User business logic
└── middleware/
├── __init__.py
├── cors.py # CORS configuration
├── security_headers.py # Security headers
└── rate_limiting.py # Rate limiting
# requirements.txt fastapi==0.115.0 uvicorn[standard]==0.30.0 sqlalchemy[asyncio]==2.0.35 aiosqlite==0.20.0 python-jose[cryptography]==3.3.0 passlib[bcrypt]==1.7.4 bcrypt==4.2.0 python-multipart==0.0.9 pydantic[email]==2.9.0 pydantic-settings==2.5.0 slowapi==0.1.9 python-dotenv==1.0.1
# .env SECRET_KEY=your-generated-secret-key-from-openssl-rand-hex-32 ALGORITHM=HS256 ACCESS_TOKEN_EXPIRE_MINUTES=30 REFRESH_TOKEN_EXPIRE_DAYS=7 DATABASE_URL=sqlite+aiosqlite:///./auth_app.db ENVIRONMENT=development BCRYPT_ROUNDS=12
# config.py
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# Application
APP_NAME: str = "FastAPI Auth System"
ENVIRONMENT: str = "development"
DEBUG: bool = False
# Database
DATABASE_URL: str = "sqlite+aiosqlite:///./auth_app.db"
# JWT
SECRET_KEY: str
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
# Security
BCRYPT_ROUNDS: int = 12
MAX_LOGIN_ATTEMPTS: int = 5
LOCKOUT_DURATION_MINUTES: int = 30
# CORS
ALLOWED_ORIGINS: list[str] = ["http://localhost:3000"]
class Config:
env_file = ".env"
case_sensitive = True
settings = Settings()
# database.py
from sqlalchemy.ext.asyncio import (
AsyncSession,
create_async_engine,
async_sessionmaker,
)
from sqlalchemy.orm import DeclarativeBase
from config import settings
engine = create_async_engine(
settings.DATABASE_URL,
echo=settings.DEBUG,
pool_pre_ping=True, # Verify connections are alive
pool_size=5, # Connection pool size
max_overflow=10, # Extra connections allowed
)
async_session = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
class Base(DeclarativeBase):
pass
async def get_db():
"""Database session dependency with automatic cleanup."""
async with async_session() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
async def init_db():
"""Create all database tables."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# schemas/auth.py
from pydantic import BaseModel
class TokenResponse(BaseModel):
"""Response schema for token endpoints."""
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
scope: str = ""
class RefreshRequest(BaseModel):
"""Request schema for token refresh."""
refresh_token: str
class PasswordChangeRequest(BaseModel):
"""Request schema for password changes."""
current_password: str
new_password: str
class MessageResponse(BaseModel):
"""Generic message response."""
detail: str
# routers/auth.py
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordRequestForm
from jose import JWTError
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from config import settings
from database import get_db
from models.user import RefreshToken, User
from schemas.auth import (
MessageResponse,
PasswordChangeRequest,
RefreshRequest,
TokenResponse,
)
from security.audit import AuditEvent, log_security_event
from security.dependencies import get_current_active_user
from security.jwt_handler import (
create_access_token,
create_refresh_token,
decode_token,
)
from security.password import hash_password, verify_password
from services.user_service import UserService
router = APIRouter(prefix="/api/v1/auth", tags=["Authentication"])
@router.post("/token", response_model=TokenResponse)
async def login(
request: Request,
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db),
):
"""
Authenticate user and return JWT access + refresh tokens.
Accepts OAuth2 password flow (form data with username and password).
"""
user_service = UserService(db)
user = await user_service.authenticate(
form_data.username, form_data.password
)
if not user:
log_security_event(
AuditEvent.LOGIN_FAILURE,
request=request,
username=form_data.username,
details="Invalid credentials",
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Determine scopes based on role
from security.scopes import get_allowed_scopes_for_role
granted_scopes = list(get_allowed_scopes_for_role(user.role))
# Create tokens
token_data = {
"sub": user.username,
"role": user.role.value,
"scopes": granted_scopes,
}
access_token = create_access_token(data=token_data)
refresh_token = create_refresh_token(data={"sub": user.username})
# Store refresh token
refresh_data = decode_token(refresh_token)
db_token = RefreshToken(
token_jti=refresh_data.jti,
user_id=user.id,
expires_at=datetime.now(timezone.utc) + timedelta(
days=settings.REFRESH_TOKEN_EXPIRE_DAYS
),
)
db.add(db_token)
log_security_event(
AuditEvent.LOGIN_SUCCESS,
request=request,
user_id=user.id,
username=user.username,
)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
scope=" ".join(granted_scopes),
)
@router.post("/refresh", response_model=TokenResponse)
async def refresh_token(
request: Request,
body: RefreshRequest,
db: AsyncSession = Depends(get_db),
):
"""Refresh access token using refresh token rotation."""
try:
token_data = decode_token(body.refresh_token)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
if token_data.token_type != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
)
# Find stored token
result = await db.execute(
select(RefreshToken).where(
RefreshToken.token_jti == token_data.jti
)
)
stored_token = result.scalar_one_or_none()
if not stored_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token not found",
)
# Detect reuse
if stored_token.is_revoked:
log_security_event(
AuditEvent.TOKEN_REUSE_DETECTED,
request=request,
user_id=stored_token.user_id,
details=f"Reused JTI: {token_data.jti}",
)
await db.execute(
update(RefreshToken)
.where(RefreshToken.user_id == stored_token.user_id)
.values(is_revoked=True)
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token reuse detected. All sessions revoked.",
)
# Revoke old token
stored_token.is_revoked = True
# Get user
user_service = UserService(db)
user = await user_service.get_by_id(stored_token.user_id)
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive",
)
# Issue new tokens
from security.scopes import get_allowed_scopes_for_role
granted_scopes = list(get_allowed_scopes_for_role(user.role))
new_access = create_access_token(data={
"sub": user.username,
"role": user.role.value,
"scopes": granted_scopes,
})
new_refresh = create_refresh_token(data={"sub": user.username})
# Store new refresh token
new_data = decode_token(new_refresh)
db.add(RefreshToken(
token_jti=new_data.jti,
user_id=user.id,
expires_at=datetime.now(timezone.utc) + timedelta(
days=settings.REFRESH_TOKEN_EXPIRE_DAYS
),
))
log_security_event(
AuditEvent.TOKEN_REFRESH,
request=request,
user_id=user.id,
username=user.username,
)
return TokenResponse(
access_token=new_access,
refresh_token=new_refresh,
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
scope=" ".join(granted_scopes),
)
@router.post("/logout", response_model=MessageResponse)
async def logout(
request: Request,
body: RefreshRequest,
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
"""Logout by revoking the refresh token."""
try:
token_data = decode_token(body.refresh_token)
except JWTError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid token",
)
result = await db.execute(
update(RefreshToken)
.where(RefreshToken.token_jti == token_data.jti)
.values(is_revoked=True)
)
log_security_event(
AuditEvent.LOGOUT,
request=request,
user_id=current_user.id,
username=current_user.username,
)
return MessageResponse(detail="Successfully logged out")
@router.post("/change-password", response_model=MessageResponse)
async def change_password(
request: Request,
body: PasswordChangeRequest,
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
"""Change the authenticated user's password."""
if not verify_password(body.current_password, current_user.hashed_password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect",
)
current_user.hashed_password = hash_password(body.new_password)
# Revoke all refresh tokens (force re-login on all devices)
await db.execute(
update(RefreshToken)
.where(
RefreshToken.user_id == current_user.id,
RefreshToken.is_revoked == False,
)
.values(is_revoked=True)
)
log_security_event(
AuditEvent.PASSWORD_CHANGED,
request=request,
user_id=current_user.id,
username=current_user.username,
)
return MessageResponse(
detail="Password changed. Please log in again on all devices."
)
# main.py
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from config import settings
from database import init_db
from middleware.security_headers import SecurityHeadersMiddleware
from routers import admin, auth, users
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application startup and shutdown events."""
# Startup
await init_db()
print(f"Database initialized. Environment: {settings.ENVIRONMENT}")
yield
# Shutdown
print("Application shutting down.")
app = FastAPI(
title=settings.APP_NAME,
description=(
"Complete authentication and authorization system with JWT, "
"refresh token rotation, RBAC, OAuth2 scopes, and API keys."
),
version="1.0.0",
lifespan=lifespan,
docs_url="/docs" if settings.ENVIRONMENT != "production" else None,
redoc_url="/redoc" if settings.ENVIRONMENT != "production" else None,
)
# Middleware (order matters — last added runs first)
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
allow_headers=["Authorization", "Content-Type", "X-API-Key"],
)
# Routers
app.include_router(auth.router)
app.include_router(users.router)
app.include_router(admin.router)
@app.get("/health")
async def health_check():
"""Health check endpoint for load balancers and monitoring."""
return {"status": "healthy", "version": "1.0.0"}
# Install dependencies pip install -r requirements.txt # Generate a secret key export SECRET_KEY=$(openssl rand -hex 32) # Run the development server uvicorn main:app --reload --host 0.0.0.0 --port 8000 # Open Swagger UI # http://localhost:8000/docs
# test_auth_system.py
"""
Integration tests for the complete auth system.
Run with: pytest test_auth_system.py -v
"""
import pytest
from httpx import AsyncClient, ASGITransport
from main import app
@pytest.fixture
async def client():
"""Create an async test client."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
@pytest.fixture
async def registered_user(client: AsyncClient) -> dict:
"""Register a test user and return credentials."""
user_data = {
"username": "testuser",
"email": "test@example.com",
"password": "TestPass123!",
"full_name": "Test User",
}
response = await client.post("/api/v1/users/register", json=user_data)
assert response.status_code == 201
return user_data
@pytest.fixture
async def auth_tokens(client: AsyncClient, registered_user: dict) -> dict:
"""Login and return access + refresh tokens."""
response = await client.post(
"/api/v1/auth/token",
data={
"username": registered_user["username"],
"password": registered_user["password"],
},
)
assert response.status_code == 200
return response.json()
class TestRegistration:
"""Test user registration."""
async def test_register_success(self, client: AsyncClient):
response = await client.post(
"/api/v1/users/register",
json={
"username": "newuser",
"email": "new@example.com",
"password": "SecurePass123!",
},
)
assert response.status_code == 201
data = response.json()
assert data["username"] == "newuser"
assert "hashed_password" not in data
async def test_register_duplicate_username(
self, client: AsyncClient, registered_user
):
response = await client.post(
"/api/v1/users/register",
json={
"username": registered_user["username"],
"email": "different@example.com",
"password": "SecurePass123!",
},
)
assert response.status_code == 409
async def test_register_weak_password(self, client: AsyncClient):
response = await client.post(
"/api/v1/users/register",
json={
"username": "weakuser",
"email": "weak@example.com",
"password": "weak",
},
)
assert response.status_code == 422
class TestLogin:
"""Test login and token generation."""
async def test_login_success(
self, client: AsyncClient, registered_user
):
response = await client.post(
"/api/v1/auth/token",
data={
"username": registered_user["username"],
"password": registered_user["password"],
},
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
async def test_login_wrong_password(
self, client: AsyncClient, registered_user
):
response = await client.post(
"/api/v1/auth/token",
data={
"username": registered_user["username"],
"password": "WrongPass123!",
},
)
assert response.status_code == 401
class TestProtectedEndpoints:
"""Test protected endpoint access."""
async def test_get_profile_authenticated(
self, client: AsyncClient, auth_tokens
):
response = await client.get(
"/api/v1/users/me",
headers={
"Authorization": f"Bearer {auth_tokens['access_token']}"
},
)
assert response.status_code == 200
async def test_get_profile_no_token(self, client: AsyncClient):
response = await client.get("/api/v1/users/me")
assert response.status_code == 401
async def test_get_profile_invalid_token(self, client: AsyncClient):
response = await client.get(
"/api/v1/users/me",
headers={"Authorization": "Bearer invalid-token"},
)
assert response.status_code == 401
class TestTokenRefresh:
"""Test refresh token rotation."""
async def test_refresh_success(
self, client: AsyncClient, auth_tokens
):
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": auth_tokens["refresh_token"]},
)
assert response.status_code == 200
data = response.json()
assert data["access_token"] != auth_tokens["access_token"]
assert data["refresh_token"] != auth_tokens["refresh_token"]
async def test_refresh_reuse_detection(
self, client: AsyncClient, auth_tokens
):
# First refresh: should succeed
response1 = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": auth_tokens["refresh_token"]},
)
assert response1.status_code == 200
# Second refresh with same token: should fail (reuse detected)
response2 = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": auth_tokens["refresh_token"]},
)
assert response2.status_code == 401
assert "reuse" in response2.json()["detail"].lower()
| Method | Endpoint | Auth | Description |
|---|---|---|---|
POST |
/api/v1/users/register |
None | Register a new user |
POST |
/api/v1/auth/token |
None | Login (get tokens) |
POST |
/api/v1/auth/refresh |
None | Refresh access token |
POST |
/api/v1/auth/logout |
Bearer | Revoke refresh token |
POST |
/api/v1/auth/change-password |
Bearer | Change password |
GET |
/api/v1/users/me |
Bearer | Get current user profile |
PUT |
/api/v1/users/me |
Bearer | Update profile |
DELETE |
/api/v1/users/me |
Bearer | Deactivate account |
POST |
/api/v1/api-keys/generate |
Bearer | Generate API key |
DELETE |
/api/v1/api-keys/revoke |
Bearer | Revoke API key |
GET |
/api/v1/admin/users |
Admin | List all users |
PUT |
/api/v1/admin/users/{id}/role |
Admin | Change user role |
GET |
/api/v1/admin/dashboard |
Admin | Admin statistics |
GET |
/health |
None | Health check |
┌──────────┐ 1. POST /register ┌──────────┐ │ │ ──────────────────────────> │ │ │ Client │ <────────────────────────── │ Server │ │ │ User created (201) │ │ │ │ │ │ │ │ 2. POST /token │ │ │ │ (username + password) │ │ │ │ ──────────────────────────> │ │ │ │ <────────────────────────── │ │ │ │ access_token + │ │ │ │ refresh_token │ │ │ │ │ │ │ │ 3. GET /users/me │ │ │ │ Authorization: Bearer │ │ │ │ ──────────────────────────> │ │ │ │ <────────────────────────── │ │ │ │ User profile (200) │ │ │ │ │ │ │ │ 4. POST /refresh │ │ │ │ (when token expires) │ │ │ │ ──────────────────────────> │ │ │ │ <────────────────────────── │ │ │ │ new access_token + │ │ │ │ new refresh_token │ │ │ │ │ │ │ │ 5. POST /logout │ │ │ │ ──────────────────────────> │ │ │ │ <────────────────────────── │ │ │ │ Tokens revoked (200) │ │ └──────────┘ └──────────┘
This tutorial covered a comprehensive set of authentication and authorization patterns in FastAPI. Here is a summary of the key concepts and techniques you learned:
| Topic | Key Concept | Implementation |
|---|---|---|
| Authentication vs Authorization | Authentication verifies identity; authorization controls access | Return 401 for auth failures, 403 for permission failures |
| Password Hashing | Never store plaintext passwords; use bcrypt with cost factor 12+ | passlib.context.CryptContext with bcrypt scheme |
| OAuth2 Password Flow | Standard flow for username/password login returning tokens | OAuth2PasswordBearer + OAuth2PasswordRequestForm |
| JWT Tokens | Stateless, signed tokens with claims (sub, exp, iat, jti) | python-jose for creating and verifying tokens |
| User Registration | Validate input, check duplicates, hash password, store user | Pydantic validators + SQLAlchemy models |
| Login System | Authenticate credentials, issue access + refresh tokens | Token endpoint returning TokenResponse |
| Protected Endpoints | Dependency injection validates tokens automatically | Depends(get_current_active_user) |
| RBAC | Roles map to permissions; check before granting access | require_role() and require_permission() dependencies |
| API Keys | Pre-shared keys for programmatic access | APIKeyHeader + database lookup |
| OAuth2 Scopes | Token-level permissions for fine-grained access control | Security(dep, scopes=[...]) + SecurityScopes |
| Refresh Token Rotation | New refresh token on each use; detect reuse as theft | Database-tracked JTIs with revocation flags |
| Security Best Practices | Defense in depth: HTTPS, CORS, headers, rate limiting, audit | Middleware stack + security header configuration |
# Import and use these in your endpoints:
from security.dependencies import (
get_current_user, # Basic: validates token, returns User
get_current_active_user, # + checks user.is_active
get_current_user_optional, # Returns User or None (no 401)
)
from security.rbac import (
require_role, # require_role(UserRole.ADMIN)
require_minimum_role, # require_minimum_role(UserRole.MODERATOR)
require_permission, # require_permission(Permission.DELETE_ANY_POST)
)
from security.scopes import (
get_current_user_with_scopes, # Security(dep, scopes=["posts:read"])
)
from security.api_key import (
get_user_from_api_key, # X-API-Key header auth
)
from security.combined_auth import (
get_current_user_flexible, # JWT OR API key
)
authlibpyotpTesting is not optional — it is a fundamental part of professional software development. FastAPI, built on top of Starlette and Pydantic, provides first-class testing support that makes writing comprehensive tests straightforward and enjoyable. In this tutorial, we will explore every aspect of testing FastAPI applications, from simple route tests to complex integration tests with databases, authentication, mocking, and CI/CD pipelines.
By the end of this guide, you will have a complete understanding of how to build a robust test suite that gives you confidence in your FastAPI application’s correctness, performance, and reliability.
Before writing a single test, it is important to understand why we test and what strategies guide our testing decisions.
Testing serves multiple critical purposes in software development:
| Benefit | Description |
|---|---|
| Bug Prevention | Catch errors before they reach production |
| Confidence in Refactoring | Change code without fear of breaking existing functionality |
| Documentation | Tests describe how the system should behave |
| Design Improvement | Writing testable code leads to better architecture |
| Regression Prevention | Ensure fixed bugs stay fixed |
| Deployment Confidence | Automated tests enable continuous deployment |
The testing pyramid is a strategy that guides how many tests of each type you should write:
"""
The Testing Pyramid for FastAPI Applications
/\
/ \ E2E Tests (Few)
/ \ - Full browser/API integration tests
/------\ - Slow, expensive, brittle
/ \
/ Integ. \ Integration Tests (Some)
/ Tests \ - Database, external services
/--------------\ - Moderate speed
/ \
/ Unit Tests \ Unit Tests (Many)
/------------------\ - Fast, isolated, focused
- Test individual functions/classes
"""
# Unit Test Example
def test_calculate_discount():
"""Tests a pure function in isolation."""
assert calculate_discount(100, 10) == 90.0
assert calculate_discount(100, 0) == 100.0
assert calculate_discount(50, 50) == 25.0
# Integration Test Example
def test_create_user_in_database(db_session):
"""Tests interaction with a real database."""
user = User(name="Alice", email="alice@example.com")
db_session.add(user)
db_session.commit()
saved = db_session.query(User).filter_by(email="alice@example.com").first()
assert saved is not None
assert saved.name == "Alice"
# E2E Test Example
def test_full_user_registration_flow(client):
"""Tests the complete registration workflow."""
# Register
response = client.post("/auth/register", json={
"name": "Alice",
"email": "alice@example.com",
"password": "SecurePass123!"
})
assert response.status_code == 201
# Login
response = client.post("/auth/login", json={
"email": "alice@example.com",
"password": "SecurePass123!"
})
assert response.status_code == 200
token = response.json()["access_token"]
# Access protected resource
response = client.get("/users/me", headers={
"Authorization": f"Bearer {token}"
})
assert response.status_code == 200
assert response.json()["email"] == "alice@example.com"
Test-Driven Development (TDD) follows a simple cycle: Red (write a failing test), Green (make it pass), Refactor (improve the code). Here is how TDD works with FastAPI:
"""
TDD Cycle with FastAPI:
1. RED - Write a test for a feature that doesn't exist yet
2. GREEN - Write the minimum code to make the test pass
3. REFACTOR - Clean up while keeping tests green
"""
# Step 1: RED - Write the failing test first
# tests/test_items.py
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_get_items_returns_empty_list():
"""We want GET /items to return an empty list initially."""
response = client.get("/items")
assert response.status_code == 200
assert response.json() == []
# Running this test will FAIL because we haven't created the endpoint yet.
# Step 2: GREEN - Write minimal code to pass
# app/main.py
from fastapi import FastAPI
app = FastAPI()
@app.get("/items")
def get_items():
return []
# Now the test passes!
# Step 3: REFACTOR - Improve the implementation
# app/main.py
from fastapi import FastAPI
from typing import List
from app.models import Item
app = FastAPI()
items_db: List[Item] = []
@app.get("/items", response_model=List[Item])
def get_items():
"""Retrieve all items."""
return items_db
A well-organized test setup is the foundation for a maintainable test suite. Let us configure everything properly from the start.
# Core testing dependencies pip install pytest httpx # Additional testing utilities pip install pytest-asyncio pytest-cov pytest-mock # For database testing pip install sqlalchemy aiosqlite # For performance testing pip install locust # Install all at once pip install pytest httpx pytest-asyncio pytest-cov pytest-mock sqlalchemy aiosqlite locust
| Package | Purpose |
|---|---|
pytest |
Primary testing framework with powerful fixtures and assertions |
httpx |
Async HTTP client, used for async test client |
pytest-asyncio |
Enables async test functions with pytest |
pytest-cov |
Test coverage reporting |
pytest-mock |
Simplified mocking with pytest fixtures |
locust |
Load testing and performance benchmarking |
my_fastapi_project/ ├── app/ │ ├── __init__.py │ ├── main.py # FastAPI app instance │ ├── config.py # Configuration settings │ ├── database.py # Database connection │ ├── models/ │ │ ├── __init__.py │ │ ├── user.py # SQLAlchemy User model │ │ └── item.py # SQLAlchemy Item model │ ├── schemas/ │ │ ├── __init__.py │ │ ├── user.py # Pydantic User schemas │ │ └── item.py # Pydantic Item schemas │ ├── routers/ │ │ ├── __init__.py │ │ ├── users.py # User endpoints │ │ ├── items.py # Item endpoints │ │ └── auth.py # Authentication endpoints │ ├── services/ │ │ ├── __init__.py │ │ ├── user_service.py # User business logic │ │ └── email_service.py # Email sending service │ └── dependencies.py # Shared dependencies ├── tests/ │ ├── __init__.py │ ├── conftest.py # Shared fixtures (CRITICAL FILE) │ ├── test_main.py # App-level tests │ ├── unit/ │ │ ├── __init__.py │ │ ├── test_schemas.py # Pydantic schema tests │ │ └── test_services.py # Business logic tests │ ├── integration/ │ │ ├── __init__.py │ │ ├── test_users.py # User endpoint tests │ │ ├── test_items.py # Item endpoint tests │ │ └── test_auth.py # Auth endpoint tests │ ├── e2e/ │ │ ├── __init__.py │ │ └── test_workflows.py # End-to-end workflow tests │ └── performance/ │ └── locustfile.py # Load testing scripts ├── pytest.ini # Pytest configuration ├── .coveragerc # Coverage configuration └── requirements-test.txt # Test dependencies
# pytest.ini
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
asyncio_mode = auto
addopts = -v --tb=short --strict-markers
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
integration: marks integration tests
e2e: marks end-to-end tests
unit: marks unit tests
# .coveragerc
[run]
source = app
omit =
app/__init__.py
app/config.py
tests/*
[report]
show_missing = true
fail_under = 80
exclude_lines =
pragma: no cover
def __repr__
if __name__ == .__main__.
raise NotImplementedError
The conftest.py file is the heart of your test configuration. Pytest automatically discovers and uses fixtures defined here across all test files.
# tests/conftest.py
"""
Central test configuration and shared fixtures.
Fixtures defined here are automatically available to all test files
without requiring explicit imports.
"""
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.main import app
from app.database import Base, get_db
# ---------------------------------------------------------------------------
# Test Database Configuration
# ---------------------------------------------------------------------------
# Use an in-memory SQLite database for tests — fast and isolated
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
# ---------------------------------------------------------------------------
# Database Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="function")
def db_session():
"""
Create a fresh database session for each test.
Tables are created before the test and dropped after,
ensuring complete isolation between tests.
"""
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def client(db_session):
"""
Create a TestClient with the test database session injected.
This overrides the real database dependency so that all
endpoint tests use the test database.
"""
def override_get_db():
try:
yield db_session
finally:
db_session.close()
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
# Clean up overrides after the test
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# Authentication Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def auth_headers(client):
"""
Register a test user and return auth headers with a valid JWT token.
"""
# Register a user
client.post("/auth/register", json={
"name": "Test User",
"email": "test@example.com",
"password": "TestPass123!"
})
# Log in and get token
response = client.post("/auth/login", json={
"email": "test@example.com",
"password": "TestPass123!"
})
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def admin_headers(client):
"""
Create an admin user and return auth headers.
"""
client.post("/auth/register", json={
"name": "Admin User",
"email": "admin@example.com",
"password": "AdminPass123!",
"role": "admin"
})
response = client.post("/auth/login", json={
"email": "admin@example.com",
"password": "AdminPass123!"
})
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
# ---------------------------------------------------------------------------
# Sample Data Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def sample_user_data():
"""Return valid user registration data."""
return {
"name": "Alice Johnson",
"email": "alice@example.com",
"password": "SecurePass123!"
}
@pytest.fixture
def sample_item_data():
"""Return valid item creation data."""
return {
"title": "Test Item",
"description": "A test item for unit testing",
"price": 29.99,
"quantity": 10
}
@pytest.fixture
def multiple_items_data():
"""Return a list of items for batch testing."""
return [
{"title": f"Item {i}", "description": f"Description {i}",
"price": 10.0 * i, "quantity": i * 5}
for i in range(1, 6)
]
conftest.py file must be placed in the tests/ directory (or in subdirectories for scoped fixtures). Pytest discovers it automatically — you should never import from conftest.py directly.
FastAPI’s TestClient is built on top of Starlette’s test client, which uses the requests library interface. It allows you to test your endpoints synchronously without running a server.
# app/main.py
from fastapi import FastAPI
app = FastAPI(title="My API", version="1.0.0")
@app.get("/")
def read_root():
return {"message": "Hello, World!"}
@app.get("/health")
def health_check():
return {"status": "healthy", "version": "1.0.0"}
@app.get("/items/{item_id}")
def get_item(item_id: int, q: str = None):
result = {"item_id": item_id}
if q:
result["query"] = q
return result
# tests/test_main.py
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
class TestRootEndpoint:
"""Tests for the root endpoint."""
def test_root_returns_200(self):
"""GET / should return 200 OK."""
response = client.get("/")
assert response.status_code == 200
def test_root_returns_hello_message(self):
"""GET / should return a hello message."""
response = client.get("/")
data = response.json()
assert data["message"] == "Hello, World!"
def test_root_content_type_is_json(self):
"""GET / should return JSON content type."""
response = client.get("/")
assert response.headers["content-type"] == "application/json"
class TestHealthEndpoint:
"""Tests for the health check endpoint."""
def test_health_returns_200(self):
response = client.get("/health")
assert response.status_code == 200
def test_health_returns_healthy_status(self):
response = client.get("/health")
data = response.json()
assert data["status"] == "healthy"
def test_health_includes_version(self):
response = client.get("/health")
data = response.json()
assert "version" in data
assert data["version"] == "1.0.0"
class TestGetItem:
"""Tests for the GET /items/{item_id} endpoint."""
def test_get_item_returns_correct_id(self):
response = client.get("/items/42")
assert response.status_code == 200
assert response.json()["item_id"] == 42
def test_get_item_with_query_parameter(self):
response = client.get("/items/42?q=search_term")
data = response.json()
assert data["item_id"] == 42
assert data["query"] == "search_term"
def test_get_item_without_query_parameter(self):
response = client.get("/items/42")
data = response.json()
assert "query" not in data
def test_get_item_invalid_id_type(self):
"""Passing a string where int is expected should return 422."""
response = client.get("/items/not_a_number")
assert response.status_code == 422
# tests/test_http_methods.py
"""
Demonstrates testing all HTTP methods with TestClient.
"""
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
class TestHTTPMethods:
"""Test all standard HTTP methods."""
def test_get_request(self):
"""Test a GET request."""
response = client.get("/items")
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_post_request_with_json(self):
"""Test a POST request with JSON body."""
payload = {
"title": "New Item",
"description": "A brand new item",
"price": 19.99,
"quantity": 5
}
response = client.post("/items", json=payload)
assert response.status_code == 201
data = response.json()
assert data["title"] == "New Item"
assert data["price"] == 19.99
assert "id" in data # Should have an auto-generated ID
def test_put_request_updates_resource(self):
"""Test a PUT request to update a resource."""
# First, create an item
create_response = client.post("/items", json={
"title": "Original",
"description": "Original description",
"price": 10.00,
"quantity": 1
})
item_id = create_response.json()["id"]
# Then update it
update_payload = {
"title": "Updated",
"description": "Updated description",
"price": 15.00,
"quantity": 2
}
response = client.put(f"/items/{item_id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["title"] == "Updated"
assert data["price"] == 15.00
def test_patch_request_partial_update(self):
"""Test a PATCH request for partial updates."""
# Create an item
create_response = client.post("/items", json={
"title": "Original",
"description": "Original description",
"price": 10.00,
"quantity": 1
})
item_id = create_response.json()["id"]
# Partially update (only the price)
response = client.patch(
f"/items/{item_id}",
json={"price": 25.00}
)
assert response.status_code == 200
assert response.json()["price"] == 25.00
assert response.json()["title"] == "Original" # Unchanged
def test_delete_request(self):
"""Test a DELETE request."""
# Create an item
create_response = client.post("/items", json={
"title": "To Delete",
"description": "This will be deleted",
"price": 5.00,
"quantity": 1
})
item_id = create_response.json()["id"]
# Delete it
response = client.delete(f"/items/{item_id}")
assert response.status_code == 204
# Verify it is gone
get_response = client.get(f"/items/{item_id}")
assert get_response.status_code == 404
class TestRequestHeaders:
"""Test sending custom headers."""
def test_custom_headers(self):
"""Send custom headers with a request."""
headers = {
"X-Custom-Header": "test-value",
"Accept-Language": "en-US"
}
response = client.get("/items", headers=headers)
assert response.status_code == 200
def test_content_type_header(self):
"""Verify correct content-type in response."""
response = client.get("/items")
assert "application/json" in response.headers["content-type"]
class TestQueryParameters:
"""Test various query parameter patterns."""
def test_single_query_param(self):
response = client.get("/items?skip=0")
assert response.status_code == 200
def test_multiple_query_params(self):
response = client.get("/items?skip=0&limit=10")
assert response.status_code == 200
def test_query_params_with_params_dict(self):
"""Use the params argument for cleaner code."""
response = client.get("/items", params={
"skip": 0,
"limit": 10,
"search": "test"
})
assert response.status_code == 200
def test_query_param_validation(self):
"""Test that invalid query params are rejected."""
response = client.get("/items?limit=-1")
assert response.status_code == 422
# tests/test_responses.py
"""
Comprehensive examples of inspecting response objects.
"""
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_response_status_codes():
"""Check various status code scenarios."""
# Success
assert client.get("/items").status_code == 200
# Created
response = client.post("/items", json={
"title": "New", "description": "Test",
"price": 1.0, "quantity": 1
})
assert response.status_code == 201
# Not Found
assert client.get("/items/99999").status_code == 404
# Validation Error
assert client.post("/items", json={}).status_code == 422
def test_response_json_body():
"""Parse and inspect JSON response body."""
response = client.get("/items")
# Parse JSON
data = response.json()
# Check structure
assert isinstance(data, list)
# Check individual items if any exist
if len(data) > 0:
item = data[0]
assert "id" in item
assert "title" in item
assert "price" in item
def test_response_headers():
"""Inspect response headers."""
response = client.get("/items")
# Check content type
assert "application/json" in response.headers["content-type"]
# Check custom headers if your app sets them
# assert response.headers.get("X-Request-ID") is not None
def test_response_text():
"""Access raw response text."""
response = client.get("/items")
# Raw text representation
text = response.text
assert isinstance(text, str)
# Useful for debugging
print(f"Response body: {text}")
def test_response_cookies():
"""Check cookies set by the response."""
response = client.post("/auth/login", json={
"email": "test@example.com",
"password": "TestPass123!"
})
# Access cookies
cookies = response.cookies
# If your app sets cookies:
# assert "session_id" in cookies
def test_response_timing():
"""Measure response time (useful for performance assertions)."""
import time
start = time.time()
response = client.get("/health")
elapsed = time.time() - start
assert response.status_code == 200
assert elapsed < 1.0 # Should respond within 1 second
While TestClient works synchronously, FastAPI endpoints are often asynchronous. For testing async code directly (such as async dependencies, database queries, or background tasks), you need async test support.
# Install the required packages pip install pytest-asyncio httpx
# pytest.ini — enable auto mode for async tests [pytest] asyncio_mode = auto
# tests/test_async.py
"""
Async testing with httpx AsyncClient.
The AsyncClient communicates with your FastAPI app using ASGI,
which means it can test truly async behavior.
"""
import pytest
from httpx import AsyncClient, ASGITransport
from app.main import app
@pytest.fixture
async def async_client():
"""Create an async test client."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
class TestAsyncEndpoints:
"""Test endpoints using async client."""
async def test_read_root(self, async_client):
"""Test GET / asynchronously."""
response = await async_client.get("/")
assert response.status_code == 200
assert response.json()["message"] == "Hello, World!"
async def test_create_item_async(self, async_client):
"""Test POST /items asynchronously."""
response = await async_client.post("/items", json={
"title": "Async Item",
"description": "Created in async test",
"price": 42.00,
"quantity": 3
})
assert response.status_code == 201
assert response.json()["title"] == "Async Item"
async def test_concurrent_requests(self, async_client):
"""Test multiple concurrent requests."""
import asyncio
# Send 10 requests concurrently
tasks = [
async_client.get(f"/items/{i}")
for i in range(1, 11)
]
responses = await asyncio.gather(*tasks, return_exceptions=True)
# All should complete without errors
for response in responses:
assert not isinstance(response, Exception)
# app/dependencies.py
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///./app.db"
async_engine = create_async_engine(ASYNC_DATABASE_URL)
AsyncSessionLocal = sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""Yield an async database session."""
async with AsyncSessionLocal() as session:
yield session
# tests/test_async_db.py
import pytest
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from app.main import app
from app.database import Base
from app.dependencies import get_async_db
# Test async engine using in-memory SQLite
TEST_ASYNC_DATABASE_URL = "sqlite+aiosqlite://"
test_async_engine = create_async_engine(TEST_ASYNC_DATABASE_URL)
TestAsyncSession = sessionmaker(
test_async_engine, class_=AsyncSession, expire_on_commit=False
)
@pytest.fixture(autouse=True)
async def setup_test_db():
"""Create tables before each test, drop after."""
async with test_async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with test_async_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def async_db_session():
"""Provide a test async database session."""
async with TestAsyncSession() as session:
yield session
@pytest.fixture
async def async_client(async_db_session):
"""Async client with overridden database dependency."""
async def override_get_async_db():
yield async_db_session
app.dependency_overrides[get_async_db] = override_get_async_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class TestAsyncDatabaseOperations:
"""Test async database operations."""
async def test_create_and_retrieve_user(self, async_client):
"""Test creating a user with async database."""
# Create
response = await async_client.post("/users", json={
"name": "Async User",
"email": "async@example.com",
"password": "AsyncPass123!"
})
assert response.status_code == 201
user_id = response.json()["id"]
# Retrieve
response = await async_client.get(f"/users/{user_id}")
assert response.status_code == 200
assert response.json()["name"] == "Async User"
| Scenario | Use Sync (TestClient) | Use Async (AsyncClient) |
|---|---|---|
| Simple endpoint testing | Yes | Optional |
| Testing async dependencies | No | Yes |
| Testing async database operations | No | Yes |
| Testing WebSockets | Yes (built-in support) | No |
| Testing concurrent behavior | No | Yes |
| Testing background tasks | Partial | Yes |
| Testing SSE/streaming | Partial | Yes |
TestClient for most tests. Only use AsyncClient when you specifically need to test async behavior, concurrent operations, or async dependencies.
Route testing is where you verify that your API endpoints accept correct input, return proper responses, and handle edge cases gracefully. Let us build a comprehensive set of route tests.
# app/routers/items.py
"""
Item router with full CRUD operations.
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.item import Item as ItemModel
from app.schemas.item import ItemCreate, ItemUpdate, ItemResponse
router = APIRouter(prefix="/items", tags=["items"])
@router.get("/", response_model=List[ItemResponse])
def list_items(
skip: int = Query(0, ge=0, description="Number of items to skip"),
limit: int = Query(10, ge=1, le=100, description="Max items to return"),
search: Optional[str] = Query(None, min_length=1, max_length=100),
db: Session = Depends(get_db),
):
"""List items with pagination and optional search."""
query = db.query(ItemModel)
if search:
query = query.filter(ItemModel.title.contains(search))
items = query.offset(skip).limit(limit).all()
return items
@router.get("/{item_id}", response_model=ItemResponse)
def get_item(item_id: int, db: Session = Depends(get_db)):
"""Get a single item by ID."""
item = db.query(ItemModel).filter(ItemModel.id == item_id).first()
if not item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Item with id {item_id} not found"
)
return item
@router.post("/", response_model=ItemResponse, status_code=status.HTTP_201_CREATED)
def create_item(item_data: ItemCreate, db: Session = Depends(get_db)):
"""Create a new item."""
item = ItemModel(**item_data.model_dump())
db.add(item)
db.commit()
db.refresh(item)
return item
@router.put("/{item_id}", response_model=ItemResponse)
def update_item(item_id: int, item_data: ItemUpdate, db: Session = Depends(get_db)):
"""Update an existing item (full replacement)."""
item = db.query(ItemModel).filter(ItemModel.id == item_id).first()
if not item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Item with id {item_id} not found"
)
for field, value in item_data.model_dump().items():
setattr(item, field, value)
db.commit()
db.refresh(item)
return item
@router.delete("/{item_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_item(item_id: int, db: Session = Depends(get_db)):
"""Delete an item."""
item = db.query(ItemModel).filter(ItemModel.id == item_id).first()
if not item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Item with id {item_id} not found"
)
db.delete(item)
db.commit()
# tests/integration/test_items.py
"""
Comprehensive tests for the items router.
Covers all CRUD operations, edge cases, and error scenarios.
"""
import pytest
class TestListItems:
"""Tests for GET /items."""
def test_list_items_empty(self, client):
"""Returns empty list when no items exist."""
response = client.get("/items")
assert response.status_code == 200
assert response.json() == []
def test_list_items_returns_all(self, client, sample_item_data):
"""Returns all created items."""
# Create 3 items
for i in range(3):
data = {**sample_item_data, "title": f"Item {i}"}
client.post("/items", json=data)
response = client.get("/items")
assert response.status_code == 200
assert len(response.json()) == 3
def test_list_items_pagination_skip(self, client, sample_item_data):
"""Skip parameter works correctly."""
for i in range(5):
data = {**sample_item_data, "title": f"Item {i}"}
client.post("/items", json=data)
response = client.get("/items?skip=3")
assert response.status_code == 200
assert len(response.json()) == 2 # 5 total - 3 skipped
def test_list_items_pagination_limit(self, client, sample_item_data):
"""Limit parameter caps the results."""
for i in range(5):
data = {**sample_item_data, "title": f"Item {i}"}
client.post("/items", json=data)
response = client.get("/items?limit=2")
assert response.status_code == 200
assert len(response.json()) == 2
def test_list_items_pagination_combined(self, client, sample_item_data):
"""Skip and limit work together."""
for i in range(10):
data = {**sample_item_data, "title": f"Item {i}"}
client.post("/items", json=data)
response = client.get("/items?skip=2&limit=3")
assert response.status_code == 200
data = response.json()
assert len(data) == 3
assert data[0]["title"] == "Item 2"
def test_list_items_search(self, client):
"""Search filter returns matching items only."""
client.post("/items", json={
"title": "Python Tutorial",
"description": "Learn Python",
"price": 29.99, "quantity": 1
})
client.post("/items", json={
"title": "Java Tutorial",
"description": "Learn Java",
"price": 39.99, "quantity": 1
})
response = client.get("/items?search=Python")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["title"] == "Python Tutorial"
def test_list_items_search_no_results(self, client, sample_item_data):
"""Search with no matches returns empty list."""
client.post("/items", json=sample_item_data)
response = client.get("/items?search=nonexistent")
assert response.status_code == 200
assert response.json() == []
def test_list_items_invalid_skip(self, client):
"""Negative skip value returns 422."""
response = client.get("/items?skip=-1")
assert response.status_code == 422
def test_list_items_invalid_limit(self, client):
"""Limit exceeding max returns 422."""
response = client.get("/items?limit=101")
assert response.status_code == 422
def test_list_items_limit_zero(self, client):
"""Zero limit returns 422."""
response = client.get("/items?limit=0")
assert response.status_code == 422
class TestGetItem:
"""Tests for GET /items/{item_id}."""
def test_get_existing_item(self, client, sample_item_data):
"""Successfully retrieve an existing item."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
response = client.get(f"/items/{item_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == item_id
assert data["title"] == sample_item_data["title"]
assert data["price"] == sample_item_data["price"]
def test_get_nonexistent_item(self, client):
"""Returns 404 for non-existent item."""
response = client.get("/items/99999")
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_get_item_invalid_id_type(self, client):
"""Returns 422 for invalid ID type."""
response = client.get("/items/abc")
assert response.status_code == 422
def test_get_item_response_structure(self, client, sample_item_data):
"""Verify the response has all expected fields."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
response = client.get(f"/items/{item_id}")
data = response.json()
expected_fields = {"id", "title", "description", "price", "quantity"}
assert expected_fields.issubset(set(data.keys()))
class TestCreateItem:
"""Tests for POST /items."""
def test_create_item_success(self, client, sample_item_data):
"""Successfully create a new item."""
response = client.post("/items", json=sample_item_data)
assert response.status_code == 201
data = response.json()
assert data["title"] == sample_item_data["title"]
assert data["description"] == sample_item_data["description"]
assert data["price"] == sample_item_data["price"]
assert "id" in data
def test_create_item_auto_generates_id(self, client, sample_item_data):
"""Created item gets an auto-generated ID."""
response = client.post("/items", json=sample_item_data)
assert "id" in response.json()
assert isinstance(response.json()["id"], int)
def test_create_item_persists(self, client, sample_item_data):
"""Created item can be retrieved afterward."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
get_response = client.get(f"/items/{item_id}")
assert get_response.status_code == 200
assert get_response.json()["title"] == sample_item_data["title"]
def test_create_multiple_items(self, client, multiple_items_data):
"""Create multiple items and verify count."""
for item_data in multiple_items_data:
response = client.post("/items", json=item_data)
assert response.status_code == 201
list_response = client.get("/items")
assert len(list_response.json()) == len(multiple_items_data)
def test_create_item_missing_required_field(self, client):
"""Missing required fields return 422."""
response = client.post("/items", json={"title": "Incomplete"})
assert response.status_code == 422
def test_create_item_empty_body(self, client):
"""Empty body returns 422."""
response = client.post("/items", json={})
assert response.status_code == 422
class TestUpdateItem:
"""Tests for PUT /items/{item_id}."""
def test_update_item_success(self, client, sample_item_data):
"""Successfully update an existing item."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
update_data = {
"title": "Updated Title",
"description": "Updated description",
"price": 99.99,
"quantity": 20
}
response = client.put(f"/items/{item_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["title"] == "Updated Title"
assert data["price"] == 99.99
def test_update_nonexistent_item(self, client):
"""Updating non-existent item returns 404."""
update_data = {
"title": "Ghost", "description": "Does not exist",
"price": 0.0, "quantity": 0
}
response = client.put("/items/99999", json=update_data)
assert response.status_code == 404
def test_update_persists_changes(self, client, sample_item_data):
"""Changes are persisted after update."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
client.put(f"/items/{item_id}", json={
"title": "Persisted",
"description": "This should persist",
"price": 55.55,
"quantity": 5
})
get_response = client.get(f"/items/{item_id}")
assert get_response.json()["title"] == "Persisted"
class TestDeleteItem:
"""Tests for DELETE /items/{item_id}."""
def test_delete_item_success(self, client, sample_item_data):
"""Successfully delete an existing item."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
response = client.delete(f"/items/{item_id}")
assert response.status_code == 204
def test_delete_item_removes_from_database(self, client, sample_item_data):
"""Deleted item can no longer be retrieved."""
create_response = client.post("/items", json=sample_item_data)
item_id = create_response.json()["id"]
client.delete(f"/items/{item_id}")
get_response = client.get(f"/items/{item_id}")
assert get_response.status_code == 404
def test_delete_nonexistent_item(self, client):
"""Deleting non-existent item returns 404."""
response = client.delete("/items/99999")
assert response.status_code == 404
def test_delete_reduces_count(self, client, sample_item_data):
"""Deleting an item reduces the total count."""
# Create 3 items
ids = []
for i in range(3):
data = {**sample_item_data, "title": f"Item {i}"}
resp = client.post("/items", json=data)
ids.append(resp.json()["id"])
assert len(client.get("/items").json()) == 3
# Delete one
client.delete(f"/items/{ids[0]}")
assert len(client.get("/items").json()) == 2
FastAPI uses Pydantic for request validation. Testing validation ensures that your API properly rejects bad input and returns informative error messages.
# app/schemas/item.py
from pydantic import BaseModel, Field, field_validator
from typing import Optional
class ItemCreate(BaseModel):
"""Schema for creating a new item."""
title: str = Field(..., min_length=1, max_length=200)
description: Optional[str] = Field(None, max_length=1000)
price: float = Field(..., gt=0, description="Price must be positive")
quantity: int = Field(..., ge=0, description="Quantity cannot be negative")
@field_validator("title")
@classmethod
def title_must_not_be_blank(cls, v):
if v.strip() == "":
raise ValueError("Title cannot be blank or whitespace only")
return v.strip()
@field_validator("price")
@classmethod
def price_must_have_two_decimals(cls, v):
if round(v, 2) != v:
raise ValueError("Price must have at most 2 decimal places")
return v
class ItemUpdate(BaseModel):
"""Schema for updating an item."""
title: str = Field(..., min_length=1, max_length=200)
description: Optional[str] = Field(None, max_length=1000)
price: float = Field(..., gt=0)
quantity: int = Field(..., ge=0)
class ItemResponse(BaseModel):
"""Schema for item responses."""
id: int
title: str
description: Optional[str]
price: float
quantity: int
model_config = {"from_attributes": True}
# tests/unit/test_schemas.py
"""
Test Pydantic schema validation independently of the API.
These are pure unit tests — no HTTP calls needed.
"""
import pytest
from pydantic import ValidationError
from app.schemas.item import ItemCreate, ItemUpdate, ItemResponse
class TestItemCreateValidation:
"""Test ItemCreate schema validation rules."""
def test_valid_item(self):
"""Valid data creates an item successfully."""
item = ItemCreate(
title="Test Item",
description="A valid item",
price=29.99,
quantity=10
)
assert item.title == "Test Item"
assert item.price == 29.99
def test_title_required(self):
"""Missing title raises ValidationError."""
with pytest.raises(ValidationError) as exc_info:
ItemCreate(
description="No title",
price=10.0,
quantity=1
)
errors = exc_info.value.errors()
assert any(e["loc"] == ("title",) for e in errors)
def test_title_min_length(self):
"""Empty string title is rejected."""
with pytest.raises(ValidationError):
ItemCreate(title="", price=10.0, quantity=1)
def test_title_max_length(self):
"""Title exceeding 200 chars is rejected."""
with pytest.raises(ValidationError):
ItemCreate(
title="x" * 201,
price=10.0,
quantity=1
)
def test_title_whitespace_only(self):
"""Whitespace-only title is rejected by custom validator."""
with pytest.raises(ValidationError) as exc_info:
ItemCreate(title=" ", price=10.0, quantity=1)
assert "blank" in str(exc_info.value).lower()
def test_title_stripped(self):
"""Title is stripped of leading/trailing whitespace."""
item = ItemCreate(
title=" Trimmed Title ",
price=10.0,
quantity=1
)
assert item.title == "Trimmed Title"
def test_price_required(self):
"""Missing price raises ValidationError."""
with pytest.raises(ValidationError):
ItemCreate(title="No Price", quantity=1)
def test_price_must_be_positive(self):
"""Zero or negative price is rejected."""
with pytest.raises(ValidationError):
ItemCreate(title="Free Item", price=0, quantity=1)
with pytest.raises(ValidationError):
ItemCreate(title="Negative", price=-10.0, quantity=1)
def test_price_too_many_decimals(self):
"""Price with more than 2 decimal places is rejected."""
with pytest.raises(ValidationError):
ItemCreate(title="Precise", price=10.999, quantity=1)
def test_quantity_cannot_be_negative(self):
"""Negative quantity is rejected."""
with pytest.raises(ValidationError):
ItemCreate(title="Negative Qty", price=10.0, quantity=-1)
def test_quantity_zero_allowed(self):
"""Zero quantity is allowed (out of stock)."""
item = ItemCreate(title="Out of Stock", price=10.0, quantity=0)
assert item.quantity == 0
def test_description_optional(self):
"""Description is optional and defaults to None."""
item = ItemCreate(title="No Description", price=10.0, quantity=1)
assert item.description is None
def test_description_max_length(self):
"""Description exceeding 1000 chars is rejected."""
with pytest.raises(ValidationError):
ItemCreate(
title="Long Desc",
description="x" * 1001,
price=10.0,
quantity=1
)
class TestValidationErrorResponses:
"""Test that the API returns proper validation error responses."""
def test_missing_field_error_format(self, client):
"""Validation errors include field location and message."""
response = client.post("/items", json={
"description": "Missing title and price"
})
assert response.status_code == 422
errors = response.json()["detail"]
assert isinstance(errors, list)
assert len(errors) > 0
# Each error should have location, message, and type
for error in errors:
assert "loc" in error
assert "msg" in error
assert "type" in error
def test_invalid_type_error(self, client):
"""Sending wrong type returns descriptive error."""
response = client.post("/items", json={
"title": "Test",
"price": "not_a_number",
"quantity": 1
})
assert response.status_code == 422
errors = response.json()["detail"]
price_errors = [e for e in errors if "price" in str(e.get("loc", []))]
assert len(price_errors) > 0
def test_multiple_validation_errors(self, client):
"""Multiple invalid fields return multiple errors."""
response = client.post("/items", json={
"title": "",
"price": -5,
"quantity": -1
})
assert response.status_code == 422
errors = response.json()["detail"]
assert len(errors) >= 2 # At least title and price errors
def test_extra_fields_ignored(self, client):
"""Extra fields in request body are silently ignored."""
response = client.post("/items", json={
"title": "Valid Item",
"description": "Valid",
"price": 10.0,
"quantity": 1,
"extra_field": "should be ignored",
"another_extra": 42
})
assert response.status_code == 201
assert "extra_field" not in response.json()
Database testing verifies that your application correctly interacts with the database. The key challenge is test isolation — each test should start with a clean database and not affect other tests.
# app/database.py
"""
Production database configuration.
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, DeclarativeBase
SQLALCHEMY_DATABASE_URL = "postgresql://user:pass@localhost/mydb"
engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
class Base(DeclarativeBase):
pass
def get_db():
"""Dependency that provides a database session."""
db = SessionLocal()
try:
yield db
finally:
db.close()
# app/models/item.py
"""SQLAlchemy Item model."""
from sqlalchemy import Column, Integer, String, Float
from app.database import Base
class Item(Base):
__tablename__ = "items"
id = Column(Integer, primary_key=True, index=True)
title = Column(String(200), nullable=False)
description = Column(String(1000), nullable=True)
price = Column(Float, nullable=False)
quantity = Column(Integer, nullable=False, default=0)
# app/models/user.py
"""SQLAlchemy User model."""
from sqlalchemy import Column, Integer, String, Boolean, DateTime
from sqlalchemy.sql import func
from app.database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
email = Column(String(255), unique=True, nullable=False, index=True)
hashed_password = Column(String(255), nullable=False)
is_active = Column(Boolean, default=True)
role = Column(String(20), default="user")
created_at = Column(DateTime(timezone=True), server_default=func.now())
# tests/conftest.py — Database fixtures with transaction rollback
"""
Two approaches to database test isolation:
1. Drop/Create tables (simple but slower)
2. Transaction rollback (fast but more complex)
Below we show the rollback approach for maximum speed.
"""
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.main import app
from app.database import Base, get_db
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
# Create tables once for the entire test session
@pytest.fixture(scope="session", autouse=True)
def create_tables():
"""Create all tables once before running tests."""
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def db_session():
"""
Transaction rollback pattern:
1. Begin a transaction
2. Create a session bound to this transaction
3. Run the test
4. Roll back the transaction (undoing all changes)
This is faster than drop/create because it avoids DDL operations.
"""
connection = engine.connect()
transaction = connection.begin()
# Bind a session to the ongoing transaction
session = TestingSessionLocal(bind=connection)
# Begin a nested transaction (savepoint)
nested = connection.begin_nested()
# Restart the nested transaction on each commit
@event.listens_for(session, "after_transaction_end")
def end_savepoint(session, transaction):
nonlocal nested
if not nested.is_active:
nested = connection.begin_nested()
yield session
# Roll back everything
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(scope="function")
def client(db_session):
"""TestClient with rolled-back database session."""
def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
app.dependency_overrides.clear()
# tests/integration/test_database.py
"""
Tests that verify database operations work correctly.
"""
import pytest
from app.models.item import Item
from app.models.user import User
class TestItemDatabaseOperations:
"""Test Item model database interactions."""
def test_create_item_in_database(self, db_session):
"""Directly create and query an item from the database."""
item = Item(
title="DB Test Item",
description="Created directly in DB",
price=19.99,
quantity=5
)
db_session.add(item)
db_session.commit()
# Query it back
saved = db_session.query(Item).filter_by(title="DB Test Item").first()
assert saved is not None
assert saved.price == 19.99
assert saved.quantity == 5
def test_update_item_in_database(self, db_session):
"""Update an item and verify changes persist."""
item = Item(title="Original", price=10.0, quantity=1)
db_session.add(item)
db_session.commit()
# Update
item.title = "Modified"
item.price = 25.0
db_session.commit()
# Verify
updated = db_session.query(Item).filter_by(id=item.id).first()
assert updated.title == "Modified"
assert updated.price == 25.0
def test_delete_item_from_database(self, db_session):
"""Delete an item and verify it is gone."""
item = Item(title="To Delete", price=5.0, quantity=1)
db_session.add(item)
db_session.commit()
item_id = item.id
db_session.delete(item)
db_session.commit()
deleted = db_session.query(Item).filter_by(id=item_id).first()
assert deleted is None
def test_query_items_with_filter(self, db_session):
"""Test complex queries with filters."""
items = [
Item(title="Cheap Item", price=5.0, quantity=100),
Item(title="Mid Item", price=25.0, quantity=50),
Item(title="Expensive Item", price=100.0, quantity=10),
]
db_session.add_all(items)
db_session.commit()
# Find items under $30
cheap_items = db_session.query(Item).filter(Item.price < 30.0).all()
assert len(cheap_items) == 2
# Find items with quantity over 20
stocked = db_session.query(Item).filter(Item.quantity > 20).all()
assert len(stocked) == 2
def test_database_isolation_between_tests(self, db_session):
"""
Verify that each test starts with a clean database.
If isolation works correctly, this test should find
zero items (regardless of what previous tests created).
"""
count = db_session.query(Item).count()
assert count == 0
class TestUserDatabaseOperations:
"""Test User model database interactions."""
def test_create_user(self, db_session):
"""Create a user and verify all fields."""
user = User(
name="Alice",
email="alice@example.com",
hashed_password="hashed_abc123",
role="user"
)
db_session.add(user)
db_session.commit()
saved = db_session.query(User).filter_by(email="alice@example.com").first()
assert saved is not None
assert saved.name == "Alice"
assert saved.is_active is True # Default value
assert saved.role == "user"
def test_email_uniqueness(self, db_session):
"""Duplicate email should raise an error."""
user1 = User(
name="Alice",
email="duplicate@example.com",
hashed_password="hash1"
)
db_session.add(user1)
db_session.commit()
user2 = User(
name="Bob",
email="duplicate@example.com",
hashed_password="hash2"
)
db_session.add(user2)
with pytest.raises(Exception): # IntegrityError
db_session.commit()
db_session.rollback()
def test_user_default_values(self, db_session):
"""Verify default values are set correctly."""
user = User(
name="Default User",
email="default@example.com",
hashed_password="hash"
)
db_session.add(user)
db_session.commit()
assert user.is_active is True
assert user.role == "user"
assert user.created_at is not None
# tests/factories.py
"""
Factory fixtures for creating test data.
Factories encapsulate object creation logic, making tests
cleaner and reducing duplication.
"""
import pytest
from app.models.item import Item
from app.models.user import User
@pytest.fixture
def item_factory(db_session):
"""
Factory fixture for creating Item instances.
Usage:
def test_something(item_factory):
item = item_factory(title="Custom Title", price=25.0)
"""
created_items = []
def _create_item(
title="Test Item",
description="Test Description",
price=10.0,
quantity=5,
):
item = Item(
title=title,
description=description,
price=price,
quantity=quantity,
)
db_session.add(item)
db_session.commit()
db_session.refresh(item)
created_items.append(item)
return item
yield _create_item
@pytest.fixture
def user_factory(db_session):
"""Factory fixture for creating User instances."""
counter = 0
def _create_user(
name=None,
email=None,
password="hashed_default",
role="user",
is_active=True,
):
nonlocal counter
counter += 1
user = User(
name=name or f"User {counter}",
email=email or f"user{counter}@example.com",
hashed_password=password,
role=role,
is_active=is_active,
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
yield _create_user
# tests/integration/test_with_factories.py
"""Using factory fixtures for cleaner tests."""
class TestWithFactories:
"""Demonstrate factory fixture usage."""
def test_list_items_with_factory(self, client, item_factory):
"""Use the factory to create test data."""
item_factory(title="Item A", price=10.0)
item_factory(title="Item B", price=20.0)
item_factory(title="Item C", price=30.0)
response = client.get("/items")
assert len(response.json()) == 3
def test_search_with_factory(self, client, item_factory):
"""Factory makes it easy to set up specific scenarios."""
item_factory(title="Python Guide", price=29.99)
item_factory(title="Java Guide", price=39.99)
item_factory(title="Python Advanced", price=49.99)
response = client.get("/items?search=Python")
assert len(response.json()) == 2
def test_user_roles_with_factory(self, client, user_factory):
"""Create users with different roles."""
admin = user_factory(name="Admin", role="admin")
user = user_factory(name="Regular", role="user")
assert admin.role == "admin"
assert user.role == "user"
FastAPI’s dependency injection system is one of its most powerful features, and app.dependency_overrides makes it trivially easy to swap real dependencies with test doubles.
# app/dependencies.py
"""
Application dependencies that can be overridden in tests.
"""
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.user import User
security = HTTPBearer()
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db),
) -> User:
"""Decode JWT token and return the current user."""
token = credentials.credentials
# In production, this decodes and validates the JWT
payload = decode_jwt_token(token) # Raises on invalid token
user = db.query(User).filter(User.id == payload["sub"]).first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found"
)
return user
def get_current_admin(
user: User = Depends(get_current_user),
) -> User:
"""Ensure the current user has admin role."""
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin access required"
)
return user
class EmailService:
"""Service for sending emails."""
def send_welcome_email(self, email: str, name: str) -> bool:
"""Send a welcome email to a new user."""
# In production, this calls an email API
return True
def send_password_reset(self, email: str, token: str) -> bool:
"""Send a password reset email."""
return True
def get_email_service() -> EmailService:
"""Dependency that provides the email service."""
return EmailService()
# tests/test_dependency_overrides.py
"""
Demonstrates various ways to override dependencies in tests.
"""
import pytest
from fastapi.testclient import TestClient
from app.main import app
from app.dependencies import (
get_current_user,
get_current_admin,
get_email_service,
EmailService,
)
from app.models.user import User
# ---------------------------------------------------------------------------
# Approach 1: Simple function override
# ---------------------------------------------------------------------------
class TestWithSimpleOverride:
"""Override a dependency with a simple function."""
@pytest.fixture(autouse=True)
def setup(self):
"""Set up and tear down dependency overrides."""
# Create a fake user to return
fake_user = User(
id=1,
name="Test User",
email="test@example.com",
hashed_password="fake",
role="user",
is_active=True,
)
def override_get_current_user():
return fake_user
app.dependency_overrides[get_current_user] = override_get_current_user
yield
app.dependency_overrides.clear()
def test_protected_endpoint(self):
"""Access a protected endpoint without real authentication."""
client = TestClient(app)
response = client.get("/users/me")
assert response.status_code == 200
assert response.json()["email"] == "test@example.com"
# ---------------------------------------------------------------------------
# Approach 2: Parameterized override fixture
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_current_user():
"""Fixture that creates a mock user and overrides the dependency."""
def _create_override(
user_id=1,
name="Test User",
email="test@example.com",
role="user",
):
user = User(
id=user_id,
name=name,
email=email,
hashed_password="fake",
role=role,
is_active=True,
)
def override():
return user
app.dependency_overrides[get_current_user] = override
return user
yield _create_override
app.dependency_overrides.clear()
class TestWithParameterizedOverride:
"""Use parameterized overrides for different scenarios."""
def test_regular_user_access(self, mock_current_user):
"""Test as a regular user."""
mock_current_user(role="user")
client = TestClient(app)
response = client.get("/users/me")
assert response.status_code == 200
def test_admin_endpoint_as_user(self, mock_current_user):
"""Regular user cannot access admin endpoints."""
mock_current_user(role="user")
client = TestClient(app)
response = client.get("/admin/dashboard")
assert response.status_code == 403
def test_admin_endpoint_as_admin(self, mock_current_user):
"""Admin user can access admin endpoints."""
mock_current_user(role="admin")
# Also override the admin check
def override_admin():
return User(
id=1, name="Admin", email="admin@example.com",
hashed_password="fake", role="admin", is_active=True
)
app.dependency_overrides[get_current_admin] = override_admin
client = TestClient(app)
response = client.get("/admin/dashboard")
assert response.status_code == 200
# ---------------------------------------------------------------------------
# Approach 3: Override with a mock class
# ---------------------------------------------------------------------------
class MockEmailService(EmailService):
"""Mock email service that records calls instead of sending."""
def __init__(self):
self.sent_emails = []
def send_welcome_email(self, email: str, name: str) -> bool:
self.sent_emails.append({
"type": "welcome",
"email": email,
"name": name,
})
return True
def send_password_reset(self, email: str, token: str) -> bool:
self.sent_emails.append({
"type": "password_reset",
"email": email,
"token": token,
})
return True
@pytest.fixture
def mock_email_service():
"""Override the email service with a mock."""
mock_service = MockEmailService()
def override():
return mock_service
app.dependency_overrides[get_email_service] = override
yield mock_service
app.dependency_overrides.clear()
class TestEmailServiceOverride:
"""Test email-related functionality with mock service."""
def test_registration_sends_welcome_email(
self, client, mock_email_service
):
"""Registering a user should send a welcome email."""
client.post("/auth/register", json={
"name": "New User",
"email": "new@example.com",
"password": "Pass123!"
})
# Verify the welcome email was "sent"
assert len(mock_email_service.sent_emails) == 1
assert mock_email_service.sent_emails[0]["type"] == "welcome"
assert mock_email_service.sent_emails[0]["email"] == "new@example.com"
def test_password_reset_sends_email(
self, client, mock_email_service
):
"""Password reset should send a reset email."""
# Assume user already exists
client.post("/auth/password-reset", json={
"email": "existing@example.com"
})
reset_emails = [
e for e in mock_email_service.sent_emails
if e["type"] == "password_reset"
]
assert len(reset_emails) == 1
app.dependency_overrides.clear() after your tests to prevent overrides from leaking between tests. Using a fixture with proper teardown (after yield) ensures this happens automatically.
Real applications interact with external services: APIs, email providers, payment gateways, and cloud storage. In tests, you should mock these services to avoid real network calls, costs, and flakiness.
# app/services/weather_service.py
"""External weather API service."""
import httpx
class WeatherService:
"""Fetches weather data from an external API."""
BASE_URL = "https://api.weatherapi.com/v1"
def __init__(self, api_key: str):
self.api_key = api_key
def get_current_weather(self, city: str) -> dict:
"""Fetch current weather for a city."""
response = httpx.get(
f"{self.BASE_URL}/current.json",
params={"key": self.api_key, "q": city}
)
response.raise_for_status()
return response.json()
def get_forecast(self, city: str, days: int = 3) -> dict:
"""Fetch weather forecast."""
response = httpx.get(
f"{self.BASE_URL}/forecast.json",
params={"key": self.api_key, "q": city, "days": days}
)
response.raise_for_status()
return response.json()
# tests/test_mocking.py
"""
Comprehensive mocking examples using unittest.mock.
"""
from unittest.mock import patch, MagicMock, AsyncMock
import pytest
import httpx
from app.services.weather_service import WeatherService
class TestWeatherServiceMocking:
"""Mock external HTTP calls in the weather service."""
@patch("app.services.weather_service.httpx.get")
def test_get_current_weather(self, mock_get):
"""Mock the HTTP call to return fake weather data."""
# Configure the mock to return a fake response
mock_response = MagicMock()
mock_response.json.return_value = {
"location": {"name": "London"},
"current": {
"temp_c": 15.0,
"condition": {"text": "Partly cloudy"}
}
}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
# Call the service
service = WeatherService(api_key="fake-key")
result = service.get_current_weather("London")
# Verify the result
assert result["location"]["name"] == "London"
assert result["current"]["temp_c"] == 15.0
# Verify the HTTP call was made correctly
mock_get.assert_called_once()
call_args = mock_get.call_args
assert "London" in str(call_args)
@patch("app.services.weather_service.httpx.get")
def test_weather_api_error(self, mock_get):
"""Test handling of API errors."""
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Server Error",
request=MagicMock(),
response=MagicMock(status_code=500)
)
mock_get.return_value = mock_response
service = WeatherService(api_key="fake-key")
with pytest.raises(httpx.HTTPStatusError):
service.get_current_weather("London")
@patch("app.services.weather_service.httpx.get")
def test_weather_network_error(self, mock_get):
"""Test handling of network errors."""
mock_get.side_effect = httpx.ConnectError("Connection refused")
service = WeatherService(api_key="fake-key")
with pytest.raises(httpx.ConnectError):
service.get_current_weather("London")
@patch("app.services.weather_service.httpx.get")
def test_forecast_calls_correct_endpoint(self, mock_get):
"""Verify the forecast method calls the correct URL."""
mock_response = MagicMock()
mock_response.json.return_value = {"forecast": {"forecastday": []}}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
service = WeatherService(api_key="test-key")
service.get_forecast("Paris", days=5)
# Check that the correct URL was called
call_args = mock_get.call_args
url = call_args[0][0] if call_args[0] else call_args[1].get("url", "")
assert "forecast.json" in str(call_args)
class TestMockingPatterns:
"""Common mocking patterns for FastAPI tests."""
def test_mock_with_context_manager(self):
"""Use patch as a context manager."""
with patch("app.services.weather_service.httpx.get") as mock_get:
mock_get.return_value = MagicMock(
json=MagicMock(return_value={"data": "test"}),
raise_for_status=MagicMock()
)
service = WeatherService(api_key="key")
result = service.get_current_weather("Tokyo")
assert result == {"data": "test"}
def test_mock_with_side_effect_list(self):
"""Return different values on consecutive calls."""
with patch("app.services.weather_service.httpx.get") as mock_get:
responses = [
MagicMock(
json=MagicMock(return_value={"temp": 20}),
raise_for_status=MagicMock()
),
MagicMock(
json=MagicMock(return_value={"temp": 25}),
raise_for_status=MagicMock()
),
]
mock_get.side_effect = responses
service = WeatherService(api_key="key")
first = service.get_current_weather("London")
second = service.get_current_weather("Paris")
assert first["temp"] == 20
assert second["temp"] == 25
def test_assert_call_count(self):
"""Verify how many times a mock was called."""
with patch("app.services.weather_service.httpx.get") as mock_get:
mock_get.return_value = MagicMock(
json=MagicMock(return_value={}),
raise_for_status=MagicMock()
)
service = WeatherService(api_key="key")
service.get_current_weather("A")
service.get_current_weather("B")
service.get_current_weather("C")
assert mock_get.call_count == 3
# tests/test_pytest_mock.py
"""
pytest-mock provides a cleaner interface for mocking.
The 'mocker' fixture is automatically available.
"""
import pytest
import httpx
class TestWithPytestMock:
"""Use the mocker fixture for cleaner mocking."""
def test_mock_external_call(self, mocker):
"""Mock using the mocker fixture."""
mock_get = mocker.patch("app.services.weather_service.httpx.get")
mock_get.return_value.json.return_value = {
"location": {"name": "Berlin"},
"current": {"temp_c": 22.0}
}
mock_get.return_value.raise_for_status = mocker.MagicMock()
from app.services.weather_service import WeatherService
service = WeatherService(api_key="key")
result = service.get_current_weather("Berlin")
assert result["current"]["temp_c"] == 22.0
def test_spy_on_method(self, mocker):
"""Spy on a method to verify it was called without changing behavior."""
from app.services.weather_service import WeatherService
spy = mocker.spy(WeatherService, "get_current_weather")
# This would still make the real call, so we also mock httpx
mocker.patch("app.services.weather_service.httpx.get").return_value = (
mocker.MagicMock(
json=mocker.MagicMock(return_value={"data": "test"}),
raise_for_status=mocker.MagicMock()
)
)
service = WeatherService(api_key="key")
service.get_current_weather("London")
# Verify the method was called
spy.assert_called_once_with(service, "London")
def test_mock_email_sending(self, mocker, client):
"""Mock email sending in an endpoint test."""
mock_send = mocker.patch(
"app.services.email_service.EmailService.send_welcome_email",
return_value=True
)
response = client.post("/auth/register", json={
"name": "New User",
"email": "new@example.com",
"password": "Pass123!"
})
# Verify the email was "sent"
mock_send.assert_called_once_with("new@example.com", "New User")
def test_mock_datetime(self, mocker):
"""Mock the current time for time-dependent tests."""
from datetime import datetime
mock_now = mocker.patch("app.services.user_service.datetime")
mock_now.utcnow.return_value = datetime(2025, 6, 15, 12, 0, 0)
# Now any code that calls datetime.utcnow() will get the mocked time
# app/services/async_api_service.py
"""Async external API service."""
import httpx
class AsyncAPIService:
"""Async service for external API calls."""
async def fetch_user_data(self, user_id: int) -> dict:
"""Fetch user data from external API."""
async with httpx.AsyncClient() as client:
response = await client.get(
f"https://api.example.com/users/{user_id}"
)
response.raise_for_status()
return response.json()
# tests/test_async_mocking.py
"""Mocking async external calls."""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
class TestAsyncMocking:
"""Test async mocking patterns."""
async def test_mock_async_http_call(self):
"""Mock an async HTTP call."""
with patch("app.services.async_api_service.httpx.AsyncClient") as MockClient:
# Set up the mock
mock_response = MagicMock()
mock_response.json.return_value = {
"id": 1,
"name": "Mocked User"
}
mock_response.raise_for_status = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_client_instance.__aenter__.return_value = mock_client_instance
mock_client_instance.__aexit__.return_value = None
MockClient.return_value = mock_client_instance
from app.services.async_api_service import AsyncAPIService
service = AsyncAPIService()
result = await service.fetch_user_data(1)
assert result["name"] == "Mocked User"
mock_client_instance.get.assert_called_once()
Testing authentication ensures that your security layer works correctly: valid credentials grant access, invalid credentials are rejected, and roles are enforced.
# app/routers/auth.py
"""Authentication router."""
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from jose import jwt, JWTError
from passlib.context import CryptContext
from app.database import get_db
from app.models.user import User
from app.schemas.user import UserCreate, UserLogin, TokenResponse
router = APIRouter(prefix="/auth", tags=["auth"])
SECRET_KEY = "your-secret-key-here"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
security = HTTPBearer()
def create_access_token(data: dict, expires_delta: timedelta = None):
"""Create a JWT access token."""
to_encode = data.copy()
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def hash_password(password: str) -> str:
"""Hash a password."""
return pwd_context.hash(password)
@router.post("/register", status_code=status.HTTP_201_CREATED)
def register(user_data: UserCreate, db: Session = Depends(get_db)):
"""Register a new user."""
# Check if email already exists
existing = db.query(User).filter(User.email == user_data.email).first()
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Email already registered"
)
user = User(
name=user_data.name,
email=user_data.email,
hashed_password=hash_password(user_data.password),
role=getattr(user_data, "role", "user"),
)
db.add(user)
db.commit()
db.refresh(user)
return {"id": user.id, "name": user.name, "email": user.email}
@router.post("/login", response_model=TokenResponse)
def login(credentials: UserLogin, db: Session = Depends(get_db)):
"""Authenticate and return a JWT token."""
user = db.query(User).filter(User.email == credentials.email).first()
if not user or not verify_password(credentials.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is deactivated"
)
token = create_access_token(
data={"sub": str(user.id), "role": user.role},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
)
return {"access_token": token, "token_type": "bearer"}
# tests/integration/test_auth.py
"""
Comprehensive authentication tests.
"""
import pytest
from jose import jwt
from datetime import datetime, timedelta
# Helper to create test tokens
def create_test_token(
user_id: int = 1,
role: str = "user",
expired: bool = False,
secret: str = "your-secret-key-here",
):
"""Create a JWT token for testing."""
expire = datetime.utcnow() + (
timedelta(minutes=-5) if expired else timedelta(minutes=30)
)
payload = {
"sub": str(user_id),
"role": role,
"exp": expire,
}
return jwt.encode(payload, secret, algorithm="HS256")
class TestRegistration:
"""Tests for POST /auth/register."""
def test_register_success(self, client, sample_user_data):
"""Successfully register a new user."""
response = client.post("/auth/register", json=sample_user_data)
assert response.status_code == 201
data = response.json()
assert data["name"] == sample_user_data["name"]
assert data["email"] == sample_user_data["email"]
assert "id" in data
# Password should NOT be in the response
assert "password" not in data
assert "hashed_password" not in data
def test_register_duplicate_email(self, client, sample_user_data):
"""Registering with an existing email returns 409."""
client.post("/auth/register", json=sample_user_data)
response = client.post("/auth/register", json=sample_user_data)
assert response.status_code == 409
assert "already registered" in response.json()["detail"].lower()
def test_register_invalid_email(self, client):
"""Invalid email format is rejected."""
response = client.post("/auth/register", json={
"name": "Test",
"email": "not-an-email",
"password": "Pass123!"
})
assert response.status_code == 422
def test_register_weak_password(self, client):
"""Weak password is rejected (if validation is implemented)."""
response = client.post("/auth/register", json={
"name": "Test",
"email": "test@example.com",
"password": "123"
})
assert response.status_code == 422
def test_register_missing_fields(self, client):
"""Missing required fields return 422."""
response = client.post("/auth/register", json={
"email": "test@example.com"
})
assert response.status_code == 422
class TestLogin:
"""Tests for POST /auth/login."""
def test_login_success(self, client, sample_user_data):
"""Successfully log in with correct credentials."""
client.post("/auth/register", json=sample_user_data)
response = client.post("/auth/login", json={
"email": sample_user_data["email"],
"password": sample_user_data["password"]
})
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
def test_login_returns_valid_jwt(self, client, sample_user_data):
"""The returned token is a valid JWT."""
client.post("/auth/register", json=sample_user_data)
response = client.post("/auth/login", json={
"email": sample_user_data["email"],
"password": sample_user_data["password"]
})
token = response.json()["access_token"]
# Decode and verify the token
payload = jwt.decode(token, "your-secret-key-here", algorithms=["HS256"])
assert "sub" in payload
assert "exp" in payload
assert "role" in payload
def test_login_wrong_password(self, client, sample_user_data):
"""Wrong password returns 401."""
client.post("/auth/register", json=sample_user_data)
response = client.post("/auth/login", json={
"email": sample_user_data["email"],
"password": "WrongPassword123!"
})
assert response.status_code == 401
def test_login_nonexistent_user(self, client):
"""Login with non-existent email returns 401."""
response = client.post("/auth/login", json={
"email": "ghost@example.com",
"password": "Pass123!"
})
assert response.status_code == 401
def test_login_deactivated_account(self, client, db_session, sample_user_data):
"""Deactivated account cannot log in."""
# Register and then deactivate
client.post("/auth/register", json=sample_user_data)
from app.models.user import User
user = db_session.query(User).filter_by(
email=sample_user_data["email"]
).first()
user.is_active = False
db_session.commit()
response = client.post("/auth/login", json={
"email": sample_user_data["email"],
"password": sample_user_data["password"]
})
assert response.status_code == 403
class TestProtectedEndpoints:
"""Tests for accessing protected endpoints."""
def test_access_with_valid_token(self, client, auth_headers):
"""Valid token grants access."""
response = client.get("/users/me", headers=auth_headers)
assert response.status_code == 200
def test_access_without_token(self, client):
"""No token returns 401 or 403."""
response = client.get("/users/me")
assert response.status_code in (401, 403)
def test_access_with_invalid_token(self, client):
"""Invalid token returns 401."""
headers = {"Authorization": "Bearer invalid.token.here"}
response = client.get("/users/me", headers=headers)
assert response.status_code in (401, 403)
def test_access_with_expired_token(self, client):
"""Expired token returns 401."""
expired_token = create_test_token(expired=True)
headers = {"Authorization": f"Bearer {expired_token}"}
response = client.get("/users/me", headers=headers)
assert response.status_code in (401, 403)
def test_access_with_wrong_secret(self, client):
"""Token signed with wrong secret is rejected."""
bad_token = create_test_token(secret="wrong-secret")
headers = {"Authorization": f"Bearer {bad_token}"}
response = client.get("/users/me", headers=headers)
assert response.status_code in (401, 403)
class TestRoleBasedAccess:
"""Tests for role-based authorization."""
def test_admin_access_admin_endpoint(self, client, admin_headers):
"""Admin can access admin-only endpoints."""
response = client.get("/admin/users", headers=admin_headers)
assert response.status_code == 200
def test_user_cannot_access_admin_endpoint(self, client, auth_headers):
"""Regular user cannot access admin endpoints."""
response = client.get("/admin/users", headers=auth_headers)
assert response.status_code == 403
def test_role_in_token_payload(self, client, sample_user_data):
"""Token payload contains the correct role."""
client.post("/auth/register", json=sample_user_data)
response = client.post("/auth/login", json={
"email": sample_user_data["email"],
"password": sample_user_data["password"]
})
token = response.json()["access_token"]
payload = jwt.decode(token, "your-secret-key-here", algorithms=["HS256"])
assert payload["role"] == "user"
Testing file upload endpoints requires sending multipart form data. FastAPI’s TestClient makes this straightforward.
# app/routers/files.py
"""File upload endpoints."""
import os
import shutil
from typing import List
from fastapi import APIRouter, File, UploadFile, HTTPException, status
router = APIRouter(prefix="/files", tags=["files"])
UPLOAD_DIR = "uploads"
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".pdf", ".txt"}
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
@router.post("/upload")
async def upload_file(file: UploadFile = File(...)):
"""Upload a single file."""
# Validate file extension
ext = os.path.splitext(file.filename)[1].lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type '{ext}' not allowed. Allowed: {ALLOWED_EXTENSIONS}"
)
# Read and check file size
contents = await file.read()
if len(contents) > MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File too large. Max size: {MAX_FILE_SIZE} bytes"
)
# Save the file
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_path = os.path.join(UPLOAD_DIR, file.filename)
with open(file_path, "wb") as f:
f.write(contents)
return {
"filename": file.filename,
"size": len(contents),
"content_type": file.content_type,
}
@router.post("/upload-multiple")
async def upload_multiple_files(files: List[UploadFile] = File(...)):
"""Upload multiple files at once."""
results = []
for file in files:
contents = await file.read()
results.append({
"filename": file.filename,
"size": len(contents),
"content_type": file.content_type,
})
return {"uploaded": len(results), "files": results}
# tests/integration/test_files.py
"""
Tests for file upload endpoints.
"""
import io
import os
import pytest
import shutil
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
UPLOAD_DIR = "uploads"
@pytest.fixture(autouse=True)
def cleanup_uploads():
"""Remove uploaded files after each test."""
yield
if os.path.exists(UPLOAD_DIR):
shutil.rmtree(UPLOAD_DIR)
class TestSingleFileUpload:
"""Tests for POST /files/upload."""
def test_upload_text_file(self):
"""Upload a text file successfully."""
file_content = b"Hello, this is a test file."
response = client.post(
"/files/upload",
files={"file": ("test.txt", file_content, "text/plain")}
)
assert response.status_code == 200
data = response.json()
assert data["filename"] == "test.txt"
assert data["size"] == len(file_content)
assert data["content_type"] == "text/plain"
def test_upload_image_file(self):
"""Upload a PNG image file."""
# Create a minimal PNG file (1x1 pixel)
png_header = (
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01'
b'\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde'
)
response = client.post(
"/files/upload",
files={"file": ("image.png", png_header, "image/png")}
)
assert response.status_code == 200
assert response.json()["filename"] == "image.png"
def test_upload_pdf_file(self):
"""Upload a PDF file."""
pdf_content = b"%PDF-1.4 fake pdf content"
response = client.post(
"/files/upload",
files={"file": ("document.pdf", pdf_content, "application/pdf")}
)
assert response.status_code == 200
def test_upload_disallowed_extension(self):
"""Uploading a .exe file should be rejected."""
response = client.post(
"/files/upload",
files={"file": ("malware.exe", b"bad content", "application/octet-stream")}
)
assert response.status_code == 400
assert "not allowed" in response.json()["detail"]
def test_upload_file_saves_to_disk(self):
"""Verify the uploaded file is actually saved."""
file_content = b"Persistent content"
client.post(
"/files/upload",
files={"file": ("saved.txt", file_content, "text/plain")}
)
file_path = os.path.join(UPLOAD_DIR, "saved.txt")
assert os.path.exists(file_path)
with open(file_path, "rb") as f:
assert f.read() == file_content
def test_upload_no_file(self):
"""Request without a file should return 422."""
response = client.post("/files/upload")
assert response.status_code == 422
def test_upload_using_io_bytes(self):
"""Upload using io.BytesIO for in-memory files."""
file_obj = io.BytesIO(b"BytesIO content")
response = client.post(
"/files/upload",
files={"file": ("bytesio.txt", file_obj, "text/plain")}
)
assert response.status_code == 200
class TestMultipleFileUpload:
"""Tests for POST /files/upload-multiple."""
def test_upload_multiple_files(self):
"""Upload multiple files at once."""
files = [
("files", ("file1.txt", b"Content 1", "text/plain")),
("files", ("file2.txt", b"Content 2", "text/plain")),
("files", ("file3.txt", b"Content 3", "text/plain")),
]
response = client.post("/files/upload-multiple", files=files)
assert response.status_code == 200
data = response.json()
assert data["uploaded"] == 3
assert len(data["files"]) == 3
def test_upload_mixed_file_types(self):
"""Upload files of different types."""
files = [
("files", ("doc.txt", b"text content", "text/plain")),
("files", ("image.jpg", b"fake jpg", "image/jpeg")),
]
response = client.post("/files/upload-multiple", files=files)
assert response.status_code == 200
assert response.json()["uploaded"] == 2
FastAPI supports WebSocket endpoints, and the TestClient provides built-in WebSocket testing capabilities.
# app/routers/websocket.py
"""WebSocket endpoints."""
from typing import List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
router = APIRouter()
class ConnectionManager:
"""Manages active WebSocket connections."""
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@router.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
"""WebSocket endpoint for real-time communication."""
await manager.connect(websocket)
try:
# Send welcome message
await websocket.send_json({
"type": "connected",
"client_id": client_id,
"message": f"Welcome, {client_id}!"
})
while True:
# Receive message from client
data = await websocket.receive_text()
# Echo back with processing
response = {
"type": "message",
"client_id": client_id,
"content": data,
"echo": f"Server received: {data}"
}
await websocket.send_json(response)
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client {client_id} disconnected")
# tests/integration/test_websocket.py
"""
Tests for WebSocket endpoints.
"""
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
class TestWebSocketEndpoint:
"""Tests for the WebSocket /ws/{client_id} endpoint."""
def test_websocket_connection(self):
"""Client can connect to the WebSocket."""
with client.websocket_connect("/ws/test-user") as websocket:
# Should receive a welcome message
data = websocket.receive_json()
assert data["type"] == "connected"
assert data["client_id"] == "test-user"
assert "Welcome" in data["message"]
def test_websocket_echo(self):
"""Server echoes back received messages."""
with client.websocket_connect("/ws/echo-user") as websocket:
# Skip welcome message
websocket.receive_json()
# Send a message
websocket.send_text("Hello, Server!")
# Receive the echo
response = websocket.receive_json()
assert response["type"] == "message"
assert response["content"] == "Hello, Server!"
assert "Server received" in response["echo"]
def test_websocket_multiple_messages(self):
"""Can send and receive multiple messages."""
with client.websocket_connect("/ws/multi-user") as websocket:
websocket.receive_json() # Skip welcome
messages = ["Hello", "How are you?", "Goodbye"]
for msg in messages:
websocket.send_text(msg)
response = websocket.receive_json()
assert response["content"] == msg
def test_websocket_json_communication(self):
"""Can send and receive JSON data."""
with client.websocket_connect("/ws/json-user") as websocket:
websocket.receive_json() # Skip welcome
# Send text, receive JSON
websocket.send_text("test message")
response = websocket.receive_json()
assert isinstance(response, dict)
assert "type" in response
assert "content" in response
def test_websocket_client_id_in_response(self):
"""The client_id appears in all responses."""
client_id = "user-123"
with client.websocket_connect(f"/ws/{client_id}") as websocket:
welcome = websocket.receive_json()
assert welcome["client_id"] == client_id
websocket.send_text("test")
response = websocket.receive_json()
assert response["client_id"] == client_id
def test_websocket_disconnect(self):
"""Disconnection is handled gracefully."""
# The 'with' block handles disconnection automatically
with client.websocket_connect("/ws/disconnect-user") as websocket:
data = websocket.receive_json()
assert data["type"] == "connected"
# After the with block, the connection is closed
# No exceptions should be raised
Performance testing ensures your API can handle expected load and identifies bottlenecks before they reach production.
# Install locust pip install locust
# tests/performance/locustfile.py
"""
Load testing configuration for FastAPI application.
Run with: locust -f tests/performance/locustfile.py --host=http://localhost:8000
"""
from locust import HttpUser, task, between, tag
class APIUser(HttpUser):
"""Simulates a typical API user."""
# Wait 1-3 seconds between tasks
wait_time = between(1, 3)
def on_start(self):
"""Run once when a simulated user starts."""
# Register and login
self.client.post("/auth/register", json={
"name": f"Load Test User",
"email": f"loadtest_{self.environment.runner.user_count}@test.com",
"password": "LoadTest123!"
})
response = self.client.post("/auth/login", json={
"email": f"loadtest_{self.environment.runner.user_count}@test.com",
"password": "LoadTest123!"
})
if response.status_code == 200:
self.token = response.json().get("access_token", "")
self.headers = {"Authorization": f"Bearer {self.token}"}
else:
self.headers = {}
@task(5)
@tag("read")
def list_items(self):
"""GET /items — most common operation (weight: 5)."""
self.client.get("/items", headers=self.headers)
@task(3)
@tag("read")
def get_single_item(self):
"""GET /items/{id} — common operation (weight: 3)."""
self.client.get("/items/1", headers=self.headers)
@task(1)
@tag("write")
def create_item(self):
"""POST /items — less common (weight: 1)."""
self.client.post("/items", json={
"title": "Load Test Item",
"description": "Created during load testing",
"price": 9.99,
"quantity": 1
}, headers=self.headers)
@task(1)
@tag("read")
def health_check(self):
"""GET /health — monitoring endpoint."""
self.client.get("/health")
class AdminUser(HttpUser):
"""Simulates an admin user — less frequent, heavier operations."""
wait_time = between(5, 10)
weight = 1 # 1 admin per 10 regular users
def on_start(self):
"""Login as admin."""
response = self.client.post("/auth/login", json={
"email": "admin@test.com",
"password": "AdminPass123!"
})
if response.status_code == 200:
self.token = response.json().get("access_token", "")
self.headers = {"Authorization": f"Bearer {self.token}"}
else:
self.headers = {}
@task
def list_all_users(self):
"""GET /admin/users — admin-only endpoint."""
self.client.get("/admin/users", headers=self.headers)
# Run Locust with web UI
locust -f tests/performance/locustfile.py --host=http://localhost:8000
# Run headless (for CI/CD)
locust -f tests/performance/locustfile.py \
--host=http://localhost:8000 \
--headless \
--users 100 \
--spawn-rate 10 \
--run-time 60s \
--csv=results/load_test
# Run only read-tagged tests
locust -f tests/performance/locustfile.py \
--host=http://localhost:8000 \
--tags read \
--headless \
--users 50 \
--spawn-rate 5 \
--run-time 30s
# tests/performance/test_benchmarks.py
"""
Benchmark tests to measure endpoint performance.
Install: pip install pytest-benchmark
Run: pytest tests/performance/test_benchmarks.py --benchmark-only
"""
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_health_endpoint_performance(benchmark):
"""Benchmark the health check endpoint."""
result = benchmark(client.get, "/health")
assert result.status_code == 200
def test_list_items_performance(benchmark):
"""Benchmark listing items."""
result = benchmark(client.get, "/items")
assert result.status_code == 200
def test_create_item_performance(benchmark):
"""Benchmark creating an item."""
payload = {
"title": "Benchmark Item",
"description": "Performance test",
"price": 10.0,
"quantity": 1
}
result = benchmark(client.post, "/items", json=payload)
assert result.status_code == 201
@pytest.mark.slow
def test_response_time_under_threshold(client):
"""Verify critical endpoints respond within acceptable time."""
import time
endpoints = [
("GET", "/health", None),
("GET", "/items", None),
("POST", "/items", {
"title": "Speed Test",
"description": "Testing speed",
"price": 10.0,
"quantity": 1
}),
]
max_response_time = 0.5 # 500ms threshold
for method, path, payload in endpoints:
start = time.time()
if method == "GET":
response = client.get(path)
else:
response = client.post(path, json=payload)
elapsed = time.time() - start
assert elapsed < max_response_time, (
f"{method} {path} took {elapsed:.3f}s "
f"(threshold: {max_response_time}s)"
)
Test coverage measures what percentage of your code is executed by tests. While 100% coverage does not guarantee bug-free code, it highlights untested code paths.
# Install pytest-cov
pip install pytest-cov
# Run tests with coverage
pytest --cov=app tests/
# Generate HTML coverage report
pytest --cov=app --cov-report=html tests/
# Generate XML report (for CI/CD)
pytest --cov=app --cov-report=xml tests/
# Show missing lines in terminal
pytest --cov=app --cov-report=term-missing tests/
# Fail if coverage is below threshold
pytest --cov=app --cov-fail-under=80 tests/
# Combine multiple report formats
pytest --cov=app \
--cov-report=term-missing \
--cov-report=html:coverage_html \
--cov-report=xml:coverage.xml \
tests/
# pyproject.toml — Coverage configuration
[tool.coverage.run]
source = ["app"]
omit = [
"app/__init__.py",
"app/config.py",
"tests/*",
"*/migrations/*",
]
[tool.coverage.report]
show_missing = true
fail_under = 80
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if __name__ == .__main__.",
"raise NotImplementedError",
"pass",
"except ImportError",
]
[tool.coverage.html]
directory = "htmlcov"
# Sample terminal output from pytest --cov=app --cov-report=term-missing ---------- coverage: platform linux, python 3.11.0 ---------- Name Stmts Miss Cover Missing ------------------------------------------------------------ app/__init__.py 0 0 100% app/main.py 15 0 100% app/database.py 12 2 83% 18-19 app/models/item.py 8 0 100% app/models/user.py 12 0 100% app/routers/auth.py 45 3 93% 67-69 app/routers/items.py 38 0 100% app/routers/files.py 30 5 83% 42-46 app/dependencies.py 25 4 84% 31-34 app/services/email_service.py 18 8 56% 12-19 app/services/weather_service.py 20 12 40% 8-19 ------------------------------------------------------------ TOTAL 223 34 85%
| Coverage Level | Meaning | Recommendation |
|---|---|---|
| < 50% | Critical gaps | Many untested code paths — high risk of bugs |
| 50-70% | Basic coverage | Core paths tested but edge cases missing |
| 70-80% | Good coverage | Solid for most projects — good balance |
| 80-90% | Strong coverage | Recommended target for production APIs |
| 90-100% | Comprehensive | Great for critical systems (payments, auth) |
Automating your tests in a CI/CD pipeline ensures that every code change is validated before reaching production.
# .github/workflows/test.yml
name: FastAPI Tests
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
services:
postgres:
image: postgres:15
env:
POSTGRES_USER: testuser
POSTGRES_PASSWORD: testpass
POSTGRES_DB: testdb
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip packages
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements*.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-test.txt
- name: Run linting
run: |
pip install ruff
ruff check app/ tests/
- name: Run type checking
run: |
pip install mypy
mypy app/ --ignore-missing-imports
- name: Run tests with coverage
env:
DATABASE_URL: postgresql://testuser:testpass@localhost:5432/testdb
SECRET_KEY: test-secret-key-for-ci
TESTING: "true"
run: |
pytest tests/ \
--cov=app \
--cov-report=xml:coverage.xml \
--cov-report=term-missing \
--cov-fail-under=80 \
-v \
--tb=short
- name: Upload coverage to Codecov
if: matrix.python-version == '3.12'
uses: codecov/codecov-action@v4
with:
file: coverage.xml
fail_ci_if_error: false
- name: Upload test results
if: always()
uses: actions/upload-artifact@v4
with:
name: test-results-${{ matrix.python-version }}
path: |
coverage.xml
htmlcov/
# Dockerfile.test FROM python:3.12-slim WORKDIR /app # Install dependencies COPY requirements.txt requirements-test.txt ./ RUN pip install --no-cache-dir -r requirements.txt -r requirements-test.txt # Copy application code COPY . . # Run tests CMD ["pytest", "tests/", "-v", "--cov=app", "--cov-report=term-missing"]
# docker-compose.test.yml
version: "3.9"
services:
test:
build:
context: .
dockerfile: Dockerfile.test
environment:
- DATABASE_URL=postgresql://testuser:testpass@db:5432/testdb
- SECRET_KEY=test-secret-key
- TESTING=true
depends_on:
db:
condition: service_healthy
db:
image: postgres:15-alpine
environment:
POSTGRES_USER: testuser
POSTGRES_PASSWORD: testpass
POSTGRES_DB: testdb
healthcheck:
test: ["CMD-SHELL", "pg_isready -U testuser -d testdb"]
interval: 5s
timeout: 5s
retries: 5
# Run tests in Docker
docker compose -f docker-compose.test.yml up --build --abort-on-container-exit
# Run specific test files
docker compose -f docker-compose.test.yml run test \
pytest tests/integration/ -v
# Run with coverage report
docker compose -f docker-compose.test.yml run test \
pytest tests/ --cov=app --cov-report=html
# Clean up
docker compose -f docker-compose.test.yml down -v
# .github/workflows/ci-cd.yml
name: CI/CD Pipeline
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
# Stage 1: Lint and Type Check
quality:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- run: pip install ruff mypy
- run: ruff check app/ tests/
- run: mypy app/ --ignore-missing-imports
# Stage 2: Unit Tests (fast, no external deps)
unit-tests:
runs-on: ubuntu-latest
needs: quality
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- run: |
pip install -r requirements.txt
pip install -r requirements-test.txt
- run: pytest tests/unit/ -v --cov=app -m unit
# Stage 3: Integration Tests (needs database)
integration-tests:
runs-on: ubuntu-latest
needs: unit-tests
services:
postgres:
image: postgres:15
env:
POSTGRES_USER: test
POSTGRES_PASSWORD: test
POSTGRES_DB: testdb
ports: ["5432:5432"]
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- run: |
pip install -r requirements.txt
pip install -r requirements-test.txt
- run: pytest tests/integration/ -v -m integration
env:
DATABASE_URL: postgresql://test:test@localhost:5432/testdb
# Stage 4: Deploy (only on main branch)
deploy:
runs-on: ubuntu-latest
needs: [unit-tests, integration-tests]
if: github.ref == 'refs/heads/main' && github.event_name == 'push'
steps:
- uses: actions/checkout@v4
- name: Deploy to production
run: echo "Deploying to production..."
Let us bring everything together into a complete, production-ready test suite for a CRUD API with authentication, database integration, and mocking.
# Complete application structure
bookstore_api/
├── app/
│ ├── __init__.py
│ ├── main.py
│ ├── config.py
│ ├── database.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── book.py
│ │ └── user.py
│ ├── schemas/
│ │ ├── __init__.py
│ │ ├── book.py
│ │ └── user.py
│ ├── routers/
│ │ ├── __init__.py
│ │ ├── books.py
│ │ └── auth.py
│ ├── services/
│ │ └── notification_service.py
│ └── dependencies.py
└── tests/
├── __init__.py
├── conftest.py
├── test_books_crud.py
├── test_auth.py
└── test_notifications.py
# app/main.py
"""FastAPI Bookstore Application."""
from fastapi import FastAPI
from app.routers import books, auth
from app.database import Base, engine
# Create tables
Base.metadata.create_all(bind=engine)
app = FastAPI(title="Bookstore API", version="1.0.0")
app.include_router(auth.router)
app.include_router(books.router)
@app.get("/health")
def health_check():
return {"status": "healthy"}
# app/schemas/book.py
"""Book Pydantic schemas."""
from pydantic import BaseModel, Field
from typing import Optional
class BookCreate(BaseModel):
title: str = Field(..., min_length=1, max_length=300)
author: str = Field(..., min_length=1, max_length=200)
isbn: str = Field(..., pattern=r"^\d{13}$")
price: float = Field(..., gt=0)
description: Optional[str] = None
class BookUpdate(BaseModel):
title: Optional[str] = Field(None, min_length=1, max_length=300)
author: Optional[str] = Field(None, min_length=1, max_length=200)
price: Optional[float] = Field(None, gt=0)
description: Optional[str] = None
class BookResponse(BaseModel):
id: int
title: str
author: str
isbn: str
price: float
description: Optional[str]
owner_id: int
model_config = {"from_attributes": True}
# app/models/book.py
"""Book SQLAlchemy model."""
from sqlalchemy import Column, Integer, String, Float, ForeignKey
from sqlalchemy.orm import relationship
from app.database import Base
class Book(Base):
__tablename__ = "books"
id = Column(Integer, primary_key=True, index=True)
title = Column(String(300), nullable=False)
author = Column(String(200), nullable=False)
isbn = Column(String(13), unique=True, nullable=False)
price = Column(Float, nullable=False)
description = Column(String(2000), nullable=True)
owner_id = Column(Integer, ForeignKey("users.id"), nullable=False)
owner = relationship("User", back_populates="books")
# app/routers/books.py
"""Book CRUD router with authentication."""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies import get_current_user
from app.models.book import Book
from app.models.user import User
from app.schemas.book import BookCreate, BookUpdate, BookResponse
router = APIRouter(prefix="/books", tags=["books"])
@router.get("/", response_model=List[BookResponse])
def list_books(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
author: Optional[str] = None,
db: Session = Depends(get_db),
):
"""List all books (public endpoint)."""
query = db.query(Book)
if author:
query = query.filter(Book.author.ilike(f"%{author}%"))
return query.offset(skip).limit(limit).all()
@router.get("/{book_id}", response_model=BookResponse)
def get_book(book_id: int, db: Session = Depends(get_db)):
"""Get a book by ID (public endpoint)."""
book = db.query(Book).filter(Book.id == book_id).first()
if not book:
raise HTTPException(status_code=404, detail="Book not found")
return book
@router.post("/", response_model=BookResponse, status_code=201)
def create_book(
book_data: BookCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Create a new book (authenticated)."""
# Check for duplicate ISBN
existing = db.query(Book).filter(Book.isbn == book_data.isbn).first()
if existing:
raise HTTPException(
status_code=409,
detail=f"Book with ISBN {book_data.isbn} already exists"
)
book = Book(**book_data.model_dump(), owner_id=current_user.id)
db.add(book)
db.commit()
db.refresh(book)
return book
@router.put("/{book_id}", response_model=BookResponse)
def update_book(
book_id: int,
book_data: BookUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Update a book (owner only)."""
book = db.query(Book).filter(Book.id == book_id).first()
if not book:
raise HTTPException(status_code=404, detail="Book not found")
if book.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
update_data = book_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(book, field, value)
db.commit()
db.refresh(book)
return book
@router.delete("/{book_id}", status_code=204)
def delete_book(
book_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Delete a book (owner only)."""
book = db.query(Book).filter(Book.id == book_id).first()
if not book:
raise HTTPException(status_code=404, detail="Book not found")
if book.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
db.delete(book)
db.commit()
# tests/conftest.py
"""Complete test configuration for the Bookstore API."""
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.main import app
from app.database import Base, get_db
from app.dependencies import get_current_user
from app.models.user import User
SQLALCHEMY_DATABASE_URL = "sqlite://"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
@pytest.fixture(scope="function")
def db_session():
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def client(db_session):
def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as c:
yield c
app.dependency_overrides.clear()
@pytest.fixture
def test_user(db_session):
"""Create a test user in the database."""
user = User(
id=1, name="Test User", email="test@example.com",
hashed_password="fake_hash", role="user", is_active=True
)
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def other_user(db_session):
"""Create a second test user."""
user = User(
id=2, name="Other User", email="other@example.com",
hashed_password="fake_hash", role="user", is_active=True
)
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def authenticated_client(client, test_user):
"""Client with authentication dependency overridden."""
def override_current_user():
return test_user
app.dependency_overrides[get_current_user] = override_current_user
yield client
# Note: client fixture already clears overrides
@pytest.fixture
def other_authenticated_client(client, other_user):
"""Client authenticated as a different user."""
def override_current_user():
return other_user
app.dependency_overrides[get_current_user] = override_current_user
yield client
@pytest.fixture
def sample_book():
return {
"title": "Clean Code",
"author": "Robert C. Martin",
"isbn": "9780132350884",
"price": 39.99,
"description": "A handbook of agile software craftsmanship"
}
@pytest.fixture
def another_book():
return {
"title": "The Pragmatic Programmer",
"author": "David Thomas",
"isbn": "9780135957059",
"price": 49.99,
"description": "Your journey to mastery"
}
# tests/test_books_crud.py
"""Complete CRUD tests for the Bookstore API."""
import pytest
class TestListBooks:
"""GET /books — Public endpoint."""
def test_empty_list(self, client):
response = client.get("/books")
assert response.status_code == 200
assert response.json() == []
def test_list_with_books(self, authenticated_client, sample_book, another_book):
authenticated_client.post("/books", json=sample_book)
authenticated_client.post("/books", json=another_book)
response = authenticated_client.get("/books")
assert response.status_code == 200
assert len(response.json()) == 2
def test_filter_by_author(self, authenticated_client, sample_book, another_book):
authenticated_client.post("/books", json=sample_book)
authenticated_client.post("/books", json=another_book)
response = authenticated_client.get("/books?author=Martin")
assert response.status_code == 200
books = response.json()
assert len(books) == 1
assert books[0]["author"] == "Robert C. Martin"
def test_pagination(self, authenticated_client, sample_book):
# Create several books with unique ISBNs
for i in range(5):
book = {**sample_book, "isbn": f"978013235{i:04d}", "title": f"Book {i}"}
authenticated_client.post("/books", json=book)
response = authenticated_client.get("/books?skip=2&limit=2")
assert len(response.json()) == 2
class TestGetBook:
"""GET /books/{id} — Public endpoint."""
def test_get_existing_book(self, authenticated_client, sample_book):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
response = authenticated_client.get(f"/books/{book_id}")
assert response.status_code == 200
assert response.json()["title"] == sample_book["title"]
assert response.json()["isbn"] == sample_book["isbn"]
def test_get_nonexistent_book(self, client):
response = client.get("/books/99999")
assert response.status_code == 404
def test_response_has_all_fields(self, authenticated_client, sample_book):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
response = authenticated_client.get(f"/books/{book_id}")
data = response.json()
assert "id" in data
assert "title" in data
assert "author" in data
assert "isbn" in data
assert "price" in data
assert "owner_id" in data
class TestCreateBook:
"""POST /books — Authenticated endpoint."""
def test_create_success(self, authenticated_client, sample_book):
response = authenticated_client.post("/books", json=sample_book)
assert response.status_code == 201
data = response.json()
assert data["title"] == sample_book["title"]
assert data["author"] == sample_book["author"]
assert data["isbn"] == sample_book["isbn"]
assert data["owner_id"] == 1 # test_user id
def test_create_requires_authentication(self, client, sample_book):
"""Unauthenticated request should fail."""
response = client.post("/books", json=sample_book)
assert response.status_code in (401, 403)
def test_duplicate_isbn_rejected(self, authenticated_client, sample_book):
authenticated_client.post("/books", json=sample_book)
duplicate = {**sample_book, "title": "Different Title"}
response = authenticated_client.post("/books", json=duplicate)
assert response.status_code == 409
def test_invalid_isbn_format(self, authenticated_client, sample_book):
invalid_book = {**sample_book, "isbn": "123"}
response = authenticated_client.post("/books", json=invalid_book)
assert response.status_code == 422
def test_negative_price_rejected(self, authenticated_client, sample_book):
invalid_book = {**sample_book, "isbn": "9781234567890", "price": -10.0}
response = authenticated_client.post("/books", json=invalid_book)
assert response.status_code == 422
def test_missing_required_fields(self, authenticated_client):
response = authenticated_client.post("/books", json={"title": "Incomplete"})
assert response.status_code == 422
class TestUpdateBook:
"""PUT /books/{id} — Owner only."""
def test_owner_can_update(self, authenticated_client, sample_book):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
response = authenticated_client.put(f"/books/{book_id}", json={
"title": "Clean Code (2nd Edition)",
"price": 44.99
})
assert response.status_code == 200
assert response.json()["title"] == "Clean Code (2nd Edition)"
assert response.json()["price"] == 44.99
def test_non_owner_cannot_update(
self, authenticated_client, other_authenticated_client, sample_book
):
# Create book as test_user
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
# Try to update as other_user
response = other_authenticated_client.put(f"/books/{book_id}", json={
"title": "Stolen Book"
})
assert response.status_code == 403
def test_update_nonexistent_book(self, authenticated_client):
response = authenticated_client.put("/books/99999", json={
"title": "Ghost Book"
})
assert response.status_code == 404
def test_partial_update(self, authenticated_client, sample_book):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
# Only update the price
response = authenticated_client.put(f"/books/{book_id}", json={
"price": 29.99
})
assert response.status_code == 200
assert response.json()["price"] == 29.99
assert response.json()["title"] == sample_book["title"] # Unchanged
class TestDeleteBook:
"""DELETE /books/{id} — Owner only."""
def test_owner_can_delete(self, authenticated_client, sample_book):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
response = authenticated_client.delete(f"/books/{book_id}")
assert response.status_code == 204
# Verify deletion
get_resp = authenticated_client.get(f"/books/{book_id}")
assert get_resp.status_code == 404
def test_non_owner_cannot_delete(
self, authenticated_client, other_authenticated_client, sample_book
):
create_resp = authenticated_client.post("/books", json=sample_book)
book_id = create_resp.json()["id"]
response = other_authenticated_client.delete(f"/books/{book_id}")
assert response.status_code == 403
def test_delete_nonexistent(self, authenticated_client):
response = authenticated_client.delete("/books/99999")
assert response.status_code == 404
def test_delete_reduces_count(self, authenticated_client, sample_book, another_book):
r1 = authenticated_client.post("/books", json=sample_book)
authenticated_client.post("/books", json=another_book)
assert len(authenticated_client.get("/books").json()) == 2
authenticated_client.delete(f"/books/{r1.json()['id']}")
assert len(authenticated_client.get("/books").json()) == 1
class TestNotificationMocking:
"""Test that notifications are sent on book creation."""
def test_notification_sent_on_create(
self, authenticated_client, sample_book, mocker
):
"""Mock the notification service to verify it is called."""
mock_notify = mocker.patch(
"app.routers.books.send_notification",
return_value=True
)
authenticated_client.post("/books", json=sample_book)
# If send_notification is called in the create endpoint:
# mock_notify.assert_called_once()
# Running the complete test suite # Run all tests pytest tests/ -v # Run with coverage pytest tests/ --cov=app --cov-report=term-missing --cov-fail-under=80 # Run only CRUD tests pytest tests/test_books_crud.py -v # Run a specific test class pytest tests/test_books_crud.py::TestCreateBook -v # Run a specific test pytest tests/test_books_crud.py::TestCreateBook::test_create_success -v # Run tests matching a keyword pytest tests/ -k "create" -v # Run tests in parallel (requires pytest-xdist) pip install pytest-xdist pytest tests/ -n auto -v # Show slowest tests pytest tests/ --durations=10
| # | Topic | Key Point |
|---|---|---|
| 1 | Testing Philosophy | Follow the testing pyramid — many unit tests, fewer integration tests, few E2E tests |
| 2 | Setup | Use pytest, httpx, and conftest.py for a clean test foundation |
| 3 | TestClient | Use Starlette’s TestClient for synchronous endpoint testing — no server required |
| 4 | Async Testing | Use httpx.AsyncClient with pytest-asyncio for async endpoint and dependency testing |
| 5 | Route Testing | Test all CRUD operations, edge cases, error responses, and query parameter validation |
| 6 | Validation Testing | Test Pydantic schemas independently and verify API returns proper 422 error responses |
| 7 | Database Testing | Use in-memory SQLite with transaction rollback for fast, isolated database tests |
| 8 | Dependency Override | Use app.dependency_overrides to swap real dependencies with test doubles |
| 9 | Mocking | Mock external services with unittest.mock or pytest-mock to avoid real network calls |
| 10 | Auth Testing | Test registration, login, JWT validation, token expiration, and role-based access |
| 11 | File Uploads | Test multipart form data uploads with the files parameter in TestClient |
| 12 | WebSockets | Use client.websocket_connect() to test WebSocket communication |
| 13 | Performance | Use Locust for load testing and pytest-benchmark for endpoint benchmarking |
| 14 | Coverage | Aim for 80%+ coverage; use pytest-cov with --cov-fail-under to enforce minimums |
| 15 | CI/CD | Automate tests with GitHub Actions; use Docker for reproducible test environments |
| 16 | Complete Suite | Combine all patterns into a cohesive test suite that covers auth, CRUD, mocking, and validation |
Pydantic is the backbone of FastAPI’s data validation system. Every request body, query parameter, and response model you define in FastAPI is powered by Pydantic under the hood. In this comprehensive tutorial, you will master Pydantic v2 — from basic model definitions and field constraints to custom validators, generic models, and a complete real-world registration system. By the end, you will be able to build bulletproof APIs that validate every piece of incoming data automatically.
fastapi-pydantic and follow along step by step.Pydantic is a data validation and serialization library for Python that uses standard type annotations to define data schemas. When you create a Pydantic model, every field is automatically validated based on its type annotation and any additional constraints you provide. If the data does not match, Pydantic raises a clear, structured error — no manual checking required.
FastAPI was specifically designed around Pydantic because of several key advantages:
| Feature | Benefit |
|---|---|
| Type-first validation | Uses standard Python type hints — no new DSL to learn |
| Automatic JSON Schema | Every model generates OpenAPI-compatible JSON Schema for docs |
| Performance (v2) | Core validation logic rewritten in Rust — up to 50x faster than v1 |
| Serialization | Built-in model_dump() and model_dump_json() for output control |
| Editor support | Full autocomplete and type checking in IDEs like VS Code and PyCharm |
| Ecosystem | Works with SQLAlchemy, MongoDB, settings management, and more |
Pydantic v2 was a major rewrite. Here are the most important changes:
# Pydantic v2 key improvements over v1: # # 1. pydantic-core written in Rust for massive speed gains # 2. model_dump() replaces dict() # 3. model_dump_json() replaces json() # 4. model_validate() replaces parse_obj() # 5. model_validate_json() replaces parse_raw() # 6. model_config replaces class Config # 7. @field_validator replaces @validator # 8. @model_validator replaces @root_validator # 9. ConfigDict for typed configuration # 10. Strict mode for no type coercion
# Pydantic is installed automatically with FastAPI pip install fastapi[standard] # Or install Pydantic separately pip install pydantic # Check your version python -c "import pydantic; print(pydantic.__version__)" # Should show 2.x.x
from pydantic import BaseModel
class User(BaseModel):
name: str
age: int
email: str
# Valid data - works fine
user = User(name="Alice", age=30, email="alice@example.com")
print(user)
# name='Alice' age=30 email='alice@example.com'
print(user.name) # Alice
print(user.age) # 30
# Pydantic coerces compatible types
user2 = User(name="Bob", age="25", email="bob@example.com")
print(user2.age) # 25 (converted string "25" to int)
print(type(user2.age)) # <class 'int'>
# Invalid data - raises ValidationError
try:
bad_user = User(name="Charlie", age="not_a_number", email="charlie@example.com")
except Exception as e:
print(e)
# 1 validation error for User
# age
# Input should be a valid integer, unable to parse string as an integer
Every Pydantic model inherits from BaseModel. Fields are defined using Python type annotations, and you can set defaults, make fields optional, and compose complex structures.
from pydantic import BaseModel
from typing import Optional
from datetime import datetime
class Product(BaseModel):
# Required fields - must be provided
name: str
price: float
# Field with a default value
quantity: int = 0
# Optional field - can be None
description: Optional[str] = None
# Field with a default factory
created_at: datetime = datetime.now()
# Only required fields
p1 = Product(name="Laptop", price=999.99)
print(p1)
# name='Laptop' price=999.99 quantity=0 description=None created_at=...
# All fields provided
p2 = Product(
name="Phone",
price=699.99,
quantity=50,
description="Latest smartphone model",
created_at=datetime(2024, 1, 15)
)
print(p2.description) # Latest smartphone model
from pydantic import BaseModel
from typing import Optional, List, Dict
from datetime import datetime, date
from decimal import Decimal
from uuid import UUID, uuid4
from enum import Enum
class StatusEnum(str, Enum):
ACTIVE = "active"
INACTIVE = "inactive"
PENDING = "pending"
class CompleteExample(BaseModel):
# String types
name: str
description: str = "No description"
# Numeric types
age: int
price: float
exact_price: Decimal = Decimal("0.00")
# Boolean
is_active: bool = True
# Date/time
created_at: datetime = datetime.now()
birth_date: Optional[date] = None
# UUID
id: UUID = uuid4()
# Enum
status: StatusEnum = StatusEnum.ACTIVE
# Collections
tags: List[str] = []
metadata: Dict[str, str] = {}
scores: List[float] = []
# Optional fields
nickname: Optional[str] = None
parent_id: Optional[int] = None
# Usage
item = CompleteExample(
name="Test Item",
age=5,
price=29.99,
tags=["sale", "featured"],
metadata={"color": "blue", "size": "large"},
scores=[95.5, 87.3, 92.1]
)
print(item.name) # Test Item
print(item.tags) # ['sale', 'featured']
print(item.status) # StatusEnum.ACTIVE
from pydantic import BaseModel
from typing import Optional
class FieldExamples(BaseModel):
# REQUIRED - no default, must be provided
name: str
# REQUIRED - no default, must be provided (can be None)
middle_name: Optional[str]
# OPTIONAL - has a default of None
nickname: Optional[str] = None
# OPTIONAL - has a default value
country: str = "US"
# This works - all required fields provided
user = FieldExamples(name="Alice", middle_name=None)
print(user)
# name='Alice' middle_name=None nickname=None country='US'
# This fails - middle_name is required even though it is Optional[str]
try:
user = FieldExamples(name="Bob")
except Exception as e:
print(e)
# 1 validation error for FieldExamples
# middle_name
# Field required
Optional[str] without a default value is still required. It means the field accepts str or None, but you must explicitly provide it. To make a field truly optional (not required), give it a default: Optional[str] = None.
from pydantic import BaseModel
class Employee(BaseModel):
first_name: str
last_name: str
department: str
salary: float
@property
def full_name(self) -> str:
return f"{self.first_name} {self.last_name}"
@property
def annual_salary(self) -> float:
return self.salary * 12
def get_summary(self) -> str:
return f"{self.full_name} - {self.department} (${self.salary:,.2f}/month)"
emp = Employee(
first_name="Jane",
last_name="Smith",
department="Engineering",
salary=8500.00
)
print(emp.full_name) # Jane Smith
print(emp.annual_salary) # 102000.0
print(emp.get_summary()) # Jane Smith - Engineering ($8,500.00/month)
Pydantic’s Field() function lets you add constraints to individual fields — minimum/maximum values, string lengths, regex patterns, and more. These constraints are enforced automatically and appear in the generated JSON Schema.
from pydantic import BaseModel, Field
class Product(BaseModel):
name: str = Field(
..., # ... means required
min_length=1,
max_length=100,
description="Product name",
examples=["Laptop Pro 15"]
)
price: float = Field(
...,
gt=0, # greater than 0
le=1_000_000, # less than or equal to 1,000,000
description="Price in USD"
)
quantity: int = Field(
default=0,
ge=0, # greater than or equal to 0
le=10_000,
description="Items in stock"
)
sku: str = Field(
...,
pattern=r"^[A-Z]{2,4}-\d{4,8}$",
description="Stock keeping unit (e.g., PROD-12345)"
)
discount: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Discount as a decimal (0.0 to 1.0)"
)
# Valid product
product = Product(
name="Wireless Mouse",
price=29.99,
quantity=150,
sku="WM-12345",
discount=0.1
)
print(product)
# Invalid - triggers multiple validation errors
try:
bad_product = Product(
name="", # too short
price=-5, # not greater than 0
quantity=-1, # not >= 0
sku="invalid", # does not match pattern
discount=1.5 # exceeds max
)
except Exception as e:
print(e)
from pydantic import BaseModel, Field
class NumericConstraints(BaseModel):
# gt = greater than (exclusive minimum)
positive: int = Field(gt=0) # must be > 0
# ge = greater than or equal (inclusive minimum)
non_negative: int = Field(ge=0) # must be >= 0
# lt = less than (exclusive maximum)
below_hundred: int = Field(lt=100) # must be < 100
# le = less than or equal (inclusive maximum)
max_hundred: int = Field(le=100) # must be <= 100
# Combine constraints
percentage: float = Field(ge=0.0, le=100.0) # 0 to 100
# multiple_of
even_number: int = Field(multiple_of=2) # divisible by 2
# Price constraint
price: float = Field(gt=0, le=999999.99)
# Valid
data = NumericConstraints(
positive=1,
non_negative=0,
below_hundred=99,
max_hundred=100,
percentage=85.5,
even_number=42,
price=199.99
)
print(data)
from pydantic import BaseModel, Field
class StringConstraints(BaseModel):
# Length constraints
username: str = Field(min_length=3, max_length=30)
# Exact length (min_length == max_length)
country_code: str = Field(min_length=2, max_length=2)
# Regex pattern for phone
phone: str = Field(pattern=r"^\+?1?\d{9,15}$")
# Email pattern
email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w{2,}$")
# URL-safe slug
slug: str = Field(pattern=r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
# Password length
password: str = Field(min_length=8, max_length=128)
# Valid data
data = StringConstraints(
username="john_doe",
country_code="US",
phone="+15551234567",
email="john@example.com",
slug="my-blog-post",
password="Str0ng!Pass"
)
print(data.username) # john_doe
from pydantic import BaseModel, Field
class APIResponse(BaseModel):
# alias: the key name in incoming JSON data
item_id: int = Field(alias="itemId")
item_name: str = Field(alias="itemName")
is_available: bool = Field(alias="isAvailable", default=True)
# validation_alias: only for input; serialization_alias: only for output
total_price: float = Field(
validation_alias="total_price",
serialization_alias="totalPrice"
)
# Parse from camelCase JSON
data = {"itemId": 42, "itemName": "Widget", "isAvailable": True, "total_price": 29.99}
response = APIResponse(**data)
print(response.item_id) # 42
# Serialize with Python names (default)
print(response.model_dump())
# {'item_id': 42, 'item_name': 'Widget', 'is_available': True, 'total_price': 29.99}
# Serialize with aliases
print(response.model_dump(by_alias=True))
# {'itemId': 42, 'itemName': 'Widget', 'isAvailable': True, 'totalPrice': 29.99}
from pydantic import BaseModel, EmailStr, HttpUrl, IPvAnyAddress
from pydantic import PositiveInt, PositiveFloat, NonNegativeInt
from pydantic import constr, conint, confloat
class BuiltInTypes(BaseModel):
# Email validation (requires: pip install pydantic[email])
email: EmailStr
# URL validation
website: HttpUrl
# IP address
server_ip: IPvAnyAddress
# Constrained positive numbers
user_id: PositiveInt # must be > 0
rating: PositiveFloat # must be > 0.0
retry_count: NonNegativeInt # must be >= 0
# Constrained types (alternative to Field())
username: constr(min_length=3, max_length=50)
age: conint(ge=0, le=150)
score: confloat(ge=0.0, le=100.0)
data = BuiltInTypes(
email="user@example.com",
website="https://example.com",
server_ip="192.168.1.1",
user_id=42,
rating=4.5,
retry_count=0,
username="john_doe",
age=30,
score=95.5
)
print(data.email) # user@example.com
print(data.website) # https://example.com/
Pydantic leverages Python’s type annotation system extensively. Understanding how different type hints translate to validation rules is essential for building robust models.
from pydantic import BaseModel
class BasicTypes(BaseModel):
name: str # accepts strings
age: int # accepts integers, coerces numeric strings
height: float # accepts floats, coerces ints
is_active: bool # accepts booleans, coerces "true"/"false"
data: bytes # accepts bytes and strings
# Type coercion examples
model = BasicTypes(
name="Alice",
age="30", # string "30" coerced to int 30
height=5, # int 5 coerced to float 5.0
is_active="true", # string "true" coerced to True
data="hello" # string coerced to b"hello"
)
print(f"age: {model.age} (type: {type(model.age).__name__})")
# age: 30 (type: int)
print(f"height: {model.height} (type: {type(model.height).__name__})")
# height: 5.0 (type: float)
print(f"is_active: {model.is_active} (type: {type(model.is_active).__name__})")
# is_active: True (type: bool)
from pydantic import BaseModel, Field
from typing import List, Dict, Set, Tuple, FrozenSet
class CollectionTypes(BaseModel):
tags: List[str] = []
scores: List[int] = Field(default=[], min_length=1, max_length=100)
metadata: Dict[str, str] = {}
unique_tags: Set[str] = set()
coordinates: Tuple[float, float] = (0.0, 0.0)
values: Tuple[int, ...] = ()
permissions: FrozenSet[str] = frozenset()
data = CollectionTypes(
tags=["python", "fastapi", "tutorial"],
scores=[95, 87, 92, 88],
metadata={"author": "Alice", "version": "2.0"},
unique_tags=["python", "python", "fastapi"], # duplicates removed
coordinates=(40.7128, -74.0060),
values=(1, 2, 3, 4, 5),
permissions=["read", "write", "read"] # duplicates removed
)
print(data.unique_tags) # {'python', 'fastapi'}
print(data.permissions) # frozenset({'read', 'write'})
from pydantic import BaseModel
from typing import Optional, Union, Literal
class AdvancedTypes(BaseModel):
nickname: Optional[str] = None
identifier: Union[int, str]
role: Literal["admin", "user", "moderator"]
priority: Literal[1, 2, 3, "high", "medium", "low"]
# Valid
data = AdvancedTypes(identifier=42, role="admin", priority="high")
print(data)
# Union tries types in order
data2 = AdvancedTypes(identifier="user-abc-123", role="user", priority=2)
print(data2.identifier) # user-abc-123
# Invalid literal
try:
AdvancedTypes(identifier=1, role="superadmin", priority=5)
except Exception as e:
print(e)
from pydantic import BaseModel, Field
from typing import Annotated
from pydantic.functional_validators import AfterValidator
# Create reusable validated types with Annotated
PositiveName = Annotated[str, Field(min_length=1, max_length=100)]
Percentage = Annotated[float, Field(ge=0.0, le=100.0)]
Port = Annotated[int, Field(ge=1, le=65535)]
def normalize_email(v: str) -> str:
return v.lower().strip()
NormalizedEmail = Annotated[str, AfterValidator(normalize_email)]
class ServerConfig(BaseModel):
name: PositiveName
host: str
port: Port
cpu_usage: Percentage
admin_email: NormalizedEmail
config = ServerConfig(
name="Production Server",
host="api.example.com",
port=8080,
cpu_usage=67.5,
admin_email=" ADMIN@Example.COM "
)
print(config.admin_email) # admin@example.com (normalized!)
print(config.port) # 8080
Annotated types are the recommended way in Pydantic v2 to create reusable field definitions. They keep your models clean and ensure consistent validation across your codebase.Real-world data is rarely flat. Pydantic excels at validating deeply nested structures by composing models together. Each nested model is fully validated independently.
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
class Address(BaseModel):
street: str = Field(min_length=1, max_length=200)
city: str = Field(min_length=1, max_length=100)
state: str = Field(min_length=2, max_length=2)
zip_code: str = Field(pattern=r"^\d{5}(-\d{4})?$")
country: str = "US"
class ContactInfo(BaseModel):
email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w{2,}$")
phone: Optional[str] = Field(default=None, pattern=r"^\+?1?\d{10,15}$")
address: Address
class User(BaseModel):
id: int
name: str = Field(min_length=1, max_length=100)
contact: ContactInfo
created_at: datetime = Field(default_factory=datetime.now)
# Create with nested data
user = User(
id=1,
name="Alice Johnson",
contact={
"email": "alice@example.com",
"phone": "+15551234567",
"address": {
"street": "123 Main St",
"city": "Springfield",
"state": "IL",
"zip_code": "62701"
}
}
)
# Access nested fields
print(user.contact.email) # alice@example.com
print(user.contact.address.city) # Springfield
print(user.contact.address.zip_code) # 62701
# Nested validation works automatically
try:
bad_user = User(
id=2,
name="Bob",
contact={
"email": "bob@example.com",
"address": {
"street": "456 Oak Ave",
"city": "Chicago",
"state": "Illinois", # too long - must be 2 chars
"zip_code": "abc" # does not match pattern
}
}
)
except Exception as e:
print(e)
# Shows errors at contact -> address -> state and zip_code
from pydantic import BaseModel, Field
from typing import List, Optional
from enum import Enum
class OrderStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
SHIPPED = "shipped"
DELIVERED = "delivered"
CANCELLED = "cancelled"
class OrderItem(BaseModel):
product_name: str = Field(min_length=1)
quantity: int = Field(gt=0, le=1000)
unit_price: float = Field(gt=0)
@property
def total(self) -> float:
return round(self.quantity * self.unit_price, 2)
class ShippingAddress(BaseModel):
name: str
street: str
city: str
state: str
zip_code: str
class Order(BaseModel):
order_id: str = Field(pattern=r"^ORD-\d{6,10}$")
customer_id: int = Field(gt=0)
items: List[OrderItem] = Field(min_length=1, max_length=50)
shipping: ShippingAddress
status: OrderStatus = OrderStatus.PENDING
notes: Optional[str] = None
@property
def order_total(self) -> float:
return round(sum(item.total for item in self.items), 2)
@property
def item_count(self) -> int:
return sum(item.quantity for item in self.items)
# Create a complete order
order = Order(
order_id="ORD-123456",
customer_id=42,
items=[
{"product_name": "Laptop", "quantity": 1, "unit_price": 999.99},
{"product_name": "Mouse", "quantity": 2, "unit_price": 29.99},
{"product_name": "Keyboard", "quantity": 1, "unit_price": 79.99},
],
shipping={
"name": "Alice Johnson",
"street": "123 Main St",
"city": "Springfield",
"state": "IL",
"zip_code": "62701"
},
notes="Please leave at front door"
)
print(f"Order: {order.order_id}")
print(f"Items: {len(order.items)}")
print(f"Total items: {order.item_count}")
print(f"Order total: ${order.order_total}")
# Order total: $1139.96
from __future__ import annotations
from pydantic import BaseModel, Field
from typing import Optional, List
class Comment(BaseModel):
id: int
author: str
text: str = Field(min_length=1, max_length=5000)
replies: List[Comment] = [] # self-referential!
@property
def reply_count(self) -> int:
count = len(self.replies)
for reply in self.replies:
count += reply.reply_count
return count
class Category(BaseModel):
name: str
slug: str
children: List[Category] = []
def find(self, slug: str) -> Optional[Category]:
if self.slug == slug:
return self
for child in self.children:
result = child.find(slug)
if result:
return result
return None
# Build a comment tree
thread = Comment(
id=1,
author="Alice",
text="Great tutorial on Pydantic!",
replies=[
{
"id": 2,
"author": "Bob",
"text": "Thanks Alice! I learned a lot.",
"replies": [
{"id": 3, "author": "Alice", "text": "Glad to hear that!"}
]
},
{"id": 4, "author": "Charlie", "text": "Could you explain validators more?"}
]
)
print(f"Total replies: {thread.reply_count}") # Total replies: 3
# Build a category tree
categories = Category(
name="Programming",
slug="programming",
children=[
{
"name": "Python",
"slug": "python",
"children": [
{"name": "FastAPI", "slug": "fastapi"},
{"name": "Django", "slug": "django"},
]
},
{
"name": "JavaScript",
"slug": "javascript",
"children": [
{"name": "React", "slug": "react"},
{"name": "Node.js", "slug": "nodejs"},
]
}
]
)
found = categories.find("fastapi")
if found:
print(f"Found: {found.name}") # Found: FastAPI
While built-in constraints handle most cases, custom validators let you implement any validation logic you need. Pydantic v2 provides @field_validator for single fields and @model_validator for cross-field validation.
from pydantic import BaseModel, Field, field_validator
class UserRegistration(BaseModel):
username: str = Field(min_length=3, max_length=30)
email: str
password: str = Field(min_length=8)
age: int = Field(ge=13)
@field_validator("username")
@classmethod
def username_must_be_alphanumeric(cls, v: str) -> str:
if not v.replace("_", "").isalnum():
raise ValueError("Username must contain only letters, numbers, and underscores")
if v[0].isdigit():
raise ValueError("Username cannot start with a number")
return v.lower() # normalize to lowercase
@field_validator("email")
@classmethod
def validate_email(cls, v: str) -> str:
if "@" not in v:
raise ValueError("Invalid email address")
local, domain = v.rsplit("@", 1)
if not local or not domain or "." not in domain:
raise ValueError("Invalid email address")
return v.lower().strip()
@field_validator("password")
@classmethod
def password_strength(cls, v: str) -> str:
if not any(c.isupper() for c in v):
raise ValueError("Password must contain at least one uppercase letter")
if not any(c.islower() for c in v):
raise ValueError("Password must contain at least one lowercase letter")
if not any(c.isdigit() for c in v):
raise ValueError("Password must contain at least one digit")
if not any(c in "!@#$%^&*()_+-=[]{}|;:,./?" for c in v):
raise ValueError("Password must contain at least one special character")
return v
# Valid registration
user = UserRegistration(
username="Alice_Dev",
email="ALICE@Example.COM",
password="Str0ng!Pass",
age=25
)
print(user.username) # alice_dev (normalized)
print(user.email) # alice@example.com (normalized)
# Invalid registration
try:
UserRegistration(
username="123invalid",
email="not-email",
password="weakpass",
age=10
)
except Exception as e:
print(e)
from pydantic import BaseModel, field_validator
class DataProcessor(BaseModel):
value: int
tags: list[str]
name: str
# BEFORE validator - runs before Pydantic's type validation
@field_validator("value", mode="before")
@classmethod
def parse_value(cls, v):
if isinstance(v, str):
cleaned = v.strip().replace("$", "").replace("%", "").replace(",", "")
try:
return int(float(cleaned))
except ValueError:
raise ValueError(f"Cannot parse '{v}' as a number")
return v
# AFTER validator (default) - runs after type validation
@field_validator("tags")
@classmethod
def normalize_tags(cls, v: list[str]) -> list[str]:
seen = set()
result = []
for tag in v:
normalized = tag.lower().strip()
if normalized and normalized not in seen:
seen.add(normalized)
result.append(normalized)
return result
# BEFORE validator for type coercion
@field_validator("name", mode="before")
@classmethod
def coerce_name(cls, v):
if not isinstance(v, str):
return str(v)
return v
# Before validator handles type conversion
data = DataProcessor(
value="$1,234",
tags=["Python", "PYTHON", " fastapi ", "FastAPI", "tutorial"],
name=42
)
print(data.value) # 1234
print(data.tags) # ['python', 'fastapi', 'tutorial']
print(data.name) # 42 (converted to string)
from pydantic import BaseModel, Field, field_validator
class PriceRange(BaseModel):
min_price: float = Field(ge=0)
max_price: float = Field(ge=0)
currency: str = Field(min_length=3, max_length=3)
# Apply one validator to multiple fields
@field_validator("min_price", "max_price")
@classmethod
def round_price(cls, v: float) -> float:
return round(v, 2)
@field_validator("currency")
@classmethod
def uppercase_currency(cls, v: str) -> str:
return v.upper()
data = PriceRange(min_price=10.555, max_price=99.999, currency="usd")
print(data.min_price) # 10.56
print(data.max_price) # 100.0
print(data.currency) # USD
from pydantic import BaseModel, Field, model_validator
from datetime import date
from typing import Optional
class DateRange(BaseModel):
start_date: date
end_date: date
@model_validator(mode="after")
def validate_date_range(self):
if self.start_date >= self.end_date:
raise ValueError("start_date must be before end_date")
return self
class PasswordChange(BaseModel):
current_password: str
new_password: str = Field(min_length=8)
confirm_password: str
@model_validator(mode="after")
def validate_passwords(self):
if self.new_password != self.confirm_password:
raise ValueError("new_password and confirm_password must match")
if self.new_password == self.current_password:
raise ValueError("New password must be different from current password")
return self
class DiscountRule(BaseModel):
discount_type: str # "percentage" or "fixed"
discount_value: float
min_order: float = 0.0
max_discount: Optional[float] = None
@model_validator(mode="after")
def validate_discount(self):
if self.discount_type == "percentage":
if not (0 < self.discount_value <= 100):
raise ValueError("Percentage discount must be between 0 and 100")
elif self.discount_type == "fixed":
if self.discount_value <= 0:
raise ValueError("Fixed discount must be positive")
else:
raise ValueError("discount_type must be 'percentage' or 'fixed'")
return self
# Valid date range
dr = DateRange(start_date="2024-01-01", end_date="2024-12-31")
print(dr)
# Invalid - passwords don't match
try:
PasswordChange(
current_password="old",
new_password="NewStr0ng!Pass",
confirm_password="DifferentPass"
)
except Exception as e:
print(e)
from pydantic import BaseModel, model_validator
from typing import Any
class FlexibleInput(BaseModel):
name: str
value: float
@model_validator(mode="before")
@classmethod
def normalize_input(cls, data: Any) -> dict:
# Handle string input like "name:value"
if isinstance(data, str):
parts = data.split(":")
if len(parts) != 2:
raise ValueError("String input must be 'name:value' format")
return {"name": parts[0].strip(), "value": float(parts[1].strip())}
# Handle tuple/list input
if isinstance(data, (list, tuple)):
if len(data) != 2:
raise ValueError("Sequence must have exactly 2 elements")
return {"name": str(data[0]), "value": float(data[1])}
# Handle dict with alternative key names
if isinstance(data, dict):
if "label" in data and "name" not in data:
data["name"] = data.pop("label")
if "amount" in data and "value" not in data:
data["value"] = data.pop("amount")
return data
# All of these work:
m1 = FlexibleInput(name="temperature", value=98.6)
m2 = FlexibleInput.model_validate("humidity:65.0")
m3 = FlexibleInput.model_validate(["pressure", 1013.25])
m4 = FlexibleInput.model_validate({"label": "wind", "amount": 15.5})
print(m1) # name='temperature' value=98.6
print(m2) # name='humidity' value=65.0
print(m3) # name='pressure' value=1013.25
print(m4) # name='wind' value=15.5
from pydantic import BaseModel, Field, ValidationError, field_validator
class StrictUser(BaseModel):
name: str = Field(min_length=2, max_length=50)
age: int = Field(ge=0, le=150)
email: str
@field_validator("email")
@classmethod
def validate_email(cls, v):
if "@" not in v:
raise ValueError("Must contain @")
return v
# Trigger multiple validation errors
try:
StrictUser(name="A", age=-5, email="invalid")
except ValidationError as e:
print(f"Error count: {e.error_count()}")
for error in e.errors():
field = ' -> '.join(str(loc) for loc in error['loc'])
print(f"\nField: {field}")
print(f" Type: {error['type']}")
print(f" Message: {error['msg']}")
print(f" Input: {error['input']}")
# JSON format for API responses
import json
print(json.dumps(e.errors(), indent=2, default=str))
Computed fields are derived from other field values and are included automatically in serialization and JSON schema output. Pydantic v2 introduced the @computed_field decorator for this purpose.
from pydantic import BaseModel, Field, computed_field
class Product(BaseModel):
name: str
base_price: float = Field(gt=0)
tax_rate: float = Field(ge=0, le=1, default=0.08)
discount: float = Field(ge=0, le=1, default=0)
@computed_field
@property
def discount_amount(self) -> float:
return round(self.base_price * self.discount, 2)
@computed_field
@property
def subtotal(self) -> float:
return round(self.base_price - self.discount_amount, 2)
@computed_field
@property
def tax_amount(self) -> float:
return round(self.subtotal * self.tax_rate, 2)
@computed_field
@property
def total_price(self) -> float:
return round(self.subtotal + self.tax_amount, 2)
product = Product(name="Laptop", base_price=999.99, discount=0.15)
print(f"Base price: ${product.base_price}")
print(f"Discount: -${product.discount_amount}")
print(f"Subtotal: ${product.subtotal}")
print(f"Tax: +${product.tax_amount}")
print(f"Total: ${product.total_price}")
# Computed fields are included in model_dump()!
data = product.model_dump()
print(data)
# Includes: discount_amount, subtotal, tax_amount, total_price
from pydantic import BaseModel, Field, computed_field
from datetime import date
from typing import List
class Student(BaseModel):
first_name: str
last_name: str
birth_date: date
grades: List[float] = Field(default=[])
@computed_field
@property
def full_name(self) -> str:
return f"{self.first_name} {self.last_name}"
@computed_field
@property
def age(self) -> int:
today = date.today()
born = self.birth_date
return today.year - born.year - (
(today.month, today.day) < (born.month, born.day)
)
@computed_field
@property
def gpa(self) -> float:
if not self.grades:
return 0.0
return round(sum(self.grades) / len(self.grades), 2)
@computed_field
@property
def letter_grade(self) -> str:
gpa = self.gpa
if gpa >= 90: return "A"
elif gpa >= 80: return "B"
elif gpa >= 70: return "C"
elif gpa >= 60: return "D"
return "F"
@computed_field
@property
def honor_roll(self) -> bool:
return self.gpa >= 85.0 and len(self.grades) >= 3
student = Student(
first_name="Alice",
last_name="Johnson",
birth_date="2000-05-15",
grades=[92, 88, 95, 91, 87]
)
print(student.model_dump())
# Includes: full_name, age, gpa, letter_grade, honor_roll
Pydantic v2 uses model_config (a ConfigDict) to control model behavior — how it handles extra fields, whether it strips whitespace, whether it can read from ORM objects, and more.
from pydantic import BaseModel, ConfigDict
class StrictModel(BaseModel):
model_config = ConfigDict(
extra="forbid", # Reject extra fields
str_strip_whitespace=True, # Strip whitespace from strings
str_to_lower=True, # Convert strings to lowercase
frozen=False, # Allow mutation
use_enum_values=True, # Use enum values in serialization
validate_default=True, # Validate default values
)
name: str
email: str
role: str = "user"
# str_strip_whitespace and str_to_lower in action
user = StrictModel(
name=" ALICE JOHNSON ",
email=" Alice@Example.COM "
)
print(user.name) # alice johnson (stripped and lowered)
print(user.email) # alice@example.com
# extra="forbid" rejects unknown fields
try:
StrictModel(name="Bob", email="bob@test.com", unknown_field="value")
except Exception as e:
print(e)
# Extra inputs are not permitted
from pydantic import BaseModel, ConfigDict
class ForbidExtra(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str
class IgnoreExtra(BaseModel):
model_config = ConfigDict(extra="ignore")
name: str
class AllowExtra(BaseModel):
model_config = ConfigDict(extra="allow")
name: str
data = {"name": "Alice", "age": 30, "role": "admin"}
# forbid - raises error
try:
ForbidExtra(**data)
except Exception as e:
print(f"Forbid: {e}")
# ignore - silently drops extra fields
m2 = IgnoreExtra(**data)
print(f"Ignore: {m2.model_dump()}")
# {'name': 'Alice'}
# allow - keeps extra fields
m3 = AllowExtra(**data)
print(f"Allow: {m3.model_dump()}")
# {'name': 'Alice', 'age': 30, 'role': 'admin'}
from pydantic import BaseModel, ConfigDict
# Simulating an ORM model (like SQLAlchemy)
class UserORM:
def __init__(self, id, name, email, is_active):
self.id = id
self.name = name
self.email = email
self.is_active = is_active
class UserSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
name: str
email: str
is_active: bool
# Create from ORM object
orm_user = UserORM(id=1, name="Alice", email="alice@example.com", is_active=True)
# model_validate reads attributes from the object
user_schema = UserSchema.model_validate(orm_user)
print(user_schema)
# id=1 name='Alice' email='alice@example.com' is_active=True
print(user_schema.model_dump())
# {'id': 1, 'name': 'Alice', 'email': 'alice@example.com', 'is_active': True}
# This is essential for FastAPI + SQLAlchemy integration
from pydantic import BaseModel, Field, ConfigDict
import json
class CreateUserRequest(BaseModel):
model_config = ConfigDict(
json_schema_extra={
"examples": [
{
"username": "alice_dev",
"email": "alice@example.com",
"full_name": "Alice Johnson",
"age": 28
}
]
}
)
username: str = Field(
min_length=3, max_length=30,
description="Unique username for the account"
)
email: str = Field(description="Valid email address")
full_name: str = Field(min_length=1, max_length=100)
age: int = Field(ge=13, le=150, description="Must be 13 or older")
# View the generated JSON Schema
schema = CreateUserRequest.model_json_schema()
print(json.dumps(schema, indent=2))
# This schema is automatically used by FastAPI for Swagger UI
from pydantic import BaseModel, ConfigDict
class ImmutableConfig(BaseModel):
model_config = ConfigDict(frozen=True)
host: str
port: int
debug: bool = False
config = ImmutableConfig(host="localhost", port=8000, debug=True)
print(config.host) # localhost
# Cannot modify frozen model
try:
config.host = "production.example.com"
except Exception as e:
print(e) # Instance is frozen
# Create a new instance with changes using model_copy
new_config = config.model_copy(
update={"host": "production.example.com", "debug": False}
)
print(new_config.host) # production.example.com
print(config.host) # localhost (original unchanged)
Serialization is the process of converting your Pydantic model to a dictionary, JSON, or other format for output. Pydantic v2 provides powerful serialization controls through model_dump() and model_dump_json().
from pydantic import BaseModel, Field
from typing import Optional, List
from datetime import datetime
class UserProfile(BaseModel):
id: int
username: str
email: str
full_name: Optional[str] = None
bio: Optional[str] = None
avatar_url: Optional[str] = None
is_active: bool = True
role: str = "user"
login_count: int = 0
created_at: datetime = Field(default_factory=datetime.now)
tags: List[str] = []
user = UserProfile(
id=1,
username="alice",
email="alice@example.com",
full_name="Alice Johnson",
bio="Software developer",
tags=["python", "fastapi"]
)
# Exclude specific fields
public_data = user.model_dump(exclude={"email", "login_count", "is_active"})
print(public_data)
# Include only specific fields
summary = user.model_dump(include={"id", "username", "full_name"})
print(summary)
# {'id': 1, 'username': 'alice', 'full_name': 'Alice Johnson'}
# Exclude None values
clean_data = user.model_dump(exclude_none=True)
print(clean_data) # avatar_url won't appear
# Exclude unset values (only fields explicitly provided)
explicit_data = user.model_dump(exclude_unset=True)
print(explicit_data)
# Only id, username, email, full_name, bio, tags appear
# Exclude defaults
non_default = user.model_dump(exclude_defaults=True)
print(non_default)
from pydantic import BaseModel
from typing import Optional
class Address(BaseModel):
street: str
city: str
state: str
zip_code: str
country: str = "US"
class ContactInfo(BaseModel):
email: str
phone: Optional[str] = None
address: Address
class UserFull(BaseModel):
id: int
name: str
password_hash: str
contact: ContactInfo
internal_notes: Optional[str] = None
user = UserFull(
id=1,
name="Alice",
password_hash="$2b$12$abc123...",
contact={
"email": "alice@example.com",
"phone": "+15551234567",
"address": {
"street": "123 Main St",
"city": "Springfield",
"state": "IL",
"zip_code": "62701"
}
},
internal_notes="VIP customer"
)
# Exclude nested fields using dict notation
public_data = user.model_dump(
exclude={
"password_hash": True,
"internal_notes": True,
"contact": {
"phone": True,
"address": {"zip_code", "country"}
}
}
)
print(public_data)
# password_hash, internal_notes, phone, zip_code, country all excluded
# Include only specific nested fields
minimal = user.model_dump(
include={
"id": True,
"name": True,
"contact": {"email": True}
}
)
print(minimal)
# {'id': 1, 'name': 'Alice', 'contact': {'email': 'alice@example.com'}}
from pydantic import BaseModel, Field
from datetime import datetime, date
from typing import List
from decimal import Decimal
from uuid import UUID, uuid4
class Invoice(BaseModel):
id: UUID = Field(default_factory=uuid4)
customer: str
amount: Decimal
tax: Decimal
issue_date: date
due_date: date
items: List[str]
created_at: datetime = Field(default_factory=datetime.now)
invoice = Invoice(
customer="Acme Corp",
amount=Decimal("1500.00"),
tax=Decimal("120.00"),
issue_date="2024-03-15",
due_date="2024-04-15",
items=["Consulting", "Development", "Testing"]
)
# model_dump_json() handles special types automatically
json_str = invoice.model_dump_json(indent=2)
print(json_str)
# model_dump() vs model_dump_json() with special types
dict_data = invoice.model_dump()
print(type(dict_data["id"])) # UUID object
print(type(dict_data["amount"])) # Decimal object
# model_dump with mode="json" converts to JSON-compatible types
json_dict = invoice.model_dump(mode="json")
print(type(json_dict["id"])) # str
print(type(json_dict["amount"])) # str
from pydantic import BaseModel, field_serializer
from datetime import datetime
from decimal import Decimal
class Transaction(BaseModel):
id: int
amount: Decimal
currency: str = "USD"
timestamp: datetime
description: str
@field_serializer("amount")
def serialize_amount(self, value: Decimal) -> str:
return f"{value:.2f}"
@field_serializer("timestamp")
def serialize_timestamp(self, value: datetime) -> str:
return value.strftime("%Y-%m-%dT%H:%M:%SZ")
@field_serializer("description")
def serialize_description(self, value: str) -> str:
if len(value) > 100:
return value[:97] + "..."
return value
tx = Transaction(
id=1,
amount=Decimal("1234.5"),
timestamp=datetime(2024, 3, 15, 14, 30, 0),
description="Payment for consulting services"
)
print(tx.model_dump())
# amount shows as '1234.50', timestamp as '2024-03-15T14:30:00Z'
from pydantic import BaseModel, Field, ConfigDict
class APIItem(BaseModel):
model_config = ConfigDict(populate_by_name=True)
item_id: int = Field(alias="itemId")
item_name: str = Field(alias="itemName")
unit_price: float = Field(alias="unitPrice")
is_available: bool = Field(alias="isAvailable", default=True)
# Parse from camelCase (API input)
item = APIItem(**{"itemId": 1, "itemName": "Widget", "unitPrice": 9.99})
# Or parse using Python names (with populate_by_name=True)
item2 = APIItem(item_id=2, item_name="Gadget", unit_price=19.99)
# Serialize with Python names (default)
print(item.model_dump())
# {'item_id': 1, 'item_name': 'Widget', ...}
# Serialize with camelCase aliases
print(item.model_dump(by_alias=True))
# {'itemId': 1, 'itemName': 'Widget', ...}
Now that you understand Pydantic models, let us see how FastAPI integrates them for automatic request validation. FastAPI uses Pydantic to validate request bodies, query parameters, path parameters, and headers.
from fastapi import FastAPI, Body
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List
from datetime import datetime
app = FastAPI()
class CreateProductRequest(BaseModel):
name: str = Field(min_length=1, max_length=200, description="Product name")
description: Optional[str] = Field(default=None, max_length=5000)
price: float = Field(gt=0, le=1_000_000, description="Price in USD")
quantity: int = Field(ge=0, le=100_000, default=0)
tags: List[str] = Field(default=[], max_length=20)
sku: str = Field(pattern=r"^[A-Z]{2,4}-\d{4,8}$")
@field_validator("tags")
@classmethod
def normalize_tags(cls, v):
return [tag.lower().strip() for tag in v if tag.strip()]
class ProductResponse(BaseModel):
id: int
name: str
description: Optional[str]
price: float
quantity: int
tags: List[str]
sku: str
created_at: datetime
@app.post("/products", response_model=ProductResponse, status_code=201)
async def create_product(product: CreateProductRequest):
# FastAPI automatically:
# 1. Reads the JSON request body
# 2. Validates it against CreateProductRequest
# 3. Returns 422 with details if validation fails
# 4. Provides the validated model to your function
return ProductResponse(
id=1,
**product.model_dump(),
created_at=datetime.now()
)
# Test with curl:
curl -X POST http://localhost:8000/products \
-H "Content-Type: application/json" \
-d '{
"name": "Wireless Mouse",
"price": 29.99,
"quantity": 100,
"tags": ["electronics", "accessories"],
"sku": "WM-12345"
}'
from fastapi import FastAPI, Query
from typing import Optional, List
from enum import Enum
app = FastAPI()
class SortOrder(str, Enum):
ASC = "asc"
DESC = "desc"
class SortField(str, Enum):
NAME = "name"
PRICE = "price"
CREATED = "created_at"
@app.get("/products")
async def list_products(
search: Optional[str] = Query(default=None, min_length=1, max_length=100),
page: int = Query(default=1, ge=1, le=10000),
page_size: int = Query(default=20, ge=1, le=100),
min_price: Optional[float] = Query(default=None, ge=0),
max_price: Optional[float] = Query(default=None, ge=0),
sort_by: SortField = Query(default=SortField.CREATED),
sort_order: SortOrder = Query(default=SortOrder.DESC),
tags: Optional[List[str]] = Query(default=None, max_length=10),
in_stock: Optional[bool] = Query(default=None),
):
# All query params are validated automatically
# Example: /products?search=mouse&page=1&min_price=10&sort_by=price
return {
"search": search,
"page": page,
"page_size": page_size,
"filters": {
"min_price": min_price,
"max_price": max_price,
"sort_by": sort_by,
"sort_order": sort_order,
"tags": tags,
"in_stock": in_stock
}
}
from fastapi import FastAPI, Path
from enum import Enum
from uuid import UUID
app = FastAPI()
class ResourceType(str, Enum):
USERS = "users"
PRODUCTS = "products"
ORDERS = "orders"
@app.get("/items/{item_id}")
async def get_item(
item_id: int = Path(..., gt=0, le=1_000_000, description="The item ID")
):
return {"item_id": item_id}
@app.get("/users/{user_id}/orders/{order_id}")
async def get_user_order(
user_id: int = Path(..., gt=0),
order_id: int = Path(..., gt=0),
):
return {"user_id": user_id, "order_id": order_id}
# UUID path parameter
@app.get("/resources/{resource_id}")
async def get_resource(resource_id: UUID):
return {"resource_id": str(resource_id)}
# Enum path parameter
@app.get("/api/{resource_type}/{resource_id}")
async def get_any_resource(
resource_type: ResourceType,
resource_id: int = Path(..., gt=0)
):
return {"resource_type": resource_type.value, "resource_id": resource_id}
from fastapi import FastAPI, Header, Cookie
from typing import Optional, List
app = FastAPI()
@app.get("/protected")
async def protected_route(
authorization: str = Header(..., alias="Authorization"),
accept_language: str = Header(default="en", alias="Accept-Language"),
x_request_id: Optional[str] = Header(default=None, alias="X-Request-ID"),
):
return {
"auth": authorization[:20] + "...",
"language": accept_language,
"request_id": x_request_id,
}
@app.get("/preferences")
async def get_preferences(
session_id: Optional[str] = Cookie(default=None),
theme: str = Cookie(default="light"),
language: str = Cookie(default="en"),
):
return {"session_id": session_id, "theme": theme, "language": language}
from fastapi import FastAPI, Body
from pydantic import BaseModel, Field
app = FastAPI()
class Item(BaseModel):
name: str = Field(min_length=1)
price: float = Field(gt=0)
description: str = ""
class UserInfo(BaseModel):
username: str
email: str
@app.post("/purchase")
async def make_purchase(
item: Item,
user: UserInfo,
quantity: int = Body(ge=1, le=100),
notes: str = Body(default=""),
):
# Expected JSON body:
# {
# "item": {"name": "Widget", "price": 9.99},
# "user": {"username": "alice", "email": "alice@example.com"},
# "quantity": 3,
# "notes": "Gift wrap please"
# }
total = item.price * quantity
return {
"item": item.model_dump(),
"user": user.model_dump(),
"quantity": quantity,
"total": total,
"notes": notes
}
# Force single model to use key in body with embed=True
@app.post("/items")
async def create_item(item: Item = Body(embed=True)):
# With embed=True, expects: {"item": {"name": "Widget", "price": 9.99}}
return item
Response models control what data your API returns. They ensure sensitive fields are never leaked, provide consistent output structures, and generate accurate API documentation.
from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import Optional, List
from datetime import datetime
app = FastAPI()
# Input model - what the client sends
class UserCreate(BaseModel):
username: str = Field(min_length=3, max_length=30)
email: str
password: str = Field(min_length=8)
full_name: Optional[str] = None
# Output model - what the client receives (no password!)
class UserResponse(BaseModel):
id: int
username: str
email: str
full_name: Optional[str] = None
is_active: bool
created_at: datetime
fake_db = {}
next_id = 1
@app.post("/users", response_model=UserResponse, status_code=201)
async def create_user(user: UserCreate):
# response_model=UserResponse ensures:
# 1. Only fields in UserResponse are returned
# 2. password is NEVER included in the response
# 3. Swagger docs show the correct response schema
global next_id
db_user = {
"id": next_id,
"username": user.username,
"email": user.email,
"full_name": user.full_name,
"password_hash": f"hashed_{user.password}",
"is_active": True,
"created_at": datetime.now(),
"internal_notes": "New user"
}
fake_db[next_id] = db_user
next_id += 1
# Even though db_user has password_hash and internal_notes,
# response_model filters them out!
return db_user
@app.get("/users", response_model=List[UserResponse])
async def list_users():
return list(fake_db.values())
from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
app = FastAPI()
# Shared base
class ProductBase(BaseModel):
name: str = Field(min_length=1, max_length=200)
description: Optional[str] = None
price: float = Field(gt=0)
category: str
# Create input
class ProductCreate(ProductBase):
sku: str = Field(pattern=r"^[A-Z]{2,4}-\d{4,8}$")
initial_stock: int = Field(ge=0, default=0)
# Update input - all fields optional
class ProductUpdate(BaseModel):
name: Optional[str] = Field(default=None, min_length=1, max_length=200)
description: Optional[str] = None
price: Optional[float] = Field(default=None, gt=0)
category: Optional[str] = None
# Full response
class ProductResponse(ProductBase):
id: int
sku: str
stock: int
is_active: bool
created_at: datetime
updated_at: datetime
# Summary response (for list views)
class ProductSummary(BaseModel):
id: int
name: str
price: float
category: str
in_stock: bool
@app.post("/products", response_model=ProductResponse, status_code=201)
async def create_product(product: ProductCreate):
return {
"id": 1,
**product.model_dump(),
"stock": product.initial_stock,
"is_active": True,
"created_at": datetime.now(),
"updated_at": datetime.now()
}
@app.patch("/products/{product_id}", response_model=ProductResponse)
async def update_product(product_id: int, product: ProductUpdate):
updates = product.model_dump(exclude_unset=True)
# In real app: apply updates to database
return {
"id": product_id,
"name": updates.get("name", "Existing Product"),
"description": updates.get("description"),
"price": updates.get("price", 29.99),
"category": updates.get("category", "general"),
"sku": "PROD-1234",
"stock": 50,
"is_active": True,
"created_at": datetime.now(),
"updated_at": datetime.now()
}
@app.get("/products", response_model=list[ProductSummary])
async def list_products():
return [
{"id": 1, "name": "Wireless Mouse", "price": 29.99,
"category": "electronics", "in_stock": True},
{"id": 2, "name": "Keyboard", "price": 59.99,
"category": "electronics", "in_stock": False}
]
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional
app = FastAPI()
class UserOut(BaseModel):
id: int
name: str
email: str
bio: Optional[str] = None
avatar: Optional[str] = None
settings: dict = {}
# Exclude None values from response
@app.get("/users/{user_id}", response_model=UserOut,
response_model_exclude_none=True)
async def get_user(user_id: int):
return {
"id": user_id, "name": "Alice", "email": "alice@example.com",
"bio": None, "avatar": None, "settings": {}
}
# Response: {"id": 1, "name": "Alice", "email": "alice@example.com", "settings": {}}
# Exclude specific fields
@app.get("/users/{user_id}/public", response_model=UserOut,
response_model_exclude={"email", "settings"})
async def get_public_user(user_id: int):
return {
"id": user_id, "name": "Alice", "email": "alice@example.com",
"bio": "Developer", "settings": {"theme": "dark"}
}
# Response: {"id": 1, "name": "Alice", "bio": "Developer", "avatar": null}
# Include only specific fields
@app.get("/users/{user_id}/summary", response_model=UserOut,
response_model_include={"id", "name"})
async def get_user_summary(user_id: int):
return {"id": user_id, "name": "Alice", "email": "alice@example.com"}
# Response: {"id": 1, "name": "Alice"}
When validation fails, FastAPI automatically returns a 422 Unprocessable Entity response with detailed error information. You can customize this behavior to return errors in your preferred format.
from fastapi import FastAPI
from pydantic import BaseModel, Field
app = FastAPI()
class UserCreate(BaseModel):
name: str = Field(min_length=2, max_length=50)
age: int = Field(ge=0, le=150)
email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w{2,}$")
@app.post("/users")
async def create_user(user: UserCreate):
return user
# Sending invalid data returns 422 with:
# {
# "detail": [
# {
# "type": "string_too_short",
# "loc": ["body", "name"],
# "msg": "String should have at least 2 characters",
# "input": "A",
# "ctx": {"min_length": 2}
# },
# ...
# ]
# }
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from datetime import datetime
app = FastAPI()
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request,
exc: RequestValidationError
):
errors = []
for error in exc.errors():
field_path = " -> ".join(
str(loc) for loc in error["loc"] if loc != "body"
)
errors.append({
"field": field_path,
"message": error["msg"],
"type": error["type"],
"value": error.get("input"),
})
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"status": "error",
"message": "Validation failed",
"error_count": len(errors),
"errors": errors,
"timestamp": datetime.now().isoformat(),
"path": str(request.url)
}
)
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel, Field
app = FastAPI()
# Simulated database
existing_usernames = {"alice", "bob", "admin"}
existing_emails = {"alice@example.com", "bob@example.com"}
class UserCreate(BaseModel):
username: str = Field(min_length=3, max_length=30)
email: str
password: str = Field(min_length=8)
@app.post("/users", status_code=201)
async def create_user(user: UserCreate):
# Pydantic validates structure; we validate business rules
errors = []
if user.username.lower() in existing_usernames:
errors.append({
"field": "username",
"message": f"Username '{user.username}' is already taken"
})
if user.email.lower() in existing_emails:
errors.append({
"field": "email",
"message": f"Email '{user.email}' is already registered"
})
if errors:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"status": "error",
"message": "Registration failed",
"errors": errors
}
)
return {"status": "success", "message": f"User '{user.username}' created"}
from fastapi import FastAPI, Request, HTTPException, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
app = FastAPI()
class ErrorResponse:
@staticmethod
def build(status_code: int, message: str, errors: list = None,
path: str = None) -> dict:
return {
"status": "error",
"status_code": status_code,
"message": message,
"errors": errors or [],
"timestamp": datetime.now().isoformat(),
"path": path
}
# Handle Pydantic validation errors (422)
@app.exception_handler(RequestValidationError)
async def request_validation_handler(request: Request, exc: RequestValidationError):
errors = []
for error in exc.errors():
loc = [str(l) for l in error["loc"] if l != "body"]
errors.append({
"field": ".".join(loc) if loc else "unknown",
"message": error["msg"],
"type": error["type"]
})
return JSONResponse(
status_code=422,
content=ErrorResponse.build(422, "Request validation failed",
errors, str(request.url))
)
# Handle HTTP exceptions (400, 401, 403, 404, etc.)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse.build(
exc.status_code,
str(exc.detail) if isinstance(exc.detail, str) else "HTTP Error",
[exc.detail] if isinstance(exc.detail, dict) else [],
str(request.url)
)
)
# Handle unexpected exceptions (500)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled error: {exc}")
return JSONResponse(
status_code=500,
content=ErrorResponse.build(500, "An unexpected error occurred",
path=str(request.url))
)
Generic models let you create reusable response wrappers — like paginated responses, API envelopes, or standardized result types — that work with any data model. This is one of Pydantic’s most powerful features for building consistent APIs.
from pydantic import BaseModel
from typing import Generic, TypeVar, Optional, List
T = TypeVar("T")
class APIResponse(BaseModel, Generic[T]):
success: bool = True
message: str = "OK"
data: Optional[T] = None
class UserOut(BaseModel):
id: int
name: str
email: str
class ProductOut(BaseModel):
id: int
name: str
price: float
# Use with different types
user_response = APIResponse[UserOut](
data=UserOut(id=1, name="Alice", email="alice@example.com")
)
print(user_response.model_dump())
# {'success': True, 'message': 'OK',
# 'data': {'id': 1, 'name': 'Alice', 'email': 'alice@example.com'}}
product_response = APIResponse[ProductOut](
data=ProductOut(id=1, name="Laptop", price=999.99)
)
print(product_response.model_dump())
# Error response
error_response = APIResponse[None](
success=False, message="User not found", data=None
)
print(error_response.model_dump())
from pydantic import BaseModel, Field, computed_field
from typing import Generic, TypeVar, List, Optional
from math import ceil
T = TypeVar("T")
class PaginatedResponse(BaseModel, Generic[T]):
items: List[T]
total: int = Field(ge=0)
page: int = Field(ge=1)
page_size: int = Field(ge=1, le=100)
@computed_field
@property
def total_pages(self) -> int:
return ceil(self.total / self.page_size) if self.page_size else 0
@computed_field
@property
def has_next(self) -> bool:
return self.page < self.total_pages
@computed_field
@property
def has_previous(self) -> bool:
return self.page > 1
@computed_field
@property
def next_page(self) -> Optional[int]:
return self.page + 1 if self.has_next else None
@computed_field
@property
def previous_page(self) -> Optional[int]:
return self.page - 1 if self.has_previous else None
class UserOut(BaseModel):
id: int
name: str
email: str
# Usage in FastAPI
from fastapi import FastAPI, Query
app = FastAPI()
@app.get("/users", response_model=PaginatedResponse[UserOut])
async def list_users(
page: int = Query(default=1, ge=1),
page_size: int = Query(default=20, ge=1, le=100)
):
all_users = [
{"id": i, "name": f"User {i}", "email": f"user{i}@example.com"}
for i in range(1, 96)
]
total = len(all_users)
start = (page - 1) * page_size
end = start + page_size
items = all_users[start:end]
return PaginatedResponse[UserOut](
items=items, total=total, page=page, page_size=page_size
)
# Response for page=2, page_size=20:
# {
# "items": [...],
# "total": 95,
# "page": 2,
# "page_size": 20,
# "total_pages": 5,
# "has_next": true,
# "has_previous": true,
# "next_page": 3,
# "previous_page": 1
# }
from pydantic import BaseModel, Field
from typing import Generic, TypeVar, Optional, List
from datetime import datetime
T = TypeVar("T")
class ErrorDetail(BaseModel):
field: Optional[str] = None
message: str
code: str
class APIEnvelope(BaseModel, Generic[T]):
success: bool = True
status_code: int = 200
message: str = "OK"
data: Optional[T] = None
errors: List[ErrorDetail] = []
meta: dict = {}
timestamp: datetime = Field(default_factory=datetime.now)
@classmethod
def ok(cls, data: T, message: str = "OK", meta: dict = None):
return cls(data=data, message=message, meta=meta or {})
@classmethod
def error(cls, status_code: int, message: str,
errors: List[ErrorDetail] = None):
return cls(
success=False, status_code=status_code,
message=message, errors=errors or []
)
@classmethod
def not_found(cls, resource: str, id):
return cls.error(404, f"{resource} with id {id} not found")
# Usage in FastAPI
from fastapi import FastAPI, HTTPException
app = FastAPI()
class UserOut(BaseModel):
id: int
name: str
email: str
@app.get("/users/{user_id}", response_model=APIEnvelope[UserOut])
async def get_user(user_id: int):
if user_id == 1:
return APIEnvelope.ok(
data=UserOut(id=1, name="Alice", email="alice@example.com"),
meta={"cache": "hit"}
)
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
@app.get("/users", response_model=APIEnvelope[List[UserOut]])
async def list_users():
users = [
UserOut(id=1, name="Alice", email="alice@example.com"),
UserOut(id=2, name="Bob", email="bob@example.com"),
]
return APIEnvelope.ok(data=users, meta={"count": len(users)})
FastAPI handles file uploads and HTML form data differently from JSON bodies. You cannot use Pydantic’s BaseModel directly for form fields — instead, you use Form() and File() parameters with the same validation constraints.
from fastapi import FastAPI, Form, HTTPException
app = FastAPI()
@app.post("/login")
async def login(
username: str = Form(..., min_length=3, max_length=50),
password: str = Form(..., min_length=8),
remember_me: bool = Form(default=False),
):
# Accepts application/x-www-form-urlencoded or multipart/form-data
if username == "admin" and password == "Admin!234":
return {"status": "success", "username": username}
raise HTTPException(status_code=401, detail="Invalid credentials")
@app.post("/contact")
async def submit_contact_form(
name: str = Form(..., min_length=2, max_length=100),
email: str = Form(..., pattern=r"^[\w\.-]+@[\w\.-]+\.\w{2,}$"),
subject: str = Form(..., min_length=5, max_length=200),
message: str = Form(..., min_length=20, max_length=5000),
category: str = Form(default="general"),
priority: int = Form(default=3, ge=1, le=5),
):
return {
"status": "submitted",
"contact": {
"name": name,
"email": email,
"subject": subject,
"message_length": len(message),
"category": category,
"priority": priority,
}
}
from fastapi import FastAPI, File, UploadFile, HTTPException, status
from typing import List
app = FastAPI()
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5 MB
ALLOWED_IMAGE_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}
ALLOWED_DOC_TYPES = {"application/pdf", "text/plain", "text/csv"}
def validate_image(file: UploadFile) -> None:
if file.content_type not in ALLOWED_IMAGE_TYPES:
raise HTTPException(
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
detail=f"Type '{file.content_type}' not allowed. "
f"Allowed: {', '.join(ALLOWED_IMAGE_TYPES)}"
)
@app.post("/upload/avatar")
async def upload_avatar(file: UploadFile = File(...)):
validate_image(file)
contents = await file.read()
size = len(contents)
if size > MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File too large. Max: {MAX_FILE_SIZE // (1024*1024)}MB"
)
await file.seek(0)
return {
"filename": file.filename,
"content_type": file.content_type,
"size_bytes": size,
"size_mb": round(size / (1024 * 1024), 2)
}
@app.post("/upload/documents")
async def upload_documents(
files: List[UploadFile] = File(..., description="Max 10 files")
):
if len(files) > 10:
raise HTTPException(400, "Maximum 10 files allowed")
results = []
errors = []
allowed = ALLOWED_IMAGE_TYPES | ALLOWED_DOC_TYPES
for idx, file in enumerate(files):
if file.content_type not in allowed:
errors.append({"file": file.filename, "error": "Type not allowed"})
continue
contents = await file.read()
if len(contents) > MAX_FILE_SIZE:
errors.append({"file": file.filename, "error": "Too large"})
continue
results.append({
"filename": file.filename,
"content_type": file.content_type,
"size_bytes": len(contents)
})
return {"uploaded": len(results), "failed": len(errors),
"files": results, "errors": errors}
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from typing import List
app = FastAPI()
@app.post("/products")
async def create_product_with_image(
# Form fields
name: str = Form(..., min_length=1, max_length=200),
description: str = Form(default="", max_length=5000),
price: float = Form(..., gt=0),
category: str = Form(..., min_length=1),
tags: str = Form(default=""), # comma-separated
# File fields
main_image: UploadFile = File(...),
additional_images: List[UploadFile] = File(default=[]),
):
# When mixing Form and File, request must use multipart/form-data
if main_image.content_type not in {"image/jpeg", "image/png", "image/webp"}:
raise HTTPException(400, "Main image must be JPEG, PNG, or WebP")
if len(additional_images) > 5:
raise HTTPException(400, "Maximum 5 additional images")
tag_list = [t.strip() for t in tags.split(",") if t.strip()] if tags else []
main_size = len(await main_image.read())
additional_info = []
for img in additional_images:
content = await img.read()
additional_info.append({
"filename": img.filename, "size": len(content)
})
return {
"product": {
"name": name, "description": description,
"price": price, "category": category, "tags": tag_list
},
"images": {
"main": {"filename": main_image.filename, "size": main_size},
"additional": additional_info
}
}
# Test with curl: curl -X POST http://localhost:8000/products \ -F "name=Wireless Mouse" \ -F "price=29.99" \ -F "category=electronics" \ -F "tags=sale,featured" \ -F "main_image=@mouse.jpg" \ -F "additional_images=@mouse_side.jpg"
Let us bring everything together in a comprehensive user registration API with multi-step validation, address handling, password strength checking, and email verification.
registration_system/ ├── main.py # FastAPI application and routes ├── models/ │ ├── __init__.py │ ├── user.py # User Pydantic models │ ├── address.py # Address models with validation │ ├── common.py # Shared/generic models │ └── validators.py # Reusable validators ├── services/ │ ├── __init__.py │ └── user_service.py # Business logic └── requirements.txt
# models/validators.py
import re
from typing import Annotated
from pydantic import AfterValidator
def validate_password_strength(password: str) -> str:
if len(password) < 8:
raise ValueError("Password must be at least 8 characters")
if len(password) > 128:
raise ValueError("Password must not exceed 128 characters")
if not re.search(r"[A-Z]", password):
raise ValueError("Must contain at least one uppercase letter")
if not re.search(r"[a-z]", password):
raise ValueError("Must contain at least one lowercase letter")
if not re.search(r"\d", password):
raise ValueError("Must contain at least one digit")
if not re.search(r"[!@#$%^&*()_+\-=\[\]{}|;:,./?\\\"]", password):
raise ValueError("Must contain at least one special character")
common_patterns = ["password", "12345678", "qwerty", "abc123"]
if password.lower() in common_patterns:
raise ValueError("Password is too common")
return password
def validate_phone_number(phone: str) -> str:
cleaned = phone.strip()
if cleaned.startswith("+"):
digits = "+" + re.sub(r"\D", "", cleaned[1:])
else:
digits = re.sub(r"\D", "", cleaned)
if not digits.startswith("+"):
if len(digits) == 10:
digits = "+1" + digits
elif len(digits) == 11 and digits.startswith("1"):
digits = "+" + digits
else:
raise ValueError("Invalid phone number format")
digit_count = len(digits.replace("+", ""))
if digit_count < 10 or digit_count > 15:
raise ValueError("Phone number must be 10-15 digits")
return digits
def validate_username(username: str) -> str:
username = username.strip().lower()
if len(username) < 3:
raise ValueError("Username must be at least 3 characters")
if len(username) > 30:
raise ValueError("Username must not exceed 30 characters")
if not re.match(r"^[a-z][a-z0-9_]*$", username):
raise ValueError("Must start with letter; only lowercase, numbers, underscores")
if "__" in username:
raise ValueError("No consecutive underscores")
reserved = {"admin", "root", "system", "api", "null", "undefined", "support"}
if username in reserved:
raise ValueError(f"Username '{username}' is reserved")
return username
# Create reusable annotated types
StrongPassword = Annotated[str, AfterValidator(validate_password_strength)]
PhoneNumber = Annotated[str, AfterValidator(validate_phone_number)]
Username = Annotated[str, AfterValidator(validate_username)]
# models/address.py
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import Optional, Literal
US_STATES = {
"AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA",
"HI", "ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", "MD",
"MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ",
"NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI", "SC",
"SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY",
"DC", "PR", "VI", "GU", "AS", "MP"
}
class Address(BaseModel):
street_line_1: str = Field(min_length=1, max_length=200)
street_line_2: Optional[str] = Field(default=None, max_length=200)
city: str = Field(min_length=1, max_length=100)
state: str = Field(min_length=2, max_length=2)
zip_code: str = Field(pattern=r"^\d{5}(-\d{4})?$")
country: Literal["US"] = "US"
@field_validator("state")
@classmethod
def validate_state(cls, v: str) -> str:
v = v.upper().strip()
if v not in US_STATES:
raise ValueError(f"Invalid US state code: {v}")
return v
@field_validator("city")
@classmethod
def normalize_city(cls, v: str) -> str:
return v.strip().title()
class AddressCreate(Address):
address_type: Literal["home", "work", "billing", "shipping"] = "home"
is_primary: bool = False
@model_validator(mode="after")
def validate_po_box(self):
if self.address_type == "shipping":
street = self.street_line_1.lower()
if "po box" in street or "p.o. box" in street:
raise ValueError("Shipping address cannot be a PO Box")
return self
# models/user.py
from pydantic import (
BaseModel, Field, ConfigDict, field_validator,
model_validator, computed_field
)
from typing import Optional, List, Annotated
from datetime import date, datetime
from enum import Enum
from pydantic import AfterValidator
import re
# Inline password validator for self-contained example
def _validate_password(p: str) -> str:
if len(p) < 8:
raise ValueError("Password must be at least 8 characters")
if not re.search(r"[A-Z]", p):
raise ValueError("Must contain uppercase letter")
if not re.search(r"[a-z]", p):
raise ValueError("Must contain lowercase letter")
if not re.search(r"\d", p):
raise ValueError("Must contain digit")
if not re.search(r"[!@#$%^&*()\-_=+]", p):
raise ValueError("Must contain special character")
return p
StrongPassword = Annotated[str, AfterValidator(_validate_password)]
class Gender(str, Enum):
MALE = "male"
FEMALE = "female"
NON_BINARY = "non_binary"
PREFER_NOT_TO_SAY = "prefer_not_to_say"
class NotificationPreference(str, Enum):
EMAIL = "email"
SMS = "sms"
PUSH = "push"
NONE = "none"
class UserRegistration(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
# Account Info
username: str = Field(
min_length=3, max_length=30,
pattern=r"^[a-zA-Z][a-zA-Z0-9_]*$"
)
email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w{2,}$")
password: StrongPassword
confirm_password: str
# Personal Info
first_name: str = Field(min_length=1, max_length=50)
last_name: str = Field(min_length=1, max_length=50)
date_of_birth: date
gender: Optional[Gender] = None
phone: Optional[str] = Field(default=None, pattern=r"^\+?1?\d{10,15}$")
# Preferences
notifications: List[NotificationPreference] = Field(
default=[NotificationPreference.EMAIL]
)
newsletter: bool = False
terms_accepted: bool
@field_validator("username")
@classmethod
def normalize_username(cls, v: str) -> str:
return v.lower()
@field_validator("email")
@classmethod
def normalize_email(cls, v: str) -> str:
return v.lower()
@field_validator("first_name", "last_name")
@classmethod
def capitalize_name(cls, v: str) -> str:
return v.strip().title()
@field_validator("date_of_birth")
@classmethod
def validate_age(cls, v: date) -> date:
today = date.today()
age = today.year - v.year - ((today.month, today.day) < (v.month, v.day))
if age < 13:
raise ValueError("Must be at least 13 years old")
if age > 120:
raise ValueError("Invalid date of birth")
return v
@field_validator("terms_accepted")
@classmethod
def must_accept_terms(cls, v: bool) -> bool:
if not v:
raise ValueError("You must accept the terms and conditions")
return v
@model_validator(mode="after")
def validate_registration(self):
if self.password != self.confirm_password:
raise ValueError("Passwords do not match")
if self.username.lower() in self.password.lower():
raise ValueError("Password must not contain your username")
if NotificationPreference.SMS in self.notifications and not self.phone:
raise ValueError("Phone number required for SMS notifications")
return self
@computed_field
@property
def display_name(self) -> str:
return f"{self.first_name} {self.last_name}"
class UserResponse(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
username: str
email: str
first_name: str
last_name: str
display_name: str
date_of_birth: date
gender: Optional[Gender]
phone: Optional[str]
notifications: List[NotificationPreference]
is_active: bool
is_verified: bool
created_at: datetime
class UserUpdate(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
first_name: Optional[str] = Field(default=None, min_length=1, max_length=50)
last_name: Optional[str] = Field(default=None, min_length=1, max_length=50)
phone: Optional[str] = Field(default=None, pattern=r"^\+?1?\d{10,15}$")
gender: Optional[Gender] = None
notifications: Optional[List[NotificationPreference]] = None
newsletter: Optional[bool] = None
# main.py
from fastapi import FastAPI, HTTPException, Request, Query, Path, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from datetime import datetime
from typing import Optional
import secrets
import re
app = FastAPI(
title="User Registration System",
description="Complete registration with Pydantic validation",
version="1.0.0"
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
errors = []
for error in exc.errors():
loc = [str(l) for l in error["loc"] if l != "body"]
errors.append({
"field": ".".join(loc) if loc else "general",
"message": error["msg"],
"type": error["type"]
})
return JSONResponse(
status_code=422,
content={
"success": False, "message": "Validation failed",
"errors": errors, "timestamp": datetime.now().isoformat()
}
)
# Simulated Database
users_db: dict = {}
next_user_id = 1
email_verifications: dict = {}
@app.post("/api/v1/register", status_code=201)
async def register_user(registration):
global next_user_id
for user in users_db.values():
if user["username"] == registration.username:
raise HTTPException(409, {"field": "username", "message": "Already taken"})
if user["email"] == registration.email:
raise HTTPException(409, {"field": "email", "message": "Already registered"})
user_data = {
"id": next_user_id,
"username": registration.username,
"email": registration.email,
"first_name": registration.first_name,
"last_name": registration.last_name,
"display_name": registration.display_name,
"date_of_birth": registration.date_of_birth.isoformat(),
"gender": registration.gender.value if registration.gender else None,
"phone": registration.phone,
"notifications": [n.value for n in registration.notifications],
"is_active": True,
"is_verified": False,
"created_at": datetime.now().isoformat()
}
users_db[next_user_id] = user_data
token = secrets.token_urlsafe(32)
email_verifications[token] = next_user_id
next_user_id += 1
return {
"success": True,
"message": "Registration successful! Check your email to verify.",
"data": {
"id": user_data["id"],
"username": user_data["username"],
"email": user_data["email"],
"display_name": user_data["display_name"]
}
}
@app.get("/api/v1/verify-email/{token}")
async def verify_email(token: str):
user_id = email_verifications.get(token)
if not user_id or user_id not in users_db:
raise HTTPException(400, "Invalid or expired verification token")
users_db[user_id]["is_verified"] = True
del email_verifications[token]
return {"success": True, "message": "Email verified successfully!"}
@app.get("/api/v1/users/{user_id}")
async def get_user(user_id: int = Path(..., gt=0)):
user = users_db.get(user_id)
if not user:
raise HTTPException(404, f"User {user_id} not found")
return {"success": True, "data": user}
@app.get("/api/v1/users")
async def list_users(
page: int = Query(default=1, ge=1),
page_size: int = Query(default=20, ge=1, le=100),
search: Optional[str] = Query(default=None, min_length=1),
):
filtered = list(users_db.values())
if search:
s = search.lower()
filtered = [u for u in filtered
if s in u["username"] or s in u["email"]
or s in u["display_name"].lower()]
total = len(filtered)
start = (page - 1) * page_size
items = filtered[start:start + page_size]
return {
"success": True,
"data": {
"items": items, "total": total,
"page": page, "page_size": page_size,
"total_pages": -(-total // page_size)
}
}
@app.get("/api/v1/check-username/{username}")
async def check_username(username: str = Path(..., min_length=3, max_length=30)):
taken = any(u["username"] == username.lower() for u in users_db.values())
return {"username": username.lower(), "available": not taken}
@app.get("/api/v1/password-strength")
async def check_password_strength(
password: str = Query(..., min_length=1)
):
score = 0
feedback = []
if len(password) >= 8: score += 1
else: feedback.append("Use at least 8 characters")
if len(password) >= 12: score += 1
if len(password) >= 16: score += 1
if re.search(r"[A-Z]", password): score += 1
else: feedback.append("Add an uppercase letter")
if re.search(r"[a-z]", password): score += 1
else: feedback.append("Add a lowercase letter")
if re.search(r"\d", password): score += 1
else: feedback.append("Add a number")
if re.search(r"[!@#$%^&*()\-_=+]", password): score += 1
else: feedback.append("Add a special character")
if score <= 2: strength = "weak"
elif score <= 4: strength = "fair"
elif score <= 5: strength = "good"
else: strength = "strong"
return {
"score": score, "max_score": 7, "strength": strength,
"feedback": feedback, "meets_requirements": score >= 4
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
# Start the server
uvicorn main:app --reload
# 1. Check username availability
curl http://localhost:8000/api/v1/check-username/alice_dev
# 2. Check password strength
curl "http://localhost:8000/api/v1/password-strength?password=Str0ng!Pass"
# 3. Register a new user
curl -X POST http://localhost:8000/api/v1/register \
-H "Content-Type: application/json" \
-d '{
"username": "alice_dev",
"email": "alice@example.com",
"password": "Str0ng!Pass#1",
"confirm_password": "Str0ng!Pass#1",
"first_name": "Alice",
"last_name": "Johnson",
"date_of_birth": "1995-03-15",
"gender": "female",
"phone": "+15551234567",
"notifications": ["email", "sms"],
"newsletter": true,
"terms_accepted": true
}'
# 4. Try invalid registration (triggers multiple errors)
curl -X POST http://localhost:8000/api/v1/register \
-H "Content-Type: application/json" \
-d '{
"username": "a",
"email": "not-email",
"password": "weak",
"confirm_password": "different",
"first_name": "",
"last_name": "Johnson",
"date_of_birth": "2020-01-01",
"terms_accepted": false
}'
# 5. List users
curl "http://localhost:8000/api/v1/users?page=1&page_size=10"
# 6. Get specific user
curl http://localhost:8000/api/v1/users/1
Here is a summary of everything you have learned about Pydantic models and validation in FastAPI:
| Topic | Key Points |
|---|---|
| BaseModel | Inherit from BaseModel, define fields with type annotations, use defaults and Optional |
| Field() | Add constraints: min_length, max_length, gt, lt, ge, le, pattern, multiple_of |
| Type Annotations | Use str, int, float, bool, List, Dict, Optional, Union, Literal, Annotated |
| Nested Models | Compose models within models; validation is recursive and automatic |
| @field_validator | Custom validation for individual fields; use mode="before" for pre-processing |
| @model_validator | Cross-field validation after all fields are set; use for password confirmation, date ranges |
| @computed_field | Derived values included in serialization; use with @property |
| model_config | extra="forbid", str_strip_whitespace, from_attributes, frozen |
| Serialization | model_dump() with exclude, include, exclude_none, exclude_unset, by_alias |
| Request Validation | Body, Query, Path, Header, Cookie — all validated automatically by FastAPI |
| Response Models | response_model filters output; use different input/output schemas to protect sensitive data |
| Error Handling | Customize RequestValidationError handler for consistent API error format |
| Generic Models | Generic[T] for reusable wrappers: paginated responses, API envelopes |
| Form & Files | Use Form() and File() for non-JSON input; cannot use BaseModel directly |
extra="forbid" on input models to catch typos and unexpected fieldsAnnotated types for reusable validation logic across models@field_validator and cross-field checks in @model_validatormodel_dump(exclude_unset=True) for PATCH endpoints to only update provided fieldsfrom_attributes=True when working with SQLAlchemy ORM objectsIn the next tutorial, FastAPI – Testing, you will learn how to write comprehensive tests for all of these validation rules using pytest and FastAPI’s TestClient. You will see how to test valid inputs, invalid inputs, edge cases, and custom error responses — ensuring your validation logic is bulletproof.