Skip to main content
Glama
app.py9.25 kB
# gateway/app.py from __future__ import annotations import os from pathlib import Path from typing import List, Dict, Tuple, Any import yaml from dotenv import load_dotenv from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from qdrant_client import QdrantClient from openai import OpenAI # ---------------- Env & clients ---------------- load_dotenv("env/.env") # Qdrant (embedded local if QDRANT_LOCAL_PATH is set; otherwise host/port) _LOCAL_PATH = os.getenv("QDRANT_LOCAL_PATH") if _LOCAL_PATH: qc = QdrantClient(path=_LOCAL_PATH) else: qc = QdrantClient( host=os.getenv("QDRANT_HOST", "localhost"), port=int(os.getenv("QDRANT_PORT", "6333")), ) # OpenAI client client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), timeout=30.0) EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-large") CHAT_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") # RBAC policy (reuse same file as MCP server) PREFIX = os.getenv("QDRANT_COLLECTION_PREFIX", "finrag") RBAC_PATH = Path(__file__).resolve().parents[1] / "mcp_server" / "policies" / "rbac.yaml" RBAC = yaml.safe_load(open(RBAC_PATH, "r", encoding="utf-8")) if RBAC_PATH.exists() else {"roles": {}} def allowed_domains(role: str) -> List[str]: roles = (RBAC or {}).get("roles", {}) entry = roles.get(role) or roles.get(role.upper()) or {} allow = entry.get("allow") or [] return list(allow) if allow else ["general"] def embed(text: str) -> List[float]: return client.embeddings.create(model=EMBED_MODEL, input=[text]).data[0].embedding # ---------------- Qdrant helpers ---------------- def _query_points(collection: str, vector: List[float], top_k: int): """ Use legacy `search` for maximum compatibility across qdrant-client versions, and to ensure we get objects with payload/score. """ return qc.search( collection_name=collection, query_vector=vector, limit=top_k, with_payload=True, with_vectors=False, ) def _normalize_results(res: Any) -> List[Any]: """ Normalize qdrant results into a flat list of items that (ideally) have .payload/.score. Handles QueryResponse, list, tuple, or nested structures. """ if hasattr(res, "points"): # e.g., QueryResponse items = res.points else: items = list(res or []) flat: List[Any] = [] for it in items: if isinstance(it, (list, tuple)): flat.extend(list(it)) else: flat.append(it) return flat # ---------------- Retrieval ---------------- def retrieve( role: str, query: str, top_k: int = 50, include_general: bool = True, ) -> List[Dict]: """ RBAC-aware retrieval across allowed collections. Returns dicts with: score, domain, collection, source, title, doc_type, tags, chunk_id, snippet On failure for a collection, appends {'error': ...} so the caller can surface issues. """ domains = allowed_domains(role) # Optional filter to hide `general` unless asked if not include_general: domains = [d for d in domains if d != "general"] # Safety: never end up with zero domains if not domains: domains = allowed_domains(role) vector = embed(query) hits: List[Dict] = [] role_domain = role.lower().strip() for d in domains: col = f"{PREFIX}_{d}" try: res = _query_points(col, vector, top_k) for p in _normalize_results(res): # tolerate tuple/dict shapes across client versions payload = getattr(p, "payload", None) score = float(getattr(p, "score", 0.0)) if payload is None and isinstance(p, dict): payload = p.get("payload", {}) score = float(p.get("score", 0.0)) if payload is None and isinstance(p, (list, tuple)) and p: maybe = p[0] payload = getattr(maybe, "payload", None) or {} score = float(getattr(maybe, "score", 0.0) or score) pl = payload or {} dom = (pl.get("domain") or d).lower() # Domain prior: boost same-domain results slightly so FINANCE prefers finance docs, etc. prior = 1.15 if dom == role_domain else 1.0 hits.append( { "score": score * float(pl.get("boost", 1.0)) * prior, "domain": pl.get("domain", d), "collection": col, "source": pl.get("source"), "title": pl.get("title"), "doc_type": pl.get("doc_type"), "tags": pl.get("tags", []), "chunk_id": pl.get("chunk_id"), "snippet": (pl.get("text") or "")[:1000], } ) except Exception as e: hits.append({"error": f"{col}: {e}"}) hits.sort(key=lambda x: x.get("score", 0.0), reverse=True) return hits[:top_k] # ---------------- API models ---------------- class ChatRequest(BaseModel): query: str role: str = Field(default="EMPLOYEE") top_k: int = Field(default=50, ge=1, le=200) # how many results to return & display ctx_k: int = Field(default=20, ge=1, le=100) # how many chunks to include in the LLM context include_general: bool = Field(default=True) # include `general` alongside role-specific domains class SourceItem(BaseModel): score: float domain: str source: str | None = None snippet: str | None = None title: str | None = None doc_type: str | None = None tags: List[str] = [] chunk_id: int | None = None class ChatResponse(BaseModel): answer: str sources: List[SourceItem] = [] # ---------------- FastAPI app ---------------- app = FastAPI(title="FinRAG Gateway", version="1.3") # Allow Streamlit (localhost) to call this API app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:8501", "http://127.0.0.1:8501"], allow_methods=["*"], allow_headers=["*"], ) @app.get("/") def root(): return {"status": "ok", "endpoints": ["/chat", "/health", "/docs"]} @app.get("/health") def health(): try: cols = [c.name for c in qc.get_collections().collections] except Exception as e: cols = [f"error: {e}"] return { "ok": True, "collections": cols, "embed_model": EMBED_MODEL, "chat_model": CHAT_MODEL, "local_qdrant": bool(_LOCAL_PATH), } def _clean_hits(raw_hits: List[Dict]) -> Tuple[List[Dict], List[str]]: """Separate out any {'error': ...} rows so response building never crashes.""" hits, errors = [], [] for h in raw_hits: if isinstance(h, dict) and "error" in h: errors.append(str(h["error"])) else: hits.append(h) return hits, errors @app.post("/chat", response_model=ChatResponse) def chat(req: ChatRequest): # 1) Retrieval try: raw_hits = retrieve( req.role, req.query, top_k=req.top_k, include_general=req.include_general, ) except Exception as e: return ChatResponse(answer=f"Sorry, retrieval failed: {e}", sources=[]) hits, errors = _clean_hits(raw_hits) if not hits and errors: msg = "Sorry, I couldn't retrieve any context:\n- " + "\n- ".join(errors[:10]) return ChatResponse(answer=msg, sources=[]) # 2) Build context from the first ctx_k hits (but return ALL hits for display) ctx_hits = hits[: min(req.ctx_k, len(hits))] context = "\n\n".join( f"[{i+1}] ({h.get('domain','?')}) {h.get('snippet','')}\nSOURCE: {h.get('source','')}" for i, h in enumerate(ctx_hits) ) system = ( "You are an enterprise RAG assistant. Use ONLY the provided context to answer. " "If the context is insufficient, say you don't know. Always cite sources as [1], [2], etc." ) user = f"Question: {req.query}\n\nContext:\n{context}" # 3) LLM completion (graceful fallback) try: completion = client.chat.completions.create( model=CHAT_MODEL, messages=[ {"role": "system", "content": system}, {"role": "user", "content": user}, ], temperature=0.2, ) answer = completion.choices[0].message.content except Exception as e: answer = ( f"Model call failed: {e}\n\nTop sources:\n" + "\n".join(f"[{i+1}] {h.get('source','')}" for i, h in enumerate(ctx_hits)) ) sources = [ SourceItem( score=float(h.get("score", 0.0)), domain=h.get("domain", "unknown"), source=h.get("source"), snippet=h.get("snippet"), title=h.get("title"), doc_type=h.get("doc_type"), tags=h.get("tags", []), chunk_id=h.get("chunk_id"), ) for h in hits # return ALL hits (no dedupe) ] return ChatResponse(answer=answer, sources=sources)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Nithishkaranam2002/Finrag--mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server