create_graph
Generate graphs and plots from data files using matplotlib/seaborn. Specify columns for axes, choose graph types like scatter or line, and save visualizations to files for analysis.
Instructions
Create a graph/plot from data using matplotlib/seaborn.
Args: file_path: Path to the data file x_column: Column name for x-axis (must be numeric) y_column: Column name for y-axis (must be numeric) output_path: Path where to save the graph image graph_type: Type of graph (scatter, line, bar, histogram) category_column: Optional categorical column for grouping/coloring
Returns: Information about the created graph
Input Schema
| Name | Required | Description | Default |
|---|---|---|---|
| file_path | Yes | ||
| x_column | Yes | ||
| y_column | Yes | ||
| output_path | Yes | ||
| graph_type | No | scatter | |
| category_column | No |
Implementation Reference
- src/visidata_mcp/server.py:555-684 (handler)The complete create_graph tool handler function that creates scatter, line, bar, or histogram plots from data using matplotlib/seaborn. It validates inputs, loads data, handles categorical grouping, and saves the plot to a file.
@mcp.tool() def create_graph(file_path: str, x_column: str, y_column: str, output_path: str, graph_type: str = "scatter", category_column: Optional[str] = None) -> str: """ Create a graph/plot from data using matplotlib/seaborn. Args: file_path: Path to the data file x_column: Column name for x-axis (must be numeric) y_column: Column name for y-axis (must be numeric) output_path: Path where to save the graph image graph_type: Type of graph (scatter, line, bar, histogram) category_column: Optional categorical column for grouping/coloring Returns: Information about the created graph """ try: if not VISUALIZATION_AVAILABLE: return f"Error: {VISUALIZATION_ERROR}" import pandas as pd from pathlib import Path # Load the data file_extension = Path(file_path).suffix.lower() if file_extension == '.csv': df = pd.read_csv(file_path) elif file_extension == '.json': df = pd.read_json(file_path) elif file_extension in ['.xlsx', '.xls']: df = pd.read_excel(file_path) elif file_extension == '.tsv': df = pd.read_csv(file_path, sep='\t') else: df = pd.read_csv(file_path) # Validate columns exist if x_column not in df.columns: return f"Error: Column '{x_column}' not found in data" if y_column not in df.columns: return f"Error: Column '{y_column}' not found in data" if category_column and category_column not in df.columns: return f"Error: Category column '{category_column}' not found in data" # Ensure numeric columns are properly typed try: df[x_column] = pd.to_numeric(df[x_column], errors='coerce') df[y_column] = pd.to_numeric(df[y_column], errors='coerce') except: return f"Error: Could not convert {x_column} or {y_column} to numeric values" # Remove rows with NaN values in plotting columns plot_columns = [x_column, y_column] if category_column: plot_columns.append(category_column) df_clean = df[plot_columns].dropna() if len(df_clean) == 0: return "Error: No valid data points for plotting after removing NaN values" # Create the plot plt.figure(figsize=(10, 6)) if graph_type == "scatter": if category_column: sns.scatterplot(data=df_clean, x=x_column, y=y_column, hue=category_column, alpha=0.7) else: plt.scatter(df_clean[x_column], df_clean[y_column], alpha=0.7) elif graph_type == "line": if category_column: sns.lineplot(data=df_clean, x=x_column, y=y_column, hue=category_column) else: plt.plot(df_clean[x_column], df_clean[y_column]) elif graph_type == "bar": if category_column: # Group by category and take mean of y values for each x value grouped = df_clean.groupby([x_column, category_column])[y_column].mean().reset_index() sns.barplot(data=grouped, x=x_column, y=y_column, hue=category_column) else: grouped = df_clean.groupby(x_column)[y_column].mean() plt.bar(grouped.index, grouped.values) elif graph_type == "histogram": if category_column: for category in df_clean[category_column].unique(): subset = df_clean[df_clean[category_column] == category] plt.hist(subset[y_column], alpha=0.7, label=str(category), bins=20) plt.legend() else: plt.hist(df_clean[y_column], bins=20, alpha=0.7) plt.xlabel(y_column) plt.ylabel('Frequency') else: return f"Error: Unsupported graph type '{graph_type}'. Use: scatter, line, bar, histogram" # Set labels and title plt.xlabel(x_column.replace('_', ' ').title()) plt.ylabel(y_column.replace('_', ' ').title()) title = f"{graph_type.title()} Plot: {y_column} vs {x_column}" if category_column: title += f" (grouped by {category_column})" plt.title(title) # Add grid for better readability plt.grid(True, alpha=0.3) # Adjust layout to prevent label cutoff plt.tight_layout() # Save the plot plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() result = { "graph_created": True, "graph_type": graph_type, "x_column": x_column, "y_column": y_column, "category_column": category_column, "data_points": len(df_clean), "output_file": output_path, "file_size": Path(output_path).stat().st_size if Path(output_path).exists() else 0 } return json.dumps(result, indent=2) except Exception as e: return f"Error creating graph: {str(e)}\n{traceback.format_exc()}" - src/visidata_mcp/server.py:23-32 (helper)Visualization library imports and availability check. Matplotlib and seaborn are imported with the Agg non-interactive backend to support server-side graph generation.
# Try to import visualization packages early to detect missing dependencies try: import matplotlib matplotlib.use('Agg') # Use non-interactive backend import matplotlib.pyplot as plt import seaborn as sns VISUALIZATION_AVAILABLE = True except ImportError as e: VISUALIZATION_AVAILABLE = False VISUALIZATION_ERROR = f"Visualization libraries not available: {e}. Please install matplotlib and seaborn."