107 lines
3.4 KiB
Python
107 lines
3.4 KiB
Python
import pytest
|
|
import pytest_asyncio
|
|
import uuid
|
|
from unittest.mock import Mock, patch
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from backend.database import AsyncSessionLocal, init_db
|
|
from backend import crud
|
|
from backend.notify_service import NotifyService
|
|
|
|
@pytest_asyncio.fixture
|
|
async def db():
|
|
await init_db()
|
|
async with AsyncSessionLocal() as session:
|
|
yield session
|
|
await session.rollback()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_discord(db):
|
|
"""测试发送到 Discord"""
|
|
unique_name = f"测试Discord_{uuid.uuid4().hex[:8]}"
|
|
# 创建测试通道
|
|
channel = await crud.create_channel(db, {
|
|
"name": unique_name,
|
|
"type": "discord",
|
|
"config": {"webhook_url": "https://discord.com/api/webhooks/test"}
|
|
})
|
|
|
|
service = NotifyService()
|
|
|
|
with patch('apprise.Apprise.notify') as mock_notify:
|
|
mock_notify.return_value = True
|
|
result = await service.send_notification(
|
|
db, channel.id, "测试标题", "测试内容", "normal"
|
|
)
|
|
|
|
assert result['status'] == 'sent'
|
|
assert result['channel_id'] == channel.id
|
|
mock_notify.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_multiple_channels(db):
|
|
"""测试批量发送到多个通道"""
|
|
# 创建多个通道
|
|
name_a = f"通道A_{uuid.uuid4().hex[:8]}"
|
|
name_b = f"通道B_{uuid.uuid4().hex[:8]}"
|
|
await crud.create_channel(db, {
|
|
"name": name_a,
|
|
"type": "discord",
|
|
"config": {"webhook_url": "https://discord.com/webhook"},
|
|
"tags": ["alerts"]
|
|
})
|
|
await crud.create_channel(db, {
|
|
"name": name_b,
|
|
"type": "telegram",
|
|
"config": {"bot_token": "123456:ABC", "chat_id": "12345"},
|
|
"tags": ["alerts"]
|
|
})
|
|
|
|
service = NotifyService()
|
|
|
|
with patch('apprise.Apprise.notify') as mock_notify:
|
|
mock_notify.return_value = True
|
|
results = await service.send_to_channels(
|
|
db, [name_a, name_b], None, "测试", "内容", "normal"
|
|
)
|
|
|
|
assert len(results) == 2
|
|
assert all(r['status'] == 'sent' for r in results)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_by_tags(db):
|
|
"""测试按标签发送"""
|
|
unique_tag = f"alerts_{uuid.uuid4().hex[:8]}"
|
|
name_prod = f"生产告警_{uuid.uuid4().hex[:8]}"
|
|
name_dev = f"开发告警_{uuid.uuid4().hex[:8]}"
|
|
name_info = f"普通通知_{uuid.uuid4().hex[:8]}"
|
|
|
|
await crud.create_channel(db, {
|
|
"name": name_prod,
|
|
"type": "discord",
|
|
"config": {"webhook_url": "https://discord.com/webhook"},
|
|
"tags": [unique_tag, "production"]
|
|
})
|
|
await crud.create_channel(db, {
|
|
"name": name_dev,
|
|
"type": "telegram",
|
|
"config": {"bot_token": "123456:ABC", "chat_id": "12345"},
|
|
"tags": [unique_tag, "dev"]
|
|
})
|
|
await crud.create_channel(db, {
|
|
"name": name_info,
|
|
"type": "email",
|
|
"config": {"username": "test@test.com", "password": "pass", "to_email": "to@test.com", "smtp_host": "smtp.test.com"},
|
|
"tags": ["info"]
|
|
})
|
|
|
|
service = NotifyService()
|
|
|
|
with patch('apprise.Apprise.notify') as mock_notify:
|
|
mock_notify.return_value = True
|
|
results = await service.send_to_channels(
|
|
db, None, [unique_tag], "告警", "服务器异常", "high"
|
|
)
|
|
|
|
assert len(results) == 2 # 只有带唯一标签的通道
|
|
assert mock_notify.call_count == 2
|