huggingface_client.py•9.5 kB
"""Hugging Face client for papers, models, and datasets."""
import requests
from datetime import datetime, timedelta, timezone
from typing import List, Dict, Optional
from huggingface_hub import HfApi
class HuggingFaceClient:
"""Client for Hugging Face papers, models, and datasets."""
# Task types for filtering
NLP_TASKS = [
"text-generation",
"translation",
"question-answering",
"summarization",
"fill-mask",
"text-classification",
]
CV_TASKS = [
"image-classification",
"object-detection",
"image-segmentation",
"image-to-image",
"depth-estimation",
]
AUDIO_TASKS = [
"automatic-speech-recognition",
"text-to-speech",
"audio-classification",
"audio-to-audio",
]
MULTIMODAL_TASKS = [
"image-to-text",
"text-to-image",
"visual-question-answering",
"image-text-to-text",
]
def __init__(self):
"""Initialize Hugging Face client."""
self.api = HfApi()
self.papers_base_url = "https://huggingface.co/api/daily_papers"
def get_daily_papers(self, days: int = 1) -> List[Dict]:
"""Get daily papers from Hugging Face.
Args:
days: Number of days to look back (1-7)
Returns:
List of paper dictionaries
"""
papers = []
for day_offset in range(days):
date = datetime.now(timezone.utc) - timedelta(days=day_offset)
date_str = date.strftime("%Y-%m-%d")
try:
url = f"{self.papers_base_url}?date={date_str}"
response = requests.get(url, timeout=10)
response.raise_for_status()
daily_papers = response.json()
for paper in daily_papers:
# Ensure published date has timezone info
published_date = paper.get("publishedAt", date_str)
if published_date and "T" not in published_date:
# If it's just a date, add time and timezone
published_date = f"{published_date}T00:00:00+00:00"
papers.append({
"title": paper.get("title", ""),
"authors": paper.get("authors", []),
"summary": paper.get("summary", ""),
"published": published_date,
"url": f"https://huggingface.co/papers/{paper.get('id', '')}",
"arxiv_id": paper.get("id", ""),
"upvotes": paper.get("upvotes", 0),
"num_comments": paper.get("numComments", 0),
"thumbnail": paper.get("thumbnail", ""),
"source": "huggingface",
})
except requests.RequestException as e:
print(f"Error fetching papers for {date_str}: {e}")
continue
# Sort by upvotes
papers.sort(key=lambda x: x.get("upvotes", 0), reverse=True)
return papers
def get_trending_models(
self,
task: Optional[str] = None,
library: Optional[str] = None,
sort: str = "downloads",
limit: int = 50,
) -> List[Dict]:
"""Get trending models from Hugging Face.
Args:
task: Filter by task (e.g., 'text-generation', 'image-classification')
library: Filter by library (e.g., 'pytorch', 'transformers')
sort: Sort by 'downloads', 'likes', 'trending', or 'created'
limit: Maximum number of results
Returns:
List of model dictionaries
"""
try:
models = self.api.list_models(
filter=task,
library=library,
sort=sort,
direction=-1,
limit=limit,
)
results = []
for model in models:
# Get model info
model_info = {
"id": model.id,
"author": model.author if hasattr(model, "author") else model.id.split("/")[0],
"model_name": model.modelId if hasattr(model, "modelId") else model.id.split("/")[-1],
"url": f"https://huggingface.co/{model.id}",
"downloads": model.downloads if hasattr(model, "downloads") else 0,
"likes": model.likes if hasattr(model, "likes") else 0,
"tags": model.tags if hasattr(model, "tags") else [],
"pipeline_tag": model.pipeline_tag if hasattr(model, "pipeline_tag") else None,
"library": model.library_name if hasattr(model, "library_name") else None,
"created_at": model.created_at.isoformat() if hasattr(model, "created_at") and model.created_at else None,
"last_modified": model.last_modified.isoformat() if hasattr(model, "last_modified") and model.last_modified else None,
"source": "huggingface",
}
results.append(model_info)
return results
except Exception as e:
print(f"Error fetching models: {e}")
return []
def get_trending_datasets(
self,
task: Optional[str] = None,
sort: str = "downloads",
limit: int = 50,
) -> List[Dict]:
"""Get trending datasets from Hugging Face.
Args:
task: Filter by task category
sort: Sort by 'downloads', 'likes', or 'created'
limit: Maximum number of results
Returns:
List of dataset dictionaries
"""
try:
datasets = self.api.list_datasets(
filter=task,
sort=sort,
direction=-1,
limit=limit,
)
results = []
for dataset in datasets:
dataset_info = {
"id": dataset.id,
"author": dataset.author if hasattr(dataset, "author") else dataset.id.split("/")[0],
"dataset_name": dataset.id.split("/")[-1],
"url": f"https://huggingface.co/datasets/{dataset.id}",
"downloads": dataset.downloads if hasattr(dataset, "downloads") else 0,
"likes": dataset.likes if hasattr(dataset, "likes") else 0,
"tags": dataset.tags if hasattr(dataset, "tags") else [],
"created_at": dataset.created_at.isoformat() if hasattr(dataset, "created_at") and dataset.created_at else None,
"last_modified": dataset.last_modified.isoformat() if hasattr(dataset, "last_modified") and dataset.last_modified else None,
"source": "huggingface",
}
results.append(dataset_info)
return results
except Exception as e:
print(f"Error fetching datasets: {e}")
return []
def get_recent_models(self, days: int = 7, limit: int = 50) -> List[Dict]:
"""Get recently created or updated models.
Args:
days: Number of days to look back
limit: Maximum number of results
Returns:
List of model dictionaries
"""
# Get models sorted by last modified
all_models = self.get_trending_models(sort="created", limit=limit * 2)
# Filter by date
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
recent_models = []
for model in all_models:
if model.get("created_at"):
created_date = datetime.fromisoformat(model["created_at"].replace("Z", "+00:00"))
if created_date >= cutoff_date:
recent_models.append(model)
if len(recent_models) >= limit:
break
return recent_models
def get_llm_models(self, limit: int = 30) -> List[Dict]:
"""Get popular LLM models.
Args:
limit: Maximum number of results
Returns:
List of LLM model dictionaries
"""
return self.get_trending_models(
task="text-generation",
sort="downloads",
limit=limit,
)
def get_multimodal_models(self, limit: int = 30) -> List[Dict]:
"""Get popular multimodal models.
Args:
limit: Maximum number of results
Returns:
List of multimodal model dictionaries
"""
results = []
for task in self.MULTIMODAL_TASKS:
models = self.get_trending_models(task=task, sort="downloads", limit=limit // len(self.MULTIMODAL_TASKS))
results.extend(models)
# Sort by downloads and deduplicate
seen = set()
unique_results = []
for model in sorted(results, key=lambda x: x.get("downloads", 0), reverse=True):
if model["id"] not in seen:
seen.add(model["id"])
unique_results.append(model)
return unique_results[:limit]