train_linear_regression_model
Train a linear regression model on uploaded CSV data to predict values and evaluate performance using RMSE metrics.
Instructions
This function trains linear regression model.
Args: Takes input for output column name.
Returns: String which contains the RMSE value.
Input Schema
TableJSON Schema
| Name | Required | Description | Default |
|---|---|---|---|
| output_column | Yes |
Implementation Reference
- server.py:123-165 (handler)Implements the training of a linear regression model using scikit-learn. It retrieves data from context, prepares features and target, splits into train/test sets (90/10), fits the model, predicts on test set, computes and returns RMSE.def train_linear_regression_model(output_column: str) -> str: """ This function trains linear regression model. Args: Takes input for output column name. Returns: String which contains the RMSE value. """ try: data = context.get_data() # Check if the output column exists in the dataset if output_column not in data.columns: return f"Error: '{output_column}' column not found in the dataset." # Prepare the features (X) and target variable (y) X = data.drop(columns=[output_column]) # Drop the target column for features y = data[output_column] # The target variable is the output column # Split the data into training and test sets (80% train, 20% test) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) # Initialize the Linear Regression model model = LinearRegression() # Train the model model.fit(X_train, y_train) # Predict on the test set y_pred = model.predict(X_test) # Calculate RMSE (Root Mean Squared Error) rmse = np.sqrt(mean_squared_error(y_test, y_pred)) # Return the RMSE value return f"Model trained successfully. RMSE: {rmse:.4f}" except Exception as e: return f"An error occurred while training the model: {str(e)}"
- server.py:122-122 (registration)Registers the train_linear_regression_model function as an MCP tool using the FastMCP decorator.@mcp.tool()
- server.py:14-35 (helper)DataContext dataclass provides shared storage for the pandas DataFrame used by the tool (via global context instance). Used in the handler to get_data().@dataclass class DataContext(): """ A class that stores the DataFrame in the context. """ _data: pd.DataFrame = None def set_data(self, new_data: pd.DataFrame): """ Method to set or update the data. """ self._data = new_data def get_data(self) -> pd.DataFrame: """ Method to get the data from the context. """ return self._data # Initialize the DataContext instance globally context = DataContext()