from typing import Dict, Any, List
from .base_tool import BaseTool
from src.database.connection import db_manager
import logging
logger = logging.getLogger(__name__)
class InsertTool(BaseTool):
def __init__(self):
super().__init__(
name="insert_data",
description="Insert new data into a table"
)
def execute(self, **kwargs) -> Dict[str, Any]:
try:
table_name = kwargs.get('table_name')
data = kwargs.get('data') # Dictionary or list of dictionaries
# Validate required parameters
if not table_name:
return self.format_response(False, error="table_name is required")
if not data:
return self.format_response(False, error="data is required")
# Validate table name
if not table_name.replace('_', '').replace('-', '').isalnum():
return self.format_response(False, error="Invalid table name")
# Handle single dictionary or list of dictionaries
if isinstance(data, dict):
data_list = [data]
elif isinstance(data, list):
data_list = data
else:
return self.format_response(False, error="data must be a dictionary or list of dictionaries")
if not data_list:
return self.format_response(False, error="data cannot be empty")
# Get columns from first data item
columns = list(data_list[0].keys())
# Validate all data items have same columns
for item in data_list:
if list(item.keys()) != columns:
return self.format_response(False, error="All data items must have the same columns")
# Build INSERT query
columns_str = ', '.join([f"`{col}`" for col in columns])
placeholders = ', '.join(['%s'] * len(columns))
query = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
# Prepare parameters
params_list = []
for item in data_list:
params = tuple(item[col] for col in columns)
params_list.append(params)
logger.info(f"Inserting {len(params_list)} rows into {table_name}")
# Execute insert
if len(params_list) == 1:
affected_rows = db_manager.execute_update(query, params_list[0])
else:
affected_rows = db_manager.execute_many(query, params_list)
result = {
'table_name': table_name,
'inserted_rows': affected_rows,
'columns': columns,
'data_count': len(data_list)
}
logger.info(f"Successfully inserted {affected_rows} rows into {table_name}")
return self.format_response(True, result)
except Exception as e:
logger.error(f"Insert tool execution failed: {str(e)}")
return self.format_response(False, error=str(e))