Skip to main content
Glama

FastAPI OpenAPI MCP Server

by jason-chang
custom_tool_example.py40.9 kB
""" 自定义 Tool 实例示例 展示如何创建和集成自定义 MCP Tools 到 OpenAPI MCP Server 中。 演示各种类型的自定义工具,包括数据分析、业务逻辑、外部服务等。 """ import json import re from datetime import datetime from typing import Any from fastapi import FastAPI from openapi_mcp import OpenApiMcpServer from openapi_mcp.tools.base import BaseTool from openapi_mcp.types import ToolDefinition, ToolResult # 创建示例 FastAPI 应用 def create_demo_app() -> FastAPI: """创建演示用的 FastAPI 应用""" app = FastAPI( title='Custom Tools Demo API', version='1.0.0', description='展示自定义 Tools 的 API', ) @app.get('/products', tags=['products']) async def list_products(): return {'products': []} @app.get('/orders', tags=['orders']) async def list_orders(): return {'orders': []} @app.get('/users', tags=['users']) async def list_users(): return {'users': []} return app # 自定义工具1: API 分析工具 class ApiAnalyzerTool(BaseTool): """API 分析工具 - 分析 OpenAPI 规范并提供统计信息""" def __init__(self): super().__init__() self.name = 'api_analyzer' self.description = '分析 OpenAPI 规范,提供详细的统计和分析信息' def get_definition(self) -> ToolDefinition: """获取工具定义""" return ToolDefinition( name=self.name, description=self.description, input_schema={ 'type': 'object', 'properties': { 'analysis_type': { 'type': 'string', 'enum': ['overview', 'endpoints', 'models', 'tags', 'security'], 'description': '分析类型', 'default': 'overview', }, 'include_details': { 'type': 'boolean', 'description': '是否包含详细信息', 'default': False, }, 'filter_tag': { 'type': 'string', 'description': '按标签过滤(可选)', }, }, }, ) async def execute(self, arguments: dict[str, Any]) -> ToolResult: """执行工具""" try: analysis_type = arguments.get('analysis_type', 'overview') include_details = arguments.get('include_details', False) filter_tag = arguments.get('filter_tag') # 获取 OpenAPI 规范 spec = await self.get_openapi_spec() if filter_tag: spec = self._filter_spec_by_tag(spec, filter_tag) result = self._analyze_spec(spec, analysis_type, include_details) return ToolResult( success=True, data=result, message=f'API {analysis_type} 分析完成' ) except Exception as e: return ToolResult(success=False, error=str(e), message='API 分析失败') def _filter_spec_by_tag(self, spec: dict, tag: str) -> dict: """按标签过滤 OpenAPI 规范""" filtered_spec = spec.copy() filtered_paths = {} for path, methods in spec.get('paths', {}).items(): filtered_methods = {} for method, details in methods.items(): if 'tags' in details and tag in details['tags']: filtered_methods[method] = details if filtered_methods: filtered_paths[path] = filtered_methods filtered_spec['paths'] = filtered_paths return filtered_spec def _analyze_spec( self, spec: dict, analysis_type: str, include_details: bool ) -> dict: """分析 OpenAPI 规范""" if analysis_type == 'overview': return self._analyze_overview(spec) elif analysis_type == 'endpoints': return self._analyze_endpoints(spec, include_details) elif analysis_type == 'models': return self._analyze_models(spec, include_details) elif analysis_type == 'tags': return self._analyze_tags(spec, include_details) elif analysis_type == 'security': return self._analyze_security(spec, include_details) else: raise ValueError(f'未知的分析类型: {analysis_type}') def _analyze_overview(self, spec: dict) -> dict: """分析概览信息""" info = spec.get('info', {}) paths = spec.get('paths', {}) endpoint_count = sum(len(methods) for methods in paths.values()) path_count = len(paths) return { 'api_info': { 'title': info.get('title'), 'version': info.get('version'), 'description': info.get('description'), }, 'statistics': { 'total_paths': path_count, 'total_endpoints': endpoint_count, 'average_endpoints_per_path': endpoint_count / path_count if path_count > 0 else 0, }, } def _analyze_endpoints(self, spec: dict, include_details: bool) -> dict: """分析端点信息""" paths = spec.get('paths', {}) endpoints_by_method = {} endpoints_by_tag = {} for path, methods in paths.items(): for method, details in methods.items(): method_lower = method.lower() if method_lower not in endpoints_by_method: endpoints_by_method[method_lower] = 0 endpoints_by_method[method_lower] += 1 tags = details.get('tags', []) for tag in tags: if tag not in endpoints_by_tag: endpoints_by_tag[tag] = 0 endpoints_by_tag[tag] += 1 result = { 'by_method': endpoints_by_method, 'by_tag': endpoints_by_tag, 'total': sum(endpoints_by_method.values()), } if include_details: result['detailed_endpoints'] = [] for path, methods in paths.items(): for method, details in methods.items(): result['detailed_endpoints'].append( { 'path': path, 'method': method, 'summary': details.get('summary', ''), 'tags': details.get('tags', []), } ) return result def _analyze_models(self, spec: dict, include_details: bool) -> dict: """分析模型信息""" components = spec.get('components', {}) schemas = components.get('schemas', {}) model_stats = { 'total_models': len(schemas), 'model_types': {}, 'complexity_scores': {}, } for model_name, model_def in schemas.items(): # 分析模型类型 model_type = model_def.get('type', 'object') if model_type not in model_stats['model_types']: model_stats['model_types'][model_type] = 0 model_stats['model_types'][model_type] += 1 # 计算复杂度分数 complexity = self._calculate_model_complexity(model_def) model_stats['complexity_scores'][model_name] = complexity if include_details: model_stats['model_details'] = schemas return model_stats def _analyze_tags(self, spec: dict, include_details: bool) -> dict: """分析标签信息""" tags = spec.get('tags', []) paths = spec.get('paths', {}) tag_stats = {} for tag_info in tags: tag_name = tag_info.get('name') tag_stats[tag_name] = { 'description': tag_info.get('description', ''), 'endpoint_count': 0, 'external_docs': tag_info.get('externalDocs', {}), } # 统计每个标签的端点数量 for methods in paths.values(): for details in methods.values(): for tag in details.get('tags', []): if tag in tag_stats: tag_stats[tag]['endpoint_count'] += 1 return {'total_tags': len(tags), 'tags': tag_stats} def _analyze_security(self, spec: dict, include_details: bool) -> dict: """分析安全配置""" security_schemes = spec.get('components', {}).get('securitySchemes', {}) global_security = spec.get('security', []) security_stats = { 'total_security_schemes': len(security_schemes), 'scheme_types': {}, 'global_security_requirements': len(global_security), 'public_endpoints': 0, 'secured_endpoints': 0, } for _scheme_name, scheme_def in security_schemes.items(): scheme_type = scheme_def.get('type') if scheme_type not in security_stats['scheme_types']: security_stats['scheme_types'][scheme_type] = 0 security_stats['scheme_types'][scheme_type] += 1 # 统计安全端点 paths = spec.get('paths', {}) for methods in paths.values(): for details in methods.values(): if details.get('security'): security_stats['secured_endpoints'] += 1 else: security_stats['public_endpoints'] += 1 if include_details: security_stats['security_schemes'] = security_schemes security_stats['global_security'] = global_security return security_stats def _calculate_model_complexity(self, model_def: dict) -> int: """计算模型复杂度分数""" score = 1 if model_def.get('type') == 'object': properties = model_def.get('properties', {}) score += len(properties) * 2 for prop_def in properties.values(): if '$ref' in prop_def: score += 3 # 引用增加复杂度 elif prop_def.get('type') == 'array': score += 2 # 数组类型增加复杂度 elif prop_def.get('type') == 'object': score += 4 # 嵌套对象显著增加复杂度 return score # 自定义工具2: 代码生成工具 class CodeGeneratorTool(BaseTool): """代码生成工具 - 根据 OpenAPI 规范生成客户端代码""" def __init__(self): super().__init__() self.name = 'code_generator' self.description = '根据 OpenAPI 规范生成各种语言的客户端代码' def get_definition(self) -> ToolDefinition: """获取工具定义""" return ToolDefinition( name=self.name, description=self.description, input_schema={ 'type': 'object', 'properties': { 'language': { 'type': 'string', 'enum': [ 'python', 'javascript', 'typescript', 'java', 'go', 'curl', ], 'description': '目标编程语言', 'default': 'python', }, 'output_format': { 'type': 'string', 'enum': ['classes', 'functions', 'sdk', 'examples'], 'description': '输出格式', 'default': 'functions', }, 'include_auth': { 'type': 'boolean', 'description': '是否包含认证代码', 'default': True, }, 'filter_tags': { 'type': 'array', 'items': {'type': 'string'}, 'description': '只包含指定标签的端点(可选)', }, }, 'required': ['language'], }, ) async def execute(self, arguments: dict[str, Any]) -> ToolResult: """执行工具""" try: language = arguments['language'] output_format = arguments.get('output_format', 'functions') include_auth = arguments.get('include_auth', True) filter_tags = arguments.get('filter_tags') # 获取 OpenAPI 规范 spec = await self.get_openapi_spec() # 生成代码 code = self._generate_code( spec, language, output_format, include_auth, filter_tags ) return ToolResult( success=True, data={ 'language': language, 'format': output_format, 'code': code, 'auth_included': include_auth, }, message=f'{language.title()} 代码生成完成', ) except Exception as e: return ToolResult(success=False, error=str(e), message='代码生成失败') def _generate_code( self, spec: dict, language: str, output_format: str, include_auth: bool, filter_tags: list[str] | None, ) -> str: """生成代码""" if language == 'python': return self._generate_python_code( spec, output_format, include_auth, filter_tags ) elif language == 'javascript': return self._generate_javascript_code( spec, output_format, include_auth, filter_tags ) elif language == 'typescript': return self._generate_typescript_code( spec, output_format, include_auth, filter_tags ) elif language == 'curl': return self._generate_curl_examples(spec, filter_tags) else: raise ValueError(f'不支持的编程语言: {language}') def _generate_python_code( self, spec: dict, output_format: str, include_auth: bool, filter_tags: list[str] | None, ) -> str: """生成 Python 代码""" info = spec.get('info', {}) title = info.get('title', 'API').replace(' ', '_').lower() base_url = spec.get('servers', [{}])[0].get('url', 'https://api.example.com') code = f'''""" Auto-generated {info.get('title', 'API')} Python Client Generated at: {datetime.now().isoformat()} """ import requests import json from typing import Dict, Any, Optional, List ''' if include_auth: code += ''' class Auth: """Authentication helper class""" def __init__(self, api_key: Optional[str] = None, bearer_token: Optional[str] = None): self.api_key = api_key self.bearer_token = bearer_token def get_headers(self) -> Dict[str, str]: """Get authentication headers""" headers = {{}} if self.api_key: headers["X-API-Key"] = self.api_key if self.bearer_token: headers["Authorization"] = f"Bearer {{self.bearer_token}}" return headers ''' code += f''' class {title.title()}Client: """{info.get('title', 'API')} Client""" def __init__(self, base_url: str = "{base_url}", auth: Optional[Auth] = None): self.base_url = base_url.rstrip("/") self.auth = auth self.session = requests.Session() def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]: """Make HTTP request""" url = f"{{self.base_url}}{{endpoint}}" headers = kwargs.get("headers", {{}}) if self.auth: headers.update(self.auth.get_headers()) kwargs["headers"] = headers response = self.session.request(method, url, **kwargs) response.raise_for_status() return response.json() ''' # 添加端点方法 paths = spec.get('paths', {}) for path, methods in paths.items(): for method, details in methods.items(): if filter_tags and not any( tag in filter_tags for tag in details.get('tags', []) ): continue method_name = self._generate_method_name(method, path, details) endpoint_path = self._convert_path_to_template(path) code += f''' def {method_name}(self{self._generate_method_params(path, details)}) -> Dict[str, Any]: """{details.get('summary', '')}""" params = locals() params.pop("self") # Remove None values params = {{k: v for k, v in params.items() if v is not None}} return self._make_request( "{method.upper()}", "{endpoint_path}"{self._generate_request_args(method, path)} ) ''' return code def _generate_javascript_code( self, spec: dict, output_format: str, include_auth: bool, filter_tags: list[str] | None, ) -> str: """生成 JavaScript 代码""" info = spec.get('info', {}) title = info.get('title', 'API').replace(' ', '_').lower() base_url = spec.get('servers', [{}])[0].get('url', 'https://api.example.com') code = f'''/** * Auto-generated {info.get('title', 'API')} JavaScript Client * Generated at: {datetime.now().isoformat()} */ class {title.title()}Client {{ constructor(baseUrl = "{base_url}", apiKey = null) {{ this.baseUrl = baseUrl.replace(/\\/$/, ""); this.apiKey = apiKey; }} async _makeRequest(method, endpoint, options = {{}}) {{ const url = `${{this.baseUrl}}${{endpoint}}`; const headers = {{ "Content-Type": "application/json", ...options.headers }}; if (this.apiKey) {{ headers["X-API-Key"] = this.apiKey; }} const response = await fetch(url, {{ method, headers, ...options }}); if (!response.ok) {{ throw new Error(`HTTP error! status: ${{response.status}}`); }} return await response.json(); }} ''' # 添加端点方法 paths = spec.get('paths', {}) for path, methods in paths.items(): for method, details in methods.items(): if filter_tags and not any( tag in filter_tags for tag in details.get('tags', []) ): continue method_name = self._generate_method_name(method, path, details) endpoint_path = self._convert_path_to_template(path) code += f''' async {method_name}({self._generate_js_method_params(path, details)}) {{ const params = {{ {self._generate_js_params_object(path, details)} }}; // Remove undefined values Object.keys(params).forEach(key => params[key] === undefined && delete params[key] ); return await this._makeRequest( "{method.upper()}", "{endpoint_path}"{self._generate_js_request_args(method, path)} ); }} ''' code += """ } """ return code def _generate_curl_examples(self, spec: dict, filter_tags: list[str] | None) -> str: """生成 cURL 示例""" info = spec.get('info', {}) base_url = spec.get('servers', [{}])[0].get('url', 'https://api.example.com') code = f'''# {info.get('title', 'API')} cURL Examples # Generated at: {datetime.now().isoformat()} BASE_URL="{base_url}" ''' paths = spec.get('paths', {}) for path, methods in paths.items(): for method, details in methods.items(): if filter_tags and not any( tag in filter_tags for tag in details.get('tags', []) ): continue endpoint_path = self._convert_path_to_template(path) full_url = f'${{BASE_URL}}{endpoint_path}' code += f''' # {details.get('summary', f'{method.upper()} {path}')} curl -X {method.upper()} "{full_url}" \\ -H "Content-Type: application/json" \\ -H "Accept: application/json" ''' return code def _generate_method_name(self, method: str, path: str, details: dict) -> str: """生成方法名""" # 移除路径参数 clean_path = re.sub(r'\{[^}]+\}', '', path) # 移除前导斜杠并替换斜杠为下划线 clean_path = clean_path.strip('/').replace('/', '_') # 移除多余的下划线 clean_path = re.sub(r'_+', '_', clean_path) method_prefix = { 'get': 'get', 'post': 'create', 'put': 'update', 'patch': 'patch', 'delete': 'delete', }.get(method.lower(), method.lower()) if clean_path: return f'{method_prefix}_{clean_path}' else: return method_prefix def _convert_path_to_template(self, path: str) -> str: """将路径转换为模板格式""" return re.sub(r'\{([^}]+)\}', r'{\1}', path) def _generate_method_params(self, path: str, details: dict) -> str: """生成 Python 方法参数""" params = [] # 路径参数 path_params = re.findall(r'\{([^}]+)\}', path) for param in path_params: params.append(f'{param}: str') # 查询参数(简化处理) if details.get('parameters'): for param in details['parameters']: if param.get('in') == 'query': param_name = param['name'] required = param.get('required', False) param_type = 'str' # 简化处理 default = '' if required else ' = None' params.append(f'{param_name}: {param_type}{default}') if not params: return '' return ', ' + ', '.join(params) def _generate_js_method_params(self, path: str, details: dict) -> str: """生成 JavaScript 方法参数""" params = [] # 路径参数 path_params = re.findall(r'\{([^}]+)\}', path) for param in path_params: params.append(param) # 查询参数 if details.get('parameters'): for param in details['parameters']: if param.get('in') == 'query': param_name = param['name'] params.append(param_name) if not params: return '' return ', ' + ', '.join(params) def _generate_js_params_object(self, path: str, details: dict) -> str: """生成 JavaScript 参数对象""" params = [] # 路径参数 path_params = re.findall(r'\{([^}]+)\}', path) for param in path_params: params.append(f' {param}') # 查询参数 if details.get('parameters'): for param in details['parameters']: if param.get('in') == 'query': param_name = param['name'] params.append(f' {param_name}') if not params: return ' // no parameters' return '\n'.join(params) def _generate_request_args(self, method: str, path: str) -> str: """生成 Python 请求参数""" if method.upper() in ['POST', 'PUT', 'PATCH']: return ',\n json=data' return '' def _generate_js_request_args(self, method: str, path: str) -> str: """生成 JavaScript 请求参数""" if method.upper() in ['POST', 'PUT', 'PATCH']: return ',\n body: JSON.stringify(data)' return '' def _generate_typescript_code( self, spec: dict, output_format: str, include_auth: bool, filter_tags: list[str] | None, ) -> str: """生成 TypeScript 代码(简化版本)""" # 基于JavaScript代码添加类型注解 js_code = self._generate_javascript_code( spec, output_format, include_auth, filter_tags ) # 添加类型定义 types_code = """ // Type definitions interface ApiResponse { [key: string]: any; } interface RequestOptions { headers?: Record<string, string>; body?: string; } """ return types_code + js_code.replace('class', 'class /* implements */').replace( 'options = {}', 'options: RequestOptions = {}' ) # 自定义工具3: 测试数据生成工具 class TestDataGeneratorTool(BaseTool): """测试数据生成工具 - 根据 OpenAPI 模型生成测试数据""" def __init__(self): super().__init__() self.name = 'test_data_generator' self.description = '根据 OpenAPI 模型定义生成测试数据' def get_definition(self) -> ToolDefinition: """获取工具定义""" return ToolDefinition( name=self.name, description=self.description, input_schema={ 'type': 'object', 'properties': { 'model_name': { 'type': 'string', 'description': '模型名称(可选,如果不指定则生成所有模型的数据)', }, 'count': { 'type': 'integer', 'description': '生成数据条数', 'default': 1, 'minimum': 1, 'maximum': 100, }, 'format': { 'type': 'string', 'enum': ['json', 'python', 'typescript'], 'description': '输出格式', 'default': 'json', }, 'locale': { 'type': 'string', 'description': '地区设置(影响生成的本地化数据)', 'default': 'en', }, }, }, ) async def execute(self, arguments: dict[str, Any]) -> ToolResult: """执行工具""" try: model_name = arguments.get('model_name') count = arguments.get('count', 1) format_type = arguments.get('format', 'json') locale = arguments.get('locale', 'en') # 获取 OpenAPI 规范 spec = await self.get_openapi_spec() # 获取模型定义 schemas = spec.get('components', {}).get('schemas', {}) if model_name and model_name not in schemas: return ToolResult( success=False, error=f"模型 '{model_name}' 不存在", message='模型名称错误', ) # 生成测试数据 if model_name: test_data = self._generate_model_data( schemas[model_name], schemas, locale, count ) result = {model_name: test_data} else: result = {} for name, model_def in schemas.items(): result[name] = self._generate_model_data( model_def, schemas, locale, count ) # 格式化输出 if format_type == 'json': output = json.dumps(result, indent=2, ensure_ascii=False) elif format_type == 'python': output = self._format_as_python(result) elif format_type == 'typescript': output = self._format_as_typescript(result) else: raise ValueError(f'不支持的输出格式: {format_type}') return ToolResult( success=True, data={ 'format': format_type, 'locale': locale, 'count': count, 'data': output, }, message='测试数据生成完成', ) except Exception as e: return ToolResult(success=False, error=str(e), message='测试数据生成失败') def _generate_model_data( self, model_def: dict, all_schemas: dict, locale: str, count: int ) -> list[dict]: """为模型生成测试数据""" data_list = [] for _ in range(count): data = self._generate_value_for_schema(model_def, all_schemas, locale) data_list.append(data) return data_list def _generate_value_for_schema( self, schema: dict, all_schemas: dict, locale: str ) -> Any: """根据 schema 生成值""" if '$ref' in schema: # 处理引用 ref_path = schema['$ref'] if ref_path.startswith('#/components/schemas/'): model_name = ref_path.split('/')[-1] if model_name in all_schemas: return self._generate_value_for_schema( all_schemas[model_name], all_schemas, locale ) schema_type = schema.get('type', 'string') if schema_type == 'object': obj = {} properties = schema.get('properties', {}) for prop_name, prop_schema in properties.items(): obj[prop_name] = self._generate_value_for_schema( prop_schema, all_schemas, locale ) return obj elif schema_type == 'array': items_schema = schema.get('items', {'type': 'string'}) # 生成1-3个数组元素 array_length = min(3, max(1, schema.get('minItems', 1))) return [ self._generate_value_for_schema(items_schema, all_schemas, locale) for _ in range(array_length) ] elif schema_type == 'string': return self._generate_string_value(schema, locale) elif schema_type == 'integer': return self._generate_integer_value(schema) elif schema_type == 'number': return self._generate_number_value(schema) elif schema_type == 'boolean': return True else: return 'generated_value' def _generate_string_value(self, schema: dict, locale: str) -> str: """生成字符串值""" format_type = schema.get('format') if format_type == 'email': return f'user{hash(locale) % 1000}@example.com' elif format_type == 'date': return '2024-01-15' elif format_type == 'date-time': return '2024-01-15T10:30:00Z' elif format_type == 'uuid': return '550e8400-e29b-41d4-a716-446655440000' elif schema.get('enum'): return schema['enum'][0] else: # 根据字段名生成合适的值 min_length = schema.get('minLength', 1) max_length = min(schema.get('maxLength', 20), 50) length = min(max_length, max(min_length, 5)) field_name = 'value' # 这里可以根据上下文推断字段名 if 'name' in field_name.lower(): return self._generate_name(locale, length) elif 'email' in field_name.lower(): return f'test{length}@example.com' elif 'description' in field_name.lower() or 'content' in field_name.lower(): return 'This is a generated description for testing purposes.' else: return 'x' * length def _generate_integer_value(self, schema: dict) -> int: """生成整数值""" minimum = schema.get('minimum', 0) maximum = schema.get('maximum', 100) if 'enum' in schema: return schema['enum'][0] return min( maximum, max(minimum, minimum + (hash('test') % (maximum - minimum + 1))) ) def _generate_number_value(self, schema: dict) -> float: """生成数字值""" minimum = schema.get('minimum', 0.0) maximum = schema.get('maximum', 100.0) return round(minimum + (hash('test') % 100) / 100 * (maximum - minimum), 2) def _generate_name(self, locale: str, length: int) -> str: """生成姓名""" names = { 'en': ['John', 'Jane', 'Bob', 'Alice', 'Charlie', 'Diana'], 'zh': ['张三', '李四', '王五', '赵六', '孙七', '周八'], 'ja': ['田中', '佐藤', '鈴木', '高橋', '伊藤'], } name_list = names.get(locale, names['en']) name = name_list[hash('test') % len(name_list)] return name[:length] def _format_as_python(self, data: dict) -> str: """格式化为 Python 代码""" output = '# Generated test data\n\n' for model_name, values in data.items(): class_name = self._to_pascal_case(model_name) output += f'class {class_name}:\n' if values: first_value = values[0] for key, value in first_value.items(): value_repr = repr(value) output += f' {key} = {value_repr}\n' output += f'\n# Test instances\n{model_name.lower()}_data = [\n' for value in values: output += f' {class_name}(**{repr(value)}),\n' output += ']\n\n' return output def _format_as_typescript(self, data: dict) -> str: """格式化为 TypeScript 代码""" output = '// Generated test data\n\n' # 生成接口定义 for model_name, values in data.items(): interface_name = self._to_pascal_case(model_name) output += f'interface {interface_name} {{\n' if values: first_value = values[0] for key, value in first_value.items(): type_str = self._infer_typescript_type(value) output += f' {key}: {type_str};\n' output += '}\n\n' # 生成测试数据 for model_name, values in data.items(): output += f'const {model_name.toLowerCase()}TestData: {self._to_pascal_case(model_name)}[] = [\n' for value in values: output += f' {json.dumps(value, ensure_ascii=False)},\n' output += '];\n\n' return output def _to_pascal_case(self, snake_str: str) -> str: """转换为 PascalCase""" return ''.join(word.capitalize() for word in snake_str.split('_')) def _infer_typescript_type(self, value: Any) -> str: """推断 TypeScript 类型""" if isinstance(value, str): return 'string' elif isinstance(value, int): return 'number' elif isinstance(value, float): return 'number' elif isinstance(value, bool): return 'boolean' elif isinstance(value, list): if value: item_type = self._infer_typescript_type(value[0]) return f'{item_type}[]' return 'any[]' elif isinstance(value, dict): return 'Record<string, any>' else: return 'any' # 自定义工具4: 性能分析工具 class PerformanceAnalyzerTool(BaseTool): """性能分析工具 - 分析 API 性能指标""" def __init__(self): super().__init__() self.name = 'performance_analyzer' self.description = '分析 OpenAPI 规范,评估 API 性能特征和优化建议' def get_definition(self) -> ToolDefinition: """获取工具定义""" return ToolDefinition( name=self.name, description=self.description, input_schema={ 'type': 'object', 'properties': { 'analysis_type': { 'type': 'string', 'enum': [ 'complexity', 'optimization', 'bottlenecks', 'recommendations', ], 'description': '分析类型', 'default': 'recommendations', }, 'traffic_profile': { 'type': 'string', 'enum': ['low', 'medium', 'high', 'enterprise'], 'description': '预期流量级别', 'default': 'medium', }, }, }, ) async def execute(self, arguments: dict[str, Any]) -> ToolResult: """执行工具""" try: analysis_type = arguments.get('analysis_type', 'recommendations') traffic_profile = arguments.get('traffic_profile', 'medium') # 获取 OpenAPI 规范 spec = await self.get_openapi_spec() # 执行性能分析 result = self._analyze_performance(spec, analysis_type, traffic_profile) return ToolResult( success=True, data=result, message=f'性能分析完成: {analysis_type}' ) except Exception as e: return ToolResult(success=False, error=str(e), message='性能分析失败') def _analyze_performance( self, spec: dict, analysis_type: str, traffic_profile: str ) -> dict: """分析性能""" if analysis_type == 'complexity': return self._analyze_complexity(spec) elif analysis_type == 'optimization': return self._analyze_optimization_opportunities(spec) elif analysis_type == 'bottlenecks': return self._identify_potential_bottlenecks(spec, traffic_profile) elif analysis_type == 'recommendations': return self._generate_recommendations(spec, traffic_profile) else: raise ValueError(f'未知的分析类型: {analysis_type}') def _analyze_complexity(self, spec: dict) -> dict: """分析 API 复杂度""" paths = spec.get('paths', {}) complexity_scores = [] for path, methods in paths.items(): for method, details in methods.items(): score = self._calculate_endpoint_complexity(path, details) complexity_scores.append( { 'endpoint': f'{method.upper()} {path}', 'score': score, 'factors': self._get_complexity_factors(details), } ) # 按复杂度排序 complexity_scores.sort(key=lambda x: x['score'], reverse=True) return { 'average_complexity': sum(s['score'] for s in complexity_scores) / len(complexity_scores), 'most_complex': complexity_scores[:5], 'least_complex': complexity_scores[-5:], 'total_endpoints': len(complexity_scores), } def _analyze_optimization_opportunities(self, spec: dict) -> dict: """分析优化机会""" opportunities = [] paths = spec.get('paths', {}) for path, methods in paths.items(): for method, details in methods.items(): endpoint = f'{method.upper()} {path}' # 检查缓存机会 if method.upper() == 'GET' and not self._has_cache_headers(details): opportunities.append( { 'type': 'caching', 'endpoint': endpoint, 'suggestion': '添加缓存头以提高 GET 请求性能', 'impact': 'high', } ) # 检查分页机会 if ( method.upper() == 'GET' and 'list' in path and not self._has_pagination(details) ): opportunities.append( { 'type': 'pagination', 'endpoint': endpoint, 'suggestion': '添加分页参数以处理大数据集', 'impact': 'medium', } ) # 检查数据压缩机会 if not self._has_compression(details): opportunities.append( { 'type': 'compression', 'endpoint': endpoint, 'suggestion': '启用响应压缩以减少带宽使用', 'impact': 'medium', } ) return { 'total_opportunities': len(opportunities), 'opportunities': sorted( opportunities, key=lambda x: x['impact'], reverse=True ), } def _identify_potential_bottlenecks(self, spec: dict, traffic_profile: str) -> dict: """识别潜在瓶颈""" bottlenecks = [] paths = spec.get('paths', {}) for path, methods in paths.items(): for method, details in methods.items(): endpoint = f'{method.upper()} {path}' # 检查复杂查询参数 complex_params = self._get_complex_parameters(details) if complex_params: bottlenecks.append( { 'type': 'complex_parameters', 'endpoint': endpoint, 'description': f'复杂参数可能影响性能: {", ".join(complex_params)}', 'severity': 'medium', } ) # 检查大响应潜力 if self._has_large_response_potential(details): bottlenecks.append( { 'type': 'large_response', 'endpoint': endpoint, 'description': '可能返回大量数据,影响响应时间', 'severity': 'high', } ) # 检查缺少限制的端点 if not self._has_rate_limiting_hints(details): bottlenecks.append( { 'type': 'no_rate_limiting', 'endpoint': endpoint, 'description': '缺少速率限制提示,可能在流量高峰时出现问题', 'severity': 'medium', } ) return { 'traffic_profile': traffic_profile, 'total_bottlenecks': len(bottlenecks), 'bottlenecks': bottlenecks, } def _generate_recommendations(self, spec: dict, traffic_profile: str) -> dict: """生成性能优化建议""" recommendations = [] # 基础建议 recommendations.extend( [ { 'category': 'caching', 'priority': 'high', 'recommendation': '为所有 GET 端点实现适当的缓存策略', 'implementation': '使用 HTTP 缓存头或 Redis 缓存', }, { 'category': 'monitoring', 'priority': 'high', 'recommendation': '实现全面的性能监控', 'implementation': '监控响应时间、错误率、吞吐量等关键指标', }, ] ) # 基于流量级别的建议 if traffic_profile in ['high', 'enterprise']: recommendations.extend( [ { 'category': 'scaling', 'priority': 'critical', 'recommendation': '实现水平扩展能力', 'implementation': '使用负载均衡器和自动扩缩容', }, { 'category': 'database', 'priority': 'high', 'recommendation': '优化数据库查询和连接池', 'implementation': '添加索引、使用读写分离、优化连接池配置', }, ] ) # 基于分析的具体建议 paths = spec.get('paths', {}) total_endpoints = sum(len(methods) for methods in paths.values()) if total_endpoints > 50: recommendations.append( { 'category': 'architecture', 'priority': 'medium', 'recommendation': '考虑 API 拆分为微服务', 'implementation': '按业务功能拆分大型 API', } ) return { 'traffic_profile': traffic_profile, 'total_recommendations': len(recommendations), 'recommendations': sorted( recommendations, key=lambda x: {'critical': 0, 'high': 1, 'medium': 2}[x['priority']], ), } def _calculate_endpoint_complexity(self, path: str, details: dict) -> int: """计算端点复杂度""" score = 1 # 路径复杂度 path_params = len(re.findall(r'\{[^}]+\}', path)) score += path_params * 2 # 参数复杂度 parameters = details.get('parameters', []) score += len(parameters) # 响应复杂度 responses = details.get('responses', {}) for response in responses.values(): content = response.get('content', {}) if 'application/json' in content: score += 3 # JSON 响应增加复杂度 return score def _get_complexity_factors(self, details: dict) -> list[str]: """获取复杂度因素""" factors = [] if len(details.get('parameters', [])) > 5: factors.append('多个参数') if len(re.findall(r'\{[^}]+\}', details.get('path', ''))) > 2: factors.append('多个路径参数') if len(details.get('responses', {})) > 3: factors.append('多个响应类型') return factors def _has_cache_headers(self, details: dict) -> bool: """检查是否有缓存头""" # 简化检查,实际中需要检查响应头定义 return 'cache' in str(details).lower() def _has_pagination(self, details: dict) -> bool: """检查是否有分页""" params = details.get('parameters', []) pagination_params = ['limit', 'offset', 'page', 'size', 'skip'] return any(p.get('name') in pagination_params for p in params) def _has_compression(self, details: dict) -> bool: """检查是否支持压缩""" return 'compression' in str(details).lower() def _get_complex_parameters(self, details: dict) -> list[str]: """获取复杂参数""" complex_params = [] for param in details.get('parameters', []): if param.get('type') == 'array' or param.get('type') == 'object': complex_params.append(param.get('name', '')) return complex_params def _has_large_response_potential(self, details: dict) -> bool: """检查是否有大响应潜力""" responses = details.get('responses', {}) for response in responses.values(): content = response.get('content', {}) if 'application/json' in content: schema = content['application/json'].get('schema', {}) if schema.get('type') == 'array': return True return False def _has_rate_limiting_hints(self, details: dict) -> bool: """检查是否有速率限制提示""" return 'rate' in str(details).lower() or 'limit' in str(details).lower() # 自定义工具注册和演示 def main(): """运行自定义工具演示""" print('🔧 OpenAPI MCP 自定义工具演示') print('=' * 60) # 创建演示应用 app = create_demo_app() # 创建 MCP 服务器 mcp_server = OpenApiMcpServer(app) # 注册自定义工具 custom_tools = [ ApiAnalyzerTool(), CodeGeneratorTool(), TestDataGeneratorTool(), PerformanceAnalyzerTool(), ] for tool in custom_tools: mcp_server.register_tool(tool) print(f'✅ 注册自定义工具: {tool.name}') # 挂载 MCP 服务器 mcp_server.mount('/mcp') print('\n📋 已注册的自定义工具:') for tool in custom_tools: definition = tool.get_definition() print(f' 🔧 {tool.name}') print(f' 描述: {definition.description}') print(f' 参数: {len(definition.input_schema.get("properties", {}))} 个') print('\n💡 自定义工具功能:') print(' 1. API 分析器 - 深入分析 OpenAPI 规范') print(' 2. 代码生成器 - 生成多语言客户端代码') print(' 3. 测试数据生成器 - 自动生成测试数据') print(' 4. 性能分析器 - 评估和优化 API 性能') print('\n🚀 使用方法:') print(' 1. 启动服务器: python custom_tool_example.py') print(' 2. 连接 MCP 客户端到 http://localhost:8000/mcp') print(' 3. 调用自定义工具进行相应操作') print('\n📖 扩展自定义工具:') print(' - 继承 BaseTool 基类') print(' - 实现 get_definition() 方法定义工具接口') print(' - 实现 execute() 方法提供工具功能') print(' - 使用 register_tool() 注册到服务器') return app if __name__ == '__main__': import uvicorn app = main() print('\n🌐 启动服务器...') print(' 访问 http://localhost:8000/docs 查看 API 文档') print(' 连接 http://localhost:8000/mcp 使用自定义工具') uvicorn.run(app, host='0.0.0.0', port=8000)

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/jason-chang/fastapi-openapi-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server