Skip to main content
Glama
sdk_client.py21.4 kB
import asyncio import base64 import io import math from typing import Any, Dict, Optional, Tuple # Set matplotlib to non-interactive backend BEFORE importing pyplot # This prevents GUI windows and Python icon from appearing in dock import matplotlib import pyarrow as pa import pyarrow.compute as pc import rockfish as rf import rockfish.actions as ra import rockfish.labs as rl from rockfish.remote import glue matplotlib.use("Agg") import logging import uuid import matplotlib.pyplot as plt logger = logging.getLogger(__name__) class RockfishSDKClient: def __init__( self, API_KEY: str, API_URL: str, ORGANIZATION_ID: Optional[str] = None, PROJECT_ID: Optional[str] = None, ): """Initialize SDK client using environment variables via Connection.from_env().""" self._conn = rf.Connection.remote( API_KEY, api_url=API_URL, organization=ORGANIZATION_ID, project=PROJECT_ID ) self._cache = {} async def close(self): """Close the SDK connection.""" if self._conn is not None: await self._conn.close() self._conn = None async def call_endpoint( self, tool_name: str, arguments: Dict[str, Any] ) -> Dict[str, Any]: """Route tool calls to appropriate SDK methods. Raises: NotImplementedError: If the operation is not supported by the SDK ValueError: If unknown tool name """ if tool_name == "obtain_train_config": dataset_id = arguments["dataset_id"] model_type = arguments["model_type"] # dataset_properties = arguments["dataset_properties"] # field_properties_map = arguments["field_properties_map"] dataset = await get_local_dataset(self._conn, dataset_id) # TODO: Add train_config column_metadata = {} if model_type == "rf_tab_gan": train_config, column_metadata = guess_tab_gan_train_config(dataset) elif model_type == "rf_time_gan": return { "success": False, "message": f"Model type 'rf_time_gan' is not yet implemented. Currently only 'rf_tab_gan' is supported.", "dataset_id": dataset_id, "model_type": model_type, } else: return { "success": False, "message": f"Model type '{model_type}' is not supported. Currently only 'rf_tab_gan' is supported.", "dataset_id": dataset_id, "model_type": model_type, } # Serialize Config object to dict using Rockfish converter train_config_dict = rf.converter.unstructure(train_config) config_id = f"config_{uuid.uuid4()}" self._cache[config_id] = train_config_dict # Build response with column metadata and warnings response = { "success": True, "dataset_id": dataset_id, "train_config_id": config_id, "train_config": train_config_dict, } # Add warning for high cardinality columns high_card_cols = column_metadata.get("high_cardinality_columns", []) if high_card_cols: warning_msg = ( f"For now, we are ignoring {len(high_card_cols)} high cardinality columns: {high_card_cols}. " "These columns will be handled in future iterations." ) response["warnings"] = [warning_msg] return response elif tool_name == "start_training_workflow": dataset_id = arguments["dataset_id"] train_config_id = arguments["train_config_id"] # Check if config exists in cache if train_config_id not in self._cache: return { "success": False, "message": f"Config ID '{train_config_id}' not found in cache. It may have expired or already been used. Please call obtain_train_config again.", "dataset_id": dataset_id, "train_config_id": train_config_id, } train_config = self._cache.pop(train_config_id) # Detect model type from config train_config = rf.converter.unstructure(train_config) # NOTICE: unstructured config (dict) returns "tabular-gan" rather than "tabular_gan" if "tabular-gan" in train_config: train_config = rf.converter.structure( train_config, ra.TrainTabGAN.Config ) train_action = ra.TrainTabGAN(train_config) elif "doppelganger" in train_config: train_config = rf.converter.structure( train_config, ra.TrainTimeGAN.Config ) train_action = ra.TrainTimeGAN(train_config) else: return { "success": False, "message": "Unsupported training config format. Currently only RF-Tab-GAN (tabular-gan) and RF-Time-GAN (doppelganger) models are supported.", "dataset_id": dataset_id, "train_config_id": train_config_id, } load_action = ra.DatasetLoad(dataset_id=dataset_id) builder = create_workflow([load_action, train_action]) train_workflow = await builder.start(self._conn) return {"success": True, "train_workflow_id": train_workflow.id()} elif tool_name == "get_workflow_logs": workflow_id = arguments["workflow_id"] log_level_str = arguments.get("log_level", "INFO") collection_timeout = arguments.get("timeout", 10) # Map string to rf.LogLevel enum log_level_map = { "DEBUG": rf.LogLevel.DEBUG, "INFO": rf.LogLevel.INFO, "WARN": rf.LogLevel.WARN, "ERROR": rf.LogLevel.ERROR, } log_level = log_level_map.get(log_level_str, rf.LogLevel.INFO) # Get workflow and stream logs workflow = await self._conn.get_workflow(workflow_id) logs = [] async def collect_logs(): nonlocal logs async for log in workflow.logs(level=log_level): logs.append(str(log)) try: # Collect logs for specified duration await asyncio.wait_for(collect_logs(), timeout=collection_timeout) except asyncio.TimeoutError: # Expected - collected logs for specified duration pass # Build response with helpful messages result = { "workflow_id": workflow_id, "logs": logs, "count": len(logs), "log_level": log_level_str, } if len(logs) == 0: result["message"] = ( f"No {log_level_str} logs collected in {collection_timeout}s. Workflow may still be starting. Try waiting longer or increase timeout parameter." ) else: result["message"] = ( f"Collected {len(logs)} {log_level_str} logs in {collection_timeout}s. Call again to get more logs if workflow is still running." ) return result elif tool_name == "get_trained_model_id": workflow_id = arguments["workflow_id"] workflow = await self._conn.get_workflow(workflow_id) status = await workflow.status() if status not in {"completed", "finalized"}: return { "success": False, "message": f"Workflow is in '{status}' state. This tool only works on COMPLETED or FINALIZED workflows. Please wait for the workflow to complete.", "workflow_id": workflow_id, "status": status, } model = await workflow.models().last() return {"success": True, "workflow_id": workflow_id, "model_id": model.id} elif tool_name == "start_generation_workflow": model_id = arguments["model_id"] generate_rec = rl.steps.GenerateRecommender(self._conn, model=model_id) generate_builder = await generate_rec.builder() generate_workflow = await generate_builder.start(self._conn) return {"generation_workflow_id": generate_workflow.id()} elif tool_name == "obtain_synthetic_dataset_id": generation_workflow_id = arguments["generation_workflow_id"] generation_workflow = await self._conn.get_workflow(generation_workflow_id) status = await generation_workflow.status() if status not in {"completed", "finalized"}: return { "success": False, "message": f"Generation workflow is in '{status}' state. This tool only works on COMPLETED or FINALIZED workflows. Please wait for the workflow to complete.", "generation_workflow_id": generation_workflow_id, "status": status, } generated_dataset = await generation_workflow.datasets().last() return { "success": True, "generation_workflow_id": generation_workflow_id, "generated_dataset_id": generated_dataset.id, } elif tool_name == "plot_distribution": dataset_ids = arguments["dataset_ids"] column_name = arguments["column_name"] img_base64 = await plot_distribution(self._conn, dataset_ids, column_name) return { "image": img_base64, "mimeType": "image/png", "dataset_ids": dataset_ids, "column_name": column_name, } elif tool_name == "get_marginal_distribution_score": dataset_ids = arguments["dataset_ids"] dataset = await self._conn.get_dataset(dataset_ids[0]) dataset = await dataset.to_local(self._conn) real_columns = dataset.table.column_names synthetic = await self._conn.get_dataset(dataset_ids[1]) synthetic = await synthetic.to_local(self._conn) syn_columns = synthetic.table.column_names # Find common and different columns using set operations real_columns_set = set(real_columns) syn_columns_set = set(syn_columns) common_columns = list(real_columns_set & syn_columns_set) only_in_real = real_columns_set - syn_columns_set only_in_syn = syn_columns_set - real_columns_set # Select only common columns from both datasets dataset.table = dataset.table.select(common_columns) synthetic.table = synthetic.table.select(common_columns) # Build informative message about excluded columns if only_in_real or only_in_syn: msg_parts = [] if only_in_real: msg_parts.append( f"{', '.join(sorted(only_in_real))} only in real data" ) if only_in_syn: msg_parts.append( f"{', '.join(sorted(only_in_syn))} only in synthetic data" ) msg = f"Columns excluded from evaluation: {'; '.join(msg_parts)}" else: msg = "All columns match between datasets" marginal_dist_score = rl.metrics.marginal_dist_score(dataset, synthetic) # Check if score is NaN (happens when datasets have missing values) if math.isnan(marginal_dist_score): return { "success": False, "message": "Dataset contains missing values. Marginal distribution score does not currently support datasets with missing values.", "marginal_distribution_score": None, } else: return { "success": True, "message": msg, "marginal_distribution_score": marginal_dist_score, } # this tool has not been tested out yet - still experiemental and require udpates # For now, it is only for rf-tab-gan elif tool_name == "update_train_config": train_config_id = arguments["train_config_id"] updates = arguments["updates"] # Retrieve cached config if train_config_id not in self._cache: return { "success": False, "message": f"Config ID '{train_config_id}' not found in cache. It may have expired or already been used. Please call obtain_train_config again.", "train_config_id": train_config_id, } config_dict = self._cache[train_config_id].copy() # Detect model type if "tabular-gan" in config_dict: model_key = "tabular-gan" elif "doppelganger" in config_dict: model_key = "doppelganger" else: return { "success": False, "message": "Cannot determine model type from config. Supported model types: 'tabular-gan', 'doppelganger'", "train_config_id": train_config_id, } changes_applied = {} # Update model_config (hyperparameters) if "model_config" in updates: model_config = updates["model_config"] for field, value in model_config.items(): if field not in config_dict[model_key]: return { "success": False, "message": f"Field '{field}' not found in {model_key} config. Available fields: {list(config_dict[model_key].keys())}", "train_config_id": train_config_id, "invalid_field": field, } old_value = config_dict[model_key][field] config_dict[model_key][field] = value changes_applied[f"{model_key}.{field}"] = { "old": old_value, "new": value, } # Update encoder_config (field classifications) if "encoder_config" in updates: encoder_config = updates["encoder_config"] # Update metadata fields if "metadata" in encoder_config: metadata_updates = encoder_config["metadata"] metadata = config_dict["encoder"]["metadata"] for field_name, new_type in metadata_updates.items(): # Validate type valid_types = ["categorical", "continuous", "ignore"] if model_key == "doppelganger": valid_types.append("session") if new_type not in valid_types: return { "success": False, "message": f"Invalid type '{new_type}' for field '{field_name}'. Valid types: {valid_types}", "train_config_id": train_config_id, "field_name": field_name, "invalid_type": new_type, } # Find and update field in metadata list field_found = False for field_config in metadata: if field_config["field"] == field_name: old_type = field_config["type"] field_config["type"] = new_type changes_applied[f"encoder.metadata.{field_name}"] = { "old": old_type, "new": new_type, } field_found = True break if not field_found: available_fields = [f["field"] for f in metadata] return { "success": False, "message": f"Field '{field_name}' not found in encoder metadata. Available fields: {available_fields}", "train_config_id": train_config_id, "field_name": field_name, } # Update cache with modified config self._cache[train_config_id] = config_dict return { "success": True, "train_config_id": train_config_id, "changes_applied": changes_applied, "train_config": config_dict, } else: return { "success": False, "message": f"Unknown SDK tool: '{tool_name}'. This tool is not recognized by the SDK client.", "tool_name": tool_name, } # Helper functions # Workflow helpers def create_workflow(actions: list[rf.Action]) -> rf.WorkflowBuilder: """Create a workflow builder with a linear path of actions.""" builder = rf.WorkflowBuilder() builder.add_path(*actions) return builder async def get_local_dataset(conn, dataset_id: str) -> rf.dataset.LocalDataset: """Fetch a dataset and convert it to a local dataset.""" dataset = await conn.get_dataset(dataset_id) dataset = await dataset.to_local(conn) return dataset # Visualization helpers async def plot_distribution(conn, dataset_ids: list, column_name: str): """Plot distribution comparison between real and synthetic data for a given column.""" def _fig_to_base64(fig): """Convert a figure(plot) to a base64 string""" buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", dpi=100) img_str = base64.b64encode(buf.getbuffer()).decode("utf-8") buf.close() plt.close(fig.fig) # Close the underlying matplotlib figure to free memory return img_str # Load dataset and convert to LocalDataset dataset = await conn.get_dataset(dataset_ids[0]) dataset = await dataset.to_local(conn) synthetic = await conn.get_dataset(dataset_ids[1]) synthetic = await synthetic.to_local(conn) table = dataset.table field_type = table[column_name].type # Choose plot type based on data characteristics if pa.types.is_string(field_type): # Categorical/string data → bar plot fig = rf.labs.vis.plot_bar([dataset, synthetic], column_name) else: # Numerical data with enough rows → KDE plot fig = rf.labs.vis.plot_kde([dataset, synthetic], column_name) img_base64 = _fig_to_base64(fig) return img_base64 # Configuration helpers def guess_tab_gan_train_config(dataset) -> Tuple[ra.TrainTabGAN.Config, dict]: """Generate TabGAN training configuration with automatic column type detection.""" table = dataset.table columns = table.column_names high_cardinality_columns = [] categorical_columns = [] continuous_columns = [] for column in columns: dtype = str(table[column].type) # mode='only_valid' exclude null values nunique = pc.count_distinct(table[column], mode="only_valid").as_py() if dtype in {"string", "bool"}: if nunique <= 100: categorical_columns.append(column) else: # Cardinality > 100 is likely to cause OOM so for now, we ignore them in train config. # Later, we could do resampling or label encoder to handle them high_cardinality_columns.append(column) elif nunique <= 10: categorical_columns.append(column) else: continuous_columns.append(column) encoder_config = ra.TrainTabGAN.DatasetConfig( metadata=[ ra.TrainTabGAN.FieldConfig(field=col, type="categorical") for col in categorical_columns ] + [ ra.TrainTabGAN.FieldConfig(field=col, type="ignore") for col in high_cardinality_columns ] + [ ra.TrainTabGAN.FieldConfig(field=col, type="continuous") for col in continuous_columns ] ) model_config = ra.TrainTabGAN.TrainConfig(epochs=100) train_config = ra.TrainTabGAN.Config( encoder=encoder_config, tabular_gan=model_config ) # Return config and column metadata column_metadata = { "categorical_columns": categorical_columns, "continuous_columns": continuous_columns, "high_cardinality_columns": high_cardinality_columns, } return train_config, column_metadata

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/wolfdancer/rockfish-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server