Files
trade-assistant/backend/app/core/csrf.py
T
TradeMate Dev 3e39cf0170 refactor: replace direct WeChat/Alipay with unified pay-api gateway
Switch from direct WeChat Pay / Alipay integrations to the unified
宇之然 pay-api gateway (HMAC-SHA256 auth). Removes wechat_pay.py,
keeps PaymentGateway abstraction, adds UnifiedPayService. Simplifies
payment.py create_order to {plan, pay_type} params. Single webhook
endpoint replaces separate WeChat/Alipay notify handlers.
2026-05-29 18:36:50 +08:00

167 lines
5.9 KiB
Python

"""
CSRF Protection Module for TradeMate
Provides CSRF token generation and validation for form submissions
"""
import secrets
import time
from typing import Optional, Tuple
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from app.config import settings
import logging
logger = logging.getLogger(__name__)
# CSRF token configuration
CSRF_TOKEN_EXPIRY = 3600 # 1 hour
CSRF_HEADER_NAME = "X-CSRF-Token"
CSRF_COOKIE_NAME = "csrf_token"
# Methods that require CSRF protection
CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
# Endpoints that should skip CSRF protection (e.g., webhook endpoints)
CSRF_SKIP_ENDPOINTS = [
"/api/v1/webhook/",
"/api/v1/payment/webhook",
"/api/v1/whatsapp/webhook",
]
def generate_csrf_token() -> str:
"""Generate a secure CSRF token"""
return secrets.token_urlsafe(32)
def validate_csrf_token(token: Optional[str], request: Request) -> bool:
"""
Validate CSRF token from request.
Checks:
1. Token is present
2. Token matches the session/token cookie
3. Token is not expired
"""
if not token:
return False
# Get the expected token from cookie or session
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
if not cookie_token:
return False
# Constant-time comparison to prevent timing attacks
return secrets.compare_digest(token, cookie_token)
class CSRFMiddleware(BaseHTTPMiddleware):
"""
CSRF protection middleware for FastAPI.
For JWT-based APIs, CSRF protection is primarily needed for:
1. Form-based submissions (if any)
2. Any endpoint that uses cookies for authentication
3. Prevention of cross-site request forgery attacks
The middleware:
- Generates CSRF tokens for authenticated sessions
- Validates CSRF tokens on state-changing requests
- Skips validation for webhook endpoints and public APIs
"""
async def dispatch(self, request: Request, call_next):
path = request.url.path
# Skip CSRF protection for:
# 1. Health check endpoint
# 2. Webhook endpoints
# 3. Public API endpoints (no auth required)
if path == "/health" or any(path.startswith(skip) for skip in CSRF_SKIP_ENDPOINTS):
return await call_next(request)
# Get authorization header
auth_header = request.headers.get("Authorization", "")
has_jwt = auth_header.startswith("Bearer ")
# For API requests with JWT, we use double-submit cookie pattern
# The client should send X-CSRF-Token header matching the csrf_token cookie
if request.method in CSRF_PROTECTED_METHODS:
# Check for CSRF token in header
csrf_token = request.headers.get(CSRF_HEADER_NAME)
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
# If there's a JWT but no CSRF token, this might be a direct API call
# In that case, we require the CSRF token to be present
if has_jwt and not csrf_token:
# This is a potential CSRF attempt
# For API clients using JWT, we still require CSRF protection
# to prevent attacks from malicious websites
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "CSRF_TOKEN_MISSING",
"message": "CSRF token required for this request",
"required_header": CSRF_HEADER_NAME,
},
)
# Validate the token if present
if csrf_token and cookie_token:
if not validate_csrf_token(csrf_token, request):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "CSRF_TOKEN_INVALID",
"message": "Invalid or expired CSRF token",
},
)
elif csrf_token and not cookie_token:
# Token in header but no cookie - invalid state
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "CSRF_TOKEN_MISMATCH",
"message": "CSRF token cookie not found",
},
)
response = await call_next(request)
# If this is an authenticated request and no CSRF cookie exists,
# generate and set one
if has_jwt and not request.cookies.get(CSRF_COOKIE_NAME):
csrf_token = generate_csrf_token()
response.set_cookie(
key=CSRF_COOKIE_NAME,
value=csrf_token,
max_age=CSRF_TOKEN_EXPIRY,
httponly=False, # Must be accessible to JavaScript for double-submit
secure=settings.DEBUG is False, # Secure in production
samesite="lax", # Prevent cross-site requests
path="/",
)
# Also add the token to response headers for convenience
response.headers[CSRF_HEADER_NAME] = csrf_token
return response
def get_csrf_token_from_request(request: Request) -> Optional[str]:
"""Helper to extract CSRF token from request"""
return request.headers.get(CSRF_HEADER_NAME) or request.cookies.get(CSRF_COOKIE_NAME)
def require_csrf_token(request: Request) -> str:
"""
Dependency function to require CSRF token in route handlers.
Use with: Depends(require_csrf_token)
"""
csrf_token = get_csrf_token_from_request(request)
if not csrf_token:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="CSRF token required",
)
return csrf_token