from typing import Dict, Any, Optional
from .base_tool import BaseTool
from src.database.query_executor import query_executor
import logging
logger = logging.getLogger(__name__)
class QueryTool(BaseTool):
def __init__(self):
super().__init__(
name="query_data",
description="Execute SELECT queries on MySQL database with filtering, sorting, and pagination"
)
def execute(self, **kwargs) -> Dict[str, Any]:
try:
table_name = kwargs.get('table_name')
columns = kwargs.get('columns', '*')
where_clause = kwargs.get('where_clause', '')
order_by = kwargs.get('order_by', '')
limit = kwargs.get('limit', 100)
offset = kwargs.get('offset', 0)
# Validate required parameters
if not table_name:
return self.format_response(False, error="table_name is required")
# Validate table name
if not table_name.replace('_', '').replace('-', '').isalnum():
return self.format_response(False, error="Invalid table name")
# Build query
if isinstance(columns, list):
columns_str = ', '.join([f"`{col}`" for col in columns])
else:
columns_str = columns if columns != '*' else '*'
query = f"SELECT {columns_str} FROM `{table_name}`"
# Add WHERE clause
if where_clause:
query += f" WHERE {where_clause}"
# Add ORDER BY clause
if order_by:
query += f" ORDER BY {order_by}"
# Add LIMIT and OFFSET
query += f" LIMIT {int(limit)} OFFSET {int(offset)}"
logger.info(f"Executing query: {query}")
# Execute query
results = query_executor.execute_custom_query(query)
# Get total count for pagination
count_query = f"SELECT COUNT(*) as total FROM `{table_name}`"
if where_clause:
count_query += f" WHERE {where_clause}"
total_count = query_executor.execute_custom_query(count_query)[0]['total']
return self.format_response(True, {
'results': results,
'total_count': total_count,
'limit': limit,
'offset': offset,
'has_more': (offset + limit) < total_count
})
except Exception as e:
logger.error(f"Query tool execution failed: {str(e)}")
return self.format_response(False, error=str(e))