5d2bced39f
- PROGRESS.md: update to 2026-05-29 with security hardening (T-005), 4-frontend architecture, AI provider refactoring, discovery features, landing page/referral/quota, desktop layout, admin AI management - AGENTS.md: add AI provider list (Alibaba/NVIDIA, removed Claude/DeepL/Local), DB-driven config, CSRF/rate-limit/CORS notes, admin_ai reload quirk - .env.example: sync with actual config, replace deprecated providers with current Sensenova/OpencodeGo/NVIDIA/Spark/Alibaba - docs/PROJECT_STATUS.md: archive (fully superseded by PROGRESS.md) - Remove generated JS files (_bing_search.js, _batch_search.js) - Remove empty directories (data/corpus, data/models) - Remove backend/.coverage (test artifact) - Fix services/.gitignore to cover _bing_search.js - Include pending AI provider DB admin feature (admin_ai, AIProvider model, AIProviders.vue, migration) and T-008 test report
172 lines
6.6 KiB
Python
172 lines
6.6 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, func
|
|
from fastapi import HTTPException, Depends
|
|
from datetime import datetime, date
|
|
from sqlalchemy import Date
|
|
from typing import Tuple
|
|
import logging
|
|
|
|
from app.models import UsageLog, SystemConfig, User, Customer, Product
|
|
from app.models.user import User
|
|
from app.models.subscription import Subscription
|
|
from app.api.v1.deps import get_current_user_id
|
|
from app.database import get_db
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TIER_LIMITS_DEFAULT = {
|
|
"free": {"translate_chars": 5000, "replies": 20, "marketing": 5, "customers": 5, "products": 1, "quotations": 3},
|
|
"pro": {"translate_chars": 50000, "replies": 200, "marketing": 50, "customers": 100, "products": 20, "quotations": 30},
|
|
"enterprise": {"translate_chars": 999999999, "replies": 9999, "marketing": 9999, "customers": 99999, "products": 9999, "quotations": 9999},
|
|
}
|
|
|
|
ACTION_MAP = {
|
|
"translate": "translate_chars",
|
|
"reply": "replies",
|
|
"marketing_generate": "marketing",
|
|
"create_customer": "customers",
|
|
"create_product": "products",
|
|
"create_quotation": "quotations",
|
|
}
|
|
|
|
|
|
class UsageService:
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def get_limits(self, tier: str) -> dict:
|
|
config_key = f"{tier}_daily_limits"
|
|
result = await self.db.execute(select(SystemConfig).where(SystemConfig.key == config_key))
|
|
row = result.scalar_one_or_none()
|
|
if row and row.value:
|
|
return {**TIER_LIMITS_DEFAULT.get(tier, {}), **row.value}
|
|
return dict(TIER_LIMITS_DEFAULT.get(tier, {}))
|
|
|
|
async def get_tier(self, user_id: str) -> str:
|
|
result = await self.db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
return "free"
|
|
return user.tier or "free"
|
|
|
|
async def get_daily_usage(self, user_id: str, action: str) -> int:
|
|
today = date.today()
|
|
stmt = select(func.count()).where(
|
|
UsageLog.user_id == user_id,
|
|
UsageLog.action == action,
|
|
func.cast(UsageLog.created_at, Date) == today,
|
|
)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
async def get_daily_chars(self, user_id: str) -> int:
|
|
today = date.today()
|
|
stmt = select(func.coalesce(func.sum(
|
|
(UsageLog.detail["chars"]).as_integer()
|
|
), 0)).where(
|
|
UsageLog.user_id == user_id,
|
|
UsageLog.action == "translate",
|
|
func.cast(UsageLog.created_at, Date) == today,
|
|
)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
async def get_total_count(self, user_id: str, model_class) -> int:
|
|
stmt = select(func.count()).where(model_class.user_id == user_id)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
async def check_quota(self, user_id: str, action: str, chars: int = 0) -> Tuple[bool, str]:
|
|
tier = await self.get_tier(user_id)
|
|
limits = await self.get_limits(tier)
|
|
limit_key = ACTION_MAP.get(action)
|
|
if not limit_key:
|
|
return True, ""
|
|
|
|
limit = limits.get(limit_key, 999999)
|
|
|
|
if action == "translate":
|
|
used = await self.get_daily_chars(user_id)
|
|
if used + chars > limit:
|
|
remaining = max(0, limit - used)
|
|
return False, f"今日翻译字符已达上限({limit}字符),剩余{remaining}字符。升级 Pro 获取更多额度。"
|
|
elif action in ("create_customer",):
|
|
used = await self.get_total_count(user_id, Customer)
|
|
if used >= limit:
|
|
return False, f"客户数量已达上限({limit}个)。升级 Pro 获取更多客户管理额度。"
|
|
elif action in ("create_product",):
|
|
used = await self.get_total_count(user_id, Product)
|
|
if used >= limit:
|
|
return False, f"产品数量已达上限({limit}个)。升级 Pro 获取更多产品额度。"
|
|
else:
|
|
used = await self.get_daily_usage(user_id, action)
|
|
if used >= limit:
|
|
return False, f"今日{action}次数已达上限({limit}次)。升级 Pro 获取更多额度。"
|
|
|
|
return True, ""
|
|
|
|
async def record_usage(self, user_id: str, action: str, chars: int = 0, detail: dict = None):
|
|
log = UsageLog(
|
|
user_id=user_id,
|
|
action=action,
|
|
detail=detail or {},
|
|
)
|
|
if chars:
|
|
log.detail["chars"] = chars
|
|
self.db.add(log)
|
|
await self.db.commit()
|
|
|
|
async def get_usage_stats(self, user_id: str) -> dict:
|
|
tier = await self.get_tier(user_id)
|
|
limits = await self.get_limits(tier)
|
|
|
|
trial_days_left = 0
|
|
if tier == "pro":
|
|
result = await self.db.execute(
|
|
select(Subscription).where(
|
|
Subscription.user_id == user_id,
|
|
Subscription.plan == "pro_trial",
|
|
Subscription.status == "active",
|
|
)
|
|
)
|
|
trial_sub = result.scalar_one_or_none()
|
|
if trial_sub and trial_sub.expires_at:
|
|
remaining = (trial_sub.expires_at - datetime.utcnow()).days
|
|
trial_days_left = max(0, remaining)
|
|
|
|
customer_count = await self.get_total_count(user_id, Customer)
|
|
product_count = await self.get_total_count(user_id, Product)
|
|
translate_chars = await self.get_daily_chars(user_id)
|
|
reply_count = await self.get_daily_usage(user_id, "reply")
|
|
marketing_count = await self.get_daily_usage(user_id, "marketing_generate")
|
|
quotation_count = await self.get_daily_usage(user_id, "create_quotation")
|
|
|
|
return {
|
|
"tier": tier,
|
|
"limits": limits,
|
|
"usage": {
|
|
"translate_chars": translate_chars,
|
|
"replies": reply_count,
|
|
"marketing": marketing_count,
|
|
"customers": customer_count,
|
|
"products": product_count,
|
|
"quotations": quotation_count,
|
|
},
|
|
"trial_days_left": trial_days_left,
|
|
}
|
|
|
|
|
|
def require_quota(action: str, chars_field: str = None):
|
|
async def _check(
|
|
user_id: str = Depends(get_current_user_id),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
svc = UsageService(db)
|
|
if action == "translate" and chars_field:
|
|
raise HTTPException(status_code=400, detail="translate action needs explicit chars check")
|
|
ok, msg = await svc.check_quota(user_id, action)
|
|
if not ok:
|
|
raise HTTPException(status_code=429, detail=msg)
|
|
return user_id
|
|
return _check
|