fa3050a17c
Anonymous users have no CSRF cookie, so require_csrf_token always raises 403 on first visit. This broke all first-time logins and registrations. CSRF protection is unnecessary here since there's no authenticated session to forge requests against.
422 lines
13 KiB
Python
422 lines
13 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status, Header, Request
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from typing import Optional
|
|
import uuid
|
|
from app.database import get_db
|
|
from app.models.user import User
|
|
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
|
|
from app.core.csrf import require_csrf_token
|
|
from pydantic import BaseModel, EmailStr
|
|
from datetime import datetime, timedelta
|
|
from app.services.admin import AdminService
|
|
from app.models.subscription import Subscription
|
|
from app.api.v1.referral import apply_referral
|
|
from app.config import settings
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class RegisterRequest(BaseModel):
|
|
phone: str
|
|
password: str
|
|
username: str = ""
|
|
ref_code: str = ""
|
|
|
|
|
|
class LoginResponse(BaseModel):
|
|
access_token: str
|
|
refresh_token: str
|
|
token_type: str = "bearer"
|
|
user: dict
|
|
|
|
|
|
class LoginRequest(BaseModel):
|
|
username: str = ""
|
|
phone: str = ""
|
|
password: str
|
|
|
|
|
|
class RefreshRequest(BaseModel):
|
|
refresh_token: str
|
|
|
|
|
|
@router.post("/register")
|
|
async def register(
|
|
data: RegisterRequest,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
existing = await db.execute(select(User).where(User.phone == data.phone))
|
|
if existing.scalar_one_or_none():
|
|
raise HTTPException(status_code=400, detail="Phone already registered")
|
|
|
|
user = User(
|
|
phone=data.phone,
|
|
username=data.username or data.phone,
|
|
password_hash=hash_password(data.password),
|
|
tier="pro",
|
|
)
|
|
db.add(user)
|
|
await db.flush()
|
|
|
|
trial_end = datetime.utcnow() + timedelta(days=settings.TRIAL_DAYS)
|
|
sub = Subscription(
|
|
user_id=user.id,
|
|
plan="pro_trial",
|
|
status="active",
|
|
started_at=datetime.utcnow(),
|
|
expires_at=trial_end,
|
|
)
|
|
db.add(sub)
|
|
|
|
if data.ref_code:
|
|
try:
|
|
from app.api.v1.referral import do_claim_referral
|
|
await do_claim_referral(data.ref_code, str(user.id), db)
|
|
except Exception as e:
|
|
logger.warning(f"Referral claim failed: {e}")
|
|
|
|
client_ip = request.client.host if request.client else None
|
|
await AdminService(db).log_usage(str(user.id), "user.register", {"phone": data.phone}, ip=client_ip)
|
|
|
|
return {
|
|
"id": str(user.id),
|
|
"phone": user.phone,
|
|
"username": user.username,
|
|
"tier": user.tier,
|
|
"role": user.role,
|
|
}
|
|
|
|
|
|
@router.post("/login", response_model=LoginResponse)
|
|
async def login(
|
|
data: LoginRequest,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
login_id = data.username or data.phone
|
|
if not login_id:
|
|
raise HTTPException(status_code=422, detail="phone required")
|
|
result = await db.execute(
|
|
select(User).where(
|
|
(User.phone == login_id) | (User.username == login_id)
|
|
)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not verify_password(data.password, user.password_hash):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid credentials",
|
|
)
|
|
|
|
client_ip = request.client.host if request.client else None
|
|
await AdminService(db).log_usage(str(user.id), "user.login", {"login_id": login_id}, ip=client_ip)
|
|
|
|
if user.tier == "pro":
|
|
sub_result = await db.execute(
|
|
select(Subscription).where(
|
|
Subscription.user_id == user.id,
|
|
Subscription.plan == "pro_trial",
|
|
Subscription.status == "active",
|
|
)
|
|
)
|
|
trial_sub = sub_result.scalar_one_or_none()
|
|
if trial_sub and trial_sub.expires_at and trial_sub.expires_at < datetime.utcnow():
|
|
trial_sub.status = "expired"
|
|
user.tier = "free"
|
|
await db.flush()
|
|
|
|
return LoginResponse(
|
|
access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}),
|
|
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
|
user={
|
|
"id": str(user.id),
|
|
"phone": user.phone,
|
|
"username": user.username,
|
|
"tier": user.tier,
|
|
},
|
|
)
|
|
|
|
|
|
@router.post("/login/guest")
|
|
async def guest_login(request: Request, db: AsyncSession = Depends(get_db)):
|
|
guest_id = str(uuid.uuid4())
|
|
access_token = create_access_token(
|
|
{"sub": guest_id, "tier": "guest", "role": "guest", "is_guest": True},
|
|
expires_delta=timedelta(hours=24)
|
|
)
|
|
refresh_token = create_refresh_token({"sub": guest_id, "is_guest": True})
|
|
|
|
client_ip = request.client.host if request.client else None
|
|
await AdminService(db).log_usage(guest_id, "user.login_guest", {}, ip=client_ip)
|
|
|
|
return LoginResponse(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
token_type="bearer",
|
|
user={
|
|
"id": guest_id,
|
|
"phone": None,
|
|
"username": "游客用户",
|
|
"tier": "guest",
|
|
"is_guest": True,
|
|
},
|
|
)
|
|
|
|
|
|
@router.post("/refresh")
|
|
async def refresh(data: RefreshRequest):
|
|
payload = decode_token(data.refresh_token)
|
|
if not payload or payload.get("type") != "refresh":
|
|
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
|
|
|
# 保留游客/角色等信息
|
|
extra = {}
|
|
if payload.get("is_guest"):
|
|
extra = {"is_guest": True, "tier": "guest", "role": "guest"}
|
|
else:
|
|
extra = {
|
|
"tier": payload.get("tier", "free"),
|
|
"role": payload.get("role", "user"),
|
|
}
|
|
|
|
return {
|
|
"access_token": create_access_token({"sub": payload["sub"], **extra}),
|
|
"token_type": "bearer",
|
|
}
|
|
|
|
|
|
@router.get("/me")
|
|
async def get_me(
|
|
authorization: Optional[str] = Header(None, alias="Authorization"),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(status_code=401, detail="Missing token")
|
|
|
|
payload = decode_token(authorization[7:])
|
|
if not payload:
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
if payload.get("is_guest"):
|
|
return {
|
|
"id": payload["sub"],
|
|
"phone": None,
|
|
"username": "游客用户",
|
|
"tier": "guest",
|
|
"role": "guest",
|
|
"is_guest": True,
|
|
"settings": {},
|
|
"created_at": None,
|
|
}
|
|
|
|
result = await db.execute(select(User).where(User.id == payload["sub"]))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
trial_days_left = 0
|
|
if user.tier == "pro":
|
|
sub_result = await db.execute(
|
|
select(Subscription).where(
|
|
Subscription.user_id == user.id,
|
|
Subscription.plan == "pro_trial",
|
|
Subscription.status == "active",
|
|
)
|
|
)
|
|
trial_sub = sub_result.scalar_one_or_none()
|
|
if trial_sub and trial_sub.expires_at:
|
|
remaining = (trial_sub.expires_at - datetime.utcnow()).days
|
|
trial_days_left = max(0, remaining)
|
|
|
|
return {
|
|
"id": str(user.id),
|
|
"phone": user.phone,
|
|
"username": user.username,
|
|
"tier": user.tier,
|
|
"role": user.role,
|
|
"settings": user.settings,
|
|
"created_at": user.created_at.isoformat() if user.created_at else None,
|
|
"trial_days_left": trial_days_left,
|
|
}
|
|
|
|
|
|
class ProfileUpdate(BaseModel):
|
|
username: str = None
|
|
email: str = None
|
|
|
|
|
|
class PasswordChange(BaseModel):
|
|
old_password: str
|
|
new_password: str
|
|
|
|
|
|
class SettingsUpdate(BaseModel):
|
|
preferred_translate_provider: str = None
|
|
reply_tone: str = None
|
|
timezone: str = None
|
|
languages: list = None
|
|
|
|
|
|
class WeChatLoginRequest(BaseModel):
|
|
code: str
|
|
encrypted_data: str = ""
|
|
iv: str = ""
|
|
|
|
|
|
@router.put("/me")
|
|
async def update_me(
|
|
data: ProfileUpdate,
|
|
authorization: Optional[str] = Header(None, alias="Authorization"),
|
|
db: AsyncSession = Depends(get_db),
|
|
_csrf: str = Depends(require_csrf_token),
|
|
):
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(status_code=401, detail="Missing token")
|
|
|
|
payload = decode_token(authorization[7:])
|
|
if not payload:
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
if payload.get("is_guest"):
|
|
raise HTTPException(status_code=403, detail="Guests cannot update profile")
|
|
|
|
result = await db.execute(select(User).where(User.id == payload["sub"]))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
if data.username is not None:
|
|
user.username = data.username
|
|
if data.email is not None:
|
|
user.email = data.email
|
|
|
|
await db.flush()
|
|
return {
|
|
"id": str(user.id),
|
|
"phone": user.phone,
|
|
"username": user.username,
|
|
"email": user.email,
|
|
"tier": user.tier,
|
|
"role": user.role,
|
|
}
|
|
|
|
|
|
@router.put("/password")
|
|
async def change_password(
|
|
data: PasswordChange,
|
|
authorization: Optional[str] = Header(None, alias="Authorization"),
|
|
db: AsyncSession = Depends(get_db),
|
|
_csrf: str = Depends(require_csrf_token),
|
|
):
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(status_code=401, detail="Missing token")
|
|
|
|
payload = decode_token(authorization[7:])
|
|
if not payload:
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
if payload.get("is_guest"):
|
|
raise HTTPException(status_code=403, detail="Guests cannot change password")
|
|
|
|
result = await db.execute(select(User).where(User.id == payload["sub"]))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
if not verify_password(data.old_password, user.password_hash):
|
|
raise HTTPException(status_code=400, detail="旧密码不正确")
|
|
|
|
user.password_hash = hash_password(data.new_password)
|
|
await db.flush()
|
|
return {"message": "密码修改成功"}
|
|
|
|
|
|
@router.get("/wechat/config")
|
|
async def wechat_config():
|
|
from app.config import settings
|
|
return {
|
|
"available": bool(settings.WECHAT_APP_ID and settings.WECHAT_APP_SECRET),
|
|
"app_id": settings.WECHAT_APP_ID or "",
|
|
}
|
|
|
|
|
|
@router.post("/wechat-login")
|
|
async def wechat_login(
|
|
data: WeChatLoginRequest,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
_csrf: str = Depends(require_csrf_token),
|
|
):
|
|
from app.services.wechat import wechat_service
|
|
|
|
session = await wechat_service.code2session(data.code)
|
|
if not session:
|
|
raise HTTPException(status_code=400, detail="WeChat login failed")
|
|
|
|
openid = session.get("openid")
|
|
result = await db.execute(select(User).where(User.wechat_openid == openid))
|
|
user = result.scalar_one_or_none()
|
|
is_new = False
|
|
|
|
if not user:
|
|
user = User(
|
|
wechat_openid=openid,
|
|
username=f"wx_{openid[-8:]}",
|
|
tier="free",
|
|
)
|
|
db.add(user)
|
|
await db.flush()
|
|
is_new = True
|
|
|
|
client_ip = request.client.host if request.client else None
|
|
await AdminService(db).log_usage(str(user.id), "user.wechat_login", {"is_new": is_new}, ip=client_ip)
|
|
|
|
return LoginResponse(
|
|
access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}),
|
|
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
|
user={
|
|
"id": str(user.id),
|
|
"phone": user.phone,
|
|
"username": user.username,
|
|
"tier": user.tier,
|
|
"role": user.role,
|
|
},
|
|
)
|
|
|
|
|
|
@router.patch("/settings")
|
|
async def update_settings(
|
|
data: SettingsUpdate,
|
|
authorization: Optional[str] = Header(None, alias="Authorization"),
|
|
db: AsyncSession = Depends(get_db),
|
|
_csrf: str = Depends(require_csrf_token),
|
|
):
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(status_code=401, detail="Missing token")
|
|
|
|
payload = decode_token(authorization[7:])
|
|
if not payload:
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
result = await db.execute(select(User).where(User.id == payload["sub"]))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
settings = user.settings or {}
|
|
for key, value in data.dict(exclude_unset=True).items():
|
|
if value is not None:
|
|
settings[key] = value
|
|
|
|
user.settings = settings
|
|
await db.flush()
|
|
|
|
return {"settings": user.settings}
|