from fastapi import APIRouter, Depends, HTTPException, status, Header, Request from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from typing import Optional import uuid from app.database import get_db from app.models.user import User from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token from app.core.csrf import require_csrf_token from pydantic import BaseModel, EmailStr, field_validator from datetime import datetime, timedelta from app.services.admin import AdminService from app.models.subscription import Subscription from app.api.v1.referral import apply_referral from app.config import settings import logging logger = logging.getLogger(__name__) router = APIRouter() class RegisterRequest(BaseModel): phone: str password: str username: str = "" ref_code: str = "" class LoginResponse(BaseModel): access_token: str refresh_token: str token_type: str = "bearer" user: dict class LoginRequest(BaseModel): username: str = "" phone: str = "" password: str @field_validator('password') @classmethod def validate_password(cls, v: str) -> str: if len(v) < 6: raise ValueError('Password must be at least 6 characters') return v class RefreshRequest(BaseModel): refresh_token: str @router.post("/register") async def register( data: RegisterRequest, request: Request, db: AsyncSession = Depends(get_db), ): existing = await db.execute(select(User).where(User.phone == data.phone)) if existing.scalar_one_or_none(): raise HTTPException(status_code=400, detail="Phone already registered") user = User( phone=data.phone, username=data.username or data.phone, password_hash=hash_password(data.password), tier="pro", ) db.add(user) await db.flush() trial_end = datetime.utcnow() + timedelta(days=settings.TRIAL_DAYS) sub = Subscription( user_id=user.id, plan="pro_trial", status="active", started_at=datetime.utcnow(), expires_at=trial_end, ) db.add(sub) from app.services.credit import CreditService credit_svc = CreditService(db) await credit_svc.grant_free_trial(user.id) if data.ref_code: try: from app.api.v1.referral import do_claim_referral await do_claim_referral(data.ref_code, str(user.id), db) except Exception as e: logger.warning(f"Referral claim failed: {e}") client_ip = request.client.host if request.client else None await AdminService(db).log_usage(str(user.id), "user.register", {"phone": data.phone}, ip=client_ip) return { "id": str(user.id), "phone": user.phone, "username": user.username, "tier": user.tier, "role": user.role, } @router.post("/login", response_model=LoginResponse) async def login( data: LoginRequest, request: Request, db: AsyncSession = Depends(get_db), ): login_id = data.username or data.phone if not login_id: raise HTTPException(status_code=422, detail="phone required") result = await db.execute( select(User).where( (User.phone == login_id) | (User.username == login_id) ) ) user = result.scalar_one_or_none() if not user or not verify_password(data.password, user.password_hash): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials", ) client_ip = request.client.host if request.client else None await AdminService(db).log_usage(str(user.id), "user.login", {"login_id": login_id}, ip=client_ip) if user.tier == "pro": sub_result = await db.execute( select(Subscription).where( Subscription.user_id == user.id, Subscription.plan == "pro_trial", Subscription.status == "active", ) ) trial_sub = sub_result.scalar_one_or_none() if trial_sub and trial_sub.expires_at and trial_sub.expires_at < datetime.utcnow(): trial_sub.status = "expired" user.tier = "free" await db.flush() return LoginResponse( access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}), refresh_token=create_refresh_token({"sub": str(user.id)}), user={ "id": str(user.id), "phone": user.phone, "username": user.username, "tier": user.tier, }, ) @router.post("/login/guest") async def guest_login(request: Request, db: AsyncSession = Depends(get_db)): # Rate limiting: max 5 guest logins per IP per 15 minutes from app.core.redis import get_redis client_ip = request.client.host if request.client else "unknown" cache_key = f"guest_login:{client_ip}" try: redis_client = await get_redis() now = int(time.time()) window = settings.GUEST_LOGIN_WINDOW # 15 minutes limit = settings.GUEST_LOGIN_LIMIT # Get count of logins in current window count = await redis_client.get(cache_key) if count and int(count) >= limit: raise HTTPException( status_code=429, detail="Too many guest login attempts. Please try again later or register an account." ) # Increment counter pipe = redis_client.pipeline() pipe.incr(cache_key) pipe.expire(cache_key, window) await pipe.execute() except HTTPException: raise except Exception: # If Redis is down, proceed without rate limiting pass guest_id = str(uuid.uuid4()) access_token = create_access_token( {"sub": guest_id, "tier": "guest", "role": "guest", "is_guest": True}, expires_delta=timedelta(hours=24) ) refresh_token = create_refresh_token({"sub": guest_id, "is_guest": True}) await AdminService(db).log_usage(guest_id, "user.login_guest", {}) return LoginResponse( access_token=access_token, refresh_token=refresh_token, token_type="bearer", user={ "id": guest_id, "phone": None, "username": "游客用户", "tier": "guest", "is_guest": True, }, ) @router.post("/refresh") async def refresh(data: RefreshRequest): payload = decode_token(data.refresh_token) if not payload or payload.get("type") != "refresh": raise HTTPException(status_code=401, detail="Invalid refresh token") # 保留游客/角色等信息 extra = {} if payload.get("is_guest"): extra = {"is_guest": True, "tier": "guest", "role": "guest"} else: extra = { "tier": payload.get("tier", "free"), "role": payload.get("role", "user"), } return { "access_token": create_access_token({"sub": payload["sub"], **extra}), "token_type": "bearer", } @router.get("/me") async def get_me( authorization: Optional[str] = Header(None, alias="Authorization"), db: AsyncSession = Depends(get_db), ): if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing token") payload = decode_token(authorization[7:]) if not payload: raise HTTPException(status_code=401, detail="Invalid token") if payload.get("is_guest"): return { "id": payload["sub"], "phone": None, "username": "游客用户", "tier": "guest", "role": "guest", "is_guest": True, "settings": {}, "created_at": None, } result = await db.execute(select(User).where(User.id == payload["sub"])) user = result.scalar_one_or_none() if not user: raise HTTPException(status_code=404, detail="User not found") trial_days_left = 0 if user.tier == "pro": sub_result = await db.execute( select(Subscription).where( Subscription.user_id == user.id, Subscription.plan == "pro_trial", Subscription.status == "active", ) ) trial_sub = sub_result.scalar_one_or_none() if trial_sub and trial_sub.expires_at: remaining = (trial_sub.expires_at - datetime.utcnow()).days trial_days_left = max(0, remaining) return { "id": str(user.id), "phone": user.phone, "username": user.username, "tier": user.tier, "role": user.role, "settings": user.settings, "created_at": user.created_at.isoformat() if user.created_at else None, "trial_days_left": trial_days_left, } class ProfileUpdate(BaseModel): username: str = None email: str = None class PasswordChange(BaseModel): old_password: str new_password: str class SettingsUpdate(BaseModel): preferred_translate_provider: str = None reply_tone: str = None timezone: str = None languages: list = None class WeChatLoginRequest(BaseModel): code: str encrypted_data: str = "" iv: str = "" @router.put("/me") async def update_me( data: ProfileUpdate, authorization: Optional[str] = Header(None, alias="Authorization"), db: AsyncSession = Depends(get_db), _csrf: str = Depends(require_csrf_token), ): if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing token") payload = decode_token(authorization[7:]) if not payload: raise HTTPException(status_code=401, detail="Invalid token") if payload.get("is_guest"): raise HTTPException(status_code=403, detail="Guests cannot update profile") result = await db.execute(select(User).where(User.id == payload["sub"])) user = result.scalar_one_or_none() if not user: raise HTTPException(status_code=404, detail="User not found") if data.username is not None: user.username = data.username if data.email is not None: user.email = data.email await db.flush() return { "id": str(user.id), "phone": user.phone, "username": user.username, "email": user.email, "tier": user.tier, "role": user.role, } @router.put("/password") async def change_password( data: PasswordChange, authorization: Optional[str] = Header(None, alias="Authorization"), db: AsyncSession = Depends(get_db), _csrf: str = Depends(require_csrf_token), ): if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing token") payload = decode_token(authorization[7:]) if not payload: raise HTTPException(status_code=401, detail="Invalid token") if payload.get("is_guest"): raise HTTPException(status_code=403, detail="Guests cannot change password") result = await db.execute(select(User).where(User.id == payload["sub"])) user = result.scalar_one_or_none() if not user: raise HTTPException(status_code=404, detail="User not found") if not verify_password(data.old_password, user.password_hash): raise HTTPException(status_code=400, detail="旧密码不正确") user.password_hash = hash_password(data.new_password) await db.flush() return {"message": "密码修改成功"} @router.get("/wechat/config") async def wechat_config(): from app.config import settings return { "available": bool(settings.WECHAT_APP_ID and settings.WECHAT_APP_SECRET), "app_id": settings.WECHAT_APP_ID or "", } @router.post("/wechat-login") async def wechat_login( data: WeChatLoginRequest, request: Request, db: AsyncSession = Depends(get_db), _csrf: str = Depends(require_csrf_token), ): from app.services.wechat import wechat_service session = await wechat_service.code2session(data.code) if not session: raise HTTPException(status_code=400, detail="WeChat login failed") openid = session.get("openid") result = await db.execute(select(User).where(User.wechat_openid == openid)) user = result.scalar_one_or_none() is_new = False if not user: user = User( wechat_openid=openid, username=f"wx_{openid[-8:]}", tier="free", ) db.add(user) await db.flush() is_new = True client_ip = request.client.host if request.client else None await AdminService(db).log_usage(str(user.id), "user.wechat_login", {"is_new": is_new}, ip=client_ip) return LoginResponse( access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}), refresh_token=create_refresh_token({"sub": str(user.id)}), user={ "id": str(user.id), "phone": user.phone, "username": user.username, "tier": user.tier, "role": user.role, }, ) @router.patch("/settings") async def update_settings( data: SettingsUpdate, authorization: Optional[str] = Header(None, alias="Authorization"), db: AsyncSession = Depends(get_db), _csrf: str = Depends(require_csrf_token), ): if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing token") payload = decode_token(authorization[7:]) if not payload: raise HTTPException(status_code=401, detail="Invalid token") result = await db.execute(select(User).where(User.id == payload["sub"])) user = result.scalar_one_or_none() if not user: raise HTTPException(status_code=404, detail="User not found") settings = user.settings or {} for key, value in data.dict(exclude_unset=True).items(): if value is not None: settings[key] = value user.settings = settings await db.flush() return {"settings": user.settings}