import re
import os
import json
import logging
from typing import Dict, List, Tuple, Optional
from pathlib import Path
# ... (rest of imports remain the same) ...
import numpy as np
from tqdm import tqdm
from openai import OpenAI, APIConnectionError, RateLimitError, APIError
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# --- Config ---
FAQ_DIR = os.getenv("FAQ_DIR", str(Path(__file__).parent / "faqs"))
EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-small")
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "200"))
TOP_K_DEFAULT = int(os.getenv("TOP_K_DEFAULT", "4"))
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.1"))
# Initialize the OpenAI client (fail fast if key missing)
_API_KEY = os.getenv("OPENAI_API_KEY")
if not _API_KEY:
raise RuntimeError("OPENAI_API_KEY is not set")
client = OpenAI(api_key=_API_KEY)
# Globals (preloaded at import)
_CHUNKS: List[str] = []
_SOURCES: List[str] = []
_CHUNK_EMBEDS: Optional[np.ndarray] = None # shape: (N, d)
# ---------------- Chunking Strategies ----------------
def sniff_strategy(text_sample: str) -> str:
"""Detect the best chunking strategy based on content patterns."""
# Markdown Header Strategy: Looks for lines starting with #, ##, or ###
if re.search(r"^#{1,3}\s", text_sample, re.MULTILINE):
return "markdown_splitter"
# Q&A Strategy: Looks for specific "Q:" or "Question:" patterns
elif re.search(r"^(Q:|Question:)\s", text_sample, re.MULTILINE):
return "regex_splitter"
# Default Fallback
return "recursive_splitter"
def split_by_recursive(text: str, size: int = CHUNK_SIZE) -> List[str]:
"""Fallback: Split text into fixed-size chunks (formerly _chunk_text)."""
if not text or not text.strip():
return []
chunks = []
for i in range(0, len(text), size):
chunk = text[i : i + size]
if chunk.strip():
chunks.append(chunk)
return chunks
def split_by_markdown(text: str) -> List[str]:
"""Split text by Markdown headers (#, ##, ###)."""
if not text:
return []
# Split by header lines (e.g. "\n# Header")
# Captures the delimiter so we can re-attach it if needed, or just split
# For simplicity, let's split on the header pattern.
# Pattern: Look for newline followed by #, ##, or ### and space.
chunks = re.split(r'(?=\n#{1,3}\s)', text)
return [c.strip() for c in chunks if c.strip()]
def split_by_regex(text: str) -> List[str]:
"""Split text by Q&A patterns (Q: or Question:)."""
if not text:
return []
# Split by newline followed by Q: or Question:
chunks = re.split(r'(?=\n(?:Q:|Question:)\s)', text)
return [c.strip() for c in chunks if c.strip()]
# ---------------- Core utilities ----------------
def _load_and_chunk_faqs(faq_dir: str) -> Tuple[List[str], List[str]]:
"""Load *.md files, chunk each using sniffed strategy."""
if not faq_dir:
raise ValueError("faq_dir is required")
faq_path = Path(faq_dir)
if not faq_path.is_dir():
logger.error(f"FAQ directory not found: {faq_dir}")
raise ValueError("faq_dir must be a directory")
chunks: List[str] = []
sources: List[str] = []
try:
files = list(faq_path.glob("*.md"))
if not files:
logger.warning(f"No .md files found in {faq_dir}")
return [], []
for faq_file in files:
try:
with open(faq_file, "r") as f:
text = f.read()
# Dynamic Logic
strategy = sniff_strategy(text[:1000]) # Peek first 1k chars
logger.info(f"File: {faq_file.name} | Strategy: {strategy}")
if strategy == "markdown_splitter":
file_chunks = split_by_markdown(text)
elif strategy == "regex_splitter":
file_chunks = split_by_regex(text)
else:
file_chunks = split_by_recursive(text)
chunks.extend(file_chunks)
sources.extend([faq_file.name] * len(file_chunks))
logger.debug(f"Loaded {len(file_chunks)} chunks from {faq_file.name}")
except Exception as e:
logger.error(f"Failed to read file {faq_file}: {e}")
except Exception as e:
logger.error(f"Error accessing FAQ directory: {e}")
return [], []
return chunks, sources
def _embed_texts(texts: List[str]) -> np.ndarray:
"""Create embeddings for texts and return a (N, d) float32 numpy array."""
if not texts:
return np.array([])
try:
# Note: In production, batching is recommended for large lists
response = client.embeddings.create(input=texts, model=EMBED_MODEL)
return np.array([data.embedding for data in response.data])
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
return np.array([])
def _embed_query(q: str) -> np.ndarray:
"""Create an embedding for the query and return a (d,) float32 vector."""
if not q:
raise ValueError("q is required")
try:
response = client.embeddings.create(input=q, model=EMBED_MODEL)
return np.array(response.data[0].embedding)
except Exception as e:
logger.error(f"Query embedding generation failed: {e}")
raise
def _generate_answer(context: str, question: str) -> str:
"""Call the chat model to answer using only context and cite filenames."""
if not context or not question:
raise ValueError("context and question are required")
try:
response = client.chat.completions.create(
model=LLM_MODEL,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"Context: {context}\nQuestion: {question}"},
],
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"LLM generation failed: {e}")
return "Sorry, I encountered an error while generating the answer."
# ---------------- Public API ----------------
def ask_faq_core(question: str, top_k: int = TOP_K_DEFAULT) -> Dict[str, object]:
q = (question or "").strip()
if not q:
raise ValueError("question is required")
if top_k <= 0:
top_k = TOP_K_DEFAULT
# If not yet implemented, return a safe placeholder so wrappers run.
if _CHUNK_EMBEDS is None or len(_CHUNKS) == 0:
logger.warning("RAG core not initialized or empty corpus.")
return {
"answer": "System is not ready or has no knowledge.",
"sources": [],
}
try:
q_emb = _embed_query(q)
sims = _CHUNK_EMBEDS @ q_emb # cosine if rows are normalized
top_idx = np.argsort(sims)[-top_k:][::-1]
# Filter by threshold
valid_idx = [i for i in top_idx if sims[i] >= SIMILARITY_THRESHOLD]
logger.info(f"Query: '{q}' | Top similarity: {sims[top_idx[0]]:.4f} | Matches found: {len(valid_idx)}")
if not valid_idx:
return {
"answer": "I don't have enough information to answer that question.",
"sources": []
}
top_files = [_SOURCES[i] for i in valid_idx]
context_parts = [f"From {_SOURCES[i]}:\n{_CHUNKS[i]}" for i in valid_idx]
context = "\n\n".join(context_parts)
answer = _generate_answer(context, q)
distinct_sources = sorted(list({f for f in top_files}))
# Limit distinct sources if desired, or skip limit since max is top_k
sources_out = distinct_sources[:2] if len(distinct_sources) >= 2 else distinct_sources
return {"answer": answer, "sources": sources_out}
except Exception as e:
logger.error(f"Error in ask_faq_core: {e}")
return {
"answer": f"An error occurred: {str(e)}",
"sources": []
}
# ---------------- Module preload ----------------
def _preload() -> None:
"""Load and chunk FAQs, compute embeddings, L2-normalize rows, assign globals."""
global _CHUNKS, _SOURCES, _CHUNK_EMBEDS
logger.info("Preloading FAQ corpus...")
# 1. Load chunks (now returns local lists)
chunks, sources = _load_and_chunk_faqs(FAQ_DIR)
if not chunks:
logger.warning("No FAQ chunks found.")
# Ensure globals are at least empty lists/None to avoid stale state if re-run
_CHUNKS = []
_SOURCES = []
_CHUNK_EMBEDS = None
return
# 2. Embed chunks
logger.info(f"Embedding {len(chunks)} chunks...")
embeds = _embed_texts(chunks)
if embeds.size == 0:
logger.error("Failed to generate embeddings for chunks.")
return
# 3. Normalize for cosine similarity
# (x . y) / (|x| |y|) == (x/|x|) . (y/|y|)
norms = np.linalg.norm(embeds, axis=1, keepdims=True)
# Avoid zero-division if any embedding is all-zeros (unlikely but safe)
norms[norms < 1e-9] = 1.0
embeds_norm = embeds / norms
# 4. Assign globals
_CHUNKS = chunks
_SOURCES = sources
_CHUNK_EMBEDS = embeds_norm
logger.info(f"Preloaded {len(_CHUNKS)} chunks from {len(set(_SOURCES))} files.")
# Run preload at import time (enable after implementation)
_preload()
# ---------------- Optional CLI runner ----------------
def main_cli():
try:
q = input("Enter your question: ")
print(json.dumps(ask_faq_core(q), indent=2))
except KeyboardInterrupt:
print("\nExiting.")
except Exception as e:
logger.error(f"CLI Error: {e}")
if __name__ == "__main__":
main_cli()