""" Enhanced Rate Limiting Module for TradeMate Provides fine-grained rate limiting for sensitive endpoints """ from fastapi import Request, Response, HTTPException, status from starlette.middleware.base import BaseHTTPMiddleware from app.config import settings from app.core.security import decode_token import redis.asyncio as aioredis from redis.asyncio import ConnectionPool import logging import time from datetime import datetime from typing import Dict, Optional logger = logging.getLogger(__name__) _redis_pool = None async def get_redis(): """Get Redis connection pool""" global _redis_pool if _redis_pool is None: _redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20) return aioredis.Redis(connection_pool=_redis_pool) def get_user_tier_from_token(request: Request) -> str: """Extract user tier from JWT token""" auth = request.headers.get("Authorization", "") if not auth.startswith("Bearer "): request.state.user_id = None request.state.user_tier = "anonymous" return "anonymous" payload = decode_token(auth[7:]) if not payload: request.state.user_id = None request.state.user_tier = "anonymous" return "anonymous" request.state.user_id = payload.get("sub") request.state.user_tier = payload.get("tier", "free") return request.state.user_tier # Global rate limits (per minute) by tier RATE_LIMITS = { "anonymous": 10, "guest": 30, "free": 60, "pro": 300, "enterprise": 1000, } # Sensitive endpoint rate limits (stricter limits regardless of tier) SENSITIVE_ENDPOINT_LIMITS: Dict[str, dict] = { # Authentication endpoints - very strict "/api/v1/auth/login": {"limit": 5, "window": 60, "name": "login_attempts"}, "/api/v1/auth/register": {"limit": 3, "window": 3600, "name": "register_attempts"}, "/api/v1/auth/refresh": {"limit": 10, "window": 60, "name": "refresh_attempts"}, "/api/v1/auth/login/guest": {"limit": 10, "window": 60, "name": "guest_login"}, "/api/v1/auth/wechat-login": {"limit": 5, "window": 60, "name": "wechat_login"}, # Password endpoints - very strict "/api/v1/auth/password": {"limit": 3, "window": 300, "name": "password_change"}, # Profile endpoints - moderate "/api/v1/auth/me": {"limit": 30, "window": 60, "name": "profile_update"}, "/api/v1/auth/settings": {"limit": 30, "window": 60, "name": "settings_update"}, # Payment endpoints - strict "/api/v1/payment": {"limit": 20, "window": 60, "name": "payment_requests"}, # WhatsApp endpoints - moderate "/api/v1/whatsapp": {"limit": 30, "window": 60, "name": "whatsapp_requests"}, # AI endpoints - moderate "/api/v1/ai": {"limit": 60, "window": 60, "name": "ai_requests"}, "/api/v1/ai-assistant": {"limit": 60, "window": 60, "name": "ai_assistant"}, # Translation endpoints - moderate "/api/v1/translate": {"limit": 60, "window": 60, "name": "translate_requests"}, # Marketing endpoints - moderate "/api/v1/marketing": {"limit": 30, "window": 60, "name": "marketing_requests"}, # Quotation endpoints - moderate "/api/v1/quotations": {"limit": 30, "window": 60, "name": "quotation_requests"}, # Customer endpoints - moderate "/api/v1/customers": {"limit": 60, "window": 60, "name": "customer_requests"}, # Admin endpoints - strict "/api/v1/admin": {"limit": 30, "window": 60, "name": "admin_requests"}, } async def check_endpoint_rate_limit( path: str, method: str, user_id: Optional[str], tier: str ) -> tuple[bool, int, str]: """ Check rate limit for a specific endpoint. Returns: (allowed, remaining, limit_name) """ if not user_id: # Use IP-based limiting for anonymous users user_id = "anonymous" # Find matching endpoint limit matched_limit = None for endpoint_pattern, config in SENSITIVE_ENDPOINT_LIMITS.items(): if path.startswith(endpoint_pattern): matched_limit = config break if not matched_limit: # Use tier-based global limit limit = RATE_LIMITS.get(tier, RATE_LIMITS["free"]) key = f"ratelimit:{user_id}:{int(time.time() // 60)}" limit_name = f"global_{tier}" else: limit = matched_limit["limit"] window = matched_limit["window"] key = f"ratelimit:{matched_limit['name']}:{user_id}:{int(time.time() // window)}" limit_name = matched_limit["name"] r = await get_redis() count = await r.incr(key) if count == 1: # Set expiry based on window window = matched_limit["window"] if matched_limit else 60 await r.expire(key, window + 5) remaining = max(0, limit - count) return remaining > 0, remaining, limit_name async def check_global_rate_limit(user_id: str, tier: str) -> int: """Check global rate limit (per minute)""" r = await get_redis() now = time.time() window = 60 key = f"ratelimit:{user_id}:{int(now // window)}" count = await r.incr(key) if count == 1: await r.expire(key, window + 5) limit = RATE_LIMITS.get(tier, 100) remaining = max(0, limit - count) return remaining class TierMiddleware(BaseHTTPMiddleware): """Middleware to extract user tier from JWT token""" async def dispatch(self, request: Request, call_next): if request.url.path.startswith("/api/v1"): tier = get_user_tier_from_token(request) tier_config = { "free": { "max_products": settings.FREE_MAX_PRODUCTS, "max_customers": settings.FREE_MAX_CUSTOMERS, }, "pro": { "max_products": settings.PRO_MAX_PRODUCTS, "max_customers": settings.PRO_MAX_CUSTOMERS, }, "enterprise": { "max_products": 9999, "max_customers": 99999, }, } request.state.tier_config = tier_config.get(tier, tier_config["free"]) else: request.state.user_id = None request.state.user_tier = "anonymous" request.state.tier_config = {} response = await call_next(request) return response class RateLimitMiddleware(BaseHTTPMiddleware): """ Enhanced rate limiting middleware with fine-grained endpoint limits. Features: - Tier-based global rate limits - Endpoint-specific rate limits for sensitive operations - Separate limits for authentication, payment, and admin operations - Proper rate limit headers in responses """ async def dispatch(self, request: Request, call_next): path = request.url.path # Skip rate limiting for health check and non-API routes if not path.startswith("/api/v1"): return await call_next(request) user_tier = getattr(request.state, "user_tier", None) if user_tier in ("anonymous", None): # For anonymous users, still apply basic rate limiting # Use IP-based limiting client_ip = request.client.host if request.client else "unknown" allowed, remaining, limit_name = await check_endpoint_rate_limit( path, request.method, client_ip, "anonymous" ) if not allowed: return Response( status_code=429, content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}', media_type="application/json", headers={ "Retry-After": "60", "X-RateLimit-Limit": "10", "X-RateLimit-Remaining": "0", "X-RateLimit-Reset": str(int(time.time()) + 60), }, ) response = await call_next(request) response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Limit"] = "10" response.headers["X-RateLimit-Name"] = limit_name return response try: user_id = getattr(request.state, "user_id", None) if not user_id: return await call_next(request) # Check endpoint-specific rate limit allowed, remaining, limit_name = await check_endpoint_rate_limit( path, request.method, user_id, user_tier ) if not allowed: return Response( status_code=429, content=f'{{"error":"RATE_LIMITED","detail":"Rate limit exceeded for {limit_name}. Try again later."}}', media_type="application/json", headers={ "Retry-After": "60", "X-RateLimit-Limit": str(RATE_LIMITS.get(user_tier, 100)), "X-RateLimit-Remaining": "0", "X-RateLimit-Reset": str(int(time.time()) + 60), "X-RateLimit-Name": limit_name, }, ) response = await call_next(request) response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Limit"] = str( SENSITIVE_ENDPOINT_LIMITS.get(path.split("/api/v1")[1].split("/")[0] if "/" in path.split("/api/v1")[1] else "", {}).get("limit", RATE_LIMITS.get(user_tier, 100)) ) response.headers["X-RateLimit-Name"] = limit_name return response except Exception as e: logger.warning(f"Rate limit check failed: {e}") return await call_next(request) class QuotaMiddleware(BaseHTTPMiddleware): """Middleware to enforce daily quotas for specific operations""" async def dispatch(self, request: Request, call_next): if not request.url.path.startswith("/api/v1"): return await call_next(request) user_tier = getattr(request.state, "user_tier", None) if user_tier in ("anonymous", None): return await call_next(request) user_id = getattr(request.state, "user_id", None) if not user_id: return await call_next(request) tier = user_tier if tier == "enterprise": return await call_next(request) path = request.url.path method = request.method if method == "GET": return await call_next(request) quota_map = [ ("/api/v1/translate/reply", {"free": settings.FREE_DAILY_REPLIES, "pro": settings.PRO_DAILY_REPLIES}), ("/api/v1/translate", {"free": settings.FREE_DAILY_TRANSLATE_CHARS, "pro": settings.PRO_DAILY_TRANSLATE_CHARS}), ("/api/v1/marketing/generate", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}), ("/api/v1/marketing", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}), ("/api/v1/quotations", {"free": settings.FREE_DAILY_QUOTATIONS, "pro": settings.PRO_DAILY_QUOTATIONS}), ] matched_key = None for prefix, limits in quota_map: if path.startswith(prefix): matched_key = prefix break if not matched_key: return await call_next(request) limit = quota_map[matched_key].get(tier) if limit is None: return await call_next(request) try: r = await get_redis() key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}" current = await r.incr(key) await r.expire(key, 86400) if current > limit: from app.core.exceptions import QuotaExceededError raise QuotaExceededError(matched_key) request.state.quota_remaining = limit - current except QuotaExceededError: raise except Exception as e: logger.warning(f"Quota check failed: {e}") return await call_next(request)