131 lines
4.1 KiB
Python

from typing import List, Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_
from sqlalchemy.orm import selectinload
from backend.models import Channel, Notification
async def get_channel(db: AsyncSession, channel_id: int) -> Optional[Channel]:
result = await db.execute(select(Channel).where(Channel.id == channel_id))
return result.scalar_one_or_none()
async def get_channel_by_name(db: AsyncSession, name: str) -> Optional[Channel]:
result = await db.execute(select(Channel).where(Channel.name == name))
return result.scalar_one_or_none()
async def get_channels(db: AsyncSession, skip: int = 0, limit: int = 100) -> List[Channel]:
result = await db.execute(
select(Channel).offset(skip).limit(limit)
)
return result.scalars().all()
async def get_channels_by_tags(db: AsyncSession, tags: List[str]) -> List[Channel]:
"""获取包含任一标签的所有活动通道"""
# SQLite JSON 查询:使用 json_extract 或 like 匹配
all_channels = await get_channels(db, limit=1000)
result = []
for channel in all_channels:
if not channel.is_active:
continue
channel_tags = channel.tags or []
if any(tag in channel_tags for tag in tags):
result.append(channel)
return result
async def create_channel(db: AsyncSession, channel_data: Dict[str, Any]) -> Channel:
db_channel = Channel(**channel_data)
db.add(db_channel)
await db.commit()
await db.refresh(db_channel)
return db_channel
async def update_channel(
db: AsyncSession,
channel_id: int,
channel_data: Dict[str, Any]
) -> Optional[Channel]:
db_channel = await get_channel(db, channel_id)
if not db_channel:
return None
for key, value in channel_data.items():
if value is not None:
setattr(db_channel, key, value)
await db.commit()
await db.refresh(db_channel)
return db_channel
async def delete_channel(db: AsyncSession, channel_id: int) -> bool:
db_channel = await get_channel(db, channel_id)
if not db_channel:
return False
await db.delete(db_channel)
await db.commit()
return True
# Notification CRUD
async def create_notification(
db: AsyncSession,
channel_id: int,
notification_data: Dict[str, Any]
) -> Notification:
db_notification = Notification(channel_id=channel_id, **notification_data)
db.add(db_notification)
await db.commit()
await db.refresh(db_notification)
return db_notification
async def update_notification_status(
db: AsyncSession,
notification_id: int,
status: str,
error_msg: Optional[str] = None,
sent_at = None
) -> Optional[Notification]:
result = await db.execute(
select(Notification).where(Notification.id == notification_id)
)
notification = result.scalar_one_or_none()
if notification:
notification.status = status
notification.error_msg = error_msg
notification.sent_at = sent_at
await db.commit()
await db.refresh(notification)
return notification
async def get_notifications(
db: AsyncSession,
skip: int = 0,
limit: int = 100,
channel_id: Optional[int] = None,
status: Optional[str] = None
) -> List[Notification]:
query = select(Notification)
if channel_id:
query = query.where(Notification.channel_id == channel_id)
if status:
query = query.where(Notification.status == status)
query = query.order_by(Notification.created_at.desc())
query = query.offset(skip).limit(limit)
result = await db.execute(query)
return result.scalars().all()
async def get_notification_stats(db: AsyncSession) -> Dict[str, int]:
from sqlalchemy import func
result = await db.execute(
select(Notification.status, func.count(Notification.id))
.group_by(Notification.status)
)
stats = {status: count for status, count in result.fetchall()}
total = await db.execute(select(func.count(Notification.id)))
stats['total'] = total.scalar()
return stats