Add landing page, referral system, usage quotas, search API management, and yearly pricing
- Separate workspace landing from login for better UX - Referral system rewards both parties with Pro days - Quota enforcement prevents abuse without breaking endpoints - 7-day free trial with auto-downgrade on expiry - Admin-managed search provider config (SearXNG, Bing) - 15% discount on annual subscriptions - MCP search server wrapping opencode search - Fix discovery module field name mismatch causing 422
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
"""add search_providers table
|
||||
|
||||
Revision ID: 7fe16f1f9962
|
||||
Revises: ecab04cc0e1d
|
||||
Create Date: 2026-05-25 10:18:37.103091
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = '7fe16f1f9962'
|
||||
down_revision: Union[str, None] = 'ecab04cc0e1d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('referral_codes',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('code', sa.String(length=20), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_referral_codes_code'), 'referral_codes', ['code'], unique=True)
|
||||
op.create_index(op.f('ix_referral_codes_user_id'), 'referral_codes', ['user_id'], unique=False)
|
||||
op.create_table('referrals',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('referrer_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('referred_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('code', sa.String(length=20), nullable=False),
|
||||
sa.Column('reward_days', sa.Integer(), nullable=True),
|
||||
sa.Column('status', sa.String(length=20), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('referred_id')
|
||||
)
|
||||
op.create_index(op.f('ix_referrals_referrer_id'), 'referrals', ['referrer_id'], unique=False)
|
||||
op.create_table('search_providers',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('provider_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('api_key', sa.Text(), nullable=True),
|
||||
sa.Column('api_endpoint', sa.String(length=500), nullable=True),
|
||||
sa.Column('extra_config', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('priority', sa.Integer(), nullable=True),
|
||||
sa.Column('enabled', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('search_providers')
|
||||
op.drop_index(op.f('ix_referrals_referrer_id'), table_name='referrals')
|
||||
op.drop_table('referrals')
|
||||
op.drop_index(op.f('ix_referral_codes_user_id'), table_name='referral_codes')
|
||||
op.drop_index(op.f('ix_referral_codes_code'), table_name='referral_codes')
|
||||
op.drop_table('referral_codes')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,189 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete
|
||||
from app.database import get_db
|
||||
from app.api.v1.deps import get_current_user
|
||||
from app.models.search_provider import SearchProvider
|
||||
from app.services.search import SearchService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def require_admin(current_user: dict = Depends(get_current_user)) -> dict:
|
||||
if current_user.get("role") != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
return current_user
|
||||
|
||||
|
||||
class ProviderCreate(BaseModel):
|
||||
name: str
|
||||
provider_type: str
|
||||
api_key: Optional[str] = None
|
||||
api_endpoint: Optional[str] = None
|
||||
extra_config: Optional[dict] = None
|
||||
priority: int = 0
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class ProviderUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_endpoint: Optional[str] = None
|
||||
extra_config: Optional[dict] = None
|
||||
priority: Optional[int] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
@router.get("/search-providers")
|
||||
async def list_providers(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(50, ge=1, le=100),
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(SearchProvider).order_by(SearchProvider.priority).offset((page - 1) * size).limit(size)
|
||||
)
|
||||
providers = result.scalars().all()
|
||||
total_result = await db.execute(select(SearchProvider))
|
||||
total = len(total_result.scalars().all())
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"id": str(p.id),
|
||||
"name": p.name,
|
||||
"provider_type": p.provider_type,
|
||||
"api_key": p.api_key[:8] + "..." if p.api_key and len(p.api_key) > 8 else p.api_key,
|
||||
"api_endpoint": p.api_endpoint,
|
||||
"extra_config": p.extra_config,
|
||||
"priority": p.priority,
|
||||
"enabled": p.enabled,
|
||||
"created_at": p.created_at.isoformat() if p.created_at else None,
|
||||
"updated_at": p.updated_at.isoformat() if p.updated_at else None,
|
||||
}
|
||||
for p in providers
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/search-providers")
|
||||
async def create_provider(
|
||||
data: ProviderCreate,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
provider = SearchProvider(
|
||||
name=data.name,
|
||||
provider_type=data.provider_type,
|
||||
api_key=data.api_key,
|
||||
api_endpoint=data.api_endpoint,
|
||||
extra_config=data.extra_config or {},
|
||||
priority=data.priority,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
db.add(provider)
|
||||
await db.flush()
|
||||
return {
|
||||
"id": str(provider.id),
|
||||
"name": provider.name,
|
||||
"provider_type": provider.provider_type,
|
||||
"message": "Provider created",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/search-providers/{provider_id}")
|
||||
async def get_provider(
|
||||
provider_id: str,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
_validate_uuid(provider_id)
|
||||
result = await db.execute(select(SearchProvider).where(SearchProvider.id == provider_id))
|
||||
p = result.scalar_one_or_none()
|
||||
if not p:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
return {
|
||||
"id": str(p.id),
|
||||
"name": p.name,
|
||||
"provider_type": p.provider_type,
|
||||
"api_key": p.api_key,
|
||||
"api_endpoint": p.api_endpoint,
|
||||
"extra_config": p.extra_config,
|
||||
"priority": p.priority,
|
||||
"enabled": p.enabled,
|
||||
"created_at": p.created_at.isoformat() if p.created_at else None,
|
||||
"updated_at": p.updated_at.isoformat() if p.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/search-providers/{provider_id}")
|
||||
async def update_provider(
|
||||
provider_id: str,
|
||||
data: ProviderUpdate,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
_validate_uuid(provider_id)
|
||||
result = await db.execute(select(SearchProvider).where(SearchProvider.id == provider_id))
|
||||
p = result.scalar_one_or_none()
|
||||
if not p:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
if data.name is not None:
|
||||
p.name = data.name
|
||||
if data.api_key is not None:
|
||||
p.api_key = data.api_key
|
||||
if data.api_endpoint is not None:
|
||||
p.api_endpoint = data.api_endpoint
|
||||
if data.extra_config is not None:
|
||||
p.extra_config = data.extra_config
|
||||
if data.priority is not None:
|
||||
p.priority = data.priority
|
||||
if data.enabled is not None:
|
||||
p.enabled = data.enabled
|
||||
await db.flush()
|
||||
return {"message": "Provider updated"}
|
||||
|
||||
|
||||
@router.delete("/search-providers/{provider_id}")
|
||||
async def delete_provider(
|
||||
provider_id: str,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
_validate_uuid(provider_id)
|
||||
result = await db.execute(delete(SearchProvider).where(SearchProvider.id == provider_id))
|
||||
if result.rowcount == 0:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
return {"message": "Provider deleted"}
|
||||
|
||||
|
||||
@router.post("/search-providers/{provider_id}/test")
|
||||
async def test_provider(
|
||||
provider_id: str,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
_validate_uuid(provider_id)
|
||||
result = await db.execute(select(SearchProvider).where(SearchProvider.id == provider_id))
|
||||
p = result.scalar_one_or_none()
|
||||
if not p:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
try:
|
||||
svc = SearchService(db)
|
||||
results = await svc._search_provider(p, "test", 3)
|
||||
return {"success": True, "results": results}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
def _validate_uuid(uuid_str: str):
|
||||
import uuid
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid UUID")
|
||||
@@ -10,6 +10,11 @@ from app.core.security import hash_password, verify_password, create_access_toke
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from datetime import datetime, timedelta
|
||||
from app.services.admin import AdminService
|
||||
from app.models.subscription import Subscription
|
||||
from app.api.v1.referral import apply_referral
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -18,6 +23,7 @@ class RegisterRequest(BaseModel):
|
||||
phone: str
|
||||
password: str
|
||||
username: str = ""
|
||||
ref_code: str = ""
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
@@ -47,11 +53,28 @@ async def register(data: RegisterRequest, request: Request, db: AsyncSession = D
|
||||
phone=data.phone,
|
||||
username=data.username or data.phone,
|
||||
password_hash=hash_password(data.password),
|
||||
tier="free",
|
||||
tier="pro",
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
trial_end = datetime.utcnow() + timedelta(days=settings.TRIAL_DAYS)
|
||||
sub = Subscription(
|
||||
user_id=user.id,
|
||||
plan="pro_trial",
|
||||
status="active",
|
||||
started_at=datetime.utcnow(),
|
||||
expires_at=trial_end,
|
||||
)
|
||||
db.add(sub)
|
||||
|
||||
if data.ref_code:
|
||||
try:
|
||||
from app.api.v1.referral import do_claim_referral
|
||||
await do_claim_referral(data.ref_code, str(user.id), db)
|
||||
except Exception as e:
|
||||
logger.warning(f"Referral claim failed: {e}")
|
||||
|
||||
client_ip = request.client.host if request.client else None
|
||||
await AdminService(db).log_usage(str(user.id), "user.register", {"phone": data.phone}, ip=client_ip)
|
||||
|
||||
@@ -89,6 +112,20 @@ async def login(
|
||||
client_ip = request.client.host if request.client else None
|
||||
await AdminService(db).log_usage(str(user.id), "user.login", {"login_id": login_id}, ip=client_ip)
|
||||
|
||||
if user.tier == "pro":
|
||||
sub_result = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.user_id == user.id,
|
||||
Subscription.plan == "pro_trial",
|
||||
Subscription.status == "active",
|
||||
)
|
||||
)
|
||||
trial_sub = sub_result.scalar_one_or_none()
|
||||
if trial_sub and trial_sub.expires_at and trial_sub.expires_at < datetime.utcnow():
|
||||
trial_sub.status = "expired"
|
||||
user.tier = "free"
|
||||
await db.flush()
|
||||
|
||||
return LoginResponse(
|
||||
access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}),
|
||||
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||
@@ -178,6 +215,20 @@ async def get_me(
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
trial_days_left = 0
|
||||
if user.tier == "pro":
|
||||
sub_result = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.user_id == user.id,
|
||||
Subscription.plan == "pro_trial",
|
||||
Subscription.status == "active",
|
||||
)
|
||||
)
|
||||
trial_sub = sub_result.scalar_one_or_none()
|
||||
if trial_sub and trial_sub.expires_at:
|
||||
remaining = (trial_sub.expires_at - datetime.utcnow()).days
|
||||
trial_days_left = max(0, remaining)
|
||||
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"phone": user.phone,
|
||||
@@ -186,6 +237,7 @@ async def get_me(
|
||||
"role": user.role,
|
||||
"settings": user.settings,
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
"trial_days_left": trial_days_left,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from app.database import get_db
|
||||
from app.services.customer import CustomerService
|
||||
from app.services.customer_health import CustomerHealthService
|
||||
from app.services.import_service import import_service
|
||||
from app.services.usage import UsageService
|
||||
from app.services import export
|
||||
from app.core.security import decode_token
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
@@ -98,8 +99,13 @@ async def create_customer(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
usage = UsageService(db)
|
||||
ok, msg = await usage.check_quota(user_id, "create_customer")
|
||||
if not ok:
|
||||
raise HTTPException(status_code=429, detail=msg)
|
||||
service = CustomerService(db)
|
||||
customer = await service.create_customer(user_id, data)
|
||||
await usage.record_usage(user_id, "create_customer")
|
||||
return customer
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional, List
|
||||
from app.database import get_db
|
||||
from app.services.product import ProductService
|
||||
from app.services import export
|
||||
from app.services.usage import UsageService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from pydantic import BaseModel
|
||||
import io
|
||||
@@ -175,8 +176,13 @@ async def create_product(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
usage = UsageService(db)
|
||||
ok, msg = await usage.check_quota(user_id, "create_product")
|
||||
if not ok:
|
||||
raise HTTPException(status_code=429, detail=msg)
|
||||
service = ProductService(db)
|
||||
product = await service.create_product(user_id, data.dict())
|
||||
await usage.record_usage(user_id, "create_product")
|
||||
return product
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.models.referral import ReferralCode, Referral
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.config import settings
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
import secrets
|
||||
import string
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def generate_code() -> str:
|
||||
return "TM" + "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(6))
|
||||
|
||||
|
||||
@router.post("/code")
|
||||
async def get_or_create_code(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(ReferralCode).where(ReferralCode.user_id == user_id))
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
return {"code": existing.code, "url": f"/workspace/?ref={existing.code}"}
|
||||
|
||||
code = generate_code()
|
||||
while True:
|
||||
check = await db.execute(select(ReferralCode).where(ReferralCode.code == code))
|
||||
if not check.scalar_one_or_none():
|
||||
break
|
||||
code = generate_code()
|
||||
|
||||
rc = ReferralCode(user_id=user_id, code=code)
|
||||
db.add(rc)
|
||||
await db.commit()
|
||||
return {"code": code, "url": f"/workspace/?ref={code}"}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_referral_stats(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Referral).where(Referral.referrer_id == user_id))
|
||||
referrals = result.scalars().all()
|
||||
total_reward_days = sum(r.reward_days for r in referrals if r.status == "completed")
|
||||
return {
|
||||
"total_referrals": len(referrals),
|
||||
"completed": sum(1 for r in referrals if r.status == "completed"),
|
||||
"total_reward_days": total_reward_days,
|
||||
}
|
||||
|
||||
|
||||
async def apply_referral(code: str, new_user_id: str, db: AsyncSession):
|
||||
rc_result = await db.execute(select(ReferralCode).where(ReferralCode.code == code))
|
||||
rc = rc_result.scalar_one_or_none()
|
||||
if not rc:
|
||||
return
|
||||
|
||||
if str(rc.user_id) == new_user_id:
|
||||
return
|
||||
|
||||
existing = await db.execute(select(Referral).where(Referral.referred_id == new_user_id))
|
||||
if existing.scalar_one_or_none():
|
||||
return
|
||||
|
||||
reward_days = 15
|
||||
|
||||
referrer_sub = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.user_id == rc.user_id,
|
||||
Subscription.status == "active",
|
||||
).order_by(Subscription.created_at.desc()).limit(1)
|
||||
)
|
||||
referrer_sub_row = referrer_sub.scalar_one_or_none()
|
||||
if referrer_sub_row:
|
||||
old_expiry = referrer_sub_row.expires_at or datetime.utcnow()
|
||||
referrer_sub_row.expires_at = old_expiry + timedelta(days=reward_days)
|
||||
else:
|
||||
new_sub = Subscription(
|
||||
user_id=rc.user_id,
|
||||
plan="pro_trial",
|
||||
status="active",
|
||||
started_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(days=reward_days),
|
||||
)
|
||||
db.add(new_sub)
|
||||
user_result = await db.execute(select(User).where(User.id == rc.user_id))
|
||||
u = user_result.scalar_one_or_none()
|
||||
if u and u.tier == "free":
|
||||
u.tier = "pro"
|
||||
|
||||
user_result = await db.execute(select(User).where(User.id == new_user_id))
|
||||
ru = user_result.scalar_one_or_none()
|
||||
if ru and ru.tier in ("free", "guest"):
|
||||
ru.tier = "pro"
|
||||
ref_sub = Subscription(
|
||||
user_id=new_user_id,
|
||||
plan="pro_trial",
|
||||
status="active",
|
||||
started_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(days=reward_days),
|
||||
)
|
||||
db.add(ref_sub)
|
||||
|
||||
referral = Referral(
|
||||
referrer_id=rc.user_id,
|
||||
referred_id=new_user_id,
|
||||
code=code,
|
||||
reward_days=reward_days,
|
||||
)
|
||||
db.add(referral)
|
||||
await db.flush()
|
||||
|
||||
|
||||
@router.post("/claim")
|
||||
async def claim_referral(
|
||||
code: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
rc_result = await db.execute(select(ReferralCode).where(ReferralCode.code == code))
|
||||
rc = rc_result.scalar_one_or_none()
|
||||
if not rc:
|
||||
raise HTTPException(status_code=404, detail="无效的邀请码")
|
||||
|
||||
if str(rc.user_id) == user_id:
|
||||
raise HTTPException(status_code=400, detail="不能使用自己的邀请码")
|
||||
|
||||
existing = await db.execute(select(Referral).where(Referral.referred_id == user_id))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="已经使用过邀请码了")
|
||||
|
||||
await apply_referral(code, user_id, db)
|
||||
await db.commit()
|
||||
return {"success": True, "reward_days": 15}
|
||||
@@ -0,0 +1,19 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.services.search import SearchService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/query")
|
||||
async def search(
|
||||
q: str = Query(..., min_length=1, max_length=500),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = SearchService(db)
|
||||
results = await svc.search(q, limit)
|
||||
return {"query": q, "results": results}
|
||||
@@ -0,0 +1,17 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.services.usage import UsageService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_usage_stats(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = UsageService(db)
|
||||
stats = await svc.get_usage_stats(user_id)
|
||||
return stats
|
||||
@@ -100,6 +100,8 @@ class Settings(BaseSettings):
|
||||
FREE_MAX_PRODUCTS: int = 1
|
||||
FREE_DAILY_QUOTATIONS: int = 3
|
||||
|
||||
TRIAL_DAYS: int = 7
|
||||
|
||||
PRO_DAILY_TRANSLATE_CHARS: int = 50000
|
||||
PRO_DAILY_REPLIES: int = 200
|
||||
PRO_DAILY_MARKETING: int = 50
|
||||
|
||||
@@ -141,27 +141,16 @@ class QuotaMiddleware(BaseHTTPMiddleware):
|
||||
if method == "GET":
|
||||
return await call_next(request)
|
||||
|
||||
quota_map = {
|
||||
"/api/v1/translate": {
|
||||
"free": settings.FREE_DAILY_TRANSLATE_CHARS,
|
||||
"pro": settings.PRO_DAILY_TRANSLATE_CHARS,
|
||||
},
|
||||
"/api/v1/translate/reply": {
|
||||
"free": settings.FREE_DAILY_REPLIES,
|
||||
"pro": settings.PRO_DAILY_REPLIES,
|
||||
},
|
||||
"/api/v1/marketing": {
|
||||
"free": settings.FREE_DAILY_MARKETING,
|
||||
"pro": settings.PRO_DAILY_MARKETING,
|
||||
},
|
||||
"/api/v1/quotations": {
|
||||
"free": settings.FREE_DAILY_QUOTATIONS,
|
||||
"pro": settings.PRO_DAILY_QUOTATIONS,
|
||||
},
|
||||
}
|
||||
quota_map = [
|
||||
("/api/v1/translate/reply", {"free": settings.FREE_DAILY_REPLIES, "pro": settings.PRO_DAILY_REPLIES}),
|
||||
("/api/v1/translate", {"free": settings.FREE_DAILY_TRANSLATE_CHARS, "pro": settings.PRO_DAILY_TRANSLATE_CHARS}),
|
||||
("/api/v1/marketing/generate", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}),
|
||||
("/api/v1/marketing", {"free": settings.FREE_DAILY_MARKETING, "pro": settings.PRO_DAILY_MARKETING}),
|
||||
("/api/v1/quotations", {"free": settings.FREE_DAILY_QUOTATIONS, "pro": settings.PRO_DAILY_QUOTATIONS}),
|
||||
]
|
||||
|
||||
matched_key = None
|
||||
for prefix, limits in quota_map.items():
|
||||
for prefix, limits in quota_map:
|
||||
if path.startswith(prefix):
|
||||
matched_key = prefix
|
||||
break
|
||||
|
||||
+5
-1
@@ -54,7 +54,7 @@ async def health():
|
||||
return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"}
|
||||
|
||||
|
||||
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, certification, invoice
|
||||
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, certification, invoice, usage, referral, admin_search, search
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
|
||||
app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"])
|
||||
@@ -81,6 +81,10 @@ app.include_router(ai_assistant.router, prefix="/api/v1/ai", tags=["ai-assistant
|
||||
app.include_router(discovery.router, prefix="/api/v1/discovery", tags=["discovery"])
|
||||
app.include_router(certification.router, prefix="/api/v1/certification", tags=["certification"])
|
||||
app.include_router(invoice.router, prefix="/api/v1/invoices", tags=["invoices"])
|
||||
app.include_router(usage.router, prefix="/api/v1/usage", tags=["usage"])
|
||||
app.include_router(referral.router, prefix="/api/v1/referral", tags=["referral"])
|
||||
app.include_router(admin_search.router, prefix="/api/v1/admin", tags=["admin"])
|
||||
app.include_router(search.router, prefix="/api/v1/search", tags=["search"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -14,6 +14,8 @@ from .system_config import SystemConfig
|
||||
from .translation_quota import TranslationQuota
|
||||
from .certification import Certification, CertType, CertStatus
|
||||
from .invoice import Invoice, InvoiceType, InvoiceStatus
|
||||
from .referral import ReferralCode, Referral
|
||||
from .search_provider import SearchProvider
|
||||
|
||||
__all__ = [
|
||||
"User", "Product",
|
||||
@@ -32,4 +34,6 @@ __all__ = [
|
||||
"TranslationQuota",
|
||||
"Certification", "CertType", "CertStatus",
|
||||
"Invoice", "InvoiceType", "InvoiceStatus",
|
||||
"ReferralCode", "Referral",
|
||||
"SearchProvider",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import Column, String, Integer, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class ReferralCode(Base):
|
||||
__tablename__ = "referral_codes"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
code = Column(String(20), unique=True, nullable=False, index=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class Referral(Base):
|
||||
__tablename__ = "referrals"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
referrer_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
referred_id = Column(UUID(as_uuid=True), nullable=False, unique=True)
|
||||
code = Column(String(20), nullable=False)
|
||||
reward_days = Column(Integer, default=15)
|
||||
status = Column(String(20), default="completed")
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -0,0 +1,20 @@
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Boolean, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class SearchProvider(Base):
|
||||
__tablename__ = "search_providers"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(100), nullable=False)
|
||||
provider_type = Column(String(50), nullable=False)
|
||||
api_key = Column(Text, nullable=True)
|
||||
api_endpoint = Column(String(500), nullable=True)
|
||||
extra_config = Column(JSONB, default={})
|
||||
priority = Column(Integer, default=0)
|
||||
enabled = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
@@ -218,21 +218,32 @@ URL: {company_url}
|
||||
async def _ai_strategy(self, product: str, market: str) -> Dict[str, Any]:
|
||||
if not self._ai_available:
|
||||
return self._template_strategy(product, market)
|
||||
system = """你是外贸客户发现专家。根据用户的产品和目标市场,分析出潜在买家画像和获取策略。
|
||||
system = """你是外贸客户发现专家。根据用户的产品和目标市场,列出15家有可能采购该产品的潜在公司。
|
||||
|
||||
请以 JSON 格式返回(不要用 markdown 代码块标记):
|
||||
{
|
||||
"buyer_personas": [{"type": "", "description": "", "channels": [], "search_queries": []}],
|
||||
"strategy": "",
|
||||
"tips": []
|
||||
}"""
|
||||
prompt = f"产品:{product}\n目标市场:{market}\n请分析潜在买家画像和获取策略。"
|
||||
"companies": [
|
||||
{"name": "公司名称", "description": "公司业务简介", "country": "所在国家", "match_score": 匹配度0-100, "contact": "联系方式(有就写,没有写'需进一步查找')", "source": "推荐来源说明"}
|
||||
],
|
||||
"strategy": "整体获取策略建议",
|
||||
"tips": ["搜索建议1", "搜索建议2"]
|
||||
}
|
||||
|
||||
要求:
|
||||
- 公司名称要真实感,不要编造知名大公司
|
||||
- 公司业务要与产品相关
|
||||
- 匹配度要有区分度,60-95之间
|
||||
- 至少返回10家
|
||||
- 只返回 JSON,不要其他内容"""
|
||||
|
||||
prompt = f"产品:{product}\n目标市场:{market}\n请列出在该市场可能采购该产品的公司。"
|
||||
try:
|
||||
result = await self.ai.chat(prompt, system_prompt=system)
|
||||
content = result.get("reply", "")
|
||||
parsed = self._extract_json(content)
|
||||
if parsed:
|
||||
if parsed and "companies" in parsed:
|
||||
parsed["provider"] = result.get("provider_used", "unknown")
|
||||
parsed["ai_generated"] = True
|
||||
return parsed
|
||||
return self._template_strategy(product, market)
|
||||
except Exception as e:
|
||||
@@ -241,13 +252,14 @@ URL: {company_url}
|
||||
|
||||
def _template_strategy(self, product: str, market: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"buyer_personas": [
|
||||
{"type": "进口商/批发商", "description": f"从中国进口{product}并在{market}批发的贸易商", "channels": ["LinkedIn", "Google"], "search_queries": [f"{product} importer {market}"]},
|
||||
{"type": "品牌商/OEM买家", "description": f"在{market}销售自有品牌{product}的公司", "channels": ["LinkedIn", "行业展会"], "search_queries": [f"{product} manufacturer {market}"]},
|
||||
"companies": [
|
||||
{"name": f"{product} Importers in {market} (示例)", "description": f"在{market}从事{product}进口和批发的贸易商,建议在LinkedIn上搜索相关关键词", "country": market, "match_score": 75, "contact": "需进一步查找", "source": "AI推荐"},
|
||||
{"name": f"{product} Distributors in {market} (示例)", "description": f"在{market}分销{product}的渠道商,建议通过Google搜索关键词", "country": market, "match_score": 70, "contact": "需进一步查找", "source": "AI推荐"},
|
||||
],
|
||||
"strategy": f"建议在 LinkedIn 和 Google 搜索 {market} 的 {product} 相关公司",
|
||||
"tips": ["使用多个搜索词", "找到公司后在 LinkedIn 找决策人"],
|
||||
"strategy": f"建议在 LinkedIn 和 Google 搜索 {market} 的 {product} 相关公司,使用导入商、批发商、经销商等关键词组合",
|
||||
"tips": ["使用多个搜索词组合", "找到公司后在 LinkedIn 找决策人", "查看公司网站了解其业务范围"],
|
||||
"provider": "template",
|
||||
"ai_generated": True,
|
||||
}
|
||||
|
||||
def _template_analysis(self, url: str) -> Dict[str, Any]:
|
||||
|
||||
@@ -14,12 +14,16 @@ logger = logging.getLogger(__name__)
|
||||
PLANS = {
|
||||
"free": {"price": 0, "duration_days": None},
|
||||
"pro": {"price": 99, "duration_days": 30},
|
||||
"pro_yearly": {"price": 999, "duration_days": 365},
|
||||
"enterprise": {"price": 399, "duration_days": 30},
|
||||
"enterprise_yearly": {"price": 3999, "duration_days": 365},
|
||||
}
|
||||
|
||||
PLAN_DESCRIPTIONS = {
|
||||
"pro": "TradeMate Pro 版会员",
|
||||
"pro_yearly": "TradeMate Pro 版会员(年付)",
|
||||
"enterprise": "TradeMate 企业版会员",
|
||||
"enterprise_yearly": "TradeMate 企业版会员(年付)",
|
||||
}
|
||||
|
||||
|
||||
@@ -41,6 +45,7 @@ class PaymentService:
|
||||
"id": "free",
|
||||
"name": "免费版",
|
||||
"price": 0,
|
||||
"period": "month",
|
||||
"features": [
|
||||
"1 个产品",
|
||||
"20 次翻译/天",
|
||||
@@ -52,6 +57,7 @@ class PaymentService:
|
||||
"id": "pro",
|
||||
"name": "Pro 版",
|
||||
"price": 99,
|
||||
"period": "month",
|
||||
"features": [
|
||||
"10 个产品",
|
||||
"无限翻译",
|
||||
@@ -60,19 +66,52 @@ class PaymentService:
|
||||
"报价单生成",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "pro_yearly",
|
||||
"name": "Pro 版(年付)",
|
||||
"price": 999,
|
||||
"period": "year",
|
||||
"original_price": 1188,
|
||||
"features": [
|
||||
"10 个产品",
|
||||
"无限翻译",
|
||||
"50 个客户",
|
||||
"跟进提醒",
|
||||
"报价单生成",
|
||||
"省 ¥189",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "enterprise",
|
||||
"name": "企业版",
|
||||
"price": 399,
|
||||
"period": "month",
|
||||
"features": [
|
||||
"无限产品",
|
||||
"多人协作",
|
||||
"无限产品/客户",
|
||||
"团队协作",
|
||||
"品牌报价单",
|
||||
"专属语料训练",
|
||||
"API 接入",
|
||||
"优先支持",
|
||||
],
|
||||
},
|
||||
]
|
||||
{
|
||||
"id": "enterprise_yearly",
|
||||
"name": "企业版(年付)",
|
||||
"price": 3999,
|
||||
"period": "year",
|
||||
"original_price": 4788,
|
||||
"features": [
|
||||
"无限产品/客户",
|
||||
"团队协作",
|
||||
"品牌报价单",
|
||||
"专属语料训练",
|
||||
"API 接入",
|
||||
"优先支持",
|
||||
"省 ¥789",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
async def get_current_subscription(self, user_id: str) -> Dict[str, Any]:
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
import logging
|
||||
from typing import List, Dict, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.search_provider import SearchProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IGNORE_DOMAINS = [
|
||||
"google.com", "facebook.com", "twitter.com", "instagram.com",
|
||||
"youtube.com", "reddit.com", "amazon.com", "ebay.com",
|
||||
"wikipedia.org", "linkedin.com", "pinterest.com", "baidu.com",
|
||||
"bing.com",
|
||||
]
|
||||
|
||||
|
||||
class SearchService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def search(self, query: str, limit: int = 10) -> List[Dict[str, str]]:
|
||||
providers = await self._get_enabled_providers()
|
||||
for provider in providers:
|
||||
try:
|
||||
return await self._search_provider(provider, query, limit)
|
||||
except Exception as e:
|
||||
logger.warning(f"Search provider {provider.provider_type} failed: {e}")
|
||||
return []
|
||||
|
||||
async def _get_enabled_providers(self) -> List[SearchProvider]:
|
||||
result = await self.db.execute(
|
||||
select(SearchProvider)
|
||||
.where(SearchProvider.enabled == True)
|
||||
.order_by(SearchProvider.priority)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def _search_provider(self, provider: SearchProvider, query: str, limit: int) -> List[Dict[str, str]]:
|
||||
pt = provider.provider_type
|
||||
if pt == "searxng":
|
||||
return await searxng_search(provider.api_endpoint, query, limit)
|
||||
elif pt == "bing":
|
||||
return await bing_search(provider.api_key, query, limit)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {pt}")
|
||||
|
||||
|
||||
async def searxng_search(endpoint: Optional[str], query: str, limit: int) -> List[Dict[str, str]]:
|
||||
if not endpoint:
|
||||
raise ValueError("SearXNG endpoint not configured")
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
endpoint.rstrip("/") + "/search",
|
||||
params={"q": query, "format": "json", "language": "zh-CN,en", "categories": "general"},
|
||||
headers={"User-Agent": "TradeMate/1.0"},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"SearXNG returned {resp.status_code}")
|
||||
data = resp.json()
|
||||
results = []
|
||||
for item in (data.get("results", []) if isinstance(data, dict) else data):
|
||||
url = item.get("url", "")
|
||||
if any(d in url for d in IGNORE_DOMAINS):
|
||||
continue
|
||||
results.append({
|
||||
"title": (item.get("title") or url)[:100],
|
||||
"url": url.rstrip("/"),
|
||||
"snippet": (item.get("content") or item.get("snippet") or "")[:200],
|
||||
})
|
||||
if len(results) >= limit:
|
||||
break
|
||||
return results
|
||||
|
||||
|
||||
async def bing_search(api_key: Optional[str], query: str, limit: int) -> List[Dict[str, str]]:
|
||||
if not api_key:
|
||||
raise ValueError("Bing API key not configured")
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.bing.microsoft.com/v7.0/search",
|
||||
params={"q": query, "count": min(limit, 50), "mkt": "en-US", "textFormat": "Raw"},
|
||||
headers={"Ocp-Apim-Subscription-Key": api_key},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"Bing returned {resp.status_code}")
|
||||
data = resp.json()
|
||||
results = []
|
||||
for item in data.get("webPages", {}).get("value", []):
|
||||
url = item.get("url", "")
|
||||
if any(d in url for d in IGNORE_DOMAINS):
|
||||
continue
|
||||
results.append({
|
||||
"title": (item.get("name") or url)[:100],
|
||||
"url": url.rstrip("/"),
|
||||
"snippet": (item.get("snippet") or "")[:200],
|
||||
})
|
||||
if len(results) >= limit:
|
||||
break
|
||||
return results
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from fastapi import HTTPException, Depends
|
||||
from datetime import datetime, date
|
||||
import logging
|
||||
|
||||
from app.models import UsageLog, SystemConfig, User, Customer, Product
|
||||
from app.models.user import User
|
||||
from app.models.subscription import Subscription
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.database import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TIER_LIMITS_DEFAULT = {
|
||||
"free": {"translate_chars": 5000, "replies": 20, "marketing": 5, "customers": 5, "products": 1, "quotations": 3},
|
||||
"pro": {"translate_chars": 50000, "replies": 200, "marketing": 50, "customers": 100, "products": 20, "quotations": 30},
|
||||
"enterprise": {"translate_chars": 999999999, "replies": 9999, "marketing": 9999, "customers": 99999, "products": 9999, "quotations": 9999},
|
||||
}
|
||||
|
||||
ACTION_MAP = {
|
||||
"translate": "translate_chars",
|
||||
"reply": "replies",
|
||||
"marketing_generate": "marketing",
|
||||
"create_customer": "customers",
|
||||
"create_product": "products",
|
||||
"create_quotation": "quotations",
|
||||
}
|
||||
|
||||
|
||||
class UsageService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_limits(self, tier: str) -> dict:
|
||||
config_key = f"{tier}_daily_limits"
|
||||
result = await self.db.execute(select(SystemConfig).where(SystemConfig.key == config_key))
|
||||
row = result.scalar_one_or_none()
|
||||
if row and row.value:
|
||||
return {**TIER_LIMITS_DEFAULT.get(tier, {}), **row.value}
|
||||
return dict(TIER_LIMITS_DEFAULT.get(tier, {}))
|
||||
|
||||
async def get_tier(self, user_id: str) -> str:
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
return "free"
|
||||
return user.tier or "free"
|
||||
|
||||
async def get_daily_usage(self, user_id: str, action: str) -> int:
|
||||
today = date.today()
|
||||
stmt = select(func.count()).where(
|
||||
UsageLog.user_id == user_id,
|
||||
UsageLog.action == action,
|
||||
func.cast(UsageLog.created_at, date) == today,
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_daily_chars(self, user_id: str) -> int:
|
||||
today = date.today()
|
||||
stmt = select(func.coalesce(func.sum(
|
||||
(UsageLog.detail["chars"]).as_integer()
|
||||
), 0)).where(
|
||||
UsageLog.user_id == user_id,
|
||||
UsageLog.action == "translate",
|
||||
func.cast(UsageLog.created_at, date) == today,
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_total_count(self, user_id: str, model_class) -> int:
|
||||
stmt = select(func.count()).where(model_class.user_id == user_id)
|
||||
result = await self.db.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def check_quota(self, user_id: str, action: str, chars: int = 0) -> tuple[bool, str]:
|
||||
tier = await self.get_tier(user_id)
|
||||
limits = await self.get_limits(tier)
|
||||
limit_key = ACTION_MAP.get(action)
|
||||
if not limit_key:
|
||||
return True, ""
|
||||
|
||||
limit = limits.get(limit_key, 999999)
|
||||
|
||||
if action == "translate":
|
||||
used = await self.get_daily_chars(user_id)
|
||||
if used + chars > limit:
|
||||
remaining = max(0, limit - used)
|
||||
return False, f"今日翻译字符已达上限({limit}字符),剩余{remaining}字符。升级 Pro 获取更多额度。"
|
||||
elif action in ("create_customer",):
|
||||
used = await self.get_total_count(user_id, Customer)
|
||||
if used >= limit:
|
||||
return False, f"客户数量已达上限({limit}个)。升级 Pro 获取更多客户管理额度。"
|
||||
elif action in ("create_product",):
|
||||
used = await self.get_total_count(user_id, Product)
|
||||
if used >= limit:
|
||||
return False, f"产品数量已达上限({limit}个)。升级 Pro 获取更多产品额度。"
|
||||
else:
|
||||
used = await self.get_daily_usage(user_id, action)
|
||||
if used >= limit:
|
||||
return False, f"今日{action}次数已达上限({limit}次)。升级 Pro 获取更多额度。"
|
||||
|
||||
return True, ""
|
||||
|
||||
async def record_usage(self, user_id: str, action: str, chars: int = 0, detail: dict = None):
|
||||
log = UsageLog(
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
detail=detail or {},
|
||||
)
|
||||
if chars:
|
||||
log.detail["chars"] = chars
|
||||
self.db.add(log)
|
||||
await self.db.commit()
|
||||
|
||||
async def get_usage_stats(self, user_id: str) -> dict:
|
||||
tier = await self.get_tier(user_id)
|
||||
limits = await self.get_limits(tier)
|
||||
|
||||
trial_days_left = 0
|
||||
if tier == "pro":
|
||||
result = await self.db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.user_id == user_id,
|
||||
Subscription.plan == "pro_trial",
|
||||
Subscription.status == "active",
|
||||
)
|
||||
)
|
||||
trial_sub = result.scalar_one_or_none()
|
||||
if trial_sub and trial_sub.expires_at:
|
||||
remaining = (trial_sub.expires_at - datetime.utcnow()).days
|
||||
trial_days_left = max(0, remaining)
|
||||
|
||||
customer_count = await self.get_total_count(user_id, Customer)
|
||||
product_count = await self.get_total_count(user_id, Product)
|
||||
translate_chars = await self.get_daily_chars(user_id)
|
||||
reply_count = await self.get_daily_usage(user_id, "reply")
|
||||
marketing_count = await self.get_daily_usage(user_id, "marketing_generate")
|
||||
quotation_count = await self.get_daily_usage(user_id, "create_quotation")
|
||||
|
||||
return {
|
||||
"tier": tier,
|
||||
"limits": limits,
|
||||
"usage": {
|
||||
"translate_chars": translate_chars,
|
||||
"replies": reply_count,
|
||||
"marketing": marketing_count,
|
||||
"customers": customer_count,
|
||||
"products": product_count,
|
||||
"quotations": quotation_count,
|
||||
},
|
||||
"trial_days_left": trial_days_left,
|
||||
}
|
||||
|
||||
|
||||
def require_quota(action: str, chars_field: str = None):
|
||||
async def _check(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = UsageService(db)
|
||||
if action == "translate" and chars_field:
|
||||
raise HTTPException(status_code=400, detail="translate action needs explicit chars check")
|
||||
ok, msg = await svc.check_quota(user_id, action)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=429, detail=msg)
|
||||
return user_id
|
||||
return _check
|
||||
Reference in New Issue
Block a user