refactor: replace direct WeChat/Alipay with unified pay-api gateway
Switch from direct WeChat Pay / Alipay integrations to the unified
宇之然 pay-api gateway (HMAC-SHA256 auth). Removes wechat_pay.py,
keeps PaymentGateway abstraction, adds UnifiedPayService. Simplifies
payment.py create_order to {plan, pay_type} params. Single webhook
endpoint replaces separate WeChat/Alipay notify handlers.
This commit is contained in:
+284
-171
@@ -1,13 +1,15 @@
|
||||
import logging
|
||||
import hashlib
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, desc
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.payment_transaction import PaymentTransaction
|
||||
from app.models.user import User
|
||||
from app.config import settings
|
||||
from app.services.wechat_pay import WeChatPayService
|
||||
from app.services.unified_pay import UnifiedPayService
|
||||
from app.services.payment_gateway import PaymentGateway
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,92 +28,50 @@ PLAN_DESCRIPTIONS = {
|
||||
"enterprise_yearly": "TradeMate 企业版会员(年付)",
|
||||
}
|
||||
|
||||
GATEWAY_MAP: Dict[str, PaymentGateway] = {}
|
||||
|
||||
|
||||
def init_gateways():
|
||||
if settings.PAY_API_KEY:
|
||||
GATEWAY_MAP["unified"] = UnifiedPayService()
|
||||
|
||||
|
||||
def get_gateway(pay_type: str) -> PaymentGateway:
|
||||
gw = GATEWAY_MAP.get("unified")
|
||||
if not gw:
|
||||
raise ValueError("支付网关未配置,请设置 PAY_API_KEY")
|
||||
if not gw.supports(pay_type):
|
||||
raise ValueError(f"支付方式 {pay_type} 不被支持(仅支持 alipay/wechat)")
|
||||
return gw
|
||||
|
||||
|
||||
def gen_order_no(user_id: str) -> str:
|
||||
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")[:18]
|
||||
suffix = user_id[-8:] if len(user_id) >= 8 else user_id
|
||||
return f"TM{ts}{suffix}"
|
||||
|
||||
|
||||
class PaymentService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self._wxpay = None
|
||||
|
||||
@property
|
||||
def wxpay(self) -> Optional[WeChatPayService]:
|
||||
if self._wxpay is None and settings.WECHAT_PAY_MCH_ID:
|
||||
self._wxpay = WeChatPayService()
|
||||
return self._wxpay
|
||||
|
||||
async def get_plans(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"plans": [
|
||||
{
|
||||
"id": "free",
|
||||
"name": "免费版",
|
||||
"price": 0,
|
||||
"period": "month",
|
||||
"features": [
|
||||
"1 个产品",
|
||||
"20 次翻译/天",
|
||||
"5 个客户",
|
||||
"基础回复建议",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "pro",
|
||||
"name": "Pro 版",
|
||||
"price": 99,
|
||||
"period": "month",
|
||||
"features": [
|
||||
"10 个产品",
|
||||
"无限翻译",
|
||||
"50 个客户",
|
||||
"跟进提醒",
|
||||
"报价单生成",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "pro_yearly",
|
||||
"name": "Pro 版(年付)",
|
||||
"price": 999,
|
||||
"period": "year",
|
||||
"original_price": 1188,
|
||||
"features": [
|
||||
"10 个产品",
|
||||
"无限翻译",
|
||||
"50 个客户",
|
||||
"跟进提醒",
|
||||
"报价单生成",
|
||||
"省 ¥189",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "enterprise",
|
||||
"name": "企业版",
|
||||
"price": 399,
|
||||
"period": "month",
|
||||
"features": [
|
||||
"无限产品/客户",
|
||||
"团队协作",
|
||||
"品牌报价单",
|
||||
"专属语料训练",
|
||||
"API 接入",
|
||||
"优先支持",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "enterprise_yearly",
|
||||
"name": "企业版(年付)",
|
||||
"price": 3999,
|
||||
"period": "year",
|
||||
"original_price": 4788,
|
||||
"features": [
|
||||
"无限产品/客户",
|
||||
"团队协作",
|
||||
"品牌报价单",
|
||||
"专属语料训练",
|
||||
"API 接入",
|
||||
"优先支持",
|
||||
"省 ¥789",
|
||||
],
|
||||
},
|
||||
{"id": "free", "name": "免费版", "price": 0, "period": "month",
|
||||
"features": ["1 个产品", "20 次翻译/天", "5 个客户", "基础回复建议"]},
|
||||
{"id": "pro", "name": "Pro 版", "price": 99, "period": "month",
|
||||
"features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成"]},
|
||||
{"id": "pro_yearly", "name": "Pro 版(年付)", "price": 999, "period": "year",
|
||||
"original_price": 1188,
|
||||
"features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成", "省 ¥189"]},
|
||||
{"id": "enterprise", "name": "企业版", "price": 399, "period": "month",
|
||||
"features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持"]},
|
||||
{"id": "enterprise_yearly", "name": "企业版(年付)", "price": 3999, "period": "year",
|
||||
"original_price": 4788,
|
||||
"features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持", "省 ¥789"]},
|
||||
],
|
||||
"gateways": list(GATEWAY_MAP.keys()) or ["unified"],
|
||||
}
|
||||
|
||||
async def get_current_subscription(self, user_id: str) -> Dict[str, Any]:
|
||||
@@ -122,12 +82,8 @@ class PaymentService:
|
||||
).order_by(Subscription.created_at.desc()).limit(1)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
|
||||
result = await self.db.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
return {
|
||||
"plan": user.tier if user else "free",
|
||||
"status": sub.status if sub else "active",
|
||||
@@ -136,124 +92,281 @@ class PaymentService:
|
||||
}
|
||||
|
||||
async def create_order(self, user_id: str, plan: str,
|
||||
pay_type: str = "jsapi") -> Dict[str, Any]:
|
||||
pay_type: str = "alipay") -> Dict[str, Any]:
|
||||
if plan not in PLANS:
|
||||
raise ValueError(f"Invalid plan: {plan}")
|
||||
|
||||
raise ValueError(f"无效套餐: {plan}")
|
||||
plan_info = PLANS[plan]
|
||||
if plan_info["price"] == 0:
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user:
|
||||
user.tier = plan
|
||||
await self.db.flush()
|
||||
return {"status": "ok", "plan": plan, "amount": 0}
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
order_id = f"ORD{datetime.utcnow().strftime('%Y%m%d%H%M%S')}{user_id[-6:]}"
|
||||
if plan_info["price"] == 0:
|
||||
user.tier = plan
|
||||
await self.db.flush()
|
||||
return {"status": "ok", "plan": plan, "amount": 0}
|
||||
|
||||
order_no = gen_order_no(user_id)
|
||||
description = PLAN_DESCRIPTIONS.get(plan, f"TradeMate {plan}")
|
||||
|
||||
gw = get_gateway(pay_type)
|
||||
gw_result = await gw.create_order(order_no, int(plan_info["price"] * 100),
|
||||
description, pay_type=pay_type)
|
||||
|
||||
sub = Subscription(
|
||||
user_id=user_id,
|
||||
plan=plan,
|
||||
status="pending",
|
||||
amount=plan_info["price"],
|
||||
payment_id=order_id,
|
||||
user_id=user_id, plan=plan, status="pending",
|
||||
amount=plan_info["price"], payment_id=order_no,
|
||||
payment_provider="unified",
|
||||
)
|
||||
self.db.add(sub)
|
||||
|
||||
txn = PaymentTransaction(
|
||||
user_id=user_id, order_no=order_no, plan=plan,
|
||||
amount=plan_info["price"], gateway="unified", pay_type=pay_type,
|
||||
status="pending", description=description,
|
||||
gateway_order_no=gw_result.get("gateway_order_id", ""),
|
||||
)
|
||||
self.db.add(txn)
|
||||
await self.db.flush()
|
||||
|
||||
wxpay_available = self.wxpay is not None and settings.WECHAT_PAY_NOTIFY_URL not in (
|
||||
"", "https://example.com/api/v1/payment/notify"
|
||||
)
|
||||
|
||||
if wxpay_available:
|
||||
try:
|
||||
if pay_type == "jsapi":
|
||||
openid = user.wechat_openid
|
||||
if not openid:
|
||||
raise ValueError("用户未绑定微信,请在微信小程序中登录后支付")
|
||||
|
||||
wx_result = await self.wxpay.create_jsapi_order(
|
||||
order_id, openid, int(plan_info["price"] * 100), description
|
||||
)
|
||||
prepay_id = wx_result.get("prepay_id", "")
|
||||
pay_params = self.wxpay.build_jsapi_pay_params(prepay_id)
|
||||
return {
|
||||
"status": "pending",
|
||||
"order_id": order_id,
|
||||
"plan": plan,
|
||||
"amount": plan_info["price"],
|
||||
"currency": "CNY",
|
||||
"pay_type": "jsapi",
|
||||
"pay_params": pay_params,
|
||||
}
|
||||
|
||||
elif pay_type == "native":
|
||||
wx_result = await self.wxpay.create_native_order(
|
||||
order_id, int(plan_info["price"] * 100), description
|
||||
)
|
||||
code_url = wx_result.get("code_url", "")
|
||||
return {
|
||||
"status": "pending",
|
||||
"order_id": order_id,
|
||||
"plan": plan,
|
||||
"amount": plan_info["price"],
|
||||
"currency": "CNY",
|
||||
"pay_type": "native",
|
||||
"code_url": code_url,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"WeChat Pay order failed: {e}")
|
||||
raise ValueError(f"支付创建失败: {str(e)}")
|
||||
|
||||
# 开发环境回退:生成模拟支付参数
|
||||
pay_params = {
|
||||
"appId": settings.WECHAT_APP_ID or "",
|
||||
"timeStamp": str(int(datetime.utcnow().timestamp())),
|
||||
"nonceStr": hashlib.md5(order_id.encode()).hexdigest()[:16],
|
||||
"package": f"prepay_id={order_id}",
|
||||
"signType": "MD5",
|
||||
}
|
||||
sign_str = "&".join(f"{k}={v}" for k, v in sorted(pay_params.items()))
|
||||
sign_str += f"&key={settings.SECRET_KEY}"
|
||||
pay_params["paySign"] = hashlib.md5(sign_str.encode()).hexdigest().upper()
|
||||
return {
|
||||
"status": "pending",
|
||||
"order_id": order_id,
|
||||
"order_id": order_no,
|
||||
"plan": plan,
|
||||
"amount": plan_info["price"],
|
||||
"currency": "CNY",
|
||||
"gateway": "unified",
|
||||
"pay_type": pay_type,
|
||||
"pay_params": pay_params,
|
||||
**gw_result,
|
||||
}
|
||||
|
||||
async def handle_payment_callback(self, payment_id: str, success: bool) -> bool:
|
||||
async def handle_callback(self, order_no: str, gateway_order_id: str,
|
||||
gateway_order_no: str, success: bool,
|
||||
amount: float = 0, notify_raw: str = "") -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Subscription).where(Subscription.payment_id == payment_id)
|
||||
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if not sub:
|
||||
txn = result.scalar_one_or_none()
|
||||
if not txn:
|
||||
return False
|
||||
if txn.status != "pending":
|
||||
return True
|
||||
|
||||
if success:
|
||||
sub.status = "active"
|
||||
sub.started_at = datetime.utcnow()
|
||||
sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"])
|
||||
txn.status = "paid"
|
||||
txn.gateway_order_id = gateway_order_id
|
||||
txn.gateway_order_no = gateway_order_no
|
||||
txn.paid_at = datetime.utcnow()
|
||||
txn.notify_raw = notify_raw
|
||||
|
||||
user_result = await self.db.execute(select(User).where(User.id == sub.user_id))
|
||||
sub_result = await self.db.execute(
|
||||
select(Subscription).where(Subscription.payment_id == order_no)
|
||||
)
|
||||
sub = sub_result.scalar_one_or_none()
|
||||
if sub:
|
||||
sub.status = "active"
|
||||
sub.started_at = datetime.utcnow()
|
||||
if PLANS[sub.plan]["duration_days"]:
|
||||
sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"])
|
||||
|
||||
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user:
|
||||
user.tier = sub.plan
|
||||
user.tier = txn.plan
|
||||
else:
|
||||
sub.status = "failed"
|
||||
txn.status = "failed"
|
||||
txn.notify_raw = notify_raw
|
||||
|
||||
sub_result = await self.db.execute(
|
||||
select(Subscription).where(Subscription.payment_id == order_no)
|
||||
)
|
||||
sub = sub_result.scalar_one_or_none()
|
||||
if sub:
|
||||
sub.status = "failed"
|
||||
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
async def query_payment(self, user_id: str, order_no: str) -> Dict[str, Any]:
|
||||
result = await self.db.execute(
|
||||
select(PaymentTransaction).where(
|
||||
PaymentTransaction.order_no == order_no,
|
||||
PaymentTransaction.user_id == user_id,
|
||||
)
|
||||
)
|
||||
txn = result.scalar_one_or_none()
|
||||
if not txn:
|
||||
raise ValueError("订单不存在")
|
||||
return {
|
||||
"order_no": txn.order_no, "plan": txn.plan,
|
||||
"amount": txn.amount, "currency": txn.currency,
|
||||
"gateway": txn.gateway, "pay_type": txn.pay_type,
|
||||
"status": txn.status,
|
||||
"gateway_order_no": txn.gateway_order_no,
|
||||
"paid_at": txn.paid_at.isoformat() if txn.paid_at else None,
|
||||
"refund_amount": txn.refund_amount,
|
||||
"created_at": txn.created_at.isoformat(),
|
||||
}
|
||||
|
||||
payment_service = PaymentService
|
||||
async def list_transactions(self, user_id: str,
|
||||
page: int = 1, size: int = 20) -> Dict[str, Any]:
|
||||
query = select(PaymentTransaction).where(
|
||||
PaymentTransaction.user_id == user_id
|
||||
).order_by(desc(PaymentTransaction.created_at))
|
||||
total_q = select(PaymentTransaction.id).where(
|
||||
PaymentTransaction.user_id == user_id
|
||||
)
|
||||
total_result = await self.db.execute(total_q)
|
||||
total = len(total_result.scalars().all())
|
||||
result = await self.db.execute(query.offset((page - 1) * size).limit(size))
|
||||
items = result.scalars().all()
|
||||
return {
|
||||
"items": [{
|
||||
"order_no": t.order_no, "plan": t.plan,
|
||||
"amount": t.amount, "gateway": t.gateway,
|
||||
"pay_type": t.pay_type, "status": t.status,
|
||||
"created_at": t.created_at.isoformat(),
|
||||
"paid_at": t.paid_at.isoformat() if t.paid_at else None,
|
||||
} for t in items],
|
||||
"total": total, "page": page, "size": size,
|
||||
}
|
||||
|
||||
async def refund(self, user_id: str, order_no: str,
|
||||
reason: str = "") -> Dict[str, Any]:
|
||||
result = await self.db.execute(
|
||||
select(PaymentTransaction).where(
|
||||
PaymentTransaction.order_no == order_no,
|
||||
PaymentTransaction.user_id == user_id,
|
||||
)
|
||||
)
|
||||
txn = result.scalar_one_or_none()
|
||||
if not txn:
|
||||
raise ValueError("订单不存在")
|
||||
if txn.status != "paid":
|
||||
raise ValueError("只有已支付订单可退款")
|
||||
if txn.refund_amount >= txn.amount:
|
||||
raise ValueError("该订单已全额退款")
|
||||
|
||||
gw = get_gateway(txn.pay_type)
|
||||
remaining = int((txn.amount - txn.refund_amount) * 100)
|
||||
try:
|
||||
gw_result = await gw.refund(txn.order_no, remaining, reason)
|
||||
logger.info(f"Refund {txn.order_no}: {gw_result}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"退款请求失败: {e}")
|
||||
|
||||
txn.status = "refunded"
|
||||
txn.refund_amount = txn.amount
|
||||
txn.refund_reason = reason
|
||||
txn.refunded_at = datetime.utcnow()
|
||||
|
||||
sub_result = await self.db.execute(
|
||||
select(Subscription).where(Subscription.payment_id == order_no)
|
||||
)
|
||||
sub = sub_result.scalar_one_or_none()
|
||||
if sub:
|
||||
sub.status = "expired"
|
||||
|
||||
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user and user.tier == txn.plan:
|
||||
user.tier = "free"
|
||||
|
||||
await self.db.flush()
|
||||
return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount}
|
||||
|
||||
async def admin_list_payments(self, page: int = 1, size: int = 20,
|
||||
gateway: str = "", status: str = "",
|
||||
user_id: str = "") -> Dict[str, Any]:
|
||||
query = select(PaymentTransaction).order_by(desc(PaymentTransaction.created_at))
|
||||
count_query = select(PaymentTransaction.id)
|
||||
if gateway:
|
||||
query = query.where(PaymentTransaction.gateway == gateway)
|
||||
count_query = count_query.where(PaymentTransaction.gateway == gateway)
|
||||
if status:
|
||||
query = query.where(PaymentTransaction.status == status)
|
||||
count_query = count_query.where(PaymentTransaction.status == status)
|
||||
if user_id:
|
||||
query = query.where(PaymentTransaction.user_id == user_id)
|
||||
count_query = count_query.where(PaymentTransaction.user_id == user_id)
|
||||
|
||||
total_result = await self.db.execute(count_query)
|
||||
total = len(total_result.scalars().all())
|
||||
result = await self.db.execute(query.offset((page - 1) * size).limit(size))
|
||||
items = result.scalars().all()
|
||||
return {
|
||||
"items": [{
|
||||
"id": str(t.id), "user_id": str(t.user_id),
|
||||
"order_no": t.order_no, "plan": t.plan,
|
||||
"amount": t.amount, "gateway": t.gateway,
|
||||
"pay_type": t.pay_type, "status": t.status,
|
||||
"gateway_order_no": t.gateway_order_no,
|
||||
"refund_amount": t.refund_amount,
|
||||
"created_at": t.created_at.isoformat(),
|
||||
"paid_at": t.paid_at.isoformat() if t.paid_at else None,
|
||||
"refunded_at": t.refunded_at.isoformat() if t.refunded_at else None,
|
||||
} for t in items],
|
||||
"total": total, "page": page, "size": size,
|
||||
}
|
||||
|
||||
async def admin_refund(self, order_no: str, reason: str = "") -> Dict[str, Any]:
|
||||
result = await self.db.execute(
|
||||
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
|
||||
)
|
||||
txn = result.scalar_one_or_none()
|
||||
if not txn:
|
||||
raise ValueError("订单不存在")
|
||||
if txn.status != "paid":
|
||||
raise ValueError("只有已支付订单可退款")
|
||||
|
||||
gw = get_gateway(txn.pay_type)
|
||||
remaining = int((txn.amount - txn.refund_amount) * 100)
|
||||
try:
|
||||
gw_result = await gw.refund(txn.order_no, remaining, reason)
|
||||
logger.info(f"Admin refund {txn.order_no}: {gw_result}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"退款请求失败: {e}")
|
||||
|
||||
txn.status = "refunded"
|
||||
txn.refund_amount = txn.amount
|
||||
txn.refund_reason = reason
|
||||
txn.refunded_at = datetime.utcnow()
|
||||
|
||||
sub_result = await self.db.execute(
|
||||
select(Subscription).where(Subscription.payment_id == order_no)
|
||||
)
|
||||
sub = sub_result.scalar_one_or_none()
|
||||
if sub:
|
||||
sub.status = "expired"
|
||||
|
||||
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user and user.tier == txn.plan:
|
||||
user.tier = "free"
|
||||
|
||||
await self.db.flush()
|
||||
return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount,
|
||||
"user_id": str(txn.user_id)}
|
||||
|
||||
async def admin_payment_stats(self) -> Dict[str, Any]:
|
||||
all_txns = await self.db.execute(select(PaymentTransaction))
|
||||
rows = all_txns.scalars().all()
|
||||
total_count = len(rows)
|
||||
total_revenue = sum(r.amount for r in rows if r.status == "paid")
|
||||
total_refund = sum(r.refund_amount for r in rows)
|
||||
paid_count = sum(1 for r in rows if r.status == "paid")
|
||||
pending_count = sum(1 for r in rows if r.status == "pending")
|
||||
refunded_count = sum(1 for r in rows if r.status == "refunded")
|
||||
failed_count = sum(1 for r in rows if r.status == "failed")
|
||||
wechat_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "wechat")
|
||||
alipay_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "alipay")
|
||||
return {
|
||||
"total_count": total_count, "total_revenue": total_revenue,
|
||||
"total_refund": total_refund, "paid_count": paid_count,
|
||||
"pending_count": pending_count, "refunded_count": refunded_count,
|
||||
"failed_count": failed_count, "wechat_count": wechat_count,
|
||||
"alipay_count": alipay_count,
|
||||
}
|
||||
|
||||
|
||||
init_gateways()
|
||||
|
||||
Reference in New Issue
Block a user