Files
trade-assistant/backend/app/ai/router.py
T
TradeMate Dev c6206787da Initial commit: TradeMate 外贸小助手 MVP
项目结构:
- backend/     Python FastAPI 后端
- uni-app/     uni-app跨端前端
- docs/        设计文档
- docker-compose.yml  Docker编排
- nginx/scripts/systemd 运维配置

已完成功能:
- 用户认证 (JWT)
- 智能翻译 + 回复建议
- 营销素材生成
- 客户管理 + 沉默检测
- 报价单管理
- 产品库管理
- 汇率换算
- 推送通知 (uni-push)
- WhatsApp Webhook框架
- Celery定时任务
2026-05-08 18:17:12 +08:00

111 lines
4.2 KiB
Python

from typing import Dict, Any, Optional, List
from app.ai.base import AIProvider
from app.ai.providers import OpenAIProvider, ClaudeProvider, DeepLProvider, LocalProvider
from app.config import settings
from app.ai.trade_corpus import TradeCorpus
import logging
logger = logging.getLogger(__name__)
class AIRouter:
def __init__(self):
self.providers: Dict[str, AIProvider] = {}
self.routing_rules = settings.AI_ROUTING
self.corpus = TradeCorpus()
self._init_providers()
def _init_providers(self):
if settings.OPENAI_API_KEY:
try:
self.providers["openai"] = OpenAIProvider(api_key=settings.OPENAI_API_KEY)
logger.info("OpenAI provider ready")
except Exception as e:
logger.warning(f"OpenAI init failed: {e}")
if settings.ANTHROPIC_API_KEY:
try:
self.providers["anthropic"] = ClaudeProvider(api_key=settings.ANTHROPIC_API_KEY)
logger.info("Claude provider ready")
except Exception as e:
logger.warning(f"Claude init failed: {e}")
if settings.DEEPL_API_KEY:
try:
self.providers["deepl"] = DeepLProvider(api_key=settings.DEEPL_API_KEY)
logger.info("DeepL provider ready")
except Exception as e:
logger.warning(f"DeepL init failed: {e}")
if settings.LOCAL_MODEL_ENABLED:
try:
self.providers["local"] = LocalProvider(model_url=settings.LOCAL_MODEL_URL)
logger.info("Local provider ready")
except Exception as e:
logger.warning(f"Local init failed: {e}")
def get_providers_for_task(self, task_type: str) -> List[AIProvider]:
rules = self.routing_rules.get(
task_type,
{"primary": "openai", "fallback": ["local"]},
)
ordered = []
seen = set()
primary = rules.get("primary")
if primary and primary in self.providers:
ordered.append(self.providers[primary])
seen.add(primary)
for name in rules.get("fallback", []):
if name in self.providers and name not in seen:
ordered.append(self.providers[name])
seen.add(name)
if not ordered:
ordered = list(self.providers.values())
logger.warning(f"No preferred providers for '{task_type}', using all available")
return ordered
async def execute(self, task_type: str, method: str, *args, **kwargs) -> Dict[str, Any]:
providers = self.get_providers_for_task(task_type)
last_error = None
for provider in providers:
try:
method_fn = getattr(provider, method)
result = await method_fn(*args, **kwargs)
result["provider_used"] = provider.name
return result
except NotImplementedError:
continue
except Exception as e:
logger.warning(f"{provider.name} failed for {task_type}: {e}")
last_error = e
continue
raise Exception(f"All providers failed for '{task_type}'. Last error: {last_error}")
async def translate(self, text: str, target_lang: str, source_lang: Optional[str] = None, context: Optional[str] = None) -> Dict[str, Any]:
return await self.execute("translate", "translate", text, source_lang, target_lang, context)
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
return await self.execute("reply", "reply", inquiry, context, tone)
async def marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
return await self.execute("marketing", "generate_marketing", product_info, target, style, language)
async def extract(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
return await self.execute("extract", "extract_info", text, schema)
_router_instance = None
def get_ai_router() -> AIRouter:
global _router_instance
if _router_instance is None:
_router_instance = AIRouter()
return _router_instance