Skip to main content
Glama

StarRocks MCP Server

Official
# Copyright 2021-present StarRocks, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https:#www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import asyncio import math import sys import os import io import time import traceback from fastmcp import FastMCP, Image from fastmcp.exceptions import ToolError from typing import Annotated from pydantic import Field from mysql.connector import Error as MySQLError import mysql.connector import pandas as pd import plotly.express as px import base64 mcp = FastMCP('mcp-server-starrocks') global_connection = None default_database = os.getenv('STARROCKS_DB') # a hint for soft limit, not enforced overview_length_limit = int(os.getenv('STARROCKS_OVERVIEW_LIMIT', str(20000))) # Global cache for table overviews: {(db_name, table_name): overview_string} global_table_overview_cache = {} mcp_transport = os.getenv('MCP_TRANSPORT_MODE', 'stdio') def get_connection(): global global_connection, default_database if global_connection is None: connection_params = { 'host': os.getenv('STARROCKS_HOST', 'localhost'), 'port': os.getenv('STARROCKS_PORT', '9030'), 'user': os.getenv('STARROCKS_USER', 'root'), 'password': os.getenv('STARROCKS_PASSWORD', ''), 'auth_plugin': os.getenv('STARROCKS_MYSQL_AUTH_PLUGIN', 'mysql_native_password') } # Use default_database if set during initial connection attempt if default_database: connection_params['database'] = default_database try: global_connection = mysql.connector.connect(**connection_params) # If connection succeeds without db and default_database is set, try USE DB if 'database' not in connection_params and default_database: try: cursor = global_connection.cursor() cursor.execute(f"USE {default_database}") cursor.close() except MySQLError as db_err: # Warn but don't fail connection if USE DB fails print(f"Warning: Could not switch to default database '{default_database}': {db_err}") except MySQLError as conn_err: # Reset global connection on failure global_connection = None # Re-raise the exception to be caught by callers raise conn_err # Ensure connection is alive, reconnect if not if global_connection is not None: try: if not global_connection.is_connected(): global_connection.reconnect() # Re-apply default database if needed after reconnect if default_database: try: cursor = global_connection.cursor() cursor.execute(f"USE {default_database}") cursor.close() except MySQLError as db_err: print( f"Warning: Could not switch to default database '{default_database}' after reconnect: {db_err}") except MySQLError as check_err: print(f"Connection check/reconnect failed: {check_err}") reset_connection() # Force reset if reconnect fails raise check_err # Raise error to indicate connection failure return global_connection def reset_connection(): global global_connection if global_connection is not None: try: global_connection.close() except Exception as e: print(f"Error closing connection: {e}") # Log error but proceed finally: global_connection = None def _format_rows_to_string(columns, rows, limit=None): """Helper to format rows similar to handle_read_query but without row count.""" output = io.StringIO() def to_csv_line(row): return ",".join( str(item).replace("\"", "\"\"") if isinstance(item, str) else str(item) for item in row) output.write(to_csv_line(columns) + "\n") for row in rows: l = to_csv_line(row) + "\n"; if limit is not None and output.tell() + len(l) > limit: break output.write(l) return output.getvalue() def _get_table_details(conn, db_name, table_name, limit=None): """ Helper function to get description, sample rows, and count for a table. Returns a formatted string. Handles DB errors internally and returns error messages. """ global global_table_overview_cache # Access cache for potential updates output_lines = [] # Use backticks for safety full_table_name = f"`{table_name}`" if db_name: full_table_name = f"`{db_name}`.`{table_name}`" else: # Should ideally not happen if logic is correct, but handle defensively output_lines.append( f"Warning: Database name missing for table '{table_name}'. Using potentially incorrect context.") count = 0 output_lines.append(f"--- Overview for {full_table_name} ---") cursor = None # Initialize cursor to None try: cursor = conn.cursor() # 1. Get Row Count try: # TODO: get estimated row count from statistics if available query = f"SELECT COUNT(*) FROM {full_table_name}" # print(f"Executing: {query}") # Debug cursor.execute(query) count_result = cursor.fetchone() if count_result: count = count_result[0] output_lines.append(f"\nTotal rows: {count}") else: output_lines.append(f"\nCould not determine total row count.") except MySQLError as e: output_lines.append(f"Error getting row count for {full_table_name}: {e}") # 2. Get Columns (DESCRIBE) if count > 0: try: query = f"DESCRIBE {full_table_name}" # print(f"Executing: {query}") # Debug cursor.execute(query) cols = [desc[0] for desc in cursor.description] if cursor.description else [] rows = cursor.fetchall() output_lines.append(f"\nColumns:") if rows: output_lines.append(_format_rows_to_string(cols, rows, limit=limit)) else: output_lines.append("(Could not retrieve column information or table has no columns).") except MySQLError as e: output_lines.append(f"Error getting columns for {full_table_name}: {e}") # If DESCRIBE fails, likely the table doesn't exist or no access, # return early as other queries will also fail. return "\n".join(output_lines) # 3. Get Sample Rows (LIMIT 5) try: query = f"SELECT * FROM {full_table_name} LIMIT 3" # print(f"Executing: {query}") # Debug cursor.execute(query) cols = [desc[0] for desc in cursor.description] if cursor.description else [] rows = cursor.fetchall() output_lines.append(f"\nSample rows (limit 3):") if rows: output_lines.append(_format_rows_to_string(cols, rows, limit=limit)) else: output_lines.append(f"(No rows found in {full_table_name}).") except MySQLError as e: output_lines.append(f"Error getting sample rows for {full_table_name}: {e}") except MySQLError as outer_e: # Catch errors potentially related to cursor creation or initial connection state output_lines.append(f"Database error during overview for {full_table_name}: {outer_e}") reset_connection() # Reset connection on error except Exception as gen_e: output_lines.append(f"Unexpected error during overview for {full_table_name}: {gen_e}") finally: if cursor: try: cursor.close() except Exception as close_err: print(f"Warning: Error closing cursor for {full_table_name}: {close_err}") # Log non-critical error overview_string = "\n".join(output_lines) # Update cache even if there were partial errors, so we cache the error message too cache_key = (db_name, table_name) global_table_overview_cache[cache_key] = overview_string return overview_string def handle_single_column_query(query): # return csv like result set, with column names as first row try: conn = get_connection() cursor = conn.cursor() cursor.execute(query) rows = cursor.fetchall() if rows: # Assuming the desired column is the first one return "\n".join([str(row[0]) for row in rows]) else: return "None" except MySQLError as e: # Catch specific DB errors reset_connection() # Reset connection on DB error return f"Error executing query '{query}': {str(e)}" except Exception as e: # Catch other potential errors return f"Unexpected error executing query '{query}': {str(e)}" finally: if cursor: cursor.close() def read_query(query: Annotated[str, Field(description="SQL query to execute")]) -> str: # return csv like result set, with column names as first row try: conn = get_connection() cursor = conn.cursor() cursor.execute(query) if cursor.description: # Check if there's a result set description columns = [desc[0] for desc in cursor.description] # Get column names rows = cursor.fetchall() output = io.StringIO() # Convert rows to CSV-like format def to_csv_line(row): return ",".join( str(item).replace("\"", "\"\"") if isinstance(item, str) else str(item) for item in row) output.write(to_csv_line(columns) + "\n") # Write column names for row in rows: output.write(to_csv_line(row) + "\n") # Write data rows output.write(f"\n{len(rows)} rows in set\n") return output.getvalue() else: # Handle commands that don't return rows but might have messages (e.g., USE DB) # Or potentially commands that succeeded but produced no results (e.g., SELECT on empty table) # For simplicity, return a message indicating no result set. # More sophisticated handling could check cursor.warning_count etc. return "Query executed successfully, but no result set was returned." except MySQLError as e: # Catch specific DB errors reset_connection() # Reset connection on DB error return f"Error executing query '{query}': {str(e)}" except Exception as e: # Catch other potential errors return f"Unexpected error executing query '{query}': {str(e)}" finally: if cursor: cursor.close() def write_query(query: Annotated[str, Field(description="SQL to execute")]) -> str: try: conn = get_connection() cursor = conn.cursor() start_time = time.time() cursor.execute(query) conn.commit() # Commit changes for DML/DDL affected_rows = cursor.rowcount elapsed_time = time.time() - start_time # Provide a more informative message for DDL/DML if affected_rows >= 0: # rowcount is >= 0 for DML, -1 for DDL or not applicable return f"Query OK, {affected_rows} rows affected ({elapsed_time:.2f} sec)" else: return f"Query OK ({elapsed_time:.2f} sec)" # For DDL or commands where rowcount is not applicable except MySQLError as e: # Catch specific DB errors reset_connection() # Reset connection on DB error try: conn.rollback() # Rollback on error except Exception as rb_err: print(f"Error during rollback: {rb_err}") # Log rollback error return f"Error executing query '{query}': {str(e)}" except Exception as e: # Catch other potential errors return f"Unexpected error executing query '{query}': {str(e)}" finally: if cursor: cursor.close() SR_PROC_DESC = ''' Internal information exposed by StarRocks similar to linux /proc, following are some common paths: '/frontends' Shows the information of FE nodes. '/backends' Shows the information of BE nodes if this SR is non cloud native deployment. '/compute_nodes' Shows the information of CN nodes if this SR is cloud native deployment. '/dbs' Shows the information of databases. '/dbs/<DB_ID>' Shows the information of a database by database ID. '/dbs/<DB_ID>/<TABLE_ID>' Shows the information of tables by database ID. '/dbs/<DB_ID>/<TABLE_ID>/partitions' Shows the information of partitions by database ID and table ID. '/transactions' Shows the information of transactions by database. '/transactions/<DB_ID>' Show the information of transactions by database ID. '/transactions/<DB_ID>/running' Show the information of running transactions by database ID. '/transactions/<DB_ID>/finished' Show the information of finished transactions by database ID. '/jobs' Shows the information of jobs. '/statistic' Shows the statistics of each database. '/tasks' Shows the total number of all generic tasks and the failed tasks. '/cluster_balance' Shows the load balance information. '/routine_loads' Shows the information of Routine Load. '/colocation_group' Shows the information of Colocate Join groups. '/catalog' Shows the information of catalogs. ''' @mcp.resource(uri="starrocks:///databases", name="All Databases", description="List all databases in StarRocks", mime_type="text/plain") def get_all_databases() -> str: return handle_single_column_query("SHOW DATABASES") @mcp.resource(uri="starrocks:///{db}/{table}/schema", name="Table Schema", description="Get the schema of a table using SHOW CREATE TABLE", mime_type="text/plain") def get_table_schema(db: str, table: str) -> str: return handle_single_column_query(f"SHOW CREATE TABLE {db}.{table}") @mcp.resource(uri="starrocks:///{db}/tables", name="Database Tables", description="List all tables in a specific database", mime_type="text/plain") def get_database_tables(db: str) -> str: return handle_single_column_query(f"SHOW TABLES FROM {db}") @mcp.resource(uri="proc:///{path*}", name="System internal information", description=SR_PROC_DESC, mime_type="text/plain") def get_system_internal_information(path: str) -> str: return read_query(f"show proc '{path}'") def validate_plotly_expr(expr: str): """ Validates a string to ensure it represents a single call to a method of the 'px' object, without containing other statements or imports, and ensures its arguments do not contain nested function calls. Args: expr: The string expression to validate. Raises: ValueError: If the expression does not meet the security criteria. SyntaxError: If the expression is not valid Python syntax. """ # 1. Check for valid Python syntax try: tree = ast.parse(expr) except SyntaxError as e: raise SyntaxError(f"Invalid Python syntax in expression: {e}") from e # 2. Check that the tree contains exactly one top-level node (statement/expression) if len(tree.body) != 1: raise ValueError("Expression must be a single statement or expression.") node = tree.body[0] # 3. Check that the single node is an expression if not isinstance(node, ast.Expr): raise ValueError( "Expression must be a single expression, not a statement (like assignment, function definition, import, etc.).") # 4. Get the actual value of the expression and check it's a function call expr_value = node.value if not isinstance(expr_value, ast.Call): raise ValueError("Expression must be a function call.") # 5. Check that the function being called is an attribute lookup (like px.scatter) if not isinstance(expr_value.func, ast.Attribute): raise ValueError("Function call must be on an object attribute (e.g., px.scatter).") # 6. Check that the attribute is being accessed on a simple variable name if not isinstance(expr_value.func.value, ast.Name): raise ValueError("Function call must be on a simple variable name (e.g., px.scatter, not obj.px.scatter).") # 7. Check that the simple variable name is 'px' if expr_value.func.value.id != 'px': raise ValueError("Function call must be on the 'px' object.") # Check positional arguments for i, arg_node in enumerate(expr_value.args): for sub_node in ast.walk(arg_node): if isinstance(sub_node, ast.Call): raise ValueError(f"Positional argument at index {i} contains a disallowed nested function call.") # Check keyword arguments for kw in expr_value.keywords: for sub_node in ast.walk(kw.value): if isinstance(sub_node, ast.Call): keyword_name = kw.arg if kw.arg else '<unknown>' raise ValueError(f"Keyword argument '{keyword_name}' contains a disallowed nested function call.") def query_and_plotly_chart(query: Annotated[str, Field(description="SQL query to execute")], plotly_expr: Annotated[ str, Field( description="a one function call expression, with 2 vars binded: `px` as `import plotly.express as px`, and `df` as dataframe generated by query `plotly_expr` example: `px.scatter(df, x=\"sepal_width\", y=\"sepal_length\", color=\"species\", marginal_y=\"violin\", marginal_x=\"box\", trendline=\"ols\", template=\"simple_white\")`")]): """ Executes an SQL query, creates a Pandas DataFrame, generates a Plotly chart using the provided expression, encodes the chart as a base64 PNG image, and returns it along with optional text. Args: query: The SQL query string to execute. plotly_expr: A Python string expression using 'px' (plotly.express) and 'df' (the DataFrame from the query) to generate a figure. Example: "px.scatter(df, x='col1', y='col2')" Returns: A list containing types.TextContent and types.ImageContent, or just types.TextContent in case of an error or no data. Raises: Exception: Propagates exceptions from database interaction, pandas, plotly expression evaluation, or image generation, after attempting to close the cursor. """ try: conn = get_connection() cursor = conn.cursor() cursor.execute(query) # Check if cursor.description is None (happens for non-SELECT queries) if cursor.description is None: return f'Query "{query}" did not return data suitable for plotting.' column_names = [desc[0] for desc in cursor.description] if cursor.description else [] rows = cursor.fetchall() df = pd.DataFrame(rows, columns=column_names) if df.empty: return 'Query returned no data to plot.' # evaluate the plotly expression using px and df, get result figure as `fig` # SECURITY WARNING: eval() can execute arbitrary code. Only use this if # 'plotly_expr' comes from a trusted source or is heavily sanitized. # In a production scenario with untrusted input, consider safer alternatives # like AST parsing or a restricted execution environment. local_vars = {'df': df} validate_plotly_expr(plotly_expr) fig = eval(plotly_expr, {"px": px}, local_vars) # Pass px in globals, df in locals if not hasattr(fig, 'to_image'): raise ToolError(f"The evaluated expression did not return a Plotly figure object. Result type: {type(fig)}") img_bytes = fig.to_image(format='jpg', width=960, height=720) # save to tmp file for debugging # with open("chart.jpg", "wb") as f: # f.write(img_bytes) # base64 encode the image bytes img_base64_bytes = base64.b64encode(img_bytes) # Decode bytes to utf-8 string for easier handling (e.g., JSON serialization) img_base64_string = img_base64_bytes.decode('utf-8') return [ f'dataframe data:\n{df}\nChart generated but for UI only', Image(data=img_base64_string, format="jpg") ] except (MySQLError, pd.errors.EmptyDataError) as db_pd_err: # Handle DB or Pandas specific errors gracefully return [f'Error during data fetching or processing: {db_pd_err}'] except Exception as eval_err: # Handle errors during eval or image generation return [f'Error during chart generation: {eval_err}'] finally: # Ensure the cursor is always closed if cursor: cursor.close() async def table_overview( table: Annotated[str, Field( description="Table name, optionally prefixed with database name (e.g., 'db_name.table_name'). If database is omitted, uses the default database.")], refresh: Annotated[ bool, Field(description="Set to true to force refresh, ignoring cache. Defaults to false.")] = False ) -> str: try: conn = get_connection() if not table: return "Error: Missing 'table' argument." # Parse table argument: [db.]<table> parts = table.split('.', 1) db_name = None table_name = None if len(parts) == 2: db_name, table_name = parts[0], parts[1] elif len(parts) == 1: table_name = parts[0] db_name = default_database # Use default if only table name is given if not table_name: # Should not happen if table_arg exists, but check return f"Error: Invalid table name format '{table}'." if not db_name: return f"Error: Database name not specified for table '{table_name}' and no default database is set." cache_key = (db_name, table_name) # Check cache if not refresh and cache_key in global_table_overview_cache: # print(f"Cache hit for table overview: {cache_key}") # Debug return global_table_overview_cache[cache_key] # Fetch details (will also update cache) # print(f"Cache miss or refresh for table overview: {cache_key}") # Debug overview_text = _get_table_details(conn, db_name, table_name, limit=overview_length_limit) return overview_text except MySQLError as e: # Catch DB errors at tool call level reset_connection() return f"Database Error executing tool 'table_overview': {type(e).__name__}: {e}" except Exception as e: # Catch any other unexpected errors during tool execution reset_connection() # Also reset connection on unexpected errors stack_trace = traceback.format_exc() return f"Unexpected Error executing tool 'table_overview': {type(e).__name__}: {e}\nStack Trace:\n{stack_trace}" async def db_overview( db: Annotated[str, Field( description="Database name. Optional: uses the default database if not provided.")] = default_database, refresh: Annotated[ bool, Field(description="Set to true to force refresh, ignoring cache. Defaults to false.")] = False ) -> str: try: conn = get_connection() db_name = db if db else default_database if not db_name: return "Error: Database name not provided and no default database is set." # List tables in the database cursor = None try: cursor = conn.cursor() query = f"SHOW TABLES FROM `{db_name}`" # Use backticks # print(f"Executing: {query}") # Debug cursor.execute(query) tables = [row[0] for row in cursor.fetchall()] except MySQLError as e: print(f"Error listing tables in '{db_name}': {e}") reset_connection() return f"Database Error listing tables in '{db_name}': {e}" except Exception as e: print(f"Unexpected error listing tables in '{db_name}': {e}") return f"Unexpected error listing tables in '{db_name}': {e}" finally: if cursor: try: cursor.close() except Exception as ce: print(f"Warning: error closing cursor: {ce}") if not tables: return f"No tables found in database '{db_name}'." all_overviews = [f"--- Overview for Database: `{db_name}` ({len(tables)} tables) ---"] # print(f"Generating overview for {len(tables)} tables in '{db_name}' (refresh={refresh})") # Debug total_length = 0 limit_per_table = overview_length_limit * (math.log10(len(tables)) + 1) // len(tables) # Limit per table for table_name in tables: cache_key = (db_name, table_name) overview_text = None # Check cache first if not refresh and cache_key in global_table_overview_cache: # print(f"Cache hit for db overview (table): {cache_key}") # Debug overview_text = global_table_overview_cache[cache_key] else: # print(f"Cache miss or refresh for db overview (table): {cache_key}") # Debug # Fetch details for this table (will update cache via _get_table_details) overview_text = _get_table_details(conn, db_name, table_name, limit=limit_per_table) all_overviews.append(overview_text) all_overviews.append("\n") # Add separator total_length += len(overview_text) + 1 return "\n".join(all_overviews) except MySQLError as e: # Catch DB errors at tool call level reset_connection() return f"Database Error executing tool 'db_overview': {type(e).__name__}: {e}" except Exception as e: # Catch any other unexpected errors during tool execution reset_connection() # Also reset connection on unexpected errors stack_trace = traceback.format_exc() return f"Unexpected Error executing tool 'db_overview': {type(e).__name__}: {e}\nStack Trace:\n{stack_trace}" async def main(): global default_database db_suffix = f". db session already in default db `{default_database}`" if default_database else "" mcp.add_tool(read_query, description="Execute a SELECT query or commands that return a ResultSet"+db_suffix) mcp.add_tool(write_query, description="Execute a DDL/DML or other StarRocks command that do not have a ResultSet"+db_suffix) mcp.add_tool(query_and_plotly_chart, description="using sql `query` to extract data from database, then using python `plotly_expr` to generate a chart for UI to display"+db_suffix) mcp.add_tool(table_overview, description="Get an overview of a specific table: columns, sample rows (up to 5), and total row count. Uses cache unless refresh=true"+db_suffix) mcp.add_tool(db_overview, description="Get an overview (columns, sample rows, row count) for ALL tables in a database. Uses cache unless refresh=True"+db_suffix) await mcp.run_async(transport=mcp_transport) async def run_tool_test(): result_table = await table_overview("quickstart.crashdata") print("Result:") print(result_table) print("-" * 20) if __name__ == "__main__": # Example usage (requires environment variables set) print(f"Default database (STARROCKS_DB): {default_database or 'Not Set'}") if len(sys.argv) > 1 and sys.argv[1] == '--test': # Run the test function try: asyncio.run(run_tool_test()) except Exception as test_err: print(f"\nError running test function: {test_err}") finally: reset_connection() # Ensure cleanup even if run_tool_test fails badly else: asyncio.run(main()) print("MCP Server script loaded. Run via MCP host.")

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/StarRocks/mcp-server-starrocks'

If you have feedback or need assistance with the MCP directory API, please join our Discord server