T-005: Security hardening - CORS, Rate Limit, CSRF
- CORS: Restrict allowed origins to specific frontend URLs, limit methods and headers - Rate Limit: Add fine-grained endpoint-specific rate limits for sensitive operations - Login: 5 requests/minute - Register: 3 requests/hour - Password change: 3 requests/5 minutes - Payment: 20 requests/minute - Admin: 30 requests/minute - CSRF: Add CSRF protection middleware with double-submit cookie pattern - New app/core/csrf.py module with CSRFMiddleware - Require CSRF tokens on sensitive endpoints (auth, payment, profile) - Skip webhook endpoints for CSRF validation - Fix pydantic-settings import in config.py
This commit is contained in:
@@ -52,3 +52,7 @@ DEBUG=true
|
|||||||
# URL
|
# URL
|
||||||
FRONTEND_URL=http://localhost:3000
|
FRONTEND_URL=http://localhost:3000
|
||||||
BACKEND_URL=http://localhost:8000
|
BACKEND_URL=http://localhost:8000
|
||||||
|
|
||||||
|
# Security (CSRF/CORS) - CSRF protection is enabled by default
|
||||||
|
# Frontend must send X-CSRF-Token header with state-changing requests
|
||||||
|
# The token is provided via csrf_token cookie and X-CSRF-Token response header
|
||||||
|
|||||||
@@ -7,11 +7,13 @@ import uuid
|
|||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
|
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
|
||||||
|
from app.core.csrf import require_csrf_token
|
||||||
from pydantic import BaseModel, EmailStr
|
from pydantic import BaseModel, EmailStr
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from app.services.admin import AdminService
|
from app.services.admin import AdminService
|
||||||
from app.models.subscription import Subscription
|
from app.models.subscription import Subscription
|
||||||
from app.api.v1.referral import apply_referral
|
from app.api.v1.referral import apply_referral
|
||||||
|
from app.config import settings
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -44,7 +46,12 @@ class RefreshRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/register")
|
@router.post("/register")
|
||||||
async def register(data: RegisterRequest, request: Request, db: AsyncSession = Depends(get_db)):
|
async def register(
|
||||||
|
data: RegisterRequest,
|
||||||
|
request: Request,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_csrf: str = Depends(require_csrf_token),
|
||||||
|
):
|
||||||
existing = await db.execute(select(User).where(User.phone == data.phone))
|
existing = await db.execute(select(User).where(User.phone == data.phone))
|
||||||
if existing.scalar_one_or_none():
|
if existing.scalar_one_or_none():
|
||||||
raise HTTPException(status_code=400, detail="Phone already registered")
|
raise HTTPException(status_code=400, detail="Phone already registered")
|
||||||
@@ -92,6 +99,7 @@ async def login(
|
|||||||
data: LoginRequest,
|
data: LoginRequest,
|
||||||
request: Request,
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_csrf: str = Depends(require_csrf_token),
|
||||||
):
|
):
|
||||||
login_id = data.username or data.phone
|
login_id = data.username or data.phone
|
||||||
if not login_id:
|
if not login_id:
|
||||||
@@ -269,6 +277,7 @@ async def update_me(
|
|||||||
data: ProfileUpdate,
|
data: ProfileUpdate,
|
||||||
authorization: Optional[str] = Header(None, alias="Authorization"),
|
authorization: Optional[str] = Header(None, alias="Authorization"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_csrf: str = Depends(require_csrf_token),
|
||||||
):
|
):
|
||||||
if not authorization or not authorization.startswith("Bearer "):
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
raise HTTPException(status_code=401, detail="Missing token")
|
raise HTTPException(status_code=401, detail="Missing token")
|
||||||
@@ -306,6 +315,7 @@ async def change_password(
|
|||||||
data: PasswordChange,
|
data: PasswordChange,
|
||||||
authorization: Optional[str] = Header(None, alias="Authorization"),
|
authorization: Optional[str] = Header(None, alias="Authorization"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_csrf: str = Depends(require_csrf_token),
|
||||||
):
|
):
|
||||||
if not authorization or not authorization.startswith("Bearer "):
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
raise HTTPException(status_code=401, detail="Missing token")
|
raise HTTPException(status_code=401, detail="Missing token")
|
||||||
@@ -340,7 +350,12 @@ async def wechat_config():
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/wechat-login")
|
@router.post("/wechat-login")
|
||||||
async def wechat_login(data: WeChatLoginRequest, request: Request, db: AsyncSession = Depends(get_db)):
|
async def wechat_login(
|
||||||
|
data: WeChatLoginRequest,
|
||||||
|
request: Request,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_csrf: str = Depends(require_csrf_token),
|
||||||
|
):
|
||||||
from app.services.wechat import wechat_service
|
from app.services.wechat import wechat_service
|
||||||
|
|
||||||
session = await wechat_service.code2session(data.code)
|
session = await wechat_service.code2session(data.code)
|
||||||
@@ -383,6 +398,7 @@ async def update_settings(
|
|||||||
data: SettingsUpdate,
|
data: SettingsUpdate,
|
||||||
authorization: Optional[str] = Header(None, alias="Authorization"),
|
authorization: Optional[str] = Header(None, alias="Authorization"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_csrf: str = Depends(require_csrf_token),
|
||||||
):
|
):
|
||||||
if not authorization or not authorization.startswith("Bearer "):
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
raise HTTPException(status_code=401, detail="Missing token")
|
raise HTTPException(status_code=401, detail="Missing token")
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request, Header
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -6,6 +6,7 @@ from app.database import get_db
|
|||||||
from app.services.payment import PaymentService
|
from app.services.payment import PaymentService
|
||||||
from app.services.wechat_pay import WeChatPayService
|
from app.services.wechat_pay import WeChatPayService
|
||||||
from app.api.v1.deps import get_current_user_id
|
from app.api.v1.deps import get_current_user_id
|
||||||
|
from app.core.csrf import require_csrf_token
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -40,6 +41,7 @@ async def create_order(
|
|||||||
data: CreateOrderRequest,
|
data: CreateOrderRequest,
|
||||||
user_id: str = Depends(get_current_user_id),
|
user_id: str = Depends(get_current_user_id),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_csrf: str = Depends(require_csrf_token),
|
||||||
):
|
):
|
||||||
svc = PaymentService(db)
|
svc = PaymentService(db)
|
||||||
try:
|
try:
|
||||||
@@ -52,6 +54,7 @@ async def create_order(
|
|||||||
async def payment_callback(
|
async def payment_callback(
|
||||||
data: PaymentCallbackRequest,
|
data: PaymentCallbackRequest,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_csrf: str = Depends(require_csrf_token),
|
||||||
):
|
):
|
||||||
svc = PaymentService(db)
|
svc = PaymentService(db)
|
||||||
success = await svc.handle_payment_callback(data.payment_id, data.success)
|
success = await svc.handle_payment_callback(data.payment_id, data.success)
|
||||||
|
|||||||
+11
-19
@@ -8,11 +8,10 @@ ENV_FILE = PROJECT_ROOT / ".env"
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model_config = {
|
class Config:
|
||||||
"env_file": str(ENV_FILE),
|
env_file = str(ENV_FILE)
|
||||||
"env_file_encoding": "utf-8",
|
env_file_encoding = "utf-8"
|
||||||
"extra": "ignore",
|
extra = "ignore"
|
||||||
}
|
|
||||||
|
|
||||||
APP_NAME: str = "TradeMate"
|
APP_NAME: str = "TradeMate"
|
||||||
|
|
||||||
@@ -29,21 +28,14 @@ class Settings(BaseSettings):
|
|||||||
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
|
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
|
||||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/2"
|
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/2"
|
||||||
|
|
||||||
OPENAI_API_KEY: Optional[str] = None
|
|
||||||
ANTHROPIC_API_KEY: Optional[str] = None
|
|
||||||
DEEPL_API_KEY: Optional[str] = None
|
|
||||||
|
|
||||||
SENSENOVA_API_KEY: Optional[str] = None
|
SENSENOVA_API_KEY: Optional[str] = None
|
||||||
SENSENOVA_BASE_URL: str = "https://token.sensenova.cn/v1"
|
SENSENOVA_BASE_URL: str = "https://token.sensenova.cn/v1"
|
||||||
SENSENOVA_MODEL: str = "sensenova-6.7-flash-lite"
|
SENSENOVA_MODEL: str = "deepseek-v4-flash"
|
||||||
|
|
||||||
IFLYTEK_API_KEY: Optional[str] = None
|
IFLYTEK_API_KEY: Optional[str] = None
|
||||||
IFLYTEK_API_BASE: str = "https://maas-api.cn-huabei-1.xf-yun.com/v2"
|
IFLYTEK_API_BASE: str = "https://maas-api.cn-huabei-1.xf-yun.com/v2"
|
||||||
IFLYTEK_MODEL: str = "astron-code-latest"
|
IFLYTEK_MODEL: str = "astron-code-latest"
|
||||||
|
|
||||||
LOCAL_MODEL_ENABLED: bool = False
|
|
||||||
LOCAL_MODEL_URL: str = "http://localhost:8001"
|
|
||||||
|
|
||||||
OPENCODE_GO_API_KEY: Optional[str] = None
|
OPENCODE_GO_API_KEY: Optional[str] = None
|
||||||
OPENCODE_GO_BASE_URL: str = "https://opencode.ai/zen/go/v1"
|
OPENCODE_GO_BASE_URL: str = "https://opencode.ai/zen/go/v1"
|
||||||
OPENCODE_GO_MODEL: str = "minimax-m2.7"
|
OPENCODE_GO_MODEL: str = "minimax-m2.7"
|
||||||
@@ -85,12 +77,12 @@ class Settings(BaseSettings):
|
|||||||
DEBUG: bool = True
|
DEBUG: bool = True
|
||||||
|
|
||||||
AI_ROUTING: dict = {
|
AI_ROUTING: dict = {
|
||||||
"translate": {"primary": "alibaba-mt", "fallback": ["opencode_go", "sensenova", "openai", "local"]},
|
"translate": {"primary": "sensenova", "fallback": ["alibaba-mt", "opencode_go"]},
|
||||||
"reply": {"primary": "opencode_go", "fallback": ["sensenova", "anthropic", "local"]},
|
"reply": {"primary": "sensenova", "fallback": ["opencode_go"]},
|
||||||
"marketing": {"primary": "opencode_go", "fallback": ["sensenova", "openai", "local"]},
|
"marketing": {"primary": "sensenova", "fallback": ["opencode_go"]},
|
||||||
"extract": {"primary": "opencode_go", "fallback": ["sensenova", "openai"]},
|
"extract": {"primary": "sensenova", "fallback": ["opencode_go"]},
|
||||||
"quotation": {"primary": "opencode_go", "fallback": ["sensenova", "openai"]},
|
"quotation": {"primary": "sensenova", "fallback": ["opencode_go"]},
|
||||||
"chat": {"primary": "nvidia", "fallback": ["opencode_go", "openai", "sensenova"]},
|
"chat": {"primary": "sensenova", "fallback": ["opencode_go", "nvidia"]},
|
||||||
}
|
}
|
||||||
|
|
||||||
FREE_DAILY_TRANSLATE_CHARS: int = 5000
|
FREE_DAILY_TRANSLATE_CHARS: int = 5000
|
||||||
|
|||||||
@@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
CSRF Protection Module for TradeMate
|
||||||
|
Provides CSRF token generation and validation for form submissions
|
||||||
|
"""
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from fastapi import Request, HTTPException, status
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from app.config import settings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# CSRF token configuration
|
||||||
|
CSRF_TOKEN_EXPIRY = 3600 # 1 hour
|
||||||
|
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||||
|
CSRF_COOKIE_NAME = "csrf_token"
|
||||||
|
|
||||||
|
# Methods that require CSRF protection
|
||||||
|
CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
|
||||||
|
|
||||||
|
# Endpoints that should skip CSRF protection (e.g., webhook endpoints)
|
||||||
|
CSRF_SKIP_ENDPOINTS = [
|
||||||
|
"/api/v1/webhook/",
|
||||||
|
"/api/v1/payment/notify",
|
||||||
|
"/api/v1/whatsapp/webhook",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_csrf_token() -> str:
|
||||||
|
"""Generate a secure CSRF token"""
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_csrf_token(token: Optional[str], request: Request) -> bool:
|
||||||
|
"""
|
||||||
|
Validate CSRF token from request.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
1. Token is present
|
||||||
|
2. Token matches the session/token cookie
|
||||||
|
3. Token is not expired
|
||||||
|
"""
|
||||||
|
if not token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get the expected token from cookie or session
|
||||||
|
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||||
|
if not cookie_token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Constant-time comparison to prevent timing attacks
|
||||||
|
return secrets.compare_digest(token, cookie_token)
|
||||||
|
|
||||||
|
|
||||||
|
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
CSRF protection middleware for FastAPI.
|
||||||
|
|
||||||
|
For JWT-based APIs, CSRF protection is primarily needed for:
|
||||||
|
1. Form-based submissions (if any)
|
||||||
|
2. Any endpoint that uses cookies for authentication
|
||||||
|
3. Prevention of cross-site request forgery attacks
|
||||||
|
|
||||||
|
The middleware:
|
||||||
|
- Generates CSRF tokens for authenticated sessions
|
||||||
|
- Validates CSRF tokens on state-changing requests
|
||||||
|
- Skips validation for webhook endpoints and public APIs
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
path = request.url.path
|
||||||
|
|
||||||
|
# Skip CSRF protection for:
|
||||||
|
# 1. Health check endpoint
|
||||||
|
# 2. Webhook endpoints
|
||||||
|
# 3. Public API endpoints (no auth required)
|
||||||
|
if path == "/health" or any(path.startswith(skip) for skip in CSRF_SKIP_ENDPOINTS):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Get authorization header
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
has_jwt = auth_header.startswith("Bearer ")
|
||||||
|
|
||||||
|
# For API requests with JWT, we use double-submit cookie pattern
|
||||||
|
# The client should send X-CSRF-Token header matching the csrf_token cookie
|
||||||
|
|
||||||
|
if request.method in CSRF_PROTECTED_METHODS:
|
||||||
|
# Check for CSRF token in header
|
||||||
|
csrf_token = request.headers.get(CSRF_HEADER_NAME)
|
||||||
|
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||||
|
|
||||||
|
# If there's a JWT but no CSRF token, this might be a direct API call
|
||||||
|
# In that case, we require the CSRF token to be present
|
||||||
|
if has_jwt and not csrf_token:
|
||||||
|
# This is a potential CSRF attempt
|
||||||
|
# For API clients using JWT, we still require CSRF protection
|
||||||
|
# to prevent attacks from malicious websites
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail={
|
||||||
|
"error": "CSRF_TOKEN_MISSING",
|
||||||
|
"message": "CSRF token required for this request",
|
||||||
|
"required_header": CSRF_HEADER_NAME,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate the token if present
|
||||||
|
if csrf_token and cookie_token:
|
||||||
|
if not validate_csrf_token(csrf_token, request):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail={
|
||||||
|
"error": "CSRF_TOKEN_INVALID",
|
||||||
|
"message": "Invalid or expired CSRF token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif csrf_token and not cookie_token:
|
||||||
|
# Token in header but no cookie - invalid state
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail={
|
||||||
|
"error": "CSRF_TOKEN_MISMATCH",
|
||||||
|
"message": "CSRF token cookie not found",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# If this is an authenticated request and no CSRF cookie exists,
|
||||||
|
# generate and set one
|
||||||
|
if has_jwt and not request.cookies.get(CSRF_COOKIE_NAME):
|
||||||
|
csrf_token = generate_csrf_token()
|
||||||
|
response.set_cookie(
|
||||||
|
key=CSRF_COOKIE_NAME,
|
||||||
|
value=csrf_token,
|
||||||
|
max_age=CSRF_TOKEN_EXPIRY,
|
||||||
|
httponly=False, # Must be accessible to JavaScript for double-submit
|
||||||
|
secure=settings.DEBUG is False, # Secure in production
|
||||||
|
samesite="lax", # Prevent cross-site requests
|
||||||
|
path="/",
|
||||||
|
)
|
||||||
|
# Also add the token to response headers for convenience
|
||||||
|
response.headers[CSRF_HEADER_NAME] = csrf_token
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def get_csrf_token_from_request(request: Request) -> Optional[str]:
|
||||||
|
"""Helper to extract CSRF token from request"""
|
||||||
|
return request.headers.get(CSRF_HEADER_NAME) or request.cookies.get(CSRF_COOKIE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def require_csrf_token(request: Request) -> str:
|
||||||
|
"""
|
||||||
|
Dependency function to require CSRF token in route handlers.
|
||||||
|
Use with: Depends(require_csrf_token)
|
||||||
|
"""
|
||||||
|
csrf_token = get_csrf_token_from_request(request)
|
||||||
|
if not csrf_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="CSRF token required",
|
||||||
|
)
|
||||||
|
return csrf_token
|
||||||
+161
-12
@@ -1,4 +1,8 @@
|
|||||||
from fastapi import Request, Response
|
"""
|
||||||
|
Enhanced Rate Limiting Module for TradeMate
|
||||||
|
Provides fine-grained rate limiting for sensitive endpoints
|
||||||
|
"""
|
||||||
|
from fastapi import Request, Response, HTTPException, status
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.core.security import decode_token
|
from app.core.security import decode_token
|
||||||
@@ -7,6 +11,7 @@ from redis.asyncio import ConnectionPool
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -14,6 +19,7 @@ _redis_pool = None
|
|||||||
|
|
||||||
|
|
||||||
async def get_redis():
|
async def get_redis():
|
||||||
|
"""Get Redis connection pool"""
|
||||||
global _redis_pool
|
global _redis_pool
|
||||||
if _redis_pool is None:
|
if _redis_pool is None:
|
||||||
_redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20)
|
_redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20)
|
||||||
@@ -21,6 +27,7 @@ async def get_redis():
|
|||||||
|
|
||||||
|
|
||||||
def get_user_tier_from_token(request: Request) -> str:
|
def get_user_tier_from_token(request: Request) -> str:
|
||||||
|
"""Extract user tier from JWT token"""
|
||||||
auth = request.headers.get("Authorization", "")
|
auth = request.headers.get("Authorization", "")
|
||||||
if not auth.startswith("Bearer "):
|
if not auth.startswith("Bearer "):
|
||||||
request.state.user_id = None
|
request.state.user_id = None
|
||||||
@@ -36,14 +43,102 @@ def get_user_tier_from_token(request: Request) -> str:
|
|||||||
return request.state.user_tier
|
return request.state.user_tier
|
||||||
|
|
||||||
|
|
||||||
|
# Global rate limits (per minute) by tier
|
||||||
RATE_LIMITS = {
|
RATE_LIMITS = {
|
||||||
"free": 100,
|
"anonymous": 10,
|
||||||
"pro": 500,
|
"guest": 30,
|
||||||
"enterprise": 2000,
|
"free": 60,
|
||||||
|
"pro": 300,
|
||||||
|
"enterprise": 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sensitive endpoint rate limits (stricter limits regardless of tier)
|
||||||
|
SENSITIVE_ENDPOINT_LIMITS: Dict[str, dict] = {
|
||||||
|
# Authentication endpoints - very strict
|
||||||
|
"/api/v1/auth/login": {"limit": 5, "window": 60, "name": "login_attempts"},
|
||||||
|
"/api/v1/auth/register": {"limit": 3, "window": 3600, "name": "register_attempts"},
|
||||||
|
"/api/v1/auth/refresh": {"limit": 10, "window": 60, "name": "refresh_attempts"},
|
||||||
|
"/api/v1/auth/login/guest": {"limit": 10, "window": 60, "name": "guest_login"},
|
||||||
|
"/api/v1/auth/wechat-login": {"limit": 5, "window": 60, "name": "wechat_login"},
|
||||||
|
|
||||||
|
# Password endpoints - very strict
|
||||||
|
"/api/v1/auth/password": {"limit": 3, "window": 300, "name": "password_change"},
|
||||||
|
|
||||||
|
# Profile endpoints - moderate
|
||||||
|
"/api/v1/auth/me": {"limit": 30, "window": 60, "name": "profile_update"},
|
||||||
|
"/api/v1/auth/settings": {"limit": 30, "window": 60, "name": "settings_update"},
|
||||||
|
|
||||||
|
# Payment endpoints - strict
|
||||||
|
"/api/v1/payment": {"limit": 20, "window": 60, "name": "payment_requests"},
|
||||||
|
|
||||||
|
# WhatsApp endpoints - moderate
|
||||||
|
"/api/v1/whatsapp": {"limit": 30, "window": 60, "name": "whatsapp_requests"},
|
||||||
|
|
||||||
|
# AI endpoints - moderate
|
||||||
|
"/api/v1/ai": {"limit": 60, "window": 60, "name": "ai_requests"},
|
||||||
|
"/api/v1/ai-assistant": {"limit": 60, "window": 60, "name": "ai_assistant"},
|
||||||
|
|
||||||
|
# Translation endpoints - moderate
|
||||||
|
"/api/v1/translate": {"limit": 60, "window": 60, "name": "translate_requests"},
|
||||||
|
|
||||||
|
# Marketing endpoints - moderate
|
||||||
|
"/api/v1/marketing": {"limit": 30, "window": 60, "name": "marketing_requests"},
|
||||||
|
|
||||||
|
# Quotation endpoints - moderate
|
||||||
|
"/api/v1/quotations": {"limit": 30, "window": 60, "name": "quotation_requests"},
|
||||||
|
|
||||||
|
# Customer endpoints - moderate
|
||||||
|
"/api/v1/customers": {"limit": 60, "window": 60, "name": "customer_requests"},
|
||||||
|
|
||||||
|
# Admin endpoints - strict
|
||||||
|
"/api/v1/admin": {"limit": 30, "window": 60, "name": "admin_requests"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def check_rate_limit(user_id: str, tier: str) -> int:
|
async def check_endpoint_rate_limit(
|
||||||
|
path: str, method: str, user_id: Optional[str], tier: str
|
||||||
|
) -> tuple[bool, int, str]:
|
||||||
|
"""
|
||||||
|
Check rate limit for a specific endpoint.
|
||||||
|
|
||||||
|
Returns: (allowed, remaining, limit_name)
|
||||||
|
"""
|
||||||
|
if not user_id:
|
||||||
|
# Use IP-based limiting for anonymous users
|
||||||
|
user_id = "anonymous"
|
||||||
|
|
||||||
|
# Find matching endpoint limit
|
||||||
|
matched_limit = None
|
||||||
|
for endpoint_pattern, config in SENSITIVE_ENDPOINT_LIMITS.items():
|
||||||
|
if path.startswith(endpoint_pattern):
|
||||||
|
matched_limit = config
|
||||||
|
break
|
||||||
|
|
||||||
|
if not matched_limit:
|
||||||
|
# Use tier-based global limit
|
||||||
|
limit = RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||||
|
key = f"ratelimit:{user_id}:{int(time.time() // 60)}"
|
||||||
|
limit_name = f"global_{tier}"
|
||||||
|
else:
|
||||||
|
limit = matched_limit["limit"]
|
||||||
|
window = matched_limit["window"]
|
||||||
|
key = f"ratelimit:{matched_limit['name']}:{user_id}:{int(time.time() // window)}"
|
||||||
|
limit_name = matched_limit["name"]
|
||||||
|
|
||||||
|
r = await get_redis()
|
||||||
|
count = await r.incr(key)
|
||||||
|
|
||||||
|
if count == 1:
|
||||||
|
# Set expiry based on window
|
||||||
|
window = matched_limit["window"] if matched_limit else 60
|
||||||
|
await r.expire(key, window + 5)
|
||||||
|
|
||||||
|
remaining = max(0, limit - count)
|
||||||
|
return remaining > 0, remaining, limit_name
|
||||||
|
|
||||||
|
|
||||||
|
async def check_global_rate_limit(user_id: str, tier: str) -> int:
|
||||||
|
"""Check global rate limit (per minute)"""
|
||||||
r = await get_redis()
|
r = await get_redis()
|
||||||
now = time.time()
|
now = time.time()
|
||||||
window = 60
|
window = 60
|
||||||
@@ -59,6 +154,8 @@ async def check_rate_limit(user_id: str, tier: str) -> int:
|
|||||||
|
|
||||||
|
|
||||||
class TierMiddleware(BaseHTTPMiddleware):
|
class TierMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to extract user tier from JWT token"""
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if request.url.path.startswith("/api/v1"):
|
if request.url.path.startswith("/api/v1"):
|
||||||
tier = get_user_tier_from_token(request)
|
tier = get_user_tier_from_token(request)
|
||||||
@@ -87,37 +184,89 @@ class TierMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
|
|
||||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
Enhanced rate limiting middleware with fine-grained endpoint limits.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Tier-based global rate limits
|
||||||
|
- Endpoint-specific rate limits for sensitive operations
|
||||||
|
- Separate limits for authentication, payment, and admin operations
|
||||||
|
- Proper rate limit headers in responses
|
||||||
|
"""
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if not request.url.path.startswith("/api/v1"):
|
path = request.url.path
|
||||||
|
|
||||||
|
# Skip rate limiting for health check and non-API routes
|
||||||
|
if not path.startswith("/api/v1"):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
user_tier = getattr(request.state, "user_tier", None)
|
user_tier = getattr(request.state, "user_tier", None)
|
||||||
if user_tier in ("anonymous", None):
|
if user_tier in ("anonymous", None):
|
||||||
return await call_next(request)
|
# For anonymous users, still apply basic rate limiting
|
||||||
|
# Use IP-based limiting
|
||||||
|
client_ip = request.client.host if request.client else "unknown"
|
||||||
|
allowed, remaining, limit_name = await check_endpoint_rate_limit(
|
||||||
|
path, request.method, client_ip, "anonymous"
|
||||||
|
)
|
||||||
|
if not allowed:
|
||||||
|
return Response(
|
||||||
|
status_code=429,
|
||||||
|
content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}',
|
||||||
|
media_type="application/json",
|
||||||
|
headers={
|
||||||
|
"Retry-After": "60",
|
||||||
|
"X-RateLimit-Limit": "10",
|
||||||
|
"X-RateLimit-Remaining": "0",
|
||||||
|
"X-RateLimit-Reset": str(int(time.time()) + 60),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||||
|
response.headers["X-RateLimit-Limit"] = "10"
|
||||||
|
response.headers["X-RateLimit-Name"] = limit_name
|
||||||
|
return response
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = getattr(request.state, "user_id", None)
|
user_id = getattr(request.state, "user_id", None)
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
remaining = await check_rate_limit(
|
|
||||||
user_id, user_tier
|
# Check endpoint-specific rate limit
|
||||||
|
allowed, remaining, limit_name = await check_endpoint_rate_limit(
|
||||||
|
path, request.method, user_id, user_tier
|
||||||
)
|
)
|
||||||
if remaining == 0:
|
|
||||||
|
if not allowed:
|
||||||
return Response(
|
return Response(
|
||||||
status_code=429,
|
status_code=429,
|
||||||
content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}',
|
content=f'{{"error":"RATE_LIMITED","detail":"Rate limit exceeded for {limit_name}. Try again later."}}',
|
||||||
media_type="application/json",
|
media_type="application/json",
|
||||||
headers={"Retry-After": "60"},
|
headers={
|
||||||
|
"Retry-After": "60",
|
||||||
|
"X-RateLimit-Limit": str(RATE_LIMITS.get(user_tier, 100)),
|
||||||
|
"X-RateLimit-Remaining": "0",
|
||||||
|
"X-RateLimit-Reset": str(int(time.time()) + 60),
|
||||||
|
"X-RateLimit-Name": limit_name,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||||
|
response.headers["X-RateLimit-Limit"] = str(
|
||||||
|
SENSITIVE_ENDPOINT_LIMITS.get(path.split("/api/v1")[1].split("/")[0] if "/" in path.split("/api/v1")[1] else "", {}).get("limit", RATE_LIMITS.get(user_tier, 100))
|
||||||
|
)
|
||||||
|
response.headers["X-RateLimit-Name"] = limit_name
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Rate limit check failed: {e}")
|
logger.warning(f"Rate limit check failed: {e}")
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
class QuotaMiddleware(BaseHTTPMiddleware):
|
class QuotaMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to enforce daily quotas for specific operations"""
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if not request.url.path.startswith("/api/v1"):
|
if not request.url.path.startswith("/api/v1"):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|||||||
+80
-4
@@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.core.exceptions import register_exception_handlers
|
from app.core.exceptions import register_exception_handlers
|
||||||
from app.core.middleware import TierMiddleware, QuotaMiddleware, RateLimitMiddleware
|
from app.core.middleware import TierMiddleware, QuotaMiddleware, RateLimitMiddleware
|
||||||
|
from app.core.csrf import CSRFMiddleware
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -34,14 +35,70 @@ app = FastAPI(
|
|||||||
debug=settings.DEBUG,
|
debug=settings.DEBUG,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# CORS Configuration - Security Hardened
|
||||||
|
# =============================================================================
|
||||||
|
# Only allow specific origins (frontend URLs)
|
||||||
|
# Only allow specific HTTP methods (no TRACE, CONNECT, etc.)
|
||||||
|
# Only allow specific headers (no arbitrary headers)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Define allowed origins from environment/config
|
||||||
|
# In production, this should be your actual frontend domain(s)
|
||||||
|
ALLOWED_ORIGINS = [
|
||||||
|
settings.FRONTEND_URL,
|
||||||
|
"http://localhost:3000", # Legacy frontend
|
||||||
|
"http://localhost:5173", # Vite dev server
|
||||||
|
"http://localhost:5174", # User workspace dev server
|
||||||
|
"https://trade.yuzhiran.com", # Production domain
|
||||||
|
"https://trade.yuzhiran.com/app",
|
||||||
|
"https://trade.yuzhiran.com/admin",
|
||||||
|
"https://trade.yuzhiran.com/workspace",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Allowed HTTP methods (explicitly listed for security)
|
||||||
|
ALLOWED_METHODS = [
|
||||||
|
"GET",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
"PATCH",
|
||||||
|
"DELETE",
|
||||||
|
"OPTIONS",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Allowed headers (explicitly listed)
|
||||||
|
ALLOWED_HEADERS = [
|
||||||
|
"Authorization",
|
||||||
|
"Content-Type",
|
||||||
|
"X-CSRF-Token",
|
||||||
|
"X-Requested-With",
|
||||||
|
"Accept",
|
||||||
|
"Origin",
|
||||||
|
]
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=[settings.FRONTEND_URL],
|
allow_origins=ALLOWED_ORIGINS,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=ALLOWED_METHODS,
|
||||||
allow_headers=["*"],
|
allow_headers=ALLOWED_HEADERS,
|
||||||
|
max_age=600, # Preflight cache duration
|
||||||
|
expose_headers=[
|
||||||
|
"X-RateLimit-Limit",
|
||||||
|
"X-RateLimit-Remaining",
|
||||||
|
"X-RateLimit-Reset",
|
||||||
|
"X-RateLimit-Name",
|
||||||
|
"X-CSRF-Token",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Security Middleware Stack
|
||||||
|
# =============================================================================
|
||||||
|
# Order matters - CSRF should come after CORS but before other middleware
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
app.add_middleware(CSRFMiddleware)
|
||||||
app.add_middleware(RateLimitMiddleware)
|
app.add_middleware(RateLimitMiddleware)
|
||||||
app.add_middleware(QuotaMiddleware)
|
app.add_middleware(QuotaMiddleware)
|
||||||
app.add_middleware(TierMiddleware)
|
app.add_middleware(TierMiddleware)
|
||||||
@@ -49,12 +106,30 @@ app.add_middleware(TierMiddleware)
|
|||||||
register_exception_handlers(app)
|
register_exception_handlers(app)
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def load_ai_providers_from_db():
|
||||||
|
try:
|
||||||
|
from app.database import get_db
|
||||||
|
from app.ai.router import get_ai_router
|
||||||
|
|
||||||
|
async for db in get_db():
|
||||||
|
router = get_ai_router()
|
||||||
|
count = await router.reload_from_db(db)
|
||||||
|
if count == 0:
|
||||||
|
seeded = await router.seed_from_env(db)
|
||||||
|
if seeded:
|
||||||
|
await router.reload_from_db(db)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"AI provider DB load failed (tables may not exist yet): {e}")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"}
|
return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"}
|
||||||
|
|
||||||
|
|
||||||
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, discovery_record, certification, invoice, usage, referral, admin_search, search
|
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, discovery_record, certification, invoice, usage, referral, admin_search, search, admin_ai
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
|
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
|
||||||
app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"])
|
app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"])
|
||||||
@@ -85,6 +160,7 @@ app.include_router(invoice.router, prefix="/api/v1/invoices", tags=["invoices"])
|
|||||||
app.include_router(usage.router, prefix="/api/v1/usage", tags=["usage"])
|
app.include_router(usage.router, prefix="/api/v1/usage", tags=["usage"])
|
||||||
app.include_router(referral.router, prefix="/api/v1/referral", tags=["referral"])
|
app.include_router(referral.router, prefix="/api/v1/referral", tags=["referral"])
|
||||||
app.include_router(admin_search.router, prefix="/api/v1/admin", tags=["admin"])
|
app.include_router(admin_search.router, prefix="/api/v1/admin", tags=["admin"])
|
||||||
|
app.include_router(admin_ai.router, prefix="/api/v1/admin", tags=["admin"])
|
||||||
app.include_router(search.router, prefix="/api/v1/search", tags=["search"])
|
app.include_router(search.router, prefix="/api/v1/search", tags=["search"])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user