bed5c7abef
- Separate workspace landing from login for better UX - Referral system rewards both parties with Pro days - Quota enforcement prevents abuse without breaking endpoints - 7-day free trial with auto-downgrade on expiry - Admin-managed search provider config (SearXNG, Bing) - 15% discount on annual subscriptions - MCP search server wrapping opencode search - Fix discovery module field name mismatch causing 422
180 lines
6.0 KiB
Python
180 lines
6.0 KiB
Python
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from app.config import settings
|
|
from app.core.security import decode_token
|
|
import redis.asyncio as aioredis
|
|
from redis.asyncio import ConnectionPool
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_redis_pool = None
|
|
|
|
|
|
async def get_redis():
|
|
global _redis_pool
|
|
if _redis_pool is None:
|
|
_redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20)
|
|
return aioredis.Redis(connection_pool=_redis_pool)
|
|
|
|
|
|
def get_user_tier_from_token(request: Request) -> str:
|
|
auth = request.headers.get("Authorization", "")
|
|
if not auth.startswith("Bearer "):
|
|
request.state.user_id = None
|
|
request.state.user_tier = "anonymous"
|
|
return "anonymous"
|
|
payload = decode_token(auth[7:])
|
|
if not payload:
|
|
request.state.user_id = None
|
|
request.state.user_tier = "anonymous"
|
|
return "anonymous"
|
|
request.state.user_id = payload.get("sub")
|
|
request.state.user_tier = payload.get("tier", "free")
|
|
return request.state.user_tier
|
|
|
|
|
|
RATE_LIMITS = {
|
|
"free": 100,
|
|
"pro": 500,
|
|
"enterprise": 2000,
|
|
}
|
|
|
|
|
|
async def check_rate_limit(user_id: str, tier: str) -> int:
|
|
r = await get_redis()
|
|
now = time.time()
|
|
window = 60
|
|
key = f"ratelimit:{user_id}:{int(now // window)}"
|
|
|
|
count = await r.incr(key)
|
|
if count == 1:
|
|
await r.expire(key, window + 5)
|
|
|
|
limit = RATE_LIMITS.get(tier, 100)
|
|
remaining = max(0, limit - count)
|
|
return remaining
|
|
|
|
|
|
class TierMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
if request.url.path.startswith("/api/v1"):
|
|
tier = get_user_tier_from_token(request)
|
|
tier_config = {
|
|
"free": {
|
|
"max_products": settings.FREE_MAX_PRODUCTS,
|
|
"max_customers": settings.FREE_MAX_CUSTOMERS,
|
|
},
|
|
"pro": {
|
|
"max_products": settings.PRO_MAX_PRODUCTS,
|
|
"max_customers": settings.PRO_MAX_CUSTOMERS,
|
|
},
|
|
"enterprise": {
|
|
"max_products": 9999,
|
|
"max_customers": 99999,
|
|
},
|
|
}
|
|
request.state.tier_config = tier_config.get(tier, tier_config["free"])
|
|
else:
|
|
request.state.user_id = None
|
|
request.state.user_tier = "anonymous"
|
|
request.state.tier_config = {}
|
|
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
if not request.url.path.startswith("/api/v1"):
|
|
return await call_next(request)
|
|
|
|
user_tier = getattr(request.state, "user_tier", None)
|
|
if user_tier in ("anonymous", None):
|
|
return await call_next(request)
|
|
|
|
try:
|
|
user_id = getattr(request.state, "user_id", None)
|
|
if not user_id:
|
|
return await call_next(request)
|
|
remaining = await check_rate_limit(
|
|
user_id, user_tier
|
|
)
|
|
if remaining == 0:
|
|
return Response(
|
|
status_code=429,
|
|
content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}',
|
|
media_type="application/json",
|
|
headers={"Retry-After": "60"},
|
|
)
|
|
response = await call_next(request)
|
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
return response
|
|
except Exception as e:
|
|
logger.warning(f"Rate limit check failed: {e}")
|
|
return await call_next(request)
|
|
|
|
|
|
class QuotaMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
if not request.url.path.startswith("/api/v1"):
|
|
return await call_next(request)
|
|
|
|
user_tier = getattr(request.state, "user_tier", None)
|
|
if user_tier in ("anonymous", None):
|
|
return await call_next(request)
|
|
|
|
user_id = getattr(request.state, "user_id", None)
|
|
if not user_id:
|
|
return await call_next(request)
|
|
|
|
tier = user_tier
|
|
|
|
if tier == "enterprise":
|
|
return await call_next(request)
|
|
|
|
path = request.url.path
|
|
method = request.method
|
|
|
|
if method == "GET":
|
|
return await call_next(request)
|
|
|
|
quota_map = [
|
|
("/api/v1/translate/reply", {"free": settings.FREE_DAILY_REPLIES, "pro": settings.PRO_DAILY_REPLIES}),
|
|
("/api/v1/translate", {"free": settings.FREE_DAILY_TRANSLATE_CHARS, "pro": settings.PRO_DAILY_TRANSLATE_CHARS}),
|
|
("/api/v1/marketing/generate", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}),
|
|
("/api/v1/marketing", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}),
|
|
("/api/v1/quotations", {"free": settings.FREE_DAILY_QUOTATIONS, "pro": settings.PRO_DAILY_QUOTATIONS}),
|
|
]
|
|
|
|
matched_key = None
|
|
for prefix, limits in quota_map:
|
|
if path.startswith(prefix):
|
|
matched_key = prefix
|
|
break
|
|
|
|
if not matched_key:
|
|
return await call_next(request)
|
|
|
|
limit = quota_map[matched_key].get(tier)
|
|
if limit is None:
|
|
return await call_next(request)
|
|
|
|
try:
|
|
r = await get_redis()
|
|
key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}"
|
|
current = await r.incr(key)
|
|
await r.expire(key, 86400)
|
|
if current > limit:
|
|
from app.core.exceptions import QuotaExceededError
|
|
raise QuotaExceededError(matched_key)
|
|
request.state.quota_remaining = limit - current
|
|
except QuotaExceededError:
|
|
raise
|
|
except Exception as e:
|
|
logger.warning(f"Quota check failed: {e}")
|
|
|
|
return await call_next(request)
|