ab7e313804ae_permission_system_rework.py•7.42 kB
"""permission_system_rework
Revision ID: ab7e313804ae
Revises: 1d0bb7fede17
Create Date: 2025-06-16 15:20:43.118246
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy import UUID
from datetime import datetime, timezone
from uuid import uuid4
# revision identifiers, used by Alembic.
revision: str = "ab7e313804ae"
down_revision: Union[str, None] = "1d0bb7fede17"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _now():
return datetime.now(timezone.utc)
def _define_dataset_table() -> sa.Table:
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
# definition or load what is in the database
table = sa.Table(
"datasets",
sa.MetaData(),
sa.Column("id", UUID, primary_key=True, default=uuid4),
sa.Column("name", sa.Text),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
onupdate=lambda: datetime.now(timezone.utc),
),
sa.Column("owner_id", UUID, sa.ForeignKey("principals.id"), index=True),
)
return table
def _define_data_table() -> sa.Table:
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
# definition or load what is in the database
table = sa.Table(
"data",
sa.MetaData(),
sa.Column("id", UUID, primary_key=True, default=uuid4),
sa.Column("name", sa.String),
sa.Column("extension", sa.String),
sa.Column("mime_type", sa.String),
sa.Column("raw_data_location", sa.String),
sa.Column("owner_id", UUID, index=True),
sa.Column("content_hash", sa.String),
sa.Column("external_metadata", sa.JSON),
sa.Column("node_set", sa.JSON, nullable=True), # list of strings
sa.Column("token_count", sa.Integer),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
onupdate=lambda: datetime.now(timezone.utc),
),
)
return table
def _ensure_permission(conn, permission_name) -> str:
"""
Return the permission.id for the given name, creating the row if needed.
"""
permissions_table = sa.Table(
"permissions",
sa.MetaData(),
sa.Column("id", UUID, primary_key=True, index=True, default=uuid4),
sa.Column(
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
onupdate=lambda: datetime.now(timezone.utc),
),
sa.Column("name", sa.String, unique=True, nullable=False, index=True),
)
row = conn.execute(
sa.select(permissions_table).filter(permissions_table.c.name == permission_name)
).fetchone()
if row is None:
permission_id = uuid4()
op.bulk_insert(
permissions_table,
[
{
"id": permission_id,
"name": permission_name,
"created_at": _now(),
}
],
)
return permission_id
return row.id
def _build_acl_row(*, user_id, target_id, permission_id, target_col) -> dict:
"""Create a dict with the correct column names for the ACL row."""
return {
"id": uuid4(),
"created_at": _now(),
"principal_id": user_id,
target_col: target_id,
"permission_id": permission_id,
}
def _create_dataset_permission(conn, user_id, dataset_id, permission_name):
perm_id = _ensure_permission(conn, permission_name)
return _build_acl_row(
user_id=user_id, target_id=dataset_id, permission_id=perm_id, target_col="dataset_id"
)
def _create_data_permission(conn, user_id, data_id, permission_name):
perm_id = _ensure_permission(conn, permission_name)
return _build_acl_row(
user_id=user_id, target_id=data_id, permission_id=perm_id, target_col="data_id"
)
def upgrade() -> None:
conn = op.get_bind()
# Recreate ACLs table with default permissions set to datasets instead of documents
op.drop_table("acls")
acls_table = op.create_table(
"acls",
sa.Column("id", UUID, primary_key=True, default=uuid4),
sa.Column(
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
),
sa.Column(
"updated_at", sa.DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)
),
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
)
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
# definition or load what is in the database
dataset_table = _define_dataset_table()
datasets = conn.execute(sa.select(dataset_table)).fetchall()
if not datasets:
return
acl_list = []
for dataset in datasets:
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete"))
if acl_list:
op.bulk_insert(acls_table, acl_list)
def downgrade() -> None:
conn = op.get_bind()
op.drop_table("acls")
acls_table = op.create_table(
"acls",
sa.Column("id", UUID, primary_key=True, nullable=False, default=uuid4),
sa.Column(
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
),
sa.Column(
"updated_at", sa.DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)
),
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
sa.Column("data_id", UUID, sa.ForeignKey("data.id", ondelete="CASCADE")),
)
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
# definition or load what is in the database
data_table = _define_data_table()
data = conn.execute(sa.select(data_table)).fetchall()
if not data:
return
acl_list = []
for single_data in data:
acl_list.append(_create_data_permission(conn, single_data.owner_id, single_data.id, "read"))
acl_list.append(
_create_data_permission(conn, single_data.owner_id, single_data.id, "write")
)
if acl_list:
op.bulk_insert(acls_table, acl_list)