"""
Automated RAG pipeline validation script.
Tests end-to-end functionality, multi-tenant isolation, and anti-hallucination.
"""
import httpx
import time
import json
from pathlib import Path
from typing import Dict, List, Any, Tuple
import sys
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
BASE_URL = "http://localhost:8000"
TEST_TENANT_A = "tenant_A"
TEST_TENANT_B = "tenant_B"
TEST_USER_A = "user_A"
TEST_USER_B = "user_B"
TEST_KB_A = "kb_A"
TEST_KB_B = "kb_B"
# Test documents
TENANT_A_DOC = Path(__file__).parent.parent / "data" / "test_docs" / "tenant_A_kb.md"
TENANT_B_DOC = Path(__file__).parent.parent / "data" / "test_docs" / "tenant_B_kb.md"
# Test results storage
test_results: List[Dict[str, Any]] = []
def print_header(text: str):
"""Print a formatted header."""
print("\n" + "=" * 80)
print(f" {text}")
print("=" * 80)
def print_test(test_name: str, passed: bool, reason: str = ""):
"""Print test result."""
status = "[PASS]" if passed else "[FAIL]"
print(f"{status} | {test_name}")
if reason:
print(f" └─ {reason}")
test_results.append({
"test": test_name,
"passed": passed,
"reason": reason
})
def wait_for_server(max_retries: int = 10, delay: int = 2) -> bool:
"""Wait for the server to be ready."""
print("Waiting for server to be ready...")
for i in range(max_retries):
try:
response = httpx.get(f"{BASE_URL}/health", timeout=5)
if response.status_code == 200:
print("[OK] Server is ready")
return True
except Exception:
pass
time.sleep(delay)
print(f" Retry {i+1}/{max_retries}...")
print("[FAIL] Server not ready after max retries")
return False
def upload_document(
client: httpx.Client,
file_path: Path,
tenant_id: str,
user_id: str,
kb_id: str
) -> Dict[str, Any]:
"""Upload a document to the knowledge base."""
try:
with open(file_path, "rb") as f:
files = {"file": (file_path.name, f, "text/markdown")}
data = {
"tenant_id": tenant_id,
"user_id": user_id,
"kb_id": kb_id
}
response = client.post(
f"{BASE_URL}/kb/upload",
files=files,
data=data,
timeout=60
)
if response.status_code == 200:
return {"success": True, "data": response.json()}
else:
return {"success": False, "error": response.text}
except Exception as e:
return {"success": False, "error": str(e)}
def test_retrieval(
client: httpx.Client,
query: str,
tenant_id: str,
user_id: str,
kb_id: str,
expected_keywords: List[str],
should_not_contain: List[str] = None,
top_k: int = 5
) -> Tuple[bool, str]:
"""Test retrieval accuracy."""
try:
# Use GET for search endpoint with headers for dev mode auth
headers = {
"X-Tenant-Id": tenant_id,
"X-User-Id": user_id
}
response = client.get(
f"{BASE_URL}/kb/search",
params={
"query": query,
"kb_id": kb_id,
"top_k": top_k
},
headers=headers,
timeout=30
)
if response.status_code != 200:
return False, f"API returned {response.status_code}: {response.text}"
data = response.json()
results = data.get("results", [])
if not results:
return False, "No results retrieved"
# Check tenant isolation
for result in results:
metadata = result.get("metadata", {})
result_tenant = metadata.get("tenant_id")
if result_tenant != tenant_id:
return False, f"Tenant leak detected! Got tenant_id={result_tenant}, expected {tenant_id}"
# Check for expected keywords
all_content = " ".join([r.get("content", "") for r in results]).lower()
found_keywords = [kw for kw in expected_keywords if kw.lower() in all_content]
if not found_keywords:
return False, f"Expected keywords not found: {expected_keywords}"
# Check for forbidden content
if should_not_contain:
for forbidden in should_not_contain:
if forbidden.lower() in all_content:
return False, f"Forbidden content found: {forbidden}"
return True, f"Retrieved {len(results)} results, found keywords: {found_keywords}"
except Exception as e:
return False, f"Error: {str(e)}"
def test_chat(
client: httpx.Client,
question: str,
tenant_id: str,
user_id: str,
kb_id: str,
expected_keywords: List[str] = None,
should_refuse: bool = False,
should_not_contain: List[str] = None
) -> Tuple[bool, str, Dict[str, Any]]:
"""Test full chat endpoint."""
try:
# Include headers for dev mode auth
headers = {
"X-Tenant-Id": tenant_id,
"X-User-Id": user_id
}
response = client.post(
f"{BASE_URL}/chat",
json={
"tenant_id": tenant_id,
"user_id": user_id,
"kb_id": kb_id,
"question": question
},
headers=headers,
timeout=60
)
if response.status_code != 200:
return False, f"API returned {response.status_code}: {response.text}", {}
data = response.json()
answer = data.get("answer", "").lower()
citations = data.get("citations", [])
from_kb = data.get("from_knowledge_base", False)
confidence = data.get("confidence", 0.0)
metadata = data.get("metadata", {})
refused = metadata.get("refused", False)
# Check refusal behavior (STRICT)
if should_refuse:
# Check if response explicitly indicates refusal
refused = data.get("refused", False)
refusal_keywords = [
"couldn't find", "don't have", "not available", "contact support",
"not in the knowledge base", "could not verify", "not enough information",
"apologize", "couldn't find relevant information"
]
has_refusal_keywords = any(kw in answer for kw in refusal_keywords)
# If answer was generated with citations, it's a FAIL (should have refused)
if citations and len(citations) > 0:
return False, (
f"Should have refused but generated answer with {len(citations)} citations. "
f"Answer: {answer[:300]}"
), data
# If confidence is high and answer exists, it's a FAIL
if confidence >= 0.30 and answer and not has_refusal_keywords:
return False, (
f"Should have refused but generated answer with confidence {confidence:.2f}. "
f"Answer: {answer[:300]}"
), data
# If not refused and no refusal keywords, it's a FAIL
if not refused and not has_refusal_keywords:
return False, (
f"Should have refused but didn't. "
f"refused={refused}, confidence={confidence:.2f}, citations={len(citations)}. "
f"Answer: {answer[:300]}"
), data
# If we got here, it properly refused
return True, f"Properly refused (refused={refused}, confidence={confidence:.2f})", data
# Check for expected keywords
if expected_keywords:
found = [kw for kw in expected_keywords if kw.lower() in answer]
if not found:
return False, f"Expected keywords not found: {expected_keywords}. Answer: {answer[:200]}", data
# Check citations
if not should_refuse and from_kb:
if not citations:
return False, "Answer claims to be from KB but has no citations", data
# Check for forbidden content
if should_not_contain:
for forbidden in should_not_contain:
if forbidden.lower() in answer:
return False, f"Forbidden content found in answer: {forbidden}", data
# Check citation integrity
if citations and expected_keywords:
citation_text = " ".join([c.get("excerpt", "") for c in citations]).lower()
for kw in expected_keywords:
if kw.lower() in answer and kw.lower() not in citation_text:
# This is a warning, not a failure
pass
return True, f"Answer generated (confidence: {confidence:.2f}, citations: {len(citations)})", data
except Exception as e:
return False, f"Error: {str(e)}", {}
def main():
"""Run all validation tests."""
print_header("RAG Pipeline Validation Suite")
# Check server
if not wait_for_server():
print("[FAIL] Cannot proceed without server")
return
client = httpx.Client(timeout=120.0)
# ========== PHASE 1: Upload Documents ==========
print_header("Phase 1: Upload Test Documents")
# Upload tenant A doc
print(f"\n📤 Uploading {TENANT_A_DOC.name} for {TEST_TENANT_A}...")
result = upload_document(client, TENANT_A_DOC, TEST_TENANT_A, TEST_USER_A, TEST_KB_A)
if result["success"]:
print("[OK] Upload successful")
print("⏳ Waiting for document processing (10 seconds)...")
time.sleep(10) # Wait longer for processing (parsing, chunking, embedding)
else:
print(f"[FAIL] Upload failed: {result.get('error')}")
return
# Upload tenant B doc
print(f"\n📤 Uploading {TENANT_B_DOC.name} for {TEST_TENANT_B}...")
result = upload_document(client, TENANT_B_DOC, TEST_TENANT_B, TEST_USER_B, TEST_KB_B)
if result["success"]:
print("[OK] Upload successful")
print("⏳ Waiting for document processing (10 seconds)...")
time.sleep(10) # Wait longer for processing (parsing, chunking, embedding)
else:
print(f"[FAIL] Upload failed: {result.get('error')}")
return
# ========== PHASE 2: Retrieval Tests ==========
print_header("Phase 2: Retrieval Accuracy Tests")
# Test 1: Tenant A retrieval
passed, reason = test_retrieval(
client,
"What is the refund window?",
TEST_TENANT_A,
TEST_USER_A,
TEST_KB_A,
expected_keywords=["7 days"],
should_not_contain=["30 days"]
)
print_test("Retrieval: Tenant A - Refund Window", passed, reason)
# Test 2: Tenant B retrieval
passed, reason = test_retrieval(
client,
"What is the refund window?",
TEST_TENANT_B,
TEST_USER_B,
TEST_KB_B,
expected_keywords=["30 days"],
should_not_contain=["7 days"]
)
print_test("Retrieval: Tenant B - Refund Window", passed, reason)
# Test 3: Tenant isolation (A should not get B's data)
passed, reason = test_retrieval(
client,
"Starter plan price",
TEST_TENANT_A,
TEST_USER_A,
TEST_KB_A,
expected_keywords=["499"],
should_not_contain=["999"]
)
print_test("Retrieval: Tenant A - Starter Plan Price (Isolation)", passed, reason)
# Test 4: Tenant isolation (B should not get A's data)
passed, reason = test_retrieval(
client,
"Starter plan price",
TEST_TENANT_B,
TEST_USER_B,
TEST_KB_B,
expected_keywords=["999"],
should_not_contain=["499"]
)
print_test("Retrieval: Tenant B - Starter Plan Price (Isolation)", passed, reason)
# ========== PHASE 3: Chat Tests ==========
print_header("Phase 3: Chat Endpoint Tests")
# Test 5: Tenant A chat - refund window
passed, reason, data = test_chat(
client,
"What is the refund window?",
TEST_TENANT_A,
TEST_USER_A,
TEST_KB_A,
expected_keywords=["7 days"],
should_not_contain=["30 days"]
)
print_test("Chat: Tenant A - Refund Window", passed, reason)
# Test 6: Tenant B chat - refund window
passed, reason, data = test_chat(
client,
"What is the refund window?",
TEST_TENANT_B,
TEST_USER_B,
TEST_KB_B,
expected_keywords=["30 days"],
should_not_contain=["7 days"]
)
print_test("Chat: Tenant B - Refund Window", passed, reason)
# Test 7: Tenant A chat - Starter plan
passed, reason, data = test_chat(
client,
"What is the Starter plan price?",
TEST_TENANT_A,
TEST_USER_A,
TEST_KB_A,
expected_keywords=["499"],
should_not_contain=["999"]
)
print_test("Chat: Tenant A - Starter Plan Price", passed, reason)
# Test 8: Tenant B chat - Starter plan
passed, reason, data = test_chat(
client,
"What is the Starter plan price?",
TEST_TENANT_B,
TEST_USER_B,
TEST_KB_B,
expected_keywords=["999"],
should_not_contain=["499"]
)
print_test("Chat: Tenant B - Starter Plan Price", passed, reason)
# Test 9: Hallucination refusal - out of scope
passed, reason, data = test_chat(
client,
"How to integrate ClientSphere with Shopify?",
TEST_TENANT_A,
TEST_USER_A,
TEST_KB_A,
should_refuse=True
)
print_test("Chat: Hallucination Refusal (Out of Scope)", passed, reason)
# Test 10: Citation integrity
passed, reason, data = test_chat(
client,
"How long do password reset links last?",
TEST_TENANT_A,
TEST_USER_A,
TEST_KB_A,
expected_keywords=["15"]
)
if passed:
citations = data.get("citations", [])
if citations:
print_test("Chat: Citation Integrity", True, f"Found {len(citations)} citations")
else:
print_test("Chat: Citation Integrity", False, "No citations provided")
else:
print_test("Chat: Citation Integrity", False, reason)
# ========== PHASE 4: Summary ==========
print_header("Test Summary")
total_tests = len(test_results)
passed_tests = sum(1 for r in test_results if r["passed"])
failed_tests = total_tests - passed_tests
print(f"\nTotal Tests: {total_tests}")
print(f"[PASS] Passed: {passed_tests}")
print(f"[FAIL] Failed: {failed_tests}")
print(f"Success Rate: {(passed_tests/total_tests*100):.1f}%")
if failed_tests > 0:
print("\n[FAIL] Failed Tests:")
for result in test_results:
if not result["passed"]:
print(f" - {result['test']}: {result['reason']}")
# Final verdict
print_header("Final Verdict")
if failed_tests == 0:
print("[PASS] ALL TESTS PASSED - RAG Pipeline is working correctly")
return 0
else:
print(f"[FAIL] {failed_tests} TEST(S) FAILED - Review issues above")
return 1
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code)