app.py•5.25 kB
import os
import uuid
import gradio as gr
from agent import construct_agent, initialize_agent_llm, initialize_instrumentor
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage, SystemMessage
from openinference.instrumentation import using_session
from opentelemetry.trace import Status, StatusCode
from rag import initialize_vector_store
load_dotenv()
SYSTEM_MESSAGE_FOR_AGENT_WORKFLOW = """
You are a Retrieval-Augmented Generation (RAG) assistant designed to provide responses by leveraging provided tools
Your goal is to ensure the user's query is addressed with quality. If further clarification is required,
you can request additional input from the user.
"""
def initialize_agent(phoenix_key, project_name, openai_key, user_session_id, vector_source_web_url):
from tools import initialize_tool_llm
os.environ["PHOENIX_API_KEY"] = phoenix_key
os.environ["OPENAI_API_KEY"] = openai_key
os.environ["PHOENIX_COLLECTOR_ENDPOINT"] = "https://app.phoenix.arize.com/v1/traces"
agent_tracer = initialize_instrumentor(project_name)
initialize_agent_llm("gpt-4o-mini")
tool_model = initialize_tool_llm("gpt-4o-mini")
initialize_vector_store(vector_source_web_url)
copilot_agent = construct_agent()
return (
copilot_agent,
agent_tracer,
tool_model,
user_session_id,
(f"Configuration Set: Project '{project_name}' is Ready!"),
)
def chat_with_agent(
copilot_agent,
agent_tracer,
tool_model,
user_input_message,
user_session_id,
user_chat_history,
conversation_history,
):
if not agent:
return "Error: RAG Agent is not initialized. Please set API keys first."
if not conversation_history:
messages = [SystemMessage(content=SYSTEM_MESSAGE_FOR_AGENT_WORKFLOW)]
else:
messages = conversation_history["messages"]
messages.append(HumanMessage(content=user_input_message))
with using_session(session_id=user_session_id):
with agent_tracer.start_as_current_span(
f"agent-{user_session_id}",
openinference_span_kind="chain",
) as span:
span.set_input(user_input_message)
conversation_history = copilot_agent.invoke(
{"messages": messages},
config={
"configurable": {
"thread_id": user_session_id,
"user_session_id": user_session_id,
"tool_model": tool_model,
}
},
)
span.set_output(conversation_history["messages"][-1].content)
span.set_status(Status(StatusCode.OK))
user_chat_history.append(
(user_input_message, conversation_history["messages"][-1].content)
)
return (
copilot_agent,
"",
user_chat_history,
user_session_id,
user_chat_history,
conversation_history,
)
with gr.Blocks() as demo:
agent = gr.State(None)
tracer = gr.State(None)
openai_tool_model = gr.State(None)
history = gr.State({}) # State to maintain the message history as a list of tuples
session_id = gr.State(str(uuid.uuid4()))
chat_history = gr.State([])
gr.Markdown("## Chat with RAG Agent 🔥")
with gr.Row():
with gr.Column(scale=1, min_width=250):
gr.Markdown("### Configuration Panel ⚙️")
phoenix_input = gr.Textbox(label="Phoenix API Key", type="password")
project_input = gr.Textbox(label="Project Name", value="Agentic Rag")
openai_input = gr.Textbox(label="OpenAI API Key", type="password")
web_url = gr.Textbox(
label="Vector Source Web URL",
value="https://lilianweng.github.io/posts/2023-06-23-agent/",
)
set_button = gr.Button("Set API Keys & Initialize")
output_message = gr.Textbox(label="Status", interactive=False)
set_button.click(
fn=initialize_agent,
inputs=[phoenix_input, project_input, openai_input, session_id, web_url],
outputs=[agent, tracer, openai_tool_model, session_id, output_message],
)
with gr.Column(scale=4):
gr.Markdown("### Chat with RAG Agent 💬")
chat_display = gr.Chatbot(label="Chat History", height=400)
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
submit_button = gr.Button("Send")
submit_button.click(
fn=chat_with_agent,
inputs=[
agent,
tracer,
openai_tool_model,
user_input,
session_id,
chat_display,
history,
],
outputs=[agent, user_input, chat_display, session_id, chat_history, history],
)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 10000))
demo.launch(share=True, server_name="0.0.0.0", server_port=port)