Skip to main content
Glama

AWS Security MCP

ec2_tools.py105 kB
"""EC2 tools for AWS Security MCP. This module provides tools for working with EC2 resources, including: - Listing and searching EC2 instances - Listing and filtering security groups, including by inbound rules and ports - Identifying resources with public internet access - Analyzing VPCs and networking components - Finding security vulnerabilities like open ports """ import logging from typing import Any, Dict, List, Optional, Union import json from datetime import datetime import boto3 from aws_security_mcp.services import ec2 from aws_security_mcp.services.base import get_client from aws_security_mcp.tools import register_tool # Configure logging logger = logging.getLogger(__name__) # Define custom JSON encoder for datetime objects class CustomJSONEncoder(json.JSONEncoder): """Custom JSON encoder that handles datetime objects.""" def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() return super().default(obj) @register_tool('list_ec2_instances') async def list_ec2_instances( limit: Optional[int] = None, search_term: str = "", state: str = "running", next_token: Optional[str] = None, session_context: Optional[str] = None ) -> str: """List EC2 instances with details. Args: limit: Maximum number of instances to return (None for all) search_term: Optional search term to filter instances by name, ID, or type state: Instance state to filter by (default is "running"). Set to empty string to show all states. next_token: Pagination token from a previous request (optional) session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with EC2 instance information """ logger.info(f"Listing EC2 instances (limit={limit}, search_term='{search_term}', state='{state}', next_token={next_token}, session_context={session_context})") try: ec2_client = get_client("ec2", session_context=session_context) # Prepare parameters for the API call kwargs = {} # Add filters based on state if provided if state: filters = [{ 'Name': 'instance-state-name', 'Values': [state] }] kwargs['Filters'] = filters # Add pagination token if provided if next_token: kwargs['NextToken'] = next_token # Add max results if limit is specified if limit: kwargs['MaxResults'] = min(limit, 1000) # AWS API max is 1000 # Get instances response = ec2_client.describe_instances(**kwargs) # Extract instances from reservations instances = [] for reservation in response.get('Reservations', []): instances.extend(reservation.get('Instances', [])) # Filter instances based on search term if provided filtered_instances = [] for instance in instances: if search_term: # Convert all values to lowercase for case-insensitive search instance_id = instance.get('InstanceId', '').lower() instance_type = instance.get('InstanceType', '').lower() private_ip = instance.get('PrivateIpAddress', '').lower() if instance.get('PrivateIpAddress') else '' public_ip = instance.get('PublicIpAddress', '').lower() if instance.get('PublicIpAddress') else '' vpc_id = instance.get('VpcId', '').lower() if instance.get('VpcId') else '' subnet_id = instance.get('SubnetId', '').lower() if instance.get('SubnetId') else '' # Get name tag if available instance_name = next((tag['Value'].lower() for tag in instance.get('Tags', []) if tag['Key'] == 'Name'), '') search_lower = search_term.lower() # Skip if search term not found in any field if (search_lower not in instance_id and search_lower not in instance_type and search_lower not in private_ip and search_lower not in public_ip and search_lower not in vpc_id and search_lower not in subnet_id and search_lower not in instance_name): continue # Extract security groups security_groups = [] for sg in instance.get('SecurityGroups', []): security_groups.append({ 'id': sg.get('GroupId', ''), 'name': sg.get('GroupName', '') }) # Get instance name from tags name = next((tag.get('Value', '') for tag in instance.get('Tags', []) if tag.get('Key') == 'Name'), '') # Format instance details formatted_instance = { 'id': instance.get('InstanceId', ''), 'name': name, 'type': instance.get('InstanceType', ''), 'state': instance.get('State', {}).get('Name', ''), 'private_ip': instance.get('PrivateIpAddress', ''), 'public_ip': instance.get('PublicIpAddress', ''), 'vpc_id': instance.get('VpcId', ''), 'subnet_id': instance.get('SubnetId', ''), 'image_id': instance.get('ImageId', ''), 'launch_time': instance.get('LaunchTime', '').isoformat() if instance.get('LaunchTime') else '', 'key_name': instance.get('KeyName', ''), 'security_groups': security_groups, 'iam_instance_profile': instance.get('IamInstanceProfile', {}).get('Arn', '') if instance.get('IamInstanceProfile') else '', 'ebs_volumes': [ { 'device': mapping.get('DeviceName', ''), 'volume_id': mapping.get('Ebs', {}).get('VolumeId', '') if mapping.get('Ebs') else '' } for mapping in instance.get('BlockDeviceMappings', []) ] } filtered_instances.append(formatted_instance) # Check if we've reached the limit if limit and len(filtered_instances) >= limit: break # Extract next token for pagination next_token = response.get('NextToken') # Format the final result result = { 'instances': filtered_instances, 'count': len(filtered_instances), 'has_more': bool(next_token), 'next_token': next_token } return json.dumps(result, cls=CustomJSONEncoder) except Exception as e: error_message = f"Error listing EC2 instances: {str(e)}" logger.error(error_message, exc_info=True) return json.dumps({ "error": error_message, "instances": [], "count": 0, "has_more": False }) @register_tool('count_ec2_instances') async def count_ec2_instances( state: str = "", has_public_access: Optional[bool] = None, port: Optional[int] = None, session_context: Optional[str] = None ) -> str: """Count EC2 instances, optionally filtering by state and security group rules. Args: state: Optional instance state to filter by (e.g., running, stopped, terminated) has_public_access: If set, only count instances with (True) or without (False) public internet access port: Optional specific port to check for access (e.g., 22 for SSH) session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with instance count information """ logger.info(f"Counting EC2 instances (state='{state}', has_public_access={has_public_access}, port={port}, session_context={session_context})") try: # Handle state filtering - convert single string state to list if provided states = [state] if state else None # Get all instances with state filtering - pass session_context to ec2 service calls all_instances = ec2.get_all_instances(max_items=None, states=states, session_context=session_context) # If we're not filtering by public access, return the count now if has_public_access is None and port is None: count_by_state = {} for instance in all_instances: current_state = instance.get('State', {}).get('Name', 'unknown') count_by_state[current_state] = count_by_state.get(current_state, 0) + 1 total_count = len(all_instances) if state: return json.dumps({"summary": f"Found {total_count} EC2 instance(s) in state '{state}'"}) else: result = f"Total EC2 instances: {total_count}\n\nBreakdown by state:\n" for s, count in sorted(count_by_state.items()): result += f"- {s}: {count}\n" return json.dumps({"summary": result}) # We need to check for public access, so get security groups all_security_groups = ec2.get_all_security_groups(max_items=None, session_context=session_context) # Create lookup dictionary for security groups sg_dict = {sg.get('GroupId'): sg for sg in all_security_groups} # Track instances with public access public_instances = [] non_public_instances = [] # Check each instance for instance in all_instances: # Get security groups attached to the instance instance_sgs = instance.get('SecurityGroups', []) instance_public = False for instance_sg in instance_sgs: sg_id = instance_sg.get('GroupId') # Get the full security group details sg = sg_dict.get(sg_id) if not sg: continue # Check inbound rules for rule in sg.get('IpPermissions', []): rule_protocol = rule.get('IpProtocol', '') from_port = rule.get('FromPort') to_port = rule.get('ToPort') # Check if port matches if specified port_match = False if port is not None: if rule_protocol == '-1': # All traffic port_match = True elif from_port is not None and to_port is not None: if from_port <= port <= to_port: port_match = True else: port_match = True # No specific port to check # Skip if port doesn't match if not port_match: continue # Check for public CIDR ranges for ip_range in rule.get('IpRanges', []): cidr = ip_range.get('CidrIp', '') if cidr in ('0.0.0.0/0', '::/0'): instance_public = True break if instance_public: break if instance_public: public_instances.append(instance) else: non_public_instances.append(instance) # Count results based on public access filter if has_public_access is True: instances_to_count = public_instances elif has_public_access is False: instances_to_count = non_public_instances else: # Port filter only, regardless of public access instances_to_count = public_instances + non_public_instances # Formulate result message if port is not None: msg = f"Found {len(instances_to_count)} EC2 instance(s)" if has_public_access is True: msg += " with public access" elif has_public_access is False: msg += " without public access" msg += f" on port {port}" if state: msg += f" in state '{state}'" return json.dumps({"summary": msg}) else: msg = f"Found {len(instances_to_count)} EC2 instance(s)" if has_public_access is True: msg += " with public internet access" elif has_public_access is False: msg += " without public internet access" if state: msg += f" in state '{state}'" return json.dumps({"summary": msg}) except Exception as e: logger.error(f"Error counting EC2 instances: {e}") return json.dumps({"error": {"message": f"Error counting EC2 instances: {str(e)}", "type": type(e).__name__}}) @register_tool('list_security_groups') async def list_security_groups( limit: Optional[int] = None, search_term: str = "", next_token: Optional[str] = None, session_context: Optional[str] = None ) -> str: """List EC2 security groups with details. Args: limit: Maximum number of security groups to return (None for all) search_term: Optional search term to filter security groups. Supports special syntax: - Standard text search by name, ID, description, or VPC ID - port:XX - Find security groups with specific port open (e.g., port:22 for SSH) - protocol:XX - Find security groups allowing specific protocol (e.g., protocol:http) - public:true - Find security groups open to the internet (0.0.0.0/0) - cidr:X.X.X.X/X - Find security groups allowing specific CIDR range next_token: Pagination token from a previous request (optional) session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with security group information """ logger.info(f"Listing security groups (limit={limit}, search_term='{search_term}', next_token={next_token}, session_context={session_context})") try: if next_token: # If next_token is provided, we're requesting a specific page response = ec2.describe_security_groups( next_token=next_token, session_context=session_context ) security_groups = response.get('SecurityGroups', []) # Apply text search filter if provided if search_term: security_groups = ec2.filter_security_groups_by_text( search_term=search_term, session_context=session_context ) next_token_value = response.get('NextToken') else: # No next_token provided, get with optional filtering if search_term: # Special search syntax handling is done in filter_security_groups_by_text security_groups = ec2.filter_security_groups_by_text( search_term=search_term, session_context=session_context ) next_token_value = None else: # Get all security groups if no search term response = ec2.describe_security_groups(session_context=session_context) security_groups = response.get('SecurityGroups', []) next_token_value = response.get('NextToken') # Format response formatted_groups = [] for sg in security_groups: # Extract basic info group_id = sg.get('GroupId', '') group_name = sg.get('GroupName', '') vpc_id = sg.get('VpcId', '') description = sg.get('Description', '') # Format inbound rules inbound_rules = [] for rule in sg.get('IpPermissions', []): protocol = rule.get('IpProtocol', '') from_port = rule.get('FromPort', '') to_port = rule.get('ToPort', '') # Format protocol display if protocol == '-1': protocol_display = 'All' port_range = 'All' elif protocol == 'tcp': protocol_display = 'TCP' if from_port == to_port: port_range = str(from_port) else: port_range = f"{from_port}-{to_port}" elif protocol == 'udp': protocol_display = 'UDP' if from_port == to_port: port_range = str(from_port) else: port_range = f"{from_port}-{to_port}" elif protocol == 'icmp': protocol_display = 'ICMP' port_range = 'N/A' else: protocol_display = protocol if from_port or to_port: if from_port == to_port: port_range = str(from_port) else: port_range = f"{from_port}-{to_port}" else: port_range = 'All' # Add sources sources = [] # IP ranges for ip_range in rule.get('IpRanges', []): cidr = ip_range.get('CidrIp', '') description = ip_range.get('Description', '') source_entry = { 'type': 'IPv4 CIDR', 'value': cidr, 'description': description } # Flag if open to the internet if cidr == '0.0.0.0/0': source_entry['is_public'] = True sources.append(source_entry) # IPv6 ranges for ipv6_range in rule.get('Ipv6Ranges', []): cidr = ipv6_range.get('CidrIpv6', '') description = ipv6_range.get('Description', '') source_entry = { 'type': 'IPv6 CIDR', 'value': cidr, 'description': description } # Flag if open to the internet if cidr == '::/0': source_entry['is_public'] = True sources.append(source_entry) # Security group sources for sg_source in rule.get('UserIdGroupPairs', []): source_entry = { 'type': 'Security Group', 'value': sg_source.get('GroupId', ''), 'description': sg_source.get('Description', '') } sources.append(source_entry) # Prefix list sources for prefix_source in rule.get('PrefixListIds', []): source_entry = { 'type': 'Prefix List', 'value': prefix_source.get('PrefixListId', ''), 'description': prefix_source.get('Description', '') } sources.append(source_entry) # Add rule to inbound_rules rule_entry = { 'protocol': protocol_display, 'port_range': port_range, 'sources': sources } inbound_rules.append(rule_entry) # Format outbound rules outbound_rules = [] for rule in sg.get('IpPermissionsEgress', []): protocol = rule.get('IpProtocol', '') from_port = rule.get('FromPort', '') to_port = rule.get('ToPort', '') # Format protocol display if protocol == '-1': protocol_display = 'All' port_range = 'All' elif protocol == 'tcp': protocol_display = 'TCP' if from_port == to_port: port_range = str(from_port) else: port_range = f"{from_port}-{to_port}" elif protocol == 'udp': protocol_display = 'UDP' if from_port == to_port: port_range = str(from_port) else: port_range = f"{from_port}-{to_port}" elif protocol == 'icmp': protocol_display = 'ICMP' port_range = 'N/A' else: protocol_display = protocol if from_port or to_port: if from_port == to_port: port_range = str(from_port) else: port_range = f"{from_port}-{to_port}" else: port_range = 'All' # Add destinations destinations = [] # IP ranges for ip_range in rule.get('IpRanges', []): cidr = ip_range.get('CidrIp', '') description = ip_range.get('Description', '') dest_entry = { 'type': 'IPv4 CIDR', 'value': cidr, 'description': description } # Flag if open to the internet if cidr == '0.0.0.0/0': dest_entry['is_public'] = True destinations.append(dest_entry) # IPv6 ranges for ipv6_range in rule.get('Ipv6Ranges', []): cidr = ipv6_range.get('CidrIpv6', '') description = ipv6_range.get('Description', '') dest_entry = { 'type': 'IPv6 CIDR', 'value': cidr, 'description': description } # Flag if open to the internet if cidr == '::/0': dest_entry['is_public'] = True destinations.append(dest_entry) # Security group destinations for sg_dest in rule.get('UserIdGroupPairs', []): dest_entry = { 'type': 'Security Group', 'value': sg_dest.get('GroupId', ''), 'description': sg_dest.get('Description', '') } destinations.append(dest_entry) # Prefix list destinations for prefix_dest in rule.get('PrefixListIds', []): dest_entry = { 'type': 'Prefix List', 'value': prefix_dest.get('PrefixListId', ''), 'description': prefix_dest.get('Description', '') } destinations.append(dest_entry) # Add rule to outbound_rules rule_entry = { 'protocol': protocol_display, 'port_range': port_range, 'destinations': destinations } outbound_rules.append(rule_entry) # Add security group to formatted_groups sg_entry = { 'id': group_id, 'name': group_name, 'vpc_id': vpc_id, 'description': description, 'inbound_rules': inbound_rules, 'outbound_rules': outbound_rules } formatted_groups.append(sg_entry) # Create summary and response result = { 'security_groups': formatted_groups, 'count': len(formatted_groups), 'total_count': len(formatted_groups), # For legacy compatibility 'has_more': next_token_value is not None, 'next_token': next_token_value } return json.dumps(result) except Exception as e: logger.error(f"Error listing security groups: {e}", exc_info=True) return json.dumps({ 'error': str(e), 'security_groups': [], 'count': 0, 'total_count': 0, 'has_more': False }) @register_tool('list_vpcs') async def list_vpcs( limit: Optional[int] = None, search_term: str = "", next_token: Optional[str] = None, session_context: Optional[str] = None ) -> str: """List VPCs with details. Args: limit: Maximum number of VPCs to return (None for all) search_term: Optional search term to filter VPCs by ID or CIDR next_token: Optional pagination token for fetching next page of results session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with VPC information """ logger.info(f"Listing VPCs (limit={limit}, search_term='{search_term}', next_token={next_token}, session_context={session_context})") try: ec2_client = get_client("ec2", session_context=session_context) # Prepare parameters for the API call kwargs = {} # Add pagination token if provided if next_token: kwargs['NextToken'] = next_token # Add max results if limit is specified if limit: kwargs['MaxResults'] = min(limit, 1000) # AWS API max is 1000 # Get VPCs response = ec2_client.describe_vpcs(**kwargs) vpcs = response.get('Vpcs', []) # Filter VPCs based on search term if provided filtered_vpcs = [] for vpc in vpcs: if search_term: # Convert all values to lowercase for case-insensitive search vpc_id = vpc.get('VpcId', '').lower() cidr_block = vpc.get('CidrBlock', '').lower() vpc_name = next((tag['Value'].lower() for tag in vpc.get('Tags', []) if tag['Key'].lower() == 'name'), '') search_lower = search_term.lower() # Skip if search term not found in any field if (search_lower not in vpc_id and search_lower not in cidr_block and search_lower not in vpc_name): continue # Format VPC data formatted_vpc = { 'id': vpc.get('VpcId'), 'cidr_block': vpc.get('CidrBlock'), 'state': vpc.get('State'), 'is_default': vpc.get('IsDefault', False), 'tags': {tag['Key']: tag['Value'] for tag in vpc.get('Tags', [])} if 'Tags' in vpc else {}, 'cidr_block_association_set': vpc.get('CidrBlockAssociationSet', []), 'ipv6_cidr_block_association_set': vpc.get('Ipv6CidrBlockAssociationSet', []), 'instance_tenancy': vpc.get('InstanceTenancy') } # Add name from tags if available formatted_vpc['name'] = next((tag['Value'] for tag in vpc.get('Tags', []) if tag['Key'] == 'Name'), '') filtered_vpcs.append(formatted_vpc) # Check if we've reached the limit if limit and len(filtered_vpcs) >= limit: break # Extract next token for pagination next_token = response.get('NextToken') # Format the final result result = { 'vpcs': filtered_vpcs, 'count': len(filtered_vpcs), 'has_more': bool(next_token), 'next_token': next_token } return json.dumps(result, cls=CustomJSONEncoder) except Exception as e: error_message = f"Error listing VPCs: {str(e)}" logger.error(error_message, exc_info=True) return json.dumps({ "error": error_message, "vpcs": [], "count": 0, "has_more": False }) @register_tool('list_route_tables') async def list_route_tables( limit: Optional[int] = None, search_term: str = "", next_token: Optional[str] = None, session_context: Optional[str] = None ) -> str: """List route tables with details. Args: limit: Maximum number of route tables to return (None for all) search_term: Optional search term to filter route tables by ID or VPC ID next_token: Optional pagination token for fetching next page of results session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with route table information """ logger.info(f"Listing route tables (limit={limit}, search_term='{search_term}', next_token={next_token}, session_context={session_context})") try: ec2_client = get_client("ec2", session_context=session_context) # Prepare parameters for the API call kwargs = {} # Set up filters based on search term filters = None if search_term and (search_term.startswith("vpc:") or search_term.startswith("vpc-")): # Check if search is for a specific VPC if search_term.startswith("vpc:"): vpc_id = search_term.split(':', 1)[1] else: vpc_id = search_term filters = [{ 'Name': 'vpc-id', 'Values': [vpc_id] }] kwargs['Filters'] = filters search_term = "" # Clear search term since we're using a filter # Add pagination token if provided if next_token: kwargs['NextToken'] = next_token # Add max results if limit is specified if limit: kwargs['MaxResults'] = min(limit, 1000) # AWS API max is 1000 # Get route tables response = ec2_client.describe_route_tables(**kwargs) route_tables = response.get('RouteTables', []) # Filter route tables based on search term if provided filtered_route_tables = [] for rt in route_tables: if search_term: # Convert all values to lowercase for case-insensitive search rt_id = rt.get('RouteTableId', '').lower() vpc_id = rt.get('VpcId', '').lower() rt_name = next((tag['Value'].lower() for tag in rt.get('Tags', []) if tag['Key'].lower() == 'name'), '') search_lower = search_term.lower() # Skip if search term not found in any field if (search_lower not in rt_id and search_lower not in vpc_id and search_lower not in rt_name): continue # Format route table data formatted_rt = { 'id': rt.get('RouteTableId'), 'vpc_id': rt.get('VpcId'), 'routes': [ { 'destination_cidr_block': route.get('DestinationCidrBlock', ''), 'destination_ipv6_cidr_block': route.get('DestinationIpv6CidrBlock', ''), 'destination_prefix_list_id': route.get('DestinationPrefixListId', ''), 'gateway_id': route.get('GatewayId', ''), 'instance_id': route.get('InstanceId', ''), 'nat_gateway_id': route.get('NatGatewayId', ''), 'network_interface_id': route.get('NetworkInterfaceId', ''), 'vpc_peering_connection_id': route.get('VpcPeeringConnectionId', ''), 'state': route.get('State', '') } for route in rt.get('Routes', []) ], 'associations': [ { 'id': assoc.get('RouteTableAssociationId', ''), 'subnet_id': assoc.get('SubnetId', ''), 'gateway_id': assoc.get('GatewayId', ''), 'main': assoc.get('Main', False) } for assoc in rt.get('Associations', []) ], 'propagating_vgws': [vgw.get('GatewayId', '') for vgw in rt.get('PropagatingVgws', [])], 'tags': {tag['Key']: tag['Value'] for tag in rt.get('Tags', [])} if 'Tags' in rt else {} } # Add name from tags if available formatted_rt['name'] = next((tag['Value'] for tag in rt.get('Tags', []) if tag['Key'] == 'Name'), '') # Check if this is the main route table formatted_rt['is_main'] = any(assoc.get('Main', False) for assoc in rt.get('Associations', [])) # Determine if this route table has internet access formatted_rt['has_internet_access'] = any( route.get('GatewayId', '').startswith('igw-') for route in rt.get('Routes', []) ) filtered_route_tables.append(formatted_rt) # Check if we've reached the limit if limit and len(filtered_route_tables) >= limit: break # Extract next token for pagination next_token = response.get('NextToken') # Format the final result result = { 'route_tables': filtered_route_tables, 'count': len(filtered_route_tables), 'has_more': bool(next_token), 'next_token': next_token } return json.dumps(result, cls=CustomJSONEncoder) except Exception as e: error_message = f"Error listing route tables: {str(e)}" logger.error(error_message, exc_info=True) return json.dumps({ "error": error_message, "route_tables": [], "count": 0, "has_more": False }) @register_tool('list_subnets') async def list_subnets( vpc_id: Optional[str] = None, include_details: bool = True, limit: Optional[int] = None, search_term: str = "", next_token: Optional[str] = None, session_context: Optional[str] = None ) -> str: """List all subnets in a VPC or across all VPCs. Args: vpc_id: Optional VPC ID to list subnets for. If None, lists subnets across all VPCs. include_details: Whether to include detailed subnet information (route tables, ACLs) limit: Maximum number of subnets to return (None for all) search_term: Optional text to filter subnets by ID, VPC ID, CIDR, or tags next_token: Optional pagination token for fetching next page of results session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with subnet information """ logger.info(f"Listing subnets (vpc_id={vpc_id}, include_details={include_details}, limit={limit}, search_term='{search_term}', next_token={next_token}, session_context={session_context})") try: # Set up filters if VPC ID is provided filters = None if vpc_id: filters = [{ 'Name': 'vpc-id', 'Values': [vpc_id] }] if next_token: # If next_token is provided, we're requesting a specific page response = ec2.describe_subnets( filters=filters, next_token=next_token, session_context=session_context ) subnets = response.get('Subnets', []) # Apply text search filter if provided if search_term: subnets = ec2.filter_subnets_by_text(subnets, search_term) next_token_value = response.get('NextToken') else: # No next_token provided, get with optional filtering if search_term: # Get all subnets and filter by text subnets = ec2.get_all_subnets(filters=filters, session_context=session_context) if search_term: subnets = ec2.filter_subnets_by_text(subnets, search_term) next_token_value = None else: # Get all subnets if no search term response = ec2.describe_subnets(filters=filters, session_context=session_context) subnets = response.get('Subnets', []) next_token_value = response.get('NextToken') # Format response formatted_subnets = [] for subnet in subnets: # Extract basic info subnet_id = subnet.get('SubnetId', '') subnet_vpc_id = subnet.get('VpcId', '') cidr_block = subnet.get('CidrBlock', '') availability_zone = subnet.get('AvailabilityZone', '') state = subnet.get('State', '') available_ip_count = subnet.get('AvailableIpAddressCount', 0) # Get public IP assignment setting map_public_ip = subnet.get('MapPublicIpOnLaunch', False) # Get name tag if available name = '' for tag in subnet.get('Tags', []): if tag.get('Key') == 'Name': name = tag.get('Value', '') break # Format tags tags = {} for tag in subnet.get('Tags', []): tags[tag.get('Key', '')] = tag.get('Value', '') # Create basic subnet details subnet_details = { 'id': subnet_id, 'name': name, 'vpc_id': subnet_vpc_id, 'cidr_block': cidr_block, 'availability_zone': availability_zone, 'state': state, 'available_ip_count': available_ip_count, 'map_public_ip_on_launch': map_public_ip, 'tags': tags } # Add detailed information if requested if include_details: # Get route table associations for this subnet route_table_filters = [{ 'Name': 'association.subnet-id', 'Values': [subnet_id] }] route_tables = ec2.get_all_route_tables(filters=route_table_filters, session_context=session_context) # Format route tables associated_route_tables = [] for rt in route_tables: rt_id = rt.get('RouteTableId', '') # Get association information association = None for assoc in rt.get('Associations', []): if assoc.get('SubnetId') == subnet_id: association = { 'id': assoc.get('RouteTableAssociationId', ''), 'main': assoc.get('Main', False) } break # Get name tag if available rt_name = '' for tag in rt.get('Tags', []): if tag.get('Key') == 'Name': rt_name = tag.get('Value', '') break associated_route_tables.append({ 'id': rt_id, 'name': rt_name, 'association': association }) subnet_details['route_tables'] = associated_route_tables # Get network ACLs for this subnet acl_filters = [{ 'Name': 'association.subnet-id', 'Values': [subnet_id] }] acl_response = ec2.describe_network_acls(filters=acl_filters, session_context=session_context) acls = acl_response.get('NetworkAcls', []) # Format network ACLs associated_acls = [] for acl in acls: acl_id = acl.get('NetworkAclId', '') is_default = acl.get('IsDefault', False) # Format entries entries = [] for entry in acl.get('Entries', []): entry_data = { 'rule_number': entry.get('RuleNumber', 0), 'protocol': entry.get('Protocol', ''), 'rule_action': entry.get('RuleAction', ''), 'egress': entry.get('Egress', False), 'cidr_block': entry.get('CidrBlock', '') } # Add port range information if available if 'PortRange' in entry: entry_data['port_range'] = { 'from': entry.get('PortRange', {}).get('From', 0), 'to': entry.get('PortRange', {}).get('To', 0) } entries.append(entry_data) # Get association information association = None for assoc in acl.get('Associations', []): if assoc.get('SubnetId') == subnet_id: association = { 'id': assoc.get('NetworkAclAssociationId', '') } break associated_acls.append({ 'id': acl_id, 'is_default': is_default, 'association': association, 'entries': entries }) subnet_details['network_acls'] = associated_acls formatted_subnets.append(subnet_details) # Create summary and response result = { 'subnets': formatted_subnets, 'count': len(formatted_subnets), 'total_count': len(formatted_subnets), # For legacy compatibility 'has_more': next_token_value is not None, 'next_token': next_token_value } return json.dumps(result) except Exception as e: logger.error(f"Error listing subnets: {e}", exc_info=True) return json.dumps({ 'error': str(e), 'subnets': [], 'count': 0, 'total_count': 0, 'has_more': False }) @register_tool() async def list_ec2_resources(resource_type: str = "all", limit: Optional[int] = None, search_term: str = "", state: str = "running", next_token: Optional[str] = None, session_context: Optional[str] = None) -> str: """List EC2 resources of the specified type. Args: resource_type: Type of EC2 resource to list (instances, security_groups, vpcs, route_tables, subnets, or all) limit: Maximum number of resources to return (None for all) search_term: Optional search term to filter resources state: Instance state to filter by (default is "running"). Only applies to instances. next_token: Optional pagination token for fetching next page of results session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with EC2 resource information """ logger.info(f"Listing EC2 resources (type={resource_type}, limit={limit}, search_term='{search_term}', state='{state}', next_token={next_token}, session_context={session_context})") try: # Normalize resource type resource_type = resource_type.lower() # Validate resource type valid_types = ["all", "instances", "security_groups", "vpcs", "route_tables", "subnets"] if resource_type not in valid_types: return json.dumps({ "error": f"Invalid resource type: {resource_type}. Valid types are: {', '.join(valid_types)}" }) if resource_type == "instances" or resource_type == "all": instances_result = await list_ec2_instances(limit=limit, search_term=search_term, state=state, next_token=next_token if resource_type == "instances" else None, session_context=session_context) instances_data = json.loads(instances_result) else: instances_data = {"instances": [], "count": 0, "has_more": False} if resource_type == "security_groups" or resource_type == "all": sg_result = await list_security_groups(limit=limit, search_term=search_term, next_token=next_token if resource_type == "security_groups" else None, session_context=session_context) sg_data = json.loads(sg_result) else: sg_data = {"security_groups": [], "count": 0, "has_more": False} if resource_type == "vpcs" or resource_type == "all": vpc_result = await list_vpcs(limit=limit, search_term=search_term, next_token=next_token if resource_type == "vpcs" else None, session_context=session_context) vpc_data = json.loads(vpc_result) else: vpc_data = {"vpcs": [], "count": 0, "has_more": False} if resource_type == "route_tables" or resource_type == "all": rt_result = await list_route_tables(limit=limit, search_term=search_term, next_token=next_token if resource_type == "route_tables" else None, session_context=session_context) rt_data = json.loads(rt_result) else: rt_data = {"route_tables": [], "count": 0, "has_more": False} if resource_type == "subnets" or resource_type == "all": subnet_result = await list_subnets(limit=limit, search_term=search_term, next_token=next_token if resource_type == "subnets" else None, session_context=session_context) subnet_data = json.loads(subnet_result) else: subnet_data = {"subnets": [], "count": 0, "has_more": False} # Build comprehensive result result = { "resource_type": resource_type, "instances": instances_data.get("instances", []) if resource_type in ["instances", "all"] else [], "security_groups": sg_data.get("security_groups", []) if resource_type in ["security_groups", "all"] else [], "vpcs": vpc_data.get("vpcs", []) if resource_type in ["vpcs", "all"] else [], "route_tables": rt_data.get("route_tables", []) if resource_type in ["route_tables", "all"] else [], "subnets": subnet_data.get("subnets", []) if resource_type in ["subnets", "all"] else [], "counts": { "instances": instances_data.get("count", 0) if resource_type in ["instances", "all"] else 0, "security_groups": sg_data.get("count", 0) if resource_type in ["security_groups", "all"] else 0, "vpcs": vpc_data.get("count", 0) if resource_type in ["vpcs", "all"] else 0, "route_tables": rt_data.get("count", 0) if resource_type in ["route_tables", "all"] else 0, "subnets": subnet_data.get("count", 0) if resource_type in ["subnets", "all"] else 0 }, "has_more": ( (resource_type == "instances" and instances_data.get("has_more", False)) or (resource_type == "security_groups" and sg_data.get("has_more", False)) or (resource_type == "vpcs" and vpc_data.get("has_more", False)) or (resource_type == "route_tables" and rt_data.get("has_more", False)) or (resource_type == "subnets" and subnet_data.get("has_more", False)) ), "next_token": ( instances_data.get("next_token") if resource_type == "instances" and instances_data.get("has_more", False) else sg_data.get("next_token") if resource_type == "security_groups" and sg_data.get("has_more", False) else vpc_data.get("next_token") if resource_type == "vpcs" and vpc_data.get("has_more", False) else rt_data.get("next_token") if resource_type == "route_tables" and rt_data.get("has_more", False) else subnet_data.get("next_token") if resource_type == "subnets" and subnet_data.get("has_more", False) else None ) } return json.dumps(result) except Exception as e: logger.error(f"Error listing EC2 resources: {e}", exc_info=True) return json.dumps({ "error": str(e), "resource_type": resource_type, "instances": [], "security_groups": [], "vpcs": [], "route_tables": [], "subnets": [], "counts": { "instances": 0, "security_groups": 0, "vpcs": 0, "route_tables": 0, "subnets": 0 }, "has_more": False }) @register_tool() async def find_public_security_groups(port: Optional[int] = None, session_context: Optional[str] = None) -> str: """Find security groups with public internet access (0.0.0.0/0). Args: port: Optional specific port to check for public access (e.g., 22 for SSH) session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with security groups that allow public access """ logger.info(f"Finding security groups with public internet access (port={port}, session_context={session_context})") try: # Get all security groups using pagination security_groups = [] ec2_client = ec2.get_ec2_client(session_context=session_context) paginator = ec2_client.get_paginator('describe_security_groups') # Iterate through all pages for page in paginator.paginate(): security_groups.extend(page.get('SecurityGroups', [])) # Filter security groups with public access public_sgs = [] for sg in security_groups: is_public = False is_port_match = False # Check each inbound rule for rule in sg.get('IpPermissions', []): rule_protocol = rule.get('IpProtocol', '') from_port = rule.get('FromPort') to_port = rule.get('ToPort') # Check if port matches if specified if port is not None: if rule_protocol == '-1': # All traffic is_port_match = True elif from_port is not None and to_port is not None: if from_port <= port <= to_port: is_port_match = True else: is_port_match = False else: is_port_match = True # No specific port to check # Check for public CIDR ranges for ip_range in rule.get('IpRanges', []): cidr = ip_range.get('CidrIp', '') if cidr in ('0.0.0.0/0', '::/0'): is_public = True break # If we found a public rule matching our criteria, no need to check further if is_public and is_port_match: public_sgs.append(sg) break if not public_sgs: if port is not None: return json.dumps({"summary": f"No security groups found with public internet access on port {port}", "security_groups": []}) else: return json.dumps({"summary": "No security groups found with public internet access", "security_groups": []}) # Format the results with JSON objects instead of strings formatted_sgs = [] for sg in public_sgs: # Create a structured JSON object for each security group formatted_sg = { "id": sg.get('GroupId'), "name": sg.get('GroupName'), "description": sg.get('Description'), "vpc_id": sg.get('VpcId'), "inbound_rules_count": len(sg.get('IpPermissions', [])), "outbound_rules_count": len(sg.get('IpPermissionsEgress', [])), "tags": {tag.get('Key'): tag.get('Value') for tag in sg.get('Tags', [])}, "inbound_rules": [] } # Format inbound rules for rule in sg.get('IpPermissions', []): protocol = rule.get('IpProtocol', '-1') # Handle protocol display if protocol == '-1': protocol_display = 'All Traffic' elif protocol == '6': protocol_display = 'TCP' elif protocol == '17': protocol_display = 'UDP' elif protocol == '1': protocol_display = 'ICMP' else: protocol_display = protocol.upper() # Handle port range display from_port = rule.get('FromPort') to_port = rule.get('ToPort') if protocol == '-1': port_range = 'All' elif from_port is None or to_port is None: port_range = 'N/A' elif from_port == to_port: port_range = str(from_port) else: port_range = f"{from_port}-{to_port}" # Process IP ranges for ip_range in rule.get('IpRanges', []): cidr = ip_range.get('CidrIp', '') description = ip_range.get('Description', '') formatted_rule = { "protocol": protocol_display, "port_range": port_range, "source": cidr, "description": description, "type": "IPv4" } formatted_sg["inbound_rules"].append(formatted_rule) # Process IPv6 ranges for ipv6_range in rule.get('Ipv6Ranges', []): cidr = ipv6_range.get('CidrIpv6', '') description = ipv6_range.get('Description', '') formatted_rule = { "protocol": protocol_display, "port_range": port_range, "source": cidr, "description": description, "type": "IPv6" } formatted_sg["inbound_rules"].append(formatted_rule) formatted_sgs.append(formatted_sg) summary = f"Found {len(public_sgs)} security group(s) with public internet access" if port is not None: summary += f" on port {port}" return json.dumps({"summary": summary, "security_groups": formatted_sgs}) except Exception as e: logger.error(f"Error finding public security groups: {e}") return json.dumps({"error": {"message": f"Error finding public security groups: {str(e)}", "type": type(e).__name__}}) @register_tool() async def find_instances_with_public_access(port: Optional[int] = None, state: str = "running", session_context: Optional[str] = None) -> str: """Find EC2 instances that have public internet access through their security groups. Args: port: Optional specific port to check for public access (e.g., 22 for SSH) state: Instance state to filter by (default is "running") session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with publicly accessible instances """ logger.info(f"Finding EC2 instances with public access (port={port}, state='{state}', session_context={session_context})") try: # Get all instances in the specified state - pass session_context to ec2 service calls states = [state] if state else None instances = ec2.get_all_instances(max_items=None, states=states, session_context=session_context) # Get all security groups - pass session_context to ec2 service calls security_groups = ec2.get_all_security_groups(max_items=None, session_context=session_context) # Create lookup dictionary for security groups sg_dict = {sg.get('GroupId'): sg for sg in security_groups} # Track publicly accessible instances and their exposed ports public_instances = [] # Process each instance for instance in instances: instance_id = instance.get('InstanceId') # Extract instance security groups instance_sgs = instance.get('SecurityGroups', []) # Check each security group for public access rules public_ports = [] public_sg_ids = [] for instance_sg in instance_sgs: sg_id = instance_sg.get('GroupId') # Get the full security group details sg = sg_dict.get(sg_id) if not sg: continue # Check inbound rules for rule in sg.get('IpPermissions', []): # Check if this matches our target port if specified if port is not None: # Skip rules that don't match our port # Handle protocol -1 (all traffic) if rule.get('IpProtocol') != '-1': from_port = rule.get('FromPort') to_port = rule.get('ToPort') # Skip if port is not in range if from_port is not None and to_port is not None: if port < from_port or port > to_port: continue # Check for public CIDR ranges for ip_range in rule.get('IpRanges', []): cidr = ip_range.get('CidrIp', '') if cidr == '0.0.0.0/0': # This rule allows public access if rule.get('IpProtocol') == '-1': # All ports public_ports.append('ALL') elif rule.get('FromPort') == rule.get('ToPort'): # Single port public_ports.append(str(rule.get('FromPort'))) else: # Port range public_ports.append(f"{rule.get('FromPort')}-{rule.get('ToPort')}") # Track the security group ID if sg_id not in public_sg_ids: public_sg_ids.append(sg_id) # Check IPv6 ranges too for ipv6_range in rule.get('Ipv6Ranges', []): cidr = ipv6_range.get('CidrIpv6', '') if cidr == '::/0': # This rule allows public IPv6 access if rule.get('IpProtocol') == '-1': # All ports public_ports.append('ALL (IPv6)') elif rule.get('FromPort') == rule.get('ToPort'): # Single port public_ports.append(f"{rule.get('FromPort')} (IPv6)") else: # Port range public_ports.append(f"{rule.get('FromPort')}-{rule.get('ToPort')} (IPv6)") # Track the security group ID if sg_id not in public_sg_ids: public_sg_ids.append(sg_id) # If we found public access ports for this instance, add it to our list if public_ports: # Extract name tag name = None for tag in instance.get('Tags', []): if tag.get('Key') == 'Name': name = tag.get('Value') break # Get IAM Instance Profile information iam_profile_info = None if instance.get('IamInstanceProfile'): iam_profile = instance.get('IamInstanceProfile', {}) iam_profile_info = { "id": iam_profile.get('Id'), "arn": iam_profile.get('Arn') } # Format the instance details instance_details = { "id": instance_id, "name": name, "type": instance.get('InstanceType'), "state": instance.get('State', {}).get('Name'), "private_ip": instance.get('PrivateIpAddress'), "public_ip": instance.get('PublicIpAddress'), "vpc_id": instance.get('VpcId'), "subnet_id": instance.get('SubnetId'), "public_ports": public_ports, "public_security_groups": public_sg_ids, "iam_instance_profile": iam_profile_info } public_instances.append(instance_details) # Build the summary message if port is not None: summary = f"Found {len(public_instances)} {state} instances with public access on port {port}" else: summary = f"Found {len(public_instances)} {state} instances with public internet access" # Build the final response result = { "summary": summary, "count": len(public_instances), "instances": public_instances } return json.dumps(result) except Exception as e: logger.error(f"Error finding instances with public access: {e}") return json.dumps({"error": f"Error finding instances with public access: {str(e)}"}) @register_tool() async def find_resource_by_ip(ip_address: str, session_context: Optional[str] = None) -> str: """Find AWS resources associated with a specific IP address. Args: ip_address: IP address to search for (public or private) session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with information about resources using the IP address """ logger.info(f"Searching for resources with IP address: {ip_address}") try: # Get EC2 client - pass session_context to get cross-account access client = ec2.get_ec2_client(session_context=session_context) # Try to find ENIs with this IP private_ip_filters = [ {'Name': 'private-ip-address', 'Values': [ip_address]} ] # First check for private IPs private_response = client.describe_network_interfaces(Filters=private_ip_filters) interfaces = private_response.get('NetworkInterfaces', []) # If not found as private IP, check for public IPs if not interfaces: logger.info(f"IP {ip_address} not found as private IP, checking public IPs") public_ip_filters = [ {'Name': 'association.public-ip', 'Values': [ip_address]} ] public_response = client.describe_network_interfaces(Filters=public_ip_filters) interfaces = public_response.get('NetworkInterfaces', []) if not interfaces: return json.dumps({"summary": f"No resources found with IP address {ip_address}"}) # Process the results results = [] for eni in interfaces: eni_id = eni.get('NetworkInterfaceId', 'Unknown') subnet_id = eni.get('SubnetId', 'Unknown') vpc_id = eni.get('VpcId', 'Unknown') private_ip = eni.get('PrivateIpAddress', 'Unknown') # Get public IP if available public_ip = None association = eni.get('Association', {}) if association: public_ip = association.get('PublicIp') # Determine the resource type and ID resource_type = "Unknown" resource_id = "Unknown" attachment = eni.get('Attachment', {}) if attachment: instance_id = attachment.get('InstanceId') if instance_id: resource_type = "EC2 Instance" resource_id = instance_id # Get instance details try: instance_response = client.describe_instances(InstanceIds=[instance_id]) reservations = instance_response.get('Reservations', []) if reservations and reservations[0].get('Instances'): instance = reservations[0]['Instances'][0] instance_name = None for tag in instance.get('Tags', []): if tag.get('Key') == 'Name': instance_name = tag.get('Value') break if instance_name: resource_id = f"{instance_id} ({instance_name})" except Exception as e: logger.warning(f"Error getting instance details: {e}") # Check for other resource types owner_id = eni.get('RequesterId') description = eni.get('Description', '') if description and owner_id: # Check for ELB if 'ELB' in description or 'elasticloadbalancing' in owner_id: resource_type = "Elastic Load Balancer" resource_id = description # Check for RDS elif 'RDS' in description or 'rds.amazonaws.com' in owner_id: resource_type = "RDS Database" resource_id = description # Check for Lambda elif 'lambda' in description.lower() or 'lambda' in owner_id.lower(): resource_type = "Lambda Function" resource_id = description # Check for NAT Gateway elif 'NAT' in description or 'nat-' in description: resource_type = "NAT Gateway" resource_id = description # Check for ECS tasks elif 'ecs-tasks' in owner_id: resource_type = "ECS Task" resource_id = description # Format the result result = { "ip_address": ip_address, "private_ip": private_ip if private_ip != ip_address and private_ip != "Unknown" else None, "public_ip": public_ip if public_ip and public_ip != ip_address else None, "network_interface": eni_id, "resource_type": resource_type, "resource_id": resource_id, "vpc": vpc_id, "subnet": subnet_id, "status": eni.get('Status', 'Unknown'), "description": description if description else None, "security_groups": [ { "id": sg.get('GroupId', 'Unknown'), "name": sg.get('GroupName', 'Unknown') } for sg in eni.get('Groups', []) ] } results.append(result) summary = f"Found {len(interfaces)} resource(s) with IP address {ip_address}" return json.dumps({"summary": summary, "resources": results}) except Exception as e: logger.error(f"Error finding resources by IP: {e}") return json.dumps({"error": {"message": f"Error searching for IP address {ip_address}: {str(e)}", "type": type(e).__name__}}) @register_tool() async def find_instances_by_port(port: int, state: str = "running", session_context: Optional[str] = None) -> str: """Find EC2 instances with security groups allowing access on a specific port. Args: port: The port number to check for state: Instance state to filter by (default is "running") session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with instances that have the specified port open """ logger.info(f"Finding EC2 instances with port {port} open (state='{state}', session_context={session_context})") try: # Get all instances in the specified state - pass session_context to ec2 service calls states = [state] if state else None instances = ec2.get_all_instances(max_items=None, states=states, session_context=session_context) # Get all security groups - pass session_context to ec2 service calls security_groups = ec2.get_all_security_groups(max_items=None, session_context=session_context) # Identify security groups that allow the specified port port_open_sg_ids = [] for sg in security_groups: sg_id = sg.get('GroupId') # Check inbound rules for rule in sg.get('IpPermissions', []): # All traffic (-1) allows any port if rule.get('IpProtocol') == '-1': port_open_sg_ids.append(sg_id) break # For TCP/UDP protocols, check port range from_port = rule.get('FromPort') to_port = rule.get('ToPort') # Check if port is in range (if both ports are defined) if from_port is not None and to_port is not None: if from_port <= port <= to_port: port_open_sg_ids.append(sg_id) break # Create lookup dictionary for security groups sg_dict = {sg.get('GroupId'): sg for sg in security_groups} # Find instances with the open port port_open_instances = [] for instance in instances: instance_id = instance.get('InstanceId') # Extract instance security groups instance_sgs = instance.get('SecurityGroups', []) # Check if any of the instance's security groups allow the port instance_open_sg_ids = [] for instance_sg in instance_sgs: sg_id = instance_sg.get('GroupId') if sg_id in port_open_sg_ids: instance_open_sg_ids.append(sg_id) # If we found security groups allowing this port, add the instance to our list if instance_open_sg_ids: # Extract name tag name = None for tag in instance.get('Tags', []): if tag.get('Key') == 'Name': name = tag.get('Value') break # Get IAM Instance Profile information iam_profile_info = None if instance.get('IamInstanceProfile'): iam_profile = instance.get('IamInstanceProfile', {}) iam_profile_info = { "id": iam_profile.get('Id'), "arn": iam_profile.get('Arn') } # Format the instance details instance_details = { "id": instance_id, "name": name, "type": instance.get('InstanceType'), "state": instance.get('State', {}).get('Name'), "private_ip": instance.get('PrivateIpAddress'), "public_ip": instance.get('PublicIpAddress'), "vpc_id": instance.get('VpcId'), "subnet_id": instance.get('SubnetId'), "security_groups_with_port_open": instance_open_sg_ids, "iam_instance_profile": iam_profile_info } # Get security group details sg_details = [] for sg_id in instance_open_sg_ids: sg = sg_dict.get(sg_id) if sg: sg_detail = { "id": sg_id, "name": sg.get('GroupName'), "description": sg.get('Description'), "rules": [] } # Add rules that include the port for rule in sg.get('IpPermissions', []): # Check if rule includes the port if rule.get('IpProtocol') == '-1' or (rule.get('FromPort', float('inf')) <= port <= rule.get('ToPort', float('-inf'))): sg_detail["rules"].append({ "protocol": rule.get('IpProtocol'), "port_range": f"{rule.get('FromPort')}-{rule.get('ToPort')}" if rule.get('FromPort') != rule.get('ToPort') else str(rule.get('FromPort')), "sources": [cidr.get('CidrIp') for cidr in rule.get('IpRanges', [])] }) sg_details.append(sg_detail) instance_details["security_group_details"] = sg_details port_open_instances.append(instance_details) # Build the summary message summary = f"Found {len(port_open_instances)} {state} instances with port {port} open" # Build the final response result = { "summary": summary, "count": len(port_open_instances), "instances": port_open_instances } return json.dumps(result) except Exception as e: logger.error(f"Error finding instances by port: {e}") return json.dumps({"error": f"Error finding instances by port: {str(e)}"}) @register_tool() async def find_security_groups_by_port(port: int, session_context: Optional[str] = None) -> str: """Find security groups with a specific port open. Args: port: Port number to check for (e.g., 22 for SSH, 3389 for RDP) session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with security groups that have the specified port open """ logger.info(f"Finding security groups with port {port} open (session_context={session_context})") try: # Get all security groups using pagination security_groups = [] ec2_client = ec2.get_ec2_client(session_context=session_context) paginator = ec2_client.get_paginator('describe_security_groups') # Iterate through all pages for page in paginator.paginate(): security_groups.extend(page.get('SecurityGroups', [])) # Filter security groups with the specified port open matching_sgs = [] for sg in security_groups: is_port_match = False # Check each inbound rule for rule in sg.get('IpPermissions', []): rule_protocol = rule.get('IpProtocol', '') from_port = rule.get('FromPort') to_port = rule.get('ToPort') # Check if port matches if rule_protocol == '-1': # All traffic is_port_match = True elif from_port is not None and to_port is not None: if from_port <= port <= to_port: is_port_match = True # If we found a rule matching our criteria, no need to check further if is_port_match: matching_sgs.append(sg) break if not matching_sgs: return json.dumps({"summary": f"No security groups found with port {port} open", "security_groups": []}) # Format the results with JSON objects instead of strings formatted_sgs = [] for sg in matching_sgs: # Create a structured JSON object for each security group formatted_sg = { "id": sg.get('GroupId'), "name": sg.get('GroupName'), "description": sg.get('Description'), "vpc_id": sg.get('VpcId'), "inbound_rules_count": len(sg.get('IpPermissions', [])), "outbound_rules_count": len(sg.get('IpPermissionsEgress', [])), "tags": {tag.get('Key'): tag.get('Value') for tag in sg.get('Tags', [])}, "inbound_rules": [] } # Format inbound rules for rule in sg.get('IpPermissions', []): protocol = rule.get('IpProtocol', '-1') # Handle protocol display if protocol == '-1': protocol_display = 'All Traffic' elif protocol == '6': protocol_display = 'TCP' elif protocol == '17': protocol_display = 'UDP' elif protocol == '1': protocol_display = 'ICMP' else: protocol_display = protocol.upper() # Handle port range display from_port = rule.get('FromPort') to_port = rule.get('ToPort') if protocol == '-1': port_range = 'All' elif from_port is None or to_port is None: port_range = 'N/A' elif from_port == to_port: port_range = str(from_port) else: port_range = f"{from_port}-{to_port}" # Process IP ranges for ip_range in rule.get('IpRanges', []): cidr = ip_range.get('CidrIp', '') description = ip_range.get('Description', '') formatted_rule = { "protocol": protocol_display, "port_range": port_range, "source": cidr, "description": description, "type": "IPv4" } formatted_sg["inbound_rules"].append(formatted_rule) # Process IPv6 ranges for ipv6_range in rule.get('Ipv6Ranges', []): cidr = ipv6_range.get('CidrIpv6', '') description = ipv6_range.get('Description', '') formatted_rule = { "protocol": protocol_display, "port_range": port_range, "source": cidr, "description": description, "type": "IPv6" } formatted_sg["inbound_rules"].append(formatted_rule) formatted_sgs.append(formatted_sg) summary = f"Found {len(matching_sgs)} security group(s) with port {port} open" return json.dumps({"summary": summary, "security_groups": formatted_sgs}) except Exception as e: logger.error(f"Error finding security groups by port: {e}") return json.dumps({"error": {"message": f"Error finding security groups by port: {str(e)}", "type": type(e).__name__}}) @register_tool() async def batch_describe_security_groups(security_group_ids: List[str], session_context: Optional[str] = None) -> str: """Batch describe multiple security groups by ID. Args: security_group_ids: List of security group IDs to describe session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with detailed information about multiple security groups """ logger.info(f"Batch describing {len(security_group_ids)} security groups (session_context={session_context})") if not security_group_ids: return json.dumps({"error": "No security group IDs provided"}) try: # EC2 API can accept multiple security group IDs in a single call # But we'll batch in smaller chunks to be safe batch_size = 100 all_security_groups = [] not_found_sgs = [] client = ec2.get_ec2_client(session_context=session_context) # Process in batches for i in range(0, len(security_group_ids), batch_size): batch = security_group_ids[i:i+batch_size] try: response = client.describe_security_groups(GroupIds=batch) # Extract security groups security_groups = response.get('SecurityGroups', []) all_security_groups.extend(security_groups) # Check if any security groups were not found found_ids = set() for sg in security_groups: found_ids.add(sg.get('GroupId')) for sg_id in batch: if sg_id not in found_ids: not_found_sgs.append(sg_id) except Exception as batch_error: logger.warning(f"Error describing batch {i//batch_size}: {batch_error}") # If the error is due to non-existent security groups, extract those IDs error_message = str(batch_error) if "InvalidGroup.NotFound" in error_message or "InvalidGroupId.Malformed" in error_message: # Try to extract the security group IDs from the error message import re missing_ids = re.findall(r'sg-[a-zA-Z0-9]+', error_message) if missing_ids: not_found_sgs.extend(missing_ids) else: # If we can't extract IDs, mark all as not found not_found_sgs.extend(batch) else: # For other errors, mark all in this batch as failed not_found_sgs.extend(batch) # Format the results formatted_sgs = [] for sg in all_security_groups: formatted_sg = { "id": sg.get('GroupId'), "name": sg.get('GroupName'), "description": sg.get('Description'), "vpc_id": sg.get('VpcId'), "tags": {tag.get('Key'): tag.get('Value') for tag in sg.get('Tags', [])}, "inbound_rules": [], "outbound_rules": [] } # Format inbound rules for rule in sg.get('IpPermissions', []): protocol = rule.get('IpProtocol', '-1') # Handle protocol display if protocol == '-1': protocol_display = 'All Traffic' elif protocol == '6': protocol_display = 'TCP' elif protocol == '17': protocol_display = 'UDP' elif protocol == '1': protocol_display = 'ICMP' else: protocol_display = protocol.upper() # Handle port range display from_port = rule.get('FromPort') to_port = rule.get('ToPort') if protocol == '-1': port_range = 'All' elif from_port is None or to_port is None: port_range = 'N/A' elif from_port == to_port: port_range = str(from_port) else: port_range = f"{from_port}-{to_port}" # Process IP ranges for ip_range in rule.get('IpRanges', []): cidr = ip_range.get('CidrIp', '') description = ip_range.get('Description', '') formatted_rule = { "protocol": protocol_display, "port_range": port_range, "source": cidr, "description": description, "type": "IPv4" } formatted_sg["inbound_rules"].append(formatted_rule) # Process IPv6 ranges for ipv6_range in rule.get('Ipv6Ranges', []): cidr = ipv6_range.get('CidrIpv6', '') description = ipv6_range.get('Description', '') formatted_rule = { "protocol": protocol_display, "port_range": port_range, "source": cidr, "description": description, "type": "IPv6" } formatted_sg["inbound_rules"].append(formatted_rule) # Process security group references for group_pair in rule.get('UserIdGroupPairs', []): group_id = group_pair.get('GroupId', '') description = group_pair.get('Description', '') formatted_rule = { "protocol": protocol_display, "port_range": port_range, "source": f"sg-{group_id}" if not group_id.startswith('sg-') else group_id, "description": description, "type": "Security Group" } formatted_sg["inbound_rules"].append(formatted_rule) # Format outbound rules (similar logic as inbound rules) for rule in sg.get('IpPermissionsEgress', []): protocol = rule.get('IpProtocol', '-1') # Handle protocol display if protocol == '-1': protocol_display = 'All Traffic' elif protocol == '6': protocol_display = 'TCP' elif protocol == '17': protocol_display = 'UDP' elif protocol == '1': protocol_display = 'ICMP' else: protocol_display = protocol.upper() # Handle port range display from_port = rule.get('FromPort') to_port = rule.get('ToPort') if protocol == '-1': port_range = 'All' elif from_port is None or to_port is None: port_range = 'N/A' elif from_port == to_port: port_range = str(from_port) else: port_range = f"{from_port}-{to_port}" # Process IP ranges for ip_range in rule.get('IpRanges', []): cidr = ip_range.get('CidrIp', '') description = ip_range.get('Description', '') formatted_rule = { "protocol": protocol_display, "port_range": port_range, "destination": cidr, "description": description, "type": "IPv4" } formatted_sg["outbound_rules"].append(formatted_rule) # Process IPv6 ranges for ipv6_range in rule.get('Ipv6Ranges', []): cidr = ipv6_range.get('CidrIpv6', '') description = ipv6_range.get('Description', '') formatted_rule = { "protocol": protocol_display, "port_range": port_range, "destination": cidr, "description": description, "type": "IPv6" } formatted_sg["outbound_rules"].append(formatted_rule) # Process security group references for group_pair in rule.get('UserIdGroupPairs', []): group_id = group_pair.get('GroupId', '') description = group_pair.get('Description', '') formatted_rule = { "protocol": protocol_display, "port_range": port_range, "destination": f"sg-{group_id}" if not group_id.startswith('sg-') else group_id, "description": description, "type": "Security Group" } formatted_sg["outbound_rules"].append(formatted_rule) formatted_sgs.append(formatted_sg) # Build the final response result = { "summary": f"Found {len(formatted_sgs)} of {len(security_group_ids)} requested security groups", "security_groups": formatted_sgs, "count": len(formatted_sgs), "not_found": not_found_sgs, "not_found_count": len(not_found_sgs) } return json.dumps(result, default=lambda o: o.isoformat() if hasattr(o, 'isoformat') else str(o)) except Exception as e: logger.error(f"Error batch describing security groups: {e}") return json.dumps({"error": {"message": f"Error batch describing security groups: {str(e)}", "type": type(e).__name__}}) @register_tool() async def batch_describe_instances(instance_ids: List[str], session_context: Optional[str] = None) -> str: """Batch describe multiple EC2 instances by ID. Args: instance_ids: List of EC2 instance IDs to describe session_context: Optional session key for cross-account access (e.g., "123456789012_aws_dev") Returns: JSON formatted string with detailed information about multiple EC2 instances """ logger.info(f"Batch describing {len(instance_ids)} EC2 instances (session_context={session_context})") if not instance_ids: return json.dumps({"error": "No instance IDs provided"}) try: all_instances = [] not_found_instances = [] # Get EC2 client - pass session_context to get cross-account access client = ec2.get_ec2_client(session_context=session_context) # The EC2 API has a limit of 1000 IDs per call, so we need to batch # This isn't limiting the results, just working within API constraints max_ids_per_request = 1000 # AWS API limit # Process all instance IDs in necessary batches for i in range(0, len(instance_ids), max_ids_per_request): batch = instance_ids[i:i+max_ids_per_request] try: response = client.describe_instances(InstanceIds=batch) # Extract instances from reservations found_instances = [] for reservation in response.get('Reservations', []): found_instances.extend(reservation.get('Instances', [])) # Add to our cumulative list all_instances.extend(found_instances) # Check if any instance IDs were not found found_ids = set(instance.get('InstanceId') for instance in found_instances) for instance_id in batch: if instance_id not in found_ids: not_found_instances.append(instance_id) except Exception as batch_error: logger.warning(f"Error describing batch {i//max_ids_per_request}: {batch_error}") # If the error is due to non-existent instances, extract those IDs error_message = str(batch_error) if "InvalidInstanceID.NotFound" in error_message: # Try to extract the instance IDs from the error message import re missing_ids = re.findall(r'i-[a-zA-Z0-9]+', error_message) if missing_ids: not_found_instances.extend(missing_ids) else: # If we can't extract IDs, mark all as not found not_found_instances.extend(batch) else: # For other errors, mark all in this batch as failed not_found_instances.extend(batch) # Format the results - no limiting of results here formatted_instances = [] for instance in all_instances: # Extract name tag if available name = None for tag in instance.get('Tags', []): if tag.get('Key') == 'Name': name = tag.get('Value') break # Basic instance details formatted_instance = { "id": instance.get('InstanceId'), "name": name, "type": instance.get('InstanceType'), "state": instance.get('State', {}).get('Name'), "private_ip": instance.get('PrivateIpAddress'), "public_ip": instance.get('PublicIpAddress'), "vpc_id": instance.get('VpcId'), "subnet_id": instance.get('SubnetId'), "image_id": instance.get('ImageId'), "availability_zone": instance.get('Placement', {}).get('AvailabilityZone'), "launch_time": instance.get('LaunchTime').isoformat() if instance.get('LaunchTime') else None, "key_name": instance.get('KeyName'), "monitoring": instance.get('Monitoring', {}).get('State'), "platform": instance.get('Platform'), "architecture": instance.get('Architecture'), "root_device_type": instance.get('RootDeviceType'), "root_device_name": instance.get('RootDeviceName'), "virtualization_type": instance.get('VirtualizationType'), "tags": {tag.get('Key'): tag.get('Value') for tag in instance.get('Tags', [])} } # Security groups - include ALL security groups formatted_instance["security_groups"] = [ { "id": sg.get('GroupId'), "name": sg.get('GroupName') } for sg in instance.get('SecurityGroups', []) ] # IAM Instance Profile information if available if instance.get('IamInstanceProfile'): iam_profile = instance.get('IamInstanceProfile', {}) formatted_instance["iam_instance_profile"] = { "id": iam_profile.get('Id'), "arn": iam_profile.get('Arn') } else: formatted_instance["iam_instance_profile"] = None # EBS volume information - include ALL volumes formatted_instance["ebs_volumes"] = [] for block_device in instance.get('BlockDeviceMappings', []): if 'Ebs' in block_device: volume_info = { "device_name": block_device.get('DeviceName'), "volume_id": block_device.get('Ebs', {}).get('VolumeId'), "status": block_device.get('Ebs', {}).get('Status'), "delete_on_termination": block_device.get('Ebs', {}).get('DeleteOnTermination', False), "volume_type": block_device.get('Ebs', {}).get('VolumeType'), "size_gb": block_device.get('Ebs', {}).get('Size'), "encrypted": block_device.get('Ebs', {}).get('Encrypted', False), "kms_key_id": block_device.get('Ebs', {}).get('KmsKeyId') } formatted_instance["ebs_volumes"].append(volume_info) # Network interfaces - include ALL interfaces formatted_instance["network_interfaces"] = [] for eni in instance.get('NetworkInterfaces', []): eni_info = { "id": eni.get('NetworkInterfaceId'), "description": eni.get('Description'), "subnet_id": eni.get('SubnetId'), "vpc_id": eni.get('VpcId'), "private_ip": eni.get('PrivateIpAddress'), "private_dns_name": eni.get('PrivateDnsName'), "status": eni.get('Status'), "mac_address": eni.get('MacAddress'), "security_groups": [ { "id": sg.get('GroupId'), "name": sg.get('GroupName') } for sg in eni.get('Groups', []) ], "source_dest_check": eni.get('SourceDestCheck') } # Add all private IPs eni_info["private_ips"] = [ { "ip": ip.get('PrivateIpAddress'), "is_primary": ip.get('Primary', False) } for ip in eni.get('PrivateIpAddresses', []) ] # Add public IP if available if eni.get('Association', {}).get('PublicIp'): eni_info["public_ip"] = eni.get('Association', {}).get('PublicIp') eni_info["public_dns_name"] = eni.get('Association', {}).get('PublicDnsName') formatted_instance["network_interfaces"].append(eni_info) # Product codes formatted_instance["product_codes"] = [ { "id": pc.get('ProductCodeId'), "type": pc.get('ProductCodeType') } for pc in instance.get('ProductCodes', []) ] # CPU options if instance.get('CpuOptions'): formatted_instance["cpu_options"] = { "cores": instance.get('CpuOptions', {}).get('CoreCount'), "threads_per_core": instance.get('CpuOptions', {}).get('ThreadsPerCore') } # Instance lifecycle (spot, scheduled) if instance.get('InstanceLifecycle'): formatted_instance["lifecycle"] = instance.get('InstanceLifecycle') if instance.get('SpotInstanceRequestId'): formatted_instance["spot_instance_request_id"] = instance.get('SpotInstanceRequestId') # State transition reason if instance.get('StateTransitionReason'): formatted_instance["state_transition_reason"] = instance.get('StateTransitionReason') # Hypervisor information if instance.get('Hypervisor'): formatted_instance["hypervisor"] = instance.get('Hypervisor') # ENA support if 'EnaSupport' in instance: formatted_instance["ena_support"] = instance.get('EnaSupport') # Public DNS name if instance.get('PublicDnsName'): formatted_instance["public_dns_name"] = instance.get('PublicDnsName') # Private DNS name if instance.get('PrivateDnsName'): formatted_instance["private_dns_name"] = instance.get('PrivateDnsName') formatted_instances.append(formatted_instance) # Build the final response - all data included, no limits result = { "summary": f"Found {len(formatted_instances)} of {len(instance_ids)} requested EC2 instances", "instances": formatted_instances, "count": len(formatted_instances), "not_found": not_found_instances, "not_found_count": len(not_found_instances) } return json.dumps(result, default=lambda o: o.isoformat() if hasattr(o, 'isoformat') else str(o)) except Exception as e: logger.error(f"Error batch describing EC2 instances: {e}") return json.dumps({"error": {"message": f"Error batch describing EC2 instances: {str(e)}", "type": type(e).__name__}})

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/groovyBugify/aws-security-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server