Solr MCP
by allenday
Verified
- solr-mcp
- scripts
#!/usr/bin/env python3
"""
Demo script showing how to use the MCP client to search for information.
"""
import argparse
import asyncio
import os
import sys
import json
from typing import Dict, List, Optional, Any
from mcp import client
from mcp.transport.stdio import StdioClientTransport
from loguru import logger
# Add the project root to the path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from solr_mcp.embeddings.client import OllamaClient
async def search_by_text(query: str, collection: Optional[str] = None, rows: int = 5):
"""
Perform a text search using the MCP client.
Args:
query: Search query
collection: Collection name (optional)
rows: Number of results to return
"""
# Set up MCP client
mcp_command = ["python", "-m", "solr_mcp.server"]
transport = StdioClientTransport({"command": mcp_command})
try:
c = client.Client()
await c.connect(transport)
# Call the solr_search tool
args = {
"query": query,
"rows": rows
}
if collection:
args["collection"] = collection
logger.info(f"Searching for: {query}")
result = await c.request(
{"name": "solr_search", "arguments": args}
)
# Display results
print(f"\n=== Results for text search: '{query}' ===\n")
display_results(result)
finally:
await c.close()
async def search_by_vector(query: str, collection: Optional[str] = None, k: int = 5):
"""
Perform a vector similarity search using the MCP client.
Args:
query: Text to generate embedding from
collection: Collection name (optional)
k: Number of nearest neighbors to return
"""
# First, generate an embedding for the query
ollama_client = OllamaClient()
embedding = await ollama_client.get_embedding(query)
# Set up MCP client
mcp_command = ["python", "-m", "solr_mcp.server"]
transport = StdioClientTransport({"command": mcp_command})
try:
c = client.Client()
await c.connect(transport)
# Call the solr_vector_search tool
args = {
"vector": embedding,
"k": k
}
if collection:
args["collection"] = collection
logger.info(f"Vector searching for: {query}")
result = await c.request(
{"name": "solr_vector_search", "arguments": args}
)
# Display results
print(f"\n=== Results for vector search: '{query}' ===\n")
display_results(result)
finally:
await c.close()
def display_results(result: Dict):
"""
Display search results in a readable format.
Args:
result: Response from the MCP server
"""
if isinstance(result, dict) and "content" in result:
content = result["content"]
if isinstance(content, list) and len(content) > 0:
text_content = content[0].get("text", "")
# Try to parse the JSON content
try:
data = json.loads(text_content)
if "docs" in data and isinstance(data["docs"], list):
docs = data["docs"]
if not docs:
print("No results found.")
return
for i, doc in enumerate(docs, 1):
print(f"Result {i}:")
print(f" Title: {doc.get('title', 'No title')}")
print(f" ID: {doc.get('id', 'No ID')}")
if "score" in doc:
print(f" Score: {doc['score']}")
# Show a preview of the text (first 150 chars)
text = doc.get("text", "")
if text:
preview = text[:150] + "..." if len(text) > 150 else text
print(f" Preview: {preview}")
if "category" in doc:
categories = doc["category"] if isinstance(doc["category"], list) else [doc["category"]]
print(f" Categories: {', '.join(categories)}")
if "tags" in doc:
tags = doc["tags"] if isinstance(doc["tags"], list) else [doc["tags"]]
print(f" Tags: {', '.join(tags)}")
print()
print(f"Total results: {data.get('numFound', len(docs))}")
else:
print(text_content)
except json.JSONDecodeError:
print(text_content)
else:
print(result)
async def main():
"""Main entry point."""
parser = argparse.ArgumentParser(description="Demo search using the MCP client")
parser.add_argument("query", help="Search query")
parser.add_argument("--vector", "-v", action="store_true", help="Use vector search instead of text search")
parser.add_argument("--collection", "-c", help="Collection name")
parser.add_argument("--results", "-n", type=int, default=5, help="Number of results to return")
args = parser.parse_args()
if args.vector:
await search_by_vector(args.query, args.collection, args.results)
else:
await search_by_text(args.query, args.collection, args.results)
if __name__ == "__main__":
logger.remove()
logger.add(sys.stderr, level="INFO")
asyncio.run(main())