""" CSRF Protection Module for TradeMate Provides CSRF token generation and validation for form submissions """ import secrets import time from typing import Optional, Tuple from fastapi import Request, HTTPException, status from starlette.middleware.base import BaseHTTPMiddleware from app.config import settings import logging logger = logging.getLogger(__name__) # CSRF token configuration CSRF_TOKEN_EXPIRY = 3600 # 1 hour CSRF_HEADER_NAME = "X-CSRF-Token" CSRF_COOKIE_NAME = "csrf_token" # Methods that require CSRF protection CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"} # Endpoints that should skip CSRF protection (e.g., webhook endpoints) CSRF_SKIP_ENDPOINTS = [ "/api/v1/webhook/", "/api/v1/payment/", "/api/v1/whatsapp/webhook", "/api/v1/ai/", ] def generate_csrf_token() -> str: """Generate a secure CSRF token""" return secrets.token_urlsafe(32) def validate_csrf_token(token: Optional[str], request: Request) -> bool: """ Validate CSRF token from request. Checks: 1. Token is present 2. Token matches the session/token cookie 3. Token is not expired """ if not token: return False # Get the expected token from cookie or session cookie_token = request.cookies.get(CSRF_COOKIE_NAME) if not cookie_token: return False # Constant-time comparison to prevent timing attacks return secrets.compare_digest(token, cookie_token) class CSRFMiddleware(BaseHTTPMiddleware): """ CSRF protection middleware for FastAPI. For JWT-based APIs, CSRF protection is primarily needed for: 1. Form-based submissions (if any) 2. Any endpoint that uses cookies for authentication 3. Prevention of cross-site request forgery attacks The middleware: - Generates CSRF tokens for authenticated sessions - Validates CSRF tokens on state-changing requests - Skips validation for webhook endpoints and public APIs """ async def dispatch(self, request: Request, call_next): path = request.url.path # Skip CSRF protection for: # 1. Health check endpoint # 2. Webhook endpoints # 3. Public API endpoints (no auth required) if path == "/health" or any(path.startswith(skip) for skip in CSRF_SKIP_ENDPOINTS): return await call_next(request) # Get authorization header auth_header = request.headers.get("Authorization", "") has_jwt = auth_header.startswith("Bearer ") # For API requests with JWT, we use double-submit cookie pattern # The client should send X-CSRF-Token header matching the csrf_token cookie if request.method in CSRF_PROTECTED_METHODS: # Check for CSRF token in header csrf_token = request.headers.get(CSRF_HEADER_NAME) cookie_token = request.cookies.get(CSRF_COOKIE_NAME) # If there's a JWT but no CSRF token, this might be a direct API call # In that case, we require the CSRF token to be present if has_jwt and not csrf_token: # This is a potential CSRF attempt # For API clients using JWT, we still require CSRF protection # to prevent attacks from malicious websites raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail={ "error": "CSRF_TOKEN_MISSING", "message": "CSRF token required for this request", "required_header": CSRF_HEADER_NAME, }, ) # Validate the token if present if csrf_token and cookie_token: if not validate_csrf_token(csrf_token, request): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail={ "error": "CSRF_TOKEN_INVALID", "message": "Invalid or expired CSRF token", }, ) elif csrf_token and not cookie_token: # Token in header but no cookie - invalid state raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail={ "error": "CSRF_TOKEN_MISMATCH", "message": "CSRF token cookie not found", }, ) response = await call_next(request) # If this is an authenticated request and no CSRF cookie exists, # generate and set one if has_jwt and not request.cookies.get(CSRF_COOKIE_NAME): csrf_token = generate_csrf_token() response.set_cookie( key=CSRF_COOKIE_NAME, value=csrf_token, max_age=CSRF_TOKEN_EXPIRY, httponly=False, # Must be accessible to JavaScript for double-submit secure=settings.DEBUG is False, # Secure in production samesite="lax", # Prevent cross-site requests path="/", ) # Also add the token to response headers for convenience response.headers[CSRF_HEADER_NAME] = csrf_token return response def get_csrf_token_from_request(request: Request) -> Optional[str]: """Helper to extract CSRF token from request""" return request.headers.get(CSRF_HEADER_NAME) or request.cookies.get(CSRF_COOKIE_NAME) def require_csrf_token(request: Request) -> str: """ Dependency function to require CSRF token in route handlers. Use with: Depends(require_csrf_token) """ csrf_token = get_csrf_token_from_request(request) if not csrf_token: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="CSRF token required", ) return csrf_token