Solr MCP
by allenday
Verified
- solr-mcp
- scripts
#!/usr/bin/env python3
"""
Test script for vector search in Solr.
"""
import argparse
import asyncio
import json
import os
import sys
from typing import Dict, List, Any
import httpx
# Add the project root to the path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from solr_mcp.embeddings.client import OllamaClient
async def generate_query_embedding(query_text: str) -> List[float]:
"""Generate embedding for a query using Ollama.
Args:
query_text: Query text to generate embedding for
Returns:
Embedding vector for the query
"""
client = OllamaClient()
print(f"Generating embedding for query: '{query_text}'")
embedding = await client.get_embedding(query_text)
return embedding
async def vector_search(
query: str,
collection: str = "testvectors",
vector_field: str = "embedding",
k: int = 5,
filter_query: str = None
):
"""
Perform a vector search in Solr using the generated embedding.
Args:
query: Search query text
collection: Solr collection name
vector_field: Name of the vector field
k: Number of results to return
filter_query: Optional filter query
"""
# Generate embedding for the query
query_embedding = await generate_query_embedding(query)
# Format the vector as a string that Solr expects for KNN search
vector_str = "[" + ",".join(str(v) for v in query_embedding) + "]"
# Prepare Solr KNN query
solr_url = f"http://localhost:8983/solr/{collection}/select"
# Build query parameters
params = {
"q": f"{{!knn f={vector_field} topK={k}}}{vector_str}",
"fl": "id,title,text,score,vector_model",
"wt": "json"
}
if filter_query:
params["fq"] = filter_query
print(f"Executing vector search in collection '{collection}'")
try:
# Split implementation - try POST first (to handle long vectors), fall back to GET
async with httpx.AsyncClient() as client:
try:
# First try with POST to handle large vectors
response = await client.post(
solr_url,
data={"q": params["q"]},
params={
"fl": params["fl"],
"wt": params["wt"]
},
timeout=30.0
)
except Exception as post_error:
print(f"POST request failed, trying GET: {post_error}")
# Fall back to GET with a shorter vector representation
# Truncate the vector string if needed
if len(vector_str) > 800:
short_vector = ",".join(str(round(v, 4)) for v in query_embedding[:100])
params["q"] = f"{{!knn f={vector_field} topK={k}}}{short_vector}"
response = await client.get(solr_url, params=params, timeout=30.0)
if response.status_code == 200:
result = response.json()
return result
else:
print(f"Error in vector search: {response.status_code} - {response.text}")
return None
except Exception as e:
print(f"Error during vector search: {e}")
return None
def display_results(results: Dict[str, Any]):
"""Display search results in a readable format.
Args:
results: Search results from Solr
"""
if not results or 'response' not in results:
print("No valid results received")
return
print("\n=== Vector Search Results ===\n")
docs = results['response']['docs']
num_found = results['response']['numFound']
if not docs:
print("No matching documents found.")
return
print(f"Found {num_found} matching document(s):\n")
for i, doc in enumerate(docs, 1):
print(f"Result {i}:")
print(f" ID: {doc.get('id', 'N/A')}")
# Handle title which could be a string or list
title = doc.get('title', 'N/A')
if isinstance(title, list) and title:
title = title[0]
print(f" Title: {title}")
if 'score' in doc:
print(f" Score: {doc['score']}")
# Handle text which could be string or list
text = doc.get('text', '')
if isinstance(text, list) and text:
text = text[0]
if text:
preview = text[:150] + "..." if len(text) > 150 else text
print(f" Preview: {preview}")
# Print model info if available
if 'vector_model' in doc:
print(f" Model: {doc.get('vector_model')}")
print()
async def main():
"""Main entry point."""
parser = argparse.ArgumentParser(description="Test vector search in Solr")
parser.add_argument("query", help="Search query")
parser.add_argument("--collection", "-c", default="vectors", help="Solr collection name")
parser.add_argument("--field", "-f", default="embedding", help="Vector field name")
parser.add_argument("--results", "-k", type=int, default=5, help="Number of results to return")
parser.add_argument("--filter", "-fq", help="Optional filter query")
args = parser.parse_args()
results = await vector_search(
args.query,
args.collection,
args.field,
args.results,
args.filter
)
if results:
display_results(results)
if __name__ == "__main__":
asyncio.run(main())