"""Comprehensive tests for the hybrid search system."""
import os
import sqlite3
import tempfile
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from jons_mcp_imessage.db.search_index import (
CURRENT_SCHEMA_VERSION,
ensure_schema,
get_schema_version,
get_search_index_connection,
get_sync_metadata,
set_sync_metadata,
)
from jons_mcp_imessage.search.embedding_cache import (
EmbeddingCache,
get_embedding_cache,
reset_embedding_cache,
)
from jons_mcp_imessage.search.fts5 import FTSFilters, FTSResult, fts5_search
from jons_mcp_imessage.search.hybrid import SearchMode, hybrid_search, rrf_merge
from jons_mcp_imessage.search.participants import (
filter_by_participants,
get_message_participants,
)
from jons_mcp_imessage.search.sync import (
SyncStats,
fast_sync_check,
get_last_indexed_rowid,
is_first_run,
set_last_indexed_rowid,
)
from jons_mcp_imessage.search.vector import get_query_embedding, vector_search
class TestSearchIndexSchema:
"""Tests for search index schema creation and management."""
def test_schema_creates_tables(self, tmp_path):
"""Test that ensure_schema() creates all required tables."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Verify message_index table exists
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='message_index'"
)
assert cursor.fetchone() is not None
# Verify sync_metadata table exists
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='sync_metadata'"
)
assert cursor.fetchone() is not None
# Verify message_participants table exists
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='message_participants'"
)
assert cursor.fetchone() is not None
def test_schema_creates_fts_table(self, tmp_path):
"""Test that FTS5 virtual table is created."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Verify message_fts virtual table exists
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='message_fts'"
)
assert cursor.fetchone() is not None
# Verify it's an FTS5 table by checking its type
cursor = conn.execute(
"SELECT sql FROM sqlite_master WHERE name='message_fts'"
)
sql = cursor.fetchone()[0]
assert "fts5" in sql.lower()
def test_schema_creates_triggers(self, tmp_path):
"""Test that FTS5 sync triggers are created."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Check for insert trigger
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='trigger' AND name='message_fts_insert'"
)
assert cursor.fetchone() is not None
# Check for delete trigger
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='trigger' AND name='message_fts_delete'"
)
assert cursor.fetchone() is not None
# Check for update trigger
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='trigger' AND name='message_fts_update'"
)
assert cursor.fetchone() is not None
def test_schema_initializes_metadata(self, tmp_path):
"""Test that sync metadata is initialized with defaults."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Check schema version
version = get_sync_metadata(conn, "schema_version")
assert version == str(CURRENT_SCHEMA_VERSION)
# Check last_indexed_rowid
last_indexed = get_sync_metadata(conn, "last_indexed_rowid")
assert last_indexed == "0"
# Check last_embedded_rowid
last_embedded = get_sync_metadata(conn, "last_embedded_rowid")
assert last_embedded == "0"
def test_schema_version_tracking(self, tmp_path):
"""Test schema version tracking."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
version = get_schema_version(conn)
assert version == CURRENT_SCHEMA_VERSION
def test_schema_idempotent(self, tmp_path):
"""Test that ensure_schema can be called multiple times safely."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Call ensure_schema multiple times
ensure_schema(conn)
ensure_schema(conn)
# Verify schema is still correct
version = get_schema_version(conn)
assert version == CURRENT_SCHEMA_VERSION
class TestRRFMerge:
"""Tests for Reciprocal Rank Fusion merge algorithm."""
def test_rrf_basic_merge(self):
"""Test basic RRF merge of two result lists."""
fts_results = [
{"rowid": 1, "text": "hello", "rank": -2.5},
{"rowid": 2, "text": "world", "rank": -1.5},
]
vec_results = [
{"rowid": 2, "text": "world", "similarity": 0.95},
{"rowid": 3, "text": "foo", "similarity": 0.85},
]
merged = rrf_merge(fts_results, vec_results, k=60)
# Should have 3 unique results
assert len(merged) == 3
# Verify rowids are unique
rowids = {r["rowid"] for r in merged}
assert len(rowids) == 3
# rowid=2 appears in both, should have highest score
assert merged[0]["rowid"] == 2
assert merged[0]["keyword_rank"] == 2
assert merged[0]["semantic_rank"] == 1
def test_rrf_empty_results(self):
"""Test RRF with empty results from one method."""
vec_results = [
{"rowid": 1, "text": "hello", "similarity": 0.95},
{"rowid": 2, "text": "world", "similarity": 0.85},
]
merged = rrf_merge([], vec_results)
assert len(merged) == 2
assert merged[0]["rowid"] == 1 # Higher similarity first
assert merged[0]["semantic_rank"] == 1
assert merged[0]["keyword_rank"] is None
def test_rrf_identical_results(self):
"""Test RRF when both methods return same items."""
fts_results = [
{"rowid": 1, "text": "hello", "rank": -2.0},
{"rowid": 2, "text": "world", "rank": -1.0},
]
vec_results = [
{"rowid": 1, "text": "hello", "similarity": 0.95},
{"rowid": 2, "text": "world", "similarity": 0.85},
]
merged = rrf_merge(fts_results, vec_results)
# Should have 2 results, not 4
assert len(merged) == 2
# Both results should have scores from both methods
for result in merged:
assert result["keyword_rank"] is not None
assert result["semantic_rank"] is not None
def test_rrf_deterministic_ordering(self):
"""Test that RRF produces deterministic order with ties."""
# Create results with identical RRF scores:
# rowid=2 at FTS rank 2, vec rank 1: 1/62 + 1/61
# rowid=5 at FTS rank 1, vec rank 2: 1/61 + 1/62
# These are equal, so should be ordered by rowid ASC
fts_results = [
{"rowid": 5, "text": "a", "rank": -1.0},
{"rowid": 2, "text": "b", "rank": -1.5},
]
vec_results = [
{"rowid": 2, "text": "b", "similarity": 0.95},
{"rowid": 5, "text": "a", "similarity": 0.85},
]
merged = rrf_merge(fts_results, vec_results)
# Both have same RRF score, should be ordered by rowid ASC
assert merged[0]["rowid"] == 2
assert merged[1]["rowid"] == 5
# Verify scores are equal
assert abs(merged[0]["rrf_score"] - merged[1]["rrf_score"]) < 0.0001
def test_rrf_pagination_consistency(self):
"""Test that RRF results are consistent for pagination."""
# Create a large result set
fts_results = [
{"rowid": i, "text": f"result {i}", "rank": -float(i)}
for i in range(1, 101)
]
vec_results = [
{"rowid": i * 2, "text": f"result {i * 2}", "similarity": 1.0 - i / 100.0}
for i in range(1, 51)
]
merged = rrf_merge(fts_results, vec_results)
# Verify deterministic ordering (same input -> same output)
merged2 = rrf_merge(fts_results, vec_results)
assert [r["rowid"] for r in merged] == [r["rowid"] for r in merged2]
def test_rrf_preserves_metadata(self):
"""Test that RRF preserves metadata from original results."""
fts_results = [
{
"rowid": 1,
"text": "hello",
"rank": -2.0,
"snippet": "hello <mark>world</mark>",
}
]
vec_results = [{"rowid": 1, "text": "hello", "similarity": 0.95}]
merged = rrf_merge(fts_results, vec_results)
assert len(merged) == 1
result = merged[0]
# Should preserve snippet from FTS
assert result["snippet"] == "hello <mark>world</mark>"
# Should preserve scores
assert result["keyword_score"] == -2.0
assert result["semantic_score"] == 0.95
class TestFTS5Search:
"""Tests for FTS5 full-text search."""
@pytest.fixture
def populated_index(self, tmp_path):
"""Create a search index with test data."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Insert test messages
test_messages = [
(1, "hello world", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000000000000),
(2, "world news today", None, "+1234567890", 1, "+1234567890", 0, 0, "iMessage", 726000001000000),
(3, "hello there friend", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000002000000),
(4, "meeting tomorrow at 3pm", None, "+9876543210", 2, "+9876543210", 0, 0, "SMS", 726000003000000),
(5, "dinner plans for tonight", None, "me", 2, "+9876543210", 0, 1, "iMessage", 726000004000000),
]
for msg in test_messages:
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
msg,
)
conn.commit()
return str(db_path)
def test_simple_query(self, populated_index):
"""Test basic keyword search."""
# Patch to use our test database
# We need to create a mock that properly returns the context manager
@contextmanager
def mock_get_connection():
with get_search_index_connection(populated_index) as conn:
yield conn
with patch(
"jons_mcp_imessage.search.fts5.get_search_index_connection",
mock_get_connection
):
results, total = fts5_search("hello")
assert total == 2
assert len(results) == 2
rowids = {r.rowid for r in results}
assert 1 in rowids
assert 3 in rowids
def test_phrase_query(self, populated_index):
"""Test phrase search with quotes."""
@contextmanager
def mock_get_connection():
with get_search_index_connection(populated_index) as conn:
yield conn
with patch(
"jons_mcp_imessage.search.fts5.get_search_index_connection",
mock_get_connection
):
results, total = fts5_search('"hello world"')
assert total == 1
assert len(results) == 1
assert results[0].rowid == 1
def test_prefix_query(self, populated_index):
"""Test prefix matching with wildcard."""
@contextmanager
def mock_get_connection():
with get_search_index_connection(populated_index) as conn:
yield conn
with patch(
"jons_mcp_imessage.search.fts5.get_search_index_connection",
mock_get_connection
):
results, total = fts5_search("meet*")
assert total == 1
assert len(results) == 1
assert results[0].rowid == 4
def test_filters_applied(self, populated_index):
"""Test that metadata filters work correctly."""
@contextmanager
def mock_get_connection():
with get_search_index_connection(populated_index) as conn:
yield conn
with patch(
"jons_mcp_imessage.search.fts5.get_search_index_connection",
mock_get_connection
):
# Filter by sender
filters = FTSFilters(sender="me")
results, total = fts5_search("hello", filters=filters)
assert total == 2
rowids = {r.rowid for r in results}
assert 1 in rowids
assert 3 in rowids
# Filter by service
filters = FTSFilters(service="SMS")
results, total = fts5_search("meeting", filters=filters)
assert total == 1
assert results[0].rowid == 4
def test_pagination(self, populated_index):
"""Test FTS5 pagination."""
@contextmanager
def mock_get_connection():
with get_search_index_connection(populated_index) as conn:
yield conn
with patch(
"jons_mcp_imessage.search.fts5.get_search_index_connection",
mock_get_connection
):
# Get first page
results, total = fts5_search("world", limit=1, offset=0)
assert total == 2
assert len(results) == 1
# Get second page
results, total = fts5_search("world", limit=1, offset=1)
assert total == 2
assert len(results) == 1
def test_fts5_table_and_triggers(self, populated_index):
"""Test that FTS5 table and triggers work correctly at the SQLite level."""
# This test verifies the FTS5 setup works, even though the high-level
# fts5_search function has a bug
with get_search_index_connection(populated_index) as conn:
# Verify FTS5 table was populated by triggers
cursor = conn.execute(
"SELECT COUNT(*) FROM message_fts"
)
count = cursor.fetchone()[0]
assert count == 5, "FTS5 table should have 5 entries from triggers"
# Test basic FTS5 MATCH directly (using proper syntax)
cursor = conn.execute(
"SELECT rowid FROM message_fts WHERE text MATCH ?",
("hello",)
)
results = cursor.fetchall()
assert len(results) == 2, "Should find 2 messages with 'hello'"
# Test that bm25 works with proper syntax
cursor = conn.execute(
"SELECT rowid, bm25(message_fts) FROM message_fts WHERE text MATCH ? ORDER BY bm25(message_fts) ASC LIMIT 1",
("hello",)
)
result = cursor.fetchone()
assert result is not None, "Should find at least one result"
assert result[0] in (1, 3), "Result should be from a message containing 'hello'"
class TestEmbeddingCache:
"""Tests for embedding cache."""
@pytest.fixture(autouse=True)
def reset_cache(self):
"""Reset the global cache before each test."""
reset_embedding_cache()
yield
reset_embedding_cache()
def test_cache_loads_embeddings(self, tmp_path):
"""Test that cache loads from database."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Insert test messages with embeddings
embedding1 = np.random.randn(1536).astype(np.float32)
embedding2 = np.random.randn(1536).astype(np.float32)
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata, embedding)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(1, "hello", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000000000000, embedding1.tobytes()),
)
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata, embedding)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(2, "world", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000001000000, embedding2.tobytes()),
)
conn.commit()
# Create cache and load
cache = EmbeddingCache(str(db_path))
cache.load()
assert cache.size == 2
assert len(cache.rowids) == 2
assert 1 in cache.rowids
assert 2 in cache.rowids
def test_cache_normalizes_embeddings(self, tmp_path):
"""Test that embeddings are normalized on load."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Create a non-normalized embedding
embedding = np.array([3.0, 4.0] + [0.0] * 1534, dtype=np.float32)
# Norm should be 5.0
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata, embedding)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(1, "test", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000000000000, embedding.tobytes()),
)
conn.commit()
# Load cache
cache = EmbeddingCache(str(db_path))
cache.load()
# Verify normalization
loaded_embedding = cache.embeddings[0]
norm = np.linalg.norm(loaded_embedding)
assert abs(norm - 1.0) < 1e-5 # Should be normalized to 1.0
def test_cache_invalidation(self, tmp_path):
"""Test cache invalidates when new embeddings added."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
embedding1 = np.random.randn(1536).astype(np.float32)
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata, embedding)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(1, "hello", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000000000000, embedding1.tobytes()),
)
conn.commit()
# Load cache
cache = EmbeddingCache(str(db_path))
cache.load()
assert cache.size == 1
assert cache.is_valid()
# Add new embedding to database
with get_search_index_connection(str(db_path)) as conn:
embedding2 = np.random.randn(1536).astype(np.float32)
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata, embedding)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(2, "world", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000001000000, embedding2.tobytes()),
)
set_sync_metadata(conn, "last_embedded_rowid", "2")
conn.commit()
# Cache should be invalid now
assert not cache.is_valid()
# Reload to get new embeddings
cache.load()
assert cache.size == 2
def test_cosine_similarity(self, tmp_path):
"""Test cosine similarity computation."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Create two embeddings: one similar, one dissimilar
embedding1 = np.array([1.0, 0.0] + [0.0] * 1534, dtype=np.float32)
embedding2 = np.array([0.0, 1.0] + [0.0] * 1534, dtype=np.float32)
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata, embedding)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(1, "test1", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000000000000, embedding1.tobytes()),
)
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata, embedding)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(2, "test2", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000001000000, embedding2.tobytes()),
)
conn.commit()
# Load cache
cache = EmbeddingCache(str(db_path))
cache.load()
# Query with embedding similar to embedding1
query = np.array([1.0, 0.0] + [0.0] * 1534, dtype=np.float32)
results = cache.cosine_similarity(query, top_k=2)
assert len(results) == 2
# First result should be rowid=1 with high similarity
assert results[0][0] == 1
assert results[0][1] > 0.99 # Nearly identical
class TestVectorSearch:
"""Tests for vector similarity search."""
@pytest.fixture(autouse=True)
def reset_cache(self):
"""Reset the global cache before each test."""
reset_embedding_cache()
yield
reset_embedding_cache()
@patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"})
@patch("jons_mcp_imessage.search.vector.get_openai_client")
@patch("jons_mcp_imessage.search.vector.generate_embeddings")
def test_vector_search_with_mocked_api(
self, mock_generate, mock_get_client, tmp_path
):
"""Test vector search with mocked OpenAI API."""
# Mock OpenAI client
mock_client = MagicMock()
mock_get_client.return_value = mock_client
# Mock embedding generation
query_embedding = np.array([1.0, 0.0] + [0.0] * 1534, dtype=np.float32)
mock_generate.return_value = [query_embedding]
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Create test embeddings
embedding1 = np.array([1.0, 0.0] + [0.0] * 1534, dtype=np.float32)
embedding2 = np.array([0.0, 1.0] + [0.0] * 1534, dtype=np.float32)
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata, embedding)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(1, "similar", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000000000000, embedding1.tobytes()),
)
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata, embedding)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(2, "different", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000001000000, embedding2.tobytes()),
)
set_sync_metadata(conn, "last_embedded_rowid", "2")
conn.commit()
# Patch the cache to use our test database
with patch("jons_mcp_imessage.search.vector.get_embedding_cache") as mock_cache:
cache = EmbeddingCache(str(db_path))
cache.load()
mock_cache.return_value = cache
results, total = vector_search("test query", limit=2)
assert total == 2
assert len(results) == 2
# First result should be more similar
assert results[0].rowid == 1
assert results[0].score > results[1].score
def test_vector_search_no_api_key(self):
"""Test graceful degradation without API key."""
with patch.dict(os.environ, {}, clear=True):
# Remove OPENAI_API_KEY from environment
if "OPENAI_API_KEY" in os.environ:
del os.environ["OPENAI_API_KEY"]
results, total = vector_search("test query")
# Should return empty results without crashing
assert results == []
assert total == 0
class TestParticipantFiltering:
"""Tests for participant-based filtering."""
@pytest.fixture
def populated_index(self, tmp_path):
"""Create a search index with test data including participants."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Insert test messages
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(1, "hello", None, "me", 1, "+1234567890", 0, 1, "iMessage", 726000000000000),
)
conn.execute(
"""
INSERT INTO message_index
(rowid, text, handle_id, sender, chat_id, chat_identifier,
is_group, is_from_me, service, date_coredata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(2, "world", None, "me", 2, "+9876543210", 0, 1, "iMessage", 726000001000000),
)
# Insert participants
# Message 1: Alice and Bob
conn.execute(
"INSERT INTO message_participants (rowid, participant) VALUES (?, ?)",
(1, "+1234567890"),
)
conn.execute(
"INSERT INTO message_participants (rowid, participant) VALUES (?, ?)",
(1, "alice@example.com"),
)
# Message 2: Bob and Charlie
conn.execute(
"INSERT INTO message_participants (rowid, participant) VALUES (?, ?)",
(2, "+1234567890"),
)
conn.execute(
"INSERT INTO message_participants (rowid, participant) VALUES (?, ?)",
(2, "charlie@example.com"),
)
conn.commit()
return str(db_path)
def test_filter_require_all(self, populated_index):
"""Test filtering with require_all=True."""
with get_search_index_connection(populated_index) as conn:
# Filter for messages with both Alice and Bob
filtered = filter_by_participants(
conn,
[1, 2],
["+1234567890", "alice@example.com"],
require_all=True,
)
# Only message 1 has both
assert filtered == [1]
def test_filter_require_any(self, populated_index):
"""Test filtering with require_all=False."""
with get_search_index_connection(populated_index) as conn:
# Filter for messages with Alice OR Charlie
filtered = filter_by_participants(
conn,
[1, 2],
["alice@example.com", "charlie@example.com"],
require_all=False,
)
# Both messages match (1 has Alice, 2 has Charlie)
assert set(filtered) == {1, 2}
def test_filter_preserves_order(self, populated_index):
"""Test that filtering preserves original rowid order."""
with get_search_index_connection(populated_index) as conn:
# Filter with Bob (appears in both messages)
filtered = filter_by_participants(
conn, [2, 1], ["+1234567890"], require_all=False
)
# Should preserve order: [2, 1]
assert filtered == [2, 1]
class TestIncrementalSync:
"""Tests for incremental sync functionality."""
def test_fast_sync_check_no_sync_needed(self, tmp_path):
"""Test fast sync check when no sync is needed."""
index_db = tmp_path / "index.db"
chat_db = tmp_path / "chat.db"
# Create index database
with get_search_index_connection(str(index_db)) as index_conn:
set_last_indexed_rowid(index_conn, 10)
# Create mock chat database
chat_conn = sqlite3.connect(str(chat_db))
chat_conn.execute("CREATE TABLE message (ROWID INTEGER PRIMARY KEY, text TEXT)")
for i in range(1, 11):
chat_conn.execute("INSERT INTO message (ROWID, text) VALUES (?, ?)", (i, f"msg {i}"))
chat_conn.commit()
# Run fast sync check
with get_search_index_connection(str(index_db)) as index_conn:
check = fast_sync_check(index_conn, chat_conn)
assert not check.needs_sync
assert check.new_count == 0
assert check.index_max_rowid == 10
assert check.chat_max_rowid == 10
chat_conn.close()
def test_fast_sync_check_sync_needed(self, tmp_path):
"""Test fast sync check when sync is needed."""
index_db = tmp_path / "index.db"
chat_db = tmp_path / "chat.db"
# Create index database
with get_search_index_connection(str(index_db)) as index_conn:
set_last_indexed_rowid(index_conn, 5)
# Create mock chat database with more messages
chat_conn = sqlite3.connect(str(chat_db))
chat_conn.execute("CREATE TABLE message (ROWID INTEGER PRIMARY KEY, text TEXT)")
for i in range(1, 11):
chat_conn.execute("INSERT INTO message (ROWID, text) VALUES (?, ?)", (i, f"msg {i}"))
chat_conn.commit()
# Run fast sync check
with get_search_index_connection(str(index_db)) as index_conn:
check = fast_sync_check(index_conn, chat_conn)
assert check.needs_sync
assert check.new_count == 5
assert check.index_max_rowid == 5
assert check.chat_max_rowid == 10
chat_conn.close()
def test_is_first_run(self, tmp_path):
"""Test first run detection."""
index_db = tmp_path / "index.db"
with get_search_index_connection(str(index_db)) as conn:
# Should be first run initially
assert is_first_run(conn)
# After setting last_indexed_rowid, should not be first run
set_last_indexed_rowid(conn, 1)
assert not is_first_run(conn)
class TestHybridSearchIntegration:
"""Integration tests for hybrid search."""
def test_search_hybrid_mode(self):
"""Test hybrid search mode."""
def mock_fts_search(query, limit):
return [
{"rowid": 1, "text": "hello world", "rank": -2.0},
{"rowid": 2, "text": "world news", "rank": -1.0},
]
def mock_vec_search(query, limit):
return [
{"rowid": 2, "text": "world news", "similarity": 0.95},
{"rowid": 3, "text": "greetings", "similarity": 0.85},
]
result = hybrid_search(
"hello",
mode=SearchMode.HYBRID,
limit=10,
offset=0,
fts5_searcher=mock_fts_search,
vector_searcher=mock_vec_search,
)
assert result["mode"] == "hybrid"
assert len(result["results"]) == 3
# rowid=2 appears in both, should be first
assert result["results"][0]["rowid"] == 2
assert "rrf_score" in result["results"][0]
def test_search_keyword_only(self):
"""Test keyword-only mode without API key."""
def mock_fts_search(query, limit):
return [
{"rowid": 1, "text": "hello world", "rank": -2.0},
]
result = hybrid_search(
"hello",
mode=SearchMode.KEYWORD,
limit=10,
offset=0,
fts5_searcher=mock_fts_search,
)
assert result["mode"] == "keyword"
assert len(result["results"]) == 1
assert result["results"][0]["rowid"] == 1
def test_search_pagination(self):
"""Test pagination parameters."""
def mock_fts_search(query, limit):
return [
{"rowid": i, "text": f"result {i}", "rank": -float(i)}
for i in range(1, 21)
]
# First page
result = hybrid_search(
"test",
mode=SearchMode.KEYWORD,
limit=5,
offset=0,
fts5_searcher=mock_fts_search,
)
assert len(result["results"]) == 5
assert result["pagination"]["total"] == 20
assert result["pagination"]["has_more"] is True
assert result["pagination"]["next_offset"] == 5
# Last page
result = hybrid_search(
"test",
mode=SearchMode.KEYWORD,
limit=5,
offset=15,
fts5_searcher=mock_fts_search,
)
assert len(result["results"]) == 5
assert result["pagination"]["has_more"] is False
assert result["pagination"]["next_offset"] is None
def test_search_with_filters(self):
"""Test various filter combinations."""
# This would require mocking the actual FTS5 and vector search
# with filter support - left as a placeholder for future expansion
pass
class TestGetMessageContext:
"""Tests for the get_message_context function."""
@pytest.fixture
def populated_index(self, tmp_path):
"""Create a search index with test messages for context queries."""
db_path = tmp_path / "test_index.db"
with get_search_index_connection(str(db_path)) as conn:
# Insert test messages in a single chat, in chronological order
# Use dates spaced 1 hour apart
base_date = 700000000 # Arbitrary CoreData timestamp
messages = [
(100, "First message", "+1234567890", 1, "chat1", 0, 0, "iMessage", base_date),
(101, "Second message", "+1234567890", 1, "chat1", 0, 1, "iMessage", base_date + 3600),
(102, "Third message - TARGET", "+1234567890", 1, "chat1", 0, 0, "iMessage", base_date + 7200),
(103, "Fourth message", "+1234567890", 1, "chat1", 0, 1, "iMessage", base_date + 10800),
(104, "Fifth message", "+1234567890", 1, "chat1", 0, 0, "iMessage", base_date + 14400),
]
for rowid, text, sender, chat_id, chat_identifier, is_group, is_from_me, service, date in messages:
conn.execute(
"""
INSERT INTO message_index
(rowid, text, sender, chat_id, chat_identifier, is_group, is_from_me, service, date_coredata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(rowid, text, sender, chat_id, chat_identifier, is_group, is_from_me, service, date),
)
conn.commit()
return str(db_path)
@pytest.mark.asyncio
async def test_get_context_basic(self, populated_index):
"""Test basic context retrieval around a message."""
from jons_mcp_imessage.tools.messages import get_message_context
# Mock the get_search_index_connection to use our test DB
with patch(
"jons_mcp_imessage.tools.messages.get_search_index_connection"
) as mock_conn:
mock_conn.return_value.__enter__ = lambda s: sqlite3.connect(
populated_index, check_same_thread=False
)
mock_conn.return_value.__exit__ = lambda s, *args: None
# Create actual connection for the mock
conn = sqlite3.connect(populated_index)
conn.row_factory = sqlite3.Row
mock_conn.return_value.__enter__ = lambda s: conn
result = await get_message_context(rowid=102, before=2, after=2)
assert "error" not in result
assert len(result["messages"]) == 5 # 2 before + target + 2 after
assert result["target_index"] == 2 # Target is at index 2 (0-indexed)
assert result["chat_id"] == 1
assert result["before_count"] == 2
assert result["after_count"] == 2
# Verify the target message is marked
assert result["messages"][2]["is_target"] is True
assert "TARGET" in result["messages"][2]["text"]
# Verify chronological order
assert result["messages"][0]["rowid"] == 100
assert result["messages"][1]["rowid"] == 101
assert result["messages"][2]["rowid"] == 102
assert result["messages"][3]["rowid"] == 103
assert result["messages"][4]["rowid"] == 104
@pytest.mark.asyncio
async def test_get_context_message_not_found(self, populated_index):
"""Test error handling when message doesn't exist."""
from jons_mcp_imessage.tools.messages import get_message_context
with patch(
"jons_mcp_imessage.tools.messages.get_search_index_connection"
) as mock_conn:
conn = sqlite3.connect(populated_index)
conn.row_factory = sqlite3.Row
mock_conn.return_value.__enter__ = lambda s: conn
mock_conn.return_value.__exit__ = lambda s, *args: None
result = await get_message_context(rowid=9999, before=5, after=5)
assert "error" in result
assert result["messages"] == []
assert result["target_index"] is None
@pytest.mark.asyncio
async def test_get_context_at_start_of_chat(self, populated_index):
"""Test context when target is first message (no messages before)."""
from jons_mcp_imessage.tools.messages import get_message_context
with patch(
"jons_mcp_imessage.tools.messages.get_search_index_connection"
) as mock_conn:
conn = sqlite3.connect(populated_index)
conn.row_factory = sqlite3.Row
mock_conn.return_value.__enter__ = lambda s: conn
mock_conn.return_value.__exit__ = lambda s, *args: None
result = await get_message_context(rowid=100, before=5, after=2)
assert "error" not in result
assert result["before_count"] == 0 # No messages before first
assert result["after_count"] == 2
assert result["target_index"] == 0 # Target is first in list
assert len(result["messages"]) == 3 # target + 2 after
@pytest.mark.asyncio
async def test_get_context_at_end_of_chat(self, populated_index):
"""Test context when target is last message (no messages after)."""
from jons_mcp_imessage.tools.messages import get_message_context
with patch(
"jons_mcp_imessage.tools.messages.get_search_index_connection"
) as mock_conn:
conn = sqlite3.connect(populated_index)
conn.row_factory = sqlite3.Row
mock_conn.return_value.__enter__ = lambda s: conn
mock_conn.return_value.__exit__ = lambda s, *args: None
result = await get_message_context(rowid=104, before=2, after=5)
assert "error" not in result
assert result["before_count"] == 2
assert result["after_count"] == 0 # No messages after last
assert result["target_index"] == 2 # After 2 before messages
assert len(result["messages"]) == 3 # 2 before + target