"""Message query and search tools for iMessage."""
import logging
from datetime import datetime
from typing import Any
from ..constants import DEFAULT_MESSAGE_LIMIT
from ..db import (
coredata_to_datetime,
datetime_to_coredata,
get_connection,
get_search_index_connection,
get_sync_metadata,
normalize_handle,
parse_message_text,
)
from ..db.contacts import get_contact_cache
from ..db.queries import format_datetime_iso
from ..exceptions import DatabaseLockedError
from ..search.embeddings import generate_pending_embeddings
from ..search.fts5 import FTSFilters, FTSResult, fts5_search
from ..search.hybrid import rrf_merge
from ..search.sync import fast_sync_check, sync_new_messages
from ..search.vector import VectorResult, get_query_embedding, vector_search
logger = logging.getLogger(__name__)
async def enrich_message(msg: dict) -> dict:
"""Add contact_name field to a message dict.
Args:
msg: Message dictionary with at least 'sender' and 'is_from_me' fields
Returns:
Message dictionary with contact_name field added
"""
cache = get_contact_cache()
sender = msg.get("sender")
if sender and sender != "me":
msg["contact_name"] = await cache.resolve_name(sender)
else:
msg["contact_name"] = None
return msg
async def enrich_messages_batch(msgs: list[dict]) -> list[dict]:
"""Efficiently enrich multiple messages with contact names.
Args:
msgs: List of message dictionaries
Returns:
List of message dictionaries with contact_name fields added
"""
cache = get_contact_cache()
# Collect unique senders (filter out None and "me")
senders: set[str] = set()
for m in msgs:
sender = m.get("sender")
if sender and sender != "me":
senders.add(sender)
# Batch resolve
names = await cache.resolve_names_batch(list(senders))
# Apply to messages
for msg in msgs:
sender = msg.get("sender")
msg["contact_name"] = names.get(sender) if sender and sender != "me" else None
return msgs
def iso_to_coredata(date_str: str) -> int:
"""Convert ISO8601 date string to CoreData timestamp in nanoseconds.
Args:
date_str: ISO8601 formatted date string (e.g., "2024-01-15T18:30:00Z")
Returns:
CoreData timestamp in nanoseconds since 2001-01-01
(iMessage stores dates as nanoseconds, not seconds)
Raises:
ValueError: If date_str is not a valid ISO8601 format
"""
# Handle timezone suffix
dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
# datetime_to_coredata returns seconds, but iMessage stores nanoseconds
return datetime_to_coredata(dt) * 1_000_000_000
def coredata_to_iso(timestamp: int) -> str:
"""Convert CoreData timestamp to ISO8601 string.
Args:
timestamp: CoreData timestamp (seconds or nanoseconds since 2001-01-01)
Returns:
ISO8601 formatted date string
"""
dt = coredata_to_datetime(timestamp)
if dt is None:
return ""
return dt.isoformat()
def generate_imessage_link(sender: str, is_group: bool) -> str:
"""Generate an imessage:// deep link for a contact or group.
Args:
sender: Phone number or email address
is_group: Whether this is a group chat
Returns:
imessage:// deep link URL
"""
# For group chats, we can't generate a direct link
if is_group or not sender or sender == "me":
return ""
# For individual contacts, create imessage:// link
return f"imessage://{sender}"
async def get_messages_by_rowids(rowids: list[int]) -> dict[int, dict[str, Any]]:
"""Fetch message details from message_index for given rowids.
Args:
rowids: List of message rowids to fetch
Returns:
Dictionary mapping rowid to message details
"""
if not rowids:
return {}
with get_search_index_connection() as conn:
# Build query with placeholders
placeholders = ",".join("?" * len(rowids))
cursor = conn.execute(
f"""
SELECT
rowid,
text,
sender,
chat_id,
chat_identifier,
is_group,
is_from_me,
service,
date_coredata
FROM message_index
WHERE rowid IN ({placeholders})
""",
rowids,
)
messages = {}
for row in cursor:
messages[row["rowid"]] = {
"rowid": row["rowid"],
"text": row["text"] or "",
"sender": row["sender"] if not row["is_from_me"] else "me",
"is_from_me": bool(row["is_from_me"]),
"date": coredata_to_iso(row["date_coredata"]),
"service": row["service"],
"chat_id": row["chat_id"],
"is_group": bool(row["is_group"]),
"chat_identifier": row["chat_identifier"],
}
# Enrich with contact names
messages_list = list(messages.values())
enriched_list = await enrich_messages_batch(messages_list)
# Rebuild dictionary with enriched messages
return {msg["rowid"]: msg for msg in enriched_list}
def get_index_status() -> dict[str, Any]:
"""Get the current status of the search index.
Returns:
Dictionary with index status information
"""
with get_search_index_connection() as conn:
# Get total indexed messages
cursor = conn.execute("SELECT COUNT(*) FROM message_index")
total_indexed = cursor.fetchone()[0]
# Get total embedded messages (non-empty embeddings)
cursor = conn.execute(
"SELECT COUNT(*) FROM message_index WHERE embedding IS NOT NULL AND LENGTH(embedding) > 0"
)
total_embedded = cursor.fetchone()[0]
# Get last sync metadata
backfill_status = get_sync_metadata(conn, "backfill_status") or "complete"
# Determine sync status
if backfill_status == "in_progress":
sync_status = "backfilling"
else:
sync_status = "current"
return {
"total_indexed": total_indexed,
"total_embedded": total_embedded,
"sync_status": sync_status,
}
def filter_by_participants(rowids: list[int], participants: list[str]) -> list[int]:
"""Filter rowids to only include messages in chats with specified participants.
Args:
rowids: List of message rowids to filter
participants: List of participant identifiers to filter by
Returns:
Filtered list of rowids
"""
if not rowids or not participants:
return rowids
# Normalize participant identifiers
normalized_participants = [normalize_handle(p) for p in participants]
with get_search_index_connection() as conn:
# Find rowids that have ALL the specified participants
# We need to find messages where the chat contains all participants
placeholders_rowids = ",".join("?" * len(rowids))
placeholders_participants = ",".join("?" * len(normalized_participants))
# Query messages where the chat contains all specified participants
cursor = conn.execute(
f"""
SELECT DISTINCT mp.rowid
FROM message_participants mp
WHERE mp.rowid IN ({placeholders_rowids})
AND mp.participant IN ({placeholders_participants})
GROUP BY mp.rowid
HAVING COUNT(DISTINCT mp.participant) = ?
""",
rowids + normalized_participants + [len(normalized_participants)],
)
return [row[0] for row in cursor.fetchall()]
async def get_recent_messages(
limit: int = DEFAULT_MESSAGE_LIMIT,
offset: int = 0,
after_date: str | None = None,
) -> dict:
"""Get the most recent messages across all conversations.
Args:
limit: Maximum number of messages to return (default: 50)
offset: Number of messages to skip (default: 0)
after_date: Only return messages after this date (ISO8601 format)
Returns:
Dictionary with messages list and pagination metadata.
"""
with get_connection() as conn:
# Build query with optional date filter
params: list = []
date_filter = ""
if after_date:
try:
coredata_ns = iso_to_coredata(after_date)
date_filter = "WHERE m.date > ?"
params.append(coredata_ns)
except ValueError:
pass
params.extend([limit + 1, offset])
cursor = conn.execute(
f"""
SELECT
m.ROWID,
m.guid,
m.text,
m.attributedBody,
m.date,
m.is_from_me,
m.service,
h.id as sender_handle,
c.ROWID as chat_id,
c.chat_identifier,
c.display_name
FROM message m
LEFT JOIN handle h ON m.handle_id = h.ROWID
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
{date_filter}
ORDER BY m.date DESC
LIMIT ? OFFSET ?
""",
params,
)
messages = []
rows = cursor.fetchall()
for row in rows[:limit]:
text = parse_message_text(row)
msg_date = coredata_to_datetime(row["date"])
messages.append(
{
"rowid": row["ROWID"],
"guid": row["guid"],
"text": text,
"date": format_datetime_iso(msg_date),
"date_raw": row["date"],
"is_from_me": bool(row["is_from_me"]),
"service": row["service"],
"sender": row["sender_handle"] if not row["is_from_me"] else "me",
"chat_id": row["chat_id"],
"chat_identifier": row["chat_identifier"],
"chat_display_name": row["display_name"],
}
)
# Get total count
count_cursor = conn.execute(
f"SELECT COUNT(*) FROM message m {date_filter}",
params[:-2] if date_filter else [],
)
total = count_cursor.fetchone()[0]
has_more = len(rows) > limit
# Enrich messages with contact names
messages = await enrich_messages_batch(messages)
return {
"messages": messages,
"pagination": {
"total": total,
"offset": offset,
"limit": limit,
"has_more": has_more,
"next_offset": offset + limit if has_more else None,
},
}
async def get_message_context(
rowid: int,
before: int = 5,
after: int = 5,
) -> dict:
"""Get messages surrounding a specific message in the same conversation.
Use this to get context around a message found via search. Returns messages
from the same chat/thread, ordered chronologically, with the target message
marked.
Args:
rowid: The rowid of the message to get context for
before: Number of messages to fetch before the target (default: 5)
after: Number of messages to fetch after the target (default: 5)
Returns:
Dictionary with:
- messages: List of messages in chronological order
- target_index: Index of the target message in the list
- chat_id: The chat ID for the conversation
- chat_identifier: The chat identifier (phone/email/group)
"""
# First, look up the target message from the search index
with get_search_index_connection() as conn:
cursor = conn.execute(
"""
SELECT
rowid,
text,
sender,
chat_id,
chat_identifier,
is_group,
is_from_me,
service,
date_coredata
FROM message_index
WHERE rowid = ?
""",
(rowid,),
)
target_row = cursor.fetchone()
if target_row is None:
return {
"error": f"Message with rowid {rowid} not found in index",
"messages": [],
"target_index": None,
"chat_id": None,
"chat_identifier": None,
}
chat_id = target_row["chat_id"]
target_date = target_row["date_coredata"]
chat_identifier = target_row["chat_identifier"]
# Fetch messages before the target (older messages)
before_cursor = conn.execute(
"""
SELECT
rowid,
text,
sender,
chat_id,
chat_identifier,
is_group,
is_from_me,
service,
date_coredata
FROM message_index
WHERE chat_id = ? AND date_coredata < ?
ORDER BY date_coredata DESC
LIMIT ?
""",
(chat_id, target_date, before),
)
before_rows = list(before_cursor.fetchall())
# Reverse to get chronological order (oldest first)
before_rows.reverse()
# Fetch messages after the target (newer messages)
after_cursor = conn.execute(
"""
SELECT
rowid,
text,
sender,
chat_id,
chat_identifier,
is_group,
is_from_me,
service,
date_coredata
FROM message_index
WHERE chat_id = ? AND date_coredata > ?
ORDER BY date_coredata ASC
LIMIT ?
""",
(chat_id, target_date, after),
)
after_rows = list(after_cursor.fetchall())
# Build the messages list
def row_to_message(row, is_target: bool = False) -> dict:
return {
"rowid": row["rowid"],
"text": row["text"] or "",
"sender": row["sender"] if not row["is_from_me"] else "me",
"is_from_me": bool(row["is_from_me"]),
"date": coredata_to_iso(row["date_coredata"]),
"service": row["service"],
"is_target": is_target,
}
messages = []
# Add messages before
for row in before_rows:
messages.append(row_to_message(row, is_target=False))
# Add target message
target_index = len(messages)
messages.append(row_to_message(target_row, is_target=True))
# Add messages after
for row in after_rows:
messages.append(row_to_message(row, is_target=False))
# Enrich messages with contact names
messages = await enrich_messages_batch(messages)
return {
"messages": messages,
"target_index": target_index,
"chat_id": chat_id,
"chat_identifier": chat_identifier,
"is_group": bool(target_row["is_group"]),
"before_count": len(before_rows),
"after_count": len(after_rows),
}
async def search_messages(
query: str,
sender: str | None = None,
chat_id: int | None = None,
participants: list[str] | None = None,
after_date: str | None = None,
before_date: str | None = None,
service: str | None = None,
search_mode: str = "hybrid",
limit: int = 100,
offset: int = 0,
) -> dict:
"""Search messages using hybrid keyword + semantic search.
Combines FTS5 full-text search with OpenAI embedding-based semantic search,
merged using Reciprocal Rank Fusion (RRF) for optimal relevance.
This function is designed to handle partial failures gracefully:
- If sync fails, search uses existing index
- If semantic search fails, falls back to keyword search
- If keyword search fails, returns empty results rather than crashing
- Returns error information in response for visibility
Args:
query: Search query text
sender: Filter by sender phone/email (exact match after normalization)
chat_id: Filter by specific conversation ID
participants: Filter by chat participants (for group chats)
after_date: Only messages after this date (ISO8601 format)
before_date: Only messages before this date (ISO8601 format)
service: Filter by "iMessage" or "SMS"
search_mode: "hybrid" (default), "keyword" (FTS5 only), or "semantic" (vector only)
limit: Results per page (default 100)
offset: Pagination offset (default 0)
Returns:
Dictionary with messages, pagination, and index_status including:
- messages: List of matching messages with relevance scores
- search_mode: The search mode that was used
- fts5_matches: Number of keyword matches found
- semantic_matches: Number of semantic matches found
- pagination: Pagination metadata (total, limit, offset, has_more, next_offset)
- index_status: Current state of the search index
- errors: List of error messages encountered (if any)
- warning: Warning message if partial results (optional)
"""
# Validate search_mode
valid_modes = ("hybrid", "keyword", "semantic")
if search_mode not in valid_modes:
search_mode = "hybrid"
logger.debug(f"Starting search: query='{query}', mode={search_mode}, limit={limit}, offset={offset}")
# Initialize response metadata
fts5_matches = 0
semantic_matches = 0
errors: list[str] = []
# Step 1: Parse date strings to CoreData timestamps
filters = FTSFilters()
if sender:
try:
filters.sender = normalize_handle(sender)
except Exception as e:
logger.warning(f"Failed to normalize sender '{sender}': {e}")
errors.append(f"Invalid sender format: {sender}")
if chat_id is not None:
filters.chat_id = chat_id
if after_date:
try:
filters.after_date = iso_to_coredata(after_date)
except ValueError as e:
logger.warning(f"Invalid after_date format '{after_date}': {e}")
errors.append(f"Invalid date format for after_date: {after_date}")
if before_date:
try:
filters.before_date = iso_to_coredata(before_date)
except ValueError as e:
logger.warning(f"Invalid before_date format '{before_date}': {e}")
errors.append(f"Invalid date format for before_date: {before_date}")
if service:
filters.service = service
# Step 2: Fast sync check and sync if needed
sync_status = "current"
try:
with get_search_index_connection() as index_conn:
with get_connection() as chat_conn:
try:
check = fast_sync_check(index_conn, chat_conn)
if check.needs_sync:
logger.info(f"Syncing {check.new_count} new messages before search")
try:
stats = sync_new_messages(index_conn, chat_conn)
if stats.status == "stale":
sync_status = "stale"
errors.append(f"Sync partially failed: {stats.messages_failed} messages failed to index")
elif stats.status == "partial":
errors.append(f"Sync completed with {stats.messages_failed} failures")
# Generate embeddings for newly synced messages
if stats.messages_synced > 0:
try:
logger.info(f"Generating embeddings for {stats.messages_synced} new messages")
embed_result = generate_pending_embeddings(index_conn, batch_size=100)
logger.info(f"Generated {embed_result['embedded']} embeddings")
except Exception as embed_err:
logger.warning(f"Embedding generation failed (search will use keyword only): {embed_err}")
except Exception as e:
logger.error(f"Error syncing messages: {e}")
sync_status = "stale"
errors.append(f"Sync failed: {e}")
except Exception as e:
logger.warning(f"Fast sync check failed: {e}")
sync_status = "stale"
errors.append(f"Sync check failed: {e}")
except DatabaseLockedError:
logger.warning("Database locked during sync, search will use existing index")
sync_status = "stale"
errors.append("Database locked during sync; using stale index")
except Exception as e:
logger.error(f"Unexpected error during sync: {e}")
sync_status = "stale"
errors.append(f"Sync error: {e}")
# Step 3: Run FTS5 search if needed
fts5_results: list[FTSResult] = []
fts5_error = None
if search_mode in ("hybrid", "keyword"):
try:
# For hybrid mode, fetch more results to merge effectively
fetch_limit = limit * 3 if search_mode == "hybrid" else limit + offset + 1
logger.debug(f"Running FTS5 search with limit={fetch_limit}")
fts5_results, fts5_total = fts5_search(
query, filters=filters, limit=fetch_limit, offset=0
)
fts5_matches = fts5_total
logger.info(f"FTS5 search returned {fts5_total} matches")
except Exception as e:
logger.error(f"FTS5 search error: {e}")
fts5_error = str(e)
errors.append(f"Keyword search failed: {e}")
fts5_results = []
# Step 4: Run vector search if needed
vector_results: list[VectorResult] = []
semantic_available = True
vector_error = None
if search_mode in ("hybrid", "semantic"):
try:
logger.debug("Checking if semantic search is available")
# Check if semantic search is available (API key set)
test_embedding = get_query_embedding(query)
if test_embedding is None:
semantic_available = False
logger.warning("Semantic search unavailable: OPENAI_API_KEY not set or invalid")
if search_mode == "semantic":
# Semantic-only mode but no API key - return empty results with warning
logger.info("Semantic-only mode requested but semantic search unavailable")
return {
"messages": [],
"search_mode": search_mode,
"fts5_matches": 0,
"semantic_matches": 0,
"pagination": {
"total": 0,
"limit": limit,
"offset": offset,
"has_more": False,
"next_offset": None,
},
"index_status": get_index_status(),
"warning": "Semantic search unavailable: OPENAI_API_KEY not set or invalid. Use search_mode='keyword' for keyword-only search.",
"errors": errors,
}
else:
# For hybrid mode, fetch more results to merge effectively
fetch_limit = (
limit * 3 if search_mode == "hybrid" else limit + offset + 1
)
logger.debug(f"Running vector search with limit={fetch_limit}")
vector_results, vector_total = vector_search(
query, filters=filters, limit=fetch_limit
)
semantic_matches = vector_total
logger.info(f"Vector search returned {vector_total} matches")
except Exception as e:
logger.error(f"Vector search error: {e}")
semantic_available = False
vector_error = str(e)
errors.append(f"Semantic search failed: {e}")
# If semantic-only mode was requested but failed, we have no results
if search_mode == "semantic" and not vector_results and vector_error:
logger.warning(f"Semantic-only mode failed: {vector_error}")
return {
"messages": [],
"search_mode": search_mode,
"fts5_matches": 0,
"semantic_matches": 0,
"pagination": {
"total": 0,
"limit": limit,
"offset": offset,
"has_more": False,
"next_offset": None,
},
"index_status": get_index_status(),
"warning": "Semantic search failed and is the only search method enabled.",
"errors": errors,
}
# Step 5: Merge results based on search mode
if search_mode == "hybrid":
# Convert FTS5 results to dict format for RRF merge
fts5_dicts = [
{"rowid": r.rowid, "score": r.score, "snippet": r.snippet}
for r in fts5_results
]
# Convert vector results to dict format
vector_dicts = [
{"rowid": r.rowid, "similarity": r.score} for r in vector_results
]
# Merge using RRF
try:
if fts5_dicts or vector_dicts:
merged = rrf_merge(fts5_dicts, vector_dicts)
else:
merged = []
logger.debug(f"RRF merge returned {len(merged)} results")
except Exception as e:
logger.error(f"Error merging results: {e}")
errors.append(f"Result merging failed: {e}")
merged = []
# Extract rowids with RRF scores
result_rowids = [r["rowid"] for r in merged]
rrf_scores = {r["rowid"]: r["rrf_score"] for r in merged}
snippets = {r["rowid"]: r.get("snippet") for r in merged}
keyword_ranks = {r["rowid"]: r.get("keyword_rank") for r in merged}
semantic_ranks = {r["rowid"]: r.get("semantic_rank") for r in merged}
elif search_mode == "keyword":
result_rowids = [r.rowid for r in fts5_results]
rrf_scores = {
r.rowid: abs(r.score) for r in fts5_results
} # BM25 scores are negative
snippets = {r.rowid: r.snippet for r in fts5_results}
keyword_ranks = {r.rowid: i + 1 for i, r in enumerate(fts5_results)}
semantic_ranks = {}
else: # semantic
result_rowids = [r.rowid for r in vector_results]
rrf_scores = {r.rowid: r.score for r in vector_results}
snippets = {}
keyword_ranks = {}
semantic_ranks = {r.rowid: i + 1 for i, r in enumerate(vector_results)}
# Step 6: Apply participant filtering if specified
try:
if participants and result_rowids:
result_rowids = filter_by_participants(result_rowids, participants)
except Exception as e:
logger.warning(f"Participant filtering failed: {e}")
errors.append(f"Participant filtering failed: {e}")
# Calculate total before pagination
total = len(result_rowids)
# Step 7: Apply pagination
paginated_rowids = result_rowids[offset : offset + limit]
# Step 8: Fetch message details
try:
message_details = await get_messages_by_rowids(paginated_rowids)
except Exception as e:
logger.error(f"Failed to fetch message details: {e}")
errors.append(f"Failed to fetch message details: {e}")
message_details = {}
# Build response messages maintaining order
messages = []
for rowid in paginated_rowids:
try:
if rowid not in message_details:
logger.warning(f"Message {rowid} not found in details")
continue
msg = message_details[rowid]
# Determine match type
has_keyword = keyword_ranks.get(rowid) is not None
has_semantic = semantic_ranks.get(rowid) is not None
if has_keyword and has_semantic:
match_type = "hybrid"
elif has_keyword:
match_type = "keyword"
else:
match_type = "semantic"
result = {
"rowid": rowid,
"text": msg["text"],
"sender": msg["sender"],
"is_from_me": msg["is_from_me"],
"date": msg["date"],
"service": msg["service"],
"chat_id": msg["chat_id"],
"is_group": msg["is_group"],
"relevance_score": round(rrf_scores.get(rowid, 0), 6),
"match_type": match_type,
"imessage_link": generate_imessage_link(msg["sender"], msg["is_group"]),
}
# Add snippet if available (from FTS5)
if snippets.get(rowid):
result["snippet"] = snippets[rowid]
messages.append(result)
except Exception as e:
logger.warning(f"Error building result for message {rowid}: {e}")
errors.append(f"Error processing message {rowid}: {e}")
continue
# Step 9: Build response
has_more = offset + limit < total
# Get index status and update sync_status
try:
index_status = get_index_status()
index_status["sync_status"] = sync_status
except Exception as e:
logger.warning(f"Failed to get index status: {e}")
errors.append(f"Failed to get index status: {e}")
index_status = {"sync_status": sync_status}
response: dict[str, Any] = {
"messages": messages,
"search_mode": search_mode,
"fts5_matches": fts5_matches,
"semantic_matches": semantic_matches,
"pagination": {
"total": total,
"limit": limit,
"offset": offset,
"has_more": has_more,
"next_offset": offset + limit if has_more else None,
},
"index_status": index_status,
}
# Add warning if semantic search unavailable in hybrid mode
if search_mode == "hybrid" and not semantic_available:
response["warning"] = (
"Semantic search unavailable: OPENAI_API_KEY not set. "
"Results are keyword-only."
)
# Add errors to response if any occurred
if errors:
response["errors"] = errors
logger.warning(f"Search completed with {len(errors)} errors")
logger.info(
f"Search complete: returned {len(messages)} messages, total={total}, errors={len(errors)}"
)
return response