from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from utils import (
file_to_base64,
use_router,
upsert_vectors,
create_collection,
get_texts_gather,
)
from typing import List, Any, Dict
import ast
class State(TypedDict):
config: Dict
file_paths: List[str]
file_base64_list: List[str]
file_texts: List[str]
chunks_lists: List[Any]
vector_lists: Any
qdrant_collection_name: str
def get_base64_list(state: State):
file_base64_list = []
for path in state["file_paths"]:
file_base64 = file_to_base64(path)
file_base64_list.append(file_base64)
return {"file_base64_list": file_base64_list}
async def get_text_from_file(state: State):
"""Now this is an async node function"""
file_base64_list = state["file_base64_list"]
server_name = "Pdf extract server"
tool_name = "document_convert_to_markdown"
tasks = [
use_router(
state["config"], server_name, tool_name, {"file_base64": file_base64}
)
for file_base64 in file_base64_list
]
result = await get_texts_gather(tasks)
return {"file_texts": result}
async def get_chunks(state: State):
"""Now this is an async node function"""
file_texts = state["file_texts"]
server_name = "MarkUp server"
tool_name = "markup_process_text"
method_name = "paragraph_chunker"
tasks = [
use_router(
state["config"],
server_name,
tool_name,
{"text": text, "method_name": method_name},
)
for text in file_texts
]
result = await get_texts_gather(tasks)
result = [ast.literal_eval(el) for el in result]
chunks_lists = []
for docs_chunk in result:
for chunk_d in docs_chunk["chunks"]:
chunks_lists.append(chunk_d["page_content"])
return {"chunks_lists": chunks_lists}
async def get_vectors(state: State):
"""Now this is an async node function"""
chunks_lists = state["chunks_lists"]
server_name = "Embedding server"
tool_name = "embedding_batch_generate"
result = await use_router(
state["config"], server_name, tool_name, {"texts": chunks_lists}
)
result = ast.literal_eval(result)
vector_lists = []
for el in result["data"]:
vector_lists.append(el["embedding"])
return {"vector_lists": vector_lists}
def create_and_get_qdrant_collection(state: State):
vector_lists = state["vector_lists"]
server_name = "Qdrant server"
tool_create_name = "vector_create_collection"
tool_upsert_name = "vector_upsert_points"
tool_get_collection_info_name = "vector_get_collection_info"
vector_size = len(vector_lists[0])
collection_name, create_result = create_collection(
state["config"],
server_name,
tool_get_collection_info_name,
tool_create_name,
vector_size,
)
chunks_list = state["chunks_lists"]
upsert_result = upsert_vectors(
state["config"],
chunks_list,
collection_name,
vector_lists,
server_name,
tool_upsert_name,
)
return {"qdrant_collection_name": collection_name}
def build_graph_workflow():
graph_builder = StateGraph(State)
graph_builder.add_node("get_base64_list", get_base64_list)
graph_builder.add_node("get_text_from_file", get_text_from_file) # Now async
graph_builder.add_node("get_chunks", get_chunks) # Now async
graph_builder.add_node("get_vectors", get_vectors) # Now async
graph_builder.add_node(
"create_and_get_qdrant_collection", create_and_get_qdrant_collection
)
graph_builder.add_edge(START, "get_base64_list")
graph_builder.add_edge("get_base64_list", "get_text_from_file")
graph_builder.add_edge("get_text_from_file", "get_chunks")
graph_builder.add_edge("get_chunks", "get_vectors")
graph_builder.add_edge("get_vectors", "create_and_get_qdrant_collection")
graph_builder.add_edge("create_and_get_qdrant_collection", END)
return graph_builder.compile()