PyTorch HUD MCP Server

""" PyTorch HUD log analysis tools """ import os import re import datetime from typing import Dict, Any, Optional, List, cast from mcp.server.fastmcp import Context from pytorch_hud.api.client import PyTorchHudAPI # Initialize API client singleton api = PyTorchHudAPI() def get_artifacts(provider: str, job_id: str) -> Dict[str, Any]: """Get artifacts for a job.""" return api.get_artifacts(provider, job_id) def get_s3_log_url(job_id: str) -> str: """Get the S3 log URL for a job.""" return api.get_s3_log_url(job_id) def find_commits_with_similar_failures(failure: str, repo: Optional[str] = None, workflow_name: Optional[str] = None, branch_name: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, min_score: float = 1.0) -> Dict[str, Any]: """Find commits and jobs with similar failure text using the OpenSearch API. This tool is extremely useful for investigating CI failures by finding historical jobs with similar error messages. It can help: - Determine when a particular issue first appeared - Find related failures across different jobs/workflows - Identify patterns in failures across time periods - Check if failures are associated with specific branches or workflows Args: failure: String containing the error or failure text to search for repo: Optional repository filter (e.g., "pytorch/pytorch") workflow_name: Optional filter for specific workflow branch_name: Optional filter for specific branch (like "main") start_date: ISO format date to begin search from (defaults to 7 days ago) end_date: ISO format date to end search at (defaults to now) min_score: Minimum relevance score for matches (defaults to 1.0) Returns: Dictionary with matching jobs and their commit details, containing: - matches: List of jobs with matching failure text - total_matches: Total number of matches found - total_lines: Total number of matching lines Example: ```python # Find jobs with OOM errors in the past week results = find_commits_with_similar_failures( failure="CUDA out of memory", repo="pytorch/pytorch", workflow_name="linux-bionic-cuda12.1-py3.10-gcc9" ) # Check when an issue first appeared (14 days ago) from datetime import datetime, timedelta now = datetime.now() results = find_commits_with_similar_failures( failure="PACKAGES DO NOT MATCH THE HASHES", start_date=(now - timedelta(days=14)).isoformat(), end_date=now.isoformat() ) ``` Note: Results are limited to the first 100 matching lines per job, and lines are truncated to 100 characters for brevity. """ return api.find_commits_with_similar_failures( failure=failure, repo=repo, workflow_name=workflow_name, branch_name=branch_name, start_date=start_date, end_date=end_date, min_score=min_score ) # Alias for backward compatibility search_logs = find_commits_with_similar_failures async def download_log_to_file(job_id: int, ctx: Optional[Context] = None) -> Dict[str, Any]: """Download a job log to a temporary file for analysis. This tool helps with analyzing large log files by downloading them to local storage instead of loading them entirely into context. Args: job_id: The job ID to download ctx: MCP context Returns: Dictionary with file path and metadata """ # Log the start of download if ctx: await ctx.info(f"Downloading log for job {job_id}") # Create logs directory if it doesn't exist logs_dir = os.path.join(os.getcwd(), "temp_logs") os.makedirs(logs_dir, exist_ok=True) # Convert job_id to string for API calls job_id_str = str(job_id) # Generate a filename based on job_id filename = f"job_{job_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log" filepath = os.path.join(logs_dir, filename) try: # Download the log log_content = api.download_log(job_id_str) # Write to file with open(filepath, 'w') as f: f.write(log_content) # Get basic metadata file_size = os.path.getsize(filepath) line_count = log_content.count('\n') + 1 if ctx: await ctx.info(f"Log downloaded successfully: {filepath} ({file_size} bytes, {line_count} lines)") return { "success": True, "file_path": filepath, "job_id": job_id, "size_bytes": file_size, "line_count": line_count, "url": api.get_s3_log_url(job_id_str) } except Exception as e: if ctx: await ctx.error(f"Failed to download log: {e}") return { "success": False, "error": str(e), "job_id": job_id } async def extract_log_patterns(file_path: str, patterns: Optional[Dict[str, str]] = None, ctx: Optional[Context] = None) -> Dict[str, Any]: """Extract matches for specified patterns from a log file. This tool helps analyze log files by finding patterns of interest without loading the entire file into context. Args: file_path: Path to the log file patterns: Dictionary of pattern_name:regex_pattern pairs to search for If None, uses default patterns for common errors and warnings ctx: MCP context Returns: Dictionary with pattern matches and counts """ if not os.path.exists(file_path): return { "success": False, "error": f"File not found: {file_path}" } if ctx: await ctx.info(f"Analyzing log file: {file_path}") # Default patterns if none provided default_patterns = { "error": r"(?i)error:", "exception": r"(?i)exception:", "warning": r"(?i)warning:", "test_failed": r"FAILED.*test_", "test_results": r"Ran (\d+) tests.*?(\d+) failures", "cuda_error": r"CUDA error|CUDA exception|cudaError", "out_of_memory": r"OutOfMemoryError|OOM|out of memory", "build_failed": r"Build failed|compilation failed|error: command .* failed" } use_patterns = patterns or default_patterns # Initialize result dict with properly typed fields results: Dict[str, Any] = { "success": True, "file_path": file_path, "matches": {}, "counts": {}, "samples": {} } if ctx: await ctx.info(f"Searching for {len(use_patterns)} patterns: {', '.join(use_patterns.keys())}") # Compile patterns compiled_patterns = {name: re.compile(pattern) for name, pattern in use_patterns.items()} # Process file try: with open(file_path, 'r') as f: for line_num, line in enumerate(f, 1): for name, pattern in compiled_patterns.items(): match = pattern.search(line) if match: # Initialize if first match if name not in cast(Dict[str, Any], results["matches"]): results["matches"][name] = [] results["counts"][name] = 0 results["samples"][name] = [] # Add match information cast(Dict[str, int], results["counts"])[name] += 1 # Store limited number of samples with line numbers sample_list = cast(Dict[str, List[Dict[str, Any]]], results["samples"])[name] if len(sample_list) < 5: truncated_line = line.strip()[:150] sample_list.append({ "line_num": line_num, "text": truncated_line, "groups": match.groups() if match.groups() else None }) if ctx: await ctx.info(f"Analysis complete. Found matches for {len(cast(Dict[str, Any], results['counts']))} patterns.") for name, count in cast(Dict[str, int], results["counts"]).items(): await ctx.info(f" - {name}: {count} matches") return results except Exception as e: if ctx: await ctx.error(f"Error analyzing log: {e}") return { "success": False, "error": str(e), "file_path": file_path } async def extract_test_results(file_path: str, ctx: Optional[Context] = None) -> Dict[str, Any]: """Extract test results specifically from a log file. This tool specializes in finding test execution results from various testing frameworks (pytest, unittest) without loading the entire log into context. Args: file_path: Path to the log file ctx: MCP context Returns: Dictionary with test statistics and failures """ if not os.path.exists(file_path): return { "success": False, "error": f"File not found: {file_path}" } if ctx: await ctx.info(f"Extracting test results from: {file_path}") # Initialize results with proper typing results: Dict[str, Any] = { "success": True, "file_path": file_path, "test_counts": { "total": 0, "passed": 0, "failed": 0, "skipped": 0 }, "failed_tests": [], "duration": None } # Patterns for different test frameworks patterns = { "pytest_summary": re.compile(r"=+ ([\d]+) failed, ([\d]+) passed, ([\d]+) skipped"), "unittest_summary": re.compile(r"Ran ([\d]+) tests in ([\d\.]+)s"), "unittest_failure": re.compile(r"FAILED \((.+)\)"), "test_failure": re.compile(r"FAIL: (test\w+)"), "error_failure": re.compile(r"ERROR: (test\w+)") } try: with open(file_path, 'r') as f: lines = f.readlines() for line_num, line in enumerate(lines, 1): line = line.strip() # Check for pytest summary pytest_match = patterns["pytest_summary"].search(line) if pytest_match: test_counts = cast(Dict[str, int], results["test_counts"]) test_counts["failed"] = int(pytest_match.group(1)) test_counts["passed"] = int(pytest_match.group(2)) test_counts["skipped"] = int(pytest_match.group(3)) test_counts["total"] = ( test_counts["failed"] + test_counts["passed"] + test_counts["skipped"] ) # Check for unittest summary unittest_match = patterns["unittest_summary"].search(line) if unittest_match: cast(Dict[str, int], results["test_counts"])["total"] = int(unittest_match.group(1)) results["duration"] = unittest_match.group(2) # Check for failure details for pattern_name in ["test_failure", "error_failure"]: failure_match = patterns[pattern_name].search(line) failed_tests = cast(List[Dict[str, Any]], results["failed_tests"]) if failure_match and len(failed_tests) < 20: # Limit number of failures test_name = failure_match.group(1) # Get a few lines of context after the failure context_lines: List[str] = [] for i in range(line_num, min(line_num + 5, len(lines) + 1)): if i-1 < len(lines): # Make sure we don't go out of bounds context_lines.append(lines[i-1].strip()) failed_tests.append({ "test_name": test_name, "line_num": line_num, "context": context_lines # Context lines }) if ctx: test_counts = cast(Dict[str, int], results["test_counts"]) if test_counts["total"] > 0: await ctx.info(f"Found test results: {test_counts['total']} total tests") await ctx.info(f" - Passed: {test_counts['passed']}") await ctx.info(f" - Failed: {test_counts['failed']}") await ctx.info(f" - Skipped: {test_counts['skipped']}") failed_tests = cast(List[Dict[str, Any]], results["failed_tests"]) if failed_tests: await ctx.info(f" - Found {len(failed_tests)} failed test details") else: await ctx.info("No test results found in the log") return results except Exception as e: if ctx: await ctx.error(f"Error extracting test results: {e}") return { "success": False, "error": str(e), "file_path": file_path } async def filter_log_sections(file_path: str, start_pattern: Optional[str] = None, end_pattern: Optional[str] = None, max_lines: int = 100, ctx: Optional[Context] = None) -> Dict[str, Any]: """Extract specific sections from a log file based on start/end patterns. This tool helps retrieve only relevant sections of large log files without loading the entire file into context. Args: file_path: Path to the log file start_pattern: Regex pattern that marks the start of a section end_pattern: Regex pattern that marks the end of a section max_lines: Maximum number of lines to return per section ctx: MCP context Returns: Dictionary with extracted sections """ if not os.path.exists(file_path): return { "success": False, "error": f"File not found: {file_path}" } if ctx: await ctx.info(f"Filtering sections from log file: {file_path}") if not start_pattern: return { "success": False, "error": "Start pattern is required" } try: start_re = re.compile(start_pattern) end_re = re.compile(end_pattern) if end_pattern else None # Initialize results with proper typing results: Dict[str, Any] = { "success": True, "file_path": file_path, "sections": [], "section_count": 0 } with open(file_path, 'r') as f: in_section = False current_section: List[str] = [] current_start_line = 0 for line_num, line in enumerate(f, 1): # Check for section start if not in_section and start_re.search(line): in_section = True current_section = [line.rstrip()] current_start_line = line_num continue # Add lines while in a section if in_section: # Check if we've reached max lines for this section if len(current_section) >= max_lines: # Add truncation note and end the section current_section.append(f"... [truncated after {max_lines} lines] ...") cast(List[Dict[str, Any]], results["sections"]).append({ "start_line": current_start_line, "content": "\n".join(current_section), "truncated": True }) results["section_count"] = cast(int, results["section_count"]) + 1 in_section = False current_section = [] continue # Check for section end if an end pattern was provided if end_re and end_re.search(line): current_section.append(line.rstrip()) cast(List[Dict[str, Any]], results["sections"]).append({ "start_line": current_start_line, "content": "\n".join(current_section), "truncated": False }) results["section_count"] = cast(int, results["section_count"]) + 1 in_section = False current_section = [] continue # Otherwise, add the line to the current section current_section.append(line.rstrip()) # If we're still in a section at the end of the file, add it if in_section and current_section: cast(List[Dict[str, Any]], results["sections"]).append({ "start_line": current_start_line, "content": "\n".join(current_section), "truncated": False }) results["section_count"] = cast(int, results["section_count"]) + 1 if ctx: await ctx.info(f"Found {results['section_count']} matching sections") return results except Exception as e: if ctx: await ctx.error(f"Error filtering log sections: {e}") return { "success": False, "error": str(e), "file_path": file_path }