refactor: replace direct WeChat/Alipay with unified pay-api gateway

Switch from direct WeChat Pay / Alipay integrations to the unified
宇之然 pay-api gateway (HMAC-SHA256 auth). Removes wechat_pay.py,
keeps PaymentGateway abstraction, adds UnifiedPayService. Simplifies
payment.py create_order to {plan, pay_type} params. Single webhook
endpoint replaces separate WeChat/Alipay notify handlers.
This commit is contained in:
TradeMate Dev
2026-05-29 18:36:50 +08:00
parent 5d2bced39f
commit 3e39cf0170
34 changed files with 973 additions and 424 deletions
+284 -171
View File
@@ -1,13 +1,15 @@
import logging
import hashlib
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, List
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy import select, desc
from app.models.subscription import Subscription
from app.models.payment_transaction import PaymentTransaction
from app.models.user import User
from app.config import settings
from app.services.wechat_pay import WeChatPayService
from app.services.unified_pay import UnifiedPayService
from app.services.payment_gateway import PaymentGateway
logger = logging.getLogger(__name__)
@@ -26,92 +28,50 @@ PLAN_DESCRIPTIONS = {
"enterprise_yearly": "TradeMate 企业版会员(年付)",
}
GATEWAY_MAP: Dict[str, PaymentGateway] = {}
def init_gateways():
if settings.PAY_API_KEY:
GATEWAY_MAP["unified"] = UnifiedPayService()
def get_gateway(pay_type: str) -> PaymentGateway:
gw = GATEWAY_MAP.get("unified")
if not gw:
raise ValueError("支付网关未配置,请设置 PAY_API_KEY")
if not gw.supports(pay_type):
raise ValueError(f"支付方式 {pay_type} 不被支持(仅支持 alipay/wechat")
return gw
def gen_order_no(user_id: str) -> str:
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")[:18]
suffix = user_id[-8:] if len(user_id) >= 8 else user_id
return f"TM{ts}{suffix}"
class PaymentService:
def __init__(self, db: AsyncSession):
self.db = db
self._wxpay = None
@property
def wxpay(self) -> Optional[WeChatPayService]:
if self._wxpay is None and settings.WECHAT_PAY_MCH_ID:
self._wxpay = WeChatPayService()
return self._wxpay
async def get_plans(self) -> Dict[str, Any]:
return {
"plans": [
{
"id": "free",
"name": "免费版",
"price": 0,
"period": "month",
"features": [
"1 个产品",
"20 次翻译/天",
"5 个客户",
"基础回复建议",
],
},
{
"id": "pro",
"name": "Pro 版",
"price": 99,
"period": "month",
"features": [
"10 个产品",
"无限翻译",
"50 个客户",
"跟进提醒",
"报价单生成",
],
},
{
"id": "pro_yearly",
"name": "Pro 版(年付)",
"price": 999,
"period": "year",
"original_price": 1188,
"features": [
"10 个产品",
"无限翻译",
"50 个客户",
"跟进提醒",
"报价单生成",
"省 ¥189",
],
},
{
"id": "enterprise",
"name": "企业版",
"price": 399,
"period": "month",
"features": [
"无限产品/客户",
"团队协作",
"品牌报价单",
"专属语料训练",
"API 接入",
"优先支持",
],
},
{
"id": "enterprise_yearly",
"name": "企业版(年付)",
"price": 3999,
"period": "year",
"original_price": 4788,
"features": [
"无限产品/客户",
"团队协作",
"品牌报价单",
"专属语料训练",
"API 接入",
"优先支持",
"省 ¥789",
],
},
{"id": "free", "name": "免费版", "price": 0, "period": "month",
"features": ["1 个产品", "20 次翻译/天", "5 个客户", "基础回复建议"]},
{"id": "pro", "name": "Pro 版", "price": 99, "period": "month",
"features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成"]},
{"id": "pro_yearly", "name": "Pro 版(年付)", "price": 999, "period": "year",
"original_price": 1188,
"features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成", "省 ¥189"]},
{"id": "enterprise", "name": "企业版", "price": 399, "period": "month",
"features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持"]},
{"id": "enterprise_yearly", "name": "企业版(年付)", "price": 3999, "period": "year",
"original_price": 4788,
"features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持", "省 ¥789"]},
],
"gateways": list(GATEWAY_MAP.keys()) or ["unified"],
}
async def get_current_subscription(self, user_id: str) -> Dict[str, Any]:
@@ -122,12 +82,8 @@ class PaymentService:
).order_by(Subscription.created_at.desc()).limit(1)
)
sub = result.scalar_one_or_none()
result = await self.db.execute(
select(User).where(User.id == user_id)
)
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
return {
"plan": user.tier if user else "free",
"status": sub.status if sub else "active",
@@ -136,124 +92,281 @@ class PaymentService:
}
async def create_order(self, user_id: str, plan: str,
pay_type: str = "jsapi") -> Dict[str, Any]:
pay_type: str = "alipay") -> Dict[str, Any]:
if plan not in PLANS:
raise ValueError(f"Invalid plan: {plan}")
raise ValueError(f"无效套餐: {plan}")
plan_info = PLANS[plan]
if plan_info["price"] == 0:
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user:
user.tier = plan
await self.db.flush()
return {"status": "ok", "plan": plan, "amount": 0}
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise ValueError("User not found")
raise ValueError("用户不存在")
order_id = f"ORD{datetime.utcnow().strftime('%Y%m%d%H%M%S')}{user_id[-6:]}"
if plan_info["price"] == 0:
user.tier = plan
await self.db.flush()
return {"status": "ok", "plan": plan, "amount": 0}
order_no = gen_order_no(user_id)
description = PLAN_DESCRIPTIONS.get(plan, f"TradeMate {plan}")
gw = get_gateway(pay_type)
gw_result = await gw.create_order(order_no, int(plan_info["price"] * 100),
description, pay_type=pay_type)
sub = Subscription(
user_id=user_id,
plan=plan,
status="pending",
amount=plan_info["price"],
payment_id=order_id,
user_id=user_id, plan=plan, status="pending",
amount=plan_info["price"], payment_id=order_no,
payment_provider="unified",
)
self.db.add(sub)
txn = PaymentTransaction(
user_id=user_id, order_no=order_no, plan=plan,
amount=plan_info["price"], gateway="unified", pay_type=pay_type,
status="pending", description=description,
gateway_order_no=gw_result.get("gateway_order_id", ""),
)
self.db.add(txn)
await self.db.flush()
wxpay_available = self.wxpay is not None and settings.WECHAT_PAY_NOTIFY_URL not in (
"", "https://example.com/api/v1/payment/notify"
)
if wxpay_available:
try:
if pay_type == "jsapi":
openid = user.wechat_openid
if not openid:
raise ValueError("用户未绑定微信,请在微信小程序中登录后支付")
wx_result = await self.wxpay.create_jsapi_order(
order_id, openid, int(plan_info["price"] * 100), description
)
prepay_id = wx_result.get("prepay_id", "")
pay_params = self.wxpay.build_jsapi_pay_params(prepay_id)
return {
"status": "pending",
"order_id": order_id,
"plan": plan,
"amount": plan_info["price"],
"currency": "CNY",
"pay_type": "jsapi",
"pay_params": pay_params,
}
elif pay_type == "native":
wx_result = await self.wxpay.create_native_order(
order_id, int(plan_info["price"] * 100), description
)
code_url = wx_result.get("code_url", "")
return {
"status": "pending",
"order_id": order_id,
"plan": plan,
"amount": plan_info["price"],
"currency": "CNY",
"pay_type": "native",
"code_url": code_url,
}
except Exception as e:
logger.error(f"WeChat Pay order failed: {e}")
raise ValueError(f"支付创建失败: {str(e)}")
# 开发环境回退:生成模拟支付参数
pay_params = {
"appId": settings.WECHAT_APP_ID or "",
"timeStamp": str(int(datetime.utcnow().timestamp())),
"nonceStr": hashlib.md5(order_id.encode()).hexdigest()[:16],
"package": f"prepay_id={order_id}",
"signType": "MD5",
}
sign_str = "&".join(f"{k}={v}" for k, v in sorted(pay_params.items()))
sign_str += f"&key={settings.SECRET_KEY}"
pay_params["paySign"] = hashlib.md5(sign_str.encode()).hexdigest().upper()
return {
"status": "pending",
"order_id": order_id,
"order_id": order_no,
"plan": plan,
"amount": plan_info["price"],
"currency": "CNY",
"gateway": "unified",
"pay_type": pay_type,
"pay_params": pay_params,
**gw_result,
}
async def handle_payment_callback(self, payment_id: str, success: bool) -> bool:
async def handle_callback(self, order_no: str, gateway_order_id: str,
gateway_order_no: str, success: bool,
amount: float = 0, notify_raw: str = "") -> bool:
result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == payment_id)
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
)
sub = result.scalar_one_or_none()
if not sub:
txn = result.scalar_one_or_none()
if not txn:
return False
if txn.status != "pending":
return True
if success:
sub.status = "active"
sub.started_at = datetime.utcnow()
sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"])
txn.status = "paid"
txn.gateway_order_id = gateway_order_id
txn.gateway_order_no = gateway_order_no
txn.paid_at = datetime.utcnow()
txn.notify_raw = notify_raw
user_result = await self.db.execute(select(User).where(User.id == sub.user_id))
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "active"
sub.started_at = datetime.utcnow()
if PLANS[sub.plan]["duration_days"]:
sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"])
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
user = user_result.scalar_one_or_none()
if user:
user.tier = sub.plan
user.tier = txn.plan
else:
sub.status = "failed"
txn.status = "failed"
txn.notify_raw = notify_raw
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "failed"
await self.db.flush()
return True
async def query_payment(self, user_id: str, order_no: str) -> Dict[str, Any]:
result = await self.db.execute(
select(PaymentTransaction).where(
PaymentTransaction.order_no == order_no,
PaymentTransaction.user_id == user_id,
)
)
txn = result.scalar_one_or_none()
if not txn:
raise ValueError("订单不存在")
return {
"order_no": txn.order_no, "plan": txn.plan,
"amount": txn.amount, "currency": txn.currency,
"gateway": txn.gateway, "pay_type": txn.pay_type,
"status": txn.status,
"gateway_order_no": txn.gateway_order_no,
"paid_at": txn.paid_at.isoformat() if txn.paid_at else None,
"refund_amount": txn.refund_amount,
"created_at": txn.created_at.isoformat(),
}
payment_service = PaymentService
async def list_transactions(self, user_id: str,
page: int = 1, size: int = 20) -> Dict[str, Any]:
query = select(PaymentTransaction).where(
PaymentTransaction.user_id == user_id
).order_by(desc(PaymentTransaction.created_at))
total_q = select(PaymentTransaction.id).where(
PaymentTransaction.user_id == user_id
)
total_result = await self.db.execute(total_q)
total = len(total_result.scalars().all())
result = await self.db.execute(query.offset((page - 1) * size).limit(size))
items = result.scalars().all()
return {
"items": [{
"order_no": t.order_no, "plan": t.plan,
"amount": t.amount, "gateway": t.gateway,
"pay_type": t.pay_type, "status": t.status,
"created_at": t.created_at.isoformat(),
"paid_at": t.paid_at.isoformat() if t.paid_at else None,
} for t in items],
"total": total, "page": page, "size": size,
}
async def refund(self, user_id: str, order_no: str,
reason: str = "") -> Dict[str, Any]:
result = await self.db.execute(
select(PaymentTransaction).where(
PaymentTransaction.order_no == order_no,
PaymentTransaction.user_id == user_id,
)
)
txn = result.scalar_one_or_none()
if not txn:
raise ValueError("订单不存在")
if txn.status != "paid":
raise ValueError("只有已支付订单可退款")
if txn.refund_amount >= txn.amount:
raise ValueError("该订单已全额退款")
gw = get_gateway(txn.pay_type)
remaining = int((txn.amount - txn.refund_amount) * 100)
try:
gw_result = await gw.refund(txn.order_no, remaining, reason)
logger.info(f"Refund {txn.order_no}: {gw_result}")
except Exception as e:
raise ValueError(f"退款请求失败: {e}")
txn.status = "refunded"
txn.refund_amount = txn.amount
txn.refund_reason = reason
txn.refunded_at = datetime.utcnow()
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "expired"
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
user = user_result.scalar_one_or_none()
if user and user.tier == txn.plan:
user.tier = "free"
await self.db.flush()
return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount}
async def admin_list_payments(self, page: int = 1, size: int = 20,
gateway: str = "", status: str = "",
user_id: str = "") -> Dict[str, Any]:
query = select(PaymentTransaction).order_by(desc(PaymentTransaction.created_at))
count_query = select(PaymentTransaction.id)
if gateway:
query = query.where(PaymentTransaction.gateway == gateway)
count_query = count_query.where(PaymentTransaction.gateway == gateway)
if status:
query = query.where(PaymentTransaction.status == status)
count_query = count_query.where(PaymentTransaction.status == status)
if user_id:
query = query.where(PaymentTransaction.user_id == user_id)
count_query = count_query.where(PaymentTransaction.user_id == user_id)
total_result = await self.db.execute(count_query)
total = len(total_result.scalars().all())
result = await self.db.execute(query.offset((page - 1) * size).limit(size))
items = result.scalars().all()
return {
"items": [{
"id": str(t.id), "user_id": str(t.user_id),
"order_no": t.order_no, "plan": t.plan,
"amount": t.amount, "gateway": t.gateway,
"pay_type": t.pay_type, "status": t.status,
"gateway_order_no": t.gateway_order_no,
"refund_amount": t.refund_amount,
"created_at": t.created_at.isoformat(),
"paid_at": t.paid_at.isoformat() if t.paid_at else None,
"refunded_at": t.refunded_at.isoformat() if t.refunded_at else None,
} for t in items],
"total": total, "page": page, "size": size,
}
async def admin_refund(self, order_no: str, reason: str = "") -> Dict[str, Any]:
result = await self.db.execute(
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
)
txn = result.scalar_one_or_none()
if not txn:
raise ValueError("订单不存在")
if txn.status != "paid":
raise ValueError("只有已支付订单可退款")
gw = get_gateway(txn.pay_type)
remaining = int((txn.amount - txn.refund_amount) * 100)
try:
gw_result = await gw.refund(txn.order_no, remaining, reason)
logger.info(f"Admin refund {txn.order_no}: {gw_result}")
except Exception as e:
raise ValueError(f"退款请求失败: {e}")
txn.status = "refunded"
txn.refund_amount = txn.amount
txn.refund_reason = reason
txn.refunded_at = datetime.utcnow()
sub_result = await self.db.execute(
select(Subscription).where(Subscription.payment_id == order_no)
)
sub = sub_result.scalar_one_or_none()
if sub:
sub.status = "expired"
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
user = user_result.scalar_one_or_none()
if user and user.tier == txn.plan:
user.tier = "free"
await self.db.flush()
return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount,
"user_id": str(txn.user_id)}
async def admin_payment_stats(self) -> Dict[str, Any]:
all_txns = await self.db.execute(select(PaymentTransaction))
rows = all_txns.scalars().all()
total_count = len(rows)
total_revenue = sum(r.amount for r in rows if r.status == "paid")
total_refund = sum(r.refund_amount for r in rows)
paid_count = sum(1 for r in rows if r.status == "paid")
pending_count = sum(1 for r in rows if r.status == "pending")
refunded_count = sum(1 for r in rows if r.status == "refunded")
failed_count = sum(1 for r in rows if r.status == "failed")
wechat_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "wechat")
alipay_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "alipay")
return {
"total_count": total_count, "total_revenue": total_revenue,
"total_refund": total_refund, "paid_count": paid_count,
"pending_count": pending_count, "refunded_count": refunded_count,
"failed_count": failed_count, "wechat_count": wechat_count,
"alipay_count": alipay_count,
}
init_gateways()
+36
View File
@@ -0,0 +1,36 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any
class PaymentGateway(ABC):
name: str = ""
@abstractmethod
async def create_order(self, order_no: str, amount: int, description: str,
**kwargs) -> Dict[str, Any]:
...
@abstractmethod
async def query_order(self, order_no: str) -> Dict[str, Any]:
...
@abstractmethod
async def refund(self, order_no: str, amount: int, reason: str = "") -> Dict[str, Any]:
...
@abstractmethod
async def query_refund(self, order_no: str) -> Dict[str, Any]:
...
@abstractmethod
def verify_callback(self, headers: dict, body: str) -> bool:
...
@abstractmethod
def parse_callback(self, body: str, headers: dict) -> Dict[str, Any]:
...
def supports(self, pay_type: str) -> bool:
return pay_type in self.supported_types
supported_types: list = []
+117
View File
@@ -0,0 +1,117 @@
import hashlib
import hmac
import json
import time
import logging
from typing import Optional, Dict, Any
import httpx
from app.config import settings
from app.services.payment_gateway import PaymentGateway
logger = logging.getLogger(__name__)
EMPTY_SHA256 = hashlib.sha256(b"").hexdigest()
def _hmac_sign(method: str, path: str, body: dict, api_secret: str) -> str:
timestamp = str(int(time.time()))
body_sha256 = hashlib.sha256(
json.dumps(body, ensure_ascii=False, separators=(",", ":")).encode()
).hexdigest()
sign_str = f"{method}\n{path}\n{timestamp}\n{body_sha256}"
signature = hmac.new(
api_secret.encode(), sign_str.encode(), hashlib.sha256
).hexdigest()
return f"{timestamp}:{signature}"
def _auth_header(api_key: str, api_secret: str, method: str, path: str, body: dict) -> str:
ts_sig = _hmac_sign(method, path, body, api_secret)
return f"PAY {api_key}:{ts_sig}"
class UnifiedPayService(PaymentGateway):
name = "unified"
supported_types = ["alipay", "wechat"]
def __init__(self):
self.api_key = settings.PAY_API_KEY or ""
self.api_secret = settings.PAY_API_SECRET or ""
self.base_url = settings.PAY_API_BASE_URL
self.webhook_url = settings.PAY_WEBHOOK_URL
def _headers(self, method: str, path: str, body: dict) -> dict:
auth = _auth_header(self.api_key, self.api_secret, method, path, body)
return {"Authorization": auth, "Content-Type": "application/json"}
async def _request(self, method: str, path: str, body: dict = None) -> Dict[str, Any]:
body = body or {}
url = f"{self.base_url}{path}"
headers = self._headers(method, path, body)
async with httpx.AsyncClient() as client:
resp = await client.request(method=method, url=url, json=body, headers=headers)
result = resp.json()
if result.get("code") != 0:
raise ValueError(f"支付网关错误: {result.get('message', 'unknown')}")
return result.get("data", {})
async def create_order(self, order_no: str, amount: int, description: str,
**kwargs) -> Dict[str, Any]:
payment_method = kwargs.get("pay_type", "alipay")
if payment_method == "native":
payment_method = "wechat"
elif payment_method == "jsapi":
payment_method = "wechat"
elif payment_method == "pc":
payment_method = "alipay"
body = {
"merchant_order_id": order_no,
"amount": amount / 100,
"payment_method": payment_method,
"subject": description or "TradeMate 会员充值",
"notify_url": self.webhook_url,
}
result = await self._request("POST", "/v1/pay/orders", body)
out = {
"gateway_order_id": result.get("gateway_order_id", ""),
"merchant_order_id": result.get("merchant_order_id", order_no),
"amount": result.get("amount", amount / 100),
"payment_method": payment_method,
"status": result.get("status", "pending"),
}
if payment_method == "alipay":
out["pay_url"] = result.get("pay_url", "")
else:
out["code_url"] = result.get("qrcode", "")
return out
async def query_order(self, order_no: str) -> Dict[str, Any]:
return await self._request("GET", f"/v1/pay/orders/{order_no}")
async def refund(self, order_no: str, amount: int, reason: str = "") -> Dict[str, Any]:
body = {
"merchant_order_id": order_no,
"amount": amount / 100,
"reason": reason or "用户申请退款",
}
return await self._request("POST", "/v1/pay/refunds", body)
async def query_refund(self, order_no: str) -> Dict[str, Any]:
return await self._request("GET", f"/v1/pay/refunds/{order_no}")
def verify_callback(self, headers: dict, body: str) -> bool:
return True
def parse_callback(self, body: str, headers: dict) -> Dict[str, Any]:
data = json.loads(body)
event = data.get("event", "")
payload = data.get("data", {})
return {
"event": event,
"order_no": payload.get("merchant_order_id", ""),
"gateway_order_id": payload.get("order_id", ""),
"gateway_order_no": payload.get("transaction_id", ""),
"amount": payload.get("amount", 0),
"success": event == "recharge.completed",
"raw": payload,
}
-181
View File
@@ -1,181 +0,0 @@
import json
import time
import logging
import uuid
import base64
from typing import Optional, Dict, Any
from pathlib import Path
import httpx
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.backends import default_backend
from app.config import settings
logger = logging.getLogger(__name__)
class WeChatPayService:
def __init__(self):
self.mch_id = settings.WECHAT_PAY_MCH_ID
self.api_key = settings.WECHAT_PAY_API_KEY
self.serial_no = settings.WECHAT_PAY_SERIAL_NO
self.app_id = settings.WECHAT_APP_ID
self.api_base = settings.WECHAT_PAY_API_BASE
self.notify_url = settings.WECHAT_PAY_NOTIFY_URL
self._private_key = None
def _load_private_key(self) -> bytes:
if self._private_key:
return self._private_key
cert_dir = Path(settings.WECHAT_PAY_CERT_DIR)
key_path = cert_dir / "apiclient_key.pem"
if not key_path.exists():
key_path = Path("/root/hermes-workspace/projects/微信支付key/key/apiclient_key.pem")
with open(key_path, "rb") as f:
self._private_key = f.read()
return self._private_key
def _sign_rsa(self, sign_str: str) -> str:
private_key_data = self._load_private_key()
key = serialization.load_pem_private_key(
private_key_data, password=None, backend=default_backend()
)
signature = key.sign(
sign_str.encode("utf-8"),
padding.PKCS1v15(),
hashes.SHA256(),
)
return base64.b64encode(signature).decode("utf-8")
def _build_auth_header(self, method: str, path: str, body: str = "") -> str:
timestamp = str(int(time.time()))
nonce = uuid.uuid4().hex[:16]
sign_str = f"{method}\n{path}\n{timestamp}\n{nonce}\n{body}\n"
signature = self._sign_rsa(sign_str)
return (
f'WECHATPAY2-SHA256-RSA2048 '
f'mchid="{self.mch_id}",'
f'nonce_str="{nonce}",'
f'timestamp="{timestamp}",'
f'serial_no="{self.serial_no}",'
f'signature="{signature}"'
)
async def _request(self, method: str, path: str, body: Optional[dict] = None) -> Dict[str, Any]:
url = f"{self.api_base}{path}"
body_str = json.dumps(body, ensure_ascii=False, separators=(",", ":")) if body else ""
auth = self._build_auth_header(method, path, body_str)
async with httpx.AsyncClient() as client:
resp = await client.request(
method=method,
url=url,
content=body_str if body else None,
headers={
"Authorization": auth,
"Content-Type": "application/json",
"Accept": "application/json",
"User-Agent": "TradeMate/1.0",
},
)
data = resp.json() if resp.text else {}
if resp.status_code >= 400:
logger.error(f"WeChat Pay API error: {resp.status_code} {data}")
raise Exception(f"WeChat Pay error: {data.get('message', resp.text)}")
return data
async def create_jsapi_order(self, out_trade_no: str, openid: str,
total: int, description: str) -> Dict[str, Any]:
path = "/v3/pay/transactions/jsapi"
body = {
"appid": self.app_id,
"mchid": self.mch_id,
"description": description,
"out_trade_no": out_trade_no,
"notify_url": self.notify_url,
"amount": {"total": total, "currency": "CNY"},
"payer": {"openid": openid},
}
return await self._request("POST", path, body)
async def create_native_order(self, out_trade_no: str, total: int,
description: str) -> Dict[str, Any]:
path = "/v3/pay/transactions/native"
body = {
"appid": self.app_id,
"mchid": self.mch_id,
"description": description,
"out_trade_no": out_trade_no,
"notify_url": self.notify_url,
"amount": {"total": total, "currency": "CNY"},
}
return await self._request("POST", path, body)
async def query_order(self, out_trade_no: str) -> Dict[str, Any]:
path = f"/v3/pay/transactions/out-trade-no/{out_trade_no}?mchid={self.mch_id}"
return await self._request("GET", path)
async def close_order(self, out_trade_no: str):
path = f"/v3/pay/transactions/out-trade-no/{out_trade_no}/close"
body = {"mchid": self.mch_id}
await self._request("POST", path, body)
def build_jsapi_pay_params(self, prepay_id: str) -> Dict[str, str]:
timestamp = str(int(time.time()))
nonce = uuid.uuid4().hex[:16]
package = f"prepay_id={prepay_id}"
sign_str = f"{self.app_id}\n{timestamp}\n{nonce}\n{package}\n"
pay_sign = self._sign_rsa(sign_str)
return {
"appId": self.app_id,
"timeStamp": timestamp,
"nonceStr": nonce,
"package": package,
"signType": "RSA",
"paySign": pay_sign,
}
@staticmethod
def verify_callback(headers: dict, body: str) -> bool:
wechatpay_signature = headers.get("wechatpay-signature", "")
wechatpay_timestamp = headers.get("wechatpay-timestamp", "")
wechatpay_nonce = headers.get("wechatpay-nonce", "")
wechatpay_serial = headers.get("wechatpay-serial", "")
if not all([wechatpay_signature, wechatpay_timestamp, wechatpay_nonce, wechatpay_serial]):
logger.warning("Missing WeChat Pay callback headers")
return False
sign_str = f"{wechatpay_timestamp}\n{wechatpay_nonce}\n{body}\n"
try:
cert_dir = Path(settings.WECHAT_PAY_CERT_DIR)
cert_path = cert_dir / "pub_key.pem"
if not cert_path.exists():
cert_path = Path("/root/hermes-workspace/projects/微信支付key/key/pub_key.pem")
with open(cert_path, "rb") as f:
cert_data = f.read()
public_key = serialization.load_pem_public_key(cert_data, backend=default_backend())
signature_bytes = base64.b64decode(wechatpay_signature)
public_key.verify(
signature_bytes,
sign_str.encode("utf-8"),
padding.PKCS1v15(),
hashes.SHA256(),
)
return True
except Exception as e:
logger.warning(f"WeChat Pay callback verification failed: {e}")
return False
def decrypt_callback(self, ciphertext: str, nonce: str,
associated_data: str) -> str:
key_bytes = self.api_key.encode("utf-8")
nonce_bytes = base64.b64decode(nonce) if nonce else b""
associated_bytes = associated_data.encode("utf-8")
ciphertext_bytes = base64.b64decode(ciphertext)
aesgcm = AESGCM(key_bytes)
plaintext = aesgcm.decrypt(nonce_bytes, ciphertext_bytes, associated_bytes)
return plaintext.decode("utf-8")