from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.database import get_db from app.api.v1.deps import get_current_user_id from app.models.referral import ReferralCode, Referral from app.models.subscription import Subscription from app.models.user import User from app.config import settings from datetime import datetime, timedelta import uuid import secrets import string router = APIRouter() def generate_code() -> str: return "TM" + "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(6)) @router.post("/code") async def get_or_create_code( user_id: str = Depends(get_current_user_id), db: AsyncSession = Depends(get_db), ): result = await db.execute(select(ReferralCode).where(ReferralCode.user_id == user_id)) existing = result.scalar_one_or_none() if existing: return {"code": existing.code, "url": f"/workspace/?ref={existing.code}"} code = generate_code() while True: check = await db.execute(select(ReferralCode).where(ReferralCode.code == code)) if not check.scalar_one_or_none(): break code = generate_code() rc = ReferralCode(user_id=user_id, code=code) db.add(rc) await db.commit() return {"code": code, "url": f"/workspace/?ref={code}"} @router.get("/stats") async def get_referral_stats( user_id: str = Depends(get_current_user_id), db: AsyncSession = Depends(get_db), ): result = await db.execute(select(Referral).where(Referral.referrer_id == user_id)) referrals = result.scalars().all() total_reward_days = sum(r.reward_days for r in referrals if r.status == "completed") return { "total_referrals": len(referrals), "completed": sum(1 for r in referrals if r.status == "completed"), "total_reward_days": total_reward_days, } async def apply_referral(code: str, new_user_id: str, db: AsyncSession): rc_result = await db.execute(select(ReferralCode).where(ReferralCode.code == code)) rc = rc_result.scalar_one_or_none() if not rc: return if str(rc.user_id) == new_user_id: return existing = await db.execute(select(Referral).where(Referral.referred_id == new_user_id)) if existing.scalar_one_or_none(): return reward_days = 15 referrer_sub = await db.execute( select(Subscription).where( Subscription.user_id == rc.user_id, Subscription.status == "active", ).order_by(Subscription.created_at.desc()).limit(1) ) referrer_sub_row = referrer_sub.scalar_one_or_none() if referrer_sub_row: old_expiry = referrer_sub_row.expires_at or datetime.utcnow() referrer_sub_row.expires_at = old_expiry + timedelta(days=reward_days) else: new_sub = Subscription( user_id=rc.user_id, plan="pro_trial", status="active", started_at=datetime.utcnow(), expires_at=datetime.utcnow() + timedelta(days=reward_days), ) db.add(new_sub) user_result = await db.execute(select(User).where(User.id == rc.user_id)) u = user_result.scalar_one_or_none() if u and u.tier == "free": u.tier = "pro" user_result = await db.execute(select(User).where(User.id == new_user_id)) ru = user_result.scalar_one_or_none() if ru and ru.tier in ("free", "guest"): ru.tier = "pro" ref_sub = Subscription( user_id=new_user_id, plan="pro_trial", status="active", started_at=datetime.utcnow(), expires_at=datetime.utcnow() + timedelta(days=reward_days), ) db.add(ref_sub) referral = Referral( referrer_id=rc.user_id, referred_id=new_user_id, code=code, reward_days=reward_days, ) db.add(referral) await db.flush() @router.post("/claim") async def claim_referral( code: str, user_id: str = Depends(get_current_user_id), db: AsyncSession = Depends(get_db), ): rc_result = await db.execute(select(ReferralCode).where(ReferralCode.code == code)) rc = rc_result.scalar_one_or_none() if not rc: raise HTTPException(status_code=404, detail="无效的邀请码") if str(rc.user_id) == user_id: raise HTTPException(status_code=400, detail="不能使用自己的邀请码") existing = await db.execute(select(Referral).where(Referral.referred_id == user_id)) if existing.scalar_one_or_none(): raise HTTPException(status_code=400, detail="已经使用过邀请码了") await apply_referral(code, user_id, db) await db.commit() return {"success": True, "reward_days": 15}