services.py•10.2 kB
"""Services for interacting with the arXiv API."""
import time
import xml.etree.ElementTree as ET
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from urllib.parse import urlencode
import httpx
from arxiv.models import ArxivEntry, ArxivSearchResult, Author, Link
class ArxivService:
"""Service for searching and downloading arXiv papers."""
BASE_URL = "https://export.arxiv.org/api/query"
NAMESPACE = {
"atom": "http://www.w3.org/2005/Atom",
"arxiv": "http://arxiv.org/schemas/atom",
"opensearch": "http://a9.com/-/spec/opensearch/1.1/",
}
def __init__(
self,
download_dir: str = "./.arxiv",
rate_limit_delay: float = 3.0,
):
"""Initialize the arXiv service.
Args:
download_dir: Directory to store downloaded PDFs
rate_limit_delay: Delay between API requests (seconds). arXiv recommends 3s.
"""
self.download_dir = Path(download_dir)
self.download_dir.mkdir(parents=True, exist_ok=True)
self.rate_limit_delay = rate_limit_delay
self._last_request_time = 0.0
def _rate_limit(self) -> None:
"""Enforce rate limiting between API requests."""
elapsed = time.time() - self._last_request_time
if elapsed < self.rate_limit_delay:
time.sleep(self.rate_limit_delay - elapsed)
self._last_request_time = time.time()
def _parse_author(self, author_elem: ET.Element) -> Author:
"""Parse an author element from the XML response."""
name_elem = author_elem.find("atom:name", self.NAMESPACE)
affiliation_elem = author_elem.find("arxiv:affiliation", self.NAMESPACE)
return Author(
name=name_elem.text if name_elem is not None else "",
affiliation=affiliation_elem.text if affiliation_elem is not None else None,
)
def _parse_link(self, link_elem: ET.Element) -> Link:
"""Parse a link element from the XML response."""
return Link(
href=link_elem.get("href", ""),
title=link_elem.get("title"),
rel=link_elem.get("rel"),
type=link_elem.get("type"),
)
def _parse_entry(self, entry_elem: ET.Element) -> ArxivEntry:
"""Parse an entry element from the XML response."""
# Extract basic fields
id_elem = entry_elem.find("atom:id", self.NAMESPACE)
title_elem = entry_elem.find("atom:title", self.NAMESPACE)
summary_elem = entry_elem.find("atom:summary", self.NAMESPACE)
published_elem = entry_elem.find("atom:published", self.NAMESPACE)
updated_elem = entry_elem.find("atom:updated", self.NAMESPACE)
comment_elem = entry_elem.find("arxiv:comment", self.NAMESPACE)
journal_ref_elem = entry_elem.find("arxiv:journal_ref", self.NAMESPACE)
doi_elem = entry_elem.find("arxiv:doi", self.NAMESPACE)
primary_category_elem = entry_elem.find("arxiv:primary_category", self.NAMESPACE)
# Parse authors
authors = [
self._parse_author(author)
for author in entry_elem.findall("atom:author", self.NAMESPACE)
]
# Parse links
links = [
self._parse_link(link)
for link in entry_elem.findall("atom:link", self.NAMESPACE)
]
# Find PDF URL
pdf_url = None
for link in links:
if link.title == "pdf":
pdf_url = link.href
break
# Parse categories
category_elems = entry_elem.findall("atom:category", self.NAMESPACE)
categories = [cat.get("term", "") for cat in category_elems]
return ArxivEntry(
id=id_elem.text if id_elem is not None else "",
title=title_elem.text.strip() if title_elem is not None else "",
summary=summary_elem.text.strip() if summary_elem is not None else "",
authors=authors,
published=datetime.fromisoformat(
published_elem.text.replace("Z", "+00:00")
if published_elem is not None
else datetime.now().isoformat()
),
updated=datetime.fromisoformat(
updated_elem.text.replace("Z", "+00:00")
if updated_elem is not None
else datetime.now().isoformat()
),
primary_category=(
primary_category_elem.get("term", "")
if primary_category_elem is not None
else ""
),
categories=categories,
links=links,
pdf_url=pdf_url,
comment=comment_elem.text if comment_elem is not None else None,
journal_ref=journal_ref_elem.text if journal_ref_elem is not None else None,
doi=doi_elem.text if doi_elem is not None else None,
)
def search(
self,
query: str,
max_results: int = 10,
start: int = 0,
sort_by: str = "relevance",
sort_order: str = "descending",
) -> ArxivSearchResult:
"""Search arXiv papers.
Args:
query: Search query (supports field prefixes like ti:, au:, abs:, cat:)
max_results: Maximum number of results to return
start: Starting index for pagination
sort_by: Sort criterion (relevance, lastUpdatedDate, submittedDate)
sort_order: Sort order (ascending, descending)
Returns:
ArxivSearchResult containing the search results
Example queries:
- "ti:machine learning" - Search in title
- "au:Hinton" - Search by author
- "abs:neural networks" - Search in abstract
- "cat:cs.AI" - Search in category
- "ti:transformer AND au:Vaswani" - Combined search
"""
self._rate_limit()
params = {
"search_query": query,
"start": start,
"max_results": max_results,
"sortBy": sort_by,
"sortOrder": sort_order,
}
url = f"{self.BASE_URL}?{urlencode(params)}"
with httpx.Client(timeout=30.0) as client:
response = client.get(url)
response.raise_for_status()
# Parse XML response
root = ET.fromstring(response.content)
# Extract metadata
total_results_elem = root.find("opensearch:totalResults", self.NAMESPACE)
start_index_elem = root.find("opensearch:startIndex", self.NAMESPACE)
items_per_page_elem = root.find("opensearch:itemsPerPage", self.NAMESPACE)
total_results = (
int(total_results_elem.text) if total_results_elem is not None else 0
)
start_index = int(start_index_elem.text) if start_index_elem is not None else 0
items_per_page = (
int(items_per_page_elem.text) if items_per_page_elem is not None else 0
)
# Parse entries
entries = [
self._parse_entry(entry)
for entry in root.findall("atom:entry", self.NAMESPACE)
]
return ArxivSearchResult(
total_results=total_results,
start_index=start_index,
items_per_page=items_per_page,
entries=entries,
)
def get(
self,
arxiv_id: str,
download_pdf: bool = True,
force_download: bool = False,
) -> ArxivEntry:
"""Get a specific arXiv paper by ID.
Args:
arxiv_id: arXiv ID (e.g., "2301.12345" or "arXiv:2301.12345")
download_pdf: Whether to download the PDF
force_download: If True, download even if file exists locally
Returns:
ArxivEntry for the requested paper
"""
# Clean the arXiv ID
arxiv_id = arxiv_id.replace("arXiv:", "").strip()
# Remove version suffix (e.g., v1, v2, etc.)
if "v" in arxiv_id:
parts = arxiv_id.split("v")
if len(parts) == 2 and parts[1].isdigit():
arxiv_id = parts[0]
# Search for the specific ID
result = self.search(query=f"id:{arxiv_id}", max_results=1)
if not result.entries:
raise ValueError(f"No paper found with ID: {arxiv_id}")
entry = result.entries[0]
# Download PDF if requested
if download_pdf and entry.pdf_url:
self._download_pdf(entry, force_download)
return entry
def _download_pdf(self, entry: ArxivEntry, force: bool = False) -> Path:
"""Download PDF for an arXiv entry.
Args:
entry: ArxivEntry to download
force: If True, download even if file exists
Returns:
Path to the downloaded PDF file
"""
if not entry.pdf_url:
raise ValueError(f"No PDF URL available for entry: {entry.arxiv_id}")
# Determine filename
filename = f"{entry.arxiv_id}.pdf"
filepath = self.download_dir / filename
# Check if already exists
if filepath.exists() and not force:
print(f"PDF already exists: {filepath}")
return filepath
# Download PDF
self._rate_limit()
print(f"Downloading PDF: {entry.arxiv_id}")
with httpx.Client(timeout=60.0, follow_redirects=True) as client:
response = client.get(entry.pdf_url)
response.raise_for_status()
filepath.write_bytes(response.content)
print(f"Downloaded to: {filepath}")
return filepath
def download_pdf(
self, arxiv_id: str, force_download: bool = False
) -> Optional[Path]:
"""Download PDF for a paper by arXiv ID.
Args:
arxiv_id: arXiv ID (e.g., "2301.12345")
force_download: If True, download even if file exists locally
Returns:
Path to the downloaded PDF file, or None if not found
"""
try:
entry = self.get(arxiv_id, download_pdf=False)
return self._download_pdf(entry, force=force_download)
except Exception as e:
print(f"Error downloading PDF: {e}")
return None