127 lines
4.7 KiB
Python
127 lines
4.7 KiB
Python
from typing import Optional, Dict, Any, List
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, desc
|
|
from app.models.invoice import Invoice, InvoiceType, InvoiceStatus
|
|
from app.models.certification import Certification, CertStatus
|
|
from datetime import datetime
|
|
import uuid
|
|
|
|
|
|
class InvoiceService:
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def apply(self, user_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
invoice_type = InvoiceType(data["invoice_type"])
|
|
certification_id = None
|
|
cert = None
|
|
|
|
if invoice_type == InvoiceType.individual:
|
|
cert_result = await self.db.execute(
|
|
select(Certification)
|
|
.where(
|
|
Certification.user_id == uuid.UUID(user_id),
|
|
Certification.cert_type == "individual",
|
|
Certification.status == CertStatus.approved,
|
|
)
|
|
.limit(1)
|
|
)
|
|
cert = cert_result.scalar_one_or_none()
|
|
if not cert:
|
|
return {"error": "请先完成个人实名认证"}
|
|
certification_id = cert.id
|
|
|
|
else:
|
|
cert_result = await self.db.execute(
|
|
select(Certification)
|
|
.where(
|
|
Certification.user_id == uuid.UUID(user_id),
|
|
Certification.cert_type == "enterprise",
|
|
Certification.status == CertStatus.approved,
|
|
)
|
|
.limit(1)
|
|
)
|
|
cert = cert_result.scalar_one_or_none()
|
|
if not cert:
|
|
return {"error": "请先完成企业认证"}
|
|
certification_id = cert.id
|
|
|
|
invoice = Invoice(
|
|
user_id=uuid.UUID(user_id),
|
|
certification_id=certification_id,
|
|
invoice_type=invoice_type,
|
|
title=data["title"],
|
|
tax_id=data.get("tax_id"),
|
|
amount=data["amount"],
|
|
status=InvoiceStatus.pending,
|
|
)
|
|
self.db.add(invoice)
|
|
await self.db.flush()
|
|
return {"id": str(invoice.id), "status": invoice.status.value}
|
|
|
|
async def list_user(self, user_id: str) -> List[Dict[str, Any]]:
|
|
result = await self.db.execute(
|
|
select(Invoice)
|
|
.where(Invoice.user_id == uuid.UUID(user_id))
|
|
.order_by(desc(Invoice.created_at))
|
|
)
|
|
invoices = result.scalars().all()
|
|
return [
|
|
{
|
|
"id": str(inv.id),
|
|
"invoice_type": inv.invoice_type.value,
|
|
"title": inv.title,
|
|
"tax_id": inv.tax_id,
|
|
"amount": inv.amount,
|
|
"status": inv.status.value,
|
|
"reject_reason": inv.reject_reason,
|
|
"issued_at": inv.issued_at.isoformat() if inv.issued_at else None,
|
|
"created_at": inv.created_at.isoformat() if inv.created_at else None,
|
|
}
|
|
for inv in invoices
|
|
]
|
|
|
|
async def list_all(self, page: int, size: int, status: Optional[str] = None) -> Dict[str, Any]:
|
|
query = select(Invoice).order_by(desc(Invoice.created_at))
|
|
if status:
|
|
query = query.where(Invoice.status == InvoiceStatus(status))
|
|
offset = (page - 1) * size
|
|
result = await self.db.execute(query.offset(offset).limit(size))
|
|
invoices = result.scalars().all()
|
|
return {
|
|
"items": [
|
|
{
|
|
"id": str(inv.id),
|
|
"user_id": str(inv.user_id),
|
|
"invoice_type": inv.invoice_type.value,
|
|
"title": inv.title,
|
|
"tax_id": inv.tax_id,
|
|
"amount": inv.amount,
|
|
"status": inv.status.value,
|
|
"reject_reason": inv.reject_reason,
|
|
"issued_at": inv.issued_at.isoformat() if inv.issued_at else None,
|
|
"created_at": inv.created_at.isoformat() if inv.created_at else None,
|
|
}
|
|
for inv in invoices
|
|
],
|
|
"total": len(invoices),
|
|
"page": page,
|
|
"size": size,
|
|
}
|
|
|
|
async def process(self, invoice_id: str, action: str, reason: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
|
result = await self.db.execute(
|
|
select(Invoice).where(Invoice.id == uuid.UUID(invoice_id))
|
|
)
|
|
inv = result.scalar_one_or_none()
|
|
if not inv:
|
|
return None
|
|
if action == "issue":
|
|
inv.status = InvoiceStatus.issued
|
|
inv.issued_at = datetime.utcnow()
|
|
else:
|
|
inv.status = InvoiceStatus.rejected
|
|
inv.reject_reason = reason
|
|
await self.db.flush()
|
|
return {"id": str(inv.id), "status": inv.status.value}
|