from typing import Dict, Any, Optional, List from sqlalchemy import select, func, and_ from sqlalchemy.ext.asyncio import AsyncSession from datetime import datetime, timedelta import logging logger = logging.getLogger(__name__) class CorpusTrainer: def __init__(self, db: AsyncSession): self.db = db async def compute_embeddings(self, batch_size: int = 50) -> Dict[str, Any]: from app.models.corpus import CorpusEntry result = await self.db.execute( select(CorpusEntry).where(CorpusEntry.embedding.is_(None)).limit(batch_size) ) entries = result.scalars().all() updated = 0 for entry in entries: try: embedding = await self._generate_embedding(entry.source_text) if embedding: entry.embedding = embedding updated += 1 except Exception as e: logger.warning(f"Embedding failed for entry {entry.id}: {e}") await self.db.flush() return {"processed": len(entries), "updated": updated} async def score_entries(self, batch_size: int = 100) -> Dict[str, Any]: from app.models.corpus import CorpusEntry result = await self.db.execute( select(CorpusEntry) .where(CorpusEntry.quality_score.is_(None)) .limit(batch_size) ) entries = result.scalars().all() updated = 0 for entry in entries: score = self._calculate_quality_score(entry) entry.quality_score = score updated += 1 await self.db.flush() return {"processed": len(entries), "updated": updated} async def deduplicate(self) -> Dict[str, Any]: from app.models.corpus import CorpusEntry subquery = ( select( CorpusEntry.source_text, CorpusEntry.task_type, func.min(CorpusEntry.id).label("keep_id"), ) .group_by(CorpusEntry.source_text, CorpusEntry.task_type) .having(func.count(CorpusEntry.id) > 1) .subquery() ) result = await self.db.execute( select(CorpusEntry).where( and_( CorpusEntry.source_text == subquery.c.source_text, CorpusEntry.task_type == subquery.c.task_type, CorpusEntry.id != subquery.c.keep_id, ) ) ) duplicates = result.scalars().all() for dup in duplicates: await self.db.delete(dup) await self.db.flush() return {"duplicates_removed": len(duplicates)} async def prune_low_quality(self, min_score: float = 0.2, max_age_days: int = 90) -> Dict[str, Any]: from app.models.corpus import CorpusEntry cutoff = datetime.utcnow() - timedelta(days=max_age_days) result = await self.db.execute( select(CorpusEntry).where( and_( CorpusEntry.quality_score < min_score, CorpusEntry.created_at < cutoff, CorpusEntry.usage_count.is_(None) | (CorpusEntry.usage_count < 2), ) ) ) entries = result.scalars().all() for e in entries: await self.db.delete(e) await self.db.flush() return {"pruned": len(entries)} async def get_stats(self) -> Dict[str, Any]: from app.models.corpus import CorpusEntry total = await self.db.execute(select(func.count(CorpusEntry.id))) by_type = await self.db.execute( select(CorpusEntry.task_type, func.count(CorpusEntry.id)) .group_by(CorpusEntry.task_type) ) with_embeddings = await self.db.execute( select(func.count(CorpusEntry.id)).where(CorpusEntry.embedding.isnot(None)) ) high_quality = await self.db.execute( select(func.count(CorpusEntry.id)).where(CorpusEntry.quality_score >= 0.7) ) low_quality = await self.db.execute( select(func.count(CorpusEntry.id)).where(CorpusEntry.quality_score < 0.3) ) return { "total_entries": total.scalar() or 0, "by_task_type": {row[0]: row[1] for row in by_type.all()}, "with_embeddings": with_embeddings.scalar() or 0, "high_quality": high_quality.scalar() or 0, "low_quality": low_quality.scalar() or 0, } async def run_pipeline(self) -> Dict[str, Any]: dedup_result = await self.deduplicate() score_result = await self.score_entries() embed_result = await self.compute_embeddings() prune_result = await self.prune_low_quality() stats = await self.get_stats() return { "deduplication": dedup_result, "scoring": score_result, "embeddings": embed_result, "pruning": prune_result, "stats": stats, } def _calculate_quality_score(self, entry) -> float: score = 0.5 if entry.user_rating: score = entry.user_rating / 5.0 if entry.user_edited: score = max(score - 0.1, 0) if entry.usage_count and entry.usage_count > 5: score = min(score + 0.15, 1.0) src_len = len(entry.source_text) if entry.source_text else 0 tgt_len = len(entry.target_text) if entry.target_text else 0 if src_len > 10 and tgt_len > 10: score = min(score + 0.1, 1.0) if src_len < 3 or tgt_len < 3: score = max(score - 0.3, 0) return round(score, 2) async def _generate_embedding(self, text: str) -> Optional[List[float]]: try: from app.config import settings import httpx if settings.OPENAI_API_KEY: async with httpx.AsyncClient() as client: resp = await client.post( "https://api.openai.com/v1/embeddings", headers={"Authorization": f"Bearer {settings.OPENAI_API_KEY}"}, json={"model": "text-embedding-3-small", "input": text[:8000]}, timeout=30, ) if resp.status_code == 200: data = resp.json() return data["data"][0]["embedding"] except Exception as e: logger.warning(f"Embedding generation failed: {e}") return None