feat: add database models and CRUD operations
This commit is contained in:
parent
d5344e244e
commit
1a5f32dc4e
1
backend/__init__.py
Normal file
1
backend/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# backend package
|
||||
BIN
backend/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
backend/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/config.cpython-310.pyc
Normal file
BIN
backend/__pycache__/config.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/database.cpython-310.pyc
Normal file
BIN
backend/__pycache__/database.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/models.cpython-310.pyc
Normal file
BIN
backend/__pycache__/models.cpython-310.pyc
Normal file
Binary file not shown.
3
backend/config.py
Normal file
3
backend/config.py
Normal file
@ -0,0 +1,3 @@
|
||||
import os
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./notify_center.db")
|
||||
130
backend/crud.py
Normal file
130
backend/crud.py
Normal file
@ -0,0 +1,130 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_
|
||||
from sqlalchemy.orm import selectinload
|
||||
from backend.models import Channel, Notification
|
||||
|
||||
async def get_channel(db: AsyncSession, channel_id: int) -> Optional[Channel]:
|
||||
result = await db.execute(select(Channel).where(Channel.id == channel_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_channel_by_name(db: AsyncSession, name: str) -> Optional[Channel]:
|
||||
result = await db.execute(select(Channel).where(Channel.name == name))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_channels(db: AsyncSession, skip: int = 0, limit: int = 100) -> List[Channel]:
|
||||
result = await db.execute(
|
||||
select(Channel).offset(skip).limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_channels_by_tags(db: AsyncSession, tags: List[str]) -> List[Channel]:
|
||||
"""获取包含任一标签的所有活动通道"""
|
||||
conditions = [Channel.tags.contains([tag]) for tag in tags]
|
||||
result = await db.execute(
|
||||
select(Channel).where(
|
||||
and_(
|
||||
Channel.is_active == True,
|
||||
or_(*conditions)
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def create_channel(db: AsyncSession, channel_data: Dict[str, Any]) -> Channel:
|
||||
db_channel = Channel(**channel_data)
|
||||
db.add(db_channel)
|
||||
await db.commit()
|
||||
await db.refresh(db_channel)
|
||||
return db_channel
|
||||
|
||||
async def update_channel(
|
||||
db: AsyncSession,
|
||||
channel_id: int,
|
||||
channel_data: Dict[str, Any]
|
||||
) -> Optional[Channel]:
|
||||
db_channel = await get_channel(db, channel_id)
|
||||
if not db_channel:
|
||||
return None
|
||||
|
||||
for key, value in channel_data.items():
|
||||
if value is not None:
|
||||
setattr(db_channel, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_channel)
|
||||
return db_channel
|
||||
|
||||
async def delete_channel(db: AsyncSession, channel_id: int) -> bool:
|
||||
db_channel = await get_channel(db, channel_id)
|
||||
if not db_channel:
|
||||
return False
|
||||
|
||||
await db.delete(db_channel)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
# Notification CRUD
|
||||
async def create_notification(
|
||||
db: AsyncSession,
|
||||
channel_id: int,
|
||||
notification_data: Dict[str, Any]
|
||||
) -> Notification:
|
||||
db_notification = Notification(channel_id=channel_id, **notification_data)
|
||||
db.add(db_notification)
|
||||
await db.commit()
|
||||
await db.refresh(db_notification)
|
||||
return db_notification
|
||||
|
||||
async def update_notification_status(
|
||||
db: AsyncSession,
|
||||
notification_id: int,
|
||||
status: str,
|
||||
error_msg: Optional[str] = None,
|
||||
sent_at = None
|
||||
) -> Optional[Notification]:
|
||||
result = await db.execute(
|
||||
select(Notification).where(Notification.id == notification_id)
|
||||
)
|
||||
notification = result.scalar_one_or_none()
|
||||
if notification:
|
||||
notification.status = status
|
||||
notification.error_msg = error_msg
|
||||
notification.sent_at = sent_at
|
||||
await db.commit()
|
||||
await db.refresh(notification)
|
||||
return notification
|
||||
|
||||
async def get_notifications(
|
||||
db: AsyncSession,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
channel_id: Optional[int] = None,
|
||||
status: Optional[str] = None
|
||||
) -> List[Notification]:
|
||||
query = select(Notification)
|
||||
|
||||
if channel_id:
|
||||
query = query.where(Notification.channel_id == channel_id)
|
||||
if status:
|
||||
query = query.where(Notification.status == status)
|
||||
|
||||
query = query.order_by(Notification.created_at.desc())
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_notification_stats(db: AsyncSession) -> Dict[str, int]:
|
||||
from sqlalchemy import func
|
||||
|
||||
result = await db.execute(
|
||||
select(Notification.status, func.count(Notification.id))
|
||||
.group_by(Notification.status)
|
||||
)
|
||||
stats = {status: count for status, count in result.fetchall()}
|
||||
|
||||
total = await db.execute(select(func.count(Notification.id)))
|
||||
stats['total'] = total.scalar()
|
||||
|
||||
return stats
|
||||
22
backend/database.py
Normal file
22
backend/database.py
Normal file
@ -0,0 +1,22 @@
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from sqlalchemy import text
|
||||
from backend.config import DATABASE_URL
|
||||
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
Base = declarative_base()
|
||||
|
||||
AsyncSessionLocal = sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def get_db():
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def init_db():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
34
backend/models.py
Normal file
34
backend/models.py
Normal file
@ -0,0 +1,34 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from backend.database import Base
|
||||
|
||||
class Channel(Base):
|
||||
__tablename__ = "channels"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, unique=True, nullable=False, index=True)
|
||||
type = Column(String, nullable=False)
|
||||
config = Column(JSON, default=dict)
|
||||
tags = Column(JSON, default=list)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
notifications = relationship("Notification", back_populates="channel")
|
||||
|
||||
class Notification(Base):
|
||||
__tablename__ = "notifications"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
channel_id = Column(Integer, ForeignKey("channels.id"))
|
||||
title = Column(String, nullable=True)
|
||||
body = Column(Text, nullable=False)
|
||||
priority = Column(String, default="normal")
|
||||
status = Column(String, default="pending")
|
||||
error_msg = Column(Text, nullable=True)
|
||||
sent_at = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
channel = relationship("Channel", back_populates="notifications")
|
||||
72
backend/schemas.py
Normal file
72
backend/schemas.py
Normal file
@ -0,0 +1,72 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
# Channel Schemas
|
||||
class ChannelBase(BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
config: Dict[str, Any] = {}
|
||||
tags: List[str] = []
|
||||
is_active: bool = True
|
||||
|
||||
class ChannelCreate(ChannelBase):
|
||||
pass
|
||||
|
||||
class ChannelUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
tags: Optional[List[str]] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class Channel(ChannelBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class ChannelList(BaseModel):
|
||||
channels: List[Channel]
|
||||
total: int
|
||||
|
||||
# Notification Schemas
|
||||
class NotificationBase(BaseModel):
|
||||
title: Optional[str] = None
|
||||
body: str
|
||||
priority: str = "normal"
|
||||
|
||||
class NotificationCreate(NotificationBase):
|
||||
channel_id: int
|
||||
|
||||
class Notification(NotificationBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
channel_id: int
|
||||
status: str
|
||||
error_msg: Optional[str] = None
|
||||
sent_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
|
||||
class NotificationResult(BaseModel):
|
||||
channel: str
|
||||
channel_id: int
|
||||
status: str
|
||||
notification_id: int
|
||||
error_msg: Optional[str] = None
|
||||
|
||||
class NotifyRequest(BaseModel):
|
||||
channels: Optional[List[str]] = None
|
||||
tags: Optional[List[str]] = None
|
||||
title: Optional[str] = None
|
||||
body: str
|
||||
priority: str = "normal"
|
||||
|
||||
class NotifyResponse(BaseModel):
|
||||
success: bool
|
||||
results: List[NotificationResult]
|
||||
total: int
|
||||
sent: int
|
||||
failed: int
|
||||
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# tests package
|
||||
BIN
backend/tests/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
backend/tests/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
23
backend/tests/test_database.py
Normal file
23
backend/tests/test_database.py
Normal file
@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from backend.database import engine, Base, get_db, init_db
|
||||
from backend import models # 导入模型以注册表
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_connection():
|
||||
"""测试数据库连接是否正常"""
|
||||
async with engine.begin() as conn:
|
||||
result = await conn.execute(text("SELECT 1"))
|
||||
assert result.scalar() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tables_created():
|
||||
"""测试表是否正确创建"""
|
||||
await init_db()
|
||||
async with engine.begin() as conn:
|
||||
result = await conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
|
||||
tables = [row[0] for row in result.fetchall()]
|
||||
assert "channels" in tables
|
||||
assert "notifications" in tables
|
||||
Loading…
x
Reference in New Issue
Block a user