import json import logging import hashlib from typing import Optional, Dict, Any, List from datetime import datetime, timedelta from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, desc from app.models.subscription import Subscription from app.models.payment_transaction import PaymentTransaction from app.models.user import User from app.config import settings from app.services.unified_pay import UnifiedPayService from app.services.payment_gateway import PaymentGateway logger = logging.getLogger(__name__) PLANS = { "free": {"price": 0, "duration_days": None}, "pro": {"price": 99, "duration_days": 30}, "pro_yearly": {"price": 999, "duration_days": 365}, "enterprise": {"price": 399, "duration_days": 30}, "enterprise_yearly": {"price": 3999, "duration_days": 365}, } PLAN_DESCRIPTIONS = { "pro": "TradeMate Pro 版会员", "pro_yearly": "TradeMate Pro 版会员(年付)", "enterprise": "TradeMate 企业版会员", "enterprise_yearly": "TradeMate 企业版会员(年付)", } GATEWAY_MAP: Dict[str, PaymentGateway] = {} def init_gateways(): if settings.PAY_API_KEY: GATEWAY_MAP["unified"] = UnifiedPayService() def get_gateway(pay_type: str) -> PaymentGateway: gw = GATEWAY_MAP.get("unified") if not gw: raise ValueError("支付网关未配置,请设置 PAY_API_KEY") if not gw.supports(pay_type): raise ValueError(f"支付方式 {pay_type} 不被支持(仅支持 alipay/wechat)") return gw def gen_order_no(user_id: str) -> str: ts = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")[:18] suffix = user_id[-8:] if len(user_id) >= 8 else user_id return f"TM{ts}{suffix}" class PaymentService: def __init__(self, db: AsyncSession): self.db = db async def get_plans(self) -> Dict[str, Any]: return { "plans": [ {"id": "free", "name": "免费版", "price": 0, "period": "month", "features": ["1 个产品", "20 次翻译/天", "5 个客户", "基础回复建议"]}, {"id": "pro", "name": "Pro 版", "price": 99, "period": "month", "features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成"]}, {"id": "pro_yearly", "name": "Pro 版(年付)", "price": 999, "period": "year", "original_price": 1188, "features": ["10 个产品", "无限翻译", "50 个客户", "跟进提醒", "报价单生成", "省 ¥189"]}, {"id": "enterprise", "name": "企业版", "price": 399, "period": "month", "features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持"]}, {"id": "enterprise_yearly", "name": "企业版(年付)", "price": 3999, "period": "year", "original_price": 4788, "features": ["无限产品/客户", "团队协作", "品牌报价单", "专属语料训练", "API 接入", "优先支持", "省 ¥789"]}, ], "gateways": list(GATEWAY_MAP.keys()) or ["unified"], } async def get_current_subscription(self, user_id: str) -> Dict[str, Any]: result = await self.db.execute( select(Subscription).where( Subscription.user_id == user_id, Subscription.status == "active", ).order_by(Subscription.created_at.desc()).limit(1) ) sub = result.scalar_one_or_none() result = await self.db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() return { "plan": user.tier if user else "free", "status": sub.status if sub else "active", "expires_at": sub.expires_at.isoformat() if sub and sub.expires_at else None, "auto_renew": sub.auto_renew if sub else False, } async def create_order(self, user_id: str, plan: str, pay_type: str = "alipay") -> Dict[str, Any]: if plan not in PLANS: raise ValueError(f"无效套餐: {plan}") plan_info = PLANS[plan] result = await self.db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user: raise ValueError("用户不存在") if plan_info["price"] == 0: user.tier = plan await self.db.flush() return {"status": "ok", "plan": plan, "amount": 0} order_no = gen_order_no(user_id) description = PLAN_DESCRIPTIONS.get(plan, f"TradeMate {plan}") remark = json.dumps({"uid": user_id, "oid": order_no}, ensure_ascii=False, separators=(",", ":")) gw = get_gateway(pay_type) gw_result = await gw.create_order(order_no, int(plan_info["price"] * 100), description, pay_type=pay_type, remark=remark) sub = Subscription( user_id=user_id, plan=plan, status="pending", amount=plan_info["price"], payment_id=order_no, payment_provider="unified", ) self.db.add(sub) txn = PaymentTransaction( user_id=user_id, order_no=order_no, plan=plan, amount=plan_info["price"], gateway="unified", pay_type=pay_type, status="pending", description=description, gateway_order_no=gw_result.get("gateway_order_id", ""), ) self.db.add(txn) await self.db.flush() return { "status": "pending", "order_id": order_no, "plan": plan, "amount": plan_info["price"], "currency": "CNY", "gateway": "unified", "pay_type": pay_type, **gw_result, } async def handle_callback(self, order_no: str, gateway_order_id: str, gateway_order_no: str, success: bool, amount: float = 0, notify_raw: str = "") -> bool: result = await self.db.execute( select(PaymentTransaction).where(PaymentTransaction.order_no == order_no) ) txn = result.scalar_one_or_none() if not txn: return False if txn.status != "pending": return True if success: txn.status = "paid" txn.gateway_order_id = gateway_order_id txn.gateway_order_no = gateway_order_no txn.paid_at = datetime.utcnow() txn.notify_raw = notify_raw sub_result = await self.db.execute( select(Subscription).where(Subscription.payment_id == order_no) ) sub = sub_result.scalar_one_or_none() if sub: sub.status = "active" sub.started_at = datetime.utcnow() if PLANS[sub.plan]["duration_days"]: sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"]) user_result = await self.db.execute(select(User).where(User.id == txn.user_id)) user = user_result.scalar_one_or_none() if user: user.tier = txn.plan else: txn.status = "failed" txn.notify_raw = notify_raw sub_result = await self.db.execute( select(Subscription).where(Subscription.payment_id == order_no) ) sub = sub_result.scalar_one_or_none() if sub: sub.status = "failed" await self.db.flush() return True async def query_payment(self, user_id: str, order_no: str) -> Dict[str, Any]: result = await self.db.execute( select(PaymentTransaction).where( PaymentTransaction.order_no == order_no, PaymentTransaction.user_id == user_id, ) ) txn = result.scalar_one_or_none() if not txn: raise ValueError("订单不存在") return { "order_no": txn.order_no, "plan": txn.plan, "amount": txn.amount, "currency": txn.currency, "gateway": txn.gateway, "pay_type": txn.pay_type, "status": txn.status, "gateway_order_no": txn.gateway_order_no, "paid_at": txn.paid_at.isoformat() if txn.paid_at else None, "refund_amount": txn.refund_amount, "created_at": txn.created_at.isoformat(), } async def close_order(self, user_id: str, order_no: str) -> Dict[str, Any]: result = await self.db.execute( select(PaymentTransaction).where( PaymentTransaction.order_no == order_no, PaymentTransaction.user_id == user_id, ) ) txn = result.scalar_one_or_none() if not txn: raise ValueError("订单不存在") if txn.status != "pending": raise ValueError("只有待支付订单可关闭") gw = get_gateway(txn.pay_type) await gw.close_order(order_no) txn.status = "closed" await self.db.flush() return {"status": "ok", "order_no": order_no} async def query_refund(self, order_no: str, user_id: str = "") -> Dict[str, Any]: query = select(PaymentTransaction).where(PaymentTransaction.order_no == order_no) if user_id: query = query.where(PaymentTransaction.user_id == user_id) result = await self.db.execute(query) txn = result.scalar_one_or_none() if not txn: raise ValueError("订单不存在") if txn.status != "refunded": raise ValueError("该订单未退款") gw = get_gateway(txn.pay_type) gw_result = await gw.query_refund(order_no) return { "order_no": order_no, "status": txn.status, "refund_amount": txn.refund_amount, "refund_reason": txn.refund_reason, "refunded_at": txn.refunded_at.isoformat() if txn.refunded_at else None, "gateway": gw_result, } async def list_transactions(self, user_id: str, page: int = 1, size: int = 20) -> Dict[str, Any]: query = select(PaymentTransaction).where( PaymentTransaction.user_id == user_id ).order_by(desc(PaymentTransaction.created_at)) total_q = select(PaymentTransaction.id).where( PaymentTransaction.user_id == user_id ) total_result = await self.db.execute(total_q) total = len(total_result.scalars().all()) result = await self.db.execute(query.offset((page - 1) * size).limit(size)) items = result.scalars().all() return { "items": [{ "order_no": t.order_no, "plan": t.plan, "amount": t.amount, "gateway": t.gateway, "pay_type": t.pay_type, "status": t.status, "created_at": t.created_at.isoformat(), "paid_at": t.paid_at.isoformat() if t.paid_at else None, } for t in items], "total": total, "page": page, "size": size, } async def refund(self, user_id: str, order_no: str, reason: str = "") -> Dict[str, Any]: result = await self.db.execute( select(PaymentTransaction).where( PaymentTransaction.order_no == order_no, PaymentTransaction.user_id == user_id, ) ) txn = result.scalar_one_or_none() if not txn: raise ValueError("订单不存在") if txn.status != "paid": raise ValueError("只有已支付订单可退款") if txn.refund_amount >= txn.amount: raise ValueError("该订单已全额退款") gw = get_gateway(txn.pay_type) remaining = int((txn.amount - txn.refund_amount) * 100) try: gw_result = await gw.refund(txn.order_no, remaining, reason) logger.info(f"Refund {txn.order_no}: {gw_result}") except Exception as e: raise ValueError(f"退款请求失败: {e}") txn.status = "refunded" txn.refund_amount = txn.amount txn.refund_reason = reason txn.refunded_at = datetime.utcnow() sub_result = await self.db.execute( select(Subscription).where(Subscription.payment_id == order_no) ) sub = sub_result.scalar_one_or_none() if sub: sub.status = "expired" user_result = await self.db.execute(select(User).where(User.id == txn.user_id)) user = user_result.scalar_one_or_none() if user and user.tier == txn.plan: user.tier = "free" await self.db.flush() return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount} async def admin_list_payments(self, page: int = 1, size: int = 20, gateway: str = "", status: str = "", user_id: str = "", pay_type: str = "") -> Dict[str, Any]: query = select(PaymentTransaction).order_by(desc(PaymentTransaction.created_at)) count_query = select(PaymentTransaction.id) if gateway: query = query.where(PaymentTransaction.gateway == gateway) count_query = count_query.where(PaymentTransaction.gateway == gateway) if status: query = query.where(PaymentTransaction.status == status) count_query = count_query.where(PaymentTransaction.status == status) if user_id: query = query.where(PaymentTransaction.user_id == user_id) count_query = count_query.where(PaymentTransaction.user_id == user_id) if pay_type: query = query.where(PaymentTransaction.pay_type == pay_type) count_query = count_query.where(PaymentTransaction.pay_type == pay_type) total_result = await self.db.execute(count_query) total = len(total_result.scalars().all()) result = await self.db.execute(query.offset((page - 1) * size).limit(size)) items = result.scalars().all() return { "items": [{ "id": str(t.id), "user_id": str(t.user_id), "order_no": t.order_no, "plan": t.plan, "amount": t.amount, "gateway": t.gateway, "pay_type": t.pay_type, "status": t.status, "gateway_order_no": t.gateway_order_no, "refund_amount": t.refund_amount, "created_at": t.created_at.isoformat(), "paid_at": t.paid_at.isoformat() if t.paid_at else None, "refunded_at": t.refunded_at.isoformat() if t.refunded_at else None, } for t in items], "total": total, "page": page, "size": size, } async def admin_refund(self, order_no: str, reason: str = "") -> Dict[str, Any]: result = await self.db.execute( select(PaymentTransaction).where(PaymentTransaction.order_no == order_no) ) txn = result.scalar_one_or_none() if not txn: raise ValueError("订单不存在") if txn.status != "paid": raise ValueError("只有已支付订单可退款") gw = get_gateway(txn.pay_type) remaining = int((txn.amount - txn.refund_amount) * 100) try: gw_result = await gw.refund(txn.order_no, remaining, reason) logger.info(f"Admin refund {txn.order_no}: {gw_result}") except Exception as e: raise ValueError(f"退款请求失败: {e}") txn.status = "refunded" txn.refund_amount = txn.amount txn.refund_reason = reason txn.refunded_at = datetime.utcnow() sub_result = await self.db.execute( select(Subscription).where(Subscription.payment_id == order_no) ) sub = sub_result.scalar_one_or_none() if sub: sub.status = "expired" user_result = await self.db.execute(select(User).where(User.id == txn.user_id)) user = user_result.scalar_one_or_none() if user and user.tier == txn.plan: user.tier = "free" await self.db.flush() return {"status": "ok", "order_no": order_no, "refund_amount": txn.amount, "user_id": str(txn.user_id)} async def admin_close_order(self, order_no: str) -> Dict[str, Any]: result = await self.db.execute( select(PaymentTransaction).where(PaymentTransaction.order_no == order_no) ) txn = result.scalar_one_or_none() if not txn: raise ValueError("订单不存在") if txn.status != "pending": raise ValueError("只有待支付订单可关闭") gw = get_gateway(txn.pay_type) await gw.close_order(order_no) txn.status = "closed" await self.db.flush() return {"status": "ok", "order_no": order_no} async def admin_payment_stats(self) -> Dict[str, Any]: all_txns = await self.db.execute(select(PaymentTransaction)) rows = all_txns.scalars().all() total_count = len(rows) total_revenue = sum(r.amount for r in rows if r.status == "paid") total_refund = sum(r.refund_amount for r in rows) paid_count = sum(1 for r in rows if r.status == "paid") pending_count = sum(1 for r in rows if r.status == "pending") refunded_count = sum(1 for r in rows if r.status == "refunded") failed_count = sum(1 for r in rows if r.status == "failed") wechat_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "wechat") alipay_count = sum(1 for r in rows if r.gateway == "unified" and r.pay_type == "alipay") return { "total_count": total_count, "total_revenue": total_revenue, "total_refund": total_refund, "paid_count": paid_count, "pending_count": pending_count, "refunded_count": refunded_count, "failed_count": failed_count, "wechat_count": wechat_count, "alipay_count": alipay_count, } init_gateways()