Add landing page, referral system, usage quotas, search API management, and yearly pricing
- Separate workspace landing from login for better UX - Referral system rewards both parties with Pro days - Quota enforcement prevents abuse without breaking endpoints - 7-day free trial with auto-downgrade on expiry - Admin-managed search provider config (SearXNG, Bing) - 15% discount on annual subscriptions - MCP search server wrapping opencode search - Fix discovery module field name mismatch causing 422
This commit is contained in:
@@ -0,0 +1,189 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete
|
||||
from app.database import get_db
|
||||
from app.api.v1.deps import get_current_user
|
||||
from app.models.search_provider import SearchProvider
|
||||
from app.services.search import SearchService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def require_admin(current_user: dict = Depends(get_current_user)) -> dict:
|
||||
if current_user.get("role") != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
return current_user
|
||||
|
||||
|
||||
class ProviderCreate(BaseModel):
|
||||
name: str
|
||||
provider_type: str
|
||||
api_key: Optional[str] = None
|
||||
api_endpoint: Optional[str] = None
|
||||
extra_config: Optional[dict] = None
|
||||
priority: int = 0
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class ProviderUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_endpoint: Optional[str] = None
|
||||
extra_config: Optional[dict] = None
|
||||
priority: Optional[int] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
@router.get("/search-providers")
|
||||
async def list_providers(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(50, ge=1, le=100),
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(SearchProvider).order_by(SearchProvider.priority).offset((page - 1) * size).limit(size)
|
||||
)
|
||||
providers = result.scalars().all()
|
||||
total_result = await db.execute(select(SearchProvider))
|
||||
total = len(total_result.scalars().all())
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"id": str(p.id),
|
||||
"name": p.name,
|
||||
"provider_type": p.provider_type,
|
||||
"api_key": p.api_key[:8] + "..." if p.api_key and len(p.api_key) > 8 else p.api_key,
|
||||
"api_endpoint": p.api_endpoint,
|
||||
"extra_config": p.extra_config,
|
||||
"priority": p.priority,
|
||||
"enabled": p.enabled,
|
||||
"created_at": p.created_at.isoformat() if p.created_at else None,
|
||||
"updated_at": p.updated_at.isoformat() if p.updated_at else None,
|
||||
}
|
||||
for p in providers
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/search-providers")
|
||||
async def create_provider(
|
||||
data: ProviderCreate,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
provider = SearchProvider(
|
||||
name=data.name,
|
||||
provider_type=data.provider_type,
|
||||
api_key=data.api_key,
|
||||
api_endpoint=data.api_endpoint,
|
||||
extra_config=data.extra_config or {},
|
||||
priority=data.priority,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
db.add(provider)
|
||||
await db.flush()
|
||||
return {
|
||||
"id": str(provider.id),
|
||||
"name": provider.name,
|
||||
"provider_type": provider.provider_type,
|
||||
"message": "Provider created",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/search-providers/{provider_id}")
|
||||
async def get_provider(
|
||||
provider_id: str,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
_validate_uuid(provider_id)
|
||||
result = await db.execute(select(SearchProvider).where(SearchProvider.id == provider_id))
|
||||
p = result.scalar_one_or_none()
|
||||
if not p:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
return {
|
||||
"id": str(p.id),
|
||||
"name": p.name,
|
||||
"provider_type": p.provider_type,
|
||||
"api_key": p.api_key,
|
||||
"api_endpoint": p.api_endpoint,
|
||||
"extra_config": p.extra_config,
|
||||
"priority": p.priority,
|
||||
"enabled": p.enabled,
|
||||
"created_at": p.created_at.isoformat() if p.created_at else None,
|
||||
"updated_at": p.updated_at.isoformat() if p.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/search-providers/{provider_id}")
|
||||
async def update_provider(
|
||||
provider_id: str,
|
||||
data: ProviderUpdate,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
_validate_uuid(provider_id)
|
||||
result = await db.execute(select(SearchProvider).where(SearchProvider.id == provider_id))
|
||||
p = result.scalar_one_or_none()
|
||||
if not p:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
if data.name is not None:
|
||||
p.name = data.name
|
||||
if data.api_key is not None:
|
||||
p.api_key = data.api_key
|
||||
if data.api_endpoint is not None:
|
||||
p.api_endpoint = data.api_endpoint
|
||||
if data.extra_config is not None:
|
||||
p.extra_config = data.extra_config
|
||||
if data.priority is not None:
|
||||
p.priority = data.priority
|
||||
if data.enabled is not None:
|
||||
p.enabled = data.enabled
|
||||
await db.flush()
|
||||
return {"message": "Provider updated"}
|
||||
|
||||
|
||||
@router.delete("/search-providers/{provider_id}")
|
||||
async def delete_provider(
|
||||
provider_id: str,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
_validate_uuid(provider_id)
|
||||
result = await db.execute(delete(SearchProvider).where(SearchProvider.id == provider_id))
|
||||
if result.rowcount == 0:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
return {"message": "Provider deleted"}
|
||||
|
||||
|
||||
@router.post("/search-providers/{provider_id}/test")
|
||||
async def test_provider(
|
||||
provider_id: str,
|
||||
_: dict = Depends(require_admin),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
_validate_uuid(provider_id)
|
||||
result = await db.execute(select(SearchProvider).where(SearchProvider.id == provider_id))
|
||||
p = result.scalar_one_or_none()
|
||||
if not p:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
try:
|
||||
svc = SearchService(db)
|
||||
results = await svc._search_provider(p, "test", 3)
|
||||
return {"success": True, "results": results}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
def _validate_uuid(uuid_str: str):
|
||||
import uuid
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid UUID")
|
||||
@@ -10,6 +10,11 @@ from app.core.security import hash_password, verify_password, create_access_toke
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from datetime import datetime, timedelta
|
||||
from app.services.admin import AdminService
|
||||
from app.models.subscription import Subscription
|
||||
from app.api.v1.referral import apply_referral
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -18,6 +23,7 @@ class RegisterRequest(BaseModel):
|
||||
phone: str
|
||||
password: str
|
||||
username: str = ""
|
||||
ref_code: str = ""
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
@@ -47,11 +53,28 @@ async def register(data: RegisterRequest, request: Request, db: AsyncSession = D
|
||||
phone=data.phone,
|
||||
username=data.username or data.phone,
|
||||
password_hash=hash_password(data.password),
|
||||
tier="free",
|
||||
tier="pro",
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
trial_end = datetime.utcnow() + timedelta(days=settings.TRIAL_DAYS)
|
||||
sub = Subscription(
|
||||
user_id=user.id,
|
||||
plan="pro_trial",
|
||||
status="active",
|
||||
started_at=datetime.utcnow(),
|
||||
expires_at=trial_end,
|
||||
)
|
||||
db.add(sub)
|
||||
|
||||
if data.ref_code:
|
||||
try:
|
||||
from app.api.v1.referral import do_claim_referral
|
||||
await do_claim_referral(data.ref_code, str(user.id), db)
|
||||
except Exception as e:
|
||||
logger.warning(f"Referral claim failed: {e}")
|
||||
|
||||
client_ip = request.client.host if request.client else None
|
||||
await AdminService(db).log_usage(str(user.id), "user.register", {"phone": data.phone}, ip=client_ip)
|
||||
|
||||
@@ -89,6 +112,20 @@ async def login(
|
||||
client_ip = request.client.host if request.client else None
|
||||
await AdminService(db).log_usage(str(user.id), "user.login", {"login_id": login_id}, ip=client_ip)
|
||||
|
||||
if user.tier == "pro":
|
||||
sub_result = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.user_id == user.id,
|
||||
Subscription.plan == "pro_trial",
|
||||
Subscription.status == "active",
|
||||
)
|
||||
)
|
||||
trial_sub = sub_result.scalar_one_or_none()
|
||||
if trial_sub and trial_sub.expires_at and trial_sub.expires_at < datetime.utcnow():
|
||||
trial_sub.status = "expired"
|
||||
user.tier = "free"
|
||||
await db.flush()
|
||||
|
||||
return LoginResponse(
|
||||
access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}),
|
||||
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||
@@ -178,6 +215,20 @@ async def get_me(
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
trial_days_left = 0
|
||||
if user.tier == "pro":
|
||||
sub_result = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.user_id == user.id,
|
||||
Subscription.plan == "pro_trial",
|
||||
Subscription.status == "active",
|
||||
)
|
||||
)
|
||||
trial_sub = sub_result.scalar_one_or_none()
|
||||
if trial_sub and trial_sub.expires_at:
|
||||
remaining = (trial_sub.expires_at - datetime.utcnow()).days
|
||||
trial_days_left = max(0, remaining)
|
||||
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"phone": user.phone,
|
||||
@@ -186,6 +237,7 @@ async def get_me(
|
||||
"role": user.role,
|
||||
"settings": user.settings,
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
"trial_days_left": trial_days_left,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from app.database import get_db
|
||||
from app.services.customer import CustomerService
|
||||
from app.services.customer_health import CustomerHealthService
|
||||
from app.services.import_service import import_service
|
||||
from app.services.usage import UsageService
|
||||
from app.services import export
|
||||
from app.core.security import decode_token
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
@@ -98,8 +99,13 @@ async def create_customer(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
usage = UsageService(db)
|
||||
ok, msg = await usage.check_quota(user_id, "create_customer")
|
||||
if not ok:
|
||||
raise HTTPException(status_code=429, detail=msg)
|
||||
service = CustomerService(db)
|
||||
customer = await service.create_customer(user_id, data)
|
||||
await usage.record_usage(user_id, "create_customer")
|
||||
return customer
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional, List
|
||||
from app.database import get_db
|
||||
from app.services.product import ProductService
|
||||
from app.services import export
|
||||
from app.services.usage import UsageService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from pydantic import BaseModel
|
||||
import io
|
||||
@@ -175,8 +176,13 @@ async def create_product(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
usage = UsageService(db)
|
||||
ok, msg = await usage.check_quota(user_id, "create_product")
|
||||
if not ok:
|
||||
raise HTTPException(status_code=429, detail=msg)
|
||||
service = ProductService(db)
|
||||
product = await service.create_product(user_id, data.dict())
|
||||
await usage.record_usage(user_id, "create_product")
|
||||
return product
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.models.referral import ReferralCode, Referral
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.config import settings
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
import secrets
|
||||
import string
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def generate_code() -> str:
|
||||
return "TM" + "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(6))
|
||||
|
||||
|
||||
@router.post("/code")
|
||||
async def get_or_create_code(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(ReferralCode).where(ReferralCode.user_id == user_id))
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
return {"code": existing.code, "url": f"/workspace/?ref={existing.code}"}
|
||||
|
||||
code = generate_code()
|
||||
while True:
|
||||
check = await db.execute(select(ReferralCode).where(ReferralCode.code == code))
|
||||
if not check.scalar_one_or_none():
|
||||
break
|
||||
code = generate_code()
|
||||
|
||||
rc = ReferralCode(user_id=user_id, code=code)
|
||||
db.add(rc)
|
||||
await db.commit()
|
||||
return {"code": code, "url": f"/workspace/?ref={code}"}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_referral_stats(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Referral).where(Referral.referrer_id == user_id))
|
||||
referrals = result.scalars().all()
|
||||
total_reward_days = sum(r.reward_days for r in referrals if r.status == "completed")
|
||||
return {
|
||||
"total_referrals": len(referrals),
|
||||
"completed": sum(1 for r in referrals if r.status == "completed"),
|
||||
"total_reward_days": total_reward_days,
|
||||
}
|
||||
|
||||
|
||||
async def apply_referral(code: str, new_user_id: str, db: AsyncSession):
|
||||
rc_result = await db.execute(select(ReferralCode).where(ReferralCode.code == code))
|
||||
rc = rc_result.scalar_one_or_none()
|
||||
if not rc:
|
||||
return
|
||||
|
||||
if str(rc.user_id) == new_user_id:
|
||||
return
|
||||
|
||||
existing = await db.execute(select(Referral).where(Referral.referred_id == new_user_id))
|
||||
if existing.scalar_one_or_none():
|
||||
return
|
||||
|
||||
reward_days = 15
|
||||
|
||||
referrer_sub = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.user_id == rc.user_id,
|
||||
Subscription.status == "active",
|
||||
).order_by(Subscription.created_at.desc()).limit(1)
|
||||
)
|
||||
referrer_sub_row = referrer_sub.scalar_one_or_none()
|
||||
if referrer_sub_row:
|
||||
old_expiry = referrer_sub_row.expires_at or datetime.utcnow()
|
||||
referrer_sub_row.expires_at = old_expiry + timedelta(days=reward_days)
|
||||
else:
|
||||
new_sub = Subscription(
|
||||
user_id=rc.user_id,
|
||||
plan="pro_trial",
|
||||
status="active",
|
||||
started_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(days=reward_days),
|
||||
)
|
||||
db.add(new_sub)
|
||||
user_result = await db.execute(select(User).where(User.id == rc.user_id))
|
||||
u = user_result.scalar_one_or_none()
|
||||
if u and u.tier == "free":
|
||||
u.tier = "pro"
|
||||
|
||||
user_result = await db.execute(select(User).where(User.id == new_user_id))
|
||||
ru = user_result.scalar_one_or_none()
|
||||
if ru and ru.tier in ("free", "guest"):
|
||||
ru.tier = "pro"
|
||||
ref_sub = Subscription(
|
||||
user_id=new_user_id,
|
||||
plan="pro_trial",
|
||||
status="active",
|
||||
started_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(days=reward_days),
|
||||
)
|
||||
db.add(ref_sub)
|
||||
|
||||
referral = Referral(
|
||||
referrer_id=rc.user_id,
|
||||
referred_id=new_user_id,
|
||||
code=code,
|
||||
reward_days=reward_days,
|
||||
)
|
||||
db.add(referral)
|
||||
await db.flush()
|
||||
|
||||
|
||||
@router.post("/claim")
|
||||
async def claim_referral(
|
||||
code: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
rc_result = await db.execute(select(ReferralCode).where(ReferralCode.code == code))
|
||||
rc = rc_result.scalar_one_or_none()
|
||||
if not rc:
|
||||
raise HTTPException(status_code=404, detail="无效的邀请码")
|
||||
|
||||
if str(rc.user_id) == user_id:
|
||||
raise HTTPException(status_code=400, detail="不能使用自己的邀请码")
|
||||
|
||||
existing = await db.execute(select(Referral).where(Referral.referred_id == user_id))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="已经使用过邀请码了")
|
||||
|
||||
await apply_referral(code, user_id, db)
|
||||
await db.commit()
|
||||
return {"success": True, "reward_days": 15}
|
||||
@@ -0,0 +1,19 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.services.search import SearchService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/query")
|
||||
async def search(
|
||||
q: str = Query(..., min_length=1, max_length=500),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = SearchService(db)
|
||||
results = await svc.search(q, limit)
|
||||
return {"query": q, "results": results}
|
||||
@@ -0,0 +1,17 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.services.usage import UsageService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_usage_stats(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
svc = UsageService(db)
|
||||
stats = await svc.get_usage_stats(user_id)
|
||||
return stats
|
||||
Reference in New Issue
Block a user