T-005: Security hardening - CORS, Rate Limit, CSRF

- CORS: Restrict allowed origins to specific frontend URLs, limit methods and headers
- Rate Limit: Add fine-grained endpoint-specific rate limits for sensitive operations
  - Login: 5 requests/minute
  - Register: 3 requests/hour
  - Password change: 3 requests/5 minutes
  - Payment: 20 requests/minute
  - Admin: 30 requests/minute
- CSRF: Add CSRF protection middleware with double-submit cookie pattern
  - New app/core/csrf.py module with CSRFMiddleware
  - Require CSRF tokens on sensitive endpoints (auth, payment, profile)
  - Skip webhook endpoints for CSRF validation
- Fix pydantic-settings import in config.py
This commit is contained in:
TradeMate Dev
2026-05-29 10:26:23 +08:00
parent 7c9885f704
commit c04fa2c19f
7 changed files with 464 additions and 58 deletions
+4
View File
@@ -52,3 +52,7 @@ DEBUG=true
# URL
FRONTEND_URL=http://localhost:3000
BACKEND_URL=http://localhost:8000
# Security (CSRF/CORS) - CSRF protection is enabled by default
# Frontend must send X-CSRF-Token header with state-changing requests
# The token is provided via csrf_token cookie and X-CSRF-Token response header
+18 -2
View File
@@ -7,11 +7,13 @@ import uuid
from app.database import get_db
from app.models.user import User
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
from app.core.csrf import require_csrf_token
from pydantic import BaseModel, EmailStr
from datetime import datetime, timedelta
from app.services.admin import AdminService
from app.models.subscription import Subscription
from app.api.v1.referral import apply_referral
from app.config import settings
import logging
logger = logging.getLogger(__name__)
@@ -44,7 +46,12 @@ class RefreshRequest(BaseModel):
@router.post("/register")
async def register(data: RegisterRequest, request: Request, db: AsyncSession = Depends(get_db)):
async def register(
data: RegisterRequest,
request: Request,
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
existing = await db.execute(select(User).where(User.phone == data.phone))
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Phone already registered")
@@ -92,6 +99,7 @@ async def login(
data: LoginRequest,
request: Request,
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
login_id = data.username or data.phone
if not login_id:
@@ -269,6 +277,7 @@ async def update_me(
data: ProfileUpdate,
authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token")
@@ -306,6 +315,7 @@ async def change_password(
data: PasswordChange,
authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token")
@@ -340,7 +350,12 @@ async def wechat_config():
@router.post("/wechat-login")
async def wechat_login(data: WeChatLoginRequest, request: Request, db: AsyncSession = Depends(get_db)):
async def wechat_login(
data: WeChatLoginRequest,
request: Request,
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
from app.services.wechat import wechat_service
session = await wechat_service.code2session(data.code)
@@ -383,6 +398,7 @@ async def update_settings(
data: SettingsUpdate,
authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token")
+4 -1
View File
@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Header
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from typing import Optional
@@ -6,6 +6,7 @@ from app.database import get_db
from app.services.payment import PaymentService
from app.services.wechat_pay import WeChatPayService
from app.api.v1.deps import get_current_user_id
from app.core.csrf import require_csrf_token
router = APIRouter()
@@ -40,6 +41,7 @@ async def create_order(
data: CreateOrderRequest,
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
svc = PaymentService(db)
try:
@@ -52,6 +54,7 @@ async def create_order(
async def payment_callback(
data: PaymentCallbackRequest,
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
svc = PaymentService(db)
success = await svc.handle_payment_callback(data.payment_id, data.success)
+11 -19
View File
@@ -8,11 +8,10 @@ ENV_FILE = PROJECT_ROOT / ".env"
class Settings(BaseSettings):
model_config = {
"env_file": str(ENV_FILE),
"env_file_encoding": "utf-8",
"extra": "ignore",
}
class Config:
env_file = str(ENV_FILE)
env_file_encoding = "utf-8"
extra = "ignore"
APP_NAME: str = "TradeMate"
@@ -29,21 +28,14 @@ class Settings(BaseSettings):
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/2"
OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
DEEPL_API_KEY: Optional[str] = None
SENSENOVA_API_KEY: Optional[str] = None
SENSENOVA_BASE_URL: str = "https://token.sensenova.cn/v1"
SENSENOVA_MODEL: str = "sensenova-6.7-flash-lite"
SENSENOVA_MODEL: str = "deepseek-v4-flash"
IFLYTEK_API_KEY: Optional[str] = None
IFLYTEK_API_BASE: str = "https://maas-api.cn-huabei-1.xf-yun.com/v2"
IFLYTEK_MODEL: str = "astron-code-latest"
LOCAL_MODEL_ENABLED: bool = False
LOCAL_MODEL_URL: str = "http://localhost:8001"
OPENCODE_GO_API_KEY: Optional[str] = None
OPENCODE_GO_BASE_URL: str = "https://opencode.ai/zen/go/v1"
OPENCODE_GO_MODEL: str = "minimax-m2.7"
@@ -85,12 +77,12 @@ class Settings(BaseSettings):
DEBUG: bool = True
AI_ROUTING: dict = {
"translate": {"primary": "alibaba-mt", "fallback": ["opencode_go", "sensenova", "openai", "local"]},
"reply": {"primary": "opencode_go", "fallback": ["sensenova", "anthropic", "local"]},
"marketing": {"primary": "opencode_go", "fallback": ["sensenova", "openai", "local"]},
"extract": {"primary": "opencode_go", "fallback": ["sensenova", "openai"]},
"quotation": {"primary": "opencode_go", "fallback": ["sensenova", "openai"]},
"chat": {"primary": "nvidia", "fallback": ["opencode_go", "openai", "sensenova"]},
"translate": {"primary": "sensenova", "fallback": ["alibaba-mt", "opencode_go"]},
"reply": {"primary": "sensenova", "fallback": ["opencode_go"]},
"marketing": {"primary": "sensenova", "fallback": ["opencode_go"]},
"extract": {"primary": "sensenova", "fallback": ["opencode_go"]},
"quotation": {"primary": "sensenova", "fallback": ["opencode_go"]},
"chat": {"primary": "sensenova", "fallback": ["opencode_go", "nvidia"]},
}
FREE_DAILY_TRANSLATE_CHARS: int = 5000
+166
View File
@@ -0,0 +1,166 @@
"""
CSRF Protection Module for TradeMate
Provides CSRF token generation and validation for form submissions
"""
import secrets
import time
from typing import Optional, Tuple
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from app.config import settings
import logging
logger = logging.getLogger(__name__)
# CSRF token configuration
CSRF_TOKEN_EXPIRY = 3600 # 1 hour
CSRF_HEADER_NAME = "X-CSRF-Token"
CSRF_COOKIE_NAME = "csrf_token"
# Methods that require CSRF protection
CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
# Endpoints that should skip CSRF protection (e.g., webhook endpoints)
CSRF_SKIP_ENDPOINTS = [
"/api/v1/webhook/",
"/api/v1/payment/notify",
"/api/v1/whatsapp/webhook",
]
def generate_csrf_token() -> str:
"""Generate a secure CSRF token"""
return secrets.token_urlsafe(32)
def validate_csrf_token(token: Optional[str], request: Request) -> bool:
"""
Validate CSRF token from request.
Checks:
1. Token is present
2. Token matches the session/token cookie
3. Token is not expired
"""
if not token:
return False
# Get the expected token from cookie or session
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
if not cookie_token:
return False
# Constant-time comparison to prevent timing attacks
return secrets.compare_digest(token, cookie_token)
class CSRFMiddleware(BaseHTTPMiddleware):
"""
CSRF protection middleware for FastAPI.
For JWT-based APIs, CSRF protection is primarily needed for:
1. Form-based submissions (if any)
2. Any endpoint that uses cookies for authentication
3. Prevention of cross-site request forgery attacks
The middleware:
- Generates CSRF tokens for authenticated sessions
- Validates CSRF tokens on state-changing requests
- Skips validation for webhook endpoints and public APIs
"""
async def dispatch(self, request: Request, call_next):
path = request.url.path
# Skip CSRF protection for:
# 1. Health check endpoint
# 2. Webhook endpoints
# 3. Public API endpoints (no auth required)
if path == "/health" or any(path.startswith(skip) for skip in CSRF_SKIP_ENDPOINTS):
return await call_next(request)
# Get authorization header
auth_header = request.headers.get("Authorization", "")
has_jwt = auth_header.startswith("Bearer ")
# For API requests with JWT, we use double-submit cookie pattern
# The client should send X-CSRF-Token header matching the csrf_token cookie
if request.method in CSRF_PROTECTED_METHODS:
# Check for CSRF token in header
csrf_token = request.headers.get(CSRF_HEADER_NAME)
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
# If there's a JWT but no CSRF token, this might be a direct API call
# In that case, we require the CSRF token to be present
if has_jwt and not csrf_token:
# This is a potential CSRF attempt
# For API clients using JWT, we still require CSRF protection
# to prevent attacks from malicious websites
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "CSRF_TOKEN_MISSING",
"message": "CSRF token required for this request",
"required_header": CSRF_HEADER_NAME,
},
)
# Validate the token if present
if csrf_token and cookie_token:
if not validate_csrf_token(csrf_token, request):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "CSRF_TOKEN_INVALID",
"message": "Invalid or expired CSRF token",
},
)
elif csrf_token and not cookie_token:
# Token in header but no cookie - invalid state
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "CSRF_TOKEN_MISMATCH",
"message": "CSRF token cookie not found",
},
)
response = await call_next(request)
# If this is an authenticated request and no CSRF cookie exists,
# generate and set one
if has_jwt and not request.cookies.get(CSRF_COOKIE_NAME):
csrf_token = generate_csrf_token()
response.set_cookie(
key=CSRF_COOKIE_NAME,
value=csrf_token,
max_age=CSRF_TOKEN_EXPIRY,
httponly=False, # Must be accessible to JavaScript for double-submit
secure=settings.DEBUG is False, # Secure in production
samesite="lax", # Prevent cross-site requests
path="/",
)
# Also add the token to response headers for convenience
response.headers[CSRF_HEADER_NAME] = csrf_token
return response
def get_csrf_token_from_request(request: Request) -> Optional[str]:
"""Helper to extract CSRF token from request"""
return request.headers.get(CSRF_HEADER_NAME) or request.cookies.get(CSRF_COOKIE_NAME)
def require_csrf_token(request: Request) -> str:
"""
Dependency function to require CSRF token in route handlers.
Use with: Depends(require_csrf_token)
"""
csrf_token = get_csrf_token_from_request(request)
if not csrf_token:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="CSRF token required",
)
return csrf_token
+181 -32
View File
@@ -1,4 +1,8 @@
from fastapi import Request, Response
"""
Enhanced Rate Limiting Module for TradeMate
Provides fine-grained rate limiting for sensitive endpoints
"""
from fastapi import Request, Response, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from app.config import settings
from app.core.security import decode_token
@@ -7,6 +11,7 @@ from redis.asyncio import ConnectionPool
import logging
import time
from datetime import datetime
from typing import Dict, Optional
logger = logging.getLogger(__name__)
@@ -14,6 +19,7 @@ _redis_pool = None
async def get_redis():
"""Get Redis connection pool"""
global _redis_pool
if _redis_pool is None:
_redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20)
@@ -21,6 +27,7 @@ async def get_redis():
def get_user_tier_from_token(request: Request) -> str:
"""Extract user tier from JWT token"""
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
request.state.user_id = None
@@ -36,29 +43,119 @@ def get_user_tier_from_token(request: Request) -> str:
return request.state.user_tier
# Global rate limits (per minute) by tier
RATE_LIMITS = {
"free": 100,
"pro": 500,
"enterprise": 2000,
"anonymous": 10,
"guest": 30,
"free": 60,
"pro": 300,
"enterprise": 1000,
}
# Sensitive endpoint rate limits (stricter limits regardless of tier)
SENSITIVE_ENDPOINT_LIMITS: Dict[str, dict] = {
# Authentication endpoints - very strict
"/api/v1/auth/login": {"limit": 5, "window": 60, "name": "login_attempts"},
"/api/v1/auth/register": {"limit": 3, "window": 3600, "name": "register_attempts"},
"/api/v1/auth/refresh": {"limit": 10, "window": 60, "name": "refresh_attempts"},
"/api/v1/auth/login/guest": {"limit": 10, "window": 60, "name": "guest_login"},
"/api/v1/auth/wechat-login": {"limit": 5, "window": 60, "name": "wechat_login"},
# Password endpoints - very strict
"/api/v1/auth/password": {"limit": 3, "window": 300, "name": "password_change"},
# Profile endpoints - moderate
"/api/v1/auth/me": {"limit": 30, "window": 60, "name": "profile_update"},
"/api/v1/auth/settings": {"limit": 30, "window": 60, "name": "settings_update"},
# Payment endpoints - strict
"/api/v1/payment": {"limit": 20, "window": 60, "name": "payment_requests"},
# WhatsApp endpoints - moderate
"/api/v1/whatsapp": {"limit": 30, "window": 60, "name": "whatsapp_requests"},
# AI endpoints - moderate
"/api/v1/ai": {"limit": 60, "window": 60, "name": "ai_requests"},
"/api/v1/ai-assistant": {"limit": 60, "window": 60, "name": "ai_assistant"},
# Translation endpoints - moderate
"/api/v1/translate": {"limit": 60, "window": 60, "name": "translate_requests"},
# Marketing endpoints - moderate
"/api/v1/marketing": {"limit": 30, "window": 60, "name": "marketing_requests"},
# Quotation endpoints - moderate
"/api/v1/quotations": {"limit": 30, "window": 60, "name": "quotation_requests"},
# Customer endpoints - moderate
"/api/v1/customers": {"limit": 60, "window": 60, "name": "customer_requests"},
# Admin endpoints - strict
"/api/v1/admin": {"limit": 30, "window": 60, "name": "admin_requests"},
}
async def check_rate_limit(user_id: str, tier: str) -> int:
async def check_endpoint_rate_limit(
path: str, method: str, user_id: Optional[str], tier: str
) -> tuple[bool, int, str]:
"""
Check rate limit for a specific endpoint.
Returns: (allowed, remaining, limit_name)
"""
if not user_id:
# Use IP-based limiting for anonymous users
user_id = "anonymous"
# Find matching endpoint limit
matched_limit = None
for endpoint_pattern, config in SENSITIVE_ENDPOINT_LIMITS.items():
if path.startswith(endpoint_pattern):
matched_limit = config
break
if not matched_limit:
# Use tier-based global limit
limit = RATE_LIMITS.get(tier, RATE_LIMITS["free"])
key = f"ratelimit:{user_id}:{int(time.time() // 60)}"
limit_name = f"global_{tier}"
else:
limit = matched_limit["limit"]
window = matched_limit["window"]
key = f"ratelimit:{matched_limit['name']}:{user_id}:{int(time.time() // window)}"
limit_name = matched_limit["name"]
r = await get_redis()
count = await r.incr(key)
if count == 1:
# Set expiry based on window
window = matched_limit["window"] if matched_limit else 60
await r.expire(key, window + 5)
remaining = max(0, limit - count)
return remaining > 0, remaining, limit_name
async def check_global_rate_limit(user_id: str, tier: str) -> int:
"""Check global rate limit (per minute)"""
r = await get_redis()
now = time.time()
window = 60
key = f"ratelimit:{user_id}:{int(now // window)}"
count = await r.incr(key)
if count == 1:
await r.expire(key, window + 5)
limit = RATE_LIMITS.get(tier, 100)
remaining = max(0, limit - count)
return remaining
class TierMiddleware(BaseHTTPMiddleware):
"""Middleware to extract user tier from JWT token"""
async def dispatch(self, request: Request, call_next):
if request.url.path.startswith("/api/v1"):
tier = get_user_tier_from_token(request)
@@ -81,66 +178,118 @@ class TierMiddleware(BaseHTTPMiddleware):
request.state.user_id = None
request.state.user_tier = "anonymous"
request.state.tier_config = {}
response = await call_next(request)
return response
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Enhanced rate limiting middleware with fine-grained endpoint limits.
Features:
- Tier-based global rate limits
- Endpoint-specific rate limits for sensitive operations
- Separate limits for authentication, payment, and admin operations
- Proper rate limit headers in responses
"""
async def dispatch(self, request: Request, call_next):
if not request.url.path.startswith("/api/v1"):
path = request.url.path
# Skip rate limiting for health check and non-API routes
if not path.startswith("/api/v1"):
return await call_next(request)
user_tier = getattr(request.state, "user_tier", None)
if user_tier in ("anonymous", None):
return await call_next(request)
try:
user_id = getattr(request.state, "user_id", None)
if not user_id:
return await call_next(request)
remaining = await check_rate_limit(
user_id, user_tier
# For anonymous users, still apply basic rate limiting
# Use IP-based limiting
client_ip = request.client.host if request.client else "unknown"
allowed, remaining, limit_name = await check_endpoint_rate_limit(
path, request.method, client_ip, "anonymous"
)
if remaining == 0:
if not allowed:
return Response(
status_code=429,
content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}',
media_type="application/json",
headers={"Retry-After": "60"},
headers={
"Retry-After": "60",
"X-RateLimit-Limit": "10",
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(time.time()) + 60),
},
)
response = await call_next(request)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Limit"] = "10"
response.headers["X-RateLimit-Name"] = limit_name
return response
try:
user_id = getattr(request.state, "user_id", None)
if not user_id:
return await call_next(request)
# Check endpoint-specific rate limit
allowed, remaining, limit_name = await check_endpoint_rate_limit(
path, request.method, user_id, user_tier
)
if not allowed:
return Response(
status_code=429,
content=f'{{"error":"RATE_LIMITED","detail":"Rate limit exceeded for {limit_name}. Try again later."}}',
media_type="application/json",
headers={
"Retry-After": "60",
"X-RateLimit-Limit": str(RATE_LIMITS.get(user_tier, 100)),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(time.time()) + 60),
"X-RateLimit-Name": limit_name,
},
)
response = await call_next(request)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Limit"] = str(
SENSITIVE_ENDPOINT_LIMITS.get(path.split("/api/v1")[1].split("/")[0] if "/" in path.split("/api/v1")[1] else "", {}).get("limit", RATE_LIMITS.get(user_tier, 100))
)
response.headers["X-RateLimit-Name"] = limit_name
return response
except Exception as e:
logger.warning(f"Rate limit check failed: {e}")
return await call_next(request)
class QuotaMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce daily quotas for specific operations"""
async def dispatch(self, request: Request, call_next):
if not request.url.path.startswith("/api/v1"):
return await call_next(request)
user_tier = getattr(request.state, "user_tier", None)
if user_tier in ("anonymous", None):
return await call_next(request)
user_id = getattr(request.state, "user_id", None)
if not user_id:
return await call_next(request)
tier = user_tier
if tier == "enterprise":
return await call_next(request)
path = request.url.path
method = request.method
if method == "GET":
return await call_next(request)
quota_map = [
("/api/v1/translate/reply", {"free": settings.FREE_DAILY_REPLIES, "pro": settings.PRO_DAILY_REPLIES}),
("/api/v1/translate", {"free": settings.FREE_DAILY_TRANSLATE_CHARS, "pro": settings.PRO_DAILY_TRANSLATE_CHARS}),
@@ -148,20 +297,20 @@ class QuotaMiddleware(BaseHTTPMiddleware):
("/api/v1/marketing", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}),
("/api/v1/quotations", {"free": settings.FREE_DAILY_QUOTATIONS, "pro": settings.PRO_DAILY_QUOTATIONS}),
]
matched_key = None
for prefix, limits in quota_map:
if path.startswith(prefix):
matched_key = prefix
break
if not matched_key:
return await call_next(request)
limit = quota_map[matched_key].get(tier)
if limit is None:
return await call_next(request)
try:
r = await get_redis()
key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}"
@@ -175,5 +324,5 @@ class QuotaMiddleware(BaseHTTPMiddleware):
raise
except Exception as e:
logger.warning(f"Quota check failed: {e}")
return await call_next(request)
+80 -4
View File
@@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
from app.core.exceptions import register_exception_handlers
from app.core.middleware import TierMiddleware, QuotaMiddleware, RateLimitMiddleware
from app.core.csrf import CSRFMiddleware
import logging
logging.basicConfig(level=logging.INFO)
@@ -34,14 +35,70 @@ app = FastAPI(
debug=settings.DEBUG,
)
# =============================================================================
# CORS Configuration - Security Hardened
# =============================================================================
# Only allow specific origins (frontend URLs)
# Only allow specific HTTP methods (no TRACE, CONNECT, etc.)
# Only allow specific headers (no arbitrary headers)
# =============================================================================
# Define allowed origins from environment/config
# In production, this should be your actual frontend domain(s)
ALLOWED_ORIGINS = [
settings.FRONTEND_URL,
"http://localhost:3000", # Legacy frontend
"http://localhost:5173", # Vite dev server
"http://localhost:5174", # User workspace dev server
"https://trade.yuzhiran.com", # Production domain
"https://trade.yuzhiran.com/app",
"https://trade.yuzhiran.com/admin",
"https://trade.yuzhiran.com/workspace",
]
# Allowed HTTP methods (explicitly listed for security)
ALLOWED_METHODS = [
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"OPTIONS",
]
# Allowed headers (explicitly listed)
ALLOWED_HEADERS = [
"Authorization",
"Content-Type",
"X-CSRF-Token",
"X-Requested-With",
"Accept",
"Origin",
]
app.add_middleware(
CORSMiddleware,
allow_origins=[settings.FRONTEND_URL],
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_methods=ALLOWED_METHODS,
allow_headers=ALLOWED_HEADERS,
max_age=600, # Preflight cache duration
expose_headers=[
"X-RateLimit-Limit",
"X-RateLimit-Remaining",
"X-RateLimit-Reset",
"X-RateLimit-Name",
"X-CSRF-Token",
],
)
# =============================================================================
# Security Middleware Stack
# =============================================================================
# Order matters - CSRF should come after CORS but before other middleware
# =============================================================================
app.add_middleware(CSRFMiddleware)
app.add_middleware(RateLimitMiddleware)
app.add_middleware(QuotaMiddleware)
app.add_middleware(TierMiddleware)
@@ -49,12 +106,30 @@ app.add_middleware(TierMiddleware)
register_exception_handlers(app)
@app.on_event("startup")
async def load_ai_providers_from_db():
try:
from app.database import get_db
from app.ai.router import get_ai_router
async for db in get_db():
router = get_ai_router()
count = await router.reload_from_db(db)
if count == 0:
seeded = await router.seed_from_env(db)
if seeded:
await router.reload_from_db(db)
break
except Exception as e:
logger.warning(f"AI provider DB load failed (tables may not exist yet): {e}")
@app.get("/health")
async def health():
return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"}
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, discovery_record, certification, invoice, usage, referral, admin_search, search
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, discovery_record, certification, invoice, usage, referral, admin_search, search, admin_ai
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"])
@@ -85,6 +160,7 @@ app.include_router(invoice.router, prefix="/api/v1/invoices", tags=["invoices"])
app.include_router(usage.router, prefix="/api/v1/usage", tags=["usage"])
app.include_router(referral.router, prefix="/api/v1/referral", tags=["referral"])
app.include_router(admin_search.router, prefix="/api/v1/admin", tags=["admin"])
app.include_router(admin_ai.router, prefix="/api/v1/admin", tags=["admin"])
app.include_router(search.router, prefix="/api/v1/search", tags=["search"])