2a107a42f3
- New DB models: credit_packages, subscription_plans, user_credits, credit_consumptions, credit_purchases - CreditService: balance, deduct, add_credits, grant_free_trial, history - User API: /api/v1/credits/* (balance/history/packages/purchase/subscribe) - Admin API: /api/v1/admin/credit-* (CRUD packages/plans, user credits, consumptions) - PaymentService.create_credit_order + handle_callback for credit purchases - Credit deduction on: discovery, translate, marketing, ai_chat, followup - Free trial 30 credits on registration - Documentation: docs/CREDIT_SYSTEM.md
489 lines
20 KiB
Python
489 lines
20 KiB
Python
import json
|
||
import logging
|
||
import hashlib
|
||
from typing import Optional, Dict, Any, List
|
||
from datetime import datetime, timedelta
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, desc
|
||
from app.models.subscription import Subscription
|
||
from app.models.payment_transaction import PaymentTransaction
|
||
from app.models.user import User
|
||
from app.config import settings
|
||
from app.services.unified_pay import UnifiedPayService
|
||
from app.services.payment_gateway import PaymentGateway
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
PLANS = {
|
||
"free": {"price": 0, "duration_days": None},
|
||
"pro": {"price": 99, "duration_days": 30},
|
||
"pro_yearly": {"price": 999, "duration_days": 365},
|
||
"enterprise": {"price": 399, "duration_days": 30},
|
||
"enterprise_yearly": {"price": 3999, "duration_days": 365},
|
||
}
|
||
|
||
PLAN_DESCRIPTIONS = {
|
||
"pro": "TradeMate Pro 版会员",
|
||
"pro_yearly": "TradeMate Pro 版会员(年付)",
|
||
"enterprise": "TradeMate 企业版会员",
|
||
"enterprise_yearly": "TradeMate 企业版会员(年付)",
|
||
}
|
||
|
||
GATEWAY_MAP: Dict[str, PaymentGateway] = {}
|
||
|
||
|
||
def init_gateways():
|
||
if settings.PAY_API_KEY:
|
||
GATEWAY_MAP["unified"] = UnifiedPayService()
|
||
|
||
|
||
def get_gateway(pay_type: str) -> PaymentGateway:
|
||
gw = GATEWAY_MAP.get("unified")
|
||
if not gw:
|
||
raise ValueError("支付网关未配置,请设置 PAY_API_KEY")
|
||
if not gw.supports(pay_type):
|
||
raise ValueError(f"支付方式 {pay_type} 不被支持(仅支持 alipay/wechat)")
|
||
return gw
|
||
|
||
|
||
def gen_order_no(user_id: str) -> str:
|
||
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")[:18]
|
||
suffix = user_id[-8:] if len(user_id) >= 8 else user_id
|
||
return f"TM{ts}{suffix}"
|
||
|
||
|
||
class PaymentService:
|
||
def __init__(self, db: AsyncSession):
|
||
self.db = db
|
||
|
||
async def get_plans(self) -> Dict[str, Any]:
|
||
return {
|
||
"plans": [
|
||
{"id": "free", "name": "免费版", "price": 0, "period": "month",
|
||
"features": ["1 个产品", "20 次翻译/天", "5 个客户", "基础回复建议"]},
|
||
{"id": "pro", "name": "Pro 版", "price": 99, "period": "month",
|
||
"features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成"]},
|
||
{"id": "pro_yearly", "name": "Pro 版(年付)", "price": 999, "period": "year",
|
||
"original_price": 1188,
|
||
"features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成", "省 ¥189"]},
|
||
{"id": "enterprise", "name": "企业版", "price": 399, "period": "month",
|
||
"features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持"]},
|
||
{"id": "enterprise_yearly", "name": "企业版(年付)", "price": 3999, "period": "year",
|
||
"original_price": 4788,
|
||
"features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持", "省 ¥789"]},
|
||
],
|
||
"gateways": list(GATEWAY_MAP.keys()) or ["unified"],
|
||
}
|
||
|
||
async def get_current_subscription(self, user_id: str) -> Dict[str, Any]:
|
||
result = await self.db.execute(
|
||
select(Subscription).where(
|
||
Subscription.user_id == user_id,
|
||
Subscription.status == "active",
|
||
).order_by(Subscription.created_at.desc()).limit(1)
|
||
)
|
||
sub = result.scalar_one_or_none()
|
||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||
user = result.scalar_one_or_none()
|
||
return {
|
||
"plan": user.tier if user else "free",
|
||
"status": sub.status if sub else "active",
|
||
"expires_at": sub.expires_at.isoformat() if sub and sub.expires_at else None,
|
||
"auto_renew": sub.auto_renew if sub else False,
|
||
}
|
||
|
||
async def create_order(self, user_id: str, plan: str,
|
||
pay_type: str = "alipay") -> Dict[str, Any]:
|
||
if plan not in PLANS:
|
||
raise ValueError(f"无效套餐: {plan}")
|
||
plan_info = PLANS[plan]
|
||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||
user = result.scalar_one_or_none()
|
||
if not user:
|
||
raise ValueError("用户不存在")
|
||
|
||
if plan_info["price"] == 0:
|
||
user.tier = plan
|
||
await self.db.flush()
|
||
return {"status": "ok", "plan": plan, "amount": 0}
|
||
|
||
order_no = gen_order_no(user_id)
|
||
description = PLAN_DESCRIPTIONS.get(plan, f"TradeMate {plan}")
|
||
|
||
remark = json.dumps({"uid": user_id, "oid": order_no}, ensure_ascii=False, separators=(",", ":"))
|
||
|
||
gw = get_gateway(pay_type)
|
||
gw_result = await gw.create_order(order_no, int(plan_info["price"] * 100),
|
||
description, pay_type=pay_type, remark=remark)
|
||
|
||
sub = Subscription(
|
||
user_id=user_id, plan=plan, status="pending",
|
||
amount=plan_info["price"], payment_id=order_no,
|
||
payment_provider="unified",
|
||
)
|
||
self.db.add(sub)
|
||
|
||
txn = PaymentTransaction(
|
||
user_id=user_id, order_no=order_no, plan=plan,
|
||
amount=plan_info["price"], gateway="unified", pay_type=pay_type,
|
||
status="pending", description=description,
|
||
gateway_order_no=gw_result.get("gateway_order_id", ""),
|
||
)
|
||
self.db.add(txn)
|
||
await self.db.flush()
|
||
|
||
return {
|
||
"status": "pending",
|
||
"order_id": order_no,
|
||
"plan": plan,
|
||
"amount": plan_info["price"],
|
||
"currency": "CNY",
|
||
"gateway": "unified",
|
||
"pay_type": pay_type,
|
||
**gw_result,
|
||
}
|
||
|
||
async def create_credit_order(self, user_id: str, amount: float,
|
||
description: str, pay_type: str = "alipay",
|
||
metadata: dict = None) -> Dict[str, Any]:
|
||
order_no = gen_order_no(user_id)
|
||
gw = get_gateway(pay_type)
|
||
|
||
meta_remark = {"uid": user_id, "oid": order_no, "type": "credit_purchase"}
|
||
if metadata:
|
||
meta_remark.update(metadata)
|
||
|
||
gw_result = await gw.create_order(order_no, int(amount * 100),
|
||
description, pay_type=pay_type,
|
||
remark=json.dumps(meta_remark, separators=(",", ":")))
|
||
|
||
txn = PaymentTransaction(
|
||
user_id=user_id, order_no=order_no, plan="credit_purchase",
|
||
amount=amount, gateway="unified", pay_type=pay_type,
|
||
status="pending", description=json.dumps(metadata or {}, ensure_ascii=False),
|
||
gateway_order_no=gw_result.get("gateway_order_id", ""),
|
||
)
|
||
self.db.add(txn)
|
||
await self.db.flush()
|
||
return {
|
||
"status": "pending",
|
||
"order_id": order_no,
|
||
"amount": amount,
|
||
"currency": "CNY",
|
||
"gateway": "unified",
|
||
"pay_type": pay_type,
|
||
"metadata": metadata or {},
|
||
**gw_result,
|
||
}
|
||
|
||
async def handle_callback(self, order_no: str, gateway_order_id: str,
|
||
gateway_order_no: str, success: bool,
|
||
amount: float = 0, notify_raw: str = "") -> bool:
|
||
result = await self.db.execute(
|
||
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
|
||
)
|
||
txn = result.scalar_one_or_none()
|
||
if not txn:
|
||
return False
|
||
if txn.status != "pending":
|
||
return True
|
||
|
||
if success:
|
||
txn.status = "paid"
|
||
txn.gateway_order_id = gateway_order_id
|
||
txn.gateway_order_no = gateway_order_no
|
||
txn.paid_at = datetime.utcnow()
|
||
txn.notify_raw = notify_raw
|
||
|
||
if txn.plan == "credit_purchase":
|
||
from app.services.credit import CreditService
|
||
credit_svc = CreditService(self.db)
|
||
|
||
if txn.description:
|
||
try:
|
||
meta = json.loads(txn.description)
|
||
credits = meta.get("credits", 0)
|
||
except (json.JSONDecodeError, TypeError):
|
||
credits = 0
|
||
else:
|
||
credits = 0
|
||
|
||
if not credits:
|
||
credits = max(1, int(txn.amount / 0.79))
|
||
|
||
await credit_svc.add_credits(
|
||
txn.user_id, credits, "package",
|
||
f"支付完成 - 获得 {credits} 次信用额度"
|
||
)
|
||
else:
|
||
sub_result = await self.db.execute(
|
||
select(Subscription).where(Subscription.payment_id == order_no)
|
||
)
|
||
sub = sub_result.scalar_one_or_none()
|
||
if sub:
|
||
sub.status = "active"
|
||
sub.started_at = datetime.utcnow()
|
||
if PLANS[sub.plan]["duration_days"]:
|
||
sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"])
|
||
|
||
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
|
||
user = user_result.scalar_one_or_none()
|
||
if user:
|
||
user.tier = txn.plan
|
||
else:
|
||
txn.status = "failed"
|
||
txn.notify_raw = notify_raw
|
||
|
||
if txn.plan != "credit_purchase":
|
||
sub_result = await self.db.execute(
|
||
select(Subscription).where(Subscription.payment_id == order_no)
|
||
)
|
||
sub = sub_result.scalar_one_or_none()
|
||
if sub:
|
||
sub.status = "failed"
|
||
|
||
await self.db.flush()
|
||
return True
|
||
|
||
async def query_payment(self, user_id: str, order_no: str) -> Dict[str, Any]:
|
||
result = await self.db.execute(
|
||
select(PaymentTransaction).where(
|
||
PaymentTransaction.order_no == order_no,
|
||
PaymentTransaction.user_id == user_id,
|
||
)
|
||
)
|
||
txn = result.scalar_one_or_none()
|
||
if not txn:
|
||
raise ValueError("订单不存在")
|
||
return {
|
||
"order_no": txn.order_no, "plan": txn.plan,
|
||
"amount": txn.amount, "currency": txn.currency,
|
||
"gateway": txn.gateway, "pay_type": txn.pay_type,
|
||
"status": txn.status,
|
||
"gateway_order_no": txn.gateway_order_no,
|
||
"paid_at": txn.paid_at.isoformat() if txn.paid_at else None,
|
||
"refund_amount": txn.refund_amount,
|
||
"created_at": txn.created_at.isoformat(),
|
||
}
|
||
|
||
async def close_order(self, user_id: str, order_no: str) -> Dict[str, Any]:
|
||
result = await self.db.execute(
|
||
select(PaymentTransaction).where(
|
||
PaymentTransaction.order_no == order_no,
|
||
PaymentTransaction.user_id == user_id,
|
||
)
|
||
)
|
||
txn = result.scalar_one_or_none()
|
||
if not txn:
|
||
raise ValueError("订单不存在")
|
||
if txn.status != "pending":
|
||
raise ValueError("只有待支付订单可关闭")
|
||
gw = get_gateway(txn.pay_type)
|
||
await gw.close_order(order_no)
|
||
txn.status = "closed"
|
||
await self.db.flush()
|
||
return {"status": "ok", "order_no": order_no}
|
||
|
||
async def query_refund(self, order_no: str, user_id: str = "") -> Dict[str, Any]:
|
||
query = select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
|
||
if user_id:
|
||
query = query.where(PaymentTransaction.user_id == user_id)
|
||
result = await self.db.execute(query)
|
||
txn = result.scalar_one_or_none()
|
||
if not txn:
|
||
raise ValueError("订单不存在")
|
||
if txn.status != "refunded":
|
||
raise ValueError("该订单未退款")
|
||
gw = get_gateway(txn.pay_type)
|
||
gw_result = await gw.query_refund(order_no)
|
||
return {
|
||
"order_no": order_no,
|
||
"status": txn.status,
|
||
"refund_amount": txn.refund_amount,
|
||
"refund_reason": txn.refund_reason,
|
||
"refunded_at": txn.refunded_at.isoformat() if txn.refunded_at else None,
|
||
"gateway": gw_result,
|
||
}
|
||
|
||
async def list_transactions(self, user_id: str,
|
||
page: int = 1, size: int = 20) -> Dict[str, Any]:
|
||
query = select(PaymentTransaction).where(
|
||
PaymentTransaction.user_id == user_id
|
||
).order_by(desc(PaymentTransaction.created_at))
|
||
total_q = select(PaymentTransaction.id).where(
|
||
PaymentTransaction.user_id == user_id
|
||
)
|
||
total_result = await self.db.execute(total_q)
|
||
total = len(total_result.scalars().all())
|
||
result = await self.db.execute(query.offset((page - 1) * size).limit(size))
|
||
items = result.scalars().all()
|
||
return {
|
||
"items": [{
|
||
"order_no": t.order_no, "plan": t.plan,
|
||
"amount": t.amount, "gateway": t.gateway,
|
||
"pay_type": t.pay_type, "status": t.status,
|
||
"created_at": t.created_at.isoformat(),
|
||
"paid_at": t.paid_at.isoformat() if t.paid_at else None,
|
||
} for t in items],
|
||
"total": total, "page": page, "size": size,
|
||
}
|
||
|
||
async def refund(self, user_id: str, order_no: str,
|
||
reason: str = "") -> Dict[str, Any]:
|
||
result = await self.db.execute(
|
||
select(PaymentTransaction).where(
|
||
PaymentTransaction.order_no == order_no,
|
||
PaymentTransaction.user_id == user_id,
|
||
)
|
||
)
|
||
txn = result.scalar_one_or_none()
|
||
if not txn:
|
||
raise ValueError("订单不存在")
|
||
if txn.status != "paid":
|
||
raise ValueError("只有已支付订单可退款")
|
||
if txn.refund_amount >= txn.amount:
|
||
raise ValueError("该订单已全额退款")
|
||
|
||
gw = get_gateway(txn.pay_type)
|
||
remaining = int((txn.amount - txn.refund_amount) * 100)
|
||
try:
|
||
gw_result = await gw.refund(txn.order_no, remaining, reason)
|
||
logger.info(f"Refund {txn.order_no}: {gw_result}")
|
||
except Exception as e:
|
||
raise ValueError(f"退款请求失败: {e}")
|
||
|
||
txn.status = "refunded"
|
||
txn.refund_amount = txn.amount
|
||
txn.refund_reason = reason
|
||
txn.refunded_at = datetime.utcnow()
|
||
|
||
sub_result = await self.db.execute(
|
||
select(Subscription).where(Subscription.payment_id == order_no)
|
||
)
|
||
sub = sub_result.scalar_one_or_none()
|
||
if sub:
|
||
sub.status = "expired"
|
||
|
||
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
|
||
user = user_result.scalar_one_or_none()
|
||
if user and user.tier == txn.plan:
|
||
user.tier = "free"
|
||
|
||
await self.db.flush()
|
||
return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount}
|
||
|
||
async def admin_list_payments(self, page: int = 1, size: int = 20,
|
||
gateway: str = "", status: str = "",
|
||
user_id: str = "",
|
||
pay_type: str = "") -> Dict[str, Any]:
|
||
query = select(PaymentTransaction).order_by(desc(PaymentTransaction.created_at))
|
||
count_query = select(PaymentTransaction.id)
|
||
if gateway:
|
||
query = query.where(PaymentTransaction.gateway == gateway)
|
||
count_query = count_query.where(PaymentTransaction.gateway == gateway)
|
||
if status:
|
||
query = query.where(PaymentTransaction.status == status)
|
||
count_query = count_query.where(PaymentTransaction.status == status)
|
||
if user_id:
|
||
query = query.where(PaymentTransaction.user_id == user_id)
|
||
count_query = count_query.where(PaymentTransaction.user_id == user_id)
|
||
if pay_type:
|
||
query = query.where(PaymentTransaction.pay_type == pay_type)
|
||
count_query = count_query.where(PaymentTransaction.pay_type == pay_type)
|
||
|
||
total_result = await self.db.execute(count_query)
|
||
total = len(total_result.scalars().all())
|
||
result = await self.db.execute(query.offset((page - 1) * size).limit(size))
|
||
items = result.scalars().all()
|
||
return {
|
||
"items": [{
|
||
"id": str(t.id), "user_id": str(t.user_id),
|
||
"order_no": t.order_no, "plan": t.plan,
|
||
"amount": t.amount, "gateway": t.gateway,
|
||
"pay_type": t.pay_type, "status": t.status,
|
||
"gateway_order_no": t.gateway_order_no,
|
||
"refund_amount": t.refund_amount,
|
||
"created_at": t.created_at.isoformat(),
|
||
"paid_at": t.paid_at.isoformat() if t.paid_at else None,
|
||
"refunded_at": t.refunded_at.isoformat() if t.refunded_at else None,
|
||
} for t in items],
|
||
"total": total, "page": page, "size": size,
|
||
}
|
||
|
||
async def admin_refund(self, order_no: str, reason: str = "") -> Dict[str, Any]:
|
||
result = await self.db.execute(
|
||
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
|
||
)
|
||
txn = result.scalar_one_or_none()
|
||
if not txn:
|
||
raise ValueError("订单不存在")
|
||
if txn.status != "paid":
|
||
raise ValueError("只有已支付订单可退款")
|
||
|
||
gw = get_gateway(txn.pay_type)
|
||
remaining = int((txn.amount - txn.refund_amount) * 100)
|
||
try:
|
||
gw_result = await gw.refund(txn.order_no, remaining, reason)
|
||
logger.info(f"Admin refund {txn.order_no}: {gw_result}")
|
||
except Exception as e:
|
||
raise ValueError(f"退款请求失败: {e}")
|
||
|
||
txn.status = "refunded"
|
||
txn.refund_amount = txn.amount
|
||
txn.refund_reason = reason
|
||
txn.refunded_at = datetime.utcnow()
|
||
|
||
sub_result = await self.db.execute(
|
||
select(Subscription).where(Subscription.payment_id == order_no)
|
||
)
|
||
sub = sub_result.scalar_one_or_none()
|
||
if sub:
|
||
sub.status = "expired"
|
||
|
||
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
|
||
user = user_result.scalar_one_or_none()
|
||
if user and user.tier == txn.plan:
|
||
user.tier = "free"
|
||
|
||
await self.db.flush()
|
||
return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount,
|
||
"user_id": str(txn.user_id)}
|
||
|
||
async def admin_close_order(self, order_no: str) -> Dict[str, Any]:
|
||
result = await self.db.execute(
|
||
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
|
||
)
|
||
txn = result.scalar_one_or_none()
|
||
if not txn:
|
||
raise ValueError("订单不存在")
|
||
if txn.status != "pending":
|
||
raise ValueError("只有待支付订单可关闭")
|
||
gw = get_gateway(txn.pay_type)
|
||
await gw.close_order(order_no)
|
||
txn.status = "closed"
|
||
await self.db.flush()
|
||
return {"status": "ok", "order_no": order_no}
|
||
|
||
async def admin_payment_stats(self) -> Dict[str, Any]:
|
||
all_txns = await self.db.execute(select(PaymentTransaction))
|
||
rows = all_txns.scalars().all()
|
||
total_count = len(rows)
|
||
total_revenue = sum(r.amount for r in rows if r.status == "paid")
|
||
total_refund = sum(r.refund_amount for r in rows)
|
||
paid_count = sum(1 for r in rows if r.status == "paid")
|
||
pending_count = sum(1 for r in rows if r.status == "pending")
|
||
refunded_count = sum(1 for r in rows if r.status == "refunded")
|
||
failed_count = sum(1 for r in rows if r.status == "failed")
|
||
wechat_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "wechat")
|
||
alipay_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "alipay")
|
||
return {
|
||
"total_count": total_count, "total_revenue": total_revenue,
|
||
"total_refund": total_refund, "paid_count": paid_count,
|
||
"pending_count": pending_count, "refunded_count": refunded_count,
|
||
"failed_count": failed_count, "wechat_count": wechat_count,
|
||
"alipay_count": alipay_count,
|
||
}
|
||
|
||
|
||
init_gateways()
|