T-005: Security hardening - CORS, Rate Limit, CSRF

- CORS: Restrict allowed origins to specific frontend URLs, limit methods and headers
- Rate Limit: Add fine-grained endpoint-specific rate limits for sensitive operations
  - Login: 5 requests/minute
  - Register: 3 requests/hour
  - Password change: 3 requests/5 minutes
  - Payment: 20 requests/minute
  - Admin: 30 requests/minute
- CSRF: Add CSRF protection middleware with double-submit cookie pattern
  - New app/core/csrf.py module with CSRFMiddleware
  - Require CSRF tokens on sensitive endpoints (auth, payment, profile)
  - Skip webhook endpoints for CSRF validation
- Fix pydantic-settings import in config.py
This commit is contained in:
TradeMate Dev
2026-05-29 10:26:23 +08:00
parent 7c9885f704
commit c04fa2c19f
7 changed files with 464 additions and 58 deletions
+181 -32
View File
@@ -1,4 +1,8 @@
from fastapi import Request, Response
"""
Enhanced Rate Limiting Module for TradeMate
Provides fine-grained rate limiting for sensitive endpoints
"""
from fastapi import Request, Response, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from app.config import settings
from app.core.security import decode_token
@@ -7,6 +11,7 @@ from redis.asyncio import ConnectionPool
import logging
import time
from datetime import datetime
from typing import Dict, Optional
logger = logging.getLogger(__name__)
@@ -14,6 +19,7 @@ _redis_pool = None
async def get_redis():
"""Get Redis connection pool"""
global _redis_pool
if _redis_pool is None:
_redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20)
@@ -21,6 +27,7 @@ async def get_redis():
def get_user_tier_from_token(request: Request) -> str:
"""Extract user tier from JWT token"""
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
request.state.user_id = None
@@ -36,29 +43,119 @@ def get_user_tier_from_token(request: Request) -> str:
return request.state.user_tier
# Global rate limits (per minute) by tier
RATE_LIMITS = {
"free": 100,
"pro": 500,
"enterprise": 2000,
"anonymous": 10,
"guest": 30,
"free": 60,
"pro": 300,
"enterprise": 1000,
}
# Sensitive endpoint rate limits (stricter limits regardless of tier)
SENSITIVE_ENDPOINT_LIMITS: Dict[str, dict] = {
# Authentication endpoints - very strict
"/api/v1/auth/login": {"limit": 5, "window": 60, "name": "login_attempts"},
"/api/v1/auth/register": {"limit": 3, "window": 3600, "name": "register_attempts"},
"/api/v1/auth/refresh": {"limit": 10, "window": 60, "name": "refresh_attempts"},
"/api/v1/auth/login/guest": {"limit": 10, "window": 60, "name": "guest_login"},
"/api/v1/auth/wechat-login": {"limit": 5, "window": 60, "name": "wechat_login"},
# Password endpoints - very strict
"/api/v1/auth/password": {"limit": 3, "window": 300, "name": "password_change"},
# Profile endpoints - moderate
"/api/v1/auth/me": {"limit": 30, "window": 60, "name": "profile_update"},
"/api/v1/auth/settings": {"limit": 30, "window": 60, "name": "settings_update"},
# Payment endpoints - strict
"/api/v1/payment": {"limit": 20, "window": 60, "name": "payment_requests"},
# WhatsApp endpoints - moderate
"/api/v1/whatsapp": {"limit": 30, "window": 60, "name": "whatsapp_requests"},
# AI endpoints - moderate
"/api/v1/ai": {"limit": 60, "window": 60, "name": "ai_requests"},
"/api/v1/ai-assistant": {"limit": 60, "window": 60, "name": "ai_assistant"},
# Translation endpoints - moderate
"/api/v1/translate": {"limit": 60, "window": 60, "name": "translate_requests"},
# Marketing endpoints - moderate
"/api/v1/marketing": {"limit": 30, "window": 60, "name": "marketing_requests"},
# Quotation endpoints - moderate
"/api/v1/quotations": {"limit": 30, "window": 60, "name": "quotation_requests"},
# Customer endpoints - moderate
"/api/v1/customers": {"limit": 60, "window": 60, "name": "customer_requests"},
# Admin endpoints - strict
"/api/v1/admin": {"limit": 30, "window": 60, "name": "admin_requests"},
}
async def check_rate_limit(user_id: str, tier: str) -> int:
async def check_endpoint_rate_limit(
path: str, method: str, user_id: Optional[str], tier: str
) -> tuple[bool, int, str]:
"""
Check rate limit for a specific endpoint.
Returns: (allowed, remaining, limit_name)
"""
if not user_id:
# Use IP-based limiting for anonymous users
user_id = "anonymous"
# Find matching endpoint limit
matched_limit = None
for endpoint_pattern, config in SENSITIVE_ENDPOINT_LIMITS.items():
if path.startswith(endpoint_pattern):
matched_limit = config
break
if not matched_limit:
# Use tier-based global limit
limit = RATE_LIMITS.get(tier, RATE_LIMITS["free"])
key = f"ratelimit:{user_id}:{int(time.time() // 60)}"
limit_name = f"global_{tier}"
else:
limit = matched_limit["limit"]
window = matched_limit["window"]
key = f"ratelimit:{matched_limit['name']}:{user_id}:{int(time.time() // window)}"
limit_name = matched_limit["name"]
r = await get_redis()
count = await r.incr(key)
if count == 1:
# Set expiry based on window
window = matched_limit["window"] if matched_limit else 60
await r.expire(key, window + 5)
remaining = max(0, limit - count)
return remaining > 0, remaining, limit_name
async def check_global_rate_limit(user_id: str, tier: str) -> int:
"""Check global rate limit (per minute)"""
r = await get_redis()
now = time.time()
window = 60
key = f"ratelimit:{user_id}:{int(now // window)}"
count = await r.incr(key)
if count == 1:
await r.expire(key, window + 5)
limit = RATE_LIMITS.get(tier, 100)
remaining = max(0, limit - count)
return remaining
class TierMiddleware(BaseHTTPMiddleware):
"""Middleware to extract user tier from JWT token"""
async def dispatch(self, request: Request, call_next):
if request.url.path.startswith("/api/v1"):
tier = get_user_tier_from_token(request)
@@ -81,66 +178,118 @@ class TierMiddleware(BaseHTTPMiddleware):
request.state.user_id = None
request.state.user_tier = "anonymous"
request.state.tier_config = {}
response = await call_next(request)
return response
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Enhanced rate limiting middleware with fine-grained endpoint limits.
Features:
- Tier-based global rate limits
- Endpoint-specific rate limits for sensitive operations
- Separate limits for authentication, payment, and admin operations
- Proper rate limit headers in responses
"""
async def dispatch(self, request: Request, call_next):
if not request.url.path.startswith("/api/v1"):
path = request.url.path
# Skip rate limiting for health check and non-API routes
if not path.startswith("/api/v1"):
return await call_next(request)
user_tier = getattr(request.state, "user_tier", None)
if user_tier in ("anonymous", None):
return await call_next(request)
try:
user_id = getattr(request.state, "user_id", None)
if not user_id:
return await call_next(request)
remaining = await check_rate_limit(
user_id, user_tier
# For anonymous users, still apply basic rate limiting
# Use IP-based limiting
client_ip = request.client.host if request.client else "unknown"
allowed, remaining, limit_name = await check_endpoint_rate_limit(
path, request.method, client_ip, "anonymous"
)
if remaining == 0:
if not allowed:
return Response(
status_code=429,
content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}',
media_type="application/json",
headers={"Retry-After": "60"},
headers={
"Retry-After": "60",
"X-RateLimit-Limit": "10",
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(time.time()) + 60),
},
)
response = await call_next(request)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Limit"] = "10"
response.headers["X-RateLimit-Name"] = limit_name
return response
try:
user_id = getattr(request.state, "user_id", None)
if not user_id:
return await call_next(request)
# Check endpoint-specific rate limit
allowed, remaining, limit_name = await check_endpoint_rate_limit(
path, request.method, user_id, user_tier
)
if not allowed:
return Response(
status_code=429,
content=f'{{"error":"RATE_LIMITED","detail":"Rate limit exceeded for {limit_name}. Try again later."}}',
media_type="application/json",
headers={
"Retry-After": "60",
"X-RateLimit-Limit": str(RATE_LIMITS.get(user_tier, 100)),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(time.time()) + 60),
"X-RateLimit-Name": limit_name,
},
)
response = await call_next(request)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Limit"] = str(
SENSITIVE_ENDPOINT_LIMITS.get(path.split("/api/v1")[1].split("/")[0] if "/" in path.split("/api/v1")[1] else "", {}).get("limit", RATE_LIMITS.get(user_tier, 100))
)
response.headers["X-RateLimit-Name"] = limit_name
return response
except Exception as e:
logger.warning(f"Rate limit check failed: {e}")
return await call_next(request)
class QuotaMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce daily quotas for specific operations"""
async def dispatch(self, request: Request, call_next):
if not request.url.path.startswith("/api/v1"):
return await call_next(request)
user_tier = getattr(request.state, "user_tier", None)
if user_tier in ("anonymous", None):
return await call_next(request)
user_id = getattr(request.state, "user_id", None)
if not user_id:
return await call_next(request)
tier = user_tier
if tier == "enterprise":
return await call_next(request)
path = request.url.path
method = request.method
if method == "GET":
return await call_next(request)
quota_map = [
("/api/v1/translate/reply", {"free": settings.FREE_DAILY_REPLIES, "pro": settings.PRO_DAILY_REPLIES}),
("/api/v1/translate", {"free": settings.FREE_DAILY_TRANSLATE_CHARS, "pro": settings.PRO_DAILY_TRANSLATE_CHARS}),
@@ -148,20 +297,20 @@ class QuotaMiddleware(BaseHTTPMiddleware):
("/api/v1/marketing", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}),
("/api/v1/quotations", {"free": settings.FREE_DAILY_QUOTATIONS, "pro": settings.PRO_DAILY_QUOTATIONS}),
]
matched_key = None
for prefix, limits in quota_map:
if path.startswith(prefix):
matched_key = prefix
break
if not matched_key:
return await call_next(request)
limit = quota_map[matched_key].get(tier)
if limit is None:
return await call_next(request)
try:
r = await get_redis()
key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}"
@@ -175,5 +324,5 @@ class QuotaMiddleware(BaseHTTPMiddleware):
raise
except Exception as e:
logger.warning(f"Quota check failed: {e}")
return await call_next(request)