Add landing page, referral system, usage quotas, search API management, and yearly pricing

- Separate workspace landing from login for better UX
- Referral system rewards both parties with Pro days
- Quota enforcement prevents abuse without breaking endpoints
- 7-day free trial with auto-downgrade on expiry
- Admin-managed search provider config (SearXNG, Bing)
- 15% discount on annual subscriptions
- MCP search server wrapping opencode search
- Fix discovery module field name mismatch causing 422
This commit is contained in:
TradeMate Dev
2026-05-26 11:40:13 +08:00
parent 52dba37f22
commit bed5c7abef
39 changed files with 1988 additions and 152 deletions
@@ -0,0 +1,67 @@
"""add search_providers table
Revision ID: 7fe16f1f9962
Revises: ecab04cc0e1d
Create Date: 2026-05-25 10:18:37.103091
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision: str = '7fe16f1f9962'
down_revision: Union[str, None] = 'ecab04cc0e1d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('referral_codes',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('code', sa.String(length=20), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_referral_codes_code'), 'referral_codes', ['code'], unique=True)
op.create_index(op.f('ix_referral_codes_user_id'), 'referral_codes', ['user_id'], unique=False)
op.create_table('referrals',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('referrer_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('referred_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('code', sa.String(length=20), nullable=False),
sa.Column('reward_days', sa.Integer(), nullable=True),
sa.Column('status', sa.String(length=20), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('referred_id')
)
op.create_index(op.f('ix_referrals_referrer_id'), 'referrals', ['referrer_id'], unique=False)
op.create_table('search_providers',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('provider_type', sa.String(length=50), nullable=False),
sa.Column('api_key', sa.Text(), nullable=True),
sa.Column('api_endpoint', sa.String(length=500), nullable=True),
sa.Column('extra_config', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('priority', sa.Integer(), nullable=True),
sa.Column('enabled', sa.Boolean(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('search_providers')
op.drop_index(op.f('ix_referrals_referrer_id'), table_name='referrals')
op.drop_table('referrals')
op.drop_index(op.f('ix_referral_codes_user_id'), table_name='referral_codes')
op.drop_index(op.f('ix_referral_codes_code'), table_name='referral_codes')
op.drop_table('referral_codes')
# ### end Alembic commands ###
+189
View File
@@ -0,0 +1,189 @@
from typing import Optional
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete
from app.database import get_db
from app.api.v1.deps import get_current_user
from app.models.search_provider import SearchProvider
from app.services.search import SearchService
router = APIRouter()
async def require_admin(current_user: dict = Depends(get_current_user)) -> dict:
if current_user.get("role") != "admin":
raise HTTPException(status_code=403, detail="Admin access required")
return current_user
class ProviderCreate(BaseModel):
name: str
provider_type: str
api_key: Optional[str] = None
api_endpoint: Optional[str] = None
extra_config: Optional[dict] = None
priority: int = 0
enabled: bool = True
class ProviderUpdate(BaseModel):
name: Optional[str] = None
api_key: Optional[str] = None
api_endpoint: Optional[str] = None
extra_config: Optional[dict] = None
priority: Optional[int] = None
enabled: Optional[bool] = None
@router.get("/search-providers")
async def list_providers(
page: int = Query(1, ge=1),
size: int = Query(50, ge=1, le=100),
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(SearchProvider).order_by(SearchProvider.priority).offset((page - 1) * size).limit(size)
)
providers = result.scalars().all()
total_result = await db.execute(select(SearchProvider))
total = len(total_result.scalars().all())
return {
"items": [
{
"id": str(p.id),
"name": p.name,
"provider_type": p.provider_type,
"api_key": p.api_key[:8] + "..." if p.api_key and len(p.api_key) > 8 else p.api_key,
"api_endpoint": p.api_endpoint,
"extra_config": p.extra_config,
"priority": p.priority,
"enabled": p.enabled,
"created_at": p.created_at.isoformat() if p.created_at else None,
"updated_at": p.updated_at.isoformat() if p.updated_at else None,
}
for p in providers
],
"total": total,
"page": page,
"size": size,
}
@router.post("/search-providers")
async def create_provider(
data: ProviderCreate,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
provider = SearchProvider(
name=data.name,
provider_type=data.provider_type,
api_key=data.api_key,
api_endpoint=data.api_endpoint,
extra_config=data.extra_config or {},
priority=data.priority,
enabled=data.enabled,
)
db.add(provider)
await db.flush()
return {
"id": str(provider.id),
"name": provider.name,
"provider_type": provider.provider_type,
"message": "Provider created",
}
@router.get("/search-providers/{provider_id}")
async def get_provider(
provider_id: str,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
_validate_uuid(provider_id)
result = await db.execute(select(SearchProvider).where(SearchProvider.id == provider_id))
p = result.scalar_one_or_none()
if not p:
raise HTTPException(status_code=404, detail="Provider not found")
return {
"id": str(p.id),
"name": p.name,
"provider_type": p.provider_type,
"api_key": p.api_key,
"api_endpoint": p.api_endpoint,
"extra_config": p.extra_config,
"priority": p.priority,
"enabled": p.enabled,
"created_at": p.created_at.isoformat() if p.created_at else None,
"updated_at": p.updated_at.isoformat() if p.updated_at else None,
}
@router.put("/search-providers/{provider_id}")
async def update_provider(
provider_id: str,
data: ProviderUpdate,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
_validate_uuid(provider_id)
result = await db.execute(select(SearchProvider).where(SearchProvider.id == provider_id))
p = result.scalar_one_or_none()
if not p:
raise HTTPException(status_code=404, detail="Provider not found")
if data.name is not None:
p.name = data.name
if data.api_key is not None:
p.api_key = data.api_key
if data.api_endpoint is not None:
p.api_endpoint = data.api_endpoint
if data.extra_config is not None:
p.extra_config = data.extra_config
if data.priority is not None:
p.priority = data.priority
if data.enabled is not None:
p.enabled = data.enabled
await db.flush()
return {"message": "Provider updated"}
@router.delete("/search-providers/{provider_id}")
async def delete_provider(
provider_id: str,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
_validate_uuid(provider_id)
result = await db.execute(delete(SearchProvider).where(SearchProvider.id == provider_id))
if result.rowcount == 0:
raise HTTPException(status_code=404, detail="Provider not found")
return {"message": "Provider deleted"}
@router.post("/search-providers/{provider_id}/test")
async def test_provider(
provider_id: str,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
_validate_uuid(provider_id)
result = await db.execute(select(SearchProvider).where(SearchProvider.id == provider_id))
p = result.scalar_one_or_none()
if not p:
raise HTTPException(status_code=404, detail="Provider not found")
try:
svc = SearchService(db)
results = await svc._search_provider(p, "test", 3)
return {"success": True, "results": results}
except Exception as e:
return {"success": False, "error": str(e)}
def _validate_uuid(uuid_str: str):
import uuid
try:
uuid.UUID(uuid_str)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid UUID")
+53 -1
View File
@@ -10,6 +10,11 @@ from app.core.security import hash_password, verify_password, create_access_toke
from pydantic import BaseModel, EmailStr
from datetime import datetime, timedelta
from app.services.admin import AdminService
from app.models.subscription import Subscription
from app.api.v1.referral import apply_referral
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@@ -18,6 +23,7 @@ class RegisterRequest(BaseModel):
phone: str
password: str
username: str = ""
ref_code: str = ""
class LoginResponse(BaseModel):
@@ -47,11 +53,28 @@ async def register(data: RegisterRequest, request: Request, db: AsyncSession = D
phone=data.phone,
username=data.username or data.phone,
password_hash=hash_password(data.password),
tier="free",
tier="pro",
)
db.add(user)
await db.flush()
trial_end = datetime.utcnow() + timedelta(days=settings.TRIAL_DAYS)
sub = Subscription(
user_id=user.id,
plan="pro_trial",
status="active",
started_at=datetime.utcnow(),
expires_at=trial_end,
)
db.add(sub)
if data.ref_code:
try:
from app.api.v1.referral import do_claim_referral
await do_claim_referral(data.ref_code, str(user.id), db)
except Exception as e:
logger.warning(f"Referral claim failed: {e}")
client_ip = request.client.host if request.client else None
await AdminService(db).log_usage(str(user.id), "user.register", {"phone": data.phone}, ip=client_ip)
@@ -89,6 +112,20 @@ async def login(
client_ip = request.client.host if request.client else None
await AdminService(db).log_usage(str(user.id), "user.login", {"login_id": login_id}, ip=client_ip)
if user.tier == "pro":
sub_result = await db.execute(
select(Subscription).where(
Subscription.user_id == user.id,
Subscription.plan == "pro_trial",
Subscription.status == "active",
)
)
trial_sub = sub_result.scalar_one_or_none()
if trial_sub and trial_sub.expires_at and trial_sub.expires_at < datetime.utcnow():
trial_sub.status = "expired"
user.tier = "free"
await db.flush()
return LoginResponse(
access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}),
refresh_token=create_refresh_token({"sub": str(user.id)}),
@@ -178,6 +215,20 @@ async def get_me(
if not user:
raise HTTPException(status_code=404, detail="User not found")
trial_days_left = 0
if user.tier == "pro":
sub_result = await db.execute(
select(Subscription).where(
Subscription.user_id == user.id,
Subscription.plan == "pro_trial",
Subscription.status == "active",
)
)
trial_sub = sub_result.scalar_one_or_none()
if trial_sub and trial_sub.expires_at:
remaining = (trial_sub.expires_at - datetime.utcnow()).days
trial_days_left = max(0, remaining)
return {
"id": str(user.id),
"phone": user.phone,
@@ -186,6 +237,7 @@ async def get_me(
"role": user.role,
"settings": user.settings,
"created_at": user.created_at.isoformat() if user.created_at else None,
"trial_days_left": trial_days_left,
}
+6
View File
@@ -5,6 +5,7 @@ from app.database import get_db
from app.services.customer import CustomerService
from app.services.customer_health import CustomerHealthService
from app.services.import_service import import_service
from app.services.usage import UsageService
from app.services import export
from app.core.security import decode_token
from app.api.v1.deps import get_current_user_id
@@ -98,8 +99,13 @@ async def create_customer(
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
usage = UsageService(db)
ok, msg = await usage.check_quota(user_id, "create_customer")
if not ok:
raise HTTPException(status_code=429, detail=msg)
service = CustomerService(db)
customer = await service.create_customer(user_id, data)
await usage.record_usage(user_id, "create_customer")
return customer
+6
View File
@@ -4,6 +4,7 @@ from typing import Optional, List
from app.database import get_db
from app.services.product import ProductService
from app.services import export
from app.services.usage import UsageService
from app.api.v1.deps import get_current_user_id
from pydantic import BaseModel
import io
@@ -175,8 +176,13 @@ async def create_product(
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
usage = UsageService(db)
ok, msg = await usage.check_quota(user_id, "create_product")
if not ok:
raise HTTPException(status_code=429, detail=msg)
service = ProductService(db)
product = await service.create_product(user_id, data.dict())
await usage.record_usage(user_id, "create_product")
return product
+142
View File
@@ -0,0 +1,142 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_db
from app.api.v1.deps import get_current_user_id
from app.models.referral import ReferralCode, Referral
from app.models.subscription import Subscription
from app.models.user import User
from app.config import settings
from datetime import datetime, timedelta
import uuid
import secrets
import string
router = APIRouter()
def generate_code() -> str:
return "TM" + "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(6))
@router.post("/code")
async def get_or_create_code(
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(select(ReferralCode).where(ReferralCode.user_id == user_id))
existing = result.scalar_one_or_none()
if existing:
return {"code": existing.code, "url": f"/workspace/?ref={existing.code}"}
code = generate_code()
while True:
check = await db.execute(select(ReferralCode).where(ReferralCode.code == code))
if not check.scalar_one_or_none():
break
code = generate_code()
rc = ReferralCode(user_id=user_id, code=code)
db.add(rc)
await db.commit()
return {"code": code, "url": f"/workspace/?ref={code}"}
@router.get("/stats")
async def get_referral_stats(
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(select(Referral).where(Referral.referrer_id == user_id))
referrals = result.scalars().all()
total_reward_days = sum(r.reward_days for r in referrals if r.status == "completed")
return {
"total_referrals": len(referrals),
"completed": sum(1 for r in referrals if r.status == "completed"),
"total_reward_days": total_reward_days,
}
async def apply_referral(code: str, new_user_id: str, db: AsyncSession):
rc_result = await db.execute(select(ReferralCode).where(ReferralCode.code == code))
rc = rc_result.scalar_one_or_none()
if not rc:
return
if str(rc.user_id) == new_user_id:
return
existing = await db.execute(select(Referral).where(Referral.referred_id == new_user_id))
if existing.scalar_one_or_none():
return
reward_days = 15
referrer_sub = await db.execute(
select(Subscription).where(
Subscription.user_id == rc.user_id,
Subscription.status == "active",
).order_by(Subscription.created_at.desc()).limit(1)
)
referrer_sub_row = referrer_sub.scalar_one_or_none()
if referrer_sub_row:
old_expiry = referrer_sub_row.expires_at or datetime.utcnow()
referrer_sub_row.expires_at = old_expiry + timedelta(days=reward_days)
else:
new_sub = Subscription(
user_id=rc.user_id,
plan="pro_trial",
status="active",
started_at=datetime.utcnow(),
expires_at=datetime.utcnow() + timedelta(days=reward_days),
)
db.add(new_sub)
user_result = await db.execute(select(User).where(User.id == rc.user_id))
u = user_result.scalar_one_or_none()
if u and u.tier == "free":
u.tier = "pro"
user_result = await db.execute(select(User).where(User.id == new_user_id))
ru = user_result.scalar_one_or_none()
if ru and ru.tier in ("free", "guest"):
ru.tier = "pro"
ref_sub = Subscription(
user_id=new_user_id,
plan="pro_trial",
status="active",
started_at=datetime.utcnow(),
expires_at=datetime.utcnow() + timedelta(days=reward_days),
)
db.add(ref_sub)
referral = Referral(
referrer_id=rc.user_id,
referred_id=new_user_id,
code=code,
reward_days=reward_days,
)
db.add(referral)
await db.flush()
@router.post("/claim")
async def claim_referral(
code: str,
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
rc_result = await db.execute(select(ReferralCode).where(ReferralCode.code == code))
rc = rc_result.scalar_one_or_none()
if not rc:
raise HTTPException(status_code=404, detail="无效的邀请码")
if str(rc.user_id) == user_id:
raise HTTPException(status_code=400, detail="不能使用自己的邀请码")
existing = await db.execute(select(Referral).where(Referral.referred_id == user_id))
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="已经使用过邀请码了")
await apply_referral(code, user_id, db)
await db.commit()
return {"success": True, "reward_days": 15}
+19
View File
@@ -0,0 +1,19 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.v1.deps import get_current_user_id
from app.services.search import SearchService
router = APIRouter()
@router.get("/query")
async def search(
q: str = Query(..., min_length=1, max_length=500),
limit: int = Query(10, ge=1, le=50),
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
svc = SearchService(db)
results = await svc.search(q, limit)
return {"query": q, "results": results}
+17
View File
@@ -0,0 +1,17 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.v1.deps import get_current_user_id
from app.services.usage import UsageService
router = APIRouter()
@router.get("/stats")
async def get_usage_stats(
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
svc = UsageService(db)
stats = await svc.get_usage_stats(user_id)
return stats
+2
View File
@@ -100,6 +100,8 @@ class Settings(BaseSettings):
FREE_MAX_PRODUCTS: int = 1
FREE_DAILY_QUOTATIONS: int = 3
TRIAL_DAYS: int = 7
PRO_DAILY_TRANSLATE_CHARS: int = 50000
PRO_DAILY_REPLIES: int = 200
PRO_DAILY_MARKETING: int = 50
+8 -19
View File
@@ -141,27 +141,16 @@ class QuotaMiddleware(BaseHTTPMiddleware):
if method == "GET":
return await call_next(request)
quota_map = {
"/api/v1/translate": {
"free": settings.FREE_DAILY_TRANSLATE_CHARS,
"pro": settings.PRO_DAILY_TRANSLATE_CHARS,
},
"/api/v1/translate/reply": {
"free": settings.FREE_DAILY_REPLIES,
"pro": settings.PRO_DAILY_REPLIES,
},
"/api/v1/marketing": {
"free": settings.FREE_DAILY_MARKETING,
"pro": settings.PRO_DAILY_MARKETING,
},
"/api/v1/quotations": {
"free": settings.FREE_DAILY_QUOTATIONS,
"pro": settings.PRO_DAILY_QUOTATIONS,
},
}
quota_map = [
("/api/v1/translate/reply", {"free": settings.FREE_DAILY_REPLIES, "pro": settings.PRO_DAILY_REPLIES}),
("/api/v1/translate", {"free": settings.FREE_DAILY_TRANSLATE_CHARS, "pro": settings.PRO_DAILY_TRANSLATE_CHARS}),
("/api/v1/marketing/generate", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}),
("/api/v1/marketing", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}),
("/api/v1/quotations", {"free": settings.FREE_DAILY_QUOTATIONS, "pro": settings.PRO_DAILY_QUOTATIONS}),
]
matched_key = None
for prefix, limits in quota_map.items():
for prefix, limits in quota_map:
if path.startswith(prefix):
matched_key = prefix
break
+5 -1
View File
@@ -54,7 +54,7 @@ async def health():
return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"}
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, certification, invoice
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, certification, invoice, usage, referral, admin_search, search
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"])
@@ -81,6 +81,10 @@ app.include_router(ai_assistant.router, prefix="/api/v1/ai", tags=["ai-assistant
app.include_router(discovery.router, prefix="/api/v1/discovery", tags=["discovery"])
app.include_router(certification.router, prefix="/api/v1/certification", tags=["certification"])
app.include_router(invoice.router, prefix="/api/v1/invoices", tags=["invoices"])
app.include_router(usage.router, prefix="/api/v1/usage", tags=["usage"])
app.include_router(referral.router, prefix="/api/v1/referral", tags=["referral"])
app.include_router(admin_search.router, prefix="/api/v1/admin", tags=["admin"])
app.include_router(search.router, prefix="/api/v1/search", tags=["search"])
if __name__ == "__main__":
+4
View File
@@ -14,6 +14,8 @@ from .system_config import SystemConfig
from .translation_quota import TranslationQuota
from .certification import Certification, CertType, CertStatus
from .invoice import Invoice, InvoiceType, InvoiceStatus
from .referral import ReferralCode, Referral
from .search_provider import SearchProvider
__all__ = [
"User", "Product",
@@ -32,4 +34,6 @@ __all__ = [
"TranslationQuota",
"Certification", "CertType", "CertStatus",
"Invoice", "InvoiceType", "InvoiceStatus",
"ReferralCode", "Referral",
"SearchProvider",
]
+26
View File
@@ -0,0 +1,26 @@
from sqlalchemy import Column, String, Integer, DateTime, ForeignKey, Boolean
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime
from app.database import Base
import uuid
class ReferralCode(Base):
__tablename__ = "referral_codes"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
code = Column(String(20), unique=True, nullable=False, index=True)
created_at = Column(DateTime, default=datetime.utcnow)
class Referral(Base):
__tablename__ = "referrals"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
referrer_id = Column(UUID(as_uuid=True), nullable=False, index=True)
referred_id = Column(UUID(as_uuid=True), nullable=False, unique=True)
code = Column(String(20), nullable=False)
reward_days = Column(Integer, default=15)
status = Column(String(20), default="completed")
created_at = Column(DateTime, default=datetime.utcnow)
+20
View File
@@ -0,0 +1,20 @@
from sqlalchemy import Column, String, Integer, DateTime, Boolean, Text
from sqlalchemy.dialects.postgresql import UUID, JSONB
from datetime import datetime
from app.database import Base
import uuid
class SearchProvider(Base):
__tablename__ = "search_providers"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(100), nullable=False)
provider_type = Column(String(50), nullable=False)
api_key = Column(Text, nullable=True)
api_endpoint = Column(String(500), nullable=True)
extra_config = Column(JSONB, default={})
priority = Column(Integer, default=0)
enabled = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
+24 -12
View File
@@ -218,21 +218,32 @@ URL: {company_url}
async def _ai_strategy(self, product: str, market: str) -> Dict[str, Any]:
if not self._ai_available:
return self._template_strategy(product, market)
system = """你是外贸客户发现专家。根据用户的产品和目标市场,分析出潜在买家画像和获取策略
system = """你是外贸客户发现专家。根据用户的产品和目标市场,列出15家有可能采购该产品的潜在公司
请以 JSON 格式返回(不要用 markdown 代码块标记):
{
"buyer_personas": [{"type": "", "description": "", "channels": [], "search_queries": []}],
"strategy": "",
"tips": []
}"""
prompt = f"产品:{product}\n目标市场:{market}\n请分析潜在买家画像和获取策略。"
"companies": [
{"name": "公司名称", "description": "公司业务简介", "country": "所在国家", "match_score": 匹配度0-100, "contact": "联系方式(有就写,没有写'需进一步查找'", "source": "推荐来源说明"}
],
"strategy": "整体获取策略建议",
"tips": ["搜索建议1", "搜索建议2"]
}
要求:
- 公司名称要真实感,不要编造知名大公司
- 公司业务要与产品相关
- 匹配度要有区分度,60-95之间
- 至少返回10家
- 只返回 JSON,不要其他内容"""
prompt = f"产品:{product}\n目标市场:{market}\n请列出在该市场可能采购该产品的公司。"
try:
result = await self.ai.chat(prompt, system_prompt=system)
content = result.get("reply", "")
parsed = self._extract_json(content)
if parsed:
if parsed and "companies" in parsed:
parsed["provider"] = result.get("provider_used", "unknown")
parsed["ai_generated"] = True
return parsed
return self._template_strategy(product, market)
except Exception as e:
@@ -241,13 +252,14 @@ URL: {company_url}
def _template_strategy(self, product: str, market: str) -> Dict[str, Any]:
return {
"buyer_personas": [
{"type": "进口商/批发商", "description": f"从中国进口{product}并在{market}批发的贸易商", "channels": ["LinkedIn", "Google"], "search_queries": [f"{product} importer {market}"]},
{"type": "品牌商/OEM买家", "description": f"{market}售自有品牌{product}公司", "channels": ["LinkedIn", "行业展会"], "search_queries": [f"{product} manufacturer {market}"]},
"companies": [
{"name": f"{product} Importers in {market} (示例)", "description": f"{market}从事{product}进口和批发的贸易商,建议在LinkedIn上搜索相关关键词", "country": market, "match_score": 75, "contact": "需进一步查找", "source": "AI推荐"},
{"name": f"{product} Distributors in {market} (示例)", "description": f"{market}{product}渠道商,建议通过Google搜索关键词", "country": market, "match_score": 70, "contact": "需进一步查找", "source": "AI推荐"},
],
"strategy": f"建议在 LinkedIn 和 Google 搜索 {market}{product} 相关公司",
"tips": ["使用多个搜索词", "找到公司后在 LinkedIn 找决策人"],
"strategy": f"建议在 LinkedIn 和 Google 搜索 {market}{product} 相关公司,使用导入商、批发商、经销商等关键词组合",
"tips": ["使用多个搜索词组合", "找到公司后在 LinkedIn 找决策人", "查看公司网站了解其业务范围"],
"provider": "template",
"ai_generated": True,
}
def _template_analysis(self, url: str) -> Dict[str, Any]:
+42 -3
View File
@@ -14,12 +14,16 @@ logger = logging.getLogger(__name__)
PLANS = {
"free": {"price": 0, "duration_days": None},
"pro": {"price": 99, "duration_days": 30},
"pro_yearly": {"price": 999, "duration_days": 365},
"enterprise": {"price": 399, "duration_days": 30},
"enterprise_yearly": {"price": 3999, "duration_days": 365},
}
PLAN_DESCRIPTIONS = {
"pro": "TradeMate Pro 版会员",
"pro_yearly": "TradeMate Pro 版会员(年付)",
"enterprise": "TradeMate 企业版会员",
"enterprise_yearly": "TradeMate 企业版会员(年付)",
}
@@ -41,6 +45,7 @@ class PaymentService:
"id": "free",
"name": "免费版",
"price": 0,
"period": "month",
"features": [
"1 个产品",
"20 次翻译/天",
@@ -52,6 +57,7 @@ class PaymentService:
"id": "pro",
"name": "Pro 版",
"price": 99,
"period": "month",
"features": [
"10 个产品",
"无限翻译",
@@ -60,19 +66,52 @@ class PaymentService:
"报价单生成",
],
},
{
"id": "pro_yearly",
"name": "Pro 版(年付)",
"price": 999,
"period": "year",
"original_price": 1188,
"features": [
"10 个产品",
"无限翻译",
"50 个客户",
"跟进提醒",
"报价单生成",
"省 ¥189",
],
},
{
"id": "enterprise",
"name": "企业版",
"price": 399,
"period": "month",
"features": [
"无限产品",
"多人协作",
"无限产品/客户",
"团队协作",
"品牌报价单",
"专属语料训练",
"API 接入",
"优先支持",
],
},
]
{
"id": "enterprise_yearly",
"name": "企业版(年付)",
"price": 3999,
"period": "year",
"original_price": 4788,
"features": [
"无限产品/客户",
"团队协作",
"品牌报价单",
"专属语料训练",
"API 接入",
"优先支持",
"省 ¥789",
],
},
],
}
async def get_current_subscription(self, user_id: str) -> Dict[str, Any]:
+102
View File
@@ -0,0 +1,102 @@
import logging
from typing import List, Dict, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.search_provider import SearchProvider
logger = logging.getLogger(__name__)
IGNORE_DOMAINS = [
"google.com", "facebook.com", "twitter.com", "instagram.com",
"youtube.com", "reddit.com", "amazon.com", "ebay.com",
"wikipedia.org", "linkedin.com", "pinterest.com", "baidu.com",
"bing.com",
]
class SearchService:
def __init__(self, db: AsyncSession):
self.db = db
async def search(self, query: str, limit: int = 10) -> List[Dict[str, str]]:
providers = await self._get_enabled_providers()
for provider in providers:
try:
return await self._search_provider(provider, query, limit)
except Exception as e:
logger.warning(f"Search provider {provider.provider_type} failed: {e}")
return []
async def _get_enabled_providers(self) -> List[SearchProvider]:
result = await self.db.execute(
select(SearchProvider)
.where(SearchProvider.enabled == True)
.order_by(SearchProvider.priority)
)
return list(result.scalars().all())
async def _search_provider(self, provider: SearchProvider, query: str, limit: int) -> List[Dict[str, str]]:
pt = provider.provider_type
if pt == "searxng":
return await searxng_search(provider.api_endpoint, query, limit)
elif pt == "bing":
return await bing_search(provider.api_key, query, limit)
else:
raise ValueError(f"Unknown provider type: {pt}")
async def searxng_search(endpoint: Optional[str], query: str, limit: int) -> List[Dict[str, str]]:
if not endpoint:
raise ValueError("SearXNG endpoint not configured")
import httpx
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
endpoint.rstrip("/") + "/search",
params={"q": query, "format": "json", "language": "zh-CN,en", "categories": "general"},
headers={"User-Agent": "TradeMate/1.0"},
)
if resp.status_code != 200:
raise ValueError(f"SearXNG returned {resp.status_code}")
data = resp.json()
results = []
for item in (data.get("results", []) if isinstance(data, dict) else data):
url = item.get("url", "")
if any(d in url for d in IGNORE_DOMAINS):
continue
results.append({
"title": (item.get("title") or url)[:100],
"url": url.rstrip("/"),
"snippet": (item.get("content") or item.get("snippet") or "")[:200],
})
if len(results) >= limit:
break
return results
async def bing_search(api_key: Optional[str], query: str, limit: int) -> List[Dict[str, str]]:
if not api_key:
raise ValueError("Bing API key not configured")
import httpx
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
"https://api.bing.microsoft.com/v7.0/search",
params={"q": query, "count": min(limit, 50), "mkt": "en-US", "textFormat": "Raw"},
headers={"Ocp-Apim-Subscription-Key": api_key},
)
if resp.status_code != 200:
raise ValueError(f"Bing returned {resp.status_code}")
data = resp.json()
results = []
for item in data.get("webPages", {}).get("value", []):
url = item.get("url", "")
if any(d in url for d in IGNORE_DOMAINS):
continue
results.append({
"title": (item.get("name") or url)[:100],
"url": url.rstrip("/"),
"snippet": (item.get("snippet") or "")[:200],
})
if len(results) >= limit:
break
return results
+169
View File
@@ -0,0 +1,169 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from fastapi import HTTPException, Depends
from datetime import datetime, date
import logging
from app.models import UsageLog, SystemConfig, User, Customer, Product
from app.models.user import User
from app.models.subscription import Subscription
from app.api.v1.deps import get_current_user_id
from app.database import get_db
logger = logging.getLogger(__name__)
TIER_LIMITS_DEFAULT = {
"free": {"translate_chars": 5000, "replies": 20, "marketing": 5, "customers": 5, "products": 1, "quotations": 3},
"pro": {"translate_chars": 50000, "replies": 200, "marketing": 50, "customers": 100, "products": 20, "quotations": 30},
"enterprise": {"translate_chars": 999999999, "replies": 9999, "marketing": 9999, "customers": 99999, "products": 9999, "quotations": 9999},
}
ACTION_MAP = {
"translate": "translate_chars",
"reply": "replies",
"marketing_generate": "marketing",
"create_customer": "customers",
"create_product": "products",
"create_quotation": "quotations",
}
class UsageService:
def __init__(self, db: AsyncSession):
self.db = db
async def get_limits(self, tier: str) -> dict:
config_key = f"{tier}_daily_limits"
result = await self.db.execute(select(SystemConfig).where(SystemConfig.key == config_key))
row = result.scalar_one_or_none()
if row and row.value:
return {**TIER_LIMITS_DEFAULT.get(tier, {}), **row.value}
return dict(TIER_LIMITS_DEFAULT.get(tier, {}))
async def get_tier(self, user_id: str) -> str:
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
return "free"
return user.tier or "free"
async def get_daily_usage(self, user_id: str, action: str) -> int:
today = date.today()
stmt = select(func.count()).where(
UsageLog.user_id == user_id,
UsageLog.action == action,
func.cast(UsageLog.created_at, date) == today,
)
result = await self.db.execute(stmt)
return result.scalar() or 0
async def get_daily_chars(self, user_id: str) -> int:
today = date.today()
stmt = select(func.coalesce(func.sum(
(UsageLog.detail["chars"]).as_integer()
), 0)).where(
UsageLog.user_id == user_id,
UsageLog.action == "translate",
func.cast(UsageLog.created_at, date) == today,
)
result = await self.db.execute(stmt)
return result.scalar() or 0
async def get_total_count(self, user_id: str, model_class) -> int:
stmt = select(func.count()).where(model_class.user_id == user_id)
result = await self.db.execute(stmt)
return result.scalar() or 0
async def check_quota(self, user_id: str, action: str, chars: int = 0) -> tuple[bool, str]:
tier = await self.get_tier(user_id)
limits = await self.get_limits(tier)
limit_key = ACTION_MAP.get(action)
if not limit_key:
return True, ""
limit = limits.get(limit_key, 999999)
if action == "translate":
used = await self.get_daily_chars(user_id)
if used + chars > limit:
remaining = max(0, limit - used)
return False, f"今日翻译字符已达上限({limit}字符),剩余{remaining}字符。升级 Pro 获取更多额度。"
elif action in ("create_customer",):
used = await self.get_total_count(user_id, Customer)
if used >= limit:
return False, f"客户数量已达上限({limit}个)。升级 Pro 获取更多客户管理额度。"
elif action in ("create_product",):
used = await self.get_total_count(user_id, Product)
if used >= limit:
return False, f"产品数量已达上限({limit}个)。升级 Pro 获取更多产品额度。"
else:
used = await self.get_daily_usage(user_id, action)
if used >= limit:
return False, f"今日{action}次数已达上限({limit}次)。升级 Pro 获取更多额度。"
return True, ""
async def record_usage(self, user_id: str, action: str, chars: int = 0, detail: dict = None):
log = UsageLog(
user_id=user_id,
action=action,
detail=detail or {},
)
if chars:
log.detail["chars"] = chars
self.db.add(log)
await self.db.commit()
async def get_usage_stats(self, user_id: str) -> dict:
tier = await self.get_tier(user_id)
limits = await self.get_limits(tier)
trial_days_left = 0
if tier == "pro":
result = await self.db.execute(
select(Subscription).where(
Subscription.user_id == user_id,
Subscription.plan == "pro_trial",
Subscription.status == "active",
)
)
trial_sub = result.scalar_one_or_none()
if trial_sub and trial_sub.expires_at:
remaining = (trial_sub.expires_at - datetime.utcnow()).days
trial_days_left = max(0, remaining)
customer_count = await self.get_total_count(user_id, Customer)
product_count = await self.get_total_count(user_id, Product)
translate_chars = await self.get_daily_chars(user_id)
reply_count = await self.get_daily_usage(user_id, "reply")
marketing_count = await self.get_daily_usage(user_id, "marketing_generate")
quotation_count = await self.get_daily_usage(user_id, "create_quotation")
return {
"tier": tier,
"limits": limits,
"usage": {
"translate_chars": translate_chars,
"replies": reply_count,
"marketing": marketing_count,
"customers": customer_count,
"products": product_count,
"quotations": quotation_count,
},
"trial_days_left": trial_days_left,
}
def require_quota(action: str, chars_field: str = None):
async def _check(
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
svc = UsageService(db)
if action == "translate" and chars_field:
raise HTTPException(status_code=400, detail="translate action needs explicit chars check")
ok, msg = await svc.check_quota(user_id, action)
if not ok:
raise HTTPException(status_code=429, detail=msg)
return user_id
return _check