"""Multi-Slice Retriever for parallel filtered and unfiltered queries
This module provides a retrieval strategy that runs multiple parallel queries
with different filter configurations to balance precision and recall.
Multi-slice approach:
- Slice 1 (Unfiltered): Baseline vector similarity, no metadata filter
- Slice 2 (Filtered): Apply LLM-generated metadata filter (precision)
- Additional slices can be configured as needed
All slices run in parallel; results are deduplicated by vector ID,
keeping the highest score for duplicates.
"""
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Any
import boto3
logger = logging.getLogger(__name__)
# Default configuration
DEFAULT_TIMEOUT_SECONDS = 5.0
DEFAULT_MAX_SLICES = 3
DEFAULT_FILTERED_SCORE_BOOST = 1.25 # default max boost ceiling (configurable via settings)
ADAPTIVE_BOOST_MARGIN = 1.10 # 10% margin above parity to ensure filtered beats unfiltered
ADAPTIVE_BOOST_FLOOR = 1.05 # minimum boost even when filtered scores are already higher
@dataclass
class SliceConfig:
"""Configuration for a retrieval slice."""
name: str
use_filter: bool = False
num_results: int = 5
description: str = ""
def _get_uri(result: dict) -> str:
"""Extract S3 URI from a retrieval result."""
location = result.get("location", {})
s3_location = location.get("s3Location", {})
return s3_location.get("uri", "")
def deduplicate_results(results: list[dict]) -> list[dict]:
"""
Deduplicate retrieval results by S3 URI, keeping highest score.
Args:
results: List of retrieval result dictionaries.
Returns:
Deduplicated list with highest score for each unique document.
"""
if not results:
return []
# Use dict to track best result per URI
best_by_uri: dict[str, dict] = {}
for result in results:
uri = _get_uri(result) or f"_no_uri_{id(result)}"
score = result.get("score", 0.0)
# Keep result with highest score
if uri not in best_by_uri or score > best_by_uri[uri].get("score", 0):
best_by_uri[uri] = result
# Return deduplicated results sorted by score (descending)
deduped = list(best_by_uri.values())
deduped.sort(key=lambda x: x.get("score", 0), reverse=True)
return deduped
def compute_adaptive_boost(
filtered_results: list[dict],
unfiltered_results: list[dict],
max_boost: float,
) -> float:
"""
Compute a reactive score boost from actual score distributions.
Instead of a static multiplier, computes the exact boost needed to ensure
the best filtered result scores above the best unfiltered result, plus a
margin. Clamps between ADAPTIVE_BOOST_FLOOR and max_boost.
Falls back to max_boost (static behavior) when either slice is empty or
filtered scores are zero.
Args:
filtered_results: Results from the filtered slice (sorted by score desc).
unfiltered_results: Results from the unfiltered slice (sorted by score desc).
max_boost: Maximum allowed boost (from user config).
Returns:
Computed boost multiplier.
"""
if not filtered_results or not unfiltered_results:
logger.info(
"adaptive_boost fallback=static reason=%s max_boost=%.4f",
"empty_filtered" if not filtered_results else "empty_unfiltered",
max_boost,
)
return max_boost
best_filtered = filtered_results[0].get("score", 0.0)
best_unfiltered = unfiltered_results[0].get("score", 0.0)
if best_filtered <= 0:
logger.info(
"adaptive_boost fallback=static reason=zero_filtered_score max_boost=%.4f",
max_boost,
)
return max_boost
required_boost = (best_unfiltered / best_filtered) * ADAPTIVE_BOOST_MARGIN
# max_boost takes precedence over floor — if the user sets max_boost below
# ADAPTIVE_BOOST_FLOOR, we respect the ceiling (effectively disabling boost)
boost = min(max(required_boost, ADAPTIVE_BOOST_FLOOR), max_boost)
logger.info(
"adaptive_boost best_filtered=%.4f best_unfiltered=%.4f "
"required_boost=%.4f final_boost=%.4f max_boost=%.4f",
best_filtered,
best_unfiltered,
required_boost,
boost,
max_boost,
)
return boost
def merge_slices_with_guaranteed_minimum(
slice_results: dict[str, list[dict]],
min_per_slice: int = 3,
total_results: int = 10,
filtered_score_boost: float = 1.0,
) -> list[dict]:
"""
Merge multi-slice results with filtered results prioritized and boosted.
The filtered slice represents results matching the user's explicit intent
(e.g., a person name filter), so its results receive an adaptive score
boost to improve their ranking against unfiltered results.
Strategy:
1. Compute adaptive boost from actual score gap between slices.
2. Apply boost to filtered results.
3. Merge all results and sort by boosted score.
4. Deduplicate by URI, keeping highest boosted score.
Note: Returned results have their ``score`` field set to the boosted
value, not the raw KB similarity score. This is intentional — downstream
consumers see the ranking-adjusted score.
Args:
slice_results: Dict mapping slice name to its result list.
The "filtered" key (if present) receives the score boost.
min_per_slice: Unused, kept for API compatibility.
total_results: Maximum total results to return.
filtered_score_boost: Max boost ceiling for adaptive computation
(e.g., 1.25 = up to 25% boost). Configurable at runtime.
Returns:
Merged and deduplicated results sorted by boosted score.
"""
if not slice_results:
return []
# Identify filtered vs other slices
filtered_results = slice_results.get("filtered", [])
other_slices = {k: v for k, v in slice_results.items() if k != "filtered"}
# Collect unfiltered results for adaptive boost computation
all_unfiltered = [r for results in other_slices.values() for r in results]
all_unfiltered.sort(key=lambda x: x.get("score", 0), reverse=True)
# Compute adaptive boost from actual score distributions
# filtered_score_boost param serves as the max_boost ceiling
adaptive_boost = compute_adaptive_boost(filtered_results, all_unfiltered, filtered_score_boost)
# Apply adaptive boost to filtered results and tag them
boosted_filtered = []
for result in filtered_results:
boosted = result.copy()
original_score = result.get("score", 0.0)
boosted_score = original_score * adaptive_boost
boosted["_boosted_score"] = boosted_score
boosted["_is_filtered"] = True
boosted_filtered.append(boosted)
if adaptive_boost != 1.0:
logger.debug("Boosted filtered score: %.4f -> %.4f", original_score, boosted_score)
# Tag unfiltered results (no boost)
unfiltered_results = []
for results in other_slices.values():
for result in results:
tagged = result.copy()
tagged["_boosted_score"] = result.get("score", 0.0)
tagged["_is_filtered"] = False
unfiltered_results.append(tagged)
# Merge all results
all_results = boosted_filtered + unfiltered_results
# Sort by boosted score (descending)
all_results.sort(key=lambda x: x.get("_boosted_score", 0), reverse=True)
# Deduplicate by URI, keeping highest boosted score (already sorted)
seen_uris: set[str] = set()
merged: list[dict] = []
for result in all_results:
uri = _get_uri(result) or f"_no_uri_{id(result)}"
if uri not in seen_uris:
seen_uris.add(uri)
# Remove internal tags before returning, but update score with boosted value
clean_result = {k: v for k, v in result.items() if not k.startswith("_")}
clean_result["score"] = result.get("_boosted_score", result.get("score", 0.0))
merged.append(clean_result)
if len(merged) >= total_results:
break
# Structured log for empirical tuning
# total_results is 2x per-slice num_results (set by caller), so //2
# approximates the per-slice request count for fill_rate calculation
num_requested = total_results // 2 if total_results > 0 else 1
fill_rate = len(filtered_results) / max(num_requested, 1)
# score_ratio < 1.0 means filtered scored lower; adaptive_boost (its inverse) compensates
score_ratio = (
(filtered_results[0].get("score", 0) / all_unfiltered[0].get("score", 1))
if filtered_results and all_unfiltered and all_unfiltered[0].get("score", 0) > 0
else 0.0
)
logger.info(
"merge_slices fill_rate=%.4f score_ratio=%.4f "
"adaptive_boost=%.4f max_boost=%.4f "
"filtered_count=%d unfiltered_count=%d merged_count=%d",
fill_rate,
score_ratio,
adaptive_boost,
filtered_score_boost,
len(filtered_results),
len(all_unfiltered),
len(merged),
)
return merged
class MultiSliceRetriever:
"""
Multi-slice retrieval with parallel execution.
Runs multiple queries with different filter configurations in parallel
and merges results, keeping the highest score for duplicates.
Usage:
retriever = MultiSliceRetriever()
results = retriever.retrieve(
query="genealogy documents",
knowledge_base_id="kb-123",
data_source_id="ds-456",
metadata_filter={"topic": {"$eq": "genealogy"}},
num_results=5
)
"""
def __init__(
self,
bedrock_agent_client=None,
timeout_seconds: float = DEFAULT_TIMEOUT_SECONDS,
max_slices: int = DEFAULT_MAX_SLICES,
enabled: bool = True,
filtered_score_boost: float = DEFAULT_FILTERED_SCORE_BOOST,
):
"""
Initialize the multi-slice retriever.
Args:
bedrock_agent_client: Bedrock Agent Runtime client. Creates one if not provided.
timeout_seconds: Timeout per slice in seconds.
max_slices: Maximum number of slices to execute.
enabled: Whether multi-slice is enabled. If False, falls back to single query.
filtered_score_boost: Score multiplier for filtered results (e.g., 1.15 = 15% boost).
"""
self.bedrock_agent = bedrock_agent_client or boto3.client("bedrock-agent-runtime")
self.timeout_seconds = timeout_seconds
self.max_slices = max_slices
self.enabled = enabled
self.filtered_score_boost = filtered_score_boost
logger.info(
f"Initialized MultiSliceRetriever: timeout={timeout_seconds}s, "
f"max_slices={max_slices}, enabled={enabled}, filtered_boost={filtered_score_boost}"
)
def retrieve(
self,
query: str,
knowledge_base_id: str,
data_source_id: str | None,
metadata_filter: dict | None = None,
num_results: int = 5,
) -> list[dict]:
"""
Retrieve documents using multi-slice strategy.
Args:
query: Search query text.
knowledge_base_id: Bedrock Knowledge Base ID.
data_source_id: Data source ID for filtering (optional).
metadata_filter: LLM-generated metadata filter (optional).
num_results: Number of results per slice.
Returns:
List of deduplicated retrieval results.
"""
# If disabled or no filter, fall back to single query
if not self.enabled or metadata_filter is None:
return self._single_retrieve(
query=query,
knowledge_base_id=knowledge_base_id,
data_source_id=data_source_id,
metadata_filter=None,
num_results=num_results,
)
# Build slice configurations
slices = self._build_slice_configs(metadata_filter, num_results)
# Execute slices in parallel
slice_results: dict[str, list[dict]] = {}
try:
with ThreadPoolExecutor(max_workers=min(len(slices), self.max_slices)) as executor:
# Submit all slice retrievals
futures = {}
for slice_config in slices[: self.max_slices]:
future = executor.submit(
self._execute_slice,
query=query,
knowledge_base_id=knowledge_base_id,
data_source_id=data_source_id,
slice_config=slice_config,
metadata_filter=metadata_filter if slice_config.use_filter else None,
)
futures[future] = slice_config.name
# Collect results as they complete
for future in as_completed(futures, timeout=self.timeout_seconds):
slice_name = futures[future]
try:
results = future.result(timeout=0.1)
slice_results[slice_name] = results
logger.info(f"Slice '{slice_name}' returned {len(results)} results")
except Exception as e:
logger.warning(f"Slice '{slice_name}' failed: {e}")
# Continue with other slices
except TimeoutError:
logger.warning(f"Multi-slice retrieval timed out after {self.timeout_seconds}s")
# Return whatever we collected so far
except Exception as e:
logger.error(f"Multi-slice retrieval error: {e}")
# Return whatever we collected
# Merge with score boost for filtered results
total = sum(len(r) for r in slice_results.values())
merged = merge_slices_with_guaranteed_minimum(
slice_results,
min_per_slice=min(3, num_results),
total_results=num_results * 2,
filtered_score_boost=self.filtered_score_boost,
)
logger.info(
f"Multi-slice retrieval complete: {total} total, "
f"{len(merged)} after merge (filtered boost={self.filtered_score_boost})"
)
return merged
def _build_slice_configs(
self,
metadata_filter: dict | None,
num_results: int,
) -> list[SliceConfig]:
"""
Build slice configurations based on filter availability.
Args:
metadata_filter: The LLM-generated filter (if any).
num_results: Number of results per slice.
Returns:
List of SliceConfig objects.
"""
slices = [
# Slice 1: Unfiltered (baseline recall)
SliceConfig(
name="unfiltered",
use_filter=False,
num_results=num_results,
description="Baseline vector similarity",
),
]
if metadata_filter:
# Slice 2: With filter (precision)
slices.append(
SliceConfig(
name="filtered",
use_filter=True,
num_results=num_results,
description="LLM-generated metadata filter",
)
)
return slices
def _execute_slice(
self,
query: str,
knowledge_base_id: str,
data_source_id: str | None,
slice_config: SliceConfig,
metadata_filter: dict | None = None,
) -> list[dict]:
"""
Execute a single retrieval slice.
Args:
query: Search query text.
knowledge_base_id: Knowledge Base ID.
data_source_id: Data source ID.
slice_config: Configuration for this slice.
metadata_filter: Metadata filter to apply (if slice uses filter).
Returns:
List of retrieval results.
"""
try:
# Build vector search configuration
vector_config: dict[str, Any] = {
"numberOfResults": slice_config.num_results,
}
# Build filter expression
filter_expr = self._build_filter(data_source_id, metadata_filter)
if filter_expr:
vector_config["filter"] = filter_expr
# Execute retrieve
logger.info(f"[SLICE RETRIEVE] kb_id={knowledge_base_id}, config={vector_config}")
response = self.bedrock_agent.retrieve(
knowledgeBaseId=knowledge_base_id,
retrievalQuery={"text": query},
retrievalConfiguration={"vectorSearchConfiguration": vector_config},
)
results = response.get("retrievalResults", [])
for i, r in enumerate(results):
uri = r.get("location", {}).get("s3Location", {}).get("uri", "N/A")
score = r.get("score", "N/A")
logger.info(f"[SLICE RESULT] {i}: score={score}, uri={uri}")
return results
except Exception as e:
logger.warning(f"Slice '{slice_config.name}' execution failed: {e}")
raise
def _build_filter(
self,
data_source_id: str | None,
metadata_filter: dict | None,
) -> dict | None:
"""
Build the complete filter expression.
Combines data source filter with metadata filter using $and.
Args:
data_source_id: Data source ID for filtering.
metadata_filter: Additional metadata filter.
Returns:
Combined filter expression, or None if no filters.
"""
filters = []
# Add data source filter if provided
if data_source_id:
filters.append(
{
"equals": {
"key": "x-amz-bedrock-kb-data-source-id",
"value": data_source_id,
}
}
)
# Add metadata filter if provided
if metadata_filter:
# Convert S3 Vectors format to Bedrock KB format if needed
converted = self._convert_filter_format(metadata_filter)
if converted:
filters.append(converted)
if not filters:
return None
if len(filters) == 1:
return filters[0]
return {"andAll": filters}
def _convert_filter_format(self, filter_expr: dict) -> dict | None:
"""
Convert S3 Vectors filter format to Bedrock KB filter format.
S3 Vectors: {"topic": {"$eq": "genealogy"}}
Bedrock KB: {"equals": {"key": "topic", "value": "genealogy"}}
Args:
filter_expr: Filter in S3 Vectors format.
Returns:
Filter in Bedrock KB format.
"""
if not filter_expr:
return None
# Handle logical operators
if "$and" in filter_expr:
conditions = filter_expr["$and"]
converted = [self._convert_filter_format(c) for c in conditions]
converted = [c for c in converted if c] # Remove None values
if not converted:
return None
if len(converted) == 1:
return converted[0]
return {"andAll": converted}
if "$or" in filter_expr:
conditions = filter_expr["$or"]
converted = [self._convert_filter_format(c) for c in conditions]
converted = [c for c in converted if c]
if not converted:
return None
if len(converted) == 1:
return converted[0]
return {"orAll": converted}
# Handle field conditions
for key, value in filter_expr.items():
if key.startswith("$"):
continue # Skip operators at top level
if isinstance(value, dict):
# Extract operator and value
for op, op_value in value.items():
if op == "$eq":
return {"equals": {"key": key, "value": op_value}}
if op == "$ne":
return {"notEquals": {"key": key, "value": op_value}}
if op == "$gt":
return {"greaterThan": {"key": key, "value": op_value}}
if op == "$gte":
return {"greaterThanOrEquals": {"key": key, "value": op_value}}
if op == "$lt":
return {"lessThan": {"key": key, "value": op_value}}
if op == "$lte":
return {"lessThanOrEquals": {"key": key, "value": op_value}}
if op == "$in":
return {"in": {"key": key, "value": op_value}}
if op == "$nin":
return {"notIn": {"key": key, "value": op_value}}
if op == "$listContains":
return {"listContains": {"key": key, "value": op_value}}
if op == "$exists":
# Bedrock KB doesn't have direct exists filter
# Skip for now
logger.debug(f"Skipping $exists filter for key: {key}")
return None
else:
# Implicit $eq
return {"equals": {"key": key, "value": value}}
return None
def _single_retrieve(
self,
query: str,
knowledge_base_id: str,
data_source_id: str | None,
metadata_filter: dict | None,
num_results: int,
) -> list[dict]:
"""
Execute a single (non-parallel) retrieval.
Used when multi-slice is disabled or no filter is provided.
Args:
query: Search query text.
knowledge_base_id: Knowledge Base ID.
data_source_id: Data source ID.
metadata_filter: Metadata filter (usually None for single retrieval).
num_results: Number of results.
Returns:
List of retrieval results.
"""
try:
vector_config: dict[str, Any] = {
"numberOfResults": num_results,
}
filter_expr = self._build_filter(data_source_id, metadata_filter)
if filter_expr:
vector_config["filter"] = filter_expr
logger.info(f"[SINGLE RETRIEVE] kb_id={knowledge_base_id}, config={vector_config}")
response = self.bedrock_agent.retrieve(
knowledgeBaseId=knowledge_base_id,
retrievalQuery={"text": query},
retrievalConfiguration={"vectorSearchConfiguration": vector_config},
)
results = response.get("retrievalResults", [])
logger.info(f"Single retrieval returned {len(results)} results")
for i, r in enumerate(results):
uri = r.get("location", {}).get("s3Location", {}).get("uri", "N/A")
score = r.get("score", "N/A")
logger.info(f"[SINGLE RESULT] {i}: score={score}, uri={uri}")
return results
except Exception as e:
logger.error(f"Single retrieval failed: {e}")
return []