from typing import Dict, Any, Optional
import pymssql
from contextlib import contextmanager
from ..interfaces.database_adapter import DatabaseAdapter
class MSSQLAdapter(DatabaseAdapter):
def __init__(self, config: Dict[str, Any]):
self.server = config.get("server", "localhost")
self.port = int(config.get("port") or 1433)
self.user = config.get("user", "sa")
self.password = config.get("password", "")
self.database = config.get("database", "master")
@contextmanager
def get_connection(self):
# pymssql doesn't have a built-in pool like mysql-connector, normally we'd wrap it similarly
# For simplicity in this v1, we create a fresh connection.
conn = pymssql.connect(
server=self.server,
port=self.port,
user=self.user,
password=self.password,
database=self.database,
charset='UTF-8',
as_dict=False
)
try:
yield conn
finally:
conn.close()
def execute_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
with self.get_connection() as conn:
cursor = conn.cursor()
try:
cursor.execute(query, params or {})
if cursor.description:
columns = [col[0] for col in cursor.description]
rows = cursor.fetchall()
return {
"columns": columns,
"rows": [list(row) for row in rows],
"row_count": len(rows),
"affected_rows": 0
}
else:
conn.commit()
return {
"columns": [],
"rows": [],
"row_count": 0,
"affected_rows": cursor.rowcount
}
finally:
cursor.close()
def list_tables(self) -> list[str]:
query = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE'"
result = self.execute_query(query)
# Return format: schema.table
return [f"{row[0]}.{row[1]}" for row in result["rows"]]
def describe_table(self, table_name: str) -> Dict[str, Any]:
if "." in table_name:
schema, name = table_name.split(".", 1)
query = f"SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{name}' AND TABLE_SCHEMA = '{schema}'"
else:
query = f"SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{table_name}'"
result = self.execute_query(query)
return {"columns": result["rows"]}
def get_version(self) -> str:
res = self.execute_query("SELECT @@VERSION")
return res["rows"][0][0]