get_correlation_matrix
Calculate correlation matrix to identify relationships between numeric columns in CSV data for statistical analysis.
Instructions
Calculate correlation matrix for numeric columns.
Input Schema
TableJSON Schema
| Name | Required | Description | Default |
|---|---|---|---|
| session_id | Yes | ||
| method | No | pearson | |
| columns | No | ||
| min_correlation | No |
Implementation Reference
- Core handler function that loads the CSV session data, selects numeric columns, computes the correlation matrix using pandas.corr(), filters by min_correlation if specified, identifies high correlations, and returns formatted results.async def get_correlation_matrix( session_id: str, method: str = "pearson", columns: Optional[List[str]] = None, min_correlation: Optional[float] = None, ctx: Context = None ) -> Dict[str, Any]: """ Calculate correlation matrix for numeric columns. Args: session_id: Session identifier method: Correlation method ('pearson', 'spearman', 'kendall') columns: Specific columns to include (None for all numeric) min_correlation: Filter to show only correlations above this threshold ctx: FastMCP context Returns: Dict with correlation matrix """ try: manager = get_session_manager() session = manager.get_session(session_id) if not session or session.df is None: return {"success": False, "error": "Invalid session or no data loaded"} df = session.df # Select columns if columns: missing_cols = [col for col in columns if col not in df.columns] if missing_cols: return {"success": False, "error": f"Columns not found: {missing_cols}"} numeric_df = df[columns].select_dtypes(include=[np.number]) else: numeric_df = df.select_dtypes(include=[np.number]) if numeric_df.empty: return {"success": False, "error": "No numeric columns found"} if len(numeric_df.columns) < 2: return {"success": False, "error": "Need at least 2 numeric columns for correlation"} # Calculate correlation if method not in ['pearson', 'spearman', 'kendall']: return {"success": False, "error": f"Invalid method: {method}"} corr_matrix = numeric_df.corr(method=method) # Convert to dict format correlations = {} for col1 in corr_matrix.columns: correlations[col1] = {} for col2 in corr_matrix.columns: value = corr_matrix.loc[col1, col2] if not pd.isna(value): if min_correlation is None or abs(value) >= min_correlation or col1 == col2: correlations[col1][col2] = round(float(value), 4) # Find highly correlated pairs high_correlations = [] for i, col1 in enumerate(corr_matrix.columns): for col2 in corr_matrix.columns[i+1:]: corr_value = corr_matrix.loc[col1, col2] if not pd.isna(corr_value) and abs(corr_value) >= 0.7: high_correlations.append({ "column1": col1, "column2": col2, "correlation": round(float(corr_value), 4) }) high_correlations.sort(key=lambda x: abs(x["correlation"]), reverse=True) session.record_operation(OperationType.ANALYZE, { "type": "correlation", "method": method, "columns": list(corr_matrix.columns) }) return { "success": True, "method": method, "correlation_matrix": correlations, "high_correlations": high_correlations, "columns_analyzed": list(corr_matrix.columns) } except Exception as e: logger.error(f"Error calculating correlation: {str(e)}") return {"success": False, "error": str(e)}
- src/csv_editor/server.py:337-345 (registration)MCP tool registration using @mcp.tool decorator. This wrapper function defines the tool interface and delegates to the analytics implementation.async def get_correlation_matrix( session_id: str, method: str = "pearson", columns: Optional[List[str]] = None, min_correlation: Optional[float] = None, ctx: Context = None ) -> Dict[str, Any]: """Calculate correlation matrix for numeric columns.""" return await _get_correlation_matrix(session_id, method, columns, min_correlation, ctx)
- src/csv_editor/server.py:307-315 (registration)Import of the get_correlation_matrix implementation from analytics module, aliased for use in the server tool wrappers.from .tools.analytics import ( get_statistics as _get_statistics, get_column_statistics as _get_column_statistics, get_correlation_matrix as _get_correlation_matrix, group_by_aggregate as _group_by_aggregate, get_value_counts as _get_value_counts, detect_outliers as _detect_outliers, profile_data as _profile_data )