from typing import Dict, Any, Optional
import oracledb
from contextlib import contextmanager
from ..interfaces.database_adapter import DatabaseAdapter
class OracleAdapter(DatabaseAdapter):
def __init__(self, config: Dict[str, Any]):
self.user = config.get("user", "system")
self.password = config.get("password", "")
self.dsn = config.get("dsn") # Data Source Name (e.g., localhost:1521/XEPDB1)
# Fallback if DSN is not provided directly but parts are
if not self.dsn:
host = config.get("host", "localhost")
port = config.get("port", 1521)
service_name = config.get("database", "ORCL") # Usually service name maps to database config
self.dsn = f"{host}:{port}/{service_name}"
# Oracle Thick mode might be needed for some legacy setups, but Thin (default in 2.0+) works for most.
# We stick to default Thin mode.
@contextmanager
def get_connection(self):
conn = oracledb.connect(
user=self.user,
password=self.password,
dsn=self.dsn
)
try:
yield conn
finally:
conn.close()
def execute_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
with self.get_connection() as conn:
cursor = conn.cursor()
try:
# Oracle params are usually :name or :1
# If params is a dict, oracledb handles :key.
cursor.execute(query, params or {})
if cursor.description:
columns = [col[0] for col in cursor.description]
rows = cursor.fetchall()
return {
"columns": columns,
"rows": [list(row) for row in rows],
"row_count": len(rows),
"affected_rows": 0
}
else:
conn.commit()
return {
"columns": [],
"rows": [],
"row_count": 0,
"affected_rows": cursor.rowcount
}
finally:
cursor.close()
def list_tables(self) -> list[str]:
# Use ALL_TABLES to see tables from other schemas (owners) if accessible
query = "SELECT OWNER, TABLE_NAME FROM ALL_TABLES ORDER BY OWNER, TABLE_NAME"
result = self.execute_query(query)
return [f"{row[0]}.{row[1]}" for row in result["rows"]]
def describe_table(self, table_name: str) -> Dict[str, Any]:
if "." in table_name:
owner, name = table_name.split(".", 1)
query = f"""
SELECT COLUMN_NAME, DATA_TYPE, DATA_LENGTH, NULLABLE
FROM ALL_TAB_COLUMNS
WHERE TABLE_NAME = '{name.upper()}' AND OWNER = '{owner.upper()}'
ORDER BY COLUMN_ID
"""
else:
query = f"""
SELECT COLUMN_NAME, DATA_TYPE, DATA_LENGTH, NULLABLE
FROM USER_TAB_COLUMNS
WHERE TABLE_NAME = '{table_name.upper()}'
ORDER BY COLUMN_ID
"""
result = self.execute_query(query)
return {"columns": result["rows"]}
def get_version(self) -> str:
with self.get_connection() as conn:
return conn.version