from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func from fastapi import HTTPException, Depends from datetime import datetime, date from sqlalchemy import Date import logging from app.models import UsageLog, SystemConfig, User, Customer, Product from app.models.user import User from app.models.subscription import Subscription from app.api.v1.deps import get_current_user_id from app.database import get_db logger = logging.getLogger(__name__) TIER_LIMITS_DEFAULT = { "free": {"translate_chars": 5000, "replies": 20, "marketing": 5, "customers": 5, "products": 1, "quotations": 3}, "pro": {"translate_chars": 50000, "replies": 200, "marketing": 50, "customers": 100, "products": 20, "quotations": 30}, "enterprise": {"translate_chars": 999999999, "replies": 9999, "marketing": 9999, "customers": 99999, "products": 9999, "quotations": 9999}, } ACTION_MAP = { "translate": "translate_chars", "reply": "replies", "marketing_generate": "marketing", "create_customer": "customers", "create_product": "products", "create_quotation": "quotations", } class UsageService: def __init__(self, db: AsyncSession): self.db = db async def get_limits(self, tier: str) -> dict: config_key = f"{tier}_daily_limits" result = await self.db.execute(select(SystemConfig).where(SystemConfig.key == config_key)) row = result.scalar_one_or_none() if row and row.value: return {**TIER_LIMITS_DEFAULT.get(tier, {}), **row.value} return dict(TIER_LIMITS_DEFAULT.get(tier, {})) async def get_tier(self, user_id: str) -> str: result = await self.db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user: return "free" return user.tier or "free" async def get_daily_usage(self, user_id: str, action: str) -> int: today = date.today() stmt = select(func.count()).where( UsageLog.user_id == user_id, UsageLog.action == action, func.cast(UsageLog.created_at, Date) == today, ) result = await self.db.execute(stmt) return result.scalar() or 0 async def get_daily_chars(self, user_id: str) -> int: today = date.today() stmt = select(func.coalesce(func.sum( (UsageLog.detail["chars"]).as_integer() ), 0)).where( UsageLog.user_id == user_id, UsageLog.action == "translate", func.cast(UsageLog.created_at, Date) == today, ) result = await self.db.execute(stmt) return result.scalar() or 0 async def get_total_count(self, user_id: str, model_class) -> int: stmt = select(func.count()).where(model_class.user_id == user_id) result = await self.db.execute(stmt) return result.scalar() or 0 async def check_quota(self, user_id: str, action: str, chars: int = 0) -> tuple[bool, str]: tier = await self.get_tier(user_id) limits = await self.get_limits(tier) limit_key = ACTION_MAP.get(action) if not limit_key: return True, "" limit = limits.get(limit_key, 999999) if action == "translate": used = await self.get_daily_chars(user_id) if used + chars > limit: remaining = max(0, limit - used) return False, f"今日翻译字符已达上限({limit}字符),剩余{remaining}字符。升级 Pro 获取更多额度。" elif action in ("create_customer",): used = await self.get_total_count(user_id, Customer) if used >= limit: return False, f"客户数量已达上限({limit}个)。升级 Pro 获取更多客户管理额度。" elif action in ("create_product",): used = await self.get_total_count(user_id, Product) if used >= limit: return False, f"产品数量已达上限({limit}个)。升级 Pro 获取更多产品额度。" else: used = await self.get_daily_usage(user_id, action) if used >= limit: return False, f"今日{action}次数已达上限({limit}次)。升级 Pro 获取更多额度。" return True, "" async def record_usage(self, user_id: str, action: str, chars: int = 0, detail: dict = None): log = UsageLog( user_id=user_id, action=action, detail=detail or {}, ) if chars: log.detail["chars"] = chars self.db.add(log) await self.db.commit() async def get_usage_stats(self, user_id: str) -> dict: tier = await self.get_tier(user_id) limits = await self.get_limits(tier) trial_days_left = 0 if tier == "pro": result = await self.db.execute( select(Subscription).where( Subscription.user_id == user_id, Subscription.plan == "pro_trial", Subscription.status == "active", ) ) trial_sub = result.scalar_one_or_none() if trial_sub and trial_sub.expires_at: remaining = (trial_sub.expires_at - datetime.utcnow()).days trial_days_left = max(0, remaining) customer_count = await self.get_total_count(user_id, Customer) product_count = await self.get_total_count(user_id, Product) translate_chars = await self.get_daily_chars(user_id) reply_count = await self.get_daily_usage(user_id, "reply") marketing_count = await self.get_daily_usage(user_id, "marketing_generate") quotation_count = await self.get_daily_usage(user_id, "create_quotation") return { "tier": tier, "limits": limits, "usage": { "translate_chars": translate_chars, "replies": reply_count, "marketing": marketing_count, "customers": customer_count, "products": product_count, "quotations": quotation_count, }, "trial_days_left": trial_days_left, } def require_quota(action: str, chars_field: str = None): async def _check( user_id: str = Depends(get_current_user_id), db: AsyncSession = Depends(get_db), ): svc = UsageService(db) if action == "translate" and chars_field: raise HTTPException(status_code=400, detail="translate action needs explicit chars check") ok, msg = await svc.check_quota(user_id, action) if not ok: raise HTTPException(status_code=429, detail=msg) return user_id return _check