Skip to main content
Glama
testrag.py7.02 kB
""" RAG module tests """ import platform import unittest from txtai.embeddings import Embeddings from txtai.pipeline import Questions, RAG, Similarity class TestRAG(unittest.TestCase): """ RAG tests. """ @classmethod def setUpClass(cls): """ Create single rag instance. """ cls.data = [ "Giants hit 3 HRs to down Dodgers", "Giants 5 Dodgers 4 final", "Dodgers drop Game 2 against the Giants, 5-4", "Blue Jays beat Red Sox final score 2-1", "Red Sox lost to the Blue Jays, 2-1", "Blue Jays at Red Sox is over. Score: 2-1", "Phillies win over the Braves, 5-0", "Phillies 5 Braves 0 final", "Final: Braves lose to the Phillies in the series opener, 5-0", "Lightning goaltender pulled, lose to Flyers 4-1", "Flyers 4 Lightning 1 final", "Flyers win 4-1", ] # Create embeddings model, backed by sentence-transformers & transformers cls.embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2"}) # Create rag instance cls.rag = RAG(cls.embeddings, "distilbert-base-cased-distilled-squad") @classmethod def tearDownClass(cls): """ Cleanup data. """ if cls.embeddings: cls.embeddings.close() def testAnswer(self): """ Test qa extraction with an answer """ questions = ["What team won the game?", "What was score?"] # pylint: disable=C3001 execute = lambda query: self.rag([(question, query, question, False) for question in questions], self.data) answers = execute("Red Sox - Blue Jays") self.assertEqual("Blue Jays", answers[0][1]) self.assertEqual("2-1", answers[1][1]) # Ad-hoc questions question = "What hockey team won?" answers = self.rag([(question, question, question, False)], self.data) self.assertEqual("Flyers", answers[0][1]) def testEmptyQuery(self): """ Test an empty queries list """ self.assertEqual(self.rag.query(None, None), []) def testNoAnswer(self): """ Test qa extraction with no answer """ question = "" answers = self.rag([(question, question, question, False)], self.data) self.assertIsNone(answers[0][1]) question = "abcdef" answers = self.rag([(question, question, question, False)], self.data) self.assertIsNone(answers[0][1]) @unittest.skipIf(platform.system() == "Darwin", "Quantized models not supported on macOS") def testQuantize(self): """ Test qa extraction backed by a quantized model """ rag = RAG(self.embeddings, "distilbert-base-cased-distilled-squad", True) question = "How many home runs?" answers = rag([(question, question, question, True)], self.data) self.assertTrue(answers[0][1].startswith("Giants hit 3 HRs")) def testOutputs(self): """ Test output formatting rules """ question = "How many home runs?" # Test flatten to list of answers rag = RAG(self.embeddings, "distilbert-base-cased-distilled-squad", output="flatten") answers = rag([(question, question, question, True)], self.data) self.assertTrue(answers[0].startswith("Giants hit 3 HRs")) # Test reference field rag = RAG(self.embeddings, "distilbert-base-cased-distilled-squad", output="reference") answers = rag([(question, question, question, True)], self.data) self.assertTrue(self.data[answers[0][2]].startswith("Giants hit 3 HRs")) def testPrompt(self): """ Test a user prompt with templating """ embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": True}) embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) rag = RAG( embeddings, "google/flan-t5-small", template=""" Answer the following question and return a number. Question: {question} Context:{context}""", output="flatten", ) self.assertEqual(rag("How many HRs"), "3") def testPromptTemplates(self): """ Test system and user prompt templates """ rag = RAG( self.embeddings, "sshleifer/tiny-gpt2", system="You are a friendly assistant", template=""" Answer the following question and return a number. Question: {question} Context:{context}""", ) prompts = rag.prompts(["How many HRs?"], [self.data])[0] self.assertEqual([x["role"] for x in prompts], ["system", "user"]) def testSearch(self): """ Test qa extraction with an embeddings search for context """ embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": True}) embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) rag = RAG(embeddings, "distilbert-base-cased-distilled-squad") question = "How many home runs?" answers = rag([(question, question, question, True)]) self.assertTrue(answers[0][1].startswith("Giants hit 3 HRs")) def testSimilarity(self): """ Test qa extraction using a Similarity pipeline to build context """ # Create rag instance rag = RAG(Similarity("prajjwal1/bert-medium-mnli"), Questions("distilbert-base-cased-distilled-squad")) question = "How many home runs?" answers = rag([(question, "HRs", question, True)], self.data) self.assertTrue(answers[0][1].startswith("Giants hit 3 HRs")) def testSnippet(self): """ Test qa extraction with a full answer snippet """ question = "How many home runs?" answers = self.rag([(question, question, question, True)], self.data) self.assertTrue(answers[0][1].startswith("Giants hit 3 HRs")) def testSnippetEmpty(self): """ Test snippet method can handle empty parameters """ self.assertEqual(self.rag.snippets(["name"], [None], [None], [None]), [("name", None)]) def testStringInput(self): """ Test with single string input """ result = self.rag("How many home runs?", self.data) self.assertEqual(result["answer"], "3") def testTasks(self): """ Test loading models with task parameter """ for task, model in [ ("language-generation", "hf-internal-testing/tiny-random-gpt2"), ("sequence-sequence", "hf-internal-testing/tiny-random-t5"), ]: rag = RAG(self.embeddings, model, task=task) self.assertIsNotNone(rag)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/neuml/txtai'

If you have feedback or need assistance with the MCP directory API, please join our Discord server