Skip to main content
Glama

NetBrain MCP

by NorthLaneMS
server.py60.1 kB
from mcp.server.fastmcp import FastMCP from typing import Dict, List, Optional, Any import asyncio import logging import sys import os import glob import json # 导入我们创建的模块 from tool_manager import tool_manager, ToolCategory from network_devices import ( device_manager, NetworkDevice, DeviceCredential, DeviceVendor, DeviceType, DeviceStatus, ConnectionProtocol ) from device_connector import connection_manager, CommandResult, SCRAPLI_IMPORT_SUCCESS # 导入资源管理模块 from mcp_resources import ( resource_manager, ResourceType ) # 导入新的模板系统 from template_system import ( template_manager, render_template_with_resources, Message, UserMessage, AssistantMessage, SystemMessage ) # 导入设备命令提示模板系统 import device_prompts # 导入拓扑发现模块 from topology_discovery_improved import create_improved_topology_discovery # 导入网络扫描模块 from network_scanner import create_network_scanner # 引入相同的日志格式化处理 class JsonFormatter(logging.Formatter): def __init__(self, fmt=None, datefmt=None, style='%'): super().__init__(fmt, datefmt, style) def format(self, record): log_record = super().format(record) return log_record.encode('utf-8', errors='replace').decode('utf-8') # 配置日志格式 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(stream=sys.stdout) ] ) # 设置所有处理器使用UTF-8编码格式化 for handler in logging.root.handlers: if isinstance(handler, logging.StreamHandler): handler.setFormatter(JsonFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) logger = logging.getLogger("netbrain_mcp") # 创建一个MCP服务器 mcp = FastMCP("NetBrain MCP") # 创建拓扑发现引擎实例 topology_discovery = create_improved_topology_discovery(device_manager, connection_manager) # 创建网络扫描器实例 network_scanner = create_network_scanner(device_manager) # 工具分类:设备管理 @mcp.tool() async def list_devices( vendor: Optional[str] = None, device_type: Optional[str] = None, status: Optional[str] = None, tag: Optional[str] = None ) -> List[Dict[str, Any]]: """ 列出网络设备,支持过滤 Args: vendor: 设备厂商,可选值:cisco, huawei, h3c, juniper, arista, fortinet, checkpoint, other device_type: 设备类型,可选值:router, switch, firewall, load_balancer, wireless_controller, access_point, other status: 设备状态,可选值:online, offline, unreachable, maintenance, unknown tag: 标签过滤 Returns: 设备列表 """ # 转换过滤参数 vendor_enum = DeviceVendor(vendor) if vendor else None device_type_enum = DeviceType(device_type) if device_type else None status_enum = DeviceStatus(status) if status else None # 获取设备列表 devices = device_manager.list_devices( vendor=vendor_enum, device_type=device_type_enum, status=status_enum, tag=tag ) # 转换为字典列表 return [device.to_dict() for device in devices] @mcp.tool() async def add_device( name: str, ip_address: str, device_type: str, vendor: str, platform: str = "", model: str = "", os_version: str = "", location: str = "", description: str = "", tags: str = "" ) -> Dict[str, Any]: """ 添加新的网络设备 Args: name: 设备名称 ip_address: 设备IP地址 device_type: 设备类型,可选值:router, switch, firewall, load_balancer, wireless_controller, access_point, other vendor: 设备厂商,可选值:cisco, huawei, h3c, juniper, arista, fortinet, checkpoint, other platform: 设备平台(scrapli平台类型),如 cisco_iosxe, huawei_vrp, juniper_junos 等 model: 设备型号 os_version: 操作系统版本 location: 设备位置 description: 设备描述 tags: 设备标签,使用逗号分隔多个标签 Returns: 新添加的设备信息 """ # 转换参数 device_type_enum = DeviceType(device_type) vendor_enum = DeviceVendor(vendor) # 如果未提供platform,根据vendor推断 if not platform: if vendor.lower() == "cisco": platform = "cisco_iosxe" elif vendor.lower() == "huawei": platform = "huawei_vrp" elif vendor.lower() == "juniper": platform = "juniper_junos" elif vendor.lower() == "arista": platform = "arista_eos" else: platform = f"{vendor.lower()}" # 处理标签 tag_list = [] if tags: tag_list = [tag.strip() for tag in tags.split(',') if tag.strip()] # 创建设备对象 device = NetworkDevice( name=name, ip_address=ip_address, device_type=device_type_enum, vendor=vendor_enum, platform=platform, model=model, os_version=os_version, location=location, description=description, tags=tag_list ) # 添加设备 device_id = device_manager.add_device(device) return device.to_dict() @mcp.tool() async def get_device(device_id: str) -> Optional[Dict[str, Any]]: """ 获取设备详细信息 Args: device_id: 设备ID Returns: 设备信息或None """ device = device_manager.get_device(device_id) if device: return device.to_dict() return None @mcp.tool() async def update_device( device_id: str, name: str = "", ip_address: str = "", model: str = "", os_version: str = "", status: str = "", location: str = "", description: str = "", tags: str = "" ) -> Optional[Dict[str, Any]]: """ 更新设备信息 Args: device_id: 设备ID name: 设备名称 ip_address: 设备IP地址 model: 设备型号 os_version: 操作系统版本 status: 设备状态,可选值:online, offline, unreachable, maintenance, unknown location: 设备位置 description: 设备描述 tags: 设备标签,使用逗号分隔多个标签 Returns: 更新后的设备信息或None """ # 准备更新参数 update_kwargs = {} if name: update_kwargs["name"] = name if ip_address: update_kwargs["ip_address"] = ip_address if model: update_kwargs["model"] = model if os_version: update_kwargs["os_version"] = os_version if status: update_kwargs["status"] = DeviceStatus(status) if location: update_kwargs["location"] = location if description: update_kwargs["description"] = description if tags: # 将逗号分隔的标签转换为列表 tag_list = [tag.strip() for tag in tags.split(',') if tag.strip()] update_kwargs["tags"] = tag_list # 更新设备 device = device_manager.update_device(device_id, **update_kwargs) if device: return device.to_dict() return None @mcp.tool() async def delete_device(device_id: str) -> bool: """ 删除设备 Args: device_id: 设备ID Returns: 是否删除成功 """ return device_manager.delete_device(device_id) # 工具分类:凭据管理 @mcp.tool() async def add_credential( name: str, username: str, password: str = "", protocol: str = "ssh", port: Optional[int] = None, enable_password: str = "", ssh_key_file: str = "" ) -> Dict[str, Any]: """ 添加设备凭据 Args: name: 凭据名称 username: 用户名 password: 密码 protocol: 连接协议,可选值:ssh, telnet, snmp, http, https, netconf port: 端口号,可选(默认SSH=22, Telnet=23) enable_password: 特权模式密码(思科设备) ssh_key_file: SSH密钥文件路径 Returns: 凭据ID """ # 确保密码或SSH密钥的要求基于协议 if protocol.lower() == "ssh" and not password and not ssh_key_file: return {"success": False, "message": "SSH连接需要提供密码或SSH密钥文件"} elif protocol.lower() != "ssh" and not password: return {"success": False, "message": f"{protocol.upper()}连接需要提供密码"} # 转换参数 protocol_enum = ConnectionProtocol(protocol.lower()) # 设置默认端口 if port is None: if protocol_enum == ConnectionProtocol.SSH: port = 22 elif protocol_enum == ConnectionProtocol.TELNET: port = 23 elif protocol_enum == ConnectionProtocol.HTTP: port = 80 elif protocol_enum == ConnectionProtocol.HTTPS: port = 443 else: port = 0 # 记录凭据信息(不包含敏感信息) logger.info(f"添加凭据: name={name}, username={username}, protocol={protocol}, port={port}") if ssh_key_file: logger.info(f"使用SSH密钥文件: {ssh_key_file}") # 创建凭据对象 credential = DeviceCredential( name=name, username=username, password=password, protocol=protocol_enum, port=port, enable_password=enable_password, ssh_key_file=ssh_key_file ) # 添加凭据 credential_id = device_manager.add_credential(credential) return {"success": True, "id": credential_id, "name": name} @mcp.tool() async def list_credentials() -> List[Dict[str, Any]]: """ 列出所有设备凭据 Returns: 凭据列表 """ credentials = device_manager.list_credentials() return [ { "id": cred.id, "name": cred.name, "username": cred.username, "password": cred.password, "protocol": cred.protocol.value, "port": cred.port, "ssh_key_file": cred.ssh_key_file } for cred in credentials ] # 工具分类:设备连接 @mcp.tool() async def connect_device(device_id: str, credential_id: str) -> Dict[str, Any]: """ 连接到网络设备 Args: device_id: 设备ID credential_id: 凭据ID Returns: 连接结果 """ # 获取设备和凭据 device = device_manager.get_device(device_id) credential = device_manager.get_credential(credential_id) if not device: return {"success": False, "message": f"设备不存在: {device_id}"} if not credential: return {"success": False, "message": f"凭据不存在: {credential_id}"} # 记录连接信息 logger.info(f"尝试连接设备: {device.name} ({device.ip_address}) 使用凭据: {credential.name}") logger.info(f"连接协议: {credential.protocol.value}, 端口: {credential.port or '默认'}") # 连接设备 success, error = await connection_manager.connect_device(device, credential) if success: # 获取更多的连接信息 prompt = "" version_info = "" try: # 获取连接的connector connection_key = f"{device_id}_{credential_id}" connector = connection_manager.active_connections.get(connection_key) if connector and connector.connection: # 获取提示符 if hasattr(connector.connection, "get_prompt"): if asyncio.iscoroutinefunction(connector.connection.get_prompt): prompt = await connector.connection.get_prompt() else: prompt = await asyncio.to_thread(connector.connection.get_prompt) logger.info(f"成功获取设备提示符: {prompt}") # 使用与test_scrapli_connection相同的方式获取版本信息 try: from scrapli import Scrapli # 确定平台类型 platform = device.platform if hasattr(device, 'platform') and device.platform else "" # 根据平台类型选择命令 if "huawei" in platform.lower(): command = "display version" elif "cisco" in platform.lower(): command = "show version" else: command = "show version" logger.info(f"尝试获取版本信息,平台: {platform},命令: {command}") # 直接使用Scrapli连接对象发送命令 if hasattr(connector.connection, "send_command"): # 等待1秒确保连接稳定 await asyncio.sleep(1) if asyncio.iscoroutinefunction(connector.connection.send_command): resp = await connector.connection.send_command(command) if hasattr(resp, "result"): version_info = resp.result else: version_info = str(resp) else: # 同步调用需要运行在线程中 resp = await asyncio.to_thread(connector.connection.send_command, command) if hasattr(resp, "result"): version_info = resp.result else: version_info = str(resp) logger.info(f"成功获取版本信息,长度: {len(version_info)}") # 限制长度 if len(version_info) > 500: version_info = version_info[:500] + "..." else: logger.warning("连接对象没有send_command方法") version_info = "无法获取版本信息:连接对象不支持发送命令" except Exception as e: logger.warning(f"获取版本信息时出错: {str(e)}") # 回退到使用connection_manager发送命令 try: result, _ = await connection_manager.send_command( device_id=device_id, credential_id=credential_id, command=command, timeout=60 ) if result and result.success: version_info = result.output if len(version_info) > 500: version_info = version_info[:500] + "..." else: version_info = "无法获取版本信息" except Exception as e2: logger.error(f"备用方法获取版本信息也失败: {str(e2)}") version_info = "无法获取版本信息" except Exception as e: logger.warning(f"获取设备详细信息时出错: {str(e)}") version_info = "获取版本信息时出错" return { "success": True, "message": f"成功连接到设备: {device.name} ({device.ip_address})", "protocol": credential.protocol.value, "prompt": prompt, "version_info": version_info } else: logger.error(f"连接设备失败: {device.name} ({device.ip_address}), 错误: {error}") return {"success": False, "message": f"连接设备失败: {error}"} @mcp.tool() async def disconnect_device(device_id: str, credential_id: str) -> Dict[str, Any]: """ 断开与网络设备的连接 Args: device_id: 设备ID credential_id: 凭据ID Returns: 断开连接结果 """ # 断开连接 success, error = await connection_manager.disconnect_device(device_id, credential_id) if success: return {"success": True, "message": "成功断开连接"} else: return {"success": False, "message": f"断开连接失败: {error}"} @mcp.tool() async def send_command( device_id: str, credential_id: str, command: str, timeout: int = 30 ) -> Dict[str, Any]: """ 向网络设备发送命令 Args: device_id: 设备ID credential_id: 凭据ID command: 要执行的命令 timeout: 命令超时时间(秒) Returns: 命令执行结果 """ # 发送命令 result, error = await connection_manager.send_command( device_id=device_id, credential_id=credential_id, command=command, timeout=timeout ) if result: return { "success": result.success, "command": result.command, "output": result.output, "error_message": result.error_message, "execution_time": result.execution_time.isoformat() } else: return { "success": False, "command": command, "output": "", "error_message": error, "execution_time": None } @mcp.tool() async def send_commands( device_id: str, credential_id: str, commands: str, timeout: int = 30 ) -> List[Dict[str, Any]]: """ 向网络设备发送多个命令 Args: device_id: 设备ID credential_id: 凭据ID commands: 要执行的命令列表,使用分号(;)分隔多个命令 timeout: 每个命令的超时时间(秒) Returns: 命令执行结果列表 """ # 将分号分隔的命令字符串转换为列表 command_list = [cmd.strip() for cmd in commands.split(';') if cmd.strip()] if not command_list: return [{"success": False, "message": "未提供有效的命令"}] # 发送命令 results, error = await connection_manager.send_commands( device_id=device_id, credential_id=credential_id, commands=command_list, timeout=timeout ) if results: return [ { "success": result.success, "command": result.command, "output": result.output, "error_message": result.error_message, "execution_time": result.execution_time.isoformat() } for result in results ] else: return [ { "success": False, "command": command, "output": "", "error_message": error, "execution_time": None } for command in command_list ] @mcp.tool() async def get_active_connections() -> List[Dict[str, Any]]: """ 获取活动连接列表 Returns: 活动连接信息列表 """ return connection_manager.get_active_connections() # 工具分类:拓扑发现 @mcp.tool() async def discover_topology(device_ids: str) -> Dict[str, Any]: """ 从指定设备开始发现网络拓扑 Args: device_ids: 设备ID列表,用逗号分隔,如"device1,device2" Returns: 拓扑发现结果 """ try: # 解析设备ID列表 if isinstance(device_ids, str): device_id_list = [device_id.strip() for device_id in device_ids.split(',') if device_id.strip()] else: device_id_list = device_ids if not device_id_list: return {"success": False, "message": "请提供至少一个设备ID"} logger.info(f"开始拓扑发现,设备列表: {device_id_list}") # 验证设备是否存在 valid_devices = [] for device_id in device_id_list: device = device_manager.get_device(device_id) if device: valid_devices.append(device_id) else: logger.warning(f"设备不存在: {device_id}") if not valid_devices: return {"success": False, "message": "没有找到有效的设备"} # 开始拓扑发现 topology = await topology_discovery.discover_topology_from_devices(valid_devices) return { "success": True, "message": f"拓扑发现完成,发现{len(topology.nodes)}个节点,{len(topology.links)}条链路", "topology": topology.to_dict() } except Exception as e: logger.error(f"拓扑发现失败: {e}") return {"success": False, "message": f"拓扑发现失败: {str(e)}"} @mcp.tool() async def get_topology() -> Dict[str, Any]: """ 获取当前网络拓扑 Returns: 当前拓扑信息 """ try: topology = topology_discovery.get_topology() return { "success": True, "topology": topology.to_dict() } except Exception as e: logger.error(f"获取拓扑失败: {e}") return {"success": False, "message": f"获取拓扑失败: {str(e)}"} @mcp.tool() async def clear_topology() -> Dict[str, Any]: """ 清空拓扑数据 Returns: 操作结果 """ try: topology_discovery.clear_topology() return {"success": True, "message": "拓扑数据已清空"} except Exception as e: logger.error(f"清空拓扑失败: {e}") return {"success": False, "message": f"清空拓扑失败: {str(e)}"} @mcp.tool() async def get_device_neighbors(device_id: str) -> Dict[str, Any]: """ 获取指定设备的邻居设备 Args: device_id: 设备ID Returns: 邻居设备列表 """ try: topology = topology_discovery.get_topology() neighbors = topology.get_device_neighbors(device_id) # 获取邻居设备的详细信息 neighbor_details = [] for neighbor_id in neighbors: device = device_manager.get_device(neighbor_id) if device: neighbor_details.append(device.to_dict()) return { "success": True, "device_id": device_id, "neighbors": neighbor_details } except Exception as e: logger.error(f"获取设备邻居失败: {e}") return {"success": False, "message": f"获取设备邻居失败: {str(e)}"} @mcp.tool() async def discover_device_neighbors(device_id: str) -> Dict[str, Any]: """ 发现单个设备的邻居(实时发现) Args: device_id: 设备ID Returns: 发现的邻居信息 """ try: device = device_manager.get_device(device_id) if not device: return {"success": False, "message": f"设备不存在: {device_id}"} logger.info(f"开始发现设备 {device.name} 的邻居") # 发现设备邻居 neighbors, interfaces = await topology_discovery.discover_device_neighbors(device) return { "success": True, "device_id": device_id, "device_name": device.name, "neighbors": neighbors, "interfaces": [iface.to_dict() for iface in interfaces], "neighbor_count": len(neighbors), "interface_count": len(interfaces) } except Exception as e: logger.error(f"发现设备邻居失败: {e}") return {"success": False, "message": f"发现设备邻居失败: {str(e)}"} @mcp.tool() async def get_topology_statistics() -> Dict[str, Any]: """ 获取拓扑统计信息 Returns: 拓扑统计数据 """ try: topology = topology_discovery.get_topology() # 统计不同协议的链路数量 protocol_stats = {} for link in topology.links: protocol = link.protocol.value protocol_stats[protocol] = protocol_stats.get(protocol, 0) + 1 # 统计不同厂商的设备数量 vendor_stats = {} for node in topology.nodes.values(): vendor = node.vendor vendor_stats[vendor] = vendor_stats.get(vendor, 0) + 1 return { "success": True, "total_nodes": len(topology.nodes), "total_links": len(topology.links), "protocol_distribution": protocol_stats, "vendor_distribution": vendor_stats, "last_discovery": topology.last_discovery.isoformat() if topology.last_discovery else None, "discovery_scope": topology.discovery_scope } except Exception as e: logger.error(f"获取拓扑统计失败: {e}") return {"success": False, "message": f"获取拓扑统计失败: {str(e)}"} # 工具分类:网络扫描 @mcp.tool() async def scan_network_range( network: str, timeout: float = 3.0, max_concurrent: int = 50, ping_enabled: bool = True, port_scan_enabled: bool = True, snmp_enabled: bool = True, auto_create_devices: bool = False ) -> Dict[str, Any]: """ 扫描指定网络范围内的活跃设备 Args: network: 网络范围,如 "192.168.1.0/24" 或 "10.0.0.0/16" timeout: 扫描超时时间(秒) max_concurrent: 最大并发扫描数 ping_enabled: 是否启用ping扫描 port_scan_enabled: 是否启用端口扫描 snmp_enabled: 是否启用SNMP扫描 auto_create_devices: 是否自动创建发现的设备 Returns: 扫描结果 """ try: from network_scanner import ScanConfiguration # 创建扫描配置 config = ScanConfiguration( timeout=timeout, max_concurrent=max_concurrent, ping_enabled=ping_enabled, port_scan_enabled=port_scan_enabled, snmp_enabled=snmp_enabled ) logger.info(f"开始网络扫描: {network}") # 执行网络扫描 scan_results = await network_scanner.scan_network_range(network, config) # 如果启用自动创建设备,则创建发现的设备 discovered_devices = [] if auto_create_devices: discovered_devices = await network_scanner.discover_devices_from_scan(scan_results, auto_create=True) return { "success": True, "message": f"网络扫描完成,发现 {len(scan_results)} 个活跃主机", "network_range": network, "alive_hosts": len(scan_results), "scan_results": [result.to_dict() for result in scan_results], "auto_created_devices": len(discovered_devices), "discovered_devices": [device.to_dict() for device in discovered_devices] } except Exception as e: logger.error(f"网络扫描失败: {e}") return {"success": False, "message": f"网络扫描失败: {str(e)}"} @mcp.tool() async def get_scan_results( ip_address: Optional[str] = None, alive_only: bool = True ) -> Dict[str, Any]: """ 获取网络扫描结果 Args: ip_address: 特定IP地址,如果不提供则返回所有结果 alive_only: 是否只返回活跃主机 Returns: 扫描结果 """ try: if ip_address: # 获取特定IP的扫描结果 result = network_scanner.scan_results.get(ip_address) if result: return { "success": True, "scan_result": result.to_dict() } else: return {"success": False, "message": f"未找到IP {ip_address} 的扫描结果"} else: # 获取所有扫描结果 results = list(network_scanner.scan_results.values()) if alive_only: results = [r for r in results if r.is_alive] return { "success": True, "total_results": len(results), "scan_results": [result.to_dict() for result in results] } except Exception as e: logger.error(f"获取扫描结果失败: {e}") return {"success": False, "message": f"获取扫描结果失败: {str(e)}"} @mcp.tool() async def get_scan_statistics() -> Dict[str, Any]: """ 获取网络扫描统计信息 Returns: 扫描统计数据 """ try: stats = network_scanner.get_scan_statistics() return { "success": True, "statistics": stats } except Exception as e: logger.error(f"获取扫描统计失败: {e}") return {"success": False, "message": f"获取扫描统计失败: {str(e)}"} @mcp.tool() async def discover_devices_from_scan_results( min_response_time: Optional[float] = None, required_ports: str = "", vendor_filter: Optional[str] = None, device_type_filter: Optional[str] = None ) -> Dict[str, Any]: """ 从扫描结果中发现并创建设备 Args: min_response_time: 最小响应时间过滤(毫秒) required_ports: 必需的开放端口,使用逗号分隔,如"22,23,161" vendor_filter: 厂商过滤,如"cisco"、"huawei" device_type_filter: 设备类型过滤,如"switch"、"router" Returns: 发现和创建的设备信息 """ try: # 获取所有活跃的扫描结果 all_results = [r for r in network_scanner.scan_results.values() if r.is_alive] # 应用过滤条件 filtered_results = [] for result in all_results: # 响应时间过滤 if min_response_time and result.response_time and result.response_time > min_response_time: continue # 端口过滤 if required_ports: required_port_list = [int(p.strip()) for p in required_ports.split(',') if p.strip().isdigit()] if not all(port in result.open_ports for port in required_port_list): continue # 厂商过滤 if vendor_filter and result.vendor: if vendor_filter.lower() not in result.vendor.lower(): continue # 设备类型过滤 if device_type_filter and result.device_type: if device_type_filter.lower() not in result.device_type.lower(): continue filtered_results.append(result) # 从过滤后的结果创建设备 discovered_devices = await network_scanner.discover_devices_from_scan(filtered_results, auto_create=True) return { "success": True, "message": f"从 {len(filtered_results)} 个扫描结果中创建了 {len(discovered_devices)} 个设备", "filtered_results_count": len(filtered_results), "created_devices_count": len(discovered_devices), "created_devices": [device.to_dict() for device in discovered_devices] } except Exception as e: logger.error(f"从扫描结果创建设备失败: {e}") return {"success": False, "message": f"从扫描结果创建设备失败: {str(e)}"} @mcp.tool() async def clear_scan_results() -> Dict[str, Any]: """ 清空网络扫描结果 Returns: 操作结果 """ try: network_scanner.clear_scan_results() return {"success": True, "message": "扫描结果已清空"} except Exception as e: logger.error(f"清空扫描结果失败: {e}") return {"success": False, "message": f"清空扫描结果失败: {str(e)}"} @mcp.tool() async def scan_single_host( ip_address: str, timeout: float = 3.0, port_scan_enabled: bool = True, snmp_enabled: bool = True ) -> Dict[str, Any]: """ 扫描单个主机 Args: ip_address: 目标IP地址 timeout: 扫描超时时间(秒) port_scan_enabled: 是否启用端口扫描 snmp_enabled: 是否启用SNMP扫描 Returns: 单个主机扫描结果 """ try: from network_scanner import ScanConfiguration # 创建扫描配置 config = ScanConfiguration( timeout=timeout, port_scan_enabled=port_scan_enabled, snmp_enabled=snmp_enabled ) logger.info(f"开始扫描单个主机: {ip_address}") # 执行单个主机扫描 scan_result = await network_scanner.scan_single_host(ip_address, config) # 保存到扫描结果 network_scanner.scan_results[ip_address] = scan_result network_scanner.save_scan_results() return { "success": True, "message": f"主机扫描完成: {ip_address}", "is_alive": scan_result.is_alive, "scan_result": scan_result.to_dict() } except Exception as e: logger.error(f"扫描单个主机失败: {e}") return {"success": False, "message": f"扫描单个主机失败: {str(e)}"} # =================== 资源工具 =================== @mcp.tool() async def list_resources() -> List[Dict[str, Any]]: """ 列出可用的MCP资源 Returns: 资源列表 """ return resource_manager.list_available_resources() @mcp.tool() async def get_resource(uri: str, use_cache: bool = True) -> Dict[str, Any]: """ 获取指定URI的资源 Args: uri: 资源URI use_cache: 是否使用缓存 Returns: 资源内容 """ return await resource_manager.get_resource(uri, use_cache) @mcp.tool() async def clear_resource_cache(uri: Optional[str] = None) -> Dict[str, Any]: """ 清除资源缓存 Args: uri: 要清除的特定资源URI,如果为None则清除所有缓存 Returns: 操作结果 """ success = resource_manager.clear_cache(uri) if success: if uri: return {"success": True, "message": f"已清除资源缓存: {uri}"} else: return {"success": True, "message": "已清除所有资源缓存"} else: return {"success": False, "message": "清除缓存失败"} # =================== 模板工具 =================== @mcp.tool() async def list_templates() -> List[Dict[str, Any]]: """ 列出可用的提示模板 Returns: 模板列表 """ return template_manager.list_templates() @mcp.tool() async def render_template( template_name: str, context: str = "{}", resource_uris: str = "" ) -> Dict[str, Any]: """ 渲染提示模板 Args: template_name: 模板名称 context: JSON格式的渲染上下文,如 {"name": "value"} resource_uris: 资源URI映射,格式为 "key1=uri1,key2=uri2" Returns: 渲染结果 """ try: # 解析context context_dict = {} if context.strip(): try: context_dict = json.loads(context) if not isinstance(context_dict, dict): context_dict = {} except json.JSONDecodeError: logger.warning(f"无法解析context JSON: {context}") return {"success": False, "message": "无效的JSON上下文格式"} # 解析resource_uris resource_dict = {} if resource_uris.strip(): uri_pairs = resource_uris.split(',') for pair in uri_pairs: if '=' in pair: key, uri = pair.split('=', 1) resource_dict[key.strip()] = uri.strip() if resource_dict: # 渲染带有资源的模板 result = await render_template_with_resources( template_name=template_name, context=context_dict, resource_uris=resource_dict, resource_manager=resource_manager ) else: # 渲染简单模板 result = template_manager.render_template(template_name, context_dict) if result: if isinstance(result, list): # 处理消息列表 messages = [ { "role": msg.role, "content": msg.content } for msg in result ] return {"success": True, "messages": messages} else: # 处理字符串结果 return {"success": True, "rendered_template": result} else: return {"success": False, "message": f"渲染模板失败: {template_name}"} except Exception as e: logger.error(f"渲染模板异常: {str(e)}") return {"success": False, "message": f"渲染模板异常: {str(e)}"} # =================== 资源注册 =================== # 将资源处理函数注册到资源管理器 @resource_manager.register_resource("greeting/{name}", ResourceType.SYSTEM) async def get_greeting(name: str) -> str: """获取个性化问候语""" return f"你好,{name}!欢迎使用NetBrain MCP!" @resource_manager.register_resource("device/{device_id}", ResourceType.DEVICE) async def get_device_resource(device_id: str) -> Dict[str, Any]: """获取设备资源""" device = device_manager.get_device(device_id) if device: return device.to_dict() return {"error": f"设备不存在: {device_id}"} @resource_manager.register_resource("device/{device_id}/config", ResourceType.CONFIG) async def get_device_config(device_id: str) -> Dict[str, Any]: """获取设备配置资源""" device = device_manager.get_device(device_id) if not device: return {"error": f"设备不存在: {device_id}"} # 这里我们会尝试获取设备的运行配置 # 首先获取默认凭据 credentials = device_manager.list_credentials() if not credentials: return {"error": "没有可用的凭据"} credential_id = device.credential_id or credentials[0].id credential = device_manager.get_credential(credential_id) # 尝试获取设备配置 if device.vendor == DeviceVendor.CISCO: command = "show running-config" elif device.vendor == DeviceVendor.HUAWEI: command = "display current-configuration" elif device.vendor == DeviceVendor.H3C: command = "display current-configuration" elif device.vendor == DeviceVendor.JUNIPER: command = "show configuration" else: command = "show running-config" # 默认命令 try: # 连接设备 success, error = await connection_manager.connect_device(device, credential) if not success: return {"error": f"连接设备失败: {error}"} # 发送命令 result, error = await connection_manager.send_command( device_id=device_id, credential_id=credential_id, command=command, timeout=60 # 配置命令可能需要更长时间 ) # 断开连接 await connection_manager.disconnect_device(device_id, credential_id) if result and result.success: return { "device_id": device_id, "device_name": device.name, "config_type": "running", "config_content": result.output, "timestamp": result.execution_time.isoformat() } else: return {"error": f"获取配置失败: {error or '未知错误'}"} except Exception as e: return {"error": f"获取配置时发生错误: {str(e)}"} @resource_manager.register_resource("device/{device_id}/interfaces", ResourceType.DEVICE) async def get_device_interfaces(device_id: str) -> Dict[str, Any]: """获取设备接口资源""" device = device_manager.get_device(device_id) if not device: return {"error": f"设备不存在: {device_id}"} # 获取默认凭据 credentials = device_manager.list_credentials() if not credentials: return {"error": "没有可用的凭据"} credential_id = device.credential_id or credentials[0].id credential = device_manager.get_credential(credential_id) # 根据设备类型确定接口命令 if device.vendor == DeviceVendor.CISCO: command = "show interfaces" elif device.vendor == DeviceVendor.HUAWEI: command = "display interface" elif device.vendor == DeviceVendor.H3C: command = "display interface" elif device.vendor == DeviceVendor.JUNIPER: command = "show interfaces detail" else: command = "show interfaces" # 默认命令 try: # 连接设备 success, error = await connection_manager.connect_device(device, credential) if not success: return {"error": f"连接设备失败: {error}"} # 发送命令 result, error = await connection_manager.send_command( device_id=device_id, credential_id=credential_id, command=command, timeout=30 ) # 断开连接 await connection_manager.disconnect_device(device_id, credential_id) if result and result.success: return { "device_id": device_id, "device_name": device.name, "interfaces_output": result.output, "timestamp": result.execution_time.isoformat() } else: return {"error": f"获取接口信息失败: {error or '未知错误'}"} except Exception as e: return {"error": f"获取接口信息时发生错误: {str(e)}"} @resource_manager.register_resource("device/{device_id}/routes", ResourceType.DEVICE) async def get_device_routes(device_id: str) -> Dict[str, Any]: """获取设备路由表资源""" device = device_manager.get_device(device_id) if not device: return {"error": f"设备不存在: {device_id}"} # 获取默认凭据 credentials = device_manager.list_credentials() if not credentials: return {"error": "没有可用的凭据"} credential_id = device.credential_id or credentials[0].id credential = device_manager.get_credential(credential_id) # 根据设备类型确定路由命令 if device.vendor == DeviceVendor.CISCO: command = "show ip route" elif device.vendor == DeviceVendor.HUAWEI: command = "display ip routing-table" elif device.vendor == DeviceVendor.H3C: command = "display ip routing-table" elif device.vendor == DeviceVendor.JUNIPER: command = "show route" else: command = "show ip route" # 默认命令 try: # 连接设备 success, error = await connection_manager.connect_device(device, credential) if not success: return {"error": f"连接设备失败: {error}"} # 发送命令 result, error = await connection_manager.send_command( device_id=device_id, credential_id=credential_id, command=command, timeout=30 ) # 断开连接 await connection_manager.disconnect_device(device_id, credential_id) if result and result.success: return { "device_id": device_id, "device_name": device.name, "routes_output": result.output, "timestamp": result.execution_time.isoformat() } else: return {"error": f"获取路由表失败: {error or '未知错误'}"} except Exception as e: return {"error": f"获取路由表时发生错误: {str(e)}"} @resource_manager.register_resource("credentials", ResourceType.CREDENTIAL) async def list_credentials_resource() -> Dict[str, Any]: """ 列出所有设备凭据 Returns: 凭据列表 """ credentials = device_manager.list_credentials() return { "credentials": [ { "id": cred.id, "name": cred.name, "username": cred.username, "password": cred.password, "protocol": cred.protocol.value, "port": cred.port, "ssh_key_file": cred.ssh_key_file } for cred in credentials ] } @resource_manager.register_resource("system/status", ResourceType.SYSTEM) async def get_system_status() -> Dict[str, Any]: """获取系统状态资源""" devices_count = len(device_manager.list_devices()) credentials_count = len(device_manager.list_credentials()) active_connections = len(connection_manager.get_active_connections()) return { "status": "running", "devices_count": devices_count, "credentials_count": credentials_count, "active_connections": active_connections, "version": "0.1.0" } @resource_manager.register_resource("topology", ResourceType.TOPOLOGY) async def get_topology_resource() -> Dict[str, Any]: """获取网络拓扑资源""" try: topology = topology_discovery.get_topology() return topology.to_dict() except Exception as e: logger.error(f"获取拓扑资源失败: {e}") return {"error": f"获取拓扑资源失败: {str(e)}"} @resource_manager.register_resource("topology/statistics", ResourceType.TOPOLOGY) async def get_topology_statistics_resource() -> Dict[str, Any]: """获取拓扑统计资源""" try: topology = topology_discovery.get_topology() # 统计不同协议的链路数量 protocol_stats = {} for link in topology.links: protocol = link.protocol.value protocol_stats[protocol] = protocol_stats.get(protocol, 0) + 1 # 统计不同厂商的设备数量 vendor_stats = {} for node in topology.nodes.values(): vendor = node.vendor vendor_stats[vendor] = vendor_stats.get(vendor, 0) + 1 return { "total_nodes": len(topology.nodes), "total_links": len(topology.links), "protocol_distribution": protocol_stats, "vendor_distribution": vendor_stats, "last_discovery": topology.last_discovery.isoformat() if topology.last_discovery else None, "discovery_scope": topology.discovery_scope } except Exception as e: logger.error(f"获取拓扑统计资源失败: {e}") return {"error": f"获取拓扑统计资源失败: {str(e)}"} @resource_manager.register_resource("device/{device_id}/neighbors", ResourceType.TOPOLOGY) async def get_device_neighbors_resource(device_id: str) -> Dict[str, Any]: """获取设备邻居资源""" try: topology = topology_discovery.get_topology() neighbors = topology.get_device_neighbors(device_id) # 获取邻居设备的详细信息 neighbor_details = [] for neighbor_id in neighbors: device = device_manager.get_device(neighbor_id) if device: neighbor_details.append(device.to_dict()) return { "device_id": device_id, "neighbors": neighbor_details, "neighbor_count": len(neighbor_details) } except Exception as e: logger.error(f"获取设备邻居资源失败: {e}") return {"error": f"获取设备邻居资源失败: {str(e)}"} @resource_manager.register_resource("scan/results", ResourceType.SCAN) async def get_scan_results_resource() -> Dict[str, Any]: """获取网络扫描结果资源""" try: results = [r for r in network_scanner.scan_results.values() if r.is_alive] return { "total_results": len(results), "scan_results": [result.to_dict() for result in results] } except Exception as e: logger.error(f"获取扫描结果资源失败: {e}") return {"error": f"获取扫描结果资源失败: {str(e)}"} @resource_manager.register_resource("scan/statistics", ResourceType.SCAN) async def get_scan_statistics_resource() -> Dict[str, Any]: """获取网络扫描统计资源""" try: stats = network_scanner.get_scan_statistics() return stats except Exception as e: logger.error(f"获取扫描统计资源失败: {e}") return {"error": f"获取扫描统计资源失败: {str(e)}"} @resource_manager.register_resource("scan/result/{ip_address}", ResourceType.SCAN) async def get_scan_result_resource(ip_address: str) -> Dict[str, Any]: """获取特定IP的扫描结果资源""" try: result = network_scanner.scan_results.get(ip_address) if result: return result.to_dict() else: return {"error": f"未找到IP {ip_address} 的扫描结果"} except Exception as e: logger.error(f"获取扫描结果资源失败: {e}") return {"error": f"获取扫描结果资源失败: {str(e)}"} # 向MCP服务器注册资源获取方法 @mcp.resource("{uri}") async def mcp_resource_handler(uri: str) -> Any: """MCP资源处理器,处理所有资源请求""" logger.info(f"处理资源请求: {uri}") return await resource_manager.get_resource(uri) # 向MCP服务器注册提示模板 @mcp.prompt("{name}") async def mcp_prompt_handler(name: str) -> str: """MCP提示模板处理器,处理所有模板请求""" logger.info(f"处理模板请求: {name}") template = template_manager.get_template(name) if template: # 返回模板描述或模板名称 return template.description or template.name return f"错误: 未找到模板 '{name}'" # 测试工具 @mcp.tool() async def test_scrapli_connection( host: str, username: str, password: str = None, platform: str = "cisco_iosxe", port: int = 22, protocol: str = "ssh", ssh_key_file: Optional[str] = None, connect_timeout: int = 15 ) -> Dict[str, Any]: """ 测试Scrapli连接(直接使用Scrapli) Args: host: 设备主机名或IP地址 username: 用户名 password: 密码(与ssh_key_file至少提供一个) platform: 平台类型,如cisco_iosxe, huawei_vrp等 port: 端口号(默认22,Telnet使用23) protocol: 连接协议,ssh或telnet ssh_key_file: SSH密钥文件路径,可选 connect_timeout: 连接超时时间(秒) Returns: 连接测试结果 """ # 从device_connector中导入SCRAPLI_IMPORT_SUCCESS检查scrapli是否可用 if not SCRAPLI_IMPORT_SUCCESS: return { "success": False, "message": "Scrapli库导入失败,请确保已安装: pip install scrapli scrapli-community" } try: from scrapli import Scrapli # 修正平台名称处理 if platform.lower() == "huawei": platform = "huawei_vrp" # 确定连接协议 is_telnet = protocol.lower() == "telnet" if is_telnet and port == 22: # 如果是Telnet但端口仍为默认SSH端口,修正为默认Telnet端口 port = 23 logger.info(f"连接参数: host={host}, username={username}, platform={platform}, port={port}, protocol={protocol}") # 打印密码信息(安全地) if password: logger.info(f"密码类型: {type(password)}, 长度: {len(str(password))}") password = str(password) # 确保密码是字符串类型 # 准备连接参数 device_params = { "host": host, "auth_username": username, "auth_strict_key": False, "platform": platform, "port": port, "timeout_socket": connect_timeout, "timeout_transport": connect_timeout * 2, } # 根据协议设置传输方式 if is_telnet: device_params["transport"] = "telnet" logger.info("使用Telnet连接") else: # 指定Windows兼容的传输方式 device_params["transport"] = "paramiko" # 在Windows上使用paramiko而不是system logger.info("使用SSH连接 (paramiko)") # 设置认证方式,优先使用SSH密钥 if not is_telnet and ssh_key_file and os.path.exists(ssh_key_file): device_params["auth_private_key"] = ssh_key_file logger.info(f"使用SSH密钥认证: {ssh_key_file}") elif password: device_params["auth_password"] = password logger.info("使用密码认证") else: return {"success": False, "message": "必须提供密码或SSH密钥文件"} # 详细记录连接参数(移除敏感信息) safe_params = device_params.copy() if "auth_password" in safe_params: safe_params["auth_password"] = "******" if "auth_private_key" in safe_params: safe_params["auth_private_key"] = f"[使用密钥文件: {ssh_key_file}]" logger.info(f"Scrapli测试连接参数: {safe_params}") # 建立连接 conn = Scrapli(**device_params) conn.open() # 获取设备提示符 prompt = conn.get_prompt() # 测试发送命令 try: # 根据平台类型选择命令 if "huawei" in platform: command = "display version" elif "cisco" in platform: command = "show version" else: command = "show version" resp = conn.send_command(command) version_info = resp.result[:200] + "..." if len(resp.result) > 200 else resp.result except Exception as e: logger.warning(f"获取版本信息失败: {str(e)}") version_info = "无法获取版本信息" # 关闭连接 conn.close() return { "success": True, "message": f"连接成功!设备提示符: {prompt}", "prompt": prompt, "version_info": version_info } except ImportError as e: return {"success": False, "message": f"Scrapli库未安装: {str(e)}"} except Exception as e: error_type = type(e).__name__ error_msg = str(e) logger.error(f"Scrapli连接测试失败: {error_type} - {error_msg}") return { "success": False, "message": f"连接失败: {error_type} - {error_msg}", "error_type": error_type, "error_details": error_msg } @mcp.tool() async def test_telnet_connection( host: str, username: str, password: str, platform: str = "cisco_iosxe", port: int = 23, connect_timeout: int = 15 ) -> Dict[str, Any]: """ 测试Telnet连接(快速测试) Args: host: 设备主机名或IP地址 username: 用户名 password: 密码 platform: 平台类型,如cisco_iosxe, huawei_vrp等 port: 端口号(默认23) connect_timeout: 连接超时时间(秒) Returns: 连接测试结果 """ # 直接调用test_scrapli_connection,但固定protocol为telnet return await test_scrapli_connection( host=host, username=username, password=password, platform=platform, port=port, protocol="telnet", connect_timeout=connect_timeout ) @mcp.tool() async def send_telnet_command( device_id: str, credential_id: str, command: str, timeout: int = 60 ) -> Dict[str, Any]: """ 向Telnet连接的网络设备发送命令(使用更长的超时时间) Args: device_id: 设备ID credential_id: 凭据ID command: 要执行的命令 timeout: 命令超时时间(秒,默认比普通命令更长) Returns: 命令执行结果 """ # 获取设备和凭据 device = device_manager.get_device(device_id) credential = device_manager.get_credential(credential_id) if not device: return {"success": False, "message": f"设备不存在: {device_id}"} if not credential: return {"success": False, "message": f"凭据不存在: {credential_id}"} # 验证是否为Telnet凭据 if credential.protocol != ConnectionProtocol.TELNET: return {"success": False, "message": f"此函数仅支持Telnet连接,当前凭据协议为: {credential.protocol.value}"} # 记录命令执行信息 logger.info(f"通过Telnet向设备 {device.name} ({device.ip_address}) 发送命令: {command}") logger.info(f"超时设置: {timeout}秒") # 发送命令 result, error = await connection_manager.send_command( device_id=device_id, credential_id=credential_id, command=command, timeout=timeout ) if result: # 记录执行结果 if result.success: logger.info(f"命令执行成功,输出长度: {len(result.output)}") else: logger.warning(f"命令执行失败: {result.error_message}") return { "success": result.success, "command": result.command, "output": result.output, "error_message": result.error_message, "execution_time": result.execution_time.isoformat() } else: logger.error(f"执行命令失败: {error}") return { "success": False, "command": command, "output": "", "error_message": error, "execution_time": None } # 运行服务器 if __name__ == "__main__": # 检测启动方式 import os # 获取启动命令行参数 import sys command_args = " ".join(sys.argv) # 检查环境变量或命令行,判断是否通过MCP CLI运行 is_mcp_cli = os.environ.get("MCP_CLI_RUN", "0") == "1" or "mcp dev" in command_args or "mcp run" in command_args # 设置环境变量,使其他模块也能知道是通过MCP CLI运行 if is_mcp_cli: os.environ["MCP_CLI_RUN"] = "1" # 始终启动Web服务器(后台运行) import threading import uvicorn def start_web_server(): """启动Web服务器""" from web.app import app uvicorn.run(app, host="0.0.0.0", port=8088) # 在另一个线程中启动Web服务器 web_thread = threading.Thread(target=start_web_server, daemon=True) web_thread.start() print("NetBrain MCP Web界面已启动,访问 http://localhost:8088") if is_mcp_cli: print("MCP CLI模式已检测到") # 在8000端口运行,与MCP Inspector预期一致 mcp.run()

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/NorthLaneMS/NetBrain_MCP'

If you have feedback or need assistance with the MCP directory API, please join our Discord server