search.pyโข6.18 kB
"""Search functionality for the arXiv MCP server."""
import arxiv
import json
from typing import Dict, Any, List, Optional
from datetime import datetime, timezone
from dateutil import parser
import mcp.types as types
from pathlib import Path
from ..config import Settings
settings = Settings()
search_tool = types.Tool(
name="search_papers",
description="Search for papers on arXiv with advanced filtering",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string"},
"max_results": {"type": "integer"},
"date_from": {"type": "string"},
"date_to": {"type": "string"},
"categories": {"type": "array", "items": {"type": "string"}},
"save_to_file": {"type": "string", "description": "Optional file path to save results"},
"min_citations": {"type": "integer", "description": "Minimum citation count filter"},
},
"required": ["query"],
},
)
def _is_within_date_range(
date: datetime, start: datetime | None, end: datetime | None
) -> bool:
"""Check if a date falls within the specified range."""
if start and not start.tzinfo:
start = start.replace(tzinfo=timezone.utc)
if end and not end.tzinfo:
end = end.replace(tzinfo=timezone.utc)
if start and date < start:
return False
if end and date > end:
return False
return True
def _process_paper(paper: arxiv.Result) -> Dict[str, Any]:
"""Process paper information with resource URI."""
return {
"id": paper.get_short_id(),
"title": paper.title,
"authors": [author.name for author in paper.authors],
"abstract": paper.summary,
"categories": paper.categories,
"published": paper.published.isoformat(),
"url": paper.pdf_url,
"resource_uri": f"arxiv://{paper.get_short_id()}",
"citation_count": getattr(paper, 'citation_count', 0), # Note: arxiv doesn't provide this directly
}
def save_results_to_file(results: List[Dict[str, Any]], file_path: str) -> bool:
"""Save search results to a JSON file."""
try:
output_path = Path(file_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
return True
except Exception as e:
print(f"Error saving to file: {e}")
return False
async def handle_search(arguments: Dict[str, Any]) -> List[types.TextContent]:
"""Handle paper search requests.
Automatically adds field specifiers to plain queries for better relevance.
This fixes issue #33 where queries sorted by date returned irrelevant results.
"""
try:
client = arxiv.Client()
max_results = min(int(arguments.get("max_results", 10)), settings.MAX_RESULTS)
# Build search query with category filtering
query = arguments["query"]
# Add field specifier if not already present
# This ensures the query actually searches the content
if not any(field in query for field in ["all:", "ti:", "abs:", "au:", "cat:"]):
# Convert plain query to use all: field for better results
# Handle quoted phrases
if '"' in query:
# Keep quoted phrases intact
query = f"all:{query}"
else:
# For unquoted multi-word queries, use AND operator
terms = query.split()
if len(terms) > 1:
query = " AND ".join(f"all:{term}" for term in terms)
else:
query = f"all:{query}"
if categories := arguments.get("categories"):
category_filter = " OR ".join(f"cat:{cat}" for cat in categories)
query = f"({query}) AND ({category_filter})"
# Parse dates for query construction
date_from = None
date_to = None
try:
if "date_from" in arguments:
date_from = parser.parse(arguments["date_from"]).replace(tzinfo=timezone.utc)
if "date_to" in arguments:
date_to = parser.parse(arguments["date_to"]).replace(tzinfo=timezone.utc)
except (ValueError, TypeError) as e:
return [
types.TextContent(
type="text", text=f"Error: Invalid date format - {str(e)}"
)
]
# Add date range to query if specified
# Note: arXiv API date filtering is limited, so we rely mainly on post-processing
# We can try to use lastUpdatedDate format but it's not always reliable
if date_from or date_to:
# For now, we'll rely on post-processing filtering
# The arXiv API doesn't have reliable date range queries in search
pass
search = arxiv.Search(
query=query,
max_results=max_results,
sort_by=arxiv.SortCriterion.SubmittedDate,
)
# Process results
results = []
for paper in client.results(search):
# Additional date filtering for edge cases (API date query might not be precise)
if date_from or date_to:
if not _is_within_date_range(paper.published, date_from, date_to):
continue
results.append(_process_paper(paper))
if len(results) >= max_results:
break
response_data = {"total_results": len(results), "papers": results}
# Save to file if requested
if save_file := arguments.get("save_to_file"):
save_success = save_results_to_file(response_data, save_file)
if save_success:
response_data["saved_to"] = save_file
else:
response_data["save_error"] = f"Failed to save to {save_file}"
return [
types.TextContent(type="text", text=json.dumps(response_data, indent=2))
]
except Exception as e:
return [types.TextContent(type="text", text=f"Error: {str(e)}")]