Skip to main content
Glama
rag_core.py10 kB
import re import os import json import logging from typing import Dict, List, Tuple, Optional from pathlib import Path # ... (rest of imports remain the same) ... import numpy as np from tqdm import tqdm from openai import OpenAI, APIConnectionError, RateLimitError, APIError from dotenv import load_dotenv # Load environment variables from .env file load_dotenv() # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # --- Config --- FAQ_DIR = os.getenv("FAQ_DIR", str(Path(__file__).parent / "faqs")) EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-small") LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini") CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "200")) TOP_K_DEFAULT = int(os.getenv("TOP_K_DEFAULT", "4")) SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.1")) # Initialize the OpenAI client (fail fast if key missing) _API_KEY = os.getenv("OPENAI_API_KEY") if not _API_KEY: raise RuntimeError("OPENAI_API_KEY is not set") client = OpenAI(api_key=_API_KEY) # Globals (preloaded at import) _CHUNKS: List[str] = [] _SOURCES: List[str] = [] _CHUNK_EMBEDS: Optional[np.ndarray] = None # shape: (N, d) # ---------------- Chunking Strategies ---------------- def sniff_strategy(text_sample: str) -> str: """Detect the best chunking strategy based on content patterns.""" # Markdown Header Strategy: Looks for lines starting with #, ##, or ### if re.search(r"^#{1,3}\s", text_sample, re.MULTILINE): return "markdown_splitter" # Q&A Strategy: Looks for specific "Q:" or "Question:" patterns elif re.search(r"^(Q:|Question:)\s", text_sample, re.MULTILINE): return "regex_splitter" # Default Fallback return "recursive_splitter" def split_by_recursive(text: str, size: int = CHUNK_SIZE) -> List[str]: """Fallback: Split text into fixed-size chunks (formerly _chunk_text).""" if not text or not text.strip(): return [] chunks = [] for i in range(0, len(text), size): chunk = text[i : i + size] if chunk.strip(): chunks.append(chunk) return chunks def split_by_markdown(text: str) -> List[str]: """Split text by Markdown headers (#, ##, ###).""" if not text: return [] # Split by header lines (e.g. "\n# Header") # Captures the delimiter so we can re-attach it if needed, or just split # For simplicity, let's split on the header pattern. # Pattern: Look for newline followed by #, ##, or ### and space. chunks = re.split(r'(?=\n#{1,3}\s)', text) return [c.strip() for c in chunks if c.strip()] def split_by_regex(text: str) -> List[str]: """Split text by Q&A patterns (Q: or Question:).""" if not text: return [] # Split by newline followed by Q: or Question: chunks = re.split(r'(?=\n(?:Q:|Question:)\s)', text) return [c.strip() for c in chunks if c.strip()] # ---------------- Core utilities ---------------- def _load_and_chunk_faqs(faq_dir: str) -> Tuple[List[str], List[str]]: """Load *.md files, chunk each using sniffed strategy.""" if not faq_dir: raise ValueError("faq_dir is required") faq_path = Path(faq_dir) if not faq_path.is_dir(): logger.error(f"FAQ directory not found: {faq_dir}") raise ValueError("faq_dir must be a directory") chunks: List[str] = [] sources: List[str] = [] try: files = list(faq_path.glob("*.md")) if not files: logger.warning(f"No .md files found in {faq_dir}") return [], [] for faq_file in files: try: with open(faq_file, "r") as f: text = f.read() # Dynamic Logic strategy = sniff_strategy(text[:1000]) # Peek first 1k chars logger.info(f"File: {faq_file.name} | Strategy: {strategy}") if strategy == "markdown_splitter": file_chunks = split_by_markdown(text) elif strategy == "regex_splitter": file_chunks = split_by_regex(text) else: file_chunks = split_by_recursive(text) chunks.extend(file_chunks) sources.extend([faq_file.name] * len(file_chunks)) logger.debug(f"Loaded {len(file_chunks)} chunks from {faq_file.name}") except Exception as e: logger.error(f"Failed to read file {faq_file}: {e}") except Exception as e: logger.error(f"Error accessing FAQ directory: {e}") return [], [] return chunks, sources def _embed_texts(texts: List[str]) -> np.ndarray: """Create embeddings for texts and return a (N, d) float32 numpy array.""" if not texts: return np.array([]) try: # Note: In production, batching is recommended for large lists response = client.embeddings.create(input=texts, model=EMBED_MODEL) return np.array([data.embedding for data in response.data]) except Exception as e: logger.error(f"Embedding generation failed: {e}") return np.array([]) def _embed_query(q: str) -> np.ndarray: """Create an embedding for the query and return a (d,) float32 vector.""" if not q: raise ValueError("q is required") try: response = client.embeddings.create(input=q, model=EMBED_MODEL) return np.array(response.data[0].embedding) except Exception as e: logger.error(f"Query embedding generation failed: {e}") raise def _generate_answer(context: str, question: str) -> str: """Call the chat model to answer using only context and cite filenames.""" if not context or not question: raise ValueError("context and question are required") try: response = client.chat.completions.create( model=LLM_MODEL, messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": f"Context: {context}\nQuestion: {question}"}, ], ) return response.choices[0].message.content except Exception as e: logger.error(f"LLM generation failed: {e}") return "Sorry, I encountered an error while generating the answer." # ---------------- Public API ---------------- def ask_faq_core(question: str, top_k: int = TOP_K_DEFAULT) -> Dict[str, object]: q = (question or "").strip() if not q: raise ValueError("question is required") if top_k <= 0: top_k = TOP_K_DEFAULT # If not yet implemented, return a safe placeholder so wrappers run. if _CHUNK_EMBEDS is None or len(_CHUNKS) == 0: logger.warning("RAG core not initialized or empty corpus.") return { "answer": "System is not ready or has no knowledge.", "sources": [], } try: q_emb = _embed_query(q) sims = _CHUNK_EMBEDS @ q_emb # cosine if rows are normalized top_idx = np.argsort(sims)[-top_k:][::-1] # Filter by threshold valid_idx = [i for i in top_idx if sims[i] >= SIMILARITY_THRESHOLD] logger.info(f"Query: '{q}' | Top similarity: {sims[top_idx[0]]:.4f} | Matches found: {len(valid_idx)}") if not valid_idx: return { "answer": "I don't have enough information to answer that question.", "sources": [] } top_files = [_SOURCES[i] for i in valid_idx] context_parts = [f"From {_SOURCES[i]}:\n{_CHUNKS[i]}" for i in valid_idx] context = "\n\n".join(context_parts) answer = _generate_answer(context, q) distinct_sources = sorted(list({f for f in top_files})) # Limit distinct sources if desired, or skip limit since max is top_k sources_out = distinct_sources[:2] if len(distinct_sources) >= 2 else distinct_sources return {"answer": answer, "sources": sources_out} except Exception as e: logger.error(f"Error in ask_faq_core: {e}") return { "answer": f"An error occurred: {str(e)}", "sources": [] } # ---------------- Module preload ---------------- def _preload() -> None: """Load and chunk FAQs, compute embeddings, L2-normalize rows, assign globals.""" global _CHUNKS, _SOURCES, _CHUNK_EMBEDS logger.info("Preloading FAQ corpus...") # 1. Load chunks (now returns local lists) chunks, sources = _load_and_chunk_faqs(FAQ_DIR) if not chunks: logger.warning("No FAQ chunks found.") # Ensure globals are at least empty lists/None to avoid stale state if re-run _CHUNKS = [] _SOURCES = [] _CHUNK_EMBEDS = None return # 2. Embed chunks logger.info(f"Embedding {len(chunks)} chunks...") embeds = _embed_texts(chunks) if embeds.size == 0: logger.error("Failed to generate embeddings for chunks.") return # 3. Normalize for cosine similarity # (x . y) / (|x| |y|) == (x/|x|) . (y/|y|) norms = np.linalg.norm(embeds, axis=1, keepdims=True) # Avoid zero-division if any embedding is all-zeros (unlikely but safe) norms[norms < 1e-9] = 1.0 embeds_norm = embeds / norms # 4. Assign globals _CHUNKS = chunks _SOURCES = sources _CHUNK_EMBEDS = embeds_norm logger.info(f"Preloaded {len(_CHUNKS)} chunks from {len(set(_SOURCES))} files.") # Run preload at import time (enable after implementation) _preload() # ---------------- Optional CLI runner ---------------- def main_cli(): try: q = input("Enter your question: ") print(json.dumps(ask_faq_core(q), indent=2)) except KeyboardInterrupt: print("\nExiting.") except Exception as e: logger.error(f"CLI Error: {e}") if __name__ == "__main__": main_cli()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/rhamsagar-sf/rag-mcp-project'

If you have feedback or need assistance with the MCP directory API, please join our Discord server