import pytest from unittest.mock import patch, AsyncMock from app.services.corpus_trainer import CorpusTrainer from app.models.corpus import CorpusEntry from datetime import datetime class TestCorpusTrainer: async def test_get_stats_empty(self, db_session): trainer = CorpusTrainer(db_session) stats = await trainer.get_stats() assert stats["total_entries"] == 0 assert stats["with_embeddings"] == 0 async def test_get_stats_with_data(self, db_session): entries = [ CorpusEntry(source_text="Hello", target_text="你好", task_type="translate", quality_score=0.8), CorpusEntry(source_text="Goodbye", target_text="再见", task_type="translate", quality_score=0.6), ] for e in entries: db_session.add(e) await db_session.commit() trainer = CorpusTrainer(db_session) stats = await trainer.get_stats() assert stats["total_entries"] == 2 assert stats["by_task_type"]["translate"] == 2 assert stats["high_quality"] == 1 assert stats["low_quality"] == 0 async def test_score_entries(self, db_session): entries = [ CorpusEntry(source_text="Hello world", target_text="你好世界", task_type="translate"), CorpusEntry(source_text="Hi", target_text="嗨", task_type="translate"), ] for e in entries: db_session.add(e) await db_session.commit() trainer = CorpusTrainer(db_session) result = await trainer.score_entries(batch_size=10) assert result["processed"] == 2 assert result["updated"] == 2 for e in entries: await db_session.refresh(e) assert e.quality_score is not None assert 0.0 <= e.quality_score <= 1.0 async def test_deduplicate(self, db_session): from datetime import datetime e1 = CorpusEntry( source_text="Duplicate text", target_text="重复文本", task_type="translate", quality_score=0.8, created_at=datetime.utcnow(), ) e2 = CorpusEntry( source_text="Duplicate text", target_text="重复文本", task_type="translate", quality_score=0.7, created_at=datetime.utcnow(), ) db_session.add_all([e1, e2]) await db_session.commit() trainer = CorpusTrainer(db_session) result = await trainer.deduplicate() assert result["duplicates_removed"] == 1 stats = await trainer.get_stats() assert stats["total_entries"] == 1 async def test_prune_low_quality(self, db_session): from datetime import timedelta old = datetime.utcnow() - timedelta(days=100) entry = CorpusEntry( source_text="x", target_text="y", task_type="translate", quality_score=0.1, created_at=old, usage_count=0, ) db_session.add(entry) await db_session.commit() trainer = CorpusTrainer(db_session) result = await trainer.prune_low_quality(min_score=0.2, max_age_days=30) assert result["pruned"] == 1 stats = await trainer.get_stats() assert stats["total_entries"] == 0 async def test_run_pipeline(self, db_session): trainer = CorpusTrainer(db_session) result = await trainer.run_pipeline() assert "deduplication" in result assert "scoring" in result assert "embeddings" in result assert "pruning" in result assert "stats" in result def test_calculate_quality_score_with_rating(self, db_session): trainer = CorpusTrainer(db_session) entry = CorpusEntry( source_text="Good source text with enough length", target_text="Good target text with enough length", task_type="translate", user_rating=4, ) score = trainer._calculate_quality_score(entry) assert 0.7 <= score <= 1.0 def test_calculate_quality_score_short_text(self, db_session): trainer = CorpusTrainer(db_session) entry = CorpusEntry( source_text="ab", target_text="cd", task_type="translate", ) score = trainer._calculate_quality_score(entry) assert score < 0.5 def test_calculate_quality_score_with_usage(self, db_session): trainer = CorpusTrainer(db_session) entry = CorpusEntry( source_text="Good source text here with proper length", target_text="Good target text here with proper length", task_type="translate", usage_count=10, ) score = trainer._calculate_quality_score(entry) assert score >= 0.6 async def test_embedding_generation_skipped_without_key(self, db_session): from app.config import settings original = settings.OPENAI_API_KEY settings.OPENAI_API_KEY = None trainer = CorpusTrainer(db_session) embedding = await trainer._generate_embedding("test") assert embedding is None settings.OPENAI_API_KEY = original