"""
Unit tests for web_tools module.
"""
import asyncio
import json
import pytest
from unittest.mock import AsyncMock, Mock, patch
from typing import Any, cast
from aiohttp import ClientError, ClientTimeout
from bs4 import BeautifulSoup
from datetime import datetime
from src.percepta_mcp.tools.web_tools import WebScraper
from src.percepta_mcp.config import Settings
from tests.patches.async_mocks import AsyncContextManagerMock, MockClientSession, create_mock_response
class TestWebScraper:
"""Test cases for WebScraper class."""
@pytest.fixture
def mock_settings(self) -> Mock:
"""Create mock settings."""
settings = Mock(spec=Settings)
return settings
@pytest.fixture
def web_scraper(self, mock_settings: Mock) -> WebScraper:
"""Create WebScraper instance."""
scraper = WebScraper(mock_settings)
# 直接设置一个正确的 MockClientSession
scraper.session = cast(Any, MockClientSession())
return scraper
@pytest.fixture
def sample_html(self) -> str:
"""Create sample HTML content."""
return """
<!DOCTYPE html>
<html>
<head>
<title>Test Page</title>
<meta name="description" content="A test page">
<meta property="og:title" content="OG Test Page">
<meta name="twitter:card" content="summary">
<script type="application/ld+json">
{"@context": "https://schema.org", "@type": "WebPage"}
</script>
</head>
<body>
<h1 id="title">Welcome</h1>
<p class="content">This is test content.</p>
<a href="/link1" title="Link 1">Internal Link</a>
<a href="https://external.com" title="Link 2">External Link</a>
<img src="/image1.jpg" alt="Test Image" width="100" height="200">
<form action="/submit" method="POST">
<input type="text" name="username" placeholder="Username" required>
<select name="category">
<option value="1">Category 1</option>
<option value="2" selected>Category 2</option>
</select>
<textarea name="message"></textarea>
<input type="submit" value="Submit">
</form>
<div itemscope itemtype="https://schema.org/Person">
<span itemprop="name">John Doe</span>
<span itemprop="email" content="john@example.com">Email</span>
</div>
</body>
</html>
"""
@pytest.fixture
def sample_sitemap_xml(self) -> str:
"""Create sample sitemap XML."""
return """<?xml version="1.0" encoding="UTF-8"?>
<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
<url>
<loc>https://example.com/</loc>
<lastmod>2023-01-01</lastmod>
<changefreq>daily</changefreq>
<priority>1.0</priority>
</url>
<url>
<loc>https://example.com/page2</loc>
<lastmod>2023-01-02</lastmod>
<changefreq>weekly</changefreq>
<priority>0.8</priority>
</url>
</urlset>
"""
@pytest.fixture
def mock_response(self):
"""Create mock aiohttp response."""
response = AsyncMock()
response.status = 200
response.reason = "OK"
response.headers = {"content-type": "text/html; charset=utf-8"}
return response
@pytest.fixture
def mock_session(self, mock_response: Any) -> AsyncMock:
"""Create mock aiohttp session."""
session = AsyncMock()
session.closed = False
session.get.return_value.__aenter__.return_value = mock_response
return session
def test_init(self, mock_settings: Mock) -> None:
"""Test WebScraper initialization."""
web_scraper = WebScraper(mock_settings)
assert web_scraper.settings == mock_settings
assert web_scraper.session is None
@pytest.mark.asyncio
async def test_get_session_creates_new(self, web_scraper: WebScraper) -> None:
"""Test session creation when none exists."""
# 先清除原有的 session
web_scraper.session = None
with patch('aiohttp.ClientSession') as mock_client_session:
mock_session = MockClientSession()
mock_client_session.return_value = mock_session
session = await web_scraper._get_session()
assert isinstance(session, MockClientSession)
assert web_scraper.session == session
mock_client_session.assert_called_once()
@pytest.mark.asyncio
async def test_get_session_reuses_existing(self, web_scraper):
"""Test session reuse when one exists."""
mock_session = AsyncMock()
mock_session.closed = False
web_scraper.session = mock_session
session = await web_scraper._get_session()
assert session == mock_session
@pytest.mark.asyncio
async def test_get_session_recreates_if_closed(self, web_scraper):
"""Test session recreation when existing is closed."""
old_session = MockClientSession()
old_session.closed = True
web_scraper.session = old_session
with patch('aiohttp.ClientSession') as mock_client_session:
new_session = MockClientSession()
mock_client_session.return_value = new_session
session = await web_scraper._get_session()
assert session == new_session
assert web_scraper.session == new_session
async def test_scrape_success_text_extraction(self, web_scraper, sample_html):
"""Test successful web scraping with text extraction."""
# 使用 MockClientSession 作为 session
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
# 调用要测试的方法
result = await web_scraper.scrape("https://example.com", None, "text")
# 验证结果
assert result["success"] is True
# 因为我们使用的是 BeautifulSoup,结果会包含页面内容
assert "Welcome" in result["data"]
assert "This is test content" in result["data"]
assert result["url"] == "https://example.com"
assert result["title"] == "Test Page"
assert result["content_type"] == "text/html"
assert "Welcome" in result["data"]
assert "timestamp" in result
@pytest.mark.asyncio
async def test_scrape_with_selector_text(self, web_scraper, sample_html):
"""Test scraping with CSS selector for text extraction."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
# 执行测试
result = await web_scraper.scrape("https://example.com", ".content", "text")
# 验证结果
assert result["success"] is True
assert result["data"] == ["This is test content."]
@pytest.mark.asyncio
async def test_scrape_with_selector_html(self, web_scraper, sample_html):
"""Test scraping with CSS selector for HTML extraction."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
# 执行测试
result = await web_scraper.scrape("https://example.com", "h1", "html")
# 验证结果
assert result["success"] is True
assert len(result["data"]) == 1
assert "Welcome" in result["data"][0]
assert "<h1" in result["data"][0]
@pytest.mark.asyncio
async def test_scrape_with_selector_attributes(self, web_scraper, sample_html):
"""Test scraping with CSS selector for attributes extraction."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
# 执行测试
result = await web_scraper.scrape("https://example.com", "img", "attributes")
# 验证结果
assert result["success"] is True
assert isinstance(result["data"], list)
assert len(result["data"]) >= 1
assert "src" in result["data"][0]
assert len(result["data"]) == 1
attrs = result["data"][0]
assert attrs["src"] == "/image1.jpg"
assert attrs["alt"] == "Test Image"
assert attrs["width"] == "100"
assert attrs["height"] == "200"
@pytest.mark.asyncio
async def test_scrape_extract_links(self, web_scraper, sample_html):
"""Test scraping to extract all links."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
# 执行测试
result = await web_scraper.scrape("https://example.com", None, "links")
# 验证结果
assert result["success"] is True
assert isinstance(result["data"], list)
assert len(result["data"]) >= 2 # 样本HTML包含至少两个链接
links = result["data"]
assert len(links) == 2
assert links[0]["url"] == "https://example.com/link1"
assert links[0]["text"] == "Internal Link"
assert links[0]["internal"] is True
assert links[1]["url"] == "https://external.com"
assert links[1]["internal"] is False
@pytest.mark.asyncio
async def test_scrape_extract_images(self, web_scraper, sample_html):
"""Test scraping to extract all images."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
# 执行测试
result = await web_scraper.scrape("https://example.com", None, "images")
# 验证结果
assert result["success"] is True
assert isinstance(result["data"], list)
assert len(result["data"]) >= 1
images = result["data"]
assert len(images) == 1
assert images[0]["url"] == "https://example.com/image1.jpg"
assert images[0]["alt"] == "Test Image"
assert images[0]["width"] == "100"
assert images[0]["height"] == "200"
@pytest.mark.asyncio
async def test_scrape_extract_metadata(self, web_scraper, sample_html):
"""Test scraping to extract page metadata."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
# 执行测试
result = await web_scraper.scrape("https://example.com", None, "metadata")
# 验证结果
assert result["success"] is True
assert "meta_tags" in result["data"]
assert "open_graph" in result["data"]
metadata = result["data"]
assert metadata["title"] == "Test Page"
assert metadata["meta_tags"]["description"] == "A test page"
assert metadata["open_graph"]["og:title"] == "OG Test Page"
assert metadata["twitter_card"]["twitter:card"] == "summary"
@pytest.mark.asyncio
async def test_scrape_extract_structured_data(self, web_scraper, sample_html):
"""Test scraping to extract structured data."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
result = await web_scraper.scrape("https://example.com", None, "structured")
assert result["success"] is True
structured = result["data"]
assert "json_ld" in structured
assert "microdata" in structured
assert len(structured["json_ld"]) == 1
assert structured["json_ld"][0]["@type"] == "WebPage"
assert len(structured["microdata"]) == 1
assert structured["microdata"][0]["type"] == "https://schema.org/Person"
@pytest.mark.asyncio
async def test_scrape_http_error(self, web_scraper):
"""Test scraping with HTTP error response."""
mock_session = MockClientSession()
mock_response = create_mock_response(404, "")
mock_response.reason = "Not Found"
# 为特定 URL 设置响应
mock_session.set_response("https://example.com/missing", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
result = await web_scraper.scrape("https://example.com/missing")
assert result["success"] is False
assert "HTTP 404: Not Found" in result["error"]
assert result["url"] == "https://example.com/missing"
@pytest.mark.asyncio
async def test_scrape_network_error(self, web_scraper):
"""Test scraping with network error."""
# 修改get方法以引发ClientError
class ErroringSession(MockClientSession):
def get(self, url: str, **kwargs):
raise ClientError("Connection failed")
# 实例化和设置ErroringSession
mock_session = ErroringSession()
web_scraper.session = mock_session
result = await web_scraper.scrape("https://example.com")
assert result["success"] is False
assert "Connection failed" in result["error"]
@pytest.mark.asyncio
async def test_scrape_selector_not_found(self, web_scraper, sample_html):
"""Test scraping with selector that doesn't match anything."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
result = await web_scraper.scrape("https://example.com", ".nonexistent", "text")
assert result["success"] is True
assert result["data"] is None
@pytest.mark.asyncio
async def test_crawl_sitemap_success(self, web_scraper, sample_sitemap_xml):
"""Test successful sitemap crawling."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_sitemap_xml)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com/sitemap.xml", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
result = await web_scraper.crawl_sitemap("https://example.com/sitemap.xml")
assert result["success"] is True
assert result["sitemap_url"] == "https://example.com/sitemap.xml"
assert result["total_found"] == 2
assert len(result["urls"]) == 2
# Check first URL
url1 = result["urls"][0]
assert url1["url"] == "https://example.com/"
assert url1["lastmod"] == "2023-01-01"
assert url1["changefreq"] == "daily"
assert url1["priority"] == "1.0"
# Check second URL
url2 = result["urls"][1]
assert url2["url"] == "https://example.com/page2"
assert url2["lastmod"] == "2023-01-02"
@pytest.mark.asyncio
async def test_crawl_sitemap_max_urls(self, web_scraper, sample_sitemap_xml):
"""Test sitemap crawling with max_urls limit."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_sitemap_xml)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com/sitemap.xml", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
result = await web_scraper.crawl_sitemap("https://example.com/sitemap.xml", max_urls=1)
assert result["success"] is True
assert result["total_found"] == 1
assert len(result["urls"]) == 1
@pytest.mark.asyncio
async def test_crawl_sitemap_http_error(self, web_scraper):
"""Test sitemap crawling with HTTP error."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(404, "")
mock_response.reason = "Not Found"
# 为特定 URL 设置响应
mock_session.set_response("https://example.com/sitemap.xml", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
result = await web_scraper.crawl_sitemap("https://example.com/sitemap.xml")
assert result["success"] is False
assert "HTTP 404: Not Found" in result["error"]
@pytest.mark.asyncio
async def test_extract_forms_success(self, web_scraper, sample_html):
"""Test successful form extraction."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
result = await web_scraper.extract_forms("https://example.com")
assert result["success"] is True
assert result["total_forms"] == 1
assert len(result["forms"]) == 1
form = result["forms"][0]
assert form["action"] == "/submit"
assert form["method"] == "POST"
assert len(form["fields"]) == 4 # text, select, textarea, submit
# Check text input
text_field = next(f for f in form["fields"] if f["name"] == "username")
assert text_field["type"] == "text"
assert text_field["placeholder"] == "Username"
assert text_field["required"] is True
# Check select field
select_field = next(f for f in form["fields"] if f["name"] == "category")
assert select_field["tag"] == "select"
assert len(select_field["options"]) == 2
assert select_field["options"][1]["selected"] is True
@pytest.mark.asyncio
async def test_extract_forms_http_error(self, web_scraper):
"""Test form extraction with HTTP error."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(500, "")
mock_response.reason = "Internal Server Error"
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
result = await web_scraper.extract_forms("https://example.com")
assert result["success"] is False
assert "HTTP 500: Internal Server Error" in result["error"]
@pytest.mark.asyncio
async def test_extract_forms_no_forms(self, web_scraper):
"""Test form extraction when no forms exist."""
html_no_forms = "<html><body><p>No forms here</p></body></html>"
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, html_no_forms)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
result = await web_scraper.extract_forms("https://example.com")
assert result["success"] is True
assert result["total_forms"] == 0
assert len(result["forms"]) == 0
@pytest.mark.asyncio
async def test_close_session(self, web_scraper):
"""Test session closing."""
mock_session = AsyncMock()
mock_session.closed = False
web_scraper.session = mock_session
await web_scraper.close()
mock_session.close.assert_called_once()
assert web_scraper.session is None
@pytest.mark.asyncio
async def test_close_no_session(self, web_scraper):
"""Test closing when no session exists."""
web_scraper.session = None
# Should not raise any exception
await web_scraper.close()
@pytest.mark.asyncio
async def test_close_already_closed_session(self, web_scraper):
"""Test closing when session is already closed."""
# 修正测试:close方法只有在session未关闭的情况下才会设置为None
with patch.object(web_scraper, '_get_session') as mock_get_session:
# 使用AsyncMock,因为这是web_tool中实际实现的预期类型
mock_session = AsyncMock()
mock_session.closed = True
web_scraper.session = mock_session
await web_scraper.close()
# Should not try to close already closed session
mock_session.close.assert_not_called()
@pytest.mark.asyncio
async def test_concurrent_requests(self, web_scraper, sample_html):
"""Test multiple concurrent scraping requests."""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为多个 URL 设置相同的响应
mock_session.set_response("https://example.com/page1", mock_response)
mock_session.set_response("https://example.com/page2", mock_response)
mock_session.set_response("https://example.com/page3", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
# Run multiple requests concurrently
tasks = [
web_scraper.scrape("https://example.com/page1"),
web_scraper.scrape("https://example.com/page2"),
web_scraper.scrape("https://example.com/page3")
]
results = await asyncio.gather(*tasks)
# All should succeed
for result in results:
assert result["success"] is True
assert result["title"] == "Test Page"
@pytest.mark.asyncio
async def test_extract_data_error_handling(self, web_scraper):
"""Test error handling in _extract_data method."""
# Create a soup that will cause an error
soup = BeautifulSoup("<html></html>", 'html.parser')
# Mock an error in the extraction process
with patch.object(soup, 'select', side_effect=Exception("Extraction error")):
result = await web_scraper._extract_data(soup, ".test", "text", "https://example.com")
assert result is None
@pytest.mark.asyncio
async def test_invalid_json_ld_handling(self, web_scraper):
"""Test handling of invalid JSON-LD data."""
html_invalid_json = """
<html>
<head>
<script type="application/ld+json">
{invalid json}
</script>
<script type="application/ld+json">
{"valid": "json"}
</script>
</head>
<body></body>
</html>
"""
# 使用 MockClientSession
mock_session = MockClientSession()
mock_response = create_mock_response(200, html_invalid_json)
# 为特定 URL 设置响应
mock_session.set_response("https://example.com", mock_response)
# 设置到 web_scraper
web_scraper.session = mock_session
result = await web_scraper.scrape("https://example.com", None, "structured")
assert result["success"] is True
structured = result["data"]
# Should only have the valid JSON-LD entry
assert len(structured["json_ld"]) == 1
assert structured["json_ld"][0]["valid"] == "json"
@pytest.mark.asyncio
async def test_logging_behavior(self, web_scraper, sample_html, caplog):
"""Test that appropriate logging occurs."""
with patch.object(web_scraper, '_get_session') as mock_get_session:
# 使用 MockClientSession 代替 AsyncMock
mock_session = MockClientSession()
mock_response = create_mock_response(200, sample_html)
# 为多个 URL 设置相同的响应
mock_session.set_response("https://example.com", mock_response)
mock_session.set_response("https://example.com/sitemap.xml", mock_response)
mock_get_session.return_value = mock_session
with caplog.at_level("INFO"):
await web_scraper.scrape("https://example.com")
await web_scraper.crawl_sitemap("https://example.com/sitemap.xml")
await web_scraper.extract_forms("https://example.com")
# Check that info logs were created
info_logs = [record for record in caplog.records if record.levelname == "INFO"]
assert len(info_logs) >= 3
# Test error logging
error_session = MockClientSession()
error_session.get = lambda url, **kwargs: (_ for _ in ()).throw(Exception("Network error"))
mock_get_session.return_value = error_session
with caplog.at_level("ERROR"):
await web_scraper.scrape("https://example.com")
error_logs = [record for record in caplog.records if record.levelname == "ERROR"]
assert len(error_logs) >= 1