manager.pyโข33.7 kB
#!/usr/bin/env python3
"""
Database Manager
Manages database connections, sessions, and operations.
"""
import os
import sqlite3
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from datetime import datetime, timedelta
from pathlib import Path
from sqlalchemy import (
create_engine,
text,
func,
and_,
or_,
desc,
asc,
)
from sqlalchemy.orm import (
sessionmaker,
Session,
Query,
)
from sqlalchemy.exc import (
SQLAlchemyError,
IntegrityError,
NoResultFound,
)
from sqlalchemy.engine import Engine
import structlog
from .models import (
Base,
User,
APIKey,
DocumentSource,
Document,
SearchIndex,
SyncLog,
Configuration,
DocumentStatus,
AdapterType,
UserRole,
DocumentInfo,
DocumentSourceInfo,
SearchResult,
create_database_engine,
create_tables,
create_session_factory,
)
logger = structlog.get_logger(__name__)
T = TypeVar('T', bound=Base)
class DatabaseError(Exception):
"""Base database error."""
pass
class DatabaseConnectionError(DatabaseError):
"""Database connection error."""
pass
class DatabaseIntegrityError(DatabaseError):
"""Database integrity error."""
pass
class DatabaseManager:
"""Database manager for AnyDocs MCP."""
def __init__(
self,
database_url: str,
echo: bool = False,
pool_size: int = 10,
max_overflow: int = 20,
):
"""
Initialize database manager.
Args:
database_url: Database connection URL
echo: Enable SQL query logging
pool_size: Connection pool size
max_overflow: Maximum pool overflow
"""
self.database_url = database_url
self.echo = echo
try:
# Create engine
engine_kwargs = {
'echo': echo,
'pool_pre_ping': True,
'pool_recycle': 3600,
}
# SQLite-specific settings
if database_url.startswith('sqlite'):
engine_kwargs.update({
'pool_size': 1,
'connect_args': {
'check_same_thread': False,
'timeout': 30,
}
})
else:
engine_kwargs.update({
'pool_size': pool_size,
'max_overflow': max_overflow,
})
self.engine = create_database_engine(database_url, **engine_kwargs)
self.SessionLocal = create_session_factory(self.engine)
logger.info(
"Database manager initialized",
database_type=self._get_database_type(),
echo=echo
)
except Exception as e:
logger.error("Failed to initialize database manager", error=str(e))
raise DatabaseConnectionError(f"Failed to connect to database: {e}")
def _get_database_type(self) -> str:
"""Get database type from URL."""
return self.database_url.split('://')[0] if '://' in self.database_url else 'unknown'
def initialize_database(self, force: bool = False) -> None:
"""
Initialize database schema.
Args:
force: Force recreation of tables
"""
try:
if force:
logger.warning("Dropping all tables")
Base.metadata.drop_all(self.engine)
# Create tables
create_tables(self.engine)
# Enable SQLite FTS if using SQLite
if self.database_url.startswith('sqlite'):
self._setup_sqlite_fts()
# Create default configurations
self._create_default_configurations()
logger.info("Database initialized successfully")
except Exception as e:
logger.error("Failed to initialize database", error=str(e))
raise DatabaseError(f"Failed to initialize database: {e}")
def _setup_sqlite_fts(self) -> None:
"""Setup SQLite full-text search."""
try:
with self.engine.connect() as conn:
# Create FTS virtual table for documents
conn.execute(text("""
CREATE VIRTUAL TABLE IF NOT EXISTS documents_fts USING fts5(
document_id,
title,
content,
keywords,
content='search_index',
content_rowid='rowid'
)
"""))
# Create triggers to keep FTS in sync
conn.execute(text("""
CREATE TRIGGER IF NOT EXISTS search_index_ai AFTER INSERT ON search_index BEGIN
INSERT INTO documents_fts(document_id, title, content, keywords)
VALUES (new.document_id, new.title, new.content, new.keywords);
END
"""))
conn.execute(text("""
CREATE TRIGGER IF NOT EXISTS search_index_ad AFTER DELETE ON search_index BEGIN
DELETE FROM documents_fts WHERE document_id = old.document_id;
END
"""))
conn.execute(text("""
CREATE TRIGGER IF NOT EXISTS search_index_au AFTER UPDATE ON search_index BEGIN
DELETE FROM documents_fts WHERE document_id = old.document_id;
INSERT INTO documents_fts(document_id, title, content, keywords)
VALUES (new.document_id, new.title, new.content, new.keywords);
END
"""))
conn.commit()
logger.info("SQLite FTS setup completed")
except Exception as e:
logger.warning("Failed to setup SQLite FTS", error=str(e))
def _create_default_configurations(self) -> None:
"""Create default system configurations."""
default_configs = [
{
'key': 'system.version',
'value': '1.0.0',
'description': 'System version',
'is_system': True,
},
{
'key': 'system.initialized_at',
'value': datetime.utcnow().isoformat(),
'description': 'System initialization timestamp',
'is_system': True,
},
{
'key': 'search.max_results',
'value': 100,
'description': 'Maximum search results per query',
'is_system': False,
},
{
'key': 'sync.default_interval',
'value': 3600,
'description': 'Default sync interval in seconds',
'is_system': False,
},
]
with self.get_session() as session:
for config_data in default_configs:
existing = session.query(Configuration).filter_by(key=config_data['key']).first()
if not existing:
config = Configuration(**config_data)
session.add(config)
session.commit()
@contextmanager
def get_session(self):
"""Get database session context manager."""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error("Database session error", error=str(e))
raise
finally:
session.close()
def health_check(self) -> Dict[str, Any]:
"""Perform database health check."""
try:
with self.get_session() as session:
# Test basic query
result = session.execute(text("SELECT 1")).scalar()
# Get table counts
counts = {}
for table_name in ['users', 'document_sources', 'documents', 'api_keys']:
try:
count = session.execute(text(f"SELECT COUNT(*) FROM {table_name}")).scalar()
counts[table_name] = count
except Exception:
counts[table_name] = -1
return {
'status': 'healthy',
'database_type': self._get_database_type(),
'connection_test': result == 1,
'table_counts': counts,
'timestamp': datetime.utcnow().isoformat(),
}
except Exception as e:
logger.error("Database health check failed", error=str(e))
return {
'status': 'unhealthy',
'error': str(e),
'timestamp': datetime.utcnow().isoformat(),
}
# User management
def create_user(
self,
username: str,
email: str,
password_hash: Optional[str] = None,
full_name: Optional[str] = None,
role: UserRole = UserRole.USER,
**kwargs
) -> User:
"""Create a new user."""
try:
with self.get_session() as session:
user = User(
username=username,
email=email,
password_hash=password_hash,
full_name=full_name,
role=role.value,
**kwargs
)
session.add(user)
session.flush() # Get ID
logger.info("User created", user_id=user.id, username=username)
return user
except IntegrityError as e:
logger.error("User creation failed - integrity error", username=username, error=str(e))
raise DatabaseIntegrityError(f"User with username '{username}' or email '{email}' already exists")
except Exception as e:
logger.error("User creation failed", username=username, error=str(e))
raise DatabaseError(f"Failed to create user: {e}")
def get_user(self, user_id: Optional[str] = None, username: Optional[str] = None, email: Optional[str] = None) -> Optional[User]:
"""Get user by ID, username, or email."""
try:
with self.get_session() as session:
query = session.query(User)
if user_id:
query = query.filter(User.id == user_id)
elif username:
query = query.filter(User.username == username)
elif email:
query = query.filter(User.email == email)
else:
raise ValueError("Must provide user_id, username, or email")
return query.first()
except Exception as e:
logger.error("Failed to get user", user_id=user_id, username=username, error=str(e))
return None
def update_user(self, user_id: str, **kwargs) -> Optional[User]:
"""Update user."""
try:
with self.get_session() as session:
user = session.query(User).filter(User.id == user_id).first()
if not user:
return None
for key, value in kwargs.items():
if hasattr(user, key):
setattr(user, key, value)
user.updated_at = datetime.utcnow()
session.flush()
logger.info("User updated", user_id=user_id)
return user
except Exception as e:
logger.error("Failed to update user", user_id=user_id, error=str(e))
raise DatabaseError(f"Failed to update user: {e}")
# Document source management
def create_document_source(
self,
name: str,
adapter_type: AdapterType,
base_url: str,
created_by_id: str,
auth_config: Optional[Dict] = None,
adapter_config: Optional[Dict] = None,
**kwargs
) -> DocumentSource:
"""Create a new document source."""
try:
with self.get_session() as session:
source = DocumentSource(
name=name,
adapter_type=adapter_type.value,
base_url=base_url,
created_by_id=created_by_id,
auth_config=auth_config,
adapter_config=adapter_config,
**kwargs
)
session.add(source)
session.flush()
logger.info("Document source created", source_id=source.id, name=name)
return source
except IntegrityError as e:
logger.error("Document source creation failed - integrity error", name=name, error=str(e))
raise DatabaseIntegrityError(f"Document source with name '{name}' already exists for this user")
except Exception as e:
logger.error("Document source creation failed", name=name, error=str(e))
raise DatabaseError(f"Failed to create document source: {e}")
def get_document_sources(
self,
user_id: Optional[str] = None,
is_active: Optional[bool] = None,
active_only: bool = False,
adapter_type: Optional[AdapterType] = None
) -> List[DocumentSource]:
"""Get document sources."""
try:
with self.get_session() as session:
query = session.query(DocumentSource)
if user_id:
query = query.filter(DocumentSource.created_by_id == user_id)
if is_active is not None:
query = query.filter(DocumentSource.is_active == is_active)
if active_only:
query = query.filter(DocumentSource.enabled == True)
if adapter_type:
query = query.filter(DocumentSource.adapter_type == adapter_type.value)
sources = query.order_by(DocumentSource.created_at.desc()).all()
# Create detached objects to avoid session issues
detached_sources = []
for source in sources:
detached_source = DocumentSource(
id=source.id,
name=source.name,
adapter_type=source.adapter_type,
base_url=source.base_url,
auth_config=source.auth_config,
adapter_config=source.adapter_config,
sync_enabled=source.sync_enabled,
sync_interval=source.sync_interval,
last_sync_at=source.last_sync_at,
next_sync_at=source.next_sync_at,
is_active=source.is_active,
enabled=source.enabled,
status=source.status,
error_message=source.error_message,
extra_metadata=source.extra_metadata,
created_by_id=source.created_by_id,
created_at=source.created_at,
updated_at=source.updated_at
)
detached_sources.append(detached_source)
return detached_sources
except Exception as e:
logger.error("Failed to get document sources", error=str(e))
return []
def get_documents_by_source(self, source_id: str) -> List[Document]:
"""Get all documents for a specific source."""
try:
with self.get_session() as session:
return session.query(Document).filter(Document.source_id == source_id).all()
except Exception as e:
logger.error("Failed to get documents by source", source_id=source_id, error=str(e))
return []
def get_document_by_path(self, source_name: str, doc_path: str) -> Optional[Document]:
"""Get document by source name and path."""
try:
with self.get_session() as session:
from sqlalchemy.orm import joinedload
doc = session.query(Document).join(DocumentSource).options(joinedload(Document.source)).filter(
DocumentSource.name == source_name,
Document.path == doc_path
).first()
if not doc:
return None
# Create detached document
detached_doc = Document(
id=doc.id,
source_id=doc.source_id,
external_id=doc.external_id,
title=doc.title,
slug=doc.slug,
path=doc.path,
content=doc.content,
processed_content=doc.processed_content,
searchable_text=doc.searchable_text,
author=doc.author,
description=doc.description,
word_count=doc.word_count,
status=doc.status,
created_at=doc.created_at,
updated_at=doc.updated_at
)
# Create detached source object if exists
if doc.source:
detached_source = DocumentSource(
id=doc.source.id,
name=doc.source.name,
adapter_type=doc.source.adapter_type,
base_url=doc.source.base_url,
is_active=doc.source.is_active,
enabled=doc.source.enabled,
status=doc.source.status
)
detached_doc.source = detached_source
return detached_doc
except Exception as e:
logger.error("Failed to get document by path", source_name=source_name, path=doc_path, error=str(e))
return None
def get_document_by_id(self, document_id: str) -> Optional[Document]:
"""Get document by ID."""
try:
with self.get_session() as session:
from sqlalchemy.orm import joinedload
doc = session.query(Document).options(joinedload(Document.source)).filter(Document.id == document_id).first()
if not doc:
return None
# Create detached document
detached_doc = Document(
id=doc.id,
source_id=doc.source_id,
external_id=doc.external_id,
title=doc.title,
slug=doc.slug,
path=doc.path,
content=doc.content,
processed_content=doc.processed_content,
searchable_text=doc.searchable_text,
author=doc.author,
description=doc.description,
word_count=doc.word_count,
status=doc.status,
created_at=doc.created_at,
updated_at=doc.updated_at
)
# Create detached source object if exists
if doc.source:
detached_source = DocumentSource(
id=doc.source.id,
name=doc.source.name,
adapter_type=doc.source.adapter_type,
base_url=doc.source.base_url,
is_active=doc.source.is_active,
enabled=doc.source.enabled,
status=doc.source.status
)
detached_doc.source = detached_source
return detached_doc
except Exception as e:
logger.error("Failed to get document by ID", document_id=document_id, error=str(e))
return None
def get_document_count_by_source(self, source_id: str) -> int:
"""Get document count for a specific source."""
try:
with self.get_session() as session:
return session.query(Document).filter(Document.source_id == source_id).count()
except Exception as e:
logger.error("Failed to get document count", source_id=source_id, error=str(e))
return 0
def get_all_documents(self) -> List[Document]:
"""Get all documents from all sources."""
try:
with self.get_session() as session:
from sqlalchemy.orm import joinedload
docs = session.query(Document).options(joinedload(Document.source)).all()
# Create detached documents to avoid session issues
detached_docs = []
for doc in docs:
detached_doc = Document(
id=doc.id,
source_id=doc.source_id,
external_id=doc.external_id,
title=doc.title,
slug=doc.slug,
path=doc.path,
content=doc.content,
processed_content=doc.processed_content,
searchable_text=doc.searchable_text,
author=doc.author,
description=doc.description,
word_count=doc.word_count,
status=doc.status,
created_at=doc.created_at,
updated_at=doc.updated_at
)
# Create detached source object if exists
if doc.source:
detached_source = DocumentSource(
id=doc.source.id,
name=doc.source.name,
adapter_type=doc.source.adapter_type,
base_url=doc.source.base_url,
is_active=doc.source.is_active,
enabled=doc.source.enabled,
status=doc.source.status
)
detached_doc.source = detached_source
detached_docs.append(detached_doc)
return detached_docs
except Exception as e:
logger.error("Failed to get all documents", error=str(e))
return []
def get_documents_by_source_name(self, source_name: str) -> List[Document]:
"""Get all documents for a specific source by name."""
try:
with self.get_session() as session:
from sqlalchemy.orm import joinedload
docs = session.query(Document).join(DocumentSource).options(joinedload(Document.source)).filter(
DocumentSource.name == source_name
).all()
# Create detached documents to avoid session issues
detached_docs = []
for doc in docs:
detached_doc = Document(
id=doc.id,
source_id=doc.source_id,
external_id=doc.external_id,
title=doc.title,
slug=doc.slug,
path=doc.path,
content=doc.content,
processed_content=doc.processed_content,
searchable_text=doc.searchable_text,
author=doc.author,
description=doc.description,
word_count=doc.word_count,
status=doc.status,
created_at=doc.created_at,
updated_at=doc.updated_at
)
# Create detached source object if exists
if doc.source:
detached_source = DocumentSource(
id=doc.source.id,
name=doc.source.name,
adapter_type=doc.source.adapter_type,
base_url=doc.source.base_url,
is_active=doc.source.is_active,
enabled=doc.source.enabled,
status=doc.source.status
)
detached_doc.source = detached_source
detached_docs.append(detached_doc)
return detached_docs
except Exception as e:
logger.error("Failed to get documents by source name", source_name=source_name, error=str(e))
return []
# Document management
def create_document(
self,
source_id: str,
external_id: str,
title: str,
content: Optional[str] = None,
**kwargs
) -> Document:
"""Create a new document."""
try:
with self.get_session() as session:
document = Document(
source_id=source_id,
external_id=external_id,
title=title,
content=content,
**kwargs
)
session.add(document)
session.flush()
logger.info("Document created", document_id=document.id, title=title)
return document
except IntegrityError as e:
logger.error("Document creation failed - integrity error", external_id=external_id, error=str(e))
raise DatabaseIntegrityError(f"Document with external_id '{external_id}' already exists for this source")
except Exception as e:
logger.error("Document creation failed", external_id=external_id, error=str(e))
raise DatabaseError(f"Failed to create document: {e}")
def search_documents(
self,
query: str,
source_name: Optional[str] = None,
source_ids: Optional[List[str]] = None,
limit: int = 50,
offset: int = 0
) -> List[Document]:
"""Search documents using full-text search."""
try:
with self.get_session() as session:
from sqlalchemy.orm import joinedload
# Simple implementation - in production you'd use proper FTS
db_query = session.query(Document).join(DocumentSource).options(joinedload(Document.source))
# Filter by content containing the search query
db_query = db_query.filter(
or_(
Document.title.ilike(f'%{query}%'),
Document.content.ilike(f'%{query}%'),
Document.processed_content.ilike(f'%{query}%')
)
)
# Filter by source name if provided
if source_name:
db_query = db_query.filter(DocumentSource.name == source_name)
# Filter by source IDs if provided
if source_ids:
db_query = db_query.filter(Document.source_id.in_(source_ids))
# Order by relevance (title matches first, then updated date)
db_query = db_query.order_by(
Document.title.ilike(f'%{query}%').desc(),
Document.updated_at.desc()
)
# Execute query
results = db_query.limit(limit).offset(offset).all()
# Create detached objects to avoid session issues
detached_results = []
for doc in results:
# Create a new document object with source info embedded
detached_doc = Document(
id=doc.id,
source_id=doc.source_id,
external_id=doc.external_id,
title=doc.title,
slug=doc.slug,
path=doc.path,
content=doc.content,
processed_content=doc.processed_content,
searchable_text=doc.searchable_text,
author=doc.author,
description=doc.description,
word_count=doc.word_count,
status=doc.status,
created_at=doc.created_at,
updated_at=doc.updated_at
)
# Create a detached source object
if doc.source:
detached_source = DocumentSource(
id=doc.source.id,
name=doc.source.name,
adapter_type=doc.source.adapter_type,
base_url=doc.source.base_url,
is_active=doc.source.is_active,
enabled=doc.source.enabled,
status=doc.source.status
)
detached_doc.source = detached_source
detached_results.append(detached_doc)
return detached_results
except Exception as e:
logger.error("Document search failed", query=query, error=str(e))
return []
# Configuration management
def get_configuration(self, key: str) -> Optional[Any]:
"""Get configuration value."""
try:
with self.get_session() as session:
config = session.query(Configuration).filter(Configuration.key == key).first()
return config.value if config else None
except Exception as e:
logger.error("Failed to get configuration", key=key, error=str(e))
return None
def set_configuration(self, key: str, value: Any, description: Optional[str] = None, is_sensitive: bool = False) -> None:
"""Set configuration value."""
try:
with self.get_session() as session:
config = session.query(Configuration).filter(Configuration.key == key).first()
if config:
config.value = value
config.updated_at = datetime.utcnow()
if description:
config.description = description
else:
config = Configuration(
key=key,
value=value,
description=description,
is_sensitive=is_sensitive,
)
session.add(config)
session.flush()
logger.info("Configuration updated", key=key)
except Exception as e:
logger.error("Failed to set configuration", key=key, error=str(e))
raise DatabaseError(f"Failed to set configuration: {e}")
def close(self) -> None:
"""Close database connections."""
try:
if hasattr(self, 'engine'):
self.engine.dispose()
logger.info("Database connections closed")
except Exception as e:
logger.error("Error closing database connections", error=str(e))
# Global database manager instance
_db_manager: Optional[DatabaseManager] = None
def get_database_manager() -> Optional[DatabaseManager]:
"""Get global database manager instance."""
return _db_manager
def initialize_database_manager(
database_url: str,
echo: bool = False,
**kwargs
) -> DatabaseManager:
"""Initialize global database manager."""
global _db_manager
if _db_manager:
logger.warning("Database manager already initialized")
return _db_manager
_db_manager = DatabaseManager(database_url, echo=echo, **kwargs)
return _db_manager
def close_database_manager() -> None:
"""Close global database manager."""
global _db_manager
if _db_manager:
_db_manager.close()
_db_manager = None
logger.info("Global database manager closed")