Files
trade-assistant/backend/app/api/v1/auth.py
T
TradeMate Dev c04fa2c19f T-005: Security hardening - CORS, Rate Limit, CSRF
- CORS: Restrict allowed origins to specific frontend URLs, limit methods and headers
- Rate Limit: Add fine-grained endpoint-specific rate limits for sensitive operations
  - Login: 5 requests/minute
  - Register: 3 requests/hour
  - Password change: 3 requests/5 minutes
  - Payment: 20 requests/minute
  - Admin: 30 requests/minute
- CSRF: Add CSRF protection middleware with double-submit cookie pattern
  - New app/core/csrf.py module with CSRFMiddleware
  - Require CSRF tokens on sensitive endpoints (auth, payment, profile)
  - Skip webhook endpoints for CSRF validation
- Fix pydantic-settings import in config.py
2026-05-29 10:26:23 +08:00

424 lines
13 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, status, Header, Request
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Optional
import uuid
from app.database import get_db
from app.models.user import User
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
from app.core.csrf import require_csrf_token
from pydantic import BaseModel, EmailStr
from datetime import datetime, timedelta
from app.services.admin import AdminService
from app.models.subscription import Subscription
from app.api.v1.referral import apply_referral
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
class RegisterRequest(BaseModel):
phone: str
password: str
username: str = ""
ref_code: str = ""
class LoginResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
user: dict
class LoginRequest(BaseModel):
username: str = ""
phone: str = ""
password: str
class RefreshRequest(BaseModel):
refresh_token: str
@router.post("/register")
async def register(
data: RegisterRequest,
request: Request,
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
existing = await db.execute(select(User).where(User.phone == data.phone))
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Phone already registered")
user = User(
phone=data.phone,
username=data.username or data.phone,
password_hash=hash_password(data.password),
tier="pro",
)
db.add(user)
await db.flush()
trial_end = datetime.utcnow() + timedelta(days=settings.TRIAL_DAYS)
sub = Subscription(
user_id=user.id,
plan="pro_trial",
status="active",
started_at=datetime.utcnow(),
expires_at=trial_end,
)
db.add(sub)
if data.ref_code:
try:
from app.api.v1.referral import do_claim_referral
await do_claim_referral(data.ref_code, str(user.id), db)
except Exception as e:
logger.warning(f"Referral claim failed: {e}")
client_ip = request.client.host if request.client else None
await AdminService(db).log_usage(str(user.id), "user.register", {"phone": data.phone}, ip=client_ip)
return {
"id": str(user.id),
"phone": user.phone,
"username": user.username,
"tier": user.tier,
"role": user.role,
}
@router.post("/login", response_model=LoginResponse)
async def login(
data: LoginRequest,
request: Request,
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
login_id = data.username or data.phone
if not login_id:
raise HTTPException(status_code=422, detail="phone required")
result = await db.execute(
select(User).where(
(User.phone == login_id) | (User.username == login_id)
)
)
user = result.scalar_one_or_none()
if not user or not verify_password(data.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
)
client_ip = request.client.host if request.client else None
await AdminService(db).log_usage(str(user.id), "user.login", {"login_id": login_id}, ip=client_ip)
if user.tier == "pro":
sub_result = await db.execute(
select(Subscription).where(
Subscription.user_id == user.id,
Subscription.plan == "pro_trial",
Subscription.status == "active",
)
)
trial_sub = sub_result.scalar_one_or_none()
if trial_sub and trial_sub.expires_at and trial_sub.expires_at < datetime.utcnow():
trial_sub.status = "expired"
user.tier = "free"
await db.flush()
return LoginResponse(
access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}),
refresh_token=create_refresh_token({"sub": str(user.id)}),
user={
"id": str(user.id),
"phone": user.phone,
"username": user.username,
"tier": user.tier,
},
)
@router.post("/login/guest")
async def guest_login(request: Request, db: AsyncSession = Depends(get_db)):
guest_id = str(uuid.uuid4())
access_token = create_access_token(
{"sub": guest_id, "tier": "guest", "role": "guest", "is_guest": True},
expires_delta=timedelta(hours=24)
)
refresh_token = create_refresh_token({"sub": guest_id, "is_guest": True})
client_ip = request.client.host if request.client else None
await AdminService(db).log_usage(guest_id, "user.login_guest", {}, ip=client_ip)
return LoginResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
user={
"id": guest_id,
"phone": None,
"username": "游客用户",
"tier": "guest",
"is_guest": True,
},
)
@router.post("/refresh")
async def refresh(data: RefreshRequest):
payload = decode_token(data.refresh_token)
if not payload or payload.get("type") != "refresh":
raise HTTPException(status_code=401, detail="Invalid refresh token")
# 保留游客/角色等信息
extra = {}
if payload.get("is_guest"):
extra = {"is_guest": True, "tier": "guest", "role": "guest"}
else:
extra = {
"tier": payload.get("tier", "free"),
"role": payload.get("role", "user"),
}
return {
"access_token": create_access_token({"sub": payload["sub"], **extra}),
"token_type": "bearer",
}
@router.get("/me")
async def get_me(
authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db),
):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token")
payload = decode_token(authorization[7:])
if not payload:
raise HTTPException(status_code=401, detail="Invalid token")
if payload.get("is_guest"):
return {
"id": payload["sub"],
"phone": None,
"username": "游客用户",
"tier": "guest",
"role": "guest",
"is_guest": True,
"settings": {},
"created_at": None,
}
result = await db.execute(select(User).where(User.id == payload["sub"]))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
trial_days_left = 0
if user.tier == "pro":
sub_result = await db.execute(
select(Subscription).where(
Subscription.user_id == user.id,
Subscription.plan == "pro_trial",
Subscription.status == "active",
)
)
trial_sub = sub_result.scalar_one_or_none()
if trial_sub and trial_sub.expires_at:
remaining = (trial_sub.expires_at - datetime.utcnow()).days
trial_days_left = max(0, remaining)
return {
"id": str(user.id),
"phone": user.phone,
"username": user.username,
"tier": user.tier,
"role": user.role,
"settings": user.settings,
"created_at": user.created_at.isoformat() if user.created_at else None,
"trial_days_left": trial_days_left,
}
class ProfileUpdate(BaseModel):
username: str = None
email: str = None
class PasswordChange(BaseModel):
old_password: str
new_password: str
class SettingsUpdate(BaseModel):
preferred_translate_provider: str = None
reply_tone: str = None
timezone: str = None
languages: list = None
class WeChatLoginRequest(BaseModel):
code: str
encrypted_data: str = ""
iv: str = ""
@router.put("/me")
async def update_me(
data: ProfileUpdate,
authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token")
payload = decode_token(authorization[7:])
if not payload:
raise HTTPException(status_code=401, detail="Invalid token")
if payload.get("is_guest"):
raise HTTPException(status_code=403, detail="Guests cannot update profile")
result = await db.execute(select(User).where(User.id == payload["sub"]))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
if data.username is not None:
user.username = data.username
if data.email is not None:
user.email = data.email
await db.flush()
return {
"id": str(user.id),
"phone": user.phone,
"username": user.username,
"email": user.email,
"tier": user.tier,
"role": user.role,
}
@router.put("/password")
async def change_password(
data: PasswordChange,
authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token")
payload = decode_token(authorization[7:])
if not payload:
raise HTTPException(status_code=401, detail="Invalid token")
if payload.get("is_guest"):
raise HTTPException(status_code=403, detail="Guests cannot change password")
result = await db.execute(select(User).where(User.id == payload["sub"]))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
if not verify_password(data.old_password, user.password_hash):
raise HTTPException(status_code=400, detail="旧密码不正确")
user.password_hash = hash_password(data.new_password)
await db.flush()
return {"message": "密码修改成功"}
@router.get("/wechat/config")
async def wechat_config():
from app.config import settings
return {
"available": bool(settings.WECHAT_APP_ID and settings.WECHAT_APP_SECRET),
"app_id": settings.WECHAT_APP_ID or "",
}
@router.post("/wechat-login")
async def wechat_login(
data: WeChatLoginRequest,
request: Request,
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
from app.services.wechat import wechat_service
session = await wechat_service.code2session(data.code)
if not session:
raise HTTPException(status_code=400, detail="WeChat login failed")
openid = session.get("openid")
result = await db.execute(select(User).where(User.wechat_openid == openid))
user = result.scalar_one_or_none()
is_new = False
if not user:
user = User(
wechat_openid=openid,
username=f"wx_{openid[-8:]}",
tier="free",
)
db.add(user)
await db.flush()
is_new = True
client_ip = request.client.host if request.client else None
await AdminService(db).log_usage(str(user.id), "user.wechat_login", {"is_new": is_new}, ip=client_ip)
return LoginResponse(
access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}),
refresh_token=create_refresh_token({"sub": str(user.id)}),
user={
"id": str(user.id),
"phone": user.phone,
"username": user.username,
"tier": user.tier,
"role": user.role,
},
)
@router.patch("/settings")
async def update_settings(
data: SettingsUpdate,
authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token")
payload = decode_token(authorization[7:])
if not payload:
raise HTTPException(status_code=401, detail="Invalid token")
result = await db.execute(select(User).where(User.id == payload["sub"]))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
settings = user.settings or {}
for key, value in data.dict(exclude_unset=True).items():
if value is not None:
settings[key] = value
user.settings = settings
await db.flush()
return {"settings": user.settings}