from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, desc from datetime import datetime, date from decimal import Decimal import logging from app.models import UserCredit, CreditConsumption, CreditPackage, SubscriptionPlan, CreditPurchase from app.models.system_config import SystemConfig logger = logging.getLogger(__name__) DEFAULT_CONSUMPTION_RATES = { "lead_search": 10, "company_analysis": 5, "market_intel": 20, "translate_per_1000chars": 1, "reply_suggest": 2, "outreach": 3, "marketing_content": 5, "competitor_analysis": 10, "ai_chat_per_10msg": 1, "info_extract": 1, "quotation": 2, "followup_scan": 2, } FREE_TRIAL_CREDITS = 30 DAILY_FREE_TRANSLATE_CHARS = 1000 class CreditService: def __init__(self, db: AsyncSession): self.db = db async def _ensure_credit(self, user_id: str) -> UserCredit: result = await self.db.execute( select(UserCredit).where(UserCredit.user_id == user_id) ) uc = result.scalar_one_or_none() if not uc: uc = UserCredit(user_id=user_id, balance=0) self.db.add(uc) await self.db.flush() return uc async def get_balance(self, user_id: str) -> dict: uc = await self._ensure_credit(user_id) rates = await self._get_rates() return { "balance": uc.balance, "total_purchased": uc.total_purchased, "total_used": uc.total_used, "subscription": { "plan_id": str(uc.subscription_plan_id) if uc.subscription_plan_id else None, "expires_at": uc.subscription_expires_at.isoformat() if uc.subscription_expires_at else None, "auto_renew": uc.subscription_auto_renew, } if uc.subscription_plan_id else None, "free_trial_used": uc.free_trial_used, "daily_free_translate_chars_left": max(0, DAILY_FREE_TRANSLATE_CHARS - await self._daily_translate_chars(uc)), "rates": rates, } async def deduct(self, user_id: str, result_type: str, reference_id: str = None, amount: float = None, metadata: dict = None) -> tuple[bool, float]: rates = await self._get_rates() cost = amount or rates.get(result_type, 1) uc = await self._ensure_credit(user_id) if result_type == "translate": char_count = (metadata or {}).get("chars", 0) if char_count > 0: daily_free = await self._daily_translate_chars(uc) free_remaining = max(0, DAILY_FREE_TRANSLATE_CHARS - daily_free) free_used = min(free_remaining, char_count) paid_chars = char_count - free_used cost = (paid_chars / 1000) * rates.get("translate_per_1000chars", 1) if free_used > 0: today = date.today() if uc.daily_translate_date != today: uc.daily_translate_date = today uc.daily_translate_chars = 0 uc.daily_translate_chars += free_used await self.db.flush() if cost <= 0: await self._log(user_id, result_type, reference_id, 0, uc.balance, "daily_free", metadata) return True, uc.balance if uc.balance < cost: return False, uc.balance uc.balance -= cost uc.total_used += cost balance_after = uc.balance await self._log(user_id, result_type, reference_id, -cost, balance_after, "credit", metadata) await self.db.flush() return True, balance_after async def add_credits(self, user_id: str, credits: float, source: str, description: str = None) -> float: uc = await self._ensure_credit(user_id) uc.balance += credits if credits > 0: uc.total_purchased += credits balance_after = uc.balance await self._log(user_id, "topup", None, credits, balance_after, source, {"description": description}) await self.db.flush() return balance_after async def grant_free_trial(self, user_id: str) -> float: uc = await self._ensure_credit(user_id) if uc.free_trial_used: return uc.balance return await self.add_credits( user_id, FREE_TRIAL_CREDITS, "free_trial", f"新用户注册赠送 {FREE_TRIAL_CREDITS} 次" ) async def consume_for_subscription(self, user_id: str, plan_id: str) -> tuple[bool, str]: result = await self.db.execute( select(SubscriptionPlan).where(SubscriptionPlan.id == plan_id, SubscriptionPlan.is_active == True) ) plan = result.scalar_one_or_none() if not plan: return False, "套餐不存在" uc = await self._ensure_credit(user_id) amount = plan.price return True, "ok" async def _log(self, user_id: str, result_type: str, reference_id: str, credits_change: float, balance_after: float, source: str, metadata: dict = None): log = CreditConsumption( user_id=user_id, result_type=result_type, reference_id=reference_id, credits_change=credits_change, balance_after=balance_after, source=source, metadata_=metadata or {}, ) self.db.add(log) async def get_history(self, user_id: str, page: int = 1, size: int = 20) -> dict: offset = (page - 1) * size stmt = select(CreditConsumption).where( CreditConsumption.user_id == user_id ).order_by(desc(CreditConsumption.created_at)).offset(offset).limit(size) result = await self.db.execute(stmt) items = result.scalars().all() count_stmt = select(func.count()).where(CreditConsumption.user_id == user_id) count_result = await self.db.execute(count_stmt) total = count_result.scalar() or 0 return { "items": [{ "id": str(item.id), "result_type": item.result_type, "credits_change": item.credits_change, "balance_after": item.balance_after, "source": item.source, "description": item.description, "created_at": item.created_at.isoformat() if item.created_at else None, } for item in items], "total": total, "page": page, "size": size, } async def _get_rates(self) -> dict: result = await self.db.execute( select(SystemConfig).where(SystemConfig.key == "credit_consumption_rates") ) row = result.scalar_one_or_none() if row and row.value: return {**DEFAULT_CONSUMPTION_RATES, **row.value} return dict(DEFAULT_CONSUMPTION_RATES) async def _daily_translate_chars(self, uc: UserCredit) -> int: today = date.today() if uc.daily_translate_date != today: return 0 return uc.daily_translate_chars or 0 async def get_packages(self) -> list: result = await self.db.execute( select(CreditPackage).where(CreditPackage.is_active == True).order_by(CreditPackage.sort_order) ) return [{ "id": str(p.id), "name": p.name, "name_en": p.name_en, "credits": p.credits, "price": p.price, "price_usd": p.price_usd, "original_price": p.original_price, } for p in result.scalars().all()] async def get_subscription_plans(self) -> list: result = await self.db.execute( select(SubscriptionPlan).where(SubscriptionPlan.is_active == True).order_by(SubscriptionPlan.sort_order) ) return [{ "id": str(p.id), "name": p.name, "name_en": p.name_en, "credits_per_month": p.credits_per_month, "price": p.price, "price_usd": p.price_usd, "duration_days": p.duration_days, } for p in result.scalars().all()] async def get_stats(self) -> dict: result = await self.db.execute( select(func.coalesce(func.sum(UserCredit.total_purchased), 0)) ) total_purchased = result.scalar() result = await self.db.execute( select(func.coalesce(func.sum(UserCredit.balance), 0)) ) total_balance = result.scalar() result = await self.db.execute(select(func.count(UserCredit.id))) total_users = result.scalar() result = await self.db.execute( select(func.coalesce(func.sum(CreditConsumption.credits_change), 0)).where( CreditConsumption.credits_change < 0 ) ) total_consumed = abs(result.scalar() or 0) return { "total_purchased": total_purchased, "total_balance": total_balance, "total_consumed": total_consumed, "total_users_with_credits": total_users, } CREDIT_CONSUMPTION = { "lead_search": 10, "company_analysis": 5, "market_intel": 20, "translate_per_1000chars": 1, "reply_suggest": 2, "outreach": 3, "marketing_content": 5, "competitor_analysis": 10, "ai_chat": 1, "info_extract": 1, "quotation": 2, "followup_scan": 2, }