from typing import Dict, Any, Optional, List from app.ai.base import AIProvider from app.ai.providers import SparkProvider, SensenovaProvider, OpencodeGoProvider, NvidiaProvider, AlibabaMTProvider from app.config import settings from app.ai.trade_corpus import TradeCorpus import logging logger = logging.getLogger(__name__) class AIRouter: def __init__(self): self.providers: Dict[str, AIProvider] = {} self.routing_rules = settings.AI_ROUTING self.corpus = TradeCorpus() async def reload_from_db(self, db_session) -> int: from app.models.ai_provider import AIProvider from sqlalchemy import select result = await db_session.execute( select(AIProvider).where(AIProvider.enabled == True).order_by(AIProvider.priority) ) rows = result.scalars().all() new_providers: Dict[str, AIProvider] = {} for p in rows: inst = self._build_provider(p) if inst: key = p.id.hex if hasattr(p.id, 'hex') else str(p.id) new_providers[key] = inst new_providers[p.name] = inst new_providers[p.provider_type] = inst if new_providers: self.providers = new_providers logger.info(f"Loaded {len(rows)} AI providers from DB") else: logger.warning("No enabled AI providers found in DB") return len(rows) async def seed_from_env(self, db_session) -> int: from app.models.ai_provider import AIProvider count = 0 seeds = [] if settings.SENSENOVA_API_KEY: seeds.append(AIProvider( name="Sensenova (商汤)", provider_type="sensenova", api_key=settings.SENSENOVA_API_KEY, base_url=settings.SENSENOVA_BASE_URL, model_name=settings.SENSENOVA_MODEL, priority=0, enabled=True, )) if settings.OPENCODE_GO_API_KEY: seeds.append(AIProvider( name="OpencodeGo", provider_type="opencode_go", api_key=settings.OPENCODE_GO_API_KEY, base_url=settings.OPENCODE_GO_BASE_URL, model_name=settings.OPENCODE_GO_MODEL, priority=1, enabled=True, )) if settings.NVIDIA_API_KEY: seeds.append(AIProvider( name="NVIDIA", provider_type="nvidia", api_key=settings.NVIDIA_API_KEY, base_url=settings.NVIDIA_BASE_URL, model_name=settings.NVIDIA_MODEL, priority=2, enabled=True, )) if settings.IFLYTEK_API_KEY: seeds.append(AIProvider( name="讯飞 Spark", provider_type="spark", api_key=settings.IFLYTEK_API_KEY, base_url=settings.IFLYTEK_API_BASE, model_name=settings.IFLYTEK_MODEL, priority=3, enabled=True, )) if settings.ALIBABA_ACCESS_KEY_ID and settings.ALIBABA_ACCESS_KEY_SECRET: seeds.append(AIProvider( name="阿里翻译", provider_type="alibaba-mt", api_key=settings.ALIBABA_ACCESS_KEY_ID, api_secret=settings.ALIBABA_ACCESS_KEY_SECRET, model_name="alibaba-mt", priority=4, enabled=True, )) for p in seeds: db_session.add(p) count += 1 if count: await db_session.commit() logger.info(f"Seeded {count} AI providers from .env into DB") return count def schedule_reload(self): self._needs_reload = True logger.info("AI router scheduled for reload on next call") def _build_provider(self, p) -> Optional[AIProvider]: try: t = p.provider_type if t == "sensenova": return SensenovaProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url) elif t == "opencode_go": return OpencodeGoProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url) elif t == "nvidia": return NvidiaProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url) elif t == "spark": return SparkProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url) elif t == "alibaba-mt": return AlibabaMTProvider(access_key_id=p.api_key, access_key_secret=p.api_secret or "") else: logger.warning(f"Unknown provider type: {t}") return None except Exception as e: logger.warning(f"Failed to build provider {p.name}: {e}") return None def get_providers_for_task(self, task_type: str) -> List[AIProvider]: rules = self.routing_rules.get( task_type, {"primary": "sensenova", "fallback": ["opencode_go"]}, ) ordered = [] seen = set() primary = rules.get("primary") if primary and primary in self.providers: ordered.append(self.providers[primary]) seen.add(primary) for name in rules.get("fallback", []): if name in self.providers and name not in seen: ordered.append(self.providers[name]) seen.add(name) if not ordered: ordered = list(self.providers.values()) logger.warning(f"No preferred providers for '{task_type}', using all available") return ordered async def execute(self, task_type: str, method: str, *args, **kwargs) -> Dict[str, Any]: providers = self.get_providers_for_task(task_type) last_error = None for provider in providers: try: method_fn = getattr(provider, method) result = await method_fn(*args, **kwargs) result["provider_used"] = provider.name return result except NotImplementedError: continue except Exception as e: logger.warning(f"{provider.name} failed for {task_type}: {e}") last_error = e continue raise Exception(f"All providers failed for '{task_type}'. Last error: {last_error}") async def translate(self, text: str, target_lang: str, source_lang: Optional[str] = None, context: Optional[str] = None) -> Dict[str, Any]: return await self.execute("translate", "translate", text, source_lang, target_lang, context) async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional", preference_context: Optional[str] = None) -> Dict[str, Any]: return await self.execute("reply", "reply", inquiry, context, tone, preference_context) async def marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en", preference_context: Optional[str] = None) -> Dict[str, Any]: return await self.execute("marketing", "generate_marketing", product_info, target, style, language, preference_context) async def extract(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]: return await self.execute("extract", "extract_info", text, schema) async def chat(self, message: str, history: list = None, system_prompt: str = None) -> Dict[str, Any]: return await self.execute("chat", "chat", message, history, system_prompt) _router_instance = None def get_ai_router() -> AIRouter: global _router_instance if _router_instance is None: _router_instance = AIRouter() return _router_instance