tools.py•4.69 kB
##? Import Libraries
from pydantic import Field
from utils.db import execute_sql_query
from . import mcp
## DEFn_: FIRST TOOL: to ping the database for health check and return a success message to the user
@mcp.tool(name="ping", description="Tool to ping the database for health check and return a success message")
def ping() -> dict[str, str]:
"""
Objective:
Tool to ping the database for health check and return a success message
"""
return {"result": "PostgreSQL MCP Server is running"}
## DEFn_: SECOND TOOL: to list all the schemas in the database
@mcp.tool(name="list_schemas", description="Tool to list all the schemas in the database")
async def list_schemas() -> dict[str, list[str]]:
"""
Objective:
Tool to list all the schemas in the database
Args:
schema (str): The name of the schema to list the tables from
"""
results = await execute_sql_query("SELECT schema_name FROM information_schema.schemata")
## NOTE: the result key is always used to return the result of the tool to the user
schema_dict = {"result": [f"{row['schema_name']}" for row in results]}
return schema_dict
## DEFn_: THIRD TOOL: to list all the tables in the given schema
@mcp.tool(name="list_tables", description="Tool to list all the tables in the given schema")
async def list_tables(
schema: str = Field(..., description=" Schema name to list all the tables from. Example:public",
min_length=2)) -> dict[str, list[str]]:
"""
Objective:
List all the tables present in the given schema
Args:
schema (str): The name of the schema to list the tables from
Returns:
list[str]: The list of tables in the given schema
"""
query_ = f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema}' ORDER BY table_name"
results = await execute_sql_query(query_)
table_dict = {"result": [f"{row['table_name']}" for row in results]}
return table_dict
## DEFn_: FOURTH TOOL: to get information of a table inside a database under a particular schema
@mcp.tool(name="get_table_info", description="Get information of a table inside a database under a particular schema")
async def get_table_info(
table_name: str = Field(..., description="table name to get information for", min_length=2),
schema_name: str = Field(..., description="schema under which table name falls in. Example: public", min_length=2)
) -> list:
"""
Objective:
Get table information of a table_name under a schema_name
Args:
table_name (str): The name of the table to get information for
schema_name (str): The name of the schema under which the table name falls in. Example: public
Returns:
list: The list of table information
"""
query_ = f"""SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_schema = '{schema_name}' AND table_name = '{table_name}'
ORDER BY ordinal_position"""
results = await execute_sql_query(query_)
return results
## DEFn_: FIFTH TOOL: to run any given sql query for only reading and get results
@mcp.tool(name="run_sql_query", description="To run any given sql query for only reading and get results")
async def run_sql_query(
query: str = Field(..., description="query to run for reading table data", min_length=4)
) -> list:
"""
Objective:
To run only allowed query(only reading) query and get results in a list
Args:
query (str): SQL query to run
Returns:
The results of the query after execution or Error in case of any issue querying the data
"""
forbidden = ["delete", "drop", "destroy", "truncate", "update", "insert", "alter", "create"]
allowed = ["select", "explain", "show", "describe"]
q = query.lower().strip()
if any(word in q for word in forbidden):
raise ValueError("Error: Forbidden query")
if not any(word in q for word in allowed):
raise ValueError("Error: Allowed queries are select, explain, show, describe")
results = await execute_sql_query(query)
return results
## DEFn_: SIXTH TOOL: Run Explain Query to get performance metrics of the query
@mcp.tool(name="run_explain_query", description="To run explain query to get performance metrics of the query")
async def run_explain_query(
query: str = Field(..., description="query to run for explain", min_length=4)
) -> list:
"""
Objective:
To run explain query to get performance metrics of the query
"""
query_ = f"EXPLAIN {query}"
return await execute_sql_query(query_)