from typing import Dict, Any, List, Optional from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, and_ from app.models.notification import Notification from datetime import datetime class NotificationService: def __init__(self, db: AsyncSession): self.db = db async def list_notifications( self, user_id: str, page: int = 1, size: int = 20, unread_only: bool = False ) -> Dict[str, Any]: query = select(Notification).where(Notification.user_id == user_id) if unread_only: query = query.where(Notification.is_read == False) query = query.order_by(Notification.created_at.desc()).offset( (page - 1) * size ).limit(size) count_query = select(func.count(Notification.id)).where( Notification.user_id == user_id ) if unread_only: count_query = count_query.where(Notification.is_read == False) total = await self.db.execute(count_query) result = await self.db.execute(query) notifications = result.scalars().all() return { "items": [ { "id": str(n.id), "title": n.title, "content": n.content, "type": n.notification_type, "reference_type": n.reference_type, "reference_id": n.reference_id, "is_read": n.is_read, "created_at": n.created_at.isoformat() if n.created_at else None, } for n in notifications ], "total": total.scalar() or 0, "page": page, "size": size, } async def get_unread_count(self, user_id: str) -> int: result = await self.db.execute( select(func.count(Notification.id)).where( and_(Notification.user_id == user_id, Notification.is_read == False) ) ) return result.scalar() or 0 async def mark_read(self, user_id: str, notification_id: str) -> bool: result = await self.db.execute( select(Notification).where( and_( Notification.id == notification_id, Notification.user_id == user_id, ) ) ) n = result.scalar_one_or_none() if not n: return False n.is_read = True await self.db.flush() return True async def mark_all_read(self, user_id: str) -> int: result = await self.db.execute( select(Notification).where( and_(Notification.user_id == user_id, Notification.is_read == False) ) ) notifications = result.scalars().all() for n in notifications: n.is_read = True await self.db.flush() return len(notifications) async def delete_notification(self, user_id: str, notification_id: str) -> bool: result = await self.db.execute( select(Notification).where( and_( Notification.id == notification_id, Notification.user_id == user_id, ) ) ) n = result.scalar_one_or_none() if not n: return False await self.db.delete(n) await self.db.flush() return True @staticmethod async def create_notification( db: AsyncSession, user_id: str, title: str, content: str, notification_type: str = "system", reference_type: str = None, reference_id: str = None, ): n = Notification( user_id=user_id, title=title, content=content, notification_type=notification_type, reference_type=reference_type, reference_id=reference_id, ) db.add(n) await db.flush() return n