from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from app.config import settings from app.core.security import decode_token import redis.asyncio as aioredis from redis.asyncio import ConnectionPool import logging import time from datetime import datetime logger = logging.getLogger(__name__) _redis_pool = None async def get_redis(): global _redis_pool if _redis_pool is None: _redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20) return aioredis.Redis(connection_pool=_redis_pool) def get_user_tier_from_token(request: Request) -> str: auth = request.headers.get("Authorization", "") if not auth.startswith("Bearer "): request.state.user_id = None request.state.user_tier = "anonymous" return "anonymous" payload = decode_token(auth[7:]) if not payload: request.state.user_id = None request.state.user_tier = "anonymous" return "anonymous" request.state.user_id = payload.get("sub") request.state.user_tier = payload.get("tier", "free") return request.state.user_tier RATE_LIMITS = { "free": 100, "pro": 500, "enterprise": 2000, } async def check_rate_limit(user_id: str, tier: str) -> int: r = await get_redis() now = time.time() window = 60 key = f"ratelimit:{user_id}:{int(now // window)}" count = await r.incr(key) if count == 1: await r.expire(key, window + 5) limit = RATE_LIMITS.get(tier, 100) remaining = max(0, limit - count) return remaining class TierMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if request.url.path.startswith("/api/v1"): tier = get_user_tier_from_token(request) tier_config = { "free": { "max_products": settings.FREE_MAX_PRODUCTS, "max_customers": settings.FREE_MAX_CUSTOMERS, }, "pro": { "max_products": settings.PRO_MAX_PRODUCTS, "max_customers": settings.PRO_MAX_CUSTOMERS, }, "enterprise": { "max_products": 9999, "max_customers": 99999, }, } request.state.tier_config = tier_config.get(tier, tier_config["free"]) else: request.state.user_id = None request.state.user_tier = "anonymous" request.state.tier_config = {} response = await call_next(request) return response class RateLimitMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if not request.url.path.startswith("/api/v1"): return await call_next(request) user_tier = getattr(request.state, "user_tier", None) if user_tier in ("anonymous", None): return await call_next(request) try: user_id = getattr(request.state, "user_id", None) if not user_id: return await call_next(request) remaining = await check_rate_limit( user_id, user_tier ) if remaining == 0: return Response( status_code=429, content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}', media_type="application/json", headers={"Retry-After": "60"}, ) response = await call_next(request) response.headers["X-RateLimit-Remaining"] = str(remaining) return response except Exception as e: logger.warning(f"Rate limit check failed: {e}") return await call_next(request) class QuotaMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if not request.url.path.startswith("/api/v1"): return await call_next(request) user_tier = getattr(request.state, "user_tier", None) if user_tier in ("anonymous", None): return await call_next(request) user_id = getattr(request.state, "user_id", None) if not user_id: return await call_next(request) tier = user_tier if tier == "enterprise": return await call_next(request) path = request.url.path method = request.method if method == "GET": return await call_next(request) quota_map = [ ("/api/v1/translate/reply", {"free": settings.FREE_DAILY_REPLIES, "pro": settings.PRO_DAILY_REPLIES}), ("/api/v1/translate", {"free": settings.FREE_DAILY_TRANSLATE_CHARS, "pro": settings.PRO_DAILY_TRANSLATE_CHARS}), ("/api/v1/marketing/generate", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}), ("/api/v1/marketing", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}), ("/api/v1/quotations", {"free": settings.FREE_DAILY_QUOTATIONS, "pro": settings.PRO_DAILY_QUOTATIONS}), ] matched_key = None for prefix, limits in quota_map: if path.startswith(prefix): matched_key = prefix break if not matched_key: return await call_next(request) limit = quota_map[matched_key].get(tier) if limit is None: return await call_next(request) try: r = await get_redis() key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}" current = await r.incr(key) await r.expire(key, 86400) if current > limit: from app.core.exceptions import QuotaExceededError raise QuotaExceededError(matched_key) request.state.quota_remaining = limit - current except QuotaExceededError: raise except Exception as e: logger.warning(f"Quota check failed: {e}") return await call_next(request)