from typing import List, Optional, Dict, Any from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_, or_ from sqlalchemy.orm import selectinload from backend.models import Channel, Notification async def get_channel(db: AsyncSession, channel_id: int) -> Optional[Channel]: result = await db.execute(select(Channel).where(Channel.id == channel_id)) return result.scalar_one_or_none() async def get_channel_by_name(db: AsyncSession, name: str) -> Optional[Channel]: result = await db.execute(select(Channel).where(Channel.name == name)) return result.scalar_one_or_none() async def get_channels(db: AsyncSession, skip: int = 0, limit: int = 100) -> List[Channel]: result = await db.execute( select(Channel).offset(skip).limit(limit) ) return result.scalars().all() async def get_channels_by_tags(db: AsyncSession, tags: List[str]) -> List[Channel]: """获取包含任一标签的所有活动通道""" # SQLite JSON 查询:使用 json_extract 或 like 匹配 all_channels = await get_channels(db, limit=1000) result = [] for channel in all_channels: if not channel.is_active: continue channel_tags = channel.tags or [] if any(tag in channel_tags for tag in tags): result.append(channel) return result async def create_channel(db: AsyncSession, channel_data: Dict[str, Any]) -> Channel: db_channel = Channel(**channel_data) db.add(db_channel) await db.commit() await db.refresh(db_channel) return db_channel async def update_channel( db: AsyncSession, channel_id: int, channel_data: Dict[str, Any] ) -> Optional[Channel]: db_channel = await get_channel(db, channel_id) if not db_channel: return None for key, value in channel_data.items(): if value is not None: setattr(db_channel, key, value) await db.commit() await db.refresh(db_channel) return db_channel async def delete_channel(db: AsyncSession, channel_id: int) -> bool: db_channel = await get_channel(db, channel_id) if not db_channel: return False await db.delete(db_channel) await db.commit() return True # Notification CRUD async def create_notification( db: AsyncSession, channel_id: int, notification_data: Dict[str, Any] ) -> Notification: db_notification = Notification(channel_id=channel_id, **notification_data) db.add(db_notification) await db.commit() await db.refresh(db_notification) return db_notification async def update_notification_status( db: AsyncSession, notification_id: int, status: str, error_msg: Optional[str] = None, sent_at = None ) -> Optional[Notification]: result = await db.execute( select(Notification).where(Notification.id == notification_id) ) notification = result.scalar_one_or_none() if notification: notification.status = status notification.error_msg = error_msg notification.sent_at = sent_at await db.commit() await db.refresh(notification) return notification async def get_notifications( db: AsyncSession, skip: int = 0, limit: int = 100, channel_id: Optional[int] = None, status: Optional[str] = None ) -> List[Notification]: query = select(Notification) if channel_id: query = query.where(Notification.channel_id == channel_id) if status: query = query.where(Notification.status == status) query = query.order_by(Notification.created_at.desc()) query = query.offset(skip).limit(limit) result = await db.execute(query) return result.scalars().all() async def get_notification_stats(db: AsyncSession) -> Dict[str, int]: from sqlalchemy import func result = await db.execute( select(Notification.status, func.count(Notification.id)) .group_by(Notification.status) ) stats = {status: count for status, count in result.fetchall()} total = await db.execute(select(func.count(Notification.id))) stats['total'] = total.scalar() return stats