feat: add notification service with apprise integration and tests
This commit is contained in:
parent
1a5f32dc4e
commit
53bc48fee6
BIN
backend/__pycache__/crud.cpython-310.pyc
Normal file
BIN
backend/__pycache__/crud.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend/__pycache__/notify_service.cpython-310.pyc
Normal file
BIN
backend/__pycache__/notify_service.cpython-310.pyc
Normal file
Binary file not shown.
@ -20,16 +20,16 @@ async def get_channels(db: AsyncSession, skip: int = 0, limit: int = 100) -> Lis
|
||||
|
||||
async def get_channels_by_tags(db: AsyncSession, tags: List[str]) -> List[Channel]:
|
||||
"""获取包含任一标签的所有活动通道"""
|
||||
conditions = [Channel.tags.contains([tag]) for tag in tags]
|
||||
result = await db.execute(
|
||||
select(Channel).where(
|
||||
and_(
|
||||
Channel.is_active == True,
|
||||
or_(*conditions)
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
# SQLite JSON 查询:使用 json_extract 或 like 匹配
|
||||
all_channels = await get_channels(db, limit=1000)
|
||||
result = []
|
||||
for channel in all_channels:
|
||||
if not channel.is_active:
|
||||
continue
|
||||
channel_tags = channel.tags or []
|
||||
if any(tag in channel_tags for tag in tags):
|
||||
result.append(channel)
|
||||
return result
|
||||
|
||||
async def create_channel(db: AsyncSession, channel_data: Dict[str, Any]) -> Channel:
|
||||
db_channel = Channel(**channel_data)
|
||||
|
||||
189
backend/notify_service.py
Normal file
189
backend/notify_service.py
Normal file
@ -0,0 +1,189 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import apprise
|
||||
from backend import crud
|
||||
from backend.models import Channel
|
||||
|
||||
class NotifyService:
|
||||
def __init__(self):
|
||||
self.apobj = apprise.Apprise()
|
||||
|
||||
def _build_apprise_url(self, channel: Channel) -> Optional[str]:
|
||||
"""根据通道类型构建 apprise URL"""
|
||||
config = channel.config or {}
|
||||
|
||||
if channel.type == "discord":
|
||||
webhook_url = config.get("webhook_url", "")
|
||||
if webhook_url:
|
||||
return f"{webhook_url}"
|
||||
|
||||
elif channel.type == "telegram":
|
||||
bot_token = config.get("bot_token", "")
|
||||
chat_id = config.get("chat_id", "")
|
||||
if bot_token and chat_id:
|
||||
return f"tgram://{bot_token}/{chat_id}"
|
||||
|
||||
elif channel.type == "email":
|
||||
smtp_host = config.get("smtp_host", "")
|
||||
smtp_port = config.get("smtp_port", 587)
|
||||
username = config.get("username", "")
|
||||
password = config.get("password", "")
|
||||
to_email = config.get("to_email", "")
|
||||
|
||||
if username and password and to_email:
|
||||
return f"mailtos://{username}:{password}@{smtp_host}:{smtp_port}?to={to_email}"
|
||||
|
||||
elif channel.type == "slack":
|
||||
webhook_url = config.get("webhook_url", "")
|
||||
if webhook_url:
|
||||
return f"{webhook_url}"
|
||||
|
||||
elif channel.type == "webhook":
|
||||
url = config.get("url", "")
|
||||
if url:
|
||||
return f"json://{url.replace('https://', '').replace('http://', '')}"
|
||||
|
||||
# 支持 apprise 原生 URL 格式
|
||||
elif channel.type == "apprise":
|
||||
return config.get("url", "")
|
||||
|
||||
return None
|
||||
|
||||
async def send_notification(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
channel_id: int,
|
||||
title: Optional[str],
|
||||
body: str,
|
||||
priority: str = "normal"
|
||||
) -> Dict[str, Any]:
|
||||
"""发送单条通知到指定通道"""
|
||||
channel = await crud.get_channel(db, channel_id)
|
||||
if not channel:
|
||||
return {
|
||||
"channel_id": channel_id,
|
||||
"status": "failed",
|
||||
"error_msg": "Channel not found"
|
||||
}
|
||||
|
||||
if not channel.is_active:
|
||||
return {
|
||||
"channel": channel.name,
|
||||
"channel_id": channel_id,
|
||||
"status": "skipped",
|
||||
"error_msg": "Channel is inactive"
|
||||
}
|
||||
|
||||
# 创建通知记录
|
||||
notification = await crud.create_notification(
|
||||
db, channel_id, {
|
||||
"title": title,
|
||||
"body": body,
|
||||
"priority": priority,
|
||||
"status": "pending"
|
||||
}
|
||||
)
|
||||
|
||||
# 构建 apprise URL
|
||||
apprise_url = self._build_apprise_url(channel)
|
||||
if not apprise_url:
|
||||
error_msg = f"Invalid configuration for channel type: {channel.type}"
|
||||
await crud.update_notification_status(
|
||||
db, notification.id, "failed", error_msg
|
||||
)
|
||||
return {
|
||||
"channel": channel.name,
|
||||
"channel_id": channel_id,
|
||||
"status": "failed",
|
||||
"notification_id": notification.id,
|
||||
"error_msg": error_msg
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送通知
|
||||
apobj = apprise.Apprise()
|
||||
apobj.add(apprise_url)
|
||||
|
||||
# 构建消息
|
||||
message = body
|
||||
if title:
|
||||
message = f"**{title}**\n\n{body}"
|
||||
|
||||
# 发送
|
||||
result = apobj.notify(body=message)
|
||||
|
||||
if result:
|
||||
await crud.update_notification_status(
|
||||
db, notification.id, "sent", sent_at=datetime.utcnow()
|
||||
)
|
||||
return {
|
||||
"channel": channel.name,
|
||||
"channel_id": channel_id,
|
||||
"status": "sent",
|
||||
"notification_id": notification.id
|
||||
}
|
||||
else:
|
||||
error_msg = "Failed to send notification"
|
||||
await crud.update_notification_status(
|
||||
db, notification.id, "failed", error_msg
|
||||
)
|
||||
return {
|
||||
"channel": channel.name,
|
||||
"channel_id": channel_id,
|
||||
"status": "failed",
|
||||
"notification_id": notification.id,
|
||||
"error_msg": error_msg
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
await crud.update_notification_status(
|
||||
db, notification.id, "failed", error_msg
|
||||
)
|
||||
return {
|
||||
"channel": channel.name,
|
||||
"channel_id": channel_id,
|
||||
"status": "failed",
|
||||
"notification_id": notification.id,
|
||||
"error_msg": error_msg
|
||||
}
|
||||
|
||||
async def send_to_channels(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
channels: Optional[List[str]],
|
||||
tags: Optional[List[str]],
|
||||
title: Optional[str],
|
||||
body: str,
|
||||
priority: str = "normal"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""批量发送通知到多个通道或按标签发送"""
|
||||
results = []
|
||||
target_channels = []
|
||||
|
||||
# 按名称获取通道
|
||||
if channels:
|
||||
for channel_name in channels:
|
||||
channel = await crud.get_channel_by_name(db, channel_name)
|
||||
if channel:
|
||||
target_channels.append(channel)
|
||||
|
||||
# 按标签获取通道
|
||||
if tags:
|
||||
tagged_channels = await crud.get_channels_by_tags(db, tags)
|
||||
# 合并去重
|
||||
existing_ids = {c.id for c in target_channels}
|
||||
for channel in tagged_channels:
|
||||
if channel.id not in existing_ids:
|
||||
target_channels.append(channel)
|
||||
existing_ids.add(channel.id)
|
||||
|
||||
# 发送通知
|
||||
for channel in target_channels:
|
||||
result = await self.send_notification(
|
||||
db, channel.id, title, body, priority
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
Binary file not shown.
106
backend/tests/test_notify_service.py
Normal file
106
backend/tests/test_notify_service.py
Normal file
@ -0,0 +1,106 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from backend.database import AsyncSessionLocal, init_db
|
||||
from backend import crud
|
||||
from backend.notify_service import NotifyService
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db():
|
||||
await init_db()
|
||||
async with AsyncSessionLocal() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_discord(db):
|
||||
"""测试发送到 Discord"""
|
||||
unique_name = f"测试Discord_{uuid.uuid4().hex[:8]}"
|
||||
# 创建测试通道
|
||||
channel = await crud.create_channel(db, {
|
||||
"name": unique_name,
|
||||
"type": "discord",
|
||||
"config": {"webhook_url": "https://discord.com/api/webhooks/test"}
|
||||
})
|
||||
|
||||
service = NotifyService()
|
||||
|
||||
with patch('apprise.Apprise.notify') as mock_notify:
|
||||
mock_notify.return_value = True
|
||||
result = await service.send_notification(
|
||||
db, channel.id, "测试标题", "测试内容", "normal"
|
||||
)
|
||||
|
||||
assert result['status'] == 'sent'
|
||||
assert result['channel_id'] == channel.id
|
||||
mock_notify.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_multiple_channels(db):
|
||||
"""测试批量发送到多个通道"""
|
||||
# 创建多个通道
|
||||
name_a = f"通道A_{uuid.uuid4().hex[:8]}"
|
||||
name_b = f"通道B_{uuid.uuid4().hex[:8]}"
|
||||
await crud.create_channel(db, {
|
||||
"name": name_a,
|
||||
"type": "discord",
|
||||
"config": {"webhook_url": "https://discord.com/webhook"},
|
||||
"tags": ["alerts"]
|
||||
})
|
||||
await crud.create_channel(db, {
|
||||
"name": name_b,
|
||||
"type": "telegram",
|
||||
"config": {"bot_token": "123456:ABC", "chat_id": "12345"},
|
||||
"tags": ["alerts"]
|
||||
})
|
||||
|
||||
service = NotifyService()
|
||||
|
||||
with patch('apprise.Apprise.notify') as mock_notify:
|
||||
mock_notify.return_value = True
|
||||
results = await service.send_to_channels(
|
||||
db, [name_a, name_b], None, "测试", "内容", "normal"
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(r['status'] == 'sent' for r in results)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_by_tags(db):
|
||||
"""测试按标签发送"""
|
||||
unique_tag = f"alerts_{uuid.uuid4().hex[:8]}"
|
||||
name_prod = f"生产告警_{uuid.uuid4().hex[:8]}"
|
||||
name_dev = f"开发告警_{uuid.uuid4().hex[:8]}"
|
||||
name_info = f"普通通知_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
await crud.create_channel(db, {
|
||||
"name": name_prod,
|
||||
"type": "discord",
|
||||
"config": {"webhook_url": "https://discord.com/webhook"},
|
||||
"tags": [unique_tag, "production"]
|
||||
})
|
||||
await crud.create_channel(db, {
|
||||
"name": name_dev,
|
||||
"type": "telegram",
|
||||
"config": {"bot_token": "123456:ABC", "chat_id": "12345"},
|
||||
"tags": [unique_tag, "dev"]
|
||||
})
|
||||
await crud.create_channel(db, {
|
||||
"name": name_info,
|
||||
"type": "email",
|
||||
"config": {"username": "test@test.com", "password": "pass", "to_email": "to@test.com", "smtp_host": "smtp.test.com"},
|
||||
"tags": ["info"]
|
||||
})
|
||||
|
||||
service = NotifyService()
|
||||
|
||||
with patch('apprise.Apprise.notify') as mock_notify:
|
||||
mock_notify.return_value = True
|
||||
results = await service.send_to_channels(
|
||||
db, None, [unique_tag], "告警", "服务器异常", "high"
|
||||
)
|
||||
|
||||
assert len(results) == 2 # 只有带唯一标签的通道
|
||||
assert mock_notify.call_count == 2
|
||||
Loading…
x
Reference in New Issue
Block a user