connection.go•6.11 kB
package timescale
import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"time"
"github.com/FreePeak/db-mcp-server/pkg/db"
"github.com/FreePeak/db-mcp-server/pkg/logger"
)
// DB represents a TimescaleDB database connection
type DB struct {
db.Database // Embed standard Database interface
config DBConfig // TimescaleDB-specific configuration
extVersion string // TimescaleDB extension version
isTimescaleDB bool // Whether the database supports TimescaleDB
}
// NewTimescaleDB creates a new TimescaleDB connection
func NewTimescaleDB(config DBConfig) (*DB, error) {
// Initialize PostgreSQL connection
pgDB, err := db.NewDatabase(config.PostgresConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize PostgreSQL connection: %w", err)
}
return &DB{
Database: pgDB,
config: config,
}, nil
}
// Connect establishes a connection and verifies TimescaleDB availability
func (t *DB) Connect() error {
// Connect to PostgreSQL
if err := t.Database.Connect(); err != nil {
return err
}
// Check for TimescaleDB extension
if t.config.UseTimescaleDB {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
var version string
err := t.Database.QueryRow(ctx, "SELECT extversion FROM pg_extension WHERE extname = 'timescaledb'").Scan(&version)
if err != nil {
if err == sql.ErrNoRows {
// Skip logging in tests to avoid nil pointer dereference
if !isTestEnvironment() {
logger.Warn("TimescaleDB extension not found in database. Features will be disabled.")
}
t.isTimescaleDB = false
// Don't return error, just disable TimescaleDB features
return nil
}
return fmt.Errorf("failed to check TimescaleDB extension: %w", err)
}
t.extVersion = version
t.isTimescaleDB = true
// Skip logging in tests to avoid nil pointer dereference
if !isTestEnvironment() {
logger.Info("Connected to TimescaleDB %s", version)
}
}
return nil
}
// isTestEnvironment returns true if the code is running in a test environment
func isTestEnvironment() bool {
for _, arg := range os.Args {
if strings.HasPrefix(arg, "-test.") {
return true
}
}
return false
}
// Close closes the database connection
func (t *DB) Close() error {
return t.Database.Close()
}
// ExtVersion returns the TimescaleDB extension version
func (t *DB) ExtVersion() string {
return t.extVersion
}
// IsTimescaleDB returns true if the database has TimescaleDB extension installed
func (t *DB) IsTimescaleDB() bool {
return t.isTimescaleDB
}
// ApplyConfig applies TimescaleDB-specific configuration options
func (t *DB) ApplyConfig() error {
if !t.isTimescaleDB {
return fmt.Errorf("cannot apply TimescaleDB configuration: TimescaleDB extension not available")
}
// No global configuration to apply for now
return nil
}
// ExecuteSQLWithoutParams executes a SQL query without parameters and returns a result
func (t *DB) ExecuteSQLWithoutParams(ctx context.Context, query string) (interface{}, error) {
// For non-SELECT queries (that don't return rows), use Exec
if !isSelectQuery(query) {
result, err := t.Database.Exec(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
return result, nil
}
// For SELECT queries, process rows into a map
rows, err := t.Database.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
// Log the error or append it to the returned error
if !isTestEnvironment() {
logger.Error("Failed to close rows: %v", closeErr)
}
}
}()
return processRows(rows)
}
// ExecuteSQL executes a SQL query with parameters and returns a result
func (t *DB) ExecuteSQL(ctx context.Context, query string, args ...interface{}) (interface{}, error) {
// For non-SELECT queries (that don't return rows), use Exec
if !isSelectQuery(query) {
result, err := t.Database.Exec(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
return result, nil
}
// For SELECT queries, process rows into a map
rows, err := t.Database.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
// Log the error or append it to the returned error
if !isTestEnvironment() {
logger.Error("Failed to close rows: %v", closeErr)
}
}
}()
return processRows(rows)
}
// Helper function to check if a query is a SELECT query
func isSelectQuery(query string) bool {
// Simple check for now - could be made more robust
for i := 0; i < len(query); i++ {
if query[i] == ' ' || query[i] == '\t' || query[i] == '\n' {
continue
}
return i+6 <= len(query) && (query[i:i+6] == "SELECT" || query[i:i+6] == "select")
}
return false
}
// Helper function to process rows into a map
func processRows(rows *sql.Rows) ([]map[string]interface{}, error) {
// Get column names
columns, err := rows.Columns()
if err != nil {
return nil, err
}
// Create a slice of results
var results []map[string]interface{}
// Create a slice of interface{} to hold the values
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
// Loop through rows
for rows.Next() {
// Set up pointers to each interface{} value
for i := range values {
valuePtrs[i] = &values[i]
}
// Scan the result into the values
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
// Create a map for this row
row := make(map[string]interface{})
for i, col := range columns {
if values[i] == nil {
row[col] = nil
} else {
// Try to handle different types appropriately
switch v := values[i].(type) {
case []byte:
row[col] = string(v)
default:
row[col] = v
}
}
}
results = append(results, row)
}
// Check for errors after we're done iterating
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}