from typing import Dict, Any, Optional
import pyodbc
from contextlib import contextmanager
from ..interfaces.database_adapter import DatabaseAdapter
class SybaseAdapter(DatabaseAdapter):
def __init__(self, config: Dict[str, Any]):
self.dsn = config.get("dsn")
# If DSN is not provided, try to build a connection string
# This is strictly for Sybase SQL Anywhere / SAP IQ type configurations based on previous context
if not self.dsn:
host = config.get("host", "localhost")
port = config.get("port", 2638)
server_name = config.get("server_name") # specific to Sybase often
db_name = config.get("database")
uid = config.get("user", "dba")
pwd = config.get("password", "sql")
driver = config.get("driver", "SQL Anywhere 17")
# Simple construction, real world might need more complex string building
parts = [f"Driver={{{driver}}}"]
if uid: parts.append(f"UID={uid}")
if pwd: parts.append(f"PWD={pwd}")
if host: parts.append(f"Host={host}:{port}")
if server_name: parts.append(f"ServerName={server_name}")
if db_name: parts.append(f"DatabaseName={db_name}")
parts.append("Charset=UTF-8")
self.connection_string = ";".join(parts)
else:
self.connection_string = f"DSN={self.dsn}"
if "user" in config: self.connection_string += f";UID={config['user']}"
if "password" in config: self.connection_string += f";PWD={config['password']}"
@contextmanager
def get_connection(self):
# Sybase via ODBC
conn = pyodbc.connect(self.connection_string)
try:
yield conn
finally:
conn.close()
def execute_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
with self.get_connection() as conn:
cursor = conn.cursor()
try:
# pyodbc uses ? for placeholders usually, but depends on driver.
# If params is dict, it might need conversion if the driver expects positional.
# For simplicity here we assume the query + params match the driver expectation or are injected safely before.
if params and isinstance(params, dict):
# Basic named param to question mark handling could be needed here usually
# But if the user sends raw SQL valid for sybase, we might just pass it.
# Warning: pyodbc generally prefers sequence for params.
cursor.execute(query, list(params.values()))
elif params and isinstance(params, list):
cursor.execute(query, params)
else:
cursor.execute(query)
if cursor.description:
columns = [col[0] for col in cursor.description]
rows = cursor.fetchall()
return {
"columns": columns,
"rows": [list(row) for row in rows],
"row_count": len(rows),
"affected_rows": 0
}
else:
conn.commit()
return {
"columns": [],
"rows": [],
"row_count": 0,
"affected_rows": cursor.rowcount
}
finally:
cursor.close()
def list_tables(self) -> list[str]:
# Sybase Anywhere - Join sys.systable with sys.sysuser to get creator name
query = """
SELECT u.user_name, t.table_name
FROM sys.systable t
JOIN sys.sysuser u ON t.creator = u.user_id
WHERE t.table_type = 'BASE'
ORDER BY u.user_name, t.table_name
"""
try:
result = self.execute_query(query)
# Return format: owner.table
return [f"{row[0]}.{row[1]}" for row in result["rows"]]
except:
# Fallback standard
query = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE'"
result = self.execute_query(query)
return [f"{row[0]}.{row[1]}" for row in result["rows"]]
def describe_table(self, table_name: str) -> Dict[str, Any]:
if "." in table_name:
owner, name = table_name.split(".", 1)
query = f"""
SELECT c.column_name, d.domain_name, c.width, c.nulls
FROM sys.syscolumn c
JOIN sys.systable t ON c.table_id = t.table_id
JOIN sys.sysuser u ON t.creator = u.user_id
LEFT JOIN sys.sysdomain d ON c.domain_id = d.domain_id
WHERE t.table_name = '{name}' AND u.user_name = '{owner}'
"""
else:
# Fallback if no owner provided, though list_tables returns it
query = f"""
SELECT c.column_name, d.domain_name, c.width, c.nulls
FROM sys.syscolumn c
JOIN sys.systable t ON c.table_id = t.table_id
LEFT JOIN sys.sysdomain d ON c.domain_id = d.domain_id
WHERE t.table_name = '{table_name}'
"""
result = self.execute_query(query)
return {"columns": result["rows"]}
def get_version(self) -> str:
result = self.execute_query("SELECT @@VERSION")
return result["rows"][0][0]