feat: add notification service with apprise integration and tests

This commit is contained in:
OpenClaw Agent 2026-02-07 17:30:04 +00:00
parent 1a5f32dc4e
commit 53bc48fee6
6 changed files with 305 additions and 10 deletions

Binary file not shown.

Binary file not shown.

View File

@ -20,16 +20,16 @@ async def get_channels(db: AsyncSession, skip: int = 0, limit: int = 100) -> Lis
async def get_channels_by_tags(db: AsyncSession, tags: List[str]) -> List[Channel]: async def get_channels_by_tags(db: AsyncSession, tags: List[str]) -> List[Channel]:
"""获取包含任一标签的所有活动通道""" """获取包含任一标签的所有活动通道"""
conditions = [Channel.tags.contains([tag]) for tag in tags] # SQLite JSON 查询:使用 json_extract 或 like 匹配
result = await db.execute( all_channels = await get_channels(db, limit=1000)
select(Channel).where( result = []
and_( for channel in all_channels:
Channel.is_active == True, if not channel.is_active:
or_(*conditions) continue
) channel_tags = channel.tags or []
) if any(tag in channel_tags for tag in tags):
) result.append(channel)
return result.scalars().all() return result
async def create_channel(db: AsyncSession, channel_data: Dict[str, Any]) -> Channel: async def create_channel(db: AsyncSession, channel_data: Dict[str, Any]) -> Channel:
db_channel = Channel(**channel_data) db_channel = Channel(**channel_data)

189
backend/notify_service.py Normal file
View File

@ -0,0 +1,189 @@
from typing import List, Optional, Dict, Any
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
import apprise
from backend import crud
from backend.models import Channel
class NotifyService:
def __init__(self):
self.apobj = apprise.Apprise()
def _build_apprise_url(self, channel: Channel) -> Optional[str]:
"""根据通道类型构建 apprise URL"""
config = channel.config or {}
if channel.type == "discord":
webhook_url = config.get("webhook_url", "")
if webhook_url:
return f"{webhook_url}"
elif channel.type == "telegram":
bot_token = config.get("bot_token", "")
chat_id = config.get("chat_id", "")
if bot_token and chat_id:
return f"tgram://{bot_token}/{chat_id}"
elif channel.type == "email":
smtp_host = config.get("smtp_host", "")
smtp_port = config.get("smtp_port", 587)
username = config.get("username", "")
password = config.get("password", "")
to_email = config.get("to_email", "")
if username and password and to_email:
return f"mailtos://{username}:{password}@{smtp_host}:{smtp_port}?to={to_email}"
elif channel.type == "slack":
webhook_url = config.get("webhook_url", "")
if webhook_url:
return f"{webhook_url}"
elif channel.type == "webhook":
url = config.get("url", "")
if url:
return f"json://{url.replace('https://', '').replace('http://', '')}"
# 支持 apprise 原生 URL 格式
elif channel.type == "apprise":
return config.get("url", "")
return None
async def send_notification(
self,
db: AsyncSession,
channel_id: int,
title: Optional[str],
body: str,
priority: str = "normal"
) -> Dict[str, Any]:
"""发送单条通知到指定通道"""
channel = await crud.get_channel(db, channel_id)
if not channel:
return {
"channel_id": channel_id,
"status": "failed",
"error_msg": "Channel not found"
}
if not channel.is_active:
return {
"channel": channel.name,
"channel_id": channel_id,
"status": "skipped",
"error_msg": "Channel is inactive"
}
# 创建通知记录
notification = await crud.create_notification(
db, channel_id, {
"title": title,
"body": body,
"priority": priority,
"status": "pending"
}
)
# 构建 apprise URL
apprise_url = self._build_apprise_url(channel)
if not apprise_url:
error_msg = f"Invalid configuration for channel type: {channel.type}"
await crud.update_notification_status(
db, notification.id, "failed", error_msg
)
return {
"channel": channel.name,
"channel_id": channel_id,
"status": "failed",
"notification_id": notification.id,
"error_msg": error_msg
}
try:
# 发送通知
apobj = apprise.Apprise()
apobj.add(apprise_url)
# 构建消息
message = body
if title:
message = f"**{title}**\n\n{body}"
# 发送
result = apobj.notify(body=message)
if result:
await crud.update_notification_status(
db, notification.id, "sent", sent_at=datetime.utcnow()
)
return {
"channel": channel.name,
"channel_id": channel_id,
"status": "sent",
"notification_id": notification.id
}
else:
error_msg = "Failed to send notification"
await crud.update_notification_status(
db, notification.id, "failed", error_msg
)
return {
"channel": channel.name,
"channel_id": channel_id,
"status": "failed",
"notification_id": notification.id,
"error_msg": error_msg
}
except Exception as e:
error_msg = str(e)
await crud.update_notification_status(
db, notification.id, "failed", error_msg
)
return {
"channel": channel.name,
"channel_id": channel_id,
"status": "failed",
"notification_id": notification.id,
"error_msg": error_msg
}
async def send_to_channels(
self,
db: AsyncSession,
channels: Optional[List[str]],
tags: Optional[List[str]],
title: Optional[str],
body: str,
priority: str = "normal"
) -> List[Dict[str, Any]]:
"""批量发送通知到多个通道或按标签发送"""
results = []
target_channels = []
# 按名称获取通道
if channels:
for channel_name in channels:
channel = await crud.get_channel_by_name(db, channel_name)
if channel:
target_channels.append(channel)
# 按标签获取通道
if tags:
tagged_channels = await crud.get_channels_by_tags(db, tags)
# 合并去重
existing_ids = {c.id for c in target_channels}
for channel in tagged_channels:
if channel.id not in existing_ids:
target_channels.append(channel)
existing_ids.add(channel.id)
# 发送通知
for channel in target_channels:
result = await self.send_notification(
db, channel.id, title, body, priority
)
results.append(result)
return results

View File

@ -0,0 +1,106 @@
import pytest
import pytest_asyncio
import uuid
from unittest.mock import Mock, patch
from sqlalchemy.ext.asyncio import AsyncSession
from backend.database import AsyncSessionLocal, init_db
from backend import crud
from backend.notify_service import NotifyService
@pytest_asyncio.fixture
async def db():
await init_db()
async with AsyncSessionLocal() as session:
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_send_to_discord(db):
"""测试发送到 Discord"""
unique_name = f"测试Discord_{uuid.uuid4().hex[:8]}"
# 创建测试通道
channel = await crud.create_channel(db, {
"name": unique_name,
"type": "discord",
"config": {"webhook_url": "https://discord.com/api/webhooks/test"}
})
service = NotifyService()
with patch('apprise.Apprise.notify') as mock_notify:
mock_notify.return_value = True
result = await service.send_notification(
db, channel.id, "测试标题", "测试内容", "normal"
)
assert result['status'] == 'sent'
assert result['channel_id'] == channel.id
mock_notify.assert_called_once()
@pytest.mark.asyncio
async def test_send_to_multiple_channels(db):
"""测试批量发送到多个通道"""
# 创建多个通道
name_a = f"通道A_{uuid.uuid4().hex[:8]}"
name_b = f"通道B_{uuid.uuid4().hex[:8]}"
await crud.create_channel(db, {
"name": name_a,
"type": "discord",
"config": {"webhook_url": "https://discord.com/webhook"},
"tags": ["alerts"]
})
await crud.create_channel(db, {
"name": name_b,
"type": "telegram",
"config": {"bot_token": "123456:ABC", "chat_id": "12345"},
"tags": ["alerts"]
})
service = NotifyService()
with patch('apprise.Apprise.notify') as mock_notify:
mock_notify.return_value = True
results = await service.send_to_channels(
db, [name_a, name_b], None, "测试", "内容", "normal"
)
assert len(results) == 2
assert all(r['status'] == 'sent' for r in results)
@pytest.mark.asyncio
async def test_send_by_tags(db):
"""测试按标签发送"""
unique_tag = f"alerts_{uuid.uuid4().hex[:8]}"
name_prod = f"生产告警_{uuid.uuid4().hex[:8]}"
name_dev = f"开发告警_{uuid.uuid4().hex[:8]}"
name_info = f"普通通知_{uuid.uuid4().hex[:8]}"
await crud.create_channel(db, {
"name": name_prod,
"type": "discord",
"config": {"webhook_url": "https://discord.com/webhook"},
"tags": [unique_tag, "production"]
})
await crud.create_channel(db, {
"name": name_dev,
"type": "telegram",
"config": {"bot_token": "123456:ABC", "chat_id": "12345"},
"tags": [unique_tag, "dev"]
})
await crud.create_channel(db, {
"name": name_info,
"type": "email",
"config": {"username": "test@test.com", "password": "pass", "to_email": "to@test.com", "smtp_host": "smtp.test.com"},
"tags": ["info"]
})
service = NotifyService()
with patch('apprise.Apprise.notify') as mock_notify:
mock_notify.return_value = True
results = await service.send_to_channels(
db, None, [unique_tag], "告警", "服务器异常", "high"
)
assert len(results) == 2 # 只有带唯一标签的通道
assert mock_notify.call_count == 2