refactor: replace direct WeChat/Alipay with unified pay-api gateway
Switch from direct WeChat Pay / Alipay integrations to the unified
宇之然 pay-api gateway (HMAC-SHA256 auth). Removes wechat_pay.py,
keeps PaymentGateway abstraction, adds UnifiedPayService. Simplifies
payment.py create_order to {plan, pay_type} params. Single webhook
endpoint replaces separate WeChat/Alipay notify handlers.
This commit is contained in:
@@ -8,6 +8,7 @@ from app.services.admin import AdminService
|
||||
from app.services.translation_quota import TranslationQuotaService
|
||||
from app.services.certification import CertificationService
|
||||
from app.services.invoice import InvoiceService
|
||||
from app.services.payment import PaymentService
|
||||
from app.api.v1.deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
@@ -274,3 +275,41 @@ async def admin_process_invoice(
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Invoice not found")
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/payments")
|
||||
async def admin_list_payments(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
gateway: str = Query(default=""),
|
||||
status: str = Query(default=""),
|
||||
user_id: str = Query(default=""),
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PaymentService(db)
|
||||
return await svc.admin_list_payments(page, size, gateway, status, user_id)
|
||||
|
||||
|
||||
@router.get("/payments/stats")
|
||||
async def admin_payment_stats(
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PaymentService(db)
|
||||
return await svc.admin_payment_stats()
|
||||
|
||||
|
||||
@router.post("/payments/refund")
|
||||
async def admin_refund(
|
||||
data: dict,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
order_no = data.get("order_no", "")
|
||||
reason = data.get("reason", "")
|
||||
svc = PaymentService(db)
|
||||
try:
|
||||
return await svc.admin_refund(order_no, reason)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Header
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from app.database import get_db
|
||||
from app.services.payment import PaymentService
|
||||
from app.services.wechat_pay import WeChatPayService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.core.csrf import require_csrf_token
|
||||
|
||||
@@ -13,12 +12,12 @@ router = APIRouter()
|
||||
|
||||
class CreateOrderRequest(BaseModel):
|
||||
plan: str
|
||||
pay_type: str = "jsapi"
|
||||
pay_type: str = "alipay"
|
||||
|
||||
|
||||
class PaymentCallbackRequest(BaseModel):
|
||||
payment_id: str
|
||||
success: bool
|
||||
class RefundRequest(BaseModel):
|
||||
order_no: str
|
||||
reason: str = ""
|
||||
|
||||
|
||||
@router.get("/plans")
|
||||
@@ -50,42 +49,65 @@ async def create_order(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/callback")
|
||||
async def payment_callback(
|
||||
data: PaymentCallbackRequest,
|
||||
@router.get("/query/{order_no}")
|
||||
async def query_payment(
|
||||
order_no: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PaymentService(db)
|
||||
try:
|
||||
return await svc.query_payment(user_id, order_no)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/transactions")
|
||||
async def list_transactions(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = PaymentService(db)
|
||||
return await svc.list_transactions(user_id, page, size)
|
||||
|
||||
|
||||
@router.post("/refund")
|
||||
async def refund(
|
||||
data: RefundRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_csrf: str = Depends(require_csrf_token),
|
||||
):
|
||||
svc = PaymentService(db)
|
||||
success = await svc.handle_payment_callback(data.payment_id, data.success)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Order not found")
|
||||
return {"status": "ok"}
|
||||
try:
|
||||
return await svc.refund(user_id, data.order_no, data.reason)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/notify")
|
||||
async def wechat_pay_notify(request: Request, db: AsyncSession = Depends(get_db)):
|
||||
@router.post("/webhook")
|
||||
async def unified_webhook(request: Request, db: AsyncSession = Depends(get_db)):
|
||||
body = await request.body()
|
||||
body_str = body.decode("utf-8")
|
||||
headers = dict(request.headers)
|
||||
|
||||
wxpay = WeChatPayService()
|
||||
if not wxpay.verify_callback(headers, body_str):
|
||||
raise HTTPException(status_code=401, detail="签名验证失败")
|
||||
|
||||
import json
|
||||
data = json.loads(body_str)
|
||||
resource = data.get("resource", {})
|
||||
ciphertext = resource.get("ciphertext", "")
|
||||
nonce = resource.get("nonce", "")
|
||||
associated_data = resource.get("associated_data", "")
|
||||
try:
|
||||
data = json.loads(body_str)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="无效的 JSON")
|
||||
|
||||
plaintext = wxpay.decrypt_callback(ciphertext, nonce, associated_data)
|
||||
pay_data = json.loads(plaintext)
|
||||
out_trade_no = pay_data.get("out_trade_no", "")
|
||||
trade_state = pay_data.get("trade_state", "")
|
||||
event = data.get("event", "")
|
||||
pay_data = data.get("data", {})
|
||||
merchant_order_id = pay_data.get("merchant_order_id", "")
|
||||
order_id = pay_data.get("order_id", "")
|
||||
transaction_id = pay_data.get("transaction_id", "")
|
||||
amount = pay_data.get("amount", 0)
|
||||
success = event == "recharge.completed"
|
||||
|
||||
success = trade_state == "SUCCESS"
|
||||
svc = PaymentService(db)
|
||||
await svc.handle_payment_callback(out_trade_no, success)
|
||||
return {"code": "SUCCESS", "message": "OK"}
|
||||
await svc.handle_callback(
|
||||
merchant_order_id, order_id, transaction_id,
|
||||
success, amount, body_str,
|
||||
)
|
||||
return {"code": 0, "message": "OK"}
|
||||
|
||||
@@ -55,12 +55,10 @@ class Settings(BaseSettings):
|
||||
WECHAT_APP_SECRET: Optional[str] = None
|
||||
WECHAT_PUSH_TEMPLATE_ID: Optional[str] = None
|
||||
|
||||
WECHAT_PAY_MCH_ID: Optional[str] = None
|
||||
WECHAT_PAY_API_KEY: Optional[str] = None
|
||||
WECHAT_PAY_SERIAL_NO: Optional[str] = None
|
||||
WECHAT_PAY_CERT_DIR: str = "./certs"
|
||||
WECHAT_PAY_NOTIFY_URL: str = "https://example.com/api/v1/payment/notify"
|
||||
WECHAT_PAY_API_BASE: str = "https://api.mch.weixin.qq.com"
|
||||
PAY_API_KEY: Optional[str] = None
|
||||
PAY_API_SECRET: Optional[str] = None
|
||||
PAY_API_BASE_URL: str = "https://www.yzrcloud.cn/api/gateway"
|
||||
PAY_WEBHOOK_URL: str = "https://example.com/api/v1/payment/webhook"
|
||||
|
||||
EXCHANGE_RATE_API_KEY: Optional[str] = None
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
|
||||
# Endpoints that should skip CSRF protection (e.g., webhook endpoints)
|
||||
CSRF_SKIP_ENDPOINTS = [
|
||||
"/api/v1/webhook/",
|
||||
"/api/v1/payment/notify",
|
||||
"/api/v1/payment/webhook",
|
||||
"/api/v1/whatsapp/webhook",
|
||||
]
|
||||
|
||||
|
||||
@@ -18,26 +18,23 @@ from .referral import ReferralCode, Referral
|
||||
from .search_provider import SearchProvider
|
||||
from .discovery_record import DiscoveryRecord
|
||||
from .ai_provider import AIProvider
|
||||
from .payment_transaction import PaymentTransaction
|
||||
|
||||
__all__ = [
|
||||
"User", "Product",
|
||||
"Customer", "Conversation", "Message",
|
||||
"Quotation", "QuotationItem",
|
||||
"CorpusEntry",
|
||||
"Team", "TeamMember",
|
||||
"UsageLog",
|
||||
"Notification",
|
||||
"Feedback",
|
||||
"Subscription",
|
||||
"CorpusEntry",
|
||||
"Notification",
|
||||
"Team", "TeamMember",
|
||||
"Feedback",
|
||||
"PreferenceAnalysis", "MarketingEffect",
|
||||
"Device",
|
||||
"FollowupStrategy", "FollowupLog",
|
||||
"SystemConfig",
|
||||
"TranslationQuota",
|
||||
"Certification", "CertType", "CertStatus",
|
||||
"Invoice", "InvoiceType", "InvoiceStatus",
|
||||
"Certification", "Invoice", "InvoiceType", "InvoiceStatus",
|
||||
"ReferralCode", "Referral",
|
||||
"SearchProvider",
|
||||
"DiscoveryRecord",
|
||||
"AIProvider",
|
||||
"PaymentTransaction",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Float, Text, Boolean
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class PaymentTransaction(Base):
|
||||
__tablename__ = "payment_transactions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
order_no = Column(String(64), unique=True, nullable=False, index=True)
|
||||
gateway_order_id = Column(String(128), nullable=True)
|
||||
gateway_order_no = Column(String(128), nullable=True)
|
||||
plan = Column(String(50), nullable=False)
|
||||
amount = Column(Float, nullable=False)
|
||||
currency = Column(String(10), default="CNY")
|
||||
gateway = Column(String(20), nullable=False)
|
||||
pay_type = Column(String(20), nullable=False)
|
||||
status = Column(String(20), default="pending")
|
||||
description = Column(Text, nullable=True)
|
||||
refund_amount = Column(Float, default=0)
|
||||
refund_reason = Column(Text, nullable=True)
|
||||
paid_at = Column(DateTime, nullable=True)
|
||||
refunded_at = Column(DateTime, nullable=True)
|
||||
notify_raw = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
+284
-171
@@ -1,13 +1,15 @@
|
||||
import logging
|
||||
import hashlib
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, desc
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.payment_transaction import PaymentTransaction
|
||||
from app.models.user import User
|
||||
from app.config import settings
|
||||
from app.services.wechat_pay import WeChatPayService
|
||||
from app.services.unified_pay import UnifiedPayService
|
||||
from app.services.payment_gateway import PaymentGateway
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,92 +28,50 @@ PLAN_DESCRIPTIONS = {
|
||||
"enterprise_yearly": "TradeMate 企业版会员(年付)",
|
||||
}
|
||||
|
||||
GATEWAY_MAP: Dict[str, PaymentGateway] = {}
|
||||
|
||||
|
||||
def init_gateways():
|
||||
if settings.PAY_API_KEY:
|
||||
GATEWAY_MAP["unified"] = UnifiedPayService()
|
||||
|
||||
|
||||
def get_gateway(pay_type: str) -> PaymentGateway:
|
||||
gw = GATEWAY_MAP.get("unified")
|
||||
if not gw:
|
||||
raise ValueError("支付网关未配置,请设置 PAY_API_KEY")
|
||||
if not gw.supports(pay_type):
|
||||
raise ValueError(f"支付方式 {pay_type} 不被支持(仅支持 alipay/wechat)")
|
||||
return gw
|
||||
|
||||
|
||||
def gen_order_no(user_id: str) -> str:
|
||||
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")[:18]
|
||||
suffix = user_id[-8:] if len(user_id) >= 8 else user_id
|
||||
return f"TM{ts}{suffix}"
|
||||
|
||||
|
||||
class PaymentService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self._wxpay = None
|
||||
|
||||
@property
|
||||
def wxpay(self) -> Optional[WeChatPayService]:
|
||||
if self._wxpay is None and settings.WECHAT_PAY_MCH_ID:
|
||||
self._wxpay = WeChatPayService()
|
||||
return self._wxpay
|
||||
|
||||
async def get_plans(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"plans": [
|
||||
{
|
||||
"id": "free",
|
||||
"name": "免费版",
|
||||
"price": 0,
|
||||
"period": "month",
|
||||
"features": [
|
||||
"1 个产品",
|
||||
"20 次翻译/天",
|
||||
"5 个客户",
|
||||
"基础回复建议",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "pro",
|
||||
"name": "Pro 版",
|
||||
"price": 99,
|
||||
"period": "month",
|
||||
"features": [
|
||||
"10 个产品",
|
||||
"无限翻译",
|
||||
"50 个客户",
|
||||
"跟进提醒",
|
||||
"报价单生成",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "pro_yearly",
|
||||
"name": "Pro 版(年付)",
|
||||
"price": 999,
|
||||
"period": "year",
|
||||
"original_price": 1188,
|
||||
"features": [
|
||||
"10 个产品",
|
||||
"无限翻译",
|
||||
"50 个客户",
|
||||
"跟进提醒",
|
||||
"报价单生成",
|
||||
"省 ¥189",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "enterprise",
|
||||
"name": "企业版",
|
||||
"price": 399,
|
||||
"period": "month",
|
||||
"features": [
|
||||
"无限产品/客户",
|
||||
"团队协作",
|
||||
"品牌报价单",
|
||||
"专属语料训练",
|
||||
"API 接入",
|
||||
"优先支持",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "enterprise_yearly",
|
||||
"name": "企业版(年付)",
|
||||
"price": 3999,
|
||||
"period": "year",
|
||||
"original_price": 4788,
|
||||
"features": [
|
||||
"无限产品/客户",
|
||||
"团队协作",
|
||||
"品牌报价单",
|
||||
"专属语料训练",
|
||||
"API 接入",
|
||||
"优先支持",
|
||||
"省 ¥789",
|
||||
],
|
||||
},
|
||||
{"id": "free", "name": "免费版", "price": 0, "period": "month",
|
||||
"features": ["1 个产品", "20 次翻译/天", "5 个客户", "基础回复建议"]},
|
||||
{"id": "pro", "name": "Pro 版", "price": 99, "period": "month",
|
||||
"features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成"]},
|
||||
{"id": "pro_yearly", "name": "Pro 版(年付)", "price": 999, "period": "year",
|
||||
"original_price": 1188,
|
||||
"features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成", "省 ¥189"]},
|
||||
{"id": "enterprise", "name": "企业版", "price": 399, "period": "month",
|
||||
"features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持"]},
|
||||
{"id": "enterprise_yearly", "name": "企业版(年付)", "price": 3999, "period": "year",
|
||||
"original_price": 4788,
|
||||
"features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持", "省 ¥789"]},
|
||||
],
|
||||
"gateways": list(GATEWAY_MAP.keys()) or ["unified"],
|
||||
}
|
||||
|
||||
async def get_current_subscription(self, user_id: str) -> Dict[str, Any]:
|
||||
@@ -122,12 +82,8 @@ class PaymentService:
|
||||
).order_by(Subscription.created_at.desc()).limit(1)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
|
||||
result = await self.db.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
return {
|
||||
"plan": user.tier if user else "free",
|
||||
"status": sub.status if sub else "active",
|
||||
@@ -136,124 +92,281 @@ class PaymentService:
|
||||
}
|
||||
|
||||
async def create_order(self, user_id: str, plan: str,
|
||||
pay_type: str = "jsapi") -> Dict[str, Any]:
|
||||
pay_type: str = "alipay") -> Dict[str, Any]:
|
||||
if plan not in PLANS:
|
||||
raise ValueError(f"Invalid plan: {plan}")
|
||||
|
||||
raise ValueError(f"无效套餐: {plan}")
|
||||
plan_info = PLANS[plan]
|
||||
if plan_info["price"] == 0:
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user:
|
||||
user.tier = plan
|
||||
await self.db.flush()
|
||||
return {"status": "ok", "plan": plan, "amount": 0}
|
||||
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
order_id = f"ORD{datetime.utcnow().strftime('%Y%m%d%H%M%S')}{user_id[-6:]}"
|
||||
if plan_info["price"] == 0:
|
||||
user.tier = plan
|
||||
await self.db.flush()
|
||||
return {"status": "ok", "plan": plan, "amount": 0}
|
||||
|
||||
order_no = gen_order_no(user_id)
|
||||
description = PLAN_DESCRIPTIONS.get(plan, f"TradeMate {plan}")
|
||||
|
||||
gw = get_gateway(pay_type)
|
||||
gw_result = await gw.create_order(order_no, int(plan_info["price"] * 100),
|
||||
description, pay_type=pay_type)
|
||||
|
||||
sub = Subscription(
|
||||
user_id=user_id,
|
||||
plan=plan,
|
||||
status="pending",
|
||||
amount=plan_info["price"],
|
||||
payment_id=order_id,
|
||||
user_id=user_id, plan=plan, status="pending",
|
||||
amount=plan_info["price"], payment_id=order_no,
|
||||
payment_provider="unified",
|
||||
)
|
||||
self.db.add(sub)
|
||||
|
||||
txn = PaymentTransaction(
|
||||
user_id=user_id, order_no=order_no, plan=plan,
|
||||
amount=plan_info["price"], gateway="unified", pay_type=pay_type,
|
||||
status="pending", description=description,
|
||||
gateway_order_no=gw_result.get("gateway_order_id", ""),
|
||||
)
|
||||
self.db.add(txn)
|
||||
await self.db.flush()
|
||||
|
||||
wxpay_available = self.wxpay is not None and settings.WECHAT_PAY_NOTIFY_URL not in (
|
||||
"", "https://example.com/api/v1/payment/notify"
|
||||
)
|
||||
|
||||
if wxpay_available:
|
||||
try:
|
||||
if pay_type == "jsapi":
|
||||
openid = user.wechat_openid
|
||||
if not openid:
|
||||
raise ValueError("用户未绑定微信,请在微信小程序中登录后支付")
|
||||
|
||||
wx_result = await self.wxpay.create_jsapi_order(
|
||||
order_id, openid, int(plan_info["price"] * 100), description
|
||||
)
|
||||
prepay_id = wx_result.get("prepay_id", "")
|
||||
pay_params = self.wxpay.build_jsapi_pay_params(prepay_id)
|
||||
return {
|
||||
"status": "pending",
|
||||
"order_id": order_id,
|
||||
"plan": plan,
|
||||
"amount": plan_info["price"],
|
||||
"currency": "CNY",
|
||||
"pay_type": "jsapi",
|
||||
"pay_params": pay_params,
|
||||
}
|
||||
|
||||
elif pay_type == "native":
|
||||
wx_result = await self.wxpay.create_native_order(
|
||||
order_id, int(plan_info["price"] * 100), description
|
||||
)
|
||||
code_url = wx_result.get("code_url", "")
|
||||
return {
|
||||
"status": "pending",
|
||||
"order_id": order_id,
|
||||
"plan": plan,
|
||||
"amount": plan_info["price"],
|
||||
"currency": "CNY",
|
||||
"pay_type": "native",
|
||||
"code_url": code_url,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"WeChat Pay order failed: {e}")
|
||||
raise ValueError(f"支付创建失败: {str(e)}")
|
||||
|
||||
# 开发环境回退:生成模拟支付参数
|
||||
pay_params = {
|
||||
"appId": settings.WECHAT_APP_ID or "",
|
||||
"timeStamp": str(int(datetime.utcnow().timestamp())),
|
||||
"nonceStr": hashlib.md5(order_id.encode()).hexdigest()[:16],
|
||||
"package": f"prepay_id={order_id}",
|
||||
"signType": "MD5",
|
||||
}
|
||||
sign_str = "&".join(f"{k}={v}" for k, v in sorted(pay_params.items()))
|
||||
sign_str += f"&key={settings.SECRET_KEY}"
|
||||
pay_params["paySign"] = hashlib.md5(sign_str.encode()).hexdigest().upper()
|
||||
return {
|
||||
"status": "pending",
|
||||
"order_id": order_id,
|
||||
"order_id": order_no,
|
||||
"plan": plan,
|
||||
"amount": plan_info["price"],
|
||||
"currency": "CNY",
|
||||
"gateway": "unified",
|
||||
"pay_type": pay_type,
|
||||
"pay_params": pay_params,
|
||||
**gw_result,
|
||||
}
|
||||
|
||||
async def handle_payment_callback(self, payment_id: str, success: bool) -> bool:
|
||||
async def handle_callback(self, order_no: str, gateway_order_id: str,
|
||||
gateway_order_no: str, success: bool,
|
||||
amount: float = 0, notify_raw: str = "") -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Subscription).where(Subscription.payment_id == payment_id)
|
||||
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if not sub:
|
||||
txn = result.scalar_one_or_none()
|
||||
if not txn:
|
||||
return False
|
||||
if txn.status != "pending":
|
||||
return True
|
||||
|
||||
if success:
|
||||
sub.status = "active"
|
||||
sub.started_at = datetime.utcnow()
|
||||
sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"])
|
||||
txn.status = "paid"
|
||||
txn.gateway_order_id = gateway_order_id
|
||||
txn.gateway_order_no = gateway_order_no
|
||||
txn.paid_at = datetime.utcnow()
|
||||
txn.notify_raw = notify_raw
|
||||
|
||||
user_result = await self.db.execute(select(User).where(User.id == sub.user_id))
|
||||
sub_result = await self.db.execute(
|
||||
select(Subscription).where(Subscription.payment_id == order_no)
|
||||
)
|
||||
sub = sub_result.scalar_one_or_none()
|
||||
if sub:
|
||||
sub.status = "active"
|
||||
sub.started_at = datetime.utcnow()
|
||||
if PLANS[sub.plan]["duration_days"]:
|
||||
sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"])
|
||||
|
||||
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user:
|
||||
user.tier = sub.plan
|
||||
user.tier = txn.plan
|
||||
else:
|
||||
sub.status = "failed"
|
||||
txn.status = "failed"
|
||||
txn.notify_raw = notify_raw
|
||||
|
||||
sub_result = await self.db.execute(
|
||||
select(Subscription).where(Subscription.payment_id == order_no)
|
||||
)
|
||||
sub = sub_result.scalar_one_or_none()
|
||||
if sub:
|
||||
sub.status = "failed"
|
||||
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
async def query_payment(self, user_id: str, order_no: str) -> Dict[str, Any]:
|
||||
result = await self.db.execute(
|
||||
select(PaymentTransaction).where(
|
||||
PaymentTransaction.order_no == order_no,
|
||||
PaymentTransaction.user_id == user_id,
|
||||
)
|
||||
)
|
||||
txn = result.scalar_one_or_none()
|
||||
if not txn:
|
||||
raise ValueError("订单不存在")
|
||||
return {
|
||||
"order_no": txn.order_no, "plan": txn.plan,
|
||||
"amount": txn.amount, "currency": txn.currency,
|
||||
"gateway": txn.gateway, "pay_type": txn.pay_type,
|
||||
"status": txn.status,
|
||||
"gateway_order_no": txn.gateway_order_no,
|
||||
"paid_at": txn.paid_at.isoformat() if txn.paid_at else None,
|
||||
"refund_amount": txn.refund_amount,
|
||||
"created_at": txn.created_at.isoformat(),
|
||||
}
|
||||
|
||||
payment_service = PaymentService
|
||||
async def list_transactions(self, user_id: str,
|
||||
page: int = 1, size: int = 20) -> Dict[str, Any]:
|
||||
query = select(PaymentTransaction).where(
|
||||
PaymentTransaction.user_id == user_id
|
||||
).order_by(desc(PaymentTransaction.created_at))
|
||||
total_q = select(PaymentTransaction.id).where(
|
||||
PaymentTransaction.user_id == user_id
|
||||
)
|
||||
total_result = await self.db.execute(total_q)
|
||||
total = len(total_result.scalars().all())
|
||||
result = await self.db.execute(query.offset((page - 1) * size).limit(size))
|
||||
items = result.scalars().all()
|
||||
return {
|
||||
"items": [{
|
||||
"order_no": t.order_no, "plan": t.plan,
|
||||
"amount": t.amount, "gateway": t.gateway,
|
||||
"pay_type": t.pay_type, "status": t.status,
|
||||
"created_at": t.created_at.isoformat(),
|
||||
"paid_at": t.paid_at.isoformat() if t.paid_at else None,
|
||||
} for t in items],
|
||||
"total": total, "page": page, "size": size,
|
||||
}
|
||||
|
||||
async def refund(self, user_id: str, order_no: str,
|
||||
reason: str = "") -> Dict[str, Any]:
|
||||
result = await self.db.execute(
|
||||
select(PaymentTransaction).where(
|
||||
PaymentTransaction.order_no == order_no,
|
||||
PaymentTransaction.user_id == user_id,
|
||||
)
|
||||
)
|
||||
txn = result.scalar_one_or_none()
|
||||
if not txn:
|
||||
raise ValueError("订单不存在")
|
||||
if txn.status != "paid":
|
||||
raise ValueError("只有已支付订单可退款")
|
||||
if txn.refund_amount >= txn.amount:
|
||||
raise ValueError("该订单已全额退款")
|
||||
|
||||
gw = get_gateway(txn.pay_type)
|
||||
remaining = int((txn.amount - txn.refund_amount) * 100)
|
||||
try:
|
||||
gw_result = await gw.refund(txn.order_no, remaining, reason)
|
||||
logger.info(f"Refund {txn.order_no}: {gw_result}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"退款请求失败: {e}")
|
||||
|
||||
txn.status = "refunded"
|
||||
txn.refund_amount = txn.amount
|
||||
txn.refund_reason = reason
|
||||
txn.refunded_at = datetime.utcnow()
|
||||
|
||||
sub_result = await self.db.execute(
|
||||
select(Subscription).where(Subscription.payment_id == order_no)
|
||||
)
|
||||
sub = sub_result.scalar_one_or_none()
|
||||
if sub:
|
||||
sub.status = "expired"
|
||||
|
||||
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user and user.tier == txn.plan:
|
||||
user.tier = "free"
|
||||
|
||||
await self.db.flush()
|
||||
return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount}
|
||||
|
||||
async def admin_list_payments(self, page: int = 1, size: int = 20,
|
||||
gateway: str = "", status: str = "",
|
||||
user_id: str = "") -> Dict[str, Any]:
|
||||
query = select(PaymentTransaction).order_by(desc(PaymentTransaction.created_at))
|
||||
count_query = select(PaymentTransaction.id)
|
||||
if gateway:
|
||||
query = query.where(PaymentTransaction.gateway == gateway)
|
||||
count_query = count_query.where(PaymentTransaction.gateway == gateway)
|
||||
if status:
|
||||
query = query.where(PaymentTransaction.status == status)
|
||||
count_query = count_query.where(PaymentTransaction.status == status)
|
||||
if user_id:
|
||||
query = query.where(PaymentTransaction.user_id == user_id)
|
||||
count_query = count_query.where(PaymentTransaction.user_id == user_id)
|
||||
|
||||
total_result = await self.db.execute(count_query)
|
||||
total = len(total_result.scalars().all())
|
||||
result = await self.db.execute(query.offset((page - 1) * size).limit(size))
|
||||
items = result.scalars().all()
|
||||
return {
|
||||
"items": [{
|
||||
"id": str(t.id), "user_id": str(t.user_id),
|
||||
"order_no": t.order_no, "plan": t.plan,
|
||||
"amount": t.amount, "gateway": t.gateway,
|
||||
"pay_type": t.pay_type, "status": t.status,
|
||||
"gateway_order_no": t.gateway_order_no,
|
||||
"refund_amount": t.refund_amount,
|
||||
"created_at": t.created_at.isoformat(),
|
||||
"paid_at": t.paid_at.isoformat() if t.paid_at else None,
|
||||
"refunded_at": t.refunded_at.isoformat() if t.refunded_at else None,
|
||||
} for t in items],
|
||||
"total": total, "page": page, "size": size,
|
||||
}
|
||||
|
||||
async def admin_refund(self, order_no: str, reason: str = "") -> Dict[str, Any]:
|
||||
result = await self.db.execute(
|
||||
select(PaymentTransaction).where(PaymentTransaction.order_no == order_no)
|
||||
)
|
||||
txn = result.scalar_one_or_none()
|
||||
if not txn:
|
||||
raise ValueError("订单不存在")
|
||||
if txn.status != "paid":
|
||||
raise ValueError("只有已支付订单可退款")
|
||||
|
||||
gw = get_gateway(txn.pay_type)
|
||||
remaining = int((txn.amount - txn.refund_amount) * 100)
|
||||
try:
|
||||
gw_result = await gw.refund(txn.order_no, remaining, reason)
|
||||
logger.info(f"Admin refund {txn.order_no}: {gw_result}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"退款请求失败: {e}")
|
||||
|
||||
txn.status = "refunded"
|
||||
txn.refund_amount = txn.amount
|
||||
txn.refund_reason = reason
|
||||
txn.refunded_at = datetime.utcnow()
|
||||
|
||||
sub_result = await self.db.execute(
|
||||
select(Subscription).where(Subscription.payment_id == order_no)
|
||||
)
|
||||
sub = sub_result.scalar_one_or_none()
|
||||
if sub:
|
||||
sub.status = "expired"
|
||||
|
||||
user_result = await self.db.execute(select(User).where(User.id == txn.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user and user.tier == txn.plan:
|
||||
user.tier = "free"
|
||||
|
||||
await self.db.flush()
|
||||
return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount,
|
||||
"user_id": str(txn.user_id)}
|
||||
|
||||
async def admin_payment_stats(self) -> Dict[str, Any]:
|
||||
all_txns = await self.db.execute(select(PaymentTransaction))
|
||||
rows = all_txns.scalars().all()
|
||||
total_count = len(rows)
|
||||
total_revenue = sum(r.amount for r in rows if r.status == "paid")
|
||||
total_refund = sum(r.refund_amount for r in rows)
|
||||
paid_count = sum(1 for r in rows if r.status == "paid")
|
||||
pending_count = sum(1 for r in rows if r.status == "pending")
|
||||
refunded_count = sum(1 for r in rows if r.status == "refunded")
|
||||
failed_count = sum(1 for r in rows if r.status == "failed")
|
||||
wechat_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "wechat")
|
||||
alipay_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "alipay")
|
||||
return {
|
||||
"total_count": total_count, "total_revenue": total_revenue,
|
||||
"total_refund": total_refund, "paid_count": paid_count,
|
||||
"pending_count": pending_count, "refunded_count": refunded_count,
|
||||
"failed_count": failed_count, "wechat_count": wechat_count,
|
||||
"alipay_count": alipay_count,
|
||||
}
|
||||
|
||||
|
||||
init_gateways()
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class PaymentGateway(ABC):
|
||||
name: str = ""
|
||||
|
||||
@abstractmethod
|
||||
async def create_order(self, order_no: str, amount: int, description: str,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def query_order(self, order_no: str) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def refund(self, order_no: str, amount: int, reason: str = "") -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def query_refund(self, order_no: str) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def verify_callback(self, headers: dict, body: str) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def parse_callback(self, body: str, headers: dict) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
def supports(self, pay_type: str) -> bool:
|
||||
return pay_type in self.supported_types
|
||||
|
||||
supported_types: list = []
|
||||
@@ -0,0 +1,117 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
import httpx
|
||||
from app.config import settings
|
||||
from app.services.payment_gateway import PaymentGateway
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMPTY_SHA256 = hashlib.sha256(b"").hexdigest()
|
||||
|
||||
|
||||
def _hmac_sign(method: str, path: str, body: dict, api_secret: str) -> str:
|
||||
timestamp = str(int(time.time()))
|
||||
body_sha256 = hashlib.sha256(
|
||||
json.dumps(body, ensure_ascii=False, separators=(",", ":")).encode()
|
||||
).hexdigest()
|
||||
sign_str = f"{method}\n{path}\n{timestamp}\n{body_sha256}"
|
||||
signature = hmac.new(
|
||||
api_secret.encode(), sign_str.encode(), hashlib.sha256
|
||||
).hexdigest()
|
||||
return f"{timestamp}:{signature}"
|
||||
|
||||
|
||||
def _auth_header(api_key: str, api_secret: str, method: str, path: str, body: dict) -> str:
|
||||
ts_sig = _hmac_sign(method, path, body, api_secret)
|
||||
return f"PAY {api_key}:{ts_sig}"
|
||||
|
||||
|
||||
class UnifiedPayService(PaymentGateway):
|
||||
name = "unified"
|
||||
supported_types = ["alipay", "wechat"]
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = settings.PAY_API_KEY or ""
|
||||
self.api_secret = settings.PAY_API_SECRET or ""
|
||||
self.base_url = settings.PAY_API_BASE_URL
|
||||
self.webhook_url = settings.PAY_WEBHOOK_URL
|
||||
|
||||
def _headers(self, method: str, path: str, body: dict) -> dict:
|
||||
auth = _auth_header(self.api_key, self.api_secret, method, path, body)
|
||||
return {"Authorization": auth, "Content-Type": "application/json"}
|
||||
|
||||
async def _request(self, method: str, path: str, body: dict = None) -> Dict[str, Any]:
|
||||
body = body or {}
|
||||
url = f"{self.base_url}{path}"
|
||||
headers = self._headers(method, path, body)
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.request(method=method, url=url, json=body, headers=headers)
|
||||
result = resp.json()
|
||||
if result.get("code") != 0:
|
||||
raise ValueError(f"支付网关错误: {result.get('message', 'unknown')}")
|
||||
return result.get("data", {})
|
||||
|
||||
async def create_order(self, order_no: str, amount: int, description: str,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
payment_method = kwargs.get("pay_type", "alipay")
|
||||
if payment_method == "native":
|
||||
payment_method = "wechat"
|
||||
elif payment_method == "jsapi":
|
||||
payment_method = "wechat"
|
||||
elif payment_method == "pc":
|
||||
payment_method = "alipay"
|
||||
body = {
|
||||
"merchant_order_id": order_no,
|
||||
"amount": amount / 100,
|
||||
"payment_method": payment_method,
|
||||
"subject": description or "TradeMate 会员充值",
|
||||
"notify_url": self.webhook_url,
|
||||
}
|
||||
result = await self._request("POST", "/v1/pay/orders", body)
|
||||
out = {
|
||||
"gateway_order_id": result.get("gateway_order_id", ""),
|
||||
"merchant_order_id": result.get("merchant_order_id", order_no),
|
||||
"amount": result.get("amount", amount / 100),
|
||||
"payment_method": payment_method,
|
||||
"status": result.get("status", "pending"),
|
||||
}
|
||||
if payment_method == "alipay":
|
||||
out["pay_url"] = result.get("pay_url", "")
|
||||
else:
|
||||
out["code_url"] = result.get("qrcode", "")
|
||||
return out
|
||||
|
||||
async def query_order(self, order_no: str) -> Dict[str, Any]:
|
||||
return await self._request("GET", f"/v1/pay/orders/{order_no}")
|
||||
|
||||
async def refund(self, order_no: str, amount: int, reason: str = "") -> Dict[str, Any]:
|
||||
body = {
|
||||
"merchant_order_id": order_no,
|
||||
"amount": amount / 100,
|
||||
"reason": reason or "用户申请退款",
|
||||
}
|
||||
return await self._request("POST", "/v1/pay/refunds", body)
|
||||
|
||||
async def query_refund(self, order_no: str) -> Dict[str, Any]:
|
||||
return await self._request("GET", f"/v1/pay/refunds/{order_no}")
|
||||
|
||||
def verify_callback(self, headers: dict, body: str) -> bool:
|
||||
return True
|
||||
|
||||
def parse_callback(self, body: str, headers: dict) -> Dict[str, Any]:
|
||||
data = json.loads(body)
|
||||
event = data.get("event", "")
|
||||
payload = data.get("data", {})
|
||||
return {
|
||||
"event": event,
|
||||
"order_no": payload.get("merchant_order_id", ""),
|
||||
"gateway_order_id": payload.get("order_id", ""),
|
||||
"gateway_order_no": payload.get("transaction_id", ""),
|
||||
"amount": payload.get("amount", 0),
|
||||
"success": event == "recharge.completed",
|
||||
"raw": payload,
|
||||
}
|
||||
@@ -1,181 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import uuid
|
||||
import base64
|
||||
from typing import Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives import serialization, hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeChatPayService:
|
||||
def __init__(self):
|
||||
self.mch_id = settings.WECHAT_PAY_MCH_ID
|
||||
self.api_key = settings.WECHAT_PAY_API_KEY
|
||||
self.serial_no = settings.WECHAT_PAY_SERIAL_NO
|
||||
self.app_id = settings.WECHAT_APP_ID
|
||||
self.api_base = settings.WECHAT_PAY_API_BASE
|
||||
self.notify_url = settings.WECHAT_PAY_NOTIFY_URL
|
||||
self._private_key = None
|
||||
|
||||
def _load_private_key(self) -> bytes:
|
||||
if self._private_key:
|
||||
return self._private_key
|
||||
cert_dir = Path(settings.WECHAT_PAY_CERT_DIR)
|
||||
key_path = cert_dir / "apiclient_key.pem"
|
||||
if not key_path.exists():
|
||||
key_path = Path("/root/hermes-workspace/projects/微信支付key/key/apiclient_key.pem")
|
||||
with open(key_path, "rb") as f:
|
||||
self._private_key = f.read()
|
||||
return self._private_key
|
||||
|
||||
def _sign_rsa(self, sign_str: str) -> str:
|
||||
private_key_data = self._load_private_key()
|
||||
key = serialization.load_pem_private_key(
|
||||
private_key_data, password=None, backend=default_backend()
|
||||
)
|
||||
signature = key.sign(
|
||||
sign_str.encode("utf-8"),
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
return base64.b64encode(signature).decode("utf-8")
|
||||
|
||||
def _build_auth_header(self, method: str, path: str, body: str = "") -> str:
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = uuid.uuid4().hex[:16]
|
||||
sign_str = f"{method}\n{path}\n{timestamp}\n{nonce}\n{body}\n"
|
||||
signature = self._sign_rsa(sign_str)
|
||||
return (
|
||||
f'WECHATPAY2-SHA256-RSA2048 '
|
||||
f'mchid="{self.mch_id}",'
|
||||
f'nonce_str="{nonce}",'
|
||||
f'timestamp="{timestamp}",'
|
||||
f'serial_no="{self.serial_no}",'
|
||||
f'signature="{signature}"'
|
||||
)
|
||||
|
||||
async def _request(self, method: str, path: str, body: Optional[dict] = None) -> Dict[str, Any]:
|
||||
url = f"{self.api_base}{path}"
|
||||
body_str = json.dumps(body, ensure_ascii=False, separators=(",", ":")) if body else ""
|
||||
auth = self._build_auth_header(method, path, body_str)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
content=body_str if body else None,
|
||||
headers={
|
||||
"Authorization": auth,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "TradeMate/1.0",
|
||||
},
|
||||
)
|
||||
data = resp.json() if resp.text else {}
|
||||
if resp.status_code >= 400:
|
||||
logger.error(f"WeChat Pay API error: {resp.status_code} {data}")
|
||||
raise Exception(f"WeChat Pay error: {data.get('message', resp.text)}")
|
||||
return data
|
||||
|
||||
async def create_jsapi_order(self, out_trade_no: str, openid: str,
|
||||
total: int, description: str) -> Dict[str, Any]:
|
||||
path = "/v3/pay/transactions/jsapi"
|
||||
body = {
|
||||
"appid": self.app_id,
|
||||
"mchid": self.mch_id,
|
||||
"description": description,
|
||||
"out_trade_no": out_trade_no,
|
||||
"notify_url": self.notify_url,
|
||||
"amount": {"total": total, "currency": "CNY"},
|
||||
"payer": {"openid": openid},
|
||||
}
|
||||
return await self._request("POST", path, body)
|
||||
|
||||
async def create_native_order(self, out_trade_no: str, total: int,
|
||||
description: str) -> Dict[str, Any]:
|
||||
path = "/v3/pay/transactions/native"
|
||||
body = {
|
||||
"appid": self.app_id,
|
||||
"mchid": self.mch_id,
|
||||
"description": description,
|
||||
"out_trade_no": out_trade_no,
|
||||
"notify_url": self.notify_url,
|
||||
"amount": {"total": total, "currency": "CNY"},
|
||||
}
|
||||
return await self._request("POST", path, body)
|
||||
|
||||
async def query_order(self, out_trade_no: str) -> Dict[str, Any]:
|
||||
path = f"/v3/pay/transactions/out-trade-no/{out_trade_no}?mchid={self.mch_id}"
|
||||
return await self._request("GET", path)
|
||||
|
||||
async def close_order(self, out_trade_no: str):
|
||||
path = f"/v3/pay/transactions/out-trade-no/{out_trade_no}/close"
|
||||
body = {"mchid": self.mch_id}
|
||||
await self._request("POST", path, body)
|
||||
|
||||
def build_jsapi_pay_params(self, prepay_id: str) -> Dict[str, str]:
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = uuid.uuid4().hex[:16]
|
||||
package = f"prepay_id={prepay_id}"
|
||||
sign_str = f"{self.app_id}\n{timestamp}\n{nonce}\n{package}\n"
|
||||
pay_sign = self._sign_rsa(sign_str)
|
||||
|
||||
return {
|
||||
"appId": self.app_id,
|
||||
"timeStamp": timestamp,
|
||||
"nonceStr": nonce,
|
||||
"package": package,
|
||||
"signType": "RSA",
|
||||
"paySign": pay_sign,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def verify_callback(headers: dict, body: str) -> bool:
|
||||
wechatpay_signature = headers.get("wechatpay-signature", "")
|
||||
wechatpay_timestamp = headers.get("wechatpay-timestamp", "")
|
||||
wechatpay_nonce = headers.get("wechatpay-nonce", "")
|
||||
wechatpay_serial = headers.get("wechatpay-serial", "")
|
||||
|
||||
if not all([wechatpay_signature, wechatpay_timestamp, wechatpay_nonce, wechatpay_serial]):
|
||||
logger.warning("Missing WeChat Pay callback headers")
|
||||
return False
|
||||
|
||||
sign_str = f"{wechatpay_timestamp}\n{wechatpay_nonce}\n{body}\n"
|
||||
try:
|
||||
cert_dir = Path(settings.WECHAT_PAY_CERT_DIR)
|
||||
cert_path = cert_dir / "pub_key.pem"
|
||||
if not cert_path.exists():
|
||||
cert_path = Path("/root/hermes-workspace/projects/微信支付key/key/pub_key.pem")
|
||||
with open(cert_path, "rb") as f:
|
||||
cert_data = f.read()
|
||||
public_key = serialization.load_pem_public_key(cert_data, backend=default_backend())
|
||||
signature_bytes = base64.b64decode(wechatpay_signature)
|
||||
public_key.verify(
|
||||
signature_bytes,
|
||||
sign_str.encode("utf-8"),
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"WeChat Pay callback verification failed: {e}")
|
||||
return False
|
||||
|
||||
def decrypt_callback(self, ciphertext: str, nonce: str,
|
||||
associated_data: str) -> str:
|
||||
key_bytes = self.api_key.encode("utf-8")
|
||||
nonce_bytes = base64.b64decode(nonce) if nonce else b""
|
||||
associated_bytes = associated_data.encode("utf-8")
|
||||
ciphertext_bytes = base64.b64decode(ciphertext)
|
||||
|
||||
aesgcm = AESGCM(key_bytes)
|
||||
plaintext = aesgcm.decrypt(nonce_bytes, ciphertext_bytes, associated_bytes)
|
||||
return plaintext.decode("utf-8")
|
||||
Reference in New Issue
Block a user