# -*- coding: utf-8 -*-
"""Location: ./mcpgateway/utils/sqlalchemy_modifier.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Authors: Madhav Kandukuri
SQLAlchemy modifiers
- json_contains_expr: handles json_contains logic for different dialects
- json_contains_tag_expr: handles tag filtering for dict-format tags [{id, label}]
"""
# Standard
import itertools
import re
import threading
from typing import Any, Iterable, List, Union
import uuid
# Third-Party
import orjson
from sqlalchemy import and_, func, or_, text
from sqlalchemy.sql.elements import TextClause
# Thread-safe counter for generating unique bind parameter prefixes
_bind_counter = itertools.count()
_bind_counter_lock = threading.Lock()
def _ensure_list(values: Union[str, Iterable[str]]) -> List[str]:
"""
Normalize input into a list of strings.
Args:
values: A single string or any iterable of strings. If `None`, an empty
list is returned.
Returns:
A list of strings. If `values` is a string it will be wrapped in a
single-item list; if it's already an iterable, it will be converted to
a list. If `values` is `None`, returns an empty list.
"""
if values is None:
return []
if isinstance(values, str):
return [values]
return list(values)
def _generate_unique_prefix(col_ref: str) -> str:
"""
Generate a unique SQL bind parameter prefix for a column reference.
Combines a sanitized column name with a thread-safe counter to ensure
unique bind parameter names across all calls, even when:
- The same column is filtered multiple times in one query
- Different column refs sanitize to the same string (e.g., a_b.c vs a.b_c)
Args:
col_ref: Column reference like "resources.tags"
Returns:
Unique prefix like "resources_tags_42"
"""
sanitized = re.sub(r"[^a-zA-Z0-9]", "_", col_ref)
with _bind_counter_lock:
counter = next(_bind_counter)
return f"{sanitized}_{counter}"
def _sqlite_tag_any_template(col_ref: str, prefix: str, n: int) -> TextClause:
"""
Build a SQLite SQL template for matching ANY of the given tags
inside a JSON array column.
This template supports both legacy string tags and object-style tags
(e.g., {"id": "api"}). It safely guards `json_extract` with
`CASE WHEN type = 'object'` to avoid malformed JSON errors on string values.
Design: We only filter by 'id' field, not 'label'. TagValidator ensures
all properly validated tags have an 'id' field. See json_contains_tag_expr
docstring for rationale.
The generated SQL uses unique bind parameters with the provided prefix
(e.g., :resources_tags_42_p0) to avoid collisions when multiple tag
filters are used in the same query.
Args:
col_ref (str): Fully-qualified column reference (e.g., "resources.tags").
prefix (str): Unique prefix for bind parameters (from _generate_unique_prefix).
n (int): Number of tag values being matched.
Returns:
sqlalchemy.sql.elements.TextClause:
A SQL template for matching ANY of the given tags.
"""
if n == 1:
tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{prefix}_p0 OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) = :{prefix}_p0)" # nosec B608
sql = tmp_.strip()
else:
placeholders = ",".join(f":{prefix}_p{i}" for i in range(n))
tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value IN ({placeholders}) OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) IN ({placeholders}))" # nosec B608
sql = tmp_.strip()
return text(sql)
def _sqlite_tag_all_template(col_ref: str, prefix: str, n: int) -> TextClause:
"""
Build a SQLite SQL template for matching ALL of the given tags
inside a JSON array column.
This is implemented as an AND-chain of EXISTS subqueries, where each
subquery ensures the presence of one required tag.
This template supports both legacy string tags and object-style tags
(e.g., {"id": "api"}). It safely guards `json_extract` with
`CASE WHEN type = 'object'` to avoid malformed JSON errors on string values.
The generated SQL uses unique bind parameters with the provided prefix
(e.g., :resources_tags_42_p0) to avoid collisions when multiple tag
filters are used in the same query.
Args:
col_ref (str): Fully-qualified column reference (e.g., "resources.tags").
prefix (str): Unique prefix for bind parameters (from _generate_unique_prefix).
n (int): Number of tag values being matched.
Returns:
sqlalchemy.sql.elements.TextClause:
A SQL template for matching ALL of the given tags.
"""
clauses = []
for i in range(n):
tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{prefix}_p{i} OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) = :{prefix}_p{i})" # nosec B608
clauses.append(tmp_.strip())
return text(" AND ".join(clauses))
def json_contains_tag_expr(session, col, values: Union[str, Iterable[str]], match_any: bool = True) -> Any:
"""
Return a SQLAlchemy expression that is True when JSON column `col`
contains tags matching the given values. Handles both legacy List[str]
and new List[Dict[str, str]] (with 'id' field) tag formats.
Args:
session: database session
col: column that contains JSON array of tags
values: list of tag IDs to match against
match_any: Boolean to set OR (True) or AND (False) matching
Returns:
Any: SQLAlchemy boolean expression suitable for use in .where()
Raises:
RuntimeError: If dialect is not supported
ValueError: If values is empty
Examples:
SQLite builds a parameterized `TextClause` using `json_each()`:
>>> from unittest.mock import Mock, patch
>>> from sqlalchemy.sql.elements import TextClause
>>> from mcpgateway.utils import sqlalchemy_modifier as mod
>>> session = Mock()
>>> session.get_bind.return_value.dialect.name = "sqlite"
>>> col = Mock()
>>> col.table.name = "resources"
>>> col.name = "tags"
>>> with patch.object(mod, "_generate_unique_prefix", return_value="resources_tags_0"):
... expr = mod.json_contains_tag_expr(session, col, ["api", "db"], match_any=True)
... (isinstance(expr, TextClause), "json_each(resources.tags)" in expr.text, expr.compile().params["resources_tags_0_p0"])
(True, True, 'api')
Unsupported dialects raise:
>>> session.get_bind.return_value.dialect.name = "oracle"
>>> mod.json_contains_tag_expr(session, col, ["x"])
Traceback (most recent call last):
...
RuntimeError: Unsupported dialect for json_contains_tag: oracle
"""
values_list = _ensure_list(values)
if not values_list:
raise ValueError("values must be non-empty")
dialect = session.get_bind().dialect.name
# ---------- MySQL ----------
# For dict-format tags: use JSON_SEARCH to find tags with matching id
# JSON_SEARCH returns path if found, NULL otherwise
if dialect == "mysql":
# Build conditions that check for both string tags and dict tags with matching id
conditions = []
for tag_value in values_list:
# Check if tag exists as plain string OR as dict with matching id
# JSON_SEARCH(col, 'one', value) finds plain string value
# JSON_CONTAINS with path $.*.id checks dict format
string_match = func.json_search(col, "one", tag_value).isnot(None)
dict_match = func.json_contains(col, orjson.dumps([{"id": tag_value}]).decode()) == 1
conditions.append(or_(string_match, dict_match))
if match_any:
return or_(*conditions)
return and_(*conditions)
# ---------- PostgreSQL ----------
# For dict-format tags: use JSON functions that work with both JSON and JSONB types
# Note: .contains() only works with JSONB, but our column is JSON type
#
# Design: We only filter by 'id' field, not 'label'. This is intentional because:
# 1. TagValidator.validate_list() always creates tags with both id and label
# 2. The 'id' is the normalized, canonical identifier for matching
# 3. The 'label' is for display purposes only
# If a tag somehow exists without 'id' (malformed data), it won't match at DB level,
# but _get_tag_id() has a Python-side fallback for graceful handling.
if dialect == "postgresql":
# Third-Party
from sqlalchemy import bindparam, cast, exists, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.sql import literal_column
# Build conditions for each tag value using JSON functions
conditions = []
for tag_value in values_list:
# Generate unique parameter name
param_name = f"tag_{uuid.uuid4().hex[:8]}"
param_dict = f"tag_{uuid.uuid4().hex[:8]}"
# For string tags: use @> operator to check if JSONB array contains the value
# Cast the tag_value to JSONB array and check containment
string_match = cast(col, JSONB).op("@>")(cast(func.jsonb_build_array(bindparam(param_name, value=tag_value)), JSONB))
# For dict tags: use EXISTS with jsonb_array_elements to check 'id' field
# Use table_valued() for explicit column reference (elem.c.value)
# This is the idiomatic SQLAlchemy pattern for table-valued functions
elem_table = func.jsonb_array_elements(cast(col, JSONB)).table_valued("value").alias("elem")
dict_match = exists(select(literal_column("1")).select_from(elem_table).where(elem_table.c.value.op("->>")(literal_column("'id'")) == bindparam(param_dict, value=tag_value)))
conditions.append(or_(string_match, dict_match))
if match_any:
return or_(*conditions)
return and_(*conditions)
# ---------- SQLite (json1) ----------
# For dict-format tags: use json_extract to get the 'id' field
# Use CASE WHEN type = 'object' to avoid "malformed JSON" error on string elements
if dialect == "sqlite":
table_name = getattr(getattr(col, "table", None), "name", None)
column_name = getattr(col, "name", None) or str(col)
col_ref = f"{table_name}.{column_name}" if table_name else column_name
n = len(values_list)
if n == 0:
raise ValueError("values must be non-empty")
# Generate unique prefix to avoid bind name collisions when multiple
# tag filters are combined in the same query (even on the same column
# or when different column refs sanitize to the same string)
prefix = _generate_unique_prefix(col_ref)
params = {f"{prefix}_p{i}": t for i, t in enumerate(values_list)}
if match_any:
tmpl = _sqlite_tag_any_template(col_ref, prefix, n)
return tmpl.bindparams(**params)
tmpl = _sqlite_tag_all_template(col_ref, prefix, n)
return tmpl.bindparams(**params)
raise RuntimeError(f"Unsupported dialect for json_contains_tag: {dialect}")
def json_contains_expr(session, col, values: Union[str, Iterable[str]], match_any: bool = True) -> Any:
"""
Return a SQLAlchemy expression that is True when JSON column `col`
contains the scalar `value`. `session` is used to detect dialect.
Assumes `col` is a JSON/JSONB column (array-of-strings case).
Args:
session: database session
col: column that contains JSON
values: list of values to check for in json
match_any: Boolean to set OR or AND matching
Returns:
Any: SQLAlchemy boolean expression suitable for use in .where()
Raises:
RuntimeError: If dialect is not supported
ValueError: If values is empty
Examples:
SQLite builds a parameterized SQL fragment using `json_each()`:
>>> from unittest.mock import Mock
>>> from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr
>>> session = Mock()
>>> session.get_bind.return_value.dialect.name = "sqlite"
>>> col = Mock()
>>> col.table.name = "resources"
>>> col.name = "scopes"
>>> expr = json_contains_expr(session, col, ["a", "b"], match_any=True)
>>> ("json_each(resources.scopes)" in expr.text, "value IN" in expr.text, len(expr.compile().params))
(True, True, 2)
Unsupported dialects raise:
>>> session.get_bind.return_value.dialect.name = "oracle"
>>> json_contains_expr(session, col, ["x"])
Traceback (most recent call last):
...
RuntimeError: Unsupported dialect for json_contains: oracle
"""
values_list = _ensure_list(values)
if not values_list:
raise ValueError("values must be non-empty")
dialect = session.get_bind().dialect.name
# ---------- MySQL ----------
# - all-of: JSON_CONTAINS(col, '["a","b"]') == 1
# - any-of: prefer JSON_OVERLAPS (MySQL >= 8.0.17), otherwise OR of JSON_CONTAINS for each value
if dialect == "mysql":
try:
if match_any:
# JSON_OVERLAPS exists in modern MySQL; SQLAlchemy will emit func.json_overlaps(...)
return func.json_overlaps(col, orjson.dumps(values_list).decode()) == 1
else:
return func.json_contains(col, orjson.dumps(values_list).decode()) == 1
except Exception:
# Fallback: compose OR of json_contains for each scalar
if match_any:
return or_(*[func.json_contains(col, orjson.dumps(t).decode()) == 1 for t in values_list])
else:
return and_(*[func.json_contains(col, orjson.dumps(t).decode()) == 1 for t in values_list])
# ---------- PostgreSQL ----------
# - all-of: col.contains(list) (works if col is JSONB)
# - any-of: use OR of col.contains([value]) (or use ?| operator if you prefer)
if dialect == "postgresql":
# prefer JSONB .contains for all-of
if not match_any:
return col.contains(values_list)
# match_any: use OR over element-containment
return or_(*[col.contains([t]) for t in values_list])
# ---------- SQLite (json1) ----------
# SQLite doesn't have JSON_CONTAINS. We build safe SQL:
# - any-of: single EXISTS ... WHERE value IN (:p0,:p1,...)
# - all-of: multiple EXISTS with unique bind params (one EXISTS per value) => AND semantics
if dialect == "sqlite":
table_name = getattr(getattr(col, "table", None), "name", None)
column_name = getattr(col, "name", None) or str(col)
col_ref = f"{table_name}.{column_name}" if table_name else column_name
if match_any:
# Build placeholders with unique param names and pass *values* to bindparams
params = {}
placeholders = []
for i, t in enumerate(values_list):
pname = f"t_{uuid.uuid4().hex[:8]}_{i}"
placeholders.append(f":{pname}")
params[pname] = t
placeholders_sql = ",".join(placeholders)
sq = text(f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value IN ({placeholders_sql}))") # nosec B608 - Safe: uses parameterized queries with bindparams()
# IMPORTANT: pass plain values as kwargs to bindparams
return sq.bindparams(**params)
# all-of: return AND of EXISTS(... = :pX) with plain values
exists_clauses = []
for t in values_list:
pname = f"t_{uuid.uuid4().hex[:8]}"
clause = text(f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{pname})").bindparams(**{pname: t}) # nosec B608 - Safe: uses parameterized queries with bindparams()
exists_clauses.append(clause)
if len(exists_clauses) == 1:
return exists_clauses[0]
return and_(*exists_clauses)
raise RuntimeError(f"Unsupported dialect for json_contains: {dialect}")