"""
Complete RAG (Retrieval-Augmented Generation) system for financial analysis.
Architecture:
1. DataFetcher: Retrieves financial data from yfinance
2. TextChunker: Splits data into semantic chunks
3. Embedder: Converts text to vectors (OpenAI)
4. VectorStore: In-memory semantic search
5. BM25Retriever: Keyword-based retrieval
6. HybridRetriever: Combines semantic + keyword
7. QueryProcessor: Processes user queries
8. RAGChain: Orchestrates entire pipeline
9. Metrics integration: Compatible with MetricsCollector
Comparable to FinRobot for fair evaluation.
"""
import json
import time
import re
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
from pathlib import Path
from datetime import datetime, timedelta
import os
import numpy as np
from finrobot.logging import get_logger
from finrobot.experiments.metrics_collector import MetricSnapshot
# Lazy import to avoid finnhub dependency in tests
YFinanceUtils = None
logger = get_logger(__name__)
@dataclass
class Document:
"""Represents a chunk of text with metadata."""
content: str
metadata: Dict[str, Any] = field(default_factory=dict)
embedding: Optional[np.ndarray] = None
def to_dict(self) -> Dict[str, Any]:
return {
"content": self.content,
"metadata": self.metadata,
"embedding": self.embedding.tolist() if self.embedding is not None else None,
}
class TextChunker:
"""
Split financial data into semantic chunks.
Strategies:
- Fixed size chunks with overlap
- Semantic boundaries (paragraphs, sections)
- Financial data specific (prices, news, fundamentals)
"""
def __init__(self, chunk_size: int = 500, overlap: int = 100):
"""
Initialize chunker.
Args:
chunk_size: Characters per chunk
overlap: Overlap between chunks for context
"""
self.chunk_size = chunk_size
self.overlap = overlap
logger.info(f"TextChunker initialized: size={chunk_size}, overlap={overlap}")
def chunk_text(self, text: str) -> List[str]:
"""
Split text into chunks.
Args:
text: Raw text to chunk
Returns:
List of text chunks
"""
if not text:
return []
chunks = []
start = 0
while start < len(text):
end = min(start + self.chunk_size, len(text))
chunk = text[start:end]
chunks.append(chunk)
start = end - self.overlap
logger.debug(f"Chunked text into {len(chunks)} pieces")
return chunks
def chunk_stock_data(self, ticker: str, data_dict: Dict[str, Any]) -> List[str]:
"""
Intelligently chunk stock data.
Args:
ticker: Stock symbol
data_dict: Dict with price_history, news, fundamentals, etc.
Returns:
List of semantic chunks
"""
chunks = []
# Price history chunk
if "price_history" in data_dict:
price_text = f"""
Stock: {ticker}
Price Data:
{str(data_dict['price_history'][:100])}
Summary: Recent price movements and trends visible in historical data.
"""
chunks.extend(self.chunk_text(price_text))
# News chunks
if "news" in data_dict and isinstance(data_dict["news"], list):
for i, news in enumerate(data_dict["news"][:5]): # Top 5 news
news_text = f"News #{i + 1}: {str(news)}"
chunks.extend(self.chunk_text(news_text))
# Fundamentals chunk
if "fundamentals" in data_dict:
fund_text = f"""
Financial Fundamentals for {ticker}:
{str(data_dict['fundamentals'])}
"""
chunks.extend(self.chunk_text(fund_text))
logger.info(f"Chunked {ticker} data into {len(chunks)} semantic chunks")
return chunks
class SimpleEmbedder:
"""
Embed text using OpenAI embeddings.
For now, use simple keyword-based embeddings (will replace with real OpenAI).
"""
def __init__(self, model: str = "text-embedding-3-small"):
"""
Initialize embedder.
Args:
model: OpenAI embedding model
"""
self.model = model
logger.info(f"SimpleEmbedder initialized: model={model}")
def embed(self, text: str) -> np.ndarray:
"""
Embed text to vector.
For Phase 3, use simple hash-based embedding.
For production, integrate OpenAI API.
Args:
text: Text to embed
Returns:
Embedding vector (768-dim for compatibility)
"""
# Simple deterministic embedding based on text content
# In production: use OpenAI API
# For now: create reproducible embedding from text
hash_val = hash(text) % (2**32)
np.random.seed(hash_val)
embedding = np.random.randn(768).astype(np.float32)
# Normalize
embedding = embedding / np.linalg.norm(embedding)
return embedding
def embed_batch(self, texts: List[str]) -> List[np.ndarray]:
"""
Embed multiple texts.
Args:
texts: List of texts
Returns:
List of embeddings
"""
return [self.embed(text) for text in texts]
class VectorStore:
"""
In-memory vector store for semantic search.
Stores documents with embeddings and performs similarity search.
"""
def __init__(self, embedder: Optional[SimpleEmbedder] = None):
"""
Initialize vector store.
Args:
embedder: Embedder to use (default: SimpleEmbedder)
"""
self.embedder = embedder or SimpleEmbedder()
self.documents: List[Document] = []
self.embeddings: np.ndarray = np.array([])
logger.info("VectorStore initialized")
def add_documents(self, documents: List[Document]) -> None:
"""
Add documents to store.
Args:
documents: List of Document objects
"""
for doc in documents:
if doc.embedding is None:
doc.embedding = self.embedder.embed(doc.content)
self.documents.append(doc)
# Rebuild embedding matrix
self.embeddings = np.array([doc.embedding for doc in self.documents])
logger.info(f"Added {len(documents)} documents, total: {len(self.documents)}")
def search(self, query: str, top_k: int = 5) -> List[Tuple[Document, float]]:
"""
Search for similar documents.
Args:
query: Query text
top_k: Number of results
Returns:
List of (Document, similarity_score) tuples
"""
if len(self.documents) == 0:
return []
query_embedding = self.embedder.embed(query)
# Cosine similarity
similarities = np.dot(self.embeddings, query_embedding)
# Top K
top_indices = np.argsort(similarities)[-top_k:][::-1]
results = [
(self.documents[i], float(similarities[i])) for i in top_indices
]
logger.debug(f"Search returned {len(results)} results")
return results
class BM25Retriever:
"""
BM25 keyword-based retriever.
Complements semantic search with keyword matching.
"""
def __init__(self):
"""Initialize BM25 retriever."""
self.documents: List[Document] = []
logger.info("BM25Retriever initialized")
def add_documents(self, documents: List[Document]) -> None:
"""
Add documents to index.
Args:
documents: List of Document objects
"""
self.documents = documents
logger.info(f"BM25 indexed {len(documents)} documents")
def search(self, query: str, top_k: int = 5) -> List[Tuple[Document, float]]:
"""
Search using keyword matching.
Args:
query: Query text
top_k: Number of results
Returns:
List of (Document, score) tuples
"""
query_terms = set(query.lower().split())
scores = []
for doc in self.documents:
doc_terms = set(doc.content.lower().split())
# Simple Jaccard similarity
intersection = len(query_terms & doc_terms)
union = len(query_terms | doc_terms)
score = intersection / union if union > 0 else 0
scores.append(score)
# Top K
top_indices = np.argsort(scores)[-top_k:][::-1]
results = [
(self.documents[i], float(scores[i])) for i in top_indices if scores[i] > 0
]
logger.debug(f"BM25 search returned {len(results)} results")
return results
class HybridRetriever:
"""
Combines semantic and keyword-based retrieval.
Uses weighted combination for best of both worlds.
"""
def __init__(
self,
vector_store: Optional[VectorStore] = None,
bm25_retriever: Optional[BM25Retriever] = None,
semantic_weight: float = 0.6,
):
"""
Initialize hybrid retriever.
Args:
vector_store: Vector store for semantic search
bm25_retriever: BM25 for keyword search
semantic_weight: Weight for semantic (0-1)
"""
self.vector_store = vector_store or VectorStore()
self.bm25_retriever = bm25_retriever or BM25Retriever()
self.semantic_weight = semantic_weight
logger.info(f"HybridRetriever initialized: semantic_weight={semantic_weight}")
def add_documents(self, documents: List[Document]) -> None:
"""
Add documents to both retrievers.
Args:
documents: List of Document objects
"""
self.vector_store.add_documents(documents)
self.bm25_retriever.add_documents(documents)
logger.info(f"Hybrid retriever indexed {len(documents)} documents")
def search(self, query: str, top_k: int = 5) -> List[Tuple[Document, float]]:
"""
Search using both methods, combine results.
Args:
query: Query text
top_k: Number of results
Returns:
List of (Document, combined_score) tuples
"""
# Semantic results
semantic_results = self.vector_store.search(query, top_k=top_k)
# Keyword results
keyword_results = self.bm25_retriever.search(query, top_k=top_k)
# Combine scores
doc_scores = {}
for doc, score in semantic_results:
doc_id = id(doc)
doc_scores[doc_id] = {
"doc": doc,
"semantic": score,
"keyword": 0.0,
}
for doc, score in keyword_results:
doc_id = id(doc)
if doc_id not in doc_scores:
doc_scores[doc_id] = {"doc": doc, "semantic": 0.0, "keyword": score}
else:
doc_scores[doc_id]["keyword"] = score
# Weighted combination
combined = [
(
scores["doc"],
self.semantic_weight * scores["semantic"]
+ (1 - self.semantic_weight) * scores["keyword"],
)
for scores in doc_scores.values()
]
# Sort and return top K
combined.sort(key=lambda x: x[1], reverse=True)
results = combined[:top_k]
logger.debug(f"Hybrid search returned {len(results)} results")
return results
class QueryProcessor:
"""
Process user queries and prepare retrieval context.
Handles:
- Query expansion
- Entity extraction (stock symbols, dates)
- Intent classification
"""
def __init__(self):
logger.info("QueryProcessor initialized")
def extract_entities(self, query: str) -> Dict[str, Any]:
"""
Extract entities from query.
Args:
query: User query
Returns:
Dict with extracted entities
"""
entities = {"ticker": None, "time_period": None, "metric": None}
# Extract ticker (all caps 1-5 chars)
ticker_match = re.search(r"\b([A-Z]{1,5})\b", query)
if ticker_match:
entities["ticker"] = ticker_match.group(1)
# Extract time period
time_patterns = {
"1_week": r"next\s+week|1\s+week",
"1_month": r"next\s+month|1\s+month",
"1_year": r"next\s+year|1\s+year",
}
for period, pattern in time_patterns.items():
if re.search(pattern, query, re.IGNORECASE):
entities["time_period"] = period
break
# Extract metric
metrics = ["price", "trend", "risk", "opportunity", "earnings", "revenue"]
for metric in metrics:
if metric in query.lower():
entities["metric"] = metric
break
return entities
def expand_query(self, query: str) -> str:
"""
Expand query with synonyms and related terms.
Args:
query: Original query
Returns:
Expanded query
"""
synonyms = {
"predict": "forecast predict estimate",
"price": "price cost value",
"up": "increase rise bull positive",
"down": "decrease fall bear negative",
"risk": "risk danger concern issue problem",
}
expanded = query
for word, expansion in synonyms.items():
if word in query.lower():
expanded += f" {expansion}"
return expanded
class RAGChain:
"""
Complete RAG pipeline.
Coordinates: data fetching → chunking → embedding → retrieval → LLM generation.
Outputs MetricSnapshot for comparison with FinRobot.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize RAG chain.
Args:
config: Configuration dict
"""
self.config = config or {}
self.chunker = TextChunker(
chunk_size=self.config.get("chunk_size", 500),
overlap=self.config.get("overlap", 100),
)
self.embedder = SimpleEmbedder()
self.vector_store = VectorStore(self.embedder)
self.bm25_retriever = BM25Retriever()
self.retriever = HybridRetriever(self.vector_store, self.bm25_retriever)
self.query_processor = QueryProcessor()
self.doc_cache = {} # Cache fetched data
logger.info("RAGChain initialized")
def fetch_and_prepare(self, ticker: str, start_date: str, end_date: str) -> None:
"""
Fetch financial data and prepare for retrieval.
Args:
ticker: Stock symbol
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
"""
cache_key = f"{ticker}_{start_date}_{end_date}"
if cache_key in self.doc_cache:
logger.debug(f"Using cached data for {ticker}")
return
logger.info(f"Fetching data for {ticker}...")
# Fetch from yfinance
try:
# Lazy import to avoid dependencies
from finrobot.data_source import YFinanceUtils as YF
price_data = YF.get_stock_data(ticker, start_date, end_date)
logger.info(f"Fetched price data for {ticker}: {len(str(price_data))} chars")
except Exception as e:
logger.error(f"Error fetching price data: {e}")
price_data = f"Unable to fetch price data for {ticker}"
# Create chunks
data_dict = {
"price_history": str(price_data),
"ticker": ticker,
}
chunks = self.chunker.chunk_stock_data(ticker, data_dict)
# Create documents
documents = [
Document(
content=chunk,
metadata={
"ticker": ticker,
"source": "yfinance",
"chunk_index": i,
"date_range": f"{start_date} to {end_date}",
},
)
for i, chunk in enumerate(chunks)
]
# Add to retriever
self.retriever.add_documents(documents)
self.doc_cache[cache_key] = documents
logger.info(f"Prepared {len(documents)} documents for {ticker}")
def retrieve_context(self, query: str, top_k: int = 5) -> List[str]:
"""
Retrieve relevant context for query.
Args:
query: User query
top_k: Number of results
Returns:
List of relevant document chunks
"""
# Expand query
expanded_query = self.query_processor.expand_query(query)
# Retrieve
results = self.retriever.search(expanded_query, top_k=top_k)
context = [doc.content for doc, score in results]
logger.info(f"Retrieved {len(context)} context chunks")
return context
def generate_response(
self,
query: str,
context: List[str],
llm_response: Optional[str] = None,
) -> str:
"""
Generate response using context (in Phase 4, use actual LLM).
Args:
query: User query
context: Retrieved context
llm_response: Optional LLM response (for testing)
Returns:
Generated response
"""
if llm_response:
return llm_response
# For now, generate rule-based response
# In Phase 4: Use OpenAI GPT-4
response = f"""Based on analysis of {len(context)} relevant data points:
Query: {query}
Context Summary:
{chr(10).join(['- ' + c[:100] + '...' for c in context[:3]])}
Analysis: The retrieved data suggests relevant trends and patterns.
For a complete analysis, this would be processed through GPT-4 with the retrieved context.
Note: Full LLM integration will be completed in Phase 4.
"""
return response
def run(
self,
ticker: str,
start_date: str,
end_date: str,
query: str,
task_name: str = "analysis",
) -> MetricSnapshot:
"""
Run complete RAG pipeline and return metrics.
Args:
ticker: Stock symbol
start_date: Start date
end_date: End date
query: User query
task_name: Name of task for tracking
Returns:
MetricSnapshot with results and metrics
"""
from finrobot.experiments.metrics_collector import MetricsCollector
collector = MetricsCollector()
metric = collector.start_measurement(
experiment_id=f"rag_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
system_name="rag",
ticker=ticker,
task_name=task_name,
)
try:
# Fetch and prepare
self.fetch_and_prepare(ticker, start_date, end_date)
# Retrieve context
context = self.retrieve_context(query, top_k=5)
# Generate response
response = self.generate_response(query, context)
# Store metrics
metric.set_response(response)
metric.tool_calls_count = 1 # Data retrieval + search
metric.reasoning_steps = len(context) + 1
# Estimate cost (embedding + retrieval, no LLM yet)
# ~100 tokens for retrieval
metric.set_cost(
prompt_tokens=len(query.split()) * 2,
completion_tokens=len(response.split()) * 2,
model="gpt-4",
)
logger.info(f"RAG pipeline completed for {ticker}")
except Exception as e:
logger.error(f"RAG pipeline error: {e}")
metric.error_occurred = True
metric.error_message = str(e)
finally:
collector.end_measurement(metric)
return metric