bed5c7abef
- Separate workspace landing from login for better UX - Referral system rewards both parties with Pro days - Quota enforcement prevents abuse without breaking endpoints - 7-day free trial with auto-downgrade on expiry - Admin-managed search provider config (SearXNG, Bing) - 15% discount on annual subscriptions - MCP search server wrapping opencode search - Fix discovery module field name mismatch causing 422
170 lines
6.6 KiB
Python
170 lines
6.6 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, func
|
|
from fastapi import HTTPException, Depends
|
|
from datetime import datetime, date
|
|
import logging
|
|
|
|
from app.models import UsageLog, SystemConfig, User, Customer, Product
|
|
from app.models.user import User
|
|
from app.models.subscription import Subscription
|
|
from app.api.v1.deps import get_current_user_id
|
|
from app.database import get_db
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TIER_LIMITS_DEFAULT = {
|
|
"free": {"translate_chars": 5000, "replies": 20, "marketing": 5, "customers": 5, "products": 1, "quotations": 3},
|
|
"pro": {"translate_chars": 50000, "replies": 200, "marketing": 50, "customers": 100, "products": 20, "quotations": 30},
|
|
"enterprise": {"translate_chars": 999999999, "replies": 9999, "marketing": 9999, "customers": 99999, "products": 9999, "quotations": 9999},
|
|
}
|
|
|
|
ACTION_MAP = {
|
|
"translate": "translate_chars",
|
|
"reply": "replies",
|
|
"marketing_generate": "marketing",
|
|
"create_customer": "customers",
|
|
"create_product": "products",
|
|
"create_quotation": "quotations",
|
|
}
|
|
|
|
|
|
class UsageService:
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def get_limits(self, tier: str) -> dict:
|
|
config_key = f"{tier}_daily_limits"
|
|
result = await self.db.execute(select(SystemConfig).where(SystemConfig.key == config_key))
|
|
row = result.scalar_one_or_none()
|
|
if row and row.value:
|
|
return {**TIER_LIMITS_DEFAULT.get(tier, {}), **row.value}
|
|
return dict(TIER_LIMITS_DEFAULT.get(tier, {}))
|
|
|
|
async def get_tier(self, user_id: str) -> str:
|
|
result = await self.db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
return "free"
|
|
return user.tier or "free"
|
|
|
|
async def get_daily_usage(self, user_id: str, action: str) -> int:
|
|
today = date.today()
|
|
stmt = select(func.count()).where(
|
|
UsageLog.user_id == user_id,
|
|
UsageLog.action == action,
|
|
func.cast(UsageLog.created_at, date) == today,
|
|
)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
async def get_daily_chars(self, user_id: str) -> int:
|
|
today = date.today()
|
|
stmt = select(func.coalesce(func.sum(
|
|
(UsageLog.detail["chars"]).as_integer()
|
|
), 0)).where(
|
|
UsageLog.user_id == user_id,
|
|
UsageLog.action == "translate",
|
|
func.cast(UsageLog.created_at, date) == today,
|
|
)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
async def get_total_count(self, user_id: str, model_class) -> int:
|
|
stmt = select(func.count()).where(model_class.user_id == user_id)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
async def check_quota(self, user_id: str, action: str, chars: int = 0) -> tuple[bool, str]:
|
|
tier = await self.get_tier(user_id)
|
|
limits = await self.get_limits(tier)
|
|
limit_key = ACTION_MAP.get(action)
|
|
if not limit_key:
|
|
return True, ""
|
|
|
|
limit = limits.get(limit_key, 999999)
|
|
|
|
if action == "translate":
|
|
used = await self.get_daily_chars(user_id)
|
|
if used + chars > limit:
|
|
remaining = max(0, limit - used)
|
|
return False, f"今日翻译字符已达上限({limit}字符),剩余{remaining}字符。升级 Pro 获取更多额度。"
|
|
elif action in ("create_customer",):
|
|
used = await self.get_total_count(user_id, Customer)
|
|
if used >= limit:
|
|
return False, f"客户数量已达上限({limit}个)。升级 Pro 获取更多客户管理额度。"
|
|
elif action in ("create_product",):
|
|
used = await self.get_total_count(user_id, Product)
|
|
if used >= limit:
|
|
return False, f"产品数量已达上限({limit}个)。升级 Pro 获取更多产品额度。"
|
|
else:
|
|
used = await self.get_daily_usage(user_id, action)
|
|
if used >= limit:
|
|
return False, f"今日{action}次数已达上限({limit}次)。升级 Pro 获取更多额度。"
|
|
|
|
return True, ""
|
|
|
|
async def record_usage(self, user_id: str, action: str, chars: int = 0, detail: dict = None):
|
|
log = UsageLog(
|
|
user_id=user_id,
|
|
action=action,
|
|
detail=detail or {},
|
|
)
|
|
if chars:
|
|
log.detail["chars"] = chars
|
|
self.db.add(log)
|
|
await self.db.commit()
|
|
|
|
async def get_usage_stats(self, user_id: str) -> dict:
|
|
tier = await self.get_tier(user_id)
|
|
limits = await self.get_limits(tier)
|
|
|
|
trial_days_left = 0
|
|
if tier == "pro":
|
|
result = await self.db.execute(
|
|
select(Subscription).where(
|
|
Subscription.user_id == user_id,
|
|
Subscription.plan == "pro_trial",
|
|
Subscription.status == "active",
|
|
)
|
|
)
|
|
trial_sub = result.scalar_one_or_none()
|
|
if trial_sub and trial_sub.expires_at:
|
|
remaining = (trial_sub.expires_at - datetime.utcnow()).days
|
|
trial_days_left = max(0, remaining)
|
|
|
|
customer_count = await self.get_total_count(user_id, Customer)
|
|
product_count = await self.get_total_count(user_id, Product)
|
|
translate_chars = await self.get_daily_chars(user_id)
|
|
reply_count = await self.get_daily_usage(user_id, "reply")
|
|
marketing_count = await self.get_daily_usage(user_id, "marketing_generate")
|
|
quotation_count = await self.get_daily_usage(user_id, "create_quotation")
|
|
|
|
return {
|
|
"tier": tier,
|
|
"limits": limits,
|
|
"usage": {
|
|
"translate_chars": translate_chars,
|
|
"replies": reply_count,
|
|
"marketing": marketing_count,
|
|
"customers": customer_count,
|
|
"products": product_count,
|
|
"quotations": quotation_count,
|
|
},
|
|
"trial_days_left": trial_days_left,
|
|
}
|
|
|
|
|
|
def require_quota(action: str, chars_field: str = None):
|
|
async def _check(
|
|
user_id: str = Depends(get_current_user_id),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
svc = UsageService(db)
|
|
if action == "translate" and chars_field:
|
|
raise HTTPException(status_code=400, detail="translate action needs explicit chars check")
|
|
ok, msg = await svc.check_quota(user_id, action)
|
|
if not ok:
|
|
raise HTTPException(status_code=429, detail=msg)
|
|
return user_id
|
|
return _check
|