server.py•31 kB
"""
Enhanced MCP Server for Snowflake - Compatible with LangGraph Agent Architecture
This server provides specialized tools for each agent archetype and supports
session management, caching, and feedback collection.
"""
import asyncio
import logging
import time
import json
import functools
from typing import Any, Dict, List, Optional, Tuple
from datetime import datetime, timedelta
from collections import defaultdict
from dataclasses import dataclass, field
from mcp.server.fastmcp import FastMCP
from mcp import types
from .config import load_config
from .db_client import SnowflakeDB
logger = logging.getLogger("aai_mcp_snowflake_server")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
mcp_app = FastMCP("aai-mcp-snowflake-enhanced")
mcp_app.state = {}
# ============================================================================
# Session Management
# ============================================================================
@dataclass
class SessionState:
"""Maintains state for a client session"""
session_id: str
created_at: datetime
last_accessed: datetime
query_history: List[Dict[str, Any]] = field(default_factory=list)
cache: Dict[str, Any] = field(default_factory=dict)
feedback_data: List[Dict[str, Any]] = field(default_factory=list)
agent_context: Dict[str, Any] = field(default_factory=dict)
class SessionManager:
"""Manages client sessions with state persistence"""
def __init__(self, ttl_minutes: int = 60):
self.sessions: Dict[str, SessionState] = {}
self.ttl = timedelta(minutes=ttl_minutes)
def get_or_create(self, session_id: str) -> SessionState:
"""Get existing session or create new one"""
if session_id in self.sessions:
session = self.sessions[session_id]
session.last_accessed = datetime.now()
else:
session = SessionState(
session_id=session_id,
created_at=datetime.now(),
last_accessed=datetime.now()
)
self.sessions[session_id] = session
# Clean up old sessions
self._cleanup_expired()
return session
def _cleanup_expired(self):
"""Remove expired sessions"""
now = datetime.now()
expired = [
sid for sid, session in self.sessions.items()
if now - session.last_accessed > self.ttl
]
for sid in expired:
del self.sessions[sid]
# ============================================================================
# Query Cache
# ============================================================================
class QueryCache:
"""Intelligent caching for query results"""
def __init__(self, max_size: int = 100, ttl_seconds: int = 300):
self.cache: Dict[str, Tuple[Any, datetime]] = {}
self.max_size = max_size
self.ttl = timedelta(seconds=ttl_seconds)
self.hit_count = defaultdict(int)
def get(self, key: str) -> Optional[Any]:
"""Get cached result if available and not expired"""
if key in self.cache:
result, timestamp = self.cache[key]
if datetime.now() - timestamp < self.ttl:
self.hit_count[key] += 1
return result
else:
del self.cache[key]
return None
def set(self, key: str, value: Any):
"""Cache a result"""
# Implement LRU if at capacity
if len(self.cache) >= self.max_size:
# Remove least recently used
lru_key = min(self.cache.keys(),
key=lambda k: self.hit_count[k])
del self.cache[lru_key]
del self.hit_count[lru_key]
self.cache[key] = (value, datetime.now())
# ============================================================================
# Enhanced Application Context
# ============================================================================
class EnhancedAppContext:
"""Enhanced context with session management and caching"""
def __init__(self, db: SnowflakeDB):
self.db = db
self.session_manager = SessionManager()
self.query_cache = QueryCache()
self.training_data = []
def get_session(self, session_id: str) -> SessionState:
"""Get or create session state"""
return self.session_manager.get_or_create(session_id)
# ============================================================================
# Utility Functions
# ============================================================================
def text_block(txt: str) -> list[types.TextContent]:
"""Create text content block"""
return [types.TextContent(type="text", text=txt)]
def timing(fn):
"""Decorator to time function execution"""
@functools.wraps(fn)
async def wrapper(*a, **kw):
start = time.perf_counter()
try:
result = await fn(*a, **kw)
elapsed = time.perf_counter() - start
logger.debug("%s took %.3fs", fn.__name__, elapsed)
# Add timing to session if available
if 'session_id' in kw:
ctx: EnhancedAppContext = mcp_app.state["ctx"]
session = ctx.get_session(kw['session_id'])
session.query_history.append({
'tool': fn.__name__,
'elapsed': elapsed,
'timestamp': datetime.now().isoformat()
})
return result
except Exception as e:
logger.exception("Tool error in %s", fn.__name__)
return text_block(f"Error in {fn.__name__}: {e}")
return wrapper
def with_cache(cache_prefix: str):
"""Decorator to add caching to tools"""
def decorator(fn):
@functools.wraps(fn)
async def wrapper(*a, **kw):
ctx: EnhancedAppContext = mcp_app.state["ctx"]
# Create cache key
cache_key = f"{cache_prefix}:{json.dumps(kw, sort_keys=True)}"
# Check cache
cached = ctx.query_cache.get(cache_key)
if cached is not None:
logger.debug("Cache hit for %s", fn.__name__)
return cached
# Execute and cache
result = await fn(*a, **kw)
ctx.query_cache.set(cache_key, result)
return result
return wrapper
return decorator
# ============================================================================
# Core Database Tools
# ============================================================================
@mcp_app.tool()
@timing
async def list_databases(session_id: Optional[str] = None) -> str:
"""List all available Snowflake databases."""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
dbs = await ctx.db.list_databases()
return "\n".join(str(x) for x in dbs)
@mcp_app.tool()
@timing
async def list_schemas(
database: Optional[str] = None,
session_id: Optional[str] = None
) -> str:
"""List schemas in a database."""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
schemas = await ctx.db.list_schemas(database)
return "\n".join(schemas)
@mcp_app.tool()
@timing
@with_cache("tables")
async def list_tables(
database: Optional[str] = None,
schema: Optional[str] = None,
session_id: Optional[str] = None
) -> str:
"""List tables in a schema."""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
tables = await ctx.db.list_tables(database, schema)
return "\n".join(tables)
@mcp_app.tool()
@timing
async def run_query(
sql: str,
session_id: Optional[str] = None
) -> str:
"""Execute a SELECT query (read-only)."""
if not sql.strip().lower().startswith("select"):
return "Only SELECT queries allowed for safety."
ctx: EnhancedAppContext = mcp_app.state["ctx"]
rows = await ctx.db.query(sql)
# Convert and limit results
out = [dict(r.asDict()) for r in rows][:1000]
# Track in session if available
if session_id:
session = ctx.get_session(session_id)
session.query_history.append({
'query': sql,
'row_count': len(out),
'timestamp': datetime.now().isoformat()
})
return json.dumps(out, indent=2, default=str)
# ============================================================================
# Analyst Agent Tools
# ============================================================================
@mcp_app.tool()
@timing
@with_cache("usage_analysis")
async def analyze_usage(
time_period: str = "30_days",
business_unit: Optional[str] = None,
limit: int = 100,
session_id: Optional[str] = None
) -> str:
"""
Analyze usage patterns from AAI_USAGE table.
Performs EDA on user access patterns, table usage, and trends.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
# Build query based on parameters
where_clauses = []
if business_unit:
where_clauses.append(f"business_unit_name = '{business_unit}'")
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
sql = f"""
SELECT
user_name,
object_name,
user_type,
business_unit_name,
SUM(access_count) as total_accesses,
COUNT(DISTINCT query_id) as unique_queries,
SUM(total_compute_cost) as total_cost,
MAX(query_date) as last_accessed
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_USAGE
{where_sql}
GROUP BY 1, 2, 3, 4
ORDER BY total_accesses DESC
LIMIT {limit}
"""
rows = await ctx.db.query(sql)
results = [dict(r.asDict()) for r in rows]
# Generate analysis summary
analysis = {
"period": time_period,
"top_users": results[:10] if results else [],
"total_unique_users": len(set(r['user_name'] for r in results)),
"total_objects_accessed": len(set(r['object_name'] for r in results)),
"results": results
}
return json.dumps(analysis, indent=2, default=str)
@mcp_app.tool()
@timing
async def get_table_statistics(
table_name: str,
session_id: Optional[str] = None
) -> str:
"""
Get detailed statistics for a specific table from AAI_PROFILER.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
sql = f"""
SELECT
column_name,
data_type,
pct_null,
distinct_count,
MIN_VAL,
MAX_VAL,
top_n_freq
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_PROFILER
WHERE UPPER(table_name) = UPPER('{table_name}')
ORDER BY column_name
"""
rows = await ctx.db.query(sql)
return json.dumps([dict(r.asDict()) for r in rows], indent=2, default=str)
# ============================================================================
# Lineage Expert Tools
# ============================================================================
@mcp_app.tool()
@timing
@with_cache("lineage")
async def get_lineage(
table_name: str,
direction: str = "upstream",
depth: int = 3,
session_id: Optional[str] = None
) -> str:
"""
Get data lineage for a table (upstream sources or downstream targets).
Recursively traces lineage up to specified depth.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
lineage_tree = {}
visited = set()
async def trace_lineage(table: str, current_depth: int):
if current_depth > depth or table in visited:
return
visited.add(table)
if direction == "upstream":
sql = f"""
SELECT DISTINCT src_feature_nm as related_table
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_LINEAGE
WHERE UPPER(tgt_feature_nm) = UPPER('{table}')
"""
else:
sql = f"""
SELECT DISTINCT tgt_feature_nm as related_table
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_LINEAGE
WHERE UPPER(src_feature_nm) = UPPER('{table}')
"""
rows = await ctx.db.query(sql)
related_tables = [r['related_table'] for r in rows]
lineage_tree[table] = {
'level': current_depth,
'related_tables': related_tables
}
# Recursively trace
for related in related_tables:
await trace_lineage(related, current_depth + 1)
await trace_lineage(table_name, 1)
return json.dumps({
'table': table_name,
'direction': direction,
'depth': depth,
'lineage': lineage_tree
}, indent=2)
@mcp_app.tool()
@timing
async def get_impact_analysis(
table_name: str,
session_id: Optional[str] = None
) -> str:
"""
Analyze the impact of changes to a table by examining all downstream dependencies.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
# Get direct downstream tables
sql = f"""
SELECT
tgt_feature_nm as downstream_table,
COUNT(DISTINCT query_id) as transformation_count
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_LINEAGE
WHERE UPPER(src_feature_nm) = UPPER('{table_name}')
GROUP BY 1
"""
rows = await ctx.db.query(sql)
direct_impact = [dict(r.asDict()) for r in rows]
# Get usage info for impacted tables
if direct_impact:
tables = "','".join([t['downstream_table'] for t in direct_impact])
usage_sql = f"""
SELECT
object_name,
COUNT(DISTINCT user_name) as affected_users,
SUM(access_count) as total_accesses
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_USAGE
WHERE object_name IN ('{tables}')
GROUP BY 1
"""
usage_rows = await ctx.db.query(usage_sql)
usage_map = {r['object_name']: dict(r.asDict()) for r in usage_rows}
# Combine impact and usage
for item in direct_impact:
table = item['downstream_table']
if table in usage_map:
item.update(usage_map[table])
return json.dumps({
'source_table': table_name,
'direct_impact_count': len(direct_impact),
'impacted_tables': direct_impact
}, indent=2, default=str)
# ============================================================================
# Usage Auditor Tools
# ============================================================================
@mcp_app.tool()
@timing
async def identify_heavy_users(
metric: str = "compute_cost",
top_n: int = 10,
session_id: Optional[str] = None
) -> str:
"""
Identify users with highest resource consumption.
Metrics: compute_cost, access_count, query_count
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
metric_column = {
"compute_cost": "SUM(total_compute_cost)",
"access_count": "SUM(access_count)",
"query_count": "COUNT(DISTINCT query_id)"
}.get(metric, "SUM(total_compute_cost)")
sql = f"""
SELECT
user_name,
user_type,
business_unit_name,
{metric_column} as metric_value,
COUNT(DISTINCT object_name) as unique_objects,
COUNT(DISTINCT query_id) as total_queries
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_USAGE
GROUP BY 1, 2, 3
ORDER BY metric_value DESC
LIMIT {top_n}
"""
rows = await ctx.db.query(sql)
return json.dumps({
'metric': metric,
'top_users': [dict(r.asDict()) for r in rows]
}, indent=2, default=str)
@mcp_app.tool()
@timing
async def detect_usage_anomalies(
lookback_days: int = 7,
session_id: Optional[str] = None
) -> str:
"""
Detect anomalies in usage patterns compared to historical baseline.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
# Get recent usage patterns
sql = f"""
WITH daily_usage AS (
SELECT
DATE(query_date) as usage_date,
object_name,
SUM(access_count) as daily_accesses,
COUNT(DISTINCT user_name) as unique_users
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_USAGE
WHERE query_date >= CURRENT_DATE() - {lookback_days}
GROUP BY 1, 2
),
stats AS (
SELECT
object_name,
AVG(daily_accesses) as avg_accesses,
STDDEV(daily_accesses) as stddev_accesses,
MAX(daily_accesses) as max_accesses,
MIN(daily_accesses) as min_accesses
FROM daily_usage
GROUP BY 1
)
SELECT
d.usage_date,
d.object_name,
d.daily_accesses,
s.avg_accesses,
CASE
WHEN d.daily_accesses > s.avg_accesses + (2 * s.stddev_accesses)
THEN 'spike'
WHEN d.daily_accesses < s.avg_accesses - (2 * s.stddev_accesses)
THEN 'drop'
ELSE 'normal'
END as anomaly_type
FROM daily_usage d
JOIN stats s ON d.object_name = s.object_name
WHERE d.daily_accesses > s.avg_accesses + (2 * s.stddev_accesses)
OR d.daily_accesses < s.avg_accesses - (2 * s.stddev_accesses)
ORDER BY d.usage_date DESC, d.daily_accesses DESC
LIMIT 50
"""
rows = await ctx.db.query(sql)
anomalies = [dict(r.asDict()) for r in rows]
return json.dumps({
'lookback_days': lookback_days,
'anomaly_count': len(anomalies),
'anomalies': anomalies
}, indent=2, default=str)
# ============================================================================
# Query Optimizer Tools
# ============================================================================
@mcp_app.tool()
@timing
async def analyze_slow_queries(
threshold_seconds: int = 10,
limit: int = 20,
session_id: Optional[str] = None
) -> str:
"""
Identify and analyze slow-running queries for optimization opportunities.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
sql = f"""
SELECT
query_id,
uid as user_name,
total_elapsed_time,
rows_produced,
bytes_scanned,
warehouse_name,
query_text,
direct_objects,
base_objects
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_SQL_ANALYZER
WHERE total_elapsed_time > {threshold_seconds * 1000}
ORDER BY total_elapsed_time DESC
LIMIT {limit}
"""
rows = await ctx.db.query(sql)
slow_queries = [dict(r.asDict()) for r in rows]
# Add optimization suggestions
for query in slow_queries:
suggestions = []
# Check for common issues
if query.get('bytes_scanned', 0) > 1e10: # > 10GB
suggestions.append("Consider adding filters to reduce data scanned")
if 'SELECT *' in query.get('query_text', '').upper():
suggestions.append("Replace SELECT * with specific columns")
if query.get('rows_produced', 0) > 1e6: # > 1M rows
suggestions.append("Consider pagination or aggregation")
query['optimization_suggestions'] = suggestions
return json.dumps({
'threshold_seconds': threshold_seconds,
'slow_query_count': len(slow_queries),
'queries': slow_queries
}, indent=2, default=str)
@mcp_app.tool()
@timing
@with_cache("query_patterns")
async def identify_query_patterns(
min_frequency: int = 5,
session_id: Optional[str] = None
) -> str:
"""
Identify common query patterns and suggest optimizations or data products.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
sql = f"""
WITH pattern_analysis AS (
SELECT
query_parameterized_hash,
COUNT(*) as execution_count,
AVG(total_elapsed_time) as avg_time,
SUM(bytes_scanned) as total_bytes,
ARRAY_AGG(DISTINCT uid) as users,
ARRAY_AGG(DISTINCT direct_objects) as objects_accessed,
ANY_VALUE(query_text) as sample_query
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_SQL_ANALYZER
GROUP BY 1
HAVING COUNT(*) >= {min_frequency}
)
SELECT *
FROM pattern_analysis
ORDER BY execution_count DESC
LIMIT 50
"""
rows = await ctx.db.query(sql)
patterns = [dict(r.asDict()) for r in rows]
# Analyze patterns for recommendations
recommendations = []
for pattern in patterns:
if pattern['execution_count'] > 100 and pattern['avg_time'] > 5000:
recommendations.append({
'pattern_hash': pattern['query_parameterized_hash'],
'recommendation': 'Create materialized view',
'reason': f"High frequency ({pattern['execution_count']}x) with slow performance"
})
return json.dumps({
'patterns_found': len(patterns),
'patterns': patterns,
'recommendations': recommendations
}, indent=2, default=str)
# ============================================================================
# Metadata and Access Control Tools
# ============================================================================
@mcp_app.tool()
@timing
@with_cache("metadata")
async def get_table_metadata(
table_name: str,
include_columns: bool = True,
session_id: Optional[str] = None
) -> str:
"""
Get comprehensive metadata for a table including description, columns, and data product info.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
# Get table metadata
sql = f"""
SELECT DISTINCT
MOTS_ID,
DPG,
ITAP_ID,
db_name,
schema_name,
table_name,
size_in_mb,
is_dp,
object_description,
BM_source
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_MD
WHERE UPPER(object_name) = UPPER('{table_name}')
LIMIT 1
"""
rows = await ctx.db.query(sql)
metadata = dict(rows[0].asDict()) if rows else {}
# Get column information if requested
if include_columns and metadata:
col_sql = f"""
SELECT
attribute_name,
attribute_definition
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_MD
WHERE UPPER(object_name) = UPPER('{table_name}')
AND attribute_name IS NOT NULL
ORDER BY attribute_name
"""
col_rows = await ctx.db.query(col_sql)
metadata['columns'] = [dict(r.asDict()) for r in col_rows]
return json.dumps(metadata, indent=2, default=str)
@mcp_app.tool()
@timing
async def get_access_permissions(
table_name: str,
session_id: Optional[str] = None
) -> str:
"""
Get access permissions and roles required for a table.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
sql = f"""
SELECT
upstart_role,
upstart_role_description,
platform,
database_role,
privilege_list
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_ACCESS
WHERE UPPER(object_name) = UPPER('{table_name}')
ORDER BY upstart_role
"""
rows = await ctx.db.query(sql)
return json.dumps({
'table': table_name,
'access_permissions': [dict(r.asDict()) for r in rows]
}, indent=2, default=str)
# ============================================================================
# Recommendation Engine Tools
# ============================================================================
@mcp_app.tool()
@timing
async def recommend_data_products(
analysis_scope: str = "usage_based",
session_id: Optional[str] = None
) -> str:
"""
Recommend new data products based on usage patterns and query analysis.
Scopes: usage_based, performance_based, lineage_based
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
recommendations = []
if analysis_scope in ["usage_based", "all"]:
# Find frequently joined tables
sql = """
WITH table_combinations AS (
SELECT
ARRAY_TO_STRING(ARRAY_SORT(direct_objects), ',') as table_combo,
COUNT(*) as query_count,
AVG(total_elapsed_time) as avg_time,
COUNT(DISTINCT uid) as unique_users
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_SQL_ANALYZER
WHERE ARRAY_SIZE(direct_objects) > 1
GROUP BY 1
HAVING COUNT(*) > 10
)
SELECT *
FROM table_combinations
ORDER BY query_count DESC
LIMIT 10
"""
rows = await ctx.db.query(sql)
for r in rows:
row = dict(r.asDict())
if row['query_count'] > 50:
recommendations.append({
'type': 'new_data_product',
'reason': 'frequently_joined_tables',
'tables': row['table_combo'],
'usage_stats': {
'query_count': row['query_count'],
'unique_users': row['unique_users']
},
'priority': 'high' if row['query_count'] > 100 else 'medium'
})
return json.dumps({
'analysis_scope': analysis_scope,
'recommendations': recommendations
}, indent=2, default=str)
# ============================================================================
# Session and Feedback Management
# ============================================================================
@mcp_app.tool()
@timing
async def save_feedback(
session_id: str,
query: str,
response: str,
feedback_type: str,
feedback_text: Optional[str] = None
) -> str:
"""
Save user feedback for training data collection.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
session = ctx.get_session(session_id)
feedback_data = {
'timestamp': datetime.now().isoformat(),
'query': query,
'response': response,
'feedback_type': feedback_type,
'feedback_text': feedback_text
}
session.feedback_data.append(feedback_data)
# Also save to training data if positive
if feedback_type == 'positive':
ctx.training_data.append(feedback_data)
return json.dumps({
'status': 'success',
'feedback_saved': True
}, indent=2)
@mcp_app.tool()
@timing
async def get_session_history(
session_id: str,
include_cache_stats: bool = False
) -> str:
"""
Get query history and statistics for a session.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
session = ctx.get_session(session_id)
history = {
'session_id': session_id,
'created_at': session.created_at.isoformat(),
'query_count': len(session.query_history),
'queries': session.query_history[-10:], # Last 10 queries
'feedback_count': len(session.feedback_data)
}
if include_cache_stats:
history['cache_stats'] = {
'cache_size': len(ctx.query_cache.cache),
'total_hits': sum(ctx.query_cache.hit_count.values())
}
return json.dumps(history, indent=2, default=str)
# ============================================================================
# SQL Usage Analysis Tool (Legacy Compatibility)
# ============================================================================
@mcp_app.tool()
@timing
async def analyze_sql_usage(
limit: int = 1000,
session_id: Optional[str] = None
) -> str:
"""
Analyze SQL usage from the AAI_SQL_ANALYZER table (legacy compatibility).
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
sql = f"""
SELECT
query_id,
uid,
vp_uid,
start_time,
query_text,
total_elapsed_time,
rows_produced,
bytes_scanned,
warehouse_name,
role_name,
query_tag,
query_parameterized_hash,
direct_objects,
base_objects,
columns_accessed
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_SQL_ANALYZER
LIMIT {limit}
"""
rows = await ctx.db.query(sql)
rows_dict = [dict(r.asDict()) for r in rows]
# Use analyzer logic from query_utils
from .query_utils import explode_usage, compact_usage
usage_rows = explode_usage(rows_dict)
summary = compact_usage(usage_rows)
return json.dumps(summary, indent=2, default=str)
# ============================================================================
# Advanced Analysis Tools
# ============================================================================
@mcp_app.tool()
@timing
async def analyze_data_quality(
table_name: str,
session_id: Optional[str] = None
) -> str:
"""
Analyze data quality metrics for a table using profiling data.
"""
ctx: EnhancedAppContext = mcp_app.state["ctx"]
# Get profiling data
sql = f"""
SELECT
column_name,
data_type,
pct_null,
distinct_count,
CASE
WHEN pct_null > 50 THEN 'high_nulls'
WHEN pct_null > 20 THEN 'moderate_nulls'
ELSE 'low_nulls'
END as null_category,
CASE
WHEN distinct_count = 1 THEN 'constant'
WHEN distinct_count < 10 THEN 'low_cardinality'
WHEN distinct_count < 100 THEN 'medium_cardinality'
ELSE 'high_cardinality'
END as cardinality_category
FROM AZDMAND.SDW_DPE_DASH_DB.AAI_PROFILER
WHERE UPPER(table_name) = UPPER('{table_name}')
"""
rows = await ctx.db.query(sql)
columns = [dict(r.asDict()) for r in rows]
# Calculate quality score
if columns:
avg_null_pct = sum(c.get('pct_null', 0) for c in columns) / len(columns)
quality_score = max(0, 100 - avg_null_pct)
issues = []
if avg_null_pct > 30:
issues.append("High average null percentage")
constant_cols = [c['column_name'] for c in columns if c.get('cardinality_category') == 'constant']
if constant_cols:
issues.append(f"Constant columns detected: {', '.join(constant_cols)}")
else:
quality_score = 0
issues = ["No profiling data available"]
return json.dumps({
'table': table_name,
'quality_score': quality_score,
'column_count': len(columns),
'columns': columns,
'issues': issues
}, indent=2, default=str)
# ============================================================================
# Server Initialization
# ============================================================================
async def start_http(host: str, port: int, path: str):
"""Initialize and start the HTTP server"""
cfg = load_config()
db = SnowflakeDB(cfg)
await db.ensure()
# Initialize enhanced context
mcp_app.state["ctx"] = EnhancedAppContext(db)
logger.info("Enhanced MCP Server initialized with %d tools", len(mcp_app._tools))
logger.info("Session management and caching enabled")
try:
await mcp_app.run_streamable_http_async()
finally:
await db.close()
async def main(host: str = "0.0.0.0", port: int = 8000, path: str = "/mcp"):
"""Main entry point for the server"""
logger.info(f"Starting Enhanced MCP Server on {host}:{port}{path}")
await start_http(host, port, path)
def run():
"""Run the server"""
asyncio.run(main())
if __name__ == "__main__":
run()