diff --git a/backend/__pycache__/crud.cpython-310.pyc b/backend/__pycache__/crud.cpython-310.pyc new file mode 100644 index 0000000..260515f Binary files /dev/null and b/backend/__pycache__/crud.cpython-310.pyc differ diff --git a/backend/__pycache__/notify_service.cpython-310.pyc b/backend/__pycache__/notify_service.cpython-310.pyc new file mode 100644 index 0000000..52cb2fb Binary files /dev/null and b/backend/__pycache__/notify_service.cpython-310.pyc differ diff --git a/backend/crud.py b/backend/crud.py index 72184c5..0aed4c1 100644 --- a/backend/crud.py +++ b/backend/crud.py @@ -20,16 +20,16 @@ async def get_channels(db: AsyncSession, skip: int = 0, limit: int = 100) -> Lis async def get_channels_by_tags(db: AsyncSession, tags: List[str]) -> List[Channel]: """获取包含任一标签的所有活动通道""" - conditions = [Channel.tags.contains([tag]) for tag in tags] - result = await db.execute( - select(Channel).where( - and_( - Channel.is_active == True, - or_(*conditions) - ) - ) - ) - return result.scalars().all() + # SQLite JSON 查询:使用 json_extract 或 like 匹配 + all_channels = await get_channels(db, limit=1000) + result = [] + for channel in all_channels: + if not channel.is_active: + continue + channel_tags = channel.tags or [] + if any(tag in channel_tags for tag in tags): + result.append(channel) + return result async def create_channel(db: AsyncSession, channel_data: Dict[str, Any]) -> Channel: db_channel = Channel(**channel_data) diff --git a/backend/notify_service.py b/backend/notify_service.py new file mode 100644 index 0000000..e096ca7 --- /dev/null +++ b/backend/notify_service.py @@ -0,0 +1,189 @@ +from typing import List, Optional, Dict, Any +from datetime import datetime +from sqlalchemy.ext.asyncio import AsyncSession +import apprise +from backend import crud +from backend.models import Channel + +class NotifyService: + def __init__(self): + self.apobj = apprise.Apprise() + + def _build_apprise_url(self, channel: Channel) -> Optional[str]: + """根据通道类型构建 apprise URL""" + config = channel.config or {} + + if channel.type == "discord": + webhook_url = config.get("webhook_url", "") + if webhook_url: + return f"{webhook_url}" + + elif channel.type == "telegram": + bot_token = config.get("bot_token", "") + chat_id = config.get("chat_id", "") + if bot_token and chat_id: + return f"tgram://{bot_token}/{chat_id}" + + elif channel.type == "email": + smtp_host = config.get("smtp_host", "") + smtp_port = config.get("smtp_port", 587) + username = config.get("username", "") + password = config.get("password", "") + to_email = config.get("to_email", "") + + if username and password and to_email: + return f"mailtos://{username}:{password}@{smtp_host}:{smtp_port}?to={to_email}" + + elif channel.type == "slack": + webhook_url = config.get("webhook_url", "") + if webhook_url: + return f"{webhook_url}" + + elif channel.type == "webhook": + url = config.get("url", "") + if url: + return f"json://{url.replace('https://', '').replace('http://', '')}" + + # 支持 apprise 原生 URL 格式 + elif channel.type == "apprise": + return config.get("url", "") + + return None + + async def send_notification( + self, + db: AsyncSession, + channel_id: int, + title: Optional[str], + body: str, + priority: str = "normal" + ) -> Dict[str, Any]: + """发送单条通知到指定通道""" + channel = await crud.get_channel(db, channel_id) + if not channel: + return { + "channel_id": channel_id, + "status": "failed", + "error_msg": "Channel not found" + } + + if not channel.is_active: + return { + "channel": channel.name, + "channel_id": channel_id, + "status": "skipped", + "error_msg": "Channel is inactive" + } + + # 创建通知记录 + notification = await crud.create_notification( + db, channel_id, { + "title": title, + "body": body, + "priority": priority, + "status": "pending" + } + ) + + # 构建 apprise URL + apprise_url = self._build_apprise_url(channel) + if not apprise_url: + error_msg = f"Invalid configuration for channel type: {channel.type}" + await crud.update_notification_status( + db, notification.id, "failed", error_msg + ) + return { + "channel": channel.name, + "channel_id": channel_id, + "status": "failed", + "notification_id": notification.id, + "error_msg": error_msg + } + + try: + # 发送通知 + apobj = apprise.Apprise() + apobj.add(apprise_url) + + # 构建消息 + message = body + if title: + message = f"**{title}**\n\n{body}" + + # 发送 + result = apobj.notify(body=message) + + if result: + await crud.update_notification_status( + db, notification.id, "sent", sent_at=datetime.utcnow() + ) + return { + "channel": channel.name, + "channel_id": channel_id, + "status": "sent", + "notification_id": notification.id + } + else: + error_msg = "Failed to send notification" + await crud.update_notification_status( + db, notification.id, "failed", error_msg + ) + return { + "channel": channel.name, + "channel_id": channel_id, + "status": "failed", + "notification_id": notification.id, + "error_msg": error_msg + } + + except Exception as e: + error_msg = str(e) + await crud.update_notification_status( + db, notification.id, "failed", error_msg + ) + return { + "channel": channel.name, + "channel_id": channel_id, + "status": "failed", + "notification_id": notification.id, + "error_msg": error_msg + } + + async def send_to_channels( + self, + db: AsyncSession, + channels: Optional[List[str]], + tags: Optional[List[str]], + title: Optional[str], + body: str, + priority: str = "normal" + ) -> List[Dict[str, Any]]: + """批量发送通知到多个通道或按标签发送""" + results = [] + target_channels = [] + + # 按名称获取通道 + if channels: + for channel_name in channels: + channel = await crud.get_channel_by_name(db, channel_name) + if channel: + target_channels.append(channel) + + # 按标签获取通道 + if tags: + tagged_channels = await crud.get_channels_by_tags(db, tags) + # 合并去重 + existing_ids = {c.id for c in target_channels} + for channel in tagged_channels: + if channel.id not in existing_ids: + target_channels.append(channel) + existing_ids.add(channel.id) + + # 发送通知 + for channel in target_channels: + result = await self.send_notification( + db, channel.id, title, body, priority + ) + results.append(result) + + return results diff --git a/backend/tests/__pycache__/test_notify_service.cpython-310-pytest-9.0.2.pyc b/backend/tests/__pycache__/test_notify_service.cpython-310-pytest-9.0.2.pyc new file mode 100644 index 0000000..041cdee Binary files /dev/null and b/backend/tests/__pycache__/test_notify_service.cpython-310-pytest-9.0.2.pyc differ diff --git a/backend/tests/test_notify_service.py b/backend/tests/test_notify_service.py new file mode 100644 index 0000000..9f05e9d --- /dev/null +++ b/backend/tests/test_notify_service.py @@ -0,0 +1,106 @@ +import pytest +import pytest_asyncio +import uuid +from unittest.mock import Mock, patch +from sqlalchemy.ext.asyncio import AsyncSession +from backend.database import AsyncSessionLocal, init_db +from backend import crud +from backend.notify_service import NotifyService + +@pytest_asyncio.fixture +async def db(): + await init_db() + async with AsyncSessionLocal() as session: + yield session + await session.rollback() + +@pytest.mark.asyncio +async def test_send_to_discord(db): + """测试发送到 Discord""" + unique_name = f"测试Discord_{uuid.uuid4().hex[:8]}" + # 创建测试通道 + channel = await crud.create_channel(db, { + "name": unique_name, + "type": "discord", + "config": {"webhook_url": "https://discord.com/api/webhooks/test"} + }) + + service = NotifyService() + + with patch('apprise.Apprise.notify') as mock_notify: + mock_notify.return_value = True + result = await service.send_notification( + db, channel.id, "测试标题", "测试内容", "normal" + ) + + assert result['status'] == 'sent' + assert result['channel_id'] == channel.id + mock_notify.assert_called_once() + +@pytest.mark.asyncio +async def test_send_to_multiple_channels(db): + """测试批量发送到多个通道""" + # 创建多个通道 + name_a = f"通道A_{uuid.uuid4().hex[:8]}" + name_b = f"通道B_{uuid.uuid4().hex[:8]}" + await crud.create_channel(db, { + "name": name_a, + "type": "discord", + "config": {"webhook_url": "https://discord.com/webhook"}, + "tags": ["alerts"] + }) + await crud.create_channel(db, { + "name": name_b, + "type": "telegram", + "config": {"bot_token": "123456:ABC", "chat_id": "12345"}, + "tags": ["alerts"] + }) + + service = NotifyService() + + with patch('apprise.Apprise.notify') as mock_notify: + mock_notify.return_value = True + results = await service.send_to_channels( + db, [name_a, name_b], None, "测试", "内容", "normal" + ) + + assert len(results) == 2 + assert all(r['status'] == 'sent' for r in results) + +@pytest.mark.asyncio +async def test_send_by_tags(db): + """测试按标签发送""" + unique_tag = f"alerts_{uuid.uuid4().hex[:8]}" + name_prod = f"生产告警_{uuid.uuid4().hex[:8]}" + name_dev = f"开发告警_{uuid.uuid4().hex[:8]}" + name_info = f"普通通知_{uuid.uuid4().hex[:8]}" + + await crud.create_channel(db, { + "name": name_prod, + "type": "discord", + "config": {"webhook_url": "https://discord.com/webhook"}, + "tags": [unique_tag, "production"] + }) + await crud.create_channel(db, { + "name": name_dev, + "type": "telegram", + "config": {"bot_token": "123456:ABC", "chat_id": "12345"}, + "tags": [unique_tag, "dev"] + }) + await crud.create_channel(db, { + "name": name_info, + "type": "email", + "config": {"username": "test@test.com", "password": "pass", "to_email": "to@test.com", "smtp_host": "smtp.test.com"}, + "tags": ["info"] + }) + + service = NotifyService() + + with patch('apprise.Apprise.notify') as mock_notify: + mock_notify.return_value = True + results = await service.send_to_channels( + db, None, [unique_tag], "告警", "服务器异常", "high" + ) + + assert len(results) == 2 # 只有带唯一标签的通道 + assert mock_notify.call_count == 2