Initial commit: TradeMate 外贸小助手 MVP
项目结构: - backend/ Python FastAPI 后端 - uni-app/ uni-app跨端前端 - docs/ 设计文档 - docker-compose.yml Docker编排 - nginx/scripts/systemd 运维配置 已完成功能: - 用户认证 (JWT) - 智能翻译 + 回复建议 - 营销素材生成 - 客户管理 + 沉默检测 - 报价单管理 - 产品库管理 - 汇率换算 - 推送通知 (uni-push) - WhatsApp Webhook框架 - Celery定时任务
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from .router import get_ai_router
|
||||
|
||||
__all__ = ["get_ai_router"]
|
||||
@@ -0,0 +1,45 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
@abstractmethod
|
||||
async def translate(
|
||||
self, text: str, source_lang: Optional[str], target_lang: str,
|
||||
context: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reply(
|
||||
self, inquiry: str, context: Optional[Dict[str, Any]] = None,
|
||||
tone: str = "professional",
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_marketing(
|
||||
self, product_info: Dict[str, Any], target: str,
|
||||
style: str = "professional", language: str = "en",
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def extract_info(
|
||||
self, text: str, schema: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
pass
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
return False
|
||||
@@ -0,0 +1,6 @@
|
||||
from .openai import OpenAIProvider
|
||||
from .claude import ClaudeProvider
|
||||
from .deepl import DeepLProvider
|
||||
from .local import LocalProvider
|
||||
|
||||
__all__ = ["OpenAIProvider", "ClaudeProvider", "DeepLProvider", "LocalProvider"]
|
||||
@@ -0,0 +1,83 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
from anthropic import AsyncAnthropic
|
||||
from app.ai.base import AIProvider
|
||||
|
||||
|
||||
SYSTEM_PROMPTS = {
|
||||
"marketing": "You are a world-class copywriter for international trade. Write persuasive, "
|
||||
"culturally-adapted marketing content that converts. You excel at storytelling "
|
||||
"and emotional appeal in business contexts.",
|
||||
"reply": "You are a senior international sales representative with 20 years of experience. "
|
||||
"Your replies are warm, professional, and strategically move the conversation "
|
||||
"toward closing the deal.",
|
||||
"translate": "You are a professional translator specializing in trade documents. "
|
||||
"Preserve all numbers, terms, and formatting. Translate meaning, not words.",
|
||||
"extract": "Extract structured data from text. Return ONLY valid JSON.",
|
||||
}
|
||||
|
||||
|
||||
class ClaudeProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str = "claude-sonnet-4-20250514"):
|
||||
self.client = AsyncAnthropic(api_key=api_key)
|
||||
self.model = model
|
||||
self._name = f"claude-sonnet"
|
||||
self._pricing = {"input": 0.003, "output": 0.015}
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["translate"]
|
||||
if context:
|
||||
system += f"\nContext: {context}"
|
||||
prompt = f"Translate to {target_lang}:\n\n{text}"
|
||||
content = await self._call(system, prompt)
|
||||
return {"translated_text": content, "provider": self.name}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["reply"]
|
||||
context_str = ""
|
||||
if context:
|
||||
for k, v in context.items():
|
||||
if v:
|
||||
context_str += f"{k}: {v}\n"
|
||||
prompt = f"{context_str}\nCustomer says:\n{inquiry}\n\nYour reply ({tone} tone):"
|
||||
content = await self._call(system, prompt)
|
||||
return {"reply": content, "provider": self.name}
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["marketing"]
|
||||
info = json.dumps(product_info, ensure_ascii=False, indent=2)
|
||||
prompt = f"Product:\n{info}\n\nTarget: {target}\nStyle: {style}\nLanguage: {language}\n\nWrite marketing copy:"
|
||||
content = await self._call(system, prompt, max_tokens=1500)
|
||||
return {"content": content, "provider": self.name}
|
||||
|
||||
async def extract_info(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["extract"]
|
||||
prompt = f"Schema:\n{json.dumps(schema, indent=2)}\n\nText:\n{text}\n\nJSON:"
|
||||
content = await self._call(system, prompt, max_tokens=1000)
|
||||
try:
|
||||
data = json.loads(content)
|
||||
return {"data": data, "confidence": 0.9, "provider": self.name}
|
||||
except json.JSONDecodeError:
|
||||
return {"data": {}, "confidence": 0.0, "provider": self.name, "error": "parse_failed"}
|
||||
|
||||
async def _call(self, system: str, prompt: str, max_tokens: int = 1000) -> str:
|
||||
resp = await self.client.messages.create(
|
||||
model=self.model,
|
||||
system=system,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.7,
|
||||
)
|
||||
return resp.content[0].text
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
return (self._pricing["input"] + self._pricing["output"]) / 2
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import httpx
|
||||
from app.ai.base import AIProvider
|
||||
|
||||
|
||||
class DeepLProvider(AIProvider):
|
||||
def __init__(self, api_key: str, endpoint: str = "https://api.deepl.com/v2"):
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self._name = "deepl"
|
||||
self._cost_per_char = 0.000006
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
params = {
|
||||
"auth_key": self.api_key,
|
||||
"text": text,
|
||||
"target_lang": target_lang.upper()[:2],
|
||||
}
|
||||
if source_lang and source_lang != "auto":
|
||||
params["source_lang"] = source_lang.upper()[:2]
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(f"{self.endpoint}/translate", data=params, timeout=15)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
t = data["translations"][0]
|
||||
return {
|
||||
"translated_text": t["text"],
|
||||
"provider": self.name,
|
||||
"detected_source_lang": t.get("detected_source_language", source_lang),
|
||||
"char_count": len(text),
|
||||
"cost": len(text) * self._cost_per_char,
|
||||
}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
raise NotImplementedError("DeepL does not support reply generation")
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
raise NotImplementedError("DeepL does not support marketing generation")
|
||||
|
||||
async def extract_info(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
raise NotImplementedError("DeepL does not support info extraction")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
return self._cost_per_char * 1000
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import json, httpx
|
||||
from app.ai.base import AIProvider
|
||||
|
||||
|
||||
class LocalProvider(AIProvider):
|
||||
def __init__(self, model_url: str = "http://localhost:8001", model_name: str = "gemma-3-8b"):
|
||||
self.model_url = model_url.rstrip("/")
|
||||
self.model_name = model_name
|
||||
self._name = f"local-{model_name}"
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
prompt = f"Translate{ f' from {source_lang}' if source_lang else ''} to {target_lang}:\n{text}\n\nTranslation:"
|
||||
result = await self._generate(prompt)
|
||||
return {"translated_text": result, "provider": self.name, "cost": 0.0}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
ctx = ""
|
||||
if context:
|
||||
ctx = "\n".join(f"{k}: {v}" for k, v in context.items() if v)
|
||||
prompt = f"{ctx}\nCustomer: {inquiry}\n\nWrite a {tone} reply:"
|
||||
result = await self._generate(prompt)
|
||||
return {"reply": result, "provider": self.name, "cost": 0.0}
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
info = json.dumps(product_info, ensure_ascii=False)
|
||||
prompt = f"Product: {info}\nTarget: {target}\nStyle: {style}\nLanguage: {language}\n\nMarketing copy:"
|
||||
result = await self._generate(prompt, max_tokens=800)
|
||||
return {"content": result, "provider": self.name, "cost": 0.0}
|
||||
|
||||
async def extract_info(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
prompt = f"Extract JSON from text matching schema:\nSchema: {json.dumps(schema)}\n\nText: {text}\n\nJSON:"
|
||||
result = await self._generate(prompt, max_tokens=500)
|
||||
try:
|
||||
return {"data": json.loads(result), "confidence": 0.7, "provider": self.name, "cost": 0.0}
|
||||
except json.JSONDecodeError:
|
||||
return {"data": {}, "confidence": 0.0, "provider": self.name, "cost": 0.0, "error": "parse_failed"}
|
||||
|
||||
async def _generate(self, prompt: str, max_tokens: int = 500) -> str:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.model_url}/v1/completions",
|
||||
json={"model": self.model_name, "prompt": prompt, "max_tokens": max_tokens, "temperature": 0.7, "stream": False},
|
||||
timeout=60,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["choices"][0]["text"].strip()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
return 0.0
|
||||
@@ -0,0 +1,102 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
from openai import AsyncOpenAI
|
||||
from app.ai.base import AIProvider
|
||||
|
||||
|
||||
SYSTEM_PROMPTS = {
|
||||
"translate": "You are a professional translator specialized in foreign trade and e-commerce. "
|
||||
"Accurately translate business terms like MOQ, FOB, CIF, lead time, etc. "
|
||||
"Return ONLY the translated text, no explanations.",
|
||||
"reply": "You are an experienced foreign trade sales expert. Write professional, "
|
||||
"clear business replies. Be concise but warm. Include relevant details "
|
||||
"naturally. Return ONLY the reply text, no explanations.",
|
||||
"marketing": "You are a creative copywriter for international trade. Write compelling "
|
||||
"marketing content that drives action. Adapt to the target audience's culture. "
|
||||
"Return ONLY the copy, no explanations.",
|
||||
"extract": "You extract structured data from text. Return ONLY valid JSON matching the requested schema.",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str = "gpt-4o"):
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
self.model = model
|
||||
self._name = f"openai-{model}"
|
||||
self._pricing = {
|
||||
"gpt-4o": {"input": 0.01, "output": 0.03},
|
||||
"gpt-4o-mini": {"input": 0.0015, "output": 0.006},
|
||||
}
|
||||
self._cheap_model = "gpt-4o-mini" if model == "gpt-4o" else model
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["translate"]
|
||||
if context:
|
||||
system += f"\nContext: this is about {context}"
|
||||
if source_lang and source_lang != "auto":
|
||||
system += f"\nSource language: {source_lang}"
|
||||
|
||||
content = await self._call(system, f"Translate to {target_lang}:\n\n{text}", model=self._cheap_model)
|
||||
return {"translated_text": content, "provider": self.name, "model": self.model}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["reply"] + f"\nTone: {tone}"
|
||||
|
||||
context_str = ""
|
||||
if context:
|
||||
if context.get("product"):
|
||||
context_str += f"Product: {context['product']}\n"
|
||||
if context.get("price"):
|
||||
context_str += f"Price: {context['price']}\n"
|
||||
if context.get("customer_history"):
|
||||
context_str += f"Customer history: {context['customer_history']}\n"
|
||||
if context.get("conversation_history"):
|
||||
context_str += f"Previous messages: {context['conversation_history']}\n"
|
||||
|
||||
prompt = f"{context_str}\nCustomer inquiry:\n{inquiry}\n\nWrite a reply:"
|
||||
content = await self._call(system, prompt)
|
||||
return {"reply": content, "provider": self.name, "model": self.model}
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["marketing"] + f"\nStyle: {style}\nTarget audience: {target}\nLanguage: {language}"
|
||||
|
||||
product_str = json.dumps(product_info, ensure_ascii=False, indent=2)
|
||||
prompt = f"Product information:\n{product_str}\n\nGenerate marketing copy:"
|
||||
content = await self._call(system, prompt)
|
||||
return {"content": content, "provider": self.name, "model": self.model}
|
||||
|
||||
async def extract_info(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["extract"]
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
prompt = f"Schema:\n{schema_str}\n\nText:\n{text}\n\nExtracted JSON:"
|
||||
content = await self._call(system, prompt, response_format={"type": "json_object"})
|
||||
try:
|
||||
data = json.loads(content)
|
||||
return {"data": data, "confidence": 0.9, "provider": self.name}
|
||||
except json.JSONDecodeError:
|
||||
return {"data": {}, "confidence": 0.0, "provider": self.name, "error": "parse_failed"}
|
||||
|
||||
async def _call(self, system: str, prompt: str, max_tokens: int = 1000, response_format: Optional[Dict] = None, model: Optional[str] = None) -> str:
|
||||
kwargs = {
|
||||
"model": model or self.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
resp = await self.client.chat.completions.create(**kwargs)
|
||||
return resp.choices[0].message.content
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
p = self._pricing.get(self.model, {"input": 0.01, "output": 0.03})
|
||||
return (p["input"] + p["output"]) / 2
|
||||
@@ -0,0 +1,110 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from app.ai.base import AIProvider
|
||||
from app.ai.providers import OpenAIProvider, ClaudeProvider, DeepLProvider, LocalProvider
|
||||
from app.config import settings
|
||||
from app.ai.trade_corpus import TradeCorpus
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIRouter:
|
||||
def __init__(self):
|
||||
self.providers: Dict[str, AIProvider] = {}
|
||||
self.routing_rules = settings.AI_ROUTING
|
||||
self.corpus = TradeCorpus()
|
||||
self._init_providers()
|
||||
|
||||
def _init_providers(self):
|
||||
if settings.OPENAI_API_KEY:
|
||||
try:
|
||||
self.providers["openai"] = OpenAIProvider(api_key=settings.OPENAI_API_KEY)
|
||||
logger.info("OpenAI provider ready")
|
||||
except Exception as e:
|
||||
logger.warning(f"OpenAI init failed: {e}")
|
||||
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
try:
|
||||
self.providers["anthropic"] = ClaudeProvider(api_key=settings.ANTHROPIC_API_KEY)
|
||||
logger.info("Claude provider ready")
|
||||
except Exception as e:
|
||||
logger.warning(f"Claude init failed: {e}")
|
||||
|
||||
if settings.DEEPL_API_KEY:
|
||||
try:
|
||||
self.providers["deepl"] = DeepLProvider(api_key=settings.DEEPL_API_KEY)
|
||||
logger.info("DeepL provider ready")
|
||||
except Exception as e:
|
||||
logger.warning(f"DeepL init failed: {e}")
|
||||
|
||||
if settings.LOCAL_MODEL_ENABLED:
|
||||
try:
|
||||
self.providers["local"] = LocalProvider(model_url=settings.LOCAL_MODEL_URL)
|
||||
logger.info("Local provider ready")
|
||||
except Exception as e:
|
||||
logger.warning(f"Local init failed: {e}")
|
||||
|
||||
def get_providers_for_task(self, task_type: str) -> List[AIProvider]:
|
||||
rules = self.routing_rules.get(
|
||||
task_type,
|
||||
{"primary": "openai", "fallback": ["local"]},
|
||||
)
|
||||
ordered = []
|
||||
seen = set()
|
||||
|
||||
primary = rules.get("primary")
|
||||
if primary and primary in self.providers:
|
||||
ordered.append(self.providers[primary])
|
||||
seen.add(primary)
|
||||
|
||||
for name in rules.get("fallback", []):
|
||||
if name in self.providers and name not in seen:
|
||||
ordered.append(self.providers[name])
|
||||
seen.add(name)
|
||||
|
||||
if not ordered:
|
||||
ordered = list(self.providers.values())
|
||||
logger.warning(f"No preferred providers for '{task_type}', using all available")
|
||||
|
||||
return ordered
|
||||
|
||||
async def execute(self, task_type: str, method: str, *args, **kwargs) -> Dict[str, Any]:
|
||||
providers = self.get_providers_for_task(task_type)
|
||||
last_error = None
|
||||
|
||||
for provider in providers:
|
||||
try:
|
||||
method_fn = getattr(provider, method)
|
||||
result = await method_fn(*args, **kwargs)
|
||||
result["provider_used"] = provider.name
|
||||
return result
|
||||
except NotImplementedError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"{provider.name} failed for {task_type}: {e}")
|
||||
last_error = e
|
||||
continue
|
||||
|
||||
raise Exception(f"All providers failed for '{task_type}'. Last error: {last_error}")
|
||||
|
||||
async def translate(self, text: str, target_lang: str, source_lang: Optional[str] = None, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
return await self.execute("translate", "translate", text, source_lang, target_lang, context)
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
return await self.execute("reply", "reply", inquiry, context, tone)
|
||||
|
||||
async def marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
return await self.execute("marketing", "generate_marketing", product_info, target, style, language)
|
||||
|
||||
async def extract(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return await self.execute("extract", "extract_info", text, schema)
|
||||
|
||||
|
||||
_router_instance = None
|
||||
|
||||
|
||||
def get_ai_router() -> AIRouter:
|
||||
global _router_instance
|
||||
if _router_instance is None:
|
||||
_router_instance = AIRouter()
|
||||
return _router_instance
|
||||
@@ -0,0 +1,87 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy import select, text
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TradeCorpus:
|
||||
def __init__(self):
|
||||
self._ready = False
|
||||
|
||||
async def record(
|
||||
self,
|
||||
source_text: str,
|
||||
target_text: str,
|
||||
task_type: str,
|
||||
provider: str,
|
||||
source_lang: Optional[str] = None,
|
||||
target_lang: Optional[str] = None,
|
||||
quality_score: float = 0.5,
|
||||
user_edited: bool = False,
|
||||
metadata: Optional[Dict] = None,
|
||||
):
|
||||
try:
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
entry = CorpusEntry(
|
||||
source_text=source_text[:2000],
|
||||
target_text=target_text[:2000],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
task_type=task_type,
|
||||
provider_used=provider,
|
||||
quality_score=quality_score,
|
||||
user_edited=user_edited,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
session.add(entry)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record corpus entry: {e}")
|
||||
|
||||
async def find_similar(self, text: str, task_type: str, top_k: int = 3) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(CorpusEntry)
|
||||
.where(CorpusEntry.task_type == task_type)
|
||||
.where(CorpusEntry.quality_score >= 0.6)
|
||||
.order_by(CorpusEntry.quality_score.desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
entries = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"source": e.source_text,
|
||||
"target": e.target_text,
|
||||
"score": e.quality_score,
|
||||
}
|
||||
for e in entries
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Corpus search failed: {e}")
|
||||
return []
|
||||
|
||||
async def rate_entry(self, entry_id: str, rating: int):
|
||||
try:
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(CorpusEntry).where(CorpusEntry.id == entry_id)
|
||||
)
|
||||
entry = result.scalar_one_or_none()
|
||||
if entry:
|
||||
entry.user_rating = rating
|
||||
entry.quality_score = rating / 5.0
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rate corpus entry: {e}")
|
||||
Reference in New Issue
Block a user