crud.py•7.59 kB
"""
Database CRUD operations for the Sectional MCP Panel.
"""
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Dict, Any, Optional, Union
import uuid
from datetime import datetime
from .models import Panel, Section, Server, Task, AuditLog
# Panel operations
async def get_panel_config(db: AsyncSession) -> Optional[Panel]:
"""Get the panel configuration."""
result = await db.execute(select(Panel))
return result.scalars().first()
async def update_panel_config(db: AsyncSession, config_data: Dict[str, Any]) -> Panel:
"""Update the panel configuration."""
panel = await get_panel_config(db)
if not panel:
panel = Panel(
name=config_data.get("name", "Sectional MCP Panel"),
version=config_data.get("version", "0.1.0"),
config_schema_version=config_data.get("config_schema_version", "1.0"),
global_settings=config_data.get("global_settings", {})
)
db.add(panel)
else:
for key, value in config_data.items():
if hasattr(panel, key):
setattr(panel, key, value)
await db.commit()
await db.refresh(panel)
return panel
# Section operations
async def get_sections(db: AsyncSession) -> List[Section]:
"""Get all sections."""
result = await db.execute(select(Section))
return result.scalars().all()
async def get_section_by_name(db: AsyncSession, name: str) -> Optional[Section]:
"""Get a section by name."""
result = await db.execute(select(Section).where(Section.name == name))
return result.scalars().first()
async def get_section_by_id(db: AsyncSession, section_id: int) -> Optional[Section]:
"""Get a section by ID."""
return await db.get(Section, section_id)
async def create_section(db: AsyncSession, section_data: Dict[str, Any]) -> Section:
"""Create a new section."""
section = Section(
name=section_data["name"],
description=section_data.get("description"),
settings=section_data.get("settings", {})
)
db.add(section)
await db.commit()
await db.refresh(section)
return section
async def update_section(db: AsyncSession, section_id: int, section_data: Dict[str, Any]) -> Optional[Section]:
"""Update a section."""
section = await get_section_by_id(db, section_id)
if not section:
return None
for key, value in section_data.items():
if hasattr(section, key):
setattr(section, key, value)
await db.commit()
await db.refresh(section)
return section
async def delete_section(db: AsyncSession, section_id: int) -> bool:
"""Delete a section."""
section = await get_section_by_id(db, section_id)
if not section:
return False
await db.delete(section)
await db.commit()
return True
# Server operations
async def get_servers(db: AsyncSession, section_id: Optional[int] = None) -> List[Server]:
"""Get all servers, optionally filtered by section."""
query = select(Server)
if section_id is not None:
query = query.where(Server.section_id == section_id)
result = await db.execute(query)
return result.scalars().all()
async def get_server_by_name(db: AsyncSession, section_id: int, name: str) -> Optional[Server]:
"""Get a server by name within a section."""
result = await db.execute(
select(Server).where(
Server.section_id == section_id,
Server.name == name
)
)
return result.scalars().first()
async def get_server_by_id(db: AsyncSession, server_id: int) -> Optional[Server]:
"""Get a server by ID."""
return await db.get(Server, server_id)
async def create_server(db: AsyncSession, server_data: Dict[str, Any]) -> Server:
"""Create a new server."""
server = Server(
name=server_data["name"],
section_id=server_data["section_id"],
description=server_data.get("description"),
runtime_definition=server_data["runtime_definition"],
settings=server_data.get("settings", {}),
status="Stopped"
)
db.add(server)
await db.commit()
await db.refresh(server)
return server
async def update_server(db: AsyncSession, server_id: int, server_data: Dict[str, Any]) -> Optional[Server]:
"""Update a server."""
server = await get_server_by_id(db, server_id)
if not server:
return None
for key, value in server_data.items():
if hasattr(server, key):
setattr(server, key, value)
await db.commit()
await db.refresh(server)
return server
async def update_server_status(db: AsyncSession, server_id: int, status: str, process_id: Optional[str] = None) -> Optional[Server]:
"""Update a server's status and process ID."""
server = await get_server_by_id(db, server_id)
if not server:
return None
server.status = status
if process_id is not None:
server.process_id = process_id
await db.commit()
await db.refresh(server)
return server
async def delete_server(db: AsyncSession, server_id: int) -> bool:
"""Delete a server."""
server = await get_server_by_id(db, server_id)
if not server:
return False
await db.delete(server)
await db.commit()
return True
# Task operations
async def create_task(db: AsyncSession, task_type: str, details: Dict[str, Any] = None) -> Task:
"""Create a new task."""
task_id = str(uuid.uuid4())
task = Task(
task_id=task_id,
task_type=task_type,
status="Pending",
details=details or {}
)
db.add(task)
await db.commit()
await db.refresh(task)
return task
async def get_task(db: AsyncSession, task_id: str) -> Optional[Task]:
"""Get a task by ID."""
result = await db.execute(select(Task).where(Task.task_id == task_id))
return result.scalars().first()
async def update_task_status(
db: AsyncSession,
task_id: str,
status: str,
result: Dict[str, Any] = None,
error: str = None
) -> Optional[Task]:
"""Update a task's status, result, and error."""
task = await get_task(db, task_id)
if not task:
return None
task.status = status
if result is not None:
task.result = result
if error is not None:
task.error = error
await db.commit()
await db.refresh(task)
return task
# Audit logging
async def add_audit_log(
db: AsyncSession,
action: str,
entity_type: str,
entity_id: Optional[str] = None,
user: Optional[str] = None,
details: Optional[Dict[str, Any]] = None
) -> AuditLog:
"""Add an audit log entry."""
log = AuditLog(
action=action,
entity_type=entity_type,
entity_id=entity_id,
user=user,
details=details or {},
timestamp=datetime.utcnow()
)
db.add(log)
await db.commit()
await db.refresh(log)
return log
async def get_audit_logs(
db: AsyncSession,
limit: int = 100,
entity_type: Optional[str] = None,
entity_id: Optional[str] = None
) -> List[AuditLog]:
"""Get audit logs, optionally filtered by entity type and ID."""
query = select(AuditLog).order_by(AuditLog.timestamp.desc()).limit(limit)
if entity_type:
query = query.where(AuditLog.entity_type == entity_type)
if entity_id:
query = query.where(AuditLog.entity_id == entity_id)
result = await db.execute(query)
return result.scalars().all()