7b62c2f8b4
## H5 底部导航修复 (Bug #10) - 精简 App.vue,移除重复 tabbar,仅保留全局样式 - uni-page 设置 height: calc(100% - 50px) + overflow-y: auto - 内容区域精确停在底部导航上方,独立滚动不再叠加 - 恢复 custom-tab-bar 组件 ## 项目进度文档 - PROGRESS.md 更新至 10 个 Bug 修复 - 新增 H5 底部导航修复记录 - 新增历史变更条目
187 lines
6.4 KiB
Python
187 lines
6.4 KiB
Python
from typing import Dict, Any, Optional, List
|
|
from sqlalchemy import select, func, and_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from datetime import datetime, timedelta
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CorpusTrainer:
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def compute_embeddings(self, batch_size: int = 50) -> Dict[str, Any]:
|
|
from app.models.corpus import CorpusEntry
|
|
|
|
result = await self.db.execute(
|
|
select(CorpusEntry).where(CorpusEntry.embedding.is_(None)).limit(batch_size)
|
|
)
|
|
entries = result.scalars().all()
|
|
|
|
updated = 0
|
|
for entry in entries:
|
|
try:
|
|
embedding = await self._generate_embedding(entry.source_text)
|
|
if embedding:
|
|
entry.embedding = embedding
|
|
updated += 1
|
|
except Exception as e:
|
|
logger.warning(f"Embedding failed for entry {entry.id}: {e}")
|
|
|
|
await self.db.flush()
|
|
return {"processed": len(entries), "updated": updated}
|
|
|
|
async def score_entries(self, batch_size: int = 100) -> Dict[str, Any]:
|
|
from app.models.corpus import CorpusEntry
|
|
|
|
result = await self.db.execute(
|
|
select(CorpusEntry)
|
|
.where(CorpusEntry.quality_score.is_(None))
|
|
.limit(batch_size)
|
|
)
|
|
entries = result.scalars().all()
|
|
|
|
updated = 0
|
|
for entry in entries:
|
|
score = self._calculate_quality_score(entry)
|
|
entry.quality_score = score
|
|
updated += 1
|
|
|
|
await self.db.flush()
|
|
return {"processed": len(entries), "updated": updated}
|
|
|
|
async def deduplicate(self) -> Dict[str, Any]:
|
|
from app.models.corpus import CorpusEntry
|
|
|
|
subquery = (
|
|
select(
|
|
CorpusEntry.source_text,
|
|
CorpusEntry.task_type,
|
|
func.min(CorpusEntry.id).label("keep_id"),
|
|
)
|
|
.group_by(CorpusEntry.source_text, CorpusEntry.task_type)
|
|
.having(func.count(CorpusEntry.id) > 1)
|
|
.subquery()
|
|
)
|
|
|
|
result = await self.db.execute(
|
|
select(CorpusEntry).where(
|
|
and_(
|
|
CorpusEntry.source_text == subquery.c.source_text,
|
|
CorpusEntry.task_type == subquery.c.task_type,
|
|
CorpusEntry.id != subquery.c.keep_id,
|
|
)
|
|
)
|
|
)
|
|
duplicates = result.scalars().all()
|
|
|
|
for dup in duplicates:
|
|
await self.db.delete(dup)
|
|
|
|
await self.db.flush()
|
|
return {"duplicates_removed": len(duplicates)}
|
|
|
|
async def prune_low_quality(self, min_score: float = 0.2, max_age_days: int = 90) -> Dict[str, Any]:
|
|
from app.models.corpus import CorpusEntry
|
|
|
|
cutoff = datetime.utcnow() - timedelta(days=max_age_days)
|
|
result = await self.db.execute(
|
|
select(CorpusEntry).where(
|
|
and_(
|
|
CorpusEntry.quality_score < min_score,
|
|
CorpusEntry.created_at < cutoff,
|
|
CorpusEntry.usage_count.is_(None) | (CorpusEntry.usage_count < 2),
|
|
)
|
|
)
|
|
)
|
|
entries = result.scalars().all()
|
|
|
|
for e in entries:
|
|
await self.db.delete(e)
|
|
|
|
await self.db.flush()
|
|
return {"pruned": len(entries)}
|
|
|
|
async def get_stats(self) -> Dict[str, Any]:
|
|
from app.models.corpus import CorpusEntry
|
|
|
|
total = await self.db.execute(select(func.count(CorpusEntry.id)))
|
|
by_type = await self.db.execute(
|
|
select(CorpusEntry.task_type, func.count(CorpusEntry.id))
|
|
.group_by(CorpusEntry.task_type)
|
|
)
|
|
with_embeddings = await self.db.execute(
|
|
select(func.count(CorpusEntry.id)).where(CorpusEntry.embedding.isnot(None))
|
|
)
|
|
high_quality = await self.db.execute(
|
|
select(func.count(CorpusEntry.id)).where(CorpusEntry.quality_score >= 0.7)
|
|
)
|
|
low_quality = await self.db.execute(
|
|
select(func.count(CorpusEntry.id)).where(CorpusEntry.quality_score < 0.3)
|
|
)
|
|
|
|
return {
|
|
"total_entries": total.scalar() or 0,
|
|
"by_task_type": {row[0]: row[1] for row in by_type.all()},
|
|
"with_embeddings": with_embeddings.scalar() or 0,
|
|
"high_quality": high_quality.scalar() or 0,
|
|
"low_quality": low_quality.scalar() or 0,
|
|
}
|
|
|
|
async def run_pipeline(self) -> Dict[str, Any]:
|
|
dedup_result = await self.deduplicate()
|
|
score_result = await self.score_entries()
|
|
embed_result = await self.compute_embeddings()
|
|
prune_result = await self.prune_low_quality()
|
|
stats = await self.get_stats()
|
|
|
|
return {
|
|
"deduplication": dedup_result,
|
|
"scoring": score_result,
|
|
"embeddings": embed_result,
|
|
"pruning": prune_result,
|
|
"stats": stats,
|
|
}
|
|
|
|
def _calculate_quality_score(self, entry) -> float:
|
|
score = 0.5
|
|
|
|
if entry.user_rating:
|
|
score = entry.user_rating / 5.0
|
|
|
|
if entry.user_edited:
|
|
score = max(score - 0.1, 0)
|
|
|
|
if entry.usage_count and entry.usage_count > 5:
|
|
score = min(score + 0.15, 1.0)
|
|
|
|
src_len = len(entry.source_text) if entry.source_text else 0
|
|
tgt_len = len(entry.target_text) if entry.target_text else 0
|
|
if src_len > 10 and tgt_len > 10:
|
|
score = min(score + 0.1, 1.0)
|
|
if src_len < 3 or tgt_len < 3:
|
|
score = max(score - 0.3, 0)
|
|
|
|
return round(score, 2)
|
|
|
|
async def _generate_embedding(self, text: str) -> Optional[List[float]]:
|
|
try:
|
|
from app.config import settings
|
|
import httpx
|
|
|
|
if settings.OPENAI_API_KEY:
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.post(
|
|
"https://api.openai.com/v1/embeddings",
|
|
headers={"Authorization": f"Bearer {settings.OPENAI_API_KEY}"},
|
|
json={"model": "text-embedding-3-small", "input": text[:8000]},
|
|
timeout=30,
|
|
)
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
return data["data"][0]["embedding"]
|
|
except Exception as e:
|
|
logger.warning(f"Embedding generation failed: {e}")
|
|
return None
|