#!/usr/bin/env python3
"""
NameChecker MCP Server
A Model Context Protocol server that provides domain name availability checking
capabilities to AI assistants. Supports both stdio and SSE transports.
"""
import argparse
import asyncio
import logging
import re
import sys
from typing import Optional, Dict, Any
from urllib.parse import urlparse
import httpx
from mcp import types
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.server.sse import SseServerTransport
from pydantic import BaseModel, Field, field_validator
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Create MCP server instance
server = Server("namechecker-mcp")
# Domain validation regex (RFC compliant)
DOMAIN_REGEX = re.compile(
r"^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$"
)
# Common TLDs for validation
VALID_TLDS = {
"com",
"org",
"net",
"edu",
"gov",
"mil",
"int",
"co",
"io",
"ai",
"app",
"dev",
"tech",
"xyz",
"info",
"biz",
"name",
"pro",
"mobi",
"travel",
"museum",
"aero",
"coop",
"jobs",
"tel",
"xxx",
"uk",
"de",
"fr",
"jp",
"cn",
"au",
"ca",
"us",
"ru",
"in",
"br",
"it",
"es",
"nl",
"pl",
"se",
}
class DomainValidationError(Exception):
"""Raised when domain validation fails."""
pass
class DomainCheckError(Exception):
"""Raised when domain availability check fails."""
pass
class DomainRequest(BaseModel):
"""Domain request validation model."""
domain: str = Field(..., min_length=1, max_length=255)
tld: Optional[str] = Field(default="com", min_length=2, max_length=10)
@field_validator("domain")
@classmethod
def validate_domain_format(cls, v: str) -> str:
"""Validate domain name format."""
if not v:
raise ValueError("Domain cannot be empty")
# Remove protocol if present
if "://" in v:
parsed = urlparse(f"http://{v}")
v = parsed.netloc or parsed.path
# Remove www prefix
if v.startswith("www."):
v = v[4:]
# Convert to lowercase
v = v.lower().strip()
# Check if domain already has TLD
if "." in v:
parts = v.split(".")
if len(parts) >= 2 and parts[-1] in VALID_TLDS:
return v
# Basic format validation for domain part
domain_part = v.split(".")[0] if "." in v else v
if not re.match(r"^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?$", domain_part):
raise ValueError("Invalid domain name format")
return v
@field_validator("tld")
@classmethod
def validate_tld(cls, v: Optional[str]) -> str:
"""Validate TLD format."""
if v is None:
return "com"
v = v.lower().strip()
if not re.match(r"^[a-zA-Z]{2,10}$", v):
raise ValueError("Invalid TLD format")
return v
class DomainChecker:
"""Handles domain availability checking through multiple methods."""
def __init__(self, timeout: int = 30):
self.timeout = timeout
self.client = httpx.AsyncClient(timeout=timeout)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.client.aclose()
def _build_full_domain(self, domain: str, tld: str) -> str:
"""Build full domain name with TLD."""
if "." in domain and domain.split(".")[-1] in VALID_TLDS:
return domain # Domain already has TLD
return f"{domain}.{tld}"
async def check_dns_resolution(self, full_domain: str) -> bool:
"""Check domain availability using DNS resolution."""
try:
import socket
# Try to resolve the domain
await asyncio.get_event_loop().run_in_executor(
None, socket.gethostbyname, full_domain
)
# If resolution succeeds, domain is likely registered
return False
except socket.gaierror:
# If resolution fails, domain might be available
return True
except Exception as e:
logger.warning(f"DNS resolution error for {full_domain}: {e}")
# Assume unavailable on error to be safe
return False
async def check_availability(self, domain: str, tld: str) -> bool:
"""
Check domain availability using multiple methods.
Args:
domain: Domain name to check
tld: Top-level domain
Returns:
bool: True if domain appears to be available, False otherwise
Raises:
DomainCheckError: If all checking methods fail
"""
full_domain = self._build_full_domain(domain, tld)
logger.info(f"Checking availability for: {full_domain}")
try:
dns_result = await self.check_dns_resolution(full_domain)
logger.info(
f"DNS resolution result for {full_domain}: {'available' if dns_result else 'unavailable'}"
)
return dns_result
except Exception as e:
logger.error(f"DNS resolution failed for {full_domain}: {e}")
raise DomainCheckError(
f"Failed to check domain availability for {full_domain}: {e}"
)
# Global domain checker instance
domain_checker = None
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
"""List available tools."""
return [
types.Tool(
name="check_domain_availability",
description="Check if a domain name is available for registration",
inputSchema={
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": "Domain name to check (without TLD unless already included)",
},
"tld": {
"type": "string",
"description": "Top-level domain (default: 'com')",
"default": "com",
},
},
"required": ["domain"],
},
),
types.Tool(
name="validate_domain_syntax",
description="Validate domain name syntax according to RFC standards",
inputSchema={
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": "Domain name to validate",
}
},
"required": ["domain"],
},
),
]
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]:
"""Handle tool calls."""
if name == "check_domain_availability":
domain = arguments.get("domain", "")
tld = arguments.get("tld", "com")
result = await check_domain_availability(domain, tld)
return [types.TextContent(type="text", text=str(result))]
elif name == "validate_domain_syntax":
domain = arguments.get("domain", "")
result = await validate_domain_syntax(domain)
return [types.TextContent(type="text", text=str(result))]
else:
raise ValueError(f"Unknown tool: {name}")
async def check_domain_availability(domain: str, tld: str = "com") -> bool:
"""
Check if a domain name is available for registration.
Args:
domain: Domain name to check (without TLD unless already included)
tld: Top-level domain (default: "com")
Returns:
bool: True if domain appears to be available, False if unavailable
Raises:
ValueError: If domain format is invalid
DomainCheckError: If availability check fails
"""
global domain_checker
try:
# Validate input
request = DomainRequest(domain=domain, tld=tld)
# Initialize domain checker if needed
if domain_checker is None:
domain_checker = DomainChecker()
# Check availability (request.tld is guaranteed to be str by validator)
result = await domain_checker.check_availability(
request.domain, request.tld or "com"
)
return result
except ValueError as e:
logger.error(f"Domain validation error: {e}")
raise ValueError(f"Invalid domain format: {e}")
except Exception as e:
logger.error(f"Domain check error: {e}")
raise DomainCheckError(f"Failed to check domain availability: {e}")
async def validate_domain_syntax(domain: str) -> Dict[str, Any]:
"""
Validate domain name syntax according to RFC standards.
Args:
domain: Domain name to validate
Returns:
dict: Validation result with details
"""
try:
# Try to validate using DomainRequest
request = DomainRequest(domain=domain)
# Additional checks
full_domain = request.domain
if "." in full_domain:
parts = full_domain.split(".")
domain_part = ".".join(parts[:-1])
tld_part = parts[-1]
else:
domain_part = full_domain
tld_part = request.tld
# Check domain part length
if len(domain_part) > 63:
return {
"valid": False,
"error": "Domain part exceeds 63 characters",
"domain": domain_part,
"tld": tld_part,
}
# Check total length
total_length = len(f"{domain_part}.{tld_part}")
if total_length > 255:
return {
"valid": False,
"error": "Total domain length exceeds 255 characters",
"domain": domain_part,
"tld": tld_part,
}
return {
"valid": True,
"domain": domain_part,
"tld": tld_part,
"full_domain": f"{domain_part}.{tld_part}",
"length": total_length,
}
except ValueError as e:
return {"valid": False, "error": str(e), "domain": domain, "tld": None}
async def run_stdio():
"""Run MCP server with stdio transport."""
logger.info("Starting NameChecker MCP server with stdio transport")
async with stdio_server() as (read_stream, write_stream):
await server.run(
read_stream, write_stream, server.create_initialization_options()
)
async def run_sse(port: int):
"""Run MCP server with SSE transport."""
logger.info(f"Starting NameChecker MCP server with SSE transport on port {port}")
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.responses import JSONResponse
import uvicorn
# Create SSE transport
sse_transport = SseServerTransport("/messages")
async def handle_sse(request):
"""Handle SSE connection endpoint."""
try:
# Get query parameters for session management
query_params = request.query_params
logger.debug(f"SSE connection request with params: {dict(query_params)}")
async with sse_transport.connect_sse(request) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()
)
except Exception as e:
logger.error(f"SSE connection error: {e}")
raise
async def handle_messages(request):
"""Handle HTTP POST messages endpoint."""
try:
# Handle the POST request for sending messages
return await sse_transport.handle_post_message(request)
except Exception as e:
logger.error(f"Message handling error: {e}")
return JSONResponse(
{"error": f"Message handling failed: {str(e)}"}, status_code=500
)
async def health_check(request):
"""Health check endpoint."""
return JSONResponse(
{
"status": "healthy",
"server": "namechecker-mcp",
"transport": "sse",
"port": port,
}
)
# Create Starlette application
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse, methods=["GET"]),
Route("/messages", endpoint=handle_messages, methods=["POST"]),
Route("/health", endpoint=health_check, methods=["GET"]),
]
)
# Add CORS middleware for web clients
from starlette.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
# Configure and run the server
config = uvicorn.Config(
app,
host="127.0.0.1", # Bind to localhost for security
port=port,
log_level="info",
access_log=True,
)
server_instance = uvicorn.Server(config)
logger.info(f"SSE endpoints available at:")
logger.info(f" - SSE stream: http://127.0.0.1:{port}/sse")
logger.info(f" - Messages: http://127.0.0.1:{port}/messages")
logger.info(f" - Health: http://127.0.0.1:{port}/health")
await server_instance.serve()
async def cleanup():
"""Cleanup resources on shutdown."""
global domain_checker
if domain_checker:
await domain_checker.client.aclose()
def parse_arguments():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="NameChecker MCP Server - Domain availability checking for AI assistants"
)
parser.add_argument(
"--transport",
choices=["stdio", "sse"],
default="stdio",
help="Transport protocol to use (default: stdio)",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port number for SSE transport (default: 8000)",
)
parser.add_argument(
"--log-level",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Logging verbosity (default: INFO)",
)
parser.add_argument(
"--timeout",
type=int,
default=30,
help="Request timeout in seconds (default: 30)",
)
return parser.parse_args()
async def main():
"""Main entry point."""
args = parse_arguments()
# Configure logging level
logging.getLogger().setLevel(getattr(logging, args.log_level))
# Initialize global domain checker with timeout
global domain_checker
domain_checker = DomainChecker(timeout=args.timeout)
try:
if args.transport == "sse":
await run_sse(args.port)
else:
await run_stdio()
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down...")
except Exception as e:
logger.error(f"Server error: {e}")
sys.exit(1)
finally:
await cleanup()
if __name__ == "__main__":
asyncio.run(main())