#!/usr/bin/env python3
"""
PostgreSQL Database Functions
Function-based operations for Neon PostgreSQL database
"""
import psycopg2
from psycopg2.extras import RealDictCursor
import json
from urllib.parse import urlparse
def parse_db_url(db_url):
"""Parse PostgreSQL connection URL"""
# Handle full connection string with query params
if '?' in db_url:
base_url, query_string = db_url.split('?', 1)
parsed = urlparse(base_url)
# Extract sslmode and channel_binding from query string
query_params = {}
for param in query_string.split('&'):
if '=' in param:
key, value = param.split('=', 1)
query_params[key] = value
else:
parsed = urlparse(db_url)
query_params = {}
conn_params = {
'host': parsed.hostname,
'port': parsed.port or 5432,
'database': parsed.path.lstrip('/').split('?')[0],
'user': parsed.username,
'password': parsed.password
}
# Add SSL parameters if present
if 'sslmode' in query_params:
conn_params['sslmode'] = query_params['sslmode']
else:
conn_params['sslmode'] = 'require'
return conn_params
def get_connection(db_url):
"""Get PostgreSQL database connection"""
try:
conn_params = parse_db_url(db_url)
conn = psycopg2.connect(**conn_params)
return conn
except Exception as e:
raise Exception(f"Connection error: {str(e)}")
def execute_query(db_url, query, params=None):
"""Execute a SELECT query and return results"""
conn = None
try:
conn = get_connection(db_url)
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, params or [])
rows = cursor.fetchall()
result = [dict(row) for row in rows]
conn.close()
return {
"success": True,
"rows": result,
"count": len(result)
}
except Exception as e:
if conn:
conn.close()
return {"success": False, "error": str(e)}
def execute_write(db_url, query, params=None):
"""Execute INSERT, UPDATE, DELETE queries"""
conn = None
try:
conn = get_connection(db_url)
with conn.cursor() as cursor:
cursor.execute(query, params or [])
rows_affected = cursor.rowcount
conn.commit()
conn.close()
return {
"success": True,
"rows_affected": rows_affected,
"message": f"Query executed successfully. {rows_affected} row(s) affected."
}
except Exception as e:
if conn:
conn.rollback()
conn.close()
return {"success": False, "error": str(e)}
def list_tables(db_url):
"""List all tables in the database"""
conn = None
try:
conn = get_connection(db_url)
query = """
SELECT table_name, table_schema
FROM information_schema.tables
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
ORDER BY table_schema, table_name;
"""
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query)
rows = cursor.fetchall()
tables = [dict(row) for row in rows]
conn.close()
return {
"success": True,
"tables": tables,
"count": len(tables)
}
except Exception as e:
if conn:
conn.close()
return {"success": False, "error": str(e)}
def describe_table(db_url, table_name):
"""Get table schema information"""
conn = None
try:
conn = get_connection(db_url)
query = """
SELECT
column_name,
data_type,
character_maximum_length,
is_nullable,
column_default
FROM information_schema.columns
WHERE table_name = %s
ORDER BY ordinal_position;
"""
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query, (table_name,))
columns = [dict(row) for row in cursor.fetchall()]
conn.close()
if not columns:
return {"success": False, "error": f"Table '{table_name}' not found"}
return {
"success": True,
"table_name": table_name,
"columns": columns
}
except Exception as e:
if conn:
conn.close()
return {"success": False, "error": str(e)}
def get_table_count(db_url, table_name):
"""Get row count for a table"""
conn = None
try:
conn = get_connection(db_url)
query = f'SELECT COUNT(*) as count FROM "{table_name}"'
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(query)
result = cursor.fetchone()
conn.close()
return {
"success": True,
"table_name": table_name,
"count": result['count']
}
except Exception as e:
if conn:
conn.close()
return {"success": False, "error": str(e)}
def run_custom_sql(db_url, sql, params=None):
"""Execute custom SQL query (handles both SELECT and write operations)"""
sql_upper = sql.strip().upper()
# Check if it's a SELECT query
if sql_upper.startswith('SELECT') or sql_upper.startswith('WITH'):
return execute_query(db_url, sql, params)
else:
# INSERT, UPDATE, DELETE, etc.
return execute_write(db_url, sql, params)