f17a6ccbac
- Make AI routing rules DB-driven (read from system_configs, removed from config.py) - Add translation quota tracking to LLM translation (OpenAIProvider) - Add Alibaba MT ECS RAM role support (STS token, no AccessKey needed) - Fix admin sidebar link for AI模型配置 page - Fix Quota.vue API path (quotas → translation-quotas) - Fix login auto-redirect to dashboard - Add provider dropdown selects to AI routing config UI - Clean up stale ai_provider_* system_configs records - Remove OpencodeGo, Spark providers (code + DB) - Update deploy config: nginx port 8000, systemd cwd
213 lines
8.4 KiB
Python
213 lines
8.4 KiB
Python
from typing import Dict, Any, Optional, List
|
|
from app.ai.base import AIProvider
|
|
from app.ai.providers import SensenovaProvider, NvidiaProvider, AlibabaMTProvider
|
|
from app.ai.trade_corpus import TradeCorpus
|
|
from app.config import settings
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_ROUTING: Dict[str, dict] = {
|
|
"translate": {"primary": "sensenova", "fallback": ["alibaba-mt", "nvidia"]},
|
|
"reply": {"primary": "sensenova", "fallback": ["nvidia"]},
|
|
"marketing": {"primary": "sensenova", "fallback": ["nvidia"]},
|
|
"extract": {"primary": "sensenova", "fallback": ["nvidia"]},
|
|
"quotation": {"primary": "sensenova", "fallback": ["nvidia"]},
|
|
"chat": {"primary": "sensenova", "fallback": ["nvidia"]},
|
|
}
|
|
|
|
|
|
class AIRouter:
|
|
def __init__(self):
|
|
self.providers: Dict[str, AIProvider] = {}
|
|
self.routing_rules = dict(DEFAULT_ROUTING)
|
|
self.corpus = TradeCorpus()
|
|
|
|
async def reload_from_db(self, db_session) -> int:
|
|
from app.models.ai_provider import AIProvider
|
|
from sqlalchemy import select
|
|
|
|
result = await db_session.execute(
|
|
select(AIProvider).where(AIProvider.enabled == True).order_by(AIProvider.priority)
|
|
)
|
|
rows = result.scalars().all()
|
|
|
|
new_providers: Dict[str, AIProvider] = {}
|
|
for p in rows:
|
|
inst = self._build_provider(p)
|
|
if inst:
|
|
key = p.id.hex if hasattr(p.id, 'hex') else str(p.id)
|
|
new_providers[key] = inst
|
|
new_providers[p.name] = inst
|
|
new_providers[p.provider_type] = inst
|
|
|
|
if new_providers:
|
|
self.providers = new_providers
|
|
logger.info(f"Loaded {len(rows)} AI providers from DB")
|
|
else:
|
|
logger.warning("No enabled AI providers found in DB")
|
|
|
|
await self._load_routing_rules(db_session)
|
|
return len(rows)
|
|
|
|
async def _load_routing_rules(self, db_session):
|
|
from app.models.system_config import SystemConfig
|
|
from sqlalchemy import select
|
|
|
|
# Try consolidated key first
|
|
result = await db_session.execute(
|
|
select(SystemConfig).where(SystemConfig.key == "ai_routing")
|
|
)
|
|
cfg = result.scalar_one_or_none()
|
|
if cfg and isinstance(cfg.value, dict):
|
|
self.routing_rules = {**DEFAULT_ROUTING, **cfg.value}
|
|
logger.info("Loaded routing rules from system_configs (ai_routing)")
|
|
return
|
|
|
|
# Fallback: load individual per-task keys
|
|
task_keys = {
|
|
"translate": "ai_provider_translate",
|
|
"reply": "ai_provider_reply",
|
|
"marketing": "ai_provider_marketing",
|
|
"extract": "ai_provider_extract",
|
|
"quotation": "ai_provider_quotation",
|
|
}
|
|
loaded = {}
|
|
for task, key in task_keys.items():
|
|
result = await db_session.execute(
|
|
select(SystemConfig).where(SystemConfig.key == key)
|
|
)
|
|
cfg = result.scalar_one_or_none()
|
|
if cfg and isinstance(cfg.value, dict):
|
|
loaded[task] = cfg.value
|
|
|
|
if loaded:
|
|
self.routing_rules = {**DEFAULT_ROUTING, **loaded}
|
|
logger.info(f"Loaded routing rules from system_configs (individual keys): {list(loaded.keys())}")
|
|
else:
|
|
self.routing_rules = dict(DEFAULT_ROUTING)
|
|
logger.info("No routing rules in system_configs, using defaults")
|
|
|
|
async def seed_from_env(self, db_session) -> int:
|
|
from app.models.ai_provider import AIProvider
|
|
|
|
count = 0
|
|
seeds = []
|
|
|
|
if settings.SENSENOVA_API_KEY:
|
|
seeds.append(AIProvider(
|
|
name="Sensenova (商汤)", provider_type="sensenova",
|
|
api_key=settings.SENSENOVA_API_KEY,
|
|
base_url=settings.SENSENOVA_BASE_URL,
|
|
model_name=settings.SENSENOVA_MODEL, priority=0, enabled=True,
|
|
))
|
|
if settings.NVIDIA_API_KEY:
|
|
seeds.append(AIProvider(
|
|
name="NVIDIA", provider_type="nvidia",
|
|
api_key=settings.NVIDIA_API_KEY,
|
|
base_url=settings.NVIDIA_BASE_URL,
|
|
model_name=settings.NVIDIA_MODEL, priority=1, enabled=True,
|
|
))
|
|
seeds.append(AIProvider(
|
|
name="阿里翻译", provider_type="alibaba-mt",
|
|
api_key=settings.ALIBABA_ACCESS_KEY_ID or "",
|
|
api_secret=settings.ALIBABA_ACCESS_KEY_SECRET or "",
|
|
model_name="alibaba-mt", priority=3, enabled=True,
|
|
))
|
|
|
|
for p in seeds:
|
|
db_session.add(p)
|
|
count += 1
|
|
if count:
|
|
await db_session.commit()
|
|
logger.info(f"Seeded {count} AI providers from .env into DB")
|
|
return count
|
|
|
|
def schedule_reload(self):
|
|
self._needs_reload = True
|
|
logger.info("AI router scheduled for reload on next call")
|
|
|
|
def _build_provider(self, p) -> Optional[AIProvider]:
|
|
try:
|
|
t = p.provider_type
|
|
if t == "sensenova":
|
|
return SensenovaProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url)
|
|
elif t == "nvidia":
|
|
return NvidiaProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url)
|
|
elif t == "alibaba-mt":
|
|
return AlibabaMTProvider(access_key_id=p.api_key, access_key_secret=p.api_secret or "")
|
|
else:
|
|
logger.warning(f"Unknown provider type: {t}")
|
|
return None
|
|
except Exception as e:
|
|
logger.warning(f"Failed to build provider {p.name}: {e}")
|
|
return None
|
|
|
|
def get_providers_for_task(self, task_type: str) -> List[AIProvider]:
|
|
rules = self.routing_rules.get(
|
|
task_type,
|
|
{"primary": "sensenova", "fallback": ["nvidia"]},
|
|
)
|
|
ordered = []
|
|
seen = set()
|
|
|
|
primary = rules.get("primary")
|
|
if primary and primary in self.providers:
|
|
ordered.append(self.providers[primary])
|
|
seen.add(primary)
|
|
|
|
for name in rules.get("fallback", []):
|
|
if name in self.providers and name not in seen:
|
|
ordered.append(self.providers[name])
|
|
seen.add(name)
|
|
|
|
if not ordered:
|
|
ordered = list(self.providers.values())
|
|
logger.warning(f"No preferred providers for '{task_type}', using all available")
|
|
|
|
return ordered
|
|
|
|
async def execute(self, task_type: str, method: str, *args, **kwargs) -> Dict[str, Any]:
|
|
providers = self.get_providers_for_task(task_type)
|
|
last_error = None
|
|
|
|
for provider in providers:
|
|
try:
|
|
method_fn = getattr(provider, method)
|
|
result = await method_fn(*args, **kwargs)
|
|
result["provider_used"] = provider.name
|
|
return result
|
|
except NotImplementedError:
|
|
continue
|
|
except Exception as e:
|
|
logger.warning(f"{provider.name} failed for {task_type}: {e}")
|
|
last_error = e
|
|
continue
|
|
|
|
raise Exception(f"All providers failed for '{task_type}'. Last error: {last_error}")
|
|
|
|
async def translate(self, text: str, target_lang: str, source_lang: Optional[str] = None, context: Optional[str] = None) -> Dict[str, Any]:
|
|
return await self.execute("translate", "translate", text, source_lang, target_lang, context)
|
|
|
|
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
|
return await self.execute("reply", "reply", inquiry, context, tone, preference_context)
|
|
|
|
async def marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
|
return await self.execute("marketing", "generate_marketing", product_info, target, style, language, preference_context)
|
|
|
|
async def extract(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
return await self.execute("extract", "extract_info", text, schema)
|
|
|
|
async def chat(self, message: str, history: list = None, system_prompt: str = None) -> Dict[str, Any]:
|
|
return await self.execute("chat", "chat", message, history, system_prompt)
|
|
|
|
|
|
_router_instance = None
|
|
|
|
|
|
def get_ai_router() -> AIRouter:
|
|
global _router_instance
|
|
if _router_instance is None:
|
|
_router_instance = AIRouter()
|
|
return _router_instance
|