Excel MCP Server

by fish0710
Verified
from typing import Any import uuid import logging from openpyxl import load_workbook from openpyxl.utils import get_column_letter from openpyxl.worksheet.table import Table, TableStyleInfo from openpyxl.styles import Font from .data import read_excel_range from .cell_utils import parse_cell_range from .exceptions import ValidationError, PivotError logger = logging.getLogger(__name__) def create_pivot_table( filepath: str, sheet_name: str, data_range: str, rows: list[str], values: list[str], columns: list[str] | None = None, agg_func: str = "sum" ) -> dict[str, Any]: """Create pivot table in sheet using Excel table functionality Args: filepath: Path to Excel file sheet_name: Name of worksheet containing source data data_range: Source data range reference target_cell: Cell reference for pivot table position rows: Fields for row labels values: Fields for values columns: Optional fields for column labels agg_func: Aggregation function (sum, count, average, max, min) Returns: Dictionary with status message and pivot table dimensions """ try: wb = load_workbook(filepath) if sheet_name not in wb.sheetnames: raise ValidationError(f"Sheet '{sheet_name}' not found") # Parse ranges if ':' not in data_range: raise ValidationError("Data range must be in format 'A1:B2'") try: start_cell, end_cell = data_range.split(':') start_row, start_col, end_row, end_col = parse_cell_range(start_cell, end_cell) except ValueError as e: raise ValidationError(f"Invalid data range format: {str(e)}") if end_row is None or end_col is None: raise ValidationError("Invalid data range format: missing end coordinates") # Create range string data_range_str = f"{get_column_letter(start_col)}{start_row}:{get_column_letter(end_col)}{end_row}" # Read source data try: data = read_excel_range(filepath, sheet_name, start_cell, end_cell) if not data: raise PivotError("No data found in range") except Exception as e: raise PivotError(f"Failed to read source data: {str(e)}") # Validate aggregation function valid_agg_funcs = ["sum", "average", "count", "min", "max"] if agg_func.lower() not in valid_agg_funcs: raise ValidationError( f"Invalid aggregation function. Must be one of: {', '.join(valid_agg_funcs)}" ) # Clean up field names by removing aggregation suffixes def clean_field_name(field: str) -> str: field = str(field).strip() for suffix in [" (sum)", " (average)", " (count)", " (min)", " (max)"]: if field.lower().endswith(suffix): return field[:-len(suffix)] return field # Validate field names exist in data if data: first_row = data[0] available_fields = {clean_field_name(str(header)).lower() for header in first_row.keys()} for field_list, field_type in [(rows, "row"), (values, "value")]: for field in field_list: if clean_field_name(str(field)).lower() not in available_fields: raise ValidationError( f"Invalid {field_type} field '{field}'. " f"Available fields: {', '.join(sorted(available_fields))}" ) if columns: for field in columns: if clean_field_name(str(field)).lower() not in available_fields: raise ValidationError( f"Invalid column field '{field}'. " f"Available fields: {', '.join(sorted(available_fields))}" ) # Skip header row if it matches our fields if all( any(clean_field_name(str(header)).lower() == clean_field_name(str(field)).lower() for field in rows + values) for header in first_row.keys() ): data = data[1:] # Clean up row and value field names cleaned_rows = [clean_field_name(field) for field in rows] cleaned_values = [clean_field_name(field) for field in values] # Create pivot sheet pivot_sheet_name = f"{sheet_name}_pivot" if pivot_sheet_name in wb.sheetnames: wb.remove(wb[pivot_sheet_name]) pivot_ws = wb.create_sheet(pivot_sheet_name) # Write headers current_row = 1 current_col = 1 # Write row field headers for field in cleaned_rows: cell = pivot_ws.cell(row=current_row, column=current_col, value=field) cell.font = Font(bold=True) current_col += 1 # Write value field headers for field in cleaned_values: cell = pivot_ws.cell(row=current_row, column=current_col, value=f"{field} ({agg_func})") cell.font = Font(bold=True) current_col += 1 # Get unique values for each row field field_values = {} for field in cleaned_rows: all_values = [] for record in data: value = str(record.get(field, '')) all_values.append(value) field_values[field] = sorted(set(all_values)) # Generate all combinations of row field values row_combinations = _get_combinations(field_values) # Calculate table dimensions for formatting total_rows = len(row_combinations) + 1 # +1 for header total_cols = len(cleaned_rows) + len(cleaned_values) # Write data rows current_row = 2 for combo in row_combinations: # Write row field values col = 1 for field in cleaned_rows: pivot_ws.cell(row=current_row, column=col, value=combo[field]) col += 1 # Filter data for current combination filtered_data = _filter_data(data, combo, {}) # Calculate and write aggregated values for value_field in cleaned_values: try: value = _aggregate_values(filtered_data, value_field, agg_func) pivot_ws.cell(row=current_row, column=col, value=value) except Exception as e: raise PivotError(f"Failed to aggregate values for field '{value_field}': {str(e)}") col += 1 current_row += 1 # Create a table for the pivot data try: pivot_range = f"A1:{get_column_letter(total_cols)}{total_rows}" pivot_table = Table( displayName=f"PivotTable_{uuid.uuid4().hex[:8]}", ref=pivot_range ) style = TableStyleInfo( name="TableStyleMedium9", showFirstColumn=False, showLastColumn=False, showRowStripes=True, showColumnStripes=True ) pivot_table.tableStyleInfo = style pivot_ws.add_table(pivot_table) except Exception as e: raise PivotError(f"Failed to create pivot table formatting: {str(e)}") try: wb.save(filepath) except Exception as e: raise PivotError(f"Failed to save workbook: {str(e)}") return { "message": "Summary table created successfully", "details": { "source_range": data_range_str, "pivot_sheet": pivot_sheet_name, "rows": cleaned_rows, "columns": columns or [], "values": cleaned_values, "aggregation": agg_func } } except (ValidationError, PivotError) as e: logger.error(str(e)) raise except Exception as e: logger.error(f"Failed to create pivot table: {e}") raise PivotError(str(e)) def _get_combinations(field_values: dict[str, set]) -> list[dict]: """Get all combinations of field values.""" result = [{}] for field, values in list(field_values.items()): # Convert to list to avoid runtime changes new_result = [] for combo in result: for value in sorted(values): # Sort for consistent ordering new_combo = combo.copy() new_combo[field] = value new_result.append(new_combo) result = new_result return result def _filter_data(data: list[dict], row_filters: dict, col_filters: dict) -> list[dict]: """Filter data based on row and column filters.""" result = [] for record in data: matches = True for field, value in row_filters.items(): if record.get(field) != value: matches = False break for field, value in col_filters.items(): if record.get(field) != value: matches = False break if matches: result.append(record) return result def _aggregate_values(data: list[dict], field: str, agg_func: str) -> float: """Aggregate values using the specified function.""" values = [record[field] for record in data if field in record and isinstance(record[field], (int, float))] if not values: return 0 if agg_func == "sum": return sum(values) elif agg_func == "average": return sum(values) / len(values) elif agg_func == "count": return len(values) elif agg_func == "min": return min(values) elif agg_func == "max": return max(values) else: return sum(values) # Default to sum