from typing import Dict, Any, List, Optional
from .base_tool import BaseTool
from src.database.connection import db_manager
import logging
logger = logging.getLogger(__name__)
class ProcedureTool(BaseTool):
def __init__(self):
super().__init__(
name="execute_procedure",
description="Execute stored procedures with parameters"
)
def execute(self, **kwargs) -> Dict[str, Any]:
try:
procedure_name = kwargs.get('procedure_name')
parameters = kwargs.get('parameters', [])
# Validate required parameters
if not procedure_name:
return self.format_response(False, error="procedure_name is required")
# Validate procedure name
if not procedure_name.replace('_', '').isalnum():
return self.format_response(False, error="Invalid procedure name")
# Build CALL statement
if parameters:
if not isinstance(parameters, list):
return self.format_response(False, error="parameters must be a list")
placeholders = ', '.join(['%s'] * len(parameters))
call_query = f"CALL {procedure_name}({placeholders})"
params = tuple(parameters)
else:
call_query = f"CALL {procedure_name}()"
params = None
logger.info(f"Executing procedure: {procedure_name} with {len(parameters)} parameters")
# Execute procedure
with db_manager.get_connection_context() as connection:
with connection.cursor() as cursor:
cursor.execute(call_query, params)
# Collect all result sets
results = []
while True:
result_set = cursor.fetchall()
if result_set:
results.append(result_set)
# Check if there are more result sets
if not cursor.nextset():
break
connection.commit()
result_data = {
'procedure_name': procedure_name,
'parameters': parameters,
'result_sets_count': len(results),
'results': results if results else None
}
logger.info(f"Successfully executed procedure {procedure_name}, got {len(results)} result sets")
return self.format_response(True, result_data)
except Exception as e:
logger.error(f"Procedure tool execution failed: {str(e)}")
return self.format_response(False, error=str(e))
def list_procedures(self) -> Dict[str, Any]:
"""List all stored procedures in the database"""
try:
query = """
SELECT
ROUTINE_NAME as procedure_name,
ROUTINE_TYPE as routine_type,
ROUTINE_COMMENT as comment,
CREATED as created_date,
LAST_ALTERED as last_modified
FROM information_schema.ROUTINES
WHERE ROUTINE_SCHEMA = DATABASE()
AND ROUTINE_TYPE = 'PROCEDURE'
ORDER BY ROUTINE_NAME
"""
procedures = db_manager.execute_query(query)
return self.format_response(True, {
'procedures_count': len(procedures),
'procedures': procedures
})
except Exception as e:
logger.error(f"List procedures failed: {str(e)}")
return self.format_response(False, error=str(e))