chore: post-deployment cleanup and docs update
- Make AI routing rules DB-driven (read from system_configs, removed from config.py) - Add translation quota tracking to LLM translation (OpenAIProvider) - Add Alibaba MT ECS RAM role support (STS token, no AccessKey needed) - Fix admin sidebar link for AI模型配置 page - Fix Quota.vue API path (quotas → translation-quotas) - Fix login auto-redirect to dashboard - Add provider dropdown selects to AI routing config UI - Clean up stale ai_provider_* system_configs records - Remove OpencodeGo, Spark providers (code + DB) - Update deploy config: nginx port 8000, systemd cwd
This commit is contained in:
@@ -9,7 +9,7 @@ import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
revision = "add_payment_transactions"
|
||||
down_revision = "add_ai_providers_table"
|
||||
down_revision = "add_ai_providers"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from .openai import OpenAIProvider
|
||||
from .spark import SparkProvider
|
||||
from .sensenova import SensenovaProvider
|
||||
from .opencode_go import OpencodeGoProvider
|
||||
from .nvidia import NvidiaProvider
|
||||
from .alibaba import AlibabaMTProvider
|
||||
|
||||
__all__ = ["OpenAIProvider", "SparkProvider", "SensenovaProvider", "OpencodeGoProvider", "NvidiaProvider", "AlibabaMTProvider"]
|
||||
__all__ = ["OpenAIProvider", "SensenovaProvider", "NvidiaProvider", "AlibabaMTProvider"]
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from aliyunsdkcore.client import AcsClient
|
||||
from aliyunsdkcore.auth.credentials import StsTokenCredential
|
||||
from aliyunsdkalimt.request.v20181012 import TranslateGeneralRequest, TranslateECommerceRequest
|
||||
from app.services.translation_quota import TranslationQuotaService
|
||||
from app.database import AsyncSessionLocal
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,11 +18,55 @@ ALIBABA_LANG_MAP = {
|
||||
"id": "id", "ms": "ms", "tl": "tl", "hi": "hi",
|
||||
}
|
||||
|
||||
ECS_METADATA_URL = "http://100.100.100.200/latest/meta-data/ram/security-credentials/"
|
||||
|
||||
|
||||
def _fetch_ecs_ram_credentials():
|
||||
try:
|
||||
import urllib.request
|
||||
req = urllib.request.Request(ECS_METADATA_URL, method="GET")
|
||||
with urllib.request.urlopen(req, timeout=2) as resp:
|
||||
role_name = resp.read().decode().strip()
|
||||
if not role_name:
|
||||
logger.warning("ECS metadata returned empty role name")
|
||||
return None
|
||||
url = f"{ECS_METADATA_URL}{role_name}"
|
||||
req = urllib.request.Request(url, method="GET")
|
||||
with urllib.request.urlopen(req, timeout=2) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
if data.get("Code") == "Success":
|
||||
logger.info(f"Fetched STS token for role {role_name}, expires {data.get('Expiration')}")
|
||||
return (data["AccessKeyId"], data["AccessKeySecret"], data["SecurityToken"])
|
||||
else:
|
||||
logger.warning(f"ECS metadata returned non-success: {data.get('Code')}")
|
||||
except Exception as e:
|
||||
logger.debug(f"ECS metadata fetch failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _build_acs_client(access_key_id: str = "", access_key_secret: str = "",
|
||||
region_id: str = "cn-hangzhou") -> AcsClient:
|
||||
creds = _fetch_ecs_ram_credentials()
|
||||
if creds:
|
||||
ak, sk, token = creds
|
||||
sts_cred = StsTokenCredential(ak, sk, token)
|
||||
client = AcsClient(credential=sts_cred, region_id=region_id)
|
||||
logger.info("Alibaba MT using ECS RAM role (STS token)")
|
||||
return client
|
||||
|
||||
ak = access_key_id or os.getenv("ALIBABA_ACCESS_KEY_ID", "")
|
||||
sk = access_key_secret or os.getenv("ALIBABA_ACCESS_KEY_SECRET", "")
|
||||
if ak and sk:
|
||||
logger.info("Alibaba MT using AccessKey credentials")
|
||||
return AcsClient(ak, sk, region_id)
|
||||
|
||||
raise ValueError("No Alibaba Cloud credentials found (neither ECS RAM role nor AccessKey)")
|
||||
|
||||
|
||||
class AlibabaMTProvider:
|
||||
def __init__(self, access_key_id: str, access_key_secret: str,
|
||||
def __init__(self, access_key_id: str = "", access_key_secret: str = "",
|
||||
region_id: str = "cn-hangzhou"):
|
||||
self.client = AcsClient(access_key_id, access_key_secret, region_id)
|
||||
self.client = _build_acs_client(access_key_id, access_key_secret, region_id)
|
||||
self._name = "alibaba-mt"
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str],
|
||||
|
||||
@@ -51,6 +51,20 @@ class OpenAIProvider(AIProvider):
|
||||
self._cheap_model = "gpt-4o-mini" if model == "gpt-4o" else model
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
from app.services.translation_quota import TranslationQuotaService
|
||||
from app.database import AsyncSessionLocal
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
quota_svc = TranslationQuotaService(db)
|
||||
if not await quota_svc.check_quota("llm"):
|
||||
raise Exception("LLM translation quota exhausted or disabled")
|
||||
result = await self._do_translate(text, source_lang, target_lang, context)
|
||||
if result and result.get("translated_text"):
|
||||
await quota_svc.consume("llm", len(text))
|
||||
await db.commit()
|
||||
return result
|
||||
|
||||
async def _do_translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["translate"]
|
||||
if context:
|
||||
system += f"\nContext: this is about {context}"
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from app.ai.providers.openai import OpenAIProvider
|
||||
|
||||
|
||||
class OpencodeGoProvider(OpenAIProvider):
|
||||
def __init__(self, api_key: str, model: str = "deepseek-v4-flash", base_url: str = "https://opencode.ai/zen/go/v1"):
|
||||
super().__init__(api_key=api_key, model=model, base_url=base_url)
|
||||
self._name = f"opencode-go-{model}"
|
||||
@@ -1,90 +0,0 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
from app.ai.base import AIProvider
|
||||
|
||||
|
||||
SYSTEM_PROMPTS = {
|
||||
"translate": "You are a professional translator specialized in foreign trade. "
|
||||
"Translate business terms accurately. Return ONLY the translated text.",
|
||||
"reply": "You are an experienced foreign trade sales expert. Write professional, "
|
||||
"clear business replies. Return ONLY the reply text.",
|
||||
"marketing": "You are a creative copywriter for international trade. "
|
||||
"Return ONLY the marketing copy, no explanations.",
|
||||
"extract": "Extract structured data from text. Return ONLY valid JSON.",
|
||||
}
|
||||
|
||||
|
||||
class SparkProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str = "astron-code-latest", base_url: str = None):
|
||||
from app.config import settings
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
except ImportError:
|
||||
raise ImportError("openai>=1.0 is required for SparkProvider")
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or settings.IFLYTEK_API_BASE,
|
||||
)
|
||||
self.model = model
|
||||
self._name = f"spark-{model}"
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["translate"]
|
||||
if context:
|
||||
system += f"\nContext: {context}"
|
||||
prompt = f"Translate {f'from {source_lang} ' if source_lang and source_lang != 'auto' else ''}to {target_lang}:\n\n{text}"
|
||||
content = await self._call(system, prompt)
|
||||
return {"translated_text": content, "provider": self.name}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["reply"] + f"\nTone: {tone}"
|
||||
if preference_context:
|
||||
system += f"\nUser preference: {preference_context}"
|
||||
ctx = ""
|
||||
if context:
|
||||
ctx = "\n".join(f"{k}: {v}" for k, v in context.items() if v)
|
||||
prompt = f"{ctx}\nCustomer inquiry:\n{inquiry}\n\nWrite a reply:"
|
||||
content = await self._call(system, prompt)
|
||||
return {"reply": content, "provider": self.name}
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["marketing"] + f"\nStyle: {style}\nAudience: {target}\nLanguage: {language}"
|
||||
if preference_context:
|
||||
system += f"\nUser preference: {preference_context}"
|
||||
info = json.dumps(product_info, ensure_ascii=False)
|
||||
prompt = f"Product:\n{info}\n\nGenerate marketing copy:"
|
||||
content = await self._call(system, prompt, max_tokens=1500)
|
||||
return {"content": content, "provider": self.name}
|
||||
|
||||
async def extract_info(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["extract"]
|
||||
prompt = f"Schema:\n{json.dumps(schema, indent=2)}\n\nText:\n{text}\n\nJSON:"
|
||||
content = await self._call(system, prompt, response_format={"type": "json_object"})
|
||||
try:
|
||||
data = json.loads(content)
|
||||
return {"data": data, "confidence": 0.9, "provider": self.name}
|
||||
except json.JSONDecodeError:
|
||||
return {"data": {}, "confidence": 0.0, "provider": self.name}
|
||||
|
||||
async def _call(self, system: str, prompt: str, max_tokens: int = 1000, response_format: Optional[Dict] = None) -> str:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
resp = await self.client.chat.completions.create(**kwargs)
|
||||
return resp.choices[0].message.content
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
return 0.0
|
||||
+59
-30
@@ -1,17 +1,26 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from app.ai.base import AIProvider
|
||||
from app.ai.providers import SparkProvider, SensenovaProvider, OpencodeGoProvider, NvidiaProvider, AlibabaMTProvider
|
||||
from app.config import settings
|
||||
from app.ai.providers import SensenovaProvider, NvidiaProvider, AlibabaMTProvider
|
||||
from app.ai.trade_corpus import TradeCorpus
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_ROUTING: Dict[str, dict] = {
|
||||
"translate": {"primary": "sensenova", "fallback": ["alibaba-mt", "nvidia"]},
|
||||
"reply": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"marketing": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"extract": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"quotation": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"chat": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
}
|
||||
|
||||
|
||||
class AIRouter:
|
||||
def __init__(self):
|
||||
self.providers: Dict[str, AIProvider] = {}
|
||||
self.routing_rules = settings.AI_ROUTING
|
||||
self.routing_rules = dict(DEFAULT_ROUTING)
|
||||
self.corpus = TradeCorpus()
|
||||
|
||||
async def reload_from_db(self, db_session) -> int:
|
||||
@@ -38,8 +47,47 @@ class AIRouter:
|
||||
else:
|
||||
logger.warning("No enabled AI providers found in DB")
|
||||
|
||||
await self._load_routing_rules(db_session)
|
||||
return len(rows)
|
||||
|
||||
async def _load_routing_rules(self, db_session):
|
||||
from app.models.system_config import SystemConfig
|
||||
from sqlalchemy import select
|
||||
|
||||
# Try consolidated key first
|
||||
result = await db_session.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == "ai_routing")
|
||||
)
|
||||
cfg = result.scalar_one_or_none()
|
||||
if cfg and isinstance(cfg.value, dict):
|
||||
self.routing_rules = {**DEFAULT_ROUTING, **cfg.value}
|
||||
logger.info("Loaded routing rules from system_configs (ai_routing)")
|
||||
return
|
||||
|
||||
# Fallback: load individual per-task keys
|
||||
task_keys = {
|
||||
"translate": "ai_provider_translate",
|
||||
"reply": "ai_provider_reply",
|
||||
"marketing": "ai_provider_marketing",
|
||||
"extract": "ai_provider_extract",
|
||||
"quotation": "ai_provider_quotation",
|
||||
}
|
||||
loaded = {}
|
||||
for task, key in task_keys.items():
|
||||
result = await db_session.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == key)
|
||||
)
|
||||
cfg = result.scalar_one_or_none()
|
||||
if cfg and isinstance(cfg.value, dict):
|
||||
loaded[task] = cfg.value
|
||||
|
||||
if loaded:
|
||||
self.routing_rules = {**DEFAULT_ROUTING, **loaded}
|
||||
logger.info(f"Loaded routing rules from system_configs (individual keys): {list(loaded.keys())}")
|
||||
else:
|
||||
self.routing_rules = dict(DEFAULT_ROUTING)
|
||||
logger.info("No routing rules in system_configs, using defaults")
|
||||
|
||||
async def seed_from_env(self, db_session) -> int:
|
||||
from app.models.ai_provider import AIProvider
|
||||
|
||||
@@ -53,34 +101,19 @@ class AIRouter:
|
||||
base_url=settings.SENSENOVA_BASE_URL,
|
||||
model_name=settings.SENSENOVA_MODEL, priority=0, enabled=True,
|
||||
))
|
||||
if settings.OPENCODE_GO_API_KEY:
|
||||
seeds.append(AIProvider(
|
||||
name="OpencodeGo", provider_type="opencode_go",
|
||||
api_key=settings.OPENCODE_GO_API_KEY,
|
||||
base_url=settings.OPENCODE_GO_BASE_URL,
|
||||
model_name=settings.OPENCODE_GO_MODEL, priority=1, enabled=True,
|
||||
))
|
||||
if settings.NVIDIA_API_KEY:
|
||||
seeds.append(AIProvider(
|
||||
name="NVIDIA", provider_type="nvidia",
|
||||
api_key=settings.NVIDIA_API_KEY,
|
||||
base_url=settings.NVIDIA_BASE_URL,
|
||||
model_name=settings.NVIDIA_MODEL, priority=2, enabled=True,
|
||||
))
|
||||
if settings.IFLYTEK_API_KEY:
|
||||
seeds.append(AIProvider(
|
||||
name="讯飞 Spark", provider_type="spark",
|
||||
api_key=settings.IFLYTEK_API_KEY,
|
||||
base_url=settings.IFLYTEK_API_BASE,
|
||||
model_name=settings.IFLYTEK_MODEL, priority=3, enabled=True,
|
||||
))
|
||||
if settings.ALIBABA_ACCESS_KEY_ID and settings.ALIBABA_ACCESS_KEY_SECRET:
|
||||
seeds.append(AIProvider(
|
||||
name="阿里翻译", provider_type="alibaba-mt",
|
||||
api_key=settings.ALIBABA_ACCESS_KEY_ID,
|
||||
api_secret=settings.ALIBABA_ACCESS_KEY_SECRET,
|
||||
model_name="alibaba-mt", priority=4, enabled=True,
|
||||
model_name=settings.NVIDIA_MODEL, priority=1, enabled=True,
|
||||
))
|
||||
seeds.append(AIProvider(
|
||||
name="阿里翻译", provider_type="alibaba-mt",
|
||||
api_key=settings.ALIBABA_ACCESS_KEY_ID or "",
|
||||
api_secret=settings.ALIBABA_ACCESS_KEY_SECRET or "",
|
||||
model_name="alibaba-mt", priority=3, enabled=True,
|
||||
))
|
||||
|
||||
for p in seeds:
|
||||
db_session.add(p)
|
||||
@@ -99,12 +132,8 @@ class AIRouter:
|
||||
t = p.provider_type
|
||||
if t == "sensenova":
|
||||
return SensenovaProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url)
|
||||
elif t == "opencode_go":
|
||||
return OpencodeGoProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url)
|
||||
elif t == "nvidia":
|
||||
return NvidiaProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url)
|
||||
elif t == "spark":
|
||||
return SparkProvider(api_key=p.api_key, model=p.model_name, base_url=p.base_url)
|
||||
elif t == "alibaba-mt":
|
||||
return AlibabaMTProvider(access_key_id=p.api_key, access_key_secret=p.api_secret or "")
|
||||
else:
|
||||
@@ -117,7 +146,7 @@ class AIRouter:
|
||||
def get_providers_for_task(self, task_type: str) -> List[AIProvider]:
|
||||
rules = self.routing_rules.get(
|
||||
task_type,
|
||||
{"primary": "sensenova", "fallback": ["opencode_go"]},
|
||||
{"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
)
|
||||
ordered = []
|
||||
seen = set()
|
||||
|
||||
@@ -42,9 +42,7 @@ class AIProviderUpdate(BaseModel):
|
||||
|
||||
PROVIDER_TYPE_LABELS = {
|
||||
"sensenova": "Sensenova (商汤)",
|
||||
"opencode_go": "OpencodeGo",
|
||||
"nvidia": "NVIDIA",
|
||||
"spark": "讯飞 Spark",
|
||||
"alibaba-mt": "阿里翻译",
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
ACTION_INSTRUCTIONS = """
|
||||
当用户想要执行操作时(如添加客户、创建产品、生成报价单等),请执行以下步骤:
|
||||
当用户想要执行操作时(如添加客户、创建产品、生成报价单、发送跟进、营销生成等),请执行以下步骤:
|
||||
1. 从用户消息中提取所有必要的信息
|
||||
2. 在回复末尾附上 JSON 格式的动作块,格式如下:
|
||||
|
||||
@@ -28,9 +28,20 @@ ACTION_INSTRUCTIONS = """
|
||||
[{"type": "create_customer", "label": "添加客户", "fields": {"name": "...", "phone": "...", "email": "...", "company": "...", "country": "...", "notes": "..."}}]
|
||||
```
|
||||
|
||||
支持的 action type:
|
||||
- create_customer:添加客户,fields 支持 name(必填), phone, email, company, country, notes
|
||||
- create_product:添加产品(开发中)
|
||||
支持的 action type 及字段说明:
|
||||
- create_customer:添加客户,fields 支持 name(必填), phone, email, company, country, website, notes
|
||||
- create_product:添加产品,fields 支持 name(必填), name_en, description, description_en, category, price, price_unit(默认USD), moq, keywords(逗号分隔)
|
||||
- create_quotation:生成报价单,fields 支持 customer_name(必填), product_info(必填), quantity(必填), price, terms
|
||||
- scan_followups:扫描待跟进客户,fields 不需要(空对象)
|
||||
- send_followup:发送跟进消息,fields 支持 customer_name(必填), message(必填)
|
||||
- generate_marketing:生成营销素材,fields 支持 product_name(必填), target_market, tone(如professional/casual), language
|
||||
- discovery_search:搜索潜在客户,fields 支持 keywords(必填), country, industry
|
||||
- navigate:跳转到指定页面,fields 支持 path(必填, 如 /customers /products /quotations /marketing /discovery /followup /translate /team /analytics)
|
||||
- search_users:搜索用户,fields 支持 query(必填)
|
||||
- update_user:修改用户信息,fields 支持 user_id(必填), username, phone, email, role, status
|
||||
- update_config:更新系统配置,fields 支持 key(必填), value(必填)
|
||||
- review_certification:审核认证,fields 支持 id(必填), action(approved/rejected), reason
|
||||
- process_invoice:处理发票,fields 支持 id(必填), action(approve/reject)
|
||||
|
||||
如果用户没有提供足够信息,请先询问缺少的字段,不要生成 action。
|
||||
如果用户明确表示要执行操作但缺少信息,生成 action 但标注缺失的字段。
|
||||
|
||||
@@ -32,13 +32,7 @@ class Settings(BaseSettings):
|
||||
SENSENOVA_BASE_URL: str = "https://token.sensenova.cn/v1"
|
||||
SENSENOVA_MODEL: str = "deepseek-v4-flash"
|
||||
|
||||
IFLYTEK_API_KEY: Optional[str] = None
|
||||
IFLYTEK_API_BASE: str = "https://maas-api.cn-huabei-1.xf-yun.com/v2"
|
||||
IFLYTEK_MODEL: str = "astron-code-latest"
|
||||
|
||||
OPENCODE_GO_API_KEY: Optional[str] = None
|
||||
OPENCODE_GO_BASE_URL: str = "https://opencode.ai/zen/go/v1"
|
||||
OPENCODE_GO_MODEL: str = "minimax-m2.7"
|
||||
|
||||
NVIDIA_API_KEY: Optional[str] = None
|
||||
NVIDIA_BASE_URL: str = "https://integrate.api.nvidia.com/v1"
|
||||
@@ -74,15 +68,6 @@ class Settings(BaseSettings):
|
||||
SENTRY_DSN: Optional[str] = None
|
||||
DEBUG: bool = True
|
||||
|
||||
AI_ROUTING: dict = {
|
||||
"translate": {"primary": "sensenova", "fallback": ["alibaba-mt", "opencode_go"]},
|
||||
"reply": {"primary": "sensenova", "fallback": ["opencode_go"]},
|
||||
"marketing": {"primary": "sensenova", "fallback": ["opencode_go"]},
|
||||
"extract": {"primary": "sensenova", "fallback": ["opencode_go"]},
|
||||
"quotation": {"primary": "sensenova", "fallback": ["opencode_go"]},
|
||||
"chat": {"primary": "sensenova", "fallback": ["opencode_go", "nvidia"]},
|
||||
}
|
||||
|
||||
FREE_DAILY_TRANSLATE_CHARS: int = 5000
|
||||
FREE_DAILY_REPLIES: int = 20
|
||||
FREE_DAILY_MARKETING: int = 5
|
||||
|
||||
@@ -288,11 +288,14 @@ class AdminService:
|
||||
|
||||
async def _seed_default_configs(self):
|
||||
defaults = [
|
||||
SystemConfig(key="ai_provider_translate", value={"primary": "sensenova", "fallback": ["alibaba-mt", "opencode_go"]}, description="翻译任务 AI 模型选择"),
|
||||
SystemConfig(key="ai_provider_reply", value={"primary": "sensenova", "fallback": ["opencode_go"]}, description="回复建议 AI 模型选择"),
|
||||
SystemConfig(key="ai_provider_marketing", value={"primary": "sensenova", "fallback": ["opencode_go"]}, description="营销文案 AI 模型选择"),
|
||||
SystemConfig(key="ai_provider_extract", value={"primary": "sensenova", "fallback": ["opencode_go"]}, description="信息提取 AI 模型选择"),
|
||||
SystemConfig(key="ai_provider_quotation", value={"primary": "sensenova", "fallback": ["opencode_go"]}, description="报价单 AI 模型选择"),
|
||||
SystemConfig(key="ai_routing", value={
|
||||
"translate": {"primary": "sensenova", "fallback": ["alibaba-mt", "nvidia"]},
|
||||
"reply": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"marketing": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"extract": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"quotation": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"chat": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
}, description="AI 路由规则:各任务的主选/备用供应商"),
|
||||
SystemConfig(key="feature_guest_mode", value={"enabled": True}, description="游客模式开关"),
|
||||
SystemConfig(key="feature_wechat_login", value={"enabled": False}, description="微信登录开关"),
|
||||
SystemConfig(key="feature_registration", value={"enabled": True}, description="新用户注册开关"),
|
||||
@@ -305,6 +308,19 @@ class AdminService:
|
||||
self.db.add(cfg)
|
||||
await self.db.flush()
|
||||
|
||||
async def _migrate_routing_configs(self):
|
||||
from sqlalchemy import delete
|
||||
# Remove stale individual routing keys (replaced by consolidated ai_routing)
|
||||
stale_prefixes = ["ai_provider_translate", "ai_provider_reply", "ai_provider_marketing",
|
||||
"ai_provider_extract", "ai_provider_quotation"]
|
||||
for key in stale_prefixes:
|
||||
await self.db.execute(
|
||||
delete(SystemConfig).where(SystemConfig.key == key)
|
||||
)
|
||||
if stale_prefixes:
|
||||
await self.db.flush()
|
||||
logger.info("Cleaned up stale ai_provider_* routing configs")
|
||||
|
||||
async def list_config(self) -> List[Dict[str, Any]]:
|
||||
result = await self.db.execute(
|
||||
select(func.count(SystemConfig.id))
|
||||
@@ -312,6 +328,28 @@ class AdminService:
|
||||
if result.scalar() == 0:
|
||||
await self._seed_default_configs()
|
||||
|
||||
await self._migrate_routing_configs()
|
||||
|
||||
# Ensure consolidated ai_routing exists
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == "ai_routing")
|
||||
)
|
||||
if not result.scalar_one_or_none():
|
||||
self.db.add(SystemConfig(
|
||||
key="ai_routing",
|
||||
value={
|
||||
"translate": {"primary": "sensenova", "fallback": ["alibaba-mt", "nvidia"]},
|
||||
"reply": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"marketing": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"extract": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"quotation": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
"chat": {"primary": "sensenova", "fallback": ["nvidia"]},
|
||||
},
|
||||
description="AI 路由规则:各任务的主选/备用供应商",
|
||||
))
|
||||
await self.db.flush()
|
||||
logger.info("Seeded ai_routing config")
|
||||
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).order_by(SystemConfig.key)
|
||||
)
|
||||
@@ -336,6 +374,12 @@ class AdminService:
|
||||
config.value = value
|
||||
config.updated_at = datetime.utcnow()
|
||||
await self.db.flush()
|
||||
|
||||
if key == "ai_routing":
|
||||
from app.ai.router import get_ai_router
|
||||
await get_ai_router().reload_from_db(self.db)
|
||||
logger.info("AI router reloaded after ai_routing config update")
|
||||
|
||||
return {
|
||||
"key": config.key,
|
||||
"value": config.value,
|
||||
|
||||
@@ -12,6 +12,10 @@ class TranslationQuotaService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
def _default_desc(self, version: str) -> str:
|
||||
labels = {"ecommerce": "阿里云翻译电商版", "general": "阿里云翻译通用版", "llm": "AI模型翻译"}
|
||||
return labels.get(version, f"阿里云翻译{version}版")
|
||||
|
||||
async def _get_or_create(self, version: str) -> TranslationQuota:
|
||||
result = await self.db.execute(
|
||||
select(TranslationQuota).where(TranslationQuota.version == version)
|
||||
@@ -25,7 +29,7 @@ class TranslationQuotaService:
|
||||
used_chars=0,
|
||||
current_month=now.strftime("%Y-%m"),
|
||||
enabled=True,
|
||||
description=f"阿里云翻译{version}版",
|
||||
description=self._default_desc(version),
|
||||
)
|
||||
self.db.add(quota)
|
||||
await self.db.flush()
|
||||
@@ -57,8 +61,7 @@ class TranslationQuotaService:
|
||||
return remaining
|
||||
|
||||
async def get_all_quotas(self) -> list:
|
||||
default_versions = ["ecommerce", "general"]
|
||||
for v in default_versions:
|
||||
for v in ("ecommerce", "general", "llm"):
|
||||
await self._get_or_create(v)
|
||||
|
||||
result = await self.db.execute(select(TranslationQuota).order_by(TranslationQuota.version))
|
||||
|
||||
@@ -14,11 +14,12 @@ class TestConfig:
|
||||
assert settings.REFRESH_TOKEN_EXPIRE_DAYS == 30
|
||||
|
||||
def test_ai_routing_config(self):
|
||||
assert "translate" in settings.AI_ROUTING
|
||||
assert "reply" in settings.AI_ROUTING
|
||||
assert "marketing" in settings.AI_ROUTING
|
||||
assert "extract" in settings.AI_ROUTING
|
||||
assert "primary" in settings.AI_ROUTING["translate"]
|
||||
from app.ai.router import DEFAULT_ROUTING
|
||||
assert "translate" in DEFAULT_ROUTING
|
||||
assert "reply" in DEFAULT_ROUTING
|
||||
assert "marketing" in DEFAULT_ROUTING
|
||||
assert "extract" in DEFAULT_ROUTING
|
||||
assert "primary" in DEFAULT_ROUTING["translate"]
|
||||
|
||||
def test_free_tier_limits(self):
|
||||
assert settings.FREE_DAILY_TRANSLATE_CHARS == 5000
|
||||
|
||||
Reference in New Issue
Block a user