Files
trade-assistant/backend/app/api/v1/auth.py
T
TradeMate Dev 9e9c7ac270 fix: additional code quality and performance improvements
Code quality:
- Remove empty except blocks with proper logging
- Create shared pagination utility function
- Remove duplicate UUID validation code
- Fix dead code in translation.py

Performance:
- Fix N+1 query in followup engine (use join instead of loop)
- Add eager loading for customer health scores
- Create database indexes for common query patterns:
  - customers: (user_id, status), (user_id, last_contact_at)
  - payment_transactions: (user_id, created_at)
  - followup_logs: (user_id, customer_id)
  - notifications: (user_id, is_read)

Configuration:
- Centralize magic numbers in config.py:
  - Payment prices
  - File upload limits
  - Rate limiting settings
  - Pagination defaults
- Update auth.py to use centralized rate limiting config
- Update customer/product imports to use centralized upload limits
- Update import_service.py to use centralized MAX_ROWS
2026-06-11 18:25:08 +08:00

460 lines
14 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, status, Header, Request
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Optional
import uuid
from app.database import get_db
from app.models.user import User
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
from app.core.csrf import require_csrf_token
from pydantic import BaseModel, EmailStr, field_validator
from datetime import datetime, timedelta
from app.services.admin import AdminService
from app.models.subscription import Subscription
from app.api.v1.referral import apply_referral
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
class RegisterRequest(BaseModel):
phone: str
password: str
username: str = ""
ref_code: str = ""
class LoginResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
user: dict
class LoginRequest(BaseModel):
username: str = ""
phone: str = ""
password: str
@field_validator('password')
@classmethod
def validate_password(cls, v: str) -> str:
if len(v) < 6:
raise ValueError('Password must be at least 6 characters')
return v
class RefreshRequest(BaseModel):
refresh_token: str
@router.post("/register")
async def register(
data: RegisterRequest,
request: Request,
db: AsyncSession = Depends(get_db),
):
existing = await db.execute(select(User).where(User.phone == data.phone))
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Phone already registered")
user = User(
phone=data.phone,
username=data.username or data.phone,
password_hash=hash_password(data.password),
tier="pro",
)
db.add(user)
await db.flush()
trial_end = datetime.utcnow() + timedelta(days=settings.TRIAL_DAYS)
sub = Subscription(
user_id=user.id,
plan="pro_trial",
status="active",
started_at=datetime.utcnow(),
expires_at=trial_end,
)
db.add(sub)
if data.ref_code:
try:
from app.api.v1.referral import do_claim_referral
await do_claim_referral(data.ref_code, str(user.id), db)
except Exception as e:
logger.warning(f"Referral claim failed: {e}")
client_ip = request.client.host if request.client else None
await AdminService(db).log_usage(str(user.id), "user.register", {"phone": data.phone}, ip=client_ip)
return {
"id": str(user.id),
"phone": user.phone,
"username": user.username,
"tier": user.tier,
"role": user.role,
}
@router.post("/login", response_model=LoginResponse)
async def login(
data: LoginRequest,
request: Request,
db: AsyncSession = Depends(get_db),
):
login_id = data.username or data.phone
if not login_id:
raise HTTPException(status_code=422, detail="phone required")
result = await db.execute(
select(User).where(
(User.phone == login_id) | (User.username == login_id)
)
)
user = result.scalar_one_or_none()
if not user or not verify_password(data.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
)
client_ip = request.client.host if request.client else None
await AdminService(db).log_usage(str(user.id), "user.login", {"login_id": login_id}, ip=client_ip)
if user.tier == "pro":
sub_result = await db.execute(
select(Subscription).where(
Subscription.user_id == user.id,
Subscription.plan == "pro_trial",
Subscription.status == "active",
)
)
trial_sub = sub_result.scalar_one_or_none()
if trial_sub and trial_sub.expires_at and trial_sub.expires_at < datetime.utcnow():
trial_sub.status = "expired"
user.tier = "free"
await db.flush()
return LoginResponse(
access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}),
refresh_token=create_refresh_token({"sub": str(user.id)}),
user={
"id": str(user.id),
"phone": user.phone,
"username": user.username,
"tier": user.tier,
},
)
@router.post("/login/guest")
async def guest_login(request: Request, db: AsyncSession = Depends(get_db)):
# Rate limiting: max 5 guest logins per IP per 15 minutes
from app.core.redis import get_redis
client_ip = request.client.host if request.client else "unknown"
cache_key = f"guest_login:{client_ip}"
try:
redis_client = await get_redis()
now = int(time.time())
window = settings.GUEST_LOGIN_WINDOW # 15 minutes
limit = settings.GUEST_LOGIN_LIMIT
# Get count of logins in current window
count = await redis_client.get(cache_key)
if count and int(count) >= limit:
raise HTTPException(
status_code=429,
detail="Too many guest login attempts. Please try again later or register an account."
)
# Increment counter
pipe = redis_client.pipeline()
pipe.incr(cache_key)
pipe.expire(cache_key, window)
await pipe.execute()
except HTTPException:
raise
except Exception:
# If Redis is down, proceed without rate limiting
pass
guest_id = str(uuid.uuid4())
access_token = create_access_token(
{"sub": guest_id, "tier": "guest", "role": "guest", "is_guest": True},
expires_delta=timedelta(hours=24)
)
refresh_token = create_refresh_token({"sub": guest_id, "is_guest": True})
await AdminService(db).log_usage(guest_id, "user.login_guest", {})
return LoginResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
user={
"id": guest_id,
"phone": None,
"username": "游客用户",
"tier": "guest",
"is_guest": True,
},
)
@router.post("/refresh")
async def refresh(data: RefreshRequest):
payload = decode_token(data.refresh_token)
if not payload or payload.get("type") != "refresh":
raise HTTPException(status_code=401, detail="Invalid refresh token")
# 保留游客/角色等信息
extra = {}
if payload.get("is_guest"):
extra = {"is_guest": True, "tier": "guest", "role": "guest"}
else:
extra = {
"tier": payload.get("tier", "free"),
"role": payload.get("role", "user"),
}
return {
"access_token": create_access_token({"sub": payload["sub"], **extra}),
"token_type": "bearer",
}
@router.get("/me")
async def get_me(
authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db),
):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token")
payload = decode_token(authorization[7:])
if not payload:
raise HTTPException(status_code=401, detail="Invalid token")
if payload.get("is_guest"):
return {
"id": payload["sub"],
"phone": None,
"username": "游客用户",
"tier": "guest",
"role": "guest",
"is_guest": True,
"settings": {},
"created_at": None,
}
result = await db.execute(select(User).where(User.id == payload["sub"]))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
trial_days_left = 0
if user.tier == "pro":
sub_result = await db.execute(
select(Subscription).where(
Subscription.user_id == user.id,
Subscription.plan == "pro_trial",
Subscription.status == "active",
)
)
trial_sub = sub_result.scalar_one_or_none()
if trial_sub and trial_sub.expires_at:
remaining = (trial_sub.expires_at - datetime.utcnow()).days
trial_days_left = max(0, remaining)
return {
"id": str(user.id),
"phone": user.phone,
"username": user.username,
"tier": user.tier,
"role": user.role,
"settings": user.settings,
"created_at": user.created_at.isoformat() if user.created_at else None,
"trial_days_left": trial_days_left,
}
class ProfileUpdate(BaseModel):
username: str = None
email: str = None
class PasswordChange(BaseModel):
old_password: str
new_password: str
class SettingsUpdate(BaseModel):
preferred_translate_provider: str = None
reply_tone: str = None
timezone: str = None
languages: list = None
class WeChatLoginRequest(BaseModel):
code: str
encrypted_data: str = ""
iv: str = ""
@router.put("/me")
async def update_me(
data: ProfileUpdate,
authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token")
payload = decode_token(authorization[7:])
if not payload:
raise HTTPException(status_code=401, detail="Invalid token")
if payload.get("is_guest"):
raise HTTPException(status_code=403, detail="Guests cannot update profile")
result = await db.execute(select(User).where(User.id == payload["sub"]))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
if data.username is not None:
user.username = data.username
if data.email is not None:
user.email = data.email
await db.flush()
return {
"id": str(user.id),
"phone": user.phone,
"username": user.username,
"email": user.email,
"tier": user.tier,
"role": user.role,
}
@router.put("/password")
async def change_password(
data: PasswordChange,
authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token")
payload = decode_token(authorization[7:])
if not payload:
raise HTTPException(status_code=401, detail="Invalid token")
if payload.get("is_guest"):
raise HTTPException(status_code=403, detail="Guests cannot change password")
result = await db.execute(select(User).where(User.id == payload["sub"]))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
if not verify_password(data.old_password, user.password_hash):
raise HTTPException(status_code=400, detail="旧密码不正确")
user.password_hash = hash_password(data.new_password)
await db.flush()
return {"message": "密码修改成功"}
@router.get("/wechat/config")
async def wechat_config():
from app.config import settings
return {
"available": bool(settings.WECHAT_APP_ID and settings.WECHAT_APP_SECRET),
"app_id": settings.WECHAT_APP_ID or "",
}
@router.post("/wechat-login")
async def wechat_login(
data: WeChatLoginRequest,
request: Request,
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
from app.services.wechat import wechat_service
session = await wechat_service.code2session(data.code)
if not session:
raise HTTPException(status_code=400, detail="WeChat login failed")
openid = session.get("openid")
result = await db.execute(select(User).where(User.wechat_openid == openid))
user = result.scalar_one_or_none()
is_new = False
if not user:
user = User(
wechat_openid=openid,
username=f"wx_{openid[-8:]}",
tier="free",
)
db.add(user)
await db.flush()
is_new = True
client_ip = request.client.host if request.client else None
await AdminService(db).log_usage(str(user.id), "user.wechat_login", {"is_new": is_new}, ip=client_ip)
return LoginResponse(
access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}),
refresh_token=create_refresh_token({"sub": str(user.id)}),
user={
"id": str(user.id),
"phone": user.phone,
"username": user.username,
"tier": user.tier,
"role": user.role,
},
)
@router.patch("/settings")
async def update_settings(
data: SettingsUpdate,
authorization: Optional[str] = Header(None, alias="Authorization"),
db: AsyncSession = Depends(get_db),
_csrf: str = Depends(require_csrf_token),
):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing token")
payload = decode_token(authorization[7:])
if not payload:
raise HTTPException(status_code=401, detail="Invalid token")
result = await db.execute(select(User).where(User.id == payload["sub"]))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
settings = user.settings or {}
for key, value in data.dict(exclude_unset=True).items():
if value is not None:
settings[key] = value
user.settings = settings
await db.flush()
return {"settings": user.settings}