from __future__ import annotations
import json
import logging
import os
import pickle
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias
import faiss
import numpy as np
from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool
from google.adk.tools.tool_context import ToolContext
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
EmbedderFn: TypeAlias = Callable[[str], list[float]]
class LocalFAISSRetrieval(BaseRetrievalTool):
"""
A robust tool for retrieving documents from a specific FAISS index using vector similarity.
Each instance handles a single index directory.
"""
def __init__(
self,
*,
name: str,
description: str,
index_dir: str,
model_name: str = "jhgan/ko-sroberta-multitask",
top_k: int = 10,
key_text: str = "text",
embedder: Optional[EmbedderFn] = None,
):
super().__init__(name=name, description=description)
self.index_dir = Path(index_dir)
self.index_path = self.index_dir / "rag_index.faiss"
self.documents_path = self.index_dir / "rag_documents.pkl"
self.model_name = model_name
self.top_k = top_k
self.key_text = key_text
# Use provided embedder or create sentence transformer
if embedder:
self.embedder = embedder
self.model = None
else:
self.model = SentenceTransformer(model_name)
self.embedder = self._default_embedder
# Load index
self.index, self.documents = self._load_index()
def _default_embedder(self, text: str) -> list[float]:
"""Default embedder using sentence transformer."""
return self.model.encode([text])[0].tolist()
def _load_index(self) -> tuple[faiss.IndexFlatIP, list[dict]]:
"""Load existing FAISS index and documents from the specified directory."""
if not self.index_path.exists() or not self.documents_path.exists():
raise FileNotFoundError(
f"Index files not found in {self.index_dir}. "
f"Please run build_index.py first to generate the indices."
)
try:
logger.info(f"Loading FAISS index from {self.index_dir}")
index = faiss.read_index(str(self.index_path))
with open(self.documents_path, "rb") as f:
documents = pickle.load(f)
logger.info(f"Loaded index with {index.ntotal} vectors and {len(documents)} documents")
return index, documents
except Exception as e:
raise RuntimeError(f"Failed to load index from {self.index_dir}: {e}") from e
async def run_async(
self,
*,
args: dict[str, Any],
tool_context: Optional[ToolContext] = None,
) -> list[str]:
query = args.get("query")
if not isinstance(query, str) or not query.strip():
logger.warning("Invalid or missing query")
raise ValueError("Query must be a non-empty string.")
logger.info(f"Running local FAISS retrieval for query: {query!r}")
try:
vector = self.embedder(query)
except Exception as e:
logger.error("Failed to generate embedding vector", exc_info=True)
raise RuntimeError("Embedder failed to generate a valid vector.") from e
if not isinstance(vector, list) or not all(
isinstance(x, float) for x in vector
):
raise TypeError("Embedder must return a list of floats.")
logger.debug(f"Embedding vector generated (dim={len(vector)})")
try:
# Search in FAISS index
query_vector = np.array([vector], dtype=np.float32)
# Normalize query vector for cosine similarity
faiss.normalize_L2(query_vector)
distances, indices = self.index.search(query_vector, self.top_k)
# Get relevant documents
texts = []
for i, idx in enumerate(indices[0]):
if idx < len(self.documents):
doc = self.documents[idx]
if isinstance(doc.get(self.key_text), str):
texts.append(doc[self.key_text])
else:
# Fallback to 'text' key if key_text is different
texts.append(doc.get("text", ""))
except Exception as e:
logger.error("FAISS query failed", exc_info=True)
raise RuntimeError("Failed to query FAISS index.") from e
logger.info(f"Retrieved {len(texts)} results from local FAISS index.")
return texts
def create_all_retrieval_tools(
base_index_path: str = "knowledge/index",
model_name: str = "jhgan/ko-sroberta-multitask",
top_k: int = 10,
embedder: Optional[EmbedderFn] = None,
) -> Dict[str, LocalFAISSRetrieval]:
"""
모든 인덱스 디렉토리를 검색하여 각각에 대한 LocalFAISSRetrieval 인스턴스를 생성합니다.
Args:
base_index_path: 인덱스들이 저장된 기본 경로
model_name: 사용할 임베딩 모델명
top_k: 검색할 최대 결과 수
embedder: 사용자 정의 임베더 함수 (선택사항)
Returns:
인덱스명을 키로 하고 LocalFAISSRetrieval 인스턴스를 값으로 하는 딕셔너리
"""
base_path = Path(base_index_path)
if not base_path.exists():
raise FileNotFoundError(f"Base index path does not exist: {base_index_path}")
retrieval_tools = {}
# 인덱스 디렉토리들 찾기
for index_dir in base_path.iterdir():
if index_dir.is_dir():
# 필수 파일들이 존재하는지 확인
index_file = index_dir / "rag_index.faiss"
documents_file = index_dir / "rag_documents.pkl"
if index_file.exists() and documents_file.exists():
try:
# 인덱스 이름을 디렉토리명으로 사용
index_name = index_dir.name
# 설명 생성 (인덱스 유형에 따라)
if "terms" in index_name.lower():
description = f"무역 및 물류 용어집 검색 도구 ({index_name})"
elif "code" in index_name.lower() or "package" in index_name.lower():
description = f"패키지 단위 코드 검색 도구 ({index_name})"
else:
description = f"문서 검색 도구 ({index_name})"
# LocalFAISSRetrieval 인스턴스 생성
tool = LocalFAISSRetrieval(
name=f"rag_tool_{index_name}",
description=description,
index_dir=str(index_dir),
model_name=model_name,
top_k=top_k,
embedder=embedder,
)
retrieval_tools[index_name] = tool
logger.info(f"✅ Created retrieval tool for index: {index_name}")
except Exception as e:
logger.error(f"❌ Failed to create tool for {index_dir.name}: {e}")
else:
logger.warning(f"⚠️ Skipping {index_dir.name}: missing required files")
if not retrieval_tools:
raise ValueError(f"No valid indices found in {base_index_path}")
logger.info(f"🎉 Created {len(retrieval_tools)} retrieval tools: {list(retrieval_tools.keys())}")
return retrieval_tools
def get_retrieval_tool(
index_name: str,
base_index_path: str = "ecom_agent/knowledge/index",
model_name: str = "jhgan/ko-sroberta-multitask",
top_k: int = 10,
embedder: Optional[EmbedderFn] = None,
) -> LocalFAISSRetrieval:
"""
특정 인덱스에 대한 LocalFAISSRetrieval 인스턴스를 생성합니다.
agent.py에서도 사용하는 함수입니다.
Args:
index_name: 사용할 인덱스명 (예: "terms", "code_package")
base_index_path: 인덱스들이 저장된 기본 경로
model_name: 사용할 임베딩 모델명
top_k: 검색할 최대 결과 수
embedder: 사용자 정의 임베더 함수 (선택사항)
Returns:
LocalFAISSRetrieval 인스턴스
"""
index_dir = Path(base_index_path) / index_name
if not index_dir.exists():
available_indices = [d.name for d in Path(base_index_path).iterdir() if d.is_dir()]
raise ValueError(f"Index '{index_name}' not found. Available indices: {available_indices}")
# 설명 생성
if "terms" in index_name.lower():
description = f"무역 및 물류 용어집 검색 도구 ({index_name})"
elif "code" in index_name.lower() or "package" in index_name.lower():
description = f"패키지 단위 코드 검색 도구 ({index_name})"
else:
description = f"문서 검색 도구 ({index_name})"
return LocalFAISSRetrieval(
name=f"rag_tool_{index_name}",
description=description,
index_dir=str(index_dir),
model_name=model_name,
top_k=top_k,
embedder=embedder,
)
if __name__ == "__main__":
import asyncio
async def test_single_index():
"""단일 인덱스 테스트"""
print("🚀 단일 인덱스 테스트 시작")
# 용어집 인덱스 테스트
terms_tool = get_retrieval_tool("terms", top_k=3)
query = "FOB 조건이란?"
print(f"📝 용어집 쿼리: {query}")
results = await terms_tool.run_async(args={"query": query})
print(f"📊 용어집 결과 개수: {len(results)}")
for i, result in enumerate(results, 1):
preview = result[:100] + "..." if len(result) > 100 else result
print(f" {i}. {preview}")
print("\n" + "="*50 + "\n")
# 패키지 코드 인덱스 테스트
code_tool = get_retrieval_tool("code_package", top_k=3)
query2 = "킬로그램 단위"
print(f"📝 패키지 코드 쿼리: {query2}")
results2 = await code_tool.run_async(args={"query": query2})
print(f"📊 패키지 코드 결과 개수: {len(results2)}")
for i, result in enumerate(results2, 1):
preview = result[:100] + "..." if len(result) > 100 else result
print(f" {i}. {preview}")
print("✅ 단일 인덱스 테스트 완료!")
async def test_all_indices():
"""모든 인덱스 테스트"""
print("\n🚀 모든 인덱스 테스트 시작")
# 모든 검색 도구 생성
all_tools = create_all_retrieval_tools(top_k=2)
print(f"📚 생성된 도구들: {list(all_tools.keys())}")
# 각 도구에 대해 테스트 쿼리 실행
test_queries = {
"terms": "CIF 조건",
"code_package": "박스 포장"
}
for index_name, tool in all_tools.items():
query = test_queries.get(index_name, "테스트")
print(f"\n📝 {index_name} 인덱스 쿼리: {query}")
try:
results = await tool.run_async(args={"query": query})
print(f"📊 결과 개수: {len(results)}")
for i, result in enumerate(results, 1):
preview = result[:80] + "..." if len(result) > 80 else result
print(f" {i}. {preview}")
except Exception as e:
print(f"❌ 오류: {e}")
print("✅ 모든 인덱스 테스트 완료!")
async def main():
try:
await test_single_index()
await test_all_indices()
except Exception as e:
print(f"❌ 테스트 실패: {e}")
asyncio.run(main())