Skip to main content
Glama

AgentMode

Official
by agentmode
database.py14.6 kB
import asyncio from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from sqlalchemy import text, create_engine import pandas as pd from agentmode.logs import logger subclasses = {} @dataclass class DatabaseConnection: settings: dict threadpool: ThreadPoolExecutor = None create_engine_kwargs: dict = field(default_factory=dict) # Use default_factory for mutable default def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) subclasses[cls.platform] = cls @classmethod def create(cls, platform, settings, **kwargs): if platform not in subclasses.keys(): raise ValueError('Unknown DatabaseConnection platform {}'.format(platform)) instance = subclasses[platform](settings=settings, **kwargs) if instance.driver_type == 'sync': instance.threadpool = ThreadPoolExecutor() return instance def is_read_only_query(self, query_string: str, method="blocklist") -> bool: """ Determine if the query is a read-only query. There are 2 approaches: use a blocklist or an allowlist of query types the blocklist is easier to maintain for all database types, but the allowlist is more secure uses the sqlglot library for fast query parsing """ from sqlglot import parse, parse_one, exp try: if method == "blocklist": expression_tree = parse_one(query_string) # Define the types of expressions that indicate a write operation write_operations = ( exp.Insert, exp.Update, exp.Delete, exp.Create, exp.Drop, exp.Alter, exp.Truncate, ) # Check if the root expression is a write operation if isinstance(expression_tree, write_operations): return False # Traverse the AST and check for any write operation expressions for node, _, _ in expression_tree.walk(): if isinstance(node, write_operations): return False # If no write operations are found, the query is read-only return True elif method == "allowlist": # only allows SELECT, ANALYZE, VACUUM, EXPLAIN SELECT, and SHOW queries query = parse(query_string) if not query: return False for q in query: if q['type'] == 'SELECT': return True elif q['type'] == 'SHOW': return True elif q['type'] == 'ANALYZE': return True elif q['type'] == 'VACUUM': return True elif q['type'] == 'EXPLAIN': if q['subquery']: return True except Exception as e: logger.error(f"Error parsing query: {e}") return False return False async def connect(self): """ Initialize the database connection. we do this in a separate method so we can return a boolean value on success/failure """ try: if self.driver_type == 'async': self.engine = create_async_engine(self.database_uri, echo=False, future=True, **self.create_engine_kwargs) # echo is False, as we don't want to log to stdout (it interferes with MCP on stdio) self.session_factory = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) elif self.driver_type == 'sync': def sync_connect(): self.engine = create_engine(self.database_uri, echo=False, future=True, **self.create_engine_kwargs) self.connection = self.engine.connect() await asyncio.get_event_loop().run_in_executor(self.threadpool, sync_connect) return True except Exception as e: logger.error(f"Failed to create database connection: {e}") return False async def query(self, query_string: str) -> pd.DataFrame: try: if self.settings.get('read_only', True): if not self.is_read_only_query(query_string): logger.error(f"Query is not read-only, not executing: {query_string}") return False, None logger.debug(f"Executing query: {query_string}") if self.driver_interface == 'sqlalchemy': if self.driver_type == 'async': async with self.session_factory() as session: result = await session.execute(text(query_string)) # Convert result to pandas DataFrame df = pd.DataFrame(result.fetchall(), columns=result.keys()) return True, df elif self.driver_type == 'sync': def sync_query(): result = self.connection.execute(text(query_string)) return pd.DataFrame(result.fetchall(), columns=result.keys()) df = await asyncio.get_event_loop().run_in_executor(self.threadpool, sync_query) return True, df elif self.driver_interface == 'dbapi': if self.driver_type == 'sync': cur = self.connection.cursor() cur.execute(query_string) rows = cur.fetchall() columns = [desc[0] for desc in cur.description] df = pd.DataFrame(rows, columns=columns) cur.close() return True, df elif self.driver_type == 'async': cur = await self.connection.cursor() await cur.execute(query_string) rows = await cur.fetchall() columns = [desc[0] for desc in await cur.description()] df = pd.DataFrame(rows, columns=columns) await cur.close() return True, df except Exception as e: # Log the error and re-raise it logger.error(f"Query failed: {e}") return False, None async def disconnect(self): if self.driver_interface == 'sqlalchemy': if self.driver_type == 'async': await self.engine.dispose() else: def sync_disconnect(): self.connection.close() self.engine.dispose() await asyncio.get_event_loop().run_in_executor(self.threadpool, sync_disconnect) elif self.driver_interface == 'dbapi': if self.driver_type == 'async': await self.connection.close() else: def sync_disconnect(): self.connection.close() await asyncio.get_event_loop().run_in_executor(self.threadpool, sync_disconnect) async def generate_mcp_resources_and_tools(self, connection_name, mcp, connection_name_counter, connection_mapping): """ Generate MCP resources and tools based on the database connection. """ try: # Check if the connection name already exists in the mapping tool_name = f"database_query_{connection_name}" # Increment the counter for the connection name connection_name_counter[tool_name] += 1 if connection_name_counter[tool_name] > 1: tool_name = f"{tool_name}_{connection_name_counter[tool_name]}" connection_mapping[tool_name] = self tool_name = f"{tool_name}_{connection_name_counter[tool_name]}" connection_mapping[tool_name] = self # Define a function dynamically using a closure def create_dynamic_tool(fn_name): async def dynamic_tool(query: str) -> str: """Run a database query on the connection.""" db = connection_mapping.get(fn_name) if not db: logger.error(f"No database connection found for tool: {fn_name}") return None try: logger.debug(f"Executing query: {query} in dynamic tool") success_flag, result = await db.query(query) if success_flag: # convert the result pandas dataframe to a list of dictionaries result = result.to_dict('records') logger.debug(f"Query result: {result}") # we don't need to convert to JSON string here, as the mcp server will handle it for us return result else: logger.error(f"Query execution failed for {fn_name}") return 'error' except Exception as e: logger.error(f"Error executing query: {e}") return 'error' + str(e) return dynamic_tool # Create the dynamic tool function with the tool name tool_function = create_dynamic_tool(tool_name) # Register the function as an MCP tool tool_function.__name__ = tool_name tool_function.__doc__ = f"Run a query on the {connection_name} database." mcp.tool()(tool_function) except Exception as e: logger.error(f"Error generating MCP resources and tools: {e}") return None @dataclass class PostgreSQLConnection(DatabaseConnection): platform = 'postgresql' driver_type = 'async' driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"postgresql+asyncpg://{self.settings['username']}:{self.settings['password']}@{self.settings['host']}:{self.settings['port']}/{self.settings['database_name']}" @dataclass class MySQLConnection(DatabaseConnection): platform = 'mysql' driver_type = 'async' driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"mysql+aiomysql://{self.settings['username']}:{self.settings['password']}@{self.settings['host']}:{self.settings['port']}/{self.settings['database_name']}" @dataclass class OracleDBConnection(DatabaseConnection): platform = 'oracle' driver_type = 'async' driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"oracle+oracledb://{self.settings['username']}:{self.settings['password']}@{self.settings['host']}:{self.settings['port']}/?service_name={self.settings['service_name']}" @dataclass class SQLServerConnection(DatabaseConnection): platform = 'sqlserver' driver_type = 'async' driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"mssql+aioodbc://{self.settings['username']}:{self.settings['password']}@{self.settings['host']}:{self.settings['port']}/{self.settings['database_name']}?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes" @dataclass class HiveConnection(DatabaseConnection): platform = 'hive' driver_type = 'sync' driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"hive+https://{self.settings['username']}:{self.settings['password']}@{self.settings['host']}:{self.settings['port']}/" @dataclass class PrestoConnection(DatabaseConnection): """ PyHive via SQLAlchemy also supports Presto, but it is no longer maintained so we use https://github.com/prestodb/presto-python-client via DBAPI """ platform = 'presto' driver_type = 'sync' driver_interface = 'dbapi' async def connect(self): """ Initialize the database connection. we do this in a separate method so we can return a boolean value on success/failure """ try: def sync_connect(): import prestodb credentials = { 'host': self.settings['host'], 'port': self.settings['port'], 'user': self.settings['username'], 'catalog': self.settings.get('catalog', 'hive'), 'schema': self.settings.get('schema', 'default'), } if self.settings.get('username') and self.settings.get('password'): auth=prestodb.auth.BasicAuthentication(self.settings['username'], self.settings['password']) credentials['auth'] = auth self.connection = prestodb.dbapi.connect(**credentials) await asyncio.get_event_loop().run_in_executor(self.threadpool, sync_connect) return True except Exception as e: logger.error(f"Failed to create database connection: {e}") return False @dataclass class Trino(DatabaseConnection): """ """ platform = 'trino' driver_type = 'async' driver_interface = 'dbapi' async def connect(self): try: import aiotrino credentials = { 'host': self.settings['host'], 'port': self.settings['port'], 'user': self.settings['username'], 'catalog': self.settings.get('catalog', 'hive'), 'schema': self.settings.get('schema', 'default'), } if self.settings.get('username') and self.settings.get('password'): auth=auth=aiotrino.auth.BasicAuthentication(self.settings['username'], self.settings['password']) credentials['auth'] = auth self.connection = aiotrino.dbapi.connect(**credentials) return True except Exception as e: logger.error(f"Failed to create database connection: {e}") return False @dataclass class Snowflake(DatabaseConnection): platform = 'snowflake' driver_type = 'sync' # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/218 driver_interface = 'sqlalchemy' def __post_init__(self): from snowflake.sqlalchemy import URL credentials = {key: self.settings.get(key) for key in ['account', 'user', 'password', 'database', 'schema', 'warehouse', 'role', 'timezone'] if key in self.settings} self.database_uri = URL(**self.credentials) @dataclass class BigQuery(DatabaseConnection): platform = 'bigquery' driver_type = 'sync' # https://github.com/googleapis/python-bigquery-sqlalchemy/issues/1071 driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"bigquery://{self.settings['project_name']}" self.create_engine_kwargs = { 'credentials_path': self.settings.get('credentials_path', None), } @dataclass class Clickhouse(DatabaseConnection): platform = 'clickhouse' driver_type = 'sync' driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"clickhouse+{self.settings['protocol']}://{self.settings['user']}:{self.settings['password']}@{self.settings['host']}:{self.settings['port']}/{self.settings['database']}" @dataclass class Databricks(DatabaseConnection): platform = 'databricks' driver_type = 'sync' driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"databricks://token:{self.settings['access_token']}@{self.settings['host']}?http_path={self.settings['http_path']}&catalog={self.settings['catalog']}&schema={self.settings['schema']}" @dataclass class SAPHana(DatabaseConnection): platform = 'sap_hana' driver_type = 'sync' driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"hana://{self.settings['username']}:{self.settings['password']}@{self.settings['host']}:{self.settings['port']}" if self.settings.get('database_name'): self.database_uri += f'/{self.settings['database_name']}' @dataclass class Teradata(DatabaseConnection): platform = 'teradata' driver_type = 'sync' driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"teradatasql://{self.settings['username']}:{self.settings['password']}@{self.settings['host']}" @dataclass class CockroachDB(DatabaseConnection): platform = 'cockroachdb' driver_type = 'async' driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"cockroachdb+asyncpg://{self.settings['username']}@{self.settings['host']}:{self.settings['port']}/defaultdb" @dataclass class AWSAthena(DatabaseConnection): platform = 'aws_athena' driver_type = 'sync' driver_interface = 'sqlalchemy' def __post_init__(self): self.database_uri = f"awsathena+rest://{self.settings['aws_access_key_id']}:{self.settings['aws_secret_access_key']}@athena.{self.settings['region_name']}.amazonaws.com:443/{self.settings['schema_name']}?s3_staging_dir={self.settings['s3_staging_dir']}"

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/agentmode/server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server