Add admin-frontend and user-frontend standalone projects, certification/invoice/discovery features, fix auth header and theme consistency
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
"""add certification and invoice
|
||||
|
||||
Revision ID: ecab04cc0e1d
|
||||
Revises: 93a81b22bd80
|
||||
Create Date: 2026-05-22 09:20:37.807327
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = 'ecab04cc0e1d'
|
||||
down_revision: Union[str, None] = '93a81b22bd80'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('certifications',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('cert_type', sa.Enum('individual', 'enterprise', name='certtype'), nullable=False),
|
||||
sa.Column('personal_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('personal_id', sa.String(length=30), nullable=True),
|
||||
sa.Column('company_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('tax_id', sa.String(length=30), nullable=True),
|
||||
sa.Column('business_license_url', sa.String(length=500), nullable=True),
|
||||
sa.Column('status', sa.Enum('pending', 'approved', 'rejected', name='certstatus'), nullable=True),
|
||||
sa.Column('reject_reason', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_certifications_user_id'), 'certifications', ['user_id'], unique=False)
|
||||
op.create_table('invoices',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('certification_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('invoice_type', sa.Enum('individual', 'enterprise', name='invoicetype'), nullable=False),
|
||||
sa.Column('title', sa.String(length=255), nullable=False),
|
||||
sa.Column('tax_id', sa.String(length=30), nullable=True),
|
||||
sa.Column('amount', sa.Float(), nullable=False),
|
||||
sa.Column('status', sa.Enum('pending', 'issued', 'rejected', name='invoicestatus'), nullable=True),
|
||||
sa.Column('reject_reason', sa.Text(), nullable=True),
|
||||
sa.Column('issued_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['certification_id'], ['certifications.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_invoices_user_id'), 'invoices', ['user_id'], unique=False)
|
||||
op.drop_index('ix_preference_analyses_user_id', table_name='preference_analyses')
|
||||
op.create_index(op.f('ix_preference_analyses_user_id'), 'preference_analyses', ['user_id'], unique=True)
|
||||
op.drop_constraint('system_configs_key_key', 'system_configs', type_='unique')
|
||||
op.drop_index('ix_system_configs_key', table_name='system_configs')
|
||||
op.create_index(op.f('ix_system_configs_key'), 'system_configs', ['key'], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_system_configs_key'), table_name='system_configs')
|
||||
op.create_index('ix_system_configs_key', 'system_configs', ['key'], unique=False)
|
||||
op.create_unique_constraint('system_configs_key_key', 'system_configs', ['key'])
|
||||
op.drop_index(op.f('ix_preference_analyses_user_id'), table_name='preference_analyses')
|
||||
op.create_index('ix_preference_analyses_user_id', 'preference_analyses', ['user_id'], unique=False)
|
||||
op.drop_index(op.f('ix_invoices_user_id'), table_name='invoices')
|
||||
op.drop_table('invoices')
|
||||
op.drop_index(op.f('ix_certifications_user_id'), table_name='certifications')
|
||||
op.drop_table('certifications')
|
||||
# ### end Alembic commands ###
|
||||
@@ -13,7 +13,7 @@ class NvidiaProvider(OpenAIProvider):
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
http_client=httpx.AsyncClient(timeout=httpx.Timeout(60.0)),
|
||||
http_client=httpx.AsyncClient(timeout=httpx.Timeout(20.0)),
|
||||
)
|
||||
self._name = f"nvidia-{model}"
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.services.admin import AdminService
|
||||
from app.services.translation_quota import TranslationQuotaService
|
||||
from app.services.certification import CertificationService
|
||||
from app.services.invoice import InvoiceService
|
||||
from app.api.v1.deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
@@ -212,3 +214,63 @@ async def reset_translation_quota(
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Quota not found")
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/certifications")
|
||||
async def admin_list_certifications(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
status: Optional[str] = Query(None),
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = CertificationService(db)
|
||||
return await service.list_all(page, size, status)
|
||||
|
||||
|
||||
@router.post("/certifications/{cert_id}/review")
|
||||
async def admin_review_certification(
|
||||
cert_id: str,
|
||||
data: dict,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
_validate_uuid(cert_id)
|
||||
service = CertificationService(db)
|
||||
action = data.get("action")
|
||||
if action not in ("approve", "reject"):
|
||||
raise HTTPException(status_code=400, detail="Action must be 'approve' or 'reject'")
|
||||
result = await service.review(cert_id, action, data.get("reason"))
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Certification not found")
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/invoices")
|
||||
async def admin_list_invoices(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
status: Optional[str] = Query(None),
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = InvoiceService(db)
|
||||
return await service.list_all(page, size, status)
|
||||
|
||||
|
||||
@router.post("/invoices/{invoice_id}/process")
|
||||
async def admin_process_invoice(
|
||||
invoice_id: str,
|
||||
data: dict,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
_validate_uuid(invoice_id)
|
||||
service = InvoiceService(db)
|
||||
action = data.get("action")
|
||||
if action not in ("issue", "reject"):
|
||||
raise HTTPException(status_code=400, detail="Action must be 'issue' or 'reject'")
|
||||
result = await service.process(invoice_id, action, data.get("reason"))
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Invoice not found")
|
||||
return result
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from app.database import get_db
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.services.certification import CertificationService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CertSubmitRequest(BaseModel):
|
||||
cert_type: str
|
||||
personal_name: Optional[str] = None
|
||||
personal_id: Optional[str] = None
|
||||
company_name: Optional[str] = None
|
||||
tax_id: Optional[str] = None
|
||||
business_license_url: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/submit")
|
||||
async def submit_certification(
|
||||
data: CertSubmitRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = CertificationService(db)
|
||||
result = await service.submit(user_id, data.model_dump())
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=400, detail=result["error"])
|
||||
return {"success": True, "data": result}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_certification_status(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = CertificationService(db)
|
||||
cert = await service.get_user_cert(user_id)
|
||||
return {"success": True, "data": cert}
|
||||
@@ -0,0 +1,61 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from app.services.discovery import DiscoveryService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
product_description: str
|
||||
target_market: str = "US"
|
||||
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
company_url: str
|
||||
product_description: str
|
||||
|
||||
|
||||
class OutreachRequest(BaseModel):
|
||||
company: Dict[str, Any]
|
||||
product: Dict[str, Any]
|
||||
|
||||
|
||||
@router.post("/search")
|
||||
async def search_leads(req: SearchRequest):
|
||||
if not req.product_description.strip():
|
||||
raise HTTPException(status_code=400, detail="请填写产品描述")
|
||||
svc = DiscoveryService()
|
||||
try:
|
||||
result = await svc.search(req.product_description, req.target_market)
|
||||
return {"success": True, "data": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"搜索失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/analyze")
|
||||
async def analyze_company(req: AnalyzeRequest):
|
||||
if not req.company_url.strip():
|
||||
raise HTTPException(status_code=400, detail="请填写公司网址")
|
||||
if not req.product_description.strip():
|
||||
raise HTTPException(status_code=400, detail="请填写产品描述")
|
||||
svc = DiscoveryService()
|
||||
try:
|
||||
result = await svc.analyze(req.company_url, req.product_description)
|
||||
return {"success": True, "data": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/outreach")
|
||||
async def generate_outreach(req: OutreachRequest):
|
||||
if not req.company.get("name"):
|
||||
raise HTTPException(status_code=400, detail="请填写公司名称")
|
||||
if not req.product.get("name"):
|
||||
raise HTTPException(status_code=400, detail="请填写产品名称")
|
||||
svc = DiscoveryService()
|
||||
try:
|
||||
result = await svc.outreach(req.company, req.product)
|
||||
return {"success": True, "data": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"生成失败: {str(e)}")
|
||||
@@ -0,0 +1,39 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from app.database import get_db
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.services.invoice import InvoiceService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class InvoiceApplyRequest(BaseModel):
|
||||
invoice_type: str
|
||||
title: str
|
||||
tax_id: Optional[str] = None
|
||||
amount: float
|
||||
|
||||
|
||||
@router.post("/apply")
|
||||
async def apply_invoice(
|
||||
data: InvoiceApplyRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = InvoiceService(db)
|
||||
result = await service.apply(user_id, data.model_dump())
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=400, detail=result["error"])
|
||||
return {"success": True, "data": result}
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_invoices(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = InvoiceService(db)
|
||||
items = await service.list_user(user_id)
|
||||
return {"success": True, "data": items}
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
@@ -8,10 +8,11 @@ ENV_FILE = PROJECT_ROOT / ".env"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
class Config:
|
||||
env_file = str(ENV_FILE)
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
model_config = {
|
||||
"env_file": str(ENV_FILE),
|
||||
"env_file_encoding": "utf-8",
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
APP_NAME: str = "TradeMate"
|
||||
|
||||
@@ -71,6 +72,9 @@ class Settings(BaseSettings):
|
||||
|
||||
EXCHANGE_RATE_API_KEY: Optional[str] = None
|
||||
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
GOOGLE_CSE_ID: Optional[str] = None
|
||||
|
||||
UPLOAD_DIR: str = "./uploads"
|
||||
MAX_UPLOAD_SIZE: int = 10 * 1024 * 1024
|
||||
|
||||
|
||||
+4
-1
@@ -54,7 +54,7 @@ async def health():
|
||||
return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"}
|
||||
|
||||
|
||||
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant
|
||||
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, certification, invoice
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
|
||||
app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"])
|
||||
@@ -78,6 +78,9 @@ app.include_router(silent_pattern.router, prefix="/api/v1/silent-pattern", tags=
|
||||
app.include_router(training.router, prefix="/api/v1/training", tags=["training"])
|
||||
app.include_router(followup.router, prefix="/api/v1/followup", tags=["followup"])
|
||||
app.include_router(ai_assistant.router, prefix="/api/v1/ai", tags=["ai-assistant"])
|
||||
app.include_router(discovery.router, prefix="/api/v1/discovery", tags=["discovery"])
|
||||
app.include_router(certification.router, prefix="/api/v1/certification", tags=["certification"])
|
||||
app.include_router(invoice.router, prefix="/api/v1/invoices", tags=["invoices"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -12,6 +12,8 @@ from .device import Device
|
||||
from .followup import FollowupStrategy, FollowupLog
|
||||
from .system_config import SystemConfig
|
||||
from .translation_quota import TranslationQuota
|
||||
from .certification import Certification, CertType, CertStatus
|
||||
from .invoice import Invoice, InvoiceType, InvoiceStatus
|
||||
|
||||
__all__ = [
|
||||
"User", "Product",
|
||||
@@ -28,4 +30,6 @@ __all__ = [
|
||||
"FollowupStrategy", "FollowupLog",
|
||||
"SystemConfig",
|
||||
"TranslationQuota",
|
||||
"Certification", "CertType", "CertStatus",
|
||||
"Invoice", "InvoiceType", "InvoiceStatus",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
from sqlalchemy import Column, String, Boolean, Integer, DateTime, Text, ForeignKey, Enum as SAEnum
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
import enum
|
||||
|
||||
|
||||
class CertType(str, enum.Enum):
|
||||
individual = "individual"
|
||||
enterprise = "enterprise"
|
||||
|
||||
|
||||
class CertStatus(str, enum.Enum):
|
||||
pending = "pending"
|
||||
approved = "approved"
|
||||
rejected = "rejected"
|
||||
|
||||
|
||||
class Certification(Base):
|
||||
__tablename__ = "certifications"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, index=True)
|
||||
cert_type = Column(SAEnum(CertType), nullable=False)
|
||||
|
||||
personal_name = Column(String(100))
|
||||
personal_id = Column(String(30))
|
||||
|
||||
company_name = Column(String(255))
|
||||
tax_id = Column(String(30))
|
||||
business_license_url = Column(String(500))
|
||||
|
||||
status = Column(SAEnum(CertStatus), default=CertStatus.pending)
|
||||
reject_reason = Column(Text)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
user = relationship("User")
|
||||
@@ -0,0 +1,41 @@
|
||||
from sqlalchemy import Column, String, Boolean, Integer, DateTime, Text, Float, ForeignKey, Enum as SAEnum
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
import enum
|
||||
|
||||
|
||||
class InvoiceType(str, enum.Enum):
|
||||
individual = "individual"
|
||||
enterprise = "enterprise"
|
||||
|
||||
|
||||
class InvoiceStatus(str, enum.Enum):
|
||||
pending = "pending"
|
||||
issued = "issued"
|
||||
rejected = "rejected"
|
||||
|
||||
|
||||
class Invoice(Base):
|
||||
__tablename__ = "invoices"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, index=True)
|
||||
certification_id = Column(UUID(as_uuid=True), ForeignKey("certifications.id"), nullable=True)
|
||||
|
||||
invoice_type = Column(SAEnum(InvoiceType), nullable=False)
|
||||
title = Column(String(255), nullable=False)
|
||||
tax_id = Column(String(30))
|
||||
amount = Column(Float, nullable=False)
|
||||
|
||||
status = Column(SAEnum(InvoiceStatus), default=InvoiceStatus.pending)
|
||||
reject_reason = Column(Text)
|
||||
issued_at = Column(DateTime)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
user = relationship("User")
|
||||
certification = relationship("Certification")
|
||||
@@ -0,0 +1,112 @@
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
from app.models.certification import Certification, CertType, CertStatus
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
class CertificationService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def submit(self, user_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
existing = await self._get_pending(user_id)
|
||||
if existing:
|
||||
return {"error": "已有审核中的认证申请,请勿重复提交"}
|
||||
|
||||
cert = Certification(
|
||||
user_id=uuid.UUID(user_id),
|
||||
cert_type=CertType(data["cert_type"]),
|
||||
personal_name=data.get("personal_name"),
|
||||
personal_id=data.get("personal_id"),
|
||||
company_name=data.get("company_name"),
|
||||
tax_id=data.get("tax_id"),
|
||||
business_license_url=data.get("business_license_url"),
|
||||
status=CertStatus.pending,
|
||||
)
|
||||
self.db.add(cert)
|
||||
await self.db.flush()
|
||||
return {"id": str(cert.id), "status": cert.status.value}
|
||||
|
||||
async def get_user_cert(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
result = await self.db.execute(
|
||||
select(Certification)
|
||||
.where(Certification.user_id == uuid.UUID(user_id))
|
||||
.order_by(desc(Certification.created_at))
|
||||
.limit(1)
|
||||
)
|
||||
cert = result.scalar_one_or_none()
|
||||
if not cert:
|
||||
return None
|
||||
return {
|
||||
"id": str(cert.id),
|
||||
"cert_type": cert.cert_type.value,
|
||||
"personal_name": cert.personal_name,
|
||||
"personal_id": cert.personal_id,
|
||||
"company_name": cert.company_name,
|
||||
"tax_id": cert.tax_id,
|
||||
"business_license_url": cert.business_license_url,
|
||||
"status": cert.status.value,
|
||||
"reject_reason": cert.reject_reason,
|
||||
"created_at": cert.created_at.isoformat() if cert.created_at else None,
|
||||
"updated_at": cert.updated_at.isoformat() if cert.updated_at else None,
|
||||
}
|
||||
|
||||
async def list_all(self, page: int, size: int, status: Optional[str] = None) -> Dict[str, Any]:
|
||||
query = select(Certification).order_by(desc(Certification.created_at))
|
||||
if status:
|
||||
query = query.where(Certification.status == CertStatus(status))
|
||||
offset = (page - 1) * size
|
||||
result = await self.db.execute(query.offset(offset).limit(size))
|
||||
certs = result.scalars().all()
|
||||
total_result = await self.db.execute(
|
||||
select(Certification).where(Certification.status == CertStatus(status)) if status else select(Certification)
|
||||
)
|
||||
total = len(total_result.scalars().all())
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"id": str(c.id),
|
||||
"user_id": str(c.user_id),
|
||||
"cert_type": c.cert_type.value,
|
||||
"personal_name": c.personal_name,
|
||||
"personal_id": c.personal_id,
|
||||
"company_name": c.company_name,
|
||||
"tax_id": c.tax_id,
|
||||
"status": c.status.value,
|
||||
"reject_reason": c.reject_reason,
|
||||
"created_at": c.created_at.isoformat() if c.created_at else None,
|
||||
}
|
||||
for c in certs
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
async def review(self, cert_id: str, action: str, reason: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||
result = await self.db.execute(
|
||||
select(Certification).where(Certification.id == uuid.UUID(cert_id))
|
||||
)
|
||||
cert = result.scalar_one_or_none()
|
||||
if not cert:
|
||||
return None
|
||||
if action == "approve":
|
||||
cert.status = CertStatus.approved
|
||||
else:
|
||||
cert.status = CertStatus.rejected
|
||||
cert.reject_reason = reason
|
||||
await self.db.flush()
|
||||
return {"id": str(cert.id), "status": cert.status.value}
|
||||
|
||||
async def _get_pending(self, user_id: str) -> Optional[Certification]:
|
||||
result = await self.db.execute(
|
||||
select(Certification)
|
||||
.where(
|
||||
Certification.user_id == uuid.UUID(user_id),
|
||||
Certification.status == CertStatus.pending,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
@@ -0,0 +1,272 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from app.ai.router import get_ai_router
|
||||
from app.services.search_web import search_companies, fetch_page_text
|
||||
from app.services.mcp_search_client import mcp_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ANALYZE_MATCH_PROMPT = """你是外贸客户分析专家。分析目标公司的业务描述,判断其与用户产品的匹配度。
|
||||
|
||||
请以 JSON 格式返回(不要用 markdown 代码块标记):
|
||||
{
|
||||
"match_score": 0-100,
|
||||
"match_reason": "为什么匹配/不匹配",
|
||||
"company_summary": "这家公司的主要业务",
|
||||
"product_fit": "产品匹配度说明",
|
||||
"contact_info": {
|
||||
"emails": ["找到的邮箱"],
|
||||
"phones": ["找到的电话"],
|
||||
"social": ["LinkedIn等社媒链接"]
|
||||
}
|
||||
}
|
||||
|
||||
只返回 JSON,不要其他内容。"""
|
||||
|
||||
|
||||
class DiscoveryService:
|
||||
def __init__(self):
|
||||
ai_router = get_ai_router()
|
||||
self.ai = ai_router
|
||||
self._ai_available = len(ai_router.providers) > 0
|
||||
|
||||
async def search(self, product_description: str, target_market: str) -> Dict[str, Any]:
|
||||
queries = self._build_queries(product_description, target_market)
|
||||
all_results = await self._mcp_search_all(queries)
|
||||
if all_results:
|
||||
return {
|
||||
"companies": all_results[:15],
|
||||
"query": product_description,
|
||||
"market": target_market,
|
||||
"provider": "mcp_search",
|
||||
}
|
||||
|
||||
all_results = await self._google_search_all(queries)
|
||||
if all_results:
|
||||
return {
|
||||
"companies": all_results[:15],
|
||||
"query": product_description,
|
||||
"market": target_market,
|
||||
"provider": "web_search",
|
||||
}
|
||||
|
||||
logger.info("No real search results, using AI strategy")
|
||||
return await self._ai_strategy(product_description, target_market)
|
||||
|
||||
async def analyze(self, company_url: str, product_description: str) -> Dict[str, Any]:
|
||||
page_text = await fetch_page_text(company_url)
|
||||
company_info = {"url": company_url}
|
||||
if page_text:
|
||||
company_info["page_text"] = page_text[:2500]
|
||||
|
||||
if not self._ai_available:
|
||||
return self._template_analysis(company_url)
|
||||
|
||||
prompt = f"""用户的产品:{product_description}
|
||||
|
||||
目标公司信息:
|
||||
URL: {company_url}
|
||||
网页内容:{page_text[:2500] if page_text else "无法获取网页内容"}
|
||||
|
||||
请分析该公司的业务与用户产品的匹配度。"""
|
||||
try:
|
||||
result = await self.ai.chat(prompt, system_prompt=ANALYZE_MATCH_PROMPT)
|
||||
content = result.get("reply", "")
|
||||
parsed = self._extract_json(content)
|
||||
if parsed:
|
||||
parsed["url"] = company_url
|
||||
parsed["provider"] = result.get("provider_used", "unknown")
|
||||
return parsed
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.warning(f"Analysis AI parse failed: {e}")
|
||||
return self._template_analysis(company_url)
|
||||
|
||||
async def outreach(self, company_info: Dict[str, Any], product_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if not self._ai_available:
|
||||
return self._template_outreach(company_info, product_info)
|
||||
|
||||
prompt = f"""目标公司信息:
|
||||
{json.dumps(company_info, ensure_ascii=False)}
|
||||
|
||||
我的产品信息:
|
||||
{json.dumps(product_info, ensure_ascii=False)}
|
||||
|
||||
请生成个性化触达文案。"""
|
||||
system = """你是外贸开发信专家。根据目标公司信息和你的产品,生成个性化触达文案。
|
||||
|
||||
请以 JSON 格式返回(不要用 markdown 代码块标记):
|
||||
{
|
||||
"subject": "邮件标题(如适用)",
|
||||
"linkedin_message": "LinkedIn 私信文案(150字以内)",
|
||||
"whatsapp_message": "WhatsApp 消息文案(100字以内)",
|
||||
"email_body": "邮件正文(含开头问候、自我介绍、价值主张、行动号召、签名)",
|
||||
"key_points": ["客户关注的3个要点"],
|
||||
"tips": ["发送时的建议"]
|
||||
}"""
|
||||
try:
|
||||
result = await self.ai.chat(prompt, system_prompt=system)
|
||||
content = result.get("reply", "")
|
||||
parsed = self._extract_json(content)
|
||||
if parsed:
|
||||
parsed["provider"] = result.get("provider_used", "unknown")
|
||||
return parsed
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.warning(f"Outreach AI parse failed: {e}")
|
||||
return self._template_outreach(company_info, product_info)
|
||||
|
||||
async def _mcp_search_all(self, queries: list) -> list:
|
||||
seen_urls = set()
|
||||
tasks = [asyncio.create_task(mcp_search(q, max_results=6)) for q in queries[:2]]
|
||||
all_results = []
|
||||
try:
|
||||
for coro in asyncio.as_completed(tasks, timeout=8):
|
||||
try:
|
||||
results = await coro
|
||||
for r in results:
|
||||
url = r.get("url", "").rstrip("/")
|
||||
if url and url not in seen_urls:
|
||||
seen_urls.add(url)
|
||||
all_results.append(r)
|
||||
except (asyncio.TimeoutError, Exception) as e:
|
||||
logger.debug(f"MCP search query failed: {e}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("MCP search overall timeout")
|
||||
finally:
|
||||
for t in tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
if all_results:
|
||||
return self._dedup_and_filter(all_results)[:15]
|
||||
return []
|
||||
|
||||
def _dedup_and_filter(self, results: list) -> list:
|
||||
seen = set()
|
||||
filtered = []
|
||||
for r in results:
|
||||
url = r.get("url", "").rstrip("/")
|
||||
title = r.get("title", "")
|
||||
if not url or url in seen:
|
||||
continue
|
||||
seen.add(url)
|
||||
s = url.split("/")[2] if "://" in url else url
|
||||
hostname = s.split(":")[0].lower() if ":" in s else s.lower()
|
||||
if any(tld in hostname for tld in [".cn", ".com.cn", ".edu", ".ac.", ".gov"]):
|
||||
continue
|
||||
if any(domain in hostname for domain in
|
||||
["sciencedirect", "mdpi", "springer", "wiley", "acm.org",
|
||||
"ieee.org", "researchgate", "nature.com", "oup.com",
|
||||
"sagepub", "tandfonline", "ncbi", "semanticscholar",
|
||||
"britannica", "dictionary", "cambridge", "iciba", "wikipedia"]):
|
||||
continue
|
||||
filtered.append(r)
|
||||
return filtered
|
||||
|
||||
async def _google_search_all(self, queries: list) -> list:
|
||||
all_results = []
|
||||
seen_urls = set()
|
||||
for q in queries[:3]:
|
||||
results = await search_companies(q, max_results=8)
|
||||
for r in results:
|
||||
url = r["url"].rstrip("/")
|
||||
if url not in seen_urls:
|
||||
seen_urls.add(url)
|
||||
all_results.append(r)
|
||||
if len(all_results) >= 15:
|
||||
break
|
||||
return self._dedup_and_filter(all_results)[:15]
|
||||
|
||||
def _build_queries(self, product: str, market: str) -> list:
|
||||
return [
|
||||
f"{product} importer {market}",
|
||||
f"{product} distributor {market}",
|
||||
f"{product} wholesale buyer {market}",
|
||||
f"{product} procurement {market}",
|
||||
f"{product} company {market}",
|
||||
f"buy {product} from {market}",
|
||||
f"{product} supply chain {market}",
|
||||
f"top {product} manufacturers {market}",
|
||||
f"{product} import export {market}",
|
||||
f"{product} trading company {market}",
|
||||
]
|
||||
|
||||
def _extract_json(self, text: str) -> Optional[dict]:
|
||||
text = text.strip()
|
||||
for prefix in ["```json", "```", "```JSON"]:
|
||||
if text.startswith(prefix):
|
||||
text = text[len(prefix):]
|
||||
for suffix in ["```"]:
|
||||
if text.endswith(suffix):
|
||||
text = text[:-len(suffix)]
|
||||
text = text.strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
import re
|
||||
brace = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if brace >= 0 and end > brace:
|
||||
try:
|
||||
return json.loads(text[brace:end+1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def _ai_strategy(self, product: str, market: str) -> Dict[str, Any]:
|
||||
if not self._ai_available:
|
||||
return self._template_strategy(product, market)
|
||||
system = """你是外贸客户发现专家。根据用户的产品和目标市场,分析出潜在买家画像和获取策略。
|
||||
|
||||
请以 JSON 格式返回(不要用 markdown 代码块标记):
|
||||
{
|
||||
"buyer_personas": [{"type": "", "description": "", "channels": [], "search_queries": []}],
|
||||
"strategy": "",
|
||||
"tips": []
|
||||
}"""
|
||||
prompt = f"产品:{product}\n目标市场:{market}\n请分析潜在买家画像和获取策略。"
|
||||
try:
|
||||
result = await self.ai.chat(prompt, system_prompt=system)
|
||||
content = result.get("reply", "")
|
||||
parsed = self._extract_json(content)
|
||||
if parsed:
|
||||
parsed["provider"] = result.get("provider_used", "unknown")
|
||||
return parsed
|
||||
return self._template_strategy(product, market)
|
||||
except Exception as e:
|
||||
logger.warning(f"AI strategy failed: {e}")
|
||||
return self._template_strategy(product, market)
|
||||
|
||||
def _template_strategy(self, product: str, market: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"buyer_personas": [
|
||||
{"type": "进口商/批发商", "description": f"从中国进口{product}并在{market}批发的贸易商", "channels": ["LinkedIn", "Google"], "search_queries": [f"{product} importer {market}"]},
|
||||
{"type": "品牌商/OEM买家", "description": f"在{market}销售自有品牌{product}的公司", "channels": ["LinkedIn", "行业展会"], "search_queries": [f"{product} manufacturer {market}"]},
|
||||
],
|
||||
"strategy": f"建议在 LinkedIn 和 Google 搜索 {market} 的 {product} 相关公司",
|
||||
"tips": ["使用多个搜索词", "找到公司后在 LinkedIn 找决策人"],
|
||||
"provider": "template",
|
||||
}
|
||||
|
||||
def _template_analysis(self, url: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"match_score": 50,
|
||||
"match_reason": "无法获取网页内容进行分析,建议手动查看",
|
||||
"url": url,
|
||||
"provider": "template",
|
||||
}
|
||||
|
||||
def _template_outreach(self, company: Dict[str, Any], product: Dict[str, Any]) -> Dict[str, Any]:
|
||||
company_name = company.get("name", "")
|
||||
product_name = product.get("name", "")
|
||||
return {
|
||||
"subject": f"关于{product_name}的合作机会",
|
||||
"linkedin_message": f"您好!了解到贵司{company_name}在经营相关业务,我们专业生产{product_name},品质稳定,价格有竞争力。如有兴趣,我可以发详细资料供参考。",
|
||||
"whatsapp_message": f"Hello! We are a professional {product_name} manufacturer. Interested in exploring cooperation? Happy to share details.",
|
||||
"email_body": f"Dear {company_name} team,\n\nWe are a professional {product_name} manufacturer with competitive pricing and consistent quality. Would you be open to a quick chat to explore potential cooperation?\n\nBest regards,\n[Your Name]",
|
||||
"key_points": ["产品质量有保障", "价格有竞争力", "可定制"],
|
||||
"tips": ["发送前先了解对方背景", "LinkedIn 消息要简短"],
|
||||
"provider": "template",
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
from app.models.invoice import Invoice, InvoiceType, InvoiceStatus
|
||||
from app.models.certification import Certification, CertStatus
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
class InvoiceService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def apply(self, user_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
invoice_type = InvoiceType(data["invoice_type"])
|
||||
certification_id = None
|
||||
cert = None
|
||||
|
||||
if invoice_type == InvoiceType.individual:
|
||||
cert_result = await self.db.execute(
|
||||
select(Certification)
|
||||
.where(
|
||||
Certification.user_id == uuid.UUID(user_id),
|
||||
Certification.cert_type == "individual",
|
||||
Certification.status == CertStatus.approved,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
cert = cert_result.scalar_one_or_none()
|
||||
if not cert:
|
||||
return {"error": "请先完成个人实名认证"}
|
||||
certification_id = cert.id
|
||||
|
||||
else:
|
||||
cert_result = await self.db.execute(
|
||||
select(Certification)
|
||||
.where(
|
||||
Certification.user_id == uuid.UUID(user_id),
|
||||
Certification.cert_type == "enterprise",
|
||||
Certification.status == CertStatus.approved,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
cert = cert_result.scalar_one_or_none()
|
||||
if not cert:
|
||||
return {"error": "请先完成企业认证"}
|
||||
certification_id = cert.id
|
||||
|
||||
invoice = Invoice(
|
||||
user_id=uuid.UUID(user_id),
|
||||
certification_id=certification_id,
|
||||
invoice_type=invoice_type,
|
||||
title=data["title"],
|
||||
tax_id=data.get("tax_id"),
|
||||
amount=data["amount"],
|
||||
status=InvoiceStatus.pending,
|
||||
)
|
||||
self.db.add(invoice)
|
||||
await self.db.flush()
|
||||
return {"id": str(invoice.id), "status": invoice.status.value}
|
||||
|
||||
async def list_user(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
result = await self.db.execute(
|
||||
select(Invoice)
|
||||
.where(Invoice.user_id == uuid.UUID(user_id))
|
||||
.order_by(desc(Invoice.created_at))
|
||||
)
|
||||
invoices = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": str(inv.id),
|
||||
"invoice_type": inv.invoice_type.value,
|
||||
"title": inv.title,
|
||||
"tax_id": inv.tax_id,
|
||||
"amount": inv.amount,
|
||||
"status": inv.status.value,
|
||||
"reject_reason": inv.reject_reason,
|
||||
"issued_at": inv.issued_at.isoformat() if inv.issued_at else None,
|
||||
"created_at": inv.created_at.isoformat() if inv.created_at else None,
|
||||
}
|
||||
for inv in invoices
|
||||
]
|
||||
|
||||
async def list_all(self, page: int, size: int, status: Optional[str] = None) -> Dict[str, Any]:
|
||||
query = select(Invoice).order_by(desc(Invoice.created_at))
|
||||
if status:
|
||||
query = query.where(Invoice.status == InvoiceStatus(status))
|
||||
offset = (page - 1) * size
|
||||
result = await self.db.execute(query.offset(offset).limit(size))
|
||||
invoices = result.scalars().all()
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"id": str(inv.id),
|
||||
"user_id": str(inv.user_id),
|
||||
"invoice_type": inv.invoice_type.value,
|
||||
"title": inv.title,
|
||||
"tax_id": inv.tax_id,
|
||||
"amount": inv.amount,
|
||||
"status": inv.status.value,
|
||||
"reject_reason": inv.reject_reason,
|
||||
"issued_at": inv.issued_at.isoformat() if inv.issued_at else None,
|
||||
"created_at": inv.created_at.isoformat() if inv.created_at else None,
|
||||
}
|
||||
for inv in invoices
|
||||
],
|
||||
"total": len(invoices),
|
||||
"page": page,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
async def process(self, invoice_id: str, action: str, reason: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||
result = await self.db.execute(
|
||||
select(Invoice).where(Invoice.id == uuid.UUID(invoice_id))
|
||||
)
|
||||
inv = result.scalar_one_or_none()
|
||||
if not inv:
|
||||
return None
|
||||
if action == "issue":
|
||||
inv.status = InvoiceStatus.issued
|
||||
inv.issued_at = datetime.utcnow()
|
||||
else:
|
||||
inv.status = InvoiceStatus.rejected
|
||||
inv.reject_reason = reason
|
||||
await self.db.flush()
|
||||
return {"id": str(inv.id), "status": inv.status.value}
|
||||
@@ -0,0 +1,101 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from mcp.client.stdio import stdio_client, StdioServerParameters
|
||||
from mcp.client.session import ClientSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVER_SCRIPT = os.path.join(os.path.dirname(__file__), "mcp_search_server.py")
|
||||
VENV_PYTHON = sys.executable
|
||||
|
||||
|
||||
class MCPClientManager:
|
||||
_instance: Optional["MCPClientManager"] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._read = None
|
||||
self._write = None
|
||||
self._ctx = None
|
||||
self._initialized = False
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "MCPClientManager":
|
||||
if cls._instance is None or not cls._instance._initialized:
|
||||
async with cls._lock:
|
||||
if cls._instance is None or not cls._instance._initialized:
|
||||
cls._instance = cls()
|
||||
try:
|
||||
await asyncio.wait_for(cls._instance._start(), timeout=10)
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP init failed: {e}")
|
||||
cls._instance = None
|
||||
raise
|
||||
return cls._instance
|
||||
|
||||
async def _start(self):
|
||||
params = StdioServerParameters(
|
||||
command=VENV_PYTHON,
|
||||
args=[SERVER_SCRIPT],
|
||||
)
|
||||
self._ctx = stdio_client(params)
|
||||
self._read, self._write = await asyncio.wait_for(
|
||||
self._ctx.__aenter__(), timeout=5
|
||||
)
|
||||
self._session = await asyncio.wait_for(
|
||||
ClientSession(self._read, self._write).__aenter__(), timeout=5
|
||||
)
|
||||
await asyncio.wait_for(self._session.initialize(), timeout=5)
|
||||
self._initialized = True
|
||||
logger.info("MCP search client initialized")
|
||||
|
||||
async def search(self, query: str, max_results: int = 10) -> List[Dict[str, str]]:
|
||||
if not self._initialized or self._session is None:
|
||||
logger.warning("MCP client not initialized")
|
||||
return []
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.call_tool(
|
||||
"web_search",
|
||||
{"query": query, "max_results": max_results},
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
if result.content and len(result.content) > 0:
|
||||
text = result.content[0].text
|
||||
data = json.loads(text)
|
||||
return data.get("results", [])
|
||||
return []
|
||||
except (asyncio.TimeoutError, Exception) as e:
|
||||
logger.warning(f"MCP search call failed: {e}")
|
||||
return []
|
||||
|
||||
async def close(self):
|
||||
self._initialized = False
|
||||
MCPClientManager._instance = None
|
||||
if self._session:
|
||||
try:
|
||||
await self._session.__aexit__(None, None, None)
|
||||
except (BaseExceptionGroup, RuntimeError, Exception):
|
||||
pass
|
||||
if self._ctx:
|
||||
try:
|
||||
await self._ctx.__aexit__(None, None, None)
|
||||
except (BaseExceptionGroup, RuntimeError, Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def mcp_search(query: str, max_results: int = 10) -> List[Dict[str, str]]:
|
||||
try:
|
||||
mgr = await MCPClientManager.get_instance()
|
||||
return await mgr.search(query, max_results)
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP search failed: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,105 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List, Dict
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
NODE_BIN = "/usr/bin/node"
|
||||
|
||||
BING_SCRIPT = r"""
|
||||
const p = require('puppeteer');
|
||||
(async () => {
|
||||
const b = await p.launch({headless:true,args:['--no-sandbox','--disable-setuid-sandbox','--disable-blink-features=AutomationControlled']});
|
||||
const page = await b.newPage();
|
||||
await page.setUserAgent('Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36');
|
||||
await page.setExtraHTTPHeaders({'Accept-Language':'en-US,en;q=0.9'});
|
||||
await page.evaluateOnNewDocument(() => { Object.defineProperty(navigator, 'webdriver', {get:()=>undefined}); });
|
||||
const q = process.argv[process.argv.length - 2];
|
||||
const max = parseInt(process.argv[process.argv.length - 1] || '10', 10);
|
||||
const sk = ['bing.com','google.com','facebook.com','twitter.com','instagram.com','youtube.com','reddit.com','amazon.com','wikipedia.org','baidu.com','linkedin.com','pinterest.com','ebay.com','walmart.com','w3.org','whatsapp.com','wechat.com','qq.com','taobao.com','tmall.com','alibaba.com','alipay.com','dict','dictionary','translate','zhihu.com','baike.baidu.com','sogou.com','163.com','sohu.com','sina.com','iciba.com','cambridge','britannica','sciencedirect','mdpi.com','springer','wiley.com','acm.org','ieee.org','researchgate','semanticscholar','ncbi.nlm.nih','nature.com','oup.com','sagepub','tandfonline'];
|
||||
try {
|
||||
await page.goto('https://cn.bing.com/search?q=' + encodeURIComponent(q) + '&setlang=en-US&cc=US', {waitUntil:'domcontentloaded',timeout:10000});
|
||||
await page.waitForSelector('.b_algo', {timeout:5000}).catch(()=>{});
|
||||
const results = await page.evaluate((m, sk) => {
|
||||
const reCJK = /[\u4e00-\u9fff\u3400-\u4dbf]/;
|
||||
const found = []; const seen = new Set();
|
||||
document.querySelectorAll('li.b_algo').forEach(li => {
|
||||
const a = li.querySelector('h2 a'); if (!a) return;
|
||||
let url = (a.href || '').replace(/\/$/,'');
|
||||
if (!url.startsWith('http') || seen.has(url)) return;
|
||||
seen.add(url);
|
||||
if (sk.some(d => url.includes(d))) return;
|
||||
const hostname = url.replace(/^https?:\/\//,'').split('/')[0];
|
||||
if (hostname.endsWith('.cn') || hostname.endsWith('.com.cn') || hostname.endsWith('.edu') || hostname.endsWith('.ac')) return;
|
||||
const title = (a.textContent||'').trim().substring(0,100);
|
||||
if (reCJK.test(title)) return;
|
||||
const s = li.querySelector('.b_caption p, .b_lineclamp2');
|
||||
found.push({title, url, snippet:s?s.textContent.trim().substring(0,200):''});
|
||||
});
|
||||
return found.slice(0,m);
|
||||
}, max, sk);
|
||||
console.log(JSON.stringify(results));
|
||||
} catch(e) { console.log('[]'); }
|
||||
await b.close();
|
||||
})();
|
||||
"""
|
||||
|
||||
|
||||
BING_SCRIPT_FILE = os.path.join(os.path.dirname(__file__), "_bing_search.js")
|
||||
NODE_MODULES = os.path.join(PROJECT_ROOT, "node_modules")
|
||||
|
||||
|
||||
async def search_bing(query: str, max_results: int = 10) -> List[Dict[str, str]]:
|
||||
try:
|
||||
with open(BING_SCRIPT_FILE, "w") as f:
|
||||
f.write(BING_SCRIPT)
|
||||
env = os.environ.copy()
|
||||
env["NODE_PATH"] = NODE_MODULES
|
||||
result = subprocess.run(
|
||||
[NODE_BIN, BING_SCRIPT_FILE, query, str(max_results)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
cwd=PROJECT_ROOT,
|
||||
env=env,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(f"Bing search failed: {result.stderr[:300]}")
|
||||
return []
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("["):
|
||||
return json.loads(line)
|
||||
return []
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Bing search timed out")
|
||||
return []
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.warning(f"Bing search error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
mcp = FastMCP("trade-search", log_level="WARNING")
|
||||
|
||||
|
||||
@mcp.tool(
|
||||
name="web_search",
|
||||
description="Search the web for companies, buyers, or business information. Returns title, URL, and snippet for each result. Useful for finding potential customers, researching companies, or gathering market intelligence.",
|
||||
)
|
||||
async def web_search(query: str, max_results: int = 10) -> str:
|
||||
results = await search_bing(query, max_results)
|
||||
if not results:
|
||||
return json.dumps({"results": [], "error": None})
|
||||
return json.dumps({"results": results, "error": None})
|
||||
|
||||
|
||||
def main():
|
||||
asyncio.run(mcp.run_stdio_async())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,73 @@
|
||||
from typing import List, Dict, Optional
|
||||
import httpx
|
||||
import json
|
||||
import logging
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GOOGLE_CSE_URL = "https://www.googleapis.com/customsearch/v1"
|
||||
|
||||
IGNORE_DOMAINS = [
|
||||
"google.com", "facebook.com", "twitter.com", "instagram.com",
|
||||
"youtube.com", "reddit.com", "amazon.com", "ebay.com",
|
||||
"wikipedia.org", "linkedin.com", "pinterest.com", "baidu.com",
|
||||
"bing.com", "duckduckgo.com",
|
||||
]
|
||||
|
||||
|
||||
async def search_companies(query: str, max_results: int = 10) -> List[Dict[str, str]]:
|
||||
api_key = settings.GOOGLE_API_KEY or ""
|
||||
cse_id = settings.GOOGLE_CSE_ID or ""
|
||||
if api_key and cse_id:
|
||||
return await _google_cse(query, max_results, api_key, cse_id)
|
||||
logger.info("Google CSE not configured, using template results")
|
||||
return []
|
||||
|
||||
|
||||
async def _google_cse(query: str, max_results: int, api_key: str, cse_id: str) -> List[Dict[str, str]]:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(GOOGLE_CSE_URL, params={
|
||||
"key": api_key,
|
||||
"cx": cse_id,
|
||||
"q": query,
|
||||
"num": min(max_results, 10),
|
||||
"lr": "lang_en",
|
||||
})
|
||||
if resp.status_code != 200:
|
||||
logger.warning(f"Google CSE returned {resp.status_code}")
|
||||
return []
|
||||
data = resp.json()
|
||||
results = []
|
||||
for item in data.get("items", []):
|
||||
url = item.get("link", "")
|
||||
if not url or any(d in url for d in IGNORE_DOMAINS):
|
||||
continue
|
||||
results.append({
|
||||
"title": item.get("title", url)[:100],
|
||||
"url": url.rstrip("/"),
|
||||
"snippet": item.get("snippet", "")[:200],
|
||||
})
|
||||
return results[:max_results]
|
||||
except Exception as e:
|
||||
logger.warning(f"Google CSE failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def fetch_page_text(url: str) -> Optional[str]:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||
if resp.status_code == 200:
|
||||
from bs4 import BeautifulSoup
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
for tag in soup(["script", "style", "nav", "footer", "header"]):
|
||||
tag.decompose()
|
||||
text = soup.get_text(separator=" ", strip=True)
|
||||
import re
|
||||
text = re.sub(r"\s+", " ", text)[:3000]
|
||||
return text if len(text) > 100 else None
|
||||
except Exception as e:
|
||||
logger.debug(f"fetch {url} failed: {e}")
|
||||
return None
|
||||
@@ -1,14 +1,15 @@
|
||||
fastapi==0.100.0
|
||||
uvicorn==0.23.2
|
||||
fastapi==0.136.1
|
||||
uvicorn==0.47.0
|
||||
sqlalchemy==1.4.48
|
||||
asyncpg==0.27.0
|
||||
pydantic==1.10.12
|
||||
pydantic==2.13.4
|
||||
pydantic-settings==2.14.1
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
python-multipart==0.0.6
|
||||
redis==4.5.5
|
||||
celery==5.2.7
|
||||
httpx==0.23.3
|
||||
httpx>=0.23.3,<0.28
|
||||
openai==1.12.0
|
||||
anthropic==0.8.1
|
||||
jinja2==3.1.2
|
||||
@@ -19,4 +20,6 @@ pytest-asyncio==0.21.1
|
||||
pytest-cov==4.1.0
|
||||
weasyprint==60.2
|
||||
openpyxl==3.1.2
|
||||
edge-tts>=6.0.0
|
||||
edge-tts>=6.0.0
|
||||
mcp==1.27.1
|
||||
starlette==1.0.0
|
||||
Reference in New Issue
Block a user