Files
trade-assistant/backend/app/ai/router.py
T
TradeMate Dev 7b62c2f8b4 feat: 修复 H5 底部导航覆盖 + 更新项目进度文档
## H5 底部导航修复 (Bug #10)
- 精简 App.vue,移除重复 tabbar,仅保留全局样式
- uni-page 设置 height: calc(100% - 50px) + overflow-y: auto
- 内容区域精确停在底部导航上方,独立滚动不再叠加
- 恢复 custom-tab-bar 组件

## 项目进度文档
- PROGRESS.md 更新至 10 个 Bug 修复
- 新增 H5 底部导航修复记录
- 新增历史变更条目
2026-05-12 20:24:42 +08:00

133 lines
5.3 KiB
Python

from typing import Dict, Any, Optional, List
from app.ai.base import AIProvider
from app.ai.providers import OpenAIProvider, ClaudeProvider, DeepLProvider, LocalProvider, SparkProvider, SensenovaProvider
from app.config import settings
from app.ai.trade_corpus import TradeCorpus
import logging
logger = logging.getLogger(__name__)
class AIRouter:
def __init__(self):
self.providers: Dict[str, AIProvider] = {}
self.routing_rules = settings.AI_ROUTING
self.corpus = TradeCorpus()
self._init_providers()
def _init_providers(self):
if settings.OPENAI_API_KEY:
try:
self.providers["openai"] = OpenAIProvider(api_key=settings.OPENAI_API_KEY)
logger.info("OpenAI provider ready")
except Exception as e:
logger.warning(f"OpenAI init failed: {e}")
if settings.SENSENOVA_API_KEY:
try:
self.providers["sensenova"] = SensenovaProvider(
api_key=settings.SENSENOVA_API_KEY,
model=settings.SENSENOVA_MODEL,
base_url=settings.SENSENOVA_BASE_URL,
)
logger.info("Sensenova provider ready")
except Exception as e:
logger.warning(f"Sensenova init failed: {e}")
if settings.ANTHROPIC_API_KEY:
try:
self.providers["anthropic"] = ClaudeProvider(api_key=settings.ANTHROPIC_API_KEY)
logger.info("Claude provider ready")
except Exception as e:
logger.warning(f"Claude init failed: {e}")
if settings.DEEPL_API_KEY:
try:
self.providers["deepl"] = DeepLProvider(api_key=settings.DEEPL_API_KEY)
logger.info("DeepL provider ready")
except Exception as e:
logger.warning(f"DeepL init failed: {e}")
if settings.IFLYTEK_API_KEY:
try:
self.providers["spark"] = SparkProvider(
api_key=settings.IFLYTEK_API_KEY,
model=settings.IFLYTEK_MODEL,
base_url=settings.IFLYTEK_API_BASE,
)
logger.info("Spark provider ready")
except Exception as e:
logger.warning(f"Spark init failed: {e}")
if settings.LOCAL_MODEL_ENABLED:
try:
self.providers["local"] = LocalProvider(model_url=settings.LOCAL_MODEL_URL)
logger.info("Local provider ready")
except Exception as e:
logger.warning(f"Local init failed: {e}")
def get_providers_for_task(self, task_type: str) -> List[AIProvider]:
rules = self.routing_rules.get(
task_type,
{"primary": "openai", "fallback": ["local"]},
)
ordered = []
seen = set()
primary = rules.get("primary")
if primary and primary in self.providers:
ordered.append(self.providers[primary])
seen.add(primary)
for name in rules.get("fallback", []):
if name in self.providers and name not in seen:
ordered.append(self.providers[name])
seen.add(name)
if not ordered:
ordered = list(self.providers.values())
logger.warning(f"No preferred providers for '{task_type}', using all available")
return ordered
async def execute(self, task_type: str, method: str, *args, **kwargs) -> Dict[str, Any]:
providers = self.get_providers_for_task(task_type)
last_error = None
for provider in providers:
try:
method_fn = getattr(provider, method)
result = await method_fn(*args, **kwargs)
result["provider_used"] = provider.name
return result
except NotImplementedError:
continue
except Exception as e:
logger.warning(f"{provider.name} failed for {task_type}: {e}")
last_error = e
continue
raise Exception(f"All providers failed for '{task_type}'. Last error: {last_error}")
async def translate(self, text: str, target_lang: str, source_lang: Optional[str] = None, context: Optional[str] = None) -> Dict[str, Any]:
return await self.execute("translate", "translate", text, source_lang, target_lang, context)
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional", preference_context: Optional[str] = None) -> Dict[str, Any]:
return await self.execute("reply", "reply", inquiry, context, tone, preference_context)
async def marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en", preference_context: Optional[str] = None) -> Dict[str, Any]:
return await self.execute("marketing", "generate_marketing", product_info, target, style, language, preference_context)
async def extract(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
return await self.execute("extract", "extract_info", text, schema)
_router_instance = None
def get_ai_router() -> AIRouter:
global _router_instance
if _router_instance is None:
_router_instance = AIRouter()
return _router_instance