diff --git a/backend/.env.example b/backend/.env.example index db73853..daa44e0 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -52,3 +52,7 @@ DEBUG=true # URL FRONTEND_URL=http://localhost:3000 BACKEND_URL=http://localhost:8000 + +# Security (CSRF/CORS) - CSRF protection is enabled by default +# Frontend must send X-CSRF-Token header with state-changing requests +# The token is provided via csrf_token cookie and X-CSRF-Token response header diff --git a/backend/app/api/v1/auth.py b/backend/app/api/v1/auth.py index cf3424e..c784098 100644 --- a/backend/app/api/v1/auth.py +++ b/backend/app/api/v1/auth.py @@ -7,11 +7,13 @@ import uuid from app.database import get_db from app.models.user import User from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token +from app.core.csrf import require_csrf_token from pydantic import BaseModel, EmailStr from datetime import datetime, timedelta from app.services.admin import AdminService from app.models.subscription import Subscription from app.api.v1.referral import apply_referral +from app.config import settings import logging logger = logging.getLogger(__name__) @@ -44,7 +46,12 @@ class RefreshRequest(BaseModel): @router.post("/register") -async def register(data: RegisterRequest, request: Request, db: AsyncSession = Depends(get_db)): +async def register( + data: RegisterRequest, + request: Request, + db: AsyncSession = Depends(get_db), + _csrf: str = Depends(require_csrf_token), +): existing = await db.execute(select(User).where(User.phone == data.phone)) if existing.scalar_one_or_none(): raise HTTPException(status_code=400, detail="Phone already registered") @@ -92,6 +99,7 @@ async def login( data: LoginRequest, request: Request, db: AsyncSession = Depends(get_db), + _csrf: str = Depends(require_csrf_token), ): login_id = data.username or data.phone if not login_id: @@ -269,6 +277,7 @@ async def update_me( data: ProfileUpdate, authorization: Optional[str] = Header(None, alias="Authorization"), db: AsyncSession = Depends(get_db), + _csrf: str = Depends(require_csrf_token), ): if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing token") @@ -306,6 +315,7 @@ async def change_password( data: PasswordChange, authorization: Optional[str] = Header(None, alias="Authorization"), db: AsyncSession = Depends(get_db), + _csrf: str = Depends(require_csrf_token), ): if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing token") @@ -340,7 +350,12 @@ async def wechat_config(): @router.post("/wechat-login") -async def wechat_login(data: WeChatLoginRequest, request: Request, db: AsyncSession = Depends(get_db)): +async def wechat_login( + data: WeChatLoginRequest, + request: Request, + db: AsyncSession = Depends(get_db), + _csrf: str = Depends(require_csrf_token), +): from app.services.wechat import wechat_service session = await wechat_service.code2session(data.code) @@ -383,6 +398,7 @@ async def update_settings( data: SettingsUpdate, authorization: Optional[str] = Header(None, alias="Authorization"), db: AsyncSession = Depends(get_db), + _csrf: str = Depends(require_csrf_token), ): if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing token") diff --git a/backend/app/api/v1/payment.py b/backend/app/api/v1/payment.py index e6ca093..0153a6c 100644 --- a/backend/app/api/v1/payment.py +++ b/backend/app/api/v1/payment.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Header from sqlalchemy.ext.asyncio import AsyncSession from pydantic import BaseModel from typing import Optional @@ -6,6 +6,7 @@ from app.database import get_db from app.services.payment import PaymentService from app.services.wechat_pay import WeChatPayService from app.api.v1.deps import get_current_user_id +from app.core.csrf import require_csrf_token router = APIRouter() @@ -40,6 +41,7 @@ async def create_order( data: CreateOrderRequest, user_id: str = Depends(get_current_user_id), db: AsyncSession = Depends(get_db), + _csrf: str = Depends(require_csrf_token), ): svc = PaymentService(db) try: @@ -52,6 +54,7 @@ async def create_order( async def payment_callback( data: PaymentCallbackRequest, db: AsyncSession = Depends(get_db), + _csrf: str = Depends(require_csrf_token), ): svc = PaymentService(db) success = await svc.handle_payment_callback(data.payment_id, data.success) diff --git a/backend/app/config.py b/backend/app/config.py index 5ccffed..e0ac648 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -8,11 +8,10 @@ ENV_FILE = PROJECT_ROOT / ".env" class Settings(BaseSettings): - model_config = { - "env_file": str(ENV_FILE), - "env_file_encoding": "utf-8", - "extra": "ignore", - } + class Config: + env_file = str(ENV_FILE) + env_file_encoding = "utf-8" + extra = "ignore" APP_NAME: str = "TradeMate" @@ -29,21 +28,14 @@ class Settings(BaseSettings): CELERY_BROKER_URL: str = "redis://localhost:6379/1" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/2" - OPENAI_API_KEY: Optional[str] = None - ANTHROPIC_API_KEY: Optional[str] = None - DEEPL_API_KEY: Optional[str] = None - SENSENOVA_API_KEY: Optional[str] = None SENSENOVA_BASE_URL: str = "https://token.sensenova.cn/v1" - SENSENOVA_MODEL: str = "sensenova-6.7-flash-lite" + SENSENOVA_MODEL: str = "deepseek-v4-flash" IFLYTEK_API_KEY: Optional[str] = None IFLYTEK_API_BASE: str = "https://maas-api.cn-huabei-1.xf-yun.com/v2" IFLYTEK_MODEL: str = "astron-code-latest" - LOCAL_MODEL_ENABLED: bool = False - LOCAL_MODEL_URL: str = "http://localhost:8001" - OPENCODE_GO_API_KEY: Optional[str] = None OPENCODE_GO_BASE_URL: str = "https://opencode.ai/zen/go/v1" OPENCODE_GO_MODEL: str = "minimax-m2.7" @@ -85,12 +77,12 @@ class Settings(BaseSettings): DEBUG: bool = True AI_ROUTING: dict = { - "translate": {"primary": "alibaba-mt", "fallback": ["opencode_go", "sensenova", "openai", "local"]}, - "reply": {"primary": "opencode_go", "fallback": ["sensenova", "anthropic", "local"]}, - "marketing": {"primary": "opencode_go", "fallback": ["sensenova", "openai", "local"]}, - "extract": {"primary": "opencode_go", "fallback": ["sensenova", "openai"]}, - "quotation": {"primary": "opencode_go", "fallback": ["sensenova", "openai"]}, - "chat": {"primary": "nvidia", "fallback": ["opencode_go", "openai", "sensenova"]}, + "translate": {"primary": "sensenova", "fallback": ["alibaba-mt", "opencode_go"]}, + "reply": {"primary": "sensenova", "fallback": ["opencode_go"]}, + "marketing": {"primary": "sensenova", "fallback": ["opencode_go"]}, + "extract": {"primary": "sensenova", "fallback": ["opencode_go"]}, + "quotation": {"primary": "sensenova", "fallback": ["opencode_go"]}, + "chat": {"primary": "sensenova", "fallback": ["opencode_go", "nvidia"]}, } FREE_DAILY_TRANSLATE_CHARS: int = 5000 diff --git a/backend/app/core/csrf.py b/backend/app/core/csrf.py new file mode 100644 index 0000000..e1b4486 --- /dev/null +++ b/backend/app/core/csrf.py @@ -0,0 +1,166 @@ +""" +CSRF Protection Module for TradeMate +Provides CSRF token generation and validation for form submissions +""" +import secrets +import time +from typing import Optional, Tuple +from fastapi import Request, HTTPException, status +from starlette.middleware.base import BaseHTTPMiddleware +from app.config import settings +import logging + +logger = logging.getLogger(__name__) + +# CSRF token configuration +CSRF_TOKEN_EXPIRY = 3600 # 1 hour +CSRF_HEADER_NAME = "X-CSRF-Token" +CSRF_COOKIE_NAME = "csrf_token" + +# Methods that require CSRF protection +CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"} + +# Endpoints that should skip CSRF protection (e.g., webhook endpoints) +CSRF_SKIP_ENDPOINTS = [ + "/api/v1/webhook/", + "/api/v1/payment/notify", + "/api/v1/whatsapp/webhook", +] + + +def generate_csrf_token() -> str: + """Generate a secure CSRF token""" + return secrets.token_urlsafe(32) + + +def validate_csrf_token(token: Optional[str], request: Request) -> bool: + """ + Validate CSRF token from request. + + Checks: + 1. Token is present + 2. Token matches the session/token cookie + 3. Token is not expired + """ + if not token: + return False + + # Get the expected token from cookie or session + cookie_token = request.cookies.get(CSRF_COOKIE_NAME) + if not cookie_token: + return False + + # Constant-time comparison to prevent timing attacks + return secrets.compare_digest(token, cookie_token) + + +class CSRFMiddleware(BaseHTTPMiddleware): + """ + CSRF protection middleware for FastAPI. + + For JWT-based APIs, CSRF protection is primarily needed for: + 1. Form-based submissions (if any) + 2. Any endpoint that uses cookies for authentication + 3. Prevention of cross-site request forgery attacks + + The middleware: + - Generates CSRF tokens for authenticated sessions + - Validates CSRF tokens on state-changing requests + - Skips validation for webhook endpoints and public APIs + """ + + async def dispatch(self, request: Request, call_next): + path = request.url.path + + # Skip CSRF protection for: + # 1. Health check endpoint + # 2. Webhook endpoints + # 3. Public API endpoints (no auth required) + if path == "/health" or any(path.startswith(skip) for skip in CSRF_SKIP_ENDPOINTS): + return await call_next(request) + + # Get authorization header + auth_header = request.headers.get("Authorization", "") + has_jwt = auth_header.startswith("Bearer ") + + # For API requests with JWT, we use double-submit cookie pattern + # The client should send X-CSRF-Token header matching the csrf_token cookie + + if request.method in CSRF_PROTECTED_METHODS: + # Check for CSRF token in header + csrf_token = request.headers.get(CSRF_HEADER_NAME) + cookie_token = request.cookies.get(CSRF_COOKIE_NAME) + + # If there's a JWT but no CSRF token, this might be a direct API call + # In that case, we require the CSRF token to be present + if has_jwt and not csrf_token: + # This is a potential CSRF attempt + # For API clients using JWT, we still require CSRF protection + # to prevent attacks from malicious websites + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": "CSRF_TOKEN_MISSING", + "message": "CSRF token required for this request", + "required_header": CSRF_HEADER_NAME, + }, + ) + + # Validate the token if present + if csrf_token and cookie_token: + if not validate_csrf_token(csrf_token, request): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": "CSRF_TOKEN_INVALID", + "message": "Invalid or expired CSRF token", + }, + ) + elif csrf_token and not cookie_token: + # Token in header but no cookie - invalid state + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": "CSRF_TOKEN_MISMATCH", + "message": "CSRF token cookie not found", + }, + ) + + response = await call_next(request) + + # If this is an authenticated request and no CSRF cookie exists, + # generate and set one + if has_jwt and not request.cookies.get(CSRF_COOKIE_NAME): + csrf_token = generate_csrf_token() + response.set_cookie( + key=CSRF_COOKIE_NAME, + value=csrf_token, + max_age=CSRF_TOKEN_EXPIRY, + httponly=False, # Must be accessible to JavaScript for double-submit + secure=settings.DEBUG is False, # Secure in production + samesite="lax", # Prevent cross-site requests + path="/", + ) + # Also add the token to response headers for convenience + response.headers[CSRF_HEADER_NAME] = csrf_token + + return response + + +def get_csrf_token_from_request(request: Request) -> Optional[str]: + """Helper to extract CSRF token from request""" + return request.headers.get(CSRF_HEADER_NAME) or request.cookies.get(CSRF_COOKIE_NAME) + + +def require_csrf_token(request: Request) -> str: + """ + Dependency function to require CSRF token in route handlers. + Use with: Depends(require_csrf_token) + """ + csrf_token = get_csrf_token_from_request(request) + if not csrf_token: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="CSRF token required", + ) + return csrf_token diff --git a/backend/app/core/middleware.py b/backend/app/core/middleware.py index 926589e..33c5097 100644 --- a/backend/app/core/middleware.py +++ b/backend/app/core/middleware.py @@ -1,4 +1,8 @@ -from fastapi import Request, Response +""" +Enhanced Rate Limiting Module for TradeMate +Provides fine-grained rate limiting for sensitive endpoints +""" +from fastapi import Request, Response, HTTPException, status from starlette.middleware.base import BaseHTTPMiddleware from app.config import settings from app.core.security import decode_token @@ -7,6 +11,7 @@ from redis.asyncio import ConnectionPool import logging import time from datetime import datetime +from typing import Dict, Optional logger = logging.getLogger(__name__) @@ -14,6 +19,7 @@ _redis_pool = None async def get_redis(): + """Get Redis connection pool""" global _redis_pool if _redis_pool is None: _redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20) @@ -21,6 +27,7 @@ async def get_redis(): def get_user_tier_from_token(request: Request) -> str: + """Extract user tier from JWT token""" auth = request.headers.get("Authorization", "") if not auth.startswith("Bearer "): request.state.user_id = None @@ -36,29 +43,119 @@ def get_user_tier_from_token(request: Request) -> str: return request.state.user_tier +# Global rate limits (per minute) by tier RATE_LIMITS = { - "free": 100, - "pro": 500, - "enterprise": 2000, + "anonymous": 10, + "guest": 30, + "free": 60, + "pro": 300, + "enterprise": 1000, +} + +# Sensitive endpoint rate limits (stricter limits regardless of tier) +SENSITIVE_ENDPOINT_LIMITS: Dict[str, dict] = { + # Authentication endpoints - very strict + "/api/v1/auth/login": {"limit": 5, "window": 60, "name": "login_attempts"}, + "/api/v1/auth/register": {"limit": 3, "window": 3600, "name": "register_attempts"}, + "/api/v1/auth/refresh": {"limit": 10, "window": 60, "name": "refresh_attempts"}, + "/api/v1/auth/login/guest": {"limit": 10, "window": 60, "name": "guest_login"}, + "/api/v1/auth/wechat-login": {"limit": 5, "window": 60, "name": "wechat_login"}, + + # Password endpoints - very strict + "/api/v1/auth/password": {"limit": 3, "window": 300, "name": "password_change"}, + + # Profile endpoints - moderate + "/api/v1/auth/me": {"limit": 30, "window": 60, "name": "profile_update"}, + "/api/v1/auth/settings": {"limit": 30, "window": 60, "name": "settings_update"}, + + # Payment endpoints - strict + "/api/v1/payment": {"limit": 20, "window": 60, "name": "payment_requests"}, + + # WhatsApp endpoints - moderate + "/api/v1/whatsapp": {"limit": 30, "window": 60, "name": "whatsapp_requests"}, + + # AI endpoints - moderate + "/api/v1/ai": {"limit": 60, "window": 60, "name": "ai_requests"}, + "/api/v1/ai-assistant": {"limit": 60, "window": 60, "name": "ai_assistant"}, + + # Translation endpoints - moderate + "/api/v1/translate": {"limit": 60, "window": 60, "name": "translate_requests"}, + + # Marketing endpoints - moderate + "/api/v1/marketing": {"limit": 30, "window": 60, "name": "marketing_requests"}, + + # Quotation endpoints - moderate + "/api/v1/quotations": {"limit": 30, "window": 60, "name": "quotation_requests"}, + + # Customer endpoints - moderate + "/api/v1/customers": {"limit": 60, "window": 60, "name": "customer_requests"}, + + # Admin endpoints - strict + "/api/v1/admin": {"limit": 30, "window": 60, "name": "admin_requests"}, } -async def check_rate_limit(user_id: str, tier: str) -> int: +async def check_endpoint_rate_limit( + path: str, method: str, user_id: Optional[str], tier: str +) -> tuple[bool, int, str]: + """ + Check rate limit for a specific endpoint. + + Returns: (allowed, remaining, limit_name) + """ + if not user_id: + # Use IP-based limiting for anonymous users + user_id = "anonymous" + + # Find matching endpoint limit + matched_limit = None + for endpoint_pattern, config in SENSITIVE_ENDPOINT_LIMITS.items(): + if path.startswith(endpoint_pattern): + matched_limit = config + break + + if not matched_limit: + # Use tier-based global limit + limit = RATE_LIMITS.get(tier, RATE_LIMITS["free"]) + key = f"ratelimit:{user_id}:{int(time.time() // 60)}" + limit_name = f"global_{tier}" + else: + limit = matched_limit["limit"] + window = matched_limit["window"] + key = f"ratelimit:{matched_limit['name']}:{user_id}:{int(time.time() // window)}" + limit_name = matched_limit["name"] + + r = await get_redis() + count = await r.incr(key) + + if count == 1: + # Set expiry based on window + window = matched_limit["window"] if matched_limit else 60 + await r.expire(key, window + 5) + + remaining = max(0, limit - count) + return remaining > 0, remaining, limit_name + + +async def check_global_rate_limit(user_id: str, tier: str) -> int: + """Check global rate limit (per minute)""" r = await get_redis() now = time.time() window = 60 key = f"ratelimit:{user_id}:{int(now // window)}" - + count = await r.incr(key) if count == 1: await r.expire(key, window + 5) - + limit = RATE_LIMITS.get(tier, 100) remaining = max(0, limit - count) return remaining class TierMiddleware(BaseHTTPMiddleware): + """Middleware to extract user tier from JWT token""" + async def dispatch(self, request: Request, call_next): if request.url.path.startswith("/api/v1"): tier = get_user_tier_from_token(request) @@ -81,66 +178,118 @@ class TierMiddleware(BaseHTTPMiddleware): request.state.user_id = None request.state.user_tier = "anonymous" request.state.tier_config = {} - + response = await call_next(request) return response class RateLimitMiddleware(BaseHTTPMiddleware): + """ + Enhanced rate limiting middleware with fine-grained endpoint limits. + + Features: + - Tier-based global rate limits + - Endpoint-specific rate limits for sensitive operations + - Separate limits for authentication, payment, and admin operations + - Proper rate limit headers in responses + """ + async def dispatch(self, request: Request, call_next): - if not request.url.path.startswith("/api/v1"): + path = request.url.path + + # Skip rate limiting for health check and non-API routes + if not path.startswith("/api/v1"): return await call_next(request) - + user_tier = getattr(request.state, "user_tier", None) if user_tier in ("anonymous", None): - return await call_next(request) - - try: - user_id = getattr(request.state, "user_id", None) - if not user_id: - return await call_next(request) - remaining = await check_rate_limit( - user_id, user_tier + # For anonymous users, still apply basic rate limiting + # Use IP-based limiting + client_ip = request.client.host if request.client else "unknown" + allowed, remaining, limit_name = await check_endpoint_rate_limit( + path, request.method, client_ip, "anonymous" ) - if remaining == 0: + if not allowed: return Response( status_code=429, content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}', media_type="application/json", - headers={"Retry-After": "60"}, + headers={ + "Retry-After": "60", + "X-RateLimit-Limit": "10", + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(time.time()) + 60), + }, ) response = await call_next(request) response.headers["X-RateLimit-Remaining"] = str(remaining) + response.headers["X-RateLimit-Limit"] = "10" + response.headers["X-RateLimit-Name"] = limit_name return response + + try: + user_id = getattr(request.state, "user_id", None) + if not user_id: + return await call_next(request) + + # Check endpoint-specific rate limit + allowed, remaining, limit_name = await check_endpoint_rate_limit( + path, request.method, user_id, user_tier + ) + + if not allowed: + return Response( + status_code=429, + content=f'{{"error":"RATE_LIMITED","detail":"Rate limit exceeded for {limit_name}. Try again later."}}', + media_type="application/json", + headers={ + "Retry-After": "60", + "X-RateLimit-Limit": str(RATE_LIMITS.get(user_tier, 100)), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(time.time()) + 60), + "X-RateLimit-Name": limit_name, + }, + ) + + response = await call_next(request) + response.headers["X-RateLimit-Remaining"] = str(remaining) + response.headers["X-RateLimit-Limit"] = str( + SENSITIVE_ENDPOINT_LIMITS.get(path.split("/api/v1")[1].split("/")[0] if "/" in path.split("/api/v1")[1] else "", {}).get("limit", RATE_LIMITS.get(user_tier, 100)) + ) + response.headers["X-RateLimit-Name"] = limit_name + return response + except Exception as e: logger.warning(f"Rate limit check failed: {e}") return await call_next(request) class QuotaMiddleware(BaseHTTPMiddleware): + """Middleware to enforce daily quotas for specific operations""" + async def dispatch(self, request: Request, call_next): if not request.url.path.startswith("/api/v1"): return await call_next(request) - + user_tier = getattr(request.state, "user_tier", None) if user_tier in ("anonymous", None): return await call_next(request) - + user_id = getattr(request.state, "user_id", None) if not user_id: return await call_next(request) - + tier = user_tier - + if tier == "enterprise": return await call_next(request) - + path = request.url.path method = request.method - + if method == "GET": return await call_next(request) - + quota_map = [ ("/api/v1/translate/reply", {"free": settings.FREE_DAILY_REPLIES, "pro": settings.PRO_DAILY_REPLIES}), ("/api/v1/translate", {"free": settings.FREE_DAILY_TRANSLATE_CHARS, "pro": settings.PRO_DAILY_TRANSLATE_CHARS}), @@ -148,20 +297,20 @@ class QuotaMiddleware(BaseHTTPMiddleware): ("/api/v1/marketing", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}), ("/api/v1/quotations", {"free": settings.FREE_DAILY_QUOTATIONS, "pro": settings.PRO_DAILY_QUOTATIONS}), ] - + matched_key = None for prefix, limits in quota_map: if path.startswith(prefix): matched_key = prefix break - + if not matched_key: return await call_next(request) - + limit = quota_map[matched_key].get(tier) if limit is None: return await call_next(request) - + try: r = await get_redis() key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}" @@ -175,5 +324,5 @@ class QuotaMiddleware(BaseHTTPMiddleware): raise except Exception as e: logger.warning(f"Quota check failed: {e}") - + return await call_next(request) diff --git a/backend/app/main.py b/backend/app/main.py index 8bee3e2..57df4f8 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from app.config import settings from app.core.exceptions import register_exception_handlers from app.core.middleware import TierMiddleware, QuotaMiddleware, RateLimitMiddleware +from app.core.csrf import CSRFMiddleware import logging logging.basicConfig(level=logging.INFO) @@ -34,14 +35,70 @@ app = FastAPI( debug=settings.DEBUG, ) +# ============================================================================= +# CORS Configuration - Security Hardened +# ============================================================================= +# Only allow specific origins (frontend URLs) +# Only allow specific HTTP methods (no TRACE, CONNECT, etc.) +# Only allow specific headers (no arbitrary headers) +# ============================================================================= + +# Define allowed origins from environment/config +# In production, this should be your actual frontend domain(s) +ALLOWED_ORIGINS = [ + settings.FRONTEND_URL, + "http://localhost:3000", # Legacy frontend + "http://localhost:5173", # Vite dev server + "http://localhost:5174", # User workspace dev server + "https://trade.yuzhiran.com", # Production domain + "https://trade.yuzhiran.com/app", + "https://trade.yuzhiran.com/admin", + "https://trade.yuzhiran.com/workspace", +] + +# Allowed HTTP methods (explicitly listed for security) +ALLOWED_METHODS = [ + "GET", + "POST", + "PUT", + "PATCH", + "DELETE", + "OPTIONS", +] + +# Allowed headers (explicitly listed) +ALLOWED_HEADERS = [ + "Authorization", + "Content-Type", + "X-CSRF-Token", + "X-Requested-With", + "Accept", + "Origin", +] + app.add_middleware( CORSMiddleware, - allow_origins=[settings.FRONTEND_URL], + allow_origins=ALLOWED_ORIGINS, allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=ALLOWED_METHODS, + allow_headers=ALLOWED_HEADERS, + max_age=600, # Preflight cache duration + expose_headers=[ + "X-RateLimit-Limit", + "X-RateLimit-Remaining", + "X-RateLimit-Reset", + "X-RateLimit-Name", + "X-CSRF-Token", + ], ) +# ============================================================================= +# Security Middleware Stack +# ============================================================================= +# Order matters - CSRF should come after CORS but before other middleware +# ============================================================================= + +app.add_middleware(CSRFMiddleware) app.add_middleware(RateLimitMiddleware) app.add_middleware(QuotaMiddleware) app.add_middleware(TierMiddleware) @@ -49,12 +106,30 @@ app.add_middleware(TierMiddleware) register_exception_handlers(app) +@app.on_event("startup") +async def load_ai_providers_from_db(): + try: + from app.database import get_db + from app.ai.router import get_ai_router + + async for db in get_db(): + router = get_ai_router() + count = await router.reload_from_db(db) + if count == 0: + seeded = await router.seed_from_env(db) + if seeded: + await router.reload_from_db(db) + break + except Exception as e: + logger.warning(f"AI provider DB load failed (tables may not exist yet): {e}") + + @app.get("/health") async def health(): return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"} -from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, discovery_record, certification, invoice, usage, referral, admin_search, search +from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, discovery_record, certification, invoice, usage, referral, admin_search, search, admin_ai app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"]) app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"]) @@ -85,6 +160,7 @@ app.include_router(invoice.router, prefix="/api/v1/invoices", tags=["invoices"]) app.include_router(usage.router, prefix="/api/v1/usage", tags=["usage"]) app.include_router(referral.router, prefix="/api/v1/referral", tags=["referral"]) app.include_router(admin_search.router, prefix="/api/v1/admin", tags=["admin"]) +app.include_router(admin_ai.router, prefix="/api/v1/admin", tags=["admin"]) app.include_router(search.router, prefix="/api/v1/search", tags=["search"])