Files
trade-assistant/backend/app/services/payment.py
T
TradeMate Dev 2a107a42f3 feat: credit-based billing system
- New DB models: credit_packages, subscription_plans, user_credits, credit_consumptions, credit_purchases
- CreditService: balance, deduct, add_credits, grant_free_trial, history
- User API: /api/v1/credits/* (balance/history/packages/purchase/subscribe)
- Admin API: /api/v1/admin/credit-* (CRUD packages/plans, user credits, consumptions)
- PaymentService.create_credit_order + handle_callback for credit purchases
- Credit deduction on: discovery, translate, marketing, ai_chat, followup
- Free trial 30 credits on registration
- Documentation: docs/CREDIT_SYSTEM.md
2026-06-12 10:39:45 +08:00

489 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import logging
import hashlib
from typing import Optional, Dict, Any, List
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc
from app.models.subscription import Subscription
from app.models.payment_transaction import PaymentTransaction
from app.models.user import User
from app.config import settings
from app.services.unified_pay import UnifiedPayService
from app.services.payment_gateway import PaymentGateway
logger = logging.getLogger(__name__)
PLANS = {
"free": {"price": 0, "duration_days": None},
"pro": {"price": 99, "duration_days": 30},
"pro_yearly": {"price": 999, "duration_days": 365},
"enterprise": {"price": 399, "duration_days": 30},
"enterprise_yearly": {"price": 3999, "duration_days": 365},
}
PLAN_DESCRIPTIONS = {
"pro": "TradeMate Pro 版会员",
"pro_yearly": "TradeMate Pro 版会员(年付)",
"enterprise": "TradeMate 企业版会员",
"enterprise_yearly": "TradeMate 企业版会员(年付)",
}
GATEWAY_MAP: Dict[str, PaymentGateway] = {}
def init_gateways():
if settings.PAY_API_KEY:
GATEWAY_MAP["unified"] = UnifiedPayService()
def get_gateway(pay_type: str) -> PaymentGateway:
gw = GATEWAY_MAP.get("unified")
if not gw:
raise ValueError("支付网关未配置,请设置 PAY_API_KEY")
if not gw.supports(pay_type):
raise ValueError(f"支付方式 {pay_type} 不被支持(仅支持 alipay/wechat")
return gw
def gen_order_no(user_id: str) -> str:
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")[:18]
suffix = user_id[-8:] if len(user_id) >= 8 else user_id
return f"TM{ts}{suffix}"
class PaymentService:
def __init__(self, db: AsyncSession):
self.db = db
async def get_plans(self) -> Dict[str, Any]:
return {
"plans": [
{"id": "free", "name": "免费版", "price": 0, "period": "month",
"features": ["1 个产品", "20 次翻译/天", "5 个客户", "基础回复建议"]},
{"id": "pro", "name": "Pro 版", "price": 99, "period": "month",
"features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成"]},
{"id": "pro_yearly", "name": "Pro 版(年付)", "price": 999, "period": "year",
"original_price": 1188,
"features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成", "省 ¥189"]},
{"id": "enterprise", "name": "企业版", "price": 399, "period": "month",
"features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持"]},
{"id": "enterprise_yearly", "name": "企业版(年付)", "price": 3999, "period": "year",
"original_price": 4788,
"features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持", "省 ¥789"]},
],
"gateways": list(GATEWAY_MAP.keys()) or ["unified"],
}
async def get_current_subscription(self, user_id: str) -> Dict[str, Any]:
result = await self.db.execute(
select(Subscription).where(
Subscription.user_id == user_id,
Subscription.status == "active",
).order_by(Subscription.created_at.desc()).limit(1)
)
sub = result.scalar_one_or_none()
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
return {
"plan": user.tier if user else "free",
"status": sub.status if sub else "active",
"expires_at": sub.expires_at.isoformat() if sub and sub.expires_at else None,
"auto_renew": sub.auto_renew if sub else False,
}
async def create_order(self, user_id: str, plan: str,
pay_type: str = "alipay") -> Dict[str, Any]:
if plan not in PLANS:
raise ValueError(f"无效套餐: {plan}")
plan_info = PLANS[plan]
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise ValueError("用户不存在")
if plan_info["price"] == 0:
user.tier = plan
await self.db.flush()
return {"status": "ok", "plan": plan, "amount": 0}
order_no = gen_order_no(user_id)
description = PLAN_DESCRIPTIONS.get(plan, f"TradeMate {plan}")
remark = json.dumps({"uid": user_id, "oid": order_no}, ensure_ascii=False, separators=(",", ":"))
gw = get_gateway(pay_type)
gw_result = await gw.create_order(order_no, int(plan_info["price"] * 100),
description, pay_type=pay_type, remark=remark)
sub = Subscription(
user_id=user_id, plan=plan, status="pending",
amount=plan_info["price"], payment_id=order_no,
payment_provider="unified",
)
self.db.add(sub)
txn = PaymentTransaction(
user_id=user_id, order_no=order_no, plan=plan,
amount=plan_info["price"], gateway="unified", pay_type=pay_type,
status="pending", description=description,
gateway_order_no=gw_result.get("gateway_order_id", ""),
)
self.db.add(txn)
await self.db.flush()
return {
"status": "pending",
"order_id": order_no,
"plan": plan,
"amount": plan_info["price"],
"currency": "CNY",
"gateway": "unified",
"pay_type": pay_type,
**gw_result,
}
async def create_credit_order(self, user_id: str, amount: float,
description: str, pay_type: str = "alipay",
metadata: dict = None) -> Dict[str, Any]:
order_no = gen_order_no(user_id)
gw = get_gateway(pay_type)
meta_remark = {"uid": user_id, "oid": order_no, "type": "credit_purchase"}
if metadata:
meta_remark.update(metadata)
gw_result = await gw.create_order(order_no, int(amount * 100),
description, pay_type=pay_type,
remark=json.dumps(meta_remark, separators=(",", ":")))
txn = PaymentTransaction(
user_id=user_id, order_no=order_no, plan="credit_purchase",
amount=amount, gateway="unified", pay_type=pay_type,
status="pending", description=json.dumps(metadata or {}, ensure_ascii=False),
gateway_order_no=gw_result.get("gateway_order_id", ""),
)
self.db.add(txn)
await self.db.flush()
return {
"status": "pending",
"order_id": order_no,
"amount": amount,
"currency": "CNY",
"gateway": "unified",
"pay_type": pay_type,
"metadata": metadata or {},
**gw_result,
}
async def handle_callback(self, order_no: str, gateway_order_id: str,
gateway_order_no: str, success: bool,
amount: float = 0, notify_raw: str = "") -> bool:
result = await self.db.execute(
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
)
txn = result.scalar_one_or_none()
if not txn:
return False
if txn.status != "pending":
return True
if success:
txn.status = "paid"
txn.gateway_order_id = gateway_order_id
txn.gateway_order_no = gateway_order_no
txn.paid_at = datetime.utcnow()
txn.notify_raw = notify_raw
if txn.plan == "credit_purchase":
from app.services.credit import CreditService
credit_svc = CreditService(self.db)
if txn.description:
try:
meta = json.loads(txn.description)
credits = meta.get("credits", 0)
except (json.JSONDecodeError, TypeError):
credits = 0
else:
credits = 0
if not credits:
credits = max(1, int(txn.amount / 0.79))
await credit_svc.add_credits(
txn.user_id, credits, "package",
f"支付完成 - 获得 {credits} 次信用额度"
)
else:
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "active"
sub.started_at = datetime.utcnow()
if PLANS[sub.plan]["duration_days"]:
sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"])
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
user = user_result.scalar_one_or_none()
if user:
user.tier = txn.plan
else:
txn.status = "failed"
txn.notify_raw = notify_raw
if txn.plan != "credit_purchase":
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "failed"
await self.db.flush()
return True
async def query_payment(self, user_id: str, order_no: str) -> Dict[str, Any]:
result = await self.db.execute(
select(PaymentTransaction).where(
PaymentTransaction.order_no == order_no,
PaymentTransaction.user_id == user_id,
)
)
txn = result.scalar_one_or_none()
if not txn:
raise ValueError("订单不存在")
return {
"order_no": txn.order_no, "plan": txn.plan,
"amount": txn.amount, "currency": txn.currency,
"gateway": txn.gateway, "pay_type": txn.pay_type,
"status": txn.status,
"gateway_order_no": txn.gateway_order_no,
"paid_at": txn.paid_at.isoformat() if txn.paid_at else None,
"refund_amount": txn.refund_amount,
"created_at": txn.created_at.isoformat(),
}
async def close_order(self, user_id: str, order_no: str) -> Dict[str, Any]:
result = await self.db.execute(
select(PaymentTransaction).where(
PaymentTransaction.order_no == order_no,
PaymentTransaction.user_id == user_id,
)
)
txn = result.scalar_one_or_none()
if not txn:
raise ValueError("订单不存在")
if txn.status != "pending":
raise ValueError("只有待支付订单可关闭")
gw = get_gateway(txn.pay_type)
await gw.close_order(order_no)
txn.status = "closed"
await self.db.flush()
return {"status": "ok", "order_no": order_no}
async def query_refund(self, order_no: str, user_id: str = "") -> Dict[str, Any]:
query = select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
if user_id:
query = query.where(PaymentTransaction.user_id == user_id)
result = await self.db.execute(query)
txn = result.scalar_one_or_none()
if not txn:
raise ValueError("订单不存在")
if txn.status != "refunded":
raise ValueError("该订单未退款")
gw = get_gateway(txn.pay_type)
gw_result = await gw.query_refund(order_no)
return {
"order_no": order_no,
"status": txn.status,
"refund_amount": txn.refund_amount,
"refund_reason": txn.refund_reason,
"refunded_at": txn.refunded_at.isoformat() if txn.refunded_at else None,
"gateway": gw_result,
}
async def list_transactions(self, user_id: str,
page: int = 1, size: int = 20) -> Dict[str, Any]:
query = select(PaymentTransaction).where(
PaymentTransaction.user_id == user_id
).order_by(desc(PaymentTransaction.created_at))
total_q = select(PaymentTransaction.id).where(
PaymentTransaction.user_id == user_id
)
total_result = await self.db.execute(total_q)
total = len(total_result.scalars().all())
result = await self.db.execute(query.offset((page - 1) * size).limit(size))
items = result.scalars().all()
return {
"items": [{
"order_no": t.order_no, "plan": t.plan,
"amount": t.amount, "gateway": t.gateway,
"pay_type": t.pay_type, "status": t.status,
"created_at": t.created_at.isoformat(),
"paid_at": t.paid_at.isoformat() if t.paid_at else None,
} for t in items],
"total": total, "page": page, "size": size,
}
async def refund(self, user_id: str, order_no: str,
reason: str = "") -> Dict[str, Any]:
result = await self.db.execute(
select(PaymentTransaction).where(
PaymentTransaction.order_no == order_no,
PaymentTransaction.user_id == user_id,
)
)
txn = result.scalar_one_or_none()
if not txn:
raise ValueError("订单不存在")
if txn.status != "paid":
raise ValueError("只有已支付订单可退款")
if txn.refund_amount >= txn.amount:
raise ValueError("该订单已全额退款")
gw = get_gateway(txn.pay_type)
remaining = int((txn.amount - txn.refund_amount) * 100)
try:
gw_result = await gw.refund(txn.order_no, remaining, reason)
logger.info(f"Refund {txn.order_no}: {gw_result}")
except Exception as e:
raise ValueError(f"退款请求失败: {e}")
txn.status = "refunded"
txn.refund_amount = txn.amount
txn.refund_reason = reason
txn.refunded_at = datetime.utcnow()
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "expired"
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
user = user_result.scalar_one_or_none()
if user and user.tier == txn.plan:
user.tier = "free"
await self.db.flush()
return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount}
async def admin_list_payments(self, page: int = 1, size: int = 20,
gateway: str = "", status: str = "",
user_id: str = "",
pay_type: str = "") -> Dict[str, Any]:
query = select(PaymentTransaction).order_by(desc(PaymentTransaction.created_at))
count_query = select(PaymentTransaction.id)
if gateway:
query = query.where(PaymentTransaction.gateway == gateway)
count_query = count_query.where(PaymentTransaction.gateway == gateway)
if status:
query = query.where(PaymentTransaction.status == status)
count_query = count_query.where(PaymentTransaction.status == status)
if user_id:
query = query.where(PaymentTransaction.user_id == user_id)
count_query = count_query.where(PaymentTransaction.user_id == user_id)
if pay_type:
query = query.where(PaymentTransaction.pay_type == pay_type)
count_query = count_query.where(PaymentTransaction.pay_type == pay_type)
total_result = await self.db.execute(count_query)
total = len(total_result.scalars().all())
result = await self.db.execute(query.offset((page - 1) * size).limit(size))
items = result.scalars().all()
return {
"items": [{
"id": str(t.id), "user_id": str(t.user_id),
"order_no": t.order_no, "plan": t.plan,
"amount": t.amount, "gateway": t.gateway,
"pay_type": t.pay_type, "status": t.status,
"gateway_order_no": t.gateway_order_no,
"refund_amount": t.refund_amount,
"created_at": t.created_at.isoformat(),
"paid_at": t.paid_at.isoformat() if t.paid_at else None,
"refunded_at": t.refunded_at.isoformat() if t.refunded_at else None,
} for t in items],
"total": total, "page": page, "size": size,
}
async def admin_refund(self, order_no: str, reason: str = "") -> Dict[str, Any]:
result = await self.db.execute(
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
)
txn = result.scalar_one_or_none()
if not txn:
raise ValueError("订单不存在")
if txn.status != "paid":
raise ValueError("只有已支付订单可退款")
gw = get_gateway(txn.pay_type)
remaining = int((txn.amount - txn.refund_amount) * 100)
try:
gw_result = await gw.refund(txn.order_no, remaining, reason)
logger.info(f"Admin refund {txn.order_no}: {gw_result}")
except Exception as e:
raise ValueError(f"退款请求失败: {e}")
txn.status = "refunded"
txn.refund_amount = txn.amount
txn.refund_reason = reason
txn.refunded_at = datetime.utcnow()
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "expired"
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
user = user_result.scalar_one_or_none()
if user and user.tier == txn.plan:
user.tier = "free"
await self.db.flush()
return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount,
"user_id": str(txn.user_id)}
async def admin_close_order(self, order_no: str) -> Dict[str, Any]:
result = await self.db.execute(
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
)
txn = result.scalar_one_or_none()
if not txn:
raise ValueError("订单不存在")
if txn.status != "pending":
raise ValueError("只有待支付订单可关闭")
gw = get_gateway(txn.pay_type)
await gw.close_order(order_no)
txn.status = "closed"
await self.db.flush()
return {"status": "ok", "order_no": order_no}
async def admin_payment_stats(self) -> Dict[str, Any]:
all_txns = await self.db.execute(select(PaymentTransaction))
rows = all_txns.scalars().all()
total_count = len(rows)
total_revenue = sum(r.amount for r in rows if r.status == "paid")
total_refund = sum(r.refund_amount for r in rows)
paid_count = sum(1 for r in rows if r.status == "paid")
pending_count = sum(1 for r in rows if r.status == "pending")
refunded_count = sum(1 for r in rows if r.status == "refunded")
failed_count = sum(1 for r in rows if r.status == "failed")
wechat_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "wechat")
alipay_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "alipay")
return {
"total_count": total_count, "total_revenue": total_revenue,
"total_refund": total_refund, "paid_count": paid_count,
"pending_count": pending_count, "refunded_count": refunded_count,
"failed_count": failed_count, "wechat_count": wechat_count,
"alipay_count": alipay_count,
}
init_gateways()