expand_vector_db.pyā¢4.73 kB
#!/usr/bin/env python3
"""
Expand Vector Database with Comprehensive Data
==============================================
This script loads data from multiple sources to create a comprehensive
vector database with better domain coverage:
1. Full MMLU dataset (all domains, no sampling)
2. MMLU-Pro (harder questions)
3. GPQA Diamond (graduate-level questions)
4. MATH dataset (competition mathematics)
Target: 20,000+ questions across 20+ domains
"""
from pathlib import Path
from benchmark_vector_db import BenchmarkVectorDB
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def expand_database():
"""Build comprehensive vector database"""
logger.info("=" * 60)
logger.info("Expanding Vector Database with Comprehensive Data")
logger.info("=" * 60)
# Initialize new database
db = BenchmarkVectorDB(
db_path=Path("./data/benchmark_vector_db_expanded"),
embedding_model="all-MiniLM-L6-v2"
)
# Build with significantly higher limits
logger.info("\nPhase 1: Loading MMLU-Pro (harder subset)")
logger.info("-" * 40)
mmlu_pro_questions = db.load_mmlu_pro_dataset(max_samples=5000)
logger.info(f"Loaded {len(mmlu_pro_questions)} MMLU-Pro questions")
logger.info("\nPhase 2: Loading GPQA Diamond (graduate-level)")
logger.info("-" * 40)
gpqa_questions = db.load_gpqa_dataset(fetch_real_scores=False)
logger.info(f"Loaded {len(gpqa_questions)} GPQA questions")
logger.info("\nPhase 3: Loading MATH dataset (competition math)")
logger.info("-" * 40)
math_questions = db.load_math_dataset(max_samples=2000)
logger.info(f"Loaded {len(math_questions)} MATH questions")
# Combine all questions
all_questions = mmlu_pro_questions + gpqa_questions + math_questions
logger.info(f"\nTotal questions to index: {len(all_questions)}")
# Index into vector database
if all_questions:
logger.info("\nIndexing questions into vector database...")
logger.info("This may take several minutes...")
db.index_questions(all_questions)
# Get final statistics
logger.info("\n" + "=" * 60)
logger.info("Database Statistics")
logger.info("=" * 60)
stats = db.get_statistics()
logger.info(f"\nTotal Questions: {stats['total_questions']}")
logger.info(f"\nSources:")
for source, count in stats.get('sources', {}).items():
logger.info(f" {source}: {count}")
logger.info(f"\nDomains:")
for domain, count in sorted(stats.get('domains', {}).items(), key=lambda x: x[1], reverse=True)[:20]:
logger.info(f" {domain}: {count}")
logger.info(f"\nDifficulty Levels:")
for level, count in stats.get('difficulty_levels', {}).items():
logger.info(f" {level}: {count}")
logger.info("\n" + "=" * 60)
logger.info("ā
Database expansion complete!")
logger.info("=" * 60)
return db, stats
def test_expanded_database(db):
"""Test the expanded database with example queries"""
logger.info("\n" + "=" * 60)
logger.info("Testing Expanded Database")
logger.info("=" * 60)
test_prompts = [
# Hard prompts
("Graduate-level physics", "Calculate the quantum correction to the partition function for a 3D harmonic oscillator"),
("Abstract mathematics", "Prove that every field is also a ring"),
("Competition math", "Find all zeros of the polynomial x^3 + 2x + 2 in Z_7"),
# Easy prompts
("Basic arithmetic", "What is 2 + 2?"),
("General knowledge", "What is the capital of France?"),
# Domain-specific
("Medical reasoning", "Diagnose a patient with acute chest pain"),
("Legal knowledge", "Explain the doctrine of precedent in common law"),
("Computer science", "Implement a binary search tree"),
]
for category, prompt in test_prompts:
logger.info(f"\n{category}: '{prompt[:50]}...'")
result = db.query_similar_questions(prompt, k=3)
logger.info(f" Risk Level: {result['risk_level']}")
logger.info(f" Success Rate: {result['weighted_success_rate']:.1%}")
logger.info(f" Recommendation: {result['recommendation']}")
if __name__ == "__main__":
# Expand database
db, stats = expand_database()
# Test with example queries
test_expanded_database(db)
logger.info("\nš All done! You can now use the expanded database.")
logger.info("To switch to the expanded database, update your demo files:")
logger.info(" db_path=Path('./data/benchmark_vector_db_expanded')")