"""Participant-based filtering for group chat messages.
This module provides efficient participant filtering using the message_participants
junction table. It enables queries like "find messages in chats that include
person A and person B" without using slow JSON operations.
"""
import sqlite3
from ..db.queries import normalize_handle
def get_chat_participants(
chat_conn: sqlite3.Connection,
chat_id: int,
) -> list[str]:
"""Get all participant handles for a chat from chat.db.
Args:
chat_conn: SQLite connection to chat.db.
chat_id: The chat ROWID.
Returns:
List of normalized participant identifiers.
"""
cursor = chat_conn.execute(
"""
SELECT h.id
FROM handle h
JOIN chat_handle_join chj ON h.ROWID = chj.handle_id
WHERE chj.chat_id = ?
""",
(chat_id,),
)
return [normalize_handle(row[0]) for row in cursor.fetchall()]
def filter_by_participants(
index_conn: sqlite3.Connection,
rowids: list[int],
participants: list[str],
require_all: bool = True,
) -> list[int]:
"""Filter message rowids by participant handles.
Uses the message_participants junction table for efficient filtering.
This is much faster than JSON LIKE queries, especially for large result sets.
Args:
index_conn: SQLite connection to the search index database.
rowids: List of message rowids to filter.
participants: List of participant handles to match (will be normalized).
require_all: If True, all participants must be in the chat.
If False, any participant match is sufficient.
Returns:
Filtered list of rowids, preserving the original order.
Example:
>>> # Find messages in chats with both Alice and Bob
>>> filtered = filter_by_participants(
... conn,
... [1, 2, 3, 4, 5],
... ["+15551234567", "alice@example.com"],
... require_all=True
... )
>>> # Returns rowids where BOTH participants are in the chat
>>> # Find messages in chats with Alice OR Bob
>>> filtered = filter_by_participants(
... conn,
... [1, 2, 3, 4, 5],
... ["+15551234567", "alice@example.com"],
... require_all=False
... )
>>> # Returns rowids where ANY participant is in the chat
"""
if not rowids or not participants:
return rowids
# Normalize all participant handles
normalized_participants = [normalize_handle(p) for p in participants]
# Build placeholders for SQL query
rowid_placeholders = ",".join("?" * len(rowids))
participant_placeholders = ",".join("?" * len(normalized_participants))
if require_all:
# All participants must be in the chat
# Use GROUP BY with HAVING COUNT(DISTINCT participant) = num_participants
query = f"""
SELECT mp.rowid
FROM message_participants mp
WHERE mp.rowid IN ({rowid_placeholders})
AND mp.participant IN ({participant_placeholders})
GROUP BY mp.rowid
HAVING COUNT(DISTINCT mp.participant) = ?
"""
params = rowids + normalized_participants + [len(normalized_participants)]
else:
# Any participant match is sufficient
# Simple IN clause with DISTINCT
query = f"""
SELECT DISTINCT mp.rowid
FROM message_participants mp
WHERE mp.rowid IN ({rowid_placeholders})
AND mp.participant IN ({participant_placeholders})
"""
params = rowids + normalized_participants
cursor = index_conn.execute(query, params)
filtered_rowids = {row[0] for row in cursor.fetchall()}
# Preserve original order of rowids
return [rowid for rowid in rowids if rowid in filtered_rowids]
def get_message_participants(
index_conn: sqlite3.Connection,
rowid: int,
) -> list[str]:
"""Get all participants for a specific message.
Args:
index_conn: SQLite connection to the search index database.
rowid: Message rowid.
Returns:
List of participant handles for the message's chat.
"""
cursor = index_conn.execute(
"""
SELECT participant
FROM message_participants
WHERE rowid = ?
ORDER BY participant
""",
(rowid,),
)
return [row[0] for row in cursor.fetchall()]
def bulk_get_message_participants(
index_conn: sqlite3.Connection,
rowids: list[int],
) -> dict[int, list[str]]:
"""Get participants for multiple messages in a single query.
Args:
index_conn: SQLite connection to the search index database.
rowids: List of message rowids.
Returns:
Dictionary mapping rowid to list of participant handles.
"""
if not rowids:
return {}
placeholders = ",".join("?" * len(rowids))
cursor = index_conn.execute(
f"""
SELECT rowid, participant
FROM message_participants
WHERE rowid IN ({placeholders})
ORDER BY rowid, participant
""",
rowids,
)
# Build result dictionary
result: dict[int, list[str]] = {rowid: [] for rowid in rowids}
for row in cursor.fetchall():
rowid = row[0]
participant = row[1]
result[rowid].append(participant)
return result