PyTorch HUD MCP Server
by izaitsevfb
- pytorch_hud
- api
"""
PyTorch HUD API client implementation
"""
import json
import requests
import logging
import time
from typing import Dict, Any, List, Optional
import base64
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("PyTorchHud")
class PyTorchHudAPIError(Exception):
"""Base exception for API errors."""
pass
class PyTorchHudAPI:
"""Python wrapper for the PyTorch Hud APIs."""
BASE_URL = "https://hud.pytorch.org/api"
def __init__(self, base_url: Optional[str] = None, retry_attempts: int = 3, retry_delay: float = 1.0):
"""Initialize with optional custom base URL.
Args:
base_url: Optional custom base URL for the API
retry_attempts: Number of times to retry failed requests (default: 3)
retry_delay: Initial delay between retries in seconds, doubles after each retry (default: 1.0)
"""
self.base_url = base_url or self.BASE_URL
self.retry_attempts = retry_attempts
self.retry_delay = retry_delay
self._clickhouse_queries_cache: Optional[List[str]] = None
def _make_request(self, endpoint: str, params: Optional[Dict[str, Any]] = None,
retry_remaining: Optional[int] = None) -> Dict[str, Any]:
"""Make a GET request to the API with retry logic.
Args:
endpoint: API endpoint path
params: Query parameters
retry_remaining: Number of retry attempts left (internal use)
Returns:
Parsed JSON response
Raises:
PyTorchHudAPIError: If the request fails after all retries
"""
if retry_remaining is None:
retry_remaining = self.retry_attempts
url = f"{self.base_url}/{endpoint}"
try:
logger.debug(f"Making request to {url} with params {params}")
response = requests.get(url, params=params)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if retry_remaining > 0:
delay = self.retry_delay * (2 ** (self.retry_attempts - retry_remaining))
logger.warning(f"Request to {url} failed: {e}. Retrying in {delay:.2f}s... ({retry_remaining} attempts left)")
time.sleep(delay)
return self._make_request(endpoint, params, retry_remaining - 1)
else:
logger.error(f"Request to {url} failed after {self.retry_attempts} attempts: {e}")
raise PyTorchHudAPIError(f"Failed to make request to {url}: {e}") from e
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response from {url}: {e}")
raise PyTorchHudAPIError(f"Failed to parse JSON response from {url}: {e}") from e
def get_hud_data(self, repo_owner: str, repo_name: str, branch_or_commit_sha: str,
per_page: Optional[int] = None, merge_lf: Optional[bool] = None,
page: Optional[int] = None) -> Dict[str, Any]:
"""Get HUD data for a specific commit or branch.
Args:
repo_owner: Repository owner (e.g., 'pytorch')
repo_name: Repository name (e.g., 'pytorch')
branch_or_commit_sha: Branch name (e.g., 'main') or commit SHA
- When passing a branch name like 'main', returns recent commits on that branch
- When passing a full commit SHA, returns data starting from that specific commit
(the requested commit will be the first in the result list)
per_page: Number of items per page
merge_lf: Whether to merge LandingFlow data
page: Page number for pagination
Returns:
Dictionary containing HUD data for the specified commit(s)
Note:
The API doesn't accept "HEAD" as a special value. To get the latest commit,
use a branch name like "main" instead.
"""
if page is None or page < 1:
page = 1
if per_page is None or per_page < 1:
per_page = 20
# Use the branch_or_commit_sha parameter directly in the endpoint
endpoint = f"hud/{repo_owner}/{repo_name}/{branch_or_commit_sha}/{page}"
params = {"per_page": per_page, "mergeLF": str(merge_lf).lower()}
logger.info(f"Making HUD data request to {endpoint} with params {params}")
return self._make_request(endpoint, params)
def query_clickhouse(self, query_name: str, parameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Run a ClickHouse query by name with parameters.
Args:
query_name: Name of the ClickHouse query to run
parameters: Query parameters
Returns:
Query results
Note:
The ClickHouse API is sensitive to parameters format. This method will
automatically format the parameters as required by the API.
"""
endpoint = f"clickhouse/{query_name}"
params = {}
if parameters is not None:
# ClickHouse API requires JSON-encoded parameters
params["parameters"] = json.dumps(parameters)
return self._make_request(endpoint, params)
def get_clickhouse_queries(self, use_cache: bool = True) -> List[str]:
"""Get a list of all available ClickHouse queries.
This method attempts to discover all available ClickHouse queries by
looking at the URL structure and testing query names.
Args:
use_cache: Whether to use cached results if available
Returns:
List of query names
"""
if use_cache and self._clickhouse_queries_cache is not None:
return self._clickhouse_queries_cache
# Try to fetch from repo directory structure first
try:
# This is a basic implementation. The real implementation would
# need to be more sophisticated to parse the directory structure.
github_queries: List[str] = []
url = "https://api.github.com/repos/pytorch/test-infra/contents/torchci/clickhouse_queries"
response = requests.get(url)
response.raise_for_status()
for item in response.json():
if item['type'] == 'dir':
github_queries.append(item['name'])
if github_queries:
self._clickhouse_queries_cache = github_queries
return github_queries
except Exception as e:
logger.warning(f"Failed to fetch queries from GitHub: {e}")
# Fallback to a hardcoded list based on known queries
hardcoded_queries: List[str] = [
"master_commit_red",
"queued_jobs",
"disabled_test_historical",
"master_commit_red_percent",
"master_commit_red_jobs",
"nightly_jobs_red",
"nightly_jobs_red_by_name",
"commit_jobs_query",
"commit_jobs_batch_query",
"flaky_tests",
"disabled_tests",
"tts_avg",
"tts_percentile",
"ttrs_percentiles"
]
self._clickhouse_queries_cache = hardcoded_queries
return hardcoded_queries
def get_clickhouse_query_parameters(self, query_name: str) -> Dict[str, Any]:
"""Get the expected parameters for a specific ClickHouse query.
Args:
query_name: Name of the query
Returns:
Dictionary of parameter names and example values
"""
try:
# Try to fetch from repo
url = f"https://api.github.com/repos/pytorch/test-infra/contents/torchci/clickhouse_queries/{query_name}/params.json"
response = requests.get(url)
response.raise_for_status()
# Get file contents (Base64 encoded)
content = response.json()['content']
decoded_content = base64.b64decode(content).decode('utf-8')
return json.loads(decoded_content)
except Exception as e:
logger.warning(f"Failed to fetch parameters for query {query_name}: {e}")
# Fallback to common parameters
from datetime import datetime, timedelta
now = datetime.now()
return {
"startTime": (now - timedelta(days=7)).isoformat(),
"stopTime": now.isoformat(),
"timezone": "America/Los_Angeles"
}
def get_artifacts(self, provider: str, job_id: str) -> Dict[str, Any]:
"""Get artifacts for a job.
Args:
provider: Artifact provider (e.g., 's3')
job_id: Job ID
"""
endpoint = f"artifacts/{provider}/{job_id}"
return self._make_request(endpoint)
def get_s3_log_url(self, job_id: str) -> str:
"""Get the S3 log URL for a job.
Args:
job_id: Job ID
Returns:
S3 log URL
"""
return f"https://ossci-raw-job-status.s3.amazonaws.com/log/{job_id}"
def find_commits_with_similar_failures(self, failure: str,
repo: Optional[str] = None,
workflow_name: Optional[str] = None,
branch_name: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
min_score: float = 1.0) -> Dict[str, Any]:
"""Find commits and jobs with similar failure text using the OpenSearch API.
This is useful for investigating CI failures by finding historical jobs with similar
error messages. It can help narrow down when a particular issue first appeared or
identify patterns across different jobs and workflows.
Args:
failure: String containing the error or failure text to search for
repo: Optional repository filter (e.g., "pytorch/pytorch")
workflow_name: Optional filter for specific workflow
branch_name: Optional filter for specific branch (like "main")
start_date: ISO format date to begin search from (required by API, defaults to 7 days ago)
end_date: ISO format date to end search at (required by API, defaults to now)
min_score: Minimum relevance score for matches (defaults to 1.0)
Returns:
Dictionary with matching jobs and their commit details, containing:
- matches: List of jobs with matching failure text
- total_matches: Total number of matches found
- total_lines: Total number of matching lines
Note:
Results are limited to the first 100 matching lines per job,
and lines are truncated to 100 characters for brevity.
"""
endpoint = "search"
# Set default dates if not provided
if not start_date or not end_date:
from datetime import datetime, timedelta
now = datetime.now()
if not end_date:
end_date = now.isoformat()
if not start_date:
start_date = (now - timedelta(days=7)).isoformat()
params = {
"failure": failure,
"startDate": start_date,
"endDate": end_date,
"minScore": min_score
}
if repo:
params["repo"] = repo
if workflow_name:
params["workflowName"] = workflow_name
if branch_name:
params["branchName"] = branch_name
return self._make_request(endpoint, params)
# Alias for backward compatibility
search_logs = find_commits_with_similar_failures
def download_log(self, job_id: str) -> str:
"""Download the full text log for a job.
Args:
job_id: The job ID
Returns:
The log content as a string
Raises:
PyTorchHudAPIError: If the log cannot be downloaded
"""
url = f"https://ossci-raw-job-status.s3.amazonaws.com/log/{job_id}"
try:
response = requests.get(url)
response.raise_for_status()
return response.text
except requests.exceptions.RequestException as e:
logger.error(f"Failed to download log for job {job_id}: {e}")
raise PyTorchHudAPIError(f"Failed to download log for job {job_id}: {e}") from e