T-005: Security hardening - CORS, Rate Limit, CSRF
- CORS: Restrict allowed origins to specific frontend URLs, limit methods and headers - Rate Limit: Add fine-grained endpoint-specific rate limits for sensitive operations - Login: 5 requests/minute - Register: 3 requests/hour - Password change: 3 requests/5 minutes - Payment: 20 requests/minute - Admin: 30 requests/minute - CSRF: Add CSRF protection middleware with double-submit cookie pattern - New app/core/csrf.py module with CSRFMiddleware - Require CSRF tokens on sensitive endpoints (auth, payment, profile) - Skip webhook endpoints for CSRF validation - Fix pydantic-settings import in config.py
This commit is contained in:
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
CSRF Protection Module for TradeMate
|
||||
Provides CSRF token generation and validation for form submissions
|
||||
"""
|
||||
import secrets
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
from fastapi import Request, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# CSRF token configuration
|
||||
CSRF_TOKEN_EXPIRY = 3600 # 1 hour
|
||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||
CSRF_COOKIE_NAME = "csrf_token"
|
||||
|
||||
# Methods that require CSRF protection
|
||||
CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
|
||||
|
||||
# Endpoints that should skip CSRF protection (e.g., webhook endpoints)
|
||||
CSRF_SKIP_ENDPOINTS = [
|
||||
"/api/v1/webhook/",
|
||||
"/api/v1/payment/notify",
|
||||
"/api/v1/whatsapp/webhook",
|
||||
]
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
"""Generate a secure CSRF token"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def validate_csrf_token(token: Optional[str], request: Request) -> bool:
|
||||
"""
|
||||
Validate CSRF token from request.
|
||||
|
||||
Checks:
|
||||
1. Token is present
|
||||
2. Token matches the session/token cookie
|
||||
3. Token is not expired
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
|
||||
# Get the expected token from cookie or session
|
||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||
if not cookie_token:
|
||||
return False
|
||||
|
||||
# Constant-time comparison to prevent timing attacks
|
||||
return secrets.compare_digest(token, cookie_token)
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
CSRF protection middleware for FastAPI.
|
||||
|
||||
For JWT-based APIs, CSRF protection is primarily needed for:
|
||||
1. Form-based submissions (if any)
|
||||
2. Any endpoint that uses cookies for authentication
|
||||
3. Prevention of cross-site request forgery attacks
|
||||
|
||||
The middleware:
|
||||
- Generates CSRF tokens for authenticated sessions
|
||||
- Validates CSRF tokens on state-changing requests
|
||||
- Skips validation for webhook endpoints and public APIs
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
path = request.url.path
|
||||
|
||||
# Skip CSRF protection for:
|
||||
# 1. Health check endpoint
|
||||
# 2. Webhook endpoints
|
||||
# 3. Public API endpoints (no auth required)
|
||||
if path == "/health" or any(path.startswith(skip) for skip in CSRF_SKIP_ENDPOINTS):
|
||||
return await call_next(request)
|
||||
|
||||
# Get authorization header
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
has_jwt = auth_header.startswith("Bearer ")
|
||||
|
||||
# For API requests with JWT, we use double-submit cookie pattern
|
||||
# The client should send X-CSRF-Token header matching the csrf_token cookie
|
||||
|
||||
if request.method in CSRF_PROTECTED_METHODS:
|
||||
# Check for CSRF token in header
|
||||
csrf_token = request.headers.get(CSRF_HEADER_NAME)
|
||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||
|
||||
# If there's a JWT but no CSRF token, this might be a direct API call
|
||||
# In that case, we require the CSRF token to be present
|
||||
if has_jwt and not csrf_token:
|
||||
# This is a potential CSRF attempt
|
||||
# For API clients using JWT, we still require CSRF protection
|
||||
# to prevent attacks from malicious websites
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": "CSRF_TOKEN_MISSING",
|
||||
"message": "CSRF token required for this request",
|
||||
"required_header": CSRF_HEADER_NAME,
|
||||
},
|
||||
)
|
||||
|
||||
# Validate the token if present
|
||||
if csrf_token and cookie_token:
|
||||
if not validate_csrf_token(csrf_token, request):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": "CSRF_TOKEN_INVALID",
|
||||
"message": "Invalid or expired CSRF token",
|
||||
},
|
||||
)
|
||||
elif csrf_token and not cookie_token:
|
||||
# Token in header but no cookie - invalid state
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": "CSRF_TOKEN_MISMATCH",
|
||||
"message": "CSRF token cookie not found",
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# If this is an authenticated request and no CSRF cookie exists,
|
||||
# generate and set one
|
||||
if has_jwt and not request.cookies.get(CSRF_COOKIE_NAME):
|
||||
csrf_token = generate_csrf_token()
|
||||
response.set_cookie(
|
||||
key=CSRF_COOKIE_NAME,
|
||||
value=csrf_token,
|
||||
max_age=CSRF_TOKEN_EXPIRY,
|
||||
httponly=False, # Must be accessible to JavaScript for double-submit
|
||||
secure=settings.DEBUG is False, # Secure in production
|
||||
samesite="lax", # Prevent cross-site requests
|
||||
path="/",
|
||||
)
|
||||
# Also add the token to response headers for convenience
|
||||
response.headers[CSRF_HEADER_NAME] = csrf_token
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_csrf_token_from_request(request: Request) -> Optional[str]:
|
||||
"""Helper to extract CSRF token from request"""
|
||||
return request.headers.get(CSRF_HEADER_NAME) or request.cookies.get(CSRF_COOKIE_NAME)
|
||||
|
||||
|
||||
def require_csrf_token(request: Request) -> str:
|
||||
"""
|
||||
Dependency function to require CSRF token in route handlers.
|
||||
Use with: Depends(require_csrf_token)
|
||||
"""
|
||||
csrf_token = get_csrf_token_from_request(request)
|
||||
if not csrf_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="CSRF token required",
|
||||
)
|
||||
return csrf_token
|
||||
+181
-32
@@ -1,4 +1,8 @@
|
||||
from fastapi import Request, Response
|
||||
"""
|
||||
Enhanced Rate Limiting Module for TradeMate
|
||||
Provides fine-grained rate limiting for sensitive endpoints
|
||||
"""
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.config import settings
|
||||
from app.core.security import decode_token
|
||||
@@ -7,6 +11,7 @@ from redis.asyncio import ConnectionPool
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,6 +19,7 @@ _redis_pool = None
|
||||
|
||||
|
||||
async def get_redis():
|
||||
"""Get Redis connection pool"""
|
||||
global _redis_pool
|
||||
if _redis_pool is None:
|
||||
_redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20)
|
||||
@@ -21,6 +27,7 @@ async def get_redis():
|
||||
|
||||
|
||||
def get_user_tier_from_token(request: Request) -> str:
|
||||
"""Extract user tier from JWT token"""
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
request.state.user_id = None
|
||||
@@ -36,29 +43,119 @@ def get_user_tier_from_token(request: Request) -> str:
|
||||
return request.state.user_tier
|
||||
|
||||
|
||||
# Global rate limits (per minute) by tier
|
||||
RATE_LIMITS = {
|
||||
"free": 100,
|
||||
"pro": 500,
|
||||
"enterprise": 2000,
|
||||
"anonymous": 10,
|
||||
"guest": 30,
|
||||
"free": 60,
|
||||
"pro": 300,
|
||||
"enterprise": 1000,
|
||||
}
|
||||
|
||||
# Sensitive endpoint rate limits (stricter limits regardless of tier)
|
||||
SENSITIVE_ENDPOINT_LIMITS: Dict[str, dict] = {
|
||||
# Authentication endpoints - very strict
|
||||
"/api/v1/auth/login": {"limit": 5, "window": 60, "name": "login_attempts"},
|
||||
"/api/v1/auth/register": {"limit": 3, "window": 3600, "name": "register_attempts"},
|
||||
"/api/v1/auth/refresh": {"limit": 10, "window": 60, "name": "refresh_attempts"},
|
||||
"/api/v1/auth/login/guest": {"limit": 10, "window": 60, "name": "guest_login"},
|
||||
"/api/v1/auth/wechat-login": {"limit": 5, "window": 60, "name": "wechat_login"},
|
||||
|
||||
# Password endpoints - very strict
|
||||
"/api/v1/auth/password": {"limit": 3, "window": 300, "name": "password_change"},
|
||||
|
||||
# Profile endpoints - moderate
|
||||
"/api/v1/auth/me": {"limit": 30, "window": 60, "name": "profile_update"},
|
||||
"/api/v1/auth/settings": {"limit": 30, "window": 60, "name": "settings_update"},
|
||||
|
||||
# Payment endpoints - strict
|
||||
"/api/v1/payment": {"limit": 20, "window": 60, "name": "payment_requests"},
|
||||
|
||||
# WhatsApp endpoints - moderate
|
||||
"/api/v1/whatsapp": {"limit": 30, "window": 60, "name": "whatsapp_requests"},
|
||||
|
||||
# AI endpoints - moderate
|
||||
"/api/v1/ai": {"limit": 60, "window": 60, "name": "ai_requests"},
|
||||
"/api/v1/ai-assistant": {"limit": 60, "window": 60, "name": "ai_assistant"},
|
||||
|
||||
# Translation endpoints - moderate
|
||||
"/api/v1/translate": {"limit": 60, "window": 60, "name": "translate_requests"},
|
||||
|
||||
# Marketing endpoints - moderate
|
||||
"/api/v1/marketing": {"limit": 30, "window": 60, "name": "marketing_requests"},
|
||||
|
||||
# Quotation endpoints - moderate
|
||||
"/api/v1/quotations": {"limit": 30, "window": 60, "name": "quotation_requests"},
|
||||
|
||||
# Customer endpoints - moderate
|
||||
"/api/v1/customers": {"limit": 60, "window": 60, "name": "customer_requests"},
|
||||
|
||||
# Admin endpoints - strict
|
||||
"/api/v1/admin": {"limit": 30, "window": 60, "name": "admin_requests"},
|
||||
}
|
||||
|
||||
|
||||
async def check_rate_limit(user_id: str, tier: str) -> int:
|
||||
async def check_endpoint_rate_limit(
|
||||
path: str, method: str, user_id: Optional[str], tier: str
|
||||
) -> tuple[bool, int, str]:
|
||||
"""
|
||||
Check rate limit for a specific endpoint.
|
||||
|
||||
Returns: (allowed, remaining, limit_name)
|
||||
"""
|
||||
if not user_id:
|
||||
# Use IP-based limiting for anonymous users
|
||||
user_id = "anonymous"
|
||||
|
||||
# Find matching endpoint limit
|
||||
matched_limit = None
|
||||
for endpoint_pattern, config in SENSITIVE_ENDPOINT_LIMITS.items():
|
||||
if path.startswith(endpoint_pattern):
|
||||
matched_limit = config
|
||||
break
|
||||
|
||||
if not matched_limit:
|
||||
# Use tier-based global limit
|
||||
limit = RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||
key = f"ratelimit:{user_id}:{int(time.time() // 60)}"
|
||||
limit_name = f"global_{tier}"
|
||||
else:
|
||||
limit = matched_limit["limit"]
|
||||
window = matched_limit["window"]
|
||||
key = f"ratelimit:{matched_limit['name']}:{user_id}:{int(time.time() // window)}"
|
||||
limit_name = matched_limit["name"]
|
||||
|
||||
r = await get_redis()
|
||||
count = await r.incr(key)
|
||||
|
||||
if count == 1:
|
||||
# Set expiry based on window
|
||||
window = matched_limit["window"] if matched_limit else 60
|
||||
await r.expire(key, window + 5)
|
||||
|
||||
remaining = max(0, limit - count)
|
||||
return remaining > 0, remaining, limit_name
|
||||
|
||||
|
||||
async def check_global_rate_limit(user_id: str, tier: str) -> int:
|
||||
"""Check global rate limit (per minute)"""
|
||||
r = await get_redis()
|
||||
now = time.time()
|
||||
window = 60
|
||||
key = f"ratelimit:{user_id}:{int(now // window)}"
|
||||
|
||||
|
||||
count = await r.incr(key)
|
||||
if count == 1:
|
||||
await r.expire(key, window + 5)
|
||||
|
||||
|
||||
limit = RATE_LIMITS.get(tier, 100)
|
||||
remaining = max(0, limit - count)
|
||||
return remaining
|
||||
|
||||
|
||||
class TierMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to extract user tier from JWT token"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.url.path.startswith("/api/v1"):
|
||||
tier = get_user_tier_from_token(request)
|
||||
@@ -81,66 +178,118 @@ class TierMiddleware(BaseHTTPMiddleware):
|
||||
request.state.user_id = None
|
||||
request.state.user_tier = "anonymous"
|
||||
request.state.tier_config = {}
|
||||
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Enhanced rate limiting middleware with fine-grained endpoint limits.
|
||||
|
||||
Features:
|
||||
- Tier-based global rate limits
|
||||
- Endpoint-specific rate limits for sensitive operations
|
||||
- Separate limits for authentication, payment, and admin operations
|
||||
- Proper rate limit headers in responses
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not request.url.path.startswith("/api/v1"):
|
||||
path = request.url.path
|
||||
|
||||
# Skip rate limiting for health check and non-API routes
|
||||
if not path.startswith("/api/v1"):
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
user_tier = getattr(request.state, "user_tier", None)
|
||||
if user_tier in ("anonymous", None):
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
if not user_id:
|
||||
return await call_next(request)
|
||||
remaining = await check_rate_limit(
|
||||
user_id, user_tier
|
||||
# For anonymous users, still apply basic rate limiting
|
||||
# Use IP-based limiting
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
allowed, remaining, limit_name = await check_endpoint_rate_limit(
|
||||
path, request.method, client_ip, "anonymous"
|
||||
)
|
||||
if remaining == 0:
|
||||
if not allowed:
|
||||
return Response(
|
||||
status_code=429,
|
||||
content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}',
|
||||
media_type="application/json",
|
||||
headers={"Retry-After": "60"},
|
||||
headers={
|
||||
"Retry-After": "60",
|
||||
"X-RateLimit-Limit": "10",
|
||||
"X-RateLimit-Remaining": "0",
|
||||
"X-RateLimit-Reset": str(int(time.time()) + 60),
|
||||
},
|
||||
)
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Limit"] = "10"
|
||||
response.headers["X-RateLimit-Name"] = limit_name
|
||||
return response
|
||||
|
||||
try:
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
if not user_id:
|
||||
return await call_next(request)
|
||||
|
||||
# Check endpoint-specific rate limit
|
||||
allowed, remaining, limit_name = await check_endpoint_rate_limit(
|
||||
path, request.method, user_id, user_tier
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
return Response(
|
||||
status_code=429,
|
||||
content=f'{{"error":"RATE_LIMITED","detail":"Rate limit exceeded for {limit_name}. Try again later."}}',
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Retry-After": "60",
|
||||
"X-RateLimit-Limit": str(RATE_LIMITS.get(user_tier, 100)),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
"X-RateLimit-Reset": str(int(time.time()) + 60),
|
||||
"X-RateLimit-Name": limit_name,
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Limit"] = str(
|
||||
SENSITIVE_ENDPOINT_LIMITS.get(path.split("/api/v1")[1].split("/")[0] if "/" in path.split("/api/v1")[1] else "", {}).get("limit", RATE_LIMITS.get(user_tier, 100))
|
||||
)
|
||||
response.headers["X-RateLimit-Name"] = limit_name
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Rate limit check failed: {e}")
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class QuotaMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to enforce daily quotas for specific operations"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not request.url.path.startswith("/api/v1"):
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
user_tier = getattr(request.state, "user_tier", None)
|
||||
if user_tier in ("anonymous", None):
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
if not user_id:
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
tier = user_tier
|
||||
|
||||
|
||||
if tier == "enterprise":
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
|
||||
if method == "GET":
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
quota_map = [
|
||||
("/api/v1/translate/reply", {"free": settings.FREE_DAILY_REPLIES, "pro": settings.PRO_DAILY_REPLIES}),
|
||||
("/api/v1/translate", {"free": settings.FREE_DAILY_TRANSLATE_CHARS, "pro": settings.PRO_DAILY_TRANSLATE_CHARS}),
|
||||
@@ -148,20 +297,20 @@ class QuotaMiddleware(BaseHTTPMiddleware):
|
||||
("/api/v1/marketing", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}),
|
||||
("/api/v1/quotations", {"free": settings.FREE_DAILY_QUOTATIONS, "pro": settings.PRO_DAILY_QUOTATIONS}),
|
||||
]
|
||||
|
||||
|
||||
matched_key = None
|
||||
for prefix, limits in quota_map:
|
||||
if path.startswith(prefix):
|
||||
matched_key = prefix
|
||||
break
|
||||
|
||||
|
||||
if not matched_key:
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
limit = quota_map[matched_key].get(tier)
|
||||
if limit is None:
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
try:
|
||||
r = await get_redis()
|
||||
key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}"
|
||||
@@ -175,5 +324,5 @@ class QuotaMiddleware(BaseHTTPMiddleware):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Quota check failed: {e}")
|
||||
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
Reference in New Issue
Block a user