Files
trade-assistant/backend/app/ai/router.py
T
TradeMate Dev 5d2bced39f docs: update project docs and clean up redundant files
- PROGRESS.md: update to 2026-05-29 with security hardening (T-005),
  4-frontend architecture, AI provider refactoring, discovery features,
  landing page/referral/quota, desktop layout, admin AI management
- AGENTS.md: add AI provider list (Alibaba/NVIDIA, removed Claude/DeepL/Local),
  DB-driven config, CSRF/rate-limit/CORS notes, admin_ai reload quirk
- .env.example: sync with actual config, replace deprecated providers
  with current Sensenova/OpencodeGo/NVIDIA/Spark/Alibaba
- docs/PROJECT_STATUS.md: archive (fully superseded by PROGRESS.md)
- Remove generated JS files (_bing_search.js, _batch_search.js)
- Remove empty directories (data/corpus, data/models)
- Remove backend/.coverage (test artifact)
- Fix services/.gitignore to cover _bing_search.js
- Include pending AI provider DB admin feature (admin_ai, AIProvider model,
  AIProviders.vue, migration) and T-008 test report
2026-05-29 11:15:33 +08:00

184 lines
7.4 KiB
Python

from typing import Dict, Any, Optional, List
from app.ai.base import AIProvider
from app.ai.providers import SparkProvider, SensenovaProvider, OpencodeGoProvider, NvidiaProvider, AlibabaMTProvider
from app.config import settings
from app.ai.trade_corpus import TradeCorpus
import logging
logger = logging.getLogger(__name__)
class AIRouter:
def __init__(self):
self.providers: Dict[str, AIProvider] = {}
self.routing_rules = settings.AI_ROUTING
self.corpus = TradeCorpus()
async def reload_from_db(self, db_session) -> int:
from app.models.ai_provider import AIProvider
from sqlalchemy import select
result = await db_session.execute(
select(AIProvider).where(AIProvider.enabled == True).order_by(AIProvider.priority)
)
rows = result.scalars().all()
new_providers: Dict[str, AIProvider] = {}
for p in rows:
inst = self._build_provider(p)
if inst:
key = p.id.hex if hasattr(p.id, 'hex') else str(p.id)
new_providers[key] = inst
new_providers[p.name] = inst
new_providers[p.provider_type] = inst
if new_providers:
self.providers = new_providers
logger.info(f"Loaded {len(rows)} AI providers from DB")
else:
logger.warning("No enabled AI providers found in DB")
return len(rows)
async def seed_from_env(self, db_session) -> int:
from app.models.ai_provider import AIProvider
count = 0
seeds = []
if settings.SENSENOVA_API_KEY:
seeds.append(AIProvider(
name="Sensenova (商汤)", provider_type="sensenova",
api_key=settings.SENSENOVA_API_KEY,
base_url=settings.SENSENOVA_BASE_URL,
model_name=settings.SENSENOVA_MODEL, priority=0, enabled=True,
))
if settings.OPENCODE_GO_API_KEY:
seeds.append(AIProvider(
name="OpencodeGo", provider_type="opencode_go",
api_key=settings.OPENCODE_GO_API_KEY,
base_url=settings.OPENCODE_GO_BASE_URL,
model_name=settings.OPENCODE_GO_MODEL, priority=1, enabled=True,
))
if settings.NVIDIA_API_KEY:
seeds.append(AIProvider(
name="NVIDIA", provider_type="nvidia",
api_key=settings.NVIDIA_API_KEY,
base_url=settings.NVIDIA_BASE_URL,
model_name=settings.NVIDIA_MODEL, priority=2, enabled=True,
))
if settings.IFLYTEK_API_KEY:
seeds.append(AIProvider(
name="讯飞 Spark", provider_type="spark",
api_key=settings.IFLYTEK_API_KEY,
base_url=settings.IFLYTEK_API_BASE,
model_name=settings.IFLYTEK_MODEL, priority=3, enabled=True,
))
if settings.ALIBABA_ACCESS_KEY_ID and settings.ALIBABA_ACCESS_KEY_SECRET:
seeds.append(AIProvider(
name="阿里翻译", provider_type="alibaba-mt",
api_key=settings.ALIBABA_ACCESS_KEY_ID,
api_secret=settings.ALIBABA_ACCESS_KEY_SECRET,
model_name="alibaba-mt", priority=4, enabled=True,
))
for p in seeds:
db_session.add(p)
count += 1
if count:
await db_session.commit()
logger.info(f"Seeded {count} AI providers from .env into DB")
return count
def schedule_reload(self):
self._needs_reload = True
logger.info("AI router scheduled for reload on next call")
def _build_provider(self, p) -> Optional[AIProvider]:
try:
t = p.provider_type
if t == "sensenova":
return SensenovaProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url)
elif t == "opencode_go":
return OpencodeGoProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url)
elif t == "nvidia":
return NvidiaProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url)
elif t == "spark":
return SparkProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url)
elif t == "alibaba-mt":
return AlibabaMTProvider(access_key_id=p.api_key, access_key_secret=p.api_secret or "")
else:
logger.warning(f"Unknown provider type: {t}")
return None
except Exception as e:
logger.warning(f"Failed to build provider {p.name}: {e}")
return None
def get_providers_for_task(self, task_type: str) -> List[AIProvider]:
rules = self.routing_rules.get(
task_type,
{"primary": "sensenova", "fallback": ["opencode_go"]},
)
ordered = []
seen = set()
primary = rules.get("primary")
if primary and primary in self.providers:
ordered.append(self.providers[primary])
seen.add(primary)
for name in rules.get("fallback", []):
if name in self.providers and name not in seen:
ordered.append(self.providers[name])
seen.add(name)
if not ordered:
ordered = list(self.providers.values())
logger.warning(f"No preferred providers for '{task_type}', using all available")
return ordered
async def execute(self, task_type: str, method: str, *args, **kwargs) -> Dict[str, Any]:
providers = self.get_providers_for_task(task_type)
last_error = None
for provider in providers:
try:
method_fn = getattr(provider, method)
result = await method_fn(*args, **kwargs)
result["provider_used"] = provider.name
return result
except NotImplementedError:
continue
except Exception as e:
logger.warning(f"{provider.name} failed for {task_type}: {e}")
last_error = e
continue
raise Exception(f"All providers failed for '{task_type}'. Last error: {last_error}")
async def translate(self, text: str, target_lang: str, source_lang: Optional[str] = None, context: Optional[str] = None) -> Dict[str, Any]:
return await self.execute("translate", "translate", text, source_lang, target_lang, context)
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional", preference_context: Optional[str] = None) -> Dict[str, Any]:
return await self.execute("reply", "reply", inquiry, context, tone, preference_context)
async def marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en", preference_context: Optional[str] = None) -> Dict[str, Any]:
return await self.execute("marketing", "generate_marketing", product_info, target, style, language, preference_context)
async def extract(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
return await self.execute("extract", "extract_info", text, schema)
async def chat(self, message: str, history: list = None, system_prompt: str = None) -> Dict[str, Any]:
return await self.execute("chat", "chat", message, history, system_prompt)
_router_instance = None
def get_ai_router() -> AIRouter:
global _router_instance
if _router_instance is None:
_router_instance = AIRouter()
return _router_instance