models.py•3.08 kB
"""Pydantic models for arXiv API responses."""
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, Field
class Author(BaseModel):
"""Author of an arXiv paper."""
name: str
affiliation: Optional[str] = None
def __str__(self) -> str:
if self.affiliation:
return f"{self.name} ({self.affiliation})"
return self.name
class Link(BaseModel):
"""Link associated with an arXiv entry."""
href: str
title: Optional[str] = None
rel: Optional[str] = None
type: Optional[str] = None
class ArxivEntry(BaseModel):
"""Single arXiv paper entry."""
id: str = Field(..., description="arXiv ID (e.g., 2301.12345)")
title: str
summary: str
authors: List[Author]
published: datetime
updated: datetime
primary_category: str
categories: List[str]
links: List[Link]
pdf_url: Optional[str] = None
comment: Optional[str] = None
journal_ref: Optional[str] = None
doi: Optional[str] = None
@property
def arxiv_id(self) -> str:
"""Extract the arXiv ID from the full URL, without version suffix."""
id_str = self.id
if "arxiv.org/abs/" in id_str:
id_str = id_str.split("/abs/")[-1]
# Remove version suffix (e.g., v1, v2, etc.)
if "v" in id_str:
parts = id_str.split("v")
if len(parts) == 2 and parts[1].isdigit():
id_str = parts[0]
return id_str
def __str__(self) -> str:
authors_str = ", ".join(str(a) for a in self.authors[:3])
if len(self.authors) > 3:
authors_str += f" et al. ({len(self.authors) - 3} more)"
lines = [
f"ID: {self.arxiv_id}",
f"Title: {self.title}",
f"Authors: {authors_str}",
f"Published: {self.published.strftime('%Y-%m-%d')}",
f"Categories: {', '.join(self.categories)}",
]
if self.pdf_url:
lines.append(f"PDF: {self.pdf_url}")
if self.comment:
lines.append(f"Comment: {self.comment}")
if self.journal_ref:
lines.append(f"Journal: {self.journal_ref}")
if self.doi:
lines.append(f"DOI: {self.doi}")
lines.append(f"\nAbstract:\n{self.summary[:500]}{'...' if len(self.summary) > 500 else ''}")
return "\n".join(lines)
class ArxivSearchResult(BaseModel):
"""Result from arXiv API search."""
total_results: int
start_index: int
items_per_page: int
entries: List[ArxivEntry]
def __str__(self) -> str:
lines = [
f"Total results: {self.total_results}",
f"Showing: {self.start_index + 1}-{self.start_index + len(self.entries)}",
f"Items per page: {self.items_per_page}",
"",
]
for i, entry in enumerate(self.entries, 1):
lines.append(f"{'=' * 80}")
lines.append(f"Result {i}:")
lines.append(str(entry))
lines.append("")
return "\n".join(lines)