#!/usr/bin/env python3
"""Benchmark script to compare embedding latency between Ollama and MLX backends.
This script measures embedding performance for both backends and produces
a comparison table with statistics including average, p95, p99, min, and max latency.
Usage:
# Compare both backends (default)
uv run python scripts/benchmark_embeddings.py
# Test specific backend
uv run python scripts/benchmark_embeddings.py --backend ollama
uv run python scripts/benchmark_embeddings.py --backend mlx
# Custom iterations and batch size
uv run python scripts/benchmark_embeddings.py --iterations 50 --batch-size 10
Requirements:
- Ollama backend: Ollama server running with mxbai-embed-large model
- MLX backend: mlx-embeddings installed (Apple Silicon only)
"""
from __future__ import annotations
import argparse
import asyncio
import statistics
import sys
import time
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class BenchmarkResult:
"""Results from a benchmark run."""
backend: str
latencies_ms: List[float] = field(default_factory=list)
errors: List[str] = field(default_factory=list)
@property
def success_count(self) -> int:
"""Number of successful iterations."""
return len(self.latencies_ms)
@property
def error_count(self) -> int:
"""Number of failed iterations."""
return len(self.errors)
@property
def avg_ms(self) -> Optional[float]:
"""Average latency in milliseconds."""
if not self.latencies_ms:
return None
return statistics.mean(self.latencies_ms)
@property
def min_ms(self) -> Optional[float]:
"""Minimum latency in milliseconds."""
if not self.latencies_ms:
return None
return min(self.latencies_ms)
@property
def max_ms(self) -> Optional[float]:
"""Maximum latency in milliseconds."""
if not self.latencies_ms:
return None
return max(self.latencies_ms)
@property
def p95_ms(self) -> Optional[float]:
"""95th percentile latency in milliseconds."""
if not self.latencies_ms or len(self.latencies_ms) < 2:
return None
return statistics.quantiles(self.latencies_ms, n=20)[18] # 95th percentile
@property
def p99_ms(self) -> Optional[float]:
"""99th percentile latency in milliseconds."""
if not self.latencies_ms or len(self.latencies_ms) < 2:
return None
return statistics.quantiles(self.latencies_ms, n=100)[98] # 99th percentile
@property
def stddev_ms(self) -> Optional[float]:
"""Standard deviation of latency in milliseconds."""
if not self.latencies_ms or len(self.latencies_ms) < 2:
return None
return statistics.stdev(self.latencies_ms)
def check_mlx_available() -> bool:
"""Check if mlx-embeddings is installed."""
try:
import mlx_embeddings # noqa: F401
return True
except ImportError:
return False
def check_ollama_available() -> bool:
"""Check if Ollama server appears reachable."""
try:
import httpx
with httpx.Client(timeout=2.0) as client:
response = client.get("http://localhost:11434/api/tags")
return response.status_code == 200
except Exception:
return False
async def benchmark_ollama(
iterations: int,
batch_size: int,
texts: List[str],
warmup: int = 2,
) -> BenchmarkResult:
"""Benchmark Ollama embedding backend.
Args:
iterations: Number of embedding operations to perform.
batch_size: Number of texts per batch.
texts: Sample texts to embed.
warmup: Number of warmup iterations.
Returns:
BenchmarkResult with timing data.
"""
result = BenchmarkResult(backend="ollama")
try:
from recall.embedding import create_embedding_provider
provider = create_embedding_provider("ollama")
# Warmup iterations (not counted)
print(f" Warming up Ollama ({warmup} iterations)...")
for _ in range(warmup):
try:
await provider.embed_batch(texts[:batch_size])
except Exception:
pass
# Benchmark iterations
print(f" Running {iterations} iterations...")
for i in range(iterations):
batch = texts[: min(batch_size, len(texts))]
start = time.perf_counter()
try:
await provider.embed_batch(batch)
elapsed_ms = (time.perf_counter() - start) * 1000
result.latencies_ms.append(elapsed_ms)
except Exception as e:
result.errors.append(str(e))
# Progress indicator
if (i + 1) % 10 == 0:
print(f" Progress: {i + 1}/{iterations}")
await provider.close()
except ImportError as e:
result.errors.append(f"Import error: {e}")
except Exception as e:
result.errors.append(f"Setup error: {e}")
return result
async def benchmark_mlx(
iterations: int,
batch_size: int,
texts: List[str],
warmup: int = 2,
) -> BenchmarkResult:
"""Benchmark MLX embedding backend.
Args:
iterations: Number of embedding operations to perform.
batch_size: Number of texts per batch.
texts: Sample texts to embed.
warmup: Number of warmup iterations.
Returns:
BenchmarkResult with timing data.
"""
result = BenchmarkResult(backend="mlx")
try:
from recall.embedding import create_embedding_provider
provider = create_embedding_provider("mlx")
# Warmup iterations (not counted)
print(f" Warming up MLX ({warmup} iterations)...")
for _ in range(warmup):
try:
await provider.embed_batch(texts[:batch_size])
except Exception:
pass
# Benchmark iterations
print(f" Running {iterations} iterations...")
for i in range(iterations):
batch = texts[: min(batch_size, len(texts))]
start = time.perf_counter()
try:
await provider.embed_batch(batch)
elapsed_ms = (time.perf_counter() - start) * 1000
result.latencies_ms.append(elapsed_ms)
except Exception as e:
result.errors.append(str(e))
# Progress indicator
if (i + 1) % 10 == 0:
print(f" Progress: {i + 1}/{iterations}")
await provider.close()
except ImportError as e:
result.errors.append(f"Import error: {e}")
except Exception as e:
result.errors.append(f"Setup error: {e}")
return result
def format_ms(value: Optional[float], precision: int = 2) -> str:
"""Format millisecond value for display."""
if value is None:
return "N/A"
return f"{value:.{precision}f}"
def print_comparison_table(results: List[BenchmarkResult]) -> None:
"""Print a formatted comparison table of benchmark results.
Args:
results: List of BenchmarkResult objects to compare.
"""
print("\n" + "=" * 80)
print("EMBEDDING BENCHMARK RESULTS")
print("=" * 80)
# Header
headers = ["Backend", "Avg (ms)", "Min (ms)", "Max (ms)", "P95 (ms)", "P99 (ms)", "StdDev", "Success", "Errors"]
col_widths = [10, 12, 12, 12, 12, 12, 10, 8, 8]
header_line = " | ".join(h.center(w) for h, w in zip(headers, col_widths))
print(header_line)
print("-" * len(header_line))
# Data rows
for r in results:
row = [
r.backend.center(col_widths[0]),
format_ms(r.avg_ms).rjust(col_widths[1]),
format_ms(r.min_ms).rjust(col_widths[2]),
format_ms(r.max_ms).rjust(col_widths[3]),
format_ms(r.p95_ms).rjust(col_widths[4]),
format_ms(r.p99_ms).rjust(col_widths[5]),
format_ms(r.stddev_ms).rjust(col_widths[6]),
str(r.success_count).rjust(col_widths[7]),
str(r.error_count).rjust(col_widths[8]),
]
print(" | ".join(row))
print("=" * 80)
# Comparison summary
if len(results) == 2 and all(r.avg_ms is not None for r in results):
ollama_result = next((r for r in results if r.backend == "ollama"), None)
mlx_result = next((r for r in results if r.backend == "mlx"), None)
if ollama_result and mlx_result and ollama_result.avg_ms and mlx_result.avg_ms:
speedup = ollama_result.avg_ms / mlx_result.avg_ms
faster_backend = "MLX" if speedup > 1 else "Ollama"
speedup_factor = speedup if speedup > 1 else 1 / speedup
print(f"\nSummary: {faster_backend} is {speedup_factor:.1f}x faster on average")
if mlx_result.avg_ms < ollama_result.avg_ms:
savings_ms = ollama_result.avg_ms - mlx_result.avg_ms
print(f" MLX saves {savings_ms:.1f}ms per batch operation")
def print_errors(results: List[BenchmarkResult]) -> None:
"""Print any errors encountered during benchmarking."""
for r in results:
if r.errors:
print(f"\n{r.backend.upper()} Errors ({len(r.errors)}):")
for i, err in enumerate(r.errors[:5], 1):
print(f" {i}. {err[:100]}...")
if len(r.errors) > 5:
print(f" ... and {len(r.errors) - 5} more errors")
def get_sample_texts(count: int = 20) -> List[str]:
"""Generate sample texts for embedding benchmarks.
Args:
count: Number of sample texts to generate.
Returns:
List of sample texts with varying lengths.
"""
samples = [
"The quick brown fox jumps over the lazy dog.",
"Python is a high-level programming language.",
"Machine learning enables computers to learn from data.",
"The weather today is sunny with clear skies.",
"Artificial intelligence is transforming industries worldwide.",
"Natural language processing helps computers understand human language.",
"Deep learning uses neural networks with multiple layers.",
"The Recall MCP server provides persistent memory storage.",
"Vector embeddings capture semantic meaning of text.",
"ChromaDB is a vector database for similarity search.",
"SQLite provides lightweight relational database functionality.",
"Async programming improves application responsiveness.",
"API design patterns help create consistent interfaces.",
"Test-driven development improves code quality.",
"Documentation makes software easier to understand.",
"Version control tracks changes to source code.",
"Continuous integration automates build and test processes.",
"Code reviews help catch bugs early in development.",
"Performance optimization reduces resource consumption.",
"Security best practices protect against vulnerabilities.",
]
# Extend with repeated samples if needed
while len(samples) < count:
samples.extend(samples[: count - len(samples)])
return samples[:count]
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Benchmark embedding latency for Ollama and MLX backends.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--backend",
type=str,
choices=["ollama", "mlx", "both"],
default="both",
help="Which backend(s) to benchmark",
)
parser.add_argument(
"--iterations",
type=int,
default=20,
help="Number of benchmark iterations per backend",
)
parser.add_argument(
"--batch-size",
type=int,
default=5,
help="Number of texts per embedding batch",
)
parser.add_argument(
"--warmup",
type=int,
default=2,
help="Number of warmup iterations (not counted)",
)
parser.add_argument(
"--json",
action="store_true",
help="Output results as JSON",
)
return parser.parse_args()
async def main() -> int:
"""Run the benchmark script.
Returns:
Exit code (0 for success, 1 for failure).
"""
args = parse_args()
results: List[BenchmarkResult] = []
sample_texts = get_sample_texts(count=args.batch_size * 2)
print("\n" + "=" * 80)
print("EMBEDDING BACKEND BENCHMARK")
print("=" * 80)
print(f"Configuration:")
print(f" Iterations: {args.iterations}")
print(f" Batch size: {args.batch_size}")
print(f" Warmup: {args.warmup}")
print(f" Backend(s): {args.backend}")
print()
# Check availability
ollama_available = check_ollama_available()
mlx_available = check_mlx_available()
print("Backend Availability:")
print(f" Ollama: {'Available' if ollama_available else 'Not available (server not running)'}")
print(f" MLX: {'Available' if mlx_available else 'Not available (mlx-embeddings not installed)'}")
print()
# Run benchmarks
if args.backend in ("both", "ollama"):
if ollama_available:
print("Benchmarking Ollama...")
result = await benchmark_ollama(
iterations=args.iterations,
batch_size=args.batch_size,
texts=sample_texts,
warmup=args.warmup,
)
results.append(result)
print(f" Completed: {result.success_count} successful, {result.error_count} errors")
else:
print("Skipping Ollama benchmark (server not available)")
results.append(BenchmarkResult(backend="ollama", errors=["Server not available"]))
if args.backend in ("both", "mlx"):
if mlx_available:
print("Benchmarking MLX...")
result = await benchmark_mlx(
iterations=args.iterations,
batch_size=args.batch_size,
texts=sample_texts,
warmup=args.warmup,
)
results.append(result)
print(f" Completed: {result.success_count} successful, {result.error_count} errors")
else:
print("Skipping MLX benchmark (mlx-embeddings not installed)")
results.append(BenchmarkResult(backend="mlx", errors=["mlx-embeddings not installed"]))
# Output results
if args.json:
import json
output = {
"config": {
"iterations": args.iterations,
"batch_size": args.batch_size,
"warmup": args.warmup,
},
"results": [
{
"backend": r.backend,
"avg_ms": r.avg_ms,
"min_ms": r.min_ms,
"max_ms": r.max_ms,
"p95_ms": r.p95_ms,
"p99_ms": r.p99_ms,
"stddev_ms": r.stddev_ms,
"success_count": r.success_count,
"error_count": r.error_count,
}
for r in results
],
}
print(json.dumps(output, indent=2))
else:
print_comparison_table(results)
print_errors(results)
# Return success if at least one backend worked
return 0 if any(r.success_count > 0 for r in results) else 1
if __name__ == "__main__":
sys.exit(asyncio.run(main()))