import json import logging from typing import Dict, Any, Optional, List from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, desc from app.models.agent_pipeline import AgentPipeline from app.models.customer import Customer from app.ai.router import get_ai_router from app.services.discovery import DiscoveryService from app.services.marketing import MarketingService from app.services.followup_engine import FollowupEngine logger = logging.getLogger(__name__) class AgentOrchestrator: """AI Digital Employee — chains discovery → analysis → outreach → followup.""" def __init__(self, db: AsyncSession): self.db = db self.ai = get_ai_router() self.discovery = DiscoveryService(db=db) self.marketing = MarketingService() self.followup = FollowupEngine(db) async def start_pipeline( self, user_id: str, product_name: str, product_description: str, target_market: str, ) -> Dict[str, Any]: pipeline = AgentPipeline( user_id=user_id, product_name=product_name, product_description=product_description, target_market=target_market, ) self.db.add(pipeline) await self.db.flush() pipeline_id = str(pipeline.id) data = pipeline.pipeline_data or {} stages = data.get("stages", {}) leads = data.get("leads", []) try: # ── Stage 1: Discover ── stages["discover"] = {"status": "running", "message": "正在搜索潜在客户..."} pipeline.pipeline_data = {"stages": stages, "leads": leads, "summary": data.get("summary", {})} pipeline.progress = 10 await self.db.flush() search_result = await self.discovery.search( f"{product_name} {product_description}", target_market, ) companies = search_result.get("companies", []) provider = search_result.get("provider", "unknown") stages["discover"] = { "status": "completed", "message": f"已发现 {len(companies)} 家潜在客户", "provider": provider, "count": len(companies), } pipeline.pipeline_data = {"stages": stages, "leads": leads, "summary": data.get("summary", {})} pipeline.progress = 30 await self.db.flush() # ── Stage 2: Analyze ── stages["analyze"] = {"status": "running", "message": "正在分析客户匹配度..."} pipeline.pipeline_data = {"stages": stages, "leads": leads, "summary": data.get("summary", {})} pipeline.progress = 40 await self.db.flush() analyzed_leads = [] for idx, company in enumerate(companies): company_url = company.get("contact", "") if company_url and company_url.startswith("http"): try: analysis = await self.discovery.analyze( company_url, f"{product_name} {product_description}", ) except Exception as e: logger.warning(f"Analysis failed for {company.get('name')}: {e}") analysis = {"match_score": 50, "match_reason": "分析失败"} else: analysis = {"match_score": company.get("match_score", 50), "match_reason": "基于搜索结果的初步评估"} lead = { "id": str(idx + 1), "name": company.get("name", "未知"), "description": company.get("description", ""), "url": company.get("contact", ""), "country": company.get("country", ""), "source": company.get("source", "web"), "match_score": analysis.get("match_score", 50), "match_reason": analysis.get("match_reason", ""), "contact_info": analysis.get("contact_info", {}), "company_summary": analysis.get("company_summary", ""), "product_fit": analysis.get("product_fit", ""), "outreach": None, } analyzed_leads.append(lead) analyzed_leads.sort(key=lambda x: x["match_score"], reverse=True) leads = analyzed_leads stages["analyze"] = { "status": "completed", "message": f"已完成 {len(leads)} 家客户分析", "count": len(leads), } pipeline.pipeline_data = {"stages": stages, "leads": leads, "summary": data.get("summary", {})} pipeline.progress = 65 await self.db.flush() # ── Stage 3: Outreach ── stages["outreach"] = {"status": "running", "message": "正在生成触达文案..."} pipeline.pipeline_data = {"stages": stages, "leads": leads, "summary": data.get("summary", {})} pipeline.progress = 75 await self.db.flush() top_leads = [l for l in leads if l["match_score"] >= 60][:5] for lead in top_leads: try: outreach = await self.discovery.outreach( {"name": lead["name"], "url": lead["url"], "description": lead.get("company_summary", lead["description"])}, {"name": product_name, "description": product_description}, ) lead["outreach"] = outreach except Exception as e: logger.warning(f"Outreach failed for {lead['name']}: {e}") lead["outreach"] = None # Auto-save high-scoring leads as customers saved_count = 0 for lead in leads: if lead["match_score"] >= 70 and lead.get("url"): existing = await self.db.execute( select(Customer).where( Customer.user_id == user_id, Customer.name == lead["name"], ) ) if not existing.scalar_one_or_none(): customer = Customer( user_id=user_id, name=lead["name"], company=lead.get("company_summary", lead["name"])[:200], country=lead.get("country", ""), description=lead.get("description", "")[:500], status="lead", source=f"ai_agent:{pipeline_id}", ) self.db.add(customer) saved_count += 1 if saved_count > 0: await self.db.flush() stages["outreach"] = { "status": "completed", "message": f"已为 {len(top_leads)} 个高匹配客户生成触达文案,自动保存 {saved_count} 个客户", "top_count": len(top_leads), "saved_count": saved_count, } pipeline.pipeline_data = {"stages": stages, "leads": leads, "summary": data.get("summary", {})} pipeline.progress = 90 await self.db.flush() # ── Complete ── stages["complete"] = { "status": "completed", "message": f"AI数字员工任务完成!发现 {len(leads)} 个潜在客户,分析完成,高匹配客户已保存并生成触达文案。", } summary = { "total_leads": len(leads), "high_match": len([l for l in leads if l["match_score"] >= 70]), "medium_match": len([l for l in leads if 50 <= l["match_score"] < 70]), "low_match": len([l for l in leads if l["match_score"] < 50]), "outreach_generated": len([l for l in leads if l.get("outreach")]), "customers_saved": saved_count, } pipeline.pipeline_data = {"stages": stages, "leads": leads, "summary": summary} pipeline.status = "completed" pipeline.progress = 100 await self.db.flush() except Exception as e: logger.error(f"Pipeline {pipeline_id} failed: {e}", exc_info=True) pipeline.status = "failed" pipeline.error_message = str(e)[:500] stages["discover"] = stages.get("discover", {"status": "pending", "message": ""}) pipeline.pipeline_data = {"stages": stages, "leads": leads, "summary": data.get("summary", {})} await self.db.flush() return await self._pipeline_to_dict(pipeline) async def get_pipeline(self, pipeline_id: str, user_id: str) -> Optional[Dict[str, Any]]: result = await self.db.execute( select(AgentPipeline).where( AgentPipeline.id == pipeline_id, AgentPipeline.user_id == user_id, ) ) pipeline = result.scalar_one_or_none() if not pipeline: return None return await self._pipeline_to_dict(pipeline) async def list_pipelines(self, user_id: str, page: int = 1, size: int = 20) -> Dict[str, Any]: query = ( select(AgentPipeline) .where(AgentPipeline.user_id == user_id) .order_by(desc(AgentPipeline.created_at)) .offset((page - 1) * size) .limit(size) ) count_q = select(AgentPipeline).where(AgentPipeline.user_id == user_id) result = await self.db.execute(query) pipelines = result.scalars().all() count_result = await self.db.execute(count_q) total = len(count_result.scalars().all()) items = [await self._pipeline_to_dict(p) for p in pipelines] return {"items": items, "total": total, "page": page, "size": size} async def _pipeline_to_dict(self, p: AgentPipeline) -> Dict[str, Any]: return { "id": str(p.id), "status": p.status, "progress": p.progress, "product_name": p.product_name, "product_description": p.product_description, "target_market": p.target_market, "pipeline_data": p.pipeline_data, "error_message": p.error_message, "created_at": p.created_at.isoformat() if p.created_at else None, "updated_at": p.updated_at.isoformat() if p.updated_at else None, }