Skip to main content
Glama
startreedata

StarTree MCP Server for Apache Pinot

Official
by startreedata
pinot_client.py27.4 kB
import base64 from fnmatch import fnmatch import json import re from threading import Lock from typing import Any, Dict, Tuple import pandas as pd from pinotdb import connect import requests from .config import PinotConfig, get_logger, reload_table_filters_from_file logger = get_logger() def get_auth_credentials(config: PinotConfig) -> Tuple[str | None, str | None]: """Extract authentication credentials for PinotDB connection""" if config.token: if config.token.startswith("Bearer "): return "", config.token # Empty username, full Bearer token as password else: return "", config.token elif config.username and config.password: return config.username, config.password return None, None def test_connection_query(connection) -> None: """Test connection with a simple query""" test_cursor = connection.cursor() test_cursor.execute("SELECT 1") test_result = test_cursor.fetchall() logger.debug(f"Connection test successful: {test_result}") # URL pattern constants class PinotEndpoints: QUERY_SQL = "query/sql" TABLES = "tables" SCHEMAS = "schemas" TABLE_SIZE = "tables/{}/size" SEGMENTS = "segments/{}" SEGMENT_METADATA = "segments/{}/metadata" SEGMENT_DETAIL = "segments/{}_{}/{}/metadata?columns=*" TABLE_CONFIG = "tableConfigs/{}" def create_connection(config: PinotConfig) -> connect: """Create Pinot connection with proper authentication handling""" try: auth_username, auth_password = get_auth_credentials(config) logger.debug( f"Creating connection to {config.broker_host}:{config.broker_port} " f"with MSQE={config.use_msqe}" ) auth_method = "token" if config.token else "username/password" logger.debug(f"Database: {config.database}, Auth method: {auth_method}") connection = connect( host=config.broker_host, port=config.broker_port, path="/query/sql", scheme=config.broker_scheme, username=auth_username, password=auth_password, use_multistage_engine=config.use_msqe, database=config.database, extra_conn_args={ "timeout": config.query_timeout, "verify": True, "retries": 3, "backoff_factor": 1.0, }, ) test_connection_query(connection) return connection except Exception as e: logger.error(f"Failed to create Pinot connection: {e}") logger.error( f"Connection details - Host: {config.broker_host}, " f"Port: {config.broker_port}, Scheme: {config.broker_scheme}" ) raise class PinotClient: def __init__(self, config: PinotConfig): self.config = config self.insights: list[str] = [] self._conn = None # Store filters separately to avoid mutating config self._included_tables = config.included_tables self._config_lock = Lock() # For thread-safe filter updates def reload_table_filters(self) -> dict[str, Any]: """Reload table filters from the configured filter file without restarting. This allows dynamic updates to the table access list by: 1. Editing the YAML filter file 2. Calling this method to reload the configuration Returns: dict: Status information with previous and new filter counts Raises: ValueError: If no table filter file is configured FileNotFoundError: If the filter file doesn't exist yaml.YAMLError: If the file contains invalid YAML """ if not self.config.table_filter_file: raise ValueError( "No table filter file configured. " "Set PINOT_TABLE_FILTER_FILE to enable hot-reload." ) logger.info(f"Reloading table filters from {self.config.table_filter_file}") # Load new filters (validates file exists and parses YAML) new_filters = reload_table_filters_from_file(self.config.table_filter_file) # Atomically update the filters with lock with self._config_lock: old_filters = self._included_tables old_count = len(old_filters) if old_filters else 0 self._included_tables = new_filters new_count = len(new_filters) if new_filters else 0 logger.info(f"Table filters reloaded: {old_count} -> {new_count} tables") return { "status": "success", "message": "Table filters reloaded successfully", "previous_filter_count": old_count, "new_filter_count": new_count, "previous_filters": old_filters, "new_filters": new_filters, } def _create_auth_headers(self) -> Dict[str, str]: """Create HTTP headers with authentication based on configuration""" headers = {"accept": "application/json", "Content-Type": "application/json"} if self.config.token: headers["Authorization"] = self.config.token elif self.config.username and self.config.password: creds_str = f"{self.config.username}:{self.config.password}" credentials = base64.b64encode(creds_str.encode()).decode() headers["Authorization"] = f"Basic {credentials}" if self.config.database: headers["database"] = self.config.database return headers def http_request( self, url: str, method: str = "GET", json_data: Dict = None, ) -> requests.Response: """Make HTTP request with authentication headers and timeout handling""" headers = self._create_auth_headers() try: if method.upper() == "POST": response = requests.post( url, headers=headers, json=json_data, timeout=( self.config.connection_timeout, self.config.request_timeout, ), verify=True, ) else: response = requests.get( url, headers=headers, timeout=( self.config.connection_timeout, self.config.request_timeout, ), verify=True, ) response.raise_for_status() return response except requests.exceptions.Timeout: logger.error(f"HTTP request timeout for {url}") raise except Exception as e: logger.error(f"HTTP request failed for {url}: {e}") raise def get_connection(self): """Get or create a reusable connection""" try: if self._conn is None: self._conn = create_connection(self.config) else: # Test if connection is still alive test_connection_query(self._conn) return self._conn except Exception as e: logger.warning(f"Connection test failed, creating new connection: {e}") self._conn = create_connection(self.config) return self._conn def test_connection(self) -> dict[str, Any]: """Test the connection and return diagnostic information""" result = { "connection_test": False, "query_test": False, "tables_test": False, "error": None, "config": { "broker_host": self.config.broker_host, "broker_port": self.config.broker_port, "broker_scheme": self.config.broker_scheme, "controller_url": self.config.controller_url, "database": self.config.database, "use_msqe": self.config.use_msqe, "has_token": bool(self.config.token), "has_username": bool(self.config.username), "timeout_config": { "connection": self.config.connection_timeout, "request": self.config.request_timeout, "query": self.config.query_timeout, }, }, } try: # Test basic connection conn = self.get_connection() result["connection_test"] = True # Test simple query curs = conn.cursor() curs.execute("SELECT 1 as test_column") test_result = curs.fetchall() result["query_test"] = True result["query_result"] = test_result # Test tables listing tables = self.get_tables() result["tables_test"] = True result["tables_count"] = len(tables) result["sample_tables"] = tables[:5] if tables else [] except Exception as e: result["error"] = str(e) logger.error(f"Connection test failed: {e}") return result def execute_query_http(self, query: str) -> list[dict[str, Any]]: """Alternative query execution using HTTP requests directly to broker""" broker_url = f"{self.config.broker_scheme}://{self.config.broker_host}:{self.config.broker_port}/{PinotEndpoints.QUERY_SQL}" logger.debug(f"Executing query via HTTP: {query[:100]}...") payload = { "sql": query, "queryOptions": f"timeoutMs={self.config.query_timeout * 1000}", } response = self.http_request(broker_url, "POST", payload) result_data = response.json() # Check for query errors in response if "exceptions" in result_data and result_data["exceptions"]: raise Exception(f"Query error: {result_data['exceptions']}") # Parse the result into pandas-like format if "resultTable" in result_data: columns = result_data["resultTable"]["dataSchema"]["columnNames"] rows = result_data["resultTable"]["rows"] # Convert to list of dictionaries result = [dict(zip(columns, row)) for row in rows] logger.debug(f"HTTP query returned {len(result)} rows") return result else: logger.warning("No resultTable in response, returning empty result") return [] def execute_query( self, query: str, params: dict[str, Any] | None = None, ) -> list[dict[str, Any]]: logger.debug(f"Executing query: {query[:100]}...") # Log first 100 chars # Validate table access authorization self._validate_table_access(query) # Use HTTP as primary method since it works reliably with authenticated clusters try: return self.execute_query_http(query) except Exception as e: logger.warning(f"HTTP query failed: {e}, trying PinotDB fallback") try: return self.execute_query_pinotdb(query, params) except Exception as pinotdb_error: error_msg = ( f"Both HTTP and PinotDB queries failed. " f"HTTP: {e}, PinotDB: {pinotdb_error}" ) logger.error(error_msg) raise def preprocess_query(self, query: str) -> str: """Preprocess query by removing database prefix and adding timeout options""" # Remove database prefix if present if self.config.database and f"{self.config.database}." in query: query = query.replace(f"{self.config.database}.", "") logger.debug(f"Removed database prefix, query now: {query[:100]}...") # Add query timeout hint if not present if "SET timeoutMs" not in query.upper() and "OPTION" not in query.upper(): timeout_ms = self.config.query_timeout * 1000 # Convert to milliseconds if query.strip().endswith(";"): query = query.rstrip(";") query = f"{query} OPTION(timeoutMs={timeout_ms})" logger.debug(f"Added timeout option: {timeout_ms}ms") return query def execute_query_pinotdb( self, query: str, params: dict[str, Any] | None = None, ) -> list[dict[str, Any]]: """Original pinotdb-based query execution""" logger.debug(f"Executing query via PinotDB: {query[:100]}...") try: current_conn = self.get_connection() curs = current_conn.cursor() query = self.preprocess_query(query) logger.debug(f"Final query: {query}") curs.execute(query) # Get column names and fetch results columns = [item[0] for item in curs.description] if curs.description else [] df = pd.DataFrame(curs.fetchall(), columns=columns) result = df.to_dict(orient="records") logger.debug(f"Query executed successfully, returned {len(result)} rows") return result except Exception as e: logger.error(f"Query execution failed: {e}") logger.error(f"Query was: {query}") # Reset connection on error self._conn = None raise def _matches_patterns(self, table: str, patterns: list[str]) -> bool: """Check if table matches any pattern.""" return any(fnmatch(table, pattern) for pattern in patterns) def _is_table_filtering_enabled(self) -> bool: """Check if table filtering is configured and enabled. Returns: bool: True if filtering is enabled (included_tables is configured), False otherwise (None, empty list, or any falsy value) """ return bool(self._included_tables) def _extract_sql_table_names(self, query: str) -> list[str]: """Extract table names from a SQL query. Handles table references in FROM, JOIN, and subquery clauses. Supports quoted identifiers (double quotes, backticks). Args: query: SQL query string Returns: list[str]: Unique list of table names found in the query """ # Remove comments and normalize whitespace query = re.sub(r"--.*?$", "", query, flags=re.MULTILINE) query = re.sub(r"/\*.*?\*/", "", query, flags=re.DOTALL) query = " ".join(query.split()) matches = [] # Pattern 1: Unquoted tables (after FROM/JOIN or comma-separated) # Matches: FROM table, JOIN table, table1, table2 # Uses negative lookahead to exclude SQL keywords (LEFT, RIGHT, INNER, etc.) unquoted_pattern = ( r"(?:\b(?:FROM|JOIN)\s+|,\s*)" r"(?:[\w.]+\.)?" r"(?!(?:LEFT|RIGHT|INNER|OUTER|FULL|CROSS|ON|WHERE|GROUP|ORDER|" r"HAVING|LIMIT)\b)" r"(\w+)" ) matches.extend(re.findall(unquoted_pattern, query, re.IGNORECASE)) # Pattern 2: Double-quoted tables (after FROM/JOIN or comma-separated) # Matches: FROM "table name", "quoted_table", "another table" double_quoted_pattern = r'(?:\b(?:FROM|JOIN)\s+|,\s*)(?:[\w.]+\.)?"([^"]+)"' matches.extend(re.findall(double_quoted_pattern, query, re.IGNORECASE)) # Pattern 3: Backtick-quoted tables (after FROM/JOIN or comma-separated) # Matches: FROM `table_name`, `quoted table`, `another table` backtick_pattern = r"(?:\b(?:FROM|JOIN)\s+|,\s*)(?:[\w.]+\.)?`([^`]+)`" matches.extend(re.findall(backtick_pattern, query, re.IGNORECASE)) return list(set(matches)) def _validate_table_name_access(self, table_name: str) -> None: """Validate that a table name is allowed by filtering rules. Args: table_name: Table name to validate Raises: ValueError: If table is not in included_tables filter """ if not self._is_table_filtering_enabled(): return if not self._matches_patterns(table_name, self._included_tables): allowed = ", ".join(self._included_tables) raise ValueError( f"Access denied to table '{table_name}'. Allowed tables: {allowed}" ) def _extract_and_validate_name_from_json(self, json_str: str, key: str) -> None: """Extract and validate table/schema name from JSON. Args: json_str: JSON string containing table or schema name key: JSON key to extract ("tableName" or "schemaName") Raises: ValueError: If name extraction fails or access is denied """ try: data = json.loads(json_str) name = data.get(key) if not name: raise ValueError(f"Missing required field '{key}' in JSON") self._validate_table_name_access(name) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON: {e}") def _validate_table_access(self, query: str) -> None: """Validate that query only accesses allowed tables. Args: query: SQL query string to validate Raises: ValueError: If query references tables not in included_tables filter """ if not self._is_table_filtering_enabled(): return table_names = self._extract_sql_table_names(query) if not table_names: return unauthorized_tables = [ table for table in table_names if not self._matches_patterns(table, self._included_tables) ] if unauthorized_tables: allowed = ", ".join(self._included_tables) unauthorized = ", ".join(unauthorized_tables) raise ValueError( f"Query references unauthorized tables: {unauthorized}. " f"Allowed tables: {allowed}" ) def _filter_tables(self, tables: list[str]) -> list[str]: """Filter tables based on included_tables configuration.""" if not tables or not self._is_table_filtering_enabled(): return tables return [t for t in tables if self._matches_patterns(t, self._included_tables)] def get_tables(self, params: dict[str, Any] | None = None) -> list[str]: url = f"{self.config.controller_url}/{PinotEndpoints.TABLES}" logger.debug(f"Fetching tables from: {url}") response = self.http_request(url) tables = response.json()["tables"] logger.debug(f"Successfully fetched {len(tables)} tables") return self._filter_tables(tables) def get_table_detail( self, tableName: str, params: dict[str, Any] | None = None, ) -> dict[str, Any]: self._validate_table_name_access(tableName) endpoint = PinotEndpoints.TABLE_SIZE.format(tableName) url = f"{self.config.controller_url}/{endpoint}" logger.debug(f"Fetching table details for {tableName} from: {url}") response = self.http_request(url) return response.json() def get_segment_metadata_detail( self, tableName: str, params: dict[str, Any] | None = None, ) -> dict[str, Any]: self._validate_table_name_access(tableName) endpoint = PinotEndpoints.SEGMENT_METADATA.format(tableName) url = f"{self.config.controller_url}/{endpoint}" logger.debug(f"Fetching segment metadata for {tableName} from: {url}") response = self.http_request(url) return response.json() def get_segments( self, tableName: str, params: dict[str, Any] | None = None, ) -> dict[str, Any]: self._validate_table_name_access(tableName) endpoint = PinotEndpoints.SEGMENTS.format(tableName) url = f"{self.config.controller_url}/{endpoint}" logger.debug(f"Fetching segments for {tableName} from: {url}") response = self.http_request(url) return response.json() def get_index_column_detail( self, tableName: str, segmentName: str, params: dict[str, Any] | None = None, ) -> dict[str, Any]: self._validate_table_name_access(tableName) for type_suffix in ["REALTIME", "OFFLINE"]: endpoint = PinotEndpoints.SEGMENT_DETAIL.format( tableName, type_suffix, segmentName ) url = f"{self.config.controller_url}/{endpoint}" logger.debug(f"Trying to fetch index column details from: {url}") try: response = self.http_request(url) return response.json() except Exception as e: error_msg = ( f"Failed to fetch index column details for " f"{tableName}_{type_suffix}/{segmentName}: {e}" ) logger.error(error_msg) continue raise ValueError("Index column detail not found") def get_tableconfig_schema_detail( self, tableName: str, params: dict[str, Any] | None = None, ) -> dict[str, Any]: self._validate_table_name_access(tableName) endpoint = PinotEndpoints.TABLE_CONFIG.format(tableName) url = f"{self.config.controller_url}/{endpoint}" logger.debug(f"Fetching table config for {tableName} from: {url}") response = self.http_request(url) return response.json() def create_schema( self, schemaJson: str, override: bool = True, force: bool = False, ) -> dict[str, Any]: self._extract_and_validate_name_from_json(schemaJson, "schemaName") url = f"{self.config.controller_url}/{PinotEndpoints.SCHEMAS}" params = {"override": str(override).lower(), "force": str(force).lower()} headers = self._create_auth_headers() headers["Content-Type"] = "application/json" response = requests.post( url, headers=headers, params=params, data=schemaJson, timeout=(self.config.connection_timeout, self.config.request_timeout), verify=True, ) response.raise_for_status() try: return response.json() except requests.exceptions.JSONDecodeError: return { "status": "success", "message": "Schema creation request processed.", "response_body": response.text, } def update_schema( self, schemaName: str, schemaJson: str, reload: bool = False, force: bool = False, ) -> dict[str, Any]: self._validate_table_name_access(schemaName) url = f"{self.config.controller_url}/{PinotEndpoints.SCHEMAS}/{schemaName}" params = {"reload": str(reload).lower(), "force": str(force).lower()} headers = self._create_auth_headers() headers["Content-Type"] = "application/json" response = requests.put( url, headers=headers, params=params, data=schemaJson, timeout=(self.config.connection_timeout, self.config.request_timeout), verify=True, ) response.raise_for_status() try: return response.json() except requests.exceptions.JSONDecodeError: return { "status": "success", "message": "Schema update request processed.", "response_body": response.text, } def get_schema(self, schemaName: str) -> dict[str, Any]: self._validate_table_name_access(schemaName) url = f"{self.config.controller_url}/{PinotEndpoints.SCHEMAS}/{schemaName}" headers = self._create_auth_headers() response = requests.get( url, headers=headers, timeout=(self.config.connection_timeout, self.config.request_timeout), verify=True, ) response.raise_for_status() return response.json() def create_table_config( self, tableConfigJson: str, validationTypesToSkip: str | None = None, ) -> dict[str, Any]: self._extract_and_validate_name_from_json(tableConfigJson, "tableName") url = f"{self.config.controller_url}/{PinotEndpoints.TABLES}" params: dict[str, str] = {} if validationTypesToSkip: params["validationTypesToSkip"] = validationTypesToSkip headers = self._create_auth_headers() headers["Content-Type"] = "application/json" response = requests.post( url, headers=headers, params=params, data=tableConfigJson, timeout=(self.config.connection_timeout, self.config.request_timeout), verify=True, ) response.raise_for_status() try: return response.json() except requests.exceptions.JSONDecodeError: return { "status": "success", "message": "Table config creation request processed.", "response_body": response.text, } def update_table_config( self, tableName: str, tableConfigJson: str, validationTypesToSkip: str | None = None, ) -> dict[str, Any]: self._validate_table_name_access(tableName) url = f"{self.config.controller_url}/{PinotEndpoints.TABLES}/{tableName}" params: dict[str, str] = {} if validationTypesToSkip: params["validationTypesToSkip"] = validationTypesToSkip headers = self._create_auth_headers() headers["Content-Type"] = "application/json" response = requests.put( url, headers=headers, params=params, data=tableConfigJson, timeout=(self.config.connection_timeout, self.config.request_timeout), verify=True, ) response.raise_for_status() try: return response.json() except requests.exceptions.JSONDecodeError: return { "status": "success", "message": "Table config update request processed.", "response_body": response.text, } def get_table_config( self, tableName: str, tableType: str | None = None, ) -> dict[str, Any]: self._validate_table_name_access(tableName) url = f"{self.config.controller_url}/{PinotEndpoints.TABLES}/{tableName}" params: dict[str, str] = {} if tableType: params["type"] = tableType headers = self._create_auth_headers() response = requests.get( url, headers=headers, params=params, timeout=(self.config.connection_timeout, self.config.request_timeout), verify=True, ) response.raise_for_status() raw_response = response.json() if tableType and tableType.upper() in raw_response: return raw_response[tableType.upper()] if not tableType and ("OFFLINE" in raw_response or "REALTIME" in raw_response): return raw_response return raw_response

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/startreedata/mcp-pinot'

If you have feedback or need assistance with the MCP directory API, please join our Discord server