import { BaseRetriever } from '@langchain/core/retrievers';
import { Document } from '@langchain/core/documents';
import { CallbackManagerForRetrieverRun } from '@langchain/core/callbacks/manager';
import { VectorStore } from './store';
import logger from '../utils/logger';
interface SalesforceRAGRetrieverParams {
vectorStore: VectorStore;
limit?: number;
searchMode?: 'hybrid' | 'vector' | 'keyword' | 'symbol';
types?: string[];
orgId: string;
}
export class SalesforceRAGRetriever extends BaseRetriever {
lc_namespace = ['salesforce-rag', 'retrievers'];
private vectorStore: VectorStore;
private limit: number;
private searchMode: 'hybrid' | 'vector' | 'keyword' | 'symbol';
private types?: string[];
private orgId: string;
constructor(params: SalesforceRAGRetrieverParams) {
super();
this.vectorStore = params.vectorStore;
this.limit = params.limit || 10;
this.searchMode = params.searchMode || 'hybrid';
this.types = params.types;
this.orgId = params.orgId;
}
async _getRelevantDocuments(
query: string,
runManager?: CallbackManagerForRetrieverRun
): Promise<Document[]> {
try {
logger.info('LangChain retrieval started', {
query,
searchMode: this.searchMode,
limit: this.limit,
orgId: this.orgId
});
// Execute search based on mode
let results;
switch (this.searchMode) {
case 'vector':
results = await this.vectorStore.vectorSearch(query, this.limit);
break;
case 'keyword':
results = await this.vectorStore.keywordSearch(query, this.limit);
break;
case 'symbol':
results = await this.vectorStore.symbolSearch(query, this.limit);
break;
case 'hybrid':
default:
results = await this.vectorStore.hybridSearch(query, this.limit);
break;
}
// Filter by types if specified
if (this.types && this.types.length > 0) {
results = results.filter(result => this.types!.includes(result.type));
}
// Filter by orgId for multi-tenant support
results = results.filter(result => result.orgId === this.orgId);
// Convert to LangChain Documents
const documents = results.map(chunk => {
const metadata = {
id: chunk.id,
type: chunk.type,
name: chunk.name,
path: chunk.path,
orgId: chunk.orgId,
symbols: chunk.symbols,
references: chunk.references,
similarity: chunk.similarity || chunk.rank,
source: 'salesforce-rag'
};
return new Document({
pageContent: chunk.content,
metadata
});
});
logger.info('LangChain retrieval completed', {
query,
resultCount: documents.length,
searchMode: this.searchMode,
orgId: this.orgId
});
// Optional callback for metrics
if (runManager) {
await runManager.handleRetrieverEnd(documents);
}
return documents;
} catch (error) {
logger.error('LangChain retrieval failed', {
error: error instanceof Error ? error.message : String(error),
query,
orgId: this.orgId
});
throw error;
}
}
// Method to update retriever parameters
updateParams(params: Partial<SalesforceRAGRetrieverParams>) {
if (params.limit !== undefined) this.limit = params.limit;
if (params.searchMode !== undefined) this.searchMode = params.searchMode;
if (params.types !== undefined) this.types = params.types;
if (params.orgId !== undefined) this.orgId = params.orgId;
}
// Method to get current configuration
getConfig() {
return {
limit: this.limit,
searchMode: this.searchMode,
types: this.types,
orgId: this.orgId
};
}
}
// Factory function for easy instantiation
export function createSalesforceRAGRetriever(
vectorStore: VectorStore,
orgId: string,
options?: {
limit?: number;
searchMode?: 'hybrid' | 'vector' | 'keyword' | 'symbol';
types?: string[];
}
): SalesforceRAGRetriever {
return new SalesforceRAGRetriever({
vectorStore,
orgId,
limit: options?.limit || 10,
searchMode: options?.searchMode || 'hybrid',
types: options?.types
});
}