T-005: Security hardening - CORS, Rate Limit, CSRF

- CORS: Restrict allowed origins to specific frontend URLs, limit methods and headers
- Rate Limit: Add fine-grained endpoint-specific rate limits for sensitive operations
  - Login: 5 requests/minute
  - Register: 3 requests/hour
  - Password change: 3 requests/5 minutes
  - Payment: 20 requests/minute
  - Admin: 30 requests/minute
- CSRF: Add CSRF protection middleware with double-submit cookie pattern
  - New app/core/csrf.py module with CSRFMiddleware
  - Require CSRF tokens on sensitive endpoints (auth, payment, profile)
  - Skip webhook endpoints for CSRF validation
- Fix pydantic-settings import in config.py
This commit is contained in:
TradeMate Dev
2026-05-29 10:26:23 +08:00
parent 7c9885f704
commit c04fa2c19f
7 changed files with 464 additions and 58 deletions
+4
View File
@@ -52,3 +52,7 @@ DEBUG=true
# URL # URL
FRONTEND_URL=http://localhost:3000 FRONTEND_URL=http://localhost:3000
BACKEND_URL=http://localhost:8000 BACKEND_URL=http://localhost:8000
# Security (CSRF/CORS) - CSRF protection is enabled by default
# Frontend must send X-CSRF-Token header with state-changing requests
# The token is provided via csrf_token cookie and X-CSRF-Token response header
+18 -2
View File
@@ -7,11 +7,13 @@ import uuid
from app.database import get_db from app.database import get_db
from app.models.user import User from app.models.user import User
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
from app.core.csrf import require_csrf_token
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, EmailStr
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app.services.admin import AdminService from app.services.admin import AdminService
from app.models.subscription import Subscription from app.models.subscription import Subscription
from app.api.v1.referral import apply_referral from app.api.v1.referral import apply_referral
from app.config import settings
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -44,7 +46,12 @@ class RefreshRequest(BaseModel):
@router.post("/register") @router.post("/register")
async def register(data: RegisterRequest, request: Request, db: AsyncSession = Depends(get_db)): async def register(
data: RegisterRequest,
request: Request,
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
existing = await db.execute(select(User).where(User.phone == data.phone)) existing = await db.execute(select(User).where(User.phone == data.phone))
if existing.scalar_one_or_none(): if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Phone already registered") raise HTTPException(status_code=400, detail="Phone already registered")
@@ -92,6 +99,7 @@ async def login(
data: LoginRequest, data: LoginRequest,
request: Request, request: Request,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
): ):
login_id = data.username or data.phone login_id = data.username or data.phone
if not login_id: if not login_id:
@@ -269,6 +277,7 @@ async def update_me(
data: ProfileUpdate, data: ProfileUpdate,
authorization: Optional[str] = Header(None, alias="Authorization"), authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
): ):
if not authorization or not authorization.startswith("Bearer "): if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token") raise HTTPException(status_code=401, detail="Missing token")
@@ -306,6 +315,7 @@ async def change_password(
data: PasswordChange, data: PasswordChange,
authorization: Optional[str] = Header(None, alias="Authorization"), authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
): ):
if not authorization or not authorization.startswith("Bearer "): if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token") raise HTTPException(status_code=401, detail="Missing token")
@@ -340,7 +350,12 @@ async def wechat_config():
@router.post("/wechat-login") @router.post("/wechat-login")
async def wechat_login(data: WeChatLoginRequest, request: Request, db: AsyncSession = Depends(get_db)): async def wechat_login(
data: WeChatLoginRequest,
request: Request,
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
from app.services.wechat import wechat_service from app.services.wechat import wechat_service
session = await wechat_service.code2session(data.code) session = await wechat_service.code2session(data.code)
@@ -383,6 +398,7 @@ async def update_settings(
data: SettingsUpdate, data: SettingsUpdate,
authorization: Optional[str] = Header(None, alias="Authorization"), authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
): ):
if not authorization or not authorization.startswith("Bearer "): if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token") raise HTTPException(status_code=401, detail="Missing token")
+4 -1
View File
@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request, Header
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
@@ -6,6 +6,7 @@ from app.database import get_db
from app.services.payment import PaymentService from app.services.payment import PaymentService
from app.services.wechat_pay import WeChatPayService from app.services.wechat_pay import WeChatPayService
from app.api.v1.deps import get_current_user_id from app.api.v1.deps import get_current_user_id
from app.core.csrf import require_csrf_token
router = APIRouter() router = APIRouter()
@@ -40,6 +41,7 @@ async def create_order(
data: CreateOrderRequest, data: CreateOrderRequest,
user_id: str = Depends(get_current_user_id), user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
): ):
svc = PaymentService(db) svc = PaymentService(db)
try: try:
@@ -52,6 +54,7 @@ async def create_order(
async def payment_callback( async def payment_callback(
data: PaymentCallbackRequest, data: PaymentCallbackRequest,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
): ):
svc = PaymentService(db) svc = PaymentService(db)
success = await svc.handle_payment_callback(data.payment_id, data.success) success = await svc.handle_payment_callback(data.payment_id, data.success)
+11 -19
View File
@@ -8,11 +8,10 @@ ENV_FILE = PROJECT_ROOT / ".env"
class Settings(BaseSettings): class Settings(BaseSettings):
model_config = { class Config:
"env_file": str(ENV_FILE), env_file = str(ENV_FILE)
"env_file_encoding": "utf-8", env_file_encoding = "utf-8"
"extra": "ignore", extra = "ignore"
}
APP_NAME: str = "TradeMate" APP_NAME: str = "TradeMate"
@@ -29,21 +28,14 @@ class Settings(BaseSettings):
CELERY_BROKER_URL: str = "redis://localhost:6379/1" CELERY_BROKER_URL: str = "redis://localhost:6379/1"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/2" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/2"
OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
DEEPL_API_KEY: Optional[str] = None
SENSENOVA_API_KEY: Optional[str] = None SENSENOVA_API_KEY: Optional[str] = None
SENSENOVA_BASE_URL: str = "https://token.sensenova.cn/v1" SENSENOVA_BASE_URL: str = "https://token.sensenova.cn/v1"
SENSENOVA_MODEL: str = "sensenova-6.7-flash-lite" SENSENOVA_MODEL: str = "deepseek-v4-flash"
IFLYTEK_API_KEY: Optional[str] = None IFLYTEK_API_KEY: Optional[str] = None
IFLYTEK_API_BASE: str = "https://maas-api.cn-huabei-1.xf-yun.com/v2" IFLYTEK_API_BASE: str = "https://maas-api.cn-huabei-1.xf-yun.com/v2"
IFLYTEK_MODEL: str = "astron-code-latest" IFLYTEK_MODEL: str = "astron-code-latest"
LOCAL_MODEL_ENABLED: bool = False
LOCAL_MODEL_URL: str = "http://localhost:8001"
OPENCODE_GO_API_KEY: Optional[str] = None OPENCODE_GO_API_KEY: Optional[str] = None
OPENCODE_GO_BASE_URL: str = "https://opencode.ai/zen/go/v1" OPENCODE_GO_BASE_URL: str = "https://opencode.ai/zen/go/v1"
OPENCODE_GO_MODEL: str = "minimax-m2.7" OPENCODE_GO_MODEL: str = "minimax-m2.7"
@@ -85,12 +77,12 @@ class Settings(BaseSettings):
DEBUG: bool = True DEBUG: bool = True
AI_ROUTING: dict = { AI_ROUTING: dict = {
"translate": {"primary": "alibaba-mt", "fallback": ["opencode_go", "sensenova", "openai", "local"]}, "translate": {"primary": "sensenova", "fallback": ["alibaba-mt", "opencode_go"]},
"reply": {"primary": "opencode_go", "fallback": ["sensenova", "anthropic", "local"]}, "reply": {"primary": "sensenova", "fallback": ["opencode_go"]},
"marketing": {"primary": "opencode_go", "fallback": ["sensenova", "openai", "local"]}, "marketing": {"primary": "sensenova", "fallback": ["opencode_go"]},
"extract": {"primary": "opencode_go", "fallback": ["sensenova", "openai"]}, "extract": {"primary": "sensenova", "fallback": ["opencode_go"]},
"quotation": {"primary": "opencode_go", "fallback": ["sensenova", "openai"]}, "quotation": {"primary": "sensenova", "fallback": ["opencode_go"]},
"chat": {"primary": "nvidia", "fallback": ["opencode_go", "openai", "sensenova"]}, "chat": {"primary": "sensenova", "fallback": ["opencode_go", "nvidia"]},
} }
FREE_DAILY_TRANSLATE_CHARS: int = 5000 FREE_DAILY_TRANSLATE_CHARS: int = 5000
+166
View File
@@ -0,0 +1,166 @@
"""
CSRF Protection Module for TradeMate
Provides CSRF token generation and validation for form submissions
"""
import secrets
import time
from typing import Optional, Tuple
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from app.config import settings
import logging
logger = logging.getLogger(__name__)
# CSRF token configuration
CSRF_TOKEN_EXPIRY = 3600 # 1 hour
CSRF_HEADER_NAME = "X-CSRF-Token"
CSRF_COOKIE_NAME = "csrf_token"
# Methods that require CSRF protection
CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
# Endpoints that should skip CSRF protection (e.g., webhook endpoints)
CSRF_SKIP_ENDPOINTS = [
"/api/v1/webhook/",
"/api/v1/payment/notify",
"/api/v1/whatsapp/webhook",
]
def generate_csrf_token() -> str:
"""Generate a secure CSRF token"""
return secrets.token_urlsafe(32)
def validate_csrf_token(token: Optional[str], request: Request) -> bool:
"""
Validate CSRF token from request.
Checks:
1. Token is present
2. Token matches the session/token cookie
3. Token is not expired
"""
if not token:
return False
# Get the expected token from cookie or session
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
if not cookie_token:
return False
# Constant-time comparison to prevent timing attacks
return secrets.compare_digest(token, cookie_token)
class CSRFMiddleware(BaseHTTPMiddleware):
"""
CSRF protection middleware for FastAPI.
For JWT-based APIs, CSRF protection is primarily needed for:
1. Form-based submissions (if any)
2. Any endpoint that uses cookies for authentication
3. Prevention of cross-site request forgery attacks
The middleware:
- Generates CSRF tokens for authenticated sessions
- Validates CSRF tokens on state-changing requests
- Skips validation for webhook endpoints and public APIs
"""
async def dispatch(self, request: Request, call_next):
path = request.url.path
# Skip CSRF protection for:
# 1. Health check endpoint
# 2. Webhook endpoints
# 3. Public API endpoints (no auth required)
if path == "/health" or any(path.startswith(skip) for skip in CSRF_SKIP_ENDPOINTS):
return await call_next(request)
# Get authorization header
auth_header = request.headers.get("Authorization", "")
has_jwt = auth_header.startswith("Bearer ")
# For API requests with JWT, we use double-submit cookie pattern
# The client should send X-CSRF-Token header matching the csrf_token cookie
if request.method in CSRF_PROTECTED_METHODS:
# Check for CSRF token in header
csrf_token = request.headers.get(CSRF_HEADER_NAME)
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
# If there's a JWT but no CSRF token, this might be a direct API call
# In that case, we require the CSRF token to be present
if has_jwt and not csrf_token:
# This is a potential CSRF attempt
# For API clients using JWT, we still require CSRF protection
# to prevent attacks from malicious websites
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "CSRF_TOKEN_MISSING",
"message": "CSRF token required for this request",
"required_header": CSRF_HEADER_NAME,
},
)
# Validate the token if present
if csrf_token and cookie_token:
if not validate_csrf_token(csrf_token, request):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "CSRF_TOKEN_INVALID",
"message": "Invalid or expired CSRF token",
},
)
elif csrf_token and not cookie_token:
# Token in header but no cookie - invalid state
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "CSRF_TOKEN_MISMATCH",
"message": "CSRF token cookie not found",
},
)
response = await call_next(request)
# If this is an authenticated request and no CSRF cookie exists,
# generate and set one
if has_jwt and not request.cookies.get(CSRF_COOKIE_NAME):
csrf_token = generate_csrf_token()
response.set_cookie(
key=CSRF_COOKIE_NAME,
value=csrf_token,
max_age=CSRF_TOKEN_EXPIRY,
httponly=False, # Must be accessible to JavaScript for double-submit
secure=settings.DEBUG is False, # Secure in production
samesite="lax", # Prevent cross-site requests
path="/",
)
# Also add the token to response headers for convenience
response.headers[CSRF_HEADER_NAME] = csrf_token
return response
def get_csrf_token_from_request(request: Request) -> Optional[str]:
"""Helper to extract CSRF token from request"""
return request.headers.get(CSRF_HEADER_NAME) or request.cookies.get(CSRF_COOKIE_NAME)
def require_csrf_token(request: Request) -> str:
"""
Dependency function to require CSRF token in route handlers.
Use with: Depends(require_csrf_token)
"""
csrf_token = get_csrf_token_from_request(request)
if not csrf_token:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="CSRF token required",
)
return csrf_token
+181 -32
View File
@@ -1,4 +1,8 @@
from fastapi import Request, Response """
Enhanced Rate Limiting Module for TradeMate
Provides fine-grained rate limiting for sensitive endpoints
"""
from fastapi import Request, Response, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from app.config import settings from app.config import settings
from app.core.security import decode_token from app.core.security import decode_token
@@ -7,6 +11,7 @@ from redis.asyncio import ConnectionPool
import logging import logging
import time import time
from datetime import datetime from datetime import datetime
from typing import Dict, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -14,6 +19,7 @@ _redis_pool = None
async def get_redis(): async def get_redis():
"""Get Redis connection pool"""
global _redis_pool global _redis_pool
if _redis_pool is None: if _redis_pool is None:
_redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20) _redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20)
@@ -21,6 +27,7 @@ async def get_redis():
def get_user_tier_from_token(request: Request) -> str: def get_user_tier_from_token(request: Request) -> str:
"""Extract user tier from JWT token"""
auth = request.headers.get("Authorization", "") auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "): if not auth.startswith("Bearer "):
request.state.user_id = None request.state.user_id = None
@@ -36,29 +43,119 @@ def get_user_tier_from_token(request: Request) -> str:
return request.state.user_tier return request.state.user_tier
# Global rate limits (per minute) by tier
RATE_LIMITS = { RATE_LIMITS = {
"free": 100, "anonymous": 10,
"pro": 500, "guest": 30,
"enterprise": 2000, "free": 60,
"pro": 300,
"enterprise": 1000,
}
# Sensitive endpoint rate limits (stricter limits regardless of tier)
SENSITIVE_ENDPOINT_LIMITS: Dict[str, dict] = {
# Authentication endpoints - very strict
"/api/v1/auth/login": {"limit": 5, "window": 60, "name": "login_attempts"},
"/api/v1/auth/register": {"limit": 3, "window": 3600, "name": "register_attempts"},
"/api/v1/auth/refresh": {"limit": 10, "window": 60, "name": "refresh_attempts"},
"/api/v1/auth/login/guest": {"limit": 10, "window": 60, "name": "guest_login"},
"/api/v1/auth/wechat-login": {"limit": 5, "window": 60, "name": "wechat_login"},
# Password endpoints - very strict
"/api/v1/auth/password": {"limit": 3, "window": 300, "name": "password_change"},
# Profile endpoints - moderate
"/api/v1/auth/me": {"limit": 30, "window": 60, "name": "profile_update"},
"/api/v1/auth/settings": {"limit": 30, "window": 60, "name": "settings_update"},
# Payment endpoints - strict
"/api/v1/payment": {"limit": 20, "window": 60, "name": "payment_requests"},
# WhatsApp endpoints - moderate
"/api/v1/whatsapp": {"limit": 30, "window": 60, "name": "whatsapp_requests"},
# AI endpoints - moderate
"/api/v1/ai": {"limit": 60, "window": 60, "name": "ai_requests"},
"/api/v1/ai-assistant": {"limit": 60, "window": 60, "name": "ai_assistant"},
# Translation endpoints - moderate
"/api/v1/translate": {"limit": 60, "window": 60, "name": "translate_requests"},
# Marketing endpoints - moderate
"/api/v1/marketing": {"limit": 30, "window": 60, "name": "marketing_requests"},
# Quotation endpoints - moderate
"/api/v1/quotations": {"limit": 30, "window": 60, "name": "quotation_requests"},
# Customer endpoints - moderate
"/api/v1/customers": {"limit": 60, "window": 60, "name": "customer_requests"},
# Admin endpoints - strict
"/api/v1/admin": {"limit": 30, "window": 60, "name": "admin_requests"},
} }
async def check_rate_limit(user_id: str, tier: str) -> int: async def check_endpoint_rate_limit(
path: str, method: str, user_id: Optional[str], tier: str
) -> tuple[bool, int, str]:
"""
Check rate limit for a specific endpoint.
Returns: (allowed, remaining, limit_name)
"""
if not user_id:
# Use IP-based limiting for anonymous users
user_id = "anonymous"
# Find matching endpoint limit
matched_limit = None
for endpoint_pattern, config in SENSITIVE_ENDPOINT_LIMITS.items():
if path.startswith(endpoint_pattern):
matched_limit = config
break
if not matched_limit:
# Use tier-based global limit
limit = RATE_LIMITS.get(tier, RATE_LIMITS["free"])
key = f"ratelimit:{user_id}:{int(time.time() // 60)}"
limit_name = f"global_{tier}"
else:
limit = matched_limit["limit"]
window = matched_limit["window"]
key = f"ratelimit:{matched_limit['name']}:{user_id}:{int(time.time() // window)}"
limit_name = matched_limit["name"]
r = await get_redis()
count = await r.incr(key)
if count == 1:
# Set expiry based on window
window = matched_limit["window"] if matched_limit else 60
await r.expire(key, window + 5)
remaining = max(0, limit - count)
return remaining > 0, remaining, limit_name
async def check_global_rate_limit(user_id: str, tier: str) -> int:
"""Check global rate limit (per minute)"""
r = await get_redis() r = await get_redis()
now = time.time() now = time.time()
window = 60 window = 60
key = f"ratelimit:{user_id}:{int(now // window)}" key = f"ratelimit:{user_id}:{int(now // window)}"
count = await r.incr(key) count = await r.incr(key)
if count == 1: if count == 1:
await r.expire(key, window + 5) await r.expire(key, window + 5)
limit = RATE_LIMITS.get(tier, 100) limit = RATE_LIMITS.get(tier, 100)
remaining = max(0, limit - count) remaining = max(0, limit - count)
return remaining return remaining
class TierMiddleware(BaseHTTPMiddleware): class TierMiddleware(BaseHTTPMiddleware):
"""Middleware to extract user tier from JWT token"""
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
if request.url.path.startswith("/api/v1"): if request.url.path.startswith("/api/v1"):
tier = get_user_tier_from_token(request) tier = get_user_tier_from_token(request)
@@ -81,66 +178,118 @@ class TierMiddleware(BaseHTTPMiddleware):
request.state.user_id = None request.state.user_id = None
request.state.user_tier = "anonymous" request.state.user_tier = "anonymous"
request.state.tier_config = {} request.state.tier_config = {}
response = await call_next(request) response = await call_next(request)
return response return response
class RateLimitMiddleware(BaseHTTPMiddleware): class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Enhanced rate limiting middleware with fine-grained endpoint limits.
Features:
- Tier-based global rate limits
- Endpoint-specific rate limits for sensitive operations
- Separate limits for authentication, payment, and admin operations
- Proper rate limit headers in responses
"""
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
if not request.url.path.startswith("/api/v1"): path = request.url.path
# Skip rate limiting for health check and non-API routes
if not path.startswith("/api/v1"):
return await call_next(request) return await call_next(request)
user_tier = getattr(request.state, "user_tier", None) user_tier = getattr(request.state, "user_tier", None)
if user_tier in ("anonymous", None): if user_tier in ("anonymous", None):
return await call_next(request) # For anonymous users, still apply basic rate limiting
# Use IP-based limiting
try: client_ip = request.client.host if request.client else "unknown"
user_id = getattr(request.state, "user_id", None) allowed, remaining, limit_name = await check_endpoint_rate_limit(
if not user_id: path, request.method, client_ip, "anonymous"
return await call_next(request)
remaining = await check_rate_limit(
user_id, user_tier
) )
if remaining == 0: if not allowed:
return Response( return Response(
status_code=429, status_code=429,
content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}', content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}',
media_type="application/json", media_type="application/json",
headers={"Retry-After": "60"}, headers={
"Retry-After": "60",
"X-RateLimit-Limit": "10",
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(time.time()) + 60),
},
) )
response = await call_next(request) response = await call_next(request)
response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Limit"] = "10"
response.headers["X-RateLimit-Name"] = limit_name
return response return response
try:
user_id = getattr(request.state, "user_id", None)
if not user_id:
return await call_next(request)
# Check endpoint-specific rate limit
allowed, remaining, limit_name = await check_endpoint_rate_limit(
path, request.method, user_id, user_tier
)
if not allowed:
return Response(
status_code=429,
content=f'{{"error":"RATE_LIMITED","detail":"Rate limit exceeded for {limit_name}. Try again later."}}',
media_type="application/json",
headers={
"Retry-After": "60",
"X-RateLimit-Limit": str(RATE_LIMITS.get(user_tier, 100)),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(time.time()) + 60),
"X-RateLimit-Name": limit_name,
},
)
response = await call_next(request)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Limit"] = str(
SENSITIVE_ENDPOINT_LIMITS.get(path.split("/api/v1")[1].split("/")[0] if "/" in path.split("/api/v1")[1] else "", {}).get("limit", RATE_LIMITS.get(user_tier, 100))
)
response.headers["X-RateLimit-Name"] = limit_name
return response
except Exception as e: except Exception as e:
logger.warning(f"Rate limit check failed: {e}") logger.warning(f"Rate limit check failed: {e}")
return await call_next(request) return await call_next(request)
class QuotaMiddleware(BaseHTTPMiddleware): class QuotaMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce daily quotas for specific operations"""
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
if not request.url.path.startswith("/api/v1"): if not request.url.path.startswith("/api/v1"):
return await call_next(request) return await call_next(request)
user_tier = getattr(request.state, "user_tier", None) user_tier = getattr(request.state, "user_tier", None)
if user_tier in ("anonymous", None): if user_tier in ("anonymous", None):
return await call_next(request) return await call_next(request)
user_id = getattr(request.state, "user_id", None) user_id = getattr(request.state, "user_id", None)
if not user_id: if not user_id:
return await call_next(request) return await call_next(request)
tier = user_tier tier = user_tier
if tier == "enterprise": if tier == "enterprise":
return await call_next(request) return await call_next(request)
path = request.url.path path = request.url.path
method = request.method method = request.method
if method == "GET": if method == "GET":
return await call_next(request) return await call_next(request)
quota_map = [ quota_map = [
("/api/v1/translate/reply", {"free": settings.FREE_DAILY_REPLIES, "pro": settings.PRO_DAILY_REPLIES}), ("/api/v1/translate/reply", {"free": settings.FREE_DAILY_REPLIES, "pro": settings.PRO_DAILY_REPLIES}),
("/api/v1/translate", {"free": settings.FREE_DAILY_TRANSLATE_CHARS, "pro": settings.PRO_DAILY_TRANSLATE_CHARS}), ("/api/v1/translate", {"free": settings.FREE_DAILY_TRANSLATE_CHARS, "pro": settings.PRO_DAILY_TRANSLATE_CHARS}),
@@ -148,20 +297,20 @@ class QuotaMiddleware(BaseHTTPMiddleware):
("/api/v1/marketing", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}), ("/api/v1/marketing", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}),
("/api/v1/quotations", {"free": settings.FREE_DAILY_QUOTATIONS, "pro": settings.PRO_DAILY_QUOTATIONS}), ("/api/v1/quotations", {"free": settings.FREE_DAILY_QUOTATIONS, "pro": settings.PRO_DAILY_QUOTATIONS}),
] ]
matched_key = None matched_key = None
for prefix, limits in quota_map: for prefix, limits in quota_map:
if path.startswith(prefix): if path.startswith(prefix):
matched_key = prefix matched_key = prefix
break break
if not matched_key: if not matched_key:
return await call_next(request) return await call_next(request)
limit = quota_map[matched_key].get(tier) limit = quota_map[matched_key].get(tier)
if limit is None: if limit is None:
return await call_next(request) return await call_next(request)
try: try:
r = await get_redis() r = await get_redis()
key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}" key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}"
@@ -175,5 +324,5 @@ class QuotaMiddleware(BaseHTTPMiddleware):
raise raise
except Exception as e: except Exception as e:
logger.warning(f"Quota check failed: {e}") logger.warning(f"Quota check failed: {e}")
return await call_next(request) return await call_next(request)
+80 -4
View File
@@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
from app.config import settings from app.config import settings
from app.core.exceptions import register_exception_handlers from app.core.exceptions import register_exception_handlers
from app.core.middleware import TierMiddleware, QuotaMiddleware, RateLimitMiddleware from app.core.middleware import TierMiddleware, QuotaMiddleware, RateLimitMiddleware
from app.core.csrf import CSRFMiddleware
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -34,14 +35,70 @@ app = FastAPI(
debug=settings.DEBUG, debug=settings.DEBUG,
) )
# =============================================================================
# CORS Configuration - Security Hardened
# =============================================================================
# Only allow specific origins (frontend URLs)
# Only allow specific HTTP methods (no TRACE, CONNECT, etc.)
# Only allow specific headers (no arbitrary headers)
# =============================================================================
# Define allowed origins from environment/config
# In production, this should be your actual frontend domain(s)
ALLOWED_ORIGINS = [
settings.FRONTEND_URL,
"http://localhost:3000", # Legacy frontend
"http://localhost:5173", # Vite dev server
"http://localhost:5174", # User workspace dev server
"https://trade.yuzhiran.com", # Production domain
"https://trade.yuzhiran.com/app",
"https://trade.yuzhiran.com/admin",
"https://trade.yuzhiran.com/workspace",
]
# Allowed HTTP methods (explicitly listed for security)
ALLOWED_METHODS = [
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"OPTIONS",
]
# Allowed headers (explicitly listed)
ALLOWED_HEADERS = [
"Authorization",
"Content-Type",
"X-CSRF-Token",
"X-Requested-With",
"Accept",
"Origin",
]
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=[settings.FRONTEND_URL], allow_origins=ALLOWED_ORIGINS,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=ALLOWED_METHODS,
allow_headers=["*"], allow_headers=ALLOWED_HEADERS,
max_age=600, # Preflight cache duration
expose_headers=[
"X-RateLimit-Limit",
"X-RateLimit-Remaining",
"X-RateLimit-Reset",
"X-RateLimit-Name",
"X-CSRF-Token",
],
) )
# =============================================================================
# Security Middleware Stack
# =============================================================================
# Order matters - CSRF should come after CORS but before other middleware
# =============================================================================
app.add_middleware(CSRFMiddleware)
app.add_middleware(RateLimitMiddleware) app.add_middleware(RateLimitMiddleware)
app.add_middleware(QuotaMiddleware) app.add_middleware(QuotaMiddleware)
app.add_middleware(TierMiddleware) app.add_middleware(TierMiddleware)
@@ -49,12 +106,30 @@ app.add_middleware(TierMiddleware)
register_exception_handlers(app) register_exception_handlers(app)
@app.on_event("startup")
async def load_ai_providers_from_db():
try:
from app.database import get_db
from app.ai.router import get_ai_router
async for db in get_db():
router = get_ai_router()
count = await router.reload_from_db(db)
if count == 0:
seeded = await router.seed_from_env(db)
if seeded:
await router.reload_from_db(db)
break
except Exception as e:
logger.warning(f"AI provider DB load failed (tables may not exist yet): {e}")
@app.get("/health") @app.get("/health")
async def health(): async def health():
return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"} return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"}
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, discovery_record, certification, invoice, usage, referral, admin_search, search from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, discovery_record, certification, invoice, usage, referral, admin_search, search, admin_ai
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"]) app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"]) app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"])
@@ -85,6 +160,7 @@ app.include_router(invoice.router, prefix="/api/v1/invoices", tags=["invoices"])
app.include_router(usage.router, prefix="/api/v1/usage", tags=["usage"]) app.include_router(usage.router, prefix="/api/v1/usage", tags=["usage"])
app.include_router(referral.router, prefix="/api/v1/referral", tags=["referral"]) app.include_router(referral.router, prefix="/api/v1/referral", tags=["referral"])
app.include_router(admin_search.router, prefix="/api/v1/admin", tags=["admin"]) app.include_router(admin_search.router, prefix="/api/v1/admin", tags=["admin"])
app.include_router(admin_ai.router, prefix="/api/v1/admin", tags=["admin"])
app.include_router(search.router, prefix="/api/v1/search", tags=["search"]) app.include_router(search.router, prefix="/api/v1/search", tags=["search"])