import arxiv
import os
import json
from tqdm import tqdm
def download_papers(count=25, output_dir="./sample_papers"):
os.makedirs(output_dir, exist_ok=True)
# Search queries for diverse AI topics
queries = [
"retrieval augmented generation",
"large language model fine-tuning",
"transformer attention mechanism",
"prompt engineering",
"RLHF reinforcement learning"
]
papers_metadata = []
downloaded = 0
print(f"Downloading {count} AI/ML papers from arXiv...")
for query in queries:
if downloaded >= count:
break
search = arxiv.Search(
query=query,
max_results=count // len(queries) + 2,
sort_by=arxiv.SortCriterion.Relevance
)
for result in tqdm(search.results(), desc=f"Query: {query[:30]}"):
if downloaded >= count:
break
# Clean filename
safe_title = result.title[:50].replace(' ', '_').replace('/', '_').replace('\\', '_').replace(':', '_')
filename = f"{safe_title}_{result.entry_id.split('/')[-1]}.pdf"
filepath = os.path.join(output_dir, filename)
# Skip if exists
if os.path.exists(filepath):
downloaded += 1
continue
# Download PDF
try:
result.download_pdf(dirpath=output_dir, filename=filename)
# Save metadata
papers_metadata.append({
"title": result.title,
"authors": [author.name for author in result.authors],
"published": result.published.strftime("%Y-%m-%d"),
"summary": result.summary[:500],
"filename": filename,
"arxiv_id": result.entry_id.split('/')[-1]
})
downloaded += 1
print(f"✓ Downloaded: {result.title[:60]}")
except Exception as e:
print(f"✗ Failed: {result.title[:60]} - {e}")
continue
# Save metadata JSON
metadata_path = os.path.join(output_dir, "papers_metadata.json")
with open(metadata_path, 'w') as f:
json.dump(papers_metadata, f, indent=2)
print(f"\n✅ Downloaded {downloaded} papers to {output_dir}")
print(f"📄 Metadata saved to {metadata_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--count", type=int, default=25, help="Number of papers to download")
parser.add_argument("--output", default="./sample_papers", help="Output directory")
args = parser.parse_args()
download_papers(args.count, args.output)