Skip to main content
Glama
llm_opt.py17 kB
import logging import math from dataclasses import dataclass from typing import Any from typing import override import instructor from openai import OpenAI from pglast.ast import SelectStmt from pydantic import BaseModel from postgres_mcp.artifacts import ErrorResult from postgres_mcp.explain.explain_plan import ExplainPlanTool from postgres_mcp.sql import TableAliasVisitor from ..sql import IndexDefinition from ..sql import SqlDriver from .index_opt_base import IndexRecommendation from .index_opt_base import IndexTuningBase logger = logging.getLogger(__name__) # We introduce a Pydantic index class to facilitate communication with the LLM # via the instructor library. class Index(BaseModel): table_name: str columns: tuple[str, ...] def __hash__(self): return hash((self.table_name, self.columns)) def __eq__(self, other): if not isinstance(other, Index): return False return self.table_name == other.table_name and self.columns == other.columns def to_index_recommendation(self) -> IndexRecommendation: return IndexRecommendation(table=self.table_name, columns=self.columns) def to_index_definition(self) -> IndexDefinition: return IndexDefinition(table=self.table_name, columns=self.columns) class IndexingAlternative(BaseModel): alternatives: list[set[Index]] @dataclass class ScoredIndexes: indexes: set[Index] execution_cost: float index_size: float objective_score: float class LLMOptimizerTool(IndexTuningBase): def __init__( self, sql_driver: SqlDriver, max_no_progress_attempts: int = 5, pareto_alpha: float = 2.0, ): super().__init__(sql_driver) self.sql_driver = sql_driver self.max_no_progress_attempts = max_no_progress_attempts self.pareto_alpha = pareto_alpha logger.info("Initialized LLMOptimizerTool with max_no_progress_attempts=%d", max_no_progress_attempts) def score(self, execution_cost: float, index_size: float) -> float: return math.log(execution_cost) + self.pareto_alpha * math.log(index_size) @override async def _generate_recommendations(self, query_weights: list[tuple[str, SelectStmt, float]]) -> tuple[set[IndexRecommendation], float]: """Generate index tuning queries using optimization by LLM.""" # For now we support only one table at a time if len(query_weights) > 1: logger.error("LLM optimization currently supports only one query at a time") raise ValueError("Optimization by LLM supports only one query at a time.") query = query_weights[0][0] parsed_query = query_weights[0][1] logger.info("Generating index recommendations for query: %s", query) # Extract tables from the parsed query table_visitor = TableAliasVisitor() table_visitor(parsed_query) tables = table_visitor.tables logger.info("Extracted tables from query: %s", tables) # Get the size of the tables table_sizes = {} for table in tables: table_sizes[table] = await self._get_table_size(table) total_table_size = sum(table_sizes.values()) logger.info("Total table size: %s", total_table_size) # Generate explain plan for the query explain_tool = ExplainPlanTool(self.sql_driver) explain_result = await explain_tool.explain(query) if isinstance(explain_result, ErrorResult): logger.error("Failed to generate explain plan: %s", explain_result.to_text()) raise ValueError(f"Failed to generate explain plan: {explain_result.to_text()}") # Get the explain plan JSON explain_plan_json = explain_result.value logger.debug("Generated explain plan: %s", explain_plan_json) # Extract indexes used in the explain plan indexes_used: set[Index] = await self._extract_indexes_from_explain_plan_with_columns(explain_plan_json) # Get the current cost original_cost = await self._evaluate_configuration_cost(query_weights, frozenset()) logger.info("Original query cost: %f", original_cost) original_config = ScoredIndexes( indexes=indexes_used, execution_cost=original_cost, index_size=total_table_size, objective_score=self.score(original_cost, total_table_size), ) best_config = original_config # Initialize attempt history for this run attempt_history: list[ScoredIndexes] = [original_config] no_progress_count = 0 client = instructor.from_openai(OpenAI()) # Starting cost # TODO should include the size of the starting indexes score = self.score(original_cost, total_table_size) logger.info("Starting score: %f", score) while no_progress_count < self.max_no_progress_attempts: logger.info("Requesting index recommendations from LLM") # Build history of past attempts history_prompt = "" if attempt_history: history_prompt = "\nPrevious attempts and their costs:\n" for attempt in attempt_history: indexes_str = ";".join(idx.to_index_definition().definition for idx in attempt.indexes) history_prompt += f"- Indexes: {indexes_str}, Cost: {attempt.execution_cost}, Index Size: {attempt.index_size}, " history_prompt += f"Objective Score: {attempt.objective_score}\n" if no_progress_count > 0: remaining_attempts_prompt = f"You have made {no_progress_count} attempts without progress. " if self.max_no_progress_attempts - no_progress_count < self.max_no_progress_attempts / 2: remaining_attempts_prompt += "Get creative and suggest indexes that are not obvious." else: remaining_attempts_prompt = "" response = client.chat.completions.create( model="gpt-4o", response_model=IndexingAlternative, temperature=1.2, messages=[ {"role": "system", "content": "You are a helpful assistant that generates index recommendations for a given workload."}, { "role": "user", "content": f"Here is the query we are optimizing: {query}\n" f"Here is the explain plan: {explain_plan_json}\n" f"Here are the existing indexes: {';'.join(idx.to_index_definition().definition for idx in indexes_used)}\n" f"{history_prompt}\n" "Each indexing suggestion that you provide is a combination of indexes. You can provide multiple alternative suggestions. " "We will evaluate each alternative using hypopg to see how the optimizer will be behave with those indexes in place. " "The overall score is based on a combination of execution cost and index size. In all cases, lower is better. " "Prefer fewer indexes to more indexes. Prefer indexes with fewer columns to indexes with more columns. " f"{remaining_attempts_prompt}", }, ], ) # Convert the response to IndexConfig objects index_alternatives: list[set[Index]] = response.alternatives logger.info("Received %d alternative index configurations from LLM", len(index_alternatives)) # If no alternatives were generated, break the loop if not index_alternatives: logger.warning("No index alternatives were generated by the LLM") break # Try each alternative found_improvement = False for i, index_set in enumerate(index_alternatives): try: logger.info("Evaluating alternative %d/%d with %d indexes", i + 1, len(index_alternatives), len(index_set)) # Evaluate this index configuration execution_cost_estimate = await self._evaluate_configuration_cost( query_weights, frozenset({index.to_index_definition() for index in index_set}) ) logger.info( "Alternative %d cost: %f (reduction: %.2f%%)", i + 1, execution_cost_estimate, ((best_config.execution_cost - execution_cost_estimate) / best_config.execution_cost) * 100, ) # Estimate the size of the indexes index_size_estimate = await self._estimate_index_size_2({index.to_index_definition() for index in index_set}, 1024 * 1024) logger.info("Estimated index size: %f", index_size_estimate) # Score based on a balance of size and performance score = math.log(execution_cost_estimate) + self.pareto_alpha * math.log(total_table_size + index_size_estimate) # Record this attempt in history latest_config = ScoredIndexes( indexes={Index(table_name=index.table_name, columns=index.columns) for index in index_set}, execution_cost=execution_cost_estimate, index_size=index_size_estimate, objective_score=score, ) attempt_history.append(latest_config) logger.info("Latest config: %s", latest_config) # If this is better than what we've seen so far, update our best # Minimum 2% improvement required if latest_config.objective_score < best_config.objective_score: best_config = latest_config found_improvement = True except Exception as e: # We discard the alternative. We are seeing this happen due to invalid index definitions. logger.error("Error evaluating alternative %d/%d: %s", i + 1, len(index_alternatives), str(e)) # Keep only the 5 best results in the attempt history attempt_history.sort(key=lambda x: x.objective_score) attempt_history = attempt_history[:5] if found_improvement: no_progress_count = 0 else: no_progress_count += 1 logger.info( "No improvement found in this iteration. Attempts without progress: %d/%d", no_progress_count, self.max_no_progress_attempts ) if best_config != original_config: logger.info( "Selected best index configuration with %d indexes, cost reduction: %.2f%%, indexes: %s", len(best_config.indexes), ((original_cost - best_config.execution_cost) / original_cost) * 100, ", ".join(f"{idx.table_name}.({','.join(idx.columns)})" for idx in best_config.indexes), ) else: logger.info("No better index configuration found") # Convert Index objects to IndexConfig objects for return best_index_config_set = {index.to_index_recommendation() for index in best_config.indexes} return (best_index_config_set, best_config.execution_cost) async def _estimate_index_size_2(self, index_set: set[IndexDefinition], min_size_penalty: float = 1024 * 1024) -> float: """ Estimate the size of a set of indexes using hypopg. Args: index_set: Set of IndexConfig objects representing the indexes to estimate Returns: Total estimated size of all indexes in bytes """ if not index_set: return 0.0 total_size = 0.0 for index_config in index_set: try: # Create a hypothetical index using hypopg # Using a tuple to avoid LiteralString type error create_index_query = ( "WITH hypo_index AS (SELECT indexrelid FROM hypopg_create_index(%s)) " "SELECT hypopg_relation_size(indexrelid) as size, hypopg_drop_index(indexrelid) FROM hypo_index;" ) # Execute the query to get the index size result = await self.sql_driver.execute_query(create_index_query, params=[index_config.definition]) if result and len(result) > 0: # Extract the size from the result size = result[0].cells.get("size", 0) total_size += max(float(size), min_size_penalty) logger.debug(f"Estimated size for index {index_config.name}: {size} bytes") else: logger.warning(f"Failed to estimate size for index {index_config.name}") except Exception as e: logger.error(f"Error estimating size for index {index_config.name}: {e!s}") return total_size def _extract_indexes_from_explain_plan(self, explain_plan_json: Any) -> set[tuple[str, str]]: """ Extract indexes used in the explain plan JSON. Args: explain_plan_json: The explain plan JSON from PostgreSQL Returns: A set of tuples (table_name, index_name) representing the indexes used in the plan """ indexes_used = set() if isinstance(explain_plan_json, dict): plan_data = explain_plan_json.get("Plan") if plan_data is not None: def extract_indexes_from_node(node): # Check if this is an index scan node if node.get("Node Type") in ["Index Scan", "Index Only Scan", "Bitmap Index Scan"]: if "Index Name" in node and "Relation Name" in node: # Add the table name and index name indexes_used.add((node["Relation Name"], node["Index Name"])) # Recursively process child plans if "Plans" in node: for child in node["Plans"]: extract_indexes_from_node(child) # Start extraction from the root plan extract_indexes_from_node(plan_data) logger.info("Extracted %d indexes from explain plan", len(indexes_used)) return indexes_used async def _extract_indexes_from_explain_plan_with_columns(self, explain_plan_json: Any) -> set[Index]: """ Extract indexes used in the explain plan JSON and populate their columns. Args: explain_plan_json: The explain plan JSON from PostgreSQL Returns: A set of Index objects representing the indexes used in the plan with their columns """ # First extract the indexes without columns index_tuples = self._extract_indexes_from_explain_plan(explain_plan_json) # Now populate the columns for each index indexes_with_columns = set() for table_name, index_name in index_tuples: # Get the columns for this index columns = await self._get_index_columns(index_name) # Create a new Index object with the columns index_with_columns = Index(table_name=table_name, columns=columns) indexes_with_columns.add(index_with_columns) return indexes_with_columns async def _get_index_columns(self, index_name: str) -> tuple[str, ...]: """ Get the columns for a specific index by querying the database. Args: index_name: The name of the index Returns: A tuple of column names in the index """ try: # Query to get index columns query = """ SELECT a.attname FROM pg_index i JOIN pg_class c ON c.oid = i.indexrelid JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) WHERE c.relname = %s ORDER BY array_position(i.indkey, a.attnum) """ result = await self.sql_driver.execute_query(query, [index_name]) if result and len(result) > 0: # Extract column names from the result columns = [row.cells.get("attname", "") for row in result if row.cells.get("attname")] return tuple(columns) else: logger.warning(f"No columns found for index {index_name}") return tuple() except Exception as e: logger.error(f"Error getting columns for index {index_name}: {e!s}") return tuple()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/crystaldba/postgres-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server