JVM MCP Server

  • src
  • jvm_mcp_server
"""Arthas客户端实现""" import subprocess import os import time import telnetlib import socket import paramiko import logging import re from typing import Optional, Dict, Union, Any from .config import ArthasConfig # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ArthasClient: """Arthas客户端封装类""" _connection_pool = None # 类级别的连接池 _config = None # 类级别的配置对象 @classmethod def get_connection_pool(cls): """获取连接池实例""" if cls._connection_pool is None: from .connection_pool import ArthasConnectionPool # 延迟导入 cls._connection_pool = ArthasConnectionPool() return cls._connection_pool @classmethod def get_config(cls) -> ArthasConfig: """获取配置实例""" if cls._config is None: config_file = os.path.join(os.path.dirname(__file__), '../../config/arthas.json') cls._config = ArthasConfig.load(config_file) return cls._config def __init__(self, telnet_port: int = 3658, ssh_host: str = None, ssh_port: int = 22, ssh_password: str = None): """ 初始化Arthas客户端 Args: telnet_port: Arthas telnet端口 ssh_host: SSH连接地址,格式为 user@host,为None时表示本地连接 ssh_port: SSH端口,默认22 ssh_password: SSH密码,为None时表示使用密钥认证 """ self.arthas_boot_path = "arthas-boot.jar" self.telnet_port = telnet_port self.telnet = None self.ssh = None self.attached_pid = None self.arthas_started = False # 新增:标记Arthas是否已启动 self.local_port = None # 新增:保存本地转发端口 # SSH连接信息 self.ssh_host = ssh_host if ssh_host and '@' in ssh_host: self.ssh_user, self.ssh_host = ssh_host.split('@') logger.info(f"SSH连接信息: 用户={self.ssh_user}, 主机={self.ssh_host}, 端口={ssh_port}") else: self.ssh_user = None self.ssh_host = None logger.info("使用本地连接模式") self.ssh_port = ssh_port self.ssh_password = ssh_password # 如果是远程连接,建立SSH连接 if self.ssh_host: try: self._setup_ssh_connection() except Exception as e: logger.error(f"SSH连接失败: {e}") raise else: self._download_arthas() def _setup_ssh_connection(self): """建立SSH连接""" try: logger.info(f"正在连接到远程服务器: {self.ssh_host}:{self.ssh_port}") self.ssh = paramiko.SSHClient() self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) if self.ssh_password: logger.info("使用密码认证") self.ssh.connect(self.ssh_host, self.ssh_port, self.ssh_user, self.ssh_password) else: logger.info("使用密钥认证") self.ssh.connect(self.ssh_host, self.ssh_port, self.ssh_user) logger.info("SSH连接成功") # 在远程服务器上下载arthas self._download_arthas_remote() except paramiko.AuthenticationException: logger.error("SSH认证失败,请检查用户名和密码") if self.ssh: self.ssh.close() self.ssh = None raise except paramiko.SSHException as e: logger.error(f"SSH连接出错: {e}") if self.ssh: self.ssh.close() self.ssh = None raise except Exception as e: logger.error(f"连接远程服务器失败: {e}") if self.ssh: self.ssh.close() self.ssh = None raise def _download_arthas_remote(self): """在远程服务器上下载Arthas""" logger.info("正在远程服务器上下载Arthas...") # 首先检查文件是否已存在 check_cmd = "[ -f arthas-boot.jar ] && echo 'exists'" stdin, stdout, stderr = self.ssh.exec_command(check_cmd) if stdout.read().decode('utf-8').strip() == 'exists': logger.info("远程服务器上已存在arthas-boot.jar") return # 下载文件 cmd = "curl -s -o arthas-boot.jar https://arthas.aliyun.com/arthas-boot.jar" stdin, stdout, stderr = self.ssh.exec_command(cmd) # 验证文件是否下载成功 verify_cmd = "[ -f arthas-boot.jar ] && [ -s arthas-boot.jar ] && echo 'success'" stdin, stdout, stderr = self.ssh.exec_command(verify_cmd) if stdout.read().decode('utf-8').strip() != 'success': error_msg = "下载的arthas-boot.jar文件不存在或大小为0" logger.error(error_msg) raise Exception(error_msg) logger.info("Arthas下载成功") def _download_arthas(self): """在本地下载Arthas启动器""" if not os.path.exists(self.arthas_boot_path): logger.info("正在本地下载Arthas...") try: subprocess.run( ["curl", "-o", self.arthas_boot_path, "https://arthas.aliyun.com/arthas-boot.jar"], check=True ) logger.info("Arthas下载成功") except subprocess.CalledProcessError as e: logger.error(f"下载Arthas失败: {e}") raise def _attach_to_process(self, pid: int): """连接到指定的Java进程""" if self.attached_pid == pid and hasattr(self, 'arthas_channel') and self.arthas_channel: logger.debug(f"已经连接到进程 {pid}") return # 如果已经连接到其他进程,先断开 self._disconnect() logger.info(f"正在连接到Java进程 {pid}") if self.ssh_host: # 确保SSH连接有效 self._ensure_ssh_connection() try: # 检查Java进程是否存在 check_pid_cmd = f"ps -p {pid} > /dev/null 2>&1 && echo 'exists'" stdin, stdout, stderr = self.ssh.exec_command(check_pid_cmd) if stdout.read().decode('utf-8').strip() != 'exists': error_msg = f"进程 {pid} 不存在" logger.error(error_msg) raise Exception(error_msg) # 启动Arthas并保持会话 cmd = f"java -jar arthas-boot.jar --telnet-port {self.telnet_port} --http-port -1 {pid}" logger.debug(f"执行远程命令: {cmd}") # 使用get_pty=True来模拟终端,并保持会话 self.arthas_channel = self.ssh.get_transport().open_session() self.arthas_channel.get_pty() self.arthas_channel.exec_command(cmd) # 等待Arthas启动,同时检查输出中是否有错误信息 start_time = time.time() success = False error_msg = None buffer = "" while time.time() - start_time < 30: # 最多等待30秒 if self.arthas_channel.recv_ready(): output = self.arthas_channel.recv(1024).decode('utf-8') buffer += output logger.debug(f"Arthas输出: {output}") if "Can not attach to target process" in buffer: error_msg = "无法附加到目标进程,可能是权限问题" break elif "ERROR" in buffer: error_msg = f"启动Arthas时发生错误: {buffer}" break elif "as.sh" in buffer or "$" in buffer: # Arthas的命令提示符 success = True break time.sleep(0.1) if not success: if error_msg is None: error_msg = "启动Arthas超时" logger.error(error_msg) self._disconnect() raise Exception(error_msg) # 等待一段时间确保Arthas完全启动 time.sleep(2) logger.info("Arthas启动成功") self.attached_pid = pid except Exception as e: logger.error(f"连接过程中发生错误: {e}") self._disconnect() raise else: # 本地启动Arthas logger.info("在本地启动Arthas") try: # 使用subprocess.Popen启动Arthas cmd = [ "java", "-jar", self.arthas_boot_path, "--target-ip", "127.0.0.1", "--telnet-port", str(self.telnet_port), "--http-port", "-1", str(pid) ] logger.debug(f"执行本地命令: {' '.join(cmd)}") # 使用subprocess.PIPE来捕获输出 process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, # 使用文本模式 bufsize=1 # 行缓冲 ) # 等待Arthas启动并检查输出 start_time = time.time() success = False error_msg = None while time.time() - start_time < 30: # 最多等待30秒 # 检查进程是否还在运行 if process.poll() is not None: error_msg = f"Arthas进程意外退出,返回码: {process.returncode}" break # 尝试建立telnet连接 try: logger.debug(f"尝试连接到本地端口 {self.telnet_port}") self.telnet = telnetlib.Telnet("127.0.0.1", self.telnet_port, timeout=2) # 等待提示符确认连接成功 response = self.telnet.read_until(b"$", timeout=2).decode('utf-8') if "arthas" in response.lower(): logger.info(f"成功连接到进程 {pid}") success = True self.attached_pid = pid # 保存进程引用以便后续管理 self.arthas_process = process break else: self.telnet.close() self.telnet = None except (socket.error, EOFError, socket.timeout): # 连接失败,继续等待 pass # 检查是否有错误输出 stderr_data = process.stderr.readline() if stderr_data: error_msg = f"Arthas启动错误: {stderr_data.strip()}" break time.sleep(1) if not success: # 如果没有成功,确保清理资源 if process.poll() is None: process.terminate() process.wait(timeout=5) if error_msg is None: error_msg = "启动Arthas超时" logger.error(error_msg) self._disconnect() raise Exception(error_msg) except Exception as e: logger.error(f"启动Arthas失败: {e}") self._disconnect() raise def _find_free_port(self) -> int: """查找可用的本地端口""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('', 0)) s.listen(1) port = s.getsockname()[1] return port def _check_connection(self) -> bool: """检查telnet连接是否有效""" if not self.telnet: return False try: self.telnet.write(b"\n") self.telnet.read_until(b"$", timeout=1) return True except (socket.error, EOFError): return False def _disconnect(self): """断开与Arthas的连接""" if hasattr(self, 'arthas_channel') and self.arthas_channel: try: # 发送quit命令给Arthas self.arthas_channel.send('quit\n'.encode('utf-8')) time.sleep(1) # 等待命令执行 self.arthas_channel.close() except: pass finally: self.arthas_channel = None self.attached_pid = None if self.telnet: try: self.telnet.write(b"quit\n") # 发送quit命令给Arthas time.sleep(1) # 等待命令执行 self.telnet.close() except: pass finally: self.telnet = None self.attached_pid = None # 清理本地Arthas进程 if hasattr(self, 'arthas_process') and self.arthas_process: try: if self.arthas_process.poll() is None: # 如果进程还在运行 self.arthas_process.terminate() # 先尝试正常终止 try: self.arthas_process.wait(timeout=5) # 等待最多5秒 except subprocess.TimeoutExpired: self.arthas_process.kill() # 如果等待超时,强制终止 except: pass finally: self.arthas_process = None if self.ssh: try: self.ssh.close() logger.debug("已关闭SSH连接") except Exception as e: logger.warning(f"关闭SSH连接时出错: {e}") finally: self.ssh = None def _check_ssh_connection(self) -> bool: """检查SSH连接是否有效""" if not self.ssh: return False try: self.ssh.exec_command('echo 1') return True except: return False def _ensure_ssh_connection(self): """确保SSH连接有效,如果断开则重连""" if not self._check_ssh_connection(): logger.info("SSH连接已断开,尝试重新连接") self._setup_ssh_connection() def _execute_command(self, pid: int, command: str) -> str: """执行Arthas命令""" try: # 从连接池获取连接 logger.info(f"从连接池获取连接 pid={pid}, command={command}") conn = self.get_connection_pool().get_connection(pid) try: # 执行命令 result = conn.client._execute_command_internal(command) if isinstance(result, dict) and "raw_output" in result: return result["raw_output"] return result finally: # 归还连接 self.get_connection_pool().return_connection(conn) except Exception as e: logger.error(f"执行命令时发生错误: {e}") raise def _execute_command_internal(self, command: str) -> str: """执行Arthas命令并返回结果""" config = self.get_config() cmd_config = config.get_command_config(command.split()[0]) # 获取命令的配置 logger.info(f"开始执行命令: {command}") max_retries = cmd_config.max_retries if cmd_config else 3 retry_interval = cmd_config.retry_interval if cmd_config else 1 timeout = cmd_config.timeout if cmd_config else 10 max_output_size = 50000 # 设置最大输出大小为50KB for retry in range(max_retries): try: if self.ssh_host: logger.debug("使用SSH模式执行命令") if not hasattr(self, 'arthas_channel') or not self.arthas_channel: raise Exception("Arthas会话未建立") # 清空之前的输出 while self.arthas_channel.recv_ready(): self.arthas_channel.recv(1024) logger.debug(f"发送命令: {command}") self.arthas_channel.send(command + "\n") # 等待并收集输出 output = "" start_time = time.time() output_size = 0 truncated = False while time.time() - start_time < timeout: if self.arthas_channel.recv_ready(): chunk = self.arthas_channel.recv(4096).decode('utf-8') logger.debug(f"接收到数据块: {len(chunk)} 字节") chunk_size = len(chunk.encode('utf-8')) if output_size + chunk_size > max_output_size: logger.warning(f"输出超过大小限制 ({max_output_size} 字节),进行截断") remaining = max_output_size - output_size if remaining > 0: output += chunk[:remaining] truncated = True break output += chunk output_size += chunk_size if "$" in chunk: # 命令提示符表示命令执行完成 logger.debug("检测到命令提示符,命令执行完成") break time.sleep(0.1) if time.time() - start_time >= timeout: raise TimeoutError(f"命令执行超时: {command}") # 移除命令回显和提示符 lines = output.split("\n") lines = [line for line in lines if line and not line.startswith(command) and "$" not in line] result = "\n".join(lines) if truncated: result += "\n... (输出已截断,超过50KB)" logger.info(f"命令执行成功,输出大小: {len(result)} 字节") return result else: logger.debug("使用本地模式执行命令") if not hasattr(self, 'telnet') or not self.telnet: self.telnet = telnetlib.Telnet('127.0.0.1', self.telnet_port, timeout=timeout) # 清空之前的输出 self.telnet.read_very_eager() logger.debug(f"发送命令: {command}") self.telnet.write(command.encode() + b"\n") # 等待并收集输出 output = "" start_time = time.time() output_size = 0 truncated = False while time.time() - start_time < timeout: try: chunk = self.telnet.read_eager().decode('utf-8') if chunk: logger.debug(f"接收到数据块: {len(chunk)} 字节") chunk_size = len(chunk.encode('utf-8')) if output_size + chunk_size > max_output_size: logger.warning(f"输出超过大小限制 ({max_output_size} 字节),进行截断") remaining = max_output_size - output_size if remaining > 0: output += chunk[:remaining] truncated = True break output += chunk output_size += chunk_size if "$" in chunk: # 命令提示符表示命令执行完成 logger.debug("检测到命令提示符,命令执行完成") break else: time.sleep(0.1) except EOFError: logger.error("连接已关闭") break if time.time() - start_time >= timeout: raise TimeoutError(f"命令执行超时: {command}") # 移除命令回显和提示符 lines = output.split("\n") lines = [line for line in lines if line and not line.startswith(command) and "$" not in line] result = "\n".join(lines) if truncated: result += "\n... (输出已截断,超过50KB)" logger.info(f"命令执行成功,输出大小: {len(result)} 字节") return result except (TimeoutError, socket.timeout) as e: logger.warning(f"命令执行超时 (重试 {retry + 1}/{max_retries}): {e}") if retry < max_retries - 1: time.sleep(retry_interval) continue raise except Exception as e: logger.error(f"命令执行失败: {str(e)}") if retry < max_retries - 1: time.sleep(retry_interval) continue raise raise Exception(f"命令执行失败,已重试{max_retries}次: {command}") def __del__(self): """析构函数,确保断开连接""" self._disconnect() def _format_thread_info(self, output: str) -> str: """格式化线程信息输出 Args: output: 原始输出字符串 Returns: 格式化后的输出字符串 """ try: # 移除ANSI转义序列 output = re.sub(r'\x1b\[[0-9;]*[a-zA-Z]', '', output) # 移除空行和命令提示符 lines = [line.strip() for line in output.split('\n') if line.strip() and not line.strip().endswith('$')] # 如果输出为空,返回原始输出 if not lines: return output return '\n'.join(lines) except Exception as e: logger.error(f"格式化线程信息失败: {str(e)}") return output # 如果格式化失败,返回原始输出 def get_thread_info(self, pid: int) -> Dict[str, Any]: """获取指定进程的线程信息 Args: pid: 进程ID Returns: 包含线程信息的字典 """ try: output = self._execute_command(pid, "thread -n 20") formatted_output = self._format_thread_info(output) return { "raw_output": formatted_output, "timestamp": time.time() } except Exception as e: logger.error(f"获取线程信息失败: {str(e)}") raise def get_jvm_info(self, pid: int) -> str: """获取JVM信息""" return self._execute_command(pid, "jvm") def get_memory_info(self, pid: int) -> str: """获取内存信息""" return self._execute_command(pid, "memory") def get_stack_trace( self, pid: int, thread_id: Optional[int] = None, top_n: Optional[int] = None, find_blocking: bool = False, interval: Optional[int] = None, show_all: bool = False ) -> Dict[str, Any]: """获取线程堆栈信息 Args: pid: 进程ID thread_id: 线程ID top_n: 显示最忙的前N个线程 find_blocking: 是否查找阻塞线程 interval: CPU使用率统计的采样间隔(毫秒) show_all: 是否显示所有线程 Returns: 包含堆栈信息的字典 """ try: cmd = ["thread"] if thread_id is not None: cmd.append(str(thread_id)) elif top_n is not None: cmd.extend(["-n", str(top_n)]) elif show_all: cmd.append("--all") else: cmd.extend(["-n", "20"]) # 默认显示前20个线程 if find_blocking: cmd.append("-b") if interval is not None: cmd.extend(["-i", str(interval)]) output = self._execute_command(pid, " ".join(cmd)) formatted_output = self._format_thread_info(output) return { "raw_output": formatted_output, "timestamp": time.time() } except Exception as e: logger.error(f"获取堆栈信息失败: {str(e)}") raise def get_class_info(self, pid: int, class_pattern: str, show_detail: bool = False, show_field: bool = False, use_regex: bool = False, depth: int = None, classloader_hash: str = None, classloader_class: str = None, max_matches: int = None) -> str: """获取类信息 Args: pid: 进程ID class_pattern: 类名表达式匹配 show_detail: 是否显示详细信息 show_field: 是否显示成员变量信息(需要show_detail=True) use_regex: 是否使用正则表达式匹配 depth: 指定输出静态变量时属性的遍历深度 classloader_hash: 指定class的ClassLoader的hashcode classloader_class: 指定执行表达式的ClassLoader的class name max_matches: 具有详细信息的匹配类的最大数量 """ command = f"sc" # 添加参数 if show_detail: command += " -d" if show_field and show_detail: # show_field需要配合-d使用 command += " -f" if use_regex: command += " -E" if depth is not None: command += f" -x {depth}" if classloader_hash: command += f" -c {classloader_hash}" if classloader_class: command += f" --classLoaderClass {classloader_class}" if max_matches is not None: command += f" -n {max_matches}" # 添加类名匹配模式 command += f" {class_pattern}" return self._execute_command(pid, command) def list_java_processes(self) -> str: """列出Java进程""" if self.ssh_host: # 确保SSH连接有效 self._ensure_ssh_connection() stdin, stdout, stderr = self.ssh.exec_command("jps -l -v") return stdout.read().decode('utf-8') else: result = subprocess.run(["jps", "-l", "-v"], capture_output=True, text=True) return result.stdout def get_version(self, pid: int) -> str: """获取Arthas版本信息""" return self._execute_command(pid, "version") def get_stack_trace_by_method(self, pid: int, class_pattern: str, method_pattern: str, condition: str = None, use_regex: bool = False, max_matches: int = None, max_times: int = None) -> str: """获取方法的调用路径 Args: pid: 进程ID class_pattern: 类名表达式匹配 method_pattern: 方法名表达式匹配 condition: 条件表达式,例如:'params[0]<0' 或 '#cost>10' use_regex: 是否开启正则表达式匹配,默认为通配符匹配 max_matches: 指定Class最大匹配数量,默认值为50 max_times: 执行次数限制 """ command = f"stack {class_pattern} {method_pattern}" # 添加参数 if condition: command += f" '{condition}'" if use_regex: command += " -E" if max_matches is not None: command += f" -m {max_matches}" if max_times is not None: command += f" -n {max_times}" return self._execute_command(pid, command) def decompile_class(self, pid: int, class_pattern: str, method_pattern: str = None) -> str: """反编译指定类的源码 Args: pid: 进程ID class_pattern: 类名表达式 method_pattern: 可选的方法名,如果指定则只反编译特定方法 """ command = f"jad {class_pattern}" if method_pattern: command += f" {method_pattern}" return self._execute_command(pid, command) def search_method(self, pid: int, class_pattern: str, method_pattern: str = None, show_detail: bool = False, use_regex: bool = False, classloader_hash: str = None, classloader_class: str = None, max_matches: int = None) -> str: """查看类的方法信息 Args: pid: 进程ID class_pattern: 类名表达式匹配 method_pattern: 可选的方法名表达式 show_detail: 是否展示每个方法的详细信息 use_regex: 是否开启正则表达式匹配,默认为通配符匹配 classloader_hash: 指定class的ClassLoader的hashcode classloader_class: 指定执行表达式的ClassLoader的class name max_matches: 具有详细信息的匹配类的最大数量(默认为100) """ command = f"sm" # 添加参数 if show_detail: command += " -d" if use_regex: command += " -E" if classloader_hash: command += f" -c {classloader_hash}" if classloader_class: command += f" --classLoaderClass {classloader_class}" if max_matches is not None: command += f" -n {max_matches}" # 添加类名和方法名匹配模式 command += f" {class_pattern}" if method_pattern: command += f" {method_pattern}" return self._execute_command(pid, command) def watch_method(self, pid: int, class_pattern: str, method_pattern: str, watch_params: bool = True, watch_return: bool = True, condition: str = None, max_times: int = 10) -> str: """监控方法的调用情况 Args: pid: 进程ID class_pattern: 类名表达式 method_pattern: 方法名表达式 watch_params: 是否监控参数 watch_return: 是否监控返回值 condition: 条件表达式 max_times: 最大监控次数 """ command = f"watch {class_pattern} {method_pattern}" if watch_params: command += " params" if watch_return: command += " returnObj" if condition: command += f" '{condition}'" command += f" -n {max_times}" return self._execute_command(pid, command) def get_logger_info(self, pid: int, name: str = None) -> str: """获取logger信息 Args: pid: 进程ID name: logger名称 """ command = "logger" if name: command += f" --name {name}" return self._execute_command(pid, command) def set_logger_level(self, pid: int, name: str, level: str) -> str: """设置logger级别 Args: pid: 进程ID name: logger名称 level: 日志级别(trace, debug, info, warn, error) """ return self._execute_command(pid, f"logger --name {name} --level {level}") def get_dashboard(self, pid: int) -> str: """获取系统实时数据面板""" return self._execute_command(pid, "dashboard")