from typing import Dict, Any
from .base_tool import BaseTool
from src.database.connection import db_manager
import logging
logger = logging.getLogger(__name__)
class UpdateTool(BaseTool):
def __init__(self):
super().__init__(
name="update_data",
description="Update existing data in a table"
)
def execute(self, **kwargs) -> Dict[str, Any]:
try:
table_name = kwargs.get('table_name')
data = kwargs.get('data') # Dictionary of column: value to update
where_clause = kwargs.get('where_clause')
where_params = kwargs.get('where_params', [])
# Validate required parameters
if not table_name:
return self.format_response(False, error="table_name is required")
if not data:
return self.format_response(False, error="data is required")
if not where_clause:
return self.format_response(False, error="where_clause is required for safety")
# Validate table name
if not table_name.replace('_', '').replace('-', '').isalnum():
return self.format_response(False, error="Invalid table name")
# Validate data is dictionary
if not isinstance(data, dict):
return self.format_response(False, error="data must be a dictionary")
# Build UPDATE query
set_clauses = []
set_params = []
for column, value in data.items():
# Basic column name validation
if not column.replace('_', '').replace('-', '').isalnum():
return self.format_response(False, error=f"Invalid column name: {column}")
set_clauses.append(f"`{column}` = %s")
set_params.append(value)
set_clause = ', '.join(set_clauses)
query = f"UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}"
# Combine parameters
all_params = set_params + (where_params if isinstance(where_params, list) else [where_params])
logger.info(f"Updating data in {table_name} with WHERE: {where_clause}")
# Execute update
affected_rows = db_manager.execute_update(query, tuple(all_params))
result = {
'table_name': table_name,
'updated_rows': affected_rows,
'updated_columns': list(data.keys()),
'where_clause': where_clause
}
logger.info(f"Successfully updated {affected_rows} rows in {table_name}")
return self.format_response(True, result)
except Exception as e:
logger.error(f"Update tool execution failed: {str(e)}")
return self.format_response(False, error=str(e))