Files
trade-assistant/backend/app/ai/router.py
T
TradeMate Dev c397740748 feat: WeChat Pay integration, translation quota management, login UX fixes
- WeChat Pay APIv3 integration (JSAPI + Native) with cert-based auth
- TranslationQuota model + admin management UI (配额 tab)
- Alibaba MT provider now checks quota before translation
- Fix: admin tabs scrollable on mobile, remove header-card
- Fix: profile/login navigation - logout stays on profile, login returns to profile
- Fix: login form now visible by default (no extra click to show)
- Fix: home page translate link uses navigateTo (was switchTab to non-tabBar page)
- Add .coverage and apiclient_key.pem to gitignore
2026-05-20 18:30:12 +08:00

168 lines
6.8 KiB
Python

from typing import Dict, Any, Optional, List
from app.ai.base import AIProvider
from app.ai.providers import OpenAIProvider, ClaudeProvider, DeepLProvider, LocalProvider, SparkProvider, SensenovaProvider, OpencodeGoProvider, NvidiaProvider, AlibabaMTProvider
from app.config import settings
from app.ai.trade_corpus import TradeCorpus
import logging
logger = logging.getLogger(__name__)
class AIRouter:
def __init__(self):
self.providers: Dict[str, AIProvider] = {}
self.routing_rules = settings.AI_ROUTING
self.corpus = TradeCorpus()
self._init_providers()
def _init_providers(self):
if settings.OPENAI_API_KEY:
try:
self.providers["openai"] = OpenAIProvider(api_key=settings.OPENAI_API_KEY)
logger.info("OpenAI provider ready")
except Exception as e:
logger.warning(f"OpenAI init failed: {e}")
if settings.SENSENOVA_API_KEY:
try:
self.providers["sensenova"] = SensenovaProvider(
api_key=settings.SENSENOVA_API_KEY,
model=settings.SENSENOVA_MODEL,
base_url=settings.SENSENOVA_BASE_URL,
)
logger.info("Sensenova provider ready")
except Exception as e:
logger.warning(f"Sensenova init failed: {e}")
if settings.OPENCODE_GO_API_KEY:
try:
self.providers["opencode_go"] = OpencodeGoProvider(
api_key=settings.OPENCODE_GO_API_KEY,
model=settings.OPENCODE_GO_MODEL,
base_url=settings.OPENCODE_GO_BASE_URL,
)
logger.info("OpencodeGo provider ready")
except Exception as e:
logger.warning(f"OpencodeGo init failed: {e}")
if settings.NVIDIA_API_KEY:
try:
self.providers["nvidia"] = NvidiaProvider(
api_key=settings.NVIDIA_API_KEY,
model=settings.NVIDIA_MODEL,
base_url=settings.NVIDIA_BASE_URL,
)
logger.info("Nvidia provider ready")
except Exception as e:
logger.warning(f"Nvidia init failed: {e}")
if settings.ANTHROPIC_API_KEY:
try:
self.providers["anthropic"] = ClaudeProvider(api_key=settings.ANTHROPIC_API_KEY)
logger.info("Claude provider ready")
except Exception as e:
logger.warning(f"Claude init failed: {e}")
if settings.DEEPL_API_KEY:
try:
self.providers["deepl"] = DeepLProvider(api_key=settings.DEEPL_API_KEY)
logger.info("DeepL provider ready")
except Exception as e:
logger.warning(f"DeepL init failed: {e}")
if settings.IFLYTEK_API_KEY:
try:
self.providers["spark"] = SparkProvider(
api_key=settings.IFLYTEK_API_KEY,
model=settings.IFLYTEK_MODEL,
base_url=settings.IFLYTEK_API_BASE,
)
logger.info("Spark provider ready")
except Exception as e:
logger.warning(f"Spark init failed: {e}")
if settings.ALIBABA_ACCESS_KEY_ID and settings.ALIBABA_ACCESS_KEY_SECRET:
try:
self.providers["alibaba-mt"] = AlibabaMTProvider(
access_key_id=settings.ALIBABA_ACCESS_KEY_ID,
access_key_secret=settings.ALIBABA_ACCESS_KEY_SECRET,
)
logger.info("Alibaba MT provider ready")
except Exception as e:
logger.warning(f"Alibaba MT init failed: {e}")
if settings.LOCAL_MODEL_ENABLED:
try:
self.providers["local"] = LocalProvider(model_url=settings.LOCAL_MODEL_URL)
logger.info("Local provider ready")
except Exception as e:
logger.warning(f"Local init failed: {e}")
def get_providers_for_task(self, task_type: str) -> List[AIProvider]:
rules = self.routing_rules.get(
task_type,
{"primary": "openai", "fallback": ["local"]},
)
ordered = []
seen = set()
primary = rules.get("primary")
if primary and primary in self.providers:
ordered.append(self.providers[primary])
seen.add(primary)
for name in rules.get("fallback", []):
if name in self.providers and name not in seen:
ordered.append(self.providers[name])
seen.add(name)
if not ordered:
ordered = list(self.providers.values())
logger.warning(f"No preferred providers for '{task_type}', using all available")
return ordered
async def execute(self, task_type: str, method: str, *args, **kwargs) -> Dict[str, Any]:
providers = self.get_providers_for_task(task_type)
last_error = None
for provider in providers:
try:
method_fn = getattr(provider, method)
result = await method_fn(*args, **kwargs)
result["provider_used"] = provider.name
return result
except NotImplementedError:
continue
except Exception as e:
logger.warning(f"{provider.name} failed for {task_type}: {e}")
last_error = e
continue
raise Exception(f"All providers failed for '{task_type}'. Last error: {last_error}")
async def translate(self, text: str, target_lang: str, source_lang: Optional[str] = None, context: Optional[str] = None) -> Dict[str, Any]:
return await self.execute("translate", "translate", text, source_lang, target_lang, context)
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional", preference_context: Optional[str] = None) -> Dict[str, Any]:
return await self.execute("reply", "reply", inquiry, context, tone, preference_context)
async def marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en", preference_context: Optional[str] = None) -> Dict[str, Any]:
return await self.execute("marketing", "generate_marketing", product_info, target, style, language, preference_context)
async def extract(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
return await self.execute("extract", "extract_info", text, schema)
async def chat(self, message: str, history: list = None, system_prompt: str = None) -> Dict[str, Any]:
return await self.execute("chat", "chat", message, history, system_prompt)
_router_instance = None
def get_ai_router() -> AIRouter:
global _router_instance
if _router_instance is None:
_router_instance = AIRouter()
return _router_instance