feat: credit-based billing system

- New DB models: credit_packages, subscription_plans, user_credits, credit_consumptions, credit_purchases
- CreditService: balance, deduct, add_credits, grant_free_trial, history
- User API: /api/v1/credits/* (balance/history/packages/purchase/subscribe)
- Admin API: /api/v1/admin/credit-* (CRUD packages/plans, user credits, consumptions)
- PaymentService.create_credit_order + handle_callback for credit purchases
- Credit deduction on: discovery, translate, marketing, ai_chat, followup
- Free trial 30 credits on registration
- Documentation: docs/CREDIT_SYSTEM.md
This commit is contained in:
TradeMate Dev
2026-06-12 10:39:45 +08:00
parent 5d895ae12c
commit 2a107a42f3
21 changed files with 1528 additions and 33 deletions
@@ -0,0 +1,101 @@
"""add credit system tables (packages, plans, user credits, consumptions, purchases)
Revision ID: add_credit_system
Revises: add_perf_indexes
Create Date: 2026-06-12
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID, JSONB
revision = "add_credit_system"
down_revision = "add_perf_indexes"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"credit_packages",
sa.Column("id", UUID(as_uuid=True), primary_key=True),
sa.Column("name", sa.String(100), nullable=False),
sa.Column("name_en", sa.String(100), nullable=False),
sa.Column("credits", sa.Integer, nullable=False),
sa.Column("price", sa.Float, nullable=False),
sa.Column("price_usd", sa.Float, nullable=True),
sa.Column("original_price", sa.Float, nullable=True),
sa.Column("is_active", sa.Boolean, default=True),
sa.Column("sort_order", sa.Integer, default=0),
sa.Column("created_at", sa.DateTime, default=sa.func.now()),
sa.Column("updated_at", sa.DateTime, default=sa.func.now()),
)
op.create_table(
"subscription_plans",
sa.Column("id", UUID(as_uuid=True), primary_key=True),
sa.Column("name", sa.String(100), nullable=False),
sa.Column("name_en", sa.String(100), nullable=False),
sa.Column("credits_per_month", sa.Integer, nullable=False),
sa.Column("price", sa.Float, nullable=False),
sa.Column("price_usd", sa.Float, nullable=True),
sa.Column("duration_days", sa.Integer, default=30),
sa.Column("is_active", sa.Boolean, default=True),
sa.Column("sort_order", sa.Integer, default=0),
sa.Column("created_at", sa.DateTime, default=sa.func.now()),
sa.Column("updated_at", sa.DateTime, default=sa.func.now()),
)
op.create_table(
"user_credits",
sa.Column("id", UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=False, unique=True, index=True),
sa.Column("balance", sa.Float, default=0),
sa.Column("total_purchased", sa.Float, default=0),
sa.Column("total_used", sa.Float, default=0),
sa.Column("subscription_plan_id", UUID(as_uuid=True), sa.ForeignKey("subscription_plans.id"), nullable=True),
sa.Column("subscription_expires_at", sa.DateTime, nullable=True),
sa.Column("subscription_auto_renew", sa.Boolean, default=False),
sa.Column("free_trial_used", sa.Boolean, default=False),
sa.Column("daily_translate_chars", sa.Integer, default=0),
sa.Column("daily_translate_date", sa.Date, nullable=True),
sa.Column("created_at", sa.DateTime, default=sa.func.now()),
sa.Column("updated_at", sa.DateTime, default=sa.func.now()),
)
op.create_table(
"credit_consumptions",
sa.Column("id", UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=False, index=True),
sa.Column("result_type", sa.String(50), nullable=False),
sa.Column("reference_id", UUID(as_uuid=True), nullable=True),
sa.Column("credits_change", sa.Float, nullable=False),
sa.Column("balance_after", sa.Float, nullable=False),
sa.Column("source", sa.String(30), nullable=False),
sa.Column("description", sa.String(500), nullable=True),
sa.Column("metadata", JSONB, nullable=True),
sa.Column("created_at", sa.DateTime, default=sa.func.now()),
)
op.create_index("idx_credit_consumptions_user", "credit_consumptions", ["user_id", sa.text("created_at DESC")])
op.create_index("idx_credit_consumptions_type", "credit_consumptions", ["result_type"])
op.create_table(
"credit_purchases",
sa.Column("id", UUID(as_uuid=True), primary_key=True),
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=False, index=True),
sa.Column("package_id", UUID(as_uuid=True), sa.ForeignKey("credit_packages.id"), nullable=True),
sa.Column("subscription_plan_id", UUID(as_uuid=True), sa.ForeignKey("subscription_plans.id"), nullable=True),
sa.Column("credits", sa.Integer, nullable=False),
sa.Column("amount", sa.Float, nullable=False),
sa.Column("currency", sa.String(3), default="CNY"),
sa.Column("payment_method", sa.String(20), nullable=True),
sa.Column("status", sa.String(20), default="pending"),
sa.Column("payment_transaction_id", UUID(as_uuid=True), sa.ForeignKey("payment_transactions.id"), nullable=True),
sa.Column("created_at", sa.DateTime, default=sa.func.now()),
sa.Column("paid_at", sa.DateTime, nullable=True),
)
op.create_index("idx_credit_purchases_user", "credit_purchases", ["user_id"])
op.create_index("idx_credit_purchases_status", "credit_purchases", ["status"])
def downgrade():
op.drop_table("credit_purchases")
op.drop_table("credit_consumptions")
op.drop_table("user_credits")
op.drop_table("subscription_plans")
op.drop_table("credit_packages")
@@ -12,7 +12,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'add_perf_indexes'
down_revision = 'add_payment_transactions_table'
down_revision = 'add_payment_transactions'
branch_labels = None
depends_on = None
+308
View File
@@ -0,0 +1,308 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_db
from app.api.v1.admin import require_admin
from app.services.credit import CreditService
from app.models.credit_package import CreditPackage, SubscriptionPlan
from app.models.user_credit import UserCredit
from app.models.user import User
import uuid
router = APIRouter()
class PackageForm(BaseModel):
name: str
name_en: str
credits: int
price: float
price_usd: Optional[float] = None
original_price: Optional[float] = None
is_active: bool = True
sort_order: int = 0
class PlanForm(BaseModel):
name: str
name_en: str
credits_per_month: int
price: float
price_usd: Optional[float] = None
duration_days: int = 30
is_active: bool = True
sort_order: int = 0
class AdjustCreditsForm(BaseModel):
user_id: str
credits: float
reason: str = ""
@router.get("/credit-packages")
async def list_packages(
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(CreditPackage).order_by(CreditPackage.sort_order)
)
return [{
"id": str(p.id),
"name": p.name,
"name_en": p.name_en,
"credits": p.credits,
"price": p.price,
"price_usd": p.price_usd,
"original_price": p.original_price,
"is_active": p.is_active,
"sort_order": p.sort_order,
} for p in result.scalars().all()]
@router.post("/credit-packages")
async def create_package(
data: PackageForm,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
pkg = CreditPackage(**data.model_dump())
db.add(pkg)
await db.flush()
return {"id": str(pkg.id), "status": "ok"}
@router.put("/credit-packages/{pkg_id}")
async def update_package(
pkg_id: str,
data: PackageForm,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
try:
uid = uuid.UUID(pkg_id)
except ValueError:
raise HTTPException(status_code=400, detail="无效ID")
result = await db.execute(select(CreditPackage).where(CreditPackage.id == uid))
pkg = result.scalar_one_or_none()
if not pkg:
raise HTTPException(status_code=404, detail="次数包不存在")
for k, v in data.model_dump().items():
setattr(pkg, k, v)
await db.flush()
return {"status": "ok"}
@router.delete("/credit-packages/{pkg_id}")
async def delete_package(
pkg_id: str,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
try:
uid = uuid.UUID(pkg_id)
except ValueError:
raise HTTPException(status_code=400, detail="无效ID")
result = await db.execute(select(CreditPackage).where(CreditPackage.id == uid))
pkg = result.scalar_one_or_none()
if not pkg:
raise HTTPException(status_code=404, detail="次数包不存在")
await db.delete(pkg)
await db.flush()
return {"status": "ok"}
@router.get("/subscription-plans")
async def list_plans(
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(SubscriptionPlan).order_by(SubscriptionPlan.sort_order)
)
return [{
"id": str(p.id),
"name": p.name,
"name_en": p.name_en,
"credits_per_month": p.credits_per_month,
"price": p.price,
"price_usd": p.price_usd,
"duration_days": p.duration_days,
"is_active": p.is_active,
"sort_order": p.sort_order,
} for p in result.scalars().all()]
@router.post("/subscription-plans")
async def create_plan(
data: PlanForm,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
plan = SubscriptionPlan(**data.model_dump())
db.add(plan)
await db.flush()
return {"id": str(plan.id), "status": "ok"}
@router.put("/subscription-plans/{plan_id}")
async def update_plan(
plan_id: str,
data: PlanForm,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
try:
uid = uuid.UUID(plan_id)
except ValueError:
raise HTTPException(status_code=400, detail="无效ID")
result = await db.execute(select(SubscriptionPlan).where(SubscriptionPlan.id == uid))
plan = result.scalar_one_or_none()
if not plan:
raise HTTPException(status_code=404, detail="订阅套餐不存在")
for k, v in data.model_dump().items():
setattr(plan, k, v)
await db.flush()
return {"status": "ok"}
@router.delete("/subscription-plans/{plan_id}")
async def delete_plan(
plan_id: str,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
try:
uid = uuid.UUID(plan_id)
except ValueError:
raise HTTPException(status_code=400, detail="无效ID")
result = await db.execute(select(SubscriptionPlan).where(SubscriptionPlan.id == uid))
plan = result.scalar_one_or_none()
if not plan:
raise HTTPException(status_code=404, detail="订阅套餐不存在")
await db.delete(plan)
await db.flush()
return {"status": "ok"}
@router.get("/user-credits")
async def list_user_credits(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
offset = (page - 1) * size
result = await db.execute(
select(UserCredit).order_by(UserCredit.updated_at.desc()).offset(offset).limit(size)
)
items = result.scalars().all()
from sqlalchemy import func
count_result = await db.execute(select(func.count(UserCredit.id)))
total = count_result.scalar() or 0
enriched = []
for uc in items:
user_result = await db.execute(select(User).where(User.id == uc.user_id))
user = user_result.scalar_one_or_none()
enriched.append({
"id": str(uc.id),
"user_id": str(uc.user_id),
"username": user.username if user else "N/A",
"balance": uc.balance,
"total_purchased": uc.total_purchased,
"total_used": uc.total_used,
"subscription_plan_id": str(uc.subscription_plan_id) if uc.subscription_plan_id else None,
"subscription_expires_at": uc.subscription_expires_at.isoformat() if uc.subscription_expires_at else None,
"free_trial_used": uc.free_trial_used,
"updated_at": uc.updated_at.isoformat() if uc.updated_at else None,
})
return {"items": enriched, "total": total, "page": page, "size": size}
class AdjustForm(BaseModel):
user_id: str
credits: float
reason: str = ""
@router.post("/user-credits/adjust")
async def adjust_credits(
data: AdjustForm,
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
svc = CreditService(db)
try:
uid = uuid.UUID(data.user_id)
except ValueError:
raise HTTPException(status_code=400, detail="无效用户ID")
balance = await svc.add_credits(
user_id=uid,
credits=data.credits,
source="admin_grant",
description=data.reason or f"管理员调整: {data.credits:+.1f}",
)
return {"status": "ok", "balance": balance}
@router.get("/credit-consumptions")
async def list_consumptions(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=200),
user_id: str = Query(None),
result_type: str = Query(None),
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
from app.models.credit_consumption import CreditConsumption
from sqlalchemy import select, func, desc
conditions = []
if user_id:
try:
conditions.append(CreditConsumption.user_id == uuid.UUID(user_id))
except ValueError:
pass
if result_type:
conditions.append(CreditConsumption.result_type == result_type)
stmt = select(CreditConsumption).where(*conditions).order_by(
desc(CreditConsumption.created_at)
).offset((page - 1) * size).limit(size)
result = await db.execute(stmt)
items = result.scalars().all()
count_stmt = select(func.count(CreditConsumption.id)).where(*conditions)
count_result = await db.execute(count_stmt)
total = count_result.scalar() or 0
return {
"items": [{
"id": str(c.id),
"user_id": str(c.user_id),
"result_type": c.result_type,
"credits_change": c.credits_change,
"balance_after": c.balance_after,
"source": c.source,
"description": c.description,
"created_at": c.created_at.isoformat() if c.created_at else None,
} for c in items],
"total": total,
"page": page,
"size": size,
}
@router.get("/credit-stats")
async def credit_stats(
_: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
svc = CreditService(db)
return await svc.get_stats()
+9
View File
@@ -9,6 +9,7 @@ from app.ai.local_faq import match_faq
from app.api.v1.deps import get_current_user_id
from app.models.system_config import SystemConfig
from app.services.admin import AdminService
from app.services.credit import CreditService
import logging
import time
import re
@@ -108,6 +109,14 @@ async def chat(
f"db={t1-t0:.2f}s orm={t2-t1:.2f}s total={t4-t_start:.2f}s"
)
else:
credit_svc = CreditService(db)
ok, balance = await credit_svc.deduct(user_id, "ai_chat")
if not ok:
raise HTTPException(
status_code=402,
detail=f"次数不足 (剩余 {balance:.1f})"
)
t3 = time.time()
ai = get_ai_router()
result = await ai.chat(data.message, data.history or [], system_prompt)
+4
View File
@@ -81,6 +81,10 @@ async def register(
)
db.add(sub)
from app.services.credit import CreditService
credit_svc = CreditService(db)
await credit_svc.grant_free_trial(user.id)
if data.ref_code:
try:
from app.api.v1.referral import do_claim_referral
+121
View File
@@ -0,0 +1,121 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.v1.deps import get_current_user_id
from app.services.credit import CreditService
from app.services.payment import PaymentService
router = APIRouter()
class PurchaseRequest(BaseModel):
package_id: str
pay_type: str = "alipay"
class SubscribeRequest(BaseModel):
plan_id: str
pay_type: str = "alipay"
@router.get("/balance")
async def get_balance(
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
svc = CreditService(db)
return await svc.get_balance(user_id)
@router.get("/history")
async def get_history(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
svc = CreditService(db)
return await svc.get_history(user_id, page, size)
@router.get("/packages")
async def list_packages(
db: AsyncSession = Depends(get_db),
):
svc = CreditService(db)
return await svc.get_packages()
@router.get("/subscription-plans")
async def list_subscription_plans(
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
svc = CreditService(db)
return await svc.get_subscription_plans()
@router.post("/purchase")
async def purchase_package(
req: PurchaseRequest,
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
svc = CreditService(db)
packages = await svc.get_packages()
pkg = next((p for p in packages if p["id"] == req.package_id), None)
if not pkg:
raise HTTPException(status_code=404, detail="次数包不存在")
pay_svc = PaymentService(db)
order = await pay_svc.create_credit_order(
user_id=user_id,
amount=pkg["price"],
description=f"购买 {pkg['name']} ({pkg['credits']}次)",
pay_type=req.pay_type,
metadata={"credit_package_id": req.package_id, "credits": pkg["credits"]},
)
return order
@router.post("/subscribe")
async def subscribe_plan(
req: SubscribeRequest,
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
svc = CreditService(db)
plans = await svc.get_subscription_plans()
plan = next((p for p in plans if p["id"] == req.plan_id), None)
if not plan:
raise HTTPException(status_code=404, detail="订阅套餐不存在")
pay_svc = PaymentService(db)
order = await pay_svc.create_credit_order(
user_id=user_id,
amount=plan["price"],
description=f"开通 {plan['name']} (每月{plan['credits_per_month']}次)",
pay_type=req.pay_type,
metadata={"subscription_plan_id": req.plan_id, "credits_per_month": plan["credits_per_month"]},
)
return order
@router.post("/cancel-subscription")
async def cancel_subscription(
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
from app.models.user_credit import UserCredit
from sqlalchemy import select
result = await db.execute(select(UserCredit).where(UserCredit.user_id == user_id))
uc = result.scalar_one_or_none()
if not uc or not uc.subscription_plan_id:
raise HTTPException(status_code=400, detail="没有有效的订阅")
uc.subscription_auto_renew = False
await db.flush()
return {"success": True, "message": "已取消自动续费,当前订阅到期后不再续费"}
+71 -7
View File
@@ -4,6 +4,11 @@ from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.services.discovery import DiscoveryService
from app.services.credit import CreditService
from app.api.v1.deps import get_current_user_id
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@@ -23,45 +28,104 @@ class OutreachRequest(BaseModel):
product: Dict[str, Any]
CREDIT_COST = {
"search": 10,
"analyze": 5,
"outreach": 3,
}
async def _deduct_credits(user_id: str, result_type: str, db: AsyncSession):
svc = CreditService(db)
ok, balance = await svc.deduct(user_id, result_type)
if not ok:
raise HTTPException(
status_code=402,
detail=f"次数不足 (剩余 {balance:.1f}, 需要 {CREDIT_COST.get(result_type, 1)})"
)
return balance
@router.post("/search")
async def search_leads(req: SearchRequest, db: AsyncSession = Depends(get_db)):
async def search_leads(
req: SearchRequest,
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
if not req.product_description.strip():
raise HTTPException(status_code=400, detail="请填写产品描述")
credit_svc = CreditService(db)
ok, balance = await credit_svc.deduct(user_id, "lead_search")
if not ok:
raise HTTPException(
status_code=402,
detail=f"次数不足 (剩余 {balance:.1f}, 需要 10)"
)
svc = DiscoveryService(db=db)
try:
result = await svc.search(req.product_description, req.target_market)
return {"success": True, "data": result}
return {"success": True, "data": result, "credits_remaining": balance - 10}
except Exception as e:
await credit_svc.add_credits(user_id, 10, "refund", "搜索失败退回次数")
logger.error(f"Search failed: {e}")
raise HTTPException(status_code=500, detail="搜索失败,请稍后重试")
@router.post("/analyze")
async def analyze_company(req: AnalyzeRequest):
async def analyze_company(
req: AnalyzeRequest,
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
if not req.company_url.strip():
raise HTTPException(status_code=400, detail="请填写公司网址")
if not req.product_description.strip():
raise HTTPException(status_code=400, detail="请填写产品描述")
credit_svc = CreditService(db)
ok, balance = await credit_svc.deduct(user_id, "company_analysis")
if not ok:
raise HTTPException(
status_code=402,
detail=f"次数不足 (剩余 {balance:.1f}, 需要 5)"
)
svc = DiscoveryService()
try:
result = await svc.analyze(req.company_url, req.product_description)
return {"success": True, "data": result}
return {"success": True, "data": result, "credits_remaining": balance - 5}
except Exception as e:
await credit_svc.add_credits(user_id, 5, "refund", "分析失败退回次数")
logger.error(f"Analysis failed: {e}")
raise HTTPException(status_code=500, detail="分析失败,请稍后重试")
@router.post("/outreach")
async def generate_outreach(req: OutreachRequest):
async def generate_outreach(
req: OutreachRequest,
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
if not req.company.get("name"):
raise HTTPException(status_code=400, detail="请填写公司名称")
if not req.product.get("name"):
raise HTTPException(status_code=400, detail="请填写产品名称")
credit_svc = CreditService(db)
ok, balance = await credit_svc.deduct(user_id, "outreach")
if not ok:
raise HTTPException(
status_code=402,
detail=f"次数不足 (剩余 {balance:.1f}, 需要 3)"
)
svc = DiscoveryService()
try:
result = await svc.generate_outreach(req.company, req.product)
return {"success": True, "data": result}
return {"success": True, "data": result, "credits_remaining": balance - 3}
except Exception as e:
await credit_svc.add_credits(user_id, 3, "refund", "生成失败退回次数")
logger.error(f"Outreach generation failed: {e}")
raise HTTPException(status_code=500, detail="生成失败,请稍后重试")
+9
View File
@@ -3,6 +3,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional
from app.database import get_db
from app.services.followup_engine import FollowupEngine
from app.services.credit import CreditService
from app.api.v1.deps import get_current_user_id
router = APIRouter()
@@ -84,6 +85,14 @@ async def trigger_followup_scan(
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
credit_svc = CreditService(db)
ok, balance = await credit_svc.deduct(user_id, "followup_scan")
if not ok:
raise HTTPException(
status_code=402,
detail=f"次数不足 (剩余 {balance:.1f}, 需要 2)"
)
engine = FollowupEngine(db)
result = await engine.scan_and_followup()
return result
+30 -2
View File
@@ -5,6 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.services.marketing import MarketingService
from app.services.preference import UserPreferenceService
from app.services.credit import CreditService
from app.core.security import decode_token
from app.api.v1.deps import get_current_user_id
from app.config import settings
@@ -45,6 +46,14 @@ async def generate_marketing(
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
credit_svc = CreditService(db)
ok, balance = await credit_svc.deduct(user_id, "marketing_content")
if not ok:
raise HTTPException(
status_code=402,
detail=f"次数不足 (剩余 {balance:.1f}, 需要 5)"
)
service = MarketingService()
pref_service = UserPreferenceService(db)
pref_context = await pref_service.get_preference_context(user_id, "marketing")
@@ -63,6 +72,7 @@ async def generate_marketing(
"product": data.product_name,
"target": data.target,
"count": len(results),
"credits_remaining": balance - 5,
}
@@ -70,7 +80,16 @@ async def generate_marketing(
async def generate_keywords(
data: KeywordsRequest,
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
credit_svc = CreditService(db)
ok, balance = await credit_svc.deduct(user_id, "marketing_content")
if not ok:
raise HTTPException(
status_code=402,
detail=f"次数不足 (剩余 {balance:.1f}, 需要 5)"
)
service = MarketingService()
product_info = {
"name": data.product_name,
@@ -79,14 +98,23 @@ async def generate_keywords(
}
keywords = await service.generate_keywords(product_info, data.language, data.count)
return {"keywords": keywords, "product": data.product_name}
return {"keywords": keywords, "product": data.product_name, "credits_remaining": balance - 5}
@router.post("/competitor-analysis")
async def competitor_analysis(
data: CompetitorRequest,
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
credit_svc = CreditService(db)
ok, balance = await credit_svc.deduct(user_id, "competitor_analysis")
if not ok:
raise HTTPException(
status_code=402,
detail=f"次数不足 (剩余 {balance:.1f}, 需要 10)"
)
service = MarketingService()
product_info = {
"name": data.product_name,
@@ -95,4 +123,4 @@ async def competitor_analysis(
}
analysis = await service.analyze_competitors(product_info, data.market)
return {"analysis": analysis, "product": data.product_name, "market": data.market}
return {"analysis": analysis, "product": data.product_name, "market": data.market, "credits_remaining": balance - 10}
+1 -2
View File
@@ -104,8 +104,7 @@ async def import_products(
from app.config import settings
MAX_UPLOAD_SIZE = settings.MAX_UPLOAD_SIZE
MAX_UPLOAD_SIZE = settings.MAX_UPLOAD_SIZE
filename = file.filename or "unknown"
file_size = 0
+27
View File
@@ -6,6 +6,7 @@ from app.database import get_db
from app.services.translation import TranslationService
from app.services.tts import tts_service
from app.services.preference import UserPreferenceService
from app.services.credit import CreditService
from app.core.security import decode_token
from app.api.v1.deps import get_current_user_id
@@ -35,6 +36,7 @@ class ExtractRequest(BaseModel):
async def translate_text(
data: TranslateRequest,
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
service = TranslationService()
result = await service.translate(
@@ -44,6 +46,13 @@ async def translate_text(
context=data.context,
user_id=user_id,
)
credit_svc = CreditService(db)
char_count = len(data.text)
await credit_svc.deduct(
user_id, "translate",
metadata={"chars": char_count, "target_lang": data.target_lang},
)
return result
@@ -54,6 +63,15 @@ async def generate_reply(
db: AsyncSession = Depends(get_db),
):
pref_service = UserPreferenceService(db)
credit_svc = CreditService(db)
ok, balance = await credit_svc.deduct(user_id, "reply_suggest")
if not ok:
raise HTTPException(
status_code=402,
detail=f"次数不足 (剩余 {balance:.1f}, 需要 2)"
)
pref_context = await pref_service.get_preference_context(user_id, "reply")
service = TranslationService()
@@ -71,7 +89,16 @@ async def generate_reply(
async def extract_info(
data: ExtractRequest,
user_id: str = Depends(get_current_user_id),
db: AsyncSession = Depends(get_db),
):
credit_svc = CreditService(db)
ok, balance = await credit_svc.deduct(user_id, "info_extract")
if not ok:
raise HTTPException(
status_code=402,
detail=f"次数不足 (剩余 {balance:.1f}, 需要 1)"
)
service = TranslationService()
result = await service.extract_info(data.text, data.extract_type)
return {"extracted": result, "type": data.extract_type}
+3 -1
View File
@@ -129,7 +129,7 @@ async def health():
return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"}
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, discovery_record, certification, invoice, usage, referral, admin_search, search, admin_ai
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup, ai_assistant, discovery, discovery_record, certification, invoice, usage, referral, admin_search, search, admin_ai, credits, admin_credits
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"])
@@ -161,6 +161,8 @@ app.include_router(usage.router, prefix="/api/v1/usage", tags=["usage"])
app.include_router(referral.router, prefix="/api/v1/referral", tags=["referral"])
app.include_router(admin_search.router, prefix="/api/v1/admin", tags=["admin"])
app.include_router(admin_ai.router, prefix="/api/v1/admin", tags=["admin"])
app.include_router(admin_credits.router, prefix="/api/v1/admin", tags=["admin"])
app.include_router(credits.router, prefix="/api/v1/credits", tags=["credits"])
app.include_router(search.router, prefix="/api/v1/search", tags=["search"])
+8
View File
@@ -19,6 +19,10 @@ from .search_provider import SearchProvider
from .discovery_record import DiscoveryRecord
from .ai_provider import AIProvider
from .payment_transaction import PaymentTransaction
from .credit_package import CreditPackage, SubscriptionPlan
from .user_credit import UserCredit
from .credit_consumption import CreditConsumption
from .credit_purchase import CreditPurchase
__all__ = [
"User", "Product",
@@ -37,4 +41,8 @@ __all__ = [
"DiscoveryRecord",
"AIProvider",
"PaymentTransaction",
"CreditPackage", "SubscriptionPlan",
"UserCredit",
"CreditConsumption",
"CreditPurchase",
]
+20
View File
@@ -0,0 +1,20 @@
from sqlalchemy import Column, String, Float, DateTime, ForeignKey, Text
from sqlalchemy.dialects.postgresql import UUID, JSONB
from datetime import datetime
from app.database import Base
import uuid
class CreditConsumption(Base):
__tablename__ = "credit_consumptions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, index=True)
result_type = Column(String(50), nullable=False)
reference_id = Column(UUID(as_uuid=True), nullable=True)
credits_change = Column(Float, nullable=False)
balance_after = Column(Float, nullable=False)
source = Column(String(30), nullable=False)
description = Column(String(500))
metadata_ = Column("metadata", JSONB)
created_at = Column(DateTime, default=datetime.utcnow)
+37
View File
@@ -0,0 +1,37 @@
from sqlalchemy import Column, String, Integer, Float, Boolean, DateTime
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime
from app.database import Base
import uuid
class CreditPackage(Base):
__tablename__ = "credit_packages"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(100), nullable=False)
name_en = Column(String(100), nullable=False)
credits = Column(Integer, nullable=False)
price = Column(Float, nullable=False)
price_usd = Column(Float)
original_price = Column(Float)
is_active = Column(Boolean, default=True)
sort_order = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class SubscriptionPlan(Base):
__tablename__ = "subscription_plans"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(100), nullable=False)
name_en = Column(String(100), nullable=False)
credits_per_month = Column(Integer, nullable=False)
price = Column(Float, nullable=False)
price_usd = Column(Float)
duration_days = Column(Integer, default=30)
is_active = Column(Boolean, default=True)
sort_order = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
+22
View File
@@ -0,0 +1,22 @@
from sqlalchemy import Column, String, Integer, Float, DateTime, ForeignKey
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime
from app.database import Base
import uuid
class CreditPurchase(Base):
__tablename__ = "credit_purchases"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, index=True)
package_id = Column(UUID(as_uuid=True), ForeignKey("credit_packages.id"), nullable=True)
subscription_plan_id = Column(UUID(as_uuid=True), ForeignKey("subscription_plans.id"), nullable=True)
credits = Column(Integer, nullable=False)
amount = Column(Float, nullable=False)
currency = Column(String(3), default="CNY")
payment_method = Column(String(20))
status = Column(String(20), default="pending")
payment_transaction_id = Column(UUID(as_uuid=True), ForeignKey("payment_transactions.id"), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
paid_at = Column(DateTime, nullable=True)
+26
View File
@@ -0,0 +1,26 @@
from sqlalchemy import Column, String, Float, Boolean, DateTime, ForeignKey, Integer, Date
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime
from app.database import Base
import uuid
class UserCredit(Base):
__tablename__ = "user_credits"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, unique=True, index=True)
balance = Column(Float, default=0)
total_purchased = Column(Float, default=0)
total_used = Column(Float, default=0)
subscription_plan_id = Column(UUID(as_uuid=True), ForeignKey("subscription_plans.id"), nullable=True)
subscription_expires_at = Column(DateTime, nullable=True)
subscription_auto_renew = Column(Boolean, default=False)
free_trial_used = Column(Boolean, default=False)
daily_translate_chars = Column(Integer, default=0)
daily_translate_date = Column(Date, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
+256
View File
@@ -0,0 +1,256 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, desc
from datetime import datetime, date
from decimal import Decimal
import logging
from app.models import UserCredit, CreditConsumption, CreditPackage, SubscriptionPlan, CreditPurchase
from app.models.system_config import SystemConfig
logger = logging.getLogger(__name__)
DEFAULT_CONSUMPTION_RATES = {
"lead_search": 10,
"company_analysis": 5,
"market_intel": 20,
"translate_per_1000chars": 1,
"reply_suggest": 2,
"outreach": 3,
"marketing_content": 5,
"competitor_analysis": 10,
"ai_chat_per_10msg": 1,
"info_extract": 1,
"quotation": 2,
"followup_scan": 2,
}
FREE_TRIAL_CREDITS = 30
DAILY_FREE_TRANSLATE_CHARS = 1000
class CreditService:
def __init__(self, db: AsyncSession):
self.db = db
async def _ensure_credit(self, user_id: str) -> UserCredit:
result = await self.db.execute(
select(UserCredit).where(UserCredit.user_id == user_id)
)
uc = result.scalar_one_or_none()
if not uc:
uc = UserCredit(user_id=user_id, balance=0)
self.db.add(uc)
await self.db.flush()
return uc
async def get_balance(self, user_id: str) -> dict:
uc = await self._ensure_credit(user_id)
rates = await self._get_rates()
return {
"balance": uc.balance,
"total_purchased": uc.total_purchased,
"total_used": uc.total_used,
"subscription": {
"plan_id": str(uc.subscription_plan_id) if uc.subscription_plan_id else None,
"expires_at": uc.subscription_expires_at.isoformat() if uc.subscription_expires_at else None,
"auto_renew": uc.subscription_auto_renew,
} if uc.subscription_plan_id else None,
"free_trial_used": uc.free_trial_used,
"daily_free_translate_chars_left": max(0, DAILY_FREE_TRANSLATE_CHARS - await self._daily_translate_chars(uc)),
"rates": rates,
}
async def deduct(self, user_id: str, result_type: str, reference_id: str = None, amount: float = None, metadata: dict = None) -> tuple[bool, float]:
rates = await self._get_rates()
cost = amount or rates.get(result_type, 1)
uc = await self._ensure_credit(user_id)
if result_type == "translate":
char_count = (metadata or {}).get("chars", 0)
if char_count > 0:
daily_free = await self._daily_translate_chars(uc)
free_remaining = max(0, DAILY_FREE_TRANSLATE_CHARS - daily_free)
free_used = min(free_remaining, char_count)
paid_chars = char_count - free_used
cost = (paid_chars / 1000) * rates.get("translate_per_1000chars", 1)
if free_used > 0:
today = date.today()
if uc.daily_translate_date != today:
uc.daily_translate_date = today
uc.daily_translate_chars = 0
uc.daily_translate_chars += free_used
await self.db.flush()
if cost <= 0:
await self._log(user_id, result_type, reference_id, 0, uc.balance, "daily_free", metadata)
return True, uc.balance
if uc.balance < cost:
return False, uc.balance
uc.balance -= cost
uc.total_used += cost
balance_after = uc.balance
await self._log(user_id, result_type, reference_id, -cost, balance_after, "credit", metadata)
await self.db.flush()
return True, balance_after
async def add_credits(self, user_id: str, credits: float, source: str, description: str = None) -> float:
uc = await self._ensure_credit(user_id)
uc.balance += credits
if credits > 0:
uc.total_purchased += credits
balance_after = uc.balance
await self._log(user_id, "topup", None, credits, balance_after, source, {"description": description})
await self.db.flush()
return balance_after
async def grant_free_trial(self, user_id: str) -> float:
uc = await self._ensure_credit(user_id)
if uc.free_trial_used:
return uc.balance
return await self.add_credits(
user_id, FREE_TRIAL_CREDITS, "free_trial",
f"新用户注册赠送 {FREE_TRIAL_CREDITS}"
)
async def consume_for_subscription(self, user_id: str, plan_id: str) -> tuple[bool, str]:
result = await self.db.execute(
select(SubscriptionPlan).where(SubscriptionPlan.id == plan_id, SubscriptionPlan.is_active == True)
)
plan = result.scalar_one_or_none()
if not plan:
return False, "套餐不存在"
uc = await self._ensure_credit(user_id)
amount = plan.price
return True, "ok"
async def _log(self, user_id: str, result_type: str, reference_id: str,
credits_change: float, balance_after: float, source: str, metadata: dict = None):
log = CreditConsumption(
user_id=user_id,
result_type=result_type,
reference_id=reference_id,
credits_change=credits_change,
balance_after=balance_after,
source=source,
metadata_=metadata or {},
)
self.db.add(log)
async def get_history(self, user_id: str, page: int = 1, size: int = 20) -> dict:
offset = (page - 1) * size
stmt = select(CreditConsumption).where(
CreditConsumption.user_id == user_id
).order_by(desc(CreditConsumption.created_at)).offset(offset).limit(size)
result = await self.db.execute(stmt)
items = result.scalars().all()
count_stmt = select(func.count()).where(CreditConsumption.user_id == user_id)
count_result = await self.db.execute(count_stmt)
total = count_result.scalar() or 0
return {
"items": [{
"id": str(item.id),
"result_type": item.result_type,
"credits_change": item.credits_change,
"balance_after": item.balance_after,
"source": item.source,
"description": item.description,
"created_at": item.created_at.isoformat() if item.created_at else None,
} for item in items],
"total": total,
"page": page,
"size": size,
}
async def _get_rates(self) -> dict:
result = await self.db.execute(
select(SystemConfig).where(SystemConfig.key == "credit_consumption_rates")
)
row = result.scalar_one_or_none()
if row and row.value:
return {**DEFAULT_CONSUMPTION_RATES, **row.value}
return dict(DEFAULT_CONSUMPTION_RATES)
async def _daily_translate_chars(self, uc: UserCredit) -> int:
today = date.today()
if uc.daily_translate_date != today:
return 0
return uc.daily_translate_chars or 0
async def get_packages(self) -> list:
result = await self.db.execute(
select(CreditPackage).where(CreditPackage.is_active == True).order_by(CreditPackage.sort_order)
)
return [{
"id": str(p.id),
"name": p.name,
"name_en": p.name_en,
"credits": p.credits,
"price": p.price,
"price_usd": p.price_usd,
"original_price": p.original_price,
} for p in result.scalars().all()]
async def get_subscription_plans(self) -> list:
result = await self.db.execute(
select(SubscriptionPlan).where(SubscriptionPlan.is_active == True).order_by(SubscriptionPlan.sort_order)
)
return [{
"id": str(p.id),
"name": p.name,
"name_en": p.name_en,
"credits_per_month": p.credits_per_month,
"price": p.price,
"price_usd": p.price_usd,
"duration_days": p.duration_days,
} for p in result.scalars().all()]
async def get_stats(self) -> dict:
result = await self.db.execute(
select(func.coalesce(func.sum(UserCredit.total_purchased), 0))
)
total_purchased = result.scalar()
result = await self.db.execute(
select(func.coalesce(func.sum(UserCredit.balance), 0))
)
total_balance = result.scalar()
result = await self.db.execute(select(func.count(UserCredit.id)))
total_users = result.scalar()
result = await self.db.execute(
select(func.coalesce(func.sum(CreditConsumption.credits_change), 0)).where(
CreditConsumption.credits_change < 0
)
)
total_consumed = abs(result.scalar() or 0)
return {
"total_purchased": total_purchased,
"total_balance": total_balance,
"total_consumed": total_consumed,
"total_users_with_credits": total_users,
}
CREDIT_CONSUMPTION = {
"lead_search": 10,
"company_analysis": 5,
"market_intel": 20,
"translate_per_1000chars": 1,
"reply_suggest": 2,
"outreach": 3,
"marketing_content": 5,
"competitor_analysis": 10,
"ai_chat": 1,
"info_extract": 1,
"quotation": 2,
"followup_scan": 2,
}
+1 -1
View File
@@ -259,7 +259,7 @@ URL: {company_url}
return json.loads(text)
except json.JSONDecodeError:
import re
brace = text.find("{")
brace = text.find("{")
end = text.rfind("}")
if brace >= 0 and end > brace:
try:
+74 -19
View File
@@ -143,6 +143,39 @@ class PaymentService:
**gw_result,
}
async def create_credit_order(self, user_id: str, amount: float,
description: str, pay_type: str = "alipay",
metadata: dict = None) -> Dict[str, Any]:
order_no = gen_order_no(user_id)
gw = get_gateway(pay_type)
meta_remark = {"uid": user_id, "oid": order_no, "type": "credit_purchase"}
if metadata:
meta_remark.update(metadata)
gw_result = await gw.create_order(order_no, int(amount * 100),
description, pay_type=pay_type,
remark=json.dumps(meta_remark, separators=(",", ":")))
txn = PaymentTransaction(
user_id=user_id, order_no=order_no, plan="credit_purchase",
amount=amount, gateway="unified", pay_type=pay_type,
status="pending", description=json.dumps(metadata or {}, ensure_ascii=False),
gateway_order_no=gw_result.get("gateway_order_id", ""),
)
self.db.add(txn)
await self.db.flush()
return {
"status": "pending",
"order_id": order_no,
"amount": amount,
"currency": "CNY",
"gateway": "unified",
"pay_type": pay_type,
"metadata": metadata or {},
**gw_result,
}
async def handle_callback(self, order_no: str, gateway_order_id: str,
gateway_order_no: str, success: bool,
amount: float = 0, notify_raw: str = "") -> bool:
@@ -162,30 +195,52 @@ class PaymentService:
txn.paid_at = datetime.utcnow()
txn.notify_raw = notify_raw
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "active"
sub.started_at = datetime.utcnow()
if PLANS[sub.plan]["duration_days"]:
sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"])
if txn.plan == "credit_purchase":
from app.services.credit import CreditService
credit_svc = CreditService(self.db)
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
user = user_result.scalar_one_or_none()
if user:
user.tier = txn.plan
if txn.description:
try:
meta = json.loads(txn.description)
credits = meta.get("credits", 0)
except (json.JSONDecodeError, TypeError):
credits = 0
else:
credits = 0
if not credits:
credits = max(1, int(txn.amount / 0.79))
await credit_svc.add_credits(
txn.user_id, credits, "package",
f"支付完成 - 获得 {credits} 次信用额度"
)
else:
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "active"
sub.started_at = datetime.utcnow()
if PLANS[sub.plan]["duration_days"]:
sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"])
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
user = user_result.scalar_one_or_none()
if user:
user.tier = txn.plan
else:
txn.status = "failed"
txn.notify_raw = notify_raw
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "failed"
if txn.plan != "credit_purchase":
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "failed"
await self.db.flush()
return True