feat: add database models and CRUD operations
This commit is contained in:
parent
d5344e244e
commit
1a5f32dc4e
1
backend/__init__.py
Normal file
1
backend/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# backend package
|
||||||
BIN
backend/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
backend/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/config.cpython-310.pyc
Normal file
BIN
backend/__pycache__/config.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/database.cpython-310.pyc
Normal file
BIN
backend/__pycache__/database.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/models.cpython-310.pyc
Normal file
BIN
backend/__pycache__/models.cpython-310.pyc
Normal file
Binary file not shown.
3
backend/config.py
Normal file
3
backend/config.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./notify_center.db")
|
||||||
130
backend/crud.py
Normal file
130
backend/crud.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select, and_, or_
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
from backend.models import Channel, Notification
|
||||||
|
|
||||||
|
async def get_channel(db: AsyncSession, channel_id: int) -> Optional[Channel]:
|
||||||
|
result = await db.execute(select(Channel).where(Channel.id == channel_id))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_channel_by_name(db: AsyncSession, name: str) -> Optional[Channel]:
|
||||||
|
result = await db.execute(select(Channel).where(Channel.name == name))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_channels(db: AsyncSession, skip: int = 0, limit: int = 100) -> List[Channel]:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Channel).offset(skip).limit(limit)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_channels_by_tags(db: AsyncSession, tags: List[str]) -> List[Channel]:
|
||||||
|
"""获取包含任一标签的所有活动通道"""
|
||||||
|
conditions = [Channel.tags.contains([tag]) for tag in tags]
|
||||||
|
result = await db.execute(
|
||||||
|
select(Channel).where(
|
||||||
|
and_(
|
||||||
|
Channel.is_active == True,
|
||||||
|
or_(*conditions)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def create_channel(db: AsyncSession, channel_data: Dict[str, Any]) -> Channel:
|
||||||
|
db_channel = Channel(**channel_data)
|
||||||
|
db.add(db_channel)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_channel)
|
||||||
|
return db_channel
|
||||||
|
|
||||||
|
async def update_channel(
|
||||||
|
db: AsyncSession,
|
||||||
|
channel_id: int,
|
||||||
|
channel_data: Dict[str, Any]
|
||||||
|
) -> Optional[Channel]:
|
||||||
|
db_channel = await get_channel(db, channel_id)
|
||||||
|
if not db_channel:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for key, value in channel_data.items():
|
||||||
|
if value is not None:
|
||||||
|
setattr(db_channel, key, value)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_channel)
|
||||||
|
return db_channel
|
||||||
|
|
||||||
|
async def delete_channel(db: AsyncSession, channel_id: int) -> bool:
|
||||||
|
db_channel = await get_channel(db, channel_id)
|
||||||
|
if not db_channel:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await db.delete(db_channel)
|
||||||
|
await db.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Notification CRUD
|
||||||
|
async def create_notification(
|
||||||
|
db: AsyncSession,
|
||||||
|
channel_id: int,
|
||||||
|
notification_data: Dict[str, Any]
|
||||||
|
) -> Notification:
|
||||||
|
db_notification = Notification(channel_id=channel_id, **notification_data)
|
||||||
|
db.add(db_notification)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_notification)
|
||||||
|
return db_notification
|
||||||
|
|
||||||
|
async def update_notification_status(
|
||||||
|
db: AsyncSession,
|
||||||
|
notification_id: int,
|
||||||
|
status: str,
|
||||||
|
error_msg: Optional[str] = None,
|
||||||
|
sent_at = None
|
||||||
|
) -> Optional[Notification]:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Notification).where(Notification.id == notification_id)
|
||||||
|
)
|
||||||
|
notification = result.scalar_one_or_none()
|
||||||
|
if notification:
|
||||||
|
notification.status = status
|
||||||
|
notification.error_msg = error_msg
|
||||||
|
notification.sent_at = sent_at
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(notification)
|
||||||
|
return notification
|
||||||
|
|
||||||
|
async def get_notifications(
|
||||||
|
db: AsyncSession,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
channel_id: Optional[int] = None,
|
||||||
|
status: Optional[str] = None
|
||||||
|
) -> List[Notification]:
|
||||||
|
query = select(Notification)
|
||||||
|
|
||||||
|
if channel_id:
|
||||||
|
query = query.where(Notification.channel_id == channel_id)
|
||||||
|
if status:
|
||||||
|
query = query.where(Notification.status == status)
|
||||||
|
|
||||||
|
query = query.order_by(Notification.created_at.desc())
|
||||||
|
query = query.offset(skip).limit(limit)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_notification_stats(db: AsyncSession) -> Dict[str, int]:
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Notification.status, func.count(Notification.id))
|
||||||
|
.group_by(Notification.status)
|
||||||
|
)
|
||||||
|
stats = {status: count for status, count in result.fetchall()}
|
||||||
|
|
||||||
|
total = await db.execute(select(func.count(Notification.id)))
|
||||||
|
stats['total'] = total.scalar()
|
||||||
|
|
||||||
|
return stats
|
||||||
22
backend/database.py
Normal file
22
backend/database.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||||
|
from sqlalchemy import text
|
||||||
|
from backend.config import DATABASE_URL
|
||||||
|
|
||||||
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
AsyncSessionLocal = sessionmaker(
|
||||||
|
engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_db():
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
async def init_db():
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
34
backend/models.py
Normal file
34
backend/models.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, JSON
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from backend.database import Base
|
||||||
|
|
||||||
|
class Channel(Base):
|
||||||
|
__tablename__ = "channels"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
name = Column(String, unique=True, nullable=False, index=True)
|
||||||
|
type = Column(String, nullable=False)
|
||||||
|
config = Column(JSON, default=dict)
|
||||||
|
tags = Column(JSON, default=list)
|
||||||
|
is_active = Column(Boolean, default=True)
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||||
|
|
||||||
|
notifications = relationship("Notification", back_populates="channel")
|
||||||
|
|
||||||
|
class Notification(Base):
|
||||||
|
__tablename__ = "notifications"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
channel_id = Column(Integer, ForeignKey("channels.id"))
|
||||||
|
title = Column(String, nullable=True)
|
||||||
|
body = Column(Text, nullable=False)
|
||||||
|
priority = Column(String, default="normal")
|
||||||
|
status = Column(String, default="pending")
|
||||||
|
error_msg = Column(Text, nullable=True)
|
||||||
|
sent_at = Column(DateTime, nullable=True)
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
channel = relationship("Channel", back_populates="notifications")
|
||||||
72
backend/schemas.py
Normal file
72
backend/schemas.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
# Channel Schemas
|
||||||
|
class ChannelBase(BaseModel):
|
||||||
|
name: str
|
||||||
|
type: str
|
||||||
|
config: Dict[str, Any] = {}
|
||||||
|
tags: List[str] = []
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
class ChannelCreate(ChannelBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ChannelUpdate(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
type: Optional[str] = None
|
||||||
|
config: Optional[Dict[str, Any]] = None
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
is_active: Optional[bool] = None
|
||||||
|
|
||||||
|
class Channel(ChannelBase):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
class ChannelList(BaseModel):
|
||||||
|
channels: List[Channel]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
# Notification Schemas
|
||||||
|
class NotificationBase(BaseModel):
|
||||||
|
title: Optional[str] = None
|
||||||
|
body: str
|
||||||
|
priority: str = "normal"
|
||||||
|
|
||||||
|
class NotificationCreate(NotificationBase):
|
||||||
|
channel_id: int
|
||||||
|
|
||||||
|
class Notification(NotificationBase):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
channel_id: int
|
||||||
|
status: str
|
||||||
|
error_msg: Optional[str] = None
|
||||||
|
sent_at: Optional[datetime] = None
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
class NotificationResult(BaseModel):
|
||||||
|
channel: str
|
||||||
|
channel_id: int
|
||||||
|
status: str
|
||||||
|
notification_id: int
|
||||||
|
error_msg: Optional[str] = None
|
||||||
|
|
||||||
|
class NotifyRequest(BaseModel):
|
||||||
|
channels: Optional[List[str]] = None
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
title: Optional[str] = None
|
||||||
|
body: str
|
||||||
|
priority: str = "normal"
|
||||||
|
|
||||||
|
class NotifyResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
results: List[NotificationResult]
|
||||||
|
total: int
|
||||||
|
sent: int
|
||||||
|
failed: int
|
||||||
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# tests package
|
||||||
BIN
backend/tests/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
backend/tests/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
23
backend/tests/test_database.py
Normal file
23
backend/tests/test_database.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from backend.database import engine, Base, get_db, init_db
|
||||||
|
from backend import models # 导入模型以注册表
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_database_connection():
|
||||||
|
"""测试数据库连接是否正常"""
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
result = await conn.execute(text("SELECT 1"))
|
||||||
|
assert result.scalar() == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tables_created():
|
||||||
|
"""测试表是否正确创建"""
|
||||||
|
await init_db()
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
result = await conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
|
||||||
|
tables = [row[0] for row in result.fetchall()]
|
||||||
|
assert "channels" in tables
|
||||||
|
assert "notifications" in tables
|
||||||
Loading…
x
Reference in New Issue
Block a user