feat: add database models and CRUD operations

This commit is contained in:
OpenClaw Agent 2026-02-07 17:25:47 +00:00
parent d5344e244e
commit 1a5f32dc4e
14 changed files with 286 additions and 0 deletions

1
backend/__init__.py Normal file
View File

@ -0,0 +1 @@
# backend package

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

3
backend/config.py Normal file
View File

@ -0,0 +1,3 @@
import os
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./notify_center.db")

130
backend/crud.py Normal file
View File

@ -0,0 +1,130 @@
from typing import List, Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_
from sqlalchemy.orm import selectinload
from backend.models import Channel, Notification
async def get_channel(db: AsyncSession, channel_id: int) -> Optional[Channel]:
result = await db.execute(select(Channel).where(Channel.id == channel_id))
return result.scalar_one_or_none()
async def get_channel_by_name(db: AsyncSession, name: str) -> Optional[Channel]:
result = await db.execute(select(Channel).where(Channel.name == name))
return result.scalar_one_or_none()
async def get_channels(db: AsyncSession, skip: int = 0, limit: int = 100) -> List[Channel]:
result = await db.execute(
select(Channel).offset(skip).limit(limit)
)
return result.scalars().all()
async def get_channels_by_tags(db: AsyncSession, tags: List[str]) -> List[Channel]:
"""获取包含任一标签的所有活动通道"""
conditions = [Channel.tags.contains([tag]) for tag in tags]
result = await db.execute(
select(Channel).where(
and_(
Channel.is_active == True,
or_(*conditions)
)
)
)
return result.scalars().all()
async def create_channel(db: AsyncSession, channel_data: Dict[str, Any]) -> Channel:
db_channel = Channel(**channel_data)
db.add(db_channel)
await db.commit()
await db.refresh(db_channel)
return db_channel
async def update_channel(
db: AsyncSession,
channel_id: int,
channel_data: Dict[str, Any]
) -> Optional[Channel]:
db_channel = await get_channel(db, channel_id)
if not db_channel:
return None
for key, value in channel_data.items():
if value is not None:
setattr(db_channel, key, value)
await db.commit()
await db.refresh(db_channel)
return db_channel
async def delete_channel(db: AsyncSession, channel_id: int) -> bool:
db_channel = await get_channel(db, channel_id)
if not db_channel:
return False
await db.delete(db_channel)
await db.commit()
return True
# Notification CRUD
async def create_notification(
db: AsyncSession,
channel_id: int,
notification_data: Dict[str, Any]
) -> Notification:
db_notification = Notification(channel_id=channel_id, **notification_data)
db.add(db_notification)
await db.commit()
await db.refresh(db_notification)
return db_notification
async def update_notification_status(
db: AsyncSession,
notification_id: int,
status: str,
error_msg: Optional[str] = None,
sent_at = None
) -> Optional[Notification]:
result = await db.execute(
select(Notification).where(Notification.id == notification_id)
)
notification = result.scalar_one_or_none()
if notification:
notification.status = status
notification.error_msg = error_msg
notification.sent_at = sent_at
await db.commit()
await db.refresh(notification)
return notification
async def get_notifications(
db: AsyncSession,
skip: int = 0,
limit: int = 100,
channel_id: Optional[int] = None,
status: Optional[str] = None
) -> List[Notification]:
query = select(Notification)
if channel_id:
query = query.where(Notification.channel_id == channel_id)
if status:
query = query.where(Notification.status == status)
query = query.order_by(Notification.created_at.desc())
query = query.offset(skip).limit(limit)
result = await db.execute(query)
return result.scalars().all()
async def get_notification_stats(db: AsyncSession) -> Dict[str, int]:
from sqlalchemy import func
result = await db.execute(
select(Notification.status, func.count(Notification.id))
.group_by(Notification.status)
)
stats = {status: count for status, count in result.fetchall()}
total = await db.execute(select(func.count(Notification.id)))
stats['total'] = total.scalar()
return stats

22
backend/database.py Normal file
View File

@ -0,0 +1,22 @@
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import text
from backend.config import DATABASE_URL
engine = create_async_engine(DATABASE_URL, echo=False)
Base = declarative_base()
AsyncSessionLocal = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
async def get_db():
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

34
backend/models.py Normal file
View File

@ -0,0 +1,34 @@
from datetime import datetime
from typing import Optional, List
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, JSON
from sqlalchemy.orm import relationship
from backend.database import Base
class Channel(Base):
__tablename__ = "channels"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, unique=True, nullable=False, index=True)
type = Column(String, nullable=False)
config = Column(JSON, default=dict)
tags = Column(JSON, default=list)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
notifications = relationship("Notification", back_populates="channel")
class Notification(Base):
__tablename__ = "notifications"
id = Column(Integer, primary_key=True, index=True)
channel_id = Column(Integer, ForeignKey("channels.id"))
title = Column(String, nullable=True)
body = Column(Text, nullable=False)
priority = Column(String, default="normal")
status = Column(String, default="pending")
error_msg = Column(Text, nullable=True)
sent_at = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
channel = relationship("Channel", back_populates="notifications")

72
backend/schemas.py Normal file
View File

@ -0,0 +1,72 @@
from datetime import datetime
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, ConfigDict
# Channel Schemas
class ChannelBase(BaseModel):
name: str
type: str
config: Dict[str, Any] = {}
tags: List[str] = []
is_active: bool = True
class ChannelCreate(ChannelBase):
pass
class ChannelUpdate(BaseModel):
name: Optional[str] = None
type: Optional[str] = None
config: Optional[Dict[str, Any]] = None
tags: Optional[List[str]] = None
is_active: Optional[bool] = None
class Channel(ChannelBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
updated_at: datetime
class ChannelList(BaseModel):
channels: List[Channel]
total: int
# Notification Schemas
class NotificationBase(BaseModel):
title: Optional[str] = None
body: str
priority: str = "normal"
class NotificationCreate(NotificationBase):
channel_id: int
class Notification(NotificationBase):
model_config = ConfigDict(from_attributes=True)
id: int
channel_id: int
status: str
error_msg: Optional[str] = None
sent_at: Optional[datetime] = None
created_at: datetime
class NotificationResult(BaseModel):
channel: str
channel_id: int
status: str
notification_id: int
error_msg: Optional[str] = None
class NotifyRequest(BaseModel):
channels: Optional[List[str]] = None
tags: Optional[List[str]] = None
title: Optional[str] = None
body: str
priority: str = "normal"
class NotifyResponse(BaseModel):
success: bool
results: List[NotificationResult]
total: int
sent: int
failed: int

View File

@ -0,0 +1 @@
# tests package

Binary file not shown.

View File

@ -0,0 +1,23 @@
import pytest
import asyncio
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from backend.database import engine, Base, get_db, init_db
from backend import models # 导入模型以注册表
@pytest.mark.asyncio
async def test_database_connection():
"""测试数据库连接是否正常"""
async with engine.begin() as conn:
result = await conn.execute(text("SELECT 1"))
assert result.scalar() == 1
@pytest.mark.asyncio
async def test_tables_created():
"""测试表是否正确创建"""
await init_db()
async with engine.begin() as conn:
result = await conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
tables = [row[0] for row in result.fetchall()]
assert "channels" in tables
assert "notifications" in tables