Files
trade-assistant/backend/app/core/middleware.py
T
TradeMate Dev 7b62c2f8b4 feat: 修复 H5 底部导航覆盖 + 更新项目进度文档
## H5 底部导航修复 (Bug #10)
- 精简 App.vue,移除重复 tabbar,仅保留全局样式
- uni-page 设置 height: calc(100% - 50px) + overflow-y: auto
- 内容区域精确停在底部导航上方,独立滚动不再叠加
- 恢复 custom-tab-bar 组件

## 项目进度文档
- PROGRESS.md 更新至 10 个 Bug 修复
- 新增 H5 底部导航修复记录
- 新增历史变更条目
2026-05-12 20:24:42 +08:00

191 lines
6.1 KiB
Python

from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.config import settings
from app.core.security import decode_token
import redis.asyncio as aioredis
from redis.asyncio import ConnectionPool
import logging
import time
from datetime import datetime
logger = logging.getLogger(__name__)
_redis_pool = None
async def get_redis():
global _redis_pool
if _redis_pool is None:
_redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20)
return aioredis.Redis(connection_pool=_redis_pool)
def get_user_tier_from_token(request: Request) -> str:
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
request.state.user_id = None
request.state.user_tier = "anonymous"
return "anonymous"
payload = decode_token(auth[7:])
if not payload:
request.state.user_id = None
request.state.user_tier = "anonymous"
return "anonymous"
request.state.user_id = payload.get("sub")
request.state.user_tier = payload.get("tier", "free")
return request.state.user_tier
RATE_LIMITS = {
"free": 100,
"pro": 500,
"enterprise": 2000,
}
async def check_rate_limit(user_id: str, tier: str) -> int:
r = await get_redis()
now = time.time()
window = 60
key = f"ratelimit:{user_id}:{int(now // window)}"
count = await r.incr(key)
if count == 1:
await r.expire(key, window + 5)
limit = RATE_LIMITS.get(tier, 100)
remaining = max(0, limit - count)
return remaining
class TierMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.url.path.startswith("/api/v1"):
tier = get_user_tier_from_token(request)
tier_config = {
"free": {
"max_products": settings.FREE_MAX_PRODUCTS,
"max_customers": settings.FREE_MAX_CUSTOMERS,
},
"pro": {
"max_products": settings.PRO_MAX_PRODUCTS,
"max_customers": settings.PRO_MAX_CUSTOMERS,
},
"enterprise": {
"max_products": 9999,
"max_customers": 99999,
},
}
request.state.tier_config = tier_config.get(tier, tier_config["free"])
else:
request.state.user_id = None
request.state.user_tier = "anonymous"
request.state.tier_config = {}
response = await call_next(request)
return response
class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if not request.url.path.startswith("/api/v1"):
return await call_next(request)
user_tier = getattr(request.state, "user_tier", None)
if user_tier in ("anonymous", None):
return await call_next(request)
try:
user_id = getattr(request.state, "user_id", None)
if not user_id:
return await call_next(request)
remaining = await check_rate_limit(
user_id, user_tier
)
if remaining == 0:
return Response(
status_code=429,
content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}',
media_type="application/json",
headers={"Retry-After": "60"},
)
response = await call_next(request)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
except Exception as e:
logger.warning(f"Rate limit check failed: {e}")
return await call_next(request)
class QuotaMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if not request.url.path.startswith("/api/v1"):
return await call_next(request)
user_tier = getattr(request.state, "user_tier", None)
if user_tier in ("anonymous", None):
return await call_next(request)
user_id = getattr(request.state, "user_id", None)
if not user_id:
return await call_next(request)
tier = user_tier
if tier == "enterprise":
return await call_next(request)
path = request.url.path
method = request.method
if method == "GET":
return await call_next(request)
quota_map = {
"/api/v1/translate": {
"free": settings.FREE_DAILY_TRANSLATE_CHARS,
"pro": settings.PRO_DAILY_TRANSLATE_CHARS,
},
"/api/v1/translate/reply": {
"free": settings.FREE_DAILY_REPLIES,
"pro": settings.PRO_DAILY_REPLIES,
},
"/api/v1/marketing": {
"free": settings.FREE_DAILY_MARKETING,
"pro": settings.PRO_DAILY_MARKETING,
},
"/api/v1/quotations": {
"free": settings.FREE_DAILY_QUOTATIONS,
"pro": settings.PRO_DAILY_QUOTATIONS,
},
}
matched_key = None
for prefix, limits in quota_map.items():
if path.startswith(prefix):
matched_key = prefix
break
if not matched_key:
return await call_next(request)
limit = quota_map[matched_key].get(tier)
if limit is None:
return await call_next(request)
try:
r = await get_redis()
key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}"
current = await r.incr(key)
await r.expire(key, 86400)
if current > limit:
from app.core.exceptions import QuotaExceededError
raise QuotaExceededError(matched_key)
request.state.quota_remaining = limit - current
except QuotaExceededError:
raise
except Exception as e:
logger.warning(f"Quota check failed: {e}")
return await call_next(request)