feat: credit-based billing system
- New DB models: credit_packages, subscription_plans, user_credits, credit_consumptions, credit_purchases - CreditService: balance, deduct, add_credits, grant_free_trial, history - User API: /api/v1/credits/* (balance/history/packages/purchase/subscribe) - Admin API: /api/v1/admin/credit-* (CRUD packages/plans, user credits, consumptions) - PaymentService.create_credit_order + handle_callback for credit purchases - Credit deduction on: discovery, translate, marketing, ai_chat, followup - Free trial 30 credits on registration - Documentation: docs/CREDIT_SYSTEM.md
This commit is contained in:
@@ -0,0 +1,308 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.api.v1.admin import require_admin
|
||||
from app.services.credit import CreditService
|
||||
from app.models.credit_package import CreditPackage, SubscriptionPlan
|
||||
from app.models.user_credit import UserCredit
|
||||
from app.models.user import User
|
||||
import uuid
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class PackageForm(BaseModel):
|
||||
name: str
|
||||
name_en: str
|
||||
credits: int
|
||||
price: float
|
||||
price_usd: Optional[float] = None
|
||||
original_price: Optional[float] = None
|
||||
is_active: bool = True
|
||||
sort_order: int = 0
|
||||
|
||||
|
||||
class PlanForm(BaseModel):
|
||||
name: str
|
||||
name_en: str
|
||||
credits_per_month: int
|
||||
price: float
|
||||
price_usd: Optional[float] = None
|
||||
duration_days: int = 30
|
||||
is_active: bool = True
|
||||
sort_order: int = 0
|
||||
|
||||
|
||||
class AdjustCreditsForm(BaseModel):
|
||||
user_id: str
|
||||
credits: float
|
||||
reason: str = ""
|
||||
|
||||
|
||||
@router.get("/credit-packages")
|
||||
async def list_packages(
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(CreditPackage).order_by(CreditPackage.sort_order)
|
||||
)
|
||||
return [{
|
||||
"id": str(p.id),
|
||||
"name": p.name,
|
||||
"name_en": p.name_en,
|
||||
"credits": p.credits,
|
||||
"price": p.price,
|
||||
"price_usd": p.price_usd,
|
||||
"original_price": p.original_price,
|
||||
"is_active": p.is_active,
|
||||
"sort_order": p.sort_order,
|
||||
} for p in result.scalars().all()]
|
||||
|
||||
|
||||
@router.post("/credit-packages")
|
||||
async def create_package(
|
||||
data: PackageForm,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
pkg = CreditPackage(**data.model_dump())
|
||||
db.add(pkg)
|
||||
await db.flush()
|
||||
return {"id": str(pkg.id), "status": "ok"}
|
||||
|
||||
|
||||
@router.put("/credit-packages/{pkg_id}")
|
||||
async def update_package(
|
||||
pkg_id: str,
|
||||
data: PackageForm,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
uid = uuid.UUID(pkg_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="无效ID")
|
||||
result = await db.execute(select(CreditPackage).where(CreditPackage.id == uid))
|
||||
pkg = result.scalar_one_or_none()
|
||||
if not pkg:
|
||||
raise HTTPException(status_code=404, detail="次数包不存在")
|
||||
for k, v in data.model_dump().items():
|
||||
setattr(pkg, k, v)
|
||||
await db.flush()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.delete("/credit-packages/{pkg_id}")
|
||||
async def delete_package(
|
||||
pkg_id: str,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
uid = uuid.UUID(pkg_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="无效ID")
|
||||
result = await db.execute(select(CreditPackage).where(CreditPackage.id == uid))
|
||||
pkg = result.scalar_one_or_none()
|
||||
if not pkg:
|
||||
raise HTTPException(status_code=404, detail="次数包不存在")
|
||||
await db.delete(pkg)
|
||||
await db.flush()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/subscription-plans")
|
||||
async def list_plans(
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(SubscriptionPlan).order_by(SubscriptionPlan.sort_order)
|
||||
)
|
||||
return [{
|
||||
"id": str(p.id),
|
||||
"name": p.name,
|
||||
"name_en": p.name_en,
|
||||
"credits_per_month": p.credits_per_month,
|
||||
"price": p.price,
|
||||
"price_usd": p.price_usd,
|
||||
"duration_days": p.duration_days,
|
||||
"is_active": p.is_active,
|
||||
"sort_order": p.sort_order,
|
||||
} for p in result.scalars().all()]
|
||||
|
||||
|
||||
@router.post("/subscription-plans")
|
||||
async def create_plan(
|
||||
data: PlanForm,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
plan = SubscriptionPlan(**data.model_dump())
|
||||
db.add(plan)
|
||||
await db.flush()
|
||||
return {"id": str(plan.id), "status": "ok"}
|
||||
|
||||
|
||||
@router.put("/subscription-plans/{plan_id}")
|
||||
async def update_plan(
|
||||
plan_id: str,
|
||||
data: PlanForm,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
uid = uuid.UUID(plan_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="无效ID")
|
||||
result = await db.execute(select(SubscriptionPlan).where(SubscriptionPlan.id == uid))
|
||||
plan = result.scalar_one_or_none()
|
||||
if not plan:
|
||||
raise HTTPException(status_code=404, detail="订阅套餐不存在")
|
||||
for k, v in data.model_dump().items():
|
||||
setattr(plan, k, v)
|
||||
await db.flush()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.delete("/subscription-plans/{plan_id}")
|
||||
async def delete_plan(
|
||||
plan_id: str,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
uid = uuid.UUID(plan_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="无效ID")
|
||||
result = await db.execute(select(SubscriptionPlan).where(SubscriptionPlan.id == uid))
|
||||
plan = result.scalar_one_or_none()
|
||||
if not plan:
|
||||
raise HTTPException(status_code=404, detail="订阅套餐不存在")
|
||||
await db.delete(plan)
|
||||
await db.flush()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/user-credits")
|
||||
async def list_user_credits(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
offset = (page - 1) * size
|
||||
result = await db.execute(
|
||||
select(UserCredit).order_by(UserCredit.updated_at.desc()).offset(offset).limit(size)
|
||||
)
|
||||
items = result.scalars().all()
|
||||
|
||||
from sqlalchemy import func
|
||||
count_result = await db.execute(select(func.count(UserCredit.id)))
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
enriched = []
|
||||
for uc in items:
|
||||
user_result = await db.execute(select(User).where(User.id == uc.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
enriched.append({
|
||||
"id": str(uc.id),
|
||||
"user_id": str(uc.user_id),
|
||||
"username": user.username if user else "N/A",
|
||||
"balance": uc.balance,
|
||||
"total_purchased": uc.total_purchased,
|
||||
"total_used": uc.total_used,
|
||||
"subscription_plan_id": str(uc.subscription_plan_id) if uc.subscription_plan_id else None,
|
||||
"subscription_expires_at": uc.subscription_expires_at.isoformat() if uc.subscription_expires_at else None,
|
||||
"free_trial_used": uc.free_trial_used,
|
||||
"updated_at": uc.updated_at.isoformat() if uc.updated_at else None,
|
||||
})
|
||||
|
||||
return {"items": enriched, "total": total, "page": page, "size": size}
|
||||
|
||||
|
||||
class AdjustForm(BaseModel):
|
||||
user_id: str
|
||||
credits: float
|
||||
reason: str = ""
|
||||
|
||||
|
||||
@router.post("/user-credits/adjust")
|
||||
async def adjust_credits(
|
||||
data: AdjustForm,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CreditService(db)
|
||||
try:
|
||||
uid = uuid.UUID(data.user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="无效用户ID")
|
||||
balance = await svc.add_credits(
|
||||
user_id=uid,
|
||||
credits=data.credits,
|
||||
source="admin_grant",
|
||||
description=data.reason or f"管理员调整: {data.credits:+.1f} 次",
|
||||
)
|
||||
return {"status": "ok", "balance": balance}
|
||||
|
||||
|
||||
@router.get("/credit-consumptions")
|
||||
async def list_consumptions(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=200),
|
||||
user_id: str = Query(None),
|
||||
result_type: str = Query(None),
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
from app.models.credit_consumption import CreditConsumption
|
||||
from sqlalchemy import select, func, desc
|
||||
|
||||
conditions = []
|
||||
if user_id:
|
||||
try:
|
||||
conditions.append(CreditConsumption.user_id == uuid.UUID(user_id))
|
||||
except ValueError:
|
||||
pass
|
||||
if result_type:
|
||||
conditions.append(CreditConsumption.result_type == result_type)
|
||||
|
||||
stmt = select(CreditConsumption).where(*conditions).order_by(
|
||||
desc(CreditConsumption.created_at)
|
||||
).offset((page - 1) * size).limit(size)
|
||||
result = await db.execute(stmt)
|
||||
items = result.scalars().all()
|
||||
|
||||
count_stmt = select(func.count(CreditConsumption.id)).where(*conditions)
|
||||
count_result = await db.execute(count_stmt)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"items": [{
|
||||
"id": str(c.id),
|
||||
"user_id": str(c.user_id),
|
||||
"result_type": c.result_type,
|
||||
"credits_change": c.credits_change,
|
||||
"balance_after": c.balance_after,
|
||||
"source": c.source,
|
||||
"description": c.description,
|
||||
"created_at": c.created_at.isoformat() if c.created_at else None,
|
||||
} for c in items],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/credit-stats")
|
||||
async def credit_stats(
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CreditService(db)
|
||||
return await svc.get_stats()
|
||||
@@ -9,6 +9,7 @@ from app.ai.local_faq import match_faq
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.services.admin import AdminService
|
||||
from app.services.credit import CreditService
|
||||
import logging
|
||||
import time
|
||||
import re
|
||||
@@ -108,6 +109,14 @@ async def chat(
|
||||
f"db={t1-t0:.2f}s orm={t2-t1:.2f}s total={t4-t_start:.2f}s"
|
||||
)
|
||||
else:
|
||||
credit_svc = CreditService(db)
|
||||
ok, balance = await credit_svc.deduct(user_id, "ai_chat")
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"次数不足 (剩余 {balance:.1f})"
|
||||
)
|
||||
|
||||
t3 = time.time()
|
||||
ai = get_ai_router()
|
||||
result = await ai.chat(data.message, data.history or [], system_prompt)
|
||||
|
||||
@@ -81,6 +81,10 @@ async def register(
|
||||
)
|
||||
db.add(sub)
|
||||
|
||||
from app.services.credit import CreditService
|
||||
credit_svc = CreditService(db)
|
||||
await credit_svc.grant_free_trial(user.id)
|
||||
|
||||
if data.ref_code:
|
||||
try:
|
||||
from app.api.v1.referral import do_claim_referral
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.services.credit import CreditService
|
||||
from app.services.payment import PaymentService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class PurchaseRequest(BaseModel):
|
||||
package_id: str
|
||||
pay_type: str = "alipay"
|
||||
|
||||
|
||||
class SubscribeRequest(BaseModel):
|
||||
plan_id: str
|
||||
pay_type: str = "alipay"
|
||||
|
||||
|
||||
@router.get("/balance")
|
||||
async def get_balance(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CreditService(db)
|
||||
return await svc.get_balance(user_id)
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_history(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CreditService(db)
|
||||
return await svc.get_history(user_id, page, size)
|
||||
|
||||
|
||||
@router.get("/packages")
|
||||
async def list_packages(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CreditService(db)
|
||||
return await svc.get_packages()
|
||||
|
||||
|
||||
@router.get("/subscription-plans")
|
||||
async def list_subscription_plans(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CreditService(db)
|
||||
return await svc.get_subscription_plans()
|
||||
|
||||
|
||||
@router.post("/purchase")
|
||||
async def purchase_package(
|
||||
req: PurchaseRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CreditService(db)
|
||||
packages = await svc.get_packages()
|
||||
pkg = next((p for p in packages if p["id"] == req.package_id), None)
|
||||
if not pkg:
|
||||
raise HTTPException(status_code=404, detail="次数包不存在")
|
||||
|
||||
pay_svc = PaymentService(db)
|
||||
order = await pay_svc.create_credit_order(
|
||||
user_id=user_id,
|
||||
amount=pkg["price"],
|
||||
description=f"购买 {pkg['name']} ({pkg['credits']}次)",
|
||||
pay_type=req.pay_type,
|
||||
metadata={"credit_package_id": req.package_id, "credits": pkg["credits"]},
|
||||
)
|
||||
return order
|
||||
|
||||
|
||||
@router.post("/subscribe")
|
||||
async def subscribe_plan(
|
||||
req: SubscribeRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = CreditService(db)
|
||||
plans = await svc.get_subscription_plans()
|
||||
plan = next((p for p in plans if p["id"] == req.plan_id), None)
|
||||
if not plan:
|
||||
raise HTTPException(status_code=404, detail="订阅套餐不存在")
|
||||
|
||||
pay_svc = PaymentService(db)
|
||||
order = await pay_svc.create_credit_order(
|
||||
user_id=user_id,
|
||||
amount=plan["price"],
|
||||
description=f"开通 {plan['name']} (每月{plan['credits_per_month']}次)",
|
||||
pay_type=req.pay_type,
|
||||
metadata={"subscription_plan_id": req.plan_id, "credits_per_month": plan["credits_per_month"]},
|
||||
)
|
||||
return order
|
||||
|
||||
|
||||
@router.post("/cancel-subscription")
|
||||
async def cancel_subscription(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
from app.models.user_credit import UserCredit
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await db.execute(select(UserCredit).where(UserCredit.user_id == user_id))
|
||||
uc = result.scalar_one_or_none()
|
||||
if not uc or not uc.subscription_plan_id:
|
||||
raise HTTPException(status_code=400, detail="没有有效的订阅")
|
||||
|
||||
uc.subscription_auto_renew = False
|
||||
await db.flush()
|
||||
return {"success": True, "message": "已取消自动续费,当前订阅到期后不再续费"}
|
||||
@@ -4,6 +4,11 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.services.discovery import DiscoveryService
|
||||
from app.services.credit import CreditService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -23,45 +28,104 @@ class OutreachRequest(BaseModel):
|
||||
product: Dict[str, Any]
|
||||
|
||||
|
||||
CREDIT_COST = {
|
||||
"search": 10,
|
||||
"analyze": 5,
|
||||
"outreach": 3,
|
||||
}
|
||||
|
||||
|
||||
async def _deduct_credits(user_id: str, result_type: str, db: AsyncSession):
|
||||
svc = CreditService(db)
|
||||
ok, balance = await svc.deduct(user_id, result_type)
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"次数不足 (剩余 {balance:.1f}, 需要 {CREDIT_COST.get(result_type, 1)})"
|
||||
)
|
||||
return balance
|
||||
|
||||
|
||||
@router.post("/search")
|
||||
async def search_leads(req: SearchRequest, db: AsyncSession = Depends(get_db)):
|
||||
async def search_leads(
|
||||
req: SearchRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not req.product_description.strip():
|
||||
raise HTTPException(status_code=400, detail="请填写产品描述")
|
||||
|
||||
credit_svc = CreditService(db)
|
||||
ok, balance = await credit_svc.deduct(user_id, "lead_search")
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"次数不足 (剩余 {balance:.1f}, 需要 10)"
|
||||
)
|
||||
|
||||
svc = DiscoveryService(db=db)
|
||||
try:
|
||||
result = await svc.search(req.product_description, req.target_market)
|
||||
return {"success": True, "data": result}
|
||||
return {"success": True, "data": result, "credits_remaining": balance - 10}
|
||||
except Exception as e:
|
||||
await credit_svc.add_credits(user_id, 10, "refund", "搜索失败退回次数")
|
||||
logger.error(f"Search failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="搜索失败,请稍后重试")
|
||||
|
||||
|
||||
@router.post("/analyze")
|
||||
async def analyze_company(req: AnalyzeRequest):
|
||||
async def analyze_company(
|
||||
req: AnalyzeRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not req.company_url.strip():
|
||||
raise HTTPException(status_code=400, detail="请填写公司网址")
|
||||
if not req.product_description.strip():
|
||||
raise HTTPException(status_code=400, detail="请填写产品描述")
|
||||
|
||||
credit_svc = CreditService(db)
|
||||
ok, balance = await credit_svc.deduct(user_id, "company_analysis")
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"次数不足 (剩余 {balance:.1f}, 需要 5)"
|
||||
)
|
||||
|
||||
svc = DiscoveryService()
|
||||
try:
|
||||
result = await svc.analyze(req.company_url, req.product_description)
|
||||
return {"success": True, "data": result}
|
||||
return {"success": True, "data": result, "credits_remaining": balance - 5}
|
||||
except Exception as e:
|
||||
await credit_svc.add_credits(user_id, 5, "refund", "分析失败退回次数")
|
||||
logger.error(f"Analysis failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="分析失败,请稍后重试")
|
||||
|
||||
|
||||
@router.post("/outreach")
|
||||
async def generate_outreach(req: OutreachRequest):
|
||||
async def generate_outreach(
|
||||
req: OutreachRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not req.company.get("name"):
|
||||
raise HTTPException(status_code=400, detail="请填写公司名称")
|
||||
if not req.product.get("name"):
|
||||
raise HTTPException(status_code=400, detail="请填写产品名称")
|
||||
|
||||
credit_svc = CreditService(db)
|
||||
ok, balance = await credit_svc.deduct(user_id, "outreach")
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"次数不足 (剩余 {balance:.1f}, 需要 3)"
|
||||
)
|
||||
|
||||
svc = DiscoveryService()
|
||||
try:
|
||||
result = await svc.generate_outreach(req.company, req.product)
|
||||
return {"success": True, "data": result}
|
||||
return {"success": True, "data": result, "credits_remaining": balance - 3}
|
||||
except Exception as e:
|
||||
await credit_svc.add_credits(user_id, 3, "refund", "生成失败退回次数")
|
||||
logger.error(f"Outreach generation failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="生成失败,请稍后重试")
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Optional
|
||||
from app.database import get_db
|
||||
from app.services.followup_engine import FollowupEngine
|
||||
from app.services.credit import CreditService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
@@ -84,6 +85,14 @@ async def trigger_followup_scan(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
credit_svc = CreditService(db)
|
||||
ok, balance = await credit_svc.deduct(user_id, "followup_scan")
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"次数不足 (剩余 {balance:.1f}, 需要 2)"
|
||||
)
|
||||
|
||||
engine = FollowupEngine(db)
|
||||
result = await engine.scan_and_followup()
|
||||
return result
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.services.marketing import MarketingService
|
||||
from app.services.preference import UserPreferenceService
|
||||
from app.services.credit import CreditService
|
||||
from app.core.security import decode_token
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.config import settings
|
||||
@@ -45,6 +46,14 @@ async def generate_marketing(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
credit_svc = CreditService(db)
|
||||
ok, balance = await credit_svc.deduct(user_id, "marketing_content")
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"次数不足 (剩余 {balance:.1f}, 需要 5)"
|
||||
)
|
||||
|
||||
service = MarketingService()
|
||||
pref_service = UserPreferenceService(db)
|
||||
pref_context = await pref_service.get_preference_context(user_id, "marketing")
|
||||
@@ -63,6 +72,7 @@ async def generate_marketing(
|
||||
"product": data.product_name,
|
||||
"target": data.target,
|
||||
"count": len(results),
|
||||
"credits_remaining": balance - 5,
|
||||
}
|
||||
|
||||
|
||||
@@ -70,7 +80,16 @@ async def generate_marketing(
|
||||
async def generate_keywords(
|
||||
data: KeywordsRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
credit_svc = CreditService(db)
|
||||
ok, balance = await credit_svc.deduct(user_id, "marketing_content")
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"次数不足 (剩余 {balance:.1f}, 需要 5)"
|
||||
)
|
||||
|
||||
service = MarketingService()
|
||||
product_info = {
|
||||
"name": data.product_name,
|
||||
@@ -79,14 +98,23 @@ async def generate_keywords(
|
||||
}
|
||||
keywords = await service.generate_keywords(product_info, data.language, data.count)
|
||||
|
||||
return {"keywords": keywords, "product": data.product_name}
|
||||
return {"keywords": keywords, "product": data.product_name, "credits_remaining": balance - 5}
|
||||
|
||||
|
||||
@router.post("/competitor-analysis")
|
||||
async def competitor_analysis(
|
||||
data: CompetitorRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
credit_svc = CreditService(db)
|
||||
ok, balance = await credit_svc.deduct(user_id, "competitor_analysis")
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"次数不足 (剩余 {balance:.1f}, 需要 10)"
|
||||
)
|
||||
|
||||
service = MarketingService()
|
||||
product_info = {
|
||||
"name": data.product_name,
|
||||
@@ -95,4 +123,4 @@ async def competitor_analysis(
|
||||
}
|
||||
analysis = await service.analyze_competitors(product_info, data.market)
|
||||
|
||||
return {"analysis": analysis, "product": data.product_name, "market": data.market}
|
||||
return {"analysis": analysis, "product": data.product_name, "market": data.market, "credits_remaining": balance - 10}
|
||||
|
||||
@@ -104,8 +104,7 @@ async def import_products(
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
MAX_UPLOAD_SIZE = settings.MAX_UPLOAD_SIZE
|
||||
MAX_UPLOAD_SIZE = settings.MAX_UPLOAD_SIZE
|
||||
|
||||
filename = file.filename or "unknown"
|
||||
file_size = 0
|
||||
|
||||
@@ -6,6 +6,7 @@ from app.database import get_db
|
||||
from app.services.translation import TranslationService
|
||||
from app.services.tts import tts_service
|
||||
from app.services.preference import UserPreferenceService
|
||||
from app.services.credit import CreditService
|
||||
from app.core.security import decode_token
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
@@ -35,6 +36,7 @@ class ExtractRequest(BaseModel):
|
||||
async def translate_text(
|
||||
data: TranslateRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = TranslationService()
|
||||
result = await service.translate(
|
||||
@@ -44,6 +46,13 @@ async def translate_text(
|
||||
context=data.context,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
credit_svc = CreditService(db)
|
||||
char_count = len(data.text)
|
||||
await credit_svc.deduct(
|
||||
user_id, "translate",
|
||||
metadata={"chars": char_count, "target_lang": data.target_lang},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -54,6 +63,15 @@ async def generate_reply(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
pref_service = UserPreferenceService(db)
|
||||
|
||||
credit_svc = CreditService(db)
|
||||
ok, balance = await credit_svc.deduct(user_id, "reply_suggest")
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"次数不足 (剩余 {balance:.1f}, 需要 2)"
|
||||
)
|
||||
|
||||
pref_context = await pref_service.get_preference_context(user_id, "reply")
|
||||
|
||||
service = TranslationService()
|
||||
@@ -71,7 +89,16 @@ async def generate_reply(
|
||||
async def extract_info(
|
||||
data: ExtractRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
credit_svc = CreditService(db)
|
||||
ok, balance = await credit_svc.deduct(user_id, "info_extract")
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"次数不足 (剩余 {balance:.1f}, 需要 1)"
|
||||
)
|
||||
|
||||
service = TranslationService()
|
||||
result = await service.extract_info(data.text, data.extract_type)
|
||||
return {"extracted": result, "type": data.extract_type}
|
||||
|
||||
Reference in New Issue
Block a user