Skip to main content
Glama

Datetime MCP Server

by bossjones
langgraph.txt2.76 MB
Project Path: langgraph Source Tree: ``` langgraph ├── LICENSE ├── requirements.txt ├── Makefile ├── pyproject.toml ├── docs │ ├── overrides │ │ ├── main.html │ │ └── partials │ │ ├── logo.html │ │ └── comments.html │ ├── mkdocs.yml │ ├── _scripts │ │ ├── notebook_convert_templates │ │ │ └── mdoutput │ │ │ ├── conf.json │ │ │ └── index.md.j2 │ │ ├── generate_api_reference_links.py │ │ ├── prepare_notebooks_for_ci.py │ │ ├── notebook_convert.py │ │ ├── notebook_hooks.py │ │ ├── download_tiktoken.py │ │ └── execute_notebooks.sh │ ├── test-compose.yml │ ├── cassettes │ │ ├── LLMCompiler_31854dfd-b82f-4c24-9b58-6bae66777909.msgpack.zlib │ │ ├── introduction_6a385b06-8d34-4a2d-aded-1cf4bb0ca590.msgpack.zlib │ │ ├── wait-user-input_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib │ │ ├── state-model_8edb04b9-40b6-46f1-a7a8-4b2d8aba7752.msgpack.zlib │ │ ├── create-react-agent-hitl_83148e08-63e8-49e5-a08b-02dc907bed1d.msgpack.zlib │ │ ├── add-summary-conversation-history_048805a4-3d97-4e76-ac45-8d80d4364c46.msgpack.zlib │ │ ├── introduction_32e4f36e-72ce-4ade-bd7e-94880e0d456b.msgpack.zlib │ │ ├── create-react-agent_fa16de4c-aac0-4ff4-ab69-60d399f75423.msgpack.zlib │ │ ├── customer-support_b7443751.msgpack.zlib │ │ ├── introduction_761d15fb-d5e2-4d50-a630-126d77e77294.msgpack.zlib │ │ ├── delete-messages_57b27553-21be-43e5-ac48-d1d0a3aa0dca.msgpack.zlib │ │ ├── persistence_postgres_386b78bc-2f73-49ba-a2a4-47bce6fc49b7.msgpack.zlib │ │ ├── LLMCompiler_38d3ea91-59ba-4267-8060-ed75bbc840c6.msgpack.zlib │ │ ├── introduction_35c8978e-c07d-4dd0-a97b-0ce3a723eea5.msgpack.zlib │ │ ├── disable-streaming_8.msgpack.zlib │ │ ├── introduction_9f318020-ab7e-415b-a5e2-eddec6d9f3a6.msgpack.zlib │ │ ├── create-react-agent-memory_187479f9-32fa-4611-9487-cf816ba2e147.msgpack.zlib │ │ ├── streaming-from-final-node_55d60dfa-96e3-442f-9924-0c99f46baed8.msgpack.zlib │ │ ├── pass-config-to-tools_14.msgpack.zlib │ │ ├── introduction_dba1b168-f8e0-496d-9bd6-37198fb4776e.msgpack.zlib │ │ ├── tool-calling_19.msgpack.zlib │ │ ├── reflexion_6fd51f17-c0b0-44b6-90e2-55a66cb8f5a7.msgpack.zlib │ │ ├── persistence_08ae8246-11d5-40e1-8567-361e5bef8917.msgpack.zlib │ │ ├── time-travel_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib │ │ ├── customer-support_783a548d-029a-47b7-9ac0-9c5203ec92c7.msgpack.zlib │ │ ├── sql-agent_cf02e843-438d-4168-a27a-8f1e0266f8d7.msgpack.zlib │ │ ├── edit-graph-state_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib │ │ ├── reflection_dfbf99a8-3aa0-4e09-936e-8452c35fa84d.msgpack.zlib │ │ ├── retries_746b409c-693d-49af-8c2b-bea0a4b0028d.msgpack.zlib │ │ ├── langgraph_code_assistant_2dacccf0-d73f-4017-aaf0-9806ffe5bd2c.msgpack.zlib │ │ ├── semantic-search_8.msgpack.zlib │ │ ├── langgraph_self_rag_fb69dbb9-91ee-4868-8c3c-93af3cd885be.msgpack.zlib │ │ ├── recursion-limit_5.msgpack.zlib │ │ ├── LLMCompiler_730490c6-6e3a-4173-82a1-9eb9d5eeff20.msgpack.zlib │ │ ├── add-summary-conversation-history_57b27553-21be-43e5-ac48-d1d0a3aa0dca.msgpack.zlib │ │ ├── sql-agent_3bf7709f-500c-4f28-bb85-dda317286c63.msgpack.zlib │ │ ├── agent_supervisor_45a92dfd-0e11-47f5-aad4-b68d24990e34.msgpack.zlib │ │ ├── persistence_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib │ │ ├── introduction_8b49509c-9d97-457c-a76a-c495fb30ccbc.msgpack.zlib │ │ ├── langgraph_agentic_rag_278d1d83-dda6-4de4-bf8b-be9965c227fa.msgpack.zlib │ │ ├── introduction_9fc99c7e-b61d-4aec-9c62-042798185ec3.msgpack.zlib │ │ ├── introduction_f4009ba6-dc0b-4216-ab0c-fbb104616f73.msgpack.zlib │ │ ├── subgraph_13.msgpack.zlib │ │ ├── cross-thread-persistence_d362350b-d730-48bd-9652-983812fd7811.msgpack.zlib │ │ ├── customer-support_72fceb01-b0ab-4bef-a22f-a2fce6ee33ef.msgpack.zlib │ │ ├── branching_1d0e6c56.msgpack.zlib │ │ ├── LLMCompiler_942dab42-ad42-4ba2-90d5-49edbe4fae68.msgpack.zlib │ │ ├── tool-calling-errors_19.msgpack.zlib │ │ ├── run-id-langsmith_6.msgpack.zlib │ │ ├── introduction_f5447778-53d7-47f3-801b-f47bcf2185a0.msgpack.zlib │ │ ├── semantic-search_19.msgpack.zlib │ │ ├── branching_83320227-8ab3-44c0-b6cf-064a7a425b9f.msgpack.zlib │ │ ├── shared-state_d362350b-d730-48bd-9652-983812fd7811.msgpack.zlib │ │ ├── introduction_85f17be3-eaf6-495e-a846-49436916b4ab.msgpack.zlib │ │ ├── introduction_051dc374-67cc-4371-9dd1-221e07593148.msgpack.zlib │ │ ├── configuration_e043a719-f197-46ef-9d45-84740a39aeb0.msgpack.zlib │ │ ├── review-tool-calls_d57d5131-7912-4216-aa87-b7272507fa51.msgpack.zlib │ │ ├── edit-graph-state_85e452f8-f33a-4ead-bb4d-7386cdba8edc.msgpack.zlib │ │ ├── introduction_effb95d9-b7d5-40c5-9253-253d193b23b2.msgpack.zlib │ │ ├── retries_f4752239-2aa3-4367-b777-8478c16b9471.msgpack.zlib │ │ ├── review-tool-calls_3f05f8b6-6128-4de5-8884-862fc93f1227.msgpack.zlib │ │ ├── langgraph_self_rag_bd62276f-bf26-40d0-8cff-e07b10e00321.msgpack.zlib │ │ ├── langgraph_code_assistant_3ba3df70-f6b4-4ea5-a210-e10944960bc6.msgpack.zlib │ │ ├── langgraph_code_assistant_ef7cf662-7a6f-4dee-965c-6309d4045feb.msgpack.zlib │ │ ├── create-react-agent-hitl_740bbaeb.msgpack.zlib │ │ ├── disable-streaming_4.msgpack.zlib │ │ ├── plan-and-execute_746e697a-dec4-4342-a814-9b3456828169.msgpack.zlib │ │ ├── LLMCompiler_55142257-2674-4a47-988e-0d2810917329.msgpack.zlib │ │ ├── create-react-agent-memory_9ffff6c3-a4f5-47c9-b51d-97caaee85cd6.msgpack.zlib │ │ ├── pass-config-to-tools_18.msgpack.zlib │ │ ├── semantic-search_13.msgpack.zlib │ │ ├── subgraphs-manage-state_13.msgpack.zlib │ │ ├── subgraphs-manage-state_46.msgpack.zlib │ │ ├── LLMCompiler_391d6931.msgpack.zlib │ │ ├── stream-updates_e9e9ffb0-2cd5-466f-b70b-b6ed51b852d1.msgpack.zlib │ │ ├── streaming-from-final-node_68ac2c7f.msgpack.zlib │ │ ├── review-tool-calls_df4a9900-d953-4465-b8af-bd2858cb63ea.msgpack.zlib │ │ ├── breakpoints_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib │ │ ├── configuration_070f11a6-2441-4db5-9df6-e318f110e281.msgpack.zlib │ │ ├── langgraph_agentic_rag_7b5a1d35.msgpack.zlib │ │ ├── create-react-agent-hitl_9ffff6c3-a4f5-47c9-b51d-97caaee85cd6.msgpack.zlib │ │ ├── tool-calling_23.msgpack.zlib │ │ ├── reflexion_5922e1fe-7533-4f41-8b1d-d812707c1968.msgpack.zlib │ │ ├── tool-calling-errors_13.msgpack.zlib │ │ ├── langgraph_self_rag_e78931ec-940c-46ad-a0b2-f43f953f1fd7.msgpack.zlib │ │ ├── configuration_f2f7c74b-9fb0-41c6-9728-dcf9d8a3c397.msgpack.zlib │ │ ├── LLMCompiler_0b3a0916-d8ca-4092-b91c-d9e2b05259d8.msgpack.zlib │ │ ├── react-agent-from-scratch_11.msgpack.zlib │ │ ├── persistence_postgres_180d6daf-8fa7-4608-bd2e-bfbf44ed5836.msgpack.zlib │ │ ├── subgraphs-manage-state_30.msgpack.zlib │ │ ├── persistence_redis_6a39d1ff-ca37-4457-8b52-07d33b59c36e.msgpack.zlib │ │ ├── sql-agent_4f200d1813897000.msgpack.zlib │ │ ├── delete-messages_3975f34c-c243-40ea-b9d2-424d50a48dc9.msgpack.zlib │ │ ├── semantic-search_15.msgpack.zlib │ │ ├── review-tool-calls_f9a0d5d4-52ff-49e0-a6f4-41f9a0e844d8.msgpack.zlib │ │ ├── breakpoints_6098e5cb.msgpack.zlib │ │ ├── rewoo_badaca52-5d55-433f-8770-1bd50c10bf7f.msgpack.zlib │ │ ├── shared-state_d862be40-1f8a-4057-81c4-b7bf073dc4c1.msgpack.zlib │ │ ├── add-summary-conversation-history_40e5db8e-9db9-4ac7-9d76-a99fd4034bf3.msgpack.zlib │ │ ├── subgraphs-manage-state_40.msgpack.zlib │ │ ├── subgraphs-manage-state_15.msgpack.zlib │ │ ├── persistence_redis_5fe54e79-9eaf-44e2-b2d9-1e0284b984d0.msgpack.zlib │ │ ├── configuration_ef50f048-fc43-40c0-b713-346408fcf052.msgpack.zlib │ │ ├── plan-and-execute_b8ac1f67-e87a-427c-b4f7-44351295b788.msgpack.zlib │ │ ├── langgraph_self_rag_1fafad21-60cc-483e-92a3-6a7edb1838e3.msgpack.zlib │ │ ├── disable-streaming_2.msgpack.zlib │ │ ├── async_8edb04b9-40b6-46f1-a7a8-4b2d8aba7752.msgpack.zlib │ │ ├── agent-simulation-evaluation_6f80669e-aa78-4666-b67c-a539366d5aab.msgpack.zlib │ │ ├── react-agent-structured-output_15.msgpack.zlib │ │ ├── tool-calling-errors_9.msgpack.zlib │ │ ├── subgraphs-manage-state_9.msgpack.zlib │ │ ├── retries_04e93401-50e2-42d0-8373-326006badebb.msgpack.zlib │ │ ├── state-model_e09aaa63.msgpack.zlib │ │ ├── persistence_273d56a8-f40f-4a51-a27f-7c6bb2bda0ba.msgpack.zlib │ │ ├── subgraph_9.msgpack.zlib │ │ ├── langgraph_code_assistant_9bcaafe4-ddcf-4fab-8620-2d9b6c508f98.msgpack.zlib │ │ ├── tool-calling_25.msgpack.zlib │ │ ├── langgraph_code_assistant_9f14750f-dddc-485b-ba29-5392cdf4ba43.msgpack.zlib │ │ ├── cross-thread-persistence_d862be40-1f8a-4057-81c4-b7bf073dc4c1.msgpack.zlib │ │ ├── subgraphs-manage-state_42.msgpack.zlib │ │ ├── persistence_postgres_6a39d1ff-ca37-4457-8b52-07d33b59c36e.msgpack.zlib │ │ ├── sql-agent_64b0bf1b14c2e902.msgpack.zlib │ │ ├── stream-values_e9e9ffb0-2cd5-466f-b70b-b6ed51b852d1.msgpack.zlib │ │ ├── introduction_faa345c6-38a2-42e8-9035-9cf56f7bb5b1.msgpack.zlib │ │ ├── sql-agent_eb814b85-70ba-4699-9038-20266b53efbd.msgpack.zlib │ │ ├── add-summary-conversation-history_7094c5ab-66f8-42ff-b1c3-90c8a9468e62.msgpack.zlib │ │ ├── pass-run-time-values-to-tools_2e3fd1e2-cc19-4023-8ffa-b0fc13da9e09.msgpack.zlib │ │ ├── introduction_acbec099-e5d2-497f-929e-c548d7bcbf77.msgpack.zlib │ │ ├── customer-support_96469e95-5070-4169-bedd-45db94b43d97.msgpack.zlib │ │ ├── semantic-search_17.msgpack.zlib │ │ ├── semantic-search_10.msgpack.zlib │ │ ├── streaming-from-final-node_2ab6d079-ba06-48ba-abe5-e72df24407af.msgpack.zlib │ │ ├── langgraph_self_rag_dcd77cc1-4587-40ec-b633-5364eab9e1ec.msgpack.zlib │ │ ├── breakpoints_51923913-20f7-4ee1-b9ba-d01f5fb2869b.msgpack.zlib │ │ ├── wait-user-input_a9f599b5-1a55-406b-a76b-f52b3ca06975.msgpack.zlib │ │ ├── self-discover_6cbfbe81-f751-42da-843a-f9003ace663d.msgpack.zlib │ │ ├── pass-run-time-values-to-tools_d683986f-faf4-4724-a13f-bac39ea9bafe.msgpack.zlib │ │ ├── agent_supervisor_56ba78e9-d9c1-457c-a073-d606d5d3e013.msgpack.zlib │ │ ├── map-reduce_fd90cace.msgpack.zlib │ │ ├── reflexion_2634a3ea-7423-4579-9f4e-390e439c3209.msgpack.zlib │ │ ├── stream-values_c122bf15-a489-47bf-b482-a744a54e2cc4.msgpack.zlib │ │ ├── reflection_06263a07-8a15-4ec3-b692-1c6cef3b1c1f.msgpack.zlib │ │ ├── persistence_postgres_bd235fc7-1e5c-4db6-a90b-ea75462ccf7d.msgpack.zlib │ │ ├── async_4b369a6f.msgpack.zlib │ │ ├── subgraphs-manage-state_33.msgpack.zlib │ │ ├── streaming-subgraphs_7.msgpack.zlib │ │ ├── semantic-search_6.msgpack.zlib │ │ ├── persistence_mongodb_5fe54e79-9eaf-44e2-b2d9-1e0284b984d0.msgpack.zlib │ │ ├── sql-agent_293017e8f05ac2b3.msgpack.zlib │ │ ├── tool-calling-errors_17.msgpack.zlib │ │ ├── langgraph_code_assistant_c2eb35d1-4990-47dc-a5c4-208bae588a82.msgpack.zlib │ │ ├── introduction_03a09bfc-3d90-4e54-878f-22e3cb28a418.msgpack.zlib │ │ ├── introduction_6b32914d-4d60-491f-8e11-1e6867e38ffd.msgpack.zlib │ │ ├── streaming-tokens_96050fba.msgpack.zlib │ │ ├── reflexion_7541f82c.msgpack.zlib │ │ ├── hierarchical_agent_teams_6b8badbf-d728-44bd-a2a7-5b4e587c92fe.msgpack.zlib │ │ ├── add-summary-conversation-history_0a1a0fda-5309-45f0-9465-9f3dff604d74.msgpack.zlib │ │ ├── subgraphs-manage-state_11.msgpack.zlib │ │ ├── reflection_16c5eb2a-8bce-48ab-b87d-9dacb9b64ac6.msgpack.zlib │ │ ├── persistence_mongodb_6a39d1ff-ca37-4457-8b52-07d33b59c36e.msgpack.zlib │ │ ├── create-react-agent-system-prompt_9ffff6c3-a4f5-47c9-b51d-97caaee85cd6.msgpack.zlib │ │ ├── dynamic_breakpoints_9a14c8b2-5c25-4201-93ea-e5358ee99bcb.msgpack.zlib │ │ ├── branching_66f52a20.msgpack.zlib │ │ ├── tool-calling_17.msgpack.zlib │ │ ├── LLMCompiler_15dd9639-691f-4906-9012-83fd6e9ac126.msgpack.zlib │ │ ├── LLMCompiler_152eecf3-6bef-4718-af71-a0b3c5a3b009.msgpack.zlib │ │ ├── introduction_69071b02-c011-4b7f-90b1-8e89e032322d.msgpack.zlib │ │ ├── tool-calling-errors_11.msgpack.zlib │ │ ├── return-when-recursion-limit-hits_6.msgpack.zlib │ │ ├── sql-agent_85958809-03c5-4e52-97cc-e7c0ae986f60.msgpack.zlib │ │ ├── persistence_postgres_5fe54e79-9eaf-44e2-b2d9-1e0284b984d0.msgpack.zlib │ │ ├── tool-calling_26.msgpack.zlib │ │ ├── streaming-events-from-within-tools-without-langchain_45c96a79-4147-42e3-89fd-d942b2b49f6c.msgpack.zlib │ │ ├── react-agent-from-scratch_13.msgpack.zlib │ │ ├── create-react-agent_187479f9-32fa-4611-9487-cf816ba2e147.msgpack.zlib │ │ ├── langgraph_code_assistant_71d90f9e-9dad-410c-a709-093d275029ae.msgpack.zlib │ │ ├── map-reduce_37ed1f71-63db-416f-b715-4617b33d4b7f.msgpack.zlib │ │ ├── hierarchical_agent_teams_912b0604-a178-4246-a36f-2dedae606680.msgpack.zlib │ │ ├── introduction_b3220ae2-cba0-4447-96d1-eb0be4684e59.msgpack.zlib │ │ ├── stream-multiple_e9e9ffb0-2cd5-466f-b70b-b6ed51b852d1.msgpack.zlib │ │ ├── wait-user-input_f5319e01.msgpack.zlib │ │ ├── persistence_6fa9a5e3-7101-43ab-a811-592e222b9580.msgpack.zlib │ │ ├── review-tool-calls_ec77831c-e6b8-4903-9146-e098a4b2fda1.msgpack.zlib │ │ ├── review-tool-calls_85e452f8-f33a-4ead-bb4d-7386cdba8edc.msgpack.zlib │ │ ├── review-tool-calls_1b3aa6fc-c7fb-4819-8d7f-ba6057cc4edf.msgpack.zlib │ │ ├── streaming-events-from-within-tools_ec461f66.msgpack.zlib │ │ ├── cross-thread-persistence_c871a073-a466-46ad-aafe-2b870831057e.msgpack.zlib │ │ ├── plan-and-execute_72d233ca-1dbf-4b43-b680-b3bf39e3691f.msgpack.zlib │ │ ├── introduction_11d5b934-6d8b-4f52-a3bc-b3daa7207e00.msgpack.zlib │ │ ├── self-discover_a18d8f24-5d9a-45c5-9739-6f3c4ed6c9c9.msgpack.zlib │ │ ├── plan-and-execute_67ce37b7-e089-479b-bcb8-c3f5d9874613.msgpack.zlib │ │ ├── langgraph_self_rag_c6f4c70e-1660-4149-82c0-837f19fc9fb5.msgpack.zlib │ │ ├── agent-simulation-evaluation_32848c2e-be82-46f3-81db-b23fea45461c.msgpack.zlib │ │ ├── manage-conversation-history_57b27553-21be-43e5-ac48-d1d0a3aa0dca.msgpack.zlib │ │ ├── time-travel_e986f94f-706f-4b6f-b3c4-f95483b9e9b8.msgpack.zlib │ │ ├── shared-state_c871a073-a466-46ad-aafe-2b870831057e.msgpack.zlib │ │ ├── pass-config-to-tools_17.msgpack.zlib │ │ ├── introduction_c1955d79-a1e4-47d0-ba79-b45bd5752a23.msgpack.zlib │ │ ├── breakpoints_9b53f191-1e86-4881-a667-d46a3d66958b.msgpack.zlib │ │ ├── langgraph_self_rag_4138bc51-8c84-4b8a-8d24-f7f470721f6f.msgpack.zlib │ │ ├── pass-run-time-values-to-tools_c0858273-10f2-45c4-a922-b11321ac3fae.msgpack.zlib │ │ ├── streaming-tokens-without-langchain_d6ed3df5.msgpack.zlib │ │ ├── langgraph_agentic_rag_7649f05a-cb67-490d-b24a-74d41895139a.msgpack.zlib │ │ ├── review-tool-calls_a30d40ad-611d-4ec3-84be-869ea05acb89.msgpack.zlib │ │ ├── information-gather-prompting_1b1613e0.msgpack.zlib │ │ ├── retries_5d072c9c-9404-4338-88c6-b3e136969aca.msgpack.zlib │ │ ├── streaming-tokens_72785b66.msgpack.zlib │ │ ├── async_f544977e-31f7-41f0-88c4-ec9c27b8cecb.msgpack.zlib │ │ ├── branching_932c497e.msgpack.zlib │ │ ├── manage-conversation-history_52468ebb-4b23-45ac-a98e-b4439f37740a.msgpack.zlib │ │ ├── plan-and-execute_7363e528.msgpack.zlib │ │ ├── configuration_718685f7-4cdd-4181-9fc8-e7762d584727.msgpack.zlib │ │ ├── time-travel_9a92d3da-62e2-45a2-8545-e4f6a64e0ffe.msgpack.zlib │ │ ├── LLMCompiler_5bc4584a-e31c-4065-805e-76a6db30676a.msgpack.zlib │ │ ├── create-react-agent_9ffff6c3-a4f5-47c9-b51d-97caaee85cd6.msgpack.zlib │ │ ├── wait-user-input_58eae42d-be32-48da-8d0a-ab64471657d9.msgpack.zlib │ │ ├── sql-agent_1040233f-3751-4bd3-902f-709fc2e1ecf5.msgpack.zlib │ │ ├── persistence_postgres_4faf6087-73cc-4957-9a4f-f3509a32a740.msgpack.zlib │ │ ├── review-tool-calls_2561a38f-edb5-4b44-b2d7-6a7b70d2e6b7.msgpack.zlib │ │ ├── hierarchical_agent_teams_9860fd46-c24d-40a5-a6ba-e8fddcd43369.msgpack.zlib │ │ ├── agent-simulation-evaluation_f58959bf-2ab5-4330-9ac2-c00f45237e24.msgpack.zlib │ │ ├── react-agent-structured-output_9.msgpack.zlib │ │ ├── multi-agent-collaboration_9f478b05-3f09-447f-a9f4-1b2eae73f5ef.msgpack.zlib │ │ ├── introduction_a7debb4a-2a3a-40b9-a48c-7052ec2c2726.msgpack.zlib │ │ ├── edit-graph-state_51923913-20f7-4ee1-b9ba-d01f5fb2869b.msgpack.zlib │ │ ├── information-gather-prompting_25793988-45a2-4e65-b33c-64e72aadb10e.msgpack.zlib │ │ ├── rewoo_56ecb45b-ea76-4303-a4f3-51406fe8312a.msgpack.zlib │ │ ├── async_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib │ │ ├── subgraphs-manage-state_48.msgpack.zlib │ │ ├── reflection_9bbe25dc-fd1e-4ed5-a3c8-fed830b46d12.msgpack.zlib │ │ ├── introduction_4527cf9a-b191-4bde-858a-e33a74a48c55.msgpack.zlib │ │ └── pass-config-to-tools_16.msgpack.zlib │ ├── docs │ │ ├── troubleshooting │ │ │ └── errors │ │ │ ├── INVALID_CHAT_HISTORY.md │ │ │ ├── INVALID_GRAPH_NODE_RETURN_VALUE.md │ │ │ ├── INVALID_CONCURRENT_GRAPH_UPDATE.md │ │ │ ├── index.md │ │ │ ├── MULTIPLE_SUBGRAPHS.md │ │ │ └── GRAPH_RECURSION_LIMIT.md │ │ ├── index.md │ │ ├── static │ │ │ ├── wordmark_dark.svg │ │ │ ├── values_vs_updates.png │ │ │ ├── wordmark_light.svg │ │ │ └── favicon.png │ │ ├── cloud │ │ │ ├── sdk │ │ │ │ └── img │ │ │ │ ├── thread_diagram.png │ │ │ │ └── graph_diagram.png │ │ │ ├── deployment │ │ │ │ ├── graph_rebuild.md │ │ │ │ ├── setup_pyproject.md │ │ │ │ ├── custom_docker.md │ │ │ │ ├── setup.md │ │ │ │ ├── semantic_search.md │ │ │ │ ├── setup_javascript.md │ │ │ │ ├── img │ │ │ │ │ ├── graph_run.png │ │ │ │ │ ├── deployed_page.png │ │ │ │ │ ├── quick_start_studio.png │ │ │ │ │ ├── cloud_deployment.png │ │ │ │ │ └── deployment_page.png │ │ │ │ ├── cloud.md │ │ │ │ └── test_locally.md │ │ │ ├── how-tos │ │ │ │ ├── stream_values.md │ │ │ │ ├── same-thread.md │ │ │ │ ├── configuration_cloud.md │ │ │ │ ├── background_run.md │ │ │ │ ├── stream_debug.md │ │ │ │ ├── human_in_the_loop_edit_state.md │ │ │ │ ├── cron_jobs.md │ │ │ │ ├── human_in_the_loop_breakpoint.md │ │ │ │ ├── enqueue_concurrent.md │ │ │ │ ├── threads_studio.md │ │ │ │ ├── check_thread_status.md │ │ │ │ ├── test_deployment.md │ │ │ │ ├── human_in_the_loop_time_travel.md │ │ │ │ ├── test_local_deployment.md │ │ │ │ ├── stream_messages.md │ │ │ │ ├── assistant_versioning.md │ │ │ │ ├── webhooks.md │ │ │ │ ├── invoke_studio.md │ │ │ │ ├── stream_updates.md │ │ │ │ ├── stream_events.md │ │ │ │ ├── img │ │ │ │ │ ├── studio_threads.mp4 │ │ │ │ │ ├── select_different_version.png │ │ │ │ │ ├── studio_input_poster.png │ │ │ │ │ ├── edit_created_assistant.png │ │ │ │ │ ├── click_create_assistant.png │ │ │ │ │ ├── studio_usage_poster.png │ │ │ │ │ ├── create_assistant_view.png │ │ │ │ │ ├── studio_input.mp4 │ │ │ │ │ ├── see_new_version.png │ │ │ │ │ ├── studio_forks_poster.png │ │ │ │ │ ├── create_new_version.png │ │ │ │ │ ├── studio_threads_poster.png │ │ │ │ │ ├── see_version_history.png │ │ │ │ │ ├── create_assistant.png │ │ │ │ │ ├── studio_screenshot.png │ │ │ │ │ ├── studio_forks.mp4 │ │ │ │ │ └── studio_usage.mp4 │ │ │ │ ├── human_in_the_loop_review_tool_calls.md │ │ │ │ ├── reject_concurrent.md │ │ │ │ ├── human_in_the_loop_user_input.md │ │ │ │ ├── langgraph_to_langgraph_cloud.ipynb │ │ │ │ ├── stateless_runs.md │ │ │ │ ├── interrupt_concurrent.md │ │ │ │ ├── stream_multiple.md │ │ │ │ ├── rollback_concurrent.md │ │ │ │ └── copy_threads.md │ │ │ ├── reference │ │ │ │ ├── cli.md │ │ │ │ ├── env_var.md │ │ │ │ ├── sdk │ │ │ │ │ └── python_sdk_ref.md │ │ │ │ └── api │ │ │ │ ├── api_ref.html │ │ │ │ ├── api_ref.md │ │ │ │ └── openapi.json │ │ │ └── quick_start.md │ │ ├── concepts │ │ │ ├── langgraph_cli.md │ │ │ ├── plans.md │ │ │ ├── template_applications.md │ │ │ ├── langgraph_platform.md │ │ │ ├── deployment_options.md │ │ │ ├── low_level.md │ │ │ ├── faq.md │ │ │ ├── streaming.md │ │ │ ├── img │ │ │ │ ├── multi_agent │ │ │ │ │ ├── response.png │ │ │ │ │ ├── architectures.png │ │ │ │ │ └── request.png │ │ │ │ ├── assistants.png │ │ │ │ ├── lg_studio.png │ │ │ │ ├── byoc_architecture.png │ │ │ │ ├── human_in_the_loop │ │ │ │ │ ├── replay.png │ │ │ │ │ ├── approval.png │ │ │ │ │ ├── wait_for_input.png │ │ │ │ │ ├── forking.png │ │ │ │ │ └── edit_graph_state.png │ │ │ │ ├── memory │ │ │ │ │ ├── update-profile.png │ │ │ │ │ ├── hot_path_vs_background.png │ │ │ │ │ ├── short-vs-long.png │ │ │ │ │ ├── filter.png │ │ │ │ │ ├── summary.png │ │ │ │ │ ├── update-instructions.png │ │ │ │ │ └── update-list.png │ │ │ │ ├── double_texting.png │ │ │ │ ├── agent_types.png │ │ │ │ ├── langgraph_cloud_architecture.png │ │ │ │ ├── persistence │ │ │ │ │ ├── re_play.jpg │ │ │ │ │ ├── checkpoints.jpg │ │ │ │ │ ├── get_state.jpg │ │ │ │ │ ├── checkpoints_full_story.jpg │ │ │ │ │ └── shared_state.png │ │ │ │ ├── langgraph.png │ │ │ │ ├── tool_call.png │ │ │ │ ├── challenge.png │ │ │ │ └── lg_platform.png │ │ │ ├── sdk.md │ │ │ ├── high_level.md │ │ │ ├── bring_your_own_cloud.md │ │ │ ├── langgraph_cloud.md │ │ │ ├── persistence.md │ │ │ ├── double_texting.md │ │ │ ├── memory.md │ │ │ ├── assistants.md │ │ │ ├── index.md │ │ │ ├── langgraph_studio.md │ │ │ ├── langgraph_server.md │ │ │ ├── human_in_the_loop.md │ │ │ ├── application_structure.md │ │ │ ├── self_hosted.md │ │ │ ├── agentic_concepts.md │ │ │ └── multi_agent.md │ │ ├── tutorials │ │ │ ├── multi_agent │ │ │ │ ├── agent_supervisor_docs.md │ │ │ │ ├── agent_supervisor.py │ │ │ │ ├── hierarchical_agent_teams.ipynb │ │ │ │ ├── hierarchical_agent_teams_docs.md │ │ │ │ ├── agent_supervisor.ipynb │ │ │ │ ├── multi-agent-collaboration_docs.md │ │ │ │ ├── multi-agent-collaboration.py │ │ │ │ ├── multi-agent-collaboration.ipynb │ │ │ │ └── hierarchical_agent_teams.py │ │ │ ├── introduction.ipynb │ │ │ ├── customer-support │ │ │ │ ├── img │ │ │ │ │ ├── customer-support-bot-4.png │ │ │ │ │ ├── part-3-diagram.png │ │ │ │ │ ├── part-4-diagram.png │ │ │ │ │ ├── part-1-diagram.png │ │ │ │ │ └── part-2-diagram.png │ │ │ │ └── customer-support.ipynb │ │ │ ├── plan-and-execute │ │ │ │ ├── plan-and-execute.py │ │ │ │ ├── plan-and-execute_docs.md │ │ │ │ └── plan-and-execute.ipynb │ │ │ ├── tnt-llm │ │ │ │ ├── img │ │ │ │ │ └── tnt_llm.png │ │ │ │ └── tnt-llm.ipynb │ │ │ ├── reflexion │ │ │ │ ├── reflexion.ipynb │ │ │ │ ├── reflexion_docs.md │ │ │ │ └── reflexion.py │ │ │ ├── code_assistant │ │ │ │ ├── langgraph_code_assistant_docs.md │ │ │ │ ├── langgraph_code_assistant.ipynb │ │ │ │ └── langgraph_code_assistant.py │ │ │ ├── lats │ │ │ │ └── lats.ipynb │ │ │ ├── rewoo │ │ │ │ └── rewoo.ipynb │ │ │ ├── web-navigation │ │ │ │ ├── img │ │ │ │ │ └── web-voyager.excalidraw.jpg │ │ │ │ ├── web_voyager.ipynb │ │ │ │ └── mark_page.js │ │ │ ├── rag │ │ │ │ ├── langgraph_self_rag_local.ipynb │ │ │ │ ├── langgraph_adaptive_rag_local.ipynb │ │ │ │ ├── langgraph_self_rag.ipynb │ │ │ │ ├── langgraph_crag_local.ipynb │ │ │ │ ├── langgraph_crag.ipynb │ │ │ │ ├── langgraph_adaptive_rag.ipynb │ │ │ │ └── langgraph_agentic_rag.ipynb │ │ │ ├── llm-compiler │ │ │ │ ├── LLMCompiler.ipynb │ │ │ │ ├── output_parser.py │ │ │ │ └── math_tools.py │ │ │ ├── tot │ │ │ │ ├── img │ │ │ │ │ └── tot.png │ │ │ │ └── tot.ipynb │ │ │ ├── sql-agent.ipynb │ │ │ ├── usaco │ │ │ │ ├── img │ │ │ │ │ ├── diagram-part-2.png │ │ │ │ │ ├── diagram-part-1.png │ │ │ │ │ ├── benchmark.png │ │ │ │ │ ├── diagram.png │ │ │ │ │ └── usaco.png │ │ │ │ └── usaco.ipynb │ │ │ ├── index.md │ │ │ ├── chatbots │ │ │ │ └── information-gather-prompting.ipynb │ │ │ ├── chatbot-simulation-evaluation │ │ │ │ ├── agent-simulation-evaluation.ipynb │ │ │ │ └── langsmith-agent-simulation-evaluation.ipynb │ │ │ ├── langgraph-platform │ │ │ │ └── local-server.md │ │ │ ├── extraction │ │ │ │ └── retries.ipynb │ │ │ ├── storm │ │ │ │ └── storm.ipynb │ │ │ ├── self-discover │ │ │ │ └── self-discover.ipynb │ │ │ └── reflection │ │ │ ├── reflection.py │ │ │ ├── reflection.ipynb │ │ │ └── reflection_docs.md │ │ ├── how-tos │ │ │ ├── run-id-langsmith.ipynb │ │ │ ├── react-agent-from-scratch.ipynb │ │ │ ├── streaming-from-final-node.ipynb │ │ │ ├── recursion-limit.ipynb │ │ │ ├── human_in_the_loop │ │ │ │ ├── breakpoints.ipynb │ │ │ │ ├── time-travel.ipynb │ │ │ │ ├── edit-graph-state.ipynb │ │ │ │ ├── dynamic_breakpoints.ipynb │ │ │ │ ├── review-tool-calls.ipynb │ │ │ │ └── wait-user-input.ipynb │ │ │ ├── use-remote-graph.md │ │ │ ├── memory │ │ │ │ ├── semantic-search.ipynb │ │ │ │ ├── add-summary-conversation-history.ipynb │ │ │ │ ├── manage-conversation-history.ipynb │ │ │ │ └── delete-messages.ipynb │ │ │ ├── streaming-tokens.ipynb │ │ │ ├── local-studio.md │ │ │ ├── create-react-agent.ipynb │ │ │ ├── configuration.ipynb │ │ │ ├── visualization.ipynb │ │ │ ├── branching.ipynb │ │ │ ├── create-react-agent-memory.ipynb │ │ │ ├── many-tools.ipynb │ │ │ ├── streaming-events-from-within-tools.ipynb │ │ │ ├── async.ipynb │ │ │ ├── streaming-content.ipynb │ │ │ ├── subgraphs-manage-state.ipynb │ │ │ ├── persistence.ipynb │ │ │ ├── node-retries.ipynb │ │ │ ├── persistence_mongodb.ipynb │ │ │ ├── create-react-agent-hitl.ipynb │ │ │ ├── state-model.ipynb │ │ │ ├── persistence_postgres.ipynb │ │ │ ├── deploy-self-hosted.md │ │ │ ├── stream-multiple.ipynb │ │ │ ├── index.md │ │ │ ├── pass-run-time-values-to-tools.ipynb │ │ │ ├── stream-updates.ipynb │ │ │ ├── cross-thread-persistence.ipynb │ │ │ ├── return-when-recursion-limit-hits.ipynb │ │ │ ├── tool-calling-errors.ipynb │ │ │ ├── create-react-agent-system-prompt.ipynb │ │ │ ├── streaming-subgraphs.ipynb │ │ │ ├── command.ipynb │ │ │ ├── tool-calling.ipynb │ │ │ ├── persistence_redis.ipynb │ │ │ ├── pass-config-to-tools.ipynb │ │ │ ├── stream-values.ipynb │ │ │ ├── react-agent-structured-output.ipynb │ │ │ ├── autogen-langgraph-platform.ipynb │ │ │ ├── subgraph.ipynb │ │ │ ├── autogen-integration.ipynb │ │ │ ├── react_diagrams.png │ │ │ ├── streaming-tokens-without-langchain.ipynb │ │ │ ├── map-reduce.ipynb │ │ │ ├── subgraph-persistence.ipynb │ │ │ ├── disable-streaming.ipynb │ │ │ ├── pass_private_state.ipynb │ │ │ ├── subgraph-transform-state.ipynb │ │ │ ├── input_output_schema.ipynb │ │ │ └── streaming-events-from-within-tools-without-langchain.ipynb │ │ └── reference │ │ ├── remote_graph.md │ │ ├── errors.md │ │ ├── channels.md │ │ ├── graphs.md │ │ ├── prebuilt.md │ │ ├── index.md │ │ ├── store.md │ │ ├── checkpoints.md │ │ ├── types.md │ │ └── constants.md │ ├── README.md │ └── codespell_notebooks.sh ├── libs │ ├── checkpoint-postgres │ │ ├── langgraph │ │ │ ├── checkpoint │ │ │ │ └── postgres │ │ │ │ ├── _internal.py │ │ │ │ ├── _ainternal.py │ │ │ │ ├── __init__.py │ │ │ │ ├── aio.py │ │ │ │ ├── py.typed │ │ │ │ └── base.py │ │ │ └── store │ │ │ └── postgres │ │ │ ├── __init__.py │ │ │ ├── aio.py │ │ │ ├── py.typed │ │ │ └── base.py │ │ ├── Makefile │ │ ├── pyproject.toml │ │ ├── tests │ │ │ ├── compose-postgres.yml │ │ │ ├── conftest.py │ │ │ ├── test_sync.py │ │ │ ├── test_async.py │ │ │ ├── __init__.py │ │ │ ├── test_store.py │ │ │ ├── embed_test_utils.py │ │ │ └── test_async_store.py │ │ ├── README.md │ │ └── poetry.lock │ ├── langgraph │ │ ├── langgraph │ │ │ ├── pregel │ │ │ │ ├── loop_docs.md │ │ │ │ ├── runner.py │ │ │ │ ├── remote_docs.md │ │ │ │ ├── algo_docs.md │ │ │ │ ├── write.py │ │ │ │ ├── call_docs.md │ │ │ │ ├── log.py │ │ │ │ ├── remote.py │ │ │ │ ├── runner_docs.md │ │ │ │ ├── protocol.py │ │ │ │ ├── io.py │ │ │ │ ├── __init__.py │ │ │ │ ├── io_docs.md │ │ │ │ ├── types.py │ │ │ │ ├── debug_docs.md │ │ │ │ ├── retry.py │ │ │ │ ├── write_docs.md │ │ │ │ ├── protocol_docs.md │ │ │ │ ├── executor_docs.md │ │ │ │ ├── validate.py │ │ │ │ ├── utils_docs.md │ │ │ │ ├── read_docs.md │ │ │ │ ├── utils.py │ │ │ │ ├── debug.py │ │ │ │ ├── manager_docs.md │ │ │ │ ├── messages_docs.md │ │ │ │ ├── messages.py │ │ │ │ ├── loop.py │ │ │ │ ├── retry_docs.md │ │ │ │ ├── log_docs.md │ │ │ │ ├── types_docs.md │ │ │ │ ├── validate_docs.md │ │ │ │ ├── algo.py │ │ │ │ ├── __init___docs.md │ │ │ │ ├── manager.py │ │ │ │ ├── executor.py │ │ │ │ └── read.py │ │ │ ├── managed │ │ │ │ ├── context_docs.md │ │ │ │ ├── is_last_step_docs.md │ │ │ │ ├── __init__.py │ │ │ │ ├── shared_value.py │ │ │ │ ├── context.py │ │ │ │ ├── base_docs.md │ │ │ │ ├── is_last_step.py │ │ │ │ ├── shared_value_docs.md │ │ │ │ ├── __init___docs.md │ │ │ │ └── base.py │ │ │ ├── version_docs.md │ │ │ ├── version.py │ │ │ ├── constants_docs.md │ │ │ ├── constants.py │ │ │ ├── graph │ │ │ │ ├── state_docs.md │ │ │ │ ├── graph_docs.md │ │ │ │ ├── graph.py │ │ │ │ ├── __init__.py │ │ │ │ ├── message.py │ │ │ │ ├── __init___docs.md │ │ │ │ ├── message_docs.md │ │ │ │ └── state.py │ │ │ ├── _api │ │ │ │ ├── deprecation.py │ │ │ │ ├── __init__.py │ │ │ │ ├── deprecation_docs.md │ │ │ │ └── __init___docs.md │ │ │ ├── utils │ │ │ │ ├── queue.py │ │ │ │ ├── fields_docs.md │ │ │ │ ├── config.py │ │ │ │ ├── fields.py │ │ │ │ ├── config_docs.md │ │ │ │ ├── __init__.py │ │ │ │ ├── pydantic.py │ │ │ │ ├── future_docs.md │ │ │ │ ├── runnable_docs.md │ │ │ │ ├── pydantic_docs.md │ │ │ │ ├── runnable.py │ │ │ │ ├── __init___docs.md │ │ │ │ └── queue_docs.md │ │ │ ├── types.py │ │ │ ├── prebuilt │ │ │ │ ├── chat_agent_executor_docs.md │ │ │ │ ├── __init__.py │ │ │ │ ├── tool_validator.py │ │ │ │ ├── tool_executor_docs.md │ │ │ │ ├── chat_agent_executor.py │ │ │ │ ├── tool_validator_docs.md │ │ │ │ ├── tool_node_docs.md │ │ │ │ ├── tool_node.py │ │ │ │ ├── __init___docs.md │ │ │ │ └── tool_executor.py │ │ │ ├── errors_docs.md │ │ │ ├── py.typed │ │ │ ├── types_docs.md │ │ │ ├── errors.py │ │ │ ├── channels │ │ │ │ ├── context_docs.md │ │ │ │ ├── untracked_value_docs.md │ │ │ │ ├── dynamic_barrier_value_docs.md │ │ │ │ ├── last_value.py │ │ │ │ ├── binop_docs.md │ │ │ │ ├── named_barrier_value_docs.md │ │ │ │ ├── topic_docs.md │ │ │ │ ├── __init__.py │ │ │ │ ├── untracked_value.py │ │ │ │ ├── last_value_docs.md │ │ │ │ ├── any_value.py │ │ │ │ ├── named_barrier_value.py │ │ │ │ ├── binop.py │ │ │ │ ├── ephemeral_value.py │ │ │ │ ├── dynamic_barrier_value.py │ │ │ │ ├── context.py │ │ │ │ ├── topic.py │ │ │ │ ├── base_docs.md │ │ │ │ ├── ephemeral_value_docs.md │ │ │ │ ├── __init___docs.md │ │ │ │ ├── any_value_docs.md │ │ │ │ └── base.py │ │ │ └── func │ │ │ └── __init___docs.md │ │ ├── bench │ │ │ ├── fanout_to_subgraph.py │ │ │ ├── react_agent.py │ │ │ ├── __init__.py │ │ │ ├── wide_state.py │ │ │ └── __main__.py │ │ ├── LICENSE │ │ ├── Makefile │ │ ├── pyproject.toml │ │ ├── tests │ │ │ ├── fake_tracer_docs.md │ │ │ ├── compose-postgres.yml │ │ │ ├── fake_chat_docs.md │ │ │ ├── any_int_docs.md │ │ │ ├── test_pregel_docs.md │ │ │ ├── test_utils.py │ │ │ ├── conftest.py │ │ │ ├── test_interruption_docs.md │ │ │ ├── test_remote_graph.py │ │ │ ├── fake_chat.py │ │ │ ├── test_channels_docs.md │ │ │ ├── test_large_cases_docs.md │ │ │ ├── conftest_docs.md │ │ │ ├── any_int.py │ │ │ ├── test_remote_graph_docs.md │ │ │ ├── test_interruption.py │ │ │ ├── test_utils_docs.md │ │ │ ├── test_io_docs.md │ │ │ ├── memory_assert.py │ │ │ ├── test_prebuilt_docs.md │ │ │ ├── __init__.py │ │ │ ├── test_pregel_async.py │ │ │ ├── test_prebuilt.py │ │ │ ├── test_io.py │ │ │ ├── test_state.py │ │ │ ├── test_messages_state.py │ │ │ ├── test_channels.py │ │ │ ├── test_runnable.py │ │ │ ├── test_algo_docs.md │ │ │ ├── agents_docs.md │ │ │ ├── test_pregel_async_docs.md │ │ │ ├── test_tracing_interops.py │ │ │ ├── memory_assert_docs.md │ │ │ ├── messages_docs.md │ │ │ ├── test_messages_state_docs.md │ │ │ ├── messages.py │ │ │ ├── test_pregel.py │ │ │ ├── any_str_docs.md │ │ │ ├── __init___docs.md │ │ │ ├── any_str.py │ │ │ ├── test_algo.py │ │ │ ├── test_tracing_interops_docs.md │ │ │ ├── test_large_cases_async_docs.md │ │ │ ├── fake_tracer.py │ │ │ ├── test_runnable_docs.md │ │ │ ├── __snapshots__ │ │ │ │ ├── test_pregel.ambr │ │ │ │ └── test_pregel_async.ambr │ │ │ └── test_state_docs.md │ │ ├── README.md │ │ ├── poetry.toml │ │ └── poetry.lock │ ├── checkpoint │ │ ├── langgraph │ │ │ ├── checkpoint │ │ │ │ ├── memory │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── __init__.py.md │ │ │ │ │ └── py.typed │ │ │ │ ├── serde │ │ │ │ │ ├── jsonplus.py │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── types.py │ │ │ │ │ ├── py.typed │ │ │ │ │ └── base.py │ │ │ │ └── base │ │ │ │ ├── id.py │ │ │ │ ├── __init__.py │ │ │ │ ├── __init__.py.md │ │ │ │ └── py.typed │ │ │ └── store │ │ │ ├── memory │ │ │ │ ├── __init__.py │ │ │ │ ├── __init__.py.md │ │ │ │ └── py.typed │ │ │ └── base │ │ │ ├── batch.py.md │ │ │ ├── embed.py │ │ │ ├── batch.py │ │ │ ├── __init__.py │ │ │ ├── __init__.py.md │ │ │ ├── py.typed │ │ │ └── embed.py.md │ │ ├── LICENSE │ │ ├── Makefile │ │ ├── pyproject.toml │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── test_jsonplus.py │ │ │ ├── test_store.py │ │ │ ├── test_memory.py │ │ │ ├── test_memory.py.md │ │ │ └── embed_test_utils.py │ │ ├── README.md │ │ └── poetry.lock │ ├── checkpoint-duckdb │ │ ├── langgraph │ │ │ ├── checkpoint │ │ │ │ └── duckdb │ │ │ │ ├── __init__.py │ │ │ │ ├── aio.py │ │ │ │ ├── py.typed │ │ │ │ └── base.py │ │ │ └── store │ │ │ └── duckdb │ │ │ ├── __init__.py │ │ │ ├── aio.py │ │ │ ├── py.typed │ │ │ └── base.py │ │ ├── Makefile │ │ ├── pyproject.toml │ │ ├── tests │ │ │ ├── test_sync.py │ │ │ ├── test_async.py │ │ │ ├── test_store.py │ │ │ └── test_async_store.py │ │ ├── README.md │ │ └── poetry.lock │ ├── scheduler-kafka │ │ ├── langgraph │ │ │ └── scheduler │ │ │ └── kafka │ │ │ ├── default_sync.py │ │ │ ├── default_async.py │ │ │ ├── serde.py │ │ │ ├── __init__.py │ │ │ ├── types.py │ │ │ ├── retry.py │ │ │ ├── orchestrator.py │ │ │ ├── py.typed │ │ │ └── executor.py │ │ ├── LICENSE │ │ ├── langgraph-distributed.png │ │ ├── Makefile │ │ ├── pyproject.toml │ │ ├── tests │ │ │ ├── conftest.py │ │ │ ├── test_fanout_sync.py │ │ │ ├── __init__.py │ │ │ ├── test_fanout.py │ │ │ ├── test_subgraph.py │ │ │ ├── messages.py │ │ │ ├── test_push.py │ │ │ ├── test_subgraph_sync.py │ │ │ ├── any.py │ │ │ ├── drain.py │ │ │ ├── test_push_sync.py │ │ │ └── compose.yml │ │ ├── README.md │ │ └── poetry.lock │ ├── sdk-js │ │ ├── LICENSE │ │ ├── langchain.config.js │ │ ├── README.md │ │ ├── yarn.lock │ │ ├── package.json │ │ ├── tsconfig.cjs.json │ │ ├── tsconfig.json │ │ └── src │ │ ├── schema.ts │ │ ├── utils │ │ │ ├── async_caller.ts │ │ │ ├── eventsource-parser │ │ │ │ ├── LICENSE │ │ │ │ ├── parse.ts │ │ │ │ ├── stream.ts │ │ │ │ ├── types.ts │ │ │ │ └── index.ts │ │ │ ├── stream.ts │ │ │ ├── env.ts │ │ │ └── signals.ts │ │ ├── types.ts │ │ ├── client.ts │ │ └── index.ts │ ├── sdk-py │ │ ├── LICENSE │ │ ├── Makefile │ │ ├── pyproject.toml │ │ ├── README.md │ │ ├── langgraph_sdk │ │ │ ├── client.py │ │ │ ├── __init__.py │ │ │ ├── sse.py │ │ │ ├── py.typed │ │ │ └── schema.py │ │ └── poetry.lock │ ├── cli │ │ ├── LICENSE │ │ ├── langgraph_cli │ │ │ ├── config.py │ │ │ ├── version.py │ │ │ ├── util.py │ │ │ ├── exec.py │ │ │ ├── constants.py │ │ │ ├── __init__.py │ │ │ ├── docker.py │ │ │ ├── templates.py │ │ │ ├── cli.py │ │ │ ├── py.typed │ │ │ ├── analytics.py │ │ │ └── progress.py │ │ ├── Makefile │ │ ├── pyproject.toml │ │ ├── js-examples │ │ │ ├── LICENSE │ │ │ ├── jest.config.js │ │ │ ├── tests │ │ │ │ ├── graph.int.test.ts │ │ │ │ └── agent.test.ts │ │ │ ├── README.md │ │ │ ├── yarn.lock │ │ │ ├── langgraph.json │ │ │ ├── package.json │ │ │ ├── static │ │ │ │ └── studio.png │ │ │ ├── tsconfig.json │ │ │ └── src │ │ │ └── agent │ │ │ ├── state.ts │ │ │ └── graph.ts │ │ ├── tests │ │ │ ├── unit_tests │ │ │ │ ├── test_config.json │ │ │ │ ├── conftest.py │ │ │ │ ├── __init__.py │ │ │ │ ├── cli │ │ │ │ │ ├── test_templates.py │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── test_cli.py │ │ │ │ ├── pipconfig.txt │ │ │ │ ├── agent.py │ │ │ │ ├── test_config.py │ │ │ │ ├── graphs │ │ │ │ │ └── agent.py │ │ │ │ ├── helpers.py │ │ │ │ └── test_docker.py │ │ │ ├── integration_tests │ │ │ │ ├── __init__.py │ │ │ │ └── test_cli.py │ │ │ └── __init__.py │ │ ├── README.md │ │ ├── examples │ │ │ ├── Makefile │ │ │ ├── pipconf.txt │ │ │ ├── pyproject.toml │ │ │ ├── graphs_reqs_b │ │ │ │ ├── hello.py │ │ │ │ ├── requirements.txt │ │ │ │ ├── graphs_submod │ │ │ │ │ ├── subprompt.txt │ │ │ │ │ └── agent.py │ │ │ │ ├── prompt.txt │ │ │ │ ├── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── greeter.py │ │ │ │ └── langgraph.json │ │ │ ├── langgraph.json │ │ │ ├── graphs │ │ │ │ ├── langgraph.json │ │ │ │ ├── storm.py │ │ │ │ └── agent.py │ │ │ ├── poetry.lock │ │ │ └── graphs_reqs_a │ │ │ ├── hello.py │ │ │ ├── requirements.txt │ │ │ ├── graphs_submod │ │ │ │ ├── __init__.py │ │ │ │ ├── subprompt.txt │ │ │ │ └── agent.py │ │ │ ├── __init__.py │ │ │ ├── prompt.txt │ │ │ └── langgraph.json │ │ └── poetry.lock │ └── checkpoint-sqlite │ ├── langgraph │ │ └── checkpoint │ │ └── sqlite │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── aio.py │ │ └── py.typed │ ├── Makefile │ ├── pyproject.toml │ ├── tests │ │ ├── __init__.py │ │ ├── test_aiosqlite.py │ │ └── test_sqlite.py │ ├── README.md │ └── poetry.lock ├── README.md ├── CONTRIBUTING.md ├── examples │ ├── multi_agent │ │ ├── hierarchical_agent_teams.ipynb │ │ ├── agent_supervisor.ipynb │ │ └── multi-agent-collaboration.ipynb │ ├── introduction.ipynb │ ├── customer-support │ │ └── customer-support.ipynb │ ├── plan-and-execute │ │ └── plan-and-execute.ipynb │ ├── run-id-langsmith.ipynb │ ├── react-agent-from-scratch.ipynb │ ├── streaming-from-final-node.ipynb │ ├── recursion-limit.ipynb │ ├── human_in_the_loop │ │ ├── breakpoints.ipynb │ │ ├── time-travel.ipynb │ │ ├── edit-graph-state.ipynb │ │ ├── dynamic_breakpoints.ipynb │ │ ├── review-tool-calls.ipynb │ │ └── wait-user-input.ipynb │ ├── memory │ │ ├── add-summary-conversation-history.ipynb │ │ ├── manage-conversation-history.ipynb │ │ └── delete-messages.ipynb │ ├── streaming-tokens.ipynb │ ├── create-react-agent.ipynb │ ├── reflexion │ │ └── reflexion.ipynb │ ├── configuration.ipynb │ ├── code_assistant │ │ ├── langgraph_code_assistant.ipynb │ │ └── langgraph_code_assistant_mistral.ipynb │ ├── lats │ │ └── lats.ipynb │ ├── visualization.ipynb │ ├── rewoo │ │ └── rewoo.ipynb │ ├── branching.ipynb │ ├── create-react-agent-memory.ipynb │ ├── streaming-events-from-within-tools.ipynb │ ├── async.ipynb │ ├── web-navigation │ │ └── web_voyager.ipynb │ ├── cloud_examples │ │ └── langgraph_to_langgraph_cloud.ipynb │ ├── rag │ │ ├── langgraph_self_rag_pinecone_movies.ipynb │ │ ├── langgraph_self_rag_local.ipynb │ │ ├── langgraph_adaptive_rag_local.ipynb │ │ ├── langgraph_self_rag.ipynb │ │ ├── langgraph_crag_local.ipynb │ │ ├── langgraph_crag.ipynb │ │ ├── langgraph_adaptive_rag.ipynb │ │ ├── langgraph_agentic_rag.ipynb │ │ └── langgraph_adaptive_rag_cohere.ipynb │ ├── llm-compiler │ │ └── LLMCompiler.ipynb │ ├── streaming-content.ipynb │ ├── README.md │ ├── subgraphs-manage-state.ipynb │ ├── persistence.ipynb │ ├── node-retries.ipynb │ ├── persistence_mongodb.ipynb │ ├── create-react-agent-hitl.ipynb │ ├── state-model.ipynb │ ├── persistence_postgres.ipynb │ ├── usaco │ │ └── usaco.ipynb │ ├── stream-multiple.ipynb │ ├── pass-run-time-values-to-tools.ipynb │ ├── stream-updates.ipynb │ ├── chatbots │ │ └── information-gather-prompting.ipynb │ ├── tool-calling-errors.ipynb │ ├── create-react-agent-system-prompt.ipynb │ ├── streaming-subgraphs.ipynb │ ├── tutorials │ │ ├── tnt-llm │ │ │ └── tnt-llm.ipynb │ │ └── sql-agent.ipynb │ ├── tool-calling.ipynb │ ├── persistence_redis.ipynb │ ├── pass-config-to-tools.ipynb │ ├── stream-values.ipynb │ ├── chatbot-simulation-evaluation │ │ ├── agent-simulation-evaluation.ipynb │ │ ├── langsmith-agent-simulation-evaluation.ipynb │ │ └── simulation_utils.py │ ├── react-agent-structured-output.ipynb │ ├── subgraph.ipynb │ ├── extraction │ │ └── retries.ipynb │ ├── streaming-tokens-without-langchain.ipynb │ ├── map-reduce.ipynb │ ├── storm │ │ └── storm.ipynb │ ├── pass_private_state.ipynb │ ├── self-discover │ │ └── self-discover.ipynb │ ├── subgraph-transform-state.ipynb │ ├── reflection │ │ └── reflection.ipynb │ ├── input_output_schema.ipynb │ └── streaming-events-from-within-tools-without-langchain.ipynb ├── poetry.lock └── security.md ``` `/Users/malcolm/dev/langchain-ai/langgraph/pyproject.toml`: ```toml [tool.poetry] name = "langgraph-monorepo" version = "0.0.1" description = "LangGraph monorepo" authors = [] license = "MIT" readme = "README.md" [tool.poetry.dependencies] python = "^3.10" aiohappyeyeballs = "2.4.3" [tool.poetry.group.docs.dependencies] langgraph = { path = "libs/langgraph/", develop = true } langgraph-checkpoint = { path = "libs/checkpoint/", develop = true } langgraph-checkpoint-sqlite = { path = "libs/checkpoint-sqlite", develop = true } langgraph-checkpoint-postgres = { path = "libs/checkpoint-postgres", develop = true } langgraph-sdk = {path = "libs/sdk-py", develop = true} mkdocs = "^1.6.0" mkdocs-autorefs = ">=1.0.1,<1.1.0" mkdocstrings = "^0.25.1" mkdocstrings-python = "^1.10.4" mkdocs-redirects = "^1.2.1" mkdocs-minify-plugin = "^0.8.0" mkdocs-rss-plugin = "^1.13.1" mkdocs-git-committers-plugin-2 = "^2.3.0" mkdocs-material = {extras = ["imaging"], version = "^9.5.27"} markdown-include = "^0.8.1" markdown-callouts = "^0.4.0" mkdocs-exclude = "^1.0.2" vcrpy = "^6.0.1" click = "^8.1.7" ruff = "^0.6.8" jupyter = "^1.1.1" [tool.poetry.group.test.dependencies] langchain = "^0.3.8" langchain-openai = "^0.2.0" langchain-anthropic = "^0.2.1" langchain-nomic = "^0.1.3" langchain-fireworks = "^0.2.0" langchain-community = "^0.3.0" langchain-experimental = "^0.3.2" langsmith = "^0.1.129" chromadb = "^0.5.5" gpt4all = "^2.8.2" scikit-learn = "^1.5.2" numexpr = "^2.10.1" numpy = "^1.26.4" matplotlib = "^3.9.2" redis = "^5.0.8" pymongo = "^4.8.0" motor = "^3.5.1" grandalf = "^0.8" pyppeteer = "^2.0.0" networkx = "^3.3" autogen = { version = "^0.3.0", python = "<3.13,>=3.8" } [tool.poetry.group.test] optional = true [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] extend-include = ["*.ipynb"] [tool.ruff.lint.per-file-ignores] "docs/*" = [ "E402", # allow imports to appear anywhere in docs "F401", # allow "imported but unused" example code "F811", # allow re-importing the same module, so that cells can stay independent "F841", # allow assignments to variables that are never read -- it's example code # The issues below should be cleaned up when there's time "E722", # allow base imports in notebooks ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/_scripts/generate_api_reference_links.py`: ```py import importlib import inspect import logging import os import re from typing import List, Literal, Optional from typing_extensions import TypedDict import nbformat from nbconvert.preprocessors import Preprocessor logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Base URL for all class documentation _LANGCHAIN_API_REFERENCE = "https://python.langchain.com/api_reference/" _LANGGRAPH_API_REFERENCE = "https://langchain-ai.github.io/langgraph/reference/" # (alias/re-exported modules, source module, class, docs namespace) MANUAL_API_REFERENCES_LANGGRAPH = [ ( ["langgraph.prebuilt"], "langgraph.prebuilt.chat_agent_executor", "create_react_agent", "prebuilt", ), (["langgraph.prebuilt"], "langgraph.prebuilt.tool_node", "ToolNode", "prebuilt"), ( ["langgraph.prebuilt"], "langgraph.prebuilt.tool_node", "tools_condition", "prebuilt", ), ( ["langgraph.prebuilt"], "langgraph.prebuilt.tool_node", "InjectedState", "prebuilt", ), # Graph (["langgraph.graph"], "langgraph.graph.message", "add_messages", "graphs"), (["langgraph.graph"], "langgraph.graph.state", "StateGraph", "graphs"), (["langgraph.graph"], "langgraph.graph.state", "CompiledStateGraph", "graphs"), ([], "langgraph.types", "StreamMode", "types"), (["langgraph.graph"], "langgraph.constants", "START", "constants"), (["langgraph.graph"], "langgraph.constants", "END", "constants"), (["langgraph.constants"], "langgraph.types", "Send", "types"), (["langgraph.constants"], "langgraph.types", "Interrupt", "types"), ([], "langgraph.types", "RetryPolicy", "types"), ([], "langgraph.checkpoint.base", "Checkpoint", "checkpoints"), ([], "langgraph.checkpoint.base", "CheckpointMetadata", "checkpoints"), ([], "langgraph.checkpoint.base", "BaseCheckpointSaver", "checkpoints"), ([], "langgraph.checkpoint.base", "SerializerProtocol", "checkpoints"), ([], "langgraph.checkpoint.serde.jsonplus", "JsonPlusSerializer", "checkpoints"), ([], "langgraph.checkpoint.memory", "MemorySaver", "checkpoints"), ([], "langgraph.checkpoint.sqlite.aio", "AsyncSqliteSaver", "checkpoints"), ([], "langgraph.checkpoint.sqlite", "SqliteSaver", "checkpoints"), ([], "langgraph.checkpoint.postgres.aio", "AsyncPostgresSaver", "checkpoints"), ([], "langgraph.checkpoint.postgres", "PostgresSaver", "checkpoints"), ] WELL_KNOWN_LANGGRAPH_OBJECTS = { (module_, class_): (source_module, namespace) for (modules, source_module, class_, namespace) in MANUAL_API_REFERENCES_LANGGRAPH for module_ in modules + [source_module] } def _make_regular_expression(pkg_prefix: str) -> re.Pattern: if not pkg_prefix.isidentifier(): raise ValueError(f"Invalid package prefix: {pkg_prefix}") return re.compile( r"from\s+(" + pkg_prefix + "(?:_\w+)?(?:\.\w+)*?)\s+import\s+" r"((?:\w+(?:,\s*)?)*" # Match zero or more words separated by a comma+optional ws r"(?:\s*\(.*?\))?)", # Match optional parentheses block re.DOTALL, # Match newlines as well ) # Regular expression to match langchain import lines _IMPORT_LANGCHAIN_RE = _make_regular_expression("langchain") _IMPORT_LANGGRAPH_RE = _make_regular_expression("langgraph") def _get_full_module_name(module_path, class_name) -> Optional[str]: """Get full module name using inspect""" try: module = importlib.import_module(module_path) class_ = getattr(module, class_name) module = inspect.getmodule(class_) if module is None: # For constants, inspect.getmodule() might return None # In this case, we'll return the original module_path return module_path return module.__name__ except AttributeError as e: logger.warning(f"Could not find module for {class_name}, {e}") return None except ImportError as e: logger.warning(f"Failed to load for class {class_name}, {e}") return None def _get_doc_title(data: str, file_name: str) -> str: try: return re.findall(r"^#\s*(.*)", data, re.MULTILINE)[0] except IndexError: pass # Parse the rst-style titles try: return re.findall(r"^(.*)\n=+\n", data, re.MULTILINE)[0] except IndexError: return file_name class ImportInformation(TypedDict): imported: str # imported class name source: str # module path docs: str # URL to the documentation title: str # Title of the document def _get_imports( code: str, doc_title: str, package_ecosystem: Literal["langchain", "langgraph"] ) -> List[ImportInformation]: """Get imports from the given code block. Args: code: Python code block from which to extract imports doc_title: Title of the document package_ecosystem: "langchain" or "langgraph". The two live in different repositories and have separate documentation sites. Returns: List of import information for the given code block """ imports = [] if package_ecosystem == "langchain": pattern = _IMPORT_LANGCHAIN_RE elif package_ecosystem == "langgraph": pattern = _IMPORT_LANGGRAPH_RE else: raise ValueError(f"Invalid package ecosystem: {package_ecosystem}") for import_match in pattern.finditer(code): module = import_match.group(1) if "pydantic_v1" in module: continue imports_str = ( import_match.group(2).replace("(\n", "").replace("\n)", "") ) # Handle newlines within parentheses # remove any newline and spaces, then split by comma imported_classes = [ imp.strip() for imp in re.split(r",\s*", imports_str.replace("\n", "")) if imp.strip() ] for class_name in imported_classes: module_path = _get_full_module_name(module, class_name) if not module_path: continue if len(module_path.split(".")) < 2: continue if package_ecosystem == "langchain": pkg = module_path.split(".")[0].replace("langchain_", "") top_level_mod = module_path.split(".")[1] url = ( _LANGCHAIN_API_REFERENCE + pkg + "/" + top_level_mod + "/" + module_path + "." + class_name + ".html" ) elif package_ecosystem == "langgraph": if (module, class_name) not in WELL_KNOWN_LANGGRAPH_OBJECTS: # Likely not documented yet continue source_module, namespace = WELL_KNOWN_LANGGRAPH_OBJECTS[ (module, class_name) ] url = ( _LANGGRAPH_API_REFERENCE + namespace + "/#" + source_module + "." + class_name ) else: raise ValueError(f"Invalid package ecosystem: {package_ecosystem}") # Add the import information to our list imports.append( { "imported": class_name, "source": module, "docs": url, "title": doc_title, } ) return imports class ImportPreprocessor(Preprocessor): """A preprocessor to replace imports in each Python code cell with links to their documentation and append the import info in a comment.""" def preprocess(self, nb, resources): self.all_imports = [] file_name = os.path.basename(resources.get("metadata", {}).get("name", "")) _DOC_TITLE = _get_doc_title(nb.cells[0].source, file_name) cells = [] for cell in nb.cells: if cell.cell_type == "code": cells.append(cell) imports = _get_imports( cell.source, _DOC_TITLE, "langchain" ) + _get_imports(cell.source, _DOC_TITLE, "langgraph") if not imports: continue cells.append( nbformat.v4.new_markdown_cell( source=f""" <div> <b>API Reference:</b> {' | '.join(f'<a href="{imp["docs"]}">{imp["imported"]}</a>' for imp in imports)} </div> """ ) ) else: cells.append(cell) nb.cells = cells return nb, resources ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/_scripts/prepare_notebooks_for_ci.py`: ```py """Preprocess notebooks for CI. Currently adds VCR cassettes and optionally removes pip install cells.""" import logging import os import json import click import nbformat logger = logging.getLogger(__name__) NOTEBOOK_DIRS = ("docs/docs/how-tos","docs/docs/tutorials") DOCS_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) CASSETTES_PATH = os.path.join(DOCS_PATH, "cassettes") BLOCKLIST_COMMANDS = ( # skip if has WebBaseLoader to avoid caching web pages "WebBaseLoader", # skip if has draw_mermaid_png to avoid generating mermaid images via API "draw_mermaid_png", ) NOTEBOOKS_NO_CASSETTES = ( "docs/docs/how-tos/visualization.ipynb", "docs/docs/how-tos/many-tools.ipynb" ) NOTEBOOKS_NO_EXECUTION = [ # this uses a user provided project name for langsmith "docs/docs/tutorials/tnt-llm/tnt-llm.ipynb", # this uses langsmith datasets "docs/docs/tutorials/chatbot-simulation-evaluation/langsmith-agent-simulation-evaluation.ipynb", # this uses browser APIs "docs/docs/tutorials/web-navigation/web_voyager.ipynb", # these RAG guides use an ollama model "docs/docs/tutorials/rag/langgraph_adaptive_rag_local.ipynb", "docs/docs/tutorials/rag/langgraph_crag_local.ipynb", "docs/docs/tutorials/rag/langgraph_self_rag_local.ipynb", # this loads a massive dataset from gcp "docs/docs/tutorials/usaco/usaco.ipynb", # TODO: figure out why autogen notebook is not runnable (they are just hanging. possible due to code execution?) "docs/docs/how-tos/autogen-integration.ipynb", # TODO: need to update these notebooks to make sure they are runnable in CI "docs/docs/tutorials/storm/storm.ipynb", # issues only when running with VCR "docs/docs/tutorials/lats/lats.ipynb", # issues only when running with VCR "docs/docs/tutorials/rag/langgraph_crag.ipynb", # flakiness from tavily "docs/docs/tutorials/rag/langgraph_adaptive_rag.ipynb", # Cannot create a consistent method resolution error from VCR "docs/docs/how-tos/map-reduce.ipynb" # flakiness from structured output, only when running with VCR ] def comment_install_cells(notebook: nbformat.NotebookNode) -> nbformat.NotebookNode: for cell in notebook.cells: if cell.cell_type != "code": continue if "pip install" in cell.source: # Comment out the lines in cells containing "pip install" cell.source = "\n".join( f"# {line}" if line.strip() else line for line in cell.source.splitlines() ) return notebook def is_magic_command(code: str) -> bool: return code.strip().startswith("%") or code.strip().startswith("!") def is_comment(code: str) -> bool: return code.strip().startswith("#") def has_blocklisted_command(code: str, metadata: dict) -> bool: if 'hide_from_vcr' in metadata: return True code = code.strip() for blocklisted_pattern in BLOCKLIST_COMMANDS: if blocklisted_pattern in code: return True return False def add_vcr_to_notebook( notebook: nbformat.NotebookNode, cassette_prefix: str ) -> nbformat.NotebookNode: """Inject `with vcr.cassette` into each code cell of the notebook.""" # Inject VCR context manager into each code cell for idx, cell in enumerate(notebook.cells): if cell.cell_type != "code": continue lines = cell.source.splitlines() # skip if empty cell if not lines: continue are_magic_lines = [is_magic_command(line) for line in lines] # skip if all magic if all(are_magic_lines): continue if any(are_magic_lines): raise ValueError( "Cannot process code cells with mixed magic and non-magic code." ) # skip if just comments if all(is_comment(line) or not line.strip() for line in lines): continue if has_blocklisted_command(cell.source, cell.metadata): continue cell_id = cell.get("id", idx) cassette_name = f"{cassette_prefix}_{cell_id}.msgpack.zlib" cell.source = f"with custom_vcr.use_cassette('{cassette_name}', filter_headers=['x-api-key', 'authorization'], record_mode='once', serializer='advanced_compressed'):\n" + "\n".join( f" {line}" for line in lines ) # Add import statement vcr_import_lines = [ "import nest_asyncio", "nest_asyncio.apply()", "import vcr", "import msgpack", "import base64", "import zlib", "import os", "os.environ.pop(\"LANGCHAIN_TRACING_V2\", None)", "custom_vcr = vcr.VCR()", "", "def compress_data(data, compression_level=9):", " packed = msgpack.packb(data, use_bin_type=True)", " compressed = zlib.compress(packed, level=compression_level)", " return base64.b64encode(compressed).decode('utf-8')", "", "def decompress_data(compressed_string):", " decoded = base64.b64decode(compressed_string)", " decompressed = zlib.decompress(decoded)", " return msgpack.unpackb(decompressed, raw=False)", "", "class AdvancedCompressedSerializer:", " def serialize(self, cassette_dict):", " return compress_data(cassette_dict)", "", " def deserialize(self, cassette_string):", " return decompress_data(cassette_string)", "", "custom_vcr.register_serializer('advanced_compressed', AdvancedCompressedSerializer())", "custom_vcr.serializer = 'advanced_compressed'", ] import_cell = nbformat.v4.new_code_cell(source="\n".join(vcr_import_lines)) import_cell.pop("id", None) notebook.cells.insert(0, import_cell) return notebook def process_notebooks(should_comment_install_cells: bool) -> None: for directory in NOTEBOOK_DIRS: for root, _, files in os.walk(directory): for file in files: if not file.endswith(".ipynb") or "ipynb_checkpoints" in root: continue notebook_path = os.path.join(root, file) try: notebook = nbformat.read(notebook_path, as_version=4) if should_comment_install_cells: notebook = comment_install_cells(notebook) base_filename = os.path.splitext(os.path.basename(file))[0] cassette_prefix = os.path.join(CASSETTES_PATH, base_filename) if notebook_path not in NOTEBOOKS_NO_CASSETTES: notebook = add_vcr_to_notebook( notebook, cassette_prefix=cassette_prefix ) if notebook_path in NOTEBOOKS_NO_EXECUTION: # Add a cell at the beginning to indicate that this notebook should not be executed warning_cell = nbformat.v4.new_markdown_cell( source="**Warning:** This notebook is not meant to be executed automatically." ) notebook.cells.insert(0, warning_cell) # Add a special tag to the first code cell if notebook.cells and notebook.cells[1].cell_type == "code": notebook.cells[1].metadata["tags"] = notebook.cells[1].metadata.get("tags", []) + ["no_execution"] nbformat.write(notebook, notebook_path) logger.info(f"Processed: {notebook_path}") except Exception as e: logger.error(f"Error processing {notebook_path}: {e}") with open(os.path.join(DOCS_PATH, "notebooks_no_execution.json"), "w") as f: json.dump(NOTEBOOKS_NO_EXECUTION, f) @click.command() @click.option( "--comment-install-cells", is_flag=True, default=False, help="Whether to comment out install cells", ) def main(comment_install_cells): process_notebooks(should_comment_install_cells=comment_install_cells) logger.info("All notebooks processed successfully.") if __name__ == "__main__": main() ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/_scripts/notebook_convert.py`: ```py import os import re from pathlib import Path import nbformat from nbconvert.exporters import MarkdownExporter from nbconvert.preprocessors import Preprocessor from generate_api_reference_links import ImportPreprocessor class EscapePreprocessor(Preprocessor): def preprocess_cell(self, cell, resources, cell_index): if cell.cell_type == "markdown": # rewrite markdown links to html links (excluding image links) cell.source = re.sub( r"(?<!!)\[([^\]]*)\]\((?![^\)]*//)([^)]*)(?:\.ipynb)?\)", r'<a href="\2">\1</a>', cell.source, ) # Fix image paths in <img> tags cell.source = re.sub( r'<img\s+src="\.?/img/([^"]+)"', r'<img src="../img/\1"', cell.source ) elif cell.cell_type == "code": # escape ``` in code cell.source = cell.source.replace("```", r"\`\`\`") # escape ``` in output if "outputs" in cell: filter_out = set() for i, output in enumerate(cell["outputs"]): if "text" in output: if not output["text"].strip(): filter_out.add(i) continue value = output["text"].replace("```", r"\`\`\`") # handle a funky case w/ references in text value = re.sub(r"\[(\d+)\](?=\[(\d+)\])", r"[\1]\\", value) output["text"] = value elif "data" in output: for key, value in output["data"].items(): if isinstance(value, str): value = value.replace("```", r"\`\`\`") # handle a funky case w/ references in text output["data"][key] = re.sub( r"\[(\d+)\](?=\[(\d+)\])", r"[\1]\\", value ) cell["outputs"] = [ output for i, output in enumerate(cell["outputs"]) if i not in filter_out ] return cell, resources class ExtractAttachmentsPreprocessor(Preprocessor): """ Extracts all of the outputs from the notebook file. The extracted outputs are returned in the 'resources' dictionary. """ def preprocess_cell(self, cell, resources, cell_index): """ Apply a transformation on each cell, Parameters ---------- cell : NotebookNode cell Notebook cell being processed resources : dictionary Additional resources used in the conversion process. Allows preprocessors to pass variables into the Jinja engine. cell_index : int Index of the cell being processed (see base.py) """ # Get files directory if it has been specified # Make sure outputs key exists if not isinstance(resources["outputs"], dict): resources["outputs"] = {} # Loop through all of the attachments in the cell for name, attach in cell.get("attachments", {}).items(): for mime, data in attach.items(): if mime not in { "image/png", "image/jpeg", "image/svg+xml", "application/pdf", }: continue # attachments are pre-rendered. Only replace markdown-formatted # images with the following logic attach_str = f"({name})" if attach_str in cell.source: data = f"(data:{mime};base64,{data})" cell.source = cell.source.replace(attach_str, data) return cell, resources exporter = MarkdownExporter( preprocessors=[ EscapePreprocessor, ExtractAttachmentsPreprocessor, ImportPreprocessor, ], template_name="mdoutput", extra_template_basedirs=[ os.path.join(os.path.dirname(__file__), "notebook_convert_templates") ], ) def convert_notebook( notebook_path: Path, ) -> Path: with open(notebook_path) as f: nb = nbformat.read(f, as_version=4) body, _ = exporter.from_notebook_node(nb) return body ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/_scripts/notebook_hooks.py`: ```py import logging from typing import Any, Dict from mkdocs.structure.pages import Page from mkdocs.structure.files import Files, File from notebook_convert import convert_notebook logger = logging.getLogger(__name__) logging.basicConfig() logger.setLevel(logging.INFO) class NotebookFile(File): def is_documentation_page(self): return True def on_files(files: Files, **kwargs: Dict[str, Any]): new_files = Files([]) for file in files: if file.src_path.endswith(".ipynb"): new_file = NotebookFile( path=file.src_path, src_dir=file.src_dir, dest_dir=file.dest_dir, use_directory_urls=file.use_directory_urls, ) new_files.append(new_file) else: new_files.append(file) return new_files def on_page_markdown(markdown: str, page: Page, **kwargs: Dict[str, Any]): if page.file.src_path.endswith(".ipynb"): logger.info("Processing Jupyter notebook: %s", page.file.src_path) body = convert_notebook(page.file.abs_src_path) return body return markdown ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/_scripts/download_tiktoken.py`: ```py import tiktoken # This will trigger the download and caching of the necessary files for encoding in ("gpt2", "gpt-3.5"): tiktoken.encoding_for_model(encoding) ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/multi_agent/agent_supervisor.py`: ```py # %% [markdown] # # Multi-agent supervisor # # The [previous example](../multi-agent-collaboration) routed messages automatically based on the output of the initial researcher agent. # # We can also choose to use an [LLM to orchestrate](https://langchain-ai.github.io/langgraph/concepts/multi_agent/#supervisor) the different agents. # # Below, we will create an agent group, with an agent supervisor to help delegate tasks. # # ![diagram](attachment:8ee0a8ce-f0a8-4019-b5bf-b20933e40956.png) # # To simplify the code in each agent node, we will use LangGraph's prebuilt [create_react_agent](https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent). This and other "advanced agent" notebooks are designed to show how you can implement certain design patterns in LangGraph. If the pattern suits your needs, we recommend combining it with some of the other fundamental patterns described elsewhere in the docs for best performance. # # ## Setup # # First, let's install required packages and set our API keys # %% %%capture --no-stderr %pip install -U langgraph langchain_community langchain_anthropic langchain_experimental # %% import getpass import os def _set_if_undefined(var: str): if not os.environ.get(var): os.environ[var] = getpass.getpass(f"Please provide your {var}") _set_if_undefined("ANTHROPIC_API_KEY") _set_if_undefined("TAVILY_API_KEY") # %% [markdown] # <div class="admonition tip"> # <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p> # <p style="padding-top: 5px;"> # Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>. # </p> # </div> # %% [markdown] # ## Create tools # # For this example, you will make an agent to do web research with a search engine, and one agent to create plots. Define the tools they'll use below: # %% from typing import Annotated from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.tools import tool from langchain_experimental.utilities import PythonREPL tavily_tool = TavilySearchResults(max_results=5) # This executes code locally, which can be unsafe repl = PythonREPL() @tool def python_repl_tool( code: Annotated[str, "The python code to execute to generate your chart."], ): """Use this to execute python code and do math. If you want to see the output of a value, you should print it out with `print(...)`. This is visible to the user.""" try: result = repl.run(code) except BaseException as e: return f"Failed to execute. Error: {repr(e)}" result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}" return result_str # %% [markdown] # ### Create Agent Supervisor # # It will use LLM with structured output to choose the next worker node OR finish processing. # %% from typing import Literal from typing_extensions import TypedDict from langchain_anthropic import ChatAnthropic from langgraph.graph import MessagesState from langgraph.types import Command members = ["researcher", "coder"] # Our team supervisor is an LLM node. It just picks the next agent to process # and decides when the work is completed options = members + ["FINISH"] system_prompt = ( "You are a supervisor tasked with managing a conversation between the" f" following workers: {members}. Given the following user request," " respond with the worker to act next. Each worker will perform a" " task and respond with their results and status. When finished," " respond with FINISH." ) class Router(TypedDict): """Worker to route to next. If no workers needed, route to FINISH.""" next: Literal[*options] llm = ChatAnthropic(model="claude-3-5-sonnet-latest") def supervisor_node(state: MessagesState) -> Command[Literal[*members, "__end__"]]: messages = [ {"role": "system", "content": system_prompt}, ] + state["messages"] response = llm.with_structured_output(Router).invoke(messages) goto = response["next"] if goto == "FINISH": goto = END return Command(goto=goto) # %% [markdown] # ## Construct Graph # # We're ready to start building the graph. Below, define the state and worker nodes using the function we just defined. # %% from langchain_core.messages import HumanMessage from langgraph.graph import StateGraph, START, END from langgraph.prebuilt import create_react_agent research_agent = create_react_agent( llm, tools=[tavily_tool], state_modifier="You are a researcher. DO NOT do any math." ) def research_node(state: MessagesState) -> Command[Literal["supervisor"]]: result = research_agent.invoke(state) return Command( update={ "messages": [ HumanMessage(content=result["messages"][-1].content, name="researcher") ] }, goto="supervisor", ) # NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED code_agent = create_react_agent(llm, tools=[python_repl_tool]) def code_node(state: MessagesState) -> Command[Literal["supervisor"]]: result = code_agent.invoke(state) return Command( update={ "messages": [ HumanMessage(content=result["messages"][-1].content, name="coder") ] }, goto="supervisor", ) builder = StateGraph(MessagesState) builder.add_edge(START, "supervisor") builder.add_node("supervisor", supervisor_node) builder.add_node("researcher", research_node) builder.add_node("coder", code_node) graph = builder.compile() # %% from IPython.display import display, Image display(Image(graph.get_graph().draw_mermaid_png())) # %% [markdown] # ## Invoke the team # # With the graph created, we can now invoke it and see how it performs! # %% for s in graph.stream( {"messages": [("user", "What's the square root of 42?")]}, subgraphs=True ): print(s) print("----") # %% for s in graph.stream( { "messages": [ ( "user", "Find the latest GDP of New York and California, then calculate the average", ) ] }, subgraphs=True, ): print(s) print("----") ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/multi_agent/multi-agent-collaboration.py`: ```py # %% [markdown] # # Multi-agent network # # A single agent can usually operate effectively using a handful of tools within a single domain, but even using powerful models like `gpt-4`, it can be less effective at using many tools. # # One way to approach complicated tasks is through a "divide-and-conquer" approach: create an specialized agent for each task or domain and route tasks to the correct "expert". This is an example of a [multi-agent network](https://langchain-ai.github.io/langgraph/concepts/multi_agent/#network) architecture. # # This notebook (inspired by the paper [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://arxiv.org/abs/2308.08155), by Wu, et. al.) shows one way to do this using LangGraph. # # The resulting graph will look something like the following diagram: # # ![multi_agent diagram](attachment:8088306a-da20-4f95-bb07-c3fbd546762c.png) # # Before we get started, a quick note: this and other multi-agent notebooks are designed to show _how_ you can implement certain design patterns in LangGraph. If the pattern suits your needs, we recommend combining it with some of the other fundamental patterns described elsewhere in the docs for best performance. # # ## Setup # # First, let's install our required packages and set our API keys: # %% # %%capture --no-stderr # %pip install -U langchain_community langchain_anthropic langchain_experimental matplotlib langgraph # %% import getpass import os def _set_if_undefined(var: str): if not os.environ.get(var): os.environ[var] = getpass.getpass(f"Please provide your {var}") _set_if_undefined("ANTHROPIC_API_KEY") _set_if_undefined("TAVILY_API_KEY") # %% [markdown] # <div class="admonition tip"> # <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p> # <p style="padding-top: 5px;"> # Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>. # </p> # </div> # %% [markdown] # ## Define tools # # We will also define some tools that our agents will use in the future # %% from typing import Annotated from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.tools import tool from langchain_experimental.utilities import PythonREPL tavily_tool = TavilySearchResults(max_results=5) # Warning: This executes code locally, which can be unsafe when not sandboxed repl = PythonREPL() @tool def python_repl_tool( code: Annotated[str, "The python code to execute to generate your chart."], ): """Use this to execute python code. If you want to see the output of a value, you should print it out with `print(...)`. This is visible to the user.""" try: result = repl.run(code) except BaseException as e: return f"Failed to execute. Error: {repr(e)}" result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}" return ( result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER." ) # %% [markdown] # ## Create graph # # Now that we've defined our tools and made some helper functions, will create the individual agents below and tell them how to talk to each other using LangGraph. # %% [markdown] # ### Define Agent Nodes # # We now need to define the nodes. # # First, we'll create a utility to create a system prompt for each agent. # %% def make_system_prompt(suffix: str) -> str: return ( "You are a helpful AI assistant, collaborating with other assistants." " Use the provided tools to progress towards answering the question." " If you are unable to fully answer, that's OK, another assistant with different tools " " will help where you left off. Execute what you can to make progress." " If you or any of the other assistants have the final answer or deliverable," " prefix your response with FINAL ANSWER so the team knows to stop." f"\n{suffix}" ) # %% from typing import Literal from langchain_core.messages import BaseMessage, HumanMessage from langchain_anthropic import ChatAnthropic from langgraph.prebuilt import create_react_agent from langgraph.graph import MessagesState, END from langgraph.types import Command llm = ChatAnthropic(model="claude-3-5-sonnet-latest") def get_next_node(last_message: BaseMessage, goto: str): if "FINAL ANSWER" in last_message.content: # Any agent decided the work is done return END return goto # Research agent and node research_agent = create_react_agent( llm, tools=[tavily_tool], state_modifier=make_system_prompt( "You can only do research. You are working with a chart generator colleague." ), ) def research_node( state: MessagesState, ) -> Command[Literal["chart_generator", END]]: result = research_agent.invoke(state) goto = get_next_node(result["messages"][-1], "chart_generator") # wrap in a human message, as not all providers allow # AI message at the last position of the input messages list result["messages"][-1] = HumanMessage( content=result["messages"][-1].content, name="researcher" ) return Command( update={ # share internal message history of research agent with other agents "messages": result["messages"], }, goto=goto, ) # Chart generator agent and node # NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED chart_agent = create_react_agent( llm, [python_repl_tool], state_modifier=make_system_prompt( "You can only generate charts. You are working with a researcher colleague." ), ) def chart_node(state: MessagesState) -> Command[Literal["researcher", END]]: result = chart_agent.invoke(state) goto = get_next_node(result["messages"][-1], "researcher") # wrap in a human message, as not all providers allow # AI message at the last position of the input messages list result["messages"][-1] = HumanMessage( content=result["messages"][-1].content, name="chart_generator" ) return Command( update={ # share internal message history of chart agent with other agents "messages": result["messages"], }, goto=goto, ) # %% [markdown] # ### Define the Graph # # We can now put it all together and define the graph! # %% from langgraph.graph import StateGraph, START workflow = StateGraph(MessagesState) workflow.add_node("researcher", research_node) workflow.add_node("chart_generator", chart_node) workflow.add_edge(START, "researcher") graph = workflow.compile() # %% from IPython.display import Image, display try: display(Image(graph.get_graph().draw_mermaid_png())) except Exception: # This requires some extra dependencies and is optional pass # %% [markdown] # ## Invoke # # With the graph created, you can invoke it! Let's have it chart some stats for us. # %% events = graph.stream( { "messages": [ ( "user", "First, get the UK's GDP over the past 5 years, then make a line chart of it. " "Once you make the chart, finish.", ) ], }, # Maximum number of steps to take in the graph {"recursion_limit": 150}, ) for s in events: print(s) print("----") ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/multi_agent/hierarchical_agent_teams.py`: ```py # %% [markdown] # # Hierarchical Agent Teams # # In our previous example ([Agent Supervisor](../agent_supervisor)), we introduced the concept of a single [supervisor node](https://langchain-ai.github.io/langgraph/concepts/multi_agent/#supervisor) to route work between different worker nodes. # # But what if the job for a single worker becomes too complex? What if the number of workers becomes too large? # # For some applications, the system may be more effective if work is distributed _hierarchically_. # # You can do this by composing different subgraphs and creating a top-level supervisor, along with mid-level supervisors. # # To do this, let's build a simple research assistant! The graph will look something like the following: # # ![diagram](attachment:d98ed25c-51cb-441f-a6f4-016921d59fc3.png) # # This notebook is inspired by the paper [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://arxiv.org/abs/2308.08155), by Wu, et. al. In the rest of this notebook, you will: # # 1. Define the agents' tools to access the web and write files # 2. Define some utilities to help create the graph and agents # 3. Create and define each team (web research + doc writing) # 4. Compose everything together. # # ## Setup # # First, let's install our required packages and set our API keys # %% %%capture --no-stderr %pip install -U langgraph langchain_community langchain_anthropic langchain_experimental # %% import getpass import os def _set_if_undefined(var: str): if not os.environ.get(var): os.environ[var] = getpass.getpass(f"Please provide your {var}") _set_if_undefined("OPENAI_API_KEY") _set_if_undefined("TAVILY_API_KEY") # %% [markdown] # <div class="admonition tip"> # <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p> # <p style="padding-top: 5px;"> # Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>. # </p> # </div> # %% [markdown] # ## Create Tools # # Each team will be composed of one or more agents each with one or more tools. Below, define all the tools to be used by your different teams. # # We'll start with the research team. # # **ResearchTeam tools** # # The research team can use a search engine and url scraper to find information on the web. Feel free to add additional functionality below to boost the team performance! # %% from typing import Annotated, List from langchain_community.document_loaders import WebBaseLoader from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.tools import tool tavily_tool = TavilySearchResults(max_results=5) @tool def scrape_webpages(urls: List[str]) -> str: """Use requests and bs4 to scrape the provided web pages for detailed information.""" loader = WebBaseLoader(urls) docs = loader.load() return "\n\n".join( [ f'<Document name="{doc.metadata.get("title", "")}">\n{doc.page_content}\n</Document>' for doc in docs ] ) # %% [markdown] # **Document writing team tools** # # Next up, we will give some tools for the doc writing team to use. # We define some bare-bones file-access tools below. # # Note that this gives the agents access to your file-system, which can be unsafe. We also haven't optimized the tool descriptions for performance. # %% from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict, Optional from langchain_experimental.utilities import PythonREPL from typing_extensions import TypedDict _TEMP_DIRECTORY = TemporaryDirectory() WORKING_DIRECTORY = Path(_TEMP_DIRECTORY.name) @tool def create_outline( points: Annotated[List[str], "List of main points or sections."], file_name: Annotated[str, "File path to save the outline."], ) -> Annotated[str, "Path of the saved outline file."]: """Create and save an outline.""" with (WORKING_DIRECTORY / file_name).open("w") as file: for i, point in enumerate(points): file.write(f"{i + 1}. {point}\n") return f"Outline saved to {file_name}" @tool def read_document( file_name: Annotated[str, "File path to read the document from."], start: Annotated[Optional[int], "The start line. Default is 0"] = None, end: Annotated[Optional[int], "The end line. Default is None"] = None, ) -> str: """Read the specified document.""" with (WORKING_DIRECTORY / file_name).open("r") as file: lines = file.readlines() if start is not None: start = 0 return "\n".join(lines[start:end]) @tool def write_document( content: Annotated[str, "Text content to be written into the document."], file_name: Annotated[str, "File path to save the document."], ) -> Annotated[str, "Path of the saved document file."]: """Create and save a text document.""" with (WORKING_DIRECTORY / file_name).open("w") as file: file.write(content) return f"Document saved to {file_name}" @tool def edit_document( file_name: Annotated[str, "Path of the document to be edited."], inserts: Annotated[ Dict[int, str], "Dictionary where key is the line number (1-indexed) and value is the text to be inserted at that line.", ], ) -> Annotated[str, "Path of the edited document file."]: """Edit a document by inserting text at specific line numbers.""" with (WORKING_DIRECTORY / file_name).open("r") as file: lines = file.readlines() sorted_inserts = sorted(inserts.items()) for line_number, text in sorted_inserts: if 1 <= line_number <= len(lines) + 1: lines.insert(line_number - 1, text + "\n") else: return f"Error: Line number {line_number} is out of range." with (WORKING_DIRECTORY / file_name).open("w") as file: file.writelines(lines) return f"Document edited and saved to {file_name}" # Warning: This executes code locally, which can be unsafe when not sandboxed repl = PythonREPL() @tool def python_repl_tool( code: Annotated[str, "The python code to execute to generate your chart."], ): """Use this to execute python code. If you want to see the output of a value, you should print it out with `print(...)`. This is visible to the user.""" try: result = repl.run(code) except BaseException as e: return f"Failed to execute. Error: {repr(e)}" return f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}" # %% [markdown] # ## Helper Utilities # # We are going to create a few utility functions to make it more concise when we want to: # # 1. Create a worker agent. # 2. Create a supervisor for the sub-graph. # # These will simplify the graph compositional code at the end for us so it's easier to see what's going on. # %% from typing import List, Optional, Literal from langchain_core.language_models.chat_models import BaseChatModel from langgraph.graph import StateGraph, MessagesState, START, END from langgraph.types import Command from langchain_core.messages import HumanMessage, trim_messages def make_supervisor_node(llm: BaseChatModel, members: list[str]) -> str: options = ["FINISH"] + members system_prompt = ( "You are a supervisor tasked with managing a conversation between the" f" following workers: {members}. Given the following user request," " respond with the worker to act next. Each worker will perform a" " task and respond with their results and status. When finished," " respond with FINISH." ) class Router(TypedDict): """Worker to route to next. If no workers needed, route to FINISH.""" next: Literal[*options] def supervisor_node(state: MessagesState) -> Command[Literal[*members, "__end__"]]: """An LLM-based router.""" messages = [ {"role": "system", "content": system_prompt}, ] + state["messages"] response = llm.with_structured_output(Router).invoke(messages) goto = response["next"] if goto == "FINISH": goto = END return Command(goto=goto) return supervisor_node # %% [markdown] # ## Define Agent Teams # # Now we can get to define our hierarchical teams. "Choose your player!" # # ### Research Team # # The research team will have a search agent and a web scraping "research_agent" as the two worker nodes. Let's create those, as well as the team supervisor. # %% from langchain_core.messages import HumanMessage from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent llm = ChatOpenAI(model="gpt-4o") search_agent = create_react_agent(llm, tools=[tavily_tool]) def search_node(state: MessagesState) -> Command[Literal["supervisor"]]: result = search_agent.invoke(state) return Command( update={ "messages": [ HumanMessage(content=result["messages"][-1].content, name="search") ] }, # We want our workers to ALWAYS "report back" to the supervisor when done goto="supervisor", ) web_scraper_agent = create_react_agent(llm, tools=[scrape_webpages]) def web_scraper_node(state: MessagesState) -> Command[Literal["supervisor"]]: result = web_scraper_agent.invoke(state) return Command( update={ "messages": [ HumanMessage(content=result["messages"][-1].content, name="web_scraper") ] }, # We want our workers to ALWAYS "report back" to the supervisor when done goto="supervisor", ) research_supervisor_node = make_supervisor_node(llm, ["search", "web_scraper"]) # %% [markdown] # Now that we've created the necessary components, defining their interactions is easy. Add the nodes to the team graph, and define the edges, which determine the transition criteria. # %% research_builder = StateGraph(MessagesState) research_builder.add_node("supervisor", research_supervisor_node) research_builder.add_node("search", search_node) research_builder.add_node("web_scraper", web_scraper_node) research_builder.add_edge(START, "supervisor") research_graph = research_builder.compile() # %% from IPython.display import Image, display display(Image(research_graph.get_graph().draw_mermaid_png())) # %% [markdown] # We can give this team work directly. Try it out below. # %% for s in research_graph.stream( {"messages": [("user", "when is Taylor Swift's next tour?")]}, {"recursion_limit": 100}, ): print(s) print("---") # %% [markdown] # ### Document Writing Team # # Create the document writing team below using a similar approach. This time, we will give each agent access to different file-writing tools. # # Note that we are giving file-system access to our agent here, which is not safe in all cases. # %% llm = ChatOpenAI(model="gpt-4o") doc_writer_agent = create_react_agent( llm, tools=[write_document, edit_document, read_document], state_modifier=( "You can read, write and edit documents based on note-taker's outlines. " "Don't ask follow-up questions." ), ) def doc_writing_node(state: MessagesState) -> Command[Literal["supervisor"]]: result = doc_writer_agent.invoke(state) return Command( update={ "messages": [ HumanMessage(content=result["messages"][-1].content, name="doc_writer") ] }, # We want our workers to ALWAYS "report back" to the supervisor when done goto="supervisor", ) note_taking_agent = create_react_agent( llm, tools=[create_outline, read_document], state_modifier=( "You can read documents and create outlines for the document writer. " "Don't ask follow-up questions." ), ) def note_taking_node(state: MessagesState) -> Command[Literal["supervisor"]]: result = note_taking_agent.invoke(state) return Command( update={ "messages": [ HumanMessage(content=result["messages"][-1].content, name="note_taker") ] }, # We want our workers to ALWAYS "report back" to the supervisor when done goto="supervisor", ) chart_generating_agent = create_react_agent( llm, tools=[read_document, python_repl_tool] ) def chart_generating_node(state: MessagesState) -> Command[Literal["supervisor"]]: result = chart_generating_agent.invoke(state) return Command( update={ "messages": [ HumanMessage( content=result["messages"][-1].content, name="chart_generator" ) ] }, # We want our workers to ALWAYS "report back" to the supervisor when done goto="supervisor", ) doc_writing_supervisor_node = make_supervisor_node( llm, ["doc_writer", "note_taker", "chart_generator"] ) # %% [markdown] # With the objects themselves created, we can form the graph. # %% # Create the graph here paper_writing_builder = StateGraph(MessagesState) paper_writing_builder.add_node("supervisor", doc_writing_supervisor_node) paper_writing_builder.add_node("doc_writer", doc_writing_node) paper_writing_builder.add_node("note_taker", note_taking_node) paper_writing_builder.add_node("chart_generator", chart_generating_node) paper_writing_builder.add_edge(START, "supervisor") paper_writing_graph = paper_writing_builder.compile() # %% from IPython.display import Image, display display(Image(paper_writing_graph.get_graph().draw_mermaid_png())) # %% for s in paper_writing_graph.stream( { "messages": [ ( "user", "Write an outline for poem about cats and then write the poem to disk.", ) ] }, {"recursion_limit": 100}, ): print(s) print("---") # %% [markdown] # ## Add Layers # # In this design, we are enforcing a top-down planning policy. We've created two graphs already, but we have to decide how to route work between the two. # # We'll create a _third_ graph to orchestrate the previous two, and add some connectors to define how this top-level state is shared between the different graphs. # %% from langchain_core.messages import BaseMessage llm = ChatOpenAI(model="gpt-4o") teams_supervisor_node = make_supervisor_node(llm, ["research_team", "writing_team"]) # %% def call_research_team(state: MessagesState) -> Command[Literal["supervisor"]]: response = research_graph.invoke({"messages": state["messages"][-1]}) return Command( update={ "messages": [ HumanMessage( content=response["messages"][-1].content, name="research_team" ) ] }, goto="supervisor", ) def call_paper_writing_team(state: MessagesState) -> Command[Literal["supervisor"]]: response = paper_writing_graph.invoke({"messages": state["messages"][-1]}) return Command( update={ "messages": [ HumanMessage( content=response["messages"][-1].content, name="writing_team" ) ] }, goto="supervisor", ) # Define the graph. super_builder = StateGraph(MessagesState) super_builder.add_node("supervisor", teams_supervisor_node) super_builder.add_node("research_team", call_research_team) super_builder.add_node("writing_team", call_paper_writing_team) super_builder.add_edge(START, "supervisor") super_graph = super_builder.compile() # %% from IPython.display import Image, display display(Image(super_graph.get_graph().draw_mermaid_png())) # %% for s in super_graph.stream( { "messages": [ ("user", "Research AI agents and write a brief report about them.") ], }, {"recursion_limit": 150}, ): print(s) print("---") ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/plan-and-execute/plan-and-execute.py`: ```py # %% [markdown] # # Plan-and-Execute # # This notebook shows how to create a "plan-and-execute" style agent. This is heavily inspired by the [Plan-and-Solve](https://arxiv.org/abs/2305.04091) paper as well as the [Baby-AGI](https://github.com/yoheinakajima/babyagi) project. # # The core idea is to first come up with a multi-step plan, and then go through that plan one item at a time. # After accomplishing a particular task, you can then revisit the plan and modify as appropriate. # # # The general computational graph looks like the following: # # ![plan-and-execute diagram](attachment:86cf6404-3d9b-41cb-ab97-5e451f576620.png) # # # This compares to a typical [ReAct](https://arxiv.org/abs/2210.03629) style agent where you think one step at a time. # The advantages of this "plan-and-execute" style agent are: # # 1. Explicit long term planning (which even really strong LLMs can struggle with) # 2. Ability to use smaller/weaker models for the execution step, only using larger/better models for the planning step # # # The following walkthrough demonstrates how to do so in LangGraph. The resulting agent will leave a trace like the following example: ([link](https://smith.langchain.com/public/d46e24d3-dda6-44d5-9550-b618fca4e0d4/r)). # %% [markdown] # ## Setup # # First, we need to install the packages required. # %% # %%capture --no-stderr # %pip install --quiet -U langgraph langchain-community langchain-openai tavily-python # %% [markdown] # Next, we need to set API keys for OpenAI (the LLM we will use) and Tavily (the search tool we will use) # %% import getpass import os def _set_env(var: str): if not os.environ.get(var): os.environ[var] = getpass.getpass(f"{var}: ") _set_env("OPENAI_API_KEY") _set_env("TAVILY_API_KEY") # %% [markdown] # <div class="admonition tip"> # <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p> # <p style="padding-top: 5px;"> # Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>. # </p> # </div> # %% [markdown] # ## Define Tools # # We will first define the tools we want to use. For this simple example, we will use a built-in search tool via Tavily. However, it is really easy to create your own tools - see documentation [here](https://python.langchain.com/docs/how_to/custom_tools) on how to do that. # %% from langchain_community.tools.tavily_search import TavilySearchResults tools = [TavilySearchResults(max_results=3)] # %% [markdown] # ## Define our Execution Agent # # Now we will create the execution agent we want to use to execute tasks. # Note that for this example, we will be using the same execution agent for each task, but this doesn't HAVE to be the case. # %% from langchain import hub from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent # Get the prompt to use - you can modify this! prompt = hub.pull("ih/ih-react-agent-executor") prompt.pretty_print() # Choose the LLM that will drive the agent llm = ChatOpenAI(model="gpt-4-turbo-preview") agent_executor = create_react_agent(llm, tools, state_modifier=prompt) # %% agent_executor.invoke({"messages": [("user", "who is the winnner of the us open")]}) # %% [markdown] # ## Define the State # # Let's now start by defining the state the track for this agent. # # First, we will need to track the current plan. Let's represent that as a list of strings. # # Next, we should track previously executed steps. Let's represent that as a list of tuples (these tuples will contain the step and then the result) # # Finally, we need to have some state to represent the final response as well as the original input. # %% import operator from typing import Annotated, List, Tuple from typing_extensions import TypedDict class PlanExecute(TypedDict): input: str plan: List[str] past_steps: Annotated[List[Tuple], operator.add] response: str # %% [markdown] # ## Planning Step # # Let's now think about creating the planning step. This will use function calling to create a plan. # %% [markdown] # <div class="admonition note"> # <p class="admonition-title">Using Pydantic with LangChain</p> # <p> # This notebook uses Pydantic v2 <code>BaseModel</code>, which requires <code>langchain-core >= 0.3</code>. Using <code>langchain-core < 0.3</code> will result in errors due to mixing of Pydantic v1 and v2 <code>BaseModels</code>. # </p> # </div> # %% from pydantic import BaseModel, Field class Plan(BaseModel): """Plan to follow in future""" steps: List[str] = Field( description="different steps to follow, should be in sorted order" ) # %% from langchain_core.prompts import ChatPromptTemplate planner_prompt = ChatPromptTemplate.from_messages( [ ( "system", """For the given objective, come up with a simple step by step plan. \ This plan should involve individual tasks, that if executed correctly will yield the correct answer. Do not add any superfluous steps. \ The result of the final step should be the final answer. Make sure that each step has all the information needed - do not skip steps.""", ), ("placeholder", "{messages}"), ] ) planner = planner_prompt | ChatOpenAI( model="gpt-4o", temperature=0 ).with_structured_output(Plan) # %% planner.invoke( { "messages": [ ("user", "what is the hometown of the current Australia open winner?") ] } ) # %% [markdown] # ## Re-Plan Step # # Now, let's create a step that re-does the plan based on the result of the previous step. # %% from typing import Union class Response(BaseModel): """Response to user.""" response: str class Act(BaseModel): """Action to perform.""" action: Union[Response, Plan] = Field( description="Action to perform. If you want to respond to user, use Response. " "If you need to further use tools to get the answer, use Plan." ) replanner_prompt = ChatPromptTemplate.from_template( """For the given objective, come up with a simple step by step plan. \ This plan should involve individual tasks, that if executed correctly will yield the correct answer. Do not add any superfluous steps. \ The result of the final step should be the final answer. Make sure that each step has all the information needed - do not skip steps. Your objective was this: {input} Your original plan was this: {plan} You have currently done the follow steps: {past_steps} Update your plan accordingly. If no more steps are needed and you can return to the user, then respond with that. Otherwise, fill out the plan. Only add steps to the plan that still NEED to be done. Do not return previously done steps as part of the plan.""" ) replanner = replanner_prompt | ChatOpenAI( model="gpt-4o", temperature=0 ).with_structured_output(Act) # %% [markdown] # ## Create the Graph # # We can now create the graph! # %% from typing import Literal from langgraph.graph import END async def execute_step(state: PlanExecute): plan = state["plan"] plan_str = "\n".join(f"{i+1}. {step}" for i, step in enumerate(plan)) task = plan[0] task_formatted = f"""For the following plan: {plan_str}\n\nYou are tasked with executing step {1}, {task}.""" agent_response = await agent_executor.ainvoke( {"messages": [("user", task_formatted)]} ) return { "past_steps": [(task, agent_response["messages"][-1].content)], } async def plan_step(state: PlanExecute): plan = await planner.ainvoke({"messages": [("user", state["input"])]}) return {"plan": plan.steps} async def replan_step(state: PlanExecute): output = await replanner.ainvoke(state) if isinstance(output.action, Response): return {"response": output.action.response} else: return {"plan": output.action.steps} def should_end(state: PlanExecute): if "response" in state and state["response"]: return END else: return "agent" # %% from langgraph.graph import StateGraph, START workflow = StateGraph(PlanExecute) # Add the plan node workflow.add_node("planner", plan_step) # Add the execution step workflow.add_node("agent", execute_step) # Add a replan node workflow.add_node("replan", replan_step) workflow.add_edge(START, "planner") # From plan we go to agent workflow.add_edge("planner", "agent") # From agent, we replan workflow.add_edge("agent", "replan") workflow.add_conditional_edges( "replan", # Next, we pass in the function that will determine which node is called next. should_end, ["agent", END], ) # Finally, we compile it! # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable app = workflow.compile() # %% from IPython.display import Image, display display(Image(app.get_graph(xray=True).draw_mermaid_png())) # %% config = {"recursion_limit": 50} inputs = {"input": "what is the hometown of the mens 2024 Australia open winner?"} async for event in app.astream(inputs, config=config): for k, v in event.items(): if k != "__end__": print(v) # %% [markdown] # ## Conclusion # # Congrats on making a plan-and-execute agent! One known limitations of the above design is that each task is still executed in sequence, meaning embarrassingly parallel operations all add to the total execution time. You could improve on this by having each task represented as a DAG (similar to LLMCompiler), rather than a regular list. ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/reflexion/reflexion.py`: ```py # %% [markdown] # # Reflexion # # [Reflexion](https://arxiv.org/abs/2303.11366) by Shinn, et. al., is an architecture designed to learn through verbal feedback and self-reflection. The agent explicitly critiques its responses for tasks to generate a higher quality final response, at the expense of longer execution time. # # ![reflexion diagram](attachment:2f424259-8d89-4f4e-94c4-d668a36d8ca2.png) # # The paper outlines 3 main components: # # 1. Actor (agent) with self-reflection # 2. External evaluator (task-specific, e.g. code compilation steps) # 3. Episodic memory that stores the reflections from (1). # # In their code, the last two components are very task-specific, so in this notebook, you will build the _actor_ in LangGraph. # # To skip to the graph definition, see the [Construct Graph section](#Construct-Graph) below. # %% [markdown] # ## Setup # # Install `langgraph` (for the framework), `langchain_openai` (for the LLM), and `langchain` + `tavily-python` (for the search engine). # # We will use tavily search as a tool. You can get an API key [here](https://app.tavily.com/sign-in) or replace with a different tool of your choosing. # %% # %pip install -U --quiet langgraph langchain_anthropic tavily-python # %% import getpass import os def _set_if_undefined(var: str) -> None: if os.environ.get(var): return os.environ[var] = getpass.getpass(var) _set_if_undefined("ANTHROPIC_API_KEY") _set_if_undefined("TAVILY_API_KEY") # %% [markdown] # <div class="admonition tip"> # <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p> # <p style="padding-top: 5px;"> # Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>. # </p> # </div> # # ### Define our LLM # %% from langchain_anthropic import ChatAnthropic llm = ChatAnthropic(model="claude-3-5-sonnet-20240620") # You could also use OpenAI or another provider # from langchain_openai import ChatOpenAI # llm = ChatOpenAI(model="gpt-4-turbo-preview") # %% [markdown] # ## Actor (with reflection) # # The main component of Reflexion is the "actor", which is an agent that reflects on its response and re-executes to improve based on self-critique. It's main sub-components include: # 1. Tools/tool execution # 2. Initial responder: generate an initial response (and self-reflection) # 3. Revisor: re-respond (and reflec) based on previous reflections # # We'll first define the tool execution context. # # #### Construct tools # %% from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper search = TavilySearchAPIWrapper() tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5) # %% [markdown] # #### Initial responder # %% [markdown] # <div class="admonition note"> # <p class="admonition-title">Using Pydantic with LangChain</p> # <p> # This notebook uses Pydantic v2 <code>BaseModel</code>, which requires <code>langchain-core >= 0.3</code>. Using <code>langchain-core < 0.3</code> will result in errors due to mixing of Pydantic v1 and v2 <code>BaseModels</code>. # </p> # </div> # %% from langchain_core.messages import HumanMessage, ToolMessage from langchain_core.output_parsers.openai_tools import PydanticToolsParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from pydantic import ValidationError from pydantic import BaseModel, Field class Reflection(BaseModel): missing: str = Field(description="Critique of what is missing.") superfluous: str = Field(description="Critique of what is superfluous") class AnswerQuestion(BaseModel): """Answer the question. Provide an answer, reflection, and then follow up with search queries to improve the answer.""" answer: str = Field(description="~250 word detailed answer to the question.") reflection: Reflection = Field(description="Your reflection on the initial answer.") search_queries: list[str] = Field( description="1-3 search queries for researching improvements to address the critique of your current answer." ) class ResponderWithRetries: def __init__(self, runnable, validator): self.runnable = runnable self.validator = validator def respond(self, state: dict): response = [] for attempt in range(3): response = self.runnable.invoke( {"messages": state["messages"]}, {"tags": [f"attempt:{attempt}"]} ) try: self.validator.invoke(response) return {"messages": response} except ValidationError as e: state = state + [ response, ToolMessage( content=f"{repr(e)}\n\nPay close attention to the function schema.\n\n" + self.validator.schema_json() + " Respond by fixing all validation errors.", tool_call_id=response.tool_calls[0]["id"], ), ] return {"messages": response} # %% import datetime actor_prompt_template = ChatPromptTemplate.from_messages( [ ( "system", """You are expert researcher. Current time: {time} 1. {first_instruction} 2. Reflect and critique your answer. Be severe to maximize improvement. 3. Recommend search queries to research information and improve your answer.""", ), MessagesPlaceholder(variable_name="messages"), ( "user", "\n\n<system>Reflect on the user's original question and the" " actions taken thus far. Respond using the {function_name} function.</reminder>", ), ] ).partial( time=lambda: datetime.datetime.now().isoformat(), ) initial_answer_chain = actor_prompt_template.partial( first_instruction="Provide a detailed ~250 word answer.", function_name=AnswerQuestion.__name__, ) | llm.bind_tools(tools=[AnswerQuestion]) validator = PydanticToolsParser(tools=[AnswerQuestion]) first_responder = ResponderWithRetries( runnable=initial_answer_chain, validator=validator ) # %% example_question = "Why is reflection useful in AI?" initial = first_responder.respond( {"messages": [HumanMessage(content=example_question)]} ) # %% [markdown] # #### Revision # # The second part of the actor is a revision step. # %% revise_instructions = """Revise your previous answer using the new information. - You should use the previous critique to add important information to your answer. - You MUST include numerical citations in your revised answer to ensure it can be verified. - Add a "References" section to the bottom of your answer (which does not count towards the word limit). In form of: - [1] https://example.com - [2] https://example.com - You should use the previous critique to remove superfluous information from your answer and make SURE it is not more than 250 words. """ # Extend the initial answer schema to include references. # Forcing citation in the model encourages grounded responses class ReviseAnswer(AnswerQuestion): """Revise your original answer to your question. Provide an answer, reflection, cite your reflection with references, and finally add search queries to improve the answer.""" references: list[str] = Field( description="Citations motivating your updated answer." ) revision_chain = actor_prompt_template.partial( first_instruction=revise_instructions, function_name=ReviseAnswer.__name__, ) | llm.bind_tools(tools=[ReviseAnswer]) revision_validator = PydanticToolsParser(tools=[ReviseAnswer]) revisor = ResponderWithRetries(runnable=revision_chain, validator=revision_validator) # %% import json revised = revisor.respond( { "messages": [ HumanMessage(content=example_question), initial["messages"], ToolMessage( tool_call_id=initial["messages"].tool_calls[0]["id"], content=json.dumps( tavily_tool.invoke( { "query": initial["messages"].tool_calls[0]["args"][ "search_queries" ][0] } ) ), ), ] } ) revised["messages"] # %% [markdown] # ## Create Tool Node # # Next, create a node to execute the tool calls. While we give the LLMs different schema names (and use those for validation), we want them both to route to the same tool. # %% from langchain_core.tools import StructuredTool from langgraph.prebuilt import ToolNode def run_queries(search_queries: list[str], **kwargs): """Run the generated queries.""" return tavily_tool.batch([{"query": query} for query in search_queries]) tool_node = ToolNode( [ StructuredTool.from_function(run_queries, name=AnswerQuestion.__name__), StructuredTool.from_function(run_queries, name=ReviseAnswer.__name__), ] ) # %% [markdown] # ## Construct Graph # # # Now we can wire all our components together. # %% from typing import Literal from langgraph.graph import END, StateGraph, START from langgraph.graph.message import add_messages from typing import Annotated from typing_extensions import TypedDict class State(TypedDict): messages: Annotated[list, add_messages] MAX_ITERATIONS = 5 builder = StateGraph(State) builder.add_node("draft", first_responder.respond) builder.add_node("execute_tools", tool_node) builder.add_node("revise", revisor.respond) # draft -> execute_tools builder.add_edge("draft", "execute_tools") # execute_tools -> revise builder.add_edge("execute_tools", "revise") # Define looping logic: def _get_num_iterations(state: list): i = 0 for m in state[::-1]: if m.type not in {"tool", "ai"}: break i += 1 return i def event_loop(state: list): # in our case, we'll just stop after N plans num_iterations = _get_num_iterations(state["messages"]) if num_iterations > MAX_ITERATIONS: return END return "execute_tools" # revise -> execute_tools OR end builder.add_conditional_edges("revise", event_loop, ["execute_tools", END]) builder.add_edge(START, "draft") graph = builder.compile() # %% from IPython.display import Image, display try: display(Image(graph.get_graph().draw_mermaid_png())) except Exception: # This requires some extra dependencies and is optional pass # %% events = graph.stream( {"messages": [("user", "How should we handle the climate crisis?")]}, stream_mode="values", ) for i, step in enumerate(events): print(f"Step {i}") step["messages"][-1].pretty_print() # %% [markdown] # ## Conclusion # # Congrats on building a Reflexion actor! I'll leave you with a few observations to save you some time when choosing which parts of this agent to adapt to your workflow: # 1. This agent trades off execution time for quality. It explicitly forces the agent to critique and revise the output over several steps, which usually (not always) increases the response quality but takes much longer to return a final answer # 2. The 'reflections' can be paired with additional external feedback (such as validators), to further guide the actor. # 3. In the paper, 1 environment (AlfWorld) uses external memory. It does this by storing summaries of the reflections to an external store and using them in subsequent trials/invocations. ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/code_assistant/langgraph_code_assistant.py`: ```py # %% [markdown] # # Code generation with RAG and self-correction # # AlphaCodium presented an approach for code generation that uses control flow. # # Main idea: [construct an answer to a coding question iteratively.](https://x.com/karpathy/status/1748043513156272416?s=20). # # [AlphaCodium](https://github.com/Codium-ai/AlphaCodium) iteravely tests and improves an answer on public and AI-generated tests for a particular question. # # We will implement some of these ideas from scratch using [LangGraph](https://langchain-ai.github.io/langgraph/): # # 1. We start with a set of documentation specified by a user # 2. We use a long context LLM to ingest it and perform RAG to answer a question based upon it # 3. We will invoke a tool to produce a structured output # 4. We will perform two unit tests (check imports and code execution) prior returning the solution to the user # # ![Screenshot 2024-05-23 at 2.17.42 PM.png](attachment:67b615fe-0c25-4410-9d58-835982547001.png) # %% [markdown] # ## Setup # # First, let's install our required packages and set the API keys we will need # %% # ! pip install -U langchain_community langchain-openai langchain-anthropic langchain langgraph bs4 # %% import getpass import os def _set_env(var: str): if not os.environ.get(var): os.environ[var] = getpass.getpass(f"{var}: ") _set_env("OPENAI_API_KEY") _set_env("ANTHROPIC_API_KEY") # %% [markdown] # <div class="admonition tip"> # <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p> # <p style="padding-top: 5px;"> # Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>. # </p> # </div> # %% [markdown] # ## Docs # # Load [LangChain Expression Language](https://python.langchain.com/docs/concepts/#langchain-expression-language-lcel) (LCEL) docs as an example. # %% from bs4 import BeautifulSoup as Soup from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader # LCEL docs url = "https://python.langchain.com/docs/concepts/lcel/" loader = RecursiveUrlLoader( url=url, max_depth=20, extractor=lambda x: Soup(x, "html.parser").text ) docs = loader.load() # Sort the list based on the URLs and get the text d_sorted = sorted(docs, key=lambda x: x.metadata["source"]) d_reversed = list(reversed(d_sorted)) concatenated_content = "\n\n\n --- \n\n\n".join( [doc.page_content for doc in d_reversed] ) # %% [markdown] # ## LLMs # # ### Code solution # # First, we will try OpenAI and [Claude3](https://docs.anthropic.com/en/docs/about-claude/models) with function calling. # # We will create a `code_gen_chain` w/ either OpenAI or Claude and test them here. # %% [markdown] # <div class="admonition note"> # <p class="admonition-title">Using Pydantic with LangChain</p> # <p> # This notebook uses Pydantic v2 <code>BaseModel</code>, which requires <code>langchain-core >= 0.3</code>. Using <code>langchain-core < 0.3</code> will result in errors due to mixing of Pydantic v1 and v2 <code>BaseModels</code>. # </p> # </div> # %% from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field ### OpenAI # Grader prompt code_gen_prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are a coding assistant with expertise in LCEL, LangChain expression language. \n Here is a full set of LCEL documentation: \n ------- \n {context} \n ------- \n Answer the user question based on the above provided documentation. Ensure any code you provide can be executed \n with all required imports and variables defined. Structure your answer with a description of the code solution. \n Then list the imports. And finally list the functioning code block. Here is the user question:""", ), ("placeholder", "{messages}"), ] ) # Data model class code(BaseModel): """Schema for code solutions to questions about LCEL.""" prefix: str = Field(description="Description of the problem and approach") imports: str = Field(description="Code block import statements") code: str = Field(description="Code block not including import statements") expt_llm = "gpt-4o-mini" llm = ChatOpenAI(temperature=0, model=expt_llm) code_gen_chain_oai = code_gen_prompt | llm.with_structured_output(code) question = "How do I build a RAG chain in LCEL?" solution = code_gen_chain_oai.invoke( {"context": concatenated_content, "messages": [("user", question)]} ) solution # %% from langchain_anthropic import ChatAnthropic from langchain_core.prompts import ChatPromptTemplate ### Anthropic # Prompt to enforce tool use code_gen_prompt_claude = ChatPromptTemplate.from_messages( [ ( "system", """<instructions> You are a coding assistant with expertise in LCEL, LangChain expression language. \n Here is the LCEL documentation: \n ------- \n {context} \n ------- \n Answer the user question based on the \n above provided documentation. Ensure any code you provide can be executed with all required imports and variables \n defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n Invoke the code tool to structure the output correctly. </instructions> \n Here is the user question:""", ), ("placeholder", "{messages}"), ] ) # LLM expt_llm = "claude-3-opus-20240229" llm = ChatAnthropic( model=expt_llm, default_headers={"anthropic-beta": "tools-2024-04-04"}, ) structured_llm_claude = llm.with_structured_output(code, include_raw=True) # Optional: Check for errors in case tool use is flaky def check_claude_output(tool_output): """Check for parse error or failure to call the tool""" # Error with parsing if tool_output["parsing_error"]: # Report back output and parsing errors print("Parsing error!") raw_output = str(tool_output["raw"].content) error = tool_output["parsing_error"] raise ValueError( f"Error parsing your output! Be sure to invoke the tool. Output: {raw_output}. \n Parse error: {error}" ) # Tool was not invoked elif not tool_output["parsed"]: print("Failed to invoke tool!") raise ValueError( "You did not use the provided tool! Be sure to invoke the tool to structure the output." ) return tool_output # Chain with output check code_chain_claude_raw = ( code_gen_prompt_claude | structured_llm_claude | check_claude_output ) def insert_errors(inputs): """Insert errors for tool parsing in the messages""" # Get errors error = inputs["error"] messages = inputs["messages"] messages += [ ( "assistant", f"Retry. You are required to fix the parsing errors: {error} \n\n You must invoke the provided tool.", ) ] return { "messages": messages, "context": inputs["context"], } # This will be run as a fallback chain fallback_chain = insert_errors | code_chain_claude_raw N = 3 # Max re-tries code_gen_chain_re_try = code_chain_claude_raw.with_fallbacks( fallbacks=[fallback_chain] * N, exception_key="error" ) def parse_output(solution): """When we add 'include_raw=True' to structured output, it will return a dict w 'raw', 'parsed', 'parsing_error'.""" return solution["parsed"] # Optional: With re-try to correct for failure to invoke tool code_gen_chain = code_gen_chain_re_try | parse_output # No re-try code_gen_chain = code_gen_prompt_claude | structured_llm_claude | parse_output # %% # Test question = "How do I build a RAG chain in LCEL?" solution = code_gen_chain.invoke( {"context": concatenated_content, "messages": [("user", question)]} ) solution # %% [markdown] # ## State # # Our state is a dict that will contain keys (errors, question, code generation) relevant to code generation. # %% from typing import List from typing_extensions import TypedDict class GraphState(TypedDict): """ Represents the state of our graph. Attributes: error : Binary flag for control flow to indicate whether test error was tripped messages : With user question, error messages, reasoning generation : Code solution iterations : Number of tries """ error: str messages: List generation: str iterations: int # %% [markdown] # ## Graph # # Our graph lays out the logical flow shown in the figure above. # %% ### Parameter # Max tries max_iterations = 3 # Reflect # flag = 'reflect' flag = "do not reflect" ### Nodes def generate(state: GraphState): """ Generate a code solution Args: state (dict): The current graph state Returns: state (dict): New key added to state, generation """ print("---GENERATING CODE SOLUTION---") # State messages = state["messages"] iterations = state["iterations"] error = state["error"] # We have been routed back to generation with an error if error == "yes": messages += [ ( "user", "Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:", ) ] # Solution code_solution = code_gen_chain.invoke( {"context": concatenated_content, "messages": messages} ) messages += [ ( "assistant", f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}", ) ] # Increment iterations = iterations + 1 return {"generation": code_solution, "messages": messages, "iterations": iterations} def code_check(state: GraphState): """ Check code Args: state (dict): The current graph state Returns: state (dict): New key added to state, error """ print("---CHECKING CODE---") # State messages = state["messages"] code_solution = state["generation"] iterations = state["iterations"] # Get solution components imports = code_solution.imports code = code_solution.code # Check imports try: exec(imports) except Exception as e: print("---CODE IMPORT CHECK: FAILED---") error_message = [("user", f"Your solution failed the import test: {e}")] messages += error_message return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes", } # Check execution try: exec(imports + "\n" + code) except Exception as e: print("---CODE BLOCK CHECK: FAILED---") error_message = [("user", f"Your solution failed the code execution test: {e}")] messages += error_message return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes", } # No errors print("---NO CODE TEST FAILURES---") return { "generation": code_solution, "messages": messages, "iterations": iterations, "error": "no", } def reflect(state: GraphState): """ Reflect on errors Args: state (dict): The current graph state Returns: state (dict): New key added to state, generation """ print("---GENERATING CODE SOLUTION---") # State messages = state["messages"] iterations = state["iterations"] code_solution = state["generation"] # Prompt reflection # Add reflection reflections = code_gen_chain.invoke( {"context": concatenated_content, "messages": messages} ) messages += [("assistant", f"Here are reflections on the error: {reflections}")] return {"generation": code_solution, "messages": messages, "iterations": iterations} ### Edges def decide_to_finish(state: GraphState): """ Determines whether to finish. Args: state (dict): The current graph state Returns: str: Next node to call """ error = state["error"] iterations = state["iterations"] if error == "no" or iterations == max_iterations: print("---DECISION: FINISH---") return "end" else: print("---DECISION: RE-TRY SOLUTION---") if flag == "reflect": return "reflect" else: return "generate" # %% from langgraph.graph import END, StateGraph, START workflow = StateGraph(GraphState) # Define the nodes workflow.add_node("generate", generate) # generation solution workflow.add_node("check_code", code_check) # check code workflow.add_node("reflect", reflect) # reflect # Build graph workflow.add_edge(START, "generate") workflow.add_edge("generate", "check_code") workflow.add_conditional_edges( "check_code", decide_to_finish, { "end": END, "reflect": "reflect", "generate": "generate", }, ) workflow.add_edge("reflect", "generate") app = workflow.compile() # %% question = "How can I directly pass a string to a runnable and use it to construct the input needed for my prompt?" solution = app.invoke({"messages": [("user", question)], "iterations": 0, "error": ""}) # %% solution["generation"] # %% [markdown] # ## Eval # %% [markdown] # [Here](https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d) is a public dataset of LCEL questions. # # I saved this as `lcel-teacher-eval`. # # You can also find the csv [here](https://github.com/langchain-ai/lcel-teacher/blob/main/eval/eval.csv). # %% import langsmith client = langsmith.Client() # %% # Clone the dataset to your tenant to use it try: public_dataset = ( "https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d" ) client.clone_public_dataset(public_dataset) except: print("Please setup LangSmith") # %% [markdown] # Custom evals. # %% from langsmith.schemas import Example, Run def check_import(run: Run, example: Example) -> dict: imports = run.outputs.get("imports") try: exec(imports) return {"key": "import_check", "score": 1} except Exception: return {"key": "import_check", "score": 0} def check_execution(run: Run, example: Example) -> dict: imports = run.outputs.get("imports") code = run.outputs.get("code") try: exec(imports + "\n" + code) return {"key": "code_execution_check", "score": 1} except Exception: return {"key": "code_execution_check", "score": 0} # %% [markdown] # Compare LangGraph to Context Stuffing. # %% def predict_base_case(example: dict): """Context stuffing""" solution = code_gen_chain.invoke( {"context": concatenated_content, "messages": [("user", example["question"])]} ) return {"imports": solution.imports, "code": solution.code} def predict_langgraph(example: dict): """LangGraph""" graph = app.invoke( {"messages": [("user", example["question"])], "iterations": 0, "error": ""} ) solution = graph["generation"] return {"imports": solution.imports, "code": solution.code} # %% from langsmith.evaluation import evaluate # Evaluator code_evalulator = [check_import, check_execution] # Dataset dataset_name = "lcel-teacher-eval" # %% # Run base case try: experiment_results_ = evaluate( predict_base_case, data=dataset_name, evaluators=code_evalulator, experiment_prefix=f"test-without-langgraph-{expt_llm}", max_concurrency=2, metadata={ "llm": expt_llm, }, ) except: print("Please setup LangSmith") # %% # Run with langgraph try: experiment_results = evaluate( predict_langgraph, data=dataset_name, evaluators=code_evalulator, experiment_prefix=f"test-with-langgraph-{expt_llm}-{flag}", max_concurrency=2, metadata={ "llm": expt_llm, "feedback": flag, }, ) except: print("Please setup LangSmith") # %% [markdown] # `Results:` # # * `LangGraph outperforms base case`: adding re-try loop improve performance # * `Reflection did not help`: reflection prior to re-try regression vs just passing errors directly back to the LLM # * `GPT-4 outperforms Claude3`: Claude3 had 3 and 1 run fail due to tool-use error for Opus and Haiku, respectively # # https://smith.langchain.com/public/78a3d858-c811-4e46-91cb-0f10ef56260b/d ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/llm-compiler/output_parser.py`: ```py import ast import re from typing import ( Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union, ) from langchain_core.exceptions import OutputParserException from langchain_core.messages import BaseMessage from langchain_core.output_parsers.transform import BaseTransformOutputParser from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from typing_extensions import TypedDict THOUGHT_PATTERN = r"Thought: ([^\n]*)" ACTION_PATTERN = r"\n*(\d+)\. (\w+)\((.*)\)(\s*#\w+\n)?" # $1 or ${1} -> 1 ID_PATTERN = r"\$\{?(\d+)\}?" END_OF_PLAN = "<END_OF_PLAN>" ### Helper functions def _ast_parse(arg: str) -> Any: try: return ast.literal_eval(arg) except: # noqa return arg def _parse_llm_compiler_action_args(args: str, tool: Union[str, BaseTool]) -> list[Any]: """Parse arguments from a string.""" if args == "": return () if isinstance(tool, str): return () extracted_args = {} tool_key = None prev_idx = None for key in tool.args.keys(): # Split if present if f"{key}=" in args: idx = args.index(f"{key}=") if prev_idx is not None: extracted_args[tool_key] = _ast_parse( args[prev_idx:idx].strip().rstrip(",") ) args = args.split(f"{key}=", 1)[1] tool_key = key prev_idx = 0 if prev_idx is not None: extracted_args[tool_key] = _ast_parse( args[prev_idx:].strip().rstrip(",").rstrip(")") ) return extracted_args def default_dependency_rule(idx, args: str): matches = re.findall(ID_PATTERN, args) numbers = [int(match) for match in matches] return idx in numbers def _get_dependencies_from_graph( idx: int, tool_name: str, args: Dict[str, Any] ) -> dict[str, list[str]]: """Get dependencies from a graph.""" if tool_name == "join": return list(range(1, idx)) return [i for i in range(1, idx) if default_dependency_rule(i, str(args))] class Task(TypedDict): idx: int tool: BaseTool args: list dependencies: Dict[str, list] thought: Optional[str] def instantiate_task( tools: Sequence[BaseTool], idx: int, tool_name: str, args: Union[str, Any], thought: Optional[str] = None, ) -> Task: if tool_name == "join": tool = "join" else: try: tool = tools[[tool.name for tool in tools].index(tool_name)] except ValueError as e: raise OutputParserException(f"Tool {tool_name} not found.") from e tool_args = _parse_llm_compiler_action_args(args, tool) dependencies = _get_dependencies_from_graph(idx, tool_name, tool_args) return Task( idx=idx, tool=tool, args=tool_args, dependencies=dependencies, thought=thought, ) class LLMCompilerPlanParser(BaseTransformOutputParser[dict], extra="allow"): """Planning output parser.""" tools: List[BaseTool] def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Task]: texts = [] # TODO: Cleanup tuple state tracking here. thought = None for chunk in input: # Assume input is str. TODO: support vision/other formats text = chunk if isinstance(chunk, str) else str(chunk.content) for task, thought in self.ingest_token(text, texts, thought): yield task # Final possible task if texts: task, _ = self._parse_task("".join(texts), thought) if task: yield task def parse(self, text: str) -> List[Task]: return list(self._transform([text])) def stream( self, input: str | BaseMessage, config: RunnableConfig | None = None, **kwargs: Any | None, ) -> Iterator[Task]: yield from self.transform([input], config, **kwargs) def ingest_token( self, token: str, buffer: List[str], thought: Optional[str] ) -> Iterator[Tuple[Optional[Task], str]]: buffer.append(token) if "\n" in token: buffer_ = "".join(buffer).split("\n") suffix = buffer_[-1] for line in buffer_[:-1]: task, thought = self._parse_task(line, thought) if task: yield task, thought buffer.clear() buffer.append(suffix) def _parse_task(self, line: str, thought: Optional[str] = None): task = None if match := re.match(THOUGHT_PATTERN, line): # Optionally, action can be preceded by a thought thought = match.group(1) elif match := re.match(ACTION_PATTERN, line): # if action is parsed, return the task, and clear the buffer idx, tool_name, args, _ = match.groups() idx = int(idx) task = instantiate_task( tools=self.tools, idx=idx, tool_name=tool_name, args=args, thought=thought, ) thought = None # Else it is just dropped return task, thought ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/llm-compiler/math_tools.py`: ```py import math import re from typing import List, Optional import numexpr from langchain.chains.openai_functions import create_structured_output_runnable from langchain_core.messages import SystemMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableConfig from langchain_core.tools import StructuredTool from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field _MATH_DESCRIPTION = ( "math(problem: str, context: Optional[list[str]]) -> float:\n" " - Solves the provided math problem.\n" ' - `problem` can be either a simple math problem (e.g. "1 + 3") or a word problem (e.g. "how many apples are there if there are 3 apples and 2 apples").\n' " - You cannot calculate multiple expressions in one call. For instance, `math('1 + 3, 2 + 4')` does not work. " "If you need to calculate multiple expressions, you need to call them separately like `math('1 + 3')` and then `math('2 + 4')`\n" " - Minimize the number of `math` actions as much as possible. For instance, instead of calling " '2. math("what is the 10% of $1") and then call 3. math("$1 + $2"), ' 'you MUST call 2. math("what is the 110% of $1") instead, which will reduce the number of math actions.\n' # Context specific rules below " - You can optionally provide a list of strings as `context` to help the agent solve the problem. " "If there are multiple contexts you need to answer the question, you can provide them as a list of strings.\n" " - `math` action will not see the output of the previous actions unless you provide it as `context`. " "You MUST provide the output of the previous actions as `context` if you need to do math on it.\n" " - You MUST NEVER provide `search` type action's outputs as a variable in the `problem` argument. " "This is because `search` returns a text blob that contains the information about the entity, not a number or value. " "Therefore, when you need to provide an output of `search` action, you MUST provide it as a `context` argument to `math` action. " 'For example, 1. search("Barack Obama") and then 2. math("age of $1") is NEVER allowed. ' 'Use 2. math("age of Barack Obama", context=["$1"]) instead.\n' " - When you ask a question about `context`, specify the units. " 'For instance, "what is xx in height?" or "what is xx in millions?" instead of "what is xx?"\n' ) _SYSTEM_PROMPT = """Translate a math problem into a expression that can be executed using Python's numexpr library. Use the output of running this code to answer the question. Question: ${{Question with math problem.}} ```text ${{single line mathematical expression that solves the problem}} ``` ...numexpr.evaluate(text)... ```output ${{Output of running the code}} ``` Answer: ${{Answer}} Begin. Question: What is 37593 * 67? ExecuteCode({{code: "37593 * 67"}}) ...numexpr.evaluate("37593 * 67")... ```output 2518731 ``` Answer: 2518731 Question: 37593^(1/5) ExecuteCode({{code: "37593**(1/5)"}}) ...numexpr.evaluate("37593**(1/5)")... ```output 8.222831614237718 ``` Answer: 8.222831614237718 """ _ADDITIONAL_CONTEXT_PROMPT = """The following additional context is provided from other functions.\ Use it to substitute into any ${{#}} variables or other words in the problem.\ \n\n${context}\n\nNote that context variables are not defined in code yet.\ You must extract the relevant numbers and directly put them in code.""" class ExecuteCode(BaseModel): """The input to the numexpr.evaluate() function.""" reasoning: str = Field( ..., description="The reasoning behind the code expression, including how context is included, if applicable.", ) code: str = Field( ..., description="The simple code expression to execute by numexpr.evaluate().", ) def _evaluate_expression(expression: str) -> str: try: local_dict = {"pi": math.pi, "e": math.e} output = str( numexpr.evaluate( expression.strip(), global_dict={}, # restrict access to globals local_dict=local_dict, # add common mathematical functions ) ) except Exception as e: raise ValueError( f'Failed to evaluate "{expression}". Raised error: {repr(e)}.' " Please try again with a valid numerical expression" ) # Remove any leading and trailing brackets from the output return re.sub(r"^\[|\]$", "", output) def get_math_tool(llm: ChatOpenAI): prompt = ChatPromptTemplate.from_messages( [ ("system", _SYSTEM_PROMPT), ("user", "{problem}"), MessagesPlaceholder(variable_name="context", optional=True), ] ) extractor = prompt | llm.with_structured_output(ExecuteCode) def calculate_expression( problem: str, context: Optional[List[str]] = None, config: Optional[RunnableConfig] = None, ): chain_input = {"problem": problem} if context: context_str = "\n".join(context) if context_str.strip(): context_str = _ADDITIONAL_CONTEXT_PROMPT.format( context=context_str.strip() ) chain_input["context"] = [SystemMessage(content=context_str)] code_model = extractor.invoke(chain_input, config) try: return _evaluate_expression(code_model.code) except Exception as e: return repr(e) return StructuredTool.from_function( name="math", func=calculate_expression, description=_MATH_DESCRIPTION, ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/reflection/reflection.py`: ```py # %% [markdown] # # Reflection # # # In the context of LLM agent building, reflection refers to the process of prompting an LLM to observe its past steps (along with potential observations from tools/the environment) to assess the quality of the chosen actions. # This is then used downstream for things like re-planning, search, or evaluation. # # ![Reflection](attachment:fc393f72-3401-4b86-b0d3-e4789b640a27.png) # # This notebook demonstrates a very simple form of reflection in LangGraph. # %% [markdown] # ## Setup # # First, let's install our required packages and set our API keys # %% # %pip install -U --quiet langgraph langchain-fireworks # %pip install -U --quiet tavily-python # %% import getpass import os def _set_if_undefined(var: str) -> None: if os.environ.get(var): return os.environ[var] = getpass.getpass(var) _set_if_undefined("TAVILY_API_KEY") _set_if_undefined("FIREWORKS_API_KEY") # %% [markdown] # <div class="admonition tip"> # <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p> # <p style="padding-top: 5px;"> # Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>. # </p> # </div> # %% [markdown] # ## Generate # # For our example, we will create a "5 paragraph essay" generator. First, create the generator: # # %% from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_fireworks import ChatFireworks prompt = ChatPromptTemplate.from_messages( [ ( "system", "You are an essay assistant tasked with writing excellent 5-paragraph essays." " Generate the best essay possible for the user's request." " If the user provides critique, respond with a revised version of your previous attempts.", ), MessagesPlaceholder(variable_name="messages"), ] ) llm = ChatFireworks( model="accounts/fireworks/models/mixtral-8x7b-instruct", max_tokens=32768 ) generate = prompt | llm # %% essay = "" request = HumanMessage( content="Write an essay on why the little prince is relevant in modern childhood" ) for chunk in generate.stream({"messages": [request]}): print(chunk.content, end="") essay += chunk.content # %% [markdown] # ### Reflect # %% reflection_prompt = ChatPromptTemplate.from_messages( [ ( "system", "You are a teacher grading an essay submission. Generate critique and recommendations for the user's submission." " Provide detailed recommendations, including requests for length, depth, style, etc.", ), MessagesPlaceholder(variable_name="messages"), ] ) reflect = reflection_prompt | llm # %% reflection = "" for chunk in reflect.stream({"messages": [request, HumanMessage(content=essay)]}): print(chunk.content, end="") reflection += chunk.content # %% [markdown] # ### Repeat # # And... that's all there is too it! You can repeat in a loop for a fixed number of steps, or use an LLM (or other check) to decide when the finished product is good enough. # %% for chunk in generate.stream( {"messages": [request, AIMessage(content=essay), HumanMessage(content=reflection)]} ): print(chunk.content, end="") # %% [markdown] # ## Define graph # # Now that we've shown each step in isolation, we can wire it up in a graph. # %% from typing import Annotated, List, Sequence from langgraph.graph import END, StateGraph, START from langgraph.graph.message import add_messages from langgraph.checkpoint.memory import MemorySaver from typing_extensions import TypedDict class State(TypedDict): messages: Annotated[list, add_messages] async def generation_node(state: State) -> State: return {"messages": [await generate.ainvoke(state["messages"])]} async def reflection_node(state: State) -> State: # Other messages we need to adjust cls_map = {"ai": HumanMessage, "human": AIMessage} # First message is the original user request. We hold it the same for all nodes translated = [state["messages"][0]] + [ cls_map[msg.type](content=msg.content) for msg in state["messages"][1:] ] res = await reflect.ainvoke(translated) # We treat the output of this as human feedback for the generator return {"messages": [HumanMessage(content=res.content)]} builder = StateGraph(State) builder.add_node("generate", generation_node) builder.add_node("reflect", reflection_node) builder.add_edge(START, "generate") def should_continue(state: State): if len(state["messages"]) > 6: # End after 3 iterations return END return "reflect" builder.add_conditional_edges("generate", should_continue) builder.add_edge("reflect", "generate") memory = MemorySaver() graph = builder.compile(checkpointer=memory) # %% config = {"configurable": {"thread_id": "1"}} # %% async for event in graph.astream( { "messages": [ HumanMessage( content="Generate an essay on the topicality of The Little Prince and its message in modern life" ) ], }, config, ): print(event) print("---") # %% state = graph.get_state(config) # %% ChatPromptTemplate.from_messages(state.values["messages"]).pretty_print() # %% [markdown] # ## Conclusion # # Now that you've applied reflection to an LLM agent, I'll note one thing: self-reflection is inherently cyclic: it is much more effective if the reflection step has additional context or feedback (from tool observations, checks, etc.). If, like in the scenario above, the reflection step simply prompts the LLM to reflect on its output, it can still benefit the output quality (since the LLM then has multiple "shots" at getting a good output), but it's less guaranteed. # ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/checkpoint/postgres/_internal.py`: ```py """Shared utility functions for the Postgres checkpoint & storage classes.""" from collections.abc import Iterator from contextlib import contextmanager from typing import Union from psycopg import Connection from psycopg.rows import DictRow from psycopg_pool import ConnectionPool Conn = Union[Connection[DictRow], ConnectionPool[Connection[DictRow]]] @contextmanager def get_connection(conn: Conn) -> Iterator[Connection[DictRow]]: if isinstance(conn, Connection): yield conn elif isinstance(conn, ConnectionPool): with conn.connection() as conn: yield conn else: raise TypeError(f"Invalid connection type: {type(conn)}") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/checkpoint/postgres/_ainternal.py`: ```py """Shared async utility functions for the Postgres checkpoint & storage classes.""" from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Union from psycopg import AsyncConnection from psycopg.rows import DictRow from psycopg_pool import AsyncConnectionPool Conn = Union[AsyncConnection[DictRow], AsyncConnectionPool[AsyncConnection[DictRow]]] @asynccontextmanager async def get_connection( conn: Conn, ) -> AsyncIterator[AsyncConnection[DictRow]]: if isinstance(conn, AsyncConnection): yield conn elif isinstance(conn, AsyncConnectionPool): async with conn.connection() as conn: yield conn else: raise TypeError(f"Invalid connection type: {type(conn)}") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py`: ```py import threading from collections.abc import Iterator, Sequence from contextlib import contextmanager from typing import Any, Optional from langchain_core.runnables import RunnableConfig from psycopg import Capabilities, Connection, Cursor, Pipeline from psycopg.rows import DictRow, dict_row from psycopg.types.json import Jsonb from psycopg_pool import ConnectionPool from langgraph.checkpoint.base import ( WRITES_IDX_MAP, ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, get_checkpoint_id, ) from langgraph.checkpoint.postgres import _internal from langgraph.checkpoint.postgres.base import BasePostgresSaver from langgraph.checkpoint.serde.base import SerializerProtocol Conn = _internal.Conn # For backward compatibility class PostgresSaver(BasePostgresSaver): lock: threading.Lock def __init__( self, conn: _internal.Conn, pipe: Optional[Pipeline] = None, serde: Optional[SerializerProtocol] = None, ) -> None: super().__init__(serde=serde) if isinstance(conn, ConnectionPool) and pipe is not None: raise ValueError( "Pipeline should be used only with a single Connection, not ConnectionPool." ) self.conn = conn self.pipe = pipe self.lock = threading.Lock() self.supports_pipeline = Capabilities().has_pipeline() @classmethod @contextmanager def from_conn_string( cls, conn_string: str, *, pipeline: bool = False ) -> Iterator["PostgresSaver"]: """Create a new PostgresSaver instance from a connection string. Args: conn_string (str): The Postgres connection info string. pipeline (bool): whether to use Pipeline Returns: PostgresSaver: A new PostgresSaver instance. """ with Connection.connect( conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row ) as conn: if pipeline: with conn.pipeline() as pipe: yield cls(conn, pipe) else: yield cls(conn) def setup(self) -> None: """Set up the checkpoint database asynchronously. This method creates the necessary tables in the Postgres database if they don't already exist and runs database migrations. It MUST be called directly by the user the first time checkpointer is used. """ with self._cursor() as cur: cur.execute(self.MIGRATIONS[0]) results = cur.execute( "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" ) row = results.fetchone() if row is None: version = -1 else: version = row["v"] for v, migration in zip( range(version + 1, len(self.MIGRATIONS)), self.MIGRATIONS[version + 1 :], ): cur.execute(migration) cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})") if self.pipe: self.pipe.sync() def list( self, config: Optional[RunnableConfig], *, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[CheckpointTuple]: """List checkpoints from the database. This method retrieves a list of checkpoint tuples from the Postgres database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: config (RunnableConfig): The config to use for listing the checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. Defaults to None. before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None. Yields: Iterator[CheckpointTuple]: An iterator of checkpoint tuples. Examples: >>> from langgraph.checkpoint.postgres import PostgresSaver >>> DB_URI = "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable" >>> with PostgresSaver.from_conn_string(DB_URI) as memory: ... # Run a graph, then list the checkpoints >>> config = {"configurable": {"thread_id": "1"}} >>> checkpoints = list(memory.list(config, limit=2)) >>> print(checkpoints) [CheckpointTuple(...), CheckpointTuple(...)] >>> config = {"configurable": {"thread_id": "1"}} >>> before = {"configurable": {"checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875"}} >>> with PostgresSaver.from_conn_string(DB_URI) as memory: ... # Run a graph, then list the checkpoints >>> checkpoints = list(memory.list(config, before=before)) >>> print(checkpoints) [CheckpointTuple(...), ...] """ where, args = self._search_where(config, filter, before) query = self.SELECT_SQL + where + " ORDER BY checkpoint_id DESC" if limit: query += f" LIMIT {limit}" # if we change this to use .stream() we need to make sure to close the cursor with self._cursor() as cur: cur.execute(query, args, binary=True) for value in cur: yield CheckpointTuple( { "configurable": { "thread_id": value["thread_id"], "checkpoint_ns": value["checkpoint_ns"], "checkpoint_id": value["checkpoint_id"], } }, self._load_checkpoint( value["checkpoint"], value["channel_values"], value["pending_sends"], ), self._load_metadata(value["metadata"]), ( { "configurable": { "thread_id": value["thread_id"], "checkpoint_ns": value["checkpoint_ns"], "checkpoint_id": value["parent_checkpoint_id"], } } if value["parent_checkpoint_id"] else None ), self._load_writes(value["pending_writes"]), ) def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database. This method retrieves a checkpoint tuple from the Postgres database based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. Examples: Basic: >>> config = {"configurable": {"thread_id": "1"}} >>> checkpoint_tuple = memory.get_tuple(config) >>> print(checkpoint_tuple) CheckpointTuple(...) With timestamp: >>> config = { ... "configurable": { ... "thread_id": "1", ... "checkpoint_ns": "", ... "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", ... } ... } >>> checkpoint_tuple = memory.get_tuple(config) >>> print(checkpoint_tuple) CheckpointTuple(...) """ # noqa thread_id = config["configurable"]["thread_id"] checkpoint_id = get_checkpoint_id(config) checkpoint_ns = config["configurable"].get("checkpoint_ns", "") if checkpoint_id: args: tuple[Any, ...] = (thread_id, checkpoint_ns, checkpoint_id) where = "WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s" else: args = (thread_id, checkpoint_ns) where = "WHERE thread_id = %s AND checkpoint_ns = %s ORDER BY checkpoint_id DESC LIMIT 1" with self._cursor() as cur: cur.execute( self.SELECT_SQL + where, args, binary=True, ) for value in cur: return CheckpointTuple( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": value["checkpoint_id"], } }, self._load_checkpoint( value["checkpoint"], value["channel_values"], value["pending_sends"], ), self._load_metadata(value["metadata"]), ( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": value["parent_checkpoint_id"], } } if value["parent_checkpoint_id"] else None ), self._load_writes(value["pending_writes"]), ) def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Save a checkpoint to the database. This method saves a checkpoint to the Postgres database. The checkpoint is associated with the provided config and its parent config (if any). Args: config (RunnableConfig): The config to associate with the checkpoint. checkpoint (Checkpoint): The checkpoint to save. metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. new_versions (ChannelVersions): New channel versions as of this write. Returns: RunnableConfig: Updated configuration after storing the checkpoint. Examples: >>> from langgraph.checkpoint.postgres import PostgresSaver >>> DB_URI = "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable" >>> with PostgresSaver.from_conn_string(DB_URI) as memory: >>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} >>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "channel_values": {"key": "value"}} >>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {}) >>> print(saved_config) {'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef4f797-8335-6428-8001-8a1503f9b875'}} """ configurable = config["configurable"].copy() thread_id = configurable.pop("thread_id") checkpoint_ns = configurable.pop("checkpoint_ns") checkpoint_id = configurable.pop( "checkpoint_id", configurable.pop("thread_ts", None) ) copy = checkpoint.copy() next_config = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint["id"], } } with self._cursor(pipeline=True) as cur: cur.executemany( self.UPSERT_CHECKPOINT_BLOBS_SQL, self._dump_blobs( thread_id, checkpoint_ns, copy.pop("channel_values"), # type: ignore[misc] new_versions, ), ) cur.execute( self.UPSERT_CHECKPOINTS_SQL, ( thread_id, checkpoint_ns, checkpoint["id"], checkpoint_id, Jsonb(self._dump_checkpoint(copy)), self._dump_metadata(metadata), ), ) return next_config def put_writes( self, config: RunnableConfig, writes: Sequence[tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint. This method saves intermediate writes associated with a checkpoint to the Postgres database. Args: config (RunnableConfig): Configuration of the related checkpoint. writes (List[Tuple[str, Any]]): List of writes to store. task_id (str): Identifier for the task creating the writes. """ query = ( self.UPSERT_CHECKPOINT_WRITES_SQL if all(w[0] in WRITES_IDX_MAP for w in writes) else self.INSERT_CHECKPOINT_WRITES_SQL ) with self._cursor(pipeline=True) as cur: cur.executemany( query, self._dump_writes( config["configurable"]["thread_id"], config["configurable"]["checkpoint_ns"], config["configurable"]["checkpoint_id"], task_id, writes, ), ) @contextmanager def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: """Create a database cursor as a context manager. Args: pipeline (bool): whether to use pipeline for the DB operations inside the context manager. Will be applied regardless of whether the PostgresSaver instance was initialized with a pipeline. If pipeline mode is not supported, will fall back to using transaction context manager. """ with _internal.get_connection(self.conn) as conn: if self.pipe: # a connection in pipeline mode can be used concurrently # in multiple threads/coroutines, but only one cursor can be # used at a time try: with conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur finally: if pipeline: self.pipe.sync() elif pipeline: # a connection not in pipeline mode can only be used by one # thread/coroutine at a time, so we acquire a lock if self.supports_pipeline: with ( self.lock, conn.pipeline(), conn.cursor(binary=True, row_factory=dict_row) as cur, ): yield cur else: # Use connection's transaction context manager when pipeline mode not supported with ( self.lock, conn.transaction(), conn.cursor(binary=True, row_factory=dict_row) as cur, ): yield cur else: with self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur __all__ = ["PostgresSaver", "BasePostgresSaver", "Conn"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py`: ```py import asyncio from collections.abc import AsyncIterator, Iterator, Sequence from contextlib import asynccontextmanager from typing import Any, Optional from langchain_core.runnables import RunnableConfig from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline, Capabilities from psycopg.rows import DictRow, dict_row from psycopg.types.json import Jsonb from psycopg_pool import AsyncConnectionPool from langgraph.checkpoint.base import ( WRITES_IDX_MAP, ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, get_checkpoint_id, ) from langgraph.checkpoint.postgres import _ainternal from langgraph.checkpoint.postgres.base import BasePostgresSaver from langgraph.checkpoint.serde.base import SerializerProtocol Conn = _ainternal.Conn # For backward compatibility class AsyncPostgresSaver(BasePostgresSaver): lock: asyncio.Lock def __init__( self, conn: _ainternal.Conn, pipe: Optional[AsyncPipeline] = None, serde: Optional[SerializerProtocol] = None, ) -> None: super().__init__(serde=serde) if isinstance(conn, AsyncConnectionPool) and pipe is not None: raise ValueError( "Pipeline should be used only with a single AsyncConnection, not AsyncConnectionPool." ) self.conn = conn self.pipe = pipe self.lock = asyncio.Lock() self.loop = asyncio.get_running_loop() self.supports_pipeline = Capabilities().has_pipeline() @classmethod @asynccontextmanager async def from_conn_string( cls, conn_string: str, *, pipeline: bool = False, serde: Optional[SerializerProtocol] = None, ) -> AsyncIterator["AsyncPostgresSaver"]: """Create a new AsyncPostgresSaver instance from a connection string. Args: conn_string (str): The Postgres connection info string. pipeline (bool): whether to use AsyncPipeline Returns: AsyncPostgresSaver: A new AsyncPostgresSaver instance. """ async with await AsyncConnection.connect( conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row ) as conn: if pipeline: async with conn.pipeline() as pipe: yield cls(conn=conn, pipe=pipe, serde=serde) else: yield cls(conn=conn, serde=serde) async def setup(self) -> None: """Set up the checkpoint database asynchronously. This method creates the necessary tables in the Postgres database if they don't already exist and runs database migrations. It MUST be called directly by the user the first time checkpointer is used. """ async with self._cursor() as cur: await cur.execute(self.MIGRATIONS[0]) results = await cur.execute( "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" ) row = await results.fetchone() if row is None: version = -1 else: version = row["v"] for v, migration in zip( range(version + 1, len(self.MIGRATIONS)), self.MIGRATIONS[version + 1 :], ): await cur.execute(migration) await cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})") if self.pipe: await self.pipe.sync() async def alist( self, config: Optional[RunnableConfig], *, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> AsyncIterator[CheckpointTuple]: """List checkpoints from the database asynchronously. This method retrieves a list of checkpoint tuples from the Postgres database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. limit (Optional[int]): Maximum number of checkpoints to return. Yields: AsyncIterator[CheckpointTuple]: An asynchronous iterator of matching checkpoint tuples. """ where, args = self._search_where(config, filter, before) query = self.SELECT_SQL + where + " ORDER BY checkpoint_id DESC" if limit: query += f" LIMIT {limit}" # if we change this to use .stream() we need to make sure to close the cursor async with self._cursor() as cur: await cur.execute(query, args, binary=True) async for value in cur: yield CheckpointTuple( { "configurable": { "thread_id": value["thread_id"], "checkpoint_ns": value["checkpoint_ns"], "checkpoint_id": value["checkpoint_id"], } }, await asyncio.to_thread( self._load_checkpoint, value["checkpoint"], value["channel_values"], value["pending_sends"], ), self._load_metadata(value["metadata"]), ( { "configurable": { "thread_id": value["thread_id"], "checkpoint_ns": value["checkpoint_ns"], "checkpoint_id": value["parent_checkpoint_id"], } } if value["parent_checkpoint_id"] else None ), await asyncio.to_thread(self._load_writes, value["pending_writes"]), ) async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database asynchronously. This method retrieves a checkpoint tuple from the Postgres database based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. """ thread_id = config["configurable"]["thread_id"] checkpoint_id = get_checkpoint_id(config) checkpoint_ns = config["configurable"].get("checkpoint_ns", "") if checkpoint_id: args: tuple[Any, ...] = (thread_id, checkpoint_ns, checkpoint_id) where = "WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s" else: args = (thread_id, checkpoint_ns) where = "WHERE thread_id = %s AND checkpoint_ns = %s ORDER BY checkpoint_id DESC LIMIT 1" async with self._cursor() as cur: await cur.execute( self.SELECT_SQL + where, args, binary=True, ) async for value in cur: return CheckpointTuple( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": value["checkpoint_id"], } }, await asyncio.to_thread( self._load_checkpoint, value["checkpoint"], value["channel_values"], value["pending_sends"], ), self._load_metadata(value["metadata"]), ( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": value["parent_checkpoint_id"], } } if value["parent_checkpoint_id"] else None ), await asyncio.to_thread(self._load_writes, value["pending_writes"]), ) async def aput( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Save a checkpoint to the database asynchronously. This method saves a checkpoint to the Postgres database. The checkpoint is associated with the provided config and its parent config (if any). Args: config (RunnableConfig): The config to associate with the checkpoint. checkpoint (Checkpoint): The checkpoint to save. metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. new_versions (ChannelVersions): New channel versions as of this write. Returns: RunnableConfig: Updated configuration after storing the checkpoint. """ configurable = config["configurable"].copy() thread_id = configurable.pop("thread_id") checkpoint_ns = configurable.pop("checkpoint_ns") checkpoint_id = configurable.pop( "checkpoint_id", configurable.pop("thread_ts", None) ) copy = checkpoint.copy() next_config = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint["id"], } } async with self._cursor(pipeline=True) as cur: await cur.executemany( self.UPSERT_CHECKPOINT_BLOBS_SQL, await asyncio.to_thread( self._dump_blobs, thread_id, checkpoint_ns, copy.pop("channel_values"), # type: ignore[misc] new_versions, ), ) await cur.execute( self.UPSERT_CHECKPOINTS_SQL, ( thread_id, checkpoint_ns, checkpoint["id"], checkpoint_id, Jsonb(self._dump_checkpoint(copy)), self._dump_metadata(metadata), ), ) return next_config async def aput_writes( self, config: RunnableConfig, writes: Sequence[tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint asynchronously. This method saves intermediate writes associated with a checkpoint to the database. Args: config (RunnableConfig): Configuration of the related checkpoint. writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. task_id (str): Identifier for the task creating the writes. """ query = ( self.UPSERT_CHECKPOINT_WRITES_SQL if all(w[0] in WRITES_IDX_MAP for w in writes) else self.INSERT_CHECKPOINT_WRITES_SQL ) params = await asyncio.to_thread( self._dump_writes, config["configurable"]["thread_id"], config["configurable"]["checkpoint_ns"], config["configurable"]["checkpoint_id"], task_id, writes, ) async with self._cursor(pipeline=True) as cur: await cur.executemany(query, params) @asynccontextmanager async def _cursor( self, *, pipeline: bool = False ) -> AsyncIterator[AsyncCursor[DictRow]]: """Create a database cursor as a context manager. Args: pipeline (bool): whether to use pipeline for the DB operations inside the context manager. Will be applied regardless of whether the AsyncPostgresSaver instance was initialized with a pipeline. If pipeline mode is not supported, will fall back to using transaction context manager. """ async with _ainternal.get_connection(self.conn) as conn: if self.pipe: # a connection in pipeline mode can be used concurrently # in multiple threads/coroutines, but only one cursor can be # used at a time try: async with conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur finally: if pipeline: await self.pipe.sync() elif pipeline: # a connection not in pipeline mode can only be used by one # thread/coroutine at a time, so we acquire a lock if self.supports_pipeline: async with ( self.lock, conn.pipeline(), conn.cursor(binary=True, row_factory=dict_row) as cur, ): yield cur else: # Use connection's transaction context manager when pipeline mode not supported async with ( self.lock, conn.transaction(), conn.cursor(binary=True, row_factory=dict_row) as cur, ): yield cur else: async with ( self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur, ): yield cur def list( self, config: Optional[RunnableConfig], *, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[CheckpointTuple]: """List checkpoints from the database. This method retrieves a list of checkpoint tuples from the Postgres database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. limit (Optional[int]): Maximum number of checkpoints to return. Yields: Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples. """ aiter_ = self.alist(config, filter=filter, before=before, limit=limit) while True: try: yield asyncio.run_coroutine_threadsafe( anext(aiter_), # noqa: F821 self.loop, ).result() except StopAsyncIteration: break def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database. This method retrieves a checkpoint tuple from the Postgres database based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. """ try: # check if we are in the main thread, only bg threads can block # we don't check in other methods to avoid the overhead if asyncio.get_running_loop() is self.loop: raise asyncio.InvalidStateError( "Synchronous calls to AsyncPostgresSaver are only allowed from a " "different thread. From the main thread, use the async interface." "For example, use `await checkpointer.aget_tuple(...)` or `await " "graph.ainvoke(...)`." ) except RuntimeError: pass return asyncio.run_coroutine_threadsafe( self.aget_tuple(config), self.loop ).result() def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Save a checkpoint to the database. This method saves a checkpoint to the Postgres database. The checkpoint is associated with the provided config and its parent config (if any). Args: config (RunnableConfig): The config to associate with the checkpoint. checkpoint (Checkpoint): The checkpoint to save. metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. new_versions (ChannelVersions): New channel versions as of this write. Returns: RunnableConfig: Updated configuration after storing the checkpoint. """ return asyncio.run_coroutine_threadsafe( self.aput(config, checkpoint, metadata, new_versions), self.loop ).result() def put_writes( self, config: RunnableConfig, writes: Sequence[tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint. This method saves intermediate writes associated with a checkpoint to the database. Args: config (RunnableConfig): Configuration of the related checkpoint. writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. task_id (str): Identifier for the task creating the writes. """ return asyncio.run_coroutine_threadsafe( self.aput_writes(config, writes, task_id), self.loop ).result() __all__ = ["AsyncPostgresSaver", "Conn"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py`: ```py import random from collections.abc import Sequence from typing import Any, Optional, cast from langchain_core.runnables import RunnableConfig from psycopg.types.json import Jsonb from langgraph.checkpoint.base import ( WRITES_IDX_MAP, BaseCheckpointSaver, ChannelVersions, Checkpoint, CheckpointMetadata, get_checkpoint_id, ) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol MetadataInput = Optional[dict[str, Any]] """ To add a new migration, add a new string to the MIGRATIONS list. The position of the migration in the list is the version number. """ MIGRATIONS = [ """CREATE TABLE IF NOT EXISTS checkpoint_migrations ( v INTEGER PRIMARY KEY );""", """CREATE TABLE IF NOT EXISTS checkpoints ( thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', checkpoint_id TEXT NOT NULL, parent_checkpoint_id TEXT, type TEXT, checkpoint JSONB NOT NULL, metadata JSONB NOT NULL DEFAULT '{}', PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) );""", """CREATE TABLE IF NOT EXISTS checkpoint_blobs ( thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', channel TEXT NOT NULL, version TEXT NOT NULL, type TEXT NOT NULL, blob BYTEA, PRIMARY KEY (thread_id, checkpoint_ns, channel, version) );""", """CREATE TABLE IF NOT EXISTS checkpoint_writes ( thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', checkpoint_id TEXT NOT NULL, task_id TEXT NOT NULL, idx INTEGER NOT NULL, channel TEXT NOT NULL, type TEXT, blob BYTEA NOT NULL, PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) );""", "ALTER TABLE checkpoint_blobs ALTER COLUMN blob DROP not null;", ] SELECT_SQL = f""" select thread_id, checkpoint, checkpoint_ns, checkpoint_id, parent_checkpoint_id, metadata, ( select array_agg(array[bl.channel::bytea, bl.type::bytea, bl.blob]) from jsonb_each_text(checkpoint -> 'channel_versions') inner join checkpoint_blobs bl on bl.thread_id = checkpoints.thread_id and bl.checkpoint_ns = checkpoints.checkpoint_ns and bl.channel = jsonb_each_text.key and bl.version = jsonb_each_text.value ) as channel_values, ( select array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) from checkpoint_writes cw where cw.thread_id = checkpoints.thread_id and cw.checkpoint_ns = checkpoints.checkpoint_ns and cw.checkpoint_id = checkpoints.checkpoint_id ) as pending_writes, ( select array_agg(array[cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) from checkpoint_writes cw where cw.thread_id = checkpoints.thread_id and cw.checkpoint_ns = checkpoints.checkpoint_ns and cw.checkpoint_id = checkpoints.parent_checkpoint_id and cw.channel = '{TASKS}' ) as pending_sends from checkpoints """ UPSERT_CHECKPOINT_BLOBS_SQL = """ INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, version, type, blob) VALUES (%s, %s, %s, %s, %s, %s) ON CONFLICT (thread_id, checkpoint_ns, channel, version) DO NOTHING """ UPSERT_CHECKPOINTS_SQL = """ INSERT INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata) VALUES (%s, %s, %s, %s, %s, %s) ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id) DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata; """ UPSERT_CHECKPOINT_WRITES_SQL = """ INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET channel = EXCLUDED.channel, type = EXCLUDED.type, blob = EXCLUDED.blob; """ INSERT_CHECKPOINT_WRITES_SQL = """ INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING """ class BasePostgresSaver(BaseCheckpointSaver[str]): SELECT_SQL = SELECT_SQL MIGRATIONS = MIGRATIONS UPSERT_CHECKPOINT_BLOBS_SQL = UPSERT_CHECKPOINT_BLOBS_SQL UPSERT_CHECKPOINTS_SQL = UPSERT_CHECKPOINTS_SQL UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL jsonplus_serde = JsonPlusSerializer() supports_pipeline: bool def _load_checkpoint( self, checkpoint: dict[str, Any], channel_values: list[tuple[bytes, bytes, bytes]], pending_sends: list[tuple[bytes, bytes]], ) -> Checkpoint: return { **checkpoint, "pending_sends": [ self.serde.loads_typed((c.decode(), b)) for c, b in pending_sends or [] ], "channel_values": self._load_blobs(channel_values), } def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]: return {**checkpoint, "pending_sends": []} def _load_blobs( self, blob_values: list[tuple[bytes, bytes, bytes]] ) -> dict[str, Any]: if not blob_values: return {} return { k.decode(): self.serde.loads_typed((t.decode(), v)) for k, t, v in blob_values if t.decode() != "empty" } def _dump_blobs( self, thread_id: str, checkpoint_ns: str, values: dict[str, Any], versions: ChannelVersions, ) -> list[tuple[str, str, str, str, str, Optional[bytes]]]: if not versions: return [] return [ ( thread_id, checkpoint_ns, k, cast(str, ver), *( self.serde.dumps_typed(values[k]) if k in values else ("empty", None) ), ) for k, ver in versions.items() ] def _load_writes( self, writes: list[tuple[bytes, bytes, bytes, bytes]] ) -> list[tuple[str, str, Any]]: return ( [ ( tid.decode(), channel.decode(), self.serde.loads_typed((t.decode(), v)), ) for tid, channel, t, v in writes ] if writes else [] ) def _dump_writes( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str, task_id: str, writes: Sequence[tuple[str, Any]], ) -> list[tuple[str, str, str, str, int, str, str, bytes]]: return [ ( thread_id, checkpoint_ns, checkpoint_id, task_id, WRITES_IDX_MAP.get(channel, idx), channel, *self.serde.dumps_typed(value), ) for idx, (channel, value) in enumerate(writes) ] def _load_metadata(self, metadata: dict[str, Any]) -> CheckpointMetadata: return self.jsonplus_serde.loads(self.jsonplus_serde.dumps(metadata)) def _dump_metadata(self, metadata: CheckpointMetadata) -> str: serialized_metadata = self.jsonplus_serde.dumps(metadata) # NOTE: we're using JSON serializer (not msgpack), so we need to remove null characters before writing return serialized_metadata.decode().replace("\\u0000", "") def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: if current is None: current_v = 0 elif isinstance(current, int): current_v = current else: current_v = int(current.split(".")[0]) next_v = current_v + 1 next_h = random.random() return f"{next_v:032}.{next_h:016}" def _search_where( self, config: Optional[RunnableConfig], filter: MetadataInput, before: Optional[RunnableConfig] = None, ) -> tuple[str, list[Any]]: """Return WHERE clause predicates for alist() given config, filter, before. This method returns a tuple of a string and a tuple of values. The string is the parametered WHERE clause predicate (including the WHERE keyword): "WHERE column1 = $1 AND column2 IS $2". The list of values contains the values for each of the corresponding parameters. """ wheres = [] param_values = [] # construct predicate for config filter if config: wheres.append("thread_id = %s ") param_values.append(config["configurable"]["thread_id"]) checkpoint_ns = config["configurable"].get("checkpoint_ns") if checkpoint_ns is not None: wheres.append("checkpoint_ns = %s") param_values.append(checkpoint_ns) if checkpoint_id := get_checkpoint_id(config): wheres.append("checkpoint_id = %s ") param_values.append(checkpoint_id) # construct predicate for metadata filter if filter: wheres.append("metadata @> %s ") param_values.append(Jsonb(filter)) # construct predicate for `before` if before is not None: wheres.append("checkpoint_id < %s ") param_values.append(get_checkpoint_id(before)) return ( "WHERE " + " AND ".join(wheres) if wheres else "", param_values, ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/store/postgres/__init__.py`: ```py from langgraph.store.postgres.aio import AsyncPostgresStore from langgraph.store.postgres.base import PostgresStore __all__ = ["AsyncPostgresStore", "PostgresStore"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/store/postgres/aio.py`: ```py import asyncio import logging from collections.abc import AsyncIterator, Iterable, Sequence from contextlib import asynccontextmanager from typing import Any, Callable, Optional, Union, cast import orjson from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline, Capabilities from psycopg.errors import UndefinedTable from psycopg.rows import DictRow, dict_row from psycopg_pool import AsyncConnectionPool from langgraph.checkpoint.postgres import _ainternal from langgraph.store.base import ( GetOp, ListNamespacesOp, Op, PutOp, Result, SearchOp, ) from langgraph.store.base.batch import AsyncBatchedBaseStore from langgraph.store.postgres.base import ( _PLACEHOLDER, BasePostgresStore, PoolConfig, PostgresIndexConfig, Row, _decode_ns_bytes, _ensure_index_config, _group_ops, _row_to_item, _row_to_search_item, ) logger = logging.getLogger(__name__) class AsyncPostgresStore(AsyncBatchedBaseStore, BasePostgresStore[_ainternal.Conn]): """Asynchronous Postgres-backed store with optional vector search using pgvector. !!! example "Examples" Basic setup and key-value storage: ```python from langgraph.store.postgres import AsyncPostgresStore async with AsyncPostgresStore.from_conn_string( "postgresql://user:pass@localhost:5432/dbname" ) as store: await store.setup() # Store and retrieve data await store.aput(("users", "123"), "prefs", {"theme": "dark"}) item = await store.aget(("users", "123"), "prefs") ``` Vector search using LangChain embeddings: ```python from langchain.embeddings import init_embeddings from langgraph.store.postgres import AsyncPostgresStore async with AsyncPostgresStore.from_conn_string( "postgresql://user:pass@localhost:5432/dbname", index={ "dims": 1536, "embed": init_embeddings("openai:text-embedding-3-small"), "fields": ["text"] # specify which fields to embed. Default is the whole serialized value } ) as store: await store.setup() # Do this once to run migrations # Store documents await store.aput(("docs",), "doc1", {"text": "Python tutorial"}) await store.aput(("docs",), "doc2", {"text": "TypeScript guide"}) # Don't index the following await store.aput(("docs",), "doc3", {"text": "Other guide"}, index=False) # Search by similarity results = await store.asearch(("docs",), query="python programming") ``` Using connection pooling for better performance: ```python from langgraph.store.postgres import AsyncPostgresStore, PoolConfig async with AsyncPostgresStore.from_conn_string( "postgresql://user:pass@localhost:5432/dbname", pool_config=PoolConfig( min_size=5, max_size=20 ) ) as store: await store.setup() # Use store with connection pooling... ``` Warning: Make sure to: 1. Call `setup()` before first use to create necessary tables and indexes 2. Have the pgvector extension available to use vector search 3. Use Python 3.10+ for async functionality Note: Semantic search is disabled by default. You can enable it by providing an `index` configuration when creating the store. Without this configuration, all `index` arguments passed to `put` or `aput`will have no effect. """ __slots__ = ( "_deserializer", "pipe", "lock", "supports_pipeline", "index_config", "embeddings", ) def __init__( self, conn: _ainternal.Conn, *, pipe: Optional[AsyncPipeline] = None, deserializer: Optional[ Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]] ] = None, index: Optional[PostgresIndexConfig] = None, ) -> None: if isinstance(conn, AsyncConnectionPool) and pipe is not None: raise ValueError( "Pipeline should be used only with a single AsyncConnection, not AsyncConnectionPool." ) super().__init__() self._deserializer = deserializer self.conn = conn self.pipe = pipe self.lock = asyncio.Lock() self.loop = asyncio.get_running_loop() self.supports_pipeline = Capabilities().has_pipeline() self.index_config = index if self.index_config: self.embeddings, self.index_config = _ensure_index_config(self.index_config) else: self.embeddings = None async def abatch(self, ops: Iterable[Op]) -> list[Result]: grouped_ops, num_ops = _group_ops(ops) results: list[Result] = [None] * num_ops async with _ainternal.get_connection(self.conn) as conn: if self.pipe: async with self.pipe: await self._execute_batch(grouped_ops, results, conn) else: await self._execute_batch(grouped_ops, results, conn) return results def batch(self, ops: Iterable[Op]) -> list[Result]: return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result() @classmethod @asynccontextmanager async def from_conn_string( cls, conn_string: str, *, pipeline: bool = False, pool_config: Optional[PoolConfig] = None, index: Optional[PostgresIndexConfig] = None, ) -> AsyncIterator["AsyncPostgresStore"]: """Create a new AsyncPostgresStore instance from a connection string. Args: conn_string (str): The Postgres connection info string. pipeline (bool): Whether to use AsyncPipeline (only for single connections) pool_config (Optional[PoolConfig]): Configuration for the connection pool. If provided, will create a connection pool and use it instead of a single connection. This overrides the `pipeline` argument. index (Optional[PostgresIndexConfig]): The embedding config. Returns: AsyncPostgresStore: A new AsyncPostgresStore instance. """ if pool_config is not None: pc = pool_config.copy() async with cast( AsyncConnectionPool[AsyncConnection[DictRow]], AsyncConnectionPool( conn_string, min_size=pc.pop("min_size", 1), max_size=pc.pop("max_size", None), kwargs={ "autocommit": True, "prepare_threshold": 0, "row_factory": dict_row, **(pc.pop("kwargs", None) or {}), }, **cast(dict, pc), ), ) as pool: yield cls(conn=pool, index=index) else: async with await AsyncConnection.connect( conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row ) as conn: if pipeline: async with conn.pipeline() as pipe: yield cls(conn=conn, pipe=pipe, index=index) else: yield cls(conn=conn, index=index) async def setup(self) -> None: """Set up the store database asynchronously. This method creates the necessary tables in the Postgres database if they don't already exist and runs database migrations. It MUST be called directly by the user the first time the store is used. """ async def _get_version(cur: AsyncCursor[DictRow], table: str) -> int: try: await cur.execute(f"SELECT v FROM {table} ORDER BY v DESC LIMIT 1") row = await cur.fetchone() if row is None: version = -1 else: version = row["v"] except UndefinedTable: version = -1 await cur.execute( f""" CREATE TABLE IF NOT EXISTS {table} ( v INTEGER PRIMARY KEY ) """ ) return version async with self._cursor() as cur: version = await _get_version(cur, table="store_migrations") for v, sql in enumerate(self.MIGRATIONS[version + 1 :], start=version + 1): await cur.execute(sql) await cur.execute("INSERT INTO store_migrations (v) VALUES (%s)", (v,)) if self.index_config: version = await _get_version(cur, table="vector_migrations") for v, migration in enumerate( self.VECTOR_MIGRATIONS[version + 1 :], start=version + 1 ): sql = migration.sql if migration.params: params = { k: v(self) if v is not None and callable(v) else v for k, v in migration.params.items() } sql = sql % params await cur.execute(sql) await cur.execute( "INSERT INTO vector_migrations (v) VALUES (%s)", (v,) ) async def _execute_batch( self, grouped_ops: dict, results: list[Result], conn: AsyncConnection[DictRow], ) -> None: async with self._cursor(pipeline=True) as cur: if GetOp in grouped_ops: await self._batch_get_ops( cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), results, cur, ) if SearchOp in grouped_ops: await self._batch_search_ops( cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]), results, cur, ) if ListNamespacesOp in grouped_ops: await self._batch_list_namespaces_ops( cast( Sequence[tuple[int, ListNamespacesOp]], grouped_ops[ListNamespacesOp], ), results, cur, ) if PutOp in grouped_ops: await self._batch_put_ops( cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]), cur, ) async def _batch_get_ops( self, get_ops: Sequence[tuple[int, GetOp]], results: list[Result], cur: AsyncCursor[DictRow], ) -> None: for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops): await cur.execute(query, params) rows = cast(list[Row], await cur.fetchall()) key_to_row = {row["key"]: row for row in rows} for idx, key in items: row = key_to_row.get(key) if row: results[idx] = _row_to_item( namespace, row, loader=self._deserializer ) else: results[idx] = None async def _batch_put_ops( self, put_ops: Sequence[tuple[int, PutOp]], cur: AsyncCursor[DictRow], ) -> None: queries, embedding_request = self._prepare_batch_PUT_queries(put_ops) if embedding_request: if self.embeddings is None: # Should not get here since the embedding config is required # to return an embedding_request above raise ValueError( "Embedding configuration is required for vector operations " f"(for semantic search). " f"Please provide an EmbeddingConfig when initializing the {self.__class__.__name__}." ) query, txt_params = embedding_request vectors = await self.embeddings.aembed_documents( [param[-1] for param in txt_params] ) queries.append( ( query, [ p for (ns, k, pathname, _), vector in zip(txt_params, vectors) for p in (ns, k, pathname, vector) ], ) ) for query, params in queries: await cur.execute(query, params) async def _batch_search_ops( self, search_ops: Sequence[tuple[int, SearchOp]], results: list[Result], cur: AsyncCursor[DictRow], ) -> None: queries, embedding_requests = self._prepare_batch_search_queries(search_ops) if embedding_requests and self.embeddings: vectors = await self.embeddings.aembed_documents( [query for _, query in embedding_requests] ) for (idx, _), vector in zip(embedding_requests, vectors): _paramslist = queries[idx][1] for i in range(len(_paramslist)): if _paramslist[i] is _PLACEHOLDER: _paramslist[i] = vector for (idx, _), (query, params) in zip(search_ops, queries): await cur.execute(query, params) rows = cast(list[Row], await cur.fetchall()) items = [ _row_to_search_item( _decode_ns_bytes(row["prefix"]), row, loader=self._deserializer ) for row in rows ] results[idx] = items async def _batch_list_namespaces_ops( self, list_ops: Sequence[tuple[int, ListNamespacesOp]], results: list[Result], cur: AsyncCursor[DictRow], ) -> None: queries = self._get_batch_list_namespaces_queries(list_ops) for (query, params), (idx, _) in zip(queries, list_ops): await cur.execute(query, params) rows = cast(list[dict], await cur.fetchall()) namespaces = [_decode_ns_bytes(row["truncated_prefix"]) for row in rows] results[idx] = namespaces @asynccontextmanager async def _cursor( self, *, pipeline: bool = False ) -> AsyncIterator[AsyncCursor[DictRow]]: """Create a database cursor as a context manager. Args: pipeline: whether to use pipeline for the DB operations inside the context manager. Will be applied regardless of whether the PostgresStore instance was initialized with a pipeline. If pipeline mode is not supported, will fall back to using transaction context manager. """ async with _ainternal.get_connection(self.conn) as conn: if self.pipe: # a connection in pipeline mode can be used concurrently # in multiple threads/coroutines, but only one cursor can be # used at a time try: async with conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur finally: if pipeline: await self.pipe.sync() elif pipeline: # a connection not in pipeline mode can only be used by one # thread/coroutine at a time, so we acquire a lock if self.supports_pipeline: async with ( self.lock, conn.pipeline(), conn.cursor(binary=True, row_factory=dict_row) as cur, ): yield cur else: async with ( self.lock, conn.transaction(), conn.cursor(binary=True, row_factory=dict_row) as cur, ): yield cur else: async with ( self.lock, conn.cursor(binary=True) as cur, ): yield cur ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/store/postgres/base.py`: ```py import asyncio import json import logging import threading from collections import defaultdict from collections.abc import Iterable, Iterator, Sequence from contextlib import contextmanager from datetime import datetime from typing import ( TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union, cast, ) import orjson from psycopg import Capabilities, Connection, Cursor, Pipeline from psycopg.errors import UndefinedTable from psycopg.rows import DictRow, dict_row from psycopg.types.json import Jsonb from psycopg_pool import ConnectionPool from typing_extensions import TypedDict from langgraph.checkpoint.postgres import _ainternal as _ainternal from langgraph.checkpoint.postgres import _internal as _pg_internal from langgraph.store.base import ( BaseStore, GetOp, IndexConfig, Item, ListNamespacesOp, Op, PutOp, Result, SearchItem, SearchOp, ensure_embeddings, get_text_at_path, tokenize_path, ) if TYPE_CHECKING: from langchain_core.embeddings import Embeddings logger = logging.getLogger(__name__) class Migration(NamedTuple): """A database migration with optional conditions and parameters.""" sql: str params: Optional[dict[str, Any]] = None condition: Optional[Callable[["BasePostgresStore"], bool]] = None MIGRATIONS: Sequence[str] = [ """ CREATE TABLE IF NOT EXISTS store ( -- 'prefix' represents the doc's 'namespace' prefix text NOT NULL, key text NOT NULL, value jsonb NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (prefix, key) ); """, """ -- For faster lookups by prefix CREATE INDEX IF NOT EXISTS store_prefix_idx ON store USING btree (prefix text_pattern_ops); """, ] VECTOR_MIGRATIONS: Sequence[Migration] = [ Migration( """ CREATE EXTENSION IF NOT EXISTS vector; """, ), Migration( """ CREATE TABLE IF NOT EXISTS store_vectors ( prefix text NOT NULL, key text NOT NULL, field_name text NOT NULL, embedding %(vector_type)s(%(dims)s), created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (prefix, key, field_name), FOREIGN KEY (prefix, key) REFERENCES store(prefix, key) ON DELETE CASCADE ); """, params={ "dims": lambda store: store.index_config["dims"], "vector_type": lambda store: ( cast(PostgresIndexConfig, store.index_config) .get("ann_index_config", {}) .get("vector_type", "vector") ), }, ), Migration( """ CREATE INDEX IF NOT EXISTS store_vectors_embedding_idx ON store_vectors USING %(index_type)s (embedding %(ops)s)%(index_params)s; """, condition=lambda store: bool( store.index_config and _get_index_params(store)[0] != "flat" ), params={ "index_type": lambda store: _get_index_params(store)[0], "ops": lambda store: _get_vector_type_ops(store), "index_params": lambda store: ( " WITH (" + ", ".join(f"{k}={v}" for k, v in _get_index_params(store)[1].items()) + ")" if _get_index_params(store)[1] else "" ), }, ), ] C = TypeVar("C", bound=Union[_pg_internal.Conn, _ainternal.Conn]) class PoolConfig(TypedDict, total=False): """Connection pool settings for PostgreSQL connections. Controls connection lifecycle and resource utilization: - Small pools (1-5) suit low-concurrency workloads - Larger pools handle concurrent requests but consume more resources - Setting max_size prevents resource exhaustion under load """ min_size: int """Minimum number of connections maintained in the pool. Defaults to 1.""" max_size: Optional[int] """Maximum number of connections allowed in the pool. None means unlimited.""" kwargs: dict """Additional connection arguments passed to each connection in the pool. Default kwargs set automatically: - autocommit: True - prepare_threshold: 0 - row_factory: dict_row """ class ANNIndexConfig(TypedDict, total=False): """Configuration for vector index in PostgreSQL store.""" kind: Literal["hnsw", "ivfflat", "flat"] """Type of index to use: 'hnsw' for Hierarchical Navigable Small World, or 'ivfflat' for Inverted File Flat.""" vector_type: Literal["vector", "halfvec"] """Type of vector storage to use. Options: - 'vector': Regular vectors (default) - 'halfvec': Half-precision vectors for reduced memory usage """ class HNSWConfig(ANNIndexConfig, total=False): """Configuration for HNSW (Hierarchical Navigable Small World) index.""" kind: Literal["hnsw"] # type: ignore[misc] m: int """Maximum number of connections per layer. Default is 16.""" ef_construction: int """Size of dynamic candidate list for index construction. Default is 64.""" class IVFFlatConfig(ANNIndexConfig, total=False): """IVFFlat index divides vectors into lists, and then searches a subset of those lists that are closest to the query vector. It has faster build times and uses less memory than HNSW, but has lower query performance (in terms of speed-recall tradeoff). Three keys to achieving good recall are: 1. Create the index after the table has some data 2. Choose an appropriate number of lists - a good place to start is rows / 1000 for up to 1M rows and sqrt(rows) for over 1M rows 3. When querying, specify an appropriate number of probes (higher is better for recall, lower is better for speed) - a good place to start is sqrt(lists) """ kind: Literal["ivfflat"] # type: ignore[misc] nlist: int """Number of inverted lists (clusters) for IVF index. Determines the number of clusters used in the index structure. Higher values can improve search speed but increase index size and build time. Typically set to the square root of the number of vectors in the index. """ class PostgresIndexConfig(IndexConfig, total=False): """Configuration for vector embeddings in PostgreSQL store with pgvector-specific options. Extends EmbeddingConfig with additional configuration for pgvector index and vector types. """ ann_index_config: ANNIndexConfig """Specific configuration for the chosen index type (HNSW or IVF Flat).""" distance_type: Literal["l2", "inner_product", "cosine"] """Distance metric to use for vector similarity search: - 'l2': Euclidean distance - 'inner_product': Dot product - 'cosine': Cosine similarity """ class BasePostgresStore(Generic[C]): MIGRATIONS = MIGRATIONS VECTOR_MIGRATIONS = VECTOR_MIGRATIONS conn: C _deserializer: Optional[Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]] index_config: Optional[PostgresIndexConfig] def _get_batch_GET_ops_queries( self, get_ops: Sequence[tuple[int, GetOp]], ) -> list[tuple[str, tuple, tuple[str, ...], list]]: namespace_groups = defaultdict(list) for idx, op in get_ops: namespace_groups[op.namespace].append((idx, op.key)) results = [] for namespace, items in namespace_groups.items(): _, keys = zip(*items) keys_to_query = ",".join(["%s"] * len(keys)) query = f""" SELECT key, value, created_at, updated_at FROM store WHERE prefix = %s AND key IN ({keys_to_query}) """ params = (_namespace_to_text(namespace), *keys) results.append((query, params, namespace, items)) return results def _prepare_batch_PUT_queries( self, put_ops: Sequence[tuple[int, PutOp]], ) -> tuple[ list[tuple[str, Sequence]], Optional[tuple[str, Sequence[tuple[str, str, str, str]]]], ]: # Last-write wins dedupped_ops: dict[tuple[tuple[str, ...], str], PutOp] = {} for _, op in put_ops: dedupped_ops[(op.namespace, op.key)] = op inserts: list[PutOp] = [] deletes: list[PutOp] = [] for op in dedupped_ops.values(): if op.value is None: deletes.append(op) else: inserts.append(op) queries: list[tuple[str, Sequence]] = [] if deletes: namespace_groups: dict[tuple[str, ...], list[str]] = defaultdict(list) for op in deletes: namespace_groups[op.namespace].append(op.key) for namespace, keys in namespace_groups.items(): placeholders = ",".join(["%s"] * len(keys)) query = ( f"DELETE FROM store WHERE prefix = %s AND key IN ({placeholders})" ) params = (_namespace_to_text(namespace), *keys) queries.append((query, params)) embedding_request: Optional[tuple[str, Sequence[tuple[str, str, str, str]]]] = ( None ) if inserts: values = [] insertion_params = [] vector_values = [] embedding_request_params = [] # First handle main store insertions for op in inserts: values.append("(%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)") insertion_params.extend( [ _namespace_to_text(op.namespace), op.key, Jsonb(cast(dict, op.value)), ] ) # Then handle embeddings if configured if self.index_config: for op in inserts: if op.index is False: continue value = op.value ns = _namespace_to_text(op.namespace) k = op.key if op.index is None: paths = self.index_config["__tokenized_fields"] else: paths = [(ix, tokenize_path(ix)) for ix in op.index] for path, tokenized_path in paths: texts = get_text_at_path(value, tokenized_path) for i, text in enumerate(texts): pathname = f"{path}.{i}" if len(texts) > 1 else path vector_values.append( "(%s, %s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)" ) embedding_request_params.append((ns, k, pathname, text)) values_str = ",".join(values) query = f""" INSERT INTO store (prefix, key, value, created_at, updated_at) VALUES {values_str} ON CONFLICT (prefix, key) DO UPDATE SET value = EXCLUDED.value, updated_at = CURRENT_TIMESTAMP """ queries.append((query, insertion_params)) if vector_values: values_str = ",".join(vector_values) query = f""" INSERT INTO store_vectors (prefix, key, field_name, embedding, created_at, updated_at) VALUES {values_str} ON CONFLICT (prefix, key, field_name) DO UPDATE SET embedding = EXCLUDED.embedding, updated_at = CURRENT_TIMESTAMP """ embedding_request = (query, embedding_request_params) return queries, embedding_request def _prepare_batch_search_queries( self, search_ops: Sequence[tuple[int, SearchOp]], ) -> tuple[ list[tuple[str, list[Union[None, str, list[float]]]]], # queries, params list[tuple[int, str]], # idx, query_text pairs to embed ]: queries = [] embedding_requests = [] for idx, (_, op) in enumerate(search_ops): # Build filter conditions first filter_params = [] filter_conditions = [] if op.filter: for key, value in op.filter.items(): if isinstance(value, dict): for op_name, val in value.items(): condition, filter_params_ = self._get_filter_condition( key, op_name, val ) filter_conditions.append(condition) filter_params.extend(filter_params_) else: filter_conditions.append("value->%s = %s::jsonb") filter_params.extend([key, json.dumps(value)]) # Vector search branch if op.query and self.index_config: embedding_requests.append((idx, op.query)) score_operator, post_operator = _get_distance_operator(self) vector_type = ( cast(PostgresIndexConfig, self.index_config) .get("ann_index_config", {}) .get("vector_type", "vector") ) if ( vector_type == "bit" and self.index_config.get("distance_type") == "hamming" ): score_operator = score_operator % ( "%s", self.index_config["dims"], ) else: score_operator = score_operator % ( "%s", vector_type, ) vectors_per_doc_estimate = self.index_config["__estimated_num_vectors"] expanded_limit = (op.limit * vectors_per_doc_estimate * 2) + 1 # Vector search with CTE for proper score handling filter_str = ( "" if not filter_conditions else " AND " + " AND ".join(filter_conditions) ) if op.namespace_prefix: prefix_filter_str = f"WHERE s.prefix LIKE %s {filter_str} " ns_args: Sequence = (f"{_namespace_to_text(op.namespace_prefix)}%",) else: ns_args = () if filter_str: prefix_filter_str = f"WHERE {filter_str} " else: prefix_filter_str = "" base_query = f""" WITH scored AS ( SELECT s.prefix, s.key, s.value, s.created_at, s.updated_at, {score_operator} AS neg_score FROM store s JOIN store_vectors sv ON s.prefix = sv.prefix AND s.key = sv.key {prefix_filter_str} ORDER BY {score_operator} ASC LIMIT %s ) SELECT * FROM ( SELECT DISTINCT ON (prefix, key) prefix, key, value, created_at, updated_at, {post_operator} as score FROM scored ORDER BY prefix, key, score DESC ) AS unique_docs ORDER BY score DESC LIMIT %s OFFSET %s """ params = [ _PLACEHOLDER, # Vector placeholder *ns_args, *filter_params, _PLACEHOLDER, expanded_limit, op.limit, op.offset, ] # Regular search branch else: base_query = """ SELECT prefix, key, value, created_at, updated_at FROM store WHERE prefix LIKE %s """ params = [f"{_namespace_to_text(op.namespace_prefix)}%"] if filter_conditions: params.extend(filter_params) base_query += " AND " + " AND ".join(filter_conditions) base_query += " ORDER BY updated_at DESC" base_query += " LIMIT %s OFFSET %s" params.extend([op.limit, op.offset]) queries.append((base_query, params)) return queries, embedding_requests def _get_batch_list_namespaces_queries( self, list_ops: Sequence[tuple[int, ListNamespacesOp]], ) -> list[tuple[str, Sequence]]: queries: list[tuple[str, Sequence]] = [] for _, op in list_ops: query = """ SELECT DISTINCT ON (truncated_prefix) truncated_prefix, prefix FROM ( SELECT prefix, CASE WHEN %s::integer IS NOT NULL THEN (SELECT STRING_AGG(part, '.' ORDER BY idx) FROM ( SELECT part, ROW_NUMBER() OVER () AS idx FROM UNNEST(REGEXP_SPLIT_TO_ARRAY(prefix, '\.')) AS part LIMIT %s::integer ) subquery ) ELSE prefix END AS truncated_prefix FROM store """ params: list[Any] = [op.max_depth, op.max_depth] conditions = [] if op.match_conditions: for condition in op.match_conditions: if condition.match_type == "prefix": conditions.append("prefix LIKE %s") params.append( f"{_namespace_to_text(condition.path, handle_wildcards=True)}%" ) elif condition.match_type == "suffix": conditions.append("prefix LIKE %s") params.append( f"%{_namespace_to_text(condition.path, handle_wildcards=True)}" ) else: logger.warning( f"Unknown match_type in list_namespaces: {condition.match_type}" ) if conditions: query += " WHERE " + " AND ".join(conditions) query += ") AS subquery " query += " ORDER BY truncated_prefix LIMIT %s OFFSET %s" params.extend([op.limit, op.offset]) queries.append((query, tuple(params))) return queries def _get_filter_condition(self, key: str, op: str, value: Any) -> tuple[str, list]: """Helper to generate filter conditions.""" if op == "$eq": return "value->%s = %s::jsonb", [key, json.dumps(value)] elif op == "$gt": return "value->>%s > %s", [key, str(value)] elif op == "$gte": return "value->>%s >= %s", [key, str(value)] elif op == "$lt": return "value->>%s < %s", [key, str(value)] elif op == "$lte": return "value->>%s <= %s", [key, str(value)] elif op == "$ne": return "value->%s != %s::jsonb", [key, json.dumps(value)] else: raise ValueError(f"Unsupported operator: {op}") class PostgresStore(BaseStore, BasePostgresStore[_pg_internal.Conn]): """Postgres-backed store with optional vector search using pgvector. !!! example "Examples" Basic setup and key-value storage: ```python from langgraph.store.postgres import PostgresStore store = PostgresStore( connection_string="postgresql://user:pass@localhost:5432/dbname" ) store.setup() # Store and retrieve data store.put(("users", "123"), "prefs", {"theme": "dark"}) item = store.get(("users", "123"), "prefs") ``` Vector search using LangChain embeddings: ```python from langchain.embeddings import init_embeddings from langgraph.store.postgres import PostgresStore store = PostgresStore( connection_string="postgresql://user:pass@localhost:5432/dbname", index={ "dims": 1536, "embed": init_embeddings("openai:text-embedding-3-small"), "fields": ["text"] # specify which fields to embed. Default is the whole serialized value } ) store.setup() # Do this once to run migrations # Store documents store.put(("docs",), "doc1", {"text": "Python tutorial"}) store.put(("docs",), "doc2", {"text": "TypeScript guide"}) store.put(("docs",), "doc2", {"text": "Other guide"}, index=False) # don't index # Search by similarity results = store.search(("docs",), query="python programming") Note: Semantic search is disabled by default. You can enable it by providing an `index` configuration when creating the store. Without this configuration, all `index` arguments passed to `put` or `aput`will have no effect. Warning: Make sure to call `setup()` before first use to create necessary tables and indexes. The pgvector extension must be available to use vector search. """ __slots__ = ( "_deserializer", "pipe", "lock", "supports_pipeline", "index_config", "embeddings", ) def __init__( self, conn: _pg_internal.Conn, *, pipe: Optional[Pipeline] = None, deserializer: Optional[ Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]] ] = None, index: Optional[PostgresIndexConfig] = None, ) -> None: super().__init__() self._deserializer = deserializer self.conn = conn self.pipe = pipe self.supports_pipeline = Capabilities().has_pipeline() self.lock = threading.Lock() self.index_config = index if self.index_config: self.embeddings, self.index_config = _ensure_index_config(self.index_config) else: self.embeddings = None @classmethod @contextmanager def from_conn_string( cls, conn_string: str, *, pipeline: bool = False, pool_config: Optional[PoolConfig] = None, index: Optional[PostgresIndexConfig] = None, ) -> Iterator["PostgresStore"]: """Create a new PostgresStore instance from a connection string. Args: conn_string (str): The Postgres connection info string. pipeline (bool): whether to use Pipeline pool_config (Optional[PoolArgs]): Configuration for the connection pool. If provided, will create a connection pool and use it instead of a single connection. This overrides the `pipeline` argument. index (Optional[PostgresIndexConfig]): The index configuration for the store. Returns: PostgresStore: A new PostgresStore instance. """ if pool_config is not None: pc = pool_config.copy() with cast( ConnectionPool[Connection[DictRow]], ConnectionPool( conn_string, min_size=pc.pop("min_size", 1), max_size=pc.pop("max_size", None), kwargs={ "autocommit": True, "prepare_threshold": 0, "row_factory": dict_row, **(pc.pop("kwargs", None) or {}), }, **cast(dict, pc), ), ) as pool: yield cls(conn=pool, index=index) else: with Connection.connect( conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row ) as conn: if pipeline: with conn.pipeline() as pipe: yield cls(conn, pipe=pipe, index=index) else: yield cls(conn, index=index) @contextmanager def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: """Create a database cursor as a context manager. Args: pipeline (bool): whether to use pipeline for the DB operations inside the context manager. Will be applied regardless of whether the PostgresStore instance was initialized with a pipeline. If pipeline mode is not supported, will fall back to using transaction context manager. """ with _pg_internal.get_connection(self.conn) as conn: if self.pipe: # a connection in pipeline mode can be used concurrently # in multiple threads/coroutines, but only one cursor can be # used at a time try: with conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur finally: if pipeline: self.pipe.sync() elif pipeline: # a connection not in pipeline mode can only be used by one # thread/coroutine at a time, so we acquire a lock if self.supports_pipeline: with ( self.lock, conn.pipeline(), conn.cursor(binary=True, row_factory=dict_row) as cur, ): yield cur else: with ( self.lock, conn.transaction(), conn.cursor(binary=True, row_factory=dict_row) as cur, ): yield cur else: with conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur def batch(self, ops: Iterable[Op]) -> list[Result]: grouped_ops, num_ops = _group_ops(ops) results: list[Result] = [None] * num_ops with self._cursor(pipeline=True) as cur: if GetOp in grouped_ops: self._batch_get_ops( cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), results, cur ) if SearchOp in grouped_ops: self._batch_search_ops( cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]), results, cur, ) if ListNamespacesOp in grouped_ops: self._batch_list_namespaces_ops( cast( Sequence[tuple[int, ListNamespacesOp]], grouped_ops[ListNamespacesOp], ), results, cur, ) if PutOp in grouped_ops: self._batch_put_ops( cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]), cur ) return results def _batch_get_ops( self, get_ops: Sequence[tuple[int, GetOp]], results: list[Result], cur: Cursor[DictRow], ) -> None: for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops): cur.execute(query, params) rows = cast(list[Row], cur.fetchall()) key_to_row = {row["key"]: row for row in rows} for idx, key in items: row = key_to_row.get(key) if row: results[idx] = _row_to_item( namespace, row, loader=self._deserializer ) else: results[idx] = None def _batch_put_ops( self, put_ops: Sequence[tuple[int, PutOp]], cur: Cursor[DictRow], ) -> None: queries, embedding_request = self._prepare_batch_PUT_queries(put_ops) if embedding_request: if self.embeddings is None: # Should not get here since the embedding config is required # to return an embedding_request above raise ValueError( "Embedding configuration is required for vector operations " f"(for semantic search). " f"Please provide an Embeddings when initializing the {self.__class__.__name__}." ) query, txt_params = embedding_request # Update the params to replace the raw text with the vectors vectors = self.embeddings.embed_documents( [param[-1] for param in txt_params] ) queries.append( ( query, [ p for (ns, k, pathname, _), vector in zip(txt_params, vectors) for p in (ns, k, pathname, vector) ], ) ) for query, params in queries: cur.execute(query, params) def _batch_search_ops( self, search_ops: Sequence[tuple[int, SearchOp]], results: list[Result], cur: Cursor[DictRow], ) -> None: queries, embedding_requests = self._prepare_batch_search_queries(search_ops) if embedding_requests and self.embeddings: embeddings = self.embeddings.embed_documents( [query for _, query in embedding_requests] ) for (idx, _), embedding in zip(embedding_requests, embeddings): _paramslist = queries[idx][1] for i in range(len(_paramslist)): if _paramslist[i] is _PLACEHOLDER: _paramslist[i] = embedding for (idx, _), (query, params) in zip(search_ops, queries): cur.execute(query, params) rows = cast(list[Row], cur.fetchall()) results[idx] = [ _row_to_search_item( _decode_ns_bytes(row["prefix"]), row, loader=self._deserializer ) for row in rows ] def _batch_list_namespaces_ops( self, list_ops: Sequence[tuple[int, ListNamespacesOp]], results: list[Result], cur: Cursor[DictRow], ) -> None: for (query, params), (idx, _) in zip( self._get_batch_list_namespaces_queries(list_ops), list_ops ): cur.execute(query, params) results[idx] = [_decode_ns_bytes(row["truncated_prefix"]) for row in cur] async def abatch(self, ops: Iterable[Op]) -> list[Result]: return await asyncio.get_running_loop().run_in_executor(None, self.batch, ops) def setup(self) -> None: """Set up the store database. This method creates the necessary tables in the Postgres database if they don't already exist and runs database migrations. It MUST be called directly by the user the first time the store is used. """ def _get_version(cur: Cursor[dict[str, Any]], table: str) -> int: try: cur.execute(f"SELECT v FROM {table} ORDER BY v DESC LIMIT 1") row = cast(dict, cur.fetchone()) if row is None: version = -1 else: version = row["v"] except UndefinedTable: version = -1 cur.execute( f""" CREATE TABLE IF NOT EXISTS {table} ( v INTEGER PRIMARY KEY ) """ ) return version with self._cursor() as cur: version = _get_version(cur, table="store_migrations") for v, sql in enumerate(self.MIGRATIONS[version + 1 :], start=version + 1): cur.execute(sql) cur.execute("INSERT INTO store_migrations (v) VALUES (%s)", (v,)) if self.index_config: version = _get_version(cur, table="vector_migrations") for v, migration in enumerate( self.VECTOR_MIGRATIONS[version + 1 :], start=version + 1 ): if migration.condition and not migration.condition(self): continue sql = migration.sql if migration.params: params = { k: v(self) if v is not None and callable(v) else v for k, v in migration.params.items() } sql = sql % params cur.execute(sql) cur.execute("INSERT INTO vector_migrations (v) VALUES (%s)", (v,)) class Row(TypedDict): key: str value: Any prefix: str created_at: datetime updated_at: datetime # Private utilities _DEFAULT_ANN_CONFIG = ANNIndexConfig( vector_type="vector", ) def _get_vector_type_ops(store: BasePostgresStore) -> str: """Get the vector type operator class based on config.""" if not store.index_config: return "vector_cosine_ops" config = cast(PostgresIndexConfig, store.index_config) index_config = config.get("ann_index_config", _DEFAULT_ANN_CONFIG).copy() vector_type = cast(str, index_config.get("vector_type", "vector")) if vector_type not in ("vector", "halfvec"): raise ValueError( f"Vector type must be 'vector' or 'halfvec', got {vector_type}" ) distance_type = config.get("distance_type", "cosine") # For regular vectors type_prefix = {"vector": "vector", "halfvec": "halfvec"}[vector_type] if distance_type not in ("l2", "inner_product", "cosine"): raise ValueError( f"Vector type {vector_type} only supports 'l2', 'inner_product', or 'cosine' distance, got {distance_type}" ) distance_suffix = { "l2": "l2_ops", "inner_product": "ip_ops", "cosine": "cosine_ops", }[distance_type] return f"{type_prefix}_{distance_suffix}" def _get_index_params(store: Any) -> tuple[str, dict[str, Any]]: """Get the index type and configuration based on config.""" if not store.index_config: return "hnsw", {} config = cast(PostgresIndexConfig, store.index_config) index_config = config.get("ann_index_config", _DEFAULT_ANN_CONFIG).copy() kind = index_config.pop("kind", "hnsw") index_config.pop("vector_type", None) return kind, index_config def _namespace_to_text( namespace: tuple[str, ...], handle_wildcards: bool = False ) -> str: """Convert namespace tuple to text string.""" if handle_wildcards: namespace = tuple("%" if val == "*" else val for val in namespace) return ".".join(namespace) def _row_to_item( namespace: tuple[str, ...], row: Row, *, loader: Optional[Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]] = None, ) -> Item: """Convert a row from the database into an Item. Args: namespace: Item namespace row: Database row loader: Optional value loader for non-dict values """ val = row["value"] if not isinstance(val, dict): val = (loader or _json_loads)(val) kwargs = { "key": row["key"], "namespace": namespace, "value": val, "created_at": row["created_at"], "updated_at": row["updated_at"], } return Item(**kwargs) def _row_to_search_item( namespace: tuple[str, ...], row: Row, *, loader: Optional[Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]] = None, ) -> SearchItem: """Convert a row from the database into an Item.""" loader = loader or _json_loads val = row["value"] score = row.get("score") if score is not None: try: score = float(score) # type: ignore[arg-type] except ValueError: logger.warning("Invalid score: %s", score) score = None return SearchItem( value=val if isinstance(val, dict) else loader(val), key=row["key"], namespace=namespace, created_at=row["created_at"], updated_at=row["updated_at"], score=score, ) def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]: grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list) tot = 0 for idx, op in enumerate(ops): grouped_ops[type(op)].append((idx, op)) tot += 1 return grouped_ops, tot def _json_loads(content: Union[bytes, orjson.Fragment]) -> Any: if isinstance(content, orjson.Fragment): if hasattr(content, "buf"): content = content.buf else: if isinstance(content.contents, bytes): content = content.contents else: content = content.contents.encode() return orjson.loads(cast(bytes, content)) def _decode_ns_bytes(namespace: Union[str, bytes, list]) -> tuple[str, ...]: if isinstance(namespace, list): return tuple(namespace) if isinstance(namespace, bytes): namespace = namespace.decode()[1:] return tuple(namespace.split(".")) def _get_distance_operator(store: Any) -> tuple[str, str]: """Get the distance operator and score expression based on config.""" # Note: Today, we are not using ANN indices due to restrictions # on PGVector's support for mixing vector and non-vector filters # To use the index, PGVector expects: # - ORDER BY the operator NOT an expression (even negation blocks it) # - ASCENDING order # - Any WHERE clause should be over a partial index. # If we violate any of these, it will use a sequential scan # See https://github.com/pgvector/pgvector/issues/216 and the # pgvector documentation for more details. if not store.index_config: raise ValueError( "Embedding configuration is required for vector operations " f"(for semantic search). " f"Please provide an Embeddings when initializing the {store.__class__.__name__}." ) config = cast(PostgresIndexConfig, store.index_config) distance_type = config.get("distance_type", "cosine") # Return the operator and the score expression # The operator is used in the CTE and will be compatible with an ASCENDING ORDER # sort clause. # The score expression is used in the final query and will be compatible with # a DESCENDING ORDER sort clause and the user's expectations of what the similarity score # should be. if distance_type == "l2": # Final: "-(sv.embedding <-> %s::%s)" # We return the "l2 similarity" so that the sorting order is the same return "sv.embedding <-> %s::%s", "-scored.neg_score" elif distance_type == "inner_product": # Final: "-(sv.embedding <#> %s::%s)" return "sv.embedding <#> %s::%s", "-(scored.neg_score)" else: # cosine similarity # Final: "1 - (sv.embedding <=> %s::%s)" return "sv.embedding <=> %s::%s", "1 - scored.neg_score" def _ensure_index_config( index_config: PostgresIndexConfig, ) -> tuple[Optional["Embeddings"], PostgresIndexConfig]: index_config = index_config.copy() tokenized: list[tuple[str, Union[Literal["$"], list[str]]]] = [] tot = 0 text_fields = index_config.get("text_fields") or ["$"] if isinstance(text_fields, str): text_fields = [text_fields] if not isinstance(text_fields, list): raise ValueError(f"Text fields must be a list or a string. Got {text_fields}") for p in text_fields: if p == "$": tokenized.append((p, "$")) tot += 1 else: toks = tokenize_path(p) tokenized.append((p, toks)) tot += len(toks) index_config["__tokenized_fields"] = tokenized index_config["__estimated_num_vectors"] = tot embeddings = ensure_embeddings( index_config.get("embed"), ) return embeddings, index_config _PLACEHOLDER = object() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/pyproject.toml`: ```toml [tool.poetry] name = "langgraph-checkpoint-postgres" version = "2.0.7" description = "Library with a Postgres implementation of LangGraph checkpoint saver." authors = [] license = "MIT" readme = "README.md" repository = "https://www.github.com/langchain-ai/langgraph" packages = [{ include = "langgraph" }] [tool.poetry.dependencies] python = "^3.9.0,<4.0" langgraph-checkpoint = "^2.0.7" orjson = ">=3.10.1" psycopg = "^3.2.0" psycopg-pool = "^3.2.0" [tool.poetry.group.dev.dependencies] ruff = "^0.6.2" codespell = "^2.2.0" pytest = "^7.2.1" anyio = "^4.4.0" pytest-asyncio = "^0.21.1" pytest-mock = "^3.11.1" pytest-watch = "^4.2.0" mypy = "^1.10.0" psycopg = {extras = ["binary"], version = ">=3.0.0"} langgraph-checkpoint = {path = "../checkpoint", develop = true} [tool.pytest.ini_options] # --strict-markers will raise errors on unknown marks. # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks # # https://docs.pytest.org/en/7.1.x/reference/reference.html # --strict-config any warnings encountered while parsing the `pytest` # section of the configuration file raise errors. addopts = "--strict-markers --strict-config --durations=5 -vv" asyncio_mode = "auto" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] lint.select = [ "E", # pycodestyle "F", # Pyflakes "UP", # pyupgrade "B", # flake8-bugbear "I", # isort ] lint.ignore = ["E501", "B008", "UP007", "UP006"] [tool.mypy] # https://mypy.readthedocs.io/en/stable/config_file.html disallow_untyped_defs = "True" explicit_package_bases = "True" warn_no_return = "False" warn_unused_ignores = "True" warn_redundant_casts = "True" allow_redefinition = "True" disable_error_code = "typeddict-item, return-value" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/conftest.py`: ```py from collections.abc import AsyncIterator import pytest from psycopg import AsyncConnection from psycopg.errors import UndefinedTable from psycopg.rows import DictRow, dict_row from tests.embed_test_utils import CharacterEmbeddings DEFAULT_POSTGRES_URI = "postgres://postgres:postgres@localhost:5441/" DEFAULT_URI = "postgres://postgres:postgres@localhost:5441/postgres?sslmode=disable" @pytest.fixture(scope="function") async def conn() -> AsyncIterator[AsyncConnection[DictRow]]: async with await AsyncConnection.connect( DEFAULT_URI, autocommit=True, prepare_threshold=0, row_factory=dict_row ) as conn: yield conn @pytest.fixture(scope="function", autouse=True) async def clear_test_db(conn: AsyncConnection[DictRow]) -> None: """Delete all tables before each test.""" try: await conn.execute("DELETE FROM checkpoints") await conn.execute("DELETE FROM checkpoint_blobs") await conn.execute("DELETE FROM checkpoint_writes") await conn.execute("DELETE FROM checkpoint_migrations") except UndefinedTable: pass try: await conn.execute("DELETE FROM store_migrations") await conn.execute("DELETE FROM store") except UndefinedTable: pass @pytest.fixture def fake_embeddings() -> CharacterEmbeddings: return CharacterEmbeddings(dims=500) VECTOR_TYPES = ["vector", "halfvec"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/test_sync.py`: ```py # type: ignore from contextlib import contextmanager from typing import Any from uuid import uuid4 import pytest from langchain_core.runnables import RunnableConfig from psycopg import Connection from psycopg.rows import dict_row from psycopg_pool import ConnectionPool from langgraph.checkpoint.base import ( Checkpoint, CheckpointMetadata, create_checkpoint, empty_checkpoint, ) from langgraph.checkpoint.postgres import PostgresSaver from tests.conftest import DEFAULT_POSTGRES_URI @contextmanager def _pool_saver(): """Fixture for pool mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: # yield checkpointer with ConnectionPool( DEFAULT_POSTGRES_URI + database, max_size=10, kwargs={"autocommit": True, "row_factory": dict_row}, ) as pool: checkpointer = PostgresSaver(pool) checkpointer.setup() yield checkpointer finally: # drop unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") @contextmanager def _pipe_saver(): """Fixture for pipeline mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: with Connection.connect( DEFAULT_POSTGRES_URI + database, autocommit=True, prepare_threshold=0, row_factory=dict_row, ) as conn: with conn.pipeline() as pipe: checkpointer = PostgresSaver(conn, pipe=pipe) checkpointer.setup() with conn.pipeline() as pipe: checkpointer = PostgresSaver(conn, pipe=pipe) yield checkpointer finally: # drop unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") @contextmanager def _base_saver(): """Fixture for regular connection mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: with Connection.connect( DEFAULT_POSTGRES_URI + database, autocommit=True, prepare_threshold=0, row_factory=dict_row, ) as conn: checkpointer = PostgresSaver(conn) checkpointer.setup() yield checkpointer finally: # drop unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") @contextmanager def _saver(name: str): if name == "base": with _base_saver() as saver: yield saver elif name == "pool": with _pool_saver() as saver: yield saver elif name == "pipe": with _pipe_saver() as saver: yield saver @pytest.fixture def test_data(): """Fixture providing test data for checkpoint tests.""" config_1: RunnableConfig = { "configurable": { "thread_id": "thread-1", # for backwards compatibility testing "thread_ts": "1", "checkpoint_ns": "", } } config_2: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2", "checkpoint_ns": "", } } config_3: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2-inner", "checkpoint_ns": "inner", } } chkpnt_1: Checkpoint = empty_checkpoint() chkpnt_2: Checkpoint = create_checkpoint(chkpnt_1, {}, 1) chkpnt_3: Checkpoint = empty_checkpoint() metadata_1: CheckpointMetadata = { "source": "input", "step": 2, "writes": {}, "score": 1, } metadata_2: CheckpointMetadata = { "source": "loop", "step": 1, "writes": {"foo": "bar"}, "score": None, } metadata_3: CheckpointMetadata = {} return { "configs": [config_1, config_2, config_3], "checkpoints": [chkpnt_1, chkpnt_2, chkpnt_3], "metadata": [metadata_1, metadata_2, metadata_3], } @pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"]) def test_search(saver_name: str, test_data) -> None: with _saver(saver_name) as saver: configs = test_data["configs"] checkpoints = test_data["checkpoints"] metadata = test_data["metadata"] saver.put(configs[0], checkpoints[0], metadata[0], {}) saver.put(configs[1], checkpoints[1], metadata[1], {}) saver.put(configs[2], checkpoints[2], metadata[2], {}) # call method / assertions query_1 = {"source": "input"} # search by 1 key query_2 = { "step": 1, "writes": {"foo": "bar"}, } # search by multiple keys query_3: dict[str, Any] = {} # search by no keys, return all checkpoints query_4 = {"source": "update", "step": 1} # no match search_results_1 = list(saver.list(None, filter=query_1)) assert len(search_results_1) == 1 assert search_results_1[0].metadata == metadata[0] search_results_2 = list(saver.list(None, filter=query_2)) assert len(search_results_2) == 1 assert search_results_2[0].metadata == metadata[1] search_results_3 = list(saver.list(None, filter=query_3)) assert len(search_results_3) == 3 search_results_4 = list(saver.list(None, filter=query_4)) assert len(search_results_4) == 0 # search by config (defaults to checkpoints across all namespaces) search_results_5 = list(saver.list({"configurable": {"thread_id": "thread-2"}})) assert len(search_results_5) == 2 assert { search_results_5[0].config["configurable"]["checkpoint_ns"], search_results_5[1].config["configurable"]["checkpoint_ns"], } == {"", "inner"} @pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"]) def test_null_chars(saver_name: str, test_data) -> None: with _saver(saver_name) as saver: config = saver.put( test_data["configs"][0], test_data["checkpoints"][0], {"my_key": "\x00abc"}, {}, ) assert saver.get_tuple(config).metadata["my_key"] == "abc" # type: ignore assert ( list(saver.list(None, filter={"my_key": "abc"}))[0].metadata["my_key"] == "abc" ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/test_async.py`: ```py # type: ignore from contextlib import asynccontextmanager from typing import Any from uuid import uuid4 import pytest from langchain_core.runnables import RunnableConfig from psycopg import AsyncConnection from psycopg.rows import dict_row from psycopg_pool import AsyncConnectionPool from langgraph.checkpoint.base import ( Checkpoint, CheckpointMetadata, create_checkpoint, empty_checkpoint, ) from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from tests.conftest import DEFAULT_POSTGRES_URI @asynccontextmanager async def _pool_saver(): """Fixture for pool mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: # yield checkpointer async with AsyncConnectionPool( DEFAULT_POSTGRES_URI + database, max_size=10, kwargs={"autocommit": True, "row_factory": dict_row}, ) as pool: checkpointer = AsyncPostgresSaver(pool) await checkpointer.setup() yield checkpointer finally: # drop unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") @asynccontextmanager async def _pipe_saver(): """Fixture for pipeline mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI + database, autocommit=True, prepare_threshold=0, row_factory=dict_row, ) as conn: async with conn.pipeline() as pipe: checkpointer = AsyncPostgresSaver(conn, pipe=pipe) await checkpointer.setup() async with conn.pipeline() as pipe: checkpointer = AsyncPostgresSaver(conn, pipe=pipe) yield checkpointer finally: # drop unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") @asynccontextmanager async def _base_saver(): """Fixture for regular connection mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI + database, autocommit=True, prepare_threshold=0, row_factory=dict_row, ) as conn: checkpointer = AsyncPostgresSaver(conn) await checkpointer.setup() yield checkpointer finally: # drop unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") @asynccontextmanager async def _saver(name: str): if name == "base": async with _base_saver() as saver: yield saver elif name == "pool": async with _pool_saver() as saver: yield saver elif name == "pipe": async with _pipe_saver() as saver: yield saver @pytest.fixture def test_data(): """Fixture providing test data for checkpoint tests.""" config_1: RunnableConfig = { "configurable": { "thread_id": "thread-1", # for backwards compatibility testing "thread_ts": "1", "checkpoint_ns": "", } } config_2: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2", "checkpoint_ns": "", } } config_3: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2-inner", "checkpoint_ns": "inner", } } chkpnt_1: Checkpoint = empty_checkpoint() chkpnt_2: Checkpoint = create_checkpoint(chkpnt_1, {}, 1) chkpnt_3: Checkpoint = empty_checkpoint() metadata_1: CheckpointMetadata = { "source": "input", "step": 2, "writes": {}, "score": 1, } metadata_2: CheckpointMetadata = { "source": "loop", "step": 1, "writes": {"foo": "bar"}, "score": None, } metadata_3: CheckpointMetadata = {} return { "configs": [config_1, config_2, config_3], "checkpoints": [chkpnt_1, chkpnt_2, chkpnt_3], "metadata": [metadata_1, metadata_2, metadata_3], } @pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"]) async def test_asearch(request, saver_name: str, test_data) -> None: async with _saver(saver_name) as saver: configs = test_data["configs"] checkpoints = test_data["checkpoints"] metadata = test_data["metadata"] await saver.aput(configs[0], checkpoints[0], metadata[0], {}) await saver.aput(configs[1], checkpoints[1], metadata[1], {}) await saver.aput(configs[2], checkpoints[2], metadata[2], {}) # call method / assertions query_1 = {"source": "input"} # search by 1 key query_2 = { "step": 1, "writes": {"foo": "bar"}, } # search by multiple keys query_3: dict[str, Any] = {} # search by no keys, return all checkpoints query_4 = {"source": "update", "step": 1} # no match search_results_1 = [c async for c in saver.alist(None, filter=query_1)] assert len(search_results_1) == 1 assert search_results_1[0].metadata == metadata[0] search_results_2 = [c async for c in saver.alist(None, filter=query_2)] assert len(search_results_2) == 1 assert search_results_2[0].metadata == metadata[1] search_results_3 = [c async for c in saver.alist(None, filter=query_3)] assert len(search_results_3) == 3 search_results_4 = [c async for c in saver.alist(None, filter=query_4)] assert len(search_results_4) == 0 # search by config (defaults to checkpoints across all namespaces) search_results_5 = [ c async for c in saver.alist({"configurable": {"thread_id": "thread-2"}}) ] assert len(search_results_5) == 2 assert { search_results_5[0].config["configurable"]["checkpoint_ns"], search_results_5[1].config["configurable"]["checkpoint_ns"], } == {"", "inner"} @pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"]) async def test_null_chars(request, saver_name: str, test_data) -> None: async with _saver(saver_name) as saver: config = await saver.aput( test_data["configs"][0], test_data["checkpoints"][0], {"my_key": "\x00abc"}, {}, ) assert (await saver.aget_tuple(config)).metadata["my_key"] == "abc" # type: ignore assert [c async for c in saver.alist(None, filter={"my_key": "abc"})][ 0 ].metadata["my_key"] == "abc" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/test_store.py`: ```py # type: ignore from contextlib import contextmanager from typing import Any, Optional from uuid import uuid4 import pytest from langchain_core.embeddings import Embeddings from psycopg import Connection from langgraph.store.base import ( GetOp, Item, ListNamespacesOp, MatchCondition, PutOp, SearchOp, ) from langgraph.store.postgres import PostgresStore from tests.conftest import ( DEFAULT_URI, VECTOR_TYPES, CharacterEmbeddings, ) @pytest.fixture(scope="function", params=["default", "pipe", "pool"]) def store(request) -> PostgresStore: database = f"test_{uuid4().hex[:16]}" uri_parts = DEFAULT_URI.split("/") uri_base = "/".join(uri_parts[:-1]) query_params = "" if "?" in uri_parts[-1]: db_name, query_params = uri_parts[-1].split("?", 1) query_params = "?" + query_params conn_string = f"{uri_base}/{database}{query_params}" admin_conn_string = DEFAULT_URI with Connection.connect(admin_conn_string, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: with PostgresStore.from_conn_string(conn_string) as store: store.setup() if request.param == "pipe": with PostgresStore.from_conn_string(conn_string, pipeline=True) as store: yield store elif request.param == "pool": with PostgresStore.from_conn_string( conn_string, pool_config={"min_size": 1, "max_size": 10} ) as store: yield store else: # default with PostgresStore.from_conn_string(conn_string) as store: yield store finally: with Connection.connect(admin_conn_string, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") def test_batch_order(store: PostgresStore) -> None: # Setup test data store.put(("test", "foo"), "key1", {"data": "value1"}) store.put(("test", "bar"), "key2", {"data": "value2"}) ops = [ GetOp(namespace=("test", "foo"), key="key1"), PutOp(namespace=("test", "bar"), key="key2", value={"data": "value2"}), SearchOp( namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0 ), ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0), GetOp(namespace=("test",), key="key3"), ] results = store.batch(ops) assert len(results) == 5 assert isinstance(results[0], Item) assert isinstance(results[0].value, dict) assert results[0].value == {"data": "value1"} assert results[0].key == "key1" assert results[1] is None # Put operation returns None assert isinstance(results[2], list) assert len(results[2]) == 1 assert isinstance(results[3], list) assert len(results[3]) > 0 # Should contain at least our test namespaces assert results[4] is None # Non-existent key returns None # Test reordered operations ops_reordered = [ SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0), GetOp(namespace=("test", "bar"), key="key2"), ListNamespacesOp(match_conditions=None, max_depth=None, limit=5, offset=0), PutOp(namespace=("test",), key="key3", value={"data": "value3"}), GetOp(namespace=("test", "foo"), key="key1"), ] results_reordered = store.batch(ops_reordered) assert len(results_reordered) == 5 assert isinstance(results_reordered[0], list) assert len(results_reordered[0]) >= 2 # Should find at least our two test items assert isinstance(results_reordered[1], Item) assert results_reordered[1].value == {"data": "value2"} assert results_reordered[1].key == "key2" assert isinstance(results_reordered[2], list) assert len(results_reordered[2]) > 0 assert results_reordered[3] is None # Put operation returns None assert isinstance(results_reordered[4], Item) assert results_reordered[4].value == {"data": "value1"} assert results_reordered[4].key == "key1" def test_batch_get_ops(store: PostgresStore) -> None: # Setup test data store.put(("test",), "key1", {"data": "value1"}) store.put(("test",), "key2", {"data": "value2"}) ops = [ GetOp(namespace=("test",), key="key1"), GetOp(namespace=("test",), key="key2"), GetOp(namespace=("test",), key="key3"), # Non-existent key ] results = store.batch(ops) assert len(results) == 3 assert results[0] is not None assert results[1] is not None assert results[2] is None assert results[0].key == "key1" assert results[1].key == "key2" def test_batch_put_ops(store: PostgresStore) -> None: ops = [ PutOp(namespace=("test",), key="key1", value={"data": "value1"}), PutOp(namespace=("test",), key="key2", value={"data": "value2"}), PutOp(namespace=("test",), key="key3", value=None), # Delete operation ] results = store.batch(ops) assert len(results) == 3 assert all(result is None for result in results) # Verify the puts worked item1 = store.get(("test",), "key1") item2 = store.get(("test",), "key2") item3 = store.get(("test",), "key3") assert item1 and item1.value == {"data": "value1"} assert item2 and item2.value == {"data": "value2"} assert item3 is None def test_batch_search_ops(store: PostgresStore) -> None: # Setup test data test_data = [ (("test", "foo"), "key1", {"data": "value1", "tag": "a"}), (("test", "bar"), "key2", {"data": "value2", "tag": "a"}), (("test", "baz"), "key3", {"data": "value3", "tag": "b"}), ] for namespace, key, value in test_data: store.put(namespace, key, value) ops = [ SearchOp(namespace_prefix=("test",), filter={"tag": "a"}, limit=10, offset=0), SearchOp(namespace_prefix=("test",), filter=None, limit=2, offset=0), SearchOp(namespace_prefix=("test", "foo"), filter=None, limit=10, offset=0), ] results = store.batch(ops) assert len(results) == 3 # First search should find items with tag "a" assert len(results[0]) == 2 assert all(item.value["tag"] == "a" for item in results[0]) # Second search should return first 2 items assert len(results[1]) == 2 # Third search should only find items in test/foo namespace assert len(results[2]) == 1 assert results[2][0].namespace == ("test", "foo") def test_batch_list_namespaces_ops(store: PostgresStore) -> None: # Setup test data with various namespaces test_data = [ (("test", "documents", "public"), "doc1", {"content": "public doc"}), (("test", "documents", "private"), "doc2", {"content": "private doc"}), (("test", "images", "public"), "img1", {"content": "public image"}), (("prod", "documents", "public"), "doc3", {"content": "prod doc"}), ] for namespace, key, value in test_data: store.put(namespace, key, value) ops = [ ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0), ListNamespacesOp(match_conditions=None, max_depth=2, limit=10, offset=0), ListNamespacesOp( match_conditions=[MatchCondition("suffix", "public")], max_depth=None, limit=10, offset=0, ), ] results = store.batch(ops) assert len(results) == 3 # First operation should list all namespaces assert len(results[0]) == len(test_data) # Second operation should only return namespaces up to depth 2 assert all(len(ns) <= 2 for ns in results[1]) # Third operation should only return namespaces ending with "public" assert all(ns[-1] == "public" for ns in results[2]) class TestPostgresStore: @pytest.fixture(autouse=True) def setup(self) -> None: with PostgresStore.from_conn_string(DEFAULT_URI) as store: store.setup() def test_basic_store_ops(self) -> None: with PostgresStore.from_conn_string(DEFAULT_URI) as store: namespace = ("test", "documents") item_id = "doc1" item_value = {"title": "Test Document", "content": "Hello, World!"} store.put(namespace, item_id, item_value) item = store.get(namespace, item_id) assert item assert item.namespace == namespace assert item.key == item_id assert item.value == item_value # Test update updated_value = {"title": "Updated Document", "content": "Hello, Updated!"} store.put(namespace, item_id, updated_value) updated_item = store.get(namespace, item_id) assert updated_item.value == updated_value assert updated_item.updated_at > item.updated_at # Test get from non-existent namespace different_namespace = ("test", "other_documents") item_in_different_namespace = store.get(different_namespace, item_id) assert item_in_different_namespace is None # Test delete store.delete(namespace, item_id) deleted_item = store.get(namespace, item_id) assert deleted_item is None def test_list_namespaces(self) -> None: with PostgresStore.from_conn_string(DEFAULT_URI) as store: # Create test data with various namespaces test_namespaces = [ ("test", "documents", "public"), ("test", "documents", "private"), ("test", "images", "public"), ("test", "images", "private"), ("prod", "documents", "public"), ("prod", "documents", "private"), ] # Insert test data for namespace in test_namespaces: store.put(namespace, "dummy", {"content": "dummy"}) # Test listing with various filters all_namespaces = store.list_namespaces() assert len(all_namespaces) == len(test_namespaces) # Test prefix filtering test_prefix_namespaces = store.list_namespaces(prefix=["test"]) assert len(test_prefix_namespaces) == 4 assert all(ns[0] == "test" for ns in test_prefix_namespaces) # Test suffix filtering public_namespaces = store.list_namespaces(suffix=["public"]) assert len(public_namespaces) == 3 assert all(ns[-1] == "public" for ns in public_namespaces) # Test max depth depth_2_namespaces = store.list_namespaces(max_depth=2) assert all(len(ns) <= 2 for ns in depth_2_namespaces) # Test pagination paginated_namespaces = store.list_namespaces(limit=3) assert len(paginated_namespaces) == 3 # Cleanup for namespace in test_namespaces: store.delete(namespace, "dummy") def test_search(self) -> None: with PostgresStore.from_conn_string(DEFAULT_URI) as store: # Create test data test_data = [ ( ("test", "docs"), "doc1", {"title": "First Doc", "author": "Alice", "tags": ["important"]}, ), ( ("test", "docs"), "doc2", {"title": "Second Doc", "author": "Bob", "tags": ["draft"]}, ), ( ("test", "images"), "img1", {"title": "Image 1", "author": "Alice", "tags": ["final"]}, ), ] for namespace, key, value in test_data: store.put(namespace, key, value) # Test basic search all_items = store.search(["test"]) assert len(all_items) == 3 # Test namespace filtering docs_items = store.search(["test", "docs"]) assert len(docs_items) == 2 assert all(item.namespace == ("test", "docs") for item in docs_items) # Test value filtering alice_items = store.search(["test"], filter={"author": "Alice"}) assert len(alice_items) == 2 assert all(item.value["author"] == "Alice" for item in alice_items) # Test pagination paginated_items = store.search(["test"], limit=2) assert len(paginated_items) == 2 offset_items = store.search(["test"], offset=2) assert len(offset_items) == 1 # Cleanup for namespace, key, _ in test_data: store.delete(namespace, key) @contextmanager def _create_vector_store( vector_type: str, distance_type: str, fake_embeddings: Embeddings, text_fields: Optional[list[str]] = None, ) -> PostgresStore: """Create a store with vector search enabled.""" database = f"test_{uuid4().hex[:16]}" uri_parts = DEFAULT_URI.split("/") uri_base = "/".join(uri_parts[:-1]) query_params = "" if "?" in uri_parts[-1]: db_name, query_params = uri_parts[-1].split("?", 1) query_params = "?" + query_params conn_string = f"{uri_base}/{database}{query_params}" admin_conn_string = DEFAULT_URI index_config = { "dims": fake_embeddings.dims, "embed": fake_embeddings, "ann_index_config": { "vector_type": vector_type, }, "distance_type": distance_type, "text_fields": text_fields, } with Connection.connect(admin_conn_string, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: with PostgresStore.from_conn_string( conn_string, index=index_config, ) as store: store.setup() yield store finally: with Connection.connect(admin_conn_string, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") @pytest.fixture( scope="function", params=[ (vector_type, distance_type) for vector_type in VECTOR_TYPES for distance_type in ( ["hamming"] if vector_type == "bit" else ["l2", "inner_product", "cosine"] ) ], ids=lambda p: f"{p[0]}_{p[1]}", ) def vector_store( request, fake_embeddings: Embeddings, ) -> PostgresStore: """Create a store with vector search enabled.""" vector_type, distance_type = request.param with _create_vector_store(vector_type, distance_type, fake_embeddings) as store: yield store def test_vector_store_initialization( vector_store: PostgresStore, fake_embeddings: CharacterEmbeddings ) -> None: """Test store initialization with embedding config.""" # Store should be initialized with embedding config assert vector_store.index_config is not None assert vector_store.index_config["dims"] == fake_embeddings.dims assert vector_store.index_config["embed"] == fake_embeddings def test_vector_insert_with_auto_embedding(vector_store: PostgresStore) -> None: """Test inserting items that get auto-embedded.""" docs = [ ("doc1", {"text": "short text"}), ("doc2", {"text": "longer text document"}), ("doc3", {"text": "longest text document here"}), ("doc4", {"description": "text in description field"}), ("doc5", {"content": "text in content field"}), ("doc6", {"body": "text in body field"}), ] for key, value in docs: vector_store.put(("test",), key, value) results = vector_store.search(("test",), query="long text") assert len(results) > 0 doc_order = [r.key for r in results] assert "doc2" in doc_order assert "doc3" in doc_order def test_vector_update_with_embedding(vector_store: PostgresStore) -> None: """Test that updating items properly updates their embeddings.""" vector_store.put(("test",), "doc1", {"text": "zany zebra Xerxes"}) vector_store.put(("test",), "doc2", {"text": "something about dogs"}) vector_store.put(("test",), "doc3", {"text": "text about birds"}) results_initial = vector_store.search(("test",), query="Zany Xerxes") assert len(results_initial) > 0 assert results_initial[0].key == "doc1" initial_score = results_initial[0].score vector_store.put(("test",), "doc1", {"text": "new text about dogs"}) results_after = vector_store.search(("test",), query="Zany Xerxes") after_score = next((r.score for r in results_after if r.key == "doc1"), 0.0) assert after_score < initial_score results_new = vector_store.search(("test",), query="new text about dogs") for r in results_new: if r.key == "doc1": assert r.score > after_score # Don't index this one vector_store.put(("test",), "doc4", {"text": "new text about dogs"}, index=False) results_new = vector_store.search(("test",), query="new text about dogs", limit=3) assert not any(r.key == "doc4" for r in results_new) def test_vector_search_with_filters(vector_store: PostgresStore) -> None: """Test combining vector search with filters.""" # Insert test documents docs = [ ("doc1", {"text": "red apple", "color": "red", "score": 4.5}), ("doc2", {"text": "red car", "color": "red", "score": 3.0}), ("doc3", {"text": "green apple", "color": "green", "score": 4.0}), ("doc4", {"text": "blue car", "color": "blue", "score": 3.5}), ] for key, value in docs: vector_store.put(("test",), key, value) results = vector_store.search(("test",), query="apple", filter={"color": "red"}) assert len(results) == 2 assert results[0].key == "doc1" results = vector_store.search(("test",), query="car", filter={"color": "red"}) assert len(results) == 2 assert results[0].key == "doc2" results = vector_store.search( ("test",), query="bbbbluuu", filter={"score": {"$gt": 3.2}} ) assert len(results) == 3 assert results[0].key == "doc4" # Multiple filters results = vector_store.search( ("test",), query="apple", filter={"score": {"$gte": 4.0}, "color": "green"} ) assert len(results) == 1 assert results[0].key == "doc3" def test_vector_search_pagination(vector_store: PostgresStore) -> None: """Test pagination with vector search.""" # Insert multiple similar documents for i in range(5): vector_store.put(("test",), f"doc{i}", {"text": f"test document number {i}"}) # Test with different page sizes results_page1 = vector_store.search(("test",), query="test", limit=2) results_page2 = vector_store.search(("test",), query="test", limit=2, offset=2) assert len(results_page1) == 2 assert len(results_page2) == 2 assert results_page1[0].key != results_page2[0].key # Get all results all_results = vector_store.search(("test",), query="test", limit=10) assert len(all_results) == 5 def test_vector_search_edge_cases(vector_store: PostgresStore) -> None: """Test edge cases in vector search.""" vector_store.put(("test",), "doc1", {"text": "test document"}) results = vector_store.search(("test",), query="") assert len(results) == 1 results = vector_store.search(("test",), query=None) assert len(results) == 1 long_query = "test " * 100 results = vector_store.search(("test",), query=long_query) assert len(results) == 1 special_query = "test!@#$%^&*()" results = vector_store.search(("test",), query=special_query) assert len(results) == 1 @pytest.mark.parametrize( "vector_type,distance_type", [ ("vector", "cosine"), ("vector", "inner_product"), ("halfvec", "cosine"), ("halfvec", "inner_product"), ], ) def test_embed_with_path_sync( request: Any, fake_embeddings: CharacterEmbeddings, vector_type: str, distance_type: str, ) -> None: """Test vector search with specific text fields in Postgres store.""" with _create_vector_store( vector_type, distance_type, fake_embeddings, text_fields=["key0", "key1", "key3"], ) as store: # This will have 2 vectors representing it doc1 = { # Omit key0 - check it doesn't raise an error "key1": "xxx", "key2": "yyy", "key3": "zzz", } # This will have 3 vectors representing it doc2 = { "key0": "uuu", "key1": "vvv", "key2": "www", "key3": "xxx", } store.put(("test",), "doc1", doc1) store.put(("test",), "doc2", doc2) # doc2.key3 and doc1.key1 both would have the highest score results = store.search(("test",), query="xxx") assert len(results) == 2 assert results[0].key != results[1].key ascore = results[0].score bscore = results[1].score assert ascore == pytest.approx(bscore, abs=1e-3) # ~Only match doc2 results = store.search(("test",), query="uuu") assert len(results) == 2 assert results[0].key != results[1].key assert results[0].key == "doc2" assert results[0].score > results[1].score assert ascore == pytest.approx(results[0].score, abs=1e-3) # ~Only match doc1 results = store.search(("test",), query="zzz") assert len(results) == 2 assert results[0].key != results[1].key assert results[0].key == "doc1" assert results[0].score > results[1].score assert ascore == pytest.approx(results[0].score, abs=1e-3) # Un-indexed - will have low results for both. Not zero (because we're projecting) # but less than the above. results = store.search(("test",), query="www") assert len(results) == 2 assert results[0].key != results[1].key assert results[0].score < ascore assert results[1].score < ascore @pytest.mark.parametrize( "vector_type,distance_type", [ ("vector", "cosine"), ("vector", "inner_product"), ("halfvec", "cosine"), ("halfvec", "inner_product"), ], ) def test_embed_with_path_operation_config( request: Any, fake_embeddings: CharacterEmbeddings, vector_type: str, distance_type: str, ) -> None: """Test operation-level field configuration for vector search.""" with _create_vector_store( vector_type, distance_type, fake_embeddings, text_fields=["key17"], # Default fields that won't match our test data ) as store: doc3 = { "key0": "aaa", "key1": "bbb", "key2": "ccc", "key3": "ddd", } doc4 = { "key0": "eee", "key1": "bbb", # Same as doc3.key1 "key2": "fff", "key3": "ggg", } store.put(("test",), "doc3", doc3, index=["key0", "key1"]) store.put(("test",), "doc4", doc4, index=["key1", "key3"]) results = store.search(("test",), query="aaa") assert len(results) == 2 assert results[0].key == "doc3" assert len(set(r.key for r in results)) == 2 assert results[0].score > results[1].score results = store.search(("test",), query="ggg") assert len(results) == 2 assert results[0].key == "doc4" assert results[0].score > results[1].score results = store.search(("test",), query="bbb") assert len(results) == 2 assert results[0].key != results[1].key assert results[0].score == pytest.approx(results[1].score, abs=1e-3) results = store.search(("test",), query="ccc") assert len(results) == 2 assert all( r.score < 0.9 for r in results ) # Unindexed field should have low scores # Test index=False behavior doc5 = { "key0": "hhh", "key1": "iii", } store.put(("test",), "doc5", doc5, index=False) results = store.search(("test",)) assert len(results) == 3 assert all(r.score is None for r in results) assert any(r.key == "doc5" for r in results) results = store.search(("test",), query="hhh") # TODO: We don't currently fill in additional results if there are not enough # returned during vector search. # assert len(results) == 3 # doc5_result = next(r for r in results if r.key == "doc5") # assert doc5_result.score is None def _cosine_similarity(X: list[float], Y: list[list[float]]) -> list[float]: """ Compute cosine similarity between a vector X and a matrix Y. Lazy import numpy for efficiency. """ similarities = [] for y in Y: dot_product = sum(a * b for a, b in zip(X, y)) norm1 = sum(a * a for a in X) ** 0.5 norm2 = sum(a * a for a in y) ** 0.5 similarity = dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0.0 similarities.append(similarity) return similarities def _inner_product(X: list[float], Y: list[list[float]]) -> list[float]: """ Compute inner product between a vector X and a matrix Y. Lazy import numpy for efficiency. """ similarities = [] for y in Y: similarity = sum(a * b for a, b in zip(X, y)) similarities.append(similarity) return similarities def _neg_l2_distance(X: list[float], Y: list[list[float]]) -> list[float]: """ Compute l2 distance between a vector X and a matrix Y. Lazy import numpy for efficiency. """ similarities = [] for y in Y: similarity = sum((a - b) ** 2 for a, b in zip(X, y)) ** 0.5 similarities.append(-similarity) return similarities @pytest.mark.parametrize( "vector_type,distance_type", [ ("vector", "cosine"), ("vector", "inner_product"), ("halfvec", "l2"), ], ) @pytest.mark.parametrize("query", ["aaa", "bbb", "ccc", "abcd", "poisson"]) def test_scores( fake_embeddings: CharacterEmbeddings, vector_type: str, distance_type: str, query: str, ) -> None: """Test operation-level field configuration for vector search.""" with _create_vector_store( vector_type, distance_type, fake_embeddings, text_fields=["key0"], ) as store: doc = { "key0": "aaa", } store.put(("test",), "doc", doc, index=["key0", "key1"]) results = store.search((), query=query) vec0 = fake_embeddings.embed_query(doc["key0"]) vec1 = fake_embeddings.embed_query(query) if distance_type == "cosine": similarities = _cosine_similarity(vec1, [vec0]) elif distance_type == "inner_product": similarities = _inner_product(vec1, [vec0]) elif distance_type == "l2": similarities = _neg_l2_distance(vec1, [vec0]) assert len(results) == 1 assert results[0].score == pytest.approx(similarities[0], abs=1e-3) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/embed_test_utils.py`: ```py """Embedding utilities for testing.""" import math import random from collections import Counter, defaultdict from typing import Any from langchain_core.embeddings import Embeddings class CharacterEmbeddings(Embeddings): """Simple character-frequency based embeddings using random projections.""" def __init__(self, dims: int = 50, seed: int = 42): """Initialize with embedding dimensions and random seed.""" self._rng = random.Random(seed) self.dims = dims # Create projection vector for each character lazily self._char_projections: defaultdict[str, list[float]] = defaultdict( lambda: [ self._rng.gauss(0, 1 / math.sqrt(self.dims)) for _ in range(self.dims) ] ) def _embed_one(self, text: str) -> list[float]: """Embed a single text.""" counts = Counter(text) total = sum(counts.values()) if total == 0: return [0.0] * self.dims embedding = [0.0] * self.dims for char, count in counts.items(): weight = count / total char_proj = self._char_projections[char] for i, proj in enumerate(char_proj): embedding[i] += weight * proj norm = math.sqrt(sum(x * x for x in embedding)) if norm > 0: embedding = [x / norm for x in embedding] return embedding def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed a list of documents.""" return [self._embed_one(text) for text in texts] def embed_query(self, text: str) -> list[float]: """Embed a query string.""" return self._embed_one(text) def __eq__(self, other: Any) -> bool: return isinstance(other, CharacterEmbeddings) and self.dims == other.dims ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/test_async_store.py`: ```py # type: ignore import itertools import sys import uuid from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any, Optional import pytest from langchain_core.embeddings import Embeddings from psycopg import AsyncConnection from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp from langgraph.store.postgres import AsyncPostgresStore from tests.conftest import ( DEFAULT_URI, VECTOR_TYPES, CharacterEmbeddings, ) @pytest.fixture(scope="function", params=["default", "pipe", "pool"]) async def store(request) -> AsyncIterator[AsyncPostgresStore]: if sys.version_info < (3, 10): pytest.skip("Async Postgres tests require Python 3.10+") database = f"test_{uuid.uuid4().hex[:16]}" uri_parts = DEFAULT_URI.split("/") uri_base = "/".join(uri_parts[:-1]) query_params = "" if "?" in uri_parts[-1]: db_name, query_params = uri_parts[-1].split("?", 1) query_params = "?" + query_params conn_string = f"{uri_base}/{database}{query_params}" admin_conn_string = DEFAULT_URI async with await AsyncConnection.connect( admin_conn_string, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: async with AsyncPostgresStore.from_conn_string(conn_string) as store: await store.setup() if request.param == "pipe": async with AsyncPostgresStore.from_conn_string( conn_string, pipeline=True ) as store: yield store elif request.param == "pool": async with AsyncPostgresStore.from_conn_string( conn_string, pool_config={"min_size": 1, "max_size": 10} ) as store: yield store else: # default async with AsyncPostgresStore.from_conn_string(conn_string) as store: yield store finally: async with await AsyncConnection.connect( admin_conn_string, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") async def test_abatch_order(store: AsyncPostgresStore) -> None: # Setup test data await store.aput(("test", "foo"), "key1", {"data": "value1"}) await store.aput(("test", "bar"), "key2", {"data": "value2"}) ops = [ GetOp(namespace=("test", "foo"), key="key1"), PutOp(namespace=("test", "bar"), key="key2", value={"data": "value2"}), SearchOp( namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0 ), ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0), GetOp(namespace=("test",), key="key3"), ] results = await store.abatch(ops) assert len(results) == 5 assert isinstance(results[0], Item) assert isinstance(results[0].value, dict) assert results[0].value == {"data": "value1"} assert results[0].key == "key1" assert results[1] is None assert isinstance(results[2], list) assert len(results[2]) == 1 assert isinstance(results[3], list) assert ("test", "foo") in results[3] and ("test", "bar") in results[3] assert results[4] is None ops_reordered = [ SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0), GetOp(namespace=("test", "bar"), key="key2"), ListNamespacesOp(match_conditions=None, max_depth=None, limit=5, offset=0), PutOp(namespace=("test",), key="key3", value={"data": "value3"}), GetOp(namespace=("test", "foo"), key="key1"), ] results_reordered = await store.abatch(ops_reordered) assert len(results_reordered) == 5 assert isinstance(results_reordered[0], list) assert len(results_reordered[0]) == 2 assert isinstance(results_reordered[1], Item) assert results_reordered[1].value == {"data": "value2"} assert results_reordered[1].key == "key2" assert isinstance(results_reordered[2], list) assert ("test", "foo") in results_reordered[2] and ( "test", "bar", ) in results_reordered[2] assert results_reordered[3] is None assert isinstance(results_reordered[4], Item) assert results_reordered[4].value == {"data": "value1"} assert results_reordered[4].key == "key1" async def test_batch_get_ops(store: AsyncPostgresStore) -> None: # Setup test data await store.aput(("test",), "key1", {"data": "value1"}) await store.aput(("test",), "key2", {"data": "value2"}) ops = [ GetOp(namespace=("test",), key="key1"), GetOp(namespace=("test",), key="key2"), GetOp(namespace=("test",), key="key3"), ] results = await store.abatch(ops) assert len(results) == 3 assert results[0] is not None assert results[1] is not None assert results[2] is None assert results[0].key == "key1" assert results[1].key == "key2" async def test_batch_put_ops(store: AsyncPostgresStore) -> None: ops = [ PutOp(namespace=("test",), key="key1", value={"data": "value1"}), PutOp(namespace=("test",), key="key2", value={"data": "value2"}), PutOp(namespace=("test",), key="key3", value=None), ] results = await store.abatch(ops) assert len(results) == 3 assert all(result is None for result in results) # Verify the puts worked items = await store.asearch(["test"], limit=10) assert len(items) == 2 # key3 had None value so wasn't stored async def test_batch_search_ops(store: AsyncPostgresStore) -> None: # Setup test data await store.aput(("test", "foo"), "key1", {"data": "value1"}) await store.aput(("test", "bar"), "key2", {"data": "value2"}) ops = [ SearchOp( namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0 ), SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0), ] results = await store.abatch(ops) assert len(results) == 2 assert len(results[0]) == 1 # Filtered results assert len(results[1]) == 2 # All results async def test_batch_list_namespaces_ops(store: AsyncPostgresStore) -> None: # Setup test data await store.aput(("test", "namespace1"), "key1", {"data": "value1"}) await store.aput(("test", "namespace2"), "key2", {"data": "value2"}) ops = [ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0)] results = await store.abatch(ops) assert len(results) == 1 assert len(results[0]) == 2 assert ("test", "namespace1") in results[0] assert ("test", "namespace2") in results[0] @asynccontextmanager async def _create_vector_store( vector_type: str, distance_type: str, fake_embeddings: CharacterEmbeddings, text_fields: Optional[list[str]] = None, ) -> AsyncIterator[AsyncPostgresStore]: """Create a store with vector search enabled.""" if sys.version_info < (3, 10): pytest.skip("Async Postgres tests require Python 3.10+") database = f"test_{uuid.uuid4().hex[:16]}" uri_parts = DEFAULT_URI.split("/") uri_base = "/".join(uri_parts[:-1]) query_params = "" if "?" in uri_parts[-1]: db_name, query_params = uri_parts[-1].split("?", 1) query_params = "?" + query_params conn_string = f"{uri_base}/{database}{query_params}" admin_conn_string = DEFAULT_URI index_config = { "dims": fake_embeddings.dims, "embed": fake_embeddings, "ann_index_config": { "vector_type": vector_type, }, "distance_type": distance_type, "text_fields": text_fields, } async with await AsyncConnection.connect( admin_conn_string, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: async with AsyncPostgresStore.from_conn_string( conn_string, index=index_config, ) as store: await store.setup() yield store finally: async with await AsyncConnection.connect( admin_conn_string, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") @pytest.fixture( scope="function", params=[ (vector_type, distance_type) for vector_type in VECTOR_TYPES for distance_type in ( ["hamming"] if vector_type == "bit" else ["l2", "inner_product", "cosine"] ) ], ids=lambda p: f"{p[0]}_{p[1]}", ) async def vector_store( request, fake_embeddings: CharacterEmbeddings, ) -> AsyncIterator[AsyncPostgresStore]: """Create a store with vector search enabled.""" vector_type, distance_type = request.param async with _create_vector_store( vector_type, distance_type, fake_embeddings ) as store: yield store async def test_vector_store_initialization( vector_store: AsyncPostgresStore, fake_embeddings: CharacterEmbeddings ) -> None: """Test store initialization with embedding config.""" assert vector_store.index_config is not None assert vector_store.index_config["dims"] == fake_embeddings.dims if isinstance(vector_store.index_config["embed"], Embeddings): assert vector_store.index_config["embed"] == fake_embeddings async def test_vector_insert_with_auto_embedding( vector_store: AsyncPostgresStore, ) -> None: """Test inserting items that get auto-embedded.""" docs = [ ("doc1", {"text": "short text"}), ("doc2", {"text": "longer text document"}), ("doc3", {"text": "longest text document here"}), ("doc4", {"description": "text in description field"}), ("doc5", {"content": "text in content field"}), ("doc6", {"body": "text in body field"}), ] for key, value in docs: await vector_store.aput(("test",), key, value) results = await vector_store.asearch(("test",), query="long text") assert len(results) > 0 doc_order = [r.key for r in results] assert "doc2" in doc_order assert "doc3" in doc_order async def test_vector_update_with_embedding(vector_store: AsyncPostgresStore) -> None: """Test that updating items properly updates their embeddings.""" await vector_store.aput(("test",), "doc1", {"text": "zany zebra Xerxes"}) await vector_store.aput(("test",), "doc2", {"text": "something about dogs"}) await vector_store.aput(("test",), "doc3", {"text": "text about birds"}) results_initial = await vector_store.asearch(("test",), query="Zany Xerxes") assert len(results_initial) > 0 assert results_initial[0].key == "doc1" initial_score = results_initial[0].score await vector_store.aput(("test",), "doc1", {"text": "new text about dogs"}) results_after = await vector_store.asearch(("test",), query="Zany Xerxes") after_score = next((r.score for r in results_after if r.key == "doc1"), 0.0) assert after_score < initial_score results_new = await vector_store.asearch(("test",), query="new text about dogs") for r in results_new: if r.key == "doc1": assert r.score > after_score # Don't index this one await vector_store.aput( ("test",), "doc4", {"text": "new text about dogs"}, index=False ) results_new = await vector_store.asearch( ("test",), query="new text about dogs", limit=3 ) assert not any(r.key == "doc4" for r in results_new) async def test_vector_search_with_filters(vector_store: AsyncPostgresStore) -> None: """Test combining vector search with filters.""" docs = [ ("doc1", {"text": "red apple", "color": "red", "score": 4.5}), ("doc2", {"text": "red car", "color": "red", "score": 3.0}), ("doc3", {"text": "green apple", "color": "green", "score": 4.0}), ("doc4", {"text": "blue car", "color": "blue", "score": 3.5}), ] for key, value in docs: await vector_store.aput(("test",), key, value) results = await vector_store.asearch( ("test",), query="apple", filter={"color": "red"} ) assert len(results) == 2 assert results[0].key == "doc1" results = await vector_store.asearch( ("test",), query="car", filter={"color": "red"} ) assert len(results) == 2 assert results[0].key == "doc2" results = await vector_store.asearch( ("test",), query="bbbbluuu", filter={"score": {"$gt": 3.2}} ) assert len(results) == 3 assert results[0].key == "doc4" results = await vector_store.asearch( ("test",), query="apple", filter={"score": {"$gte": 4.0}, "color": "green"} ) assert len(results) == 1 assert results[0].key == "doc3" async def test_vector_search_pagination(vector_store: AsyncPostgresStore) -> None: """Test pagination with vector search.""" for i in range(5): await vector_store.aput( ("test",), f"doc{i}", {"text": f"test document number {i}"} ) results_page1 = await vector_store.asearch(("test",), query="test", limit=2) results_page2 = await vector_store.asearch( ("test",), query="test", limit=2, offset=2 ) assert len(results_page1) == 2 assert len(results_page2) == 2 assert results_page1[0].key != results_page2[0].key all_results = await vector_store.asearch(("test",), query="test", limit=10) assert len(all_results) == 5 async def test_vector_search_edge_cases(vector_store: AsyncPostgresStore) -> None: """Test edge cases in vector search.""" await vector_store.aput(("test",), "doc1", {"text": "test document"}) perfect_match = await vector_store.asearch(("test",), query="text test document") perfect_score = perfect_match[0].score results = await vector_store.asearch(("test",), query="") assert len(results) == 1 assert results[0].score is None results = await vector_store.asearch(("test",), query=None) assert len(results) == 1 assert results[0].score is None long_query = "foo " * 100 results = await vector_store.asearch(("test",), query=long_query) assert len(results) == 1 assert results[0].score < perfect_score special_query = "test!@#$%^&*()" results = await vector_store.asearch(("test",), query=special_query) assert len(results) == 1 assert results[0].score < perfect_score @pytest.mark.parametrize( "vector_type,distance_type", [ *itertools.product(["vector", "halfvec"], ["cosine", "inner_product", "l2"]), ], ) async def test_embed_with_path( request: Any, fake_embeddings: CharacterEmbeddings, vector_type: str, distance_type: str, ) -> None: """Test vector search with specific text fields in Postgres store.""" async with _create_vector_store( vector_type, distance_type, fake_embeddings, text_fields=["key0", "key1", "key3"], ) as store: # This will have 2 vectors representing it doc1 = { # Omit key0 - check it doesn't raise an error "key1": "xxx", "key2": "yyy", "key3": "zzz", } # This will have 3 vectors representing it doc2 = { "key0": "uuu", "key1": "vvv", "key2": "www", "key3": "xxx", } await store.aput(("test",), "doc1", doc1) await store.aput(("test",), "doc2", doc2) # doc2.key3 and doc1.key1 both would have the highest score results = await store.asearch(("test",), query="xxx") assert len(results) == 2 assert results[0].key != results[1].key ascore = results[0].score bscore = results[1].score assert ascore == pytest.approx(bscore, abs=1e-3) results = await store.asearch(("test",), query="uuu") assert len(results) == 2 assert results[0].key != results[1].key assert results[0].key == "doc2" assert results[0].score > results[1].score assert ascore == pytest.approx(results[0].score, abs=1e-3) # Un-indexed - will have low results for both. Not zero (because we're projecting) # but less than the above. results = await store.asearch(("test",), query="www") assert len(results) == 2 assert results[0].score < ascore assert results[1].score < ascore @pytest.mark.parametrize( "vector_type,distance_type", [ *itertools.product(["vector", "halfvec"], ["cosine", "inner_product", "l2"]), ], ) async def test_search_sorting( request: Any, fake_embeddings: CharacterEmbeddings, vector_type: str, distance_type: str, ) -> None: """Test operation-level field configuration for vector search.""" async with _create_vector_store( vector_type, distance_type, fake_embeddings, text_fields=["key1"], # Default fields that won't match our test data ) as store: amatch = { "key1": "mmm", } await store.aput(("test", "M"), "M", amatch) N = 100 for i in range(N): await store.aput(("test", "A"), f"A{i}", {"key1": "no"}) for i in range(N): await store.aput(("test", "Z"), f"Z{i}", {"key1": "no"}) results = await store.asearch(("test",), query="mmm", limit=10) assert len(results) == 10 assert len(set(r.key for r in results)) == 10 assert results[0].key == "M" assert results[0].score > results[1].score ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/runner.py`: ```py import asyncio import concurrent.futures import time from typing import ( Any, AsyncIterator, Callable, Iterable, Iterator, Optional, Sequence, Type, Union, cast, ) from langgraph.constants import ( CONF, CONFIG_KEY_SEND, ERROR, INTERRUPT, NO_WRITES, PUSH, RESUME, TAG_HIDDEN, ) from langgraph.errors import GraphBubbleUp, GraphInterrupt from langgraph.pregel.executor import Submit from langgraph.pregel.retry import arun_with_retry, run_with_retry from langgraph.types import PregelExecutableTask, RetryPolicy class PregelRunner: """Responsible for executing a set of Pregel tasks concurrently, committing their writes, yielding control to caller when there is output to emit, and interrupting other tasks if appropriate.""" def __init__( self, *, submit: Submit, put_writes: Callable[[str, Sequence[tuple[str, Any]]], None], schedule_task: Callable[ [PregelExecutableTask, int], Optional[PregelExecutableTask] ], use_astream: bool = False, node_finished: Optional[Callable[[str], None]] = None, ) -> None: self.submit = submit self.put_writes = put_writes self.use_astream = use_astream self.node_finished = node_finished self.schedule_task = schedule_task def tick( self, tasks: Iterable[PregelExecutableTask], *, reraise: bool = True, timeout: Optional[float] = None, retry_policy: Optional[RetryPolicy] = None, get_waiter: Optional[Callable[[], concurrent.futures.Future[None]]] = None, ) -> Iterator[None]: def writer( task: PregelExecutableTask, writes: Sequence[tuple[str, Any]] ) -> None: prev_length = len(task.writes) # delegate to the underlying writer task.config[CONF][CONFIG_KEY_SEND](writes) for idx, w in enumerate(task.writes): # find the index for the newly inserted writes if idx < prev_length: continue assert writes[idx - prev_length] is w # bail if not a PUSH write if w[0] != PUSH: continue # schedule the next task, if the callback returns one if next_task := self.schedule_task(task, idx): # if the parent task was retried, # the next task might already be running if any( t == next_task.id for t in futures.values() if t is not None ): continue # schedule the next task futures[ self.submit( run_with_retry, next_task, retry_policy, writer=writer, __reraise_on_exit__=reraise, ) ] = next_task tasks = tuple(tasks) futures: dict[concurrent.futures.Future, Optional[PregelExecutableTask]] = {} # give control back to the caller yield # fast path if single task with no timeout and no waiter if len(tasks) == 1 and timeout is None and get_waiter is None: t = tasks[0] try: run_with_retry(t, retry_policy, writer=writer) self.commit(t, None) except Exception as exc: self.commit(t, exc) if reraise: raise if not futures: # maybe `t` schuduled another task return # add waiter task if requested if get_waiter is not None: futures[get_waiter()] = None # execute tasks, and wait for one to fail or all to finish. # each task is independent from all other concurrent tasks # yield updates/debug output as each task finishes for t in tasks: if not t.writes: futures[ self.submit( run_with_retry, t, retry_policy, writer=writer, __reraise_on_exit__=reraise, ) ] = t done_futures: set[concurrent.futures.Future] = set() end_time = timeout + time.monotonic() if timeout else None while len(futures) > (1 if get_waiter is not None else 0): done, inflight = concurrent.futures.wait( futures, return_when=concurrent.futures.FIRST_COMPLETED, timeout=(max(0, end_time - time.monotonic()) if end_time else None), ) if not done: break # timed out for fut in done: task = futures.pop(fut) if task is None: # waiter task finished, schedule another if inflight and get_waiter is not None: futures[get_waiter()] = None else: # store for panic check done_futures.add(fut) # task finished, commit writes self.commit(task, _exception(fut)) else: # remove references to loop vars del fut, task # maybe stop other tasks if _should_stop_others(done): break # give control back to the caller yield # panic on failure or timeout _panic_or_proceed( done_futures.union(f for f, t in futures.items() if t is not None), panic=reraise, ) async def atick( self, tasks: Iterable[PregelExecutableTask], *, reraise: bool = True, timeout: Optional[float] = None, retry_policy: Optional[RetryPolicy] = None, get_waiter: Optional[Callable[[], asyncio.Future[None]]] = None, ) -> AsyncIterator[None]: def writer( task: PregelExecutableTask, writes: Sequence[tuple[str, Any]] ) -> None: prev_length = len(task.writes) # delegate to the underlying writer task.config[CONF][CONFIG_KEY_SEND](writes) for idx, w in enumerate(task.writes): # find the index for the newly inserted writes if idx < prev_length: continue assert writes[idx - prev_length] is w # bail if not a PUSH write if w[0] != PUSH: continue # schedule the next task, if the callback returns one if next_task := self.schedule_task(task, idx): # if the parent task was retried, # the next task might already be running if any( t == next_task.id for t in futures.values() if t is not None ): continue # schedule the next task futures[ cast( asyncio.Future, self.submit( arun_with_retry, next_task, retry_policy, stream=self.use_astream, writer=writer, __name__=t.name, __cancel_on_exit__=True, __reraise_on_exit__=reraise, ), ) ] = next_task loop = asyncio.get_event_loop() tasks = tuple(tasks) futures: dict[asyncio.Future, Optional[PregelExecutableTask]] = {} # give control back to the caller yield # fast path if single task with no waiter and no timeout if len(tasks) == 1 and get_waiter is None and timeout is None: t = tasks[0] try: await arun_with_retry( t, retry_policy, stream=self.use_astream, writer=writer ) self.commit(t, None) except Exception as exc: self.commit(t, exc) if reraise: raise if not futures: # maybe `t` schuduled another task return # add waiter task if requested if get_waiter is not None: futures[get_waiter()] = None # execute tasks, and wait for one to fail or all to finish. # each task is independent from all other concurrent tasks # yield updates/debug output as each task finishes for t in tasks: if not t.writes: futures[ cast( asyncio.Future, self.submit( arun_with_retry, t, retry_policy, stream=self.use_astream, writer=writer, __name__=t.name, __cancel_on_exit__=True, __reraise_on_exit__=reraise, ), ) ] = t done_futures: set[asyncio.Future] = set() end_time = timeout + loop.time() if timeout else None while len(futures) > (1 if get_waiter is not None else 0): done, inflight = await asyncio.wait( futures, return_when=asyncio.FIRST_COMPLETED, timeout=(max(0, end_time - loop.time()) if end_time else None), ) if not done: break # timed out for fut in done: task = futures.pop(fut) if task is None: # waiter task finished, schedule another if inflight and get_waiter is not None: futures[get_waiter()] = None else: # store for panic check done_futures.add(fut) # task finished, commit writes self.commit(task, _exception(fut)) else: # remove references to loop vars del fut, task # maybe stop other tasks if _should_stop_others(done): break # give control back to the caller yield # cancel waiter task for fut in futures: fut.cancel() # panic on failure or timeout _panic_or_proceed( done_futures.union(f for f, t in futures.items() if t is not None), timeout_exc_cls=asyncio.TimeoutError, panic=reraise, ) def commit( self, task: PregelExecutableTask, exception: Optional[BaseException] ) -> None: if exception: if isinstance(exception, GraphInterrupt): # save interrupt to checkpointer if interrupts := [(INTERRUPT, i) for i in exception.args[0]]: if resumes := [w for w in task.writes if w[0] == RESUME]: interrupts.extend(resumes) self.put_writes(task.id, interrupts) elif isinstance(exception, GraphBubbleUp): raise exception else: # save error to checkpointer self.put_writes(task.id, [(ERROR, exception)]) else: if self.node_finished and ( task.config is None or TAG_HIDDEN not in task.config.get("tags", []) ): self.node_finished(task.name) if not task.writes: # add no writes marker task.writes.append((NO_WRITES, None)) # save task writes to checkpointer self.put_writes(task.id, task.writes) def _should_stop_others( done: Union[set[concurrent.futures.Future[Any]], set[asyncio.Future[Any]]], ) -> bool: """Check if any task failed, if so, cancel all other tasks. GraphInterrupts are not considered failures.""" for fut in done: if fut.cancelled(): return True if exc := fut.exception(): return not isinstance(exc, GraphBubbleUp) else: return False def _exception( fut: Union[concurrent.futures.Future[Any], asyncio.Future[Any]], ) -> Optional[BaseException]: """Return the exception from a future, without raising CancelledError.""" if fut.cancelled(): if isinstance(fut, asyncio.Future): return asyncio.CancelledError() else: return concurrent.futures.CancelledError() else: return fut.exception() def _panic_or_proceed( futs: Union[set[concurrent.futures.Future], set[asyncio.Future]], *, timeout_exc_cls: Type[Exception] = TimeoutError, panic: bool = True, ) -> None: """Cancel remaining tasks if any failed, re-raise exception if panic is True.""" done: set[Union[concurrent.futures.Future[Any], asyncio.Future[Any]]] = set() inflight: set[Union[concurrent.futures.Future[Any], asyncio.Future[Any]]] = set() for fut in futs: if fut.done(): done.add(fut) else: inflight.add(fut) while done: # if any task failed if exc := _exception(done.pop()): # cancel all pending tasks while inflight: inflight.pop().cancel() # raise the exception if panic: raise exc else: return if inflight: # if we got here means we timed out while inflight: # cancel all pending tasks inflight.pop().cancel() # raise timeout error raise timeout_exc_cls("Timed out") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/write.py`: ```py from __future__ import annotations from typing import ( Any, Callable, NamedTuple, Optional, Sequence, TypeVar, Union, cast, ) from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.utils import ConfigurableFieldSpec from langgraph.constants import CONF, CONFIG_KEY_SEND, FF_SEND_V2, PUSH, TASKS, Send from langgraph.errors import InvalidUpdateError from langgraph.utils.runnable import RunnableCallable TYPE_SEND = Callable[[Sequence[tuple[str, Any]]], None] R = TypeVar("R", bound=Runnable) SKIP_WRITE = object() PASSTHROUGH = object() class ChannelWriteEntry(NamedTuple): channel: str """Channel name to write to.""" value: Any = PASSTHROUGH """Value to write, or PASSTHROUGH to use the input.""" skip_none: bool = False """Whether to skip writing if the value is None.""" mapper: Optional[Callable] = None """Function to transform the value before writing.""" class ChannelWrite(RunnableCallable): """Implements th logic for sending writes to CONFIG_KEY_SEND. Can be used as a runnable or as a static method to call imperatively.""" writes: list[Union[ChannelWriteEntry, Send]] """Sequence of write entries or Send objects to write.""" require_at_least_one_of: Optional[Sequence[str]] """If defined, at least one of these channels must be written to.""" def __init__( self, writes: Sequence[Union[ChannelWriteEntry, Send]], *, tags: Optional[Sequence[str]] = None, require_at_least_one_of: Optional[Sequence[str]] = None, ): super().__init__(func=self._write, afunc=self._awrite, name=None, tags=tags) self.writes = cast(list[Union[ChannelWriteEntry, Send]], writes) self.require_at_least_one_of = require_at_least_one_of def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None ) -> str: if not name: name = f"ChannelWrite<{','.join(w.channel if isinstance(w, ChannelWriteEntry) else w.node for w in self.writes)}>" return super().get_name(suffix, name=name) @property def config_specs(self) -> list[ConfigurableFieldSpec]: return [ ConfigurableFieldSpec( id=CONFIG_KEY_SEND, name=CONFIG_KEY_SEND, description=None, default=None, annotation=None, ), ] def _write(self, input: Any, config: RunnableConfig) -> None: writes = [ ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper) if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH else write for write in self.writes ] self.do_write( config, writes, self.require_at_least_one_of if input is not None else None, ) return input async def _awrite(self, input: Any, config: RunnableConfig) -> None: writes = [ ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper) if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH else write for write in self.writes ] self.do_write( config, writes, self.require_at_least_one_of if input is not None else None, ) return input @staticmethod def do_write( config: RunnableConfig, writes: Sequence[Union[ChannelWriteEntry, Send]], require_at_least_one_of: Optional[Sequence[str]] = None, ) -> None: # validate for w in writes: if isinstance(w, ChannelWriteEntry): if w.channel in (TASKS, PUSH): raise InvalidUpdateError( "Cannot write to the reserved channel TASKS" ) if w.value is PASSTHROUGH: raise InvalidUpdateError("PASSTHROUGH value must be replaced") # split packets and entries sends = [ (PUSH if FF_SEND_V2 else TASKS, packet) for packet in writes if isinstance(packet, Send) ] entries = [write for write in writes if isinstance(write, ChannelWriteEntry)] # process entries into values values = [ write.mapper(write.value) if write.mapper is not None else write.value for write in entries ] values = [ (write.channel, val) for val, write in zip(values, entries) if not write.skip_none or val is not None ] # filter out SKIP_WRITE values filtered = [(chan, val) for chan, val in values if val is not SKIP_WRITE] if require_at_least_one_of is not None: if not {chan for chan, _ in filtered} & set(require_at_least_one_of): raise InvalidUpdateError( f"Must write to at least one of {require_at_least_one_of}" ) write: TYPE_SEND = config[CONF][CONFIG_KEY_SEND] write(sends + filtered) @staticmethod def is_writer(runnable: Runnable) -> bool: """Used by PregelNode to distinguish between writers and other runnables.""" return ( isinstance(runnable, ChannelWrite) or getattr(runnable, "_is_channel_writer", False) is True ) @staticmethod def register_writer(runnable: R) -> R: """Used to mark a runnable as a writer, so that it can be detected by is_writer. Instances of ChannelWrite are automatically marked as writers.""" # using object.__setattr__ to work around objects that override __setattr__ # eg. pydantic models and dataclasses object.__setattr__(runnable, "_is_channel_writer", True) return runnable ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/log.py`: ```py import logging logger = logging.getLogger("langgraph") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/remote.py`: ```py from dataclasses import asdict from typing import ( Any, AsyncIterator, Iterator, Literal, Optional, Sequence, Union, cast, ) import orjson from langchain_core.runnables import RunnableConfig from langchain_core.runnables.graph import ( Edge as DrawableEdge, ) from langchain_core.runnables.graph import ( Graph as DrawableGraph, ) from langchain_core.runnables.graph import ( Node as DrawableNode, ) from langgraph_sdk.client import ( LangGraphClient, SyncLangGraphClient, get_client, get_sync_client, ) from langgraph_sdk.schema import Checkpoint, ThreadState from langgraph_sdk.schema import Command as CommandSDK from langgraph_sdk.schema import StreamMode as StreamModeSDK from typing_extensions import Self from langgraph.checkpoint.base import CheckpointMetadata from langgraph.constants import ( CONF, CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_STREAM, INTERRUPT, NS_SEP, ) from langgraph.errors import GraphInterrupt from langgraph.pregel.protocol import PregelProtocol from langgraph.pregel.types import All, PregelTask, StateSnapshot, StreamMode from langgraph.types import Command, Interrupt, StreamProtocol from langgraph.utils.config import merge_configs class RemoteException(Exception): """Exception raised when an error occurs in the remote graph.""" pass class RemoteGraph(PregelProtocol): """The `RemoteGraph` class is a client implementation for calling remote APIs that implement the LangGraph Server API specification. For example, the `RemoteGraph` class can be used to call APIs from deployments on LangGraph Cloud. `RemoteGraph` behaves the same way as a `Graph` and can be used directly as a node in another `Graph`. """ name: str def __init__( self, name: str, # graph_id /, *, url: Optional[str] = None, api_key: Optional[str] = None, headers: Optional[dict[str, str]] = None, client: Optional[LangGraphClient] = None, sync_client: Optional[SyncLangGraphClient] = None, config: Optional[RunnableConfig] = None, ): """Specify `url`, `api_key`, and/or `headers` to create default sync and async clients. If `client` or `sync_client` are provided, they will be used instead of the default clients. See `LangGraphClient` and `SyncLangGraphClient` for details on the default clients. At least one of `url`, `client`, or `sync_client` must be provided. Args: name: The name of the graph. url: The URL of the remote API. api_key: The API key to use for authentication. If not provided, it will be read from the environment (`LANGGRAPH_API_KEY`, `LANGSMITH_API_KEY`, or `LANGCHAIN_API_KEY`). headers: Additional headers to include in the requests. client: A `LangGraphClient` instance to use instead of creating a default client. sync_client: A `SyncLangGraphClient` instance to use instead of creating a default client. config: An optional `RunnableConfig` instance with additional configuration. """ self.name = name self.config = config if client is None and url is not None: client = get_client(url=url, api_key=api_key, headers=headers) self.client = client if sync_client is None and url is not None: sync_client = get_sync_client(url=url, api_key=api_key, headers=headers) self.sync_client = sync_client def _validate_client(self) -> LangGraphClient: if self.client is None: raise ValueError( "Async client is not initialized: please provide `url` or `client` when initializing `RemoteGraph`." ) return self.client def _validate_sync_client(self) -> SyncLangGraphClient: if self.sync_client is None: raise ValueError( "Sync client is not initialized: please provide `url` or `sync_client` when initializing `RemoteGraph`." ) return self.sync_client def copy(self, update: dict[str, Any]) -> Self: attrs = {**self.__dict__, **update} return self.__class__(attrs.pop("name"), **attrs) def with_config( self, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Self: return self.copy( {"config": merge_configs(self.config, config, cast(RunnableConfig, kwargs))} ) def _get_drawable_nodes( self, graph: dict[str, list[dict[str, Any]]] ) -> dict[str, DrawableNode]: nodes = {} for node in graph["nodes"]: node_id = str(node["id"]) node_data = node.get("data", {}) # Get node name from node_data if available. If not, use node_id. node_name = node.get("name") if node_name is None: if isinstance(node_data, dict): node_name = node_data.get("name", node_id) else: node_name = node_id nodes[node_id] = DrawableNode( id=node_id, name=node_name, data=node_data, metadata=node.get("metadata"), ) return nodes def get_graph( self, config: Optional[RunnableConfig] = None, *, xray: Union[int, bool] = False, ) -> DrawableGraph: """Get graph by graph name. This method calls `GET /assistants/{assistant_id}/graph`. Args: config: This parameter is not used. xray: Include graph representation of subgraphs. If an integer value is provided, only subgraphs with a depth less than or equal to the value will be included. Returns: The graph information for the assistant in JSON format. """ sync_client = self._validate_sync_client() graph = sync_client.assistants.get_graph( assistant_id=self.name, xray=xray, ) return DrawableGraph( nodes=self._get_drawable_nodes(graph), edges=[DrawableEdge(**edge) for edge in graph["edges"]], ) async def aget_graph( self, config: Optional[RunnableConfig] = None, *, xray: Union[int, bool] = False, ) -> DrawableGraph: """Get graph by graph name. This method calls `GET /assistants/{assistant_id}/graph`. Args: config: This parameter is not used. xray: Include graph representation of subgraphs. If an integer value is provided, only subgraphs with a depth less than or equal to the value will be included. Returns: The graph information for the assistant in JSON format. """ client = self._validate_client() graph = await client.assistants.get_graph( assistant_id=self.name, xray=xray, ) return DrawableGraph( nodes=self._get_drawable_nodes(graph), edges=[DrawableEdge(**edge) for edge in graph["edges"]], ) def _create_state_snapshot(self, state: ThreadState) -> StateSnapshot: tasks = [] for task in state["tasks"]: interrupts = [] for interrupt in task["interrupts"]: interrupts.append(Interrupt(**interrupt)) tasks.append( PregelTask( id=task["id"], name=task["name"], path=tuple(), error=Exception(task["error"]) if task["error"] else None, interrupts=tuple(interrupts), state=self._create_state_snapshot(task["state"]) if task["state"] else cast(RunnableConfig, {"configurable": task["checkpoint"]}) if task["checkpoint"] else None, result=task.get("result"), ) ) return StateSnapshot( values=state["values"], next=tuple(state["next"]) if state["next"] else tuple(), config={ "configurable": { "thread_id": state["checkpoint"]["thread_id"], "checkpoint_ns": state["checkpoint"]["checkpoint_ns"], "checkpoint_id": state["checkpoint"]["checkpoint_id"], "checkpoint_map": state["checkpoint"].get("checkpoint_map", {}), } }, metadata=CheckpointMetadata(**state["metadata"]), created_at=state["created_at"], parent_config={ "configurable": { "thread_id": state["parent_checkpoint"]["thread_id"], "checkpoint_ns": state["parent_checkpoint"]["checkpoint_ns"], "checkpoint_id": state["parent_checkpoint"]["checkpoint_id"], "checkpoint_map": state["parent_checkpoint"].get( "checkpoint_map", {} ), } } if state["parent_checkpoint"] else None, tasks=tuple(tasks), ) def _get_checkpoint(self, config: Optional[RunnableConfig]) -> Optional[Checkpoint]: if config is None: return None checkpoint = {} if "thread_id" in config["configurable"]: checkpoint["thread_id"] = config["configurable"]["thread_id"] if "checkpoint_ns" in config["configurable"]: checkpoint["checkpoint_ns"] = config["configurable"]["checkpoint_ns"] if "checkpoint_id" in config["configurable"]: checkpoint["checkpoint_id"] = config["configurable"]["checkpoint_id"] if "checkpoint_map" in config["configurable"]: checkpoint["checkpoint_map"] = config["configurable"]["checkpoint_map"] return checkpoint if checkpoint else None def _get_config(self, checkpoint: Checkpoint) -> RunnableConfig: return { "configurable": { "thread_id": checkpoint["thread_id"], "checkpoint_ns": checkpoint["checkpoint_ns"], "checkpoint_id": checkpoint["checkpoint_id"], "checkpoint_map": checkpoint.get("checkpoint_map", {}), } } def _sanitize_config(self, config: RunnableConfig) -> RunnableConfig: reserved_configurable_keys = frozenset( [ "callbacks", "checkpoint_map", "checkpoint_id", "checkpoint_ns", ] ) def _sanitize_obj(obj: Any) -> Any: """Remove non-JSON serializable fields from the given object.""" if isinstance(obj, dict): return {k: _sanitize_obj(v) for k, v in obj.items()} elif isinstance(obj, list): return [_sanitize_obj(v) for v in obj] else: try: orjson.dumps(obj) return obj except orjson.JSONEncodeError: return None # Remove non-JSON serializable fields from the config. config = _sanitize_obj(config) # Only include configurable keys that are not reserved and # not starting with "__pregel_" prefix. new_configurable = { k: v for k, v in config["configurable"].items() if k not in reserved_configurable_keys and not k.startswith("__pregel_") } return { "tags": config.get("tags") or [], "metadata": config.get("metadata") or {}, "configurable": new_configurable, } def get_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: """Get the state of a thread. This method calls `POST /threads/{thread_id}/state/checkpoint` if a checkpoint is specified in the config or `GET /threads/{thread_id}/state` if no checkpoint is specified. Args: config: A `RunnableConfig` that includes `thread_id` in the `configurable` field. subgraphs: Include subgraphs in the state. Returns: The latest state of the thread. """ sync_client = self._validate_sync_client() merged_config = merge_configs(self.config, config) state = sync_client.threads.get_state( thread_id=merged_config["configurable"]["thread_id"], checkpoint=self._get_checkpoint(merged_config), subgraphs=subgraphs, ) return self._create_state_snapshot(state) async def aget_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: """Get the state of a thread. This method calls `POST /threads/{thread_id}/state/checkpoint` if a checkpoint is specified in the config or `GET /threads/{thread_id}/state` if no checkpoint is specified. Args: config: A `RunnableConfig` that includes `thread_id` in the `configurable` field. subgraphs: Include subgraphs in the state. Returns: The latest state of the thread. """ client = self._validate_client() merged_config = merge_configs(self.config, config) state = await client.threads.get_state( thread_id=merged_config["configurable"]["thread_id"], checkpoint=self._get_checkpoint(merged_config), subgraphs=subgraphs, ) return self._create_state_snapshot(state) def get_state_history( self, config: RunnableConfig, *, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[StateSnapshot]: """Get the state history of a thread. This method calls `POST /threads/{thread_id}/history`. Args: config: A `RunnableConfig` that includes `thread_id` in the `configurable` field. filter: Metadata to filter on. before: A `RunnableConfig` that includes checkpoint metadata. limit: Max number of states to return. Returns: States of the thread. """ sync_client = self._validate_sync_client() merged_config = merge_configs(self.config, config) states = sync_client.threads.get_history( thread_id=merged_config["configurable"]["thread_id"], limit=limit if limit else 10, before=self._get_checkpoint(before), metadata=filter, checkpoint=self._get_checkpoint(merged_config), ) for state in states: yield self._create_state_snapshot(state) async def aget_state_history( self, config: RunnableConfig, *, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> AsyncIterator[StateSnapshot]: """Get the state history of a thread. This method calls `POST /threads/{thread_id}/history`. Args: config: A `RunnableConfig` that includes `thread_id` in the `configurable` field. filter: Metadata to filter on. before: A `RunnableConfig` that includes checkpoint metadata. limit: Max number of states to return. Returns: States of the thread. """ client = self._validate_client() merged_config = merge_configs(self.config, config) states = await client.threads.get_history( thread_id=merged_config["configurable"]["thread_id"], limit=limit if limit else 10, before=self._get_checkpoint(before), metadata=filter, checkpoint=self._get_checkpoint(merged_config), ) for state in states: yield self._create_state_snapshot(state) def update_state( self, config: RunnableConfig, values: Optional[Union[dict[str, Any], Any]], as_node: Optional[str] = None, ) -> RunnableConfig: """Update the state of a thread. This method calls `POST /threads/{thread_id}/state`. Args: config: A `RunnableConfig` that includes `thread_id` in the `configurable` field. values: Values to update to the state. as_node: Update the state as if this node had just executed. Returns: `RunnableConfig` for the updated thread. """ sync_client = self._validate_sync_client() merged_config = merge_configs(self.config, config) response: dict = sync_client.threads.update_state( # type: ignore thread_id=merged_config["configurable"]["thread_id"], values=values, as_node=as_node, checkpoint=self._get_checkpoint(merged_config), ) return self._get_config(response["checkpoint"]) async def aupdate_state( self, config: RunnableConfig, values: Optional[Union[dict[str, Any], Any]], as_node: Optional[str] = None, ) -> RunnableConfig: """Update the state of a thread. This method calls `POST /threads/{thread_id}/state`. Args: config: A `RunnableConfig` that includes `thread_id` in the `configurable` field. values: Values to update to the state. as_node: Update the state as if this node had just executed. Returns: `RunnableConfig` for the updated thread. """ client = self._validate_client() merged_config = merge_configs(self.config, config) response: dict = await client.threads.update_state( # type: ignore thread_id=merged_config["configurable"]["thread_id"], values=values, as_node=as_node, checkpoint=self._get_checkpoint(merged_config), ) return self._get_config(response["checkpoint"]) def _get_stream_modes( self, stream_mode: Optional[Union[StreamMode, list[StreamMode]]], config: Optional[RunnableConfig], default: StreamMode = "updates", ) -> tuple[ list[StreamModeSDK], list[StreamModeSDK], bool, Optional[StreamProtocol] ]: """Return a tuple of the final list of stream modes sent to the remote graph and a boolean flag indicating if stream mode 'updates' was present in the original list of stream modes. 'updates' mode is added to the list of stream modes so that interrupts can be detected in the remote graph. """ updated_stream_modes: list[StreamModeSDK] = [] req_single = True # coerce to list, or add default stream mode if stream_mode: if isinstance(stream_mode, str): updated_stream_modes.append(stream_mode) else: req_single = False updated_stream_modes.extend(stream_mode) else: updated_stream_modes.append(default) requested_stream_modes = updated_stream_modes.copy() # add any from parent graph stream: Optional[StreamProtocol] = ( (config or {}).get(CONF, {}).get(CONFIG_KEY_STREAM) ) if stream: updated_stream_modes.extend(stream.modes) # map "messages" to "messages-tuple" if "messages" in updated_stream_modes: updated_stream_modes.remove("messages") updated_stream_modes.append("messages-tuple") # if requested "messages-tuple", # map to "messages" in requested_stream_modes if "messages-tuple" in requested_stream_modes: requested_stream_modes.remove("messages-tuple") requested_stream_modes.append("messages") # add 'updates' mode if not present if "updates" not in updated_stream_modes: updated_stream_modes.append("updates") # remove 'events', as it's not supported in Pregel if "events" in updated_stream_modes: updated_stream_modes.remove("events") return (updated_stream_modes, requested_stream_modes, req_single, stream) def stream( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, subgraphs: bool = False, **kwargs: Any, ) -> Iterator[Union[dict[str, Any], Any]]: """Create a run and stream the results. This method calls `POST /threads/{thread_id}/runs/stream` if a `thread_id` is speciffed in the `configurable` field of the config or `POST /runs/stream` otherwise. Args: input: Input to the graph. config: A `RunnableConfig` for graph invocation. stream_mode: Stream mode(s) to use. interrupt_before: Interrupt the graph before these nodes. interrupt_after: Interrupt the graph after these nodes. subgraphs: Stream from subgraphs. **kwargs: Additional params to pass to client.runs.stream. Yields: The output of the graph. """ sync_client = self._validate_sync_client() merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) stream_modes, requested, req_single, stream = self._get_stream_modes( stream_mode, config ) if isinstance(input, Command): command: Optional[CommandSDK] = cast(CommandSDK, asdict(input)) input = None else: command = None for chunk in sync_client.runs.stream( thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, command=command, config=sanitized_config, stream_mode=stream_modes, interrupt_before=interrupt_before, interrupt_after=interrupt_after, stream_subgraphs=subgraphs or stream is not None, if_not_exists="create", **kwargs, ): # split mode and ns if NS_SEP in chunk.event: mode, ns_ = chunk.event.split(NS_SEP, 1) ns = tuple(ns_.split(NS_SEP)) else: mode, ns = chunk.event, () # prepend caller ns (as it is not passed to remote graph) if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS): caller_ns = tuple(caller_ns.split(NS_SEP)) ns = caller_ns + ns # stream to parent stream if stream is not None and mode in stream.modes: stream((ns, mode, chunk.data)) # raise interrupt or errors if chunk.event.startswith("updates"): if isinstance(chunk.data, dict) and INTERRUPT in chunk.data: raise GraphInterrupt(chunk.data[INTERRUPT]) elif chunk.event.startswith("error"): raise RemoteException(chunk.data) # filter for what was actually requested if mode not in requested: continue # emit chunk if subgraphs: if NS_SEP in chunk.event: mode, ns_ = chunk.event.split(NS_SEP, 1) ns = tuple(ns_.split(NS_SEP)) else: mode, ns = chunk.event, () if req_single: yield ns, chunk.data else: yield ns, mode, chunk.data elif req_single: yield chunk.data else: yield chunk async def astream( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, subgraphs: bool = False, **kwargs: Any, ) -> AsyncIterator[Union[dict[str, Any], Any]]: """Create a run and stream the results. This method calls `POST /threads/{thread_id}/runs/stream` if a `thread_id` is speciffed in the `configurable` field of the config or `POST /runs/stream` otherwise. Args: input: Input to the graph. config: A `RunnableConfig` for graph invocation. stream_mode: Stream mode(s) to use. interrupt_before: Interrupt the graph before these nodes. interrupt_after: Interrupt the graph after these nodes. subgraphs: Stream from subgraphs. **kwargs: Additional params to pass to client.runs.stream. Yields: The output of the graph. """ client = self._validate_client() merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) stream_modes, requested, req_single, stream = self._get_stream_modes( stream_mode, config ) if isinstance(input, Command): command: Optional[CommandSDK] = cast(CommandSDK, asdict(input)) input = None else: command = None async for chunk in client.runs.stream( thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, command=command, config=sanitized_config, stream_mode=stream_modes, interrupt_before=interrupt_before, interrupt_after=interrupt_after, stream_subgraphs=subgraphs or stream is not None, if_not_exists="create", **kwargs, ): # split mode and ns if NS_SEP in chunk.event: mode, ns_ = chunk.event.split(NS_SEP, 1) ns = tuple(ns_.split(NS_SEP)) else: mode, ns = chunk.event, () # prepend caller ns (as it is not passed to remote graph) if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS): caller_ns = tuple(caller_ns.split(NS_SEP)) ns = caller_ns + ns # stream to parent stream if stream is not None and mode in stream.modes: stream((ns, mode, chunk.data)) # raise interrupt or errors if chunk.event.startswith("updates"): if isinstance(chunk.data, dict) and INTERRUPT in chunk.data: raise GraphInterrupt(chunk.data[INTERRUPT]) elif chunk.event.startswith("error"): raise RemoteException(chunk.data) # filter for what was actually requested if mode not in requested: continue # emit chunk if subgraphs: if NS_SEP in chunk.event: mode, ns_ = chunk.event.split(NS_SEP, 1) ns = tuple(ns_.split(NS_SEP)) else: mode, ns = chunk.event, () if req_single: yield ns, chunk.data else: yield ns, mode, chunk.data elif req_single: yield chunk.data else: yield chunk async def astream_events( self, input: Any, config: Optional[RunnableConfig] = None, *, version: Literal["v1", "v2"], include_names: Optional[Sequence[All]] = None, include_types: Optional[Sequence[All]] = None, include_tags: Optional[Sequence[All]] = None, exclude_names: Optional[Sequence[All]] = None, exclude_types: Optional[Sequence[All]] = None, exclude_tags: Optional[Sequence[All]] = None, **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: raise NotImplementedError def invoke( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, **kwargs: Any, ) -> Union[dict[str, Any], Any]: """Create a run, wait until it finishes and return the final state. Args: input: Input to the graph. config: A `RunnableConfig` for graph invocation. interrupt_before: Interrupt the graph before these nodes. interrupt_after: Interrupt the graph after these nodes. **kwargs: Additional params to pass to RemoteGraph.stream. Returns: The output of the graph. """ for chunk in self.stream( input, config=config, interrupt_before=interrupt_before, interrupt_after=interrupt_after, stream_mode="values", **kwargs, ): pass try: return chunk except UnboundLocalError: return None async def ainvoke( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, **kwargs: Any, ) -> Union[dict[str, Any], Any]: """Create a run, wait until it finishes and return the final state. Args: input: Input to the graph. config: A `RunnableConfig` for graph invocation. interrupt_before: Interrupt the graph before these nodes. interrupt_after: Interrupt the graph after these nodes. **kwargs: Additional params to pass to RemoteGraph.astream. Returns: The output of the graph. """ async for chunk in self.astream( input, config=config, interrupt_before=interrupt_before, interrupt_after=interrupt_after, stream_mode="values", **kwargs, ): pass try: return chunk except UnboundLocalError: return None ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/protocol.py`: ```py from abc import ABC, abstractmethod from typing import ( Any, AsyncIterator, Iterator, Optional, Sequence, Union, ) from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.graph import Graph as DrawableGraph from typing_extensions import Self from langgraph.pregel.types import All, StateSnapshot, StreamMode class PregelProtocol( Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]], ABC ): @abstractmethod def with_config( self, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Self: ... @abstractmethod def get_graph( self, config: Optional[RunnableConfig] = None, *, xray: Union[int, bool] = False, ) -> DrawableGraph: ... @abstractmethod async def aget_graph( self, config: Optional[RunnableConfig] = None, *, xray: Union[int, bool] = False, ) -> DrawableGraph: ... @abstractmethod def get_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: ... @abstractmethod async def aget_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: ... @abstractmethod def get_state_history( self, config: RunnableConfig, *, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[StateSnapshot]: ... @abstractmethod def aget_state_history( self, config: RunnableConfig, *, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> AsyncIterator[StateSnapshot]: ... @abstractmethod def update_state( self, config: RunnableConfig, values: Optional[Union[dict[str, Any], Any]], as_node: Optional[str] = None, ) -> RunnableConfig: ... @abstractmethod async def aupdate_state( self, config: RunnableConfig, values: Optional[Union[dict[str, Any], Any]], as_node: Optional[str] = None, ) -> RunnableConfig: ... @abstractmethod def stream( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, subgraphs: bool = False, ) -> Iterator[Union[dict[str, Any], Any]]: ... @abstractmethod def astream( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, subgraphs: bool = False, ) -> AsyncIterator[Union[dict[str, Any], Any]]: ... @abstractmethod def invoke( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, ) -> Union[dict[str, Any], Any]: ... @abstractmethod async def ainvoke( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, ) -> Union[dict[str, Any], Any]: ... ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/io.py`: ```py from typing import Any, Iterator, Literal, Mapping, Optional, Sequence, TypeVar, Union from uuid import UUID from langchain_core.runnables.utils import AddableDict from langgraph.channels.base import BaseChannel, EmptyChannelError from langgraph.checkpoint.base import PendingWrite from langgraph.constants import ( EMPTY_SEQ, ERROR, FF_SEND_V2, INTERRUPT, NULL_TASK_ID, PUSH, RESUME, TAG_HIDDEN, TASKS, ) from langgraph.errors import InvalidUpdateError from langgraph.pregel.log import logger from langgraph.types import Command, PregelExecutableTask, Send def is_task_id(task_id: str) -> bool: """Check if a string is a valid task id.""" try: UUID(task_id) except ValueError: return False return True def read_channel( channels: Mapping[str, BaseChannel], chan: str, *, catch: bool = True, return_exception: bool = False, ) -> Any: try: return channels[chan].get() except EmptyChannelError as exc: if return_exception: return exc elif catch: return None else: raise def read_channels( channels: Mapping[str, BaseChannel], select: Union[Sequence[str], str], *, skip_empty: bool = True, ) -> Union[dict[str, Any], Any]: if isinstance(select, str): return read_channel(channels, select) else: values: dict[str, Any] = {} for k in select: try: values[k] = read_channel(channels, k, catch=not skip_empty) except EmptyChannelError: pass return values def map_command( cmd: Command, pending_writes: list[PendingWrite] ) -> Iterator[tuple[str, str, Any]]: """Map input chunk to a sequence of pending writes in the form (channel, value).""" if cmd.graph == Command.PARENT: raise InvalidUpdateError("There is not parent graph") if cmd.goto: if isinstance(cmd.goto, (tuple, list)): sends = cmd.goto else: sends = [cmd.goto] for send in sends: if not isinstance(send, Send): raise TypeError( f"In Command.goto, expected Send, got {type(send).__name__}" ) yield (NULL_TASK_ID, PUSH if FF_SEND_V2 else TASKS, send) # TODO handle goto str for state graph if cmd.resume: if isinstance(cmd.resume, dict) and all(is_task_id(k) for k in cmd.resume): for tid, resume in cmd.resume.items(): existing: list[Any] = next( (w[2] for w in pending_writes if w[0] == tid and w[1] == RESUME), [] ) existing.append(resume) yield (tid, RESUME, existing) else: yield (NULL_TASK_ID, RESUME, cmd.resume) if cmd.update: if not isinstance(cmd.update, dict): raise TypeError( f"Expected cmd.update to be a dict mapping channel names to update values, got {type(cmd.update).__name__}" ) for k, v in cmd.update.items(): yield (NULL_TASK_ID, k, v) def map_input( input_channels: Union[str, Sequence[str]], chunk: Optional[Union[dict[str, Any], Any]], ) -> Iterator[tuple[str, Any]]: """Map input chunk to a sequence of pending writes in the form (channel, value).""" if chunk is None: return elif isinstance(input_channels, str): yield (input_channels, chunk) else: if not isinstance(chunk, dict): raise TypeError(f"Expected chunk to be a dict, got {type(chunk).__name__}") for k in chunk: if k in input_channels: yield (k, chunk[k]) else: logger.warning(f"Input channel {k} not found in {input_channels}") class AddableValuesDict(AddableDict): def __add__(self, other: dict[str, Any]) -> "AddableValuesDict": return self | other def __radd__(self, other: dict[str, Any]) -> "AddableValuesDict": return other | self def map_output_values( output_channels: Union[str, Sequence[str]], pending_writes: Union[Literal[True], Sequence[tuple[str, Any]]], channels: Mapping[str, BaseChannel], ) -> Iterator[Union[dict[str, Any], Any]]: """Map pending writes (a sequence of tuples (channel, value)) to output chunk.""" if isinstance(output_channels, str): if pending_writes is True or any( chan == output_channels for chan, _ in pending_writes ): yield read_channel(channels, output_channels) else: if pending_writes is True or { c for c, _ in pending_writes if c in output_channels }: yield AddableValuesDict(read_channels(channels, output_channels)) class AddableUpdatesDict(AddableDict): def __add__(self, other: dict[str, Any]) -> "AddableUpdatesDict": return [self, other] def __radd__(self, other: dict[str, Any]) -> "AddableUpdatesDict": raise TypeError("AddableUpdatesDict does not support right-side addition") def map_output_updates( output_channels: Union[str, Sequence[str]], tasks: list[tuple[PregelExecutableTask, Sequence[tuple[str, Any]]]], cached: bool = False, ) -> Iterator[dict[str, Union[Any, dict[str, Any]]]]: """Map pending writes (a sequence of tuples (channel, value)) to output chunk.""" output_tasks = [ (t, ww) for t, ww in tasks if (not t.config or TAG_HIDDEN not in t.config.get("tags", EMPTY_SEQ)) and ww[0][0] != ERROR and ww[0][0] != INTERRUPT ] if not output_tasks: return if isinstance(output_channels, str): updated = ( (task.name, value) for task, writes in output_tasks for chan, value in writes if chan == output_channels ) else: updated = ( ( task.name, {chan: value for chan, value in writes if chan in output_channels}, ) for task, writes in output_tasks if any(chan in output_channels for chan, _ in writes) ) grouped: dict[str, list[Any]] = {t.name: [] for t, _ in output_tasks} for node, value in updated: grouped[node].append(value) for node, value in grouped.items(): if len(value) == 0: grouped[node] = None # type: ignore[assignment] if len(value) == 1: grouped[node] = value[0] if cached: grouped["__metadata__"] = {"cached": cached} # type: ignore[assignment] yield AddableUpdatesDict(grouped) T = TypeVar("T") def single(iter: Iterator[T]) -> Optional[T]: for item in iter: return item ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/__init__.py`: ```py from __future__ import annotations import asyncio import concurrent import concurrent.futures import queue from collections import deque from functools import partial from typing import ( Any, AsyncIterator, Callable, Dict, Iterator, Mapping, Optional, Sequence, Type, Union, cast, get_type_hints, overload, ) from uuid import UUID, uuid5 from langchain_core.globals import get_debug from langchain_core.runnables import ( RunnableSequence, ) from langchain_core.runnables.base import Input, Output from langchain_core.runnables.config import ( RunnableConfig, get_async_callback_manager_for_config, get_callback_manager_for_config, ) from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( ConfigurableFieldSpec, get_unique_config_specs, ) from langchain_core.tracers._streaming import _StreamingCallbackHandler from pydantic import BaseModel from typing_extensions import Self from langgraph.channels.base import ( BaseChannel, ) from langgraph.checkpoint.base import ( BaseCheckpointSaver, CheckpointTuple, copy_checkpoint, create_checkpoint, empty_checkpoint, ) from langgraph.constants import ( CONF, CONFIG_KEY_CHECKPOINT_ID, CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_CHECKPOINTER, CONFIG_KEY_NODE_FINISHED, CONFIG_KEY_READ, CONFIG_KEY_RESUMING, CONFIG_KEY_SEND, CONFIG_KEY_STORE, CONFIG_KEY_STREAM, CONFIG_KEY_STREAM_WRITER, CONFIG_KEY_TASK_ID, END, ERROR, INPUT, INTERRUPT, NS_END, NS_SEP, NULL_TASK_ID, PUSH, SCHEDULED, ) from langgraph.errors import ( ErrorCode, GraphRecursionError, InvalidUpdateError, create_error_message, ) from langgraph.managed.base import ManagedValueSpec from langgraph.pregel.algo import ( PregelTaskWrites, apply_writes, local_read, local_write, prepare_next_tasks, ) from langgraph.pregel.debug import tasks_w_writes from langgraph.pregel.io import read_channels from langgraph.pregel.loop import AsyncPregelLoop, StreamProtocol, SyncPregelLoop from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager from langgraph.pregel.messages import StreamMessagesHandler from langgraph.pregel.protocol import PregelProtocol from langgraph.pregel.read import PregelNode from langgraph.pregel.retry import RetryPolicy from langgraph.pregel.runner import PregelRunner from langgraph.pregel.utils import find_subgraph_pregel, get_new_channel_versions from langgraph.pregel.validate import validate_graph, validate_keys from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry from langgraph.store.base import BaseStore from langgraph.types import ( All, Checkpointer, LoopProtocol, StateSnapshot, StreamChunk, StreamMode, ) from langgraph.utils.config import ( ensure_config, merge_configs, patch_checkpoint_map, patch_config, patch_configurable, ) from langgraph.utils.pydantic import create_model from langgraph.utils.queue import AsyncQueue, SyncQueue # type: ignore[attr-defined] WriteValue = Union[Callable[[Input], Output], Any] class Channel: @overload @classmethod def subscribe_to( cls, channels: str, *, key: Optional[str] = None, tags: Optional[list[str]] = None, ) -> PregelNode: ... @overload @classmethod def subscribe_to( cls, channels: Sequence[str], *, key: None = None, tags: Optional[list[str]] = None, ) -> PregelNode: ... @classmethod def subscribe_to( cls, channels: Union[str, Sequence[str]], *, key: Optional[str] = None, tags: Optional[list[str]] = None, ) -> PregelNode: """Runs process.invoke() each time channels are updated, with a dict of the channel values as input.""" if not isinstance(channels, str) and key is not None: raise ValueError( "Can't specify a key when subscribing to multiple channels" ) return PregelNode( channels=cast( Union[list[str], Mapping[str, str]], ( {key: channels} if isinstance(channels, str) and key is not None else ( [channels] if isinstance(channels, str) else {chan: chan for chan in channels} ) ), ), triggers=[channels] if isinstance(channels, str) else channels, tags=tags, ) @classmethod def write_to( cls, *channels: str, **kwargs: WriteValue, ) -> ChannelWrite: """Writes to channels the result of the lambda, or None to skip writing.""" return ChannelWrite( [ChannelWriteEntry(c) for c in channels] + [ ( ChannelWriteEntry(k, mapper=v) if callable(v) else ChannelWriteEntry(k, value=v) ) for k, v in kwargs.items() ] ) class Pregel(PregelProtocol): nodes: dict[str, PregelNode] channels: dict[str, Union[BaseChannel, ManagedValueSpec]] stream_mode: StreamMode = "values" """Mode to stream output, defaults to 'values'.""" output_channels: Union[str, Sequence[str]] stream_channels: Optional[Union[str, Sequence[str]]] = None """Channels to stream, defaults to all channels not in reserved channels""" interrupt_after_nodes: Union[All, Sequence[str]] interrupt_before_nodes: Union[All, Sequence[str]] input_channels: Union[str, Sequence[str]] step_timeout: Optional[float] = None """Maximum time to wait for a step to complete, in seconds. Defaults to None.""" debug: bool """Whether to print debug information during execution. Defaults to False.""" checkpointer: Checkpointer = None """Checkpointer used to save and load graph state. Defaults to None.""" store: Optional[BaseStore] = None """Memory store to use for SharedValues. Defaults to None.""" retry_policy: Optional[RetryPolicy] = None """Retry policy to use when running tasks. Set to None to disable.""" config_type: Optional[Type[Any]] = None config: Optional[RunnableConfig] = None name: str = "LangGraph" def __init__( self, *, nodes: dict[str, PregelNode], channels: Optional[dict[str, Union[BaseChannel, ManagedValueSpec]]], auto_validate: bool = True, stream_mode: StreamMode = "values", output_channels: Union[str, Sequence[str]], stream_channels: Optional[Union[str, Sequence[str]]] = None, interrupt_after_nodes: Union[All, Sequence[str]] = (), interrupt_before_nodes: Union[All, Sequence[str]] = (), input_channels: Union[str, Sequence[str]], step_timeout: Optional[float] = None, debug: Optional[bool] = None, checkpointer: Optional[BaseCheckpointSaver] = None, store: Optional[BaseStore] = None, retry_policy: Optional[RetryPolicy] = None, config_type: Optional[Type[Any]] = None, config: Optional[RunnableConfig] = None, name: str = "LangGraph", ) -> None: self.nodes = nodes self.channels = channels or {} self.stream_mode = stream_mode self.output_channels = output_channels self.stream_channels = stream_channels self.interrupt_after_nodes = interrupt_after_nodes self.interrupt_before_nodes = interrupt_before_nodes self.input_channels = input_channels self.step_timeout = step_timeout self.debug = debug if debug is not None else get_debug() self.checkpointer = checkpointer self.store = store self.retry_policy = retry_policy self.config_type = config_type self.config = config self.name = name if auto_validate: self.validate() def get_graph( self, config: RunnableConfig | None = None, *, xray: int | bool = False ) -> Graph: raise NotImplementedError async def aget_graph( self, config: RunnableConfig | None = None, *, xray: int | bool = False ) -> Graph: raise NotImplementedError def copy(self, update: dict[str, Any] | None = None) -> Self: attrs = {**self.__dict__, **(update or {})} return self.__class__(**attrs) def with_config(self, config: RunnableConfig | None = None, **kwargs: Any) -> Self: return self.copy( {"config": merge_configs(self.config, config, cast(RunnableConfig, kwargs))} ) def validate(self) -> Self: validate_graph( self.nodes, {k: v for k, v in self.channels.items() if isinstance(v, BaseChannel)}, self.input_channels, self.output_channels, self.stream_channels, self.interrupt_after_nodes, self.interrupt_before_nodes, ) return self @property def config_specs(self) -> list[ConfigurableFieldSpec]: return [ spec for spec in get_unique_config_specs( [spec for node in self.nodes.values() for spec in node.config_specs] + ( self.checkpointer.config_specs if isinstance(self.checkpointer, BaseCheckpointSaver) else [] ) + ( [ ConfigurableFieldSpec(id=name, annotation=typ) for name, typ in get_type_hints(self.config_type).items() ] if self.config_type is not None else [] ) ) # these are provided by the Pregel class if spec.id not in [ CONFIG_KEY_READ, CONFIG_KEY_SEND, CONFIG_KEY_CHECKPOINTER, CONFIG_KEY_RESUMING, ] ] @property def InputType(self) -> Any: if isinstance(self.input_channels, str): channel = self.channels[self.input_channels] if isinstance(channel, BaseChannel): return channel.UpdateType def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: config = merge_configs(self.config, config) if isinstance(self.input_channels, str): return super().get_input_schema(config) else: return create_model( self.get_name("Input"), field_definitions={ k: (c.UpdateType, None) for k in self.input_channels or self.channels.keys() if (c := self.channels[k]) and isinstance(c, BaseChannel) }, ) def get_input_jsonschema( self, config: Optional[RunnableConfig] = None ) -> Dict[All, Any]: schema = self.get_input_schema(config) if hasattr(schema, "model_json_schema"): return schema.model_json_schema() else: return schema.schema() @property def OutputType(self) -> Any: if isinstance(self.output_channels, str): channel = self.channels[self.output_channels] if isinstance(channel, BaseChannel): return channel.ValueType def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: config = merge_configs(self.config, config) if isinstance(self.output_channels, str): return super().get_output_schema(config) else: return create_model( self.get_name("Output"), field_definitions={ k: (c.ValueType, None) for k in self.output_channels if (c := self.channels[k]) and isinstance(c, BaseChannel) }, ) def get_output_jsonschema( self, config: Optional[RunnableConfig] = None ) -> Dict[All, Any]: schema = self.get_output_schema(config) if hasattr(schema, "model_json_schema"): return schema.model_json_schema() else: return schema.schema() @property def stream_channels_list(self) -> Sequence[str]: stream_channels = self.stream_channels_asis return ( [stream_channels] if isinstance(stream_channels, str) else stream_channels ) @property def stream_channels_asis(self) -> Union[str, Sequence[str]]: return self.stream_channels or [ k for k in self.channels if isinstance(self.channels[k], BaseChannel) ] def get_subgraphs( self, *, namespace: Optional[str] = None, recurse: bool = False ) -> Iterator[tuple[str, Pregel]]: for name, node in self.nodes.items(): # filter by prefix if namespace is not None: if not namespace.startswith(name): continue # find the subgraph, if any graph = cast(Optional[Pregel], find_subgraph_pregel(node.bound)) # if found, yield recursively if graph: if name == namespace: yield name, graph return # we found it, stop searching if namespace is None: yield name, graph if recurse: if namespace is not None: namespace = namespace[len(name) + 1 :] yield from ( (f"{name}{NS_SEP}{n}", s) for n, s in graph.get_subgraphs( namespace=namespace, recurse=recurse ) ) async def aget_subgraphs( self, *, namespace: Optional[str] = None, recurse: bool = False ) -> AsyncIterator[tuple[str, Pregel]]: for name, node in self.get_subgraphs(namespace=namespace, recurse=recurse): yield name, node def _prepare_state_snapshot( self, config: RunnableConfig, saved: Optional[CheckpointTuple], recurse: Optional[BaseCheckpointSaver] = None, apply_pending_writes: bool = False, ) -> StateSnapshot: if not saved: return StateSnapshot( values={}, next=(), config=config, metadata=None, created_at=None, parent_config=None, tasks=(), ) with ChannelsManager( self.channels, saved.checkpoint, LoopProtocol( config=saved.config, step=saved.metadata.get("step", -1) + 1, stop=saved.metadata.get("step", -1) + 2, ), skip_context=True, ) as (channels, managed): # tasks for this checkpoint next_tasks = prepare_next_tasks( saved.checkpoint, saved.pending_writes or [], self.nodes, channels, managed, saved.config, saved.metadata.get("step", -1) + 1, for_execution=True, store=self.store, checkpointer=self.checkpointer or None, manager=None, ) # get the subgraphs subgraphs = dict(self.get_subgraphs()) parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") task_states: dict[str, Union[RunnableConfig, StateSnapshot]] = {} for task in next_tasks.values(): if task.name not in subgraphs: continue # assemble checkpoint_ns for this task task_ns = f"{task.name}{NS_END}{task.id}" if parent_ns: task_ns = f"{parent_ns}{NS_SEP}{task_ns}" if not recurse: # set config as signal that subgraph checkpoints exist config = { CONF: { "thread_id": saved.config[CONF]["thread_id"], CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = config else: # get the state of the subgraph config = { CONF: { CONFIG_KEY_CHECKPOINTER: recurse, "thread_id": saved.config[CONF]["thread_id"], CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = subgraphs[task.name].get_state( config, subgraphs=True ) # apply pending writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( saved.checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], None, ) if apply_pending_writes and saved.pending_writes: for tid, k, v in saved.pending_writes: if k in (ERROR, INTERRUPT, SCHEDULED): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) if tasks := [t for t in next_tasks.values() if t.writes]: apply_writes(saved.checkpoint, channels, tasks, None) # assemble the state snapshot return StateSnapshot( read_channels(channels, self.stream_channels_asis), tuple(t.name for t in next_tasks.values() if not t.writes), patch_checkpoint_map(saved.config, saved.metadata), saved.metadata, saved.checkpoint["ts"], patch_checkpoint_map(saved.parent_config, saved.metadata), tasks_w_writes( next_tasks.values(), saved.pending_writes, task_states, self.stream_channels_asis, ), ) async def _aprepare_state_snapshot( self, config: RunnableConfig, saved: Optional[CheckpointTuple], recurse: Optional[BaseCheckpointSaver] = None, apply_pending_writes: bool = False, ) -> StateSnapshot: if not saved: return StateSnapshot( values={}, next=(), config=config, metadata=None, created_at=None, parent_config=None, tasks=(), ) async with AsyncChannelsManager( self.channels, saved.checkpoint, LoopProtocol( config=saved.config, step=saved.metadata.get("step", -1) + 1, stop=saved.metadata.get("step", -1) + 2, ), skip_context=True, ) as ( channels, managed, ): # tasks for this checkpoint next_tasks = prepare_next_tasks( saved.checkpoint, saved.pending_writes or [], self.nodes, channels, managed, saved.config, saved.metadata.get("step", -1) + 1, for_execution=True, store=self.store, checkpointer=self.checkpointer or None, manager=None, ) # get the subgraphs subgraphs = {n: g async for n, g in self.aget_subgraphs()} parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") task_states: dict[str, Union[RunnableConfig, StateSnapshot]] = {} for task in next_tasks.values(): if task.name not in subgraphs: continue # assemble checkpoint_ns for this task task_ns = f"{task.name}{NS_END}{task.id}" if parent_ns: task_ns = f"{parent_ns}{NS_SEP}{task_ns}" if not recurse: # set config as signal that subgraph checkpoints exist config = { CONF: { "thread_id": saved.config[CONF]["thread_id"], CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = config else: # get the state of the subgraph config = { CONF: { CONFIG_KEY_CHECKPOINTER: recurse, "thread_id": saved.config[CONF]["thread_id"], CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = await subgraphs[task.name].aget_state( config, subgraphs=True ) # apply pending writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( saved.checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], None, ) if apply_pending_writes and saved.pending_writes: for tid, k, v in saved.pending_writes: if k in (ERROR, INTERRUPT, SCHEDULED): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) if tasks := [t for t in next_tasks.values() if t.writes]: apply_writes(saved.checkpoint, channels, tasks, None) # assemble the state snapshot return StateSnapshot( read_channels(channels, self.stream_channels_asis), tuple(t.name for t in next_tasks.values() if not t.writes), patch_checkpoint_map(saved.config, saved.metadata), saved.metadata, saved.checkpoint["ts"], patch_checkpoint_map(saved.parent_config, saved.metadata), tasks_w_writes( next_tasks.values(), saved.pending_writes, task_states, self.stream_channels_asis, ), ) def get_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: """Get the current state of the graph.""" checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) ) # find the subgraph with the matching name for _, pregel in self.get_subgraphs( namespace=recast_checkpoint_ns, recurse=True ): return pregel.get_state( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), subgraphs=subgraphs, ) else: raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") config = merge_configs(self.config, config) if self.config else config saved = checkpointer.get_tuple(config) return self._prepare_state_snapshot( config, saved, recurse=checkpointer if subgraphs else None, apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF], ) async def aget_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: """Get the current state of the graph.""" checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) ) # find the subgraph with the matching name async for _, pregel in self.aget_subgraphs( namespace=recast_checkpoint_ns, recurse=True ): return await pregel.aget_state( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), subgraphs=subgraphs, ) else: raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") config = merge_configs(self.config, config) if self.config else config saved = await checkpointer.aget_tuple(config) return await self._aprepare_state_snapshot( config, saved, recurse=checkpointer if subgraphs else None, apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF], ) def get_state_history( self, config: RunnableConfig, *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[StateSnapshot]: config = ensure_config(config) """Get the history of the state of the graph.""" checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) ) # find the subgraph with the matching name for _, pregel in self.get_subgraphs( namespace=recast_checkpoint_ns, recurse=True ): yield from pregel.get_state_history( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), filter=filter, before=before, limit=limit, ) return else: raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") config = merge_configs( self.config, config, {CONF: {CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns}}, ) # eagerly consume list() to avoid holding up the db cursor for checkpoint_tuple in list( checkpointer.list(config, before=before, limit=limit, filter=filter) ): yield self._prepare_state_snapshot( checkpoint_tuple.config, checkpoint_tuple ) async def aget_state_history( self, config: RunnableConfig, *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> AsyncIterator[StateSnapshot]: config = ensure_config(config) """Get the history of the state of the graph.""" checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) ) # find the subgraph with the matching name async for _, pregel in self.aget_subgraphs( namespace=recast_checkpoint_ns, recurse=True ): async for state in pregel.aget_state_history( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), filter=filter, before=before, limit=limit, ): yield state return else: raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") config = merge_configs( self.config, config, {CONF: {CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns}}, ) # eagerly consume list() to avoid holding up the db cursor for checkpoint_tuple in [ c async for c in checkpointer.alist( config, before=before, limit=limit, filter=filter ) ]: yield await self._aprepare_state_snapshot( checkpoint_tuple.config, checkpoint_tuple ) def update_state( self, config: RunnableConfig, values: Optional[Union[dict[str, Any], Any]], as_node: Optional[str] = None, ) -> RunnableConfig: """Update the state of the graph with the given values, as if they came from node `as_node`. If `as_node` is not provided, it will be set to the last node that updated the state, if not ambiguous. """ checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") # delegate to subgraph if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) ) # find the subgraph with the matching name for _, pregel in self.get_subgraphs( namespace=recast_checkpoint_ns, recurse=True ): return pregel.update_state( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), values, as_node, ) else: raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") # get last checkpoint config = ensure_config(self.config, config) saved = checkpointer.get_tuple(config) checkpoint = copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint() checkpoint_previous_versions = ( saved.checkpoint["channel_versions"].copy() if saved else {} ) step = saved.metadata.get("step", -1) if saved else -1 # merge configurable fields with previous checkpoint config checkpoint_config = patch_configurable( config, {CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")}, ) checkpoint_metadata = config["metadata"] if saved: checkpoint_config = patch_configurable(config, saved.config[CONF]) checkpoint_metadata = {**saved.metadata, **checkpoint_metadata} with ChannelsManager( self.channels, checkpoint, LoopProtocol(config=config, step=step + 1, stop=step + 2), ) as (channels, managed): # no values as END, just clear all tasks if values is None and as_node == END: if saved is not None: # tasks for this checkpoint next_tasks = prepare_next_tasks( checkpoint, saved.pending_writes or [], self.nodes, channels, managed, saved.config, saved.metadata.get("step", -1) + 1, for_execution=True, store=self.store, checkpointer=self.checkpointer or None, manager=None, ) # apply null writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( saved.checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], None, ) # apply writes from tasks that already ran for tid, k, v in saved.pending_writes or []: if k in (ERROR, INTERRUPT, SCHEDULED): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) # clear all current tasks apply_writes(checkpoint, channels, next_tasks.values(), None) # save checkpoint next_config = checkpointer.put( checkpoint_config, create_checkpoint(checkpoint, None, step), { **checkpoint_metadata, "source": "update", "step": step + 1, "writes": {}, "parents": saved.metadata.get("parents", {}) if saved else {}, }, {}, ) return patch_checkpoint_map( next_config, saved.metadata if saved else None ) # no values, copy checkpoint if values is None and as_node is None: next_checkpoint = create_checkpoint(checkpoint, None, step) # copy checkpoint next_config = checkpointer.put( checkpoint_config, next_checkpoint, { **checkpoint_metadata, "source": "update", "step": step + 1, "writes": {}, "parents": saved.metadata.get("parents", {}) if saved else {}, }, {}, ) return patch_checkpoint_map( next_config, saved.metadata if saved else None ) if values is None and as_node == "__copy__": next_checkpoint = create_checkpoint(checkpoint, None, step) # copy checkpoint next_config = checkpointer.put( saved.parent_config or saved.config if saved else checkpoint_config, next_checkpoint, { **checkpoint_metadata, "source": "fork", "step": step + 1, "parents": saved.metadata.get("parents", {}) if saved else {}, }, {}, ) return patch_checkpoint_map( next_config, saved.metadata if saved else None ) # apply pending writes, if not on specific checkpoint if ( CONFIG_KEY_CHECKPOINT_ID not in config[CONF] and saved is not None and saved.pending_writes ): # tasks for this checkpoint next_tasks = prepare_next_tasks( checkpoint, saved.pending_writes, self.nodes, channels, managed, saved.config, saved.metadata.get("step", -1) + 1, for_execution=True, store=self.store, checkpointer=self.checkpointer or None, manager=None, ) # apply null writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( saved.checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], None, ) # apply writes for tid, k, v in saved.pending_writes: if k in (ERROR, INTERRUPT, SCHEDULED): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) if tasks := [t for t in next_tasks.values() if t.writes]: apply_writes(checkpoint, channels, tasks, None) # find last node that updated the state, if not provided if as_node is None and not any( v for vv in checkpoint["versions_seen"].values() for v in vv.values() ): if ( isinstance(self.input_channels, str) and self.input_channels in self.nodes ): as_node = self.input_channels elif as_node is None: last_seen_by_node = sorted( (v, n) for n, seen in checkpoint["versions_seen"].items() if n in self.nodes for v in seen.values() ) # if two nodes updated the state at the same time, it's ambiguous if last_seen_by_node: if len(last_seen_by_node) == 1: as_node = last_seen_by_node[0][1] elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]: as_node = last_seen_by_node[-1][1] if as_node is None: raise InvalidUpdateError("Ambiguous update, specify as_node") if as_node not in self.nodes: raise InvalidUpdateError(f"Node {as_node} does not exist") # create task to run all writers of the chosen node writers = self.nodes[as_node].flat_writers if not writers: raise InvalidUpdateError(f"Node {as_node} has no writers") writes: deque[tuple[str, Any]] = deque() task = PregelTaskWrites((), as_node, writes, [INTERRUPT]) task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT)) run = RunnableSequence(*writers) if len(writers) > 1 else writers[0] # execute task run.invoke( values, patch_config( config, run_name=self.name + "UpdateState", configurable={ # deque.extend is thread-safe CONFIG_KEY_SEND: partial( local_write, writes.extend, self.nodes.keys(), ), CONFIG_KEY_READ: partial( local_read, step + 1, checkpoint, channels, managed, task, config, ), }, ), ) # save task writes # channel writes are saved to current checkpoint # push writes are saved to next checkpoint channel_writes, push_writes = ( [w for w in task.writes if w[0] != PUSH], [w for w in task.writes if w[0] == PUSH], ) if saved and channel_writes: checkpointer.put_writes(checkpoint_config, channel_writes, task_id) # apply to checkpoint and save mv_writes = apply_writes( checkpoint, channels, [task], checkpointer.get_next_version ) assert not mv_writes, "Can't write to SharedValues from update_state" checkpoint = create_checkpoint(checkpoint, channels, step + 1) next_config = checkpointer.put( checkpoint_config, checkpoint, { **checkpoint_metadata, "source": "update", "step": step + 1, "writes": {as_node: values}, "parents": saved.metadata.get("parents", {}) if saved else {}, }, get_new_channel_versions( checkpoint_previous_versions, checkpoint["channel_versions"] ), ) if push_writes: checkpointer.put_writes(next_config, push_writes, task_id) return patch_checkpoint_map(next_config, saved.metadata if saved else None) async def aupdate_state( self, config: RunnableConfig, values: dict[str, Any] | Any, as_node: Optional[str] = None, ) -> RunnableConfig: checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") # delegate to subgraph if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) ) # find the subgraph with the matching name async for _, pregel in self.aget_subgraphs( namespace=recast_checkpoint_ns, recurse=True ): return await pregel.aupdate_state( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), values, as_node, ) else: raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") # get last checkpoint config = ensure_config(self.config, config) saved = await checkpointer.aget_tuple(config) checkpoint = copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint() checkpoint_previous_versions = ( saved.checkpoint["channel_versions"].copy() if saved else {} ) step = saved.metadata.get("step", -1) if saved else -1 # merge configurable fields with previous checkpoint config checkpoint_config = patch_configurable( config, {CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")}, ) checkpoint_metadata = config["metadata"] if saved: checkpoint_config = patch_configurable(config, saved.config[CONF]) checkpoint_metadata = {**saved.metadata, **checkpoint_metadata} async with AsyncChannelsManager( self.channels, checkpoint, LoopProtocol(config=config, step=step + 1, stop=step + 2), ) as ( channels, managed, ): # no values, just clear all tasks if values is None and as_node == END: if saved is not None: # tasks for this checkpoint next_tasks = prepare_next_tasks( checkpoint, saved.pending_writes or [], self.nodes, channels, managed, saved.config, saved.metadata.get("step", -1) + 1, for_execution=True, store=self.store, checkpointer=self.checkpointer or None, manager=None, ) # apply null writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( saved.checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], None, ) # apply writes from tasks that already ran for tid, k, v in saved.pending_writes or []: if k in (ERROR, INTERRUPT, SCHEDULED): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) # clear all current tasks apply_writes(checkpoint, channels, next_tasks.values(), None) # save checkpoint next_config = await checkpointer.aput( checkpoint_config, create_checkpoint(checkpoint, None, step), { **checkpoint_metadata, "source": "update", "step": step + 1, "writes": {}, "parents": saved.metadata.get("parents", {}) if saved else {}, }, {}, ) return patch_checkpoint_map( next_config, saved.metadata if saved else None ) # no values, copy checkpoint if values is None and as_node is None: next_checkpoint = create_checkpoint(checkpoint, None, step) # copy checkpoint next_config = await checkpointer.aput( checkpoint_config, next_checkpoint, { **checkpoint_metadata, "source": "update", "step": step + 1, "writes": {}, "parents": saved.metadata.get("parents", {}) if saved else {}, }, {}, ) return patch_checkpoint_map( next_config, saved.metadata if saved else None ) if values is None and as_node == "__copy__": next_checkpoint = create_checkpoint(checkpoint, None, step) # copy checkpoint next_config = await checkpointer.aput( saved.parent_config or saved.config if saved else checkpoint_config, next_checkpoint, { **checkpoint_metadata, "source": "fork", "step": step + 1, "parents": saved.metadata.get("parents", {}) if saved else {}, }, {}, ) return patch_checkpoint_map( next_config, saved.metadata if saved else None ) # apply pending writes, if not on specific checkpoint if ( CONFIG_KEY_CHECKPOINT_ID not in config[CONF] and saved is not None and saved.pending_writes ): # tasks for this checkpoint next_tasks = prepare_next_tasks( checkpoint, saved.pending_writes, self.nodes, channels, managed, saved.config, saved.metadata.get("step", -1) + 1, for_execution=True, store=self.store, checkpointer=self.checkpointer or None, manager=None, ) # apply null writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( saved.checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], None, ) for tid, k, v in saved.pending_writes: if k in (ERROR, INTERRUPT, SCHEDULED): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) if tasks := [t for t in next_tasks.values() if t.writes]: apply_writes(checkpoint, channels, tasks, None) # find last node that updated the state, if not provided if as_node is None and not saved: if ( isinstance(self.input_channels, str) and self.input_channels in self.nodes ): as_node = self.input_channels elif as_node is None: last_seen_by_node = sorted( (v, n) for n, seen in checkpoint["versions_seen"].items() if n in self.nodes for v in seen.values() ) # if two nodes updated the state at the same time, it's ambiguous if last_seen_by_node: if len(last_seen_by_node) == 1: as_node = last_seen_by_node[0][1] elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]: as_node = last_seen_by_node[-1][1] if as_node is None: raise InvalidUpdateError("Ambiguous update, specify as_node") if as_node not in self.nodes: raise InvalidUpdateError(f"Node {as_node} does not exist") # create task to run all writers of the chosen node writers = self.nodes[as_node].flat_writers if not writers: raise InvalidUpdateError(f"Node {as_node} has no writers") writes: deque[tuple[str, Any]] = deque() task = PregelTaskWrites((), as_node, writes, [INTERRUPT]) task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT)) run = RunnableSequence(*writers) if len(writers) > 1 else writers[0] # execute task await run.ainvoke( values, patch_config( config, run_name=self.name + "UpdateState", configurable={ # deque.extend is thread-safe CONFIG_KEY_SEND: partial( local_write, writes.extend, self.nodes.keys(), ), CONFIG_KEY_READ: partial( local_read, step + 1, checkpoint, channels, managed, task, config, ), }, ), ) # save task writes # channel writes are saved to current checkpoint # push writes are saved to next checkpoint channel_writes, push_writes = ( [w for w in task.writes if w[0] != PUSH], [w for w in task.writes if w[0] == PUSH], ) if saved and channel_writes: await checkpointer.aput_writes( checkpoint_config, channel_writes, task_id ) # apply to checkpoint and save mv_writes = apply_writes( checkpoint, channels, [task], checkpointer.get_next_version ) assert not mv_writes, "Can't write to SharedValues from update_state" checkpoint = create_checkpoint(checkpoint, channels, step + 1) # save checkpoint, after applying writes next_config = await checkpointer.aput( checkpoint_config, checkpoint, { **checkpoint_metadata, "source": "update", "step": step + 1, "writes": {as_node: values}, "parents": saved.metadata.get("parents", {}) if saved else {}, }, get_new_channel_versions( checkpoint_previous_versions, checkpoint["channel_versions"] ), ) # save push writes if push_writes: await checkpointer.aput_writes(next_config, push_writes, task_id) return patch_checkpoint_map(next_config, saved.metadata if saved else None) def _defaults( self, config: RunnableConfig, *, stream_mode: Optional[Union[StreamMode, list[StreamMode]]], output_keys: Optional[Union[str, Sequence[str]]], interrupt_before: Optional[Union[All, Sequence[str]]], interrupt_after: Optional[Union[All, Sequence[str]]], debug: Optional[bool], ) -> tuple[ bool, set[StreamMode], Union[str, Sequence[str]], Union[All, Sequence[str]], Union[All, Sequence[str]], Optional[BaseCheckpointSaver], Optional[BaseStore], ]: if config["recursion_limit"] < 1: raise ValueError("recursion_limit must be at least 1") debug = debug if debug is not None else self.debug if output_keys is None: output_keys = self.stream_channels_asis else: validate_keys(output_keys, self.channels) interrupt_before = interrupt_before or self.interrupt_before_nodes interrupt_after = interrupt_after or self.interrupt_after_nodes stream_mode = stream_mode if stream_mode is not None else self.stream_mode if not isinstance(stream_mode, list): stream_mode = [stream_mode] if CONFIG_KEY_TASK_ID in config.get(CONF, {}): # if being called as a node in another graph, always use values mode stream_mode = ["values"] if self.checkpointer is False: checkpointer: Optional[BaseCheckpointSaver] = None elif CONFIG_KEY_CHECKPOINTER in config.get(CONF, {}): checkpointer = config[CONF][CONFIG_KEY_CHECKPOINTER] else: checkpointer = self.checkpointer if checkpointer and not config.get(CONF): raise ValueError( f"Checkpointer requires one or more of the following 'configurable' keys: {[s.id for s in checkpointer.config_specs]}" ) if CONFIG_KEY_STORE in config.get(CONF, {}): store: Optional[BaseStore] = config[CONF][CONFIG_KEY_STORE] else: store = self.store return ( debug, set(stream_mode), output_keys, interrupt_before, interrupt_after, checkpointer, store, ) def stream( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None, output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, subgraphs: bool = False, ) -> Iterator[Union[dict[str, Any], Any]]: """Stream graph steps for a single input. Args: input: The input to the graph. config: The configuration to use for the run. stream_mode: The mode to stream output, defaults to self.stream_mode. Options are 'values', 'updates', and 'debug'. values: Emit the current values of the state for each step. updates: Emit only the updates to the state for each step. Output is a dict with the node name as key and the updated values as value. debug: Emit debug events for each step. output_keys: The keys to stream, defaults to all non-context channels. interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph. interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph. debug: Whether to print debug information during execution, defaults to False. subgraphs: Whether to stream subgraphs, defaults to False. Yields: The output of each step in the graph. The output shape depends on the stream_mode. Examples: Using different stream modes with a graph: ```pycon >>> import operator >>> from typing_extensions import Annotated, TypedDict >>> from langgraph.graph import StateGraph >>> from langgraph.constants import START ... >>> class State(TypedDict): ... alist: Annotated[list, operator.add] ... another_list: Annotated[list, operator.add] ... >>> builder = StateGraph(State) >>> builder.add_node("a", lambda _state: {"another_list": ["hi"]}) >>> builder.add_node("b", lambda _state: {"alist": ["there"]}) >>> builder.add_edge("a", "b") >>> builder.add_edge(START, "a") >>> graph = builder.compile() ``` With stream_mode="values": ```pycon >>> for event in graph.stream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"): ... print(event) {'alist': ['Ex for stream_mode="values"'], 'another_list': []} {'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']} {'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']} ``` With stream_mode="updates": ```pycon >>> for event in graph.stream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"): ... print(event) {'a': {'another_list': ['hi']}} {'b': {'alist': ['there']}} ``` With stream_mode="debug": ```pycon >>> for event in graph.stream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"): ... print(event) {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}} {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}} {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}} {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}} ``` """ stream = SyncQueue() def output() -> Iterator: while True: try: ns, mode, payload = stream.get(block=False) except queue.Empty: break if subgraphs and isinstance(stream_mode, list): yield (ns, mode, payload) elif isinstance(stream_mode, list): yield (mode, payload) elif subgraphs: yield (ns, payload) else: yield payload config = ensure_config(self.config, config) callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( None, input, name=config.get("run_name", self.get_name()), run_id=config.get("run_id"), ) try: # assign defaults ( debug, stream_modes, output_keys, interrupt_before_, interrupt_after_, checkpointer, store, ) = self._defaults( config, stream_mode=stream_mode, output_keys=output_keys, interrupt_before=interrupt_before, interrupt_after=interrupt_after, debug=debug, ) # set up messages stream mode if "messages" in stream_modes: run_manager.inheritable_handlers.append( StreamMessagesHandler(stream.put) ) # set up custom stream mode if "custom" in stream_modes: config[CONF][CONFIG_KEY_STREAM_WRITER] = lambda c: stream.put( ((), "custom", c) ) with SyncPregelLoop( input, stream=StreamProtocol(stream.put, stream_modes), config=config, store=store, checkpointer=checkpointer, nodes=self.nodes, specs=self.channels, output_keys=output_keys, stream_keys=self.stream_channels_asis, interrupt_before=interrupt_before_, interrupt_after=interrupt_after_, manager=run_manager, debug=debug, ) as loop: # create runner runner = PregelRunner( submit=loop.submit, put_writes=loop.put_writes, schedule_task=loop.accept_push, node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED), ) # enable subgraph streaming if subgraphs: loop.config[CONF][CONFIG_KEY_STREAM] = loop.stream # enable concurrent streaming if subgraphs or "messages" in stream_modes or "custom" in stream_modes: # we are careful to have a single waiter live at any one time # because on exit we increment semaphore count by exactly 1 waiter: Optional[concurrent.futures.Future] = None # because sync futures cannot be cancelled, we instead # release the stream semaphore on exit, which will cause # a pending waiter to return immediately loop.stack.callback(stream._count.release) def get_waiter() -> concurrent.futures.Future[None]: nonlocal waiter if waiter is None or waiter.done(): waiter = loop.submit(stream.wait) return waiter else: return waiter else: get_waiter = None # type: ignore[assignment] # Similarly to Bulk Synchronous Parallel / Pregel model # computation proceeds in steps, while there are channel updates # channel updates from step N are only visible in step N+1 # channels are guaranteed to be immutable for the duration of the step, # with channel updates applied only at the transition between steps while loop.tick(input_keys=self.input_channels): for _ in runner.tick( loop.tasks.values(), timeout=self.step_timeout, retry_policy=self.retry_policy, get_waiter=get_waiter, ): # emit output yield from output() # emit output yield from output() # handle exit if loop.status == "out_of_steps": msg = create_error_message( message=( f"Recursion limit of {config['recursion_limit']} reached " "without hitting a stop condition. You can increase the " "limit by setting the `recursion_limit` config key." ), error_code=ErrorCode.GRAPH_RECURSION_LIMIT, ) raise GraphRecursionError(msg) # set final channel values as run output run_manager.on_chain_end(loop.output) except BaseException as e: run_manager.on_chain_error(e) raise async def astream( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None, output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, subgraphs: bool = False, ) -> AsyncIterator[Union[dict[str, Any], Any]]: """Stream graph steps for a single input. Args: input: The input to the graph. config: The configuration to use for the run. stream_mode: The mode to stream output, defaults to self.stream_mode. Options are 'values', 'updates', and 'debug'. values: Emit the current values of the state for each step. updates: Emit only the updates to the state for each step. Output is a dict with the node name as key and the updated values as value. debug: Emit debug events for each step. output_keys: The keys to stream, defaults to all non-context channels. interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph. interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph. debug: Whether to print debug information during execution, defaults to False. subgraphs: Whether to stream subgraphs, defaults to False. Yields: The output of each step in the graph. The output shape depends on the stream_mode. Examples: Using different stream modes with a graph: ```pycon >>> import operator >>> from typing_extensions import Annotated, TypedDict >>> from langgraph.graph import StateGraph >>> from langgraph.constants import START ... >>> class State(TypedDict): ... alist: Annotated[list, operator.add] ... another_list: Annotated[list, operator.add] ... >>> builder = StateGraph(State) >>> builder.add_node("a", lambda _state: {"another_list": ["hi"]}) >>> builder.add_node("b", lambda _state: {"alist": ["there"]}) >>> builder.add_edge("a", "b") >>> builder.add_edge(START, "a") >>> graph = builder.compile() ``` With stream_mode="values": ```pycon >>> async for event in graph.astream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"): ... print(event) {'alist': ['Ex for stream_mode="values"'], 'another_list': []} {'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']} {'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']} ``` With stream_mode="updates": ```pycon >>> async for event in graph.astream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"): ... print(event) {'a': {'another_list': ['hi']}} {'b': {'alist': ['there']}} ``` With stream_mode="debug": ```pycon >>> async for event in graph.astream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"): ... print(event) {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}} {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}} {'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}} {'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}} ``` """ stream = AsyncQueue() aioloop = asyncio.get_running_loop() stream_put = cast( Callable[[StreamChunk], None], partial(aioloop.call_soon_threadsafe, stream.put_nowait), ) def output() -> Iterator: while True: try: ns, mode, payload = stream.get_nowait() except asyncio.QueueEmpty: break if subgraphs and isinstance(stream_mode, list): yield (ns, mode, payload) elif isinstance(stream_mode, list): yield (mode, payload) elif subgraphs: yield (ns, payload) else: yield payload config = ensure_config(self.config, config) callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( None, input, name=config.get("run_name", self.get_name()), run_id=config.get("run_id"), ) # if running from astream_log() run each proc with streaming do_stream = next( ( cast(_StreamingCallbackHandler, h) for h in run_manager.handlers if isinstance(h, _StreamingCallbackHandler) ), None, ) try: # assign defaults ( debug, stream_modes, output_keys, interrupt_before_, interrupt_after_, checkpointer, store, ) = self._defaults( config, stream_mode=stream_mode, output_keys=output_keys, interrupt_before=interrupt_before, interrupt_after=interrupt_after, debug=debug, ) # set up messages stream mode if "messages" in stream_modes: run_manager.inheritable_handlers.append( StreamMessagesHandler(stream_put) ) # set up custom stream mode if "custom" in stream_modes: config[CONF][CONFIG_KEY_STREAM_WRITER] = ( lambda c: aioloop.call_soon_threadsafe( stream.put_nowait, ((), "custom", c) ) ) async with AsyncPregelLoop( input, stream=StreamProtocol(stream.put_nowait, stream_modes), config=config, store=store, checkpointer=checkpointer, nodes=self.nodes, specs=self.channels, output_keys=output_keys, stream_keys=self.stream_channels_asis, interrupt_before=interrupt_before_, interrupt_after=interrupt_after_, manager=run_manager, debug=debug, ) as loop: # create runner runner = PregelRunner( submit=loop.submit, put_writes=loop.put_writes, schedule_task=loop.accept_push, use_astream=do_stream is not None, node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED), ) # enable subgraph streaming if subgraphs: loop.config[CONF][CONFIG_KEY_STREAM] = StreamProtocol( stream_put, stream_modes ) # enable concurrent streaming if subgraphs or "messages" in stream_modes or "custom" in stream_modes: def get_waiter() -> asyncio.Task[None]: return aioloop.create_task(stream.wait()) else: get_waiter = None # type: ignore[assignment] # Similarly to Bulk Synchronous Parallel / Pregel model # computation proceeds in steps, while there are channel updates # channel updates from step N are only visible in step N+1 # channels are guaranteed to be immutable for the duration of the step, # with channel updates applied only at the transition between steps while loop.tick(input_keys=self.input_channels): async for _ in runner.atick( loop.tasks.values(), timeout=self.step_timeout, retry_policy=self.retry_policy, get_waiter=get_waiter, ): # emit output for o in output(): yield o # emit output for o in output(): yield o # handle exit if loop.status == "out_of_steps": msg = create_error_message( message=( f"Recursion limit of {config['recursion_limit']} reached " "without hitting a stop condition. You can increase the " "limit by setting the `recursion_limit` config key." ), error_code=ErrorCode.GRAPH_RECURSION_LIMIT, ) raise GraphRecursionError(msg) # set final channel values as run output await run_manager.on_chain_end(loop.output) except BaseException as e: await asyncio.shield(run_manager.on_chain_error(e)) raise def invoke( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: StreamMode = "values", output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, **kwargs: Any, ) -> Union[dict[str, Any], Any]: """Run the graph with a single input and config. Args: input: The input data for the graph. It can be a dictionary or any other type. config: Optional. The configuration for the graph run. stream_mode: Optional[str]. The stream mode for the graph run. Default is "values". output_keys: Optional. The output keys to retrieve from the graph run. interrupt_before: Optional. The nodes to interrupt the graph run before. interrupt_after: Optional. The nodes to interrupt the graph run after. debug: Optional. Enable debug mode for the graph run. **kwargs: Additional keyword arguments to pass to the graph run. Returns: The output of the graph run. If stream_mode is "values", it returns the latest output. If stream_mode is not "values", it returns a list of output chunks. """ output_keys = output_keys if output_keys is not None else self.output_channels if stream_mode == "values": latest: Union[dict[str, Any], Any] = None else: chunks = [] for chunk in self.stream( input, config, stream_mode=stream_mode, output_keys=output_keys, interrupt_before=interrupt_before, interrupt_after=interrupt_after, debug=debug, **kwargs, ): if stream_mode == "values": latest = chunk else: chunks.append(chunk) if stream_mode == "values": return latest else: return chunks async def ainvoke( self, input: Union[dict[str, Any], Any], config: Optional[RunnableConfig] = None, *, stream_mode: StreamMode = "values", output_keys: Optional[Union[str, Sequence[str]]] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, debug: Optional[bool] = None, **kwargs: Any, ) -> Union[dict[str, Any], Any]: """Asynchronously invoke the graph on a single input. Args: input: The input data for the computation. It can be a dictionary or any other type. config: Optional. The configuration for the computation. stream_mode: Optional. The stream mode for the computation. Default is "values". output_keys: Optional. The output keys to include in the result. Default is None. interrupt_before: Optional. The nodes to interrupt before. Default is None. interrupt_after: Optional. The nodes to interrupt after. Default is None. debug: Optional. Whether to enable debug mode. Default is None. **kwargs: Additional keyword arguments. Returns: The result of the computation. If stream_mode is "values", it returns the latest value. If stream_mode is "chunks", it returns a list of chunks. """ output_keys = output_keys if output_keys is not None else self.output_channels if stream_mode == "values": latest: Union[dict[str, Any], Any] = None else: chunks = [] async for chunk in self.astream( input, config, stream_mode=stream_mode, output_keys=output_keys, interrupt_before=interrupt_before, interrupt_after=interrupt_after, debug=debug, **kwargs, ): if stream_mode == "values": latest = chunk else: chunks.append(chunk) if stream_mode == "values": return latest else: return chunks ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/types.py`: ```py """Re-export types moved to langgraph.types""" from langgraph.types import ( All, CachePolicy, PregelExecutableTask, PregelTask, RetryPolicy, StateSnapshot, StreamMode, StreamWriter, default_retry_on, ) __all__ = [ "All", "CachePolicy", "PregelExecutableTask", "PregelTask", "RetryPolicy", "StateSnapshot", "StreamMode", "StreamWriter", "default_retry_on", ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/retry.py`: ```py import asyncio import logging import random import sys import time from dataclasses import replace from functools import partial from typing import Any, Callable, Optional, Sequence from langgraph.constants import ( CONF, CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_RESUMING, CONFIG_KEY_SEND, NS_SEP, ) from langgraph.errors import _SEEN_CHECKPOINT_NS, GraphBubbleUp, ParentCommand from langgraph.types import Command, PregelExecutableTask, RetryPolicy from langgraph.utils.config import patch_configurable logger = logging.getLogger(__name__) SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11) def run_with_retry( task: PregelExecutableTask, retry_policy: Optional[RetryPolicy], writer: Optional[ Callable[[PregelExecutableTask, Sequence[tuple[str, Any]]], None] ] = None, ) -> None: """Run a task with retries.""" retry_policy = task.retry_policy or retry_policy interval = retry_policy.initial_interval if retry_policy else 0 attempts = 0 config = task.config if writer is not None: config = patch_configurable(config, {CONFIG_KEY_SEND: partial(writer, task)}) while True: try: # clear any writes from previous attempts task.writes.clear() # run the task task.proc.invoke(task.input, config) # if successful, end break except ParentCommand as exc: ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS] cmd = exc.args[0] if cmd.graph == ns: # this command is for the current graph, handle it for w in task.writers: w.invoke(cmd, config) break elif cmd.graph == Command.PARENT: # this command is for the parent graph, assign it to the parent parent_ns = NS_SEP.join(ns.split(NS_SEP)[:-1]) exc.args = (replace(cmd, graph=parent_ns),) # bubble up raise except GraphBubbleUp: # if interrupted, end raise except Exception as exc: if SUPPORTS_EXC_NOTES: exc.add_note(f"During task with name '{task.name}' and id '{task.id}'") if retry_policy is None: raise # increment attempts attempts += 1 # check if we should retry if isinstance(retry_policy.retry_on, Sequence): if not isinstance(exc, tuple(retry_policy.retry_on)): raise elif isinstance(retry_policy.retry_on, type) and issubclass( retry_policy.retry_on, Exception ): if not isinstance(exc, retry_policy.retry_on): raise elif callable(retry_policy.retry_on): if not retry_policy.retry_on(exc): # type: ignore[call-arg] raise else: raise TypeError( "retry_on must be an Exception class, a list or tuple of Exception classes, or a callable" ) # check if we should give up if attempts >= retry_policy.max_attempts: raise # sleep before retrying interval = min( retry_policy.max_interval, interval * retry_policy.backoff_factor, ) time.sleep( interval + random.uniform(0, 1) if retry_policy.jitter else interval ) # log the retry logger.info( f"Retrying task {task.name} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}", exc_info=exc, ) # signal subgraphs to resume (if available) config = patch_configurable(config, {CONFIG_KEY_RESUMING: True}) # clear checkpoint_ns seen (for subgraph detection) if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS): _SEEN_CHECKPOINT_NS.discard(checkpoint_ns) finally: # clear checkpoint_ns seen (for subgraph detection) if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS): _SEEN_CHECKPOINT_NS.discard(checkpoint_ns) async def arun_with_retry( task: PregelExecutableTask, retry_policy: Optional[RetryPolicy], stream: bool = False, writer: Optional[ Callable[[PregelExecutableTask, Sequence[tuple[str, Any]]], None] ] = None, ) -> None: """Run a task asynchronously with retries.""" retry_policy = task.retry_policy or retry_policy interval = retry_policy.initial_interval if retry_policy else 0 attempts = 0 config = task.config if writer is not None: config = patch_configurable(config, {CONFIG_KEY_SEND: partial(writer, task)}) while True: try: # clear any writes from previous attempts task.writes.clear() # run the task if stream: async for _ in task.proc.astream(task.input, config): pass else: await task.proc.ainvoke(task.input, config) # if successful, end break except ParentCommand as exc: ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS] cmd = exc.args[0] if cmd.graph == ns: # this command is for the current graph, handle it for w in task.writers: w.invoke(cmd, config) break elif cmd.graph == Command.PARENT: # this command is for the parent graph, assign it to the parent parent_ns = NS_SEP.join(ns.split(NS_SEP)[:-1]) exc.args = (replace(cmd, graph=parent_ns),) # bubble up raise except GraphBubbleUp: # if interrupted, end raise except Exception as exc: if SUPPORTS_EXC_NOTES: exc.add_note(f"During task with name '{task.name}' and id '{task.id}'") if retry_policy is None: raise # increment attempts attempts += 1 # check if we should retry if isinstance(retry_policy.retry_on, Sequence): if not isinstance(exc, tuple(retry_policy.retry_on)): raise elif isinstance(retry_policy.retry_on, type) and issubclass( retry_policy.retry_on, Exception ): if not isinstance(exc, retry_policy.retry_on): raise elif callable(retry_policy.retry_on): if not retry_policy.retry_on(exc): # type: ignore[call-arg] raise else: raise TypeError( "retry_on must be an Exception class, a list or tuple of Exception classes, or a callable" ) # check if we should give up if attempts >= retry_policy.max_attempts: raise # sleep before retrying interval = min( retry_policy.max_interval, interval * retry_policy.backoff_factor, ) await asyncio.sleep( interval + random.uniform(0, 1) if retry_policy.jitter else interval ) # log the retry logger.info( f"Retrying task {task.name} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}", exc_info=exc, ) # signal subgraphs to resume (if available) config = patch_configurable(config, {CONFIG_KEY_RESUMING: True}) # clear checkpoint_ns seen (for subgraph detection) if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS): _SEEN_CHECKPOINT_NS.discard(checkpoint_ns) finally: # clear checkpoint_ns seen (for subgraph detection) if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS): _SEEN_CHECKPOINT_NS.discard(checkpoint_ns) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/validate.py`: ```py from typing import Any, Mapping, Optional, Sequence, Union from langgraph.channels.base import BaseChannel from langgraph.constants import RESERVED from langgraph.pregel.read import PregelNode from langgraph.types import All def validate_graph( nodes: Mapping[str, PregelNode], channels: dict[str, BaseChannel], input_channels: Union[str, Sequence[str]], output_channels: Union[str, Sequence[str]], stream_channels: Optional[Union[str, Sequence[str]]], interrupt_after_nodes: Union[All, Sequence[str]], interrupt_before_nodes: Union[All, Sequence[str]], ) -> None: for chan in channels: if chan in RESERVED: raise ValueError(f"Channel names {chan} are reserved") subscribed_channels = set[str]() for name, node in nodes.items(): if name in RESERVED: raise ValueError(f"Node names {RESERVED} are reserved") if isinstance(node, PregelNode): subscribed_channels.update(node.triggers) else: raise TypeError( f"Invalid node type {type(node)}, expected Channel.subscribe_to()" ) for chan in subscribed_channels: if chan not in channels: raise ValueError(f"Subscribed channel '{chan}' not in 'channels'") if isinstance(input_channels, str): if input_channels not in channels: raise ValueError(f"Input channel '{input_channels}' not in 'channels'") if input_channels not in subscribed_channels: raise ValueError( f"Input channel {input_channels} is not subscribed to by any node" ) else: for chan in input_channels: if chan not in channels: raise ValueError(f"Input channel '{chan}' not in 'channels'") if all(chan not in subscribed_channels for chan in input_channels): raise ValueError( f"None of the input channels {input_channels} are subscribed to by any node" ) all_output_channels = set[str]() if isinstance(output_channels, str): all_output_channels.add(output_channels) else: all_output_channels.update(output_channels) if isinstance(stream_channels, str): all_output_channels.add(stream_channels) elif stream_channels is not None: all_output_channels.update(stream_channels) for chan in all_output_channels: if chan not in channels: raise ValueError(f"Output channel '{chan}' not in 'channels'") if interrupt_after_nodes != "*": for n in interrupt_after_nodes: if n not in nodes: raise ValueError(f"Node {n} not in nodes") if interrupt_before_nodes != "*": for n in interrupt_before_nodes: if n not in nodes: raise ValueError(f"Node {n} not in nodes") def validate_keys( keys: Optional[Union[str, Sequence[str]]], channels: Mapping[str, Any], ) -> None: if isinstance(keys, str): if keys not in channels: raise ValueError(f"Key {keys} not in channels") elif keys is not None: for chan in keys: if chan not in channels: raise ValueError(f"Key {chan} not in channels") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/utils.py`: ```py from typing import Optional from langchain_core.runnables import RunnableLambda, RunnableSequence from langchain_core.runnables.utils import get_function_nonlocals from langgraph.checkpoint.base import ChannelVersions from langgraph.pregel.protocol import PregelProtocol from langgraph.utils.runnable import Runnable, RunnableCallable, RunnableSeq def get_new_channel_versions( previous_versions: ChannelVersions, current_versions: ChannelVersions ) -> ChannelVersions: """Get subset of current_versions that are newer than previous_versions.""" if previous_versions: version_type = type(next(iter(current_versions.values()), None)) null_version = version_type() # type: ignore[misc] new_versions = { k: v for k, v in current_versions.items() if v > previous_versions.get(k, null_version) # type: ignore[operator] } else: new_versions = current_versions return new_versions def find_subgraph_pregel(candidate: Runnable) -> Optional[Runnable]: from langgraph.pregel import Pregel candidates: list[Runnable] = [candidate] for c in candidates: if ( isinstance(c, PregelProtocol) # subgraphs that disabled checkpointing are not considered and (not isinstance(c, Pregel) or c.checkpointer is not False) ): return c elif isinstance(c, RunnableSequence) or isinstance(c, RunnableSeq): candidates.extend(c.steps) elif isinstance(c, RunnableLambda): candidates.extend(c.deps) elif isinstance(c, RunnableCallable): if c.func is not None: candidates.extend( nl.__self__ if hasattr(nl, "__self__") else nl for nl in get_function_nonlocals(c.func) ) elif c.afunc is not None: candidates.extend( nl.__self__ if hasattr(nl, "__self__") else nl for nl in get_function_nonlocals(c.afunc) ) return None ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/debug.py`: ```py from collections import defaultdict from dataclasses import asdict from datetime import datetime, timezone from pprint import pformat from typing import ( Any, Iterable, Iterator, Literal, Mapping, Optional, Sequence, TypedDict, Union, ) from uuid import UUID from langchain_core.runnables.config import RunnableConfig from langchain_core.utils.input import get_bolded_text, get_colored_text from langgraph.channels.base import BaseChannel from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata, PendingWrite from langgraph.constants import ( CONF, CONFIG_KEY_CHECKPOINT_NS, ERROR, INTERRUPT, NS_END, NS_SEP, TAG_HIDDEN, ) from langgraph.pregel.io import read_channels from langgraph.pregel.utils import find_subgraph_pregel from langgraph.types import PregelExecutableTask, PregelTask, StateSnapshot from langgraph.utils.config import patch_checkpoint_map class TaskPayload(TypedDict): id: str name: str input: Any triggers: list[str] class TaskResultPayload(TypedDict): id: str name: str error: Optional[str] interrupts: list[dict] result: list[tuple[str, Any]] class CheckpointTask(TypedDict): id: str name: str error: Optional[str] interrupts: list[dict] state: Optional[RunnableConfig] class CheckpointPayload(TypedDict): config: Optional[RunnableConfig] metadata: CheckpointMetadata values: dict[str, Any] next: list[str] parent_config: Optional[RunnableConfig] tasks: list[CheckpointTask] class DebugOutputBase(TypedDict): timestamp: str step: int class DebugOutputTask(DebugOutputBase): type: Literal["task"] payload: TaskPayload class DebugOutputTaskResult(DebugOutputBase): type: Literal["task_result"] payload: TaskResultPayload class DebugOutputCheckpoint(DebugOutputBase): type: Literal["checkpoint"] payload: CheckpointPayload DebugOutput = Union[DebugOutputTask, DebugOutputTaskResult, DebugOutputCheckpoint] TASK_NAMESPACE = UUID("6ba7b831-9dad-11d1-80b4-00c04fd430c8") def map_debug_tasks( step: int, tasks: Iterable[PregelExecutableTask] ) -> Iterator[DebugOutputTask]: """Produce "task" events for stream_mode=debug.""" ts = datetime.now(timezone.utc).isoformat() for task in tasks: if task.config is not None and TAG_HIDDEN in task.config.get("tags", []): continue yield { "type": "task", "timestamp": ts, "step": step, "payload": { "id": task.id, "name": task.name, "input": task.input, "triggers": task.triggers, }, } def map_debug_task_results( step: int, task_tup: tuple[PregelExecutableTask, Sequence[tuple[str, Any]]], stream_keys: Union[str, Sequence[str]], ) -> Iterator[DebugOutputTaskResult]: """Produce "task_result" events for stream_mode=debug.""" stream_channels_list = ( [stream_keys] if isinstance(stream_keys, str) else stream_keys ) task, writes = task_tup yield { "type": "task_result", "timestamp": datetime.now(timezone.utc).isoformat(), "step": step, "payload": { "id": task.id, "name": task.name, "error": next((w[1] for w in writes if w[0] == ERROR), None), "result": [w for w in writes if w[0] in stream_channels_list], "interrupts": [asdict(w[1]) for w in writes if w[0] == INTERRUPT], }, } def map_debug_checkpoint( step: int, config: RunnableConfig, channels: Mapping[str, BaseChannel], stream_channels: Union[str, Sequence[str]], metadata: CheckpointMetadata, checkpoint: Checkpoint, tasks: Iterable[PregelExecutableTask], pending_writes: list[PendingWrite], parent_config: Optional[RunnableConfig], output_keys: Union[str, Sequence[str]], ) -> Iterator[DebugOutputCheckpoint]: """Produce "checkpoint" events for stream_mode=debug.""" parent_ns = config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") task_states: dict[str, Union[RunnableConfig, StateSnapshot]] = {} for task in tasks: if not find_subgraph_pregel(task.proc): continue # assemble checkpoint_ns for this task task_ns = f"{task.name}{NS_END}{task.id}" if parent_ns: task_ns = f"{parent_ns}{NS_SEP}{task_ns}" # set config as signal that subgraph checkpoints exist task_states[task.id] = { CONF: { "thread_id": config[CONF]["thread_id"], CONFIG_KEY_CHECKPOINT_NS: task_ns, } } yield { "type": "checkpoint", "timestamp": checkpoint["ts"], "step": step, "payload": { "config": patch_checkpoint_map(config, metadata), "parent_config": patch_checkpoint_map(parent_config, metadata), "values": read_channels(channels, stream_channels), "metadata": metadata, "next": [t.name for t in tasks], "tasks": [ { "id": t.id, "name": t.name, "error": t.error, "state": t.state, } if t.error else { "id": t.id, "name": t.name, "result": t.result, "interrupts": tuple(asdict(i) for i in t.interrupts), "state": t.state, } if t.result else { "id": t.id, "name": t.name, "interrupts": tuple(asdict(i) for i in t.interrupts), "state": t.state, } for t in tasks_w_writes(tasks, pending_writes, task_states, output_keys) ], }, } def print_step_tasks(step: int, next_tasks: list[PregelExecutableTask]) -> None: n_tasks = len(next_tasks) print( f"{get_colored_text(f'[{step}:tasks]', color='blue')} " + get_bolded_text( f"Starting {n_tasks} task{'s' if n_tasks != 1 else ''} for step {step}:\n" ) + "\n".join( f"- {get_colored_text(task.name, 'green')} -> {pformat(task.input)}" for task in next_tasks ) ) def print_step_writes( step: int, writes: Sequence[tuple[str, Any]], whitelist: Sequence[str] ) -> None: by_channel: dict[str, list[Any]] = defaultdict(list) for channel, value in writes: if channel in whitelist: by_channel[channel].append(value) print( f"{get_colored_text(f'[{step}:writes]', color='blue')} " + get_bolded_text( f"Finished step {step} with writes to {len(by_channel)} channel{'s' if len(by_channel) != 1 else ''}:\n" ) + "\n".join( f"- {get_colored_text(name, 'yellow')} -> {', '.join(pformat(v) for v in vals)}" for name, vals in by_channel.items() ) ) def print_step_checkpoint( metadata: CheckpointMetadata, channels: Mapping[str, BaseChannel], whitelist: Sequence[str], ) -> None: step = metadata["step"] print( f"{get_colored_text(f'[{step}:checkpoint]', color='blue')} " + get_bolded_text(f"State at the end of step {step}:\n") + pformat(read_channels(channels, whitelist), depth=3) ) def tasks_w_writes( tasks: Iterable[Union[PregelTask, PregelExecutableTask]], pending_writes: Optional[list[PendingWrite]], states: Optional[dict[str, Union[RunnableConfig, StateSnapshot]]], output_keys: Union[str, Sequence[str]], ) -> tuple[PregelTask, ...]: """Apply writes / subgraph states to tasks to be returned in a StateSnapshot.""" pending_writes = pending_writes or [] return tuple( PregelTask( task.id, task.name, task.path, next( ( exc for tid, n, exc in pending_writes if tid == task.id and n == ERROR ), None, ), tuple( v for tid, n, v in pending_writes if tid == task.id and n == INTERRUPT ), states.get(task.id) if states else None, ( next( ( val for tid, chan, val in pending_writes if tid == task.id and chan == output_keys ), None, ) if isinstance(output_keys, str) else { chan: val for tid, chan, val in pending_writes if tid == task.id and ( chan == output_keys if isinstance(output_keys, str) else chan in output_keys ) } ) if any( w[0] == task.id and w[1] not in (ERROR, INTERRUPT) for w in pending_writes ) else None, ) for task in tasks ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/messages.py`: ```py from typing import ( Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Sequence, Union, cast, ) from uuid import UUID, uuid4 from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatGenerationChunk, LLMResult from langchain_core.tracers._streaming import T, _StreamingCallbackHandler from langgraph.constants import NS_SEP, TAG_HIDDEN, TAG_NOSTREAM from langgraph.types import StreamChunk Meta = tuple[tuple[str, ...], dict[str, Any]] class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler): """A callback handler that implements stream_mode=messages. Collects messages from (1) chat model stream events and (2) node outputs.""" run_inline = True """We want this callback to run in the main thread, to avoid order/locking issues.""" def __init__(self, stream: Callable[[StreamChunk], None]): self.stream = stream self.metadata: dict[UUID, Meta] = {} self.seen: set[Union[int, str]] = set() def _emit(self, meta: Meta, message: BaseMessage, *, dedupe: bool = False) -> None: if dedupe and message.id in self.seen: return else: if message.id is None: message.id = str(uuid4()) self.seen.add(message.id) self.stream((meta[0], "messages", (message, meta[1]))) def tap_output_aiter( self, run_id: UUID, output: AsyncIterator[T] ) -> AsyncIterator[T]: return output def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]: return output def on_chat_model_start( self, serialized: dict[str, Any], messages: list[list[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[list[str]] = None, metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: if metadata and (not tags or TAG_NOSTREAM not in tags): self.metadata[run_id] = ( tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)), metadata, ) def on_llm_new_token( self, token: str, *, chunk: Optional[ChatGenerationChunk] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: if not isinstance(chunk, ChatGenerationChunk): return if meta := self.metadata.get(run_id): self._emit(meta, chunk.message) def on_llm_end( self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: self.metadata.pop(run_id, None) def on_llm_error( self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: self.metadata.pop(run_id, None) def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: if ( metadata and kwargs.get("name") == metadata.get("langgraph_node") and (not tags or TAG_HIDDEN not in tags) ): self.metadata[run_id] = ( tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)), metadata, ) def on_chain_end( self, response: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: if meta := self.metadata.pop(run_id, None): if isinstance(response, BaseMessage): self._emit(meta, response, dedupe=True) elif isinstance(response, Sequence): for value in response: if isinstance(value, BaseMessage): self._emit(meta, value, dedupe=True) elif isinstance(response, dict): for value in response.values(): if isinstance(value, BaseMessage): self._emit(meta, value, dedupe=True) elif isinstance(value, Sequence): for item in value: if isinstance(item, BaseMessage): self._emit(meta, item, dedupe=True) elif hasattr(response, "__dir__") and callable(response.__dir__): for key in dir(response): try: value = getattr(response, key) if isinstance(value, BaseMessage): self._emit(meta, value, dedupe=True) elif isinstance(value, Sequence): for item in value: if isinstance(item, BaseMessage): self._emit(meta, item, dedupe=True) except AttributeError: pass def on_chain_error( self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: self.metadata.pop(run_id, None) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/loop.py`: ```py import asyncio import concurrent.futures from collections import defaultdict, deque from contextlib import AsyncExitStack, ExitStack from types import TracebackType from typing import ( Any, AsyncContextManager, Callable, ContextManager, Iterator, List, Literal, Mapping, Optional, Sequence, Type, TypeVar, Union, cast, ) from langchain_core.callbacks import AsyncParentRunManager, ParentRunManager from langchain_core.runnables import RunnableConfig from typing_extensions import ParamSpec, Self from langgraph.channels.base import BaseChannel from langgraph.checkpoint.base import ( WRITES_IDX_MAP, BaseCheckpointSaver, ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, PendingWrite, copy_checkpoint, create_checkpoint, empty_checkpoint, ) from langgraph.constants import ( CONF, CONFIG_KEY_CHECKPOINT_ID, CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_DEDUPE_TASKS, CONFIG_KEY_DELEGATE, CONFIG_KEY_ENSURE_LATEST, CONFIG_KEY_RESUMING, CONFIG_KEY_STREAM, CONFIG_KEY_TASK_ID, EMPTY_SEQ, ERROR, INPUT, INTERRUPT, NS_SEP, NULL_TASK_ID, PUSH, RESUME, SCHEDULED, TAG_HIDDEN, ) from langgraph.errors import ( _SEEN_CHECKPOINT_NS, CheckpointNotLatest, EmptyInputError, GraphDelegate, GraphInterrupt, MultipleSubgraphsError, ) from langgraph.managed.base import ( ManagedValueMapping, ManagedValueSpec, WritableManagedValue, ) from langgraph.pregel.algo import ( GetNextVersion, PregelTaskWrites, apply_writes, increment, prepare_next_tasks, prepare_single_task, should_interrupt, ) from langgraph.pregel.debug import ( map_debug_checkpoint, map_debug_task_results, map_debug_tasks, print_step_checkpoint, print_step_tasks, print_step_writes, ) from langgraph.pregel.executor import ( AsyncBackgroundExecutor, BackgroundExecutor, Submit, ) from langgraph.pregel.io import ( map_command, map_input, map_output_updates, map_output_values, read_channels, single, ) from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager from langgraph.pregel.read import PregelNode from langgraph.pregel.utils import get_new_channel_versions from langgraph.store.base import BaseStore from langgraph.types import ( All, Command, LoopProtocol, PregelExecutableTask, StreamChunk, StreamProtocol, ) from langgraph.utils.config import patch_configurable V = TypeVar("V") P = ParamSpec("P") INPUT_DONE = object() INPUT_RESUMING = object() SPECIAL_CHANNELS = (ERROR, INTERRUPT, SCHEDULED) def DuplexStream(*streams: StreamProtocol) -> StreamProtocol: def __call__(value: StreamChunk) -> None: for stream in streams: if value[1] in stream.modes: stream(value) return StreamProtocol(__call__, {mode for s in streams for mode in s.modes}) class PregelLoop(LoopProtocol): input: Optional[Any] checkpointer: Optional[BaseCheckpointSaver] nodes: Mapping[str, PregelNode] specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]] output_keys: Union[str, Sequence[str]] stream_keys: Union[str, Sequence[str]] skip_done_tasks: bool is_nested: bool manager: Union[None, AsyncParentRunManager, ParentRunManager] interrupt_after: Union[All, Sequence[str]] interrupt_before: Union[All, Sequence[str]] checkpointer_get_next_version: GetNextVersion checkpointer_put_writes: Optional[ Callable[[RunnableConfig, Sequence[tuple[str, Any]], str], Any] ] _checkpointer_put_after_previous: Optional[ Callable[ [ Optional[concurrent.futures.Future], RunnableConfig, Sequence[tuple[str, Any]], str, ChannelVersions, ], Any, ] ] submit: Submit channels: Mapping[str, BaseChannel] managed: ManagedValueMapping checkpoint: Checkpoint checkpoint_ns: tuple[str, ...] checkpoint_config: RunnableConfig checkpoint_metadata: CheckpointMetadata checkpoint_pending_writes: List[PendingWrite] checkpoint_previous_versions: dict[str, Union[str, float, int]] prev_checkpoint_config: Optional[RunnableConfig] status: Literal[ "pending", "done", "interrupt_before", "interrupt_after", "out_of_steps" ] tasks: dict[str, PregelExecutableTask] to_interrupt: list[PregelExecutableTask] output: Union[None, dict[str, Any], Any] = None # public def __init__( self, input: Optional[Any], *, stream: Optional[StreamProtocol], config: RunnableConfig, store: Optional[BaseStore], checkpointer: Optional[BaseCheckpointSaver], nodes: Mapping[str, PregelNode], specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]], output_keys: Union[str, Sequence[str]], stream_keys: Union[str, Sequence[str]], interrupt_after: Union[All, Sequence[str]] = EMPTY_SEQ, interrupt_before: Union[All, Sequence[str]] = EMPTY_SEQ, manager: Union[None, AsyncParentRunManager, ParentRunManager] = None, check_subgraphs: bool = True, debug: bool = False, ) -> None: super().__init__( step=0, stop=0, config=config, stream=stream, store=store, ) self.input = input self.checkpointer = checkpointer self.nodes = nodes self.specs = specs self.output_keys = output_keys self.stream_keys = stream_keys self.interrupt_after = interrupt_after self.interrupt_before = interrupt_before self.manager = manager self.is_nested = CONFIG_KEY_TASK_ID in self.config.get(CONF, {}) self.skip_done_tasks = ( CONFIG_KEY_CHECKPOINT_ID not in config[CONF] or CONFIG_KEY_DEDUPE_TASKS in config[CONF] ) self.debug = debug if self.stream is not None and CONFIG_KEY_STREAM in config[CONF]: self.stream = DuplexStream(self.stream, config[CONF][CONFIG_KEY_STREAM]) if not self.is_nested and config[CONF].get(CONFIG_KEY_CHECKPOINT_NS): self.config = patch_configurable( self.config, {CONFIG_KEY_CHECKPOINT_NS: "", CONFIG_KEY_CHECKPOINT_ID: None}, ) if check_subgraphs and self.is_nested and self.checkpointer is not None: if self.config[CONF][CONFIG_KEY_CHECKPOINT_NS] in _SEEN_CHECKPOINT_NS: raise MultipleSubgraphsError( "Multiple subgraphs called inside the same node\n\n" "Troubleshooting URL: https://python.langchain.com/docs" "/troubleshooting/errors/MULTIPLE_SUBGRAPHS/" ) else: _SEEN_CHECKPOINT_NS.add(self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]) if ( CONFIG_KEY_CHECKPOINT_MAP in self.config[CONF] and self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS) in self.config[CONF][CONFIG_KEY_CHECKPOINT_MAP] ): self.checkpoint_config = patch_configurable( self.config, { CONFIG_KEY_CHECKPOINT_ID: config[CONF][CONFIG_KEY_CHECKPOINT_MAP][ self.config[CONF][CONFIG_KEY_CHECKPOINT_NS] ] }, ) else: self.checkpoint_config = config self.checkpoint_ns = ( tuple(cast(str, self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP)) if self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS) else () ) self.prev_checkpoint_config = None def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None: """Put writes for a task, to be read by the next tick.""" if not writes: return # deduplicate writes to special channels, last write wins if all(w[0] in WRITES_IDX_MAP for w in writes): writes = list({w[0]: w for w in writes}.values()) # save writes for c, v in writes: if ( c in WRITES_IDX_MAP and ( idx := next( ( i for i, w in enumerate(self.checkpoint_pending_writes) if w[0] == task_id and w[1] == c ), None, ) ) is not None ): self.checkpoint_pending_writes[idx] = (task_id, c, v) else: self.checkpoint_pending_writes.append((task_id, c, v)) if self.checkpointer_put_writes is not None: self.submit( self.checkpointer_put_writes, { **self.checkpoint_config, CONF: { **self.checkpoint_config[CONF], CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get( CONFIG_KEY_CHECKPOINT_NS, "" ), CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"], }, }, writes, task_id, ) # output writes if hasattr(self, "tasks"): self._output_writes(task_id, writes) def accept_push( self, task: PregelExecutableTask, write_idx: int ) -> Optional[PregelExecutableTask]: """Accept a PUSH from a task, potentially returning a new task to start.""" # don't start if an earlier PUSH has already triggered an interrupt if self.to_interrupt: return # don't start if we should interrupt *after* the original task if should_interrupt(self.checkpoint, self.interrupt_after, [task]): self.to_interrupt.append(task) return if pushed := cast( Optional[PregelExecutableTask], prepare_single_task( (PUSH, task.path, write_idx, task.id), None, checkpoint=self.checkpoint, pending_writes=[(task.id, *w) for w in task.writes], processes=self.nodes, channels=self.channels, managed=self.managed, config=self.config, step=self.step, for_execution=True, store=self.store, checkpointer=self.checkpointer, manager=self.manager, ), ): # don't start if we should interrupt *before* the new task if should_interrupt(self.checkpoint, self.interrupt_before, [pushed]): self.to_interrupt.append(pushed) return # produce debug output self._emit("debug", map_debug_tasks, self.step, [pushed]) # debug flag if self.debug: print_step_tasks(self.step, [pushed]) # save the new task self.tasks[pushed.id] = pushed # match any pending writes to the new task if self.skip_done_tasks: self._match_writes({pushed.id: pushed}) # return the new task, to be started, if not run before if not pushed.writes: return pushed def tick( self, *, input_keys: Union[str, Sequence[str]], ) -> bool: """Execute a single iteration of the Pregel loop. Returns True if more iterations are needed.""" if self.status != "pending": raise RuntimeError("Cannot tick when status is no longer 'pending'") if self.input not in (INPUT_DONE, INPUT_RESUMING): self._first(input_keys=input_keys) elif self.to_interrupt: # if we need to interrupt, do so self.status = "interrupt_before" raise GraphInterrupt() elif all(task.writes for task in self.tasks.values()): writes = [w for t in self.tasks.values() for w in t.writes] # debug flag if self.debug: print_step_writes( self.step, writes, ( [self.stream_keys] if isinstance(self.stream_keys, str) else self.stream_keys ), ) # all tasks have finished mv_writes = apply_writes( self.checkpoint, self.channels, self.tasks.values(), self.checkpointer_get_next_version, ) # apply writes to managed values for key, values in mv_writes.items(): self._update_mv(key, values) # produce values output self._emit( "values", map_output_values, self.output_keys, writes, self.channels ) # clear pending writes self.checkpoint_pending_writes.clear() # "not skip_done_tasks" only applies to first tick after resuming self.skip_done_tasks = True # save checkpoint self._put_checkpoint( { "source": "loop", "writes": single( map_output_updates( self.output_keys, [(t, t.writes) for t in self.tasks.values()], ) ), } ) # after execution, check if we should interrupt if should_interrupt( self.checkpoint, self.interrupt_after, self.tasks.values() ): self.status = "interrupt_after" raise GraphInterrupt() else: return False # check if iteration limit is reached if self.step > self.stop: self.status = "out_of_steps" return False # apply NULL writes if null_writes := [ w[1:] for w in self.checkpoint_pending_writes if w[0] == NULL_TASK_ID ]: mv_writes = apply_writes( self.checkpoint, self.channels, [PregelTaskWrites((), INPUT, null_writes, [])], self.checkpointer_get_next_version, ) for key, values in mv_writes.items(): self._update_mv(key, values) # prepare next tasks self.tasks = prepare_next_tasks( self.checkpoint, self.checkpoint_pending_writes, self.nodes, self.channels, self.managed, self.config, self.step, for_execution=True, manager=self.manager, store=self.store, checkpointer=self.checkpointer, ) self.to_interrupt = [] # produce debug output if self._checkpointer_put_after_previous is not None: self._emit( "debug", map_debug_checkpoint, self.step - 1, # printing checkpoint for previous step self.checkpoint_config, self.channels, self.stream_keys, self.checkpoint_metadata, self.checkpoint, self.tasks.values(), self.checkpoint_pending_writes, self.prev_checkpoint_config, self.output_keys, ) # if no more tasks, we're done if not self.tasks: self.status = "done" return False # check if we should delegate (used by subgraphs in distributed mode) if self.config[CONF].get(CONFIG_KEY_DELEGATE): assert self.input is INPUT_RESUMING raise GraphDelegate( { "config": patch_configurable( self.config, {CONFIG_KEY_DELEGATE: False} ), "input": None, } ) # if there are pending writes from a previous loop, apply them if self.skip_done_tasks and self.checkpoint_pending_writes: self._match_writes(self.tasks) # if all tasks have finished, re-tick if all(task.writes for task in self.tasks.values()): return self.tick(input_keys=input_keys) # before execution, check if we should interrupt if should_interrupt( self.checkpoint, self.interrupt_before, self.tasks.values() ): self.status = "interrupt_before" raise GraphInterrupt() # produce debug output self._emit("debug", map_debug_tasks, self.step, self.tasks.values()) # debug flag if self.debug: print_step_tasks(self.step, list(self.tasks.values())) # print output for any tasks we applied previous writes to for task in self.tasks.values(): if task.writes: self._output_writes(task.id, task.writes, cached=True) return True # private def _match_writes(self, tasks: Mapping[str, PregelExecutableTask]) -> None: for tid, k, v in self.checkpoint_pending_writes: if k in (ERROR, INTERRUPT, RESUME): continue if task := tasks.get(tid): if k == SCHEDULED: if v == max( self.checkpoint["versions_seen"].get(INTERRUPT, {}).values(), default=None, ): self.tasks[tid] = task._replace(scheduled=True) else: task.writes.append((k, v)) def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None: # resuming from previous checkpoint requires # - finding a previous checkpoint # - receiving None input (outer graph) or RESUMING flag (subgraph) configurable = self.config.get(CONF, {}) is_resuming = bool(self.checkpoint["channel_versions"]) and bool( configurable.get(CONFIG_KEY_RESUMING, self.input is None) ) # proceed past previous checkpoint if is_resuming: self.checkpoint["versions_seen"].setdefault(INTERRUPT, {}) for k in self.channels: if k in self.checkpoint["channel_versions"]: version = self.checkpoint["channel_versions"][k] self.checkpoint["versions_seen"][INTERRUPT][k] = version # produce values output self._emit( "values", map_output_values, self.output_keys, True, self.channels ) # map command to writes elif isinstance(self.input, Command): writes: defaultdict[str, list[tuple[str, Any]]] = defaultdict(list) # group writes by task ID for tid, c, v in map_command(self.input, self.checkpoint_pending_writes): writes[tid].append((c, v)) if not writes: raise EmptyInputError("Received empty Command input") # save writes for tid, ws in writes.items(): self.put_writes(tid, ws) # map inputs to channel updates elif input_writes := deque(map_input(input_keys, self.input)): # TODO shouldn't these writes be passed to put_writes too? # check if we should delegate (used by subgraphs in distributed mode) if self.config[CONF].get(CONFIG_KEY_DELEGATE): raise GraphDelegate( { "config": patch_configurable( self.config, {CONFIG_KEY_DELEGATE: False} ), "input": self.input, } ) # discard any unfinished tasks from previous checkpoint discard_tasks = prepare_next_tasks( self.checkpoint, self.checkpoint_pending_writes, self.nodes, self.channels, self.managed, self.config, self.step, for_execution=True, store=None, checkpointer=None, manager=None, ) # apply input writes mv_writes = apply_writes( self.checkpoint, self.channels, [ *discard_tasks.values(), PregelTaskWrites((), INPUT, input_writes, []), ], self.checkpointer_get_next_version, ) assert not mv_writes, "Can't write to SharedValues in graph input" # save input checkpoint self._put_checkpoint({"source": "input", "writes": dict(input_writes)}) elif CONFIG_KEY_RESUMING not in configurable: raise EmptyInputError(f"Received no input for {input_keys}") # done with input self.input = INPUT_RESUMING if is_resuming else INPUT_DONE # update config if not self.is_nested: self.config = patch_configurable( self.config, {CONFIG_KEY_RESUMING: is_resuming} ) def _put_checkpoint(self, metadata: CheckpointMetadata) -> None: for k, v in self.config["metadata"].items(): metadata.setdefault(k, v) # type: ignore # assign step and parents metadata["step"] = self.step metadata["parents"] = self.config[CONF].get(CONFIG_KEY_CHECKPOINT_MAP, {}) # debug flag if self.debug: print_step_checkpoint( metadata, self.channels, ( [self.stream_keys] if isinstance(self.stream_keys, str) else self.stream_keys ), ) # create new checkpoint self.checkpoint = create_checkpoint(self.checkpoint, self.channels, self.step) # bail if no checkpointer if self._checkpointer_put_after_previous is not None: self.checkpoint_metadata = metadata self.prev_checkpoint_config = ( self.checkpoint_config if CONFIG_KEY_CHECKPOINT_ID in self.checkpoint_config[CONF] and self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID] else None ) self.checkpoint_config = { **self.checkpoint_config, CONF: { **self.checkpoint_config[CONF], CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get( CONFIG_KEY_CHECKPOINT_NS, "" ), }, } channel_versions = self.checkpoint["channel_versions"].copy() new_versions = get_new_channel_versions( self.checkpoint_previous_versions, channel_versions ) self.checkpoint_previous_versions = channel_versions # save it, without blocking # if there's a previous checkpoint save in progress, wait for it # ensuring checkpointers receive checkpoints in order self._put_checkpoint_fut = self.submit( self._checkpointer_put_after_previous, getattr(self, "_put_checkpoint_fut", None), self.checkpoint_config, copy_checkpoint(self.checkpoint), self.checkpoint_metadata, new_versions, ) self.checkpoint_config = { **self.checkpoint_config, CONF: { **self.checkpoint_config[CONF], CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"], }, } # increment step self.step += 1 def _update_mv(self, key: str, values: Sequence[Any]) -> None: raise NotImplementedError def _suppress_interrupt( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> Optional[bool]: suppress = isinstance(exc_value, GraphInterrupt) and not self.is_nested if suppress or exc_type is None: # save final output self.output = read_channels(self.channels, self.output_keys) if suppress: # emit one last "values" event, with pending writes applied if ( hasattr(self, "tasks") and self.checkpoint_pending_writes and any(task.writes for task in self.tasks.values()) ): mv_writes = apply_writes( self.checkpoint, self.channels, self.tasks.values(), self.checkpointer_get_next_version, ) for key, values in mv_writes.items(): self._update_mv(key, values) self._emit( "values", map_output_values, self.output_keys, [w for t in self.tasks.values() for w in t.writes], self.channels, ) # emit INTERRUPT event self._emit( "updates", lambda: iter([{INTERRUPT: cast(GraphInterrupt, exc_value).args[0]}]), ) # suppress interrupt return True def _emit( self, mode: str, values: Callable[P, Iterator[Any]], *args: P.args, **kwargs: P.kwargs, ) -> None: if self.stream is None: return if mode not in self.stream.modes: return for v in values(*args, **kwargs): self.stream((self.checkpoint_ns, mode, v)) def _output_writes( self, task_id: str, writes: Sequence[tuple[str, Any]], *, cached: bool = False ) -> None: if task := self.tasks.get(task_id): if task.config is not None and TAG_HIDDEN in task.config.get( "tags", EMPTY_SEQ ): return if writes[0][0] != ERROR and writes[0][0] != INTERRUPT: self._emit( "updates", map_output_updates, self.output_keys, [(task, writes)], cached, ) if not cached: self._emit( "debug", map_debug_task_results, self.step, (task, writes), self.stream_keys, ) class SyncPregelLoop(PregelLoop, ContextManager): def __init__( self, input: Optional[Any], *, stream: Optional[StreamProtocol], config: RunnableConfig, store: Optional[BaseStore], checkpointer: Optional[BaseCheckpointSaver], nodes: Mapping[str, PregelNode], specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]], manager: Union[None, AsyncParentRunManager, ParentRunManager] = None, interrupt_after: Union[All, Sequence[str]] = EMPTY_SEQ, interrupt_before: Union[All, Sequence[str]] = EMPTY_SEQ, output_keys: Union[str, Sequence[str]] = EMPTY_SEQ, stream_keys: Union[str, Sequence[str]] = EMPTY_SEQ, check_subgraphs: bool = True, debug: bool = False, ) -> None: super().__init__( input, stream=stream, config=config, checkpointer=checkpointer, store=store, nodes=nodes, specs=specs, output_keys=output_keys, stream_keys=stream_keys, interrupt_after=interrupt_after, interrupt_before=interrupt_before, check_subgraphs=check_subgraphs, manager=manager, debug=debug, ) self.stack = ExitStack() if checkpointer: self.checkpointer_get_next_version = checkpointer.get_next_version self.checkpointer_put_writes = checkpointer.put_writes else: self.checkpointer_get_next_version = increment self._checkpointer_put_after_previous = None # type: ignore[assignment] self.checkpointer_put_writes = None def _checkpointer_put_after_previous( self, prev: Optional[concurrent.futures.Future], config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: try: if prev is not None: prev.result() finally: cast(BaseCheckpointSaver, self.checkpointer).put( config, checkpoint, metadata, new_versions ) def _update_mv(self, key: str, values: Sequence[Any]) -> None: return self.submit(cast(WritableManagedValue, self.managed[key]).update, values) # context manager def __enter__(self) -> Self: if self.config.get(CONF, {}).get( CONFIG_KEY_ENSURE_LATEST ) and self.checkpoint_config[CONF].get(CONFIG_KEY_CHECKPOINT_ID): if self.checkpointer is None: raise RuntimeError( "Cannot ensure latest checkpoint without checkpointer" ) saved = self.checkpointer.get_tuple( patch_configurable( self.checkpoint_config, {CONFIG_KEY_CHECKPOINT_ID: None} ) ) if ( saved is None or saved.checkpoint["id"] != self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID] ): raise CheckpointNotLatest elif self.checkpointer: saved = self.checkpointer.get_tuple(self.checkpoint_config) else: saved = None if saved is None: saved = CheckpointTuple( self.config, empty_checkpoint(), {"step": -2}, None, [] ) self.checkpoint_config = { **self.config, **saved.config, CONF: { CONFIG_KEY_CHECKPOINT_NS: "", **self.config.get(CONF, {}), **saved.config.get(CONF, {}), }, } self.prev_checkpoint_config = saved.parent_config self.checkpoint = saved.checkpoint self.checkpoint_metadata = saved.metadata self.checkpoint_pending_writes = ( [(str(tid), k, v) for tid, k, v in saved.pending_writes] if saved.pending_writes is not None else [] ) self.submit = self.stack.enter_context(BackgroundExecutor(self.config)) self.channels, self.managed = self.stack.enter_context( ChannelsManager(self.specs, self.checkpoint, self) ) self.stack.push(self._suppress_interrupt) self.status = "pending" self.step = self.checkpoint_metadata["step"] + 1 self.stop = self.step + self.config["recursion_limit"] + 1 self.checkpoint_previous_versions = self.checkpoint["channel_versions"].copy() return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> Optional[bool]: # unwind stack return self.stack.__exit__(exc_type, exc_value, traceback) class AsyncPregelLoop(PregelLoop, AsyncContextManager): def __init__( self, input: Optional[Any], *, stream: Optional[StreamProtocol], config: RunnableConfig, store: Optional[BaseStore], checkpointer: Optional[BaseCheckpointSaver], nodes: Mapping[str, PregelNode], specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]], interrupt_after: Union[All, Sequence[str]] = EMPTY_SEQ, interrupt_before: Union[All, Sequence[str]] = EMPTY_SEQ, manager: Union[None, AsyncParentRunManager, ParentRunManager] = None, output_keys: Union[str, Sequence[str]] = EMPTY_SEQ, stream_keys: Union[str, Sequence[str]] = EMPTY_SEQ, check_subgraphs: bool = True, debug: bool = False, ) -> None: super().__init__( input, stream=stream, config=config, checkpointer=checkpointer, store=store, nodes=nodes, specs=specs, output_keys=output_keys, stream_keys=stream_keys, interrupt_after=interrupt_after, interrupt_before=interrupt_before, check_subgraphs=check_subgraphs, manager=manager, debug=debug, ) self.stack = AsyncExitStack() if checkpointer: self.checkpointer_get_next_version = checkpointer.get_next_version self.checkpointer_put_writes = checkpointer.aput_writes else: self.checkpointer_get_next_version = increment self._checkpointer_put_after_previous = None # type: ignore[assignment] self.checkpointer_put_writes = None async def _checkpointer_put_after_previous( self, prev: Optional[asyncio.Task], config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: try: if prev is not None: await prev finally: await cast(BaseCheckpointSaver, self.checkpointer).aput( config, checkpoint, metadata, new_versions ) def _update_mv(self, key: str, values: Sequence[Any]) -> None: return self.submit( cast(WritableManagedValue, self.managed[key]).aupdate, values ) # context manager async def __aenter__(self) -> Self: if self.config.get(CONF, {}).get( CONFIG_KEY_ENSURE_LATEST ) and self.checkpoint_config[CONF].get(CONFIG_KEY_CHECKPOINT_ID): if self.checkpointer is None: raise RuntimeError( "Cannot ensure latest checkpoint without checkpointer" ) saved = await self.checkpointer.aget_tuple( patch_configurable( self.checkpoint_config, {CONFIG_KEY_CHECKPOINT_ID: None} ) ) if ( saved is None or saved.checkpoint["id"] != self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID] ): raise CheckpointNotLatest elif self.checkpointer: saved = await self.checkpointer.aget_tuple(self.checkpoint_config) else: saved = None if saved is None: saved = CheckpointTuple( self.config, empty_checkpoint(), {"step": -2}, None, [] ) self.checkpoint_config = { **self.config, **saved.config, CONF: { CONFIG_KEY_CHECKPOINT_NS: "", **self.config.get(CONF, {}), **saved.config.get(CONF, {}), }, } self.prev_checkpoint_config = saved.parent_config self.checkpoint = saved.checkpoint self.checkpoint_metadata = saved.metadata self.checkpoint_pending_writes = ( [(str(tid), k, v) for tid, k, v in saved.pending_writes] if saved.pending_writes is not None else [] ) self.submit = await self.stack.enter_async_context( AsyncBackgroundExecutor(self.config) ) self.channels, self.managed = await self.stack.enter_async_context( AsyncChannelsManager(self.specs, self.checkpoint, self) ) self.stack.push(self._suppress_interrupt) self.status = "pending" self.step = self.checkpoint_metadata["step"] + 1 self.stop = self.step + self.config["recursion_limit"] + 1 self.checkpoint_previous_versions = self.checkpoint["channel_versions"].copy() return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> Optional[bool]: # unwind stack return await asyncio.shield( self.stack.__aexit__(exc_type, exc_value, traceback) ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/algo.py`: ```py import sys from collections import defaultdict, deque from functools import partial from hashlib import sha1 from typing import ( Any, Callable, Iterable, Iterator, Literal, Mapping, NamedTuple, Optional, Protocol, Sequence, Union, cast, overload, ) from uuid import UUID from langchain_core.callbacks.manager import AsyncParentRunManager, ParentRunManager from langchain_core.runnables.config import RunnableConfig from langgraph.channels.base import BaseChannel from langgraph.checkpoint.base import ( BaseCheckpointSaver, Checkpoint, PendingWrite, V, copy_checkpoint, ) from langgraph.constants import ( CONF, CONFIG_KEY_CHECKPOINT_ID, CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_CHECKPOINTER, CONFIG_KEY_READ, CONFIG_KEY_SCRATCHPAD, CONFIG_KEY_SEND, CONFIG_KEY_STORE, CONFIG_KEY_TASK_ID, CONFIG_KEY_WRITES, EMPTY_SEQ, INTERRUPT, NO_WRITES, NS_END, NS_SEP, NULL_TASK_ID, PULL, PUSH, RESERVED, RESUME, TAG_HIDDEN, TASKS, Send, ) from langgraph.errors import EmptyChannelError, InvalidUpdateError from langgraph.managed.base import ManagedValueMapping from langgraph.pregel.io import read_channel, read_channels from langgraph.pregel.log import logger from langgraph.pregel.manager import ChannelsManager from langgraph.pregel.read import PregelNode from langgraph.store.base import BaseStore from langgraph.types import All, LoopProtocol, PregelExecutableTask, PregelTask from langgraph.utils.config import merge_configs, patch_config GetNextVersion = Callable[[Optional[V], BaseChannel], V] SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11) class WritesProtocol(Protocol): """Protocol for objects containing writes to be applied to checkpoint. Implemented by PregelTaskWrites and PregelExecutableTask.""" @property def path(self) -> tuple[Union[str, int, tuple], ...]: ... @property def name(self) -> str: ... @property def writes(self) -> Sequence[tuple[str, Any]]: ... @property def triggers(self) -> Sequence[str]: ... class PregelTaskWrites(NamedTuple): """Simplest implementation of WritesProtocol, for usage with writes that don't originate from a runnable task, eg. graph input, update_state, etc.""" path: tuple[Union[str, int, tuple], ...] name: str writes: Sequence[tuple[str, Any]] triggers: Sequence[str] def should_interrupt( checkpoint: Checkpoint, interrupt_nodes: Union[All, Sequence[str]], tasks: Iterable[PregelExecutableTask], ) -> list[PregelExecutableTask]: """Check if the graph should be interrupted based on current state.""" version_type = type(next(iter(checkpoint["channel_versions"].values()), None)) null_version = version_type() # type: ignore[misc] seen = checkpoint["versions_seen"].get(INTERRUPT, {}) # interrupt if any channel has been updated since last interrupt any_updates_since_prev_interrupt = any( version > seen.get(chan, null_version) # type: ignore[operator] for chan, version in checkpoint["channel_versions"].items() ) # and any triggered node is in interrupt_nodes list return ( [ task for task in tasks if ( ( not task.config or TAG_HIDDEN not in task.config.get("tags", EMPTY_SEQ) ) if interrupt_nodes == "*" else task.name in interrupt_nodes ) ] if any_updates_since_prev_interrupt else [] ) def local_read( step: int, checkpoint: Checkpoint, channels: Mapping[str, BaseChannel], managed: ManagedValueMapping, task: WritesProtocol, config: RunnableConfig, select: Union[list[str], str], fresh: bool = False, ) -> Union[dict[str, Any], Any]: """Function injected under CONFIG_KEY_READ in task config, to read current state. Used by conditional edges to read a copy of the state with reflecting the writes from that node only.""" if isinstance(select, str): managed_keys = [] for c, _ in task.writes: if c == select: updated = {c} break else: updated = set() else: managed_keys = [k for k in select if k in managed] select = [k for k in select if k not in managed] updated = set(select).intersection(c for c, _ in task.writes) if fresh and updated: with ChannelsManager( {k: v for k, v in channels.items() if k in updated}, checkpoint, LoopProtocol(config=config, step=step, stop=step + 1), skip_context=True, ) as (local_channels, _): apply_writes(copy_checkpoint(checkpoint), local_channels, [task], None) values = read_channels({**channels, **local_channels}, select) else: values = read_channels(channels, select) if managed_keys: values.update({k: managed[k]() for k in managed_keys}) return values def local_write( commit: Callable[[Sequence[tuple[str, Any]]], None], process_keys: Iterable[str], writes: Sequence[tuple[str, Any]], ) -> None: """Function injected under CONFIG_KEY_SEND in task config, to write to channels. Validates writes and forwards them to `commit` function.""" for chan, value in writes: if chan in (PUSH, TASKS): if not isinstance(value, Send): raise InvalidUpdateError(f"Expected Send, got {value}") if value.node not in process_keys: raise InvalidUpdateError(f"Invalid node name {value.node} in packet") commit(writes) def increment(current: Optional[int], channel: BaseChannel) -> int: """Default channel versioning function, increments the current int version.""" return current + 1 if current is not None else 1 def apply_writes( checkpoint: Checkpoint, channels: Mapping[str, BaseChannel], tasks: Iterable[WritesProtocol], get_next_version: Optional[GetNextVersion], ) -> dict[str, list[Any]]: """Apply writes from a set of tasks (usually the tasks from a Pregel step) to the checkpoint and channels, and return managed values writes to be applied externally.""" # sort tasks on path, to ensure deterministic order for update application # any path parts after the 3rd are ignored for sorting # (we use them for eg. task ids which aren't good for sorting) tasks = sorted(tasks, key=lambda t: t.path[:3]) # if no task has triggers this is applying writes from the null task only # so we don't do anything other than update the channels written to bump_step = any(t.triggers for t in tasks) # update seen versions for task in tasks: checkpoint["versions_seen"].setdefault(task.name, {}).update( { chan: checkpoint["channel_versions"][chan] for chan in task.triggers if chan in checkpoint["channel_versions"] } ) # Find the highest version of all channels if checkpoint["channel_versions"]: max_version = max(checkpoint["channel_versions"].values()) else: max_version = None # Consume all channels that were read for chan in { chan for task in tasks for chan in task.triggers if chan not in RESERVED and chan in channels }: if channels[chan].consume() and get_next_version is not None: checkpoint["channel_versions"][chan] = get_next_version( max_version, channels[chan], ) # clear pending sends if checkpoint["pending_sends"] and bump_step: checkpoint["pending_sends"].clear() # Group writes by channel pending_writes_by_channel: dict[str, list[Any]] = defaultdict(list) pending_writes_by_managed: dict[str, list[Any]] = defaultdict(list) for task in tasks: for chan, val in task.writes: if chan in (NO_WRITES, PUSH, RESUME, INTERRUPT): pass elif chan == TASKS: # TODO: remove branch in 1.0 checkpoint["pending_sends"].append(val) elif chan in channels: pending_writes_by_channel[chan].append(val) else: pending_writes_by_managed[chan].append(val) # Find the highest version of all channels if checkpoint["channel_versions"]: max_version = max(checkpoint["channel_versions"].values()) else: max_version = None # Apply writes to channels updated_channels: set[str] = set() for chan, vals in pending_writes_by_channel.items(): if chan in channels: if channels[chan].update(vals) and get_next_version is not None: checkpoint["channel_versions"][chan] = get_next_version( max_version, channels[chan], ) updated_channels.add(chan) # Channels that weren't updated in this step are notified of a new step if bump_step: for chan in channels: if chan not in updated_channels: if channels[chan].update([]) and get_next_version is not None: checkpoint["channel_versions"][chan] = get_next_version( max_version, channels[chan], ) # Return managed values writes to be applied externally return pending_writes_by_managed @overload def prepare_next_tasks( checkpoint: Checkpoint, pending_writes: Sequence[PendingWrite], processes: Mapping[str, PregelNode], channels: Mapping[str, BaseChannel], managed: ManagedValueMapping, config: RunnableConfig, step: int, *, for_execution: Literal[False], store: Literal[None] = None, checkpointer: Literal[None] = None, manager: Literal[None] = None, ) -> dict[str, PregelTask]: ... @overload def prepare_next_tasks( checkpoint: Checkpoint, pending_writes: Sequence[PendingWrite], processes: Mapping[str, PregelNode], channels: Mapping[str, BaseChannel], managed: ManagedValueMapping, config: RunnableConfig, step: int, *, for_execution: Literal[True], store: Optional[BaseStore], checkpointer: Optional[BaseCheckpointSaver], manager: Union[None, ParentRunManager, AsyncParentRunManager], ) -> dict[str, PregelExecutableTask]: ... def prepare_next_tasks( checkpoint: Checkpoint, pending_writes: Sequence[PendingWrite], processes: Mapping[str, PregelNode], channels: Mapping[str, BaseChannel], managed: ManagedValueMapping, config: RunnableConfig, step: int, *, for_execution: bool, store: Optional[BaseStore] = None, checkpointer: Optional[BaseCheckpointSaver] = None, manager: Union[None, ParentRunManager, AsyncParentRunManager] = None, ) -> Union[dict[str, PregelTask], dict[str, PregelExecutableTask]]: """Prepare the set of tasks that will make up the next Pregel step. This is the union of all PUSH tasks (Sends) and PULL tasks (nodes triggered by edges).""" tasks: list[Union[PregelTask, PregelExecutableTask]] = [] # Consume pending_sends from previous step (legacy version of Send) for idx, _ in enumerate(checkpoint["pending_sends"]): # TODO: remove branch in 1.0 if task := prepare_single_task( (PUSH, idx), None, checkpoint=checkpoint, pending_writes=pending_writes, processes=processes, channels=channels, managed=managed, config=config, step=step, for_execution=for_execution, store=store, checkpointer=checkpointer, manager=manager, ): tasks.append(task) # Check if any processes should be run in next step # If so, prepare the values to be passed to them for name in processes: if task := prepare_single_task( (PULL, name), None, checkpoint=checkpoint, pending_writes=pending_writes, processes=processes, channels=channels, managed=managed, config=config, step=step, for_execution=for_execution, store=store, checkpointer=checkpointer, manager=manager, ): tasks.append(task) # Consume pending Sends from this step (new version of Send) if any(c == PUSH for _, c, _ in pending_writes): # group writes by task id grouped_by_task = defaultdict(list) for tid, c, _ in pending_writes: grouped_by_task[tid].append(c) # prepare send tasks from grouped writes # 1. start from sends originating from existing tasks tidx = 0 while tidx < len(tasks): task = tasks[tidx] if twrites := grouped_by_task.pop(task.id, None): for idx, c in enumerate(twrites): if c != PUSH: continue if next_task := prepare_single_task( (PUSH, task.path, idx, task.id), None, checkpoint=checkpoint, pending_writes=pending_writes, processes=processes, channels=channels, managed=managed, config=config, step=step, for_execution=for_execution, store=store, checkpointer=checkpointer, manager=manager, ): tasks.append(next_task) tidx += 1 # key tasks by id task_map = {t.id: t for t in tasks} # 2. create new tasks for remaining sends (eg. from update_state) for tid, writes in grouped_by_task.items(): task = task_map.get(tid) for idx, c in enumerate(writes): if c != PUSH: continue if next_task := prepare_single_task( (PUSH, task.path if task else (), idx, tid), None, checkpoint=checkpoint, pending_writes=pending_writes, processes=processes, channels=channels, managed=managed, config=config, step=step, for_execution=for_execution, store=store, checkpointer=checkpointer, manager=manager, ): task_map[next_task.id] = next_task else: task_map = {t.id: t for t in tasks} return task_map def prepare_single_task( task_path: tuple[Union[str, int, tuple], ...], task_id_checksum: Optional[str], *, checkpoint: Checkpoint, pending_writes: Sequence[PendingWrite], processes: Mapping[str, PregelNode], channels: Mapping[str, BaseChannel], managed: ManagedValueMapping, config: RunnableConfig, step: int, for_execution: bool, store: Optional[BaseStore] = None, checkpointer: Optional[BaseCheckpointSaver] = None, manager: Union[None, ParentRunManager, AsyncParentRunManager] = None, ) -> Union[None, PregelTask, PregelExecutableTask]: """Prepares a single task for the next Pregel step, given a task path, which uniquely identifies a PUSH or PULL task within the graph.""" checkpoint_id = UUID(checkpoint["id"]).bytes configurable = config.get(CONF, {}) parent_ns = configurable.get(CONFIG_KEY_CHECKPOINT_NS, "") if task_path[0] == PUSH: if len(task_path) == 2: # TODO: remove branch in 1.0 # legacy SEND tasks, executed in superstep n+1 # (PUSH, idx of pending send) idx = cast(int, task_path[1]) if idx >= len(checkpoint["pending_sends"]): return packet = checkpoint["pending_sends"][idx] if not isinstance(packet, Send): logger.warning( f"Ignoring invalid packet type {type(packet)} in pending sends" ) return if packet.node not in processes: logger.warning( f"Ignoring unknown node name {packet.node} in pending sends" ) return # create task id triggers = [PUSH] checkpoint_ns = ( f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node ) task_id = _uuid5_str( checkpoint_id, checkpoint_ns, str(step), packet.node, PUSH, str(idx), ) elif len(task_path) == 4: # new PUSH tasks, executed in superstep n # (PUSH, parent task path, idx of PUSH write, id of parent task) task_path_t = cast(tuple[str, tuple, int, str], task_path) writes_for_path = [w for w in pending_writes if w[0] == task_path_t[3]] if task_path_t[2] >= len(writes_for_path): logger.warning( f"Ignoring invalid write index {task_path[2]} in pending writes" ) return packet = writes_for_path[task_path_t[2]][2] if not isinstance(packet, Send): logger.warning( f"Ignoring invalid packet type {type(packet)} in pending writes" ) return if packet.node not in processes: logger.warning( f"Ignoring unknown node name {packet.node} in pending writes" ) return # create task id triggers = [PUSH] checkpoint_ns = ( f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node ) task_id = _uuid5_str( checkpoint_id, checkpoint_ns, str(step), packet.node, PUSH, _tuple_str(task_path[1]), str(task_path[2]), ) else: logger.warning(f"Ignoring invalid PUSH task path {task_path}") return task_checkpoint_ns = f"{checkpoint_ns}:{task_id}" metadata = { "langgraph_step": step, "langgraph_node": packet.node, "langgraph_triggers": triggers, "langgraph_path": task_path, "langgraph_checkpoint_ns": task_checkpoint_ns, } if task_id_checksum is not None: assert task_id == task_id_checksum, f"{task_id} != {task_id_checksum}" if for_execution: proc = processes[packet.node] if node := proc.node: if proc.metadata: metadata.update(proc.metadata) writes: deque[tuple[str, Any]] = deque() return PregelExecutableTask( packet.node, packet.arg, node, writes, patch_config( merge_configs( config, {"metadata": metadata, "tags": proc.tags} ), run_name=packet.node, callbacks=( manager.get_child(f"graph:step:{step}") if manager else None ), configurable={ CONFIG_KEY_TASK_ID: task_id, # deque.extend is thread-safe CONFIG_KEY_SEND: partial( local_write, writes.extend, processes.keys(), ), CONFIG_KEY_READ: partial( local_read, step, checkpoint, channels, managed, PregelTaskWrites( task_path, packet.node, writes, triggers ), config, ), CONFIG_KEY_STORE: ( store or configurable.get(CONFIG_KEY_STORE) ), CONFIG_KEY_CHECKPOINTER: ( checkpointer or configurable.get(CONFIG_KEY_CHECKPOINTER) ), CONFIG_KEY_CHECKPOINT_MAP: { **configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}), parent_ns: checkpoint["id"], }, CONFIG_KEY_CHECKPOINT_ID: None, CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns, CONFIG_KEY_WRITES: [ w for w in pending_writes + configurable.get(CONFIG_KEY_WRITES, []) if w[0] in (NULL_TASK_ID, task_id) ], CONFIG_KEY_SCRATCHPAD: {}, }, ), triggers, proc.retry_policy, None, task_id, task_path, writers=proc.flat_writers, ) else: return PregelTask(task_id, packet.node, task_path) elif task_path[0] == PULL: # (PULL, node name) name = cast(str, task_path[1]) if name not in processes: return proc = processes[name] version_type = type(next(iter(checkpoint["channel_versions"].values()), None)) null_version = version_type() # type: ignore[misc] if null_version is None: return seen = checkpoint["versions_seen"].get(name, {}) # If any of the channels read by this process were updated if triggers := sorted( chan for chan in proc.triggers if not isinstance( read_channel(channels, chan, return_exception=True), EmptyChannelError ) and checkpoint["channel_versions"].get(chan, null_version) # type: ignore[operator] > seen.get(chan, null_version) ): try: val = next( _proc_input(proc, managed, channels, for_execution=for_execution) ) except StopIteration: return except Exception as exc: if SUPPORTS_EXC_NOTES: exc.add_note( f"Before task with name '{name}' and path '{task_path[:3]}'" ) raise # create task id checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name task_id = _uuid5_str( checkpoint_id, checkpoint_ns, str(step), name, PULL, *triggers, ) task_checkpoint_ns = f"{checkpoint_ns}{NS_END}{task_id}" metadata = { "langgraph_step": step, "langgraph_node": name, "langgraph_triggers": triggers, "langgraph_path": task_path, "langgraph_checkpoint_ns": task_checkpoint_ns, } if task_id_checksum is not None: assert task_id == task_id_checksum if for_execution: if node := proc.node: if proc.metadata: metadata.update(proc.metadata) writes = deque() return PregelExecutableTask( name, val, node, writes, patch_config( merge_configs( config, {"metadata": metadata, "tags": proc.tags} ), run_name=name, callbacks=( manager.get_child(f"graph:step:{step}") if manager else None ), configurable={ CONFIG_KEY_TASK_ID: task_id, # deque.extend is thread-safe CONFIG_KEY_SEND: partial( local_write, writes.extend, processes.keys(), ), CONFIG_KEY_READ: partial( local_read, step, checkpoint, channels, managed, PregelTaskWrites(task_path, name, writes, triggers), config, ), CONFIG_KEY_STORE: ( store or configurable.get(CONFIG_KEY_STORE) ), CONFIG_KEY_CHECKPOINTER: ( checkpointer or configurable.get(CONFIG_KEY_CHECKPOINTER) ), CONFIG_KEY_CHECKPOINT_MAP: { **configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}), parent_ns: checkpoint["id"], }, CONFIG_KEY_CHECKPOINT_ID: None, CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns, CONFIG_KEY_WRITES: [ w for w in pending_writes + configurable.get(CONFIG_KEY_WRITES, []) if w[0] in (NULL_TASK_ID, task_id) ], CONFIG_KEY_SCRATCHPAD: {}, }, ), triggers, proc.retry_policy, None, task_id, task_path, writers=proc.flat_writers, ) else: return PregelTask(task_id, name, task_path) def _proc_input( proc: PregelNode, managed: ManagedValueMapping, channels: Mapping[str, BaseChannel], *, for_execution: bool, ) -> Iterator[Any]: """Prepare input for a PULL task, based on the process's channels and triggers.""" # If all trigger channels subscribed by this process are not empty # then invoke the process with the values of all non-empty channels if isinstance(proc.channels, dict): try: val: dict[str, Any] = {} for k, chan in proc.channels.items(): if chan in proc.triggers: val[k] = read_channel(channels, chan, catch=False) elif chan in channels: try: val[k] = read_channel(channels, chan, catch=False) except EmptyChannelError: continue else: val[k] = managed[k]() except EmptyChannelError: return elif isinstance(proc.channels, list): for chan in proc.channels: try: val = read_channel(channels, chan, catch=False) break except EmptyChannelError: pass else: return else: raise RuntimeError( "Invalid channels type, expected list or dict, got {proc.channels}" ) # If the process has a mapper, apply it to the value if for_execution and proc.mapper is not None: val = proc.mapper(val) yield val def _uuid5_str(namespace: bytes, *parts: str) -> str: """Generate a UUID from the SHA-1 hash of a namespace UUID and a name.""" sha = sha1(namespace, usedforsecurity=False) sha.update(b"".join(p.encode() for p in parts)) hex = sha.hexdigest() return f"{hex[:8]}-{hex[8:12]}-{hex[12:16]}-{hex[16:20]}-{hex[20:32]}" def _tuple_str(tup: Union[str, int, tuple]) -> str: """Generate a string representation of a tuple.""" return ( f"({', '.join(_tuple_str(x) for x in tup)})" if isinstance(tup, (tuple, list)) else str(tup) ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/manager.py`: ```py import asyncio from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager from typing import AsyncIterator, Iterator, Mapping, Union from langgraph.channels.base import BaseChannel from langgraph.checkpoint.base import Checkpoint from langgraph.managed.base import ( ConfiguredManagedValue, ManagedValueMapping, ManagedValueSpec, ) from langgraph.managed.context import Context from langgraph.types import LoopProtocol @contextmanager def ChannelsManager( specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]], checkpoint: Checkpoint, loop: LoopProtocol, *, skip_context: bool = False, ) -> Iterator[tuple[Mapping[str, BaseChannel], ManagedValueMapping]]: """Manage channels for the lifetime of a Pregel invocation (multiple steps).""" channel_specs: dict[str, BaseChannel] = {} managed_specs: dict[str, ManagedValueSpec] = {} for k, v in specs.items(): if isinstance(v, BaseChannel): channel_specs[k] = v elif ( skip_context and isinstance(v, ConfiguredManagedValue) and v.cls is Context ): managed_specs[k] = Context.of(noop_context) else: managed_specs[k] = v with ExitStack() as stack: yield ( { k: v.from_checkpoint(checkpoint["channel_values"].get(k)) for k, v in channel_specs.items() }, ManagedValueMapping( { key: stack.enter_context( value.cls.enter(loop, **value.kwargs) if isinstance(value, ConfiguredManagedValue) else value.enter(loop) ) for key, value in managed_specs.items() } ), ) @asynccontextmanager async def AsyncChannelsManager( specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]], checkpoint: Checkpoint, loop: LoopProtocol, *, skip_context: bool = False, ) -> AsyncIterator[tuple[Mapping[str, BaseChannel], ManagedValueMapping]]: """Manage channels for the lifetime of a Pregel invocation (multiple steps).""" channel_specs: dict[str, BaseChannel] = {} managed_specs: dict[str, ManagedValueSpec] = {} for k, v in specs.items(): if isinstance(v, BaseChannel): channel_specs[k] = v elif ( skip_context and isinstance(v, ConfiguredManagedValue) and v.cls is Context ): managed_specs[k] = Context.of(noop_context) else: managed_specs[k] = v async with AsyncExitStack() as stack: # managed: create enter tasks with reference to spec, await them if tasks := { asyncio.create_task( stack.enter_async_context( value.cls.aenter(loop, **value.kwargs) if isinstance(value, ConfiguredManagedValue) else value.aenter(loop) ) ): key for key, value in managed_specs.items() }: done, _ = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) else: done = set() yield ( # channels: enter each channel with checkpoint { k: v.from_checkpoint(checkpoint["channel_values"].get(k)) for k, v in channel_specs.items() }, # managed: build mapping from spec to result ManagedValueMapping({tasks[task]: task.result() for task in done}), ) @contextmanager def noop_context() -> Iterator[None]: yield None ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/executor.py`: ```py import asyncio import concurrent.futures import sys from contextlib import ExitStack from contextvars import copy_context from types import TracebackType from typing import ( AsyncContextManager, Awaitable, Callable, ContextManager, Coroutine, Optional, Protocol, TypeVar, cast, ) from langchain_core.runnables import RunnableConfig from langchain_core.runnables.config import get_executor_for_config from typing_extensions import ParamSpec from langgraph.errors import GraphBubbleUp P = ParamSpec("P") T = TypeVar("T") class Submit(Protocol[P, T]): def __call__( self, fn: Callable[P, T], *args: P.args, __name__: Optional[str] = None, __cancel_on_exit__: bool = False, __reraise_on_exit__: bool = True, **kwargs: P.kwargs, ) -> concurrent.futures.Future[T]: ... class BackgroundExecutor(ContextManager): """A context manager that runs sync tasks in the background. Uses a thread pool executor to delegate tasks to separate threads. On exit, - cancels any (not yet started) tasks with `__cancel_on_exit__=True` - waits for all tasks to finish - re-raises the first exception from tasks with `__reraise_on_exit__=True`""" def __init__(self, config: RunnableConfig) -> None: self.stack = ExitStack() self.executor = self.stack.enter_context(get_executor_for_config(config)) self.tasks: dict[concurrent.futures.Future, tuple[bool, bool]] = {} def submit( # type: ignore[valid-type] self, fn: Callable[P, T], *args: P.args, __name__: Optional[str] = None, # currently not used in sync version __cancel_on_exit__: bool = False, # for sync, can cancel only if not started __reraise_on_exit__: bool = True, **kwargs: P.kwargs, ) -> concurrent.futures.Future[T]: task = self.executor.submit(fn, *args, **kwargs) self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__) task.add_done_callback(self.done) return task def done(self, task: concurrent.futures.Future) -> None: try: task.result() except GraphBubbleUp: # This exception is an interruption signal, not an error # so we don't want to re-raise it on exit self.tasks.pop(task) except BaseException: pass else: self.tasks.pop(task) def __enter__(self) -> Submit: return self.submit def __exit__( self, exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> Optional[bool]: # copy the tasks as done() callback may modify the dict tasks = self.tasks.copy() # cancel all tasks that should be cancelled for task, (cancel, _) in tasks.items(): if cancel: task.cancel() # wait for all tasks to finish if pending := {t for t in tasks if not t.done()}: concurrent.futures.wait(pending) # shutdown the executor self.stack.__exit__(exc_type, exc_value, traceback) # re-raise the first exception that occurred in a task if exc_type is None: # if there's already an exception being raised, don't raise another one for task, (_, reraise) in tasks.items(): if not reraise: continue try: task.result() except concurrent.futures.CancelledError: pass class AsyncBackgroundExecutor(AsyncContextManager): """A context manager that runs async tasks in the background. Uses the current event loop to delegate tasks to asyncio tasks. On exit, - cancels any tasks with `__cancel_on_exit__=True` - waits for all tasks to finish - re-raises the first exception from tasks with `__reraise_on_exit__=True` ignoring CancelledError""" def __init__(self, config: RunnableConfig) -> None: self.context_not_supported = sys.version_info < (3, 11) self.tasks: dict[asyncio.Task, tuple[bool, bool]] = {} self.sentinel = object() self.loop = asyncio.get_running_loop() if max_concurrency := config.get("max_concurrency"): self.semaphore: Optional[asyncio.Semaphore] = asyncio.Semaphore( max_concurrency ) else: self.semaphore = None def submit( # type: ignore[valid-type] self, fn: Callable[P, Awaitable[T]], *args: P.args, __name__: Optional[str] = None, __cancel_on_exit__: bool = False, __reraise_on_exit__: bool = True, **kwargs: P.kwargs, ) -> asyncio.Task[T]: coro = cast(Coroutine[None, None, T], fn(*args, **kwargs)) if self.semaphore: coro = gated(self.semaphore, coro) if self.context_not_supported: task = self.loop.create_task(coro, name=__name__) else: task = self.loop.create_task(coro, name=__name__, context=copy_context()) self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__) task.add_done_callback(self.done) return task def done(self, task: asyncio.Task) -> None: try: if exc := task.exception(): # This exception is an interruption signal, not an error # so we don't want to re-raise it on exit if isinstance(exc, GraphBubbleUp): self.tasks.pop(task) else: self.tasks.pop(task) except asyncio.CancelledError: self.tasks.pop(task) async def __aenter__(self) -> Submit: return self.submit async def __aexit__( self, exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: # copy the tasks as done() callback may modify the dict tasks = self.tasks.copy() # cancel all tasks that should be cancelled for task, (cancel, _) in tasks.items(): if cancel: task.cancel(self.sentinel) # wait for all tasks to finish if tasks: await asyncio.wait(tasks) # if there's already an exception being raised, don't raise another one if exc_type is None: # re-raise the first exception that occurred in a task for task, (_, reraise) in tasks.items(): if not reraise: continue try: if exc := task.exception(): raise exc except asyncio.CancelledError: pass async def gated(semaphore: asyncio.Semaphore, coro: Coroutine[None, None, T]) -> T: """A coroutine that waits for a semaphore before running another coroutine.""" async with semaphore: return await coro ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/read.py`: ```py from __future__ import annotations from functools import cached_property from typing import ( Any, AsyncIterator, Callable, Iterator, Mapping, Optional, Sequence, Union, ) from langchain_core.runnables import ( Runnable, RunnableConfig, RunnablePassthrough, RunnableSerializable, ) from langchain_core.runnables.base import Input, Other, coerce_to_runnable from langchain_core.runnables.utils import ConfigurableFieldSpec from langgraph.constants import CONF, CONFIG_KEY_READ from langgraph.pregel.retry import RetryPolicy from langgraph.pregel.write import ChannelWrite from langgraph.utils.config import merge_configs from langgraph.utils.runnable import RunnableCallable, RunnableSeq READ_TYPE = Callable[[Union[str, Sequence[str]], bool], Union[Any, dict[str, Any]]] class ChannelRead(RunnableCallable): """Implements the logic for reading state from CONFIG_KEY_READ. Usable both as a runnable as well as a static method to call imperatively.""" channel: Union[str, list[str]] fresh: bool = False mapper: Optional[Callable[[Any], Any]] = None @property def config_specs(self) -> list[ConfigurableFieldSpec]: return [ ConfigurableFieldSpec( id=CONFIG_KEY_READ, name=CONFIG_KEY_READ, description=None, default=None, annotation=None, ), ] def __init__( self, channel: Union[str, list[str]], *, fresh: bool = False, mapper: Optional[Callable[[Any], Any]] = None, tags: Optional[list[str]] = None, ) -> None: super().__init__(func=self._read, afunc=self._aread, tags=tags, name=None) self.fresh = fresh self.mapper = mapper self.channel = channel def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None ) -> str: if name: pass elif isinstance(self.channel, str): name = f"ChannelRead<{self.channel}>" else: name = f"ChannelRead<{','.join(self.channel)}>" return super().get_name(suffix, name=name) def _read(self, _: Any, config: RunnableConfig) -> Any: return self.do_read( config, select=self.channel, fresh=self.fresh, mapper=self.mapper ) async def _aread(self, _: Any, config: RunnableConfig) -> Any: return self.do_read( config, select=self.channel, fresh=self.fresh, mapper=self.mapper ) @staticmethod def do_read( config: RunnableConfig, *, select: Union[str, list[str]], fresh: bool = False, mapper: Optional[Callable[[Any], Any]] = None, ) -> Any: try: read: READ_TYPE = config[CONF][CONFIG_KEY_READ] except KeyError: raise RuntimeError( "Not configured with a read function" "Make sure to call in the context of a Pregel process" ) if mapper: return mapper(read(select, fresh)) else: return read(select, fresh) DEFAULT_BOUND: RunnablePassthrough = RunnablePassthrough() class PregelNode(Runnable): """A node in a Pregel graph. This won't be invoked as a runnable by the graph itself, but instead acts as a container for the components necessary to make a PregelExecutableTask for a node.""" channels: Union[list[str], Mapping[str, str]] """The channels that will be passed as input to `bound`. If a list, the node will be invoked with the first of that isn't empty. If a dict, the keys are the names of the channels, and the values are the keys to use in the input to `bound`.""" triggers: list[str] """If any of these channels is written to, this node will be triggered in the next step.""" mapper: Optional[Callable[[Any], Any]] """A function to transform the input before passing it to `bound`.""" writers: list[Runnable] """A list of writers that will be executed after `bound`, responsible for taking the output of `bound` and writing it to the appropriate channels.""" bound: Runnable[Any, Any] """The main logic of the node. This will be invoked with the input from `channels`.""" retry_policy: Optional[RetryPolicy] """The retry policy to use when invoking the node.""" tags: Optional[Sequence[str]] """Tags to attach to the node for tracing.""" metadata: Optional[Mapping[str, Any]] """Metadata to attach to the node for tracing.""" def __init__( self, *, channels: Union[list[str], Mapping[str, str]], triggers: Sequence[str], mapper: Optional[Callable[[Any], Any]] = None, writers: Optional[list[Runnable]] = None, tags: Optional[list[str]] = None, metadata: Optional[Mapping[str, Any]] = None, bound: Optional[Runnable[Any, Any]] = None, retry_policy: Optional[RetryPolicy] = None, ) -> None: self.channels = channels self.triggers = list(triggers) self.mapper = mapper self.writers = writers or [] self.bound = bound if bound is not None else DEFAULT_BOUND self.retry_policy = retry_policy self.tags = tags self.metadata = metadata def copy(self, update: dict[str, Any]) -> PregelNode: attrs = {**self.__dict__, **update} return PregelNode(**attrs) @cached_property def flat_writers(self) -> list[Runnable]: """Get writers with optimizations applied. Dedupes consecutive ChannelWrites.""" writers = self.writers.copy() while ( len(writers) > 1 and isinstance(writers[-1], ChannelWrite) and isinstance(writers[-2], ChannelWrite) ): # we can combine writes if they are consecutive # careful to not modify the original writers list or ChannelWrite writers[-2] = ChannelWrite( writes=writers[-2].writes + writers[-1].writes, tags=writers[-2].tags, require_at_least_one_of=writers[-2].require_at_least_one_of, ) writers.pop() return writers @cached_property def node(self) -> Optional[Runnable[Any, Any]]: """Get a runnable that combines `bound` and `writers`.""" writers = self.flat_writers if self.bound is DEFAULT_BOUND and not writers: return None elif self.bound is DEFAULT_BOUND and len(writers) == 1: return writers[0] elif self.bound is DEFAULT_BOUND: return RunnableSeq(*writers) elif writers: return RunnableSeq(self.bound, *writers) else: return self.bound def join(self, channels: Sequence[str]) -> PregelNode: assert isinstance(channels, list) or isinstance( channels, tuple ), "channels must be a list or tuple" assert isinstance( self.channels, dict ), "all channels must be named when using .join()" return self.copy( update=dict( channels={ **self.channels, **{chan: chan for chan in channels}, } ), ) def __or__( self, other: Union[ Runnable[Any, Other], Callable[[Any], Other], Mapping[str, Runnable[Any, Other] | Callable[[Any], Other]], ], ) -> PregelNode: if isinstance(other, Runnable) and ChannelWrite.is_writer(other): return self.copy(update=dict(writers=[*self.writers, other])) elif self.bound is DEFAULT_BOUND: return self.copy(update=dict(bound=coerce_to_runnable(other))) else: return self.copy(update=dict(bound=RunnableSeq(self.bound, other))) def pipe( self, *others: Runnable[Any, Other] | Callable[[Any], Other], name: Optional[str] = None, ) -> RunnableSerializable[Any, Other]: for other in others: self = self | other return self def __ror__( self, other: Union[ Runnable[Other, Any], Callable[[Any], Other], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], ], ) -> RunnableSerializable: raise NotImplementedError() def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Any: return self.bound.invoke( input, merge_configs({"metadata": self.metadata, "tags": self.tags}, config), **kwargs, ) async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Any: return await self.bound.ainvoke( input, merge_configs({"metadata": self.metadata, "tags": self.tags}, config), **kwargs, ) def stream( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Any]: yield from self.bound.stream( input, merge_configs({"metadata": self.metadata, "tags": self.tags}, config), **kwargs, ) async def astream( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Any]: async for item in self.bound.astream( input, merge_configs({"metadata": self.metadata, "tags": self.tags}, config), **kwargs, ): yield item ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/managed/__init__.py`: ```py from langgraph.managed.is_last_step import IsLastStep, RemainingSteps __all__ = ["IsLastStep", "RemainingSteps"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/managed/shared_value.py`: ```py import collections.abc from contextlib import asynccontextmanager, contextmanager from typing import ( Any, AsyncIterator, Iterator, Optional, Sequence, Type, ) from typing_extensions import NotRequired, Required, Self from langgraph.constants import CONF from langgraph.errors import InvalidUpdateError from langgraph.managed.base import ( ChannelKeyPlaceholder, ChannelTypePlaceholder, ConfiguredManagedValue, WritableManagedValue, ) from langgraph.store.base import PutOp from langgraph.types import LoopProtocol V = dict[str, Any] Value = dict[str, V] Update = dict[str, Optional[V]] # Adapted from typing_extensions def _strip_extras(t): # type: ignore[no-untyped-def] """Strips Annotated, Required and NotRequired from a given type.""" if hasattr(t, "__origin__"): return _strip_extras(t.__origin__) if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired): return _strip_extras(t.__args__[0]) return t class SharedValue(WritableManagedValue[Value, Update]): @staticmethod def on(scope: str) -> ConfiguredManagedValue: return ConfiguredManagedValue( SharedValue, { "scope": scope, "key": ChannelKeyPlaceholder, "typ": ChannelTypePlaceholder, }, ) @classmethod @contextmanager def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]: with super().enter(loop, **kwargs) as value: if loop.store is not None: saved = loop.store.search(value.ns) value.value = {it.key: it.value for it in saved} yield value @classmethod @asynccontextmanager async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]: async with super().aenter(loop, **kwargs) as value: if loop.store is not None: saved = await loop.store.asearch(value.ns) value.value = {it.key: it.value for it in saved} yield value def __init__( self, loop: LoopProtocol, *, typ: Type[Any], scope: str, key: str ) -> None: super().__init__(loop) if typ := _strip_extras(typ): if typ not in ( dict, collections.abc.Mapping, collections.abc.MutableMapping, ): raise ValueError("SharedValue must be a dict") self.scope = scope self.value: Value = {} if self.loop.store is None: pass elif scope_value := self.loop.config[CONF].get(self.scope): self.ns = ("scoped", scope, key, scope_value) else: raise ValueError( f"Scope {scope} for shared state key not in config.configurable" ) def __call__(self) -> Value: return self.value def _process_update(self, values: Sequence[Update]) -> list[PutOp]: writes: list[PutOp] = [] for vv in values: for k, v in vv.items(): if v is None: if k in self.value: del self.value[k] writes.append(PutOp(self.ns, k, None)) elif not isinstance(v, dict): raise InvalidUpdateError("Received a non-dict value") else: self.value[k] = v writes.append(PutOp(self.ns, k, v)) return writes def update(self, values: Sequence[Update]) -> None: if self.loop.store is None: self._process_update(values) else: return self.loop.store.batch(self._process_update(values)) async def aupdate(self, writes: Sequence[Update]) -> None: if self.loop.store is None: self._process_update(writes) else: return await self.loop.store.abatch(self._process_update(writes)) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/managed/context.py`: ```py from contextlib import asynccontextmanager, contextmanager from inspect import signature from typing import ( Any, AsyncContextManager, AsyncIterator, Callable, ContextManager, Generic, Iterator, Optional, Type, Union, ) from typing_extensions import Self from langgraph.managed.base import ConfiguredManagedValue, ManagedValue, V from langgraph.types import LoopProtocol class Context(ManagedValue[V], Generic[V]): runtime = True value: V @staticmethod def of( ctx: Union[ None, Callable[..., ContextManager[V]], Type[ContextManager[V]], Callable[..., AsyncContextManager[V]], Type[AsyncContextManager[V]], ] = None, actx: Optional[ Union[ Callable[..., AsyncContextManager[V]], Type[AsyncContextManager[V]], ] ] = None, ) -> ConfiguredManagedValue: if ctx is None and actx is None: raise ValueError("Must provide either sync or async context manager.") return ConfiguredManagedValue(Context, {"ctx": ctx, "actx": actx}) @classmethod @contextmanager def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]: with super().enter(loop, **kwargs) as self: if self.ctx is None: raise ValueError( "Synchronous context manager not found. Please initialize Context value with a sync context manager, or invoke your graph asynchronously." ) ctx = ( self.ctx(loop.config) # type: ignore[call-arg] if signature(self.ctx).parameters.get("config") else self.ctx() ) with ctx as v: # type: ignore[union-attr] self.value = v yield self @classmethod @asynccontextmanager async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]: async with super().aenter(loop, **kwargs) as self: if self.actx is not None: ctx = ( self.actx(loop.config) # type: ignore[call-arg] if signature(self.actx).parameters.get("config") else self.actx() ) elif self.ctx is not None: ctx = ( self.ctx(loop.config) # type: ignore if signature(self.ctx).parameters.get("config") else self.ctx() ) else: raise ValueError( "Asynchronous context manager not found. Please initialize Context value with an async context manager, or invoke your graph synchronously." ) if hasattr(ctx, "__aenter__"): async with ctx as v: self.value = v yield self elif hasattr(ctx, "__enter__") and hasattr(ctx, "__exit__"): with ctx as v: self.value = v yield self else: raise ValueError( "Context manager must have either __enter__ or __aenter__ method." ) def __init__( self, loop: LoopProtocol, *, ctx: Union[None, Type[ContextManager[V]], Type[AsyncContextManager[V]]] = None, actx: Optional[Type[AsyncContextManager[V]]] = None, ) -> None: self.ctx = ctx self.actx = actx def __call__(self) -> V: return self.value ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/managed/is_last_step.py`: ```py from typing import Annotated from langgraph.managed.base import ManagedValue class IsLastStepManager(ManagedValue[bool]): def __call__(self) -> bool: return self.loop.step == self.loop.stop - 1 IsLastStep = Annotated[bool, IsLastStepManager] class RemainingStepsManager(ManagedValue[int]): def __call__(self) -> int: return self.loop.stop - self.loop.step RemainingSteps = Annotated[int, RemainingStepsManager] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/managed/base.py`: ```py from abc import ABC, abstractmethod from contextlib import asynccontextmanager, contextmanager from inspect import isclass from typing import ( Any, AsyncIterator, Generic, Iterator, NamedTuple, Sequence, Type, TypeVar, Union, ) from typing_extensions import Self, TypeGuard from langgraph.types import LoopProtocol V = TypeVar("V") U = TypeVar("U") class ManagedValue(ABC, Generic[V]): def __init__(self, loop: LoopProtocol) -> None: self.loop = loop @classmethod @contextmanager def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]: try: value = cls(loop, **kwargs) yield value finally: # because managed value and Pregel have reference to each other # let's make sure to break the reference on exit try: del value except UnboundLocalError: pass @classmethod @asynccontextmanager async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]: try: value = cls(loop, **kwargs) yield value finally: # because managed value and Pregel have reference to each other # let's make sure to break the reference on exit try: del value except UnboundLocalError: pass @abstractmethod def __call__(self) -> V: ... class WritableManagedValue(Generic[V, U], ManagedValue[V], ABC): @abstractmethod def update(self, writes: Sequence[U]) -> None: ... @abstractmethod async def aupdate(self, writes: Sequence[U]) -> None: ... class ConfiguredManagedValue(NamedTuple): cls: Type[ManagedValue] kwargs: dict[str, Any] ManagedValueSpec = Union[Type[ManagedValue], ConfiguredManagedValue] def is_managed_value(value: Any) -> TypeGuard[ManagedValueSpec]: return (isclass(value) and issubclass(value, ManagedValue)) or isinstance( value, ConfiguredManagedValue ) def is_readonly_managed_value(value: Any) -> TypeGuard[Type[ManagedValue]]: return ( isclass(value) and issubclass(value, ManagedValue) and not issubclass(value, WritableManagedValue) ) or ( isinstance(value, ConfiguredManagedValue) and not issubclass(value.cls, WritableManagedValue) ) def is_writable_managed_value(value: Any) -> TypeGuard[Type[WritableManagedValue]]: return (isclass(value) and issubclass(value, WritableManagedValue)) or ( isinstance(value, ConfiguredManagedValue) and issubclass(value.cls, WritableManagedValue) ) ChannelKeyPlaceholder = object() ChannelTypePlaceholder = object() ManagedValueMapping = dict[str, ManagedValue] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/version.py`: ```py """Exports package version.""" from importlib import metadata try: __version__ = metadata.version(__package__) except metadata.PackageNotFoundError: # Case where package metadata is not available. __version__ = "" del metadata # optional, avoids polluting the results of dir(__package__) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/constants.py`: ```py import sys from os import getenv from types import MappingProxyType from typing import Any, Literal, Mapping, cast from langgraph.types import Interrupt, Send # noqa: F401 # Interrupt, Send re-exported for backwards compatibility # --- Empty read-only containers --- EMPTY_MAP: Mapping[str, Any] = MappingProxyType({}) EMPTY_SEQ: tuple[str, ...] = tuple() MISSING = object() # --- Public constants --- TAG_NOSTREAM = sys.intern("langsmith:nostream") """Tag to disable streaming for a chat model.""" TAG_HIDDEN = sys.intern("langsmith:hidden") """Tag to hide a node/edge from certain tracing/streaming environments.""" START = sys.intern("__start__") """The first (maybe virtual) node in graph-style Pregel.""" END = sys.intern("__end__") """The last (maybe virtual) node in graph-style Pregel.""" SELF = sys.intern("__self__") """The implicit branch that handles each node's Control values.""" # --- Reserved write keys --- INPUT = sys.intern("__input__") # for values passed as input to the graph INTERRUPT = sys.intern("__interrupt__") # for dynamic interrupts raised by nodes RESUME = sys.intern("__resume__") # for values passed to resume a node after an interrupt ERROR = sys.intern("__error__") # for errors raised by nodes NO_WRITES = sys.intern("__no_writes__") # marker to signal node didn't write anything SCHEDULED = sys.intern("__scheduled__") # marker to signal node was scheduled (in distributed mode) TASKS = sys.intern("__pregel_tasks") # for Send objects returned by nodes/edges, corresponds to PUSH below # --- Reserved config.configurable keys --- CONFIG_KEY_SEND = sys.intern("__pregel_send") # holds the `write` function that accepts writes to state/edges/reserved keys CONFIG_KEY_READ = sys.intern("__pregel_read") # holds the `read` function that returns a copy of the current state CONFIG_KEY_CHECKPOINTER = sys.intern("__pregel_checkpointer") # holds a `BaseCheckpointSaver` passed from parent graph to child graphs CONFIG_KEY_STREAM = sys.intern("__pregel_stream") # holds a `StreamProtocol` passed from parent graph to child graphs CONFIG_KEY_STREAM_WRITER = sys.intern("__pregel_stream_writer") # holds a `StreamWriter` for stream_mode=custom CONFIG_KEY_STORE = sys.intern("__pregel_store") # holds a `BaseStore` made available to managed values CONFIG_KEY_RESUMING = sys.intern("__pregel_resuming") # holds a boolean indicating if subgraphs should resume from a previous checkpoint CONFIG_KEY_TASK_ID = sys.intern("__pregel_task_id") # holds the task ID for the current task CONFIG_KEY_DEDUPE_TASKS = sys.intern("__pregel_dedupe_tasks") # holds a boolean indicating if tasks should be deduplicated (for distributed mode) CONFIG_KEY_ENSURE_LATEST = sys.intern("__pregel_ensure_latest") # holds a boolean indicating whether to assert the requested checkpoint is the latest # (for distributed mode) CONFIG_KEY_DELEGATE = sys.intern("__pregel_delegate") # holds a boolean indicating whether to delegate subgraphs (for distributed mode) CONFIG_KEY_CHECKPOINT_MAP = sys.intern("checkpoint_map") # holds a mapping of checkpoint_ns -> checkpoint_id for parent graphs CONFIG_KEY_CHECKPOINT_ID = sys.intern("checkpoint_id") # holds the current checkpoint_id, if any CONFIG_KEY_CHECKPOINT_NS = sys.intern("checkpoint_ns") # holds the current checkpoint_ns, "" for root graph CONFIG_KEY_NODE_FINISHED = sys.intern("__pregel_node_finished") # holds the value that "answers" an interrupt() call CONFIG_KEY_WRITES = sys.intern("__pregel_writes") # read-only list of existing task writes CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad") # holds a mutable dict for temporary storage scoped to the current task # --- Other constants --- PUSH = sys.intern("__pregel_push") # denotes push-style tasks, ie. those created by Send objects PULL = sys.intern("__pregel_pull") # denotes pull-style tasks, ie. those triggered by edges NS_SEP = sys.intern("|") # for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph) NS_END = sys.intern(":") # for checkpoint_ns, for each level, separates the namespace from the task_id CONF = cast(Literal["configurable"], sys.intern("configurable")) # key for the configurable dict in RunnableConfig FF_SEND_V2 = getenv("LANGGRAPH_FF_SEND_V2", "false").lower() == "true" # temporary flag to enable new Send semantics NULL_TASK_ID = sys.intern("00000000-0000-0000-0000-000000000000") # the task_id to use for writes that are not associated with a task RESERVED = { TAG_HIDDEN, # reserved write keys INPUT, INTERRUPT, RESUME, ERROR, NO_WRITES, SCHEDULED, TASKS, # reserved config.configurable keys CONFIG_KEY_SEND, CONFIG_KEY_READ, CONFIG_KEY_CHECKPOINTER, CONFIG_KEY_STREAM, CONFIG_KEY_STREAM_WRITER, CONFIG_KEY_STORE, CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_RESUMING, CONFIG_KEY_TASK_ID, CONFIG_KEY_DEDUPE_TASKS, CONFIG_KEY_ENSURE_LATEST, CONFIG_KEY_DELEGATE, CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_CHECKPOINT_ID, CONFIG_KEY_CHECKPOINT_NS, # other constants PUSH, PULL, NS_SEP, NS_END, CONF, } ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/graph/graph.py`: ```py import asyncio import logging from collections import defaultdict from typing import ( Any, Awaitable, Callable, Hashable, Literal, NamedTuple, Optional, Sequence, Union, cast, get_args, get_origin, get_type_hints, overload, ) from langchain_core.runnables import Runnable from langchain_core.runnables.base import RunnableLike from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.graph import Graph as DrawableGraph from langchain_core.runnables.graph import Node as DrawableNode from typing_extensions import Self from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.constants import ( EMPTY_SEQ, END, NS_END, NS_SEP, START, TAG_HIDDEN, Send, ) from langgraph.errors import InvalidUpdateError from langgraph.pregel import Channel, Pregel from langgraph.pregel.read import PregelNode from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry from langgraph.types import All, Checkpointer from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable logger = logging.getLogger(__name__) class NodeSpec(NamedTuple): runnable: Runnable metadata: Optional[dict[str, Any]] = None ends: Optional[tuple[str, ...]] = EMPTY_SEQ class Branch(NamedTuple): path: Runnable[Any, Union[Hashable, list[Hashable]]] ends: Optional[dict[Hashable, str]] then: Optional[str] = None def run( self, writer: Callable[ [Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] ], reader: Optional[Callable[[RunnableConfig], Any]] = None, ) -> RunnableCallable: return ChannelWrite.register_writer( RunnableCallable( func=self._route, afunc=self._aroute, writer=writer, reader=reader, name=None, trace=False, ) ) def _route( self, input: Any, config: RunnableConfig, *, reader: Optional[Callable[[RunnableConfig], Any]], writer: Callable[ [Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] ], ) -> Runnable: if reader: value = reader(config) # passthrough additional keys from node to branch # only doable when using dict states if isinstance(value, dict) and isinstance(input, dict): value = {**input, **value} else: value = input result = self.path.invoke(value, config) return self._finish(writer, input, result, config) async def _aroute( self, input: Any, config: RunnableConfig, *, reader: Optional[Callable[[RunnableConfig], Any]], writer: Callable[ [Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] ], ) -> Runnable: if reader: value = await asyncio.to_thread(reader, config) # passthrough additional keys from node to branch # only doable when using dict states if isinstance(value, dict) and isinstance(input, dict): value = {**input, **value} else: value = input result = await self.path.ainvoke(value, config) return self._finish(writer, input, result, config) def _finish( self, writer: Callable[ [Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] ], input: Any, result: Any, config: RunnableConfig, ) -> Union[Runnable, Any]: if not isinstance(result, (list, tuple)): result = [result] if self.ends: destinations: Sequence[Union[Send, str]] = [ r if isinstance(r, Send) else self.ends[r] for r in result ] else: destinations = cast(Sequence[Union[Send, str]], result) if any(dest is None or dest == START for dest in destinations): raise ValueError("Branch did not return a valid destination") if any(p.node == END for p in destinations if isinstance(p, Send)): raise InvalidUpdateError("Cannot send a packet to the END node") return writer(destinations, config) or input class Graph: def __init__(self) -> None: self.nodes: dict[str, NodeSpec] = {} self.edges = set[tuple[str, str]]() self.branches: defaultdict[str, dict[str, Branch]] = defaultdict(dict) self.support_multiple_edges = False self.compiled = False @property def _all_edges(self) -> set[tuple[str, str]]: return self.edges @overload def add_node( self, node: RunnableLike, *, metadata: Optional[dict[str, Any]] = None, ) -> Self: ... @overload def add_node( self, node: str, action: RunnableLike, *, metadata: Optional[dict[str, Any]] = None, ) -> Self: ... def add_node( self, node: Union[str, RunnableLike], action: Optional[RunnableLike] = None, *, metadata: Optional[dict[str, Any]] = None, ) -> Self: if isinstance(node, str): for character in (NS_SEP, NS_END): if character in node: raise ValueError( f"'{character}' is a reserved character and is not allowed in the node names." ) if self.compiled: logger.warning( "Adding a node to a graph that has already been compiled. This will " "not be reflected in the compiled graph." ) if not isinstance(node, str): action = node node = getattr(action, "name", getattr(action, "__name__")) if node is None: raise ValueError( "Node name must be provided if action is not a function" ) if action is None: raise RuntimeError( "Expected a function or Runnable action in add_node. Received None." ) if node in self.nodes: raise ValueError(f"Node `{node}` already present.") if node == END or node == START: raise ValueError(f"Node `{node}` is reserved.") self.nodes[cast(str, node)] = NodeSpec( coerce_to_runnable(action, name=cast(str, node), trace=False), metadata ) return self def add_edge(self, start_key: str, end_key: str) -> Self: if self.compiled: logger.warning( "Adding an edge to a graph that has already been compiled. This will " "not be reflected in the compiled graph." ) if start_key == END: raise ValueError("END cannot be a start node") if end_key == START: raise ValueError("START cannot be an end node") # run this validation only for non-StateGraph graphs if not hasattr(self, "channels") and start_key in set( start for start, _ in self.edges ): raise ValueError( f"Already found path for node '{start_key}'.\n" "For multiple edges, use StateGraph with an Annotated state key." ) self.edges.add((start_key, end_key)) return self def add_conditional_edges( self, source: str, path: Union[ Callable[..., Union[Hashable, list[Hashable]]], Callable[..., Awaitable[Union[Hashable, list[Hashable]]]], Runnable[Any, Union[Hashable, list[Hashable]]], ], path_map: Optional[Union[dict[Hashable, str], list[str]]] = None, then: Optional[str] = None, ) -> Self: """Add a conditional edge from the starting node to any number of destination nodes. Args: source (str): The starting node. This conditional edge will run when exiting this node. path (Union[Callable, Runnable]): The callable that determines the next node or nodes. If not specifying `path_map` it should return one or more nodes. If it returns END, the graph will stop execution. path_map (Optional[dict[Hashable, str]]): Optional mapping of paths to node names. If omitted the paths returned by `path` should be node names. then (Optional[str]): The name of a node to execute after the nodes selected by `path`. Returns: None Note: Without typehints on the `path` function's return value (e.g., `-> Literal["foo", "__end__"]:`) or a path_map, the graph visualization assumes the edge could transition to any node in the graph. """ # noqa: E501 if self.compiled: logger.warning( "Adding an edge to a graph that has already been compiled. This will " "not be reflected in the compiled graph." ) # coerce path_map to a dictionary try: if isinstance(path_map, dict): path_map_ = path_map.copy() elif isinstance(path_map, list): path_map_ = {name: name for name in path_map} elif isinstance(path, Runnable): path_map_ = None elif rtn_type := get_type_hints(path.__call__).get( # type: ignore[operator] "return" ) or get_type_hints(path).get("return"): if get_origin(rtn_type) is Literal: path_map_ = {name: name for name in get_args(rtn_type)} else: path_map_ = None else: path_map_ = None except Exception: path_map_ = None # find a name for the condition path = coerce_to_runnable(path, name=None, trace=True) name = path.name or "condition" # validate the condition if name in self.branches[source]: raise ValueError( f"Branch with name `{path.name}` already exists for node " f"`{source}`" ) # save it self.branches[source][name] = Branch(path, path_map_, then) return self def set_entry_point(self, key: str) -> Self: """Specifies the first node to be called in the graph. Equivalent to calling `add_edge(START, key)`. Parameters: key (str): The key of the node to set as the entry point. Returns: None """ return self.add_edge(START, key) def set_conditional_entry_point( self, path: Union[ Callable[..., Union[Hashable, list[Hashable]]], Callable[..., Awaitable[Union[Hashable, list[Hashable]]]], Runnable[Any, Union[Hashable, list[Hashable]]], ], path_map: Optional[Union[dict[Hashable, str], list[str]]] = None, then: Optional[str] = None, ) -> Self: """Sets a conditional entry point in the graph. Args: path (Union[Callable, Runnable]): The callable that determines the next node or nodes. If not specifying `path_map` it should return one or more nodes. If it returns END, the graph will stop execution. path_map (Optional[dict[str, str]]): Optional mapping of paths to node names. If omitted the paths returned by `path` should be node names. then (Optional[str]): The name of a node to execute after the nodes selected by `path`. Returns: None """ return self.add_conditional_edges(START, path, path_map, then) def set_finish_point(self, key: str) -> Self: """Marks a node as a finish point of the graph. If the graph reaches this node, it will cease execution. Parameters: key (str): The key of the node to set as the finish point. Returns: None """ return self.add_edge(key, END) def validate(self, interrupt: Optional[Sequence[str]] = None) -> Self: # assemble sources all_sources = {src for src, _ in self._all_edges} for start, branches in self.branches.items(): all_sources.add(start) for cond, branch in branches.items(): if branch.then is not None: if branch.ends is not None: for end in branch.ends.values(): if end != END: all_sources.add(end) else: for node in self.nodes: if node != start and node != branch.then: all_sources.add(node) for name, spec in self.nodes.items(): if spec.ends: all_sources.add(name) # validate sources for source in all_sources: if source not in self.nodes and source != START: raise ValueError(f"Found edge starting at unknown node '{source}'") if START not in all_sources: raise ValueError( "Graph must have an entrypoint: add at least one edge from START to another node" ) # assemble targets all_targets = {end for _, end in self._all_edges} for start, branches in self.branches.items(): for cond, branch in branches.items(): if branch.then is not None: all_targets.add(branch.then) if branch.ends is not None: for end in branch.ends.values(): if end not in self.nodes and end != END: raise ValueError( f"At '{start}' node, '{cond}' branch found unknown target '{end}'" ) all_targets.add(end) else: all_targets.add(END) for node in self.nodes: if node != start and node != branch.then: all_targets.add(node) for name, spec in self.nodes.items(): if spec.ends: all_targets.update(spec.ends) for target in all_targets: if target not in self.nodes and target != END: raise ValueError(f"Found edge ending at unknown node `{target}`") # validate interrupts if interrupt: for node in interrupt: if node not in self.nodes: raise ValueError(f"Interrupt node `{node}` not found") self.compiled = True return self def compile( self, checkpointer: Checkpointer = None, interrupt_before: Optional[Union[All, list[str]]] = None, interrupt_after: Optional[Union[All, list[str]]] = None, debug: bool = False, ) -> "CompiledGraph": # assign default values interrupt_before = interrupt_before or [] interrupt_after = interrupt_after or [] # validate the graph self.validate( interrupt=( (interrupt_before if interrupt_before != "*" else []) + interrupt_after if interrupt_after != "*" else [] ) ) # create empty compiled graph compiled = CompiledGraph( builder=self, nodes={}, channels={START: EphemeralValue(Any), END: EphemeralValue(Any)}, input_channels=START, output_channels=END, stream_mode="values", stream_channels=[], checkpointer=checkpointer, interrupt_before_nodes=interrupt_before, interrupt_after_nodes=interrupt_after, auto_validate=False, debug=debug, ) # attach nodes, edges, and branches for key, node in self.nodes.items(): compiled.attach_node(key, node) for start, end in self.edges: compiled.attach_edge(start, end) for start, branches in self.branches.items(): for name, branch in branches.items(): compiled.attach_branch(start, name, branch) # validate the compiled graph return compiled.validate() class CompiledGraph(Pregel): builder: Graph def __init__(self, *, builder: Graph, **kwargs: Any) -> None: super().__init__(**kwargs) self.builder = builder def attach_node(self, key: str, node: NodeSpec) -> None: self.channels[key] = EphemeralValue(Any) self.nodes[key] = ( PregelNode(channels=[], triggers=[], metadata=node.metadata) | node.runnable | ChannelWrite([ChannelWriteEntry(key)], tags=[TAG_HIDDEN]) ) cast(list[str], self.stream_channels).append(key) def attach_edge(self, start: str, end: str) -> None: if end == END: # publish to end channel self.nodes[start].writers.append( ChannelWrite([ChannelWriteEntry(END)], tags=[TAG_HIDDEN]) ) else: # subscribe to start channel self.nodes[end].triggers.append(start) cast(list[str], self.nodes[end].channels).append(start) def attach_branch(self, start: str, name: str, branch: Branch) -> None: def branch_writer( packets: Sequence[Union[str, Send]], config: RunnableConfig ) -> Optional[ChannelWrite]: writes = [ ( ChannelWriteEntry(f"branch:{start}:{name}:{p}" if p != END else END) if not isinstance(p, Send) else p ) for p in packets ] return ChannelWrite( cast(Sequence[Union[ChannelWriteEntry, Send]], writes), tags=[TAG_HIDDEN], ) # add hidden start node if start == START and start not in self.nodes: self.nodes[start] = Channel.subscribe_to(START, tags=[TAG_HIDDEN]) # attach branch writer self.nodes[start] |= branch.run(branch_writer) # attach branch readers ends = branch.ends.values() if branch.ends else [node for node in self.nodes] for end in ends: if end != END: channel_name = f"branch:{start}:{name}:{end}" self.channels[channel_name] = EphemeralValue(Any) self.nodes[end].triggers.append(channel_name) cast(list[str], self.nodes[end].channels).append(channel_name) async def aget_graph( self, config: Optional[RunnableConfig] = None, *, xray: Union[int, bool] = False, ) -> DrawableGraph: return self.get_graph(config, xray=xray) def get_graph( self, config: Optional[RunnableConfig] = None, *, xray: Union[int, bool] = False, ) -> DrawableGraph: """Returns a drawable representation of the computation graph.""" graph = DrawableGraph() start_nodes: dict[str, DrawableNode] = { START: graph.add_node(self.get_input_schema(config), START) } end_nodes: dict[str, DrawableNode] = {} if xray: subgraphs = { k: v for k, v in self.get_subgraphs() if isinstance(v, CompiledGraph) } else: subgraphs = {} def add_edge( start: str, end: str, label: Optional[Hashable] = None, conditional: bool = False, ) -> None: if end == END and END not in end_nodes: end_nodes[END] = graph.add_node(self.get_output_schema(config), END) return graph.add_edge( start_nodes[start], end_nodes[end], str(label) if label is not None else None, conditional, ) for key, n in self.builder.nodes.items(): node = n.runnable metadata = n.metadata or {} if key in self.interrupt_before_nodes and key in self.interrupt_after_nodes: metadata["__interrupt"] = "before,after" elif key in self.interrupt_before_nodes: metadata["__interrupt"] = "before" elif key in self.interrupt_after_nodes: metadata["__interrupt"] = "after" if xray and key in subgraphs: subgraph = subgraphs[key].get_graph( config=config, xray=xray - 1 if isinstance(xray, int) and not isinstance(xray, bool) and xray > 0 else xray, ) subgraph.trim_first_node() subgraph.trim_last_node() if len(subgraph.nodes) > 1: e, s = graph.extend(subgraph, prefix=key) if e is None: raise ValueError( f"Could not extend subgraph '{key}' due to missing entrypoint" ) if s is not None: start_nodes[key] = s end_nodes[key] = e else: nn = graph.add_node(node, key, metadata=metadata or None) start_nodes[key] = nn end_nodes[key] = nn else: nn = graph.add_node(node, key, metadata=metadata or None) start_nodes[key] = nn end_nodes[key] = nn for start, end in sorted(self.builder._all_edges): add_edge(start, end) for start, branches in self.builder.branches.items(): default_ends = { **{k: k for k in self.builder.nodes if k != start}, END: END, } for _, branch in branches.items(): if branch.ends is not None: ends = branch.ends elif branch.then is not None: ends = {k: k for k in default_ends if k not in (END, branch.then)} else: ends = cast(dict[Hashable, str], default_ends) for label, end in ends.items(): add_edge( start, end, label if label != end else None, conditional=True, ) if branch.then is not None: add_edge(end, branch.then) for key, n in self.builder.nodes.items(): if n.ends: for end in n.ends: add_edge(key, end, conditional=True) return graph ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/graph/__init__.py`: ```py from langgraph.graph.graph import END, START, Graph from langgraph.graph.message import MessageGraph, MessagesState, add_messages from langgraph.graph.state import StateGraph __all__ = [ "END", "START", "Graph", "StateGraph", "MessageGraph", "add_messages", "MessagesState", ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/graph/message.py`: ```py import uuid from typing import Annotated, TypedDict, Union, cast from langchain_core.messages import ( AnyMessage, BaseMessageChunk, MessageLikeRepresentation, RemoveMessage, convert_to_messages, message_chunk_to_message, ) from langgraph.graph.state import StateGraph Messages = Union[list[MessageLikeRepresentation], MessageLikeRepresentation] def add_messages(left: Messages, right: Messages) -> Messages: """Merges two lists of messages, updating existing messages by ID. By default, this ensures the state is "append-only", unless the new message has the same ID as an existing message. Args: left: The base list of messages. right: The list of messages (or single message) to merge into the base list. Returns: A new list of messages with the messages from `right` merged into `left`. If a message in `right` has the same ID as a message in `left`, the message from `right` will replace the message from `left`. Examples: ```pycon >>> from langchain_core.messages import AIMessage, HumanMessage >>> msgs1 = [HumanMessage(content="Hello", id="1")] >>> msgs2 = [AIMessage(content="Hi there!", id="2")] >>> add_messages(msgs1, msgs2) [HumanMessage(content='Hello', id='1'), AIMessage(content='Hi there!', id='2')] >>> msgs1 = [HumanMessage(content="Hello", id="1")] >>> msgs2 = [HumanMessage(content="Hello again", id="1")] >>> add_messages(msgs1, msgs2) [HumanMessage(content='Hello again', id='1')] >>> from typing import Annotated >>> from typing_extensions import TypedDict >>> from langgraph.graph import StateGraph >>> >>> class State(TypedDict): ... messages: Annotated[list, add_messages] ... >>> builder = StateGraph(State) >>> builder.add_node("chatbot", lambda state: {"messages": [("assistant", "Hello")]}) >>> builder.set_entry_point("chatbot") >>> builder.set_finish_point("chatbot") >>> graph = builder.compile() >>> graph.invoke({}) {'messages': [AIMessage(content='Hello', id=...)]} ``` """ # coerce to list if not isinstance(left, list): left = [left] # type: ignore[assignment] if not isinstance(right, list): right = [right] # type: ignore[assignment] # coerce to message left = [ message_chunk_to_message(cast(BaseMessageChunk, m)) for m in convert_to_messages(left) ] right = [ message_chunk_to_message(cast(BaseMessageChunk, m)) for m in convert_to_messages(right) ] # assign missing ids for m in left: if m.id is None: m.id = str(uuid.uuid4()) for m in right: if m.id is None: m.id = str(uuid.uuid4()) # merge left_idx_by_id = {m.id: i for i, m in enumerate(left)} merged = left.copy() ids_to_remove = set() for m in right: if (existing_idx := left_idx_by_id.get(m.id)) is not None: if isinstance(m, RemoveMessage): ids_to_remove.add(m.id) else: merged[existing_idx] = m else: if isinstance(m, RemoveMessage): raise ValueError( f"Attempting to delete a message with an ID that doesn't exist ('{m.id}')" ) merged.append(m) merged = [m for m in merged if m.id not in ids_to_remove] return merged class MessageGraph(StateGraph): """A StateGraph where every node receives a list of messages as input and returns one or more messages as output. MessageGraph is a subclass of StateGraph whose entire state is a single, append-only* list of messages. Each node in a MessageGraph takes a list of messages as input and returns zero or more messages as output. The `add_messages` function is used to merge the output messages from each node into the existing list of messages in the graph's state. Examples: ```pycon >>> from langgraph.graph.message import MessageGraph ... >>> builder = MessageGraph() >>> builder.add_node("chatbot", lambda state: [("assistant", "Hello!")]) >>> builder.set_entry_point("chatbot") >>> builder.set_finish_point("chatbot") >>> builder.compile().invoke([("user", "Hi there.")]) [HumanMessage(content="Hi there.", id='...'), AIMessage(content="Hello!", id='...')] ``` ```pycon >>> from langchain_core.messages import AIMessage, HumanMessage, ToolMessage >>> from langgraph.graph.message import MessageGraph ... >>> builder = MessageGraph() >>> builder.add_node( ... "chatbot", ... lambda state: [ ... AIMessage( ... content="Hello!", ... tool_calls=[{"name": "search", "id": "123", "args": {"query": "X"}}], ... ) ... ], ... ) >>> builder.add_node( ... "search", lambda state: [ToolMessage(content="Searching...", tool_call_id="123")] ... ) >>> builder.set_entry_point("chatbot") >>> builder.add_edge("chatbot", "search") >>> builder.set_finish_point("search") >>> builder.compile().invoke([HumanMessage(content="Hi there. Can you search for X?")]) {'messages': [HumanMessage(content="Hi there. Can you search for X?", id='b8b7d8f4-7f4d-4f4d-9c1d-f8b8d8f4d9c1'), AIMessage(content="Hello!", id='f4d9c1d8-8d8f-4d9c-b8b7-d8f4f4d9c1d8'), ToolMessage(content="Searching...", id='d8f4f4d9-c1d8-4f4d-b8b7-d8f4f4d9c1d8', tool_call_id="123")]} ``` """ def __init__(self) -> None: super().__init__(Annotated[list[AnyMessage], add_messages]) # type: ignore[arg-type] class MessagesState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/graph/state.py`: ```py import inspect import logging import typing import warnings from functools import partial from inspect import isclass, isfunction, ismethod, signature from types import FunctionType from typing import ( Any, Callable, Literal, NamedTuple, Optional, Sequence, Type, Union, cast, get_args, get_origin, get_type_hints, overload, ) from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.base import RunnableLike from pydantic import BaseModel from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self from langgraph._api.deprecation import LangGraphDeprecationWarning from langgraph.channels.base import BaseChannel from langgraph.channels.binop import BinaryOperatorAggregate from langgraph.channels.dynamic_barrier_value import DynamicBarrierValue, WaitForNames from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.channels.last_value import LastValue from langgraph.channels.named_barrier_value import NamedBarrierValue from langgraph.constants import EMPTY_SEQ, NS_END, NS_SEP, SELF, TAG_HIDDEN from langgraph.errors import ( ErrorCode, InvalidUpdateError, ParentCommand, create_error_message, ) from langgraph.graph.graph import END, START, Branch, CompiledGraph, Graph, Send from langgraph.managed.base import ( ChannelKeyPlaceholder, ChannelTypePlaceholder, ConfiguredManagedValue, ManagedValueSpec, is_managed_value, is_writable_managed_value, ) from langgraph.pregel.read import ChannelRead, PregelNode from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry from langgraph.store.base import BaseStore from langgraph.types import All, Checkpointer, Command, RetryPolicy from langgraph.utils.fields import get_field_default from langgraph.utils.pydantic import create_model from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable logger = logging.getLogger(__name__) def _warn_invalid_state_schema(schema: Union[Type[Any], Any]) -> None: if isinstance(schema, type): return if typing.get_args(schema): return warnings.warn( f"Invalid state_schema: {schema}. Expected a type or Annotated[type, reducer]. " "Please provide a valid schema to ensure correct updates.\n" " See: https://langchain-ai.github.io/langgraph/reference/graphs/#stategraph" ) def _get_node_name(node: RunnableLike) -> str: if isinstance(node, Runnable): return node.get_name() elif callable(node): return getattr(node, "__name__", node.__class__.__name__) else: raise TypeError(f"Unsupported node type: {type(node)}") class StateNodeSpec(NamedTuple): runnable: Runnable metadata: Optional[dict[str, Any]] input: Type[Any] retry_policy: Optional[RetryPolicy] ends: Optional[tuple[str, ...]] = EMPTY_SEQ class StateGraph(Graph): """A graph whose nodes communicate by reading and writing to a shared state. The signature of each node is State -> Partial<State>. Each state key can optionally be annotated with a reducer function that will be used to aggregate the values of that key received from multiple nodes. The signature of a reducer function is (Value, Value) -> Value. Args: state_schema (Type[Any]): The schema class that defines the state. config_schema (Optional[Type[Any]]): The schema class that defines the configuration. Use this to expose configurable parameters in your API. Examples: >>> from langchain_core.runnables import RunnableConfig >>> from typing_extensions import Annotated, TypedDict >>> from langgraph.checkpoint.memory import MemorySaver >>> from langgraph.graph import StateGraph >>> >>> def reducer(a: list, b: int | None) -> list: ... if b is not None: ... return a + [b] ... return a >>> >>> class State(TypedDict): ... x: Annotated[list, reducer] >>> >>> class ConfigSchema(TypedDict): ... r: float >>> >>> graph = StateGraph(State, config_schema=ConfigSchema) >>> >>> def node(state: State, config: RunnableConfig) -> dict: ... r = config["configurable"].get("r", 1.0) ... x = state["x"][-1] ... next_value = x * r * (1 - x) ... return {"x": next_value} >>> >>> graph.add_node("A", node) >>> graph.set_entry_point("A") >>> graph.set_finish_point("A") >>> compiled = graph.compile() >>> >>> print(compiled.config_specs) [ConfigurableFieldSpec(id='r', annotation=<class 'float'>, name=None, description=None, default=None, is_shared=False, dependencies=None)] >>> >>> step1 = compiled.invoke({"x": 0.5}, {"configurable": {"r": 3.0}}) >>> print(step1) {'x': [0.5, 0.75]}""" nodes: dict[str, StateNodeSpec] # type: ignore[assignment] channels: dict[str, BaseChannel] managed: dict[str, ManagedValueSpec] schemas: dict[Type[Any], dict[str, Union[BaseChannel, ManagedValueSpec]]] def __init__( self, state_schema: Optional[Type[Any]] = None, config_schema: Optional[Type[Any]] = None, *, input: Optional[Type[Any]] = None, output: Optional[Type[Any]] = None, ) -> None: super().__init__() if state_schema is None: if input is None or output is None: raise ValueError("Must provide state_schema or input and output") state_schema = input warnings.warn( "Initializing StateGraph without state_schema is deprecated. " "Please pass in an explicit state_schema instead of just an input and output schema.", LangGraphDeprecationWarning, stacklevel=2, ) else: if input is None: input = state_schema if output is None: output = state_schema self.schemas = {} self.channels = {} self.managed = {} self.schema = state_schema self.input = input self.output = output self._add_schema(state_schema) self._add_schema(input, allow_managed=False) self._add_schema(output, allow_managed=False) self.config_schema = config_schema self.waiting_edges: set[tuple[tuple[str, ...], str]] = set() @property def _all_edges(self) -> set[tuple[str, str]]: return self.edges | { (start, end) for starts, end in self.waiting_edges for start in starts } def _add_schema(self, schema: Type[Any], /, allow_managed: bool = True) -> None: if schema not in self.schemas: _warn_invalid_state_schema(schema) channels, managed = _get_channels(schema) if managed and not allow_managed: names = ", ".join(managed) schema_name = getattr(schema, "__name__", "") raise ValueError( f"Invalid managed channels detected in {schema_name}: {names}." " Managed channels are not permitted in Input/Output schema." ) self.schemas[schema] = {**channels, **managed} for key, channel in channels.items(): if key in self.channels: if self.channels[key] != channel: if isinstance(channel, LastValue): pass else: raise ValueError( f"Channel '{key}' already exists with a different type" ) else: self.channels[key] = channel for key, managed in managed.items(): if key in self.managed: if self.managed[key] != managed: raise ValueError( f"Managed value '{key}' already exists with a different type" ) else: self.managed[key] = managed @overload def add_node( self, node: RunnableLike, *, metadata: Optional[dict[str, Any]] = None, input: Optional[Type[Any]] = None, retry: Optional[RetryPolicy] = None, ) -> Self: """Adds a new node to the state graph. Will take the name of the function/runnable as the node name. Args: node (RunnableLike): The function or runnable this node will run. Raises: ValueError: If the key is already being used as a state key. Returns: StateGraph """ ... @overload def add_node( self, node: str, action: RunnableLike, *, metadata: Optional[dict[str, Any]] = None, input: Optional[Type[Any]] = None, retry: Optional[RetryPolicy] = None, ) -> Self: """Adds a new node to the state graph. Args: node (str): The key of the node. action (RunnableLike): The action associated with the node. Raises: ValueError: If the key is already being used as a state key. Returns: StateGraph """ ... def add_node( self, node: Union[str, RunnableLike], action: Optional[RunnableLike] = None, *, metadata: Optional[dict[str, Any]] = None, input: Optional[Type[Any]] = None, retry: Optional[RetryPolicy] = None, ) -> Self: """Adds a new node to the state graph. Will take the name of the function/runnable as the node name. Args: node (Union[str, RunnableLike)]: The function or runnable this node will run. action (Optional[RunnableLike]): The action associated with the node. (default: None) metadata (Optional[dict[str, Any]]): The metadata associated with the node. (default: None) input (Optional[Type[Any]]): The input schema for the node. (default: the graph's input schema) retry (Optional[RetryPolicy]): The policy for retrying the node. (default: None) Raises: ValueError: If the key is already being used as a state key. Examples: ```pycon >>> from langgraph.graph import START, StateGraph ... >>> def my_node(state, config): ... return {"x": state["x"] + 1} ... >>> builder = StateGraph(dict) >>> builder.add_node(my_node) # node name will be 'my_node' >>> builder.add_edge(START, "my_node") >>> graph = builder.compile() >>> graph.invoke({"x": 1}) {'x': 2} ``` Customize the name: ```pycon >>> builder = StateGraph(dict) >>> builder.add_node("my_fair_node", my_node) >>> builder.add_edge(START, "my_fair_node") >>> graph = builder.compile() >>> graph.invoke({"x": 1}) {'x': 2} ``` Returns: StateGraph """ if not isinstance(node, str): action = node if isinstance(action, Runnable): node = action.get_name() else: node = getattr(action, "__name__", action.__class__.__name__) if node is None: raise ValueError( "Node name must be provided if action is not a function" ) if node in self.channels: raise ValueError(f"'{node}' is already being used as a state key") if self.compiled: logger.warning( "Adding a node to a graph that has already been compiled. This will " "not be reflected in the compiled graph." ) if not isinstance(node, str): action = node node = cast(str, getattr(action, "name", getattr(action, "__name__", None))) if node is None: raise ValueError( "Node name must be provided if action is not a function" ) if action is None: raise RuntimeError if node in self.nodes: raise ValueError(f"Node `{node}` already present.") if node == END or node == START: raise ValueError(f"Node `{node}` is reserved.") for character in (NS_SEP, NS_END): if character in cast(str, node): raise ValueError( f"'{character}' is a reserved character and is not allowed in the node names." ) ends = EMPTY_SEQ try: if (isfunction(action) or ismethod(getattr(action, "__call__", None))) and ( hints := get_type_hints(getattr(action, "__call__")) or get_type_hints(action) ): if input is None: first_parameter_name = next( iter( inspect.signature( cast(FunctionType, action) ).parameters.keys() ) ) if input_hint := hints.get(first_parameter_name): if isinstance(input_hint, type) and get_type_hints(input_hint): input = input_hint if ( (rtn := hints.get("return")) and get_origin(rtn) is Command and (rargs := get_args(rtn)) and get_origin(rargs[0]) is Literal and (vals := get_args(rargs[0])) ): ends = vals except (TypeError, StopIteration): pass if input is not None: self._add_schema(input) self.nodes[cast(str, node)] = StateNodeSpec( coerce_to_runnable(action, name=cast(str, node), trace=False), metadata, input=input or self.schema, retry_policy=retry, ends=ends, ) return self def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> Self: """Adds a directed edge from the start node to the end node. If the graph transitions to the start_key node, it will always transition to the end_key node next. Args: start_key (Union[str, list[str]]): The key(s) of the start node(s) of the edge. end_key (str): The key of the end node of the edge. Raises: ValueError: If the start key is 'END' or if the start key or end key is not present in the graph. Returns: StateGraph """ if isinstance(start_key, str): return super().add_edge(start_key, end_key) if self.compiled: logger.warning( "Adding an edge to a graph that has already been compiled. This will " "not be reflected in the compiled graph." ) for start in start_key: if start == END: raise ValueError("END cannot be a start node") if start not in self.nodes: raise ValueError(f"Need to add_node `{start}` first") if end_key == START: raise ValueError("START cannot be an end node") if end_key != END and end_key not in self.nodes: raise ValueError(f"Need to add_node `{end_key}` first") self.waiting_edges.add((tuple(start_key), end_key)) return self def add_sequence( self, nodes: Sequence[Union[RunnableLike, tuple[str, RunnableLike]]], ) -> Self: """Add a sequence of nodes that will be executed in the provided order. Args: nodes: A sequence of RunnableLike objects (e.g. a LangChain Runnable or a callable) or (name, RunnableLike) tuples. If no names are provided, the name will be inferred from the node object (e.g. a runnable or a callable name). Each node will be executed in the order provided. Raises: ValueError: if the sequence is empty. ValueError: if the sequence contains duplicate node names. Returns: StateGraph """ if len(nodes) < 1: raise ValueError("Sequence requires at least one node.") previous_name: Optional[str] = None for node in nodes: if isinstance(node, tuple) and len(node) == 2: name, node = node else: name = _get_node_name(node) if name in self.nodes: raise ValueError( f"Node names must be unique: node with the name '{name}' already exists. " "If you need to use two different runnables/callables with the same name (for example, using `lambda`), please provide them as tuples (name, runnable/callable)." ) self.add_node(name, node) if previous_name is not None: self.add_edge(previous_name, name) previous_name = name return self def compile( self, checkpointer: Checkpointer = None, *, store: Optional[BaseStore] = None, interrupt_before: Optional[Union[All, list[str]]] = None, interrupt_after: Optional[Union[All, list[str]]] = None, debug: bool = False, ) -> "CompiledStateGraph": """Compiles the state graph into a `CompiledGraph` object. The compiled graph implements the `Runnable` interface and can be invoked, streamed, batched, and run asynchronously. Args: checkpointer (Optional[Union[Checkpointer, Literal[False]]]): A checkpoint saver object or flag. If provided, this Checkpointer serves as a fully versioned "short-term memory" for the graph, allowing it to be paused, resumed, and replayed from any point. If None, it may inherit the parent graph's checkpointer when used as a subgraph. If False, it will not use or inherit any checkpointer. interrupt_before (Optional[Sequence[str]]): An optional list of node names to interrupt before. interrupt_after (Optional[Sequence[str]]): An optional list of node names to interrupt after. debug (bool): A flag indicating whether to enable debug mode. Returns: CompiledStateGraph: The compiled state graph. """ # assign default values interrupt_before = interrupt_before or [] interrupt_after = interrupt_after or [] # validate the graph self.validate( interrupt=( (interrupt_before if interrupt_before != "*" else []) + interrupt_after if interrupt_after != "*" else [] ) ) # prepare output channels output_channels = ( "__root__" if len(self.schemas[self.output]) == 1 and "__root__" in self.schemas[self.output] else [ key for key, val in self.schemas[self.output].items() if not is_managed_value(val) ] ) stream_channels = ( "__root__" if len(self.channels) == 1 and "__root__" in self.channels else [ key for key, val in self.channels.items() if not is_managed_value(val) ] ) compiled = CompiledStateGraph( builder=self, config_type=self.config_schema, nodes={}, channels={ **self.channels, **self.managed, START: EphemeralValue(self.input), }, input_channels=START, stream_mode="updates", output_channels=output_channels, stream_channels=stream_channels, checkpointer=checkpointer, interrupt_before_nodes=interrupt_before, interrupt_after_nodes=interrupt_after, auto_validate=False, debug=debug, store=store, ) compiled.attach_node(START, None) for key, node in self.nodes.items(): compiled.attach_node(key, node) for key, node in self.nodes.items(): compiled.attach_branch(key, SELF, CONTROL_BRANCH, with_reader=False) for start, end in self.edges: compiled.attach_edge(start, end) for starts, end in self.waiting_edges: compiled.attach_edge(starts, end) for start, branches in self.branches.items(): for name, branch in branches.items(): compiled.attach_branch(start, name, branch) return compiled.validate() class CompiledStateGraph(CompiledGraph): builder: StateGraph def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: return _get_schema( typ=self.builder.input, schemas=self.builder.schemas, channels=self.builder.channels, name=self.get_name("Input"), ) def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: return _get_schema( typ=self.builder.output, schemas=self.builder.schemas, channels=self.builder.channels, name=self.get_name("Output"), ) def attach_node(self, key: str, node: Optional[StateNodeSpec]) -> None: if key == START: output_keys = [ k for k, v in self.builder.schemas[self.builder.input].items() if not is_managed_value(v) ] else: output_keys = list(self.builder.channels) + [ k for k, v in self.builder.managed.items() if is_writable_managed_value(v) ] def _get_root(input: Any) -> Any: if isinstance(input, Command): if input.graph == Command.PARENT: return SKIP_WRITE return input.update else: return input # to avoid name collision below node_key = key def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any: if input is None: return SKIP_WRITE elif isinstance(input, dict): if all(k not in output_keys for k in input): raise InvalidUpdateError( f"Expected node {node_key} to update at least one of {output_keys}, got {input}" ) return input.get(key, SKIP_WRITE) elif isinstance(input, Command): if input.graph == Command.PARENT: return SKIP_WRITE return _get_state_key(input.update, key=key) elif get_type_hints(type(input)): value = getattr(input, key, SKIP_WRITE) return value if value is not None else SKIP_WRITE else: msg = create_error_message( message=f"Expected dict, got {input}", error_code=ErrorCode.INVALID_GRAPH_NODE_RETURN_VALUE, ) raise InvalidUpdateError(msg) # state updaters write_entries = ( [ChannelWriteEntry("__root__", skip_none=True, mapper=_get_root)] if output_keys == ["__root__"] else [ ChannelWriteEntry(key, mapper=partial(_get_state_key, key=key)) for key in output_keys ] ) # add node and output channel if key == START: self.nodes[key] = PregelNode( tags=[TAG_HIDDEN], triggers=[START], channels=[START], writers=[ ChannelWrite( write_entries, tags=[TAG_HIDDEN], require_at_least_one_of=output_keys, ), ], ) elif node is not None: input_schema = node.input if node else self.builder.schema input_values = {k: k for k in self.builder.schemas[input_schema]} is_single_input = len(input_values) == 1 and "__root__" in input_values self.channels[key] = EphemeralValue(Any, guard=False) self.nodes[key] = PregelNode( triggers=[], # read state keys and managed values channels=(list(input_values) if is_single_input else input_values), # coerce state dict to schema class (eg. pydantic model) mapper=( None if is_single_input or issubclass(input_schema, dict) else partial(_coerce_state, input_schema) ), writers=[ # publish to this channel and state keys ChannelWrite( [ChannelWriteEntry(key, key)] + write_entries, tags=[TAG_HIDDEN], ), ], metadata=node.metadata, retry_policy=node.retry_policy, bound=node.runnable, ) else: raise RuntimeError def attach_edge(self, starts: Union[str, Sequence[str]], end: str) -> None: if isinstance(starts, str): if starts == START: channel_name = f"start:{end}" # register channel self.channels[channel_name] = EphemeralValue(Any) # subscribe to channel self.nodes[end].triggers.append(channel_name) # publish to channel self.nodes[START] |= ChannelWrite( [ChannelWriteEntry(channel_name, START)], tags=[TAG_HIDDEN] ) elif end != END: # subscribe to start channel self.nodes[end].triggers.append(starts) elif end != END: channel_name = f"join:{'+'.join(starts)}:{end}" # register channel self.channels[channel_name] = NamedBarrierValue(str, set(starts)) # subscribe to channel self.nodes[end].triggers.append(channel_name) # publish to channel for start in starts: self.nodes[start] |= ChannelWrite( [ChannelWriteEntry(channel_name, start)], tags=[TAG_HIDDEN] ) def attach_branch( self, start: str, name: str, branch: Branch, *, with_reader: bool = True ) -> None: def branch_writer( packets: Sequence[Union[str, Send]], config: RunnableConfig ) -> None: if filtered := [p for p in packets if p != END]: writes = [ ( ChannelWriteEntry(f"branch:{start}:{name}:{p}", start) if not isinstance(p, Send) else p ) for p in filtered ] if branch.then and branch.then != END: writes.append( ChannelWriteEntry( f"branch:{start}:{name}::then", WaitForNames( {p.node if isinstance(p, Send) else p for p in filtered} ), ) ) ChannelWrite.do_write( config, cast(Sequence[Union[Send, ChannelWriteEntry]], writes) ) # attach branch publisher schema = ( self.builder.nodes[start].input if start in self.builder.nodes else self.builder.schema ) self.nodes[start] |= branch.run( branch_writer, _get_state_reader(self.builder, schema) if with_reader else None, ) # attach branch subscribers ends = ( branch.ends.values() if branch.ends else [node for node in self.builder.nodes if node != branch.then] ) for end in ends: if end != END: channel_name = f"branch:{start}:{name}:{end}" self.channels[channel_name] = EphemeralValue(Any, guard=False) self.nodes[end].triggers.append(channel_name) # attach then subscriber if branch.then and branch.then != END: channel_name = f"branch:{start}:{name}::then" self.channels[channel_name] = DynamicBarrierValue(str) self.nodes[branch.then].triggers.append(channel_name) for end in ends: if end != END: self.nodes[end] |= ChannelWrite( [ChannelWriteEntry(channel_name, end)], tags=[TAG_HIDDEN] ) def _get_state_reader( builder: StateGraph, schema: Type[Any] ) -> Callable[[RunnableConfig], Any]: state_keys = list(builder.channels) select = list(builder.schemas[schema]) return partial( ChannelRead.do_read, select=select[0] if select == ["__root__"] else select, fresh=True, # coerce state dict to schema class (eg. pydantic model) mapper=( None if state_keys == ["__root__"] or issubclass(schema, dict) else partial(_coerce_state, schema) ), ) def _coerce_state(schema: Type[Any], input: dict[str, Any]) -> dict[str, Any]: return schema(**input) def _control_branch(value: Any) -> Sequence[Union[str, Send]]: if isinstance(value, Send): return [value] if not isinstance(value, Command): return EMPTY_SEQ if value.graph == Command.PARENT: raise ParentCommand(value) rtn: list[Union[str, Send]] = [] if isinstance(value.goto, Send): rtn.append(value.goto) elif isinstance(value.goto, str): rtn.append(value.goto) else: rtn.extend(value.goto) return rtn async def _acontrol_branch(value: Any) -> Sequence[Union[str, Send]]: if isinstance(value, Send): return [value] if not isinstance(value, Command): return EMPTY_SEQ if value.graph == Command.PARENT: raise ParentCommand(value) rtn: list[Union[str, Send]] = [] if isinstance(value.goto, Send): rtn.append(value.goto) elif isinstance(value.goto, str): rtn.append(value.goto) else: rtn.extend(value.goto) return rtn CONTROL_BRANCH_PATH = RunnableCallable( _control_branch, _acontrol_branch, tags=[TAG_HIDDEN], trace=False, recurse=False ) CONTROL_BRANCH = Branch(CONTROL_BRANCH_PATH, None) def _get_channels( schema: Type[dict], ) -> tuple[dict[str, BaseChannel], dict[str, ManagedValueSpec]]: if not hasattr(schema, "__annotations__"): return {"__root__": _get_channel("__root__", schema, allow_managed=False)}, {} all_keys = { name: _get_channel(name, typ) for name, typ in get_type_hints(schema, include_extras=True).items() if name != "__slots__" } return ( {k: v for k, v in all_keys.items() if isinstance(v, BaseChannel)}, {k: v for k, v in all_keys.items() if is_managed_value(v)}, ) @overload def _get_channel( name: str, annotation: Any, *, allow_managed: Literal[False] ) -> BaseChannel: ... @overload def _get_channel( name: str, annotation: Any, *, allow_managed: Literal[True] = True ) -> Union[BaseChannel, ManagedValueSpec]: ... def _get_channel( name: str, annotation: Any, *, allow_managed: bool = True ) -> Union[BaseChannel, ManagedValueSpec]: if manager := _is_field_managed_value(name, annotation): if allow_managed: return manager else: raise ValueError(f"This {annotation} not allowed in this position") elif channel := _is_field_channel(annotation): channel.key = name return channel elif channel := _is_field_binop(annotation): channel.key = name return channel fallback: LastValue = LastValue(annotation) fallback.key = name return fallback def _is_field_channel(typ: Type[Any]) -> Optional[BaseChannel]: if hasattr(typ, "__metadata__"): meta = typ.__metadata__ if len(meta) >= 1 and isinstance(meta[-1], BaseChannel): return meta[-1] elif len(meta) >= 1 and isclass(meta[-1]) and issubclass(meta[-1], BaseChannel): return meta[-1](typ.__origin__ if hasattr(typ, "__origin__") else typ) return None def _is_field_binop(typ: Type[Any]) -> Optional[BinaryOperatorAggregate]: if hasattr(typ, "__metadata__"): meta = typ.__metadata__ if len(meta) >= 1 and callable(meta[-1]): sig = signature(meta[-1]) params = list(sig.parameters.values()) if len(params) == 2 and all( p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) for p in params ): return BinaryOperatorAggregate(typ, meta[-1]) else: raise ValueError( f"Invalid reducer signature. Expected (a, b) -> c. Got {sig}" ) return None def _is_field_managed_value(name: str, typ: Type[Any]) -> Optional[ManagedValueSpec]: if hasattr(typ, "__metadata__"): meta = typ.__metadata__ if len(meta) >= 1: decoration = get_origin(meta[-1]) or meta[-1] if is_managed_value(decoration): if isinstance(decoration, ConfiguredManagedValue): for k, v in decoration.kwargs.items(): if v is ChannelKeyPlaceholder: decoration.kwargs[k] = name if v is ChannelTypePlaceholder: decoration.kwargs[k] = typ.__origin__ return decoration return None def _get_schema( typ: Type, schemas: dict, channels: dict, name: str, ) -> type[BaseModel]: if isclass(typ) and issubclass(typ, (BaseModel, BaseModelV1)): return typ else: keys = list(schemas[typ].keys()) if len(keys) == 1 and keys[0] == "__root__": return create_model( name, root=(channels[keys[0]].UpdateType, None), ) else: return create_model( name, field_definitions={ k: ( channels[k].UpdateType, ( get_field_default( k, channels[k].UpdateType, typ, ) ), ) for k in schemas[typ] if k in channels and isinstance(channels[k], BaseChannel) }, ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/_api/deprecation.py`: ```py import functools import warnings from typing import Any, Callable, Type, TypeVar, Union, cast class LangGraphDeprecationWarning(DeprecationWarning): pass F = TypeVar("F", bound=Callable[..., Any]) C = TypeVar("C", bound=Type[Any]) def deprecated( since: str, alternative: str, *, removal: str = "", example: str = "" ) -> Callable[[F], F]: def decorator(obj: Union[F, C]) -> Union[F, C]: removal_str = removal if removal else "a future version" message = ( f"{obj.__name__} is deprecated as of version {since} and will be" f" removed in {removal_str}. Use {alternative} instead.{example}" ) if isinstance(obj, type): original_init = obj.__init__ # type: ignore[misc] @functools.wraps(original_init) def new_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore[no-untyped-def] warnings.warn(message, LangGraphDeprecationWarning, stacklevel=2) original_init(self, *args, **kwargs) obj.__init__ = new_init # type: ignore[misc] docstring = ( f"**Deprecated**: This class is deprecated as of version {since}. " f"Use `{alternative}` instead." ) if obj.__doc__: docstring = docstring + f"\n\n{obj.__doc__}" obj.__doc__ = docstring return cast(C, obj) elif callable(obj): @functools.wraps(obj) def wrapper(*args: Any, **kwargs: Any) -> Any: warnings.warn(message, LangGraphDeprecationWarning, stacklevel=2) return obj(*args, **kwargs) docstring = ( f"**Deprecated**: This function is deprecated as of version {since}. " f"Use `{alternative}` instead." ) if obj.__doc__: docstring = docstring + f"\n\n{obj.__doc__}" wrapper.__doc__ = docstring return cast(F, wrapper) else: raise TypeError( f"Can only add deprecation decorator to classes or callables, got '{type(obj)}' instead." ) return decorator def deprecated_parameter( arg_name: str, since: str, alternative: str, *, removal: str ) -> Callable[[F], F]: def decorator(func: F) -> F: @functools.wraps(func) def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] if arg_name in kwargs: warnings.warn( f"Parameter '{arg_name}' in function '{func.__name__}' is " f"deprecated as of version {since} and will be removed in version {removal}. " f"Use '{alternative}' parameter instead.", category=LangGraphDeprecationWarning, stacklevel=2, ) return func(*args, **kwargs) return cast(F, wrapper) return decorator ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/utils/queue.py`: ```py # type: ignore import asyncio import queue import sys import threading import types from collections import deque from time import monotonic from typing import Optional PY_310 = sys.version_info >= (3, 10) class AsyncQueue(asyncio.Queue): """Async unbounded FIFO queue with a wait() method. Subclassed from asyncio.Queue, adding a wait() method.""" async def wait(self) -> None: """If queue is empty, wait until an item is available. Copied from Queue.get(), removing the call to .get_nowait(), ie. this doesn't consume the item, just waits for it. """ while self.empty(): if PY_310: getter = self._get_loop().create_future() else: getter = self._loop.create_future() self._getters.append(getter) try: await getter except: getter.cancel() # Just in case getter is not done yet. try: # Clean self._getters from canceled getters. self._getters.remove(getter) except ValueError: # The getter could be removed from self._getters by a # previous put_nowait call. pass if not self.empty() and not getter.cancelled(): # We were woken up by put_nowait(), but can't take # the call. Wake up the next in line. self._wakeup_next(self._getters) raise class Semaphore(threading.Semaphore): """Semaphore subclass with a wait() method.""" def wait(self, blocking: bool = True, timeout: Optional[float] = None): """Block until the semaphore can be acquired, but don't acquire it.""" if not blocking and timeout is not None: raise ValueError("can't specify timeout for non-blocking acquire") rc = False endtime = None with self._cond: while self._value == 0: if not blocking: break if timeout is not None: if endtime is None: endtime = monotonic() + timeout else: timeout = endtime - monotonic() if timeout <= 0: break self._cond.wait(timeout) else: rc = True return rc class SyncQueue: """Unbounded FIFO queue with a wait() method. Adapted from pure Python implementation of queue.SimpleQueue. """ def __init__(self): self._queue = deque() self._count = Semaphore(0) def put(self, item, block=True, timeout=None): """Put the item on the queue. The optional 'block' and 'timeout' arguments are ignored, as this method never blocks. They are provided for compatibility with the Queue class. """ self._queue.append(item) self._count.release() def get(self, block=True, timeout=None): """Remove and return an item from the queue. If optional args 'block' is true and 'timeout' is None (the default), block if necessary until an item is available. If 'timeout' is a non-negative number, it blocks at most 'timeout' seconds and raises the Empty exception if no item was available within that time. Otherwise ('block' is false), return an item if one is immediately available, else raise the Empty exception ('timeout' is ignored in that case). """ if timeout is not None and timeout < 0: raise ValueError("'timeout' must be a non-negative number") if not self._count.acquire(block, timeout): raise queue.Empty try: return self._queue.popleft() except IndexError: raise queue.Empty def wait(self, block=True, timeout=None): """If queue is empty, wait until an item maybe is available, but don't consume it. """ if timeout is not None and timeout < 0: raise ValueError("'timeout' must be a non-negative number") self._count.wait(block, timeout) def empty(self): """Return True if the queue is empty, False otherwise (not reliable!).""" return len(self._queue) == 0 def qsize(self): """Return the approximate size of the queue (not reliable!).""" return len(self._queue) __class_getitem__ = classmethod(types.GenericAlias) __all__ = ["AsyncQueue", "SyncQueue"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/utils/config.py`: ```py import asyncio import sys from collections import ChainMap from typing import Any, Optional, Sequence, cast from langchain_core.callbacks import ( AsyncCallbackManager, BaseCallbackManager, CallbackManager, Callbacks, ) from langchain_core.runnables import RunnableConfig from langchain_core.runnables.config import ( CONFIG_KEYS, COPIABLE_KEYS, DEFAULT_RECURSION_LIMIT, var_child_runnable_config, ) from langgraph.checkpoint.base import CheckpointMetadata from langgraph.constants import ( CONF, CONFIG_KEY_CHECKPOINT_ID, CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_CHECKPOINT_NS, ) def patch_configurable( config: Optional[RunnableConfig], patch: dict[str, Any] ) -> RunnableConfig: if config is None: return {CONF: patch} elif CONF not in config: return {**config, CONF: patch} else: return {**config, CONF: {**config[CONF], **patch}} def patch_checkpoint_map( config: Optional[RunnableConfig], metadata: Optional[CheckpointMetadata] ) -> RunnableConfig: if config is None: return config elif parents := (metadata.get("parents") if metadata else None): conf = config[CONF] return patch_configurable( config, { CONFIG_KEY_CHECKPOINT_MAP: { **parents, conf[CONFIG_KEY_CHECKPOINT_NS]: conf[CONFIG_KEY_CHECKPOINT_ID], }, }, ) else: return config def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: """Merge multiple configs into one. Args: *configs (Optional[RunnableConfig]): The configs to merge. Returns: RunnableConfig: The merged config. """ base: RunnableConfig = {} # Even though the keys aren't literals, this is correct # because both dicts are the same type for config in configs: if config is None: continue for key, value in config.items(): if not value: continue if key == "metadata": if base_value := base.get(key): base[key] = {**base_value, **value} # type: ignore else: base[key] = value # type: ignore[literal-required] elif key == "tags": if base_value := base.get(key): base[key] = [*base_value, *value] # type: ignore else: base[key] = value # type: ignore[literal-required] elif key == CONF: if base_value := base.get(key): base[key] = {**base_value, **value} # type: ignore[dict-item] else: base[key] = value elif key == "callbacks": base_callbacks = base.get("callbacks") # callbacks can be either None, list[handler] or manager # so merging two callbacks values has 6 cases if isinstance(value, list): if base_callbacks is None: base["callbacks"] = value.copy() elif isinstance(base_callbacks, list): base["callbacks"] = base_callbacks + value else: # base_callbacks is a manager mngr = base_callbacks.copy() for callback in value: mngr.add_handler(callback, inherit=True) base["callbacks"] = mngr elif isinstance(value, BaseCallbackManager): # value is a manager if base_callbacks is None: base["callbacks"] = value.copy() elif isinstance(base_callbacks, list): mngr = value.copy() for callback in base_callbacks: mngr.add_handler(callback, inherit=True) base["callbacks"] = mngr else: # base_callbacks is also a manager base["callbacks"] = base_callbacks.merge(value) else: raise NotImplementedError elif key == "recursion_limit": if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT: base["recursion_limit"] = config["recursion_limit"] else: base[key] = config[key] # type: ignore[literal-required] if CONF not in base: base[CONF] = {} return base def patch_config( config: Optional[RunnableConfig], *, callbacks: Optional[Callbacks] = None, recursion_limit: Optional[int] = None, max_concurrency: Optional[int] = None, run_name: Optional[str] = None, configurable: Optional[dict[str, Any]] = None, ) -> RunnableConfig: """Patch a config with new values. Args: config (Optional[RunnableConfig]): The config to patch. callbacks (Optional[BaseCallbackManager], optional): The callbacks to set. Defaults to None. recursion_limit (Optional[int], optional): The recursion limit to set. Defaults to None. max_concurrency (Optional[int], optional): The max concurrency to set. Defaults to None. run_name (Optional[str], optional): The run name to set. Defaults to None. configurable (Optional[Dict[str, Any]], optional): The configurable to set. Defaults to None. Returns: RunnableConfig: The patched config. """ config = config.copy() if config is not None else {} if callbacks is not None: # If we're replacing callbacks, we need to unset run_name # As that should apply only to the same run as the original callbacks config["callbacks"] = callbacks if "run_name" in config: del config["run_name"] if "run_id" in config: del config["run_id"] if recursion_limit is not None: config["recursion_limit"] = recursion_limit if max_concurrency is not None: config["max_concurrency"] = max_concurrency if run_name is not None: config["run_name"] = run_name if configurable is not None: config[CONF] = {**config.get(CONF, {}), **configurable} return config def get_callback_manager_for_config( config: RunnableConfig, tags: Optional[Sequence[str]] = None ) -> CallbackManager: """Get a callback manager for a config. Args: config (RunnableConfig): The config. Returns: CallbackManager: The callback manager. """ from langchain_core.callbacks.manager import CallbackManager # merge tags all_tags = config.get("tags") if all_tags is not None and tags is not None: all_tags = [*all_tags, *tags] elif tags is not None: all_tags = list(tags) # use existing callbacks if they exist if (callbacks := config.get("callbacks")) and isinstance( callbacks, CallbackManager ): if all_tags: callbacks.add_tags(all_tags) if metadata := config.get("metadata"): callbacks.add_metadata(metadata) return callbacks else: # otherwise create a new manager return CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), inheritable_tags=all_tags, inheritable_metadata=config.get("metadata"), ) def get_async_callback_manager_for_config( config: RunnableConfig, tags: Optional[Sequence[str]] = None, ) -> AsyncCallbackManager: """Get an async callback manager for a config. Args: config (RunnableConfig): The config. Returns: AsyncCallbackManager: The async callback manager. """ from langchain_core.callbacks.manager import AsyncCallbackManager # merge tags all_tags = config.get("tags") if all_tags is not None and tags is not None: all_tags = [*all_tags, *tags] elif tags is not None: all_tags = list(tags) # use existing callbacks if they exist if (callbacks := config.get("callbacks")) and isinstance( callbacks, AsyncCallbackManager ): if all_tags: callbacks.add_tags(all_tags) if metadata := config.get("metadata"): callbacks.add_metadata(metadata) return callbacks else: # otherwise create a new manager return AsyncCallbackManager.configure( inheritable_callbacks=config.get("callbacks"), inheritable_tags=config.get("tags"), inheritable_metadata=config.get("metadata"), ) def ensure_config(*configs: Optional[RunnableConfig]) -> RunnableConfig: """Ensure that a config is a dict with all keys present. Args: config (Optional[RunnableConfig], optional): The config to ensure. Defaults to None. Returns: RunnableConfig: The ensured config. """ empty = RunnableConfig( tags=[], metadata=ChainMap(), callbacks=None, recursion_limit=DEFAULT_RECURSION_LIMIT, configurable={}, ) if var_config := var_child_runnable_config.get(): empty.update( { k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined] for k, v in var_config.items() if v is not None }, ) for config in configs: if config is None: continue for k, v in config.items(): if v is not None and k in CONFIG_KEYS: if k == CONF: empty[k] = cast(dict, v).copy() else: empty[k] = v # type: ignore[literal-required] for k, v in config.items(): if v is not None and k not in CONFIG_KEYS: empty[CONF][k] = v for key, value in empty[CONF].items(): if ( not key.startswith("__") and isinstance(value, (str, int, float, bool)) and key not in empty["metadata"] ): empty["metadata"][key] = value return empty def get_configurable() -> dict[str, Any]: if sys.version_info < (3, 11): try: if asyncio.current_task(): raise RuntimeError( "Python 3.11 or later required to use this in an async context" ) except RuntimeError: pass if var_config := var_child_runnable_config.get(): return var_config[CONF] else: raise RuntimeError("Called get_configurable outside of a runnable context") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/utils/fields.py`: ```py import dataclasses from typing import Any, Optional, Type, Union from typing_extensions import Annotated, NotRequired, ReadOnly, Required, get_origin def _is_optional_type(type_: Any) -> bool: """Check if a type is Optional.""" if hasattr(type_, "__origin__") and hasattr(type_, "__args__"): origin = get_origin(type_) if origin is Optional: return True if origin is Union: return any( arg is type(None) or _is_optional_type(arg) for arg in type_.__args__ ) if origin is Annotated: return _is_optional_type(type_.__args__[0]) return origin is None if hasattr(type_, "__bound__") and type_.__bound__ is not None: return _is_optional_type(type_.__bound__) return type_ is None def _is_required_type(type_: Any) -> Optional[bool]: """Check if an annotation is marked as Required/NotRequired. Returns: - True if required - False if not required - None if not annotated with either """ origin = get_origin(type_) if origin is Required: return True if origin is NotRequired: return False if origin is Annotated or getattr(origin, "__args__", None): # See https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-annotated return _is_required_type(type_.__args__[0]) return None def _is_readonly_type(type_: Any) -> bool: """Check if an annotation is marked as ReadOnly. Returns: - True if is read only - False if not read only """ # See: https://typing.readthedocs.io/en/latest/spec/typeddict.html#typing-readonly-type-qualifier origin = get_origin(type_) if origin is Annotated: return _is_readonly_type(type_.__args__[0]) if origin is ReadOnly: return True return False _DEFAULT_KEYS: frozenset[str] = frozenset() def get_field_default(name: str, type_: Any, schema: Type[Any]) -> Any: """Determine the default value for a field in a state schema. This is based on: If TypedDict: - Required/NotRequired - total=False -> everything optional - Type annotation (Optional/Union[None]) """ optional_keys = getattr(schema, "__optional_keys__", _DEFAULT_KEYS) irq = _is_required_type(type_) if name in optional_keys: # Either total=False or explicit NotRequired. # No type annotation trumps this. if irq: # Unless it's earlier versions of python & explicit Required return ... return None if irq is not None: if irq: # Handle Required[<type>] # (we already handled NotRequired and total=False) return ... # Handle NotRequired[<type>] for earlier versions of python return None if dataclasses.is_dataclass(schema): field_info = next( (f for f in dataclasses.fields(schema) if f.name == name), None ) if field_info: if ( field_info.default is not dataclasses.MISSING and field_info.default is not ... ): return field_info.default elif field_info.default_factory is not dataclasses.MISSING: return field_info.default_factory() # Note, we ignore ReadOnly attributes, # as they don't make much sense. (we don't care if you mutate the state in your node) # and mutating state in your node has no effect on our graph state. # Base case is the annotation if _is_optional_type(type_): return None return ... ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/utils/pydantic.py`: ```py from typing import Any, Dict, Optional, Union from pydantic import BaseModel from pydantic.v1 import BaseModel as BaseModelV1 def create_model( model_name: str, *, field_definitions: Optional[Dict[str, Any]] = None, root: Optional[Any] = None, ) -> Union[BaseModel, BaseModelV1]: """Create a pydantic model with the given field definitions. Args: model_name: The name of the model. field_definitions: The field definitions for the model. root: Type for a root model (RootModel) """ try: # for langchain-core >= 0.3.0 from langchain_core.utils.pydantic import create_model_v2 return create_model_v2( model_name, field_definitions=field_definitions, root=root, ) except ImportError: # for langchain-core < 0.3.0 from langchain_core.runnables.utils import create_model v1_kwargs = {} if root is not None: v1_kwargs["__root__"] = root return create_model(model_name, **v1_kwargs, **(field_definitions or {})) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/utils/runnable.py`: ```py import asyncio import enum import inspect import sys from contextlib import AsyncExitStack from contextvars import copy_context from functools import partial, wraps from typing import ( Any, AsyncIterator, Awaitable, Callable, Coroutine, Iterator, Optional, Sequence, Union, cast, ) from langchain_core.runnables.base import ( Runnable, RunnableConfig, RunnableLambda, RunnableLike, RunnableParallel, RunnableSequence, ) from langchain_core.runnables.config import ( run_in_executor, var_child_runnable_config, ) from langchain_core.runnables.utils import Input from langchain_core.tracers._streaming import _StreamingCallbackHandler from typing_extensions import TypeGuard from langgraph.constants import CONF, CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER from langgraph.store.base import BaseStore from langgraph.types import StreamWriter from langgraph.utils.config import ( ensure_config, get_async_callback_manager_for_config, get_callback_manager_for_config, patch_config, ) try: from langchain_core.runnables.config import _set_config_context except ImportError: # For forwards compatibility def _set_config_context(context: RunnableConfig) -> None: # type: ignore """Set the context for the current thread.""" var_child_runnable_config.set(context) # Before Python 3.11 native StrEnum is not available class StrEnum(str, enum.Enum): """A string enum.""" ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11) KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = ( ( sys.intern("writer"), (StreamWriter, "StreamWriter", inspect.Parameter.empty), CONFIG_KEY_STREAM_WRITER, lambda _: None, ), ( sys.intern("store"), (BaseStore, "BaseStore", inspect.Parameter.empty), CONFIG_KEY_STORE, inspect.Parameter.empty, ), ) """List of kwargs that can be passed to functions, and their corresponding config keys, default values and type annotations.""" VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) class RunnableCallable(Runnable): """A much simpler version of RunnableLambda that requires sync and async functions.""" def __init__( self, func: Optional[Callable[..., Union[Any, Runnable]]], afunc: Optional[Callable[..., Awaitable[Union[Any, Runnable]]]] = None, *, name: Optional[str] = None, tags: Optional[Sequence[str]] = None, trace: bool = True, recurse: bool = True, **kwargs: Any, ) -> None: self.name = name if self.name is None: if func: try: if func.__name__ != "<lambda>": self.name = func.__name__ except AttributeError: pass elif afunc: try: self.name = afunc.__name__ except AttributeError: pass self.func = func self.afunc = afunc self.tags = tags self.kwargs = kwargs self.trace = trace self.recurse = recurse # check signature if func is None and afunc is None: raise ValueError("At least one of func or afunc must be provided.") params = inspect.signature(cast(Callable, func or afunc)).parameters self.func_accepts_config = "config" in params self.func_accepts: dict[str, bool] = {} for kw, typ, _, _ in KWARGS_CONFIG_KEYS: p = params.get(kw) self.func_accepts[kw] = ( p is not None and p.annotation in typ and p.kind in VALID_KINDS ) def __repr__(self) -> str: repr_args = { k: v for k, v in self.__dict__.items() if k not in {"name", "func", "afunc", "config", "kwargs", "trace"} } return f"{self.get_name()}({', '.join(f'{k}={v!r}' for k, v in repr_args.items())})" def invoke( self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: if self.func is None: raise TypeError( f'No synchronous function provided to "{self.name}".' "\nEither initialize with a synchronous function or invoke" " via the async API (ainvoke, astream, etc.)" ) if config is None: config = ensure_config() kwargs = {**self.kwargs, **kwargs} if self.func_accepts_config: kwargs["config"] = config _conf = config[CONF] for kw, _, ck, defv in KWARGS_CONFIG_KEYS: if not self.func_accepts[kw]: continue if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf: raise ValueError( f"Missing required config key '{ck}' for '{self.name}'." ) elif kwargs.get(kw) is None: kwargs[kw] = _conf.get(ck, defv) context = copy_context() if self.trace: callback_manager = get_callback_manager_for_config(config, self.tags) run_manager = callback_manager.on_chain_start( None, input, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) try: child_config = patch_config(config, callbacks=run_manager.get_child()) context = copy_context() context.run(_set_config_context, child_config) ret = context.run(self.func, input, **kwargs) except BaseException as e: run_manager.on_chain_error(e) raise else: run_manager.on_chain_end(ret) else: context.run(_set_config_context, config) ret = context.run(self.func, input, **kwargs) if isinstance(ret, Runnable) and self.recurse: return ret.invoke(input, config) return ret async def ainvoke( self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: if not self.afunc: return self.invoke(input, config) if config is None: config = ensure_config() kwargs = {**self.kwargs, **kwargs} if self.func_accepts_config: kwargs["config"] = config _conf = config[CONF] for kw, _, ck, defv in KWARGS_CONFIG_KEYS: if not self.func_accepts[kw]: continue if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf: raise ValueError( f"Missing required config key '{ck}' for '{self.name}'." ) elif kwargs.get(kw) is None: kwargs[kw] = _conf.get(ck, defv) context = copy_context() if self.trace: callback_manager = get_async_callback_manager_for_config(config, self.tags) run_manager = await callback_manager.on_chain_start( None, input, name=config.get("run_name") or self.name, run_id=config.pop("run_id", None), ) try: child_config = patch_config(config, callbacks=run_manager.get_child()) context.run(_set_config_context, child_config) coro = cast(Coroutine[None, None, Any], self.afunc(input, **kwargs)) if ASYNCIO_ACCEPTS_CONTEXT: ret = await asyncio.create_task(coro, context=context) else: ret = await coro except BaseException as e: await run_manager.on_chain_error(e) raise else: await run_manager.on_chain_end(ret) else: context.run(_set_config_context, config) if ASYNCIO_ACCEPTS_CONTEXT: coro = cast(Coroutine[None, None, Any], self.afunc(input, **kwargs)) ret = await asyncio.create_task(coro, context=context) else: ret = await self.afunc(input, **kwargs) if isinstance(ret, Runnable) and self.recurse: return await ret.ainvoke(input, config) return ret def is_async_callable( func: Any, ) -> TypeGuard[Callable[..., Awaitable]]: """Check if a function is async.""" return ( asyncio.iscoroutinefunction(func) or hasattr(func, "__call__") and asyncio.iscoroutinefunction(func.__call__) ) def is_async_generator( func: Any, ) -> TypeGuard[Callable[..., AsyncIterator]]: """Check if a function is an async generator.""" return ( inspect.isasyncgenfunction(func) or hasattr(func, "__call__") and inspect.isasyncgenfunction(func.__call__) ) def coerce_to_runnable( thing: RunnableLike, *, name: Optional[str], trace: bool ) -> Runnable: """Coerce a runnable-like object into a Runnable. Args: thing: A runnable-like object. Returns: A Runnable. """ if isinstance(thing, Runnable): return thing elif is_async_generator(thing) or inspect.isgeneratorfunction(thing): return RunnableLambda(thing, name=name) elif callable(thing): if is_async_callable(thing): return RunnableCallable(None, thing, name=name, trace=trace) else: return RunnableCallable( thing, wraps(thing)(partial(run_in_executor, None, thing)), # type: ignore[arg-type] name=name, trace=trace, ) elif isinstance(thing, dict): return RunnableParallel(thing) else: raise TypeError( f"Expected a Runnable, callable or dict." f"Instead got an unsupported type: {type(thing)}" ) class RunnableSeq(Runnable): """A simpler version of RunnableSequence.""" def __init__( self, *steps: RunnableLike, name: Optional[str] = None, ) -> None: """Create a new RunnableSequence. Args: steps: The steps to include in the sequence. name: The name of the Runnable. Defaults to None. first: The first Runnable in the sequence. Defaults to None. middle: The middle Runnables in the sequence. Defaults to None. last: The last Runnable in the sequence. Defaults to None. Raises: ValueError: If the sequence has less than 2 steps. """ steps_flat: list[Runnable] = [] for step in steps: if isinstance(step, RunnableSequence): steps_flat.extend(step.steps) elif isinstance(step, RunnableSeq): steps_flat.extend(step.steps) else: steps_flat.append(coerce_to_runnable(step, name=None, trace=True)) if len(steps_flat) < 2: raise ValueError( f"RunnableSeq must have at least 2 steps, got {len(steps_flat)}" ) self.steps = steps_flat self.name = name def __or__( self, other: Any, ) -> Runnable: if isinstance(other, RunnableSequence): return RunnableSeq( *self.steps, other.first, *other.middle, other.last, name=self.name or other.name, ) elif isinstance(other, RunnableSeq): return RunnableSeq( *self.steps, *other.steps, name=self.name or other.name, ) else: return RunnableSeq( *self.steps, coerce_to_runnable(other, name=None, trace=True), name=self.name, ) def __ror__( self, other: Any, ) -> Runnable: if isinstance(other, RunnableSequence): return RunnableSequence( other.first, *other.middle, other.last, *self.steps, name=other.name or self.name, ) elif isinstance(other, RunnableSeq): return RunnableSeq( *other.steps, *self.steps, name=other.name or self.name, ) else: return RunnableSequence( coerce_to_runnable(other, name=None, trace=True), *self.steps, name=self.name, ) def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: if config is None: config = ensure_config() # setup callbacks and context callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( None, input, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) # invoke all steps in sequence try: for i, step in enumerate(self.steps): # mark each step as a child run config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{i+1}") ) context = copy_context() context.run(_set_config_context, config) if i == 0: input = context.run(step.invoke, input, config, **kwargs) else: input = context.run(step.invoke, input, config) # finish the root run except BaseException as e: run_manager.on_chain_error(e) raise else: run_manager.on_chain_end(input) return input async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Any: if config is None: config = ensure_config() # setup callbacks callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( None, input, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) # invoke all steps in sequence try: for i, step in enumerate(self.steps): # mark each step as a child run config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{i+1}") ) context = copy_context() context.run(_set_config_context, config) if i == 0: coro = step.ainvoke(input, config, **kwargs) else: coro = step.ainvoke(input, config) if ASYNCIO_ACCEPTS_CONTEXT: input = await asyncio.create_task(coro, context=context) else: input = await asyncio.create_task(coro) # finish the root run except BaseException as e: await run_manager.on_chain_error(e) raise else: await run_manager.on_chain_end(input) return input def stream( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Any]: if config is None: config = ensure_config() # setup callbacks callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( None, input, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) try: # stream the last steps # transform the input stream of each step with the next # steps that don't natively support transforming an input stream will # buffer input in memory until all available, and then start emitting output for idx, step in enumerate(self.steps): config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{idx+1}"), ) if idx == 0: iterator = step.stream(input, config, **kwargs) else: iterator = step.transform(iterator, config) if stream_handler := next( ( cast(_StreamingCallbackHandler, h) for h in run_manager.handlers if isinstance(h, _StreamingCallbackHandler) ), None, ): # populates streamed_output in astream_log() output if needed iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator) output: Any = None add_supported = False for chunk in iterator: yield chunk # collect final output if output is None: output = chunk elif add_supported: try: output = output + chunk except TypeError: output = chunk add_supported = False else: output = chunk except BaseException as e: run_manager.on_chain_error(e) raise else: run_manager.on_chain_end(output) async def astream( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Any]: if config is None: config = ensure_config() # setup callbacks callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( None, input, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) try: async with AsyncExitStack() as stack: # stream the last steps # transform the input stream of each step with the next # steps that don't natively support transforming an input stream will # buffer input in memory until all available, and then start emitting output for idx, step in enumerate(self.steps): config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{idx+1}"), ) if idx == 0: aiterator = step.astream(input, config, **kwargs) else: aiterator = step.atransform(aiterator, config) if hasattr(aiterator, "aclose"): stack.push_async_callback(aiterator.aclose) if stream_handler := next( ( cast(_StreamingCallbackHandler, h) for h in run_manager.handlers if isinstance(h, _StreamingCallbackHandler) ), None, ): # populates streamed_output in astream_log() output if needed aiterator = stream_handler.tap_output_aiter( run_manager.run_id, aiterator ) output: Any = None add_supported = False async for chunk in aiterator: yield chunk # collect final output if add_supported: try: output = output + chunk except TypeError: output = chunk add_supported = False else: output = chunk except BaseException as e: await run_manager.on_chain_error(e) raise else: await run_manager.on_chain_end(output) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/types.py`: ```py import dataclasses import sys from collections import deque from typing import ( TYPE_CHECKING, Any, Callable, ClassVar, Generic, Hashable, Literal, NamedTuple, Optional, Sequence, Type, TypedDict, TypeVar, Union, cast, ) from langchain_core.runnables import Runnable, RunnableConfig from typing_extensions import Self from langgraph.checkpoint.base import ( BaseCheckpointSaver, CheckpointMetadata, PendingWrite, ) if TYPE_CHECKING: from langgraph.store.base import BaseStore All = Literal["*"] """Special value to indicate that graph should interrupt on all nodes.""" Checkpointer = Union[None, Literal[False], BaseCheckpointSaver] """Type of the checkpointer to use for a subgraph. False disables checkpointing, even if the parent graph has a checkpointer. None inherits checkpointer.""" StreamMode = Literal["values", "updates", "debug", "messages", "custom"] """How the stream method should emit outputs. - 'values': Emit all values of the state for each step. - 'updates': Emit only the node name(s) and updates that were returned by the node(s) **after** each step. - 'debug': Emit debug events for each step. - 'messages': Emit LLM messages token-by-token. - 'custom': Emit custom output `write: StreamWriter` kwarg of each node. """ StreamWriter = Callable[[Any], None] """Callable that accepts a single argument and writes it to the output stream. Always injected into nodes if requested as a keyword argument, but it's a no-op when not using stream_mode="custom".""" if sys.version_info >= (3, 10): _DC_KWARGS = {"kw_only": True, "slots": True, "frozen": True} else: _DC_KWARGS = {"frozen": True} def default_retry_on(exc: Exception) -> bool: import httpx import requests if isinstance(exc, ConnectionError): return True if isinstance( exc, ( ValueError, TypeError, ArithmeticError, ImportError, LookupError, NameError, SyntaxError, RuntimeError, ReferenceError, StopIteration, StopAsyncIteration, OSError, ), ): return False if isinstance(exc, httpx.HTTPStatusError): return 500 <= exc.response.status_code < 600 if isinstance(exc, requests.HTTPError): return 500 <= exc.response.status_code < 600 if exc.response else True return True class RetryPolicy(NamedTuple): """Configuration for retrying nodes.""" initial_interval: float = 0.5 """Amount of time that must elapse before the first retry occurs. In seconds.""" backoff_factor: float = 2.0 """Multiplier by which the interval increases after each retry.""" max_interval: float = 128.0 """Maximum amount of time that may elapse between retries. In seconds.""" max_attempts: int = 3 """Maximum number of attempts to make before giving up, including the first.""" jitter: bool = True """Whether to add random jitter to the interval between retries.""" retry_on: Union[ Type[Exception], Sequence[Type[Exception]], Callable[[Exception], bool] ] = default_retry_on """List of exception classes that should trigger a retry, or a callable that returns True for exceptions that should trigger a retry.""" class CachePolicy(NamedTuple): """Configuration for caching nodes.""" pass @dataclasses.dataclass(**_DC_KWARGS) class Interrupt: value: Any resumable: bool = False ns: Optional[Sequence[str]] = None when: Literal["during"] = "during" class PregelTask(NamedTuple): id: str name: str path: tuple[Union[str, int, tuple], ...] error: Optional[Exception] = None interrupts: tuple[Interrupt, ...] = () state: Union[None, RunnableConfig, "StateSnapshot"] = None result: Optional[dict[str, Any]] = None class PregelExecutableTask(NamedTuple): name: str input: Any proc: Runnable writes: deque[tuple[str, Any]] config: RunnableConfig triggers: list[str] retry_policy: Optional[RetryPolicy] cache_policy: Optional[CachePolicy] id: str path: tuple[Union[str, int, tuple], ...] scheduled: bool = False writers: Sequence[Runnable] = () class StateSnapshot(NamedTuple): """Snapshot of the state of the graph at the beginning of a step.""" values: Union[dict[str, Any], Any] """Current values of channels""" next: tuple[str, ...] """The name of the node to execute in each task for this step.""" config: RunnableConfig """Config used to fetch this snapshot""" metadata: Optional[CheckpointMetadata] """Metadata associated with this snapshot""" created_at: Optional[str] """Timestamp of snapshot creation""" parent_config: Optional[RunnableConfig] """Config used to fetch the parent snapshot, if any""" tasks: tuple[PregelTask, ...] """Tasks to execute in this step. If already attempted, may contain an error.""" class Send: """A message or packet to send to a specific node in the graph. The `Send` class is used within a `StateGraph`'s conditional edges to dynamically invoke a node with a custom state at the next step. Importantly, the sent state can differ from the core graph's state, allowing for flexible and dynamic workflow management. One such example is a "map-reduce" workflow where your graph invokes the same node multiple times in parallel with different states, before aggregating the results back into the main graph's state. Attributes: node (str): The name of the target node to send the message to. arg (Any): The state or message to send to the target node. Examples: >>> from typing import Annotated >>> import operator >>> class OverallState(TypedDict): ... subjects: list[str] ... jokes: Annotated[list[str], operator.add] ... >>> from langgraph.types import Send >>> from langgraph.graph import END, START >>> def continue_to_jokes(state: OverallState): ... return [Send("generate_joke", {"subject": s}) for s in state['subjects']] ... >>> from langgraph.graph import StateGraph >>> builder = StateGraph(OverallState) >>> builder.add_node("generate_joke", lambda state: {"jokes": [f"Joke about {state['subject']}"]}) >>> builder.add_conditional_edges(START, continue_to_jokes) >>> builder.add_edge("generate_joke", END) >>> graph = builder.compile() >>> >>> # Invoking with two subjects results in a generated joke for each >>> graph.invoke({"subjects": ["cats", "dogs"]}) {'subjects': ['cats', 'dogs'], 'jokes': ['Joke about cats', 'Joke about dogs']} """ __slots__ = ("node", "arg") node: str arg: Any def __init__(self, /, node: str, arg: Any) -> None: """ Initialize a new instance of the Send class. Args: node (str): The name of the target node to send the message to. arg (Any): The state or message to send to the target node. """ self.node = node self.arg = arg def __hash__(self) -> int: return hash((self.node, self.arg)) def __repr__(self) -> str: return f"Send(node={self.node!r}, arg={self.arg!r})" def __eq__(self, value: object) -> bool: return ( isinstance(value, Send) and self.node == value.node and self.arg == value.arg ) N = TypeVar("N", bound=Hashable) @dataclasses.dataclass(**_DC_KWARGS) class Command(Generic[N]): """One or more commands to update the graph's state and send messages to nodes. Args: graph: graph to send the command to. Supported values are: - None: the current graph (default) - GraphCommand.PARENT: closest parent graph update: update to apply to the graph's state. resume: value to resume execution with. To be used together with [`interrupt()`][langgraph.types.interrupt]. goto: can be one of the following: - name of the node to navigate to next (any node that belongs to the specified `graph`) - sequence of node names to navigate to next - `Send` object (to execute a node with the input provided) - sequence of `Send` objects """ graph: Optional[str] = None update: Optional[dict[str, Any]] = None resume: Optional[Union[Any, dict[str, Any]]] = None goto: Union[Send, Sequence[Union[Send, str]], str] = () def __repr__(self) -> str: # get all non-None values contents = ", ".join( f"{key}={value!r}" for key, value in dataclasses.asdict(self).items() if value ) return f"Command({contents})" PARENT: ClassVar[Literal["__parent__"]] = "__parent__" StreamChunk = tuple[tuple[str, ...], str, Any] class StreamProtocol: __slots__ = ("modes", "__call__") modes: set[StreamMode] __call__: Callable[[Self, StreamChunk], None] def __init__( self, __call__: Callable[[StreamChunk], None], modes: set[StreamMode], ) -> None: self.__call__ = cast(Callable[[Self, StreamChunk], None], __call__) self.modes = modes class LoopProtocol: config: RunnableConfig store: Optional["BaseStore"] stream: Optional[StreamProtocol] step: int stop: int def __init__( self, *, step: int, stop: int, config: RunnableConfig, store: Optional["BaseStore"] = None, stream: Optional[StreamProtocol] = None, ) -> None: self.stream = stream self.config = config self.store = store self.step = step self.stop = stop class PregelScratchpad(TypedDict, total=False): interrupt_counter: int used_null_resume: bool resume: list[Any] def interrupt(value: Any) -> Any: from langgraph.constants import ( CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_SCRATCHPAD, CONFIG_KEY_SEND, CONFIG_KEY_TASK_ID, CONFIG_KEY_WRITES, NS_SEP, NULL_TASK_ID, RESUME, ) from langgraph.errors import GraphInterrupt from langgraph.utils.config import get_configurable conf = get_configurable() # track interrupt index scratchpad: PregelScratchpad = conf[CONFIG_KEY_SCRATCHPAD] if "interrupt_counter" not in scratchpad: scratchpad["interrupt_counter"] = 0 else: scratchpad["interrupt_counter"] += 1 idx = scratchpad["interrupt_counter"] # find previous resume values task_id = conf[CONFIG_KEY_TASK_ID] writes: list[PendingWrite] = conf[CONFIG_KEY_WRITES] scratchpad.setdefault( "resume", next((w[2] for w in writes if w[0] == task_id and w[1] == RESUME), []) ) if scratchpad["resume"]: if idx < len(scratchpad["resume"]): return scratchpad["resume"][idx] # find current resume value if not scratchpad.get("used_null_resume"): scratchpad["used_null_resume"] = True for tid, c, v in sorted(writes, key=lambda x: x[0], reverse=True): if tid == NULL_TASK_ID and c == RESUME: assert len(scratchpad["resume"]) == idx, (scratchpad["resume"], idx) scratchpad["resume"].append(v) conf[CONFIG_KEY_SEND]([(RESUME, scratchpad["resume"])]) return v # no resume value found raise GraphInterrupt( ( Interrupt( value=value, resumable=True, ns=cast(str, conf[CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP), ), ) ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/prebuilt/__init__.py`: ```py """langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools.""" from langgraph.prebuilt.chat_agent_executor import create_react_agent from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation from langgraph.prebuilt.tool_node import ( InjectedState, InjectedStore, ToolNode, tools_condition, ) from langgraph.prebuilt.tool_validator import ValidationNode __all__ = [ "create_react_agent", "ToolExecutor", "ToolInvocation", "ToolNode", "tools_condition", "ValidationNode", "InjectedState", "InjectedStore", ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/prebuilt/tool_validator.py`: ```py """This module provides a ValidationNode class that can be used to validate tool calls in a langchain graph. It applies a pydantic schema to tool_calls in the models' outputs, and returns a ToolMessage with the validated content. If the schema is not valid, it returns a ToolMessage with the error message. The ValidationNode can be used in a StateGraph with a "messages" key or in a MessageGraph. If multiple tool calls are requested, they will be run in parallel. """ from typing import ( Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union, cast, ) from langchain_core.messages import ( AIMessage, AnyMessage, ToolCall, ToolMessage, ) from langchain_core.runnables import ( RunnableConfig, ) from langchain_core.runnables.config import get_executor_for_config from langchain_core.tools import BaseTool, create_schema_from_function from pydantic import BaseModel, ValidationError from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 from langgraph.utils.runnable import RunnableCallable def _default_format_error( error: BaseException, call: ToolCall, schema: Union[Type[BaseModel], Type[BaseModelV1]], ) -> str: """Default error formatting function.""" return f"{repr(error)}\n\nRespond after fixing all validation errors." class ValidationNode(RunnableCallable): """A node that validates all tools requests from the last AIMessage. It can be used either in StateGraph with a "messages" key or in MessageGraph. !!! note This node does not actually **run** the tools, it only validates the tool calls, which is useful for extraction and other use cases where you need to generate structured output that conforms to a complex schema without losing the original messages and tool IDs (for use in multi-turn conversations). Args: schemas: A list of schemas to validate the tool calls with. These can be any of the following: - A pydantic BaseModel class - A BaseTool instance (the args_schema will be used) - A function (a schema will be created from the function signature) format_error: A function that takes an exception, a ToolCall, and a schema and returns a formatted error string. By default, it returns the exception repr and a message to respond after fixing validation errors. name: The name of the node. tags: A list of tags to add to the node. Returns: (Union[Dict[str, List[ToolMessage]], Sequence[ToolMessage]]): A list of ToolMessages with the validated content or error messages. Examples: Example usage for re-prompting the model to generate a valid response: >>> from typing import Literal, Annotated, TypedDict ... >>> from langchain_anthropic import ChatAnthropic >>> from pydantic import BaseModel, validator ... >>> from langgraph.graph import END, START, StateGraph >>> from langgraph.prebuilt import ValidationNode >>> from langgraph.graph.message import add_messages ... ... >>> class SelectNumber(BaseModel): ... a: int ... ... @validator("a") ... def a_must_be_meaningful(cls, v): ... if v != 37: ... raise ValueError("Only 37 is allowed") ... return v ... ... >>> class State(TypedDict): ... messages: Annotated[list, add_messages] ... >>> builder = StateGraph(State) >>> llm = ChatAnthropic(model="claude-3-haiku-20240307").bind_tools([SelectNumber]) >>> builder.add_node("model", llm) >>> builder.add_node("validation", ValidationNode([SelectNumber])) >>> builder.add_edge(START, "model") ... ... >>> def should_validate(state: list) -> Literal["validation", "__end__"]: ... if state[-1].tool_calls: ... return "validation" ... return END ... ... >>> builder.add_conditional_edges("model", should_validate) ... ... >>> def should_reprompt(state: list) -> Literal["model", "__end__"]: ... for msg in state[::-1]: ... # None of the tool calls were errors ... if msg.type == "ai": ... return END ... if msg.additional_kwargs.get("is_error"): ... return "model" ... return END ... ... >>> builder.add_conditional_edges("validation", should_reprompt) ... ... >>> graph = builder.compile() >>> res = graph.invoke(("user", "Select a number, any number")) >>> # Show the retry logic >>> for msg in res: ... msg.pretty_print() ================================ Human Message ================================= Select a number, any number ================================== Ai Message ================================== [{'id': 'toolu_01JSjT9Pq8hGmTgmMPc6KnvM', 'input': {'a': 42}, 'name': 'SelectNumber', 'type': 'tool_use'}] Tool Calls: SelectNumber (toolu_01JSjT9Pq8hGmTgmMPc6KnvM) Call ID: toolu_01JSjT9Pq8hGmTgmMPc6KnvM Args: a: 42 ================================= Tool Message ================================= Name: SelectNumber ValidationError(model='SelectNumber', errors=[{'loc': ('a',), 'msg': 'Only 37 is allowed', 'type': 'value_error'}]) Respond after fixing all validation errors. ================================== Ai Message ================================== [{'id': 'toolu_01PkxSVxNxc5wqwCPW1FiSmV', 'input': {'a': 37}, 'name': 'SelectNumber', 'type': 'tool_use'}] Tool Calls: SelectNumber (toolu_01PkxSVxNxc5wqwCPW1FiSmV) Call ID: toolu_01PkxSVxNxc5wqwCPW1FiSmV Args: a: 37 ================================= Tool Message ================================= Name: SelectNumber {"a": 37} """ def __init__( self, schemas: Sequence[Union[BaseTool, Type[BaseModel], Callable]], *, format_error: Optional[ Callable[[BaseException, ToolCall, Type[BaseModel]], str] ] = None, name: str = "validation", tags: Optional[list[str]] = None, ) -> None: super().__init__(self._func, None, name=name, tags=tags, trace=False) self._format_error = format_error or _default_format_error self.schemas_by_name: Dict[str, Type[BaseModel]] = {} for schema in schemas: if isinstance(schema, BaseTool): if schema.args_schema is None: raise ValueError( f"Tool {schema.name} does not have an args_schema defined." ) self.schemas_by_name[schema.name] = schema.args_schema elif isinstance(schema, type) and issubclass( schema, (BaseModel, BaseModelV1) ): self.schemas_by_name[schema.__name__] = cast(Type[BaseModel], schema) elif callable(schema): base_model = create_schema_from_function("Validation", schema) self.schemas_by_name[schema.__name__] = base_model else: raise ValueError( f"Unsupported input to ValidationNode. Expected BaseModel, tool or function. Got: {type(schema)}." ) def _get_message( self, input: Union[list[AnyMessage], dict[str, Any]] ) -> Tuple[str, AIMessage]: """Extract the last AIMessage from the input.""" if isinstance(input, list): output_type = "list" messages: list = input elif messages := input.get("messages", []): output_type = "dict" else: raise ValueError("No message found in input") message: AnyMessage = messages[-1] if not isinstance(message, AIMessage): raise ValueError("Last message is not an AIMessage") return output_type, message def _func( self, input: Union[list[AnyMessage], dict[str, Any]], config: RunnableConfig ) -> Any: """Validate and run tool calls synchronously.""" output_type, message = self._get_message(input) def run_one(call: ToolCall) -> ToolMessage: schema = self.schemas_by_name[call["name"]] try: if issubclass(schema, BaseModel): output = schema.model_validate(call["args"]) content = output.model_dump_json() elif issubclass(schema, BaseModelV1): output = schema.validate(call["args"]) content = output.json() else: raise ValueError( f"Unsupported schema type: {type(schema)}. Expected BaseModel or BaseModelV1." ) return ToolMessage( content=content, name=call["name"], tool_call_id=cast(str, call["id"]), ) except (ValidationError, ValidationErrorV1) as e: return ToolMessage( content=self._format_error(e, call, schema), name=call["name"], tool_call_id=cast(str, call["id"]), additional_kwargs={"is_error": True}, ) with get_executor_for_config(config) as executor: outputs = [*executor.map(run_one, message.tool_calls)] if output_type == "list": return outputs else: return {"messages": outputs} ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py`: ```py from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union, cast from langchain_core.language_models import BaseChatModel, LanguageModelLike from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage from langchain_core.runnables import ( Runnable, RunnableBinding, RunnableConfig, ) from langchain_core.tools import BaseTool from typing_extensions import Annotated, TypedDict from langgraph._api.deprecation import deprecated_parameter from langgraph.errors import ErrorCode, create_error_message from langgraph.graph import StateGraph from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import add_messages from langgraph.managed import IsLastStep, RemainingSteps from langgraph.prebuilt.tool_executor import ToolExecutor from langgraph.prebuilt.tool_node import ToolNode from langgraph.store.base import BaseStore from langgraph.types import Checkpointer from langgraph.utils.runnable import RunnableCallable # We create the AgentState that we will pass around # This simply involves a list of messages # We want steps to return messages to append to the list # So we annotate the messages attribute with operator.add class AgentState(TypedDict): """The state of the agent.""" messages: Annotated[Sequence[BaseMessage], add_messages] is_last_step: IsLastStep remaining_steps: RemainingSteps StateSchema = TypeVar("StateSchema", bound=AgentState) StateSchemaType = Type[StateSchema] STATE_MODIFIER_RUNNABLE_NAME = "StateModifier" MessagesModifier = Union[ SystemMessage, str, Callable[[Sequence[BaseMessage]], Sequence[BaseMessage]], Runnable[Sequence[BaseMessage], Sequence[BaseMessage]], ] StateModifier = Union[ SystemMessage, str, Callable[[StateSchema], Sequence[BaseMessage]], Runnable[StateSchema, Sequence[BaseMessage]], ] def _get_state_modifier_runnable( state_modifier: Optional[StateModifier], store: Optional[BaseStore] = None ) -> Runnable: state_modifier_runnable: Runnable if state_modifier is None: state_modifier_runnable = RunnableCallable( lambda state: state["messages"], name=STATE_MODIFIER_RUNNABLE_NAME ) elif isinstance(state_modifier, str): _system_message: BaseMessage = SystemMessage(content=state_modifier) state_modifier_runnable = RunnableCallable( lambda state: [_system_message] + state["messages"], name=STATE_MODIFIER_RUNNABLE_NAME, ) elif isinstance(state_modifier, SystemMessage): state_modifier_runnable = RunnableCallable( lambda state: [state_modifier] + state["messages"], name=STATE_MODIFIER_RUNNABLE_NAME, ) elif callable(state_modifier): state_modifier_runnable = RunnableCallable( state_modifier, name=STATE_MODIFIER_RUNNABLE_NAME, ) elif isinstance(state_modifier, Runnable): state_modifier_runnable = state_modifier else: raise ValueError( f"Got unexpected type for `state_modifier`: {type(state_modifier)}" ) return state_modifier_runnable def _convert_messages_modifier_to_state_modifier( messages_modifier: MessagesModifier, ) -> StateModifier: state_modifier: StateModifier if isinstance(messages_modifier, (str, SystemMessage)): return messages_modifier elif callable(messages_modifier): def state_modifier(state: AgentState) -> Sequence[BaseMessage]: return messages_modifier(state["messages"]) return state_modifier elif isinstance(messages_modifier, Runnable): state_modifier = (lambda state: state["messages"]) | messages_modifier return state_modifier raise ValueError( f"Got unexpected type for `messages_modifier`: {type(messages_modifier)}" ) def _get_model_preprocessing_runnable( state_modifier: Optional[StateModifier], messages_modifier: Optional[MessagesModifier], store: Optional[BaseStore], ) -> Runnable: # Add the state or message modifier, if exists if state_modifier is not None and messages_modifier is not None: raise ValueError( "Expected value for either state_modifier or messages_modifier, got values for both" ) if state_modifier is None and messages_modifier is not None: state_modifier = _convert_messages_modifier_to_state_modifier(messages_modifier) return _get_state_modifier_runnable(state_modifier, store) def _should_bind_tools(model: LanguageModelLike, tools: Sequence[BaseTool]) -> bool: if not isinstance(model, RunnableBinding): return True if "tools" not in model.kwargs: return True bound_tools = model.kwargs["tools"] if len(tools) != len(bound_tools): raise ValueError( "Number of tools in the model.bind_tools() and tools passed to create_react_agent must match" ) tool_names = set(tool.name for tool in tools) bound_tool_names = set() for bound_tool in bound_tools: # OpenAI-style tool if bound_tool.get("type") == "function": bound_tool_name = bound_tool["function"]["name"] # Anthropic-style tool elif bound_tool.get("name"): bound_tool_name = bound_tool["name"] else: # unknown tool type so we'll ignore it continue bound_tool_names.add(bound_tool_name) if missing_tools := tool_names - bound_tool_names: raise ValueError(f"Missing tools '{missing_tools}' in the model.bind_tools()") return False def _validate_chat_history( messages: Sequence[BaseMessage], ) -> None: """Validate that all tool calls in AIMessages have a corresponding ToolMessage.""" all_tool_calls = [ tool_call for message in messages if isinstance(message, AIMessage) for tool_call in message.tool_calls ] tool_call_ids_with_results = { message.tool_call_id for message in messages if isinstance(message, ToolMessage) } tool_calls_without_results = [ tool_call for tool_call in all_tool_calls if tool_call["id"] not in tool_call_ids_with_results ] if not tool_calls_without_results: return error_message = create_error_message( message="Found AIMessages with tool_calls that do not have a corresponding ToolMessage. " f"Here are the first few of those tool calls: {tool_calls_without_results[:3]}.\n\n" "Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage " "(result of a tool invocation to return to the LLM) - this is required by most LLM providers.", error_code=ErrorCode.INVALID_CHAT_HISTORY, ) raise ValueError(error_message) @deprecated_parameter("messages_modifier", "0.1.9", "state_modifier", removal="0.3.0") def create_react_agent( model: LanguageModelLike, tools: Union[ToolExecutor, Sequence[BaseTool], ToolNode], *, state_schema: Optional[StateSchemaType] = None, messages_modifier: Optional[MessagesModifier] = None, state_modifier: Optional[StateModifier] = None, checkpointer: Optional[Checkpointer] = None, store: Optional[BaseStore] = None, interrupt_before: Optional[list[str]] = None, interrupt_after: Optional[list[str]] = None, debug: bool = False, ) -> CompiledGraph: """Creates a graph that works with a chat model that utilizes tool calling. Args: model: The `LangChain` chat model that supports tool calling. tools: A list of tools, a ToolExecutor, or a ToolNode instance. If an empty list is provided, the agent will consist of a single LLM node without tool calling. state_schema: An optional state schema that defines graph state. Must have `messages` and `is_last_step` keys. Defaults to `AgentState` that defines those two keys. messages_modifier: An optional messages modifier. This applies to messages BEFORE they are passed into the LLM. Can take a few different forms: - SystemMessage: this is added to the beginning of the list of messages. - str: This is converted to a SystemMessage and added to the beginning of the list of messages. - Callable: This function should take in a list of messages and the output is then passed to the language model. - Runnable: This runnable should take in a list of messages and the output is then passed to the language model. !!! Warning `messages_modifier` parameter is deprecated as of version 0.1.9 and will be removed in 0.2.0 state_modifier: An optional state modifier. This takes full graph state BEFORE the LLM is called and prepares the input to LLM. Can take a few different forms: - SystemMessage: this is added to the beginning of the list of messages in state["messages"]. - str: This is converted to a SystemMessage and added to the beginning of the list of messages in state["messages"]. - Callable: This function should take in full graph state and the output is then passed to the language model. - Runnable: This runnable should take in full graph state and the output is then passed to the language model. checkpointer: An optional checkpoint saver object. This is used for persisting the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation). store: An optional store object. This is used for persisting data across multiple threads (e.g., multiple conversations / users). interrupt_before: An optional list of node names to interrupt before. Should be one of the following: "agent", "tools". This is useful if you want to add a user confirmation or other interrupt before taking an action. interrupt_after: An optional list of node names to interrupt after. Should be one of the following: "agent", "tools". This is useful if you want to return directly or run additional processing on an output. debug: A flag indicating whether to enable debug mode. Returns: A compiled LangChain runnable that can be used for chat interactions. The resulting graph looks like this: ``` mermaid stateDiagram-v2 [*] --> Start Start --> Agent Agent --> Tools : continue Tools --> Agent Agent --> End : end End --> [*] classDef startClass fill:#ffdfba; classDef endClass fill:#baffc9; classDef otherClass fill:#fad7de; class Start startClass class End endClass class Agent,Tools otherClass ``` The "agent" node calls the language model with the messages list (after applying the messages modifier). If the resulting AIMessage contains `tool_calls`, the graph will then call the ["tools"][langgraph.prebuilt.tool_node.ToolNode]. The "tools" node executes the tools (1 tool per `tool_call`) and adds the responses to the messages list as `ToolMessage` objects. The agent node then calls the language model again. The process repeats until no more `tool_calls` are present in the response. The agent then returns the full list of messages as a dictionary containing the key "messages". ``` mermaid sequenceDiagram participant U as User participant A as Agent (LLM) participant T as Tools U->>A: Initial input Note over A: Messages modifier + LLM loop while tool_calls present A->>T: Execute tools T-->>A: ToolMessage for each tool_calls end A->>U: Return final state ``` Examples: Use with a simple tool: ```pycon >>> from datetime import datetime >>> from langchain_openai import ChatOpenAI >>> from langgraph.prebuilt import create_react_agent ... def check_weather(location: str, at_time: datetime | None = None) -> str: ... '''Return the weather forecast for the specified location.''' ... return f"It's always sunny in {location}" >>> >>> tools = [check_weather] >>> model = ChatOpenAI(model="gpt-4o") >>> graph = create_react_agent(model, tools=tools) >>> inputs = {"messages": [("user", "what is the weather in sf")]} >>> for s in graph.stream(inputs, stream_mode="values"): ... message = s["messages"][-1] ... if isinstance(message, tuple): ... print(message) ... else: ... message.pretty_print() ('user', 'what is the weather in sf') ================================== Ai Message ================================== Tool Calls: check_weather (call_LUzFvKJRuaWQPeXvBOzwhQOu) Call ID: call_LUzFvKJRuaWQPeXvBOzwhQOu Args: location: San Francisco ================================= Tool Message ================================= Name: check_weather It's always sunny in San Francisco ================================== Ai Message ================================== The weather in San Francisco is sunny. ``` Add a system prompt for the LLM: ```pycon >>> system_prompt = "You are a helpful bot named Fred." >>> graph = create_react_agent(model, tools, state_modifier=system_prompt) >>> inputs = {"messages": [("user", "What's your name? And what's the weather in SF?")]} >>> for s in graph.stream(inputs, stream_mode="values"): ... message = s["messages"][-1] ... if isinstance(message, tuple): ... print(message) ... else: ... message.pretty_print() ('user', "What's your name? And what's the weather in SF?") ================================== Ai Message ================================== Hi, my name is Fred. Let me check the weather in San Francisco for you. Tool Calls: check_weather (call_lqhj4O0hXYkW9eknB4S41EXk) Call ID: call_lqhj4O0hXYkW9eknB4S41EXk Args: location: San Francisco ================================= Tool Message ================================= Name: check_weather It's always sunny in San Francisco ================================== Ai Message ================================== The weather in San Francisco is currently sunny. If you need any more details or have other questions, feel free to ask! ``` Add a more complex prompt for the LLM: ```pycon >>> from langchain_core.prompts import ChatPromptTemplate >>> prompt = ChatPromptTemplate.from_messages([ ... ("system", "You are a helpful bot named Fred."), ... ("placeholder", "{messages}"), ... ("user", "Remember, always be polite!"), ... ]) >>> def format_for_model(state: AgentState): ... # You can do more complex modifications here ... return prompt.invoke({"messages": state["messages"]}) >>> >>> graph = create_react_agent(model, tools, state_modifier=format_for_model) >>> inputs = {"messages": [("user", "What's your name? And what's the weather in SF?")]} >>> for s in graph.stream(inputs, stream_mode="values"): ... message = s["messages"][-1] ... if isinstance(message, tuple): ... print(message) ... else: ... message.pretty_print() ``` Add complex prompt with custom graph state: ```pycon >>> from typing import TypedDict >>> prompt = ChatPromptTemplate.from_messages( ... [ ... ("system", "Today is {today}"), ... ("placeholder", "{messages}"), ... ] ... ) >>> >>> class CustomState(TypedDict): ... today: str ... messages: Annotated[list[BaseMessage], add_messages] ... is_last_step: str >>> >>> graph = create_react_agent(model, tools, state_schema=CustomState, state_modifier=prompt) >>> inputs = {"messages": [("user", "What's today's date? And what's the weather in SF?")], "today": "July 16, 2004"} >>> for s in graph.stream(inputs, stream_mode="values"): ... message = s["messages"][-1] ... if isinstance(message, tuple): ... print(message) ... else: ... message.pretty_print() ``` Add thread-level "chat memory" to the graph: ```pycon >>> from langgraph.checkpoint.memory import MemorySaver >>> graph = create_react_agent(model, tools, checkpointer=MemorySaver()) >>> config = {"configurable": {"thread_id": "thread-1"}} >>> def print_stream(graph, inputs, config): ... for s in graph.stream(inputs, config, stream_mode="values"): ... message = s["messages"][-1] ... if isinstance(message, tuple): ... print(message) ... else: ... message.pretty_print() >>> inputs = {"messages": [("user", "What's the weather in SF?")]} >>> print_stream(graph, inputs, config) >>> inputs2 = {"messages": [("user", "Cool, so then should i go biking today?")]} >>> print_stream(graph, inputs2, config) ('user', "What's the weather in SF?") ================================== Ai Message ================================== Tool Calls: check_weather (call_ChndaktJxpr6EMPEB5JfOFYc) Call ID: call_ChndaktJxpr6EMPEB5JfOFYc Args: location: San Francisco ================================= Tool Message ================================= Name: check_weather It's always sunny in San Francisco ================================== Ai Message ================================== The weather in San Francisco is sunny. Enjoy your day! ================================ Human Message ================================= Cool, so then should i go biking today? ================================== Ai Message ================================== Since the weather in San Francisco is sunny, it sounds like a great day for biking! Enjoy your ride! ``` Add an interrupt to let the user confirm before taking an action: ```pycon >>> graph = create_react_agent( ... model, tools, interrupt_before=["tools"], checkpointer=MemorySaver() >>> ) >>> config = {"configurable": {"thread_id": "thread-1"}} >>> inputs = {"messages": [("user", "What's the weather in SF?")]} >>> print_stream(graph, inputs, config) >>> snapshot = graph.get_state(config) >>> print("Next step: ", snapshot.next) >>> print_stream(graph, None, config) ``` Add cross-thread memory to the graph: ```pycon >>> from langgraph.prebuilt import InjectedStore >>> from langgraph.store.base import BaseStore >>> def save_memory(memory: str, *, config: RunnableConfig, store: Annotated[BaseStore, InjectedStore()]) -> str: ... '''Save the given memory for the current user.''' ... # This is a **tool** the model can use to save memories to storage ... user_id = config.get("configurable", {}).get("user_id") ... namespace = ("memories", user_id) ... store.put(namespace, f"memory_{len(store.search(namespace))}", {"data": memory}) ... return f"Saved memory: {memory}" >>> def prepare_model_inputs(state: AgentState, config: RunnableConfig, store: BaseStore): ... # Retrieve user memories and add them to the system message ... # This function is called **every time** the model is prompted. It converts the state to a prompt ... user_id = config.get("configurable", {}).get("user_id") ... namespace = ("memories", user_id) ... memories = [m.value["data"] for m in store.search(namespace)] ... system_msg = f"User memories: {', '.join(memories)}" ... return [{"role": "system", "content": system_msg)] + state["messages"] >>> from langgraph.checkpoint.memory import MemorySaver >>> from langgraph.store.memory import InMemoryStore >>> store = InMemoryStore() >>> graph = create_react_agent(model, [save_memory], state_modifier=prepare_model_inputs, store=store, checkpointer=MemorySaver()) >>> config = {"configurable": {"thread_id": "thread-1", "user_id": "1"}} >>> inputs = {"messages": [("user", "Hey I'm Will, how's it going?")]} >>> print_stream(graph, inputs, config) ('user', "Hey I'm Will, how's it going?") ================================== Ai Message ================================== Hello Will! It's nice to meet you. I'm doing well, thank you for asking. How are you doing today? >>> inputs2 = {"messages": [("user", "I like to bike")]} >>> print_stream(graph, inputs2, config) ================================ Human Message ================================= I like to bike ================================== Ai Message ================================== That's great to hear, Will! Biking is an excellent hobby and form of exercise. It's a fun way to stay active and explore your surroundings. Do you have any favorite biking routes or trails you enjoy? Or perhaps you're into a specific type of biking, like mountain biking or road cycling? >>> config = {"configurable": {"thread_id": "thread-2", "user_id": "1"}} >>> inputs3 = {"messages": [("user", "Hi there! Remember me?")]} >>> print_stream(graph, inputs3, config) ================================ Human Message ================================= Hi there! Remember me? ================================== Ai Message ================================== User memories: Hello! Of course, I remember you, Will! You mentioned earlier that you like to bike. It's great to hear from you again. How have you been? Have you been on any interesting bike rides lately? ``` Add a timeout for a given step: ```pycon >>> import time ... def check_weather(location: str, at_time: datetime | None = None) -> float: ... '''Return the weather forecast for the specified location.''' ... time.sleep(2) ... return f"It's always sunny in {location}" >>> >>> tools = [check_weather] >>> graph = create_react_agent(model, tools) >>> graph.step_timeout = 1 # Seconds >>> for s in graph.stream({"messages": [("user", "what is the weather in sf")]}): ... print(s) TimeoutError: Timed out at step 2 ``` """ if state_schema is not None: if missing_keys := {"messages", "is_last_step"} - set( state_schema.__annotations__ ): raise ValueError(f"Missing required key(s) {missing_keys} in state_schema") if isinstance(tools, ToolExecutor): tool_classes: Sequence[BaseTool] = tools.tools tool_node = ToolNode(tool_classes) elif isinstance(tools, ToolNode): tool_classes = list(tools.tools_by_name.values()) tool_node = tools else: tool_node = ToolNode(tools) # get the tool functions wrapped in a tool class from the ToolNode tool_classes = list(tool_node.tools_by_name.values()) tool_calling_enabled = len(tool_classes) > 0 if _should_bind_tools(model, tool_classes) and tool_calling_enabled: model = cast(BaseChatModel, model).bind_tools(tool_classes) # we're passing store here for validation preprocessor = _get_model_preprocessing_runnable( state_modifier, messages_modifier, store ) model_runnable = preprocessor | model # Define the function that calls the model def call_model(state: AgentState, config: RunnableConfig) -> AgentState: _validate_chat_history(state["messages"]) response = model_runnable.invoke(state, config) has_tool_calls = isinstance(response, AIMessage) and response.tool_calls all_tools_return_direct = ( all(call["name"] in should_return_direct for call in response.tool_calls) if isinstance(response, AIMessage) else False ) if ( ( "remaining_steps" not in state and state["is_last_step"] and has_tool_calls ) or ( "remaining_steps" in state and state["remaining_steps"] < 1 and all_tools_return_direct ) or ( "remaining_steps" in state and state["remaining_steps"] < 2 and has_tool_calls ) ): return { "messages": [ AIMessage( id=response.id, content="Sorry, need more steps to process this request.", ) ] } # We return a list, because this will get added to the existing list return {"messages": [response]} async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: _validate_chat_history(state["messages"]) response = await model_runnable.ainvoke(state, config) has_tool_calls = isinstance(response, AIMessage) and response.tool_calls all_tools_return_direct = ( all(call["name"] in should_return_direct for call in response.tool_calls) if isinstance(response, AIMessage) else False ) if ( ( "remaining_steps" not in state and state["is_last_step"] and has_tool_calls ) or ( "remaining_steps" in state and state["remaining_steps"] < 1 and all_tools_return_direct ) or ( "remaining_steps" in state and state["remaining_steps"] < 2 and has_tool_calls ) ): return { "messages": [ AIMessage( id=response.id, content="Sorry, need more steps to process this request.", ) ] } # We return a list, because this will get added to the existing list return {"messages": [response]} if not tool_calling_enabled: # Define a new graph workflow = StateGraph(state_schema or AgentState) workflow.add_node("agent", RunnableCallable(call_model, acall_model)) workflow.set_entry_point("agent") return workflow.compile( checkpointer=checkpointer, store=store, interrupt_before=interrupt_before, interrupt_after=interrupt_after, debug=debug, ) # Define the function that determines whether to continue or not def should_continue(state: AgentState) -> Literal["tools", "__end__"]: messages = state["messages"] last_message = messages[-1] # If there is no function call, then we finish if not isinstance(last_message, AIMessage) or not last_message.tool_calls: return "__end__" # Otherwise if there is, we continue else: return "tools" # Define a new graph workflow = StateGraph(state_schema or AgentState) # Define the two nodes we will cycle between workflow.add_node("agent", RunnableCallable(call_model, acall_model)) workflow.add_node("tools", tool_node) # Set the entrypoint as `agent` # This means that this node is the first one called workflow.set_entry_point("agent") # We now add a conditional edge workflow.add_conditional_edges( # First, we define the start node. We use `agent`. # This means these are the edges taken after the `agent` node is called. "agent", # Next, we pass in the function that will determine which node is called next. should_continue, ) # If any of the tools are configured to return_directly after running, # our graph needs to check if these were called should_return_direct = {t.name for t in tool_classes if t.return_direct} def route_tool_responses(state: AgentState) -> Literal["agent", "__end__"]: for m in reversed(state["messages"]): if not isinstance(m, ToolMessage): break if m.name in should_return_direct: return "__end__" return "agent" if should_return_direct: workflow.add_conditional_edges("tools", route_tool_responses) else: workflow.add_edge("tools", "agent") # Finally, we compile it! # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable return workflow.compile( checkpointer=checkpointer, store=store, interrupt_before=interrupt_before, interrupt_after=interrupt_after, debug=debug, ) # Keep for backwards compatibility create_tool_calling_executor = create_react_agent __all__ = [ "create_react_agent", "create_tool_calling_executor", "AgentState", ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/prebuilt/tool_node.py`: ```py from __future__ import annotations import asyncio import inspect import json from copy import copy from typing import ( TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union, cast, get_type_hints, ) from langchain_core.messages import ( AIMessage, AnyMessage, ToolCall, ToolMessage, ) from langchain_core.runnables import RunnableConfig from langchain_core.runnables.config import ( get_config_list, get_executor_for_config, ) from langchain_core.runnables.utils import Input from langchain_core.tools import BaseTool, InjectedToolArg from langchain_core.tools import tool as create_tool from langchain_core.tools.base import get_all_basemodel_annotations from typing_extensions import Annotated, get_args, get_origin from langgraph.errors import GraphBubbleUp from langgraph.store.base import BaseStore from langgraph.utils.runnable import RunnableCallable if TYPE_CHECKING: from pydantic import BaseModel INVALID_TOOL_NAME_ERROR_TEMPLATE = ( "Error: {requested_tool} is not a valid tool, try one of [{available_tools}]." ) TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." def msg_content_output(output: Any) -> str | List[dict]: recognized_content_block_types = ("image", "image_url", "text", "json") if isinstance(output, str): return output elif all( [ isinstance(x, dict) and x.get("type") in recognized_content_block_types for x in output ] ): return output # Technically a list of strings is also valid message content but it's not currently # well tested that all chat models support this. And for backwards compatibility # we want to make sure we don't break any existing ToolNode usage. else: try: return json.dumps(output, ensure_ascii=False) except Exception: return str(output) def _handle_tool_error( e: Exception, *, flag: Union[ bool, str, Callable[..., str], tuple[type[Exception], ...], ], ) -> str: if isinstance(flag, (bool, tuple)): content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) elif isinstance(flag, str): content = flag elif callable(flag): content = flag(e) else: raise ValueError( f"Got unexpected type of `handle_tool_error`. Expected bool, str " f"or callable. Received: {flag}" ) return content def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]: sig = inspect.signature(handler) params = list(sig.parameters.values()) if params: # If it's a method, the first argument is typically 'self' or 'cls' if params[0].name in ["self", "cls"] and len(params) == 2: first_param = params[1] else: first_param = params[0] type_hints = get_type_hints(handler) if first_param.name in type_hints: origin = get_origin(first_param.annotation) if origin is Union: args = get_args(first_param.annotation) if all(issubclass(arg, Exception) for arg in args): return tuple(args) else: raise ValueError( "All types in the error handler error annotation must be Exception types. " "For example, `def custom_handler(e: Union[ValueError, TypeError])`. " f"Got '{first_param.annotation}' instead." ) exception_type = type_hints[first_param.name] if Exception in exception_type.__mro__: return (exception_type,) else: raise ValueError( f"Arbitrary types are not supported in the error handler signature. " "Please annotate the error with either a specific Exception type or a union of Exception types. " "For example, `def custom_handler(e: ValueError)` or `def custom_handler(e: Union[ValueError, TypeError])`. " f"Got '{exception_type}' instead." ) # If no type information is available, return (Exception,) for backwards compatibility. return (Exception,) class ToolNode(RunnableCallable): """A node that runs the tools called in the last AIMessage. It can be used either in StateGraph with a "messages" state key (or a custom key passed via ToolNode's 'messages_key'). If multiple tool calls are requested, they will be run in parallel. The output will be a list of ToolMessages, one for each tool call. Args: tools: A sequence of tools that can be invoked by the ToolNode. name: The name of the ToolNode in the graph. Defaults to "tools". tags: Optional tags to associate with the node. Defaults to None. handle_tool_errors: How to handle tool errors raised by tools inside the node. Defaults to True. Must be one of the following: - True: all errors will be caught and a ToolMessage with a default error message (TOOL_CALL_ERROR_TEMPLATE) will be returned. - str: all errors will be caught and a ToolMessage with the string value of 'handle_tool_errors' will be returned. - tuple[type[Exception], ...]: exceptions in the tuple will be caught and a ToolMessage with a default error message (TOOL_CALL_ERROR_TEMPLATE) will be returned. - Callable[..., str]: exceptions from the signature of the callable will be caught and a ToolMessage with the string value of the result of the 'handle_tool_errors' callable will be returned. - False: none of the errors raised by the tools will be caught messages_key: The state key in the input that contains the list of messages. The same key will be used for the output from the ToolNode. Defaults to "messages". The `ToolNode` is roughly analogous to: ```python tools_by_name = {tool.name: tool for tool in tools} def tool_node(state: dict): result = [] for tool_call in state["messages"][-1].tool_calls: tool = tools_by_name[tool_call["name"]] observation = tool.invoke(tool_call["args"]) result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) return {"messages": result} ``` Important: - The state MUST contain a list of messages. - The last message MUST be an `AIMessage`. - The `AIMessage` MUST have `tool_calls` populated. """ name: str = "ToolNode" def __init__( self, tools: Sequence[Union[BaseTool, Callable]], *, name: str = "tools", tags: Optional[list[str]] = None, handle_tool_errors: Union[ bool, str, Callable[..., str], tuple[type[Exception], ...] ] = True, messages_key: str = "messages", ) -> None: super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) self.tools_by_name: Dict[str, BaseTool] = {} self.tool_to_state_args: Dict[str, Dict[str, Optional[str]]] = {} self.tool_to_store_arg: Dict[str, Optional[str]] = {} self.handle_tool_errors = handle_tool_errors self.messages_key = messages_key for tool_ in tools: if not isinstance(tool_, BaseTool): tool_ = create_tool(tool_) self.tools_by_name[tool_.name] = tool_ self.tool_to_state_args[tool_.name] = _get_state_args(tool_) self.tool_to_store_arg[tool_.name] = _get_store_arg(tool_) def _func( self, input: Union[ list[AnyMessage], dict[str, Any], BaseModel, ], config: RunnableConfig, *, store: BaseStore, ) -> Any: tool_calls, output_type = self._parse_input(input, store) config_list = get_config_list(config, len(tool_calls)) with get_executor_for_config(config) as executor: outputs = [*executor.map(self._run_one, tool_calls, config_list)] # TypedDict, pydantic, dataclass, etc. should all be able to load from dict return outputs if output_type == "list" else {self.messages_key: outputs} def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: if "store" not in kwargs: kwargs["store"] = None return super().invoke(input, config, **kwargs) async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: if "store" not in kwargs: kwargs["store"] = None return await super().ainvoke(input, config, **kwargs) async def _afunc( self, input: Union[ list[AnyMessage], dict[str, Any], BaseModel, ], config: RunnableConfig, *, store: BaseStore, ) -> Any: tool_calls, output_type = self._parse_input(input, store) outputs = await asyncio.gather( *(self._arun_one(call, config) for call in tool_calls) ) # TypedDict, pydantic, dataclass, etc. should all be able to load from dict return outputs if output_type == "list" else {self.messages_key: outputs} def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: if invalid_tool_message := self._validate_tool_call(call): return invalid_tool_message try: input = {**call, **{"type": "tool_call"}} tool_message: ToolMessage = self.tools_by_name[call["name"]].invoke( input, config ) tool_message.content = cast( Union[str, list], msg_content_output(tool_message.content) ) return tool_message # GraphInterrupt is a special exception that will always be raised. # It can be triggered in the following scenarios: # (1) a NodeInterrupt is raised inside a tool # (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture) except GraphBubbleUp as e: raise e except Exception as e: if isinstance(self.handle_tool_errors, tuple): handled_types: tuple = self.handle_tool_errors elif callable(self.handle_tool_errors): handled_types = _infer_handled_types(self.handle_tool_errors) else: # default behavior is catching all exceptions handled_types = (Exception,) # Unhandled if not self.handle_tool_errors or not isinstance(e, handled_types): raise e # Handled else: content = _handle_tool_error(e, flag=self.handle_tool_errors) return ToolMessage( content=content, name=call["name"], tool_call_id=call["id"], status="error" ) async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: if invalid_tool_message := self._validate_tool_call(call): return invalid_tool_message try: input = {**call, **{"type": "tool_call"}} tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke( input, config ) tool_message.content = cast( Union[str, list], msg_content_output(tool_message.content) ) return tool_message # GraphInterrupt is a special exception that will always be raised. # It can be triggered in the following scenarios: # (1) a NodeInterrupt is raised inside a tool # (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture) except GraphBubbleUp as e: raise e except Exception as e: if isinstance(self.handle_tool_errors, tuple): handled_types: tuple = self.handle_tool_errors elif callable(self.handle_tool_errors): handled_types = _infer_handled_types(self.handle_tool_errors) else: # default behavior is catching all exceptions handled_types = (Exception,) # Unhandled if not self.handle_tool_errors or not isinstance(e, handled_types): raise e # Handled else: content = _handle_tool_error(e, flag=self.handle_tool_errors) return ToolMessage( content=content, name=call["name"], tool_call_id=call["id"], status="error" ) def _parse_input( self, input: Union[ list[AnyMessage], dict[str, Any], BaseModel, ], store: BaseStore, ) -> Tuple[List[ToolCall], Literal["list", "dict"]]: if isinstance(input, list): output_type = "list" message: AnyMessage = input[-1] elif isinstance(input, dict) and (messages := input.get(self.messages_key, [])): output_type = "dict" message = messages[-1] elif messages := getattr(input, self.messages_key, None): # Assume dataclass-like state that can coerce from dict output_type = "dict" message = messages[-1] else: raise ValueError("No message found in input") if not isinstance(message, AIMessage): raise ValueError("Last message is not an AIMessage") tool_calls = [ self._inject_tool_args(call, input, store) for call in message.tool_calls ] return tool_calls, output_type def _validate_tool_call(self, call: ToolCall) -> Optional[ToolMessage]: if (requested_tool := call["name"]) not in self.tools_by_name: content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format( requested_tool=requested_tool, available_tools=", ".join(self.tools_by_name.keys()), ) return ToolMessage( content, name=requested_tool, tool_call_id=call["id"], status="error" ) else: return None def _inject_state( self, tool_call: ToolCall, input: Union[ list[AnyMessage], dict[str, Any], BaseModel, ], ) -> ToolCall: state_args = self.tool_to_state_args[tool_call["name"]] if state_args and isinstance(input, list): required_fields = list(state_args.values()) if ( len(required_fields) == 1 and required_fields[0] == self.messages_key or required_fields[0] is None ): input = {self.messages_key: input} else: err_msg = ( f"Invalid input to ToolNode. Tool {tool_call['name']} requires " f"graph state dict as input." ) if any(state_field for state_field in state_args.values()): required_fields_str = ", ".join(f for f in required_fields if f) err_msg += f" State should contain fields {required_fields_str}." raise ValueError(err_msg) if isinstance(input, dict): tool_state_args = { tool_arg: input[state_field] if state_field else input for tool_arg, state_field in state_args.items() } else: tool_state_args = { tool_arg: getattr(input, state_field) if state_field else input for tool_arg, state_field in state_args.items() } tool_call["args"] = { **tool_call["args"], **tool_state_args, } return tool_call def _inject_store(self, tool_call: ToolCall, store: BaseStore) -> ToolCall: store_arg = self.tool_to_store_arg[tool_call["name"]] if not store_arg: return tool_call if store is None: raise ValueError( "Cannot inject store into tools with InjectedStore annotations - " "please compile your graph with a store." ) tool_call["args"] = { **tool_call["args"], store_arg: store, } return tool_call def _inject_tool_args( self, tool_call: ToolCall, input: Union[ list[AnyMessage], dict[str, Any], BaseModel, ], store: BaseStore, ) -> ToolCall: if tool_call["name"] not in self.tools_by_name: return tool_call tool_call_copy: ToolCall = copy(tool_call) tool_call_with_state = self._inject_state(tool_call_copy, input) tool_call_with_store = self._inject_store(tool_call_with_state, store) return tool_call_with_store def tools_condition( state: Union[list[AnyMessage], dict[str, Any], BaseModel], messages_key: str = "messages", ) -> Literal["tools", "__end__"]: """Use in the conditional_edge to route to the ToolNode if the last message has tool calls. Otherwise, route to the end. Args: state (Union[list[AnyMessage], dict[str, Any], BaseModel]): The state to check for tool calls. Must have a list of messages (MessageGraph) or have the "messages" key (StateGraph). Returns: The next node to route to. Examples: Create a custom ReAct-style agent with tools. ```pycon >>> from langchain_anthropic import ChatAnthropic >>> from langchain_core.tools import tool ... >>> from langgraph.graph import StateGraph >>> from langgraph.prebuilt import ToolNode, tools_condition >>> from langgraph.graph.message import add_messages ... >>> from typing import TypedDict, Annotated ... >>> @tool >>> def divide(a: float, b: float) -> int: ... \"\"\"Return a / b.\"\"\" ... return a / b ... >>> llm = ChatAnthropic(model="claude-3-haiku-20240307") >>> tools = [divide] ... >>> class State(TypedDict): ... messages: Annotated[list, add_messages] >>> >>> graph_builder = StateGraph(State) >>> graph_builder.add_node("tools", ToolNode(tools)) >>> graph_builder.add_node("chatbot", lambda state: {"messages":llm.bind_tools(tools).invoke(state['messages'])}) >>> graph_builder.add_edge("tools", "chatbot") >>> graph_builder.add_conditional_edges( ... "chatbot", tools_condition ... ) >>> graph_builder.set_entry_point("chatbot") >>> graph = graph_builder.compile() >>> graph.invoke({"messages": {"role": "user", "content": "What's 329993 divided by 13662?"}}) ``` """ if isinstance(state, list): ai_message = state[-1] elif isinstance(state, dict) and (messages := state.get(messages_key, [])): ai_message = messages[-1] elif messages := getattr(state, messages_key, []): ai_message = messages[-1] else: raise ValueError(f"No messages found in input state to tool_edge: {state}") if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: return "tools" return "__end__" class InjectedState(InjectedToolArg): """Annotation for a Tool arg that is meant to be populated with the graph state. Any Tool argument annotated with InjectedState will be hidden from a tool-calling model, so that the model doesn't attempt to generate the argument. If using ToolNode, the appropriate graph state field will be automatically injected into the model-generated tool args. Args: field: The key from state to insert. If None, the entire state is expected to be passed in. Example: ```python from typing import List from typing_extensions import Annotated, TypedDict from langchain_core.messages import BaseMessage, AIMessage from langchain_core.tools import tool from langgraph.prebuilt import InjectedState, ToolNode class AgentState(TypedDict): messages: List[BaseMessage] foo: str @tool def state_tool(x: int, state: Annotated[dict, InjectedState]) -> str: '''Do something with state.''' if len(state["messages"]) > 2: return state["foo"] + str(x) else: return "not enough messages" @tool def foo_tool(x: int, foo: Annotated[str, InjectedState("foo")]) -> str: '''Do something else with state.''' return foo + str(x + 1) node = ToolNode([state_tool, foo_tool]) tool_call1 = {"name": "state_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"} tool_call2 = {"name": "foo_tool", "args": {"x": 1}, "id": "2", "type": "tool_call"} state = { "messages": [AIMessage("", tool_calls=[tool_call1, tool_call2])], "foo": "bar", } node.invoke(state) ``` ```pycon [ ToolMessage(content='not enough messages', name='state_tool', tool_call_id='1'), ToolMessage(content='bar2', name='foo_tool', tool_call_id='2') ] ``` """ # noqa: E501 def __init__(self, field: Optional[str] = None) -> None: self.field = field class InjectedStore(InjectedToolArg): """Annotation for a Tool arg that is meant to be populated with LangGraph store. Any Tool argument annotated with InjectedStore will be hidden from a tool-calling model, so that the model doesn't attempt to generate the argument. If using ToolNode, the appropriate store field will be automatically injected into the model-generated tool args. Note: if a graph is compiled with a store object, the store will be automatically propagated to the tools with InjectedStore args when using ToolNode. !!! Warning `InjectedStore` annotation requires `langchain-core >= 0.3.8` Example: ```python from typing import Any from typing_extensions import Annotated from langchain_core.messages import AIMessage from langchain_core.tools import tool from langgraph.store.memory import InMemoryStore from langgraph.prebuilt import InjectedStore, ToolNode store = InMemoryStore() store.put(("values",), "foo", {"bar": 2}) @tool def store_tool(x: int, my_store: Annotated[Any, InjectedStore()]) -> str: '''Do something with store.''' stored_value = my_store.get(("values",), "foo").value["bar"] return stored_value + x node = ToolNode([store_tool]) tool_call = {"name": "store_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"} state = { "messages": [AIMessage("", tool_calls=[tool_call])], } node.invoke(state, store=store) ``` ```pycon { "messages": [ ToolMessage(content='3', name='store_tool', tool_call_id='1'), ] } ``` """ # noqa: E501 def _is_injection( type_arg: Any, injection_type: Union[Type[InjectedState], Type[InjectedStore]] ) -> bool: if isinstance(type_arg, injection_type) or ( isinstance(type_arg, type) and issubclass(type_arg, injection_type) ): return True origin_ = get_origin(type_arg) if origin_ is Union or origin_ is Annotated: return any(_is_injection(ta, injection_type) for ta in get_args(type_arg)) return False def _get_state_args(tool: BaseTool) -> Dict[str, Optional[str]]: full_schema = tool.get_input_schema() tool_args_to_state_fields: Dict = {} for name, type_ in get_all_basemodel_annotations(full_schema).items(): injections = [ type_arg for type_arg in get_args(type_) if _is_injection(type_arg, InjectedState) ] if len(injections) > 1: raise ValueError( "A tool argument should not be annotated with InjectedState more than " f"once. Received arg {name} with annotations {injections}." ) elif len(injections) == 1: injection = injections[0] if isinstance(injection, InjectedState) and injection.field: tool_args_to_state_fields[name] = injection.field else: tool_args_to_state_fields[name] = None else: pass return tool_args_to_state_fields def _get_store_arg(tool: BaseTool) -> Optional[str]: full_schema = tool.get_input_schema() for name, type_ in get_all_basemodel_annotations(full_schema).items(): injections = [ type_arg for type_arg in get_args(type_) if _is_injection(type_arg, InjectedStore) ] if len(injections) > 1: ValueError( "A tool argument should not be annotated with InjectedStore more than " f"once. Received arg {name} with annotations {injections}." ) elif len(injections) == 1: return name else: pass return None ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/prebuilt/tool_executor.py`: ```py from typing import Any, Callable, Sequence, Union from langchain_core.load.serializable import Serializable from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from langchain_core.tools import tool as create_tool from langgraph._api.deprecation import deprecated from langgraph.utils.runnable import RunnableCallable INVALID_TOOL_MSG_TEMPLATE = ( "{requested_tool_name} is not a valid tool, " "try one of [{available_tool_names_str}]." ) @deprecated("0.2.0", "langgraph.prebuilt.ToolNode", removal="0.3.0") class ToolInvocationInterface: """Interface for invoking a tool. Attributes: tool (str): The name of the tool to invoke. tool_input (Union[str, dict]): The input to pass to the tool. """ tool: str tool_input: Union[str, dict] @deprecated("0.2.0", "langgraph.prebuilt.ToolNode", removal="0.3.0") class ToolInvocation(Serializable): """Information about how to invoke a tool. Attributes: tool (str): The name of the Tool to execute. tool_input (Union[str, dict]): The input to pass in to the Tool. Examples: Basic usage: ```pycon >>> invocation = ToolInvocation( ... tool="search", ... tool_input="What is the capital of France?" ... ) ``` """ tool: str tool_input: Union[str, dict] @deprecated("0.2.0", "langgraph.prebuilt.ToolNode", removal="0.3.0") class ToolExecutor(RunnableCallable): """Executes a tool invocation. Args: tools (Sequence[BaseTool]): A sequence of tools that can be invoked. invalid_tool_msg_template (str, optional): The template for the error message when an invalid tool is requested. Defaults to INVALID_TOOL_MSG_TEMPLATE. Examples: Basic usage: ```pycon >>> from langchain_core.tools import tool >>> from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation ... ... >>> @tool ... def search(query: str) -> str: ... \"\"\"Search engine.\"\"\" ... return f"Searching for: {query}" ... ... >>> tools = [search] >>> executor = ToolExecutor(tools) ... >>> invocation = ToolInvocation(tool="search", tool_input="What is the capital of France?") >>> result = executor.invoke(invocation) >>> print(result) "Searching for: What is the capital of France?" ``` Handling invalid tool: ```pycon >>> invocation = ToolInvocation( ... tool="nonexistent", tool_input="What is the capital of France?" ... ) >>> result = executor.invoke(invocation) >>> print(result) "nonexistent is not a valid tool, try one of [search]." ``` """ def __init__( self, tools: Sequence[Union[BaseTool, Callable]], *, invalid_tool_msg_template: str = INVALID_TOOL_MSG_TEMPLATE, ) -> None: super().__init__(self._execute, afunc=self._aexecute, trace=False) tools_ = [ tool if isinstance(tool, BaseTool) else create_tool(tool) for tool in tools ] self.tools = tools_ self.tool_map = {t.name: t for t in tools_} self.invalid_tool_msg_template = invalid_tool_msg_template def _execute( self, tool_invocation: ToolInvocationInterface, config: RunnableConfig ) -> Any: if tool_invocation.tool not in self.tool_map: return self.invalid_tool_msg_template.format( requested_tool_name=tool_invocation.tool, available_tool_names_str=", ".join([t.name for t in self.tools]), ) else: tool = self.tool_map[tool_invocation.tool] output = tool.invoke(tool_invocation.tool_input, config) return output async def _aexecute( self, tool_invocation: ToolInvocationInterface, config: RunnableConfig ) -> Any: if tool_invocation.tool not in self.tool_map: return self.invalid_tool_msg_template.format( requested_tool_name=tool_invocation.tool, available_tool_names_str=", ".join([t.name for t in self.tools]), ) else: tool = self.tool_map[tool_invocation.tool] output = await tool.ainvoke(tool_invocation.tool_input, config) return output ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/errors.py`: ```py from enum import Enum from typing import Any, Sequence from langgraph.checkpoint.base import EmptyChannelError # noqa: F401 from langgraph.types import Command, Interrupt # EmptyChannelError re-exported for backwards compatibility class ErrorCode(Enum): GRAPH_RECURSION_LIMIT = "GRAPH_RECURSION_LIMIT" INVALID_CONCURRENT_GRAPH_UPDATE = "INVALID_CONCURRENT_GRAPH_UPDATE" INVALID_GRAPH_NODE_RETURN_VALUE = "INVALID_GRAPH_NODE_RETURN_VALUE" MULTIPLE_SUBGRAPHS = "MULTIPLE_SUBGRAPHS" INVALID_CHAT_HISTORY = "INVALID_CHAT_HISTORY" def create_error_message(*, message: str, error_code: ErrorCode) -> str: return ( f"{message}\n" "For troubleshooting, visit: https://python.langchain.com/docs/" f"troubleshooting/errors/{error_code.value}" ) class GraphRecursionError(RecursionError): """Raised when the graph has exhausted the maximum number of steps. This prevents infinite loops. To increase the maximum number of steps, run your graph with a config specifying a higher `recursion_limit`. Troubleshooting Guides: - [GRAPH_RECURSION_LIMIT](https://python.langchain.com/docs/troubleshooting/errors/GRAPH_RECURSION_LIMIT) Examples: graph = builder.compile() graph.invoke( {"messages": [("user", "Hello, world!")]}, # The config is the second positional argument {"recursion_limit": 1000}, ) """ pass class InvalidUpdateError(Exception): """Raised when attempting to update a channel with an invalid set of updates. Troubleshooting Guides: - [INVALID_CONCURRENT_GRAPH_UPDATE](https://python.langchain.com/docs/troubleshooting/errors/INVALID_CONCURRENT_GRAPH_UPDATE) - [INVALID_GRAPH_NODE_RETURN_VALUE](https://python.langchain.com/docs/troubleshooting/errors/INVALID_GRAPH_NODE_RETURN_VALUE) """ pass class GraphBubbleUp(Exception): pass class GraphInterrupt(GraphBubbleUp): """Raised when a subgraph is interrupted, suppressed by the root graph. Never raised directly, or surfaced to the user.""" def __init__(self, interrupts: Sequence[Interrupt] = ()) -> None: super().__init__(interrupts) class NodeInterrupt(GraphInterrupt): """Raised by a node to interrupt execution.""" def __init__(self, value: Any) -> None: super().__init__([Interrupt(value=value)]) class GraphDelegate(GraphBubbleUp): """Raised when a graph is delegated (for distributed mode).""" def __init__(self, *args: dict[str, Any]) -> None: super().__init__(*args) class ParentCommand(GraphBubbleUp): args: tuple[Command] def __init__(self, command: Command) -> None: super().__init__(command) class EmptyInputError(Exception): """Raised when graph receives an empty input.""" pass class TaskNotFound(Exception): """Raised when the executor is unable to find a task (for distributed mode).""" pass class CheckpointNotLatest(Exception): """Raised when the checkpoint is not the latest version (for distributed mode).""" pass class MultipleSubgraphsError(Exception): """Raised when multiple subgraphs are called inside the same node. Troubleshooting guides: - [MULTIPLE_SUBGRAPHS](https://python.langchain.com/docs/troubleshooting/errors/MULTIPLE_SUBGRAPHS) """ pass _SEEN_CHECKPOINT_NS: set[str] = set() """Used for subgraph detection.""" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/last_value.py`: ```py from typing import Generic, Optional, Sequence, Type from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value from langgraph.errors import ( EmptyChannelError, ErrorCode, InvalidUpdateError, create_error_message, ) class LastValue(Generic[Value], BaseChannel[Value, Value, Value]): """Stores the last value received, can receive at most one value per step.""" __slots__ = ("value",) def __eq__(self, value: object) -> bool: return isinstance(value, LastValue) @property def ValueType(self) -> Type[Value]: """The type of the value stored in the channel.""" return self.typ @property def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ) empty.key = self.key if checkpoint is not None: empty.value = checkpoint return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: return False if len(values) != 1: msg = create_error_message( message=f"At key '{self.key}': Can receive only one value per step. Use an Annotated key to handle multiple values.", error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE, ) raise InvalidUpdateError(msg) self.value = values[-1] return True def get(self) -> Value: try: return self.value except AttributeError: raise EmptyChannelError() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/__init__.py`: ```py from langgraph.channels.any_value import AnyValue from langgraph.channels.binop import BinaryOperatorAggregate from langgraph.channels.context import Context from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.channels.last_value import LastValue from langgraph.channels.topic import Topic from langgraph.channels.untracked_value import UntrackedValue __all__ = [ "LastValue", "Topic", "Context", "BinaryOperatorAggregate", "UntrackedValue", "EphemeralValue", "AnyValue", ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/untracked_value.py`: ```py from typing import Generic, Optional, Sequence, Type from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value from langgraph.errors import EmptyChannelError, InvalidUpdateError class UntrackedValue(Generic[Value], BaseChannel[Value, Value, Value]): """Stores the last value received, never checkpointed.""" __slots__ = ("value", "guard") def __init__(self, typ: Type[Value], guard: bool = True) -> None: super().__init__(typ) self.guard = guard def __eq__(self, value: object) -> bool: return isinstance(value, UntrackedValue) and value.guard == self.guard @property def ValueType(self) -> Type[Value]: """The type of the value stored in the channel.""" return self.typ @property def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ def checkpoint(self) -> Value: raise EmptyChannelError() def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ, self.guard) empty.key = self.key return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: return False if len(values) != 1 and self.guard: raise InvalidUpdateError( f"At key '{self.key}': UntrackedValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values." ) self.value = values[-1] return True def get(self) -> Value: try: return self.value except AttributeError: raise EmptyChannelError() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/any_value.py`: ```py from typing import Generic, Optional, Sequence, Type from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value from langgraph.errors import EmptyChannelError class AnyValue(Generic[Value], BaseChannel[Value, Value, Value]): """Stores the last value received, assumes that if multiple values are received, they are all equal.""" __slots__ = ("typ", "value") def __eq__(self, value: object) -> bool: return isinstance(value, AnyValue) @property def ValueType(self) -> Type[Value]: """The type of the value stored in the channel.""" return self.typ @property def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ) empty.key = self.key if checkpoint is not None: empty.value = checkpoint return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: try: del self.value return True except AttributeError: return False self.value = values[-1] return True def get(self) -> Value: try: return self.value except AttributeError: raise EmptyChannelError() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/named_barrier_value.py`: ```py from typing import Generic, Optional, Sequence, Type from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value from langgraph.errors import EmptyChannelError, InvalidUpdateError class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]): """A channel that waits until all named values are received before making the value available.""" __slots__ = ("names", "seen") names: set[Value] seen: set[Value] def __init__(self, typ: Type[Value], names: set[Value]) -> None: super().__init__(typ) self.names = names self.seen: set[str] = set() def __eq__(self, value: object) -> bool: return isinstance(value, NamedBarrierValue) and value.names == self.names @property def ValueType(self) -> Type[Value]: """The type of the value stored in the channel.""" return self.typ @property def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ def checkpoint(self) -> set[Value]: return self.seen def from_checkpoint(self, checkpoint: Optional[set[Value]]) -> Self: empty = self.__class__(self.typ, self.names) empty.key = self.key if checkpoint is not None: empty.seen = checkpoint return empty def update(self, values: Sequence[Value]) -> bool: updated = False for value in values: if value in self.names: if value not in self.seen: self.seen.add(value) updated = True else: raise InvalidUpdateError( f"At key '{self.key}': Value {value} not in {self.names}" ) return updated def get(self) -> Value: if self.seen != self.names: raise EmptyChannelError() return None def consume(self) -> bool: if self.seen == self.names: self.seen = set() return True return False ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/binop.py`: ```py import collections.abc from typing import ( Callable, Generic, Optional, Sequence, Type, ) from typing_extensions import NotRequired, Required, Self from langgraph.channels.base import BaseChannel, Value from langgraph.errors import EmptyChannelError # Adapted from typing_extensions def _strip_extras(t): # type: ignore[no-untyped-def] """Strips Annotated, Required and NotRequired from a given type.""" if hasattr(t, "__origin__"): return _strip_extras(t.__origin__) if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired): return _strip_extras(t.__args__[0]) return t class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]): """Stores the result of applying a binary operator to the current value and each new value. ```python import operator total = Channels.BinaryOperatorAggregate(int, operator.add) ``` """ __slots__ = ("value", "operator") def __init__(self, typ: Type[Value], operator: Callable[[Value, Value], Value]): super().__init__(typ) self.operator = operator # special forms from typing or collections.abc are not instantiable # so we need to replace them with their concrete counterparts typ = _strip_extras(typ) if typ in (collections.abc.Sequence, collections.abc.MutableSequence): typ = list if typ in (collections.abc.Set, collections.abc.MutableSet): typ = set if typ in (collections.abc.Mapping, collections.abc.MutableMapping): typ = dict try: self.value = typ() except Exception: pass def __eq__(self, value: object) -> bool: return isinstance(value, BinaryOperatorAggregate) and ( value.operator is self.operator if value.operator.__name__ != "<lambda>" and self.operator.__name__ != "<lambda>" else True ) @property def ValueType(self) -> Type[Value]: """The type of the value stored in the channel.""" return self.typ @property def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ, self.operator) empty.key = self.key if checkpoint is not None: empty.value = checkpoint return empty def update(self, values: Sequence[Value]) -> bool: if not values: return False if not hasattr(self, "value"): self.value = values[0] values = values[1:] for value in values: self.value = self.operator(self.value, value) return True def get(self) -> Value: try: return self.value except AttributeError: raise EmptyChannelError() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/ephemeral_value.py`: ```py from typing import Any, Generic, Optional, Sequence, Type from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value from langgraph.errors import EmptyChannelError, InvalidUpdateError class EphemeralValue(Generic[Value], BaseChannel[Value, Value, Value]): """Stores the value received in the step immediately preceding, clears after.""" __slots__ = ("value", "guard") def __init__(self, typ: Any, guard: bool = True) -> None: super().__init__(typ) self.guard = guard def __eq__(self, value: object) -> bool: return isinstance(value, EphemeralValue) and value.guard == self.guard @property def ValueType(self) -> Type[Value]: """The type of the value stored in the channel.""" return self.typ @property def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ, self.guard) empty.key = self.key if checkpoint is not None: empty.value = checkpoint return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: try: del self.value return True except AttributeError: return False if len(values) != 1 and self.guard: raise InvalidUpdateError( f"At key '{self.key}': EphemeralValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values." ) self.value = values[-1] return True def get(self) -> Value: try: return self.value except AttributeError: raise EmptyChannelError() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/dynamic_barrier_value.py`: ```py from typing import Any, Generic, NamedTuple, Optional, Sequence, Type, Union from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value from langgraph.errors import EmptyChannelError, InvalidUpdateError class WaitForNames(NamedTuple): names: set[Any] class DynamicBarrierValue( Generic[Value], BaseChannel[Value, Union[Value, WaitForNames], set[Value]] ): """A channel that switches between two states - in the "priming" state it can't be read from. - if it receives a WaitForNames update, it switches to the "waiting" state. - in the "waiting" state it collects named values until all are received. - once all named values are received, it can be read once, and it switches back to the "priming" state. """ __slots__ = ("names", "seen") names: Optional[set[Value]] seen: set[Value] def __init__(self, typ: Type[Value]) -> None: super().__init__(typ) self.names = None self.seen = set() def __eq__(self, value: object) -> bool: return isinstance(value, DynamicBarrierValue) and value.names == self.names @property def ValueType(self) -> Type[Value]: """The type of the value stored in the channel.""" return self.typ @property def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ def checkpoint(self) -> tuple[Optional[set[Value]], set[Value]]: return (self.names, self.seen) def from_checkpoint( self, checkpoint: Optional[tuple[Optional[set[Value]], set[Value]]], ) -> Self: empty = self.__class__(self.typ) empty.key = self.key if checkpoint is not None: names, seen = checkpoint empty.names = names if names is not None else None empty.seen = seen return empty def update(self, values: Sequence[Union[Value, WaitForNames]]) -> bool: if wait_for_names := [v for v in values if isinstance(v, WaitForNames)]: if len(wait_for_names) > 1: raise InvalidUpdateError( f"At key '{self.key}': Received multiple WaitForNames updates in the same step." ) self.names = wait_for_names[0].names return True elif self.names is not None: updated = False for value in values: assert not isinstance(value, WaitForNames) if value in self.names: if value not in self.seen: self.seen.add(value) updated = True else: raise InvalidUpdateError(f"Value {value} not in {self.names}") return updated def get(self) -> Value: if self.seen != self.names: raise EmptyChannelError() return None def consume(self) -> bool: if self.seen == self.names: self.seen = set() self.names = None return True return False ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/context.py`: ```py from langgraph.managed.context import Context as ContextManagedValue Context = ContextManagedValue.of __all__ = ["Context"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/topic.py`: ```py from typing import Any, Generic, Iterator, Optional, Sequence, Type, Union from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value from langgraph.errors import EmptyChannelError def flatten(values: Sequence[Union[Value, list[Value]]]) -> Iterator[Value]: for value in values: if isinstance(value, list): yield from value else: yield value class Topic( Generic[Value], BaseChannel[ Sequence[Value], Union[Value, list[Value]], tuple[set[Value], list[Value]] ], ): """A configurable PubSub Topic. Args: typ: The type of the value stored in the channel. accumulate: Whether to accumulate values across steps. If False, the channel will be emptied after each step. """ __slots__ = ("values", "accumulate") def __init__(self, typ: Type[Value], accumulate: bool = False) -> None: super().__init__(typ) # attrs self.accumulate = accumulate # state self.values = list[Value]() def __eq__(self, value: object) -> bool: return isinstance(value, Topic) and value.accumulate == self.accumulate @property def ValueType(self) -> Any: """The type of the value stored in the channel.""" return Sequence[self.typ] # type: ignore[name-defined] @property def UpdateType(self) -> Any: """The type of the update received by the channel.""" return Union[self.typ, list[self.typ]] # type: ignore[name-defined] def checkpoint(self) -> tuple[set[Value], list[Value]]: return self.values def from_checkpoint(self, checkpoint: Optional[list[Value]]) -> Self: empty = self.__class__(self.typ, self.accumulate) empty.key = self.key if checkpoint is not None: if isinstance(checkpoint, tuple): empty.values = checkpoint[1] else: empty.values = checkpoint return empty def update(self, values: Sequence[Union[Value, list[Value]]]) -> None: current = list(self.values) if not self.accumulate: self.values = list[Value]() if flat_values := flatten(values): self.values.extend(flat_values) return self.values != current def get(self) -> Sequence[Value]: if self.values: return list(self.values) else: raise EmptyChannelError ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/base.py`: ```py from abc import ABC, abstractmethod from typing import Any, Generic, Optional, Sequence, Type, TypeVar from typing_extensions import Self from langgraph.errors import EmptyChannelError, InvalidUpdateError Value = TypeVar("Value") Update = TypeVar("Update") C = TypeVar("C") class BaseChannel(Generic[Value, Update, C], ABC): __slots__ = ("key", "typ") def __init__(self, typ: Type[Any], key: str = "") -> None: self.typ = typ self.key = key @property @abstractmethod def ValueType(self) -> Any: """The type of the value stored in the channel.""" @property @abstractmethod def UpdateType(self) -> Any: """The type of the update received by the channel.""" # serialize/deserialize methods def checkpoint(self) -> Optional[C]: """Return a serializable representation of the channel's current state. Raises EmptyChannelError if the channel is empty (never updated yet), or doesn't support checkpoints.""" return self.get() @abstractmethod def from_checkpoint(self, checkpoint: Optional[C]) -> Self: """Return a new identical channel, optionally initialized from a checkpoint. If the checkpoint contains complex data structures, they should be copied.""" # state methods @abstractmethod def update(self, values: Sequence[Update]) -> bool: """Update the channel's value with the given sequence of updates. The order of the updates in the sequence is arbitrary. This method is called by Pregel for all channels at the end of each step. If there are no updates, it is called with an empty sequence. Raises InvalidUpdateError if the sequence of updates is invalid. Returns True if the channel was updated, False otherwise.""" @abstractmethod def get(self) -> Value: """Return the current value of the channel. Raises EmptyChannelError if the channel is empty (never updated yet).""" def consume(self) -> bool: """Mark the current value of the channel as consumed. By default, no-op. This is called by Pregel before the start of the next step, for all channels that triggered a node. If the channel was updated, return True. """ return False __all__ = [ "BaseChannel", "EmptyChannelError", "InvalidUpdateError", ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/bench/fanout_to_subgraph.py`: ```py import operator from typing import Annotated, TypedDict from langgraph.constants import END, START, Send from langgraph.graph.state import StateGraph def fanout_to_subgraph() -> StateGraph: class OverallState(TypedDict): subjects: list[str] jokes: Annotated[list[str], operator.add] async def continue_to_jokes(state: OverallState): return [Send("generate_joke", {"subject": s}) for s in state["subjects"]] class JokeInput(TypedDict): subject: str class JokeOutput(TypedDict): jokes: list[str] async def bump(state: JokeOutput): return {"jokes": [state["jokes"][0] + " a"]} async def generate(state: JokeInput): return {"jokes": [f"Joke about {state['subject']}"]} async def edit(state: JokeInput): subject = state["subject"] return {"subject": f"{subject} - hohoho"} async def bump_loop(state: JokeOutput): return END if state["jokes"][0].endswith(" a" * 10) else "bump" # subgraph subgraph = StateGraph(input=JokeInput, output=JokeOutput) subgraph.add_node("edit", edit) subgraph.add_node("generate", generate) subgraph.add_node("bump", bump) subgraph.set_entry_point("edit") subgraph.add_edge("edit", "generate") subgraph.add_edge("generate", "bump") subgraph.add_conditional_edges("bump", bump_loop) subgraph.set_finish_point("generate") subgraphc = subgraph.compile() # parent graph builder = StateGraph(OverallState) builder.add_node("generate_joke", subgraphc) builder.add_conditional_edges(START, continue_to_jokes) builder.add_edge("generate_joke", END) return builder def fanout_to_subgraph_sync() -> StateGraph: class OverallState(TypedDict): subjects: list[str] jokes: Annotated[list[str], operator.add] def continue_to_jokes(state: OverallState): return [Send("generate_joke", {"subject": s}) for s in state["subjects"]] class JokeInput(TypedDict): subject: str class JokeOutput(TypedDict): jokes: list[str] def bump(state: JokeOutput): return {"jokes": [state["jokes"][0] + " a"]} def generate(state: JokeInput): return {"jokes": [f"Joke about {state['subject']}"]} def edit(state: JokeInput): subject = state["subject"] return {"subject": f"{subject} - hohoho"} def bump_loop(state: JokeOutput): return END if state["jokes"][0].endswith(" a" * 10) else "bump" # subgraph subgraph = StateGraph(input=JokeInput, output=JokeOutput) subgraph.add_node("edit", edit) subgraph.add_node("generate", generate) subgraph.add_node("bump", bump) subgraph.set_entry_point("edit") subgraph.add_edge("edit", "generate") subgraph.add_edge("generate", "bump") subgraph.add_conditional_edges("bump", bump_loop) subgraph.set_finish_point("generate") subgraphc = subgraph.compile() # parent graph builder = StateGraph(OverallState) builder.add_node("generate_joke", subgraphc) builder.add_conditional_edges(START, continue_to_jokes) builder.add_edge("generate_joke", END) return builder if __name__ == "__main__": import asyncio import random import uvloop from langgraph.checkpoint.memory import MemorySaver graph = fanout_to_subgraph().compile(checkpointer=MemorySaver()) input = { "subjects": [ random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(1000) ] } config = {"configurable": {"thread_id": "1"}} async def run(): len([c async for c in graph.astream(input, config=config)]) uvloop.install() asyncio.run(run()) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/bench/react_agent.py`: ```py from typing import Any, Optional from uuid import uuid4 from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.tools import StructuredTool from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.prebuilt.chat_agent_executor import create_react_agent from langgraph.pregel import Pregel def react_agent(n_tools: int, checkpointer: Optional[BaseCheckpointSaver]) -> Pregel: class FakeFuntionChatModel(FakeMessagesListChatModel): def bind_tools(self, functions: list): return self def _generate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: response = self.responses[self.i].copy() if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 generation = ChatGeneration(message=response) return ChatResult(generations=[generation]) tool = StructuredTool.from_function( lambda query: f"result for query: {query}" * 10, name=str(uuid4()), description="", ) model = FakeFuntionChatModel( responses=[ AIMessage( content="", tool_calls=[ { "id": str(uuid4()), "name": tool.name, "args": {"query": str(uuid4()) * 100}, } ], id=str(uuid4()), ) for _ in range(n_tools) ] + [ AIMessage(content="answer" * 100, id=str(uuid4())), ] ) return create_react_agent(model, [tool], checkpointer=checkpointer) if __name__ == "__main__": import asyncio import uvloop from langgraph.checkpoint.memory import MemorySaver graph = react_agent(100, checkpointer=MemorySaver()) input = {"messages": [HumanMessage("hi?")]} config = {"configurable": {"thread_id": "1"}, "recursion_limit": 20000000000} async def run(): len([c async for c in graph.astream(input, config=config)]) uvloop.install() asyncio.run(run()) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/bench/wide_state.py`: ```py import operator from dataclasses import dataclass, field from functools import partial from typing import Annotated, Optional, Sequence from langgraph.constants import END, START from langgraph.graph.state import StateGraph def wide_state(n: int) -> StateGraph: @dataclass(kw_only=True) class State: messages: Annotated[list, operator.add] = field(default_factory=list) trigger_events: Annotated[list, operator.add] = field(default_factory=list) """The external events that are converted by the graph.""" primary_issue_medium: Annotated[str, lambda x, y: y or x] = field( default="email" ) autoresponse: Annotated[Optional[dict], lambda _, y: y] = field( default=None ) # Always overwrite issue: Annotated[dict | None, lambda x, y: y if y else x] = field(default=None) relevant_rules: Optional[list[dict]] = field(default=None) """SOPs fetched from the rulebook that are relevant to the current conversation.""" memory_docs: Optional[list[dict]] = field(default=None) """Memory docs fetched from the memory service that are relevant to the current conversation.""" categorizations: Annotated[list[dict], operator.add] = field( default_factory=list ) """The issue categorizations auto-generated by the AI.""" responses: Annotated[list[dict], operator.add] = field(default_factory=list) """The draft responses recommended by the AI.""" user_info: Annotated[Optional[dict], lambda x, y: y if y is not None else x] = ( field(default=None) ) """The current user state (by email).""" crm_info: Annotated[Optional[dict], lambda x, y: y if y is not None else x] = ( field(default=None) ) """The CRM information for organization the current user is from.""" email_thread_id: Annotated[ Optional[str], lambda x, y: y if y is not None else x ] = field(default=None) """The current email thread ID.""" slack_participants: Annotated[dict, operator.or_] = field(default_factory=dict) """The growing list of current slack participants.""" bot_id: Optional[str] = field(default=None) """The ID of the bot user in the slack channel.""" notified_assignees: Annotated[dict, operator.or_] = field(default_factory=dict) def read_write(read: str, write: Sequence[str], input: State) -> dict: val = getattr(input, read) val_single = val[-1] if isinstance(val, list) else val val_list = val if isinstance(val, list) else [val] return { k: val_list if isinstance(getattr(input, k), list) else val_single for k in write } builder = StateGraph(State) builder.add_edge(START, "one") builder.add_node( "one", partial(read_write, "messages", ["trigger_events", "primary_issue_medium"]), ) builder.add_edge("one", "two") builder.add_node( "two", partial(read_write, "trigger_events", ["autoresponse", "issue"]), ) builder.add_edge("two", "three") builder.add_edge("two", "four") builder.add_node( "three", partial(read_write, "autoresponse", ["relevant_rules"]), ) builder.add_node( "four", partial( read_write, "trigger_events", ["categorizations", "responses", "memory_docs"], ), ) builder.add_node( "five", partial( read_write, "categorizations", [ "user_info", "crm_info", "email_thread_id", "slack_participants", "bot_id", "notified_assignees", ], ), ) builder.add_edge(["three", "four"], "five") builder.add_edge("five", "six") builder.add_node( "six", partial(read_write, "responses", ["messages"]), ) builder.add_conditional_edges( "six", lambda state: END if len(state.messages) > n else "one" ) return builder if __name__ == "__main__": import asyncio import uvloop from langgraph.checkpoint.memory import MemorySaver graph = wide_state(1000).compile(checkpointer=MemorySaver()) input = { "messages": [ { str(i) * 10: { str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 for j in range(5) } for i in range(5) } ] } config = {"configurable": {"thread_id": "1"}, "recursion_limit": 20000000000} async def run(): async for c in graph.astream(input, config=config): print(c.keys()) uvloop.install() asyncio.run(run()) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/bench/__main__.py`: ```py import random from uuid import uuid4 from langchain_core.messages import HumanMessage from pyperf._runner import Runner from uvloop import new_event_loop from bench.fanout_to_subgraph import fanout_to_subgraph, fanout_to_subgraph_sync from bench.react_agent import react_agent from bench.wide_state import wide_state from langgraph.checkpoint.memory import MemorySaver from langgraph.pregel import Pregel async def arun(graph: Pregel, input: dict): len( [ c async for c in graph.astream( input, { "configurable": {"thread_id": str(uuid4())}, "recursion_limit": 1000000000, }, ) ] ) def run(graph: Pregel, input: dict): len( [ c for c in graph.stream( input, { "configurable": {"thread_id": str(uuid4())}, "recursion_limit": 1000000000, }, ) ] ) benchmarks = ( ( "fanout_to_subgraph_10x", fanout_to_subgraph().compile(checkpointer=None), fanout_to_subgraph_sync().compile(checkpointer=None), { "subjects": [ random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(10) ] }, ), ( "fanout_to_subgraph_10x_checkpoint", fanout_to_subgraph().compile(checkpointer=MemorySaver()), fanout_to_subgraph_sync().compile(checkpointer=MemorySaver()), { "subjects": [ random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(10) ] }, ), ( "fanout_to_subgraph_100x", fanout_to_subgraph().compile(checkpointer=None), fanout_to_subgraph_sync().compile(checkpointer=None), { "subjects": [ random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(100) ] }, ), ( "fanout_to_subgraph_100x_checkpoint", fanout_to_subgraph().compile(checkpointer=MemorySaver()), fanout_to_subgraph_sync().compile(checkpointer=MemorySaver()), { "subjects": [ random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(100) ] }, ), ( "react_agent_10x", react_agent(10, checkpointer=None), react_agent(10, checkpointer=None), {"messages": [HumanMessage("hi?")]}, ), ( "react_agent_10x_checkpoint", react_agent(10, checkpointer=MemorySaver()), react_agent(10, checkpointer=MemorySaver()), {"messages": [HumanMessage("hi?")]}, ), ( "react_agent_100x", react_agent(100, checkpointer=None), react_agent(100, checkpointer=None), {"messages": [HumanMessage("hi?")]}, ), ( "react_agent_100x_checkpoint", react_agent(100, checkpointer=MemorySaver()), react_agent(100, checkpointer=MemorySaver()), {"messages": [HumanMessage("hi?")]}, ), ( "wide_state_25x300", wide_state(300).compile(checkpointer=None), wide_state(300).compile(checkpointer=None), { "messages": [ { str(i) * 10: { str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 for j in range(5) } for i in range(5) } ] }, ), ( "wide_state_25x300_checkpoint", wide_state(300).compile(checkpointer=MemorySaver()), wide_state(300).compile(checkpointer=MemorySaver()), { "messages": [ { str(i) * 10: { str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 for j in range(5) } for i in range(5) } ] }, ), ( "wide_state_15x600", wide_state(600).compile(checkpointer=None), wide_state(600).compile(checkpointer=None), { "messages": [ { str(i) * 10: { str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 for j in range(5) } for i in range(3) } ] }, ), ( "wide_state_15x600_checkpoint", wide_state(600).compile(checkpointer=MemorySaver()), wide_state(600).compile(checkpointer=MemorySaver()), { "messages": [ { str(i) * 10: { str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 for j in range(5) } for i in range(3) } ] }, ), ( "wide_state_9x1200", wide_state(1200).compile(checkpointer=None), wide_state(1200).compile(checkpointer=None), { "messages": [ { str(i) * 10: { str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 for j in range(3) } for i in range(3) } ] }, ), ( "wide_state_9x1200_checkpoint", wide_state(1200).compile(checkpointer=MemorySaver()), wide_state(1200).compile(checkpointer=MemorySaver()), { "messages": [ { str(i) * 10: { str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 for j in range(3) } for i in range(3) } ] }, ), ) r = Runner() for name, agraph, graph, input in benchmarks: r.bench_async_func(name, arun, agraph, input, loop_factory=new_event_loop) if graph is not None: r.bench_func(name + "_sync", run, graph, input) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/pyproject.toml`: ```toml [tool.poetry] name = "langgraph" version = "0.2.56" description = "Building stateful, multi-actor applications with LLMs" authors = [] license = "MIT" readme = "README.md" repository = "https://www.github.com/langchain-ai/langgraph" [tool.poetry.dependencies] python = ">=3.9.0,<4.0" langchain-core = ">=0.2.43,<0.4.0,!=0.3.0,!=0.3.1,!=0.3.2,!=0.3.3,!=0.3.4,!=0.3.5,!=0.3.6,!=0.3.7,!=0.3.8,!=0.3.9,!=0.3.10,!=0.3.11,!=0.3.12,!=0.3.13,!=0.3.14" langgraph-checkpoint = "^2.0.4" langgraph-sdk = "^0.1.42" [tool.poetry.group.dev.dependencies] pytest = "^8.3.2" pytest-cov = "^4.0.0" pytest-dotenv = "^0.5.2" pytest-mock = "^3.10.0" syrupy = "^4.0.2" httpx = "^0.26.0" pytest-watcher = "^0.4.1" mypy = "^1.6.0" ruff = "^0.6.2" jupyter = "^1.0.0" pytest-xdist = {extras = ["psutil"], version = "^3.6.1"} pytest-repeat = "^0.9.3" langgraph-checkpoint = {path = "../checkpoint", develop = true} langgraph-checkpoint-duckdb = {path = "../checkpoint-duckdb", develop = true} langgraph-checkpoint-sqlite = {path = "../checkpoint-sqlite", develop = true} langgraph-checkpoint-postgres = {path = "../checkpoint-postgres", develop = true} langgraph-sdk = {path = "../sdk-py", develop = true} psycopg = {extras = ["binary"], version = ">=3.0.0", python = ">=3.10"} uvloop = "0.21.0beta1" pyperf = "^2.7.0" py-spy = "^0.3.14" types-requests = "^2.32.0.20240914" [tool.ruff] lint.select = [ "E", "F", "I" ] lint.ignore = [ "E501" ] line-length = 88 indent-width = 4 extend-include = ["*.ipynb"] [tool.ruff.format] quote-style = "double" indent-style = "space" skip-magic-trailing-comma = false line-ending = "auto" docstring-code-format = false docstring-code-line-length = "dynamic" [tool.mypy] # https://mypy.readthedocs.io/en/stable/config_file.html disallow_untyped_defs = "True" explicit_package_bases = "True" warn_no_return = "False" warn_unused_ignores = "True" warn_redundant_casts = "True" allow_redefinition = "True" disable_error_code = "typeddict-item, return-value, override, has-type" [tool.coverage.run] omit = ["tests/*"] [tool.pytest-watcher] now = true delay = 0.1 patterns = ["*.py"] [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] # --strict-markers will raise errors on unknown marks. # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks # # https://docs.pytest.org/en/7.1.x/reference/reference.html # --strict-config any warnings encountered while parsing the `pytest` # section of the configuration file raise errors. # # https://github.com/tophat/syrupy # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. addopts = "--full-trace --strict-markers --strict-config --durations=5 --snapshot-warn-unused" # Registering custom markers. # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_utils.py`: ```py import functools import sys import uuid from typing import ( Any, Callable, Dict, ForwardRef, List, Literal, Optional, TypedDict, TypeVar, Union, ) from unittest.mock import patch import langsmith import pytest from typing_extensions import Annotated, NotRequired, Required from langgraph.graph import END, StateGraph from langgraph.graph.graph import CompiledGraph from langgraph.utils.fields import _is_optional_type, get_field_default from langgraph.utils.runnable import is_async_callable, is_async_generator pytestmark = pytest.mark.anyio def test_is_async() -> None: async def func() -> None: pass assert is_async_callable(func) wrapped_func = functools.wraps(func)(func) assert is_async_callable(wrapped_func) def sync_func() -> None: pass assert not is_async_callable(sync_func) wrapped_sync_func = functools.wraps(sync_func)(sync_func) assert not is_async_callable(wrapped_sync_func) class AsyncFuncCallable: async def __call__(self) -> None: pass runnable = AsyncFuncCallable() assert is_async_callable(runnable) wrapped_runnable = functools.wraps(runnable)(runnable) assert is_async_callable(wrapped_runnable) class SyncFuncCallable: def __call__(self) -> None: pass sync_runnable = SyncFuncCallable() assert not is_async_callable(sync_runnable) wrapped_sync_runnable = functools.wraps(sync_runnable)(sync_runnable) assert not is_async_callable(wrapped_sync_runnable) def test_is_generator() -> None: async def gen(): yield assert is_async_generator(gen) wrapped_gen = functools.wraps(gen)(gen) assert is_async_generator(wrapped_gen) def sync_gen(): yield assert not is_async_generator(sync_gen) wrapped_sync_gen = functools.wraps(sync_gen)(sync_gen) assert not is_async_generator(wrapped_sync_gen) class AsyncGenCallable: async def __call__(self): yield runnable = AsyncGenCallable() assert is_async_generator(runnable) wrapped_runnable = functools.wraps(runnable)(runnable) assert is_async_generator(wrapped_runnable) class SyncGenCallable: def __call__(self): yield sync_runnable = SyncGenCallable() assert not is_async_generator(sync_runnable) wrapped_sync_runnable = functools.wraps(sync_runnable)(sync_runnable) assert not is_async_generator(wrapped_sync_runnable) @pytest.fixture def rt_graph() -> CompiledGraph: class State(TypedDict): foo: int node_run_id: int def node(_: State): from langsmith import get_current_run_tree # type: ignore return {"node_run_id": get_current_run_tree().id} # type: ignore graph = StateGraph(State) graph.add_node(node) graph.set_entry_point("node") graph.add_edge("node", END) return graph.compile() def test_runnable_callable_tracing_nested(rt_graph: CompiledGraph) -> None: with patch("langsmith.client.Client", spec=langsmith.Client) as mock_client: with patch("langchain_core.tracers.langchain.get_client") as mock_get_client: mock_get_client.return_value = mock_client with langsmith.tracing_context(enabled=True): res = rt_graph.invoke({"foo": 1}) assert isinstance(res["node_run_id"], uuid.UUID) @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for async contextvars support", ) async def test_runnable_callable_tracing_nested_async(rt_graph: CompiledGraph) -> None: with patch("langsmith.client.Client", spec=langsmith.Client) as mock_client: with patch("langchain_core.tracers.langchain.get_client") as mock_get_client: mock_get_client.return_value = mock_client with langsmith.tracing_context(enabled=True): res = await rt_graph.ainvoke({"foo": 1}) assert isinstance(res["node_run_id"], uuid.UUID) def test_is_optional_type(): assert _is_optional_type(None) assert not _is_optional_type(type(None)) assert _is_optional_type(Optional[list]) assert not _is_optional_type(int) assert _is_optional_type(Optional[Literal[1, 2, 3]]) assert not _is_optional_type(Literal[1, 2, 3]) assert _is_optional_type(Optional[List[int]]) assert _is_optional_type(Optional[Dict[str, int]]) assert not _is_optional_type(List[Optional[int]]) assert _is_optional_type(Union[Optional[str], Optional[int]]) assert _is_optional_type( Union[ Union[Optional[str], Optional[int]], Union[Optional[float], Optional[dict]] ] ) assert not _is_optional_type(Union[Union[str, int], Union[float, dict]]) assert _is_optional_type(Union[int, None]) assert _is_optional_type(Union[str, None, int]) assert _is_optional_type(Union[None, str, int]) assert not _is_optional_type(Union[int, str]) assert not _is_optional_type(Any) # Do we actually want this? assert _is_optional_type(Optional[Any]) class MyClass: pass assert _is_optional_type(Optional[MyClass]) assert not _is_optional_type(MyClass) assert _is_optional_type(Optional[ForwardRef("MyClass")]) assert not _is_optional_type(ForwardRef("MyClass")) assert _is_optional_type(Optional[Union[List[int], Dict[str, Optional[int]]]]) assert not _is_optional_type(Union[List[int], Dict[str, Optional[int]]]) assert _is_optional_type(Optional[Callable[[int], str]]) assert not _is_optional_type(Callable[[int], Optional[str]]) T = TypeVar("T") assert _is_optional_type(Optional[T]) assert not _is_optional_type(T) U = TypeVar("U", bound=Optional[T]) # type: ignore assert _is_optional_type(U) def test_is_required(): class MyBaseTypedDict(TypedDict): val_1: Required[Optional[str]] val_2: Required[str] val_3: NotRequired[str] val_4: NotRequired[Optional[str]] val_5: Annotated[NotRequired[int], "foo"] val_6: NotRequired[Annotated[int, "foo"]] val_7: Annotated[Required[int], "foo"] val_8: Required[Annotated[int, "foo"]] val_9: Optional[str] val_10: str annos = MyBaseTypedDict.__annotations__ assert get_field_default("val_1", annos["val_1"], MyBaseTypedDict) == ... assert get_field_default("val_2", annos["val_2"], MyBaseTypedDict) == ... assert get_field_default("val_3", annos["val_3"], MyBaseTypedDict) is None assert get_field_default("val_4", annos["val_4"], MyBaseTypedDict) is None # See https://peps.python.org/pep-0655/#interaction-with-annotated assert get_field_default("val_5", annos["val_5"], MyBaseTypedDict) is None assert get_field_default("val_6", annos["val_6"], MyBaseTypedDict) is None assert get_field_default("val_7", annos["val_7"], MyBaseTypedDict) == ... assert get_field_default("val_8", annos["val_8"], MyBaseTypedDict) == ... assert get_field_default("val_9", annos["val_9"], MyBaseTypedDict) is None assert get_field_default("val_10", annos["val_10"], MyBaseTypedDict) == ... class MyChildDict(MyBaseTypedDict): val_11: int val_11b: Optional[int] val_11c: Union[int, None, str] class MyGrandChildDict(MyChildDict, total=False): val_12: int val_13: Required[str] cannos = MyChildDict.__annotations__ gcannos = MyGrandChildDict.__annotations__ assert get_field_default("val_11", cannos["val_11"], MyChildDict) == ... assert get_field_default("val_11b", cannos["val_11b"], MyChildDict) is None assert get_field_default("val_11c", cannos["val_11c"], MyChildDict) is None assert get_field_default("val_12", gcannos["val_12"], MyGrandChildDict) is None assert get_field_default("val_9", gcannos["val_9"], MyGrandChildDict) is None assert get_field_default("val_13", gcannos["val_13"], MyGrandChildDict) == ... ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/conftest.py`: ```py import sys from contextlib import asynccontextmanager from typing import AsyncIterator, Optional from uuid import UUID, uuid4 import pytest from langchain_core import __version__ as core_version from packaging import version from psycopg import AsyncConnection, Connection from psycopg_pool import AsyncConnectionPool, ConnectionPool from pytest_mock import MockerFixture from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.duckdb import DuckDBSaver from langgraph.checkpoint.duckdb.aio import AsyncDuckDBSaver from langgraph.checkpoint.postgres import PostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from langgraph.store.base import BaseStore from langgraph.store.duckdb import AsyncDuckDBStore, DuckDBStore from langgraph.store.memory import InMemoryStore from langgraph.store.postgres import AsyncPostgresStore, PostgresStore pytest.register_assert_rewrite("tests.memory_assert") DEFAULT_POSTGRES_URI = "postgres://postgres:postgres@localhost:5442/" # TODO: fix this once core is released IS_LANGCHAIN_CORE_030_OR_GREATER = version.parse(core_version) >= version.parse( "0.3.0.dev0" ) SHOULD_CHECK_SNAPSHOTS = IS_LANGCHAIN_CORE_030_OR_GREATER @pytest.fixture def anyio_backend(): return "asyncio" @pytest.fixture() def deterministic_uuids(mocker: MockerFixture) -> MockerFixture: side_effect = ( UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000) ) return mocker.patch("uuid.uuid4", side_effect=side_effect) # checkpointer fixtures @pytest.fixture(scope="function") def checkpointer_memory(): from tests.memory_assert import MemorySaverAssertImmutable yield MemorySaverAssertImmutable() @pytest.fixture(scope="function") def checkpointer_sqlite(): with SqliteSaver.from_conn_string(":memory:") as checkpointer: yield checkpointer @asynccontextmanager async def _checkpointer_sqlite_aio(): async with AsyncSqliteSaver.from_conn_string(":memory:") as checkpointer: yield checkpointer @pytest.fixture(scope="function") def checkpointer_duckdb(): with DuckDBSaver.from_conn_string(":memory:") as checkpointer: checkpointer.setup() yield checkpointer @asynccontextmanager async def _checkpointer_duckdb_aio(): async with AsyncDuckDBSaver.from_conn_string(":memory:") as checkpointer: await checkpointer.setup() yield checkpointer @pytest.fixture(scope="function") def checkpointer_postgres(): database = f"test_{uuid4().hex[:16]}" # create unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: # yield checkpointer with PostgresSaver.from_conn_string( DEFAULT_POSTGRES_URI + database ) as checkpointer: checkpointer.setup() yield checkpointer finally: # drop unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") @pytest.fixture(scope="function") def checkpointer_postgres_pipe(): database = f"test_{uuid4().hex[:16]}" # create unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: # yield checkpointer with PostgresSaver.from_conn_string( DEFAULT_POSTGRES_URI + database ) as checkpointer: checkpointer.setup() # setup can't run inside pipeline because of implicit transaction with checkpointer.conn.pipeline() as pipe: checkpointer.pipe = pipe yield checkpointer finally: # drop unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") @pytest.fixture(scope="function") def checkpointer_postgres_pool(): database = f"test_{uuid4().hex[:16]}" # create unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: # yield checkpointer with ConnectionPool( DEFAULT_POSTGRES_URI + database, max_size=10, kwargs={"autocommit": True} ) as pool: checkpointer = PostgresSaver(pool) checkpointer.setup() yield checkpointer finally: # drop unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") @asynccontextmanager async def _checkpointer_postgres_aio(): if sys.version_info < (3, 10): pytest.skip("Async Postgres tests require Python 3.10+") database = f"test_{uuid4().hex[:16]}" # create unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: # yield checkpointer async with AsyncPostgresSaver.from_conn_string( DEFAULT_POSTGRES_URI + database ) as checkpointer: await checkpointer.setup() yield checkpointer finally: # drop unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") @asynccontextmanager async def _checkpointer_postgres_aio_pipe(): if sys.version_info < (3, 10): pytest.skip("Async Postgres tests require Python 3.10+") database = f"test_{uuid4().hex[:16]}" # create unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: # yield checkpointer async with AsyncPostgresSaver.from_conn_string( DEFAULT_POSTGRES_URI + database ) as checkpointer: await checkpointer.setup() # setup can't run inside pipeline because of implicit transaction async with checkpointer.conn.pipeline() as pipe: checkpointer.pipe = pipe yield checkpointer finally: # drop unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") @asynccontextmanager async def _checkpointer_postgres_aio_pool(): if sys.version_info < (3, 10): pytest.skip("Async Postgres tests require Python 3.10+") database = f"test_{uuid4().hex[:16]}" # create unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: # yield checkpointer async with AsyncConnectionPool( DEFAULT_POSTGRES_URI + database, max_size=10, kwargs={"autocommit": True} ) as pool: checkpointer = AsyncPostgresSaver(pool) await checkpointer.setup() yield checkpointer finally: # drop unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") @asynccontextmanager async def awith_checkpointer( checkpointer_name: Optional[str], ) -> AsyncIterator[BaseCheckpointSaver]: if checkpointer_name is None: yield None elif checkpointer_name == "memory": from tests.memory_assert import MemorySaverAssertImmutable yield MemorySaverAssertImmutable() elif checkpointer_name == "sqlite_aio": async with _checkpointer_sqlite_aio() as checkpointer: yield checkpointer elif checkpointer_name == "duckdb_aio": async with _checkpointer_duckdb_aio() as checkpointer: yield checkpointer elif checkpointer_name == "postgres_aio": async with _checkpointer_postgres_aio() as checkpointer: yield checkpointer elif checkpointer_name == "postgres_aio_pipe": async with _checkpointer_postgres_aio_pipe() as checkpointer: yield checkpointer elif checkpointer_name == "postgres_aio_pool": async with _checkpointer_postgres_aio_pool() as checkpointer: yield checkpointer else: raise NotImplementedError(f"Unknown checkpointer: {checkpointer_name}") @asynccontextmanager async def _store_postgres_aio(): if sys.version_info < (3, 10): pytest.skip("Async Postgres tests require Python 3.10+") database = f"test_{uuid4().hex[:16]}" async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: async with AsyncPostgresStore.from_conn_string( DEFAULT_POSTGRES_URI + database ) as store: await store.setup() yield store finally: async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") @asynccontextmanager async def _store_postgres_aio_pipe(): if sys.version_info < (3, 10): pytest.skip("Async Postgres tests require Python 3.10+") database = f"test_{uuid4().hex[:16]}" async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: async with AsyncPostgresStore.from_conn_string( DEFAULT_POSTGRES_URI + database ) as store: await store.setup() # Run in its own transaction async with AsyncPostgresStore.from_conn_string( DEFAULT_POSTGRES_URI + database, pipeline=True ) as store: yield store finally: async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") @asynccontextmanager async def _store_postgres_aio_pool(): if sys.version_info < (3, 10): pytest.skip("Async Postgres tests require Python 3.10+") database = f"test_{uuid4().hex[:16]}" async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: async with AsyncPostgresStore.from_conn_string( DEFAULT_POSTGRES_URI + database, pool_config={"max_size": 10}, ) as store: await store.setup() yield store finally: async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") @asynccontextmanager async def _store_duckdb_aio(): async with AsyncDuckDBStore.from_conn_string(":memory:") as store: await store.setup() yield store @pytest.fixture(scope="function") def store_postgres(): database = f"test_{uuid4().hex[:16]}" # create unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: # yield store with PostgresStore.from_conn_string(DEFAULT_POSTGRES_URI + database) as store: store.setup() yield store finally: # drop unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") @pytest.fixture(scope="function") def store_postgres_pipe(): database = f"test_{uuid4().hex[:16]}" # create unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: # yield store with PostgresStore.from_conn_string(DEFAULT_POSTGRES_URI + database) as store: store.setup() # Run in its own transaction with PostgresStore.from_conn_string( DEFAULT_POSTGRES_URI + database, pipeline=True ) as store: yield store finally: # drop unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") @pytest.fixture(scope="function") def store_postgres_pool(): database = f"test_{uuid4().hex[:16]}" # create unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: # yield store with PostgresStore.from_conn_string( DEFAULT_POSTGRES_URI + database, pool_config={"max_size": 10} ) as store: store.setup() yield store finally: # drop unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") @pytest.fixture(scope="function") def store_duckdb(): with DuckDBStore.from_conn_string(":memory:") as store: store.setup() yield store @pytest.fixture(scope="function") def store_in_memory(): yield InMemoryStore() @asynccontextmanager async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]: if store_name is None: yield None elif store_name == "in_memory": yield InMemoryStore() elif store_name == "postgres_aio": async with _store_postgres_aio() as store: yield store elif store_name == "postgres_aio_pipe": async with _store_postgres_aio_pipe() as store: yield store elif store_name == "postgres_aio_pool": async with _store_postgres_aio_pool() as store: yield store elif store_name == "duckdb_aio": async with _store_duckdb_aio() as store: yield store else: raise NotImplementedError(f"Unknown store {store_name}") ALL_CHECKPOINTERS_SYNC = [ "memory", "sqlite", "postgres", "postgres_pipe", "postgres_pool", ] ALL_CHECKPOINTERS_ASYNC = [ "memory", "sqlite_aio", "postgres_aio", "postgres_aio_pipe", "postgres_aio_pool", ] ALL_CHECKPOINTERS_ASYNC_PLUS_NONE = [ *ALL_CHECKPOINTERS_ASYNC, None, ] ALL_STORES_SYNC = [ "in_memory", "postgres", "postgres_pipe", "postgres_pool", "duckdb", ] ALL_STORES_ASYNC = [ "in_memory", "postgres_aio", "postgres_aio_pipe", "postgres_aio_pool", "duckdb_aio", ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_remote_graph.py`: ```py from unittest.mock import AsyncMock, MagicMock import pytest from langchain_core.runnables.graph import ( Edge as DrawableEdge, ) from langchain_core.runnables.graph import ( Node as DrawableNode, ) from langgraph_sdk.schema import StreamPart from langgraph.errors import GraphInterrupt from langgraph.pregel.remote import RemoteGraph from langgraph.pregel.types import StateSnapshot def test_with_config(): # set up test remote_pregel = RemoteGraph( "test_graph_id", config={ "configurable": { "foo": "bar", "thread_id": "thread_id_1", } }, ) # call method / assertions config = {"configurable": {"hello": "world"}} remote_pregel_copy = remote_pregel.with_config(config) # assert that a copy was returned assert remote_pregel_copy != remote_pregel # assert that configs were merged assert remote_pregel_copy.config == { "configurable": { "foo": "bar", "thread_id": "thread_id_1", "hello": "world", } } def test_get_graph(): # set up test mock_sync_client = MagicMock() mock_sync_client.assistants.get_graph.return_value = { "nodes": [ {"id": "__start__", "type": "schema", "data": "__start__"}, {"id": "__end__", "type": "schema", "data": "__end__"}, { "id": "agent", "type": "runnable", "data": { "id": ["langgraph", "utils", "RunnableCallable"], "name": "agent_1", }, }, ], "edges": [ {"source": "__start__", "target": "agent"}, {"source": "agent", "target": "__end__"}, ], } remote_pregel = RemoteGraph("test_graph_id", sync_client=mock_sync_client) # call method / assertions drawable_graph = remote_pregel.get_graph() assert drawable_graph.nodes == { "__start__": DrawableNode( id="__start__", name="__start__", data="__start__", metadata=None ), "__end__": DrawableNode( id="__end__", name="__end__", data="__end__", metadata=None ), "agent": DrawableNode( id="agent", name="agent_1", data={"id": ["langgraph", "utils", "RunnableCallable"], "name": "agent_1"}, metadata=None, ), } assert drawable_graph.edges == [ DrawableEdge(source="__start__", target="agent"), DrawableEdge(source="agent", target="__end__"), ] @pytest.mark.anyio async def test_aget_graph(): # set up test mock_async_client = AsyncMock() mock_async_client.assistants.get_graph.return_value = { "nodes": [ {"id": "__start__", "type": "schema", "data": "__start__"}, {"id": "__end__", "type": "schema", "data": "__end__"}, { "id": "agent", "type": "runnable", "data": { "id": ["langgraph", "utils", "RunnableCallable"], "name": "agent_1", }, }, ], "edges": [ {"source": "__start__", "target": "agent"}, {"source": "agent", "target": "__end__"}, ], } remote_pregel = RemoteGraph("test_graph_id", client=mock_async_client) # call method / assertions drawable_graph = await remote_pregel.aget_graph() assert drawable_graph.nodes == { "__start__": DrawableNode( id="__start__", name="__start__", data="__start__", metadata=None ), "__end__": DrawableNode( id="__end__", name="__end__", data="__end__", metadata=None ), "agent": DrawableNode( id="agent", name="agent_1", data={"id": ["langgraph", "utils", "RunnableCallable"], "name": "agent_1"}, metadata=None, ), } assert drawable_graph.edges == [ DrawableEdge(source="__start__", target="agent"), DrawableEdge(source="agent", target="__end__"), ] def test_get_state(): # set up test mock_sync_client = MagicMock() mock_sync_client.threads.get_state.return_value = { "values": {"messages": [{"type": "human", "content": "hello"}]}, "next": None, "checkpoint": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, }, "metadata": {}, "created_at": "timestamp", "parent_checkpoint": None, "tasks": [], } # call method / assertions remote_pregel = RemoteGraph( "test_graph_id", sync_client=mock_sync_client, ) config = {"configurable": {"thread_id": "thread1"}} state_snapshot = remote_pregel.get_state(config) assert state_snapshot == StateSnapshot( values={"messages": [{"type": "human", "content": "hello"}]}, next=(), config={ "configurable": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, } }, metadata={}, created_at="timestamp", parent_config=None, tasks=(), ) @pytest.mark.anyio async def test_aget_state(): mock_async_client = AsyncMock() mock_async_client.threads.get_state.return_value = { "values": {"messages": [{"type": "human", "content": "hello"}]}, "next": None, "checkpoint": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_2", "checkpoint_map": {}, }, "metadata": {}, "created_at": "timestamp", "parent_checkpoint": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, }, "tasks": [], } # call method / assertions remote_pregel = RemoteGraph( "test_graph_id", client=mock_async_client, ) config = {"configurable": {"thread_id": "thread1"}} state_snapshot = await remote_pregel.aget_state(config) assert state_snapshot == StateSnapshot( values={"messages": [{"type": "human", "content": "hello"}]}, next=(), config={ "configurable": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_2", "checkpoint_map": {}, } }, metadata={}, created_at="timestamp", parent_config={ "configurable": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, } }, tasks=(), ) def test_get_state_history(): # set up test mock_sync_client = MagicMock() mock_sync_client.threads.get_history.return_value = [ { "values": {"messages": [{"type": "human", "content": "hello"}]}, "next": None, "checkpoint": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, }, "metadata": {}, "created_at": "timestamp", "parent_checkpoint": None, "tasks": [], } ] # call method / assertions remote_pregel = RemoteGraph( "test_graph_id", sync_client=mock_sync_client, ) config = {"configurable": {"thread_id": "thread1"}} state_history_snapshot = list( remote_pregel.get_state_history(config, filter=None, before=None, limit=None) ) assert len(state_history_snapshot) == 1 assert state_history_snapshot[0] == StateSnapshot( values={"messages": [{"type": "human", "content": "hello"}]}, next=(), config={ "configurable": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, } }, metadata={}, created_at="timestamp", parent_config=None, tasks=(), ) @pytest.mark.anyio async def test_aget_state_history(): # set up test mock_async_client = AsyncMock() mock_async_client.threads.get_history.return_value = [ { "values": {"messages": [{"type": "human", "content": "hello"}]}, "next": None, "checkpoint": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, }, "metadata": {}, "created_at": "timestamp", "parent_checkpoint": None, "tasks": [], } ] # call method / assertions remote_pregel = RemoteGraph( "test_graph_id", client=mock_async_client, ) config = {"configurable": {"thread_id": "thread1"}} state_history_snapshot = [] async for state_snapshot in remote_pregel.aget_state_history( config, filter=None, before=None, limit=None ): state_history_snapshot.append(state_snapshot) assert len(state_history_snapshot) == 1 assert state_history_snapshot[0] == StateSnapshot( values={"messages": [{"type": "human", "content": "hello"}]}, next=(), config={ "configurable": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, } }, metadata={}, created_at="timestamp", parent_config=None, tasks=(), ) def test_update_state(): # set up test mock_sync_client = MagicMock() mock_sync_client.threads.update_state.return_value = { "checkpoint": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, } } # call method / assertions remote_pregel = RemoteGraph( "test_graph_id", sync_client=mock_sync_client, ) config = {"configurable": {"thread_id": "thread1"}} response = remote_pregel.update_state(config, {"key": "value"}) assert response == { "configurable": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, } } @pytest.mark.anyio async def test_aupdate_state(): # set up test mock_async_client = AsyncMock() mock_async_client.threads.update_state.return_value = { "checkpoint": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, } } # call method / assertions remote_pregel = RemoteGraph( "test_graph_id", client=mock_async_client, ) config = {"configurable": {"thread_id": "thread1"}} response = await remote_pregel.aupdate_state(config, {"key": "value"}) assert response == { "configurable": { "thread_id": "thread_1", "checkpoint_ns": "ns", "checkpoint_id": "checkpoint_1", "checkpoint_map": {}, } } def test_stream(): # set up test mock_sync_client = MagicMock() mock_sync_client.runs.stream.return_value = [ StreamPart(event="values", data={"chunk": "data1"}), StreamPart(event="values", data={"chunk": "data2"}), StreamPart(event="values", data={"chunk": "data3"}), StreamPart(event="updates", data={"chunk": "data4"}), StreamPart(event="updates", data={"__interrupt__": ()}), ] # call method / assertions remote_pregel = RemoteGraph( "test_graph_id", sync_client=mock_sync_client, ) # stream modes doesn't include 'updates' stream_parts = [] with pytest.raises(GraphInterrupt): for stream_part in remote_pregel.stream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, stream_mode="values", ): stream_parts.append(stream_part) assert stream_parts == [ {"chunk": "data1"}, {"chunk": "data2"}, {"chunk": "data3"}, ] mock_sync_client.runs.stream.return_value = [ StreamPart(event="updates", data={"chunk": "data3"}), StreamPart(event="updates", data={"chunk": "data4"}), StreamPart(event="updates", data={"__interrupt__": ()}), ] # default stream_mode is updates stream_parts = [] with pytest.raises(GraphInterrupt): for stream_part in remote_pregel.stream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, ): stream_parts.append(stream_part) assert stream_parts == [ {"chunk": "data3"}, {"chunk": "data4"}, ] # list stream_mode includes mode names stream_parts = [] with pytest.raises(GraphInterrupt): for stream_part in remote_pregel.stream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, stream_mode=["updates"], ): stream_parts.append(stream_part) assert stream_parts == [ ("updates", {"chunk": "data3"}), ("updates", {"chunk": "data4"}), ] # subgraphs + list modes stream_parts = [] with pytest.raises(GraphInterrupt): for stream_part in remote_pregel.stream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, stream_mode=["updates"], subgraphs=True, ): stream_parts.append(stream_part) assert stream_parts == [ ((), "updates", {"chunk": "data3"}), ((), "updates", {"chunk": "data4"}), ] # subgraphs + single mode stream_parts = [] with pytest.raises(GraphInterrupt): for stream_part in remote_pregel.stream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, subgraphs=True, ): stream_parts.append(stream_part) assert stream_parts == [ ((), {"chunk": "data3"}), ((), {"chunk": "data4"}), ] @pytest.mark.anyio async def test_astream(): # set up test mock_async_client = MagicMock() async_iter = MagicMock() async_iter.__aiter__.return_value = [ StreamPart(event="values", data={"chunk": "data1"}), StreamPart(event="values", data={"chunk": "data2"}), StreamPart(event="values", data={"chunk": "data3"}), StreamPart(event="updates", data={"chunk": "data4"}), StreamPart(event="updates", data={"__interrupt__": ()}), ] mock_async_client.runs.stream.return_value = async_iter # call method / assertions remote_pregel = RemoteGraph( "test_graph_id", client=mock_async_client, ) # stream modes doesn't include 'updates' stream_parts = [] with pytest.raises(GraphInterrupt): async for stream_part in remote_pregel.astream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, stream_mode="values", ): stream_parts.append(stream_part) assert stream_parts == [ {"chunk": "data1"}, {"chunk": "data2"}, {"chunk": "data3"}, ] async_iter = MagicMock() async_iter.__aiter__.return_value = [ StreamPart(event="updates", data={"chunk": "data3"}), StreamPart(event="updates", data={"chunk": "data4"}), StreamPart(event="updates", data={"__interrupt__": ()}), ] mock_async_client.runs.stream.return_value = async_iter # default stream_mode is updates stream_parts = [] with pytest.raises(GraphInterrupt): async for stream_part in remote_pregel.astream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, ): stream_parts.append(stream_part) assert stream_parts == [ {"chunk": "data3"}, {"chunk": "data4"}, ] # list stream_mode includes mode names stream_parts = [] with pytest.raises(GraphInterrupt): async for stream_part in remote_pregel.astream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, stream_mode=["updates"], ): stream_parts.append(stream_part) assert stream_parts == [ ("updates", {"chunk": "data3"}), ("updates", {"chunk": "data4"}), ] # subgraphs + list modes stream_parts = [] with pytest.raises(GraphInterrupt): async for stream_part in remote_pregel.astream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, stream_mode=["updates"], subgraphs=True, ): stream_parts.append(stream_part) assert stream_parts == [ ((), "updates", {"chunk": "data3"}), ((), "updates", {"chunk": "data4"}), ] # subgraphs + single mode stream_parts = [] with pytest.raises(GraphInterrupt): async for stream_part in remote_pregel.astream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, subgraphs=True, ): stream_parts.append(stream_part) assert stream_parts == [ ((), {"chunk": "data3"}), ((), {"chunk": "data4"}), ] async_iter = MagicMock() async_iter.__aiter__.return_value = [ StreamPart(event="updates|my|subgraph", data={"chunk": "data3"}), StreamPart(event="updates|hello|subgraph", data={"chunk": "data4"}), StreamPart(event="updates|bye|subgraph", data={"__interrupt__": ()}), ] mock_async_client.runs.stream.return_value = async_iter # subgraphs + list modes stream_parts = [] with pytest.raises(GraphInterrupt): async for stream_part in remote_pregel.astream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, stream_mode=["updates"], subgraphs=True, ): stream_parts.append(stream_part) assert stream_parts == [ (("my", "subgraph"), "updates", {"chunk": "data3"}), (("hello", "subgraph"), "updates", {"chunk": "data4"}), ] # subgraphs + single mode stream_parts = [] with pytest.raises(GraphInterrupt): async for stream_part in remote_pregel.astream( {"input": "data"}, config={"configurable": {"thread_id": "thread_1"}}, subgraphs=True, ): stream_parts.append(stream_part) assert stream_parts == [ (("my", "subgraph"), {"chunk": "data3"}), (("hello", "subgraph"), {"chunk": "data4"}), ] def test_invoke(): # set up test mock_sync_client = MagicMock() mock_sync_client.runs.stream.return_value = [ StreamPart(event="values", data={"chunk": "data1"}), StreamPart(event="values", data={"chunk": "data2"}), StreamPart( event="values", data={"messages": [{"type": "human", "content": "world"}]} ), ] # call method / assertions remote_pregel = RemoteGraph( "test_graph_id", sync_client=mock_sync_client, ) config = {"configurable": {"thread_id": "thread_1"}} result = remote_pregel.invoke( {"input": {"messages": [{"type": "human", "content": "hello"}]}}, config ) assert result == {"messages": [{"type": "human", "content": "world"}]} @pytest.mark.anyio async def test_ainvoke(): # set up test mock_async_client = MagicMock() async_iter = MagicMock() async_iter.__aiter__.return_value = [ StreamPart(event="values", data={"chunk": "data1"}), StreamPart(event="values", data={"chunk": "data2"}), StreamPart( event="values", data={"messages": [{"type": "human", "content": "world"}]} ), ] mock_async_client.runs.stream.return_value = async_iter # call method / assertions remote_pregel = RemoteGraph( "test_graph_id", client=mock_async_client, ) config = {"configurable": {"thread_id": "thread_1"}} result = await remote_pregel.ainvoke( {"input": {"messages": [{"type": "human", "content": "hello"}]}}, config ) assert result == {"messages": [{"type": "human", "content": "world"}]} @pytest.mark.skip("Unskip this test to manually test the LangGraph Cloud integration") @pytest.mark.anyio async def test_langgraph_cloud_integration(): from langgraph_sdk.client import get_client, get_sync_client from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, MessagesState, StateGraph # create RemotePregel instance client = get_client() sync_client = get_sync_client() remote_pregel = RemoteGraph( "agent", client=client, sync_client=sync_client, ) # define graph workflow = StateGraph(MessagesState) workflow.add_node("agent", remote_pregel) workflow.add_edge(START, "agent") workflow.add_edge("agent", END) app = workflow.compile(checkpointer=MemorySaver()) # test invocation input = { "messages": [ { "role": "human", "content": "What's the weather in SF?", } ] } # test invoke response = app.invoke( input, config={"configurable": {"thread_id": "39a6104a-34e7-4f83-929c-d9eb163003c9"}}, interrupt_before=["agent"], ) print("response:", response["messages"][-1].content) # test stream async for chunk in app.astream( input, config={"configurable": {"thread_id": "2dc3e3e7-39ac-4597-aa57-4404b944e82a"}}, subgraphs=True, stream_mode=["debug", "messages"], ): print("chunk:", chunk) # test stream events async for chunk in remote_pregel.astream_events( input, config={"configurable": {"thread_id": "2dc3e3e7-39ac-4597-aa57-4404b944e82a"}}, version="v2", subgraphs=True, stream_mode=[], ): print("chunk:", chunk) # test get state state_snapshot = await remote_pregel.aget_state( config={"configurable": {"thread_id": "2dc3e3e7-39ac-4597-aa57-4404b944e82a"}}, subgraphs=True, ) print("state snapshot:", state_snapshot) # test update state response = await remote_pregel.aupdate_state( config={"configurable": {"thread_id": "6645e002-ed50-4022-92a3-d0d186fdf812"}}, values={ "messages": [ { "role": "ai", "content": "Hello world again!", } ] }, ) print("response:", response) # test get history async for state in remote_pregel.aget_state_history( config={"configurable": {"thread_id": "2dc3e3e7-39ac-4597-aa57-4404b944e82a"}}, ): print("state snapshot:", state) # test get graph remote_pregel.graph_id = "fe096781-5601-53d2-b2f6-0d3403f7e9ca" # must be UUID graph = await remote_pregel.aget_graph(xray=True) print("graph:", graph) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/fake_chat.py`: ```py import re from typing import Any, AsyncIterator, Iterator, List, Optional, cast from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models.fake_chat_models import GenericFakeChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult class FakeChatModel(GenericFakeChatModel): messages: list[BaseMessage] i: int = 0 def bind_tools(self, functions: list): return self def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """Top Level call""" if self.i >= len(self.messages): self.i = 0 message = self.messages[self.i] self.i += 1 if isinstance(message, str): message_ = AIMessage(content=message) else: if hasattr(message, "model_copy"): message_ = message.model_copy() else: message_ = message.copy() generation = ChatGeneration(message=message_) return ChatResult(generations=[generation]) def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Stream the output of the model.""" chat_result = self._generate( messages, stop=stop, run_manager=run_manager, **kwargs ) if not isinstance(chat_result, ChatResult): raise ValueError( f"Expected generate to return a ChatResult, " f"but got {type(chat_result)} instead." ) message = chat_result.generations[0].message if not isinstance(message, AIMessage): raise ValueError( f"Expected invoke to return an AIMessage, " f"but got {type(message)} instead." ) content = message.content if content: # Use a regular expression to split on whitespace with a capture group # so that we can preserve the whitespace in the output. assert isinstance(content, str) content_chunks = cast(list[str], re.split(r"(\s)", content)) for token in content_chunks: chunk = ChatGenerationChunk( message=AIMessageChunk(content=token, id=message.id) ) if run_manager: run_manager.on_llm_new_token(token, chunk=chunk) yield chunk else: args = message.__dict__ args.pop("type") chunk = ChatGenerationChunk(message=AIMessageChunk(**args)) if run_manager: run_manager.on_llm_new_token("", chunk=chunk) yield chunk async def _astream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: """Stream the output of the model.""" chat_result = self._generate( messages, stop=stop, run_manager=run_manager, **kwargs ) if not isinstance(chat_result, ChatResult): raise ValueError( f"Expected generate to return a ChatResult, " f"but got {type(chat_result)} instead." ) message = chat_result.generations[0].message if not isinstance(message, AIMessage): raise ValueError( f"Expected invoke to return an AIMessage, " f"but got {type(message)} instead." ) content = message.content if content: # Use a regular expression to split on whitespace with a capture group # so that we can preserve the whitespace in the output. assert isinstance(content, str) content_chunks = cast(list[str], re.split(r"(\s)", content)) for token in content_chunks: chunk = ChatGenerationChunk( message=AIMessageChunk(content=token, id=message.id) ) if run_manager: run_manager.on_llm_new_token(token, chunk=chunk) yield chunk else: args = message.__dict__ args.pop("type") chunk = ChatGenerationChunk(message=AIMessageChunk(**args)) if run_manager: await run_manager.on_llm_new_token("", chunk=chunk) yield chunk ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/any_int.py`: ```py class AnyInt(int): def __init__(self) -> None: super().__init__() def __eq__(self, other: object) -> bool: return isinstance(other, int) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_interruption.py`: ```py from typing import TypedDict import pytest from pytest_mock import MockerFixture from langgraph.graph import END, START, StateGraph from tests.conftest import ( ALL_CHECKPOINTERS_ASYNC, ALL_CHECKPOINTERS_SYNC, awith_checkpointer, ) pytestmark = pytest.mark.anyio @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_interruption_without_state_updates( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: """Test interruption without state updates. This test confirms that interrupting doesn't require a state key having been updated in the prev step""" class State(TypedDict): input: str def noop(_state): pass builder = StateGraph(State) builder.add_node("step_1", noop) builder.add_node("step_2", noop) builder.add_node("step_3", noop) builder.add_edge(START, "step_1") builder.add_edge("step_1", "step_2") builder.add_edge("step_2", "step_3") builder.add_edge("step_3", END) checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") graph = builder.compile(checkpointer=checkpointer, interrupt_after="*") initial_input = {"input": "hello world"} thread = {"configurable": {"thread_id": "1"}} graph.invoke(initial_input, thread, debug=True) assert graph.get_state(thread).next == ("step_2",) graph.invoke(None, thread, debug=True) assert graph.get_state(thread).next == ("step_3",) graph.invoke(None, thread, debug=True) assert graph.get_state(thread).next == () @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_interruption_without_state_updates_async( checkpointer_name: str, mocker: MockerFixture ): """Test interruption without state updates. This test confirms that interrupting doesn't require a state key having been updated in the prev step""" class State(TypedDict): input: str async def noop(_state): pass builder = StateGraph(State) builder.add_node("step_1", noop) builder.add_node("step_2", noop) builder.add_node("step_3", noop) builder.add_edge(START, "step_1") builder.add_edge("step_1", "step_2") builder.add_edge("step_2", "step_3") builder.add_edge("step_3", END) async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer, interrupt_after="*") initial_input = {"input": "hello world"} thread = {"configurable": {"thread_id": "1"}} await graph.ainvoke(initial_input, thread, debug=True) assert (await graph.aget_state(thread)).next == ("step_2",) await graph.ainvoke(None, thread, debug=True) assert (await graph.aget_state(thread)).next == ("step_3",) await graph.ainvoke(None, thread, debug=True) assert (await graph.aget_state(thread)).next == () ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/memory_assert.py`: ```py import asyncio import os import tempfile from collections import defaultdict from functools import partial from typing import Any, Optional from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, SerializerProtocol, copy_checkpoint, ) from langgraph.checkpoint.memory import MemorySaver, PersistentDict class NoopSerializer(SerializerProtocol): def loads_typed(self, data: tuple[str, bytes]) -> Any: return data[1] def dumps_typed(self, obj: Any) -> tuple[str, bytes]: return "type", obj class MemorySaverAssertImmutable(MemorySaver): storage_for_copies: defaultdict[str, dict[str, dict[str, Checkpoint]]] def __init__( self, *, serde: Optional[SerializerProtocol] = None, put_sleep: Optional[float] = None, ) -> None: _, filename = tempfile.mkstemp() super().__init__( serde=serde, factory=partial(PersistentDict, filename=filename) ) self.storage_for_copies = defaultdict(lambda: defaultdict(dict)) self.put_sleep = put_sleep self.stack.callback(os.remove, filename) def put( self, config: dict, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> None: if self.put_sleep: import time time.sleep(self.put_sleep) # assert checkpoint hasn't been modified since last written thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"]["checkpoint_ns"] if saved := super().get(config): assert ( self.serde.loads_typed( self.storage_for_copies[thread_id][checkpoint_ns][saved["id"]] ) == saved ) self.storage_for_copies[thread_id][checkpoint_ns][checkpoint["id"]] = ( self.serde.dumps_typed(copy_checkpoint(checkpoint)) ) # call super to write checkpoint return super().put(config, checkpoint, metadata, new_versions) class MemorySaverAssertCheckpointMetadata(MemorySaver): """This custom checkpointer is for verifying that a run's configurable fields are merged with the previous checkpoint config for each step in the run. This is the desired behavior. Because the checkpointer's (a)put() method is called for each step, the implementation of this checkpointer should produce a side effect that can be asserted. """ def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> None: """The implementation of put() merges config["configurable"] (a run's configurable fields) with the metadata field. The state of the checkpoint metadata can be asserted to confirm that the run's configurable fields were merged with the previous checkpoint config. """ configurable = config["configurable"].copy() # remove checkpoint_id to make testing simpler checkpoint_id = configurable.pop("checkpoint_id", None) thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"]["checkpoint_ns"] self.storage[thread_id][checkpoint_ns].update( { checkpoint["id"]: ( self.serde.dumps_typed(checkpoint), # merge configurable fields and metadata self.serde.dumps_typed({**configurable, **metadata}), checkpoint_id, ) } ) return { "configurable": { "thread_id": config["configurable"]["thread_id"], "checkpoint_id": checkpoint["id"], } } async def aput( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: return await asyncio.get_running_loop().run_in_executor( None, self.put, config, checkpoint, metadata, new_versions ) class MemorySaverNoPending(MemorySaver): def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: result = super().get_tuple(config) if result: return CheckpointTuple(result.config, result.checkpoint, result.metadata) return result ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_pregel_async.py`: ```py import asyncio import logging import operator import random import re import sys import uuid from collections import Counter from contextlib import asynccontextmanager, contextmanager from dataclasses import replace from time import perf_counter from typing import ( Annotated, Any, AsyncGenerator, AsyncIterator, Dict, Generator, List, Literal, Optional, Tuple, TypedDict, Union, cast, ) from uuid import UUID import httpx import pytest from langchain_core.messages import ToolCall from langchain_core.runnables import ( RunnableConfig, RunnableLambda, RunnablePassthrough, RunnablePick, ) from langchain_core.utils.aiter import aclosing from pydantic import BaseModel from pytest_mock import MockerFixture from syrupy import SnapshotAssertion from langgraph.channels.base import BaseChannel from langgraph.channels.binop import BinaryOperatorAggregate from langgraph.channels.context import Context from langgraph.channels.last_value import LastValue from langgraph.channels.topic import Topic from langgraph.channels.untracked_value import UntrackedValue from langgraph.checkpoint.base import ( ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, ) from langgraph.checkpoint.memory import MemorySaver from langgraph.constants import ( CONFIG_KEY_NODE_FINISHED, ERROR, FF_SEND_V2, PULL, PUSH, START, ) from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt from langgraph.graph import END, Graph, StateGraph from langgraph.graph.message import MessageGraph, MessagesState, add_messages from langgraph.managed.shared_value import SharedValue from langgraph.prebuilt.chat_agent_executor import create_tool_calling_executor from langgraph.prebuilt.tool_node import ToolNode from langgraph.pregel import Channel, GraphRecursionError, Pregel, StateSnapshot from langgraph.pregel.retry import RetryPolicy from langgraph.store.base import BaseStore from langgraph.store.memory import InMemoryStore from langgraph.types import ( Command, Interrupt, PregelTask, Send, StreamWriter, interrupt, ) from tests.any_str import AnyDict, AnyStr, AnyVersion, FloatBetween, UnsortedSequence from tests.conftest import ( ALL_CHECKPOINTERS_ASYNC, ALL_CHECKPOINTERS_ASYNC_PLUS_NONE, ALL_STORES_ASYNC, SHOULD_CHECK_SNAPSHOTS, awith_checkpointer, awith_store, ) from tests.fake_chat import FakeChatModel from tests.fake_tracer import FakeTracer from tests.memory_assert import ( MemorySaverAssertCheckpointMetadata, MemorySaverNoPending, ) from tests.messages import ( _AnyIdAIMessage, _AnyIdAIMessageChunk, _AnyIdHumanMessage, _AnyIdToolMessage, ) logger = logging.getLogger(__name__) pytestmark = pytest.mark.anyio async def test_checkpoint_errors() -> None: class FaultyGetCheckpointer(MemorySaver): async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: raise ValueError("Faulty get_tuple") class FaultyPutCheckpointer(MemorySaver): async def aput( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: raise ValueError("Faulty put") class FaultyPutWritesCheckpointer(MemorySaver): async def aput_writes( self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str ) -> RunnableConfig: raise ValueError("Faulty put_writes") class FaultyVersionCheckpointer(MemorySaver): def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int: raise ValueError("Faulty get_next_version") def logic(inp: str) -> str: return "" builder = StateGraph(Annotated[str, operator.add]) builder.add_node("agent", logic) builder.add_edge(START, "agent") graph = builder.compile(checkpointer=FaultyGetCheckpointer()) with pytest.raises(ValueError, match="Faulty get_tuple"): await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}}) with pytest.raises(ValueError, match="Faulty get_tuple"): async for _ in graph.astream("", {"configurable": {"thread_id": "thread-2"}}): pass with pytest.raises(ValueError, match="Faulty get_tuple"): async for _ in graph.astream_events( "", {"configurable": {"thread_id": "thread-3"}}, version="v2" ): pass graph = builder.compile(checkpointer=FaultyPutCheckpointer()) with pytest.raises(ValueError, match="Faulty put"): await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}}) with pytest.raises(ValueError, match="Faulty put"): async for _ in graph.astream("", {"configurable": {"thread_id": "thread-2"}}): pass with pytest.raises(ValueError, match="Faulty put"): async for _ in graph.astream_events( "", {"configurable": {"thread_id": "thread-3"}}, version="v2" ): pass graph = builder.compile(checkpointer=FaultyVersionCheckpointer()) with pytest.raises(ValueError, match="Faulty get_next_version"): await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}}) with pytest.raises(ValueError, match="Faulty get_next_version"): async for _ in graph.astream("", {"configurable": {"thread_id": "thread-2"}}): pass with pytest.raises(ValueError, match="Faulty get_next_version"): async for _ in graph.astream_events( "", {"configurable": {"thread_id": "thread-3"}}, version="v2" ): pass # add a parallel node builder.add_node("parallel", logic) builder.add_edge(START, "parallel") graph = builder.compile(checkpointer=FaultyPutWritesCheckpointer()) with pytest.raises(ValueError, match="Faulty put_writes"): await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}}) with pytest.raises(ValueError, match="Faulty put_writes"): async for _ in graph.astream("", {"configurable": {"thread_id": "thread-2"}}): pass with pytest.raises(ValueError, match="Faulty put_writes"): async for _ in graph.astream_events( "", {"configurable": {"thread_id": "thread-3"}}, version="v2" ): pass async def test_node_cancellation_on_external_cancel() -> None: inner_task_cancelled = False async def awhile(input: Any) -> None: try: await asyncio.sleep(1) except asyncio.CancelledError: nonlocal inner_task_cancelled inner_task_cancelled = True raise builder = Graph() builder.add_node("agent", awhile) builder.set_entry_point("agent") builder.set_finish_point("agent") graph = builder.compile() with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(graph.ainvoke(1), 0.5) assert inner_task_cancelled async def test_node_cancellation_on_other_node_exception() -> None: inner_task_cancelled = False async def awhile(input: Any) -> None: try: await asyncio.sleep(1) except asyncio.CancelledError: nonlocal inner_task_cancelled inner_task_cancelled = True raise async def iambad(input: Any) -> None: raise ValueError("I am bad") builder = Graph() builder.add_node("agent", awhile) builder.add_node("bad", iambad) builder.set_conditional_entry_point(lambda _: ["agent", "bad"], then=END) graph = builder.compile() with pytest.raises(ValueError, match="I am bad"): # This will raise ValueError, not TimeoutError await asyncio.wait_for(graph.ainvoke(1), 0.5) assert inner_task_cancelled async def test_node_cancellation_on_other_node_exception_two() -> None: async def awhile(input: Any) -> None: await asyncio.sleep(1) async def iambad(input: Any) -> None: raise ValueError("I am bad") builder = Graph() builder.add_node("agent", awhile) builder.add_node("bad", iambad) builder.set_conditional_entry_point(lambda _: ["agent", "bad"], then=END) graph = builder.compile() with pytest.raises(ValueError, match="I am bad"): # This will raise ValueError, not CancelledError await graph.ainvoke(1) @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for async contextvars support", ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_dynamic_interrupt(checkpointer_name: str) -> None: class State(TypedDict): my_key: Annotated[str, operator.add] market: str tool_two_node_count = 0 async def tool_two_node(s: State) -> State: nonlocal tool_two_node_count tool_two_node_count += 1 if s["market"] == "DE": answer = interrupt("Just because...") else: answer = " all good" return {"my_key": answer} tool_two_graph = StateGraph(State) tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy()) tool_two_graph.add_edge(START, "tool_two") tool_two = tool_two_graph.compile() tracer = FakeTracer() assert await tool_two.ainvoke( {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} ) == { "my_key": "value", "market": "DE", } assert tool_two_node_count == 1, "interrupts aren't retried" assert len(tracer.runs) == 1 run = tracer.runs[0] assert run.end_time is not None assert run.error is None assert run.outputs == {"market": "DE", "my_key": "value"} assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == { "my_key": "value all good", "market": "US", } async with awith_checkpointer(checkpointer_name) as checkpointer: tool_two = tool_two_graph.compile(checkpointer=checkpointer) # missing thread_id with pytest.raises(ValueError, match="thread_id"): await tool_two.ainvoke({"my_key": "value", "market": "DE"}) # flow: interrupt -> resume with answer thread2 = {"configurable": {"thread_id": "2"}} # stop when about to enter node assert [ c async for c in tool_two.astream( {"my_key": "value ⛰️", "market": "DE"}, thread2 ) ] == [ { "__interrupt__": ( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:")], ), ) }, ] # resume with answer assert [ c async for c in tool_two.astream(Command(resume=" my answer"), thread2) ] == [ {"tool_two": {"my_key": " my answer"}}, ] # flow: interrupt -> clear thread1 = {"configurable": {"thread_id": "1"}} # stop when about to enter node assert [ c async for c in tool_two.astream( {"my_key": "value ⛰️", "market": "DE"}, thread1 ) ] == [ { "__interrupt__": ( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:")], ), ) }, ] assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ { "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, { "parents": {}, "source": "input", "step": -1, "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, "thread_id": "1", }, ] tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=("tool_two",), tasks=( PregelTask( AnyStr(), "tool_two", (PULL, "tool_two"), interrupts=( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:")], ), ), ), ), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1, limit=2) ][-1].config, ) # clear the interrupt and next tasks await tool_two.aupdate_state(thread1, None, as_node=END) # interrupt is cleared, as well as the next tasks tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=(), tasks=(), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": {}, "thread_id": "1", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1, limit=2) ][-1].config, ) @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for async contextvars support", ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_dynamic_interrupt_subgraph(checkpointer_name: str) -> None: class SubgraphState(TypedDict): my_key: str market: str tool_two_node_count = 0 def tool_two_node(s: SubgraphState) -> SubgraphState: nonlocal tool_two_node_count tool_two_node_count += 1 if s["market"] == "DE": answer = interrupt("Just because...") else: answer = " all good" return {"my_key": answer} subgraph = StateGraph(SubgraphState) subgraph.add_node("do", tool_two_node, retry=RetryPolicy()) subgraph.add_edge(START, "do") class State(TypedDict): my_key: Annotated[str, operator.add] market: str tool_two_graph = StateGraph(State) tool_two_graph.add_node("tool_two", subgraph.compile()) tool_two_graph.add_edge(START, "tool_two") tool_two = tool_two_graph.compile() tracer = FakeTracer() assert await tool_two.ainvoke( {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} ) == { "my_key": "value", "market": "DE", } assert tool_two_node_count == 1, "interrupts aren't retried" assert len(tracer.runs) == 1 run = tracer.runs[0] assert run.end_time is not None assert run.error is None assert run.outputs == {"market": "DE", "my_key": "value"} assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == { "my_key": "value all good", "market": "US", } async with awith_checkpointer(checkpointer_name) as checkpointer: tool_two = tool_two_graph.compile(checkpointer=checkpointer) # missing thread_id with pytest.raises(ValueError, match="thread_id"): await tool_two.ainvoke({"my_key": "value", "market": "DE"}) # flow: interrupt -> resume with answer thread2 = {"configurable": {"thread_id": "2"}} # stop when about to enter node assert [ c async for c in tool_two.astream( {"my_key": "value ⛰️", "market": "DE"}, thread2 ) ] == [ { "__interrupt__": ( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:"), AnyStr("do:")], ), ) }, ] # resume with answer assert [ c async for c in tool_two.astream(Command(resume=" my answer"), thread2) ] == [ {"tool_two": {"my_key": " my answer", "market": "DE"}}, ] # flow: interrupt -> clear thread1 = {"configurable": {"thread_id": "1"}} thread1root = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} # stop when about to enter node assert [ c async for c in tool_two.astream( {"my_key": "value ⛰️", "market": "DE"}, thread1 ) ] == [ { "__interrupt__": ( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:"), AnyStr("do:")], ), ) }, ] assert [c.metadata async for c in tool_two.checkpointer.alist(thread1root)] == [ { "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, { "parents": {}, "source": "input", "step": -1, "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, "thread_id": "1", }, ] tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=("tool_two",), tasks=( PregelTask( AnyStr(), "tool_two", (PULL, "tool_two"), interrupts=( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:"), AnyStr("do:")], ), ), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("tool_two:"), } }, ), ), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1root, limit=2) ][-1].config, ) # clear the interrupt and next tasks await tool_two.aupdate_state(thread1, None, as_node=END) # interrupt is cleared, as well as the next tasks tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=(), tasks=(), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": {}, "thread_id": "1", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1root, limit=2) ][-1].config, ) @pytest.mark.skipif(not FF_SEND_V2, reason="send v2 is not enabled") @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for async contextvars support", ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_copy_checkpoint(checkpointer_name: str) -> None: class State(TypedDict): my_key: Annotated[str, operator.add] market: str def tool_one(s: State) -> State: return {"my_key": " one"} tool_two_node_count = 0 def tool_two_node(s: State) -> State: nonlocal tool_two_node_count tool_two_node_count += 1 if s["market"] == "DE": answer = interrupt("Just because...") else: answer = " all good" return {"my_key": answer} def start(state: State) -> list[Union[Send, str]]: return ["tool_two", Send("tool_one", state)] tool_two_graph = StateGraph(State) tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy()) tool_two_graph.add_node("tool_one", tool_one) tool_two_graph.set_conditional_entry_point(start) tool_two = tool_two_graph.compile() tracer = FakeTracer() assert await tool_two.ainvoke( {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} ) == { "my_key": "value one", "market": "DE", } assert tool_two_node_count == 1, "interrupts aren't retried" assert len(tracer.runs) == 1 run = tracer.runs[0] assert run.end_time is not None assert run.error is None assert run.outputs == {"market": "DE", "my_key": "value one"} assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == { "my_key": "value one all good", "market": "US", } async with awith_checkpointer(checkpointer_name) as checkpointer: tool_two = tool_two_graph.compile(checkpointer=checkpointer) # missing thread_id with pytest.raises(ValueError, match="thread_id"): await tool_two.ainvoke({"my_key": "value", "market": "DE"}) # flow: interrupt -> resume with answer thread2 = {"configurable": {"thread_id": "2"}} # stop when about to enter node assert [ c async for c in tool_two.astream( {"my_key": "value ⛰️", "market": "DE"}, thread2 ) ] == [ { "tool_one": {"my_key": " one"}, }, { "__interrupt__": ( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:")], ), ) }, ] # resume with answer assert [ c async for c in tool_two.astream(Command(resume=" my answer"), thread2) ] == [ {"tool_two": {"my_key": " my answer"}}, ] # flow: interrupt -> clear tasks thread1 = {"configurable": {"thread_id": "1"}} # stop when about to enter node assert await tool_two.ainvoke( {"my_key": "value ⛰️", "market": "DE"}, thread1 ) == { "my_key": "value ⛰️ one", "market": "DE", } assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ { "parents": {}, "source": "loop", "step": 0, "writes": {"tool_one": {"my_key": " one"}}, "thread_id": "1", }, { "parents": {}, "source": "input", "step": -1, "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, "thread_id": "1", }, ] tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️ one", "market": "DE"}, next=("tool_two",), tasks=( PregelTask( AnyStr(), "tool_two", (PULL, "tool_two"), interrupts=( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:")], ), ), ), ), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": {"tool_one": {"my_key": " one"}}, "thread_id": "1", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1, limit=2) ][-1].config, ) # clear the interrupt and next tasks await tool_two.aupdate_state(thread1, None) # interrupt is cleared, next task is kept tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️ one", "market": "DE"}, next=("tool_two",), tasks=( PregelTask( AnyStr(), "tool_two", (PULL, "tool_two"), interrupts=(), ), ), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": {}, "thread_id": "1", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1, limit=2) ][-1].config, ) @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for async contextvars support", ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_node_not_cancelled_on_other_node_interrupted( checkpointer_name: str, ) -> None: class State(TypedDict): hello: Annotated[str, operator.add] awhiles = 0 inner_task_cancelled = False async def awhile(input: State) -> None: nonlocal awhiles awhiles += 1 try: await asyncio.sleep(1) return {"hello": " again"} except asyncio.CancelledError: nonlocal inner_task_cancelled inner_task_cancelled = True raise async def iambad(input: State) -> None: return {"hello": interrupt("I am bad")} builder = StateGraph(State) builder.add_node("agent", awhile) builder.add_node("bad", iambad) builder.set_conditional_entry_point(lambda _: ["agent", "bad"], then=END) async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) thread = {"configurable": {"thread_id": "1"}} # writes from "awhile" are applied to last chunk assert await graph.ainvoke({"hello": "world"}, thread) == { "hello": "world again" } assert not inner_task_cancelled assert awhiles == 1 assert await graph.ainvoke(None, thread, debug=True) == {"hello": "world again"} assert not inner_task_cancelled assert awhiles == 1 # resume with answer assert await graph.ainvoke(Command(resume=" okay"), thread) == { "hello": "world again okay" } assert not inner_task_cancelled assert awhiles == 1 async def test_step_timeout_on_stream_hang() -> None: inner_task_cancelled = False async def awhile(input: Any) -> None: try: await asyncio.sleep(1.5) except asyncio.CancelledError: nonlocal inner_task_cancelled inner_task_cancelled = True raise async def alittlewhile(input: Any) -> None: await asyncio.sleep(0.6) return "1" builder = Graph() builder.add_node(awhile) builder.add_node(alittlewhile) builder.set_conditional_entry_point(lambda _: ["awhile", "alittlewhile"], then=END) graph = builder.compile() graph.step_timeout = 1 with pytest.raises(asyncio.TimeoutError): async for chunk in graph.astream(1, stream_mode="updates"): assert chunk == {"alittlewhile": {"alittlewhile": "1"}} await asyncio.sleep(0.6) assert inner_task_cancelled @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC_PLUS_NONE) async def test_cancel_graph_astream(checkpointer_name: str) -> None: class State(TypedDict): value: Annotated[int, operator.add] class AwhileMaker: def __init__(self) -> None: self.reset() async def __call__(self, input: State) -> Any: self.started = True try: await asyncio.sleep(1.5) except asyncio.CancelledError: self.cancelled = True raise def reset(self): self.started = False self.cancelled = False async def alittlewhile(input: State) -> None: await asyncio.sleep(0.6) return {"value": 2} awhile = AwhileMaker() aparallelwhile = AwhileMaker() builder = StateGraph(State) builder.add_node("awhile", awhile) builder.add_node("aparallelwhile", aparallelwhile) builder.add_node(alittlewhile) builder.add_edge(START, "alittlewhile") builder.add_edge(START, "aparallelwhile") builder.add_edge("alittlewhile", "awhile") async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) # test interrupting astream got_event = False thread1: RunnableConfig = {"configurable": {"thread_id": "1"}} async with aclosing(graph.astream({"value": 1}, thread1)) as stream: async for chunk in stream: assert chunk == {"alittlewhile": {"value": 2}} got_event = True break assert got_event # node aparallelwhile should start, but be cancelled assert aparallelwhile.started is True assert aparallelwhile.cancelled is True # node "awhile" should never start assert awhile.started is False # checkpoint with output of "alittlewhile" should not be saved # but we should have applied pending writes if checkpointer is not None: state = await graph.aget_state(thread1) assert state is not None assert state.values == {"value": 3} # 1 + 2 assert state.next == ("aparallelwhile",) assert state.metadata == { "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", } @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC_PLUS_NONE) async def test_cancel_graph_astream_events_v2(checkpointer_name: Optional[str]) -> None: class State(TypedDict): value: int class AwhileMaker: def __init__(self) -> None: self.reset() async def __call__(self, input: State) -> Any: self.started = True try: await asyncio.sleep(1.5) except asyncio.CancelledError: self.cancelled = True raise def reset(self): self.started = False self.cancelled = False async def alittlewhile(input: State) -> None: await asyncio.sleep(0.6) return {"value": 2} awhile = AwhileMaker() anotherwhile = AwhileMaker() builder = StateGraph(State) builder.add_node(alittlewhile) builder.add_node("awhile", awhile) builder.add_node("anotherwhile", anotherwhile) builder.add_edge(START, "alittlewhile") builder.add_edge("alittlewhile", "awhile") builder.add_edge("awhile", "anotherwhile") async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) # test interrupting astream_events v2 got_event = False thread2: RunnableConfig = {"configurable": {"thread_id": "2"}} async with aclosing( graph.astream_events({"value": 1}, thread2, version="v2") ) as stream: async for chunk in stream: if chunk["event"] == "on_chain_stream" and not chunk["parent_ids"]: got_event = True assert chunk["data"]["chunk"] == {"alittlewhile": {"value": 2}} await asyncio.sleep(0.1) break # did break assert got_event # node "awhile" maybe starts (impl detail of astream_events) # if it does start, it must be cancelled if awhile.started: assert awhile.cancelled is True # node "anotherwhile" should never start assert anotherwhile.started is False # checkpoint with output of "alittlewhile" should not be saved if checkpointer is not None: state = await graph.aget_state(thread2) assert state is not None assert state.values == {"value": 2} assert state.next == ("awhile",) assert state.metadata == { "parents": {}, "source": "loop", "step": 1, "writes": {"alittlewhile": {"value": 2}}, "thread_id": "2", } async def test_node_schemas_custom_output() -> None: class State(TypedDict): hello: str bye: str messages: Annotated[list[str], add_messages] class Output(TypedDict): messages: list[str] class StateForA(TypedDict): hello: str messages: Annotated[list[str], add_messages] async def node_a(state: StateForA): assert state == { "hello": "there", "messages": [_AnyIdHumanMessage(content="hello")], } class StateForB(TypedDict): bye: str now: int async def node_b(state: StateForB): assert state == { "bye": "world", } return { "now": 123, "hello": "again", } class StateForC(TypedDict): hello: str now: int async def node_c(state: StateForC): assert state == { "hello": "again", "now": 123, } builder = StateGraph(State, output=Output) builder.add_node("a", node_a) builder.add_node("b", node_b) builder.add_node("c", node_c) builder.add_edge(START, "a") builder.add_edge("a", "b") builder.add_edge("b", "c") graph = builder.compile() assert await graph.ainvoke( {"hello": "there", "bye": "world", "messages": "hello"} ) == { "messages": [_AnyIdHumanMessage(content="hello")], } builder = StateGraph(State, output=Output) builder.add_node("a", node_a) builder.add_node("b", node_b) builder.add_node("c", node_c) builder.add_edge(START, "a") builder.add_edge("a", "b") builder.add_edge("b", "c") graph = builder.compile() assert await graph.ainvoke( { "hello": "there", "bye": "world", "messages": "hello", "now": 345, # ignored because not in input schema } ) == { "messages": [_AnyIdHumanMessage(content="hello")], } assert [ c async for c in graph.astream( { "hello": "there", "bye": "world", "messages": "hello", "now": 345, # ignored because not in input schema } ) ] == [ {"a": None}, {"b": {"hello": "again", "now": 123}}, {"c": None}, ] async def test_invoke_single_process_in_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output") app = Pregel( nodes={ "one": chain, }, channels={ "input": LastValue(int), "output": LastValue(int), }, input_channels="input", output_channels="output", ) graph = Graph() graph.add_node("add_one", add_one) graph.set_entry_point("add_one") graph.set_finish_point("add_one") gapp = graph.compile() if SHOULD_CHECK_SNAPSHOTS: assert app.input_schema.model_json_schema() == { "title": "LangGraphInput", "type": "integer", } assert app.output_schema.model_json_schema() == { "title": "LangGraphOutput", "type": "integer", } assert await app.ainvoke(2) == 3 assert await app.ainvoke(2, output_keys=["output"]) == {"output": 3} assert await gapp.ainvoke(2) == 3 @pytest.mark.parametrize( "falsy_value", [None, False, 0, "", [], {}, set(), frozenset(), 0.0, 0j], ) async def test_invoke_single_process_in_out_falsy_values(falsy_value: Any) -> None: graph = Graph() graph.add_node("return_falsy_const", lambda *args, **kwargs: falsy_value) graph.set_entry_point("return_falsy_const") graph.set_finish_point("return_falsy_const") gapp = graph.compile() assert falsy_value == await gapp.ainvoke(1) async def test_invoke_single_process_in_write_kwargs(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) chain = ( Channel.subscribe_to("input") | add_one | Channel.write_to("output", fixed=5, output_plus_one=lambda x: x + 1) ) app = Pregel( nodes={"one": chain}, channels={ "input": LastValue(int), "output": LastValue(int), "fixed": LastValue(int), "output_plus_one": LastValue(int), }, output_channels=["output", "fixed", "output_plus_one"], input_channels="input", ) if SHOULD_CHECK_SNAPSHOTS: assert app.input_schema.model_json_schema() == { "title": "LangGraphInput", "type": "integer", } assert app.output_schema.model_json_schema() == { "title": "LangGraphOutput", "type": "object", "properties": { "output": {"title": "Output", "type": "integer", "default": None}, "fixed": {"title": "Fixed", "type": "integer", "default": None}, "output_plus_one": { "title": "Output Plus One", "type": "integer", "default": None, }, }, } assert await app.ainvoke(2) == {"output": 3, "fixed": 5, "output_plus_one": 4} async def test_invoke_single_process_in_out_dict(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": chain}, channels={"input": LastValue(int), "output": LastValue(int)}, input_channels="input", output_channels=["output"], ) if SHOULD_CHECK_SNAPSHOTS: assert app.input_schema.model_json_schema() == { "title": "LangGraphInput", "type": "integer", } assert app.output_schema.model_json_schema() == { "title": "LangGraphOutput", "type": "object", "properties": { "output": {"title": "Output", "type": "integer", "default": None} }, } assert await app.ainvoke(2) == {"output": 3} async def test_invoke_single_process_in_dict_out_dict(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": chain}, channels={"input": LastValue(int), "output": LastValue(int)}, input_channels=["input"], output_channels=["output"], ) if SHOULD_CHECK_SNAPSHOTS: assert app.input_schema.model_json_schema() == { "title": "LangGraphInput", "type": "object", "properties": { "input": {"title": "Input", "type": "integer", "default": None} }, } assert app.output_schema.model_json_schema() == { "title": "LangGraphOutput", "type": "object", "properties": { "output": {"title": "Output", "type": "integer", "default": None} }, } assert await app.ainvoke({"input": 2}) == {"output": 3} async def test_invoke_two_processes_in_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") two = Channel.subscribe_to("inbox") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": one, "two": two}, channels={ "inbox": LastValue(int), "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", stream_channels=["inbox", "output"], ) assert await app.ainvoke(2) == 4 with pytest.raises(GraphRecursionError): await app.ainvoke(2, {"recursion_limit": 1}) step = 0 async for values in app.astream(2): step += 1 if step == 1: assert values == { "inbox": 3, } elif step == 2: assert values == { "inbox": 3, "output": 4, } assert step == 2 graph = Graph() graph.add_node("add_one", add_one) graph.add_node("add_one_more", add_one) graph.set_entry_point("add_one") graph.set_finish_point("add_one_more") graph.add_edge("add_one", "add_one_more") gapp = graph.compile() assert await gapp.ainvoke(2) == 4 step = 0 async for values in gapp.astream(2): step += 1 if step == 1: assert values == { "add_one": 3, } elif step == 2: assert values == { "add_one_more": 4, } assert step == 2 @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_invoke_two_processes_in_out_interrupt( checkpointer_name: str, mocker: MockerFixture ) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") two = Channel.subscribe_to("inbox") | add_one | Channel.write_to("output") async with awith_checkpointer(checkpointer_name) as checkpointer: app = Pregel( nodes={"one": one, "two": two}, channels={ "inbox": LastValue(int), "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", checkpointer=checkpointer, interrupt_after_nodes=["one"], ) thread1 = {"configurable": {"thread_id": "1"}} thread2 = {"configurable": {"thread_id": "2"}} # start execution, stop at inbox assert await app.ainvoke(2, thread1) is None # inbox == 3 checkpoint = await checkpointer.aget(thread1) assert checkpoint is not None assert checkpoint["channel_values"]["inbox"] == 3 # resume execution, finish assert await app.ainvoke(None, thread1) == 4 # start execution again, stop at inbox assert await app.ainvoke(20, thread1) is None # inbox == 21 checkpoint = await checkpointer.aget(thread1) assert checkpoint is not None assert checkpoint["channel_values"]["inbox"] == 21 # send a new value in, interrupting the previous execution assert await app.ainvoke(3, thread1) is None assert await app.ainvoke(None, thread1) == 5 # start execution again, stopping at inbox assert await app.ainvoke(20, thread2) is None # inbox == 21 snapshot = await app.aget_state(thread2) assert snapshot.values["inbox"] == 21 assert snapshot.next == ("two",) # update the state, resume await app.aupdate_state(thread2, 25, as_node="one") assert await app.ainvoke(None, thread2) == 26 # no pending tasks snapshot = await app.aget_state(thread2) assert snapshot.next == () # list history history = [c async for c in app.aget_state_history(thread1)] assert history == [ StateSnapshot( values={"inbox": 4, "output": 5, "input": 3}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 6, "writes": {"two": 5}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[1].config, ), StateSnapshot( values={"inbox": 4, "output": 4, "input": 3}, tasks=( PregelTask(AnyStr(), "two", (PULL, "two"), result={"output": 5}), ), next=("two",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 5, "writes": {"one": None}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[2].config, ), StateSnapshot( values={"inbox": 21, "output": 4, "input": 3}, tasks=( PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 4}), ), next=("one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "step": 4, "writes": {"input": 3}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[3].config, ), StateSnapshot( values={"inbox": 21, "output": 4, "input": 20}, tasks=(PregelTask(AnyStr(), "two", (PULL, "two")),), next=("two",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"one": None}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[4].config, ), StateSnapshot( values={"inbox": 3, "output": 4, "input": 20}, tasks=( PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 21}), ), next=("one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "step": 2, "writes": {"input": 20}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[5].config, ), StateSnapshot( values={"inbox": 3, "output": 4, "input": 2}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"two": 4}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[6].config, ), StateSnapshot( values={"inbox": 3, "input": 2}, tasks=( PregelTask(AnyStr(), "two", (PULL, "two"), result={"output": 4}), ), next=("two",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 0, "writes": {"one": None}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[7].config, ), StateSnapshot( values={"input": 2}, tasks=( PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 3}), ), next=("one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "step": -1, "writes": {"input": 2}, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ), ] # forking from any previous checkpoint should re-run nodes assert [ c async for c in app.astream(None, history[0].config, stream_mode="updates") ] == [] assert [ c async for c in app.astream(None, history[1].config, stream_mode="updates") ] == [ {"two": {"output": 5}}, ] assert [ c async for c in app.astream(None, history[2].config, stream_mode="updates") ] == [ {"one": {"inbox": 4}}, {"__interrupt__": ()}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_fork_always_re_runs_nodes( checkpointer_name: str, mocker: MockerFixture ) -> None: add_one = mocker.Mock(side_effect=lambda _: 1) builder = StateGraph(Annotated[int, operator.add]) builder.add_node("add_one", add_one) builder.add_edge(START, "add_one") builder.add_conditional_edges("add_one", lambda cnt: "add_one" if cnt < 6 else END) async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) thread1 = {"configurable": {"thread_id": "1"}} # start execution, stop at inbox assert [ c async for c in graph.astream(1, thread1, stream_mode=["values", "updates"]) ] == [ ("values", 1), ("updates", {"add_one": 1}), ("values", 2), ("updates", {"add_one": 1}), ("values", 3), ("updates", {"add_one": 1}), ("values", 4), ("updates", {"add_one": 1}), ("values", 5), ("updates", {"add_one": 1}), ("values", 6), ] # list history history = [c async for c in graph.aget_state_history(thread1)] assert history == [ StateSnapshot( values=6, next=(), tasks=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 5, "writes": {"add_one": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[1].config, ), StateSnapshot( values=5, tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),), next=("add_one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 4, "writes": {"add_one": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[2].config, ), StateSnapshot( values=4, tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),), next=("add_one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"add_one": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[3].config, ), StateSnapshot( values=3, tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),), next=("add_one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 2, "writes": {"add_one": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[4].config, ), StateSnapshot( values=2, tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),), next=("add_one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"add_one": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[5].config, ), StateSnapshot( values=1, tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),), next=("add_one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[6].config, ), StateSnapshot( values=0, tasks=( PregelTask(AnyStr(), "__start__", (PULL, "__start__"), result=1), ), next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "step": -1, "writes": {"__start__": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ), ] # forking from any previous checkpoint should re-run nodes assert [ c async for c in graph.astream(None, history[0].config, stream_mode="updates") ] == [] assert [ c async for c in graph.astream(None, history[1].config, stream_mode="updates") ] == [ {"add_one": 1}, ] assert [ c async for c in graph.astream(None, history[2].config, stream_mode="updates") ] == [ {"add_one": 1}, {"add_one": 1}, ] async def test_invoke_two_processes_in_dict_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") two = ( Channel.subscribe_to("inbox") | RunnableLambda(add_one).abatch | Channel.write_to("output").abatch ) app = Pregel( nodes={"one": one, "two": two}, channels={ "inbox": Topic(int), "output": LastValue(int), "input": LastValue(int), }, input_channels=["input", "inbox"], stream_channels=["output", "inbox"], output_channels=["output"], ) # [12 + 1, 2 + 1 + 1] assert [ c async for c in app.astream( {"input": 2, "inbox": 12}, output_keys="output", stream_mode="updates" ) ] == [ {"one": None}, {"two": 13}, {"two": 4}, ] assert [ c async for c in app.astream({"input": 2, "inbox": 12}, output_keys="output") ] == [13, 4] assert [ c async for c in app.astream({"input": 2, "inbox": 12}, stream_mode="updates") ] == [ {"one": {"inbox": 3}}, {"two": {"output": 13}}, {"two": {"output": 4}}, ] assert [c async for c in app.astream({"input": 2, "inbox": 12})] == [ {"inbox": [3], "output": 13}, {"output": 4}, ] assert [ c async for c in app.astream({"input": 2, "inbox": 12}, stream_mode="debug") ] == [ { "type": "task", "timestamp": AnyStr(), "step": 0, "payload": { "id": AnyStr(), "name": "one", "input": 2, "triggers": ["input"], }, }, { "type": "task", "timestamp": AnyStr(), "step": 0, "payload": { "id": AnyStr(), "name": "two", "input": [12], "triggers": ["inbox"], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 0, "payload": { "id": AnyStr(), "name": "one", "result": [("inbox", 3)], "error": None, "interrupts": [], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 0, "payload": { "id": AnyStr(), "name": "two", "result": [("output", 13)], "error": None, "interrupts": [], }, }, { "type": "task", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "two", "input": [3], "triggers": ["inbox"], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "two", "result": [("output", 4)], "error": None, "interrupts": [], }, }, ] async def test_batch_two_processes_in_out() -> None: async def add_one_with_delay(inp: int) -> int: await asyncio.sleep(inp / 10) return inp + 1 one = Channel.subscribe_to("input") | add_one_with_delay | Channel.write_to("one") two = Channel.subscribe_to("one") | add_one_with_delay | Channel.write_to("output") app = Pregel( nodes={"one": one, "two": two}, channels={ "one": LastValue(int), "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", ) assert await app.abatch([3, 2, 1, 3, 5]) == [5, 4, 3, 5, 7] assert await app.abatch([3, 2, 1, 3, 5], output_keys=["output"]) == [ {"output": 5}, {"output": 4}, {"output": 3}, {"output": 5}, {"output": 7}, ] graph = Graph() graph.add_node("add_one", add_one_with_delay) graph.add_node("add_one_more", add_one_with_delay) graph.set_entry_point("add_one") graph.set_finish_point("add_one_more") graph.add_edge("add_one", "add_one_more") gapp = graph.compile() assert await gapp.abatch([3, 2, 1, 3, 5]) == [5, 4, 3, 5, 7] async def test_invoke_many_processes_in_out(mocker: MockerFixture) -> None: test_size = 100 add_one = mocker.Mock(side_effect=lambda x: x + 1) nodes = {"-1": Channel.subscribe_to("input") | add_one | Channel.write_to("-1")} for i in range(test_size - 2): nodes[str(i)] = ( Channel.subscribe_to(str(i - 1)) | add_one | Channel.write_to(str(i)) ) nodes["last"] = Channel.subscribe_to(str(i)) | add_one | Channel.write_to("output") app = Pregel( nodes=nodes, channels={str(i): LastValue(int) for i in range(-1, test_size - 2)} | {"input": LastValue(int), "output": LastValue(int)}, input_channels="input", output_channels="output", ) # No state is left over from previous invocations for _ in range(10): assert await app.ainvoke(2, {"recursion_limit": test_size}) == 2 + test_size # Concurrent invocations do not interfere with each other assert await asyncio.gather( *(app.ainvoke(2, {"recursion_limit": test_size}) for _ in range(10)) ) == [2 + test_size for _ in range(10)] async def test_batch_many_processes_in_out(mocker: MockerFixture) -> None: test_size = 100 add_one = mocker.Mock(side_effect=lambda x: x + 1) nodes = {"-1": Channel.subscribe_to("input") | add_one | Channel.write_to("-1")} for i in range(test_size - 2): nodes[str(i)] = ( Channel.subscribe_to(str(i - 1)) | add_one | Channel.write_to(str(i)) ) nodes["last"] = Channel.subscribe_to(str(i)) | add_one | Channel.write_to("output") app = Pregel( nodes=nodes, channels={str(i): LastValue(int) for i in range(-1, test_size - 2)} | {"input": LastValue(int), "output": LastValue(int)}, input_channels="input", output_channels="output", ) # No state is left over from previous invocations for _ in range(3): # Then invoke pubsub assert await app.abatch([2, 1, 3, 4, 5], {"recursion_limit": test_size}) == [ 2 + test_size, 1 + test_size, 3 + test_size, 4 + test_size, 5 + test_size, ] # Concurrent invocations do not interfere with each other assert await asyncio.gather( *(app.abatch([2, 1, 3, 4, 5], {"recursion_limit": test_size}) for _ in range(3)) ) == [ [2 + test_size, 1 + test_size, 3 + test_size, 4 + test_size, 5 + test_size] for _ in range(3) ] async def test_invoke_two_processes_two_in_two_out_invalid( mocker: MockerFixture, ) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("output") two = Channel.subscribe_to("input") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": one, "two": two}, channels={"output": LastValue(int), "input": LastValue(int)}, input_channels="input", output_channels="output", ) with pytest.raises(InvalidUpdateError): # LastValue channels can only be updated once per iteration await app.ainvoke(2) async def test_invoke_two_processes_two_in_two_out_valid(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("output") two = Channel.subscribe_to("input") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": one, "two": two}, channels={ "input": LastValue(int), "output": Topic(int), }, input_channels="input", output_channels="output", ) # An Topic channel accumulates updates into a sequence assert await app.ainvoke(2) == [3, 3] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_invoke_checkpoint(mocker: MockerFixture, checkpointer_name: str) -> None: add_one = mocker.Mock(side_effect=lambda x: x["total"] + x["input"]) errored_once = False def raise_if_above_10(input: int) -> int: nonlocal errored_once if input > 4: if errored_once: pass else: errored_once = True raise ConnectionError("I will be retried") if input > 10: raise ValueError("Input is too large") return input one = ( Channel.subscribe_to(["input"]).join(["total"]) | add_one | Channel.write_to("output", "total") | raise_if_above_10 ) async with awith_checkpointer(checkpointer_name) as checkpointer: app = Pregel( nodes={"one": one}, channels={ "total": BinaryOperatorAggregate(int, operator.add), "input": LastValue(int), "output": LastValue(int), }, input_channels="input", output_channels="output", checkpointer=checkpointer, retry_policy=RetryPolicy(), ) # total starts out as 0, so output is 0+2=2 assert await app.ainvoke(2, {"configurable": {"thread_id": "1"}}) == 2 checkpoint = await checkpointer.aget({"configurable": {"thread_id": "1"}}) assert checkpoint is not None assert checkpoint["channel_values"].get("total") == 2 # total is now 2, so output is 2+3=5 assert await app.ainvoke(3, {"configurable": {"thread_id": "1"}}) == 5 assert errored_once, "errored and retried" checkpoint = await checkpointer.aget({"configurable": {"thread_id": "1"}}) assert checkpoint is not None assert checkpoint["channel_values"].get("total") == 7 # total is now 2+5=7, so output would be 7+4=11, but raises ValueError with pytest.raises(ValueError): await app.ainvoke(4, {"configurable": {"thread_id": "1"}}) # checkpoint is not updated checkpoint = await checkpointer.aget({"configurable": {"thread_id": "1"}}) assert checkpoint is not None assert checkpoint["channel_values"].get("total") == 7 # on a new thread, total starts out as 0, so output is 0+5=5 assert await app.ainvoke(5, {"configurable": {"thread_id": "2"}}) == 5 checkpoint = await checkpointer.aget({"configurable": {"thread_id": "1"}}) assert checkpoint is not None assert checkpoint["channel_values"].get("total") == 7 checkpoint = await checkpointer.aget({"configurable": {"thread_id": "2"}}) assert checkpoint is not None assert checkpoint["channel_values"].get("total") == 5 @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_pending_writes_resume( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: class State(TypedDict): value: Annotated[int, operator.add] class AwhileMaker: def __init__(self, sleep: float, rtn: Union[Dict, Exception]) -> None: self.sleep = sleep self.rtn = rtn self.reset() async def __call__(self, input: State) -> Any: self.calls += 1 await asyncio.sleep(self.sleep) if isinstance(self.rtn, Exception): raise self.rtn else: return self.rtn def reset(self): self.calls = 0 one = AwhileMaker(0.1, {"value": 2}) two = AwhileMaker(0.3, ConnectionError("I'm not good")) builder = StateGraph(State) builder.add_node("one", one) builder.add_node("two", two, retry=RetryPolicy(max_attempts=2)) builder.add_edge(START, "one") builder.add_edge(START, "two") async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) thread1: RunnableConfig = {"configurable": {"thread_id": "1"}} with pytest.raises(ConnectionError, match="I'm not good"): await graph.ainvoke({"value": 1}, thread1) # both nodes should have been called once assert one.calls == 1 assert two.calls == 2 # latest checkpoint should be before nodes "one", "two" # but we should have applied pending writes from "one" state = await graph.aget_state(thread1) assert state is not None assert state.values == {"value": 3} assert state.next == ("two",) assert state.tasks == ( PregelTask(AnyStr(), "one", (PULL, "one"), result={"value": 2}), PregelTask( AnyStr(), "two", (PULL, "two"), 'ConnectionError("I\'m not good")', ), ) assert state.metadata == { "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", } # get_state with checkpoint_id should not apply any pending writes state = await graph.aget_state(state.config) assert state is not None assert state.values == {"value": 1} assert state.next == ("one", "two") # should contain pending write of "one" checkpoint = await checkpointer.aget_tuple(thread1) assert checkpoint is not None # should contain error from "two" expected_writes = [ (AnyStr(), "one", "one"), (AnyStr(), "value", 2), (AnyStr(), ERROR, 'ConnectionError("I\'m not good")'), ] assert len(checkpoint.pending_writes) == 3 assert all(w in expected_writes for w in checkpoint.pending_writes) # both non-error pending writes come from same task non_error_writes = [w for w in checkpoint.pending_writes if w[1] != ERROR] assert non_error_writes[0][0] == non_error_writes[1][0] # error write is from the other task error_write = next(w for w in checkpoint.pending_writes if w[1] == ERROR) assert error_write[0] != non_error_writes[0][0] # resume execution with pytest.raises(ConnectionError, match="I'm not good"): await graph.ainvoke(None, thread1) # node "one" succeeded previously, so shouldn't be called again assert one.calls == 1 # node "two" should have been called once again assert two.calls == 4 # confirm no new checkpoints saved state_two = await graph.aget_state(thread1) assert state_two.metadata == state.metadata # resume execution, without exception two.rtn = {"value": 3} # both the pending write and the new write were applied, 1 + 2 + 3 = 6 assert await graph.ainvoke(None, thread1) == {"value": 6} # check all final checkpoints checkpoints = [c async for c in checkpointer.alist(thread1)] # we should have 3 assert len(checkpoints) == 3 # the last one not too interesting for this test assert checkpoints[0] == CheckpointTuple( config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, checkpoint={ "v": 1, "id": AnyStr(), "ts": AnyStr(), "pending_sends": [], "versions_seen": { "one": { "start:one": AnyVersion(), }, "two": { "start:two": AnyVersion(), }, "__input__": {}, "__start__": { "__start__": AnyVersion(), }, "__interrupt__": { "value": AnyVersion(), "__start__": AnyVersion(), "start:one": AnyVersion(), "start:two": AnyVersion(), }, }, "channel_versions": { "one": AnyVersion(), "two": AnyVersion(), "value": AnyVersion(), "__start__": AnyVersion(), "start:one": AnyVersion(), "start:two": AnyVersion(), }, "channel_values": {"one": "one", "two": "two", "value": 6}, }, metadata={ "parents": {}, "step": 1, "source": "loop", "writes": {"one": {"value": 2}, "two": {"value": 3}}, "thread_id": "1", }, parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": checkpoints[1].config["configurable"][ "checkpoint_id" ], } }, pending_writes=[], ) # the previous one we assert that pending writes contains both # - original error # - successful writes from resuming after preventing error assert checkpoints[1] == CheckpointTuple( config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, checkpoint={ "v": 1, "id": AnyStr(), "ts": AnyStr(), "pending_sends": [], "versions_seen": { "__input__": {}, "__start__": { "__start__": AnyVersion(), }, }, "channel_versions": { "value": AnyVersion(), "__start__": AnyVersion(), "start:one": AnyVersion(), "start:two": AnyVersion(), }, "channel_values": { "value": 1, "start:one": "__start__", "start:two": "__start__", }, }, metadata={ "parents": {}, "step": 0, "source": "loop", "writes": None, "thread_id": "1", }, parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": checkpoints[2].config["configurable"][ "checkpoint_id" ], } }, pending_writes=UnsortedSequence( (AnyStr(), "one", "one"), (AnyStr(), "value", 2), (AnyStr(), "__error__", 'ConnectionError("I\'m not good")'), (AnyStr(), "two", "two"), (AnyStr(), "value", 3), ), ) assert checkpoints[2] == CheckpointTuple( config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, checkpoint={ "v": 1, "id": AnyStr(), "ts": AnyStr(), "pending_sends": [], "versions_seen": {"__input__": {}}, "channel_versions": { "__start__": AnyVersion(), }, "channel_values": {"__start__": {"value": 1}}, }, metadata={ "parents": {}, "step": -1, "source": "input", "writes": {"__start__": {"value": 1}}, "thread_id": "1", }, parent_config=None, pending_writes=UnsortedSequence( (AnyStr(), "value", 1), (AnyStr(), "start:one", "__start__"), (AnyStr(), "start:two", "__start__"), ), ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_run_from_checkpoint_id_retains_previous_writes( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: class MyState(TypedDict): myval: Annotated[int, operator.add] otherval: bool class Anode: def __init__(self): self.switch = False async def __call__(self, state: MyState): self.switch = not self.switch return {"myval": 2 if self.switch else 1, "otherval": self.switch} builder = StateGraph(MyState) thenode = Anode() # Fun. builder.add_node("node_one", thenode) builder.add_node("node_two", thenode) builder.add_edge(START, "node_one") def _getedge(src: str): swap = "node_one" if src == "node_two" else "node_two" def _edge(st: MyState) -> Literal["__end__", "node_one", "node_two"]: if st["myval"] > 3: return END if st["otherval"]: return swap return src return _edge builder.add_conditional_edges("node_one", _getedge("node_one")) builder.add_conditional_edges("node_two", _getedge("node_two")) async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) thread_id = uuid.uuid4() thread1 = {"configurable": {"thread_id": str(thread_id)}} result = await graph.ainvoke({"myval": 1}, thread1) assert result["myval"] == 4 history = [c async for c in graph.aget_state_history(thread1)] assert len(history) == 4 assert history[-1].values == {"myval": 0} assert history[0].values == {"myval": 4, "otherval": False} second_run_config = { **thread1, "configurable": { **thread1["configurable"], "checkpoint_id": history[1].config["configurable"]["checkpoint_id"], }, } second_result = await graph.ainvoke(None, second_run_config) assert second_result == {"myval": 5, "otherval": True} new_history = [ c async for c in graph.aget_state_history( {"configurable": {"thread_id": str(thread_id), "checkpoint_ns": ""}} ) ] assert len(new_history) == len(history) + 1 for original, new in zip(history, new_history[1:]): assert original.values == new.values assert original.next == new.next assert original.metadata["step"] == new.metadata["step"] def _get_tasks(hist: list, start: int): return [h.tasks for h in hist[start:]] assert _get_tasks(new_history, 1) == _get_tasks(history, 0) async def test_cond_edge_after_send() -> None: class Node: def __init__(self, name: str): self.name = name setattr(self, "__name__", name) async def __call__(self, state): return [self.name] async def send_for_fun(state): return [Send("2", state), Send("2", state)] async def route_to_three(state) -> Literal["3"]: return "3" builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("2")) builder.add_node(Node("3")) builder.add_edge(START, "1") builder.add_conditional_edges("1", send_for_fun) builder.add_conditional_edges("2", route_to_three) graph = builder.compile() assert await graph.ainvoke(["0"]) == ["0", "1", "2", "2", "3"] async def test_concurrent_emit_sends() -> None: class Node: def __init__(self, name: str): self.name = name setattr(self, "__name__", name) async def __call__(self, state): return ( [self.name] if isinstance(state, list) else ["|".join((self.name, str(state)))] ) async def send_for_fun(state): return [Send("2", 1), Send("2", 2), "3.1"] async def send_for_profit(state): return [Send("2", 3), Send("2", 4)] async def route_to_three(state) -> Literal["3"]: return "3" builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("1.1")) builder.add_node(Node("2")) builder.add_node(Node("3")) builder.add_node(Node("3.1")) builder.add_edge(START, "1") builder.add_edge(START, "1.1") builder.add_conditional_edges("1", send_for_fun) builder.add_conditional_edges("1.1", send_for_profit) builder.add_conditional_edges("2", route_to_three) graph = builder.compile() assert await graph.ainvoke(["0"]) == ( [ "0", "1", "1.1", "2|1", "2|2", "2|3", "2|4", "3", "3.1", ] if FF_SEND_V2 else [ "0", "1", "1.1", "3.1", "2|1", "2|2", "2|3", "2|4", "3", ] ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_send_sequences(checkpointer_name: str) -> None: class Node: def __init__(self, name: str): self.name = name setattr(self, "__name__", name) async def __call__(self, state): update = ( [self.name] if isinstance(state, list) # or isinstance(state, Control) else ["|".join((self.name, str(state)))] ) if isinstance(state, Command): return replace(state, update=update) else: return update async def send_for_fun(state): return [ Send("2", Command(goto=Send("2", 3))), Send("2", Command(goto=Send("2", 4))), "3.1", ] async def route_to_three(state) -> Literal["3"]: return "3" builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("2")) builder.add_node(Node("3")) builder.add_node(Node("3.1")) builder.add_edge(START, "1") builder.add_conditional_edges("1", send_for_fun) builder.add_conditional_edges("2", route_to_three) graph = builder.compile() assert ( await graph.ainvoke(["0"]) == [ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='2', arg=4))", "2|3", "2|4", "3", "3.1", ] if FF_SEND_V2 else [ "0", "1", "3.1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='2', arg=4))", "3", "2|3", "2|4", "3", ] ) if not FF_SEND_V2: return async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer, interrupt_before=["3.1"]) thread1 = {"configurable": {"thread_id": "1"}} assert await graph.ainvoke(["0"], thread1) == [ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='2', arg=4))", "2|3", "2|4", ] assert await graph.ainvoke(None, thread1) == [ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='2', arg=4))", "2|3", "2|4", "3", "3.1", ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_send_dedupe_on_resume(checkpointer_name: str) -> None: if not FF_SEND_V2: pytest.skip("Send deduplication is only available in Send V2") class InterruptOnce: ticks: int = 0 def __call__(self, state): self.ticks += 1 if self.ticks == 1: raise NodeInterrupt("Bahh") return ["|".join(("flaky", str(state)))] class Node: def __init__(self, name: str): self.name = name self.ticks = 0 setattr(self, "__name__", name) def __call__(self, state): self.ticks += 1 update = ( [self.name] if isinstance(state, list) else ["|".join((self.name, str(state)))] ) if isinstance(state, Command): return replace(state, update=update) else: return update def send_for_fun(state): return [ Send("2", Command(goto=Send("2", 3))), Send("2", Command(goto=Send("flaky", 4))), "3.1", ] def route_to_three(state) -> Literal["3"]: return "3" builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("2")) builder.add_node(Node("3")) builder.add_node(Node("3.1")) builder.add_node("flaky", InterruptOnce()) builder.add_edge(START, "1") builder.add_conditional_edges("1", send_for_fun) builder.add_conditional_edges("2", route_to_three) async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) thread1 = {"configurable": {"thread_id": "1"}} assert await graph.ainvoke(["0"], thread1, debug=1) == [ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='flaky', arg=4))", "2|3", ] assert builder.nodes["2"].runnable.func.ticks == 3 assert builder.nodes["flaky"].runnable.func.ticks == 1 # resume execution assert await graph.ainvoke(None, thread1, debug=1) == [ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", "3.1", ] # node "2" doesn't get called again, as we recover writes saved before assert builder.nodes["2"].runnable.func.ticks == 3 # node "flaky" gets called again, as it was interrupted assert builder.nodes["flaky"].runnable.func.ticks == 2 # check history history = [c async for c in graph.aget_state_history(thread1)] assert history == [ StateSnapshot( values=[ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", "3.1", ], next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": {"3": ["3"], "3.1": ["3.1"]}, "thread_id": "1", "step": 2, "parents": {}, }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=(), ), StateSnapshot( values=[ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", ], next=("3", "3.1"), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": { "1": ["1"], "2": [ ["2|Command(goto=Send(node='2', arg=3))"], ["2|Command(goto=Send(node='flaky', arg=4))"], ["2|3"], ], "flaky": ["flaky|4"], }, "thread_id": "1", "step": 1, "parents": {}, }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="3", path=("__pregel_pull", "3"), error=None, interrupts=(), state=None, result=["3"], ), PregelTask( id=AnyStr(), name="3.1", path=("__pregel_pull", "3.1"), error=None, interrupts=(), state=None, result=["3.1"], ), ), ), StateSnapshot( values=["0"], next=("1", "2", "2", "2", "flaky"), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": None, "thread_id": "1", "step": 0, "parents": {}, }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="1", path=("__pregel_pull", "1"), error=None, interrupts=(), state=None, result=["1"], ), PregelTask( id=AnyStr(), name="2", path=( "__pregel_push", ("__pregel_pull", "1"), 2, AnyStr(), ), error=None, interrupts=(), state=None, result=["2|Command(goto=Send(node='2', arg=3))"], ), PregelTask( id=AnyStr(), name="2", path=( "__pregel_push", ("__pregel_pull", "1"), 3, AnyStr(), ), error=None, interrupts=(), state=None, result=["2|Command(goto=Send(node='flaky', arg=4))"], ), PregelTask( id=AnyStr(), name="2", path=( "__pregel_push", ( "__pregel_push", ("__pregel_pull", "1"), 2, AnyStr(), ), 2, AnyStr(), ), error=None, interrupts=(), state=None, result=["2|3"], ), PregelTask( id=AnyStr(), name="flaky", path=( "__pregel_push", ( "__pregel_push", ("__pregel_pull", "1"), 3, AnyStr(), ), 2, AnyStr(), ), error=None, interrupts=(Interrupt(value="Bahh", when="during"),), state=None, result=["flaky|4"], ), ), ), StateSnapshot( values=[], next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "input", "writes": {"__start__": ["0"]}, "thread_id": "1", "step": -1, "parents": {}, }, created_at=AnyStr(), parent_config=None, tasks=( PregelTask( id=AnyStr(), name="__start__", path=("__pregel_pull", "__start__"), error=None, interrupts=(), state=None, result=["0"], ), ), ), ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_send_react_interrupt(checkpointer_name: str) -> None: from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage ai_message = AIMessage( "", id="ai1", tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())], ) async def agent(state): return {"messages": ai_message} def route(state): if isinstance(state["messages"][-1], AIMessage): return [ Send(call["name"], call) for call in state["messages"][-1].tool_calls ] foo_called = 0 async def foo(call: ToolCall): nonlocal foo_called foo_called += 1 return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])} builder = StateGraph(MessagesState) builder.add_node(agent) builder.add_node(foo) builder.add_edge(START, "agent") builder.add_conditional_edges("agent", route) graph = builder.compile() assert await graph.ainvoke({"messages": [HumanMessage("hello")]}) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), _AnyIdToolMessage( content="{'hi': [1, 2, 3]}", tool_call_id=AnyStr(), ), ] } assert foo_called == 1 async with awith_checkpointer(checkpointer_name) as checkpointer: # simple interrupt-resume flow foo_called = 0 graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) thread1 = {"configurable": {"thread_id": "1"}} assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] } assert foo_called == 0 assert await graph.ainvoke(None, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), _AnyIdToolMessage( content="{'hi': [1, 2, 3]}", tool_call_id=AnyStr(), ), ] } assert foo_called == 1 # interrupt-update-resume flow foo_called = 0 graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) thread1 = {"configurable": {"thread_id": "2"}} assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] } assert foo_called == 0 if not FF_SEND_V2: return # get state should show the pending task state = await graph.aget_state(thread1) assert state == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] }, next=("foo",), config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 0, "source": "loop", "writes": None, "parents": {}, "thread_id": "2", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( content="", additional_kwargs={}, response_metadata={}, id="ai1", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ) }, ), PregelTask( id=AnyStr(), name="foo", path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()), error=None, interrupts=(), state=None, result=None, ), ), ) # remove the tool call, clearing the pending task await graph.aupdate_state( thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])} ) # tool call no longer in pending tasks assert await graph.aget_state(thread1) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="Bye now", tool_calls=[], ), ] }, next=(), config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 1, "source": "update", "writes": { "agent": { "messages": _AnyIdAIMessage( content="Bye now", tool_calls=[], ) } }, "parents": {}, "thread_id": "2", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=(), ) # tool call not executed assert await graph.ainvoke(None, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage(content="Bye now"), ] } assert foo_called == 0 # interrupt-update-resume flow, creating new Send in update call foo_called = 0 graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) thread1 = {"configurable": {"thread_id": "3"}} assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] } assert foo_called == 0 # get state should show the pending task state = await graph.aget_state(thread1) assert state == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] }, next=("foo",), config={ "configurable": { "thread_id": "3", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 0, "source": "loop", "writes": None, "parents": {}, "thread_id": "3", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "3", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( "", id="ai1", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ) }, ), PregelTask( id=AnyStr(), name="foo", path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()), error=None, interrupts=(), state=None, result=None, ), ), ) # replace the tool call, should clear previous send, create new one await graph.aupdate_state( thread1, { "messages": AIMessage( "", id=ai_message.id, tool_calls=[ { "name": "foo", "args": {"hi": [4, 5, 6]}, "id": "tool1", "type": "tool_call", } ], ) }, ) # prev tool call no longer in pending tasks, new tool call is assert await graph.aget_state(thread1) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [4, 5, 6]}, "id": "tool1", "type": "tool_call", } ], ), ] }, next=("foo",), config={ "configurable": { "thread_id": "3", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 1, "source": "update", "writes": { "agent": { "messages": _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [4, 5, 6]}, "id": "tool1", "type": "tool_call", } ], ) } }, "parents": {}, "thread_id": "3", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "3", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="foo", path=("__pregel_push", (), 0, AnyStr()), error=None, interrupts=(), state=None, result=None, ), ), ) # prev tool call not executed, new tool call is assert await graph.ainvoke(None, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), AIMessage( "", id="ai1", tool_calls=[ { "name": "foo", "args": {"hi": [4, 5, 6]}, "id": "tool1", "type": "tool_call", } ], ), _AnyIdToolMessage(content="{'hi': [4, 5, 6]}", tool_call_id="tool1"), ] } assert foo_called == 1 @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_send_react_interrupt_control( checkpointer_name: str, snapshot: SnapshotAssertion ) -> None: from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage ai_message = AIMessage( "", id="ai1", tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())], ) async def agent(state) -> Command[Literal["foo"]]: return Command( update={"messages": ai_message}, goto=[Send(call["name"], call) for call in ai_message.tool_calls], ) foo_called = 0 async def foo(call: ToolCall): nonlocal foo_called foo_called += 1 return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])} builder = StateGraph(MessagesState) builder.add_node(agent) builder.add_node(foo) builder.add_edge(START, "agent") graph = builder.compile() assert graph.get_graph().draw_mermaid() == snapshot assert await graph.ainvoke({"messages": [HumanMessage("hello")]}) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), _AnyIdToolMessage( content="{'hi': [1, 2, 3]}", tool_call_id=AnyStr(), ), ] } assert foo_called == 1 async with awith_checkpointer(checkpointer_name) as checkpointer: # simple interrupt-resume flow foo_called = 0 graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) thread1 = {"configurable": {"thread_id": "1"}} assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] } assert foo_called == 0 assert await graph.ainvoke(None, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), _AnyIdToolMessage( content="{'hi': [1, 2, 3]}", tool_call_id=AnyStr(), ), ] } assert foo_called == 1 # interrupt-update-resume flow foo_called = 0 graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) thread1 = {"configurable": {"thread_id": "2"}} assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] } assert foo_called == 0 if not FF_SEND_V2: return # get state should show the pending task state = await graph.aget_state(thread1) assert state == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] }, next=("foo",), config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 0, "source": "loop", "writes": None, "parents": {}, "thread_id": "2", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( content="", additional_kwargs={}, response_metadata={}, id="ai1", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ) }, ), PregelTask( id=AnyStr(), name="foo", path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()), error=None, interrupts=(), state=None, result=None, ), ), ) # remove the tool call, clearing the pending task await graph.aupdate_state( thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])} ) # tool call no longer in pending tasks assert await graph.aget_state(thread1) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="Bye now", tool_calls=[], ), ] }, next=(), config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 1, "source": "update", "writes": { "agent": { "messages": _AnyIdAIMessage( content="Bye now", tool_calls=[], ) } }, "parents": {}, "thread_id": "2", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=(), ) # tool call not executed assert await graph.ainvoke(None, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage(content="Bye now"), ] } assert foo_called == 0 @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_max_concurrency(checkpointer_name: str) -> None: class Node: def __init__(self, name: str): self.name = name setattr(self, "__name__", name) self.currently = 0 self.max_currently = 0 async def __call__(self, state): self.currently += 1 if self.currently > self.max_currently: self.max_currently = self.currently await asyncio.sleep(random.random() / 10) self.currently -= 1 return [state] def one(state): return ["1"] def three(state): return ["3"] async def send_to_many(state): return [Send("2", idx) for idx in range(100)] async def route_to_three(state) -> Literal["3"]: return "3" node2 = Node("2") builder = StateGraph(Annotated[list, operator.add]) builder.add_node("1", one) builder.add_node(node2) builder.add_node("3", three) builder.add_edge(START, "1") builder.add_conditional_edges("1", send_to_many) builder.add_conditional_edges("2", route_to_three) graph = builder.compile() assert await graph.ainvoke(["0"]) == ["0", "1", *range(100), "3"] assert node2.max_currently == 100 assert node2.currently == 0 node2.max_currently = 0 assert await graph.ainvoke(["0"], {"max_concurrency": 10}) == [ "0", "1", *range(100), "3", ] assert node2.max_currently == 10 assert node2.currently == 0 async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer, interrupt_before=["2"]) thread1 = {"max_concurrency": 10, "configurable": {"thread_id": "1"}} assert await graph.ainvoke(["0"], thread1, debug=True) == ["0", "1"] state = await graph.aget_state(thread1) assert state.values == ["0", "1"] assert await graph.ainvoke(None, thread1) == ["0", "1", *range(100), "3"] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_max_concurrency_control(checkpointer_name: str) -> None: async def node1(state) -> Command[Literal["2"]]: return Command(update=["1"], goto=[Send("2", idx) for idx in range(100)]) node2_currently = 0 node2_max_currently = 0 async def node2(state) -> Command[Literal["3"]]: nonlocal node2_currently, node2_max_currently node2_currently += 1 if node2_currently > node2_max_currently: node2_max_currently = node2_currently await asyncio.sleep(0.1) node2_currently -= 1 return Command(update=[state], goto="3") async def node3(state) -> Literal["3"]: return ["3"] builder = StateGraph(Annotated[list, operator.add]) builder.add_node("1", node1) builder.add_node("2", node2) builder.add_node("3", node3) builder.add_edge(START, "1") graph = builder.compile() assert ( graph.get_graph().draw_mermaid() == """%%{init: {'flowchart': {'curve': 'linear'}}}%% graph TD; __start__([<p>__start__</p>]):::first 1(1) 2(2) 3([3]):::last __start__ --> 1; 1 -.-> 2; 2 -.-> 3; classDef default fill:#f2f0ff,line-height:1.2 classDef first fill-opacity:0 classDef last fill:#bfb6fc """ ) assert await graph.ainvoke(["0"], debug=True) == ["0", "1", *range(100), "3"] assert node2_max_currently == 100 assert node2_currently == 0 node2_max_currently = 0 assert await graph.ainvoke(["0"], {"max_concurrency": 10}) == [ "0", "1", *range(100), "3", ] assert node2_max_currently == 10 assert node2_currently == 0 async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer, interrupt_before=["2"]) thread1 = {"max_concurrency": 10, "configurable": {"thread_id": "1"}} assert await graph.ainvoke(["0"], thread1) == ["0", "1"] assert await graph.ainvoke(None, thread1) == ["0", "1", *range(100), "3"] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_invoke_checkpoint_three( mocker: MockerFixture, checkpointer_name: str ) -> None: add_one = mocker.Mock(side_effect=lambda x: x["total"] + x["input"]) def raise_if_above_10(input: int) -> int: if input > 10: raise ValueError("Input is too large") return input one = ( Channel.subscribe_to(["input"]).join(["total"]) | add_one | Channel.write_to("output", "total") | raise_if_above_10 ) async with awith_checkpointer(checkpointer_name) as checkpointer: app = Pregel( nodes={"one": one}, channels={ "total": BinaryOperatorAggregate(int, operator.add), "input": LastValue(int), "output": LastValue(int), }, input_channels="input", output_channels="output", checkpointer=checkpointer, debug=True, ) thread_1 = {"configurable": {"thread_id": "1"}} # total starts out as 0, so output is 0+2=2 assert await app.ainvoke(2, thread_1) == 2 state = await app.aget_state(thread_1) assert state is not None assert state.values.get("total") == 2 assert ( state.config["configurable"]["checkpoint_id"] == (await checkpointer.aget(thread_1))["id"] ) # total is now 2, so output is 2+3=5 assert await app.ainvoke(3, thread_1) == 5 state = await app.aget_state(thread_1) assert state is not None assert state.values.get("total") == 7 assert ( state.config["configurable"]["checkpoint_id"] == (await checkpointer.aget(thread_1))["id"] ) # total is now 2+5=7, so output would be 7+4=11, but raises ValueError with pytest.raises(ValueError): await app.ainvoke(4, thread_1) # checkpoint is not updated state = await app.aget_state(thread_1) assert state is not None assert state.values.get("total") == 7 assert state.next == ("one",) """we checkpoint inputs and it failed on "one", so the next node is one""" # we can recover from error by sending new inputs assert await app.ainvoke(2, thread_1) == 9 state = await app.aget_state(thread_1) assert state is not None assert state.values.get("total") == 16, "total is now 7+9=16" assert state.next == () thread_2 = {"configurable": {"thread_id": "2"}} # on a new thread, total starts out as 0, so output is 0+5=5 assert await app.ainvoke(5, thread_2) == 5 state = await app.aget_state({"configurable": {"thread_id": "1"}}) assert state is not None assert state.values.get("total") == 16 assert state.next == () state = await app.aget_state(thread_2) assert state is not None assert state.values.get("total") == 5 assert state.next == () assert len([c async for c in app.aget_state_history(thread_1, limit=1)]) == 1 # list all checkpoints for thread 1 thread_1_history = [c async for c in app.aget_state_history(thread_1)] # there are 7 checkpoints assert len(thread_1_history) == 7 assert Counter(c.metadata["source"] for c in thread_1_history) == { "input": 4, "loop": 3, } # sorted descending assert ( thread_1_history[0].config["configurable"]["checkpoint_id"] > thread_1_history[1].config["configurable"]["checkpoint_id"] ) # cursor pagination cursored = [ c async for c in app.aget_state_history( thread_1, limit=1, before=thread_1_history[0].config ) ] assert len(cursored) == 1 assert cursored[0].config == thread_1_history[1].config # the last checkpoint assert thread_1_history[0].values["total"] == 16 # the first "loop" checkpoint assert thread_1_history[-2].values["total"] == 2 # can get each checkpoint using aget with config assert (await checkpointer.aget(thread_1_history[0].config))[ "id" ] == thread_1_history[0].config["configurable"]["checkpoint_id"] assert (await checkpointer.aget(thread_1_history[1].config))[ "id" ] == thread_1_history[1].config["configurable"]["checkpoint_id"] thread_1_next_config = await app.aupdate_state(thread_1_history[1].config, 10) # update creates a new checkpoint assert ( thread_1_next_config["configurable"]["checkpoint_id"] > thread_1_history[0].config["configurable"]["checkpoint_id"] ) # 1 more checkpoint in history assert len([c async for c in app.aget_state_history(thread_1)]) == 8 assert Counter( [c.metadata["source"] async for c in app.aget_state_history(thread_1)] ) == { "update": 1, "input": 4, "loop": 3, } # the latest checkpoint is the updated one assert await app.aget_state(thread_1) == await app.aget_state( thread_1_next_config ) async def test_invoke_two_processes_two_in_join_two_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) add_10_each = mocker.Mock(side_effect=lambda x: sorted(y + 10 for y in x)) one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") chain_three = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") chain_four = ( Channel.subscribe_to("inbox") | add_10_each | Channel.write_to("output") ) app = Pregel( nodes={ "one": one, "chain_three": chain_three, "chain_four": chain_four, }, channels={ "inbox": Topic(int), "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", ) # Then invoke app # We get a single array result as chain_four waits for all publishers to finish # before operating on all elements published to topic_two as an array for _ in range(100): assert await app.ainvoke(2) == [13, 13] assert await asyncio.gather(*(app.ainvoke(2) for _ in range(100))) == [ [13, 13] for _ in range(100) ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_invoke_join_then_call_other_pregel( mocker: MockerFixture, checkpointer_name: str ) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) add_10_each = mocker.Mock(side_effect=lambda x: [y + 10 for y in x]) inner_app = Pregel( nodes={ "one": Channel.subscribe_to("input") | add_one | Channel.write_to("output") }, channels={ "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", ) one = ( Channel.subscribe_to("input") | add_10_each | Channel.write_to("inbox_one").map() ) two = ( Channel.subscribe_to("inbox_one") | inner_app.map() | sorted | Channel.write_to("outbox_one") ) chain_three = Channel.subscribe_to("outbox_one") | sum | Channel.write_to("output") app = Pregel( nodes={ "one": one, "two": two, "chain_three": chain_three, }, channels={ "inbox_one": Topic(int), "outbox_one": LastValue(int), "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", ) # Then invoke pubsub for _ in range(10): assert await app.ainvoke([2, 3]) == 27 assert await asyncio.gather(*(app.ainvoke([2, 3]) for _ in range(10))) == [ 27 for _ in range(10) ] async with awith_checkpointer(checkpointer_name) as checkpointer: # add checkpointer app.checkpointer = checkpointer # subgraph is called twice in the same node, through .map(), so raises with pytest.raises(MultipleSubgraphsError): await app.ainvoke([2, 3], {"configurable": {"thread_id": "1"}}) # set inner graph checkpointer NeverCheckpoint inner_app.checkpointer = False # subgraph still called twice, but checkpointing for inner graph is disabled assert await app.ainvoke([2, 3], {"configurable": {"thread_id": "1"}}) == 27 async def test_invoke_two_processes_one_in_two_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = ( Channel.subscribe_to("input") | add_one | Channel.write_to("output", "between") ) two = Channel.subscribe_to("between") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": one, "two": two}, channels={ "input": LastValue(int), "between": LastValue(int), "output": LastValue(int), }, stream_channels=["output", "between"], input_channels="input", output_channels="output", ) # Then invoke pubsub assert [c async for c in app.astream(2)] == [ {"between": 3, "output": 3}, {"between": 3, "output": 4}, ] async def test_invoke_two_processes_no_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("between") two = Channel.subscribe_to("between") | add_one app = Pregel( nodes={"one": one, "two": two}, channels={ "input": LastValue(int), "between": LastValue(int), "output": LastValue(int), }, input_channels="input", output_channels="output", ) # It finishes executing (once no more messages being published) # but returns nothing, as nothing was published to "output" topic assert await app.ainvoke(2) is None async def test_channel_enter_exit_timing(mocker: MockerFixture) -> None: setup_sync = mocker.Mock() cleanup_sync = mocker.Mock() setup_async = mocker.Mock() cleanup_async = mocker.Mock() @contextmanager def an_int() -> Generator[int, None, None]: setup_sync() try: yield 5 finally: cleanup_sync() @asynccontextmanager async def an_int_async() -> AsyncGenerator[int, None]: setup_async() try: yield 5 finally: cleanup_async() add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") two = ( Channel.subscribe_to("inbox") | RunnableLambda(add_one).abatch | Channel.write_to("output").abatch ) app = Pregel( nodes={"one": one, "two": two}, channels={ "input": LastValue(int), "output": LastValue(int), "inbox": Topic(int), "ctx": Context(an_int, an_int_async), }, input_channels="input", output_channels=["inbox", "output"], stream_channels=["inbox", "output"], ) async def aenumerate(aiter: AsyncIterator[Any]) -> AsyncIterator[tuple[int, Any]]: i = 0 async for chunk in aiter: yield i, chunk i += 1 assert setup_sync.call_count == 0 assert cleanup_sync.call_count == 0 assert setup_async.call_count == 0 assert cleanup_async.call_count == 0 async for i, chunk in aenumerate(app.astream(2)): assert setup_sync.call_count == 0, "Sync context manager should not be used" assert cleanup_sync.call_count == 0, "Sync context manager should not be used" assert setup_async.call_count == 1, "Expected setup to be called once" if i == 0: assert chunk == {"inbox": [3]} elif i == 1: assert chunk == {"output": 4} else: assert False, "Expected only two chunks" assert setup_sync.call_count == 0 assert cleanup_sync.call_count == 0 assert setup_async.call_count == 1, "Expected setup to be called once" assert cleanup_async.call_count == 1, "Expected cleanup to be called once" @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_conditional_graph(checkpointer_name: str) -> None: from langchain_core.agents import AgentAction, AgentFinish from langchain_core.language_models.fake import FakeStreamingListLLM from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_core.tools import tool # Assemble the tools @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] # Construct the agent prompt = PromptTemplate.from_template("Hello!") llm = FakeStreamingListLLM( responses=[ "tool:search_api:query", "tool:search_api:another", "finish:answer", ] ) async def agent_parser(input: str) -> Union[AgentAction, AgentFinish]: if input.startswith("finish"): _, answer = input.split(":") return AgentFinish(return_values={"answer": answer}, log=input) else: _, tool_name, tool_input = input.split(":") return AgentAction(tool=tool_name, tool_input=tool_input, log=input) agent = RunnablePassthrough.assign(agent_outcome=prompt | llm | agent_parser) # Define tool execution logic async def execute_tools(data: dict) -> dict: data = data.copy() agent_action: AgentAction = data.pop("agent_outcome") observation = await {t.name: t for t in tools}[agent_action.tool].ainvoke( agent_action.tool_input ) if data.get("intermediate_steps") is None: data["intermediate_steps"] = [] else: data["intermediate_steps"] = data["intermediate_steps"].copy() data["intermediate_steps"].append([agent_action, observation]) return data # Define decision-making logic async def should_continue(data: dict, config: RunnableConfig) -> str: # Logic to decide whether to continue in the loop or exit if isinstance(data["agent_outcome"], AgentFinish): return "exit" else: return "continue" # Define a new graph workflow = Graph() workflow.add_node("agent", agent) workflow.add_node("tools", execute_tools) workflow.set_entry_point("agent") workflow.add_conditional_edges( "agent", should_continue, {"continue": "tools", "exit": END} ) workflow.add_edge("tools", "agent") app = workflow.compile() assert await app.ainvoke({"input": "what is weather in sf"}) == { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), } assert [c async for c in app.astream({"input": "what is weather in sf"})] == [ { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), } }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), } }, ] patches = [c async for c in app.astream_log({"input": "what is weather in sf"})] patch_paths = {op["path"] for log in patches for op in log.ops} # Check that agent (one of the nodes) has its output streamed to the logs assert "/logs/agent/streamed_output/-" in patch_paths assert "/logs/agent:2/streamed_output/-" in patch_paths assert "/logs/agent:3/streamed_output/-" in patch_paths # Check that agent (one of the nodes) has its final output set in the logs assert "/logs/agent/final_output" in patch_paths assert "/logs/agent:2/final_output" in patch_paths assert "/logs/agent:3/final_output" in patch_paths assert [ p["value"] for log in patches for p in log.ops if p["path"] == "/logs/agent/final_output" or p["path"] == "/logs/agent:2/final_output" or p["path"] == "/logs/agent:3/final_output" ] == [ { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), }, { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), }, { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), }, ] async with awith_checkpointer(checkpointer_name) as checkpointer: # test state get/update methods with interrupt_after app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["agent"], ) config = {"configurable": {"thread_id": "1"}} assert [ c async for c in app_w_interrupt.astream( {"input": "what is weather in sf"}, config ) ] == [ { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } } ] assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), }, }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": { "agent": { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } } }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) await app_w_interrupt.aupdate_state( config, { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, ) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", } }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, ] await app_w_interrupt.aupdate_state( config, { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), }, ) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), }, }, tasks=(), next=(), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 4, "writes": { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), } }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) # test state get/update methods with interrupt_before app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before=["tools"], ) config = {"configurable": {"thread_id": "2"}} llm.i = 0 assert [ c async for c in app_w_interrupt.astream( {"input": "what is weather in sf"}, config ) ] == [ { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } } ] assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), }, }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": { "agent": { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } } }, "thread_id": "2", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) await app_w_interrupt.aupdate_state( config, { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, ) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", } }, "thread_id": "2", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, ] await app_w_interrupt.aupdate_state( config, { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), }, ) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), }, }, tasks=(), next=(), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 4, "writes": { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), } }, "thread_id": "2", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) # test re-invoke to continue with interrupt_before app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before=["tools"], ) config = {"configurable": {"thread_id": "3"}} llm.i = 0 # reset the llm assert [ c async for c in app_w_interrupt.astream( {"input": "what is weather in sf"}, config ) ] == [ { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } } ] assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), }, }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": { "agent": { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } } }, "thread_id": "3", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), }, }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, ] assert [c async for c in app_w_interrupt.astream(None, config)] == [ { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), } }, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_conditional_graph_state( mocker: MockerFixture, checkpointer_name: str ) -> None: from langchain_core.agents import AgentAction, AgentFinish from langchain_core.language_models.fake import FakeStreamingListLLM from langchain_core.prompts import PromptTemplate from langchain_core.tools import tool setup = mocker.Mock() teardown = mocker.Mock() @asynccontextmanager async def assert_ctx_once() -> AsyncIterator[None]: assert setup.call_count == 0 assert teardown.call_count == 0 try: yield finally: assert setup.call_count == 1 assert teardown.call_count == 1 setup.reset_mock() teardown.reset_mock() class MyPydanticContextModel(BaseModel, arbitrary_types_allowed=True): session: httpx.AsyncClient something_else: str @asynccontextmanager async def make_context( config: RunnableConfig, ) -> AsyncIterator[MyPydanticContextModel]: assert isinstance(config, dict) setup() session = httpx.AsyncClient() try: yield MyPydanticContextModel(session=session, something_else="hello") finally: await session.aclose() teardown() class AgentState(TypedDict): input: Annotated[str, UntrackedValue] agent_outcome: Optional[Union[AgentAction, AgentFinish]] intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add] context: Annotated[MyPydanticContextModel, Context(make_context)] # Assemble the tools @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] # Construct the agent prompt = PromptTemplate.from_template("Hello!") llm = FakeStreamingListLLM( responses=[ "tool:search_api:query", "tool:search_api:another", "finish:answer", ] ) def agent_parser(input: str) -> dict[str, Union[AgentAction, AgentFinish]]: if input.startswith("finish"): _, answer = input.split(":") return { "agent_outcome": AgentFinish( return_values={"answer": answer}, log=input ) } else: _, tool_name, tool_input = input.split(":") return { "agent_outcome": AgentAction( tool=tool_name, tool_input=tool_input, log=input ) } agent = prompt | llm | agent_parser # Define tool execution logic def execute_tools(data: AgentState) -> dict: # check we have httpx session in AgentState assert isinstance(data["context"], MyPydanticContextModel) # execute the tool agent_action: AgentAction = data.pop("agent_outcome") observation = {t.name: t for t in tools}[agent_action.tool].invoke( agent_action.tool_input ) return {"intermediate_steps": [[agent_action, observation]]} # Define decision-making logic def should_continue(data: AgentState) -> str: # check we have httpx session in AgentState assert isinstance(data["context"], MyPydanticContextModel) # Logic to decide whether to continue in the loop or exit if isinstance(data["agent_outcome"], AgentFinish): return "exit" else: return "continue" # Define a new graph workflow = StateGraph(AgentState) workflow.add_node("agent", agent) workflow.add_node("tools", execute_tools) workflow.set_entry_point("agent") workflow.add_conditional_edges( "agent", should_continue, {"continue": "tools", "exit": END} ) workflow.add_edge("tools", "agent") app = workflow.compile() async with assert_ctx_once(): assert await app.ainvoke({"input": "what is weather in sf"}) == { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), } async with assert_ctx_once(): assert [c async for c in app.astream({"input": "what is weather in sf"})] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], } }, { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], } }, { "agent": { "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), } }, ] async with assert_ctx_once(): patches = [c async for c in app.astream_log({"input": "what is weather in sf"})] patch_paths = {op["path"] for log in patches for op in log.ops} # Check that agent (one of the nodes) has its output streamed to the logs assert "/logs/agent/streamed_output/-" in patch_paths # Check that agent (one of the nodes) has its final output set in the logs assert "/logs/agent/final_output" in patch_paths assert [ p["value"] for log in patches for p in log.ops if p["path"] == "/logs/agent/final_output" or p["path"] == "/logs/agent:2/final_output" or p["path"] == "/logs/agent:3/final_output" ] == [ { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ) }, { "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another" ) }, { "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), }, ] async with awith_checkpointer(checkpointer_name) as checkpointer: # test state get/update methods with interrupt_after app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["agent"], ) config = {"configurable": {"thread_id": "1"}} async with assert_ctx_once(): assert [ c async for c in app_w_interrupt.astream( {"input": "what is weather in sf"}, config ) ] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, {"__interrupt__": ()}, ] assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "intermediate_steps": [], }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) async with assert_ctx_once(): await app_w_interrupt.aupdate_state( config, { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ) }, ) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "intermediate_steps": [], }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 2, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ) } }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) async with assert_ctx_once(): assert [c async for c in app_w_interrupt.astream(None, config)] == [ { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], } }, { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, {"__interrupt__": ()}, ] async with assert_ctx_once(): await app_w_interrupt.aupdate_state( config, { "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ) }, ) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], }, tasks=(), next=(), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 5, "writes": { "agent": { "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ) } }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) # test state get/update methods with interrupt_before app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before=["tools"], ) config = {"configurable": {"thread_id": "2"}} llm.i = 0 # reset the llm assert [ c async for c in app_w_interrupt.astream( {"input": "what is weather in sf"}, config ) ] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, {"__interrupt__": ()}, ] assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), "intermediate_steps": [], }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, "thread_id": "2", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) await app_w_interrupt.aupdate_state( config, { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ) }, ) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "intermediate_steps": [], }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 2, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ) } }, "thread_id": "2", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], } }, { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, {"__interrupt__": ()}, ] await app_w_interrupt.aupdate_state( config, { "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ) }, ) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], }, tasks=(), next=(), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 5, "writes": { "agent": { "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ) } }, "thread_id": "2", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) async def test_conditional_entrypoint_graph() -> None: async def left(data: str) -> str: return data + "->left" async def right(data: str) -> str: return data + "->right" def should_start(data: str) -> str: # Logic to decide where to start if len(data) > 10: return "go-right" else: return "go-left" # Define a new graph workflow = Graph() workflow.add_node("left", left) workflow.add_node("right", right) workflow.set_conditional_entry_point( should_start, {"go-left": "left", "go-right": "right"} ) workflow.add_conditional_edges("left", lambda data: END) workflow.add_edge("right", END) app = workflow.compile() assert await app.ainvoke("what is weather in sf") == "what is weather in sf->right" assert [c async for c in app.astream("what is weather in sf")] == [ {"right": "what is weather in sf->right"}, ] async def test_conditional_entrypoint_graph_state() -> None: class AgentState(TypedDict, total=False): input: str output: str steps: Annotated[list[str], operator.add] async def left(data: AgentState) -> AgentState: return {"output": data["input"] + "->left"} async def right(data: AgentState) -> AgentState: return {"output": data["input"] + "->right"} def should_start(data: AgentState) -> str: assert data["steps"] == [], "Expected input to be read from the state" # Logic to decide where to start if len(data["input"]) > 10: return "go-right" else: return "go-left" # Define a new graph workflow = StateGraph(AgentState) workflow.add_node("left", left) workflow.add_node("right", right) workflow.set_conditional_entry_point( should_start, {"go-left": "left", "go-right": "right"} ) workflow.add_conditional_edges("left", lambda data: END) workflow.add_edge("right", END) app = workflow.compile() assert await app.ainvoke({"input": "what is weather in sf"}) == { "input": "what is weather in sf", "output": "what is weather in sf->right", "steps": [], } assert [c async for c in app.astream({"input": "what is weather in sf"})] == [ {"right": {"output": "what is weather in sf->right"}}, ] async def test_prebuilt_tool_chat() -> None: from langchain_core.messages import AIMessage, HumanMessage from langchain_core.tools import tool model = FakeChatModel( messages=[ AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), AIMessage( content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another"}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one"}, }, ], ), AIMessage(content="answer"), ] ) @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] app = create_tool_calling_executor(model, tools) assert await app.ainvoke( {"messages": [HumanMessage(content="what is weather in sf")]} ) == { "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), _AnyIdAIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ), _AnyIdAIMessage( content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another"}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one"}, }, ], ), _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call234", ), _AnyIdToolMessage( content="result for a third one", name="search_api", tool_call_id="tool_call567", id=AnyStr(), ), _AnyIdAIMessage(content="answer"), ] } assert [ c async for c in app.astream( {"messages": [HumanMessage(content="what is weather in sf")]}, stream_mode="messages", ) ] == [ ( _AnyIdAIMessageChunk( content="", tool_calls=[ { "name": "search_api", "args": {"query": "query"}, "id": "tool_call123", "type": "tool_call", } ], tool_call_chunks=[ { "name": "search_api", "args": '{"query": "query"}', "id": "tool_call123", "index": None, "type": "tool_call_chunk", } ], ), { "langgraph_step": 1, "langgraph_node": "agent", "langgraph_triggers": ["start:agent"], "langgraph_path": ("__pregel_pull", "agent"), "langgraph_checkpoint_ns": AnyStr("agent:"), "checkpoint_ns": AnyStr("agent:"), "ls_provider": "fakechatmodel", "ls_model_type": "chat", }, ), ( _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ), { "langgraph_step": 2, "langgraph_node": "tools", "langgraph_triggers": ["branch:agent:should_continue:tools"], "langgraph_path": ("__pregel_pull", "tools"), "langgraph_checkpoint_ns": AnyStr("tools:"), }, ), ( _AnyIdAIMessageChunk( content="", tool_calls=[ { "name": "search_api", "args": {"query": "another"}, "id": "tool_call234", "type": "tool_call", }, { "name": "search_api", "args": {"query": "a third one"}, "id": "tool_call567", "type": "tool_call", }, ], tool_call_chunks=[ { "name": "search_api", "args": '{"query": "another"}', "id": "tool_call234", "index": None, "type": "tool_call_chunk", }, { "name": "search_api", "args": '{"query": "a third one"}', "id": "tool_call567", "index": None, "type": "tool_call_chunk", }, ], ), { "langgraph_step": 3, "langgraph_node": "agent", "langgraph_triggers": ["tools"], "langgraph_path": ("__pregel_pull", "agent"), "langgraph_checkpoint_ns": AnyStr("agent:"), "checkpoint_ns": AnyStr("agent:"), "ls_provider": "fakechatmodel", "ls_model_type": "chat", }, ), ( _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call234", ), { "langgraph_step": 4, "langgraph_node": "tools", "langgraph_triggers": ["branch:agent:should_continue:tools"], "langgraph_path": ("__pregel_pull", "tools"), "langgraph_checkpoint_ns": AnyStr("tools:"), }, ), ( _AnyIdToolMessage( content="result for a third one", name="search_api", tool_call_id="tool_call567", ), { "langgraph_step": 4, "langgraph_node": "tools", "langgraph_triggers": ["branch:agent:should_continue:tools"], "langgraph_path": ("__pregel_pull", "tools"), "langgraph_checkpoint_ns": AnyStr("tools:"), }, ), ( _AnyIdAIMessageChunk( content="answer", ), { "langgraph_step": 5, "langgraph_node": "agent", "langgraph_triggers": ["tools"], "langgraph_path": ("__pregel_pull", "agent"), "langgraph_checkpoint_ns": AnyStr("agent:"), "checkpoint_ns": AnyStr("agent:"), "ls_provider": "fakechatmodel", "ls_model_type": "chat", }, ), ] assert [ c async for c in app.astream( {"messages": [HumanMessage(content="what is weather in sf")]} ) ] == [ { "agent": { "messages": [ _AnyIdAIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ) ] } }, { "tools": { "messages": [ _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ) ] } }, { "agent": { "messages": [ _AnyIdAIMessage( content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another"}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one"}, }, ], ) ] } }, { "tools": { "messages": [ _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call234", ), _AnyIdToolMessage( content="result for a third one", name="search_api", tool_call_id="tool_call567", ), ] } }, {"agent": {"messages": [_AnyIdAIMessage(content="answer")]}}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_state_graph_packets(checkpointer_name: str) -> None: from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import ( AIMessage, BaseMessage, HumanMessage, ToolMessage, ) from langchain_core.tools import tool class AgentState(TypedDict): messages: Annotated[list[BaseMessage], add_messages] session: Annotated[httpx.AsyncClient, Context(httpx.AsyncClient)] @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] tools_by_name = {t.name: t for t in tools} model = FakeMessagesListChatModel( responses=[ AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ), AIMessage(id="ai3", content="answer"), ] ) # Define decision-making logic def should_continue(data: AgentState) -> str: assert isinstance(data["session"], httpx.AsyncClient) # Logic to decide whether to continue in the loop or exit if tool_calls := data["messages"][-1].tool_calls: return [Send("tools", tool_call) for tool_call in tool_calls] else: return END async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: await asyncio.sleep(input["args"].get("idx", 0) / 10) output = await tools_by_name[input["name"]].ainvoke(input["args"], config) return { "messages": ToolMessage( content=output, name=input["name"], tool_call_id=input["id"] ) } # Define a new graph workflow = StateGraph(AgentState) # Define the two nodes we will cycle between workflow.add_node("agent", {"messages": RunnablePick("messages") | model}) workflow.add_node("tools", tools_node) # Set the entrypoint as `agent` # This means that this node is the first one called workflow.set_entry_point("agent") # We now add a conditional edge workflow.add_conditional_edges("agent", should_continue) # We now add a normal edge from `tools` to `agent`. # This means that after `tools` is called, `agent` node is called next. workflow.add_edge("tools", "agent") # Finally, we compile it! # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable app = workflow.compile() assert await app.ainvoke( {"messages": HumanMessage(content="what is weather in sf")} ) == { "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ), AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ), _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call234", ), _AnyIdToolMessage( content="result for a third one", name="search_api", tool_call_id="tool_call567", ), AIMessage(content="answer", id="ai3"), ] } assert [ c async for c in app.astream( {"messages": [HumanMessage(content="what is weather in sf")]} ) ] == [ { "agent": { "messages": AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ) }, }, { "tools": { "messages": _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ) } }, { "agent": { "messages": AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ) } }, { "tools": { "messages": _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call234", ) }, }, { "tools": { "messages": _AnyIdToolMessage( content="result for a third one", name="search_api", tool_call_id="tool_call567", ), }, }, {"agent": {"messages": AIMessage(content="answer", id="ai3")}}, ] async with awith_checkpointer(checkpointer_name) as checkpointer: # interrupt after agent app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["agent"], ) config = {"configurable": {"thread_id": "1"}} assert [ c async for c in app_w_interrupt.astream( {"messages": HumanMessage(content="what is weather in sf")}, config ) ] == [ { "agent": { "messages": AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ) } }, {"__interrupt__": ()}, ] if not FF_SEND_V2: return assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), ] }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( "", id="ai1", tool_calls=[ { "name": "search_api", "args": {"query": "query"}, "id": "tool_call123", "type": "tool_call", } ], ) }, ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr()) ), ), next=("tools",), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) # modify ai message last_message = (await app_w_interrupt.aget_state(config)).values["messages"][-1] last_message.tool_calls[0]["args"]["query"] = "a different query" await app_w_interrupt.aupdate_state(config, {"messages": last_message}) # message was replaced instead of appended tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), ] }, tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0, AnyStr())),), next=("tools",), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": { "agent": { "messages": AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ) } }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ { "tools": { "messages": _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ) } }, { "agent": { "messages": AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ) }, }, {"__interrupt__": ()}, ] tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ), ] }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( "", id="ai2", tool_calls=[ { "name": "search_api", "args": {"query": "another", "idx": 0}, "id": "tool_call234", "type": "tool_call", }, { "name": "search_api", "args": {"query": "a third one", "idx": 1}, "id": "tool_call567", "type": "tool_call", }, ], ) }, ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr()) ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3, AnyStr()) ), ), next=("tools", "tools"), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 2, "writes": { "tools": { "messages": _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), }, }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) await app_w_interrupt.aupdate_state( config, {"messages": AIMessage(content="answer", id="ai2")}, ) # replaces message even if object identity is different, as long as id is the same tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage(content="answer", id="ai2"), ] }, tasks=(), next=(), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 3, "writes": { "agent": { "messages": AIMessage(content="answer", id="ai2"), } }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) # interrupt before tools app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before=["tools"], ) config = {"configurable": {"thread_id": "2"}} model.i = 0 assert [ c async for c in app_w_interrupt.astream( {"messages": HumanMessage(content="what is weather in sf")}, config ) ] == [ { "agent": { "messages": AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ) } }, {"__interrupt__": ()}, ] assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), ] }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( content="", additional_kwargs={}, response_metadata={}, id="ai1", tool_calls=[ { "name": "search_api", "args": {"query": "query"}, "id": "tool_call123", "type": "tool_call", } ], ) }, ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr()) ), ), next=("tools",), config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, created_at=( await app_w_interrupt.checkpointer.aget_tuple(config) ).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "2", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) # modify ai message last_message = (await app_w_interrupt.aget_state(config)).values["messages"][-1] last_message.tool_calls[0]["args"]["query"] = "a different query" await app_w_interrupt.aupdate_state(config, {"messages": last_message}) # message was replaced instead of appended tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), ] }, tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0, AnyStr())),), next=("tools",), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": { "agent": { "messages": AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ) } }, "thread_id": "2", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ { "tools": { "messages": _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ) } }, { "agent": { "messages": AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ) }, }, {"__interrupt__": ()}, ] tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ), ] }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( content="", additional_kwargs={}, response_metadata={}, id="ai2", tool_calls=[ { "name": "search_api", "args": {"query": "another", "idx": 0}, "id": "tool_call234", "type": "tool_call", }, { "name": "search_api", "args": {"query": "a third one", "idx": 1}, "id": "tool_call567", "type": "tool_call", }, ], ) }, ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr()) ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3, AnyStr()) ), ), next=("tools", "tools"), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 2, "writes": { "tools": { "messages": _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), }, }, "thread_id": "2", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) await app_w_interrupt.aupdate_state( config, {"messages": AIMessage(content="answer", id="ai2")}, ) # replaces message even if object identity is different, as long as id is the same tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage(content="answer", id="ai2"), ] }, tasks=(), next=(), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 3, "writes": { "agent": { "messages": AIMessage(content="answer", id="ai2"), } }, "thread_id": "2", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_message_graph(checkpointer_name: str) -> None: from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import AIMessage, HumanMessage from langchain_core.tools import tool class FakeFuntionChatModel(FakeMessagesListChatModel): def bind_functions(self, functions: list): return self @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] model = FakeFuntionChatModel( responses=[ AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ), AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ), AIMessage(content="answer", id="ai3"), ] ) # Define the function that determines whether to continue or not def should_continue(messages): last_message = messages[-1] # If there is no function call, then we finish if not last_message.tool_calls: return "end" # Otherwise if there is, we continue else: return "continue" # Define a new graph workflow = MessageGraph() # Define the two nodes we will cycle between workflow.add_node("agent", model) workflow.add_node("tools", ToolNode(tools)) # Set the entrypoint as `agent` # This means that this node is the first one called workflow.set_entry_point("agent") # We now add a conditional edge workflow.add_conditional_edges( # First, we define the start node. We use `agent`. # This means these are the edges taken after the `agent` node is called. "agent", # Next, we pass in the function that will determine which node is called next. should_continue, # Finally we pass in a mapping. # The keys are strings, and the values are other nodes. # END is a special node marking that the graph should finish. # What will happen is we will call `should_continue`, and then the output of that # will be matched against the keys in this mapping. # Based on which one it matches, that node will then be called. { # If `tools`, then we call the tool node. "continue": "tools", # Otherwise we finish. "end": END, }, ) # We now add a normal edge from `tools` to `agent`. # This means that after `tools` is called, `agent` node is called next. workflow.add_edge("tools", "agent") # Finally, we compile it! # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable app = workflow.compile() assert await app.ainvoke(HumanMessage(content="what is weather in sf")) == [ _AnyIdHumanMessage( content="what is weather in sf", ), AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", # respects ids passed in ), _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ), AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ), _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call456", ), AIMessage(content="answer", id="ai3"), ] assert [ c async for c in app.astream([HumanMessage(content="what is weather in sf")]) ] == [ { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, { "tools": [ _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ) ] }, { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, { "tools": [ _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call456", ) ] }, {"agent": AIMessage(content="answer", id="ai3")}, ] async with awith_checkpointer(checkpointer_name) as checkpointer: app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["agent"], ) config = {"configurable": {"thread_id": "1"}} assert [ c async for c in app_w_interrupt.astream( HumanMessage(content="what is weather in sf"), config ) ] == [ { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, {"__interrupt__": ()}, ] tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) # modify ai message last_message = (await app_w_interrupt.aget_state(config)).values[-1] last_message.tool_calls[0]["args"] = {"query": "a different query"} await app_w_interrupt.aupdate_state(config, last_message) # message was replaced instead of appended tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 2, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], id="ai1", ) }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ { "tools": [ _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ) ] }, { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, {"__interrupt__": ()}, ] tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 4, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) await app_w_interrupt.aupdate_state( config, AIMessage(content="answer", id="ai2"), ) # replaces message even if object identity is different, as long as id is the same tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage(content="answer", id="ai2"), ], tasks=(), next=(), config=tup.config, created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 5, "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "1", }, parent_config=[ c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) ][-1].config, ) async def test_in_one_fan_out_out_one_graph_state() -> None: def sorted_add(x: list[str], y: list[str]) -> list[str]: return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], operator.add] async def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} async def retriever_one(data: State) -> State: await asyncio.sleep(0.1) return {"docs": ["doc1", "doc2"]} async def retriever_two(data: State) -> State: return {"docs": ["doc3", "doc4"]} async def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "retriever_one") workflow.add_edge("rewrite_query", "retriever_two") workflow.add_edge("retriever_one", "qa") workflow.add_edge("retriever_two", "qa") workflow.set_finish_point("qa") app = workflow.compile() assert await app.ainvoke({"query": "what is weather in sf"}) == { "query": "query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", } assert [c async for c in app.astream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] assert [ c async for c in app.astream( {"query": "what is weather in sf"}, stream_mode="values" ) ] == [ {"query": "what is weather in sf", "docs": []}, {"query": "query: what is weather in sf", "docs": []}, { "query": "query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], }, { "query": "query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", }, ] assert [ c async for c in app.astream( {"query": "what is weather in sf"}, stream_mode=["values", "updates", "debug"], ) ] == [ ("values", {"query": "what is weather in sf", "docs": []}), ( "debug", { "type": "task", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "rewrite_query", "input": {"query": "what is weather in sf", "docs": []}, "triggers": ["start:rewrite_query"], }, }, ), ("updates", {"rewrite_query": {"query": "query: what is weather in sf"}}), ( "debug", { "type": "task_result", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "rewrite_query", "result": [("query", "query: what is weather in sf")], "error": None, "interrupts": [], }, }, ), ("values", {"query": "query: what is weather in sf", "docs": []}), ( "debug", { "type": "task", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "retriever_one", "input": {"query": "query: what is weather in sf", "docs": []}, "triggers": ["rewrite_query"], }, }, ), ( "debug", { "type": "task", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "retriever_two", "input": {"query": "query: what is weather in sf", "docs": []}, "triggers": ["rewrite_query"], }, }, ), ( "updates", {"retriever_two": {"docs": ["doc3", "doc4"]}}, ), ( "debug", { "type": "task_result", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "retriever_two", "result": [("docs", ["doc3", "doc4"])], "error": None, "interrupts": [], }, }, ), ( "updates", {"retriever_one": {"docs": ["doc1", "doc2"]}}, ), ( "debug", { "type": "task_result", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "retriever_one", "result": [("docs", ["doc1", "doc2"])], "error": None, "interrupts": [], }, }, ), ( "values", { "query": "query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], }, ), ( "debug", { "type": "task", "timestamp": AnyStr(), "step": 3, "payload": { "id": AnyStr(), "name": "qa", "input": { "query": "query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], }, "triggers": ["retriever_one", "retriever_two"], }, }, ), ("updates", {"qa": {"answer": "doc1,doc2,doc3,doc4"}}), ( "debug", { "type": "task_result", "timestamp": AnyStr(), "step": 3, "payload": { "id": AnyStr(), "name": "qa", "result": [("answer", "doc1,doc2,doc3,doc4")], "error": None, "interrupts": [], }, }, ), ( "values", { "query": "query: what is weather in sf", "answer": "doc1,doc2,doc3,doc4", "docs": ["doc1", "doc2", "doc3", "doc4"], }, ), ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_start_branch_then(checkpointer_name: str) -> None: class State(TypedDict): my_key: Annotated[str, operator.add] market: str shared: Annotated[dict[str, dict[str, Any]], SharedValue.on("assistant_id")] other: Annotated[dict[str, dict[str, Any]], SharedValue.on("assistant_id")] def assert_shared_value(data: State, config: RunnableConfig) -> State: assert "shared" in data if thread_id := config["configurable"].get("thread_id"): if thread_id == "1": # this is the first thread, so should not see a value assert data["shared"] == {} return {"shared": {"1": {"hello": "world"}}, "other": {"2": {1: 2}}} elif thread_id == "2": # this should get value saved by thread 1 assert data["shared"] == {"1": {"hello": "world"}} elif thread_id == "3": # this is a different assistant, so should not see previous value assert data["shared"] == {} return {} def tool_two_slow(data: State, config: RunnableConfig) -> State: return {"my_key": " slow", **assert_shared_value(data, config)} def tool_two_fast(data: State, config: RunnableConfig) -> State: return {"my_key": " fast", **assert_shared_value(data, config)} tool_two_graph = StateGraph(State) tool_two_graph.add_node("tool_two_slow", tool_two_slow) tool_two_graph.add_node("tool_two_fast", tool_two_fast) tool_two_graph.set_conditional_entry_point( lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast", then=END ) tool_two = tool_two_graph.compile() assert await tool_two.ainvoke({"my_key": "value", "market": "DE"}) == { "my_key": "value slow", "market": "DE", } assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == { "my_key": "value fast", "market": "US", } async with awith_checkpointer(checkpointer_name) as checkpointer: tool_two = tool_two_graph.compile( store=InMemoryStore(), checkpointer=checkpointer, interrupt_before=["tool_two_fast", "tool_two_slow"], ) # missing thread_id with pytest.raises(ValueError, match="thread_id"): await tool_two.ainvoke({"my_key": "value", "market": "DE"}) thread1 = {"configurable": {"thread_id": "1", "assistant_id": "a"}} # stop when about to enter node assert await tool_two.ainvoke({"my_key": "value", "market": "DE"}, thread1) == { "my_key": "value", "market": "DE", } assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ { "parents": {}, "source": "loop", "step": 0, "writes": None, "assistant_id": "a", "thread_id": "1", }, { "parents": {}, "source": "input", "step": -1, "writes": {"__start__": {"my_key": "value", "market": "DE"}}, "assistant_id": "a", "thread_id": "1", }, ] assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), config=(await tool_two.checkpointer.aget_tuple(thread1)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "assistant_id": "a", "thread_id": "1", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1, limit=2) ][-1].config, ) # resume, for same result as above assert await tool_two.ainvoke(None, thread1, debug=1) == { "my_key": "value slow", "market": "DE", } assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value slow", "market": "DE"}, tasks=(), next=(), config=(await tool_two.checkpointer.aget_tuple(thread1)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"tool_two_slow": {"my_key": " slow"}}, "assistant_id": "a", "thread_id": "1", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1, limit=2) ][-1].config, ) thread2 = {"configurable": {"thread_id": "2", "assistant_id": "a"}} # stop when about to enter node assert await tool_two.ainvoke({"my_key": "value", "market": "US"}, thread2) == { "my_key": "value", "market": "US", } assert await tool_two.aget_state(thread2) == StateSnapshot( values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), config=(await tool_two.checkpointer.aget_tuple(thread2)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "assistant_id": "a", "thread_id": "2", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread2, limit=2) ][-1].config, ) # resume, for same result as above assert await tool_two.ainvoke(None, thread2, debug=1) == { "my_key": "value fast", "market": "US", } assert await tool_two.aget_state(thread2) == StateSnapshot( values={"my_key": "value fast", "market": "US"}, tasks=(), next=(), config=(await tool_two.checkpointer.aget_tuple(thread2)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"tool_two_fast": {"my_key": " fast"}}, "assistant_id": "a", "thread_id": "2", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread2, limit=2) ][-1].config, ) thread3 = {"configurable": {"thread_id": "3", "assistant_id": "b"}} # stop when about to enter node assert await tool_two.ainvoke({"my_key": "value", "market": "US"}, thread3) == { "my_key": "value", "market": "US", } assert await tool_two.aget_state(thread3) == StateSnapshot( values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), config=(await tool_two.checkpointer.aget_tuple(thread3)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "assistant_id": "b", "thread_id": "3", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread3, limit=2) ][-1].config, ) # update state await tool_two.aupdate_state(thread3, {"my_key": "key"}) # appends to my_key assert await tool_two.aget_state(thread3) == StateSnapshot( values={"my_key": "valuekey", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), config=(await tool_two.checkpointer.aget_tuple(thread3)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "update", "step": 1, "writes": {START: {"my_key": "key"}}, "assistant_id": "b", "thread_id": "3", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread3, limit=2) ][-1].config, ) # resume, for same result as above assert await tool_two.ainvoke(None, thread3, debug=1) == { "my_key": "valuekey fast", "market": "US", } assert await tool_two.aget_state(thread3) == StateSnapshot( values={"my_key": "valuekey fast", "market": "US"}, tasks=(), next=(), config=(await tool_two.checkpointer.aget_tuple(thread3)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 2, "writes": {"tool_two_fast": {"my_key": " fast"}}, "assistant_id": "b", "thread_id": "3", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread3, limit=2) ][-1].config, ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_branch_then(checkpointer_name: str) -> None: class State(TypedDict): my_key: Annotated[str, operator.add] market: str tool_two_graph = StateGraph(State) tool_two_graph.set_entry_point("prepare") tool_two_graph.set_finish_point("finish") tool_two_graph.add_conditional_edges( source="prepare", path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast", then="finish", ) tool_two_graph.add_node("prepare", lambda s: {"my_key": " prepared"}) tool_two_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"}) tool_two_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"}) tool_two_graph.add_node("finish", lambda s: {"my_key": " finished"}) tool_two = tool_two_graph.compile() assert await tool_two.ainvoke({"my_key": "value", "market": "DE"}, debug=1) == { "my_key": "value prepared slow finished", "market": "DE", } assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == { "my_key": "value prepared fast finished", "market": "US", } async with awith_checkpointer(checkpointer_name) as checkpointer: # test stream_mode=debug tool_two = tool_two_graph.compile(checkpointer=checkpointer) thread10 = {"configurable": {"thread_id": "10"}} assert [ c async for c in tool_two.astream( {"my_key": "value", "market": "DE"}, thread10, stream_mode="debug" ) ] == [ { "type": "checkpoint", "timestamp": AnyStr(), "step": -1, "payload": { "config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": {"my_key": ""}, "metadata": { "parents": {}, "source": "input", "step": -1, "writes": {"__start__": {"my_key": "value", "market": "DE"}}, "thread_id": "10", }, "parent_config": None, "next": ["__start__"], "tasks": [ { "id": AnyStr(), "name": "__start__", "interrupts": (), "state": None, } ], }, }, { "type": "checkpoint", "timestamp": AnyStr(), "step": 0, "payload": { "config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": { "my_key": "value", "market": "DE", }, "metadata": { "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "10", }, "parent_config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "next": ["prepare"], "tasks": [ { "id": AnyStr(), "name": "prepare", "interrupts": (), "state": None, } ], }, }, { "type": "task", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "prepare", "input": {"my_key": "value", "market": "DE"}, "triggers": ["start:prepare"], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "prepare", "result": [("my_key", " prepared")], "error": None, "interrupts": [], }, }, { "type": "checkpoint", "timestamp": AnyStr(), "step": 1, "payload": { "config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": { "my_key": "value prepared", "market": "DE", }, "metadata": { "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "10", }, "parent_config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "next": ["tool_two_slow"], "tasks": [ { "id": AnyStr(), "name": "tool_two_slow", "interrupts": (), "state": None, } ], }, }, { "type": "task", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "tool_two_slow", "input": {"my_key": "value prepared", "market": "DE"}, "triggers": ["branch:prepare:condition:tool_two_slow"], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "tool_two_slow", "result": [("my_key", " slow")], "error": None, "interrupts": [], }, }, { "type": "checkpoint", "timestamp": AnyStr(), "step": 2, "payload": { "config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": { "my_key": "value prepared slow", "market": "DE", }, "metadata": { "parents": {}, "source": "loop", "step": 2, "writes": {"tool_two_slow": {"my_key": " slow"}}, "thread_id": "10", }, "parent_config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "next": ["finish"], "tasks": [ { "id": AnyStr(), "name": "finish", "interrupts": (), "state": None, } ], }, }, { "type": "task", "timestamp": AnyStr(), "step": 3, "payload": { "id": AnyStr(), "name": "finish", "input": {"my_key": "value prepared slow", "market": "DE"}, "triggers": ["branch:prepare:condition::then"], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 3, "payload": { "id": AnyStr(), "name": "finish", "result": [("my_key", " finished")], "error": None, "interrupts": [], }, }, { "type": "checkpoint", "timestamp": AnyStr(), "step": 3, "payload": { "config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": { "my_key": "value prepared slow finished", "market": "DE", }, "metadata": { "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "10", }, "parent_config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "next": [], "tasks": [], }, }, ] tool_two = tool_two_graph.compile( checkpointer=checkpointer, interrupt_before=["tool_two_fast", "tool_two_slow"], ) # missing thread_id with pytest.raises(ValueError, match="thread_id"): await tool_two.ainvoke({"my_key": "value", "market": "DE"}) thread1 = {"configurable": {"thread_id": "11"}} # stop when about to enter node assert [ c async for c in tool_two.astream( {"my_key": "value", "market": "DE"}, thread1, stream_mode="debug" ) ] == [ { "type": "checkpoint", "timestamp": AnyStr(), "step": -1, "payload": { "config": { "tags": [], "metadata": {"thread_id": "11"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "11", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": {"my_key": ""}, "metadata": { "parents": {}, "source": "input", "step": -1, "writes": {"__start__": {"my_key": "value", "market": "DE"}}, "thread_id": "11", }, "parent_config": None, "next": ["__start__"], "tasks": [ { "id": AnyStr(), "name": "__start__", "interrupts": (), "state": None, } ], }, }, { "type": "checkpoint", "timestamp": AnyStr(), "step": 0, "payload": { "config": { "tags": [], "metadata": {"thread_id": "11"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "11", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": { "my_key": "value", "market": "DE", }, "metadata": { "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "11", }, "parent_config": { "tags": [], "metadata": {"thread_id": "11"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "11", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "next": ["prepare"], "tasks": [ { "id": AnyStr(), "name": "prepare", "interrupts": (), "state": None, } ], }, }, { "type": "task", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "prepare", "input": {"my_key": "value", "market": "DE"}, "triggers": ["start:prepare"], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "prepare", "result": [("my_key", " prepared")], "error": None, "interrupts": [], }, }, { "type": "checkpoint", "timestamp": AnyStr(), "step": 1, "payload": { "config": { "tags": [], "metadata": {"thread_id": "11"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "11", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": { "my_key": "value prepared", "market": "DE", }, "metadata": { "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "11", }, "parent_config": { "tags": [], "metadata": {"thread_id": "11"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "11", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "next": ["tool_two_slow"], "tasks": [ { "id": AnyStr(), "name": "tool_two_slow", "interrupts": (), "state": None, } ], }, }, ] assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), config=(await tool_two.checkpointer.aget_tuple(thread1)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "11", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1, limit=2) ][-1].config, ) # resume, for same result as above assert await tool_two.ainvoke(None, thread1, debug=1) == { "my_key": "value prepared slow finished", "market": "DE", } assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), config=(await tool_two.checkpointer.aget_tuple(thread1)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "11", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1, limit=2) ][-1].config, ) thread2 = {"configurable": {"thread_id": "12"}} # stop when about to enter node assert await tool_two.ainvoke({"my_key": "value", "market": "US"}, thread2) == { "my_key": "value prepared", "market": "US", } assert await tool_two.aget_state(thread2) == StateSnapshot( values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), config=(await tool_two.checkpointer.aget_tuple(thread2)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "12", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread2, limit=2) ][-1].config, ) # resume, for same result as above assert await tool_two.ainvoke(None, thread2, debug=1) == { "my_key": "value prepared fast finished", "market": "US", } assert await tool_two.aget_state(thread2) == StateSnapshot( values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), config=(await tool_two.checkpointer.aget_tuple(thread2)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "12", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread2, limit=2) ][-1].config, ) tool_two = tool_two_graph.compile( checkpointer=checkpointer, interrupt_after=["prepare"] ) # missing thread_id with pytest.raises(ValueError, match="thread_id"): await tool_two.ainvoke({"my_key": "value", "market": "DE"}) thread1 = {"configurable": {"thread_id": "21"}} # stop when about to enter node assert await tool_two.ainvoke({"my_key": "value", "market": "DE"}, thread1) == { "my_key": "value prepared", "market": "DE", } assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), config=(await tool_two.checkpointer.aget_tuple(thread1)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "21", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1, limit=2) ][-1].config, ) # resume, for same result as above assert await tool_two.ainvoke(None, thread1, debug=1) == { "my_key": "value prepared slow finished", "market": "DE", } assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), config=(await tool_two.checkpointer.aget_tuple(thread1)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "21", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread1, limit=2) ][-1].config, ) thread2 = {"configurable": {"thread_id": "22"}} # stop when about to enter node assert await tool_two.ainvoke({"my_key": "value", "market": "US"}, thread2) == { "my_key": "value prepared", "market": "US", } assert await tool_two.aget_state(thread2) == StateSnapshot( values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), config=(await tool_two.checkpointer.aget_tuple(thread2)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "22", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread2, limit=2) ][-1].config, ) # resume, for same result as above assert await tool_two.ainvoke(None, thread2, debug=1) == { "my_key": "value prepared fast finished", "market": "US", } assert await tool_two.aget_state(thread2) == StateSnapshot( values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), config=(await tool_two.checkpointer.aget_tuple(thread2)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "22", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread2, limit=2) ][-1].config, ) thread3 = {"configurable": {"thread_id": "23"}} # update an empty thread before first run uconfig = await tool_two.aupdate_state( thread3, {"my_key": "key", "market": "DE"} ) # check current state assert await tool_two.aget_state(thread3) == StateSnapshot( values={"my_key": "key", "market": "DE"}, tasks=(PregelTask(AnyStr(), "prepare", (PULL, "prepare")),), next=("prepare",), config=uconfig, created_at=AnyStr(), metadata={ "parents": {}, "source": "update", "step": 0, "writes": {START: {"my_key": "key", "market": "DE"}}, "thread_id": "23", }, parent_config=None, ) # run from this point assert await tool_two.ainvoke(None, thread3) == { "my_key": "key prepared", "market": "DE", } # get state after first node assert await tool_two.aget_state(thread3) == StateSnapshot( values={"my_key": "key prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), config=(await tool_two.checkpointer.aget_tuple(thread3)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "23", }, parent_config=uconfig, ) # resume, for same result as above assert await tool_two.ainvoke(None, thread3, debug=1) == { "my_key": "key prepared slow finished", "market": "DE", } assert await tool_two.aget_state(thread3) == StateSnapshot( values={"my_key": "key prepared slow finished", "market": "DE"}, tasks=(), next=(), config=(await tool_two.checkpointer.aget_tuple(thread3)).config, created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ "ts" ], metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "23", }, parent_config=[ c async for c in tool_two.checkpointer.alist(thread3, limit=2) ][-1].config, ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_in_one_fan_out_state_graph_waiting_edge(checkpointer_name: str) -> None: def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] async def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} async def analyzer_one(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} async def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} async def retriever_two(data: State) -> State: await asyncio.sleep(0.1) return {"docs": ["doc3", "doc4"]} async def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_edge("rewrite_query", "retriever_two") workflow.add_edge(["retriever_one", "retriever_two"], "qa") workflow.set_finish_point("qa") app = workflow.compile() assert await app.ainvoke({"query": "what is weather in sf"}) == { "query": "analyzed: query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", } assert [c async for c in app.astream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] async with awith_checkpointer(checkpointer_name) as checkpointer: app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["retriever_one"], ) config = {"configurable": {"thread_id": "1"}} assert [ c async for c in app_w_interrupt.astream( {"query": "what is weather in sf"}, config ) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"__interrupt__": ()}, ] assert [c async for c in app_w_interrupt.astream(None, config)] == [ {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_in_one_fan_out_state_graph_waiting_edge_via_branch( snapshot: SnapshotAssertion, checkpointer_name: str ) -> None: def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] async def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} async def analyzer_one(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} async def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} async def retriever_two(data: State) -> State: await asyncio.sleep(0.1) return {"docs": ["doc3", "doc4"]} async def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_conditional_edges( "rewrite_query", lambda _: "retriever_two", {"retriever_two": "retriever_two"} ) workflow.add_edge(["retriever_one", "retriever_two"], "qa") workflow.set_finish_point("qa") app = workflow.compile() assert await app.ainvoke({"query": "what is weather in sf"}, debug=True) == { "query": "analyzed: query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", } assert [c async for c in app.astream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] async with awith_checkpointer(checkpointer_name) as checkpointer: app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["retriever_one"], ) config = {"configurable": {"thread_id": "1"}} assert [ c async for c in app_w_interrupt.astream( {"query": "what is weather in sf"}, config ) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"__interrupt__": ()}, ] assert [c async for c in app_w_interrupt.astream(None, config)] == [ {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_in_one_fan_out_state_graph_waiting_edge_custom_state_class( snapshot: SnapshotAssertion, mocker: MockerFixture, checkpointer_name: str ) -> None: from pydantic.v1 import BaseModel, ValidationError setup = mocker.Mock() teardown = mocker.Mock() @asynccontextmanager async def assert_ctx_once() -> AsyncIterator[None]: assert setup.call_count == 0 assert teardown.call_count == 0 try: yield finally: assert setup.call_count == 1 assert teardown.call_count == 1 setup.reset_mock() teardown.reset_mock() @asynccontextmanager async def make_httpx_client() -> AsyncIterator[httpx.AsyncClient]: setup() async with httpx.AsyncClient() as client: try: yield client finally: teardown() def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(BaseModel): class Config: arbitrary_types_allowed = True query: str answer: Optional[str] = None docs: Annotated[list[str], sorted_add] client: Annotated[httpx.AsyncClient, Context(make_httpx_client)] class Input(BaseModel): query: str class Output(BaseModel): answer: str docs: list[str] class StateUpdate(BaseModel): query: Optional[str] = None answer: Optional[str] = None docs: Optional[list[str]] = None async def rewrite_query(data: State) -> State: return {"query": f"query: {data.query}"} async def analyzer_one(data: State) -> State: return StateUpdate(query=f"analyzed: {data.query}") async def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} async def retriever_two(data: State) -> State: await asyncio.sleep(0.1) return {"docs": ["doc3", "doc4"]} async def qa(data: State) -> State: return {"answer": ",".join(data.docs)} async def decider(data: State) -> str: assert isinstance(data, State) return "retriever_two" workflow = StateGraph(State, input=Input, output=Output) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_conditional_edges( "rewrite_query", decider, {"retriever_two": "retriever_two"} ) workflow.add_edge(["retriever_one", "retriever_two"], "qa") workflow.set_finish_point("qa") app = workflow.compile() async with assert_ctx_once(): with pytest.raises(ValidationError): await app.ainvoke({"query": {}}) async with assert_ctx_once(): assert await app.ainvoke({"query": "what is weather in sf"}) == { "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", } async with assert_ctx_once(): assert [c async for c in app.astream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] async with awith_checkpointer(checkpointer_name) as checkpointer: app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["retriever_one"], ) config = {"configurable": {"thread_id": "1"}} async with assert_ctx_once(): assert [ c async for c in app_w_interrupt.astream( {"query": "what is weather in sf"}, config ) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"__interrupt__": ()}, ] async with assert_ctx_once(): assert [c async for c in app_w_interrupt.astream(None, config)] == [ {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "query": "analyzed: query: what is weather in sf", "answer": "doc1,doc2,doc3,doc4", "docs": ["doc1", "doc2", "doc3", "doc4"], }, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, "step": 4, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) async with assert_ctx_once(): assert await app_w_interrupt.aupdate_state( config, {"docs": ["doc5"]}, as_node="rewrite_query" ) == { "configurable": { "thread_id": "1", "checkpoint_id": AnyStr(), "checkpoint_ns": "", } } @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2( snapshot: SnapshotAssertion, checkpointer_name: str ) -> None: from pydantic import BaseModel, ValidationError def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class InnerObject(BaseModel): yo: int class State(BaseModel): query: str inner: InnerObject answer: Optional[str] = None docs: Annotated[list[str], sorted_add] class StateUpdate(BaseModel): query: Optional[str] = None answer: Optional[str] = None docs: Optional[list[str]] = None async def rewrite_query(data: State) -> State: return {"query": f"query: {data.query}"} async def analyzer_one(data: State) -> State: return StateUpdate(query=f"analyzed: {data.query}") async def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} async def retriever_two(data: State) -> State: await asyncio.sleep(0.1) return {"docs": ["doc3", "doc4"]} async def qa(data: State) -> State: return {"answer": ",".join(data.docs)} async def decider(data: State) -> str: assert isinstance(data, State) return "retriever_two" workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_conditional_edges( "rewrite_query", decider, {"retriever_two": "retriever_two"} ) workflow.add_edge(["retriever_one", "retriever_two"], "qa") workflow.set_finish_point("qa") app = workflow.compile() if SHOULD_CHECK_SNAPSHOTS: assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.get_input_schema().model_json_schema() == snapshot assert app.get_output_schema().model_json_schema() == snapshot with pytest.raises(ValidationError): await app.ainvoke({"query": {}}) assert await app.ainvoke( {"query": "what is weather in sf", "inner": {"yo": 1}} ) == { "query": "analyzed: query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", "inner": {"yo": 1}, } assert [ c async for c in app.astream( {"query": "what is weather in sf", "inner": {"yo": 1}} ) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] async with awith_checkpointer(checkpointer_name) as checkpointer: app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["retriever_one"], ) config = {"configurable": {"thread_id": "1"}} assert [ c async for c in app_w_interrupt.astream( {"query": "what is weather in sf", "inner": {"yo": 1}}, config ) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"__interrupt__": ()}, ] assert [c async for c in app_w_interrupt.astream(None, config)] == [ {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] assert await app_w_interrupt.aupdate_state( config, {"docs": ["doc5"]}, as_node="rewrite_query" ) == { "configurable": { "thread_id": "1", "checkpoint_id": AnyStr(), "checkpoint_ns": "", } } @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_in_one_fan_out_state_graph_waiting_edge_plus_regular( checkpointer_name: str, ) -> None: def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] async def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} async def analyzer_one(data: State) -> State: await asyncio.sleep(0.1) return {"query": f'analyzed: {data["query"]}'} async def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} async def retriever_two(data: State) -> State: await asyncio.sleep(0.2) return {"docs": ["doc3", "doc4"]} async def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_edge("rewrite_query", "retriever_two") workflow.add_edge(["retriever_one", "retriever_two"], "qa") workflow.set_finish_point("qa") # silly edge, to make sure having been triggered before doesn't break # semantics of named barrier (== waiting edges) workflow.add_edge("rewrite_query", "qa") app = workflow.compile() assert await app.ainvoke({"query": "what is weather in sf"}) == { "query": "analyzed: query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", } assert [c async for c in app.astream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"qa": {"answer": ""}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] async with awith_checkpointer(checkpointer_name) as checkpointer: app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["retriever_one"], ) config = {"configurable": {"thread_id": "1"}} assert [ c async for c in app_w_interrupt.astream( {"query": "what is weather in sf"}, config ) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"qa": {"answer": ""}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"__interrupt__": ()}, ] assert [c async for c in app_w_interrupt.astream(None, config)] == [ {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] async def test_in_one_fan_out_state_graph_waiting_edge_multiple() -> None: def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] async def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} async def analyzer_one(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} async def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} async def retriever_two(data: State) -> State: await asyncio.sleep(0.1) return {"docs": ["doc3", "doc4"]} async def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} async def decider(data: State) -> None: return None def decider_cond(data: State) -> str: if data["query"].count("analyzed") > 1: return "qa" else: return "rewrite_query" workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("decider", decider) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_edge("rewrite_query", "retriever_two") workflow.add_edge(["retriever_one", "retriever_two"], "decider") workflow.add_conditional_edges("decider", decider_cond) workflow.set_finish_point("qa") app = workflow.compile() assert await app.ainvoke({"query": "what is weather in sf"}) == { "query": "analyzed: query: analyzed: query: what is weather in sf", "answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4", "docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"], } assert [c async for c in app.astream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"decider": None}, {"rewrite_query": {"query": "query: analyzed: query: what is weather in sf"}}, { "analyzer_one": { "query": "analyzed: query: analyzed: query: what is weather in sf" } }, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"decider": None}, {"qa": {"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4"}}, ] async def test_in_one_fan_out_state_graph_waiting_edge_multiple_cond_edge() -> None: def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] async def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} async def retriever_picker(data: State) -> list[str]: return ["analyzer_one", "retriever_two"] async def analyzer_one(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} async def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} async def retriever_two(data: State) -> State: await asyncio.sleep(0.1) return {"docs": ["doc3", "doc4"]} async def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} async def decider(data: State) -> None: return None def decider_cond(data: State) -> str: if data["query"].count("analyzed") > 1: return "qa" else: return "rewrite_query" workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("decider", decider) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_conditional_edges("rewrite_query", retriever_picker) workflow.add_edge("analyzer_one", "retriever_one") workflow.add_edge(["retriever_one", "retriever_two"], "decider") workflow.add_conditional_edges("decider", decider_cond) workflow.set_finish_point("qa") app = workflow.compile() assert await app.ainvoke({"query": "what is weather in sf"}) == { "query": "analyzed: query: analyzed: query: what is weather in sf", "answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4", "docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"], } assert [c async for c in app.astream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"decider": None}, {"rewrite_query": {"query": "query: analyzed: query: what is weather in sf"}}, { "analyzer_one": { "query": "analyzed: query: analyzed: query: what is weather in sf" } }, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"decider": None}, {"qa": {"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4"}}, ] async def test_nested_graph(snapshot: SnapshotAssertion) -> None: def never_called_fn(state: Any): assert 0, "This function should never be called" never_called = RunnableLambda(never_called_fn) class InnerState(TypedDict): my_key: str my_other_key: str def up(state: InnerState): return {"my_key": state["my_key"] + " there", "my_other_key": state["my_key"]} inner = StateGraph(InnerState) inner.add_node("up", up) inner.set_entry_point("up") inner.set_finish_point("up") class State(TypedDict): my_key: str never_called: Any async def side(state: State): return {"my_key": state["my_key"] + " and back again"} graph = StateGraph(State) graph.add_node("inner", inner.compile()) graph.add_node("side", side) graph.set_entry_point("inner") graph.add_edge("inner", "side") graph.set_finish_point("side") app = graph.compile() assert await app.ainvoke({"my_key": "my value", "never_called": never_called}) == { "my_key": "my value there and back again", "never_called": never_called, } assert [ chunk async for chunk in app.astream( {"my_key": "my value", "never_called": never_called} ) ] == [ {"inner": {"my_key": "my value there"}}, {"side": {"my_key": "my value there and back again"}}, ] assert [ chunk async for chunk in app.astream( {"my_key": "my value", "never_called": never_called}, stream_mode="values" ) ] == [ {"my_key": "my value", "never_called": never_called}, {"my_key": "my value there", "never_called": never_called}, {"my_key": "my value there and back again", "never_called": never_called}, ] times_called = 0 async for event in app.astream_events( {"my_key": "my value", "never_called": never_called}, version="v2", config={"run_id": UUID(int=0)}, stream_mode="values", ): if event["event"] == "on_chain_end" and event["run_id"] == str(UUID(int=0)): times_called += 1 assert event["data"] == { "output": { "my_key": "my value there and back again", "never_called": never_called, } } assert times_called == 1 times_called = 0 async for event in app.astream_events( {"my_key": "my value", "never_called": never_called}, version="v2", config={"run_id": UUID(int=0)}, ): if event["event"] == "on_chain_end" and event["run_id"] == str(UUID(int=0)): times_called += 1 assert event["data"] == { "output": { "my_key": "my value there and back again", "never_called": never_called, } } assert times_called == 1 chain = app | RunnablePassthrough() assert await chain.ainvoke( {"my_key": "my value", "never_called": never_called} ) == { "my_key": "my value there and back again", "never_called": never_called, } assert [ chunk async for chunk in chain.astream( {"my_key": "my value", "never_called": never_called} ) ] == [ {"inner": {"my_key": "my value there"}}, {"side": {"my_key": "my value there and back again"}}, ] times_called = 0 async for event in chain.astream_events( {"my_key": "my value", "never_called": never_called}, version="v2", config={"run_id": UUID(int=0)}, ): if event["event"] == "on_chain_end" and event["run_id"] == str(UUID(int=0)): times_called += 1 assert event["data"] == { "output": [ {"inner": {"my_key": "my value there"}}, {"side": {"my_key": "my value there and back again"}}, ] } assert times_called == 1 @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_stream_subgraphs_during_execution(checkpointer_name: str) -> None: class InnerState(TypedDict): my_key: Annotated[str, operator.add] my_other_key: str async def inner_1(state: InnerState): return {"my_key": "got here", "my_other_key": state["my_key"]} async def inner_2(state: InnerState): await asyncio.sleep(0.5) return { "my_key": " and there", "my_other_key": state["my_key"], } inner = StateGraph(InnerState) inner.add_node("inner_1", inner_1) inner.add_node("inner_2", inner_2) inner.add_edge("inner_1", "inner_2") inner.set_entry_point("inner_1") inner.set_finish_point("inner_2") class State(TypedDict): my_key: Annotated[str, operator.add] async def outer_1(state: State): await asyncio.sleep(0.2) return {"my_key": " and parallel"} async def outer_2(state: State): return {"my_key": " and back again"} graph = StateGraph(State) graph.add_node("inner", inner.compile()) graph.add_node("outer_1", outer_1) graph.add_node("outer_2", outer_2) graph.add_edge(START, "inner") graph.add_edge(START, "outer_1") graph.add_edge(["inner", "outer_1"], "outer_2") graph.add_edge("outer_2", END) async with awith_checkpointer(checkpointer_name) as checkpointer: app = graph.compile(checkpointer=checkpointer) start = perf_counter() chunks: list[tuple[float, Any]] = [] config = {"configurable": {"thread_id": "2"}} async for c in app.astream({"my_key": ""}, config, subgraphs=True): chunks.append((round(perf_counter() - start, 1), c)) for idx in range(len(chunks)): elapsed, c = chunks[idx] chunks[idx] = (round(elapsed - chunks[0][0], 1), c) assert chunks == [ # arrives before "inner" finishes ( FloatBetween(0.0, 0.1), ( (AnyStr("inner:"),), {"inner_1": {"my_key": "got here", "my_other_key": ""}}, ), ), (FloatBetween(0.2, 0.4), ((), {"outer_1": {"my_key": " and parallel"}})), ( FloatBetween(0.5, 0.7), ( (AnyStr("inner:"),), {"inner_2": {"my_key": " and there", "my_other_key": "got here"}}, ), ), (FloatBetween(0.5, 0.7), ((), {"inner": {"my_key": "got here and there"}})), (FloatBetween(0.5, 0.7), ((), {"outer_2": {"my_key": " and back again"}})), ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_stream_buffering_single_node(checkpointer_name: str) -> None: class State(TypedDict): my_key: Annotated[str, operator.add] async def node(state: State, writer: StreamWriter): writer("Before sleep") await asyncio.sleep(0.2) writer("After sleep") return {"my_key": "got here"} builder = StateGraph(State) builder.add_node("node", node) builder.add_edge(START, "node") builder.add_edge("node", END) async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) start = perf_counter() chunks: list[tuple[float, Any]] = [] config = {"configurable": {"thread_id": "2"}} async for c in graph.astream({"my_key": ""}, config, stream_mode="custom"): chunks.append((round(perf_counter() - start, 1), c)) assert chunks == [ (FloatBetween(0.0, 0.1), "Before sleep"), (FloatBetween(0.2, 0.3), "After sleep"), ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_nested_graph_interrupts_parallel(checkpointer_name: str) -> None: class InnerState(TypedDict): my_key: Annotated[str, operator.add] my_other_key: str async def inner_1(state: InnerState): await asyncio.sleep(0.1) return {"my_key": "got here", "my_other_key": state["my_key"]} async def inner_2(state: InnerState): return { "my_key": " and there", "my_other_key": state["my_key"], } inner = StateGraph(InnerState) inner.add_node("inner_1", inner_1) inner.add_node("inner_2", inner_2) inner.add_edge("inner_1", "inner_2") inner.set_entry_point("inner_1") inner.set_finish_point("inner_2") class State(TypedDict): my_key: Annotated[str, operator.add] async def outer_1(state: State): return {"my_key": " and parallel"} async def outer_2(state: State): return {"my_key": " and back again"} graph = StateGraph(State) graph.add_node( "inner", inner.compile(interrupt_before=["inner_2"]), ) graph.add_node("outer_1", outer_1) graph.add_node("outer_2", outer_2) graph.add_edge(START, "inner") graph.add_edge(START, "outer_1") graph.add_edge(["inner", "outer_1"], "outer_2") graph.set_finish_point("outer_2") async with awith_checkpointer(checkpointer_name) as checkpointer: app = graph.compile(checkpointer=checkpointer) # test invoke w/ nested interrupt config = {"configurable": {"thread_id": "1"}} assert await app.ainvoke({"my_key": ""}, config, debug=True) == { "my_key": " and parallel", } assert await app.ainvoke(None, config, debug=True) == { "my_key": "got here and there and parallel and back again", } # below combo of assertions is asserting two things # - outer_1 finishes before inner interrupts (because we see its output in stream, which only happens after node finishes) # - the writes of outer are persisted in 1st call and used in 2nd call, ie outer isn't called again (because we dont see outer_1 output again in 2nd stream) # test stream updates w/ nested interrupt config = {"configurable": {"thread_id": "2"}} assert [ c async for c in app.astream({"my_key": ""}, config, subgraphs=True) ] == [ # we got to parallel node first ((), {"outer_1": {"my_key": " and parallel"}}), ( (AnyStr("inner:"),), {"inner_1": {"my_key": "got here", "my_other_key": ""}}, ), ((), {"__interrupt__": ()}), ] assert [c async for c in app.astream(None, config)] == [ {"outer_1": {"my_key": " and parallel"}, "__metadata__": {"cached": True}}, {"inner": {"my_key": "got here and there"}}, {"outer_2": {"my_key": " and back again"}}, ] # test stream values w/ nested interrupt config = {"configurable": {"thread_id": "3"}} assert [ c async for c in app.astream({"my_key": ""}, config, stream_mode="values") ] == [ {"my_key": ""}, {"my_key": " and parallel"}, ] assert [c async for c in app.astream(None, config, stream_mode="values")] == [ {"my_key": ""}, {"my_key": "got here and there and parallel"}, {"my_key": "got here and there and parallel and back again"}, ] # # test interrupts BEFORE the parallel node app = graph.compile(checkpointer=checkpointer, interrupt_before=["outer_1"]) config = {"configurable": {"thread_id": "4"}} assert [ c async for c in app.astream({"my_key": ""}, config, stream_mode="values") ] == [ {"my_key": ""}, ] # while we're waiting for the node w/ interrupt inside to finish assert [c async for c in app.astream(None, config, stream_mode="values")] == [ {"my_key": ""}, {"my_key": " and parallel"}, ] assert [c async for c in app.astream(None, config, stream_mode="values")] == [ {"my_key": ""}, {"my_key": "got here and there and parallel"}, {"my_key": "got here and there and parallel and back again"}, ] # test interrupts AFTER the parallel node app = graph.compile(checkpointer=checkpointer, interrupt_after=["outer_1"]) config = {"configurable": {"thread_id": "5"}} assert [ c async for c in app.astream({"my_key": ""}, config, stream_mode="values") ] == [ {"my_key": ""}, {"my_key": " and parallel"}, ] assert [c async for c in app.astream(None, config, stream_mode="values")] == [ {"my_key": ""}, {"my_key": "got here and there and parallel"}, ] assert [c async for c in app.astream(None, config, stream_mode="values")] == [ {"my_key": "got here and there and parallel"}, {"my_key": "got here and there and parallel and back again"}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_doubly_nested_graph_interrupts(checkpointer_name: str) -> None: class State(TypedDict): my_key: str class ChildState(TypedDict): my_key: str class GrandChildState(TypedDict): my_key: str async def grandchild_1(state: ChildState): return {"my_key": state["my_key"] + " here"} async def grandchild_2(state: ChildState): return { "my_key": state["my_key"] + " and there", } grandchild = StateGraph(GrandChildState) grandchild.add_node("grandchild_1", grandchild_1) grandchild.add_node("grandchild_2", grandchild_2) grandchild.add_edge("grandchild_1", "grandchild_2") grandchild.set_entry_point("grandchild_1") grandchild.set_finish_point("grandchild_2") child = StateGraph(ChildState) child.add_node( "child_1", grandchild.compile(interrupt_before=["grandchild_2"]), ) child.set_entry_point("child_1") child.set_finish_point("child_1") async def parent_1(state: State): return {"my_key": "hi " + state["my_key"]} async def parent_2(state: State): return {"my_key": state["my_key"] + " and back again"} graph = StateGraph(State) graph.add_node("parent_1", parent_1) graph.add_node("child", child.compile()) graph.add_node("parent_2", parent_2) graph.set_entry_point("parent_1") graph.add_edge("parent_1", "child") graph.add_edge("child", "parent_2") graph.set_finish_point("parent_2") async with awith_checkpointer(checkpointer_name) as checkpointer: app = graph.compile(checkpointer=checkpointer) # test invoke w/ nested interrupt config = {"configurable": {"thread_id": "1"}} assert await app.ainvoke({"my_key": "my value"}, config, debug=True) == { "my_key": "hi my value", } assert await app.ainvoke(None, config, debug=True) == { "my_key": "hi my value here and there and back again", } # test stream updates w/ nested interrupt nodes: list[str] = [] config = { "configurable": {"thread_id": "2", CONFIG_KEY_NODE_FINISHED: nodes.append} } assert [c async for c in app.astream({"my_key": "my value"}, config)] == [ {"parent_1": {"my_key": "hi my value"}}, {"__interrupt__": ()}, ] assert nodes == ["parent_1", "grandchild_1"] assert [c async for c in app.astream(None, config)] == [ {"child": {"my_key": "hi my value here and there"}}, {"parent_2": {"my_key": "hi my value here and there and back again"}}, ] assert nodes == [ "parent_1", "grandchild_1", "grandchild_2", "child_1", "child", "parent_2", ] # test stream values w/ nested interrupt config = {"configurable": {"thread_id": "3"}} assert [ c async for c in app.astream( {"my_key": "my value"}, config, stream_mode="values" ) ] == [ {"my_key": "my value"}, {"my_key": "hi my value"}, ] assert [c async for c in app.astream(None, config, stream_mode="values")] == [ {"my_key": "hi my value"}, {"my_key": "hi my value here and there"}, {"my_key": "hi my value here and there and back again"}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_nested_graph_state(checkpointer_name: str) -> None: class InnerState(TypedDict): my_key: str my_other_key: str def inner_1(state: InnerState): return { "my_key": state["my_key"] + " here", "my_other_key": state["my_key"], } def inner_2(state: InnerState): return { "my_key": state["my_key"] + " and there", "my_other_key": state["my_key"], } inner = StateGraph(InnerState) inner.add_node("inner_1", inner_1) inner.add_node("inner_2", inner_2) inner.add_edge("inner_1", "inner_2") inner.set_entry_point("inner_1") inner.set_finish_point("inner_2") class State(TypedDict): my_key: str other_parent_key: str def outer_1(state: State): return {"my_key": "hi " + state["my_key"]} def outer_2(state: State): return {"my_key": state["my_key"] + " and back again"} graph = StateGraph(State) graph.add_node("outer_1", outer_1) graph.add_node( "inner", inner.compile(interrupt_before=["inner_2"]), ) graph.add_node("outer_2", outer_2) graph.set_entry_point("outer_1") graph.add_edge("outer_1", "inner") graph.add_edge("inner", "outer_2") graph.set_finish_point("outer_2") async with awith_checkpointer(checkpointer_name) as checkpointer: app = graph.compile(checkpointer=checkpointer) config = {"configurable": {"thread_id": "1"}} await app.ainvoke({"my_key": "my value"}, config, debug=True) # test state w/ nested subgraph state (right after interrupt) # first get_state without subgraph state assert await app.aget_state(config) == StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "inner", (PULL, "inner"), state={ "configurable": {"thread_id": "1", "checkpoint_ns": AnyStr()} }, ), ), next=("inner",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"outer_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) # now, get_state with subgraphs state assert await app.aget_state(config, subgraphs=True) == StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "inner", (PULL, "inner"), state=StateSnapshot( values={ "my_key": "hi my value here", "my_other_key": "hi my value", }, tasks=( PregelTask( AnyStr(), "inner_2", (PULL, "inner_2"), ), ), next=("inner_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "parents": { "": AnyStr(), }, "source": "loop", "writes": { "inner_1": { "my_key": "hi my value here", "my_other_key": "hi my value", } }, "step": 1, "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "langgraph_node": "inner", "langgraph_path": [PULL, "inner"], "langgraph_step": 2, "langgraph_triggers": ["outer_1"], "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, ), ), ), next=("inner",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"outer_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) # get_state_history returns outer graph checkpoints history = [c async for c in app.aget_state_history(config)] assert history == [ StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "inner", (PULL, "inner"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), } }, ), ), next=("inner",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"outer_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "my value"}, tasks=( PregelTask( AnyStr(), "outer_1", (PULL, "outer_1"), result={"my_key": "hi my value"}, ), ), next=("outer_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": None, "step": 0, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={}, tasks=( PregelTask( AnyStr(), "__start__", (PULL, "__start__"), result={"my_key": "my value"}, ), ), next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "writes": {"__start__": {"my_key": "my value"}}, "step": -1, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ), ] # get_state_history for a subgraph returns its checkpoints child_history = [ c async for c in app.aget_state_history(history[0].tasks[0].state) ] assert child_history == [ StateSnapshot( values={"my_key": "hi my value here", "my_other_key": "hi my value"}, next=("inner_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "loop", "writes": { "inner_1": { "my_key": "hi my value here", "my_other_key": "hi my value", } }, "step": 1, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "langgraph_node": "inner", "langgraph_path": [PULL, "inner"], "langgraph_step": 2, "langgraph_triggers": ["outer_1"], "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, tasks=(PregelTask(AnyStr(), "inner_2", (PULL, "inner_2")),), ), StateSnapshot( values={"my_key": "hi my value"}, next=("inner_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "loop", "writes": None, "step": 0, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "langgraph_node": "inner", "langgraph_path": [PULL, "inner"], "langgraph_step": 2, "langgraph_triggers": ["outer_1"], "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, tasks=( PregelTask( AnyStr(), "inner_1", (PULL, "inner_1"), result={ "my_key": "hi my value here", "my_other_key": "hi my value", }, ), ), ), StateSnapshot( values={}, next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "input", "writes": {"__start__": {"my_key": "hi my value"}}, "step": -1, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "langgraph_node": "inner", "langgraph_path": [PULL, "inner"], "langgraph_step": 2, "langgraph_triggers": ["outer_1"], "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), parent_config=None, tasks=( PregelTask( AnyStr(), "__start__", (PULL, "__start__"), result={"my_key": "hi my value"}, ), ), ), ] # resume await app.ainvoke(None, config, debug=True) # test state w/ nested subgraph state (after resuming from interrupt) assert await app.aget_state(config) == StateSnapshot( values={"my_key": "hi my value here and there and back again"}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "outer_2": {"my_key": "hi my value here and there and back again"} }, "step": 3, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) # test full history at the end actual_history = [c async for c in app.aget_state_history(config)] expected_history = [ StateSnapshot( values={"my_key": "hi my value here and there and back again"}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "outer_2": { "my_key": "hi my value here and there and back again" } }, "step": 3, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "hi my value here and there"}, tasks=( PregelTask( AnyStr(), "outer_2", (PULL, "outer_2"), result={"my_key": "hi my value here and there and back again"}, ), ), next=("outer_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"inner": {"my_key": "hi my value here and there"}}, "step": 2, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "inner", (PULL, "inner"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), } }, result={"my_key": "hi my value here and there"}, ), ), next=("inner",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"outer_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "my value"}, tasks=( PregelTask( AnyStr(), "outer_1", (PULL, "outer_1"), result={"my_key": "hi my value"}, ), ), next=("outer_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": None, "step": 0, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={}, tasks=( PregelTask( AnyStr(), "__start__", (PULL, "__start__"), result={"my_key": "my value"}, ), ), next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "writes": {"__start__": {"my_key": "my value"}}, "step": -1, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ), ] assert actual_history == expected_history # test looking up parent state by checkpoint ID for actual_snapshot, expected_snapshot in zip(actual_history, expected_history): assert await app.aget_state(actual_snapshot.config) == expected_snapshot @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_doubly_nested_graph_state(checkpointer_name: str) -> None: class State(TypedDict): my_key: str class ChildState(TypedDict): my_key: str class GrandChildState(TypedDict): my_key: str def grandchild_1(state: ChildState): return {"my_key": state["my_key"] + " here"} def grandchild_2(state: ChildState): return { "my_key": state["my_key"] + " and there", } grandchild = StateGraph(GrandChildState) grandchild.add_node("grandchild_1", grandchild_1) grandchild.add_node("grandchild_2", grandchild_2) grandchild.add_edge("grandchild_1", "grandchild_2") grandchild.set_entry_point("grandchild_1") grandchild.set_finish_point("grandchild_2") child = StateGraph(ChildState) child.add_node( "child_1", grandchild.compile(interrupt_before=["grandchild_2"]), ) child.set_entry_point("child_1") child.set_finish_point("child_1") def parent_1(state: State): return {"my_key": "hi " + state["my_key"]} def parent_2(state: State): return {"my_key": state["my_key"] + " and back again"} graph = StateGraph(State) graph.add_node("parent_1", parent_1) graph.add_node("child", child.compile()) graph.add_node("parent_2", parent_2) graph.set_entry_point("parent_1") graph.add_edge("parent_1", "child") graph.add_edge("child", "parent_2") graph.set_finish_point("parent_2") async with awith_checkpointer(checkpointer_name) as checkpointer: app = graph.compile(checkpointer=checkpointer) # test invoke w/ nested interrupt config = {"configurable": {"thread_id": "1"}} assert [ c async for c in app.astream({"my_key": "my value"}, config, subgraphs=True) ] == [ ((), {"parent_1": {"my_key": "hi my value"}}), ( (AnyStr("child:"), AnyStr("child_1:")), {"grandchild_1": {"my_key": "hi my value here"}}, ), ((), {"__interrupt__": ()}), ] # get state without subgraphs outer_state = await app.aget_state(config) assert outer_state == StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "child", (PULL, "child"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child"), } }, ), ), next=("child",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"parent_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) child_state = await app.aget_state(outer_state.tasks[0].state) assert ( child_state.tasks[0] == StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "child_1", (PULL, "child_1"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), } }, ), ), next=("child_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), } }, metadata={ "parents": {"": AnyStr()}, "source": "loop", "writes": None, "step": 0, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), } }, ).tasks[0] ) grandchild_state = await app.aget_state(child_state.tasks[0].state) assert grandchild_state == StateSnapshot( values={"my_key": "hi my value here"}, tasks=( PregelTask( AnyStr(), "grandchild_2", (PULL, "grandchild_2"), ), ), next=("grandchild_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, metadata={ "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "source": "loop", "writes": {"grandchild_1": {"my_key": "hi my value here"}}, "step": 1, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [PULL, AnyStr("child_1")], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, ) # get state with subgraphs assert await app.aget_state(config, subgraphs=True) == StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "child", (PULL, "child"), state=StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "child_1", (PULL, "child_1"), state=StateSnapshot( values={"my_key": "hi my value here"}, tasks=( PregelTask( AnyStr(), "grandchild_2", (PULL, "grandchild_2"), ), ), next=("grandchild_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr( re.compile(r"child:.+|child1:") ): AnyStr(), } ), } }, metadata={ "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "source": "loop", "writes": { "grandchild_1": { "my_key": "hi my value here" } }, "step": 1, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [ PULL, AnyStr("child_1"), ], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr( re.compile(r"child:.+|child1:") ): AnyStr(), } ), } }, ), ), ), next=("child_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "parents": {"": AnyStr()}, "source": "loop", "writes": None, "step": 0, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_node": "child", "langgraph_path": [PULL, AnyStr("child")], "langgraph_step": 2, "langgraph_triggers": [AnyStr("parent_1")], "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, ), ), ), next=("child",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"parent_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) # resume assert [c async for c in app.astream(None, config, subgraphs=True)] == [ ( (AnyStr("child:"), AnyStr("child_1:")), {"grandchild_2": {"my_key": "hi my value here and there"}}, ), ( (AnyStr("child:"),), {"child_1": {"my_key": "hi my value here and there"}}, ), ((), {"child": {"my_key": "hi my value here and there"}}), ((), {"parent_2": {"my_key": "hi my value here and there and back again"}}), ] # get state with and without subgraphs assert ( await app.aget_state(config) == await app.aget_state(config, subgraphs=True) == StateSnapshot( values={"my_key": "hi my value here and there and back again"}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "parent_2": { "my_key": "hi my value here and there and back again" } }, "step": 3, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) ) # get outer graph history outer_history = [c async for c in app.aget_state_history(config)] assert ( outer_history[0] == [ StateSnapshot( values={"my_key": "hi my value here and there and back again"}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "parent_2": { "my_key": "hi my value here and there and back again" } }, "step": 3, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "hi my value here and there"}, next=("parent_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"child": {"my_key": "hi my value here and there"}}, "step": 2, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="parent_2", path=(PULL, "parent_2") ), ), ), StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "child", (PULL, "child"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child"), } }, ), ), next=("child",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"parent_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "my value"}, next=("parent_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": None, "step": 0, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="parent_1", path=(PULL, "parent_1") ), ), ), StateSnapshot( values={}, next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "writes": {"my_key": "my value"}, "step": -1, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, tasks=( PregelTask( id=AnyStr(), name="__start__", path=(PULL, "__start__") ), ), ), ][0] ) # get child graph history child_history = [ c async for c in app.aget_state_history(outer_history[2].tasks[0].state) ] assert child_history == [ StateSnapshot( values={"my_key": "hi my value here and there"}, next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "loop", "writes": {"child_1": {"my_key": "hi my value here and there"}}, "step": 1, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_node": "child", "langgraph_path": [PULL, AnyStr("child")], "langgraph_step": 2, "langgraph_triggers": [AnyStr("parent_1")], "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, tasks=(), ), StateSnapshot( values={"my_key": "hi my value"}, next=("child_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "loop", "writes": None, "step": 0, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_node": "child", "langgraph_path": [PULL, AnyStr("child")], "langgraph_step": 2, "langgraph_triggers": [AnyStr("parent_1")], "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, tasks=( PregelTask( id=AnyStr(), name="child_1", path=(PULL, "child_1"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), } }, result={"my_key": "hi my value here and there"}, ), ), ), StateSnapshot( values={}, next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "input", "writes": {"__start__": {"my_key": "hi my value"}}, "step": -1, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_node": "child", "langgraph_path": [PULL, AnyStr("child")], "langgraph_step": 2, "langgraph_triggers": [AnyStr("parent_1")], "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), parent_config=None, tasks=( PregelTask( id=AnyStr(), name="__start__", path=(PULL, "__start__"), result={"my_key": "hi my value"}, ), ), ), ] # get grandchild graph history grandchild_history = [ c async for c in app.aget_state_history(child_history[1].tasks[0].state) ] assert grandchild_history == [ StateSnapshot( values={"my_key": "hi my value here and there"}, next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, metadata={ "source": "loop", "writes": { "grandchild_2": {"my_key": "hi my value here and there"} }, "step": 2, "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [ PULL, AnyStr("child_1"), ], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, tasks=(), ), StateSnapshot( values={"my_key": "hi my value here"}, next=("grandchild_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, metadata={ "source": "loop", "writes": {"grandchild_1": {"my_key": "hi my value here"}}, "step": 1, "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [ PULL, AnyStr("child_1"), ], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, tasks=( PregelTask( id=AnyStr(), name="grandchild_2", path=(PULL, "grandchild_2"), result={"my_key": "hi my value here and there"}, ), ), ), StateSnapshot( values={"my_key": "hi my value"}, next=("grandchild_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, metadata={ "source": "loop", "writes": None, "step": 0, "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [ PULL, AnyStr("child_1"), ], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, tasks=( PregelTask( id=AnyStr(), name="grandchild_1", path=(PULL, "grandchild_1"), result={"my_key": "hi my value here"}, ), ), ), StateSnapshot( values={}, next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, metadata={ "source": "input", "writes": {"__start__": {"my_key": "hi my value"}}, "step": -1, "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [ PULL, AnyStr("child_1"), ], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config=None, tasks=( PregelTask( id=AnyStr(), name="__start__", path=(PULL, "__start__"), result={"my_key": "hi my value"}, ), ), ), ] # replay grandchild checkpoint assert [ c async for c in app.astream( None, grandchild_history[2].config, subgraphs=True ) ] == [ ( (AnyStr("child:"), AnyStr("child_1:")), {"grandchild_1": {"my_key": "hi my value here"}}, ), ((), {"__interrupt__": ()}), ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_send_to_nested_graphs(checkpointer_name: str) -> None: class OverallState(TypedDict): subjects: list[str] jokes: Annotated[list[str], operator.add] async def continue_to_jokes(state: OverallState): return [Send("generate_joke", {"subject": s}) for s in state["subjects"]] class JokeState(TypedDict): subject: str async def edit(state: JokeState): subject = state["subject"] return {"subject": f"{subject} - hohoho"} # subgraph subgraph = StateGraph(JokeState, output=OverallState) subgraph.add_node("edit", edit) subgraph.add_node( "generate", lambda state: {"jokes": [f"Joke about {state['subject']}"]} ) subgraph.set_entry_point("edit") subgraph.add_edge("edit", "generate") subgraph.set_finish_point("generate") # parent graph builder = StateGraph(OverallState) builder.add_node( "generate_joke", subgraph.compile(interrupt_before=["generate"]), ) builder.add_conditional_edges(START, continue_to_jokes) builder.add_edge("generate_joke", END) async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) config = {"configurable": {"thread_id": "1"}} tracer = FakeTracer() # invoke and pause at nested interrupt assert await graph.ainvoke( {"subjects": ["cats", "dogs"]}, config={**config, "callbacks": [tracer]}, ) == { "subjects": ["cats", "dogs"], "jokes": [], } assert len(tracer.runs) == 1, "Should produce exactly 1 root run" # check state outer_state = await graph.aget_state(config) if not FF_SEND_V2: # update state of dogs joke graph await graph.aupdate_state( outer_state.tasks[1].state, {"subject": "turtles - hohoho"} ) # continue past interrupt assert await graph.ainvoke(None, config=config) == { "subjects": ["cats", "dogs"], "jokes": ["Joke about cats - hohoho", "Joke about turtles - hohoho"], } return assert outer_state == StateSnapshot( values={"subjects": ["cats", "dogs"], "jokes": []}, tasks=( PregelTask( id=AnyStr(), name="__start__", path=("__pregel_pull", "__start__"), error=None, interrupts=(), state=None, result={"subjects": ["cats", "dogs"]}, ), PregelTask( AnyStr(), "generate_joke", (PUSH, ("__pregel_pull", "__start__"), 1, AnyStr()), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), } }, ), PregelTask( AnyStr(), "generate_joke", (PUSH, ("__pregel_pull", "__start__"), 2, AnyStr()), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), } }, ), ), next=("generate_joke", "generate_joke"), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "writes": { "__start__": { "subjects": [ "cats", "dogs", ], } }, "step": -1, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ) # update state of dogs joke graph await graph.aupdate_state( outer_state.tasks[2].state, {"subject": "turtles - hohoho"} ) # continue past interrupt assert await graph.ainvoke(None, config=config) == { "subjects": ["cats", "dogs"], "jokes": ["Joke about cats - hohoho", "Joke about turtles - hohoho"], } actual_snapshot = await graph.aget_state(config) expected_snapshot = StateSnapshot( values={ "subjects": ["cats", "dogs"], "jokes": ["Joke about cats - hohoho", "Joke about turtles - hohoho"], }, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "generate_joke": [ {"jokes": ["Joke about cats - hohoho"]}, {"jokes": ["Joke about turtles - hohoho"]}, ] }, "step": 0, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) assert actual_snapshot == expected_snapshot # test full history actual_history = [c async for c in graph.aget_state_history(config)] expected_history = [ StateSnapshot( values={ "subjects": ["cats", "dogs"], "jokes": [ "Joke about cats - hohoho", "Joke about turtles - hohoho", ], }, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "generate_joke": [ {"jokes": ["Joke about cats - hohoho"]}, {"jokes": ["Joke about turtles - hohoho"]}, ] }, "step": 0, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"jokes": []}, next=("__start__", "generate_joke", "generate_joke"), tasks=( PregelTask( id=AnyStr(), name="__start__", path=("__pregel_pull", "__start__"), error=None, interrupts=(), state=None, result={"subjects": ["cats", "dogs"]}, ), PregelTask( AnyStr(), "generate_joke", (PUSH, ("__pregel_pull", "__start__"), 1, AnyStr()), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), } }, result={"jokes": ["Joke about cats - hohoho"]}, ), PregelTask( AnyStr(), "generate_joke", (PUSH, ("__pregel_pull", "__start__"), 2, AnyStr()), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), } }, result={"jokes": ["Joke about turtles - hohoho"]}, ), ), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "writes": {"__start__": {"subjects": ["cats", "dogs"]}}, "step": -1, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ), ] assert actual_history == expected_history @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for async contextvars support", ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_weather_subgraph( checkpointer_name: str, snapshot: SnapshotAssertion ) -> None: from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import AIMessage, ToolCall from langchain_core.tools import tool from langgraph.graph import MessagesState # setup subgraph @tool def get_weather(city: str): """Get the weather for a specific city""" return f"I'ts sunny in {city}!" weather_model = FakeMessagesListChatModel( responses=[ AIMessage( content="", tool_calls=[ ToolCall( id="tool_call123", name="get_weather", args={"city": "San Francisco"}, ) ], ) ] ) class SubGraphState(MessagesState): city: str def model_node(state: SubGraphState, writer: StreamWriter): writer(" very") result = weather_model.invoke(state["messages"]) return {"city": cast(AIMessage, result).tool_calls[0]["args"]["city"]} def weather_node(state: SubGraphState, writer: StreamWriter): writer(" good") result = get_weather.invoke({"city": state["city"]}) return {"messages": [{"role": "assistant", "content": result}]} subgraph = StateGraph(SubGraphState) subgraph.add_node(model_node) subgraph.add_node(weather_node) subgraph.add_edge(START, "model_node") subgraph.add_edge("model_node", "weather_node") subgraph.add_edge("weather_node", END) subgraph = subgraph.compile(interrupt_before=["weather_node"]) # setup main graph class RouterState(MessagesState): route: Literal["weather", "other"] class Router(TypedDict): route: Literal["weather", "other"] router_model = FakeMessagesListChatModel( responses=[ AIMessage( content="", tool_calls=[ ToolCall( id="tool_call123", name="router", args={"dest": "weather"}, ) ], ) ] ) def router_node(state: RouterState, writer: StreamWriter): writer("I'm") system_message = "Classify the incoming query as either about weather or not." messages = [{"role": "system", "content": system_message}] + state["messages"] route = router_model.invoke(messages) return {"route": cast(AIMessage, route).tool_calls[0]["args"]["dest"]} def normal_llm_node(state: RouterState): return {"messages": [AIMessage("Hello!")]} def route_after_prediction(state: RouterState): if state["route"] == "weather": return "weather_graph" else: return "normal_llm_node" def weather_graph(state: RouterState): # this tests that all async checkpointers tested also implement sync methods # as the subgraph called with sync invoke will use sync checkpointer methods return subgraph.invoke(state) graph = StateGraph(RouterState) graph.add_node(router_node) graph.add_node(normal_llm_node) graph.add_node("weather_graph", weather_graph) graph.add_edge(START, "router_node") graph.add_conditional_edges("router_node", route_after_prediction) graph.add_edge("normal_llm_node", END) graph.add_edge("weather_graph", END) def get_first_in_list(): return [*graph.get_state_history(config, limit=1)][0] async with awith_checkpointer(checkpointer_name) as checkpointer: graph = graph.compile(checkpointer=checkpointer) assert graph.get_graph(xray=1).draw_mermaid() == snapshot config = {"configurable": {"thread_id": "1"}} thread2 = {"configurable": {"thread_id": "2"}} inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]} # run with custom output assert [ c async for c in graph.astream(inputs, thread2, stream_mode="custom") ] == [ "I'm", " very", ] assert [ c async for c in graph.astream(None, thread2, stream_mode="custom") ] == [ " good", ] # run until interrupt assert [ c async for c in graph.astream( inputs, config=config, stream_mode="updates", subgraphs=True ) ] == [ ((), {"router_node": {"route": "weather"}}), ((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}), ((), {"__interrupt__": ()}), ] # check current state state = await graph.aget_state(config) assert state == StateSnapshot( values={ "messages": [_AnyIdHumanMessage(content="what's the weather in sf")], "route": "weather", }, next=("weather_graph",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": {"router_node": {"route": "weather"}}, "step": 1, "parents": {}, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="weather_graph", path=(PULL, "weather_graph"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("weather_graph:"), } }, ), ), ) # confirm that list() delegates to alist() correctly assert await asyncio.to_thread(get_first_in_list) == state # update await graph.aupdate_state(state.tasks[0].state, {"city": "la"}) # run after update assert [ c async for c in graph.astream( None, config=config, stream_mode="updates", subgraphs=True ) ] == [ ( (AnyStr("weather_graph:"),), { "weather_node": { "messages": [ {"role": "assistant", "content": "I'ts sunny in la!"} ] } }, ), ( (), { "weather_graph": { "messages": [ _AnyIdHumanMessage(content="what's the weather in sf"), _AnyIdAIMessage(content="I'ts sunny in la!"), ] } }, ), ] # try updating acting as weather node config = {"configurable": {"thread_id": "14"}} inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]} assert [ c async for c in graph.astream( inputs, config=config, stream_mode="updates", subgraphs=True ) ] == [ ((), {"router_node": {"route": "weather"}}), ((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}), ((), {"__interrupt__": ()}), ] state = await graph.aget_state(config, subgraphs=True) assert state == StateSnapshot( values={ "messages": [_AnyIdHumanMessage(content="what's the weather in sf")], "route": "weather", }, next=("weather_graph",), config={ "configurable": { "thread_id": "14", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": {"router_node": {"route": "weather"}}, "step": 1, "parents": {}, "thread_id": "14", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "14", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="weather_graph", path=(PULL, "weather_graph"), state=StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what's the weather in sf") ], "city": "San Francisco", }, next=("weather_node",), config={ "configurable": { "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("weather_graph:"): AnyStr(), } ), } }, metadata={ "source": "loop", "writes": {"model_node": {"city": "San Francisco"}}, "step": 1, "parents": {"": AnyStr()}, "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "langgraph_node": "weather_graph", "langgraph_path": [PULL, "weather_graph"], "langgraph_step": 2, "langgraph_triggers": [ "branch:router_node:route_after_prediction:weather_graph" ], "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("weather_graph:"): AnyStr(), } ), } }, tasks=( PregelTask( id=AnyStr(), name="weather_node", path=(PULL, "weather_node"), ), ), ), ), ), ) await graph.aupdate_state( state.tasks[0].state.config, {"messages": [{"role": "assistant", "content": "rainy"}]}, as_node="weather_node", ) state = await graph.aget_state(config, subgraphs=True) assert state == StateSnapshot( values={ "messages": [_AnyIdHumanMessage(content="what's the weather in sf")], "route": "weather", }, next=("weather_graph",), config={ "configurable": { "thread_id": "14", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": {"router_node": {"route": "weather"}}, "step": 1, "parents": {}, "thread_id": "14", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "14", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="weather_graph", path=(PULL, "weather_graph"), state=StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what's the weather in sf"), _AnyIdAIMessage(content="rainy"), ], "city": "San Francisco", }, next=(), config={ "configurable": { "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("weather_graph:"): AnyStr(), } ), } }, metadata={ "step": 2, "source": "update", "writes": { "weather_node": { "messages": [ {"role": "assistant", "content": "rainy"} ] } }, "parents": {"": AnyStr()}, "thread_id": "14", "checkpoint_id": AnyStr(), "checkpoint_ns": AnyStr("weather_graph:"), "langgraph_node": "weather_graph", "langgraph_path": [PULL, "weather_graph"], "langgraph_step": 2, "langgraph_triggers": [ "branch:router_node:route_after_prediction:weather_graph" ], "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("weather_graph:"): AnyStr(), } ), } }, tasks=(), ), ), ), ) assert [ c async for c in graph.astream( None, config=config, stream_mode="updates", subgraphs=True ) ] == [ ( (), { "weather_graph": { "messages": [ _AnyIdHumanMessage(content="what's the weather in sf"), _AnyIdAIMessage(content="rainy"), ] } }, ), ] async def test_checkpoint_metadata() -> None: """This test verifies that a run's configurable fields are merged with the previous checkpoint config for each step in the run. """ # set up test from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import AIMessage, AnyMessage from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import tool # graph state class BaseState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] # initialize graph nodes @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] prompt = ChatPromptTemplate.from_messages( [ ("system", "You are a nice assistant."), ("placeholder", "{messages}"), ] ) model = FakeMessagesListChatModel( responses=[ AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), AIMessage(content="answer"), ] ) def agent(state: BaseState, config: RunnableConfig) -> BaseState: formatted = prompt.invoke(state) response = model.invoke(formatted) return {"messages": response} def should_continue(data: BaseState) -> str: # Logic to decide whether to continue in the loop or exit if not data["messages"][-1].tool_calls: return "exit" else: return "continue" # define graphs w/ and w/o interrupt workflow = StateGraph(BaseState) workflow.add_node("agent", agent) workflow.add_node("tools", ToolNode(tools)) workflow.set_entry_point("agent") workflow.add_conditional_edges( "agent", should_continue, {"continue": "tools", "exit": END} ) workflow.add_edge("tools", "agent") # graph w/o interrupt checkpointer_1 = MemorySaverAssertCheckpointMetadata() app = workflow.compile(checkpointer=checkpointer_1) # graph w/ interrupt checkpointer_2 = MemorySaverAssertCheckpointMetadata() app_w_interrupt = workflow.compile( checkpointer=checkpointer_2, interrupt_before=["tools"] ) # assertions # invoke graph w/o interrupt await app.ainvoke( {"messages": ["what is weather in sf"]}, { "configurable": { "thread_id": "1", "test_config_1": "foo", "test_config_2": "bar", }, }, ) config = {"configurable": {"thread_id": "1"}} # assert that checkpoint metadata contains the run's configurable fields chkpnt_metadata_1 = (await checkpointer_1.aget_tuple(config)).metadata assert chkpnt_metadata_1["thread_id"] == "1" assert chkpnt_metadata_1["test_config_1"] == "foo" assert chkpnt_metadata_1["test_config_2"] == "bar" # Verify that all checkpoint metadata have the expected keys. This check # is needed because a run may have an arbitrary number of steps depending # on how the graph is constructed. chkpnt_tuples_1 = checkpointer_1.alist(config) async for chkpnt_tuple in chkpnt_tuples_1: assert chkpnt_tuple.metadata["thread_id"] == "1" assert chkpnt_tuple.metadata["test_config_1"] == "foo" assert chkpnt_tuple.metadata["test_config_2"] == "bar" # invoke graph, but interrupt before tool call await app_w_interrupt.ainvoke( {"messages": ["what is weather in sf"]}, { "configurable": { "thread_id": "2", "test_config_3": "foo", "test_config_4": "bar", }, }, ) config = {"configurable": {"thread_id": "2"}} # assert that checkpoint metadata contains the run's configurable fields chkpnt_metadata_2 = (await checkpointer_2.aget_tuple(config)).metadata assert chkpnt_metadata_2["thread_id"] == "2" assert chkpnt_metadata_2["test_config_3"] == "foo" assert chkpnt_metadata_2["test_config_4"] == "bar" # resume graph execution await app_w_interrupt.ainvoke( input=None, config={ "configurable": { "thread_id": "2", "test_config_3": "foo", "test_config_4": "bar", } }, ) # assert that checkpoint metadata contains the run's configurable fields chkpnt_metadata_3 = (await checkpointer_2.aget_tuple(config)).metadata assert chkpnt_metadata_3["thread_id"] == "2" assert chkpnt_metadata_3["test_config_3"] == "foo" assert chkpnt_metadata_3["test_config_4"] == "bar" # Verify that all checkpoint metadata have the expected keys. This check # is needed because a run may have an arbitrary number of steps depending # on how the graph is constructed. chkpnt_tuples_2 = checkpointer_2.alist(config) async for chkpnt_tuple in chkpnt_tuples_2: assert chkpnt_tuple.metadata["thread_id"] == "2" assert chkpnt_tuple.metadata["test_config_3"] == "foo" assert chkpnt_tuple.metadata["test_config_4"] == "bar" async def test_checkpointer_null_pending_writes() -> None: class Node: def __init__(self, name: str): self.name = name setattr(self, "__name__", name) def __call__(self, state): return [self.name] builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_edge(START, "1") graph = builder.compile(checkpointer=MemorySaverNoPending()) assert graph.invoke([], {"configurable": {"thread_id": "foo"}}) == ["1"] assert graph.invoke([], {"configurable": {"thread_id": "foo"}}) == ["1"] * 2 assert (await graph.ainvoke([], {"configurable": {"thread_id": "foo"}})) == [ "1" ] * 3 assert (await graph.ainvoke([], {"configurable": {"thread_id": "foo"}})) == [ "1" ] * 4 @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) @pytest.mark.parametrize("store_name", ALL_STORES_ASYNC) async def test_store_injected_async(checkpointer_name: str, store_name: str) -> None: class State(TypedDict): count: Annotated[int, operator.add] doc_id = str(uuid.uuid4()) doc = {"some-key": "this-is-a-val"} uid = uuid.uuid4().hex namespace = (f"foo-{uid}", "bar") thread_1 = str(uuid.uuid4()) thread_2 = str(uuid.uuid4()) class Node: def __init__(self, i: Optional[int] = None): self.i = i async def __call__( self, inputs: State, config: RunnableConfig, store: BaseStore ): assert isinstance(store, BaseStore) await store.aput( namespace if self.i is not None and config["configurable"]["thread_id"] in (thread_1, thread_2) else (f"foo_{self.i}", "bar"), doc_id, { **doc, "from_thread": config["configurable"]["thread_id"], "some_val": inputs["count"], }, ) return {"count": 1} builder = StateGraph(State) builder.add_node("node", Node()) builder.add_edge("__start__", "node") N = 500 M = 1 if "duckdb" in store_name: logger.warning( "DuckDB store implementation has a known issue that does not" " support concurrent writes, so we're reducing the test scope" ) N = M = 1 for i in range(N): builder.add_node(f"node_{i}", Node(i)) builder.add_edge("__start__", f"node_{i}") async with awith_checkpointer(checkpointer_name) as checkpointer, awith_store( store_name ) as the_store: graph = builder.compile(store=the_store, checkpointer=checkpointer) # Test batch operations with multiple threads results = await graph.abatch( [{"count": 0}] * M, ([{"configurable": {"thread_id": str(uuid.uuid4())}}] * (M - 1)) + [{"configurable": {"thread_id": thread_1}}], ) result = results[-1] assert result == {"count": N + 1} returned_doc = (await the_store.aget(namespace, doc_id)).value assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 0} assert len((await the_store.asearch(namespace))) == 1 # Check results after another turn of the same thread result = await graph.ainvoke( {"count": 0}, {"configurable": {"thread_id": thread_1}} ) assert result == {"count": (N + 1) * 2} returned_doc = (await the_store.aget(namespace, doc_id)).value assert returned_doc == {**doc, "from_thread": thread_1, "some_val": N + 1} assert len((await the_store.asearch(namespace))) == 1 # Test with a different thread result = await graph.ainvoke( {"count": 0}, {"configurable": {"thread_id": thread_2}} ) assert result == {"count": N + 1} returned_doc = (await the_store.aget(namespace, doc_id)).value assert returned_doc == { **doc, "from_thread": thread_2, "some_val": 0, } # Overwrites the whole doc assert ( len((await the_store.asearch(namespace))) == 1 ) # still overwriting the same one async def test_debug_retry(): class State(TypedDict): messages: Annotated[list[str], operator.add] def node(name): async def _node(state: State): return {"messages": [f"entered {name} node"]} return _node builder = StateGraph(State) builder.add_node("one", node("one")) builder.add_node("two", node("two")) builder.add_edge(START, "one") builder.add_edge("one", "two") builder.add_edge("two", END) saver = MemorySaver() graph = builder.compile(checkpointer=saver) config = {"configurable": {"thread_id": "1"}} await graph.ainvoke({"messages": []}, config=config) # re-run step: 1 async for c in saver.alist(config): if c.metadata["step"] == 1: target_config = c.parent_config break assert target_config is not None update_config = await graph.aupdate_state(target_config, values=None) events = [ c async for c in graph.astream(None, config=update_config, stream_mode="debug") ] checkpoint_events = list( reversed([e["payload"] for e in events if e["type"] == "checkpoint"]) ) checkpoint_history = { c.config["configurable"]["checkpoint_id"]: c async for c in graph.aget_state_history(config) } def lax_normalize_config(config: Optional[dict]) -> Optional[dict]: if config is None: return None return config["configurable"] for stream in checkpoint_events: stream_conf = lax_normalize_config(stream["config"]) stream_parent_conf = lax_normalize_config(stream["parent_config"]) assert stream_conf != stream_parent_conf # ensure the streamed checkpoint == checkpoint from checkpointer.list() history = checkpoint_history[stream["config"]["configurable"]["checkpoint_id"]] history_conf = lax_normalize_config(history.config) assert stream_conf == history_conf history_parent_conf = lax_normalize_config(history.parent_config) assert stream_parent_conf == history_parent_conf async def test_debug_subgraphs(): class State(TypedDict): messages: Annotated[list[str], operator.add] def node(name): async def _node(state: State): return {"messages": [f"entered {name} node"]} return _node parent = StateGraph(State) child = StateGraph(State) child.add_node("c_one", node("c_one")) child.add_node("c_two", node("c_two")) child.add_edge(START, "c_one") child.add_edge("c_one", "c_two") child.add_edge("c_two", END) parent.add_node("p_one", node("p_one")) parent.add_node("p_two", child.compile()) parent.add_edge(START, "p_one") parent.add_edge("p_one", "p_two") parent.add_edge("p_two", END) graph = parent.compile(checkpointer=MemorySaver()) config = {"configurable": {"thread_id": "1"}} events = [ c async for c in graph.astream( {"messages": []}, config=config, stream_mode="debug", ) ] checkpoint_events = list( reversed([e["payload"] for e in events if e["type"] == "checkpoint"]) ) checkpoint_history = [c async for c in graph.aget_state_history(config)] assert len(checkpoint_events) == len(checkpoint_history) def normalize_config(config: Optional[dict]) -> Optional[dict]: if config is None: return None return config["configurable"] for stream, history in zip(checkpoint_events, checkpoint_history): assert stream["values"] == history.values assert stream["next"] == list(history.next) assert normalize_config(stream["config"]) == normalize_config(history.config) assert normalize_config(stream["parent_config"]) == normalize_config( history.parent_config ) assert len(stream["tasks"]) == len(history.tasks) for stream_task, history_task in zip(stream["tasks"], history.tasks): assert stream_task["id"] == history_task.id assert stream_task["name"] == history_task.name assert stream_task["interrupts"] == history_task.interrupts assert stream_task.get("error") == history_task.error assert stream_task.get("state") == history_task.state async def test_debug_nested_subgraphs(): from collections import defaultdict class State(TypedDict): messages: Annotated[list[str], operator.add] def node(name): async def _node(state: State): return {"messages": [f"entered {name} node"]} return _node grand_parent = StateGraph(State) parent = StateGraph(State) child = StateGraph(State) child.add_node("c_one", node("c_one")) child.add_node("c_two", node("c_two")) child.add_edge(START, "c_one") child.add_edge("c_one", "c_two") child.add_edge("c_two", END) parent.add_node("p_one", node("p_one")) parent.add_node("p_two", child.compile()) parent.add_edge(START, "p_one") parent.add_edge("p_one", "p_two") parent.add_edge("p_two", END) grand_parent.add_node("gp_one", node("gp_one")) grand_parent.add_node("gp_two", parent.compile()) grand_parent.add_edge(START, "gp_one") grand_parent.add_edge("gp_one", "gp_two") grand_parent.add_edge("gp_two", END) graph = grand_parent.compile(checkpointer=MemorySaver()) config = {"configurable": {"thread_id": "1"}} events = [ c async for c in graph.astream( {"messages": []}, config=config, stream_mode="debug", subgraphs=True, ) ] stream_ns: dict[tuple, dict] = defaultdict(list) for ns, e in events: if e["type"] == "checkpoint": stream_ns[ns].append(e["payload"]) assert list(stream_ns.keys()) == [ (), (AnyStr("gp_two:"),), (AnyStr("gp_two:"), AnyStr("p_two:")), ] history_ns = {} for ns in stream_ns.keys(): async def get_history(): history = [ c async for c in graph.aget_state_history( {"configurable": {"thread_id": "1", "checkpoint_ns": "|".join(ns)}} ) ] return history[::-1] history_ns[ns] = await get_history() def normalize_config(config: Optional[dict]) -> Optional[dict]: if config is None: return None clean_config = {} clean_config["thread_id"] = config["configurable"]["thread_id"] clean_config["checkpoint_id"] = config["configurable"]["checkpoint_id"] clean_config["checkpoint_ns"] = config["configurable"]["checkpoint_ns"] if "checkpoint_map" in config["configurable"]: clean_config["checkpoint_map"] = config["configurable"]["checkpoint_map"] return clean_config for checkpoint_events, checkpoint_history in zip( stream_ns.values(), history_ns.values() ): for stream, history in zip(checkpoint_events, checkpoint_history): assert stream["values"] == history.values assert stream["next"] == list(history.next) assert normalize_config(stream["config"]) == normalize_config( history.config ) assert normalize_config(stream["parent_config"]) == normalize_config( history.parent_config ) assert len(stream["tasks"]) == len(history.tasks) for stream_task, history_task in zip(stream["tasks"], history.tasks): assert stream_task["id"] == history_task.id assert stream_task["name"] == history_task.name assert stream_task["interrupts"] == history_task.interrupts assert stream_task.get("error") == history_task.error assert stream_task.get("state") == history_task.state @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_parent_command(checkpointer_name: str) -> None: from langchain_core.messages import BaseMessage from langchain_core.tools import tool @tool(return_direct=True) def get_user_name() -> Command: """Retrieve user name""" return Command(update={"user_name": "Meow"}, graph=Command.PARENT) subgraph_builder = StateGraph(MessagesState) subgraph_builder.add_node("tool", get_user_name) subgraph_builder.add_edge(START, "tool") subgraph = subgraph_builder.compile() class CustomParentState(TypedDict): messages: Annotated[list[BaseMessage], add_messages] # this key is not available to the child graph user_name: str builder = StateGraph(CustomParentState) builder.add_node("alice", subgraph) builder.add_edge(START, "alice") async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) config = {"configurable": {"thread_id": "1"}} assert await graph.ainvoke( {"messages": [("user", "get user name")]}, config ) == { "messages": [ _AnyIdHumanMessage( content="get user name", additional_kwargs={}, response_metadata={} ), ], "user_name": "Meow", } assert await graph.aget_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage( content="get user name", additional_kwargs={}, response_metadata={}, ), ], "user_name": "Meow", }, next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": { "alice": { "user_name": "Meow", } }, "thread_id": "1", "step": 1, "parents": {}, }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=(), ) @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for async contextvars support", ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_interrupt_subgraph(checkpointer_name: str): class State(TypedDict): baz: str def foo(state): return {"baz": "foo"} def bar(state): value = interrupt("Please provide baz value:") return {"baz": value} child_builder = StateGraph(State) child_builder.add_node(bar) child_builder.add_edge(START, "bar") builder = StateGraph(State) builder.add_node(foo) builder.add_node("bar", child_builder.compile()) builder.add_edge(START, "foo") builder.add_edge("foo", "bar") async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) thread1 = {"configurable": {"thread_id": "1"}} # First run, interrupted at bar assert await graph.ainvoke({"baz": ""}, thread1) # Resume with answer assert await graph.ainvoke(Command(resume="bar"), thread1) @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for async contextvars support", ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_interrupt_multiple(checkpointer_name: str): class State(TypedDict): my_key: Annotated[str, operator.add] async def node(s: State) -> State: answer = interrupt({"value": 1}) answer2 = interrupt({"value": 2}) return {"my_key": answer + " " + answer2} builder = StateGraph(State) builder.add_node("node", node) builder.add_edge(START, "node") async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) thread1 = {"configurable": {"thread_id": "1"}} assert [ e async for e in graph.astream({"my_key": "DE", "market": "DE"}, thread1) ] == [ { "__interrupt__": ( Interrupt( value={"value": 1}, resumable=True, ns=[AnyStr("node:")], when="during", ), ) } ] assert [ event async for event in graph.astream( Command(resume="answer 1", update={"my_key": "foofoo"}), thread1, stream_mode="updates", ) ] == [ { "__interrupt__": ( Interrupt( value={"value": 2}, resumable=True, ns=[AnyStr("node:")], when="during", ), ) } ] assert [ event async for event in graph.astream( Command(resume="answer 2"), thread1, stream_mode="updates" ) ] == [ {"node": {"my_key": "answer 1 answer 2"}}, ] @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for async contextvars support", ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_interrupt_loop(checkpointer_name: str): class State(TypedDict): age: int other: str async def ask_age(s: State): """Ask an expert for help.""" question = "How old are you?" value = None for _ in range(10): value: str = interrupt(question) if not value.isdigit() or int(value) < 18: question = "invalid response" value = None else: break return {"age": int(value)} builder = StateGraph(State) builder.add_node("node", ask_age) builder.add_edge(START, "node") async with awith_checkpointer(checkpointer_name) as checkpointer: graph = builder.compile(checkpointer=checkpointer) thread1 = {"configurable": {"thread_id": "1"}} assert [e async for e in graph.astream({"other": ""}, thread1)] == [ { "__interrupt__": ( Interrupt( value="How old are you?", resumable=True, ns=[AnyStr("node:")], when="during", ), ) } ] assert [ event async for event in graph.astream( Command(resume="13"), thread1, ) ] == [ { "__interrupt__": ( Interrupt( value="invalid response", resumable=True, ns=[AnyStr("node:")], when="during", ), ) } ] assert [ event async for event in graph.astream( Command(resume="15"), thread1, ) ] == [ { "__interrupt__": ( Interrupt( value="invalid response", resumable=True, ns=[AnyStr("node:")], when="during", ), ) } ] assert [ event async for event in graph.astream(Command(resume="19"), thread1) ] == [ {"node": {"age": 19}}, ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_prebuilt.py`: ```py import dataclasses import json from functools import partial from typing import ( Annotated, Any, Callable, Dict, List, Literal, Optional, Sequence, Type, TypeVar, Union, ) import pytest from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import BaseChatModel, LanguageModelInput from langchain_core.messages import ( AIMessage, AnyMessage, BaseMessage, HumanMessage, SystemMessage, ToolCall, ToolMessage, ) from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.runnables import Runnable, RunnableLambda from langchain_core.tools import BaseTool, ToolException from langchain_core.tools import tool as dec_tool from pydantic import BaseModel, ValidationError from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 from typing_extensions import TypedDict from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.memory import MemorySaver from langgraph.errors import NodeInterrupt from langgraph.graph import START, MessagesState, StateGraph, add_messages from langgraph.prebuilt import ( ToolNode, ValidationNode, create_react_agent, tools_condition, ) from langgraph.prebuilt.chat_agent_executor import _validate_chat_history from langgraph.prebuilt.tool_node import ( TOOL_CALL_ERROR_TEMPLATE, InjectedState, InjectedStore, _get_state_args, _infer_handled_types, ) from langgraph.store.base import BaseStore from langgraph.store.memory import InMemoryStore from langgraph.types import Interrupt from tests.conftest import ( ALL_CHECKPOINTERS_ASYNC, ALL_CHECKPOINTERS_SYNC, IS_LANGCHAIN_CORE_030_OR_GREATER, awith_checkpointer, ) from tests.messages import _AnyIdHumanMessage, _AnyIdToolMessage pytestmark = pytest.mark.anyio class FakeToolCallingModel(BaseChatModel): tool_calls: Optional[list[list[ToolCall]]] = None index: int = 0 tool_style: Literal["openai", "anthropic"] = "openai" def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """Top Level call""" messages_string = "-".join([m.content for m in messages]) tool_calls = ( self.tool_calls[self.index % len(self.tool_calls)] if self.tool_calls else [] ) message = AIMessage( content=messages_string, id=str(self.index), tool_calls=tool_calls.copy() ) self.index += 1 return ChatResult(generations=[ChatGeneration(message=message)]) @property def _llm_type(self) -> str: return "fake-tool-call-model" def bind_tools( self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: if len(tools) == 0: raise ValueError("Must provide at least one tool") tool_dicts = [] for tool in tools: if not isinstance(tool, BaseTool): raise TypeError( "Only BaseTool is supported by FakeToolCallingModel.bind_tools" ) # NOTE: this is a simplified tool spec for testing purposes only if self.tool_style == "openai": tool_dicts.append( { "type": "function", "function": { "name": tool.name, }, } ) elif self.tool_style == "anthropic": tool_dicts.append( { "name": tool.name, } ) return self.bind(tools=tool_dicts) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_no_modifier(request: pytest.FixtureRequest, checkpointer_name: str) -> None: checkpointer: BaseCheckpointSaver = request.getfixturevalue( "checkpointer_" + checkpointer_name ) model = FakeToolCallingModel() agent = create_react_agent(model, [], checkpointer=checkpointer) inputs = [HumanMessage("hi?")] thread = {"configurable": {"thread_id": "123"}} response = agent.invoke({"messages": inputs}, thread, debug=True) expected_response = {"messages": inputs + [AIMessage(content="hi?", id="0")]} assert response == expected_response if checkpointer: saved = checkpointer.get_tuple(thread) assert saved is not None assert saved.checkpoint["channel_values"] == { "messages": [ _AnyIdHumanMessage(content="hi?"), AIMessage(content="hi?", id="0"), ], "agent": "agent", } assert saved.metadata == { "parents": {}, "source": "loop", "writes": {"agent": {"messages": [AIMessage(content="hi?", id="0")]}}, "step": 1, "thread_id": "123", } assert saved.pending_writes == [] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_no_modifier_async(checkpointer_name: str) -> None: async with awith_checkpointer(checkpointer_name) as checkpointer: model = FakeToolCallingModel() agent = create_react_agent(model, [], checkpointer=checkpointer) inputs = [HumanMessage("hi?")] thread = {"configurable": {"thread_id": "123"}} response = await agent.ainvoke({"messages": inputs}, thread, debug=True) expected_response = {"messages": inputs + [AIMessage(content="hi?", id="0")]} assert response == expected_response if checkpointer: saved = await checkpointer.aget_tuple(thread) assert saved is not None assert saved.checkpoint["channel_values"] == { "messages": [ _AnyIdHumanMessage(content="hi?"), AIMessage(content="hi?", id="0"), ], "agent": "agent", } assert saved.metadata == { "parents": {}, "source": "loop", "writes": {"agent": {"messages": [AIMessage(content="hi?", id="0")]}}, "step": 1, "thread_id": "123", } assert saved.pending_writes == [] def test_passing_two_modifiers(): model = FakeToolCallingModel() with pytest.raises(ValueError): create_react_agent(model, [], messages_modifier="Foo", state_modifier="Bar") def test_system_message_modifier(): messages_modifier = SystemMessage(content="Foo") agent_1 = create_react_agent( FakeToolCallingModel(), [], messages_modifier=messages_modifier ) agent_2 = create_react_agent( FakeToolCallingModel(), [], state_modifier=messages_modifier ) for agent in [agent_1, agent_2]: inputs = [HumanMessage("hi?")] response = agent.invoke({"messages": inputs}) expected_response = { "messages": inputs + [AIMessage(content="Foo-hi?", id="0", tool_calls=[])] } assert response == expected_response def test_system_message_string_modifier(): messages_modifier = "Foo" agent_1 = create_react_agent( FakeToolCallingModel(), [], messages_modifier=messages_modifier ) agent_2 = create_react_agent( FakeToolCallingModel(), [], state_modifier=messages_modifier ) for agent in [agent_1, agent_2]: inputs = [HumanMessage("hi?")] response = agent.invoke({"messages": inputs}) expected_response = { "messages": inputs + [AIMessage(content="Foo-hi?", id="0", tool_calls=[])] } assert response == expected_response def test_callable_messages_modifier(): model = FakeToolCallingModel() def messages_modifier(messages): modified_message = f"Bar {messages[-1].content}" return [HumanMessage(content=modified_message)] agent = create_react_agent(model, [], messages_modifier=messages_modifier) inputs = [HumanMessage("hi?")] response = agent.invoke({"messages": inputs}) expected_response = {"messages": inputs + [AIMessage(content="Bar hi?", id="0")]} assert response == expected_response def test_callable_state_modifier(): model = FakeToolCallingModel() def state_modifier(state): modified_message = f"Bar {state['messages'][-1].content}" return [HumanMessage(content=modified_message)] agent = create_react_agent(model, [], state_modifier=state_modifier) inputs = [HumanMessage("hi?")] response = agent.invoke({"messages": inputs}) expected_response = {"messages": inputs + [AIMessage(content="Bar hi?", id="0")]} assert response == expected_response def test_runnable_messages_modifier(): model = FakeToolCallingModel() messages_modifier = RunnableLambda( lambda messages: [HumanMessage(content=f"Baz {messages[-1].content}")] ) agent = create_react_agent(model, [], messages_modifier=messages_modifier) inputs = [HumanMessage("hi?")] response = agent.invoke({"messages": inputs}) expected_response = {"messages": inputs + [AIMessage(content="Baz hi?", id="0")]} assert response == expected_response def test_runnable_state_modifier(): model = FakeToolCallingModel() state_modifier = RunnableLambda( lambda state: [HumanMessage(content=f"Baz {state['messages'][-1].content}")] ) agent = create_react_agent(model, [], state_modifier=state_modifier) inputs = [HumanMessage("hi?")] response = agent.invoke({"messages": inputs}) expected_response = {"messages": inputs + [AIMessage(content="Baz hi?", id="0")]} assert response == expected_response def test_state_modifier_with_store(): def add(a: int, b: int): """Adds a and b""" return a + b in_memory_store = InMemoryStore() in_memory_store.put(("memories", "1"), "user_name", {"data": "User name is Alice"}) in_memory_store.put(("memories", "2"), "user_name", {"data": "User name is Bob"}) def modify(state, config, *, store): user_id = config["configurable"]["user_id"] system_str = store.get(("memories", user_id), "user_name").value["data"] return [SystemMessage(system_str)] + state["messages"] def modify_no_store(state, config): return SystemMessage("foo") + state["messages"] model = FakeToolCallingModel() # test state modifier that uses store works agent = create_react_agent( model, [add], state_modifier=modify, store=in_memory_store ) response = agent.invoke( {"messages": [("user", "hi")]}, {"configurable": {"user_id": "1"}} ) assert response["messages"][-1].content == "User name is Alice-hi" # test state modifier that doesn't use store works agent = create_react_agent( model, [add], state_modifier=modify_no_store, store=in_memory_store ) response = agent.invoke( {"messages": [("user", "hi")]}, {"configurable": {"user_id": "2"}} ) assert response["messages"][-1].content == "foo-hi" @pytest.mark.parametrize("tool_style", ["openai", "anthropic"]) def test_model_with_tools(tool_style: str): model = FakeToolCallingModel(tool_style=tool_style) @dec_tool def tool1(some_val: int) -> str: """Tool 1 docstring.""" return f"Tool 1: {some_val}" @dec_tool def tool2(some_val: int) -> str: """Tool 2 docstring.""" return f"Tool 2: {some_val}" # check valid agent constructor agent = create_react_agent(model.bind_tools([tool1, tool2]), [tool1, tool2]) result = agent.nodes["tools"].invoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool1", "args": {"some_val": 2}, "id": "some 1", }, { "name": "tool2", "args": {"some_val": 2}, "id": "some 2", }, ], ) ] } ) tool_messages: ToolMessage = result["messages"][-2:] for tool_message in tool_messages: assert tool_message.type == "tool" assert tool_message.content in {"Tool 1: 2", "Tool 2: 2"} assert tool_message.tool_call_id in {"some 1", "some 2"} # test mismatching tool lengths with pytest.raises(ValueError): create_react_agent(model.bind_tools([tool1]), [tool1, tool2]) # test missing bound tools with pytest.raises(ValueError): create_react_agent(model.bind_tools([tool1]), [tool2]) def test__validate_messages(): # empty input _validate_chat_history([]) # single human message _validate_chat_history( [ HumanMessage(content="What's the weather?"), ] ) # human + AI _validate_chat_history( [ HumanMessage(content="What's the weather?"), AIMessage(content="The weather is sunny and 75°F."), ] ) # Answered tool calls _validate_chat_history( [ HumanMessage(content="What's the weather?"), AIMessage( content="Let me check that for you.", tool_calls=[{"id": "call1", "name": "get_weather", "args": {}}], ), ToolMessage(content="Sunny, 75°F", tool_call_id="call1"), AIMessage(content="The weather is sunny and 75°F."), ] ) # Unanswered tool calls with pytest.raises(ValueError): _validate_chat_history( [ AIMessage( content="I'll check that for you.", tool_calls=[ {"id": "call1", "name": "get_weather", "args": {}}, {"id": "call2", "name": "get_time", "args": {}}, ], ) ] ) with pytest.raises(ValueError): _validate_chat_history( [ HumanMessage(content="What's the weather and time?"), AIMessage( content="I'll check that for you.", tool_calls=[ {"id": "call1", "name": "get_weather", "args": {}}, {"id": "call2", "name": "get_time", "args": {}}, ], ), ToolMessage(content="Sunny, 75°F", tool_call_id="call1"), AIMessage( content="The weather is sunny and 75°F. Let me check the time." ), ] ) def test__infer_handled_types() -> None: def handle(e): # type: ignore return "" def handle2(e: Exception) -> str: return "" def handle3(e: Union[ValueError, ToolException]) -> str: return "" class Handler: def handle(self, e: ValueError) -> str: return "" handle4 = Handler().handle def handle5(e: Union[Union[TypeError, ValueError], ToolException]): return "" expected: tuple = (Exception,) actual = _infer_handled_types(handle) assert expected == actual expected = (Exception,) actual = _infer_handled_types(handle2) assert expected == actual expected = (ValueError, ToolException) actual = _infer_handled_types(handle3) assert expected == actual expected = (ValueError,) actual = _infer_handled_types(handle4) assert expected == actual expected = (TypeError, ValueError, ToolException) actual = _infer_handled_types(handle5) assert expected == actual with pytest.raises(ValueError): def handler(e: str): return "" _infer_handled_types(handler) with pytest.raises(ValueError): def handler(e: list[Exception]): return "" _infer_handled_types(handler) with pytest.raises(ValueError): def handler(e: Union[str, int]): return "" _infer_handled_types(handler) # tools for testing Too def tool1(some_val: int, some_other_val: str) -> str: """Tool 1 docstring.""" if some_val == 0: raise ValueError("Test error") return f"{some_val} - {some_other_val}" async def tool2(some_val: int, some_other_val: str) -> str: """Tool 2 docstring.""" if some_val == 0: raise ToolException("Test error") return f"tool2: {some_val} - {some_other_val}" async def tool3(some_val: int, some_other_val: str) -> str: """Tool 3 docstring.""" return [ {"key_1": some_val, "key_2": "foo"}, {"key_1": some_other_val, "key_2": "baz"}, ] async def tool4(some_val: int, some_other_val: str) -> str: """Tool 4 docstring.""" return [ {"type": "image_url", "image_url": {"url": "abdc"}}, ] @dec_tool def tool5(some_val: int): """Tool 5 docstring.""" raise ToolException("Test error") tool5.handle_tool_error = "foo" async def test_tool_node(): result = ToolNode([tool1]).invoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool1", "args": {"some_val": 1, "some_other_val": "foo"}, "id": "some 0", } ], ) ] } ) tool_message: ToolMessage = result["messages"][-1] assert tool_message.type == "tool" assert tool_message.content == "1 - foo" assert tool_message.tool_call_id == "some 0" result2 = await ToolNode([tool2]).ainvoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool2", "args": {"some_val": 2, "some_other_val": "bar"}, "id": "some 1", } ], ) ] } ) tool_message: ToolMessage = result2["messages"][-1] assert tool_message.type == "tool" assert tool_message.content == "tool2: 2 - bar" # list of dicts tool content result3 = await ToolNode([tool3]).ainvoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool3", "args": {"some_val": 2, "some_other_val": "bar"}, "id": "some 2", } ], ) ] } ) tool_message: ToolMessage = result3["messages"][-1] assert tool_message.type == "tool" assert ( tool_message.content == '[{"key_1": 2, "key_2": "foo"}, {"key_1": "bar", "key_2": "baz"}]' ) assert tool_message.tool_call_id == "some 2" # list of content blocks tool content result4 = await ToolNode([tool4]).ainvoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool4", "args": {"some_val": 2, "some_other_val": "bar"}, "id": "some 3", } ], ) ] } ) tool_message: ToolMessage = result4["messages"][-1] assert tool_message.type == "tool" assert tool_message.content == [{"type": "image_url", "image_url": {"url": "abdc"}}] assert tool_message.tool_call_id == "some 3" async def test_tool_node_error_handling(): def handle_all(e: Union[ValueError, ToolException, ValidationError]): return TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) # test catching all exceptions, via: # - handle_tool_errors = True # - passing a tuple of all exceptions # - passing a callable with all exceptions in the signature for handle_tool_errors in ( True, (ValueError, ToolException, ValidationError), handle_all, ): result_error = await ToolNode( [tool1, tool2, tool3], handle_tool_errors=handle_tool_errors ).ainvoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool1", "args": {"some_val": 0, "some_other_val": "foo"}, "id": "some id", }, { "name": "tool2", "args": {"some_val": 0, "some_other_val": "bar"}, "id": "some other id", }, { "name": "tool3", "args": {"some_val": 0}, "id": "another id", }, ], ) ] } ) assert all(m.type == "tool" for m in result_error["messages"]) assert all(m.status == "error" for m in result_error["messages"]) assert ( result_error["messages"][0].content == f"Error: {repr(ValueError('Test error'))}\n Please fix your mistakes." ) assert ( result_error["messages"][1].content == f"Error: {repr(ToolException('Test error'))}\n Please fix your mistakes." ) assert ( "ValidationError" in result_error["messages"][2].content or "validation error" in result_error["messages"][2].content ) assert result_error["messages"][0].tool_call_id == "some id" assert result_error["messages"][1].tool_call_id == "some other id" assert result_error["messages"][2].tool_call_id == "another id" async def test_tool_node_error_handling_callable(): def handle_value_error(e: ValueError): return "Value error" def handle_tool_exception(e: ToolException): return "Tool exception" for handle_tool_errors in ("Value error", handle_value_error): result_error = await ToolNode( [tool1], handle_tool_errors=handle_tool_errors ).ainvoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool1", "args": {"some_val": 0, "some_other_val": "foo"}, "id": "some id", }, ], ) ] } ) tool_message: ToolMessage = result_error["messages"][-1] assert tool_message.type == "tool" assert tool_message.status == "error" assert tool_message.content == "Value error" # test raising for an unhandled exception, via: # - passing a tuple of all exceptions # - passing a callable with all exceptions in the signature for handle_tool_errors in ((ValueError,), handle_value_error): with pytest.raises(ToolException) as exc_info: await ToolNode( [tool1, tool2], handle_tool_errors=handle_tool_errors ).ainvoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool1", "args": {"some_val": 0, "some_other_val": "foo"}, "id": "some id", }, { "name": "tool2", "args": {"some_val": 0, "some_other_val": "bar"}, "id": "some other id", }, ], ) ] } ) assert str(exc_info.value) == "Test error" for handle_tool_errors in ((ToolException,), handle_tool_exception): with pytest.raises(ValueError) as exc_info: await ToolNode( [tool1, tool2], handle_tool_errors=handle_tool_errors ).ainvoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool1", "args": {"some_val": 0, "some_other_val": "foo"}, "id": "some id", }, { "name": "tool2", "args": {"some_val": 0, "some_other_val": "bar"}, "id": "some other id", }, ], ) ] } ) assert str(exc_info.value) == "Test error" async def test_tool_node_handle_tool_errors_false(): with pytest.raises(ValueError) as exc_info: ToolNode([tool1], handle_tool_errors=False).invoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool1", "args": {"some_val": 0, "some_other_val": "foo"}, "id": "some id", } ], ) ] } ) assert str(exc_info.value) == "Test error" with pytest.raises(ToolException): await ToolNode([tool2], handle_tool_errors=False).ainvoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool2", "args": {"some_val": 0, "some_other_val": "bar"}, "id": "some id", } ], ) ] } ) assert str(exc_info.value) == "Test error" # test validation errors get raised if handle_tool_errors is False with pytest.raises((ValidationError, ValidationErrorV1)): ToolNode([tool1], handle_tool_errors=False).invoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool1", "args": {"some_val": 0}, "id": "some id", } ], ) ] } ) def test_tool_node_individual_tool_error_handling(): # test error handling on individual tools (and that it overrides overall error handling!) result_individual_tool_error_handler = ToolNode( [tool5], handle_tool_errors="bar" ).invoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool5", "args": {"some_val": 0}, "id": "some 0", } ], ) ] } ) tool_message: ToolMessage = result_individual_tool_error_handler["messages"][-1] assert tool_message.type == "tool" assert tool_message.status == "error" assert tool_message.content == "foo" assert tool_message.tool_call_id == "some 0" def test_tool_node_incorrect_tool_name(): result_incorrect_name = ToolNode([tool1, tool2]).invoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool3", "args": {"some_val": 1, "some_other_val": "foo"}, "id": "some 0", } ], ) ] } ) tool_message: ToolMessage = result_incorrect_name["messages"][-1] assert tool_message.type == "tool" assert tool_message.status == "error" assert ( tool_message.content == "Error: tool3 is not a valid tool, try one of [tool1, tool2]." ) assert tool_message.tool_call_id == "some 0" def test_tool_node_node_interrupt(): def tool_normal(some_val: int) -> str: """Tool docstring.""" return "normal" def tool_interrupt(some_val: int) -> str: """Tool docstring.""" raise NodeInterrupt("foo") def handle(e: NodeInterrupt): return "handled" for handle_tool_errors in (True, (NodeInterrupt,), "handled", handle, False): node = ToolNode([tool_interrupt], handle_tool_errors=handle_tool_errors) with pytest.raises(NodeInterrupt) as exc_info: node.invoke( { "messages": [ AIMessage( "hi?", tool_calls=[ { "name": "tool_interrupt", "args": {"some_val": 0}, "id": "some 0", } ], ) ] } ) assert exc_info.value == "foo" # test inside react agent model = FakeToolCallingModel( tool_calls=[ [ ToolCall(name="tool_interrupt", args={"some_val": 0}, id="1"), ToolCall(name="tool_normal", args={"some_val": 1}, id="2"), ], [], ] ) checkpointer = MemorySaver() config = {"configurable": {"thread_id": "1"}} agent = create_react_agent( model, [tool_interrupt, tool_normal], checkpointer=checkpointer ) result = agent.invoke({"messages": [HumanMessage("hi?")]}, config) assert result["messages"] == [ _AnyIdHumanMessage( content="hi?", ), AIMessage( content="hi?", id="0", tool_calls=[ { "name": "tool_interrupt", "args": {"some_val": 0}, "id": "1", "type": "tool_call", }, { "name": "tool_normal", "args": {"some_val": 1}, "id": "2", "type": "tool_call", }, ], ), ] state = agent.get_state(config) assert state.next == ("tools",) task = state.tasks[0] assert task.name == "tools" assert task.interrupts == (Interrupt(value="foo", when="during"),) def my_function(some_val: int, some_other_val: str) -> str: return f"{some_val} - {some_other_val}" class MyModel(BaseModel): some_val: int some_other_val: str class MyModelV1(BaseModelV1): some_val: int some_other_val: str @dec_tool def my_tool(some_val: int, some_other_val: str) -> str: """Cool.""" return f"{some_val} - {some_other_val}" @pytest.mark.parametrize( "tool_schema", [ my_function, MyModel, MyModelV1, my_tool, ], ) @pytest.mark.parametrize("use_message_key", [True, False]) async def test_validation_node(tool_schema: Any, use_message_key: bool): validation_node = ValidationNode([tool_schema]) tool_name = getattr(tool_schema, "name", getattr(tool_schema, "__name__", None)) inputs = [ AIMessage( "hi?", tool_calls=[ { "name": tool_name, "args": {"some_val": 1, "some_other_val": "foo"}, "id": "some 0", }, { "name": tool_name, # Wrong type for some_val "args": {"some_val": "bar", "some_other_val": "foo"}, "id": "some 1", }, ], ), ] if use_message_key: inputs = {"messages": inputs} result = await validation_node.ainvoke(inputs) if use_message_key: result = result["messages"] def check_results(messages: list): assert len(messages) == 2 assert all(m.type == "tool" for m in messages) assert not messages[0].additional_kwargs.get("is_error") assert messages[1].additional_kwargs.get("is_error") check_results(result) result_sync = validation_node.invoke(inputs) if use_message_key: result_sync = result_sync["messages"] check_results(result_sync) class _InjectStateSchema(TypedDict): messages: list foo: str class _InjectedStatePydanticSchema(BaseModelV1): messages: list foo: str class _InjectedStatePydanticV2Schema(BaseModel): messages: list foo: str @dataclasses.dataclass class _InjectedStateDataclassSchema: messages: list foo: str T = TypeVar("T") @pytest.mark.parametrize( "schema_", [ _InjectStateSchema, _InjectedStatePydanticSchema, _InjectedStatePydanticV2Schema, _InjectedStateDataclassSchema, ], ) def test_tool_node_inject_state(schema_: Type[T]) -> None: def tool1(some_val: int, state: Annotated[T, InjectedState]) -> str: """Tool 1 docstring.""" if isinstance(state, dict): return state["foo"] else: return getattr(state, "foo") def tool2(some_val: int, state: Annotated[T, InjectedState()]) -> str: """Tool 2 docstring.""" if isinstance(state, dict): return state["foo"] else: return getattr(state, "foo") def tool3( some_val: int, foo: Annotated[str, InjectedState("foo")], msgs: Annotated[List[AnyMessage], InjectedState("messages")], ) -> str: """Tool 1 docstring.""" return foo def tool4( some_val: int, msgs: Annotated[List[AnyMessage], InjectedState("messages")] ) -> str: """Tool 1 docstring.""" return msgs[0].content node = ToolNode([tool1, tool2, tool3, tool4]) for tool_name in ("tool1", "tool2", "tool3"): tool_call = { "name": tool_name, "args": {"some_val": 1}, "id": "some 0", "type": "tool_call", } msg = AIMessage("hi?", tool_calls=[tool_call]) result = node.invoke(schema_(**{"messages": [msg], "foo": "bar"})) tool_message = result["messages"][-1] assert tool_message.content == "bar", f"Failed for tool={tool_name}" if tool_name == "tool3": failure_input = None try: failure_input = schema_(**{"messages": [msg], "notfoo": "bar"}) except Exception: pass if failure_input is not None: with pytest.raises(KeyError): node.invoke(failure_input) with pytest.raises(ValueError): node.invoke([msg]) else: failure_input = None try: failure_input = schema_(**{"messages": [msg], "notfoo": "bar"}) except Exception: # We'd get a validation error from pydantic state and wouldn't make it to the node # anyway pass if failure_input is not None: messages_ = node.invoke(failure_input) tool_message = messages_["messages"][-1] assert "KeyError" in tool_message.content tool_message = node.invoke([msg])[-1] assert "KeyError" in tool_message.content tool_call = { "name": "tool4", "args": {"some_val": 1}, "id": "some 0", "type": "tool_call", } msg = AIMessage("hi?", tool_calls=[tool_call]) result = node.invoke(schema_(**{"messages": [msg], "foo": ""})) tool_message = result["messages"][-1] assert tool_message.content == "hi?" result = node.invoke([msg]) tool_message = result[-1] assert tool_message.content == "hi?" @pytest.mark.skipif( not IS_LANGCHAIN_CORE_030_OR_GREATER, reason="Langchain core 0.3.0 or greater is required", ) def test_tool_node_inject_store() -> None: store = InMemoryStore() namespace = ("test",) def tool1(some_val: int, store: Annotated[BaseStore, InjectedStore()]) -> str: """Tool 1 docstring.""" store_val = store.get(namespace, "test_key").value["foo"] return f"Some val: {some_val}, store val: {store_val}" def tool2(some_val: int, store: Annotated[BaseStore, InjectedStore()]) -> str: """Tool 2 docstring.""" store_val = store.get(namespace, "test_key").value["foo"] return f"Some val: {some_val}, store val: {store_val}" def tool3( some_val: int, bar: Annotated[str, InjectedState("bar")], store: Annotated[BaseStore, InjectedStore()], ) -> str: """Tool 3 docstring.""" store_val = store.get(namespace, "test_key").value["foo"] return f"Some val: {some_val}, store val: {store_val}, state val: {bar}" node = ToolNode([tool1, tool2, tool3], handle_tool_errors=True) store.put(namespace, "test_key", {"foo": "bar"}) class State(MessagesState): bar: str builder = StateGraph(State) builder.add_node("tools", node) builder.add_edge(START, "tools") graph = builder.compile(store=store) for tool_name in ("tool1", "tool2"): tool_call = { "name": tool_name, "args": {"some_val": 1}, "id": "some 0", "type": "tool_call", } msg = AIMessage("hi?", tool_calls=[tool_call]) node_result = node.invoke({"messages": [msg]}, store=store) graph_result = graph.invoke({"messages": [msg]}) for result in (node_result, graph_result): result["messages"][-1] tool_message = result["messages"][-1] assert ( tool_message.content == "Some val: 1, store val: bar" ), f"Failed for tool={tool_name}" tool_call = { "name": "tool3", "args": {"some_val": 1}, "id": "some 0", "type": "tool_call", } msg = AIMessage("hi?", tool_calls=[tool_call]) node_result = node.invoke({"messages": [msg], "bar": "baz"}, store=store) graph_result = graph.invoke({"messages": [msg], "bar": "baz"}) for result in (node_result, graph_result): result["messages"][-1] tool_message = result["messages"][-1] assert ( tool_message.content == "Some val: 1, store val: bar, state val: baz" ), f"Failed for tool={tool_name}" # test injected store without passing store to compiled graph failing_graph = builder.compile() with pytest.raises(ValueError): failing_graph.invoke({"messages": [msg], "bar": "baz"}) def test_tool_node_ensure_utf8() -> None: @dec_tool def get_day_list(days: list[str]) -> list[str]: """choose days""" return days data = ["星期一", "水曜日", "목요일", "Friday"] tools = [get_day_list] tool_calls = [ToolCall(name=get_day_list.name, args={"days": data}, id="test_id")] outputs: list[ToolMessage] = ToolNode(tools).invoke( [AIMessage(content="", tool_calls=tool_calls)] ) assert outputs[0].content == json.dumps(data, ensure_ascii=False) def test_tool_node_messages_key() -> None: @dec_tool def add(a: int, b: int): """Adds a and b.""" return a + b model = FakeToolCallingModel( tool_calls=[[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")]] ) class State(TypedDict): subgraph_messages: Annotated[list[AnyMessage], add_messages] def call_model(state: State): response = model.invoke(state["subgraph_messages"]) model.tool_calls = [] return {"subgraph_messages": response} builder = StateGraph(State) builder.add_node("agent", call_model) builder.add_node("tools", ToolNode([add], messages_key="subgraph_messages")) builder.add_conditional_edges( "agent", partial(tools_condition, messages_key="subgraph_messages") ) builder.add_edge(START, "agent") builder.add_edge("tools", "agent") graph = builder.compile() result = graph.invoke({"subgraph_messages": [HumanMessage(content="hi")]}) assert result["subgraph_messages"] == [ _AnyIdHumanMessage(content="hi"), AIMessage( content="hi", id="0", tool_calls=[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")], ), _AnyIdToolMessage(content="3", name=add.name, tool_call_id="test_id"), AIMessage(content="hi-hi-3", id="1"), ] async def test_return_direct() -> None: @dec_tool(return_direct=True) def tool_return_direct(input: str) -> str: """A tool that returns directly.""" return f"Direct result: {input}" @dec_tool def tool_normal(input: str) -> str: """A normal tool.""" return f"Normal result: {input}" first_tool_call = [ ToolCall( name="tool_return_direct", args={"input": "Test direct"}, id="1", ), ] expected_ai = AIMessage( content="Test direct", id="0", tool_calls=first_tool_call, ) model = FakeToolCallingModel(tool_calls=[first_tool_call, []]) agent = create_react_agent(model, [tool_return_direct, tool_normal]) # Test direct return for tool_return_direct result = agent.invoke( {"messages": [HumanMessage(content="Test direct", id="hum0")]} ) assert result["messages"] == [ HumanMessage(content="Test direct", id="hum0"), expected_ai, ToolMessage( content="Direct result: Test direct", name="tool_return_direct", tool_call_id="1", id=result["messages"][2].id, ), ] second_tool_call = [ ToolCall( name="tool_normal", args={"input": "Test normal"}, id="2", ), ] model = FakeToolCallingModel(tool_calls=[second_tool_call, []]) agent = create_react_agent(model, [tool_return_direct, tool_normal]) result = agent.invoke( {"messages": [HumanMessage(content="Test normal", id="hum1")]} ) assert result["messages"] == [ HumanMessage(content="Test normal", id="hum1"), AIMessage(content="Test normal", id="0", tool_calls=second_tool_call), ToolMessage( content="Normal result: Test normal", name="tool_normal", tool_call_id="2", id=result["messages"][2].id, ), AIMessage(content="Test normal-Test normal-Normal result: Test normal", id="1"), ] both_tool_calls = [ ToolCall( name="tool_return_direct", args={"input": "Test both direct"}, id="3", ), ToolCall( name="tool_normal", args={"input": "Test both normal"}, id="4", ), ] model = FakeToolCallingModel(tool_calls=[both_tool_calls, []]) agent = create_react_agent(model, [tool_return_direct, tool_normal]) result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]}) assert result["messages"] == [ HumanMessage(content="Test both", id="hum2"), AIMessage(content="Test both", id="0", tool_calls=both_tool_calls), ToolMessage( content="Direct result: Test both direct", name="tool_return_direct", tool_call_id="3", id=result["messages"][2].id, ), ToolMessage( content="Normal result: Test both normal", name="tool_normal", tool_call_id="4", id=result["messages"][3].id, ), ] def test__get_state_args() -> None: class Schema1(BaseModel): a: Annotated[str, InjectedState] class Schema2(Schema1): b: Annotated[int, InjectedState("bar")] @dec_tool(args_schema=Schema2) def foo(a: str, b: int) -> float: """return""" return 0.0 assert _get_state_args(foo) == {"a": None, "b": "bar"} ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_io.py`: ```py from typing import Iterator from langgraph.pregel.io import single def test_single() -> None: closed = False def myiter() -> Iterator[int]: try: yield 1 yield 2 finally: nonlocal closed closed = True assert single(myiter()) == 1 assert closed ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_state.py`: ```py import inspect import warnings from dataclasses import dataclass, field from typing import Annotated as Annotated2 from typing import Any, Optional import pytest from langchain_core.runnables import RunnableConfig, RunnableLambda from pydantic.v1 import BaseModel from typing_extensions import Annotated, NotRequired, Required, TypedDict from langgraph.graph.state import StateGraph, _get_node_name, _warn_invalid_state_schema from langgraph.managed.shared_value import SharedValue class State(BaseModel): foo: str bar: int class State2(TypedDict): foo: str bar: int @pytest.mark.parametrize( "schema", [ {"foo": "bar"}, ["hi", lambda x, y: x + y], State(foo="bar", bar=1), State2(foo="bar", bar=1), ], ) def test_warns_invalid_schema(schema: Any): with pytest.warns(UserWarning): _warn_invalid_state_schema(schema) @pytest.mark.parametrize( "schema", [ Annotated[dict, lambda x, y: y], Annotated2[list, lambda x, y: y], dict, State, State2, ], ) def test_doesnt_warn_valid_schema(schema: Any): # Assert the function does not raise a warning with warnings.catch_warnings(): warnings.simplefilter("error") _warn_invalid_state_schema(schema) def test_state_schema_with_type_hint(): class InputState(TypedDict): question: str class OutputState(TypedDict): input_state: InputState class FooState(InputState): foo: str def complete_hint(state: InputState) -> OutputState: return {"input_state": state} def miss_first_hint(state, config: RunnableConfig) -> OutputState: return {"input_state": state} def only_return_hint(state, config) -> OutputState: return {"input_state": state} def miss_all_hint(state, config): return {"input_state": state} def pre_foo(_) -> FooState: return {"foo": "bar"} class Foo: def __call__(self, state: FooState) -> OutputState: assert state.pop("foo") == "bar" return {"input_state": state} graph = StateGraph(InputState, output=OutputState) actions = [ complete_hint, miss_first_hint, only_return_hint, miss_all_hint, pre_foo, Foo(), ] for action in actions: graph.add_node(action) def get_name(action) -> str: return getattr(action, "__name__", action.__class__.__name__) graph.set_entry_point(get_name(actions[0])) for i in range(len(actions) - 1): graph.add_edge(get_name(actions[i]), get_name(actions[i + 1])) graph.set_finish_point(get_name(actions[-1])) graph = graph.compile() input_state = InputState(question="Hello World!") output_state = OutputState(input_state=input_state) foo_state = FooState(foo="bar") for i, c in enumerate(graph.stream(input_state, stream_mode="updates")): node_name = get_name(actions[i]) if node_name == get_name(pre_foo): assert c[node_name] == foo_state else: assert c[node_name] == output_state @pytest.mark.parametrize("total_", [True, False]) def test_state_schema_optional_values(total_: bool): class SomeParentState(TypedDict): val0a: str val0b: Optional[str] class InputState(SomeParentState, total=total_): # type: ignore val1: str val2: Optional[str] val3: Required[str] val4: NotRequired[dict] val5: Annotated[Required[str], "foo"] val6: Annotated[NotRequired[str], "bar"] class OutputState(SomeParentState, total=total_): # type: ignore out_val1: str out_val2: Optional[str] out_val3: Required[str] out_val4: NotRequired[dict] out_val5: Annotated[Required[str], "foo"] out_val6: Annotated[NotRequired[str], "bar"] class State(InputState): # this would be ignored val4: dict some_shared_channel: Annotated[str, SharedValue.on("assistant_id")] = field( default="foo" ) builder = StateGraph(State, input=InputState, output=OutputState) builder.add_node("n", lambda x: x) builder.add_edge("__start__", "n") graph = builder.compile() json_schema = graph.get_input_jsonschema() if total_ is False: expected_required = set() expected_optional = {"val2", "val1"} else: expected_required = {"val1"} expected_optional = {"val2"} # The others should always have precedence based on the required annotation expected_required |= {"val0a", "val3", "val5"} expected_optional |= {"val0b", "val4", "val6"} assert set(json_schema.get("required", set())) == expected_required assert ( set(json_schema["properties"].keys()) == expected_required | expected_optional ) # Check output schema. Should be the same process output_schema = graph.get_output_jsonschema() if total_ is False: expected_required = set() expected_optional = {"out_val2", "out_val1"} else: expected_required = {"out_val1"} expected_optional = {"out_val2"} expected_required |= {"val0a", "out_val3", "out_val5"} expected_optional |= {"val0b", "out_val4", "out_val6"} assert set(output_schema.get("required", set())) == expected_required assert ( set(output_schema["properties"].keys()) == expected_required | expected_optional ) @pytest.mark.parametrize("kw_only_", [False, True]) def test_state_schema_default_values(kw_only_: bool): kwargs = {} if "kw_only" in inspect.signature(dataclass).parameters: kwargs = {"kw_only": kw_only_} @dataclass(**kwargs) class InputState: val1: str val2: Optional[int] val3: Annotated[Optional[float], "optional annotated"] val4: Optional[str] = None val5: list[int] = field(default_factory=lambda: [1, 2, 3]) val6: dict[str, int] = field(default_factory=lambda: {"a": 1}) val7: str = field(default=...) val8: Annotated[int, "some metadata"] = 42 val9: Annotated[str, "more metadata"] = field(default="some foo") val10: str = "default" val11: Annotated[list[str], "annotated list"] = field( default_factory=lambda: ["a", "b"] ) some_shared_channel: Annotated[str, SharedValue.on("assistant_id")] = field( default="foo" ) builder = StateGraph(InputState) builder.add_node("n", lambda x: x) builder.add_edge("__start__", "n") graph = builder.compile() for json_schema in [graph.get_input_jsonschema(), graph.get_output_jsonschema()]: expected_required = {"val1", "val7"} expected_optional = { "val2", "val3", "val4", "val5", "val6", "val8", "val9", "val10", "val11", } assert set(json_schema.get("required", set())) == expected_required assert ( set(json_schema["properties"].keys()) == expected_required | expected_optional ) def test_raises_invalid_managed(): class BadInputState(TypedDict): some_thing: str some_input_channel: Annotated[str, SharedValue.on("assistant_id")] class InputState(TypedDict): some_thing: str some_input_channel: str class BadOutputState(TypedDict): some_thing: str some_output_channel: Annotated[str, SharedValue.on("assistant_id")] class OutputState(TypedDict): some_thing: str some_output_channel: str class State(TypedDict): some_thing: str some_channel: Annotated[str, SharedValue.on("assistant_id")] # All OK StateGraph(State, input=InputState, output=OutputState) StateGraph(State) StateGraph(State, input=State, output=State) StateGraph(State, input=InputState) StateGraph(State, input=InputState) bad_input_examples = [ (State, BadInputState, OutputState), (State, BadInputState, BadOutputState), (State, BadInputState, State), (State, BadInputState, None), ] for _state, _inp, _outp in bad_input_examples: with pytest.raises( ValueError, match="Invalid managed channels detected in BadInputState: some_input_channel. Managed channels are not permitted in Input/Output schema.", ): StateGraph(_state, input=_inp, output=_outp) bad_output_examples = [ (State, InputState, BadOutputState), (State, None, BadOutputState), ] for _state, _inp, _outp in bad_output_examples: with pytest.raises( ValueError, match="Invalid managed channels detected in BadOutputState: some_output_channel. Managed channels are not permitted in Input/Output schema.", ): StateGraph(_state, input=_inp, output=_outp) def test__get_node_name() -> None: # default runnable name assert _get_node_name(RunnableLambda(func=lambda x: x)) == "RunnableLambda" # custom runnable name assert ( _get_node_name(RunnableLambda(name="my_runnable", func=lambda x: x)) == "my_runnable" ) # lambda assert _get_node_name(lambda x: x) == "<lambda>" # regular function def func(state): return assert _get_node_name(func) == "func" class MyClass: def __call__(self, state): return def class_method(self, state): return # callable class assert _get_node_name(MyClass()) == "MyClass" # class method assert _get_node_name(MyClass().class_method) == "class_method" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_messages_state.py`: ```py from typing import Annotated from uuid import UUID import pytest from langchain_core.messages import ( AIMessage, AnyMessage, HumanMessage, RemoveMessage, SystemMessage, ) from pydantic import BaseModel from pydantic.v1 import BaseModel as BaseModelV1 from langgraph.graph import add_messages from langgraph.graph.message import MessagesState from langgraph.graph.state import END, START, StateGraph from tests.conftest import IS_LANGCHAIN_CORE_030_OR_GREATER from tests.messages import _AnyIdHumanMessage def test_add_single_message(): left = [HumanMessage(content="Hello", id="1")] right = AIMessage(content="Hi there!", id="2") result = add_messages(left, right) expected_result = [ HumanMessage(content="Hello", id="1"), AIMessage(content="Hi there!", id="2"), ] assert result == expected_result def test_add_multiple_messages(): left = [HumanMessage(content="Hello", id="1")] right = [ AIMessage(content="Hi there!", id="2"), SystemMessage(content="System message", id="3"), ] result = add_messages(left, right) expected_result = [ HumanMessage(content="Hello", id="1"), AIMessage(content="Hi there!", id="2"), SystemMessage(content="System message", id="3"), ] assert result == expected_result def test_update_existing_message(): left = [HumanMessage(content="Hello", id="1")] right = HumanMessage(content="Hello again", id="1") result = add_messages(left, right) expected_result = [HumanMessage(content="Hello again", id="1")] assert result == expected_result def test_missing_ids(): left = [HumanMessage(content="Hello")] right = [AIMessage(content="Hi there!")] result = add_messages(left, right) assert len(result) == 2 assert all(isinstance(m.id, str) and UUID(m.id, version=4) for m in result) def test_remove_message(): left = [ HumanMessage(content="Hello", id="1"), AIMessage(content="Hi there!", id="2"), ] right = RemoveMessage(id="2") result = add_messages(left, right) expected_result = [HumanMessage(content="Hello", id="1")] assert result == expected_result def test_duplicate_remove_message(): left = [ HumanMessage(content="Hello", id="1"), AIMessage(content="Hi there!", id="2"), ] right = [RemoveMessage(id="2"), RemoveMessage(id="2")] result = add_messages(left, right) expected_result = [HumanMessage(content="Hello", id="1")] assert result == expected_result def test_remove_nonexistent_message(): left = [HumanMessage(content="Hello", id="1")] right = RemoveMessage(id="2") with pytest.raises( ValueError, match="Attempting to delete a message with an ID that doesn't exist" ): add_messages(left, right) def test_mixed_operations(): left = [ HumanMessage(content="Hello", id="1"), AIMessage(content="Hi there!", id="2"), ] right = [ HumanMessage(content="Updated hello", id="1"), RemoveMessage(id="2"), SystemMessage(content="New message", id="3"), ] result = add_messages(left, right) expected_result = [ HumanMessage(content="Updated hello", id="1"), SystemMessage(content="New message", id="3"), ] assert result == expected_result def test_empty_inputs(): assert add_messages([], []) == [] assert add_messages([], [HumanMessage(content="Hello", id="1")]) == [ HumanMessage(content="Hello", id="1") ] assert add_messages([HumanMessage(content="Hello", id="1")], []) == [ HumanMessage(content="Hello", id="1") ] def test_non_list_inputs(): left = HumanMessage(content="Hello", id="1") right = AIMessage(content="Hi there!", id="2") result = add_messages(left, right) expected_result = [ HumanMessage(content="Hello", id="1"), AIMessage(content="Hi there!", id="2"), ] assert result == expected_result def test_delete_all(): left = [ HumanMessage(content="Hello", id="1"), AIMessage(content="Hi there!", id="2"), ] right = [ RemoveMessage(id="1"), RemoveMessage(id="2"), ] result = add_messages(left, right) expected_result = [] assert result == expected_result MESSAGES_STATE_SCHEMAS = [MessagesState] if IS_LANGCHAIN_CORE_030_OR_GREATER: class MessagesStatePydantic(BaseModel): messages: Annotated[list[AnyMessage], add_messages] MESSAGES_STATE_SCHEMAS.append(MessagesStatePydantic) else: class MessagesStatePydanticV1(BaseModelV1): messages: Annotated[list[AnyMessage], add_messages] MESSAGES_STATE_SCHEMAS.append(MessagesStatePydanticV1) @pytest.mark.parametrize("state_schema", MESSAGES_STATE_SCHEMAS) def test_messages_state(state_schema): def foo(state): return {"messages": [HumanMessage("foo")]} graph = StateGraph(state_schema) graph.add_edge(START, "foo") graph.add_edge("foo", END) graph.add_node(foo) app = graph.compile() assert app.invoke({"messages": [("user", "meow")]}) == { "messages": [ _AnyIdHumanMessage(content="meow"), _AnyIdHumanMessage(content="foo"), ] } ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_channels.py`: ```py import operator from typing import Sequence, Union import pytest from langgraph.channels.binop import BinaryOperatorAggregate from langgraph.channels.last_value import LastValue from langgraph.channels.topic import Topic from langgraph.errors import EmptyChannelError, InvalidUpdateError pytestmark = pytest.mark.anyio def test_last_value() -> None: channel = LastValue(int).from_checkpoint(None) assert channel.ValueType is int assert channel.UpdateType is int with pytest.raises(EmptyChannelError): channel.get() with pytest.raises(InvalidUpdateError): channel.update([5, 6]) channel.update([3]) assert channel.get() == 3 channel.update([4]) assert channel.get() == 4 checkpoint = channel.checkpoint() channel = LastValue(int).from_checkpoint(checkpoint) assert channel.get() == 4 def test_topic() -> None: channel = Topic(str).from_checkpoint(None) assert channel.ValueType is Sequence[str] assert channel.UpdateType is Union[str, list[str]] assert channel.update(["a", "b"]) assert channel.get() == ["a", "b"] assert channel.update([["c", "d"], "d"]) assert channel.get() == ["c", "d", "d"] assert channel.update([]) with pytest.raises(EmptyChannelError): channel.get() assert not channel.update([]), "channel already empty" assert channel.update(["e"]) assert channel.get() == ["e"] checkpoint = channel.checkpoint() channel = Topic(str).from_checkpoint(checkpoint) assert channel.get() == ["e"] channel_copy = Topic(str).from_checkpoint(checkpoint) channel_copy.update(["f"]) assert channel_copy.get() == ["f"] assert channel.get() == ["e"] def test_topic_accumulate() -> None: channel = Topic(str, accumulate=True).from_checkpoint(None) assert channel.ValueType is Sequence[str] assert channel.UpdateType is Union[str, list[str]] assert channel.update(["a", "b"]) assert channel.get() == ["a", "b"] assert channel.update(["b", ["c", "d"], "d"]) assert channel.get() == ["a", "b", "b", "c", "d", "d"] assert not channel.update([]) assert channel.get() == ["a", "b", "b", "c", "d", "d"] checkpoint = channel.checkpoint() channel = Topic(str, accumulate=True).from_checkpoint(checkpoint) assert channel.get() == ["a", "b", "b", "c", "d", "d"] assert channel.update(["e"]) assert channel.get() == ["a", "b", "b", "c", "d", "d", "e"] def test_binop() -> None: channel = BinaryOperatorAggregate(int, operator.add).from_checkpoint(None) assert channel.ValueType is int assert channel.UpdateType is int assert channel.get() == 0 channel.update([1, 2, 3]) assert channel.get() == 6 channel.update([4]) assert channel.get() == 10 checkpoint = channel.checkpoint() channel = BinaryOperatorAggregate(int, operator.add).from_checkpoint(checkpoint) assert channel.get() == 10 ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_runnable.py`: ```py from __future__ import annotations from typing import Any import pytest from langgraph.store.base import BaseStore from langgraph.types import StreamWriter from langgraph.utils.runnable import RunnableCallable pytestmark = pytest.mark.anyio def test_runnable_callable_func_accepts(): def sync_func(x: Any) -> str: return f"{x}" async def async_func(x: Any) -> str: return f"{x}" def func_with_store(x: Any, store: BaseStore) -> str: return f"{x}" def func_with_writer(x: Any, writer: StreamWriter) -> str: return f"{x}" async def afunc_with_store(x: Any, store: BaseStore) -> str: return f"{x}" async def afunc_with_writer(x: Any, writer: StreamWriter) -> str: return f"{x}" runnables = { "sync": RunnableCallable(sync_func), "async": RunnableCallable(func=None, afunc=async_func), "with_store": RunnableCallable(func_with_store), "with_writer": RunnableCallable(func_with_writer), "awith_store": RunnableCallable(afunc_with_store), "awith_writer": RunnableCallable(afunc_with_writer), } expected_store = {"with_store": True, "awith_store": True} expected_writer = {"with_writer": True, "awith_writer": True} for name, runnable in runnables.items(): assert runnable.func_accepts["writer"] == expected_writer.get(name, False) assert runnable.func_accepts["store"] == expected_store.get(name, False) async def test_runnable_callable_basic(): def sync_func(x: Any) -> str: return f"{x}" async def async_func(x: Any) -> str: return f"{x}" runnable_sync = RunnableCallable(sync_func) runnable_async = RunnableCallable(func=None, afunc=async_func) result_sync = runnable_sync.invoke("test") assert result_sync == "test" # Test asynchronous ainvoke result_async = await runnable_async.ainvoke("test") assert result_async == "test" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_tracing_interops.py`: ```py import json import sys import time from typing import Any, Callable, Tuple, TypedDict, TypeVar from unittest.mock import MagicMock import langsmith as ls import pytest from langchain_core.runnables import RunnableConfig from langchain_core.tracers import LangChainTracer from langgraph.graph import StateGraph pytestmark = pytest.mark.anyio def _get_mock_client(**kwargs: Any) -> ls.Client: mock_session = MagicMock() return ls.Client(session=mock_session, api_key="test", **kwargs) def _get_calls( mock_client: Any, verbs: set[str] = {"POST"}, ) -> list: return [ c for c in mock_client.session.request.mock_calls if c.args and c.args[0] in verbs ] T = TypeVar("T") def wait_for( condition: Callable[[], Tuple[T, bool]], max_sleep_time: int = 10, sleep_time: int = 3, ) -> T: """Wait for a condition to be true.""" start_time = time.time() last_e = None while time.time() - start_time < max_sleep_time: try: res, cond = condition() if cond: return res except Exception as e: last_e = e time.sleep(sleep_time) total_time = time.time() - start_time if last_e is not None: raise last_e raise ValueError(f"Callable did not return within {total_time}") @pytest.mark.skip("This test times out in CI") async def test_nested_tracing(): lt_py_311 = sys.version_info < (3, 11) mock_client = _get_mock_client() class State(TypedDict): value: str @ls.traceable async def some_traceable(content: State): return await child_graph.ainvoke(content) async def parent_node(state: State, config: RunnableConfig) -> State: if lt_py_311: result = await some_traceable(state, langsmith_extra={"config": config}) else: result = await some_traceable(state) return {"value": f"parent_{result['value']}"} async def child_node(state: State) -> State: return {"value": f"child_{state['value']}"} child_builder = StateGraph(State) child_builder.add_node(child_node) child_builder.add_edge("__start__", "child_node") child_graph = child_builder.compile().with_config(run_name="child_graph") parent_builder = StateGraph(State) parent_builder.add_node(parent_node) parent_builder.add_edge("__start__", "parent_node") parent_graph = parent_builder.compile() tracer = LangChainTracer(client=mock_client) result = await parent_graph.ainvoke({"value": "input"}, {"callbacks": [tracer]}) assert result == {"value": "parent_child_input"} def get_posts(): post_calls = _get_calls(mock_client, verbs={"POST"}) posts = [p for c in post_calls for p in json.loads(c.kwargs["data"])["post"]] names = [p.get("name") for p in posts] if "child_node" in names: return posts, True return None, False posts = wait_for(get_posts) # If the callbacks weren't propagated correctly, we'd # end up with broken dotted_orders parent_run = next(data for data in posts if data["name"] == "parent_node") child_run = next(data for data in posts if data["name"] == "child_graph") traceable_run = next(data for data in posts if data["name"] == "some_traceable") assert child_run["dotted_order"].startswith(traceable_run["dotted_order"]) assert traceable_run["dotted_order"].startswith(parent_run["dotted_order"]) assert child_run["parent_run_id"] == traceable_run["id"] assert traceable_run["parent_run_id"] == parent_run["id"] assert parent_run["trace_id"] == child_run["trace_id"] == traceable_run["trace_id"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/messages.py`: ```py """Redefined messages as a work-around for pydantic issue with AnyStr. The code below creates version of pydantic models that will work in unit tests with AnyStr as id field Please note that the `id` field is assigned AFTER the model is created to workaround an issue with pydantic ignoring the __eq__ method on subclassed strings. """ from typing import Any from langchain_core.documents import Document from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage from tests.any_str import AnyStr def _AnyIdDocument(**kwargs: Any) -> Document: """Create a document with an id field.""" message = Document(**kwargs) message.id = AnyStr() return message def _AnyIdAIMessage(**kwargs: Any) -> AIMessage: """Create ai message with an any id field.""" message = AIMessage(**kwargs) message.id = AnyStr() return message def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk: """Create ai message with an any id field.""" message = AIMessageChunk(**kwargs) message.id = AnyStr() return message def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: """Create a human message with an any id field.""" message = HumanMessage(**kwargs) message.id = AnyStr() return message def _AnyIdToolMessage(**kwargs: Any) -> ToolMessage: """Create a tool message with an any id field.""" message = ToolMessage(**kwargs) message.id = AnyStr() return message ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_pregel.py`: ```py import enum import json import logging import operator import re import time import uuid import warnings from collections import Counter from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from dataclasses import replace from random import randrange from typing import ( Annotated, Any, Dict, Generator, Iterator, List, Literal, Optional, Sequence, Tuple, TypedDict, Union, cast, get_type_hints, ) import httpx import pytest from langchain_core.runnables import ( RunnableConfig, RunnableLambda, RunnableMap, RunnablePassthrough, RunnablePick, ) from langsmith import traceable from pydantic import BaseModel from pytest_mock import MockerFixture from syrupy import SnapshotAssertion from langgraph.channels.base import BaseChannel from langgraph.channels.binop import BinaryOperatorAggregate from langgraph.channels.context import Context from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.channels.last_value import LastValue from langgraph.channels.topic import Topic from langgraph.channels.untracked_value import UntrackedValue from langgraph.checkpoint.base import ( BaseCheckpointSaver, Checkpoint, CheckpointMetadata, CheckpointTuple, ) from langgraph.checkpoint.memory import MemorySaver from langgraph.constants import ( CONFIG_KEY_NODE_FINISHED, ERROR, FF_SEND_V2, PULL, PUSH, START, ) from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt from langgraph.graph import END, Graph, StateGraph from langgraph.graph.message import MessageGraph, MessagesState, add_messages from langgraph.managed.shared_value import SharedValue from langgraph.prebuilt.chat_agent_executor import create_tool_calling_executor from langgraph.prebuilt.tool_node import ToolNode from langgraph.pregel import Channel, GraphRecursionError, Pregel, StateSnapshot from langgraph.pregel.retry import RetryPolicy from langgraph.store.base import BaseStore from langgraph.store.memory import InMemoryStore from langgraph.types import ( Command, Interrupt, PregelTask, Send, StreamWriter, interrupt, ) from tests.any_str import AnyDict, AnyStr, AnyVersion, FloatBetween, UnsortedSequence from tests.conftest import ( ALL_CHECKPOINTERS_SYNC, ALL_STORES_SYNC, SHOULD_CHECK_SNAPSHOTS, ) from tests.fake_chat import FakeChatModel from tests.fake_tracer import FakeTracer from tests.memory_assert import MemorySaverAssertCheckpointMetadata from tests.messages import ( _AnyIdAIMessage, _AnyIdAIMessageChunk, _AnyIdHumanMessage, _AnyIdToolMessage, ) logger = logging.getLogger(__name__) # define these objects to avoid importing langchain_core.agents # and therefore avoid relying on core Pydantic version class AgentAction(BaseModel): tool: str tool_input: Union[str, dict] log: str type: Literal["AgentAction"] = "AgentAction" model_config = { "json_schema_extra": { "description": ( """Represents a request to execute an action by an agent. The action consists of the name of the tool to execute and the input to pass to the tool. The log is used to pass along extra information about the action.""" ) } } class AgentFinish(BaseModel): """Final return value of an ActionAgent. Agents return an AgentFinish when they have reached a stopping condition. """ return_values: dict log: str type: Literal["AgentFinish"] = "AgentFinish" model_config = { "json_schema_extra": { "description": ( """Final return value of an ActionAgent. Agents return an AgentFinish when they have reached a stopping condition.""" ) } } def test_graph_validation() -> None: def logic(inp: str) -> str: return "" workflow = Graph() workflow.add_node("agent", logic) workflow.set_entry_point("agent") workflow.set_finish_point("agent") assert workflow.compile(), "valid graph" # Accept a dead-end workflow = Graph() workflow.add_node("agent", logic) workflow.set_entry_point("agent") workflow.compile() workflow = Graph() workflow.add_node("agent", logic) workflow.set_finish_point("agent") with pytest.raises(ValueError, match="must have an entrypoint"): workflow.compile() workflow = Graph() workflow.add_node("agent", logic) workflow.add_node("tools", logic) workflow.set_entry_point("agent") workflow.add_conditional_edges("agent", logic, {"continue": "tools", "exit": END}) workflow.add_edge("tools", "agent") assert workflow.compile(), "valid graph" workflow = Graph() workflow.add_node("agent", logic) workflow.add_node("tools", logic) workflow.set_entry_point("tools") workflow.add_conditional_edges("agent", logic, {"continue": "tools", "exit": END}) workflow.add_edge("tools", "agent") assert workflow.compile(), "valid graph" workflow = Graph() workflow.set_entry_point("tools") workflow.add_conditional_edges("agent", logic, {"continue": "tools", "exit": END}) workflow.add_edge("tools", "agent") workflow.add_node("agent", logic) workflow.add_node("tools", logic) assert workflow.compile(), "valid graph" workflow = Graph() workflow.set_entry_point("tools") workflow.add_conditional_edges( "agent", logic, {"continue": "tools", "exit": END, "hmm": "extra"} ) workflow.add_edge("tools", "agent") workflow.add_node("agent", logic) workflow.add_node("tools", logic) with pytest.raises(ValueError, match="unknown"): # extra is not defined workflow.compile() workflow = Graph() workflow.set_entry_point("agent") workflow.add_conditional_edges("agent", logic, {"continue": "tools", "exit": END}) workflow.add_edge("tools", "extra") workflow.add_node("agent", logic) workflow.add_node("tools", logic) with pytest.raises(ValueError, match="unknown"): # extra is not defined workflow.compile() workflow = Graph() workflow.add_node("agent", logic) workflow.add_node("tools", logic) workflow.add_node("extra", logic) workflow.set_entry_point("agent") workflow.add_conditional_edges("agent", logic) workflow.add_edge("tools", "agent") # Accept, even though extra is dead-end workflow.compile() class State(TypedDict): hello: str def node_a(state: State) -> State: # typo return {"hell": "world"} builder = StateGraph(State) builder.add_node("a", node_a) builder.set_entry_point("a") builder.set_finish_point("a") graph = builder.compile() with pytest.raises(InvalidUpdateError): graph.invoke({"hello": "there"}) graph = StateGraph(State) graph.add_node("start", lambda x: x) graph.add_edge("__start__", "start") graph.add_edge("unknown", "start") graph.add_edge("start", "__end__") with pytest.raises(ValueError, match="Found edge starting at unknown node "): graph.compile() def bad_reducer(a): ... class BadReducerState(TypedDict): hello: Annotated[str, bad_reducer] with pytest.raises(ValueError, match="Invalid reducer"): StateGraph(BadReducerState) def node_b(state: State) -> State: return {"hello": "world"} builder = StateGraph(State) builder.add_node("a", node_b) builder.add_node("b", node_b) builder.add_node("c", node_b) builder.set_entry_point("a") builder.add_edge("a", "b") builder.add_edge("a", "c") graph = builder.compile() with pytest.raises(InvalidUpdateError, match="At key 'hello'"): graph.invoke({"hello": "there"}) def test_graph_validation_with_command() -> None: class State(TypedDict): foo: str bar: str def node_a(state: State): return Command(goto="b", update={"foo": "bar"}) def node_b(state: State): return Command(goto=END, update={"bar": "baz"}) builder = StateGraph(State) builder.add_node("a", node_a) builder.add_node("b", node_b) builder.add_edge(START, "a") graph = builder.compile() assert graph.invoke({"foo": ""}) == {"foo": "bar", "bar": "baz"} def test_checkpoint_errors() -> None: class FaultyGetCheckpointer(MemorySaver): def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: raise ValueError("Faulty get_tuple") class FaultyPutCheckpointer(MemorySaver): def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: Optional[dict[str, Union[str, int, float]]] = None, ) -> RunnableConfig: raise ValueError("Faulty put") class FaultyPutWritesCheckpointer(MemorySaver): def put_writes( self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str ) -> RunnableConfig: raise ValueError("Faulty put_writes") class FaultyVersionCheckpointer(MemorySaver): def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int: raise ValueError("Faulty get_next_version") def logic(inp: str) -> str: return "" builder = StateGraph(Annotated[str, operator.add]) builder.add_node("agent", logic) builder.add_edge(START, "agent") graph = builder.compile(checkpointer=FaultyGetCheckpointer()) with pytest.raises(ValueError, match="Faulty get_tuple"): graph.invoke("", {"configurable": {"thread_id": "thread-1"}}) graph = builder.compile(checkpointer=FaultyPutCheckpointer()) with pytest.raises(ValueError, match="Faulty put"): graph.invoke("", {"configurable": {"thread_id": "thread-1"}}) graph = builder.compile(checkpointer=FaultyVersionCheckpointer()) with pytest.raises(ValueError, match="Faulty get_next_version"): graph.invoke("", {"configurable": {"thread_id": "thread-1"}}) # add parallel node builder.add_node("parallel", logic) builder.add_edge(START, "parallel") graph = builder.compile(checkpointer=FaultyPutWritesCheckpointer()) with pytest.raises(ValueError, match="Faulty put_writes"): graph.invoke("", {"configurable": {"thread_id": "thread-1"}}) def test_node_schemas_custom_output() -> None: class State(TypedDict): hello: str bye: str messages: Annotated[list[str], add_messages] class Output(TypedDict): messages: list[str] class StateForA(TypedDict): hello: str messages: Annotated[list[str], add_messages] def node_a(state: StateForA) -> State: assert state == { "hello": "there", "messages": [_AnyIdHumanMessage(content="hello")], } class StateForB(TypedDict): bye: str now: int def node_b(state: StateForB): assert state == { "bye": "world", } return { "now": 123, "hello": "again", } class StateForC(TypedDict): hello: str now: int def node_c(state: StateForC) -> StateForC: assert state == { "hello": "again", "now": 123, } builder = StateGraph(State, output=Output) builder.add_node("a", node_a) builder.add_node("b", node_b) builder.add_node("c", node_c) builder.add_edge(START, "a") builder.add_edge("a", "b") builder.add_edge("b", "c") graph = builder.compile() assert graph.invoke({"hello": "there", "bye": "world", "messages": "hello"}) == { "messages": [_AnyIdHumanMessage(content="hello")], } builder = StateGraph(State, output=Output) builder.add_node("a", node_a) builder.add_node("b", node_b) builder.add_node("c", node_c) builder.add_edge(START, "a") builder.add_edge("a", "b") builder.add_edge("b", "c") graph = builder.compile() assert graph.invoke( { "hello": "there", "bye": "world", "messages": "hello", "now": 345, # ignored because not in input schema } ) == { "messages": [_AnyIdHumanMessage(content="hello")], } assert [ c for c in graph.stream( { "hello": "there", "bye": "world", "messages": "hello", "now": 345, # ignored because not in input schema } ) ] == [ {"a": None}, {"b": {"hello": "again", "now": 123}}, {"c": None}, ] def test_reducer_before_first_node() -> None: class State(TypedDict): hello: str messages: Annotated[list[str], add_messages] def node_a(state: State) -> State: assert state == { "hello": "there", "messages": [_AnyIdHumanMessage(content="hello")], } builder = StateGraph(State) builder.add_node("a", node_a) builder.set_entry_point("a") builder.set_finish_point("a") graph = builder.compile() assert graph.invoke({"hello": "there", "messages": "hello"}) == { "hello": "there", "messages": [_AnyIdHumanMessage(content="hello")], } class State(TypedDict): hello: str messages: Annotated[List[str], add_messages] def node_a(state: State) -> State: assert state == { "hello": "there", "messages": [_AnyIdHumanMessage(content="hello")], } builder = StateGraph(State) builder.add_node("a", node_a) builder.set_entry_point("a") builder.set_finish_point("a") graph = builder.compile() assert graph.invoke({"hello": "there", "messages": "hello"}) == { "hello": "there", "messages": [_AnyIdHumanMessage(content="hello")], } class State(TypedDict): hello: str messages: Annotated[Sequence[str], add_messages] def node_a(state: State) -> State: assert state == { "hello": "there", "messages": [_AnyIdHumanMessage(content="hello")], } builder = StateGraph(State) builder.add_node("a", node_a) builder.set_entry_point("a") builder.set_finish_point("a") graph = builder.compile() assert graph.invoke({"hello": "there", "messages": "hello"}) == { "hello": "there", "messages": [_AnyIdHumanMessage(content="hello")], } def test_invoke_single_process_in_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output") app = Pregel( nodes={ "one": chain, }, channels={ "input": LastValue(int), "output": LastValue(int), }, input_channels="input", output_channels="output", ) graph = Graph() graph.add_node("add_one", add_one) graph.set_entry_point("add_one") graph.set_finish_point("add_one") gapp = graph.compile() if SHOULD_CHECK_SNAPSHOTS: assert app.input_schema.model_json_schema() == { "title": "LangGraphInput", "type": "integer", } assert app.output_schema.model_json_schema() == { "title": "LangGraphOutput", "type": "integer", } with warnings.catch_warnings(): warnings.simplefilter("error") # raise warnings as errors assert app.config_schema().model_json_schema() == { "properties": {}, "title": "LangGraphConfig", "type": "object", } assert app.invoke(2) == 3 assert app.invoke(2, output_keys=["output"]) == {"output": 3} assert repr(app), "does not raise recursion error" assert gapp.invoke(2, debug=True) == 3 @pytest.mark.parametrize( "falsy_value", [None, False, 0, "", [], {}, set(), frozenset(), 0.0, 0j], ) def test_invoke_single_process_in_out_falsy_values(falsy_value: Any) -> None: graph = Graph() graph.add_node("return_falsy_const", lambda *args, **kwargs: falsy_value) graph.set_entry_point("return_falsy_const") graph.set_finish_point("return_falsy_const") gapp = graph.compile() assert gapp.invoke(1) == falsy_value def test_invoke_single_process_in_write_kwargs(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) chain = ( Channel.subscribe_to("input") | add_one | Channel.write_to("output", fixed=5, output_plus_one=lambda x: x + 1) ) app = Pregel( nodes={"one": chain}, channels={ "input": LastValue(int), "output": LastValue(int), "fixed": LastValue(int), "output_plus_one": LastValue(int), }, output_channels=["output", "fixed", "output_plus_one"], input_channels="input", ) if SHOULD_CHECK_SNAPSHOTS: assert app.input_schema.model_json_schema() == { "title": "LangGraphInput", "type": "integer", } assert app.output_schema.model_json_schema() == { "title": "LangGraphOutput", "type": "object", "properties": { "output": {"title": "Output", "type": "integer", "default": None}, "fixed": {"title": "Fixed", "type": "integer", "default": None}, "output_plus_one": { "title": "Output Plus One", "type": "integer", "default": None, }, }, } assert app.invoke(2) == {"output": 3, "fixed": 5, "output_plus_one": 4} def test_invoke_single_process_in_out_dict(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": chain}, channels={"input": LastValue(int), "output": LastValue(int)}, input_channels="input", output_channels=["output"], ) if SHOULD_CHECK_SNAPSHOTS: assert app.input_schema.model_json_schema() == { "title": "LangGraphInput", "type": "integer", } assert app.output_schema.model_json_schema() == { "title": "LangGraphOutput", "type": "object", "properties": { "output": {"title": "Output", "type": "integer", "default": None} }, } assert app.invoke(2) == {"output": 3} def test_invoke_single_process_in_dict_out_dict(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": chain}, channels={"input": LastValue(int), "output": LastValue(int)}, input_channels=["input"], output_channels=["output"], ) if SHOULD_CHECK_SNAPSHOTS: assert app.input_schema.model_json_schema() == { "title": "LangGraphInput", "type": "object", "properties": { "input": {"title": "Input", "type": "integer", "default": None} }, } assert app.output_schema.model_json_schema() == { "title": "LangGraphOutput", "type": "object", "properties": { "output": {"title": "Output", "type": "integer", "default": None} }, } assert app.invoke({"input": 2}) == {"output": 3} def test_invoke_two_processes_in_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") two = Channel.subscribe_to("inbox") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": one, "two": two}, channels={ "inbox": LastValue(int), "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", ) assert app.invoke(2) == 4 with pytest.raises(GraphRecursionError): app.invoke(2, {"recursion_limit": 1}, debug=1) graph = Graph() graph.add_node("add_one", add_one) graph.add_node("add_one_more", add_one) graph.set_entry_point("add_one") graph.set_finish_point("add_one_more") graph.add_edge("add_one", "add_one_more") gapp = graph.compile() assert gapp.invoke(2) == 4 for step, values in enumerate(gapp.stream(2, debug=1), start=1): if step == 1: assert values == { "add_one": 3, } elif step == 2: assert values == { "add_one_more": 4, } else: assert 0, f"{step}:{values}" assert step == 2 @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_invoke_two_processes_in_out_interrupt( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") two = Channel.subscribe_to("inbox") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": one, "two": two}, channels={ "inbox": LastValue(int), "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", checkpointer=checkpointer, interrupt_after_nodes=["one"], ) thread1 = {"configurable": {"thread_id": "1"}} thread2 = {"configurable": {"thread_id": "2"}} # start execution, stop at inbox assert app.invoke(2, thread1) is None # inbox == 3 checkpoint = checkpointer.get(thread1) assert checkpoint is not None assert checkpoint["channel_values"]["inbox"] == 3 # resume execution, finish assert app.invoke(None, thread1) == 4 # start execution again, stop at inbox assert app.invoke(20, thread1) is None # inbox == 21 checkpoint = checkpointer.get(thread1) assert checkpoint is not None assert checkpoint["channel_values"]["inbox"] == 21 # send a new value in, interrupting the previous execution assert app.invoke(3, thread1) is None assert app.invoke(None, thread1) == 5 # start execution again, stopping at inbox assert app.invoke(20, thread2) is None # inbox == 21 snapshot = app.get_state(thread2) assert snapshot.values["inbox"] == 21 assert snapshot.next == ("two",) # update the state, resume app.update_state(thread2, 25, as_node="one") assert app.invoke(None, thread2) == 26 # no pending tasks snapshot = app.get_state(thread2) assert snapshot.next == () # list history history = [c for c in app.get_state_history(thread1)] assert history == [ StateSnapshot( values={"inbox": 4, "output": 5, "input": 3}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 6, "writes": {"two": 5}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[1].config, ), StateSnapshot( values={"inbox": 4, "output": 4, "input": 3}, tasks=(PregelTask(AnyStr(), "two", (PULL, "two"), result={"output": 5}),), next=("two",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 5, "writes": {"one": None}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[2].config, ), StateSnapshot( values={"inbox": 21, "output": 4, "input": 3}, tasks=(PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 4}),), next=("one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "step": 4, "writes": {"input": 3}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[3].config, ), StateSnapshot( values={"inbox": 21, "output": 4, "input": 20}, tasks=(PregelTask(AnyStr(), "two", (PULL, "two")),), next=("two",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"one": None}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[4].config, ), StateSnapshot( values={"inbox": 3, "output": 4, "input": 20}, tasks=(PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 21}),), next=("one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "step": 2, "writes": {"input": 20}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[5].config, ), StateSnapshot( values={"inbox": 3, "output": 4, "input": 2}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"two": 4}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[6].config, ), StateSnapshot( values={"inbox": 3, "input": 2}, tasks=(PregelTask(AnyStr(), "two", (PULL, "two"), result={"output": 4}),), next=("two",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 0, "writes": {"one": None}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[7].config, ), StateSnapshot( values={"input": 2}, tasks=(PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 3}),), next=("one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "step": -1, "writes": {"input": 2}, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ), ] # re-running from any previous checkpoint should re-run nodes assert [c for c in app.stream(None, history[0].config, stream_mode="updates")] == [] assert [c for c in app.stream(None, history[1].config, stream_mode="updates")] == [ {"two": {"output": 5}}, ] assert [c for c in app.stream(None, history[2].config, stream_mode="updates")] == [ {"one": {"inbox": 4}}, {"__interrupt__": ()}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_fork_always_re_runs_nodes( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") add_one = mocker.Mock(side_effect=lambda _: 1) builder = StateGraph(Annotated[int, operator.add]) builder.add_node("add_one", add_one) builder.add_edge(START, "add_one") builder.add_conditional_edges("add_one", lambda cnt: "add_one" if cnt < 6 else END) graph = builder.compile(checkpointer=checkpointer) thread1 = {"configurable": {"thread_id": "1"}} # start execution, stop at inbox assert [*graph.stream(1, thread1, stream_mode=["values", "updates"])] == [ ("values", 1), ("updates", {"add_one": 1}), ("values", 2), ("updates", {"add_one": 1}), ("values", 3), ("updates", {"add_one": 1}), ("values", 4), ("updates", {"add_one": 1}), ("values", 5), ("updates", {"add_one": 1}), ("values", 6), ] # list history history = [c for c in graph.get_state_history(thread1)] assert history == [ StateSnapshot( values=6, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 5, "writes": {"add_one": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[1].config, ), StateSnapshot( values=5, tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),), next=("add_one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 4, "writes": {"add_one": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[2].config, ), StateSnapshot( values=4, tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),), next=("add_one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"add_one": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[3].config, ), StateSnapshot( values=3, tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),), next=("add_one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 2, "writes": {"add_one": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[4].config, ), StateSnapshot( values=2, tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),), next=("add_one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"add_one": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[5].config, ), StateSnapshot( values=1, tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),), next=("add_one",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, created_at=AnyStr(), parent_config=history[6].config, ), StateSnapshot( values=0, tasks=(PregelTask(AnyStr(), "__start__", (PULL, "__start__"), result=1),), next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "step": -1, "writes": {"__start__": 1}, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ), ] # forking from any previous checkpoint should re-run nodes assert [ c for c in graph.stream(None, history[0].config, stream_mode="updates") ] == [] assert [ c for c in graph.stream(None, history[1].config, stream_mode="updates") ] == [ {"add_one": 1}, ] assert [ c for c in graph.stream(None, history[2].config, stream_mode="updates") ] == [ {"add_one": 1}, {"add_one": 1}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_run_from_checkpoint_id_retains_previous_writes( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") class MyState(TypedDict): myval: Annotated[int, operator.add] otherval: bool class Anode: def __init__(self): self.switch = False def __call__(self, state: MyState): self.switch = not self.switch return {"myval": 2 if self.switch else 1, "otherval": self.switch} builder = StateGraph(MyState) thenode = Anode() # Fun. builder.add_node("node_one", thenode) builder.add_node("node_two", thenode) builder.add_edge(START, "node_one") def _getedge(src: str): swap = "node_one" if src == "node_two" else "node_two" def _edge(st: MyState) -> Literal["__end__", "node_one", "node_two"]: if st["myval"] > 3: return END if st["otherval"]: return swap return src return _edge builder.add_conditional_edges("node_one", _getedge("node_one")) builder.add_conditional_edges("node_two", _getedge("node_two")) graph = builder.compile(checkpointer=checkpointer) thread_id = uuid.uuid4() thread1 = {"configurable": {"thread_id": str(thread_id)}} result = graph.invoke({"myval": 1}, thread1) assert result["myval"] == 4 history = [c for c in graph.get_state_history(thread1)] assert len(history) == 4 assert history[-1].values == {"myval": 0} assert history[0].values == {"myval": 4, "otherval": False} second_run_config = { **thread1, "configurable": { **thread1["configurable"], "checkpoint_id": history[1].config["configurable"]["checkpoint_id"], }, } second_result = graph.invoke(None, second_run_config) assert second_result == {"myval": 5, "otherval": True} new_history = [ c for c in graph.get_state_history( {"configurable": {"thread_id": str(thread_id), "checkpoint_ns": ""}} ) ] assert len(new_history) == len(history) + 1 for original, new in zip(history, new_history[1:]): assert original.values == new.values assert original.next == new.next assert original.metadata["step"] == new.metadata["step"] def _get_tasks(hist: list, start: int): return [h.tasks for h in hist[start:]] assert _get_tasks(new_history, 1) == _get_tasks(history, 0) def test_invoke_two_processes_in_dict_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") two = ( Channel.subscribe_to("inbox") | RunnableLambda(add_one).batch | RunnablePassthrough(lambda _: time.sleep(0.1)) | Channel.write_to("output").batch ) app = Pregel( nodes={"one": one, "two": two}, channels={ "inbox": Topic(int), "output": LastValue(int), "input": LastValue(int), }, input_channels=["input", "inbox"], stream_channels=["output", "inbox"], output_channels=["output"], ) # [12 + 1, 2 + 1 + 1] assert [ *app.stream( {"input": 2, "inbox": 12}, output_keys="output", stream_mode="updates" ) ] == [ {"one": None}, {"two": 13}, {"two": 4}, ] assert [*app.stream({"input": 2, "inbox": 12}, output_keys="output")] == [ 13, 4, ] assert [*app.stream({"input": 2, "inbox": 12}, stream_mode="updates")] == [ {"one": {"inbox": 3}}, {"two": {"output": 13}}, {"two": {"output": 4}}, ] assert [*app.stream({"input": 2, "inbox": 12})] == [ {"inbox": [3], "output": 13}, {"output": 4}, ] assert [*app.stream({"input": 2, "inbox": 12}, stream_mode="debug")] == [ { "type": "task", "timestamp": AnyStr(), "step": 0, "payload": { "id": AnyStr(), "name": "one", "input": 2, "triggers": ["input"], }, }, { "type": "task", "timestamp": AnyStr(), "step": 0, "payload": { "id": AnyStr(), "name": "two", "input": [12], "triggers": ["inbox"], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 0, "payload": { "id": AnyStr(), "name": "one", "result": [("inbox", 3)], "error": None, "interrupts": [], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 0, "payload": { "id": AnyStr(), "name": "two", "result": [("output", 13)], "error": None, "interrupts": [], }, }, { "type": "task", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "two", "input": [3], "triggers": ["inbox"], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "two", "result": [("output", 4)], "error": None, "interrupts": [], }, }, ] def test_batch_two_processes_in_out() -> None: def add_one_with_delay(inp: int) -> int: time.sleep(inp / 10) return inp + 1 one = Channel.subscribe_to("input") | add_one_with_delay | Channel.write_to("one") two = Channel.subscribe_to("one") | add_one_with_delay | Channel.write_to("output") app = Pregel( nodes={"one": one, "two": two}, channels={ "one": LastValue(int), "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", ) assert app.batch([3, 2, 1, 3, 5]) == [5, 4, 3, 5, 7] assert app.batch([3, 2, 1, 3, 5], output_keys=["output"]) == [ {"output": 5}, {"output": 4}, {"output": 3}, {"output": 5}, {"output": 7}, ] graph = Graph() graph.add_node("add_one", add_one_with_delay) graph.add_node("add_one_more", add_one_with_delay) graph.set_entry_point("add_one") graph.set_finish_point("add_one_more") graph.add_edge("add_one", "add_one_more") gapp = graph.compile() assert gapp.batch([3, 2, 1, 3, 5]) == [5, 4, 3, 5, 7] def test_invoke_many_processes_in_out(mocker: MockerFixture) -> None: test_size = 100 add_one = mocker.Mock(side_effect=lambda x: x + 1) nodes = {"-1": Channel.subscribe_to("input") | add_one | Channel.write_to("-1")} for i in range(test_size - 2): nodes[str(i)] = ( Channel.subscribe_to(str(i - 1)) | add_one | Channel.write_to(str(i)) ) nodes["last"] = Channel.subscribe_to(str(i)) | add_one | Channel.write_to("output") app = Pregel( nodes=nodes, channels={str(i): LastValue(int) for i in range(-1, test_size - 2)} | {"input": LastValue(int), "output": LastValue(int)}, input_channels="input", output_channels="output", ) for _ in range(10): assert app.invoke(2, {"recursion_limit": test_size}) == 2 + test_size with ThreadPoolExecutor() as executor: assert [ *executor.map(app.invoke, [2] * 10, [{"recursion_limit": test_size}] * 10) ] == [2 + test_size] * 10 def test_batch_many_processes_in_out(mocker: MockerFixture) -> None: test_size = 100 add_one = mocker.Mock(side_effect=lambda x: x + 1) nodes = {"-1": Channel.subscribe_to("input") | add_one | Channel.write_to("-1")} for i in range(test_size - 2): nodes[str(i)] = ( Channel.subscribe_to(str(i - 1)) | add_one | Channel.write_to(str(i)) ) nodes["last"] = Channel.subscribe_to(str(i)) | add_one | Channel.write_to("output") app = Pregel( nodes=nodes, channels={str(i): LastValue(int) for i in range(-1, test_size - 2)} | {"input": LastValue(int), "output": LastValue(int)}, input_channels="input", output_channels="output", ) for _ in range(3): assert app.batch([2, 1, 3, 4, 5], {"recursion_limit": test_size}) == [ 2 + test_size, 1 + test_size, 3 + test_size, 4 + test_size, 5 + test_size, ] with ThreadPoolExecutor() as executor: assert [ *executor.map( app.batch, [[2, 1, 3, 4, 5]] * 3, [{"recursion_limit": test_size}] * 3 ) ] == [ [2 + test_size, 1 + test_size, 3 + test_size, 4 + test_size, 5 + test_size] ] * 3 def test_invoke_two_processes_two_in_two_out_invalid(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("output") two = Channel.subscribe_to("input") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": one, "two": two}, channels={"output": LastValue(int), "input": LastValue(int)}, input_channels="input", output_channels="output", ) with pytest.raises(InvalidUpdateError): # LastValue channels can only be updated once per iteration app.invoke(2) class State(TypedDict): hello: str def my_node(input: State) -> State: return {"hello": "world"} builder = StateGraph(State) builder.add_node("one", my_node) builder.add_node("two", my_node) builder.set_conditional_entry_point(lambda _: ["one", "two"]) graph = builder.compile() with pytest.raises(InvalidUpdateError, match="At key 'hello'"): graph.invoke({"hello": "there"}, debug=True) def test_invoke_two_processes_two_in_two_out_valid(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("output") two = Channel.subscribe_to("input") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": one, "two": two}, channels={ "input": LastValue(int), "output": Topic(int), }, input_channels="input", output_channels="output", ) # An Inbox channel accumulates updates into a sequence assert app.invoke(2) == [3, 3] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_invoke_checkpoint_two( mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer: BaseCheckpointSaver = request.getfixturevalue( f"checkpointer_{checkpointer_name}" ) add_one = mocker.Mock(side_effect=lambda x: x["total"] + x["input"]) errored_once = False def raise_if_above_10(input: int) -> int: nonlocal errored_once if input > 4: if errored_once: pass else: errored_once = True raise ConnectionError("I will be retried") if input > 10: raise ValueError("Input is too large") return input one = ( Channel.subscribe_to(["input"]).join(["total"]) | add_one | Channel.write_to("output", "total") | raise_if_above_10 ) app = Pregel( nodes={"one": one}, channels={ "total": BinaryOperatorAggregate(int, operator.add), "input": LastValue(int), "output": LastValue(int), }, input_channels="input", output_channels="output", checkpointer=checkpointer, retry_policy=RetryPolicy(), ) # total starts out as 0, so output is 0+2=2 assert app.invoke(2, {"configurable": {"thread_id": "1"}}) == 2 checkpoint = checkpointer.get({"configurable": {"thread_id": "1"}}) assert checkpoint is not None assert checkpoint["channel_values"].get("total") == 2 # total is now 2, so output is 2+3=5 assert app.invoke(3, {"configurable": {"thread_id": "1"}}) == 5 assert errored_once, "errored and retried" checkpoint_tup = checkpointer.get_tuple({"configurable": {"thread_id": "1"}}) assert checkpoint_tup is not None assert checkpoint_tup.checkpoint["channel_values"].get("total") == 7 # total is now 2+5=7, so output would be 7+4=11, but raises ValueError with pytest.raises(ValueError): app.invoke(4, {"configurable": {"thread_id": "1"}}) # checkpoint is not updated, error is recorded checkpoint_tup = checkpointer.get_tuple({"configurable": {"thread_id": "1"}}) assert checkpoint_tup is not None assert checkpoint_tup.checkpoint["channel_values"].get("total") == 7 assert checkpoint_tup.pending_writes == [ (AnyStr(), ERROR, "ValueError('Input is too large')") ] # on a new thread, total starts out as 0, so output is 0+5=5 assert app.invoke(5, {"configurable": {"thread_id": "2"}}) == 5 checkpoint = checkpointer.get({"configurable": {"thread_id": "1"}}) assert checkpoint is not None assert checkpoint["channel_values"].get("total") == 7 checkpoint = checkpointer.get({"configurable": {"thread_id": "2"}}) assert checkpoint is not None assert checkpoint["channel_values"].get("total") == 5 @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_pending_writes_resume( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer: BaseCheckpointSaver = request.getfixturevalue( f"checkpointer_{checkpointer_name}" ) class State(TypedDict): value: Annotated[int, operator.add] class AwhileMaker: def __init__(self, sleep: float, rtn: Union[Dict, Exception]) -> None: self.sleep = sleep self.rtn = rtn self.reset() def __call__(self, input: State) -> Any: self.calls += 1 time.sleep(self.sleep) if isinstance(self.rtn, Exception): raise self.rtn else: return self.rtn def reset(self): self.calls = 0 one = AwhileMaker(0.1, {"value": 2}) two = AwhileMaker(0.3, ConnectionError("I'm not good")) builder = StateGraph(State) builder.add_node("one", one) builder.add_node("two", two, retry=RetryPolicy(max_attempts=2)) builder.add_edge(START, "one") builder.add_edge(START, "two") graph = builder.compile(checkpointer=checkpointer) thread1: RunnableConfig = {"configurable": {"thread_id": "1"}} with pytest.raises(ConnectionError, match="I'm not good"): graph.invoke({"value": 1}, thread1) # both nodes should have been called once assert one.calls == 1 assert two.calls == 2 # two attempts # latest checkpoint should be before nodes "one", "two" # but we should have applied the write from "one" state = graph.get_state(thread1) assert state is not None assert state.values == {"value": 3} assert state.next == ("two",) assert state.tasks == ( PregelTask(AnyStr(), "one", (PULL, "one"), result={"value": 2}), PregelTask(AnyStr(), "two", (PULL, "two"), 'ConnectionError("I\'m not good")'), ) assert state.metadata == { "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", } # get_state with checkpoint_id should not apply any pending writes state = graph.get_state(state.config) assert state is not None assert state.values == {"value": 1} assert state.next == ("one", "two") # should contain pending write of "one" checkpoint = checkpointer.get_tuple(thread1) assert checkpoint is not None # should contain error from "two" expected_writes = [ (AnyStr(), "one", "one"), (AnyStr(), "value", 2), (AnyStr(), ERROR, 'ConnectionError("I\'m not good")'), ] assert len(checkpoint.pending_writes) == 3 assert all(w in expected_writes for w in checkpoint.pending_writes) # both non-error pending writes come from same task non_error_writes = [w for w in checkpoint.pending_writes if w[1] != ERROR] assert non_error_writes[0][0] == non_error_writes[1][0] # error write is from the other task error_write = next(w for w in checkpoint.pending_writes if w[1] == ERROR) assert error_write[0] != non_error_writes[0][0] # resume execution with pytest.raises(ConnectionError, match="I'm not good"): graph.invoke(None, thread1) # node "one" succeeded previously, so shouldn't be called again assert one.calls == 1 # node "two" should have been called once again assert two.calls == 4 # two attempts before + two attempts now # confirm no new checkpoints saved state_two = graph.get_state(thread1) assert state_two.metadata == state.metadata # resume execution, without exception two.rtn = {"value": 3} # both the pending write and the new write were applied, 1 + 2 + 3 = 6 assert graph.invoke(None, thread1) == {"value": 6} # check all final checkpoints checkpoints = [c for c in checkpointer.list(thread1)] # we should have 3 assert len(checkpoints) == 3 # the last one not too interesting for this test assert checkpoints[0] == CheckpointTuple( config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, checkpoint={ "v": 1, "id": AnyStr(), "ts": AnyStr(), "pending_sends": [], "versions_seen": { "one": { "start:one": AnyVersion(), }, "two": { "start:two": AnyVersion(), }, "__input__": {}, "__start__": { "__start__": AnyVersion(), }, "__interrupt__": { "value": AnyVersion(), "__start__": AnyVersion(), "start:one": AnyVersion(), "start:two": AnyVersion(), }, }, "channel_versions": { "one": AnyVersion(), "two": AnyVersion(), "value": AnyVersion(), "__start__": AnyVersion(), "start:one": AnyVersion(), "start:two": AnyVersion(), }, "channel_values": {"one": "one", "two": "two", "value": 6}, }, metadata={ "parents": {}, "step": 1, "source": "loop", "writes": {"one": {"value": 2}, "two": {"value": 3}}, "thread_id": "1", }, parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": checkpoints[1].config["configurable"]["checkpoint_id"], } }, pending_writes=[], ) # the previous one we assert that pending writes contains both # - original error # - successful writes from resuming after preventing error assert checkpoints[1] == CheckpointTuple( config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, checkpoint={ "v": 1, "id": AnyStr(), "ts": AnyStr(), "pending_sends": [], "versions_seen": { "__input__": {}, "__start__": { "__start__": AnyVersion(), }, }, "channel_versions": { "value": AnyVersion(), "__start__": AnyVersion(), "start:one": AnyVersion(), "start:two": AnyVersion(), }, "channel_values": { "value": 1, "start:one": "__start__", "start:two": "__start__", }, }, metadata={ "parents": {}, "step": 0, "source": "loop", "writes": None, "thread_id": "1", }, parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": checkpoints[2].config["configurable"]["checkpoint_id"], } }, pending_writes=UnsortedSequence( (AnyStr(), "one", "one"), (AnyStr(), "value", 2), (AnyStr(), "__error__", 'ConnectionError("I\'m not good")'), (AnyStr(), "two", "two"), (AnyStr(), "value", 3), ), ) assert checkpoints[2] == CheckpointTuple( config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, checkpoint={ "v": 1, "id": AnyStr(), "ts": AnyStr(), "pending_sends": [], "versions_seen": {"__input__": {}}, "channel_versions": { "__start__": AnyVersion(), }, "channel_values": {"__start__": {"value": 1}}, }, metadata={ "parents": {}, "step": -1, "source": "input", "writes": {"__start__": {"value": 1}}, "thread_id": "1", }, parent_config=None, pending_writes=UnsortedSequence( (AnyStr(), "value", 1), (AnyStr(), "start:one", "__start__"), (AnyStr(), "start:two", "__start__"), ), ) def test_cond_edge_after_send() -> None: class Node: def __init__(self, name: str): self.name = name setattr(self, "__name__", name) def __call__(self, state): return [self.name] def send_for_fun(state): return [Send("2", state), Send("2", state)] def route_to_three(state) -> Literal["3"]: return "3" builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("2")) builder.add_node(Node("3")) builder.add_edge(START, "1") builder.add_conditional_edges("1", send_for_fun) builder.add_conditional_edges("2", route_to_three) graph = builder.compile() assert graph.invoke(["0"]) == ["0", "1", "2", "2", "3"] def test_concurrent_emit_sends() -> None: class Node: def __init__(self, name: str): self.name = name setattr(self, "__name__", name) def __call__(self, state): return ( [self.name] if isinstance(state, list) else ["|".join((self.name, str(state)))] ) def send_for_fun(state): return [Send("2", 1), Send("2", 2), "3.1"] def send_for_profit(state): return [Send("2", 3), Send("2", 4)] def route_to_three(state) -> Literal["3"]: return "3" builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("1.1")) builder.add_node(Node("2")) builder.add_node(Node("3")) builder.add_node(Node("3.1")) builder.add_edge(START, "1") builder.add_edge(START, "1.1") builder.add_conditional_edges("1", send_for_fun) builder.add_conditional_edges("1.1", send_for_profit) builder.add_conditional_edges("2", route_to_three) graph = builder.compile() assert graph.invoke(["0"]) == ( [ "0", "1", "1.1", "2|1", "2|2", "2|3", "2|4", "3", "3.1", ] if FF_SEND_V2 else [ "0", "1", "1.1", "3.1", "2|1", "2|2", "2|3", "2|4", "3", ] ) def test_send_sequences() -> None: class Node: def __init__(self, name: str): self.name = name setattr(self, "__name__", name) def __call__(self, state): update = ( [self.name] if isinstance(state, list) else ["|".join((self.name, str(state)))] ) if isinstance(state, Command): return replace(state, update=update) else: return update def send_for_fun(state): return [ Send("2", Command(goto=Send("2", 3))), Send("2", Command(goto=Send("2", 4))), "3.1", ] def route_to_three(state) -> Literal["3"]: return "3" builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("2")) builder.add_node(Node("3")) builder.add_node(Node("3.1")) builder.add_edge(START, "1") builder.add_conditional_edges("1", send_for_fun) builder.add_conditional_edges("2", route_to_three) graph = builder.compile() assert ( graph.invoke(["0"]) == [ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='2', arg=4))", "2|3", "2|4", "3", "3.1", ] if FF_SEND_V2 else [ "0", "1", "3.1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='2', arg=4))", "3", "2|3", "2|4", "3", ] ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_send_dedupe_on_resume( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: if not FF_SEND_V2: pytest.skip("Send deduplication is only available in Send V2") checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") class InterruptOnce: ticks: int = 0 def __call__(self, state): self.ticks += 1 if self.ticks == 1: raise NodeInterrupt("Bahh") return ["|".join(("flaky", str(state)))] class Node: def __init__(self, name: str): self.name = name self.ticks = 0 setattr(self, "__name__", name) def __call__(self, state): self.ticks += 1 update = ( [self.name] if isinstance(state, list) else ["|".join((self.name, str(state)))] ) if isinstance(state, Command): return replace(state, update=update) else: return update def send_for_fun(state): return [ Send("2", Command(goto=Send("2", 3))), Send("2", Command(goto=Send("flaky", 4))), "3.1", ] def route_to_three(state) -> Literal["3"]: return "3" builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("2")) builder.add_node(Node("3")) builder.add_node(Node("3.1")) builder.add_node("flaky", InterruptOnce()) builder.add_edge(START, "1") builder.add_conditional_edges("1", send_for_fun) builder.add_conditional_edges("2", route_to_three) graph = builder.compile(checkpointer=checkpointer) thread1 = {"configurable": {"thread_id": "1"}} assert graph.invoke(["0"], thread1, debug=1) == [ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='flaky', arg=4))", "2|3", ] assert builder.nodes["2"].runnable.func.ticks == 3 assert builder.nodes["flaky"].runnable.func.ticks == 1 # check state state = graph.get_state(thread1) assert state.next == ("flaky",) # check history history = [c for c in graph.get_state_history(thread1)] assert len(history) == 2 # resume execution assert graph.invoke(None, thread1, debug=1) == [ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", "3.1", ] # node "2" doesn't get called again, as we recover writes saved before assert builder.nodes["2"].runnable.func.ticks == 3 # node "flaky" gets called again, as it was interrupted assert builder.nodes["flaky"].runnable.func.ticks == 2 # check state state = graph.get_state(thread1) assert state.next == () # check history history = [c for c in graph.get_state_history(thread1)] assert ( history[1] == [ StateSnapshot( values=[ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", "3.1", ], next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": {"3": ["3"], "3.1": ["3.1"]}, "thread_id": "1", "step": 2, "parents": {}, }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=(), ), StateSnapshot( values=[ "0", "1", "2|Command(goto=Send(node='2', arg=3))", "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", ], next=("3", "3.1"), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": { "1": ["1"], "2": [ ["2|Command(goto=Send(node='2', arg=3))"], ["2|Command(goto=Send(node='flaky', arg=4))"], ["2|3"], ], "flaky": ["flaky|4"], }, "thread_id": "1", "step": 1, "parents": {}, }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="3", path=("__pregel_pull", "3"), error=None, interrupts=(), state=None, result=["3"], ), PregelTask( id=AnyStr(), name="3.1", path=("__pregel_pull", "3.1"), error=None, interrupts=(), state=None, result=["3.1"], ), ), ), StateSnapshot( values=["0"], next=("1", "2", "2", "2", "flaky"), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": None, "thread_id": "1", "step": 0, "parents": {}, }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="1", path=("__pregel_pull", "1"), error=None, interrupts=(), state=None, result=["1"], ), PregelTask( id=AnyStr(), name="2", path=( "__pregel_push", ("__pregel_pull", "1"), 2, AnyStr(), ), error=None, interrupts=(), state=None, result=["2|Command(goto=Send(node='2', arg=3))"], ), PregelTask( id=AnyStr(), name="2", path=( "__pregel_push", ("__pregel_pull", "1"), 3, AnyStr(), ), error=None, interrupts=(), state=None, result=["2|Command(goto=Send(node='flaky', arg=4))"], ), PregelTask( id=AnyStr(), name="2", path=( "__pregel_push", ( "__pregel_push", ("__pregel_pull", "1"), 2, AnyStr(), ), 2, AnyStr(), ), error=None, interrupts=(), state=None, result=["2|3"], ), PregelTask( id=AnyStr(), name="flaky", path=( "__pregel_push", ( "__pregel_push", ("__pregel_pull", "1"), 3, AnyStr(), ), 2, AnyStr(), ), error=None, interrupts=(Interrupt(value="Bahh", when="during"),), state=None, result=["flaky|4"], ), ), ), StateSnapshot( values=[], next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "input", "writes": {"__start__": ["0"]}, "thread_id": "1", "step": -1, "parents": {}, }, created_at=AnyStr(), parent_config=None, tasks=( PregelTask( id=AnyStr(), name="__start__", path=("__pregel_pull", "__start__"), error=None, interrupts=(), state=None, result=["0"], ), ), ), ][1] ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_send_react_interrupt( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") ai_message = AIMessage( "", id="ai1", tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())], ) def agent(state): return {"messages": ai_message} def route(state): if isinstance(state["messages"][-1], AIMessage): return [ Send(call["name"], call) for call in state["messages"][-1].tool_calls ] foo_called = 0 def foo(call: ToolCall): nonlocal foo_called foo_called += 1 return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])} builder = StateGraph(MessagesState) builder.add_node(agent) builder.add_node(foo) builder.add_edge(START, "agent") builder.add_conditional_edges("agent", route) graph = builder.compile() assert graph.invoke({"messages": [HumanMessage("hello")]}) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), _AnyIdToolMessage( content="{'hi': [1, 2, 3]}", tool_call_id=AnyStr(), ), ] } assert foo_called == 1 # simple interrupt-resume flow foo_called = 0 graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) thread1 = {"configurable": {"thread_id": "1"}} assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] } assert foo_called == 0 assert graph.invoke(None, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), _AnyIdToolMessage( content="{'hi': [1, 2, 3]}", tool_call_id=AnyStr(), ), ] } assert foo_called == 1 # interrupt-update-resume flow foo_called = 0 graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) thread1 = {"configurable": {"thread_id": "2"}} assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] } assert foo_called == 0 if not FF_SEND_V2: return # get state should show the pending task state = graph.get_state(thread1) assert state == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] }, next=("foo",), config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 0, "source": "loop", "writes": None, "parents": {}, "thread_id": "2", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( content="", additional_kwargs={}, response_metadata={}, id="ai1", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ) }, ), PregelTask( id=AnyStr(), name="foo", path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()), error=None, interrupts=(), state=None, result=None, ), ), ) # remove the tool call, clearing the pending task graph.update_state( thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])} ) # tool call no longer in pending tasks assert graph.get_state(thread1) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="Bye now", tool_calls=[], ), ] }, next=(), config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 1, "source": "update", "writes": { "agent": { "messages": _AnyIdAIMessage( content="Bye now", tool_calls=[], ) } }, "parents": {}, "thread_id": "2", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=(), ) # tool call not executed assert graph.invoke(None, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage(content="Bye now"), ] } assert foo_called == 0 # interrupt-update-resume flow, creating new Send in update call foo_called = 0 graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) thread1 = {"configurable": {"thread_id": "3"}} assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] } assert foo_called == 0 # get state should show the pending task state = graph.get_state(thread1) assert state == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] }, next=("foo",), config={ "configurable": { "thread_id": "3", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 0, "source": "loop", "writes": None, "parents": {}, "thread_id": "3", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "3", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( "", id="ai1", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ) }, ), PregelTask( id=AnyStr(), name="foo", path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()), error=None, interrupts=(), state=None, result=None, ), ), ) # replace the tool call, should clear previous send, create new one graph.update_state( thread1, { "messages": AIMessage( "", id=ai_message.id, tool_calls=[ { "name": "foo", "args": {"hi": [4, 5, 6]}, "id": "tool1", "type": "tool_call", } ], ) }, ) # prev tool call no longer in pending tasks, new tool call is assert graph.get_state(thread1) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [4, 5, 6]}, "id": "tool1", "type": "tool_call", } ], ), ] }, next=("foo",), config={ "configurable": { "thread_id": "3", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 1, "source": "update", "writes": { "agent": { "messages": _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [4, 5, 6]}, "id": "tool1", "type": "tool_call", } ], ) } }, "parents": {}, "thread_id": "3", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "3", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="foo", path=("__pregel_push", (), 0, AnyStr()), error=None, interrupts=(), state=None, result=None, ), ), ) # prev tool call not executed, new tool call is assert graph.invoke(None, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), AIMessage( "", id="ai1", tool_calls=[ { "name": "foo", "args": {"hi": [4, 5, 6]}, "id": "tool1", "type": "tool_call", } ], ), _AnyIdToolMessage(content="{'hi': [4, 5, 6]}", tool_call_id="tool1"), ] } assert foo_called == 1 @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_send_react_interrupt_control( request: pytest.FixtureRequest, checkpointer_name: str, snapshot: SnapshotAssertion ) -> None: from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") ai_message = AIMessage( "", id="ai1", tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())], ) def agent(state) -> Command[Literal["foo"]]: return Command( update={"messages": ai_message}, goto=[Send(call["name"], call) for call in ai_message.tool_calls], ) foo_called = 0 def foo(call: ToolCall): nonlocal foo_called foo_called += 1 return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])} builder = StateGraph(MessagesState) builder.add_node(agent) builder.add_node(foo) builder.add_edge(START, "agent") graph = builder.compile() assert graph.get_graph().draw_mermaid() == snapshot assert graph.invoke({"messages": [HumanMessage("hello")]}) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), _AnyIdToolMessage( content="{'hi': [1, 2, 3]}", tool_call_id=AnyStr(), ), ] } assert foo_called == 1 # simple interrupt-resume flow foo_called = 0 graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) thread1 = {"configurable": {"thread_id": "1"}} assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] } assert foo_called == 0 assert graph.invoke(None, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), _AnyIdToolMessage( content="{'hi': [1, 2, 3]}", tool_call_id=AnyStr(), ), ] } assert foo_called == 1 if not FF_SEND_V2: return # interrupt-update-resume flow foo_called = 0 graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"]) thread1 = {"configurable": {"thread_id": "2"}} assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] } assert foo_called == 0 # get state should show the pending task state = graph.get_state(thread1) assert state == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ), ] }, next=("foo",), config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 0, "source": "loop", "writes": None, "parents": {}, "thread_id": "2", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( content="", additional_kwargs={}, response_metadata={}, id="ai1", tool_calls=[ { "name": "foo", "args": {"hi": [1, 2, 3]}, "id": "", "type": "tool_call", } ], ) }, ), PregelTask( id=AnyStr(), name="foo", path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()), error=None, interrupts=(), state=None, result=None, ), ), ) # remove the tool call, clearing the pending task graph.update_state( thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])} ) # tool call no longer in pending tasks assert graph.get_state(thread1) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage( content="Bye now", tool_calls=[], ), ] }, next=(), config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "step": 1, "source": "update", "writes": { "agent": { "messages": _AnyIdAIMessage( content="Bye now", tool_calls=[], ) } }, "parents": {}, "thread_id": "2", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "2", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=(), ) # tool call not executed assert graph.invoke(None, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), _AnyIdAIMessage(content="Bye now"), ] } assert foo_called == 0 # interrupt-update-resume flow, creating new Send in update call # TODO add here test with invoke(Command()) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_invoke_checkpoint_three( mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") adder = mocker.Mock(side_effect=lambda x: x["total"] + x["input"]) def raise_if_above_10(input: int) -> int: if input > 10: raise ValueError("Input is too large") return input one = ( Channel.subscribe_to(["input"]).join(["total"]) | adder | Channel.write_to("output", "total") | raise_if_above_10 ) app = Pregel( nodes={"one": one}, channels={ "total": BinaryOperatorAggregate(int, operator.add), "input": LastValue(int), "output": LastValue(int), }, input_channels="input", output_channels="output", checkpointer=checkpointer, ) thread_1 = {"configurable": {"thread_id": "1"}} # total starts out as 0, so output is 0+2=2 assert app.invoke(2, thread_1, debug=1) == 2 state = app.get_state(thread_1) assert state is not None assert state.values.get("total") == 2 assert state.next == () assert ( state.config["configurable"]["checkpoint_id"] == checkpointer.get(thread_1)["id"] ) # total is now 2, so output is 2+3=5 assert app.invoke(3, thread_1) == 5 state = app.get_state(thread_1) assert state is not None assert state.values.get("total") == 7 assert ( state.config["configurable"]["checkpoint_id"] == checkpointer.get(thread_1)["id"] ) # total is now 2+5=7, so output would be 7+4=11, but raises ValueError with pytest.raises(ValueError): app.invoke(4, thread_1) # checkpoint is updated with new input state = app.get_state(thread_1) assert state is not None assert state.values.get("total") == 7 assert state.next == ("one",) """we checkpoint inputs and it failed on "one", so the next node is one""" # we can recover from error by sending new inputs assert app.invoke(2, thread_1) == 9 state = app.get_state(thread_1) assert state is not None assert state.values.get("total") == 16, "total is now 7+9=16" assert state.next == () thread_2 = {"configurable": {"thread_id": "2"}} # on a new thread, total starts out as 0, so output is 0+5=5 assert app.invoke(5, thread_2, debug=True) == 5 state = app.get_state({"configurable": {"thread_id": "1"}}) assert state is not None assert state.values.get("total") == 16 assert state.next == (), "checkpoint of other thread not touched" state = app.get_state(thread_2) assert state is not None assert state.values.get("total") == 5 assert state.next == () assert len(list(app.get_state_history(thread_1, limit=1))) == 1 # list all checkpoints for thread 1 thread_1_history = [c for c in app.get_state_history(thread_1)] # there are 7 checkpoints assert len(thread_1_history) == 7 assert Counter(c.metadata["source"] for c in thread_1_history) == { "input": 4, "loop": 3, } # sorted descending assert ( thread_1_history[0].config["configurable"]["checkpoint_id"] > thread_1_history[1].config["configurable"]["checkpoint_id"] ) # cursor pagination cursored = list( app.get_state_history(thread_1, limit=1, before=thread_1_history[0].config) ) assert len(cursored) == 1 assert cursored[0].config == thread_1_history[1].config # the last checkpoint assert thread_1_history[0].values["total"] == 16 # the first "loop" checkpoint assert thread_1_history[-2].values["total"] == 2 # can get each checkpoint using aget with config assert ( checkpointer.get(thread_1_history[0].config)["id"] == thread_1_history[0].config["configurable"]["checkpoint_id"] ) assert ( checkpointer.get(thread_1_history[1].config)["id"] == thread_1_history[1].config["configurable"]["checkpoint_id"] ) thread_1_next_config = app.update_state(thread_1_history[1].config, 10) # update creates a new checkpoint assert ( thread_1_next_config["configurable"]["checkpoint_id"] > thread_1_history[0].config["configurable"]["checkpoint_id"] ) # update makes new checkpoint child of the previous one assert ( app.get_state(thread_1_next_config).parent_config == thread_1_history[1].config ) # 1 more checkpoint in history assert len(list(app.get_state_history(thread_1))) == 8 assert Counter(c.metadata["source"] for c in app.get_state_history(thread_1)) == { "update": 1, "input": 4, "loop": 3, } # the latest checkpoint is the updated one assert app.get_state(thread_1) == app.get_state(thread_1_next_config) def test_invoke_two_processes_two_in_join_two_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) add_10_each = mocker.Mock(side_effect=lambda x: sorted(y + 10 for y in x)) one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") chain_three = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") chain_four = ( Channel.subscribe_to("inbox") | add_10_each | Channel.write_to("output") ) app = Pregel( nodes={ "one": one, "chain_three": chain_three, "chain_four": chain_four, }, channels={ "inbox": Topic(int), "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", ) # Then invoke app # We get a single array result as chain_four waits for all publishers to finish # before operating on all elements published to topic_two as an array for _ in range(100): assert app.invoke(2) == [13, 13] with ThreadPoolExecutor() as executor: assert [*executor.map(app.invoke, [2] * 100)] == [[13, 13]] * 100 @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_invoke_join_then_call_other_pregel( mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") add_one = mocker.Mock(side_effect=lambda x: x + 1) add_10_each = mocker.Mock(side_effect=lambda x: [y + 10 for y in x]) inner_app = Pregel( nodes={ "one": Channel.subscribe_to("input") | add_one | Channel.write_to("output") }, channels={ "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", ) one = ( Channel.subscribe_to("input") | add_10_each | Channel.write_to("inbox_one").map() ) two = ( Channel.subscribe_to("inbox_one") | inner_app.map() | sorted | Channel.write_to("outbox_one") ) chain_three = Channel.subscribe_to("outbox_one") | sum | Channel.write_to("output") app = Pregel( nodes={ "one": one, "two": two, "chain_three": chain_three, }, channels={ "inbox_one": Topic(int), "outbox_one": LastValue(int), "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels="output", ) for _ in range(10): assert app.invoke([2, 3]) == 27 with ThreadPoolExecutor() as executor: assert [*executor.map(app.invoke, [[2, 3]] * 10)] == [27] * 10 # add checkpointer app.checkpointer = checkpointer # subgraph is called twice in the same node, through .map(), so raises with pytest.raises(MultipleSubgraphsError): app.invoke([2, 3], {"configurable": {"thread_id": "1"}}) # set inner graph checkpointer NeverCheckpoint inner_app.checkpointer = False # subgraph still called twice, but checkpointing for inner graph is disabled assert app.invoke([2, 3], {"configurable": {"thread_id": "1"}}) == 27 def test_invoke_two_processes_one_in_two_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = ( Channel.subscribe_to("input") | add_one | Channel.write_to("output", "between") ) two = Channel.subscribe_to("between") | add_one | Channel.write_to("output") app = Pregel( nodes={"one": one, "two": two}, channels={ "input": LastValue(int), "between": LastValue(int), "output": LastValue(int), }, stream_channels=["output", "between"], input_channels="input", output_channels="output", ) assert [c for c in app.stream(2, stream_mode="updates")] == [ {"one": {"between": 3, "output": 3}}, {"two": {"output": 4}}, ] assert [c for c in app.stream(2)] == [ {"between": 3, "output": 3}, {"between": 3, "output": 4}, ] def test_invoke_two_processes_no_out(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("between") two = Channel.subscribe_to("between") | add_one app = Pregel( nodes={"one": one, "two": two}, channels={ "input": LastValue(int), "between": LastValue(int), "output": LastValue(int), }, input_channels="input", output_channels="output", ) # It finishes executing (once no more messages being published) # but returns nothing, as nothing was published to OUT topic assert app.invoke(2) is None def test_invoke_two_processes_no_in(mocker: MockerFixture) -> None: add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("between") | add_one | Channel.write_to("output") two = Channel.subscribe_to("between") | add_one with pytest.raises(TypeError): Pregel(nodes={"one": one, "two": two}) def test_channel_enter_exit_timing(mocker: MockerFixture) -> None: setup = mocker.Mock() cleanup = mocker.Mock() @contextmanager def an_int() -> Generator[int, None, None]: setup() try: yield 5 finally: cleanup() add_one = mocker.Mock(side_effect=lambda x: x + 1) one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox") two = ( Channel.subscribe_to("inbox") | RunnableLambda(add_one).batch | Channel.write_to("output").batch ) app = Pregel( nodes={"one": one, "two": two}, channels={ "inbox": Topic(int), "ctx": Context(an_int), "output": LastValue(int), "input": LastValue(int), }, input_channels="input", output_channels=["inbox", "output"], stream_channels=["inbox", "output"], ) assert setup.call_count == 0 assert cleanup.call_count == 0 for i, chunk in enumerate(app.stream(2)): assert setup.call_count == 1, "Expected setup to be called once" if i == 0: assert chunk == {"inbox": [3]} elif i == 1: assert chunk == {"output": 4} else: assert False, "Expected only two chunks" assert cleanup.call_count == 1, "Expected cleanup to be called once" @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_conditional_graph( snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str ) -> None: from langchain_core.language_models.fake import FakeStreamingListLLM from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_core.tools import tool checkpointer: BaseCheckpointSaver = request.getfixturevalue( f"checkpointer_{checkpointer_name}" ) # Assemble the tools @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] # Construct the agent prompt = PromptTemplate.from_template("Hello!") llm = FakeStreamingListLLM( responses=[ "tool:search_api:query", "tool:search_api:another", "finish:answer", ] ) def agent_parser(input: str) -> Union[AgentAction, AgentFinish]: if input.startswith("finish"): _, answer = input.split(":") return AgentFinish(return_values={"answer": answer}, log=input) else: _, tool_name, tool_input = input.split(":") return AgentAction(tool=tool_name, tool_input=tool_input, log=input) agent = RunnablePassthrough.assign(agent_outcome=prompt | llm | agent_parser) # Define tool execution logic def execute_tools(data: dict) -> dict: data = data.copy() agent_action: AgentAction = data.pop("agent_outcome") observation = {t.name: t for t in tools}[agent_action.tool].invoke( agent_action.tool_input ) if data.get("intermediate_steps") is None: data["intermediate_steps"] = [] else: data["intermediate_steps"] = data["intermediate_steps"].copy() data["intermediate_steps"].append([agent_action, observation]) return data # Define decision-making logic def should_continue(data: dict) -> str: # Logic to decide whether to continue in the loop or exit if isinstance(data["agent_outcome"], AgentFinish): return "exit" else: return "continue" # Define a new graph workflow = Graph() workflow.add_node("agent", agent) workflow.add_node( "tools", execute_tools, metadata={"parents": {}, "version": 2, "variant": "b"}, ) workflow.set_entry_point("agent") workflow.add_conditional_edges( "agent", should_continue, {"continue": "tools", "exit": END} ) workflow.add_edge("tools", "agent") app = workflow.compile() if SHOULD_CHECK_SNAPSHOTS: assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.get_graph().draw_mermaid() == snapshot assert json.dumps(app.get_graph(xray=True).to_json(), indent=2) == snapshot assert app.get_graph(xray=True).draw_mermaid(with_styles=False) == snapshot assert app.invoke({"input": "what is weather in sf"}) == { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), } assert [c for c in app.stream({"input": "what is weather in sf"})] == [ { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), } }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), } }, ] # test state get/update methods with interrupt_after app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["agent"], ) config = {"configurable": {"thread_id": "1"}} if SHOULD_CHECK_SNAPSHOTS: assert app_w_interrupt.get_graph().to_json() == snapshot assert app_w_interrupt.get_graph().draw_mermaid() == snapshot assert [ c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config) ] == [ { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), } } ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), }, }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], config=app_w_interrupt.checkpointer.get_tuple(config).config, metadata={ "parents": {}, "source": "loop", "step": 0, "writes": { "agent": { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert ( app_w_interrupt.checkpointer.get_tuple(config).config["configurable"][ "checkpoint_id" ] is not None ) app_w_interrupt.update_state( config, { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, ) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", } }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, ] app_w_interrupt.update_state( config, { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), }, ) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), }, }, tasks=(), next=(), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 4, "writes": { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), } }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # test state get/update methods with interrupt_before app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before=["tools"], ) config = {"configurable": {"thread_id": "2"}} llm.i = 0 # reset the llm assert [ c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config) ] == [ { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), } } ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), }, }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": { "agent": { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } } }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) app_w_interrupt.update_state( config, { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, ) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", } }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "input": "what is weather in sf", }, }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, ] app_w_interrupt.update_state( config, { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), }, ) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), }, }, tasks=(), next=(), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 4, "writes": { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), } }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # test re-invoke to continue with interrupt_before app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before=["tools"], ) config = {"configurable": {"thread_id": "3"}} llm.i = 0 # reset the llm assert [ c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config) ] == [ { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), } } ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), }, }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": { "agent": { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } } }, "thread_id": "3", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "agent": { "input": "what is weather in sf", "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), }, }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, ] assert [c for c in app_w_interrupt.stream(None, config)] == [ { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, { "tools": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], } }, { "agent": { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), } }, ] def test_conditional_entrypoint_graph(snapshot: SnapshotAssertion) -> None: def left(data: str) -> str: return data + "->left" def right(data: str) -> str: return data + "->right" def should_start(data: str) -> str: # Logic to decide where to start if len(data) > 10: return "go-right" else: return "go-left" # Define a new graph workflow = Graph() workflow.add_node("left", left) workflow.add_node("right", right) workflow.set_conditional_entry_point( should_start, {"go-left": "left", "go-right": "right"} ) workflow.add_conditional_edges("left", lambda data: END, {END: END}) workflow.add_edge("right", END) app = workflow.compile() if SHOULD_CHECK_SNAPSHOTS: assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert ( app.invoke("what is weather in sf", debug=True) == "what is weather in sf->right" ) assert [*app.stream("what is weather in sf")] == [ {"right": "what is weather in sf->right"}, ] def test_conditional_entrypoint_to_multiple_state_graph( snapshot: SnapshotAssertion, ) -> None: class OverallState(TypedDict): locations: list[str] results: Annotated[list[str], operator.add] def get_weather(state: OverallState) -> OverallState: location = state["location"] weather = "sunny" if len(location) > 2 else "cloudy" return {"results": [f"It's {weather} in {location}"]} def continue_to_weather(state: OverallState) -> list[Send]: return [ Send("get_weather", {"location": location}) for location in state["locations"] ] workflow = StateGraph(OverallState) workflow.add_node("get_weather", get_weather) workflow.add_edge("get_weather", END) workflow.set_conditional_entry_point(continue_to_weather) app = workflow.compile() if SHOULD_CHECK_SNAPSHOTS: assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.invoke({"locations": ["sf", "nyc"]}, debug=True) == { "locations": ["sf", "nyc"], "results": ["It's cloudy in sf", "It's sunny in nyc"], } assert [*app.stream({"locations": ["sf", "nyc"]}, stream_mode="values")][-1] == { "locations": ["sf", "nyc"], "results": ["It's cloudy in sf", "It's sunny in nyc"], } @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_conditional_state_graph( snapshot: SnapshotAssertion, mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str, ) -> None: from langchain_core.language_models.fake import FakeStreamingListLLM from langchain_core.prompts import PromptTemplate from langchain_core.tools import tool checkpointer: BaseCheckpointSaver = request.getfixturevalue( f"checkpointer_{checkpointer_name}" ) setup = mocker.Mock() teardown = mocker.Mock() @contextmanager def assert_ctx_once() -> Iterator[None]: assert setup.call_count == 0 assert teardown.call_count == 0 try: yield finally: assert setup.call_count == 1 assert teardown.call_count == 1 setup.reset_mock() teardown.reset_mock() @contextmanager def make_httpx_client() -> Iterator[httpx.Client]: setup() with httpx.Client() as client: try: yield client finally: teardown() class AgentState(TypedDict, total=False): input: Annotated[str, UntrackedValue] agent_outcome: Optional[Union[AgentAction, AgentFinish]] intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add] session: Annotated[httpx.Client, Context(make_httpx_client)] class ToolState(TypedDict, total=False): agent_outcome: Union[AgentAction, AgentFinish] session: Annotated[httpx.Client, Context(make_httpx_client)] # Assemble the tools @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] # Construct the agent prompt = PromptTemplate.from_template("Hello!") llm = FakeStreamingListLLM( responses=[ "tool:search_api:query", "tool:search_api:another", "finish:answer", ] ) def agent_parser(input: str) -> dict[str, Union[AgentAction, AgentFinish]]: if input.startswith("finish"): _, answer = input.split(":") return { "agent_outcome": AgentFinish( return_values={"answer": answer}, log=input ) } else: _, tool_name, tool_input = input.split(":") return { "agent_outcome": AgentAction( tool=tool_name, tool_input=tool_input, log=input ) } agent = prompt | llm | agent_parser # Define tool execution logic def execute_tools(data: ToolState) -> dict: # check session in data assert isinstance(data["session"], httpx.Client) assert "input" not in data assert "intermediate_steps" not in data # execute the tool agent_action: AgentAction = data.pop("agent_outcome") observation = {t.name: t for t in tools}[agent_action.tool].invoke( agent_action.tool_input ) return {"intermediate_steps": [[agent_action, observation]]} # Define decision-making logic def should_continue(data: AgentState) -> str: # check session in data assert isinstance(data["session"], httpx.Client) # Logic to decide whether to continue in the loop or exit if isinstance(data["agent_outcome"], AgentFinish): return "exit" else: return "continue" # Define a new graph workflow = StateGraph(AgentState) workflow.add_node("agent", agent) workflow.add_node("tools", execute_tools, input=ToolState) workflow.set_entry_point("agent") workflow.add_conditional_edges( "agent", should_continue, {"continue": "tools", "exit": END} ) workflow.add_edge("tools", "agent") app = workflow.compile() if SHOULD_CHECK_SNAPSHOTS: assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot assert app.get_graph().draw_mermaid(with_styles=False) == snapshot with assert_ctx_once(): assert app.invoke({"input": "what is weather in sf"}) == { "input": "what is weather in sf", "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ], [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), } with assert_ctx_once(): assert [*app.stream({"input": "what is weather in sf"})] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], } }, { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ], ], } }, { "agent": { "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), } }, ] # test state get/update methods with interrupt_after app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["agent"], ) config = {"configurable": {"thread_id": "1"}} with assert_ctx_once(): assert [ c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config) ] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), "intermediate_steps": [], }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) with assert_ctx_once(): app_w_interrupt.update_state( config, { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ) }, ) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "intermediate_steps": [], }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 2, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ) }, }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) with assert_ctx_once(): assert [c for c in app_w_interrupt.stream(None, config)] == [ { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], } }, { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, {"__interrupt__": ()}, ] with assert_ctx_once(): app_w_interrupt.update_state( config, { "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ) }, ) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], }, tasks=(), next=(), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 5, "writes": { "agent": { "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ) } }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # test state get/update methods with interrupt_before app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before=["tools"], debug=True, ) config = {"configurable": {"thread_id": "2"}} llm.i = 0 # reset the llm assert [ c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config) ] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), } }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), "intermediate_steps": [], }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) app_w_interrupt.update_state( config, { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ) }, ) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "intermediate_steps": [], }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 2, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ) } }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], } }, { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, {"__interrupt__": ()}, ] app_w_interrupt.update_state( config, { "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ) }, ) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ), "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:a different query", ), "result for query", ] ], }, tasks=(), next=(), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 5, "writes": { "agent": { "agent_outcome": AgentFinish( return_values={"answer": "a really nice answer"}, log="finish:a really nice answer", ) } }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # test w interrupt before all app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before="*", debug=True, ) config = {"configurable": {"thread_id": "3"}} llm.i = 0 # reset the llm assert [ c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config) ] == [ {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "intermediate_steps": [], }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "3", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), } }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), "intermediate_steps": [], }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, "thread_id": "3", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], } }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 2, "writes": { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], } }, "thread_id": "3", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, {"__interrupt__": ()}, ] # test w interrupt after all app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after="*", ) config = {"configurable": {"thread_id": "4"}} llm.i = 0 # reset the llm assert [ c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config) ] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), } }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), "intermediate_steps": [], }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), } }, "thread_id": "4", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], } }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "agent_outcome": AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 2, "writes": { "tools": { "intermediate_steps": [ [ AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query", ), "result for query", ] ], } }, "thread_id": "4", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "agent": { "agent_outcome": AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), } }, {"__interrupt__": ()}, ] def test_conditional_state_graph_with_list_edge_inputs(snapshot: SnapshotAssertion): class State(TypedDict): foo: Annotated[list[str], operator.add] graph_builder = StateGraph(State) graph_builder.add_node("A", lambda x: {"foo": ["A"]}) graph_builder.add_node("B", lambda x: {"foo": ["B"]}) graph_builder.add_edge(START, "A") graph_builder.add_edge(START, "B") graph_builder.add_edge(["A", "B"], END) app = graph_builder.compile() assert app.invoke({"foo": []}) == {"foo": ["A", "B"]} assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot assert app.get_graph().draw_mermaid(with_styles=False) == snapshot def test_state_graph_w_config_inherited_state_keys(snapshot: SnapshotAssertion) -> None: from langchain_core.language_models.fake import FakeStreamingListLLM from langchain_core.prompts import PromptTemplate from langchain_core.tools import tool class BaseState(TypedDict): input: str agent_outcome: Optional[Union[AgentAction, AgentFinish]] class AgentState(BaseState, total=False): intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add] assert get_type_hints(AgentState).keys() == { "input", "agent_outcome", "intermediate_steps", } class Config(TypedDict, total=False): tools: list[str] # Assemble the tools @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] # Construct the agent prompt = PromptTemplate.from_template("Hello!") llm = FakeStreamingListLLM( responses=[ "tool:search_api:query", "tool:search_api:another", "finish:answer", ] ) def agent_parser(input: str) -> dict[str, Union[AgentAction, AgentFinish]]: if input.startswith("finish"): _, answer = input.split(":") return { "agent_outcome": AgentFinish( return_values={"answer": answer}, log=input ) } else: _, tool_name, tool_input = input.split(":") return { "agent_outcome": AgentAction( tool=tool_name, tool_input=tool_input, log=input ) } agent = prompt | llm | agent_parser # Define tool execution logic def execute_tools(data: AgentState) -> dict: agent_action: AgentAction = data.pop("agent_outcome") observation = {t.name: t for t in tools}[agent_action.tool].invoke( agent_action.tool_input ) return {"intermediate_steps": [(agent_action, observation)]} # Define decision-making logic def should_continue(data: AgentState) -> str: # Logic to decide whether to continue in the loop or exit if isinstance(data["agent_outcome"], AgentFinish): return "exit" else: return "continue" # Define a new graph builder = StateGraph(AgentState, Config) builder.add_node("agent", agent) builder.add_node("tools", execute_tools) builder.set_entry_point("agent") builder.add_conditional_edges( "agent", should_continue, {"continue": "tools", "exit": END} ) builder.add_edge("tools", "agent") app = builder.compile() if SHOULD_CHECK_SNAPSHOTS: assert json.dumps(app.config_schema().model_json_schema()) == snapshot assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot assert builder.channels.keys() == {"input", "agent_outcome", "intermediate_steps"} assert app.invoke({"input": "what is weather in sf"}) == { "agent_outcome": AgentFinish( return_values={"answer": "answer"}, log="finish:answer" ), "input": "what is weather in sf", "intermediate_steps": [ ( AgentAction( tool="search_api", tool_input="query", log="tool:search_api:query" ), "result for query", ), ( AgentAction( tool="search_api", tool_input="another", log="tool:search_api:another", ), "result for another", ), ], } def test_conditional_entrypoint_graph_state(snapshot: SnapshotAssertion) -> None: class AgentState(TypedDict, total=False): input: str output: str steps: Annotated[list[str], operator.add] def left(data: AgentState) -> AgentState: return {"output": data["input"] + "->left"} def right(data: AgentState) -> AgentState: return {"output": data["input"] + "->right"} def should_start(data: AgentState) -> str: assert data["steps"] == [], "Expected input to be read from the state" # Logic to decide where to start if len(data["input"]) > 10: return "go-right" else: return "go-left" # Define a new graph workflow = StateGraph(AgentState) workflow.add_node("left", left) workflow.add_node("right", right) workflow.set_conditional_entry_point( should_start, {"go-left": "left", "go-right": "right"} ) workflow.add_conditional_edges("left", lambda data: END, {END: END}) workflow.add_edge("right", END) app = workflow.compile() if SHOULD_CHECK_SNAPSHOTS: assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.invoke({"input": "what is weather in sf"}) == { "input": "what is weather in sf", "output": "what is weather in sf->right", "steps": [], } assert [*app.stream({"input": "what is weather in sf"})] == [ {"right": {"output": "what is weather in sf->right"}}, ] def test_prebuilt_tool_chat(snapshot: SnapshotAssertion) -> None: from langchain_core.messages import AIMessage, HumanMessage from langchain_core.tools import tool @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] model = FakeChatModel( messages=[ AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), AIMessage( content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another"}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one"}, }, ], ), AIMessage(content="answer"), ] ) app = create_tool_calling_executor(model, tools) if SHOULD_CHECK_SNAPSHOTS: assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.invoke( {"messages": [HumanMessage(content="what is weather in sf")]} ) == { "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), _AnyIdAIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ), _AnyIdAIMessage( content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another"}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one"}, }, ], ), _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call234", ), _AnyIdToolMessage( content="result for a third one", name="search_api", tool_call_id="tool_call567", id=AnyStr(), ), _AnyIdAIMessage(content="answer"), ] } assert [ c for c in app.stream( {"messages": [HumanMessage(content="what is weather in sf")]}, stream_mode="messages", ) ] == [ ( _AnyIdAIMessageChunk( content="", tool_calls=[ { "name": "search_api", "args": {"query": "query"}, "id": "tool_call123", "type": "tool_call", } ], tool_call_chunks=[ { "name": "search_api", "args": '{"query": "query"}', "id": "tool_call123", "index": None, "type": "tool_call_chunk", } ], ), { "langgraph_step": 1, "langgraph_node": "agent", "langgraph_triggers": ["start:agent"], "langgraph_path": (PULL, "agent"), "langgraph_checkpoint_ns": AnyStr("agent:"), "checkpoint_ns": AnyStr("agent:"), "ls_provider": "fakechatmodel", "ls_model_type": "chat", }, ), ( _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ), { "langgraph_step": 2, "langgraph_node": "tools", "langgraph_triggers": ["branch:agent:should_continue:tools"], "langgraph_path": (PULL, "tools"), "langgraph_checkpoint_ns": AnyStr("tools:"), }, ), ( _AnyIdAIMessageChunk( content="", tool_calls=[ { "name": "search_api", "args": {"query": "another"}, "id": "tool_call234", "type": "tool_call", }, { "name": "search_api", "args": {"query": "a third one"}, "id": "tool_call567", "type": "tool_call", }, ], tool_call_chunks=[ { "name": "search_api", "args": '{"query": "another"}', "id": "tool_call234", "index": None, "type": "tool_call_chunk", }, { "name": "search_api", "args": '{"query": "a third one"}', "id": "tool_call567", "index": None, "type": "tool_call_chunk", }, ], ), { "langgraph_step": 3, "langgraph_node": "agent", "langgraph_triggers": ["tools"], "langgraph_path": (PULL, "agent"), "langgraph_checkpoint_ns": AnyStr("agent:"), "checkpoint_ns": AnyStr("agent:"), "ls_provider": "fakechatmodel", "ls_model_type": "chat", }, ), ( _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call234", ), { "langgraph_step": 4, "langgraph_node": "tools", "langgraph_triggers": ["branch:agent:should_continue:tools"], "langgraph_path": (PULL, "tools"), "langgraph_checkpoint_ns": AnyStr("tools:"), }, ), ( _AnyIdToolMessage( content="result for a third one", name="search_api", tool_call_id="tool_call567", ), { "langgraph_step": 4, "langgraph_node": "tools", "langgraph_triggers": ["branch:agent:should_continue:tools"], "langgraph_path": (PULL, "tools"), "langgraph_checkpoint_ns": AnyStr("tools:"), }, ), ( _AnyIdAIMessageChunk( content="answer", ), { "langgraph_step": 5, "langgraph_node": "agent", "langgraph_triggers": ["tools"], "langgraph_path": (PULL, "agent"), "langgraph_checkpoint_ns": AnyStr("agent:"), "checkpoint_ns": AnyStr("agent:"), "ls_provider": "fakechatmodel", "ls_model_type": "chat", }, ), ] assert app.invoke( {"messages": [HumanMessage(content="what is weather in sf")]}, {"recursion_limit": 2}, debug=True, ) == { "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), _AnyIdAIMessage(content="Sorry, need more steps to process this request."), ] } model.i = 0 # reset the model assert ( app.invoke( {"messages": [HumanMessage(content="what is weather in sf")]}, stream_mode="updates", )[0]["agent"]["messages"] == [ { "agent": { "messages": [ _AnyIdAIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ) ] } }, { "tools": { "messages": [ _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ) ] } }, { "agent": { "messages": [ _AnyIdAIMessage( content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another"}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one"}, }, ], ) ] } }, { "tools": { "messages": [ _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call234", ), _AnyIdToolMessage( content="result for a third one", name="search_api", tool_call_id="tool_call567", ), ] } }, {"agent": {"messages": [_AnyIdAIMessage(content="answer")]}}, ][0]["agent"]["messages"] ) assert [ *app.stream({"messages": [HumanMessage(content="what is weather in sf")]}) ] == [ { "agent": { "messages": [ _AnyIdAIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ) ] } }, { "tools": { "messages": [ _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ) ] } }, { "agent": { "messages": [ _AnyIdAIMessage( content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another"}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one"}, }, ], ) ] } }, { "tools": { "messages": [ _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call234", ), _AnyIdToolMessage( content="result for a third one", name="search_api", tool_call_id="tool_call567", ), ] } }, {"agent": {"messages": [_AnyIdAIMessage(content="answer")]}}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_state_graph_packets( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import ( AIMessage, BaseMessage, HumanMessage, ToolCall, ToolMessage, ) from langchain_core.tools import tool checkpointer: BaseCheckpointSaver = request.getfixturevalue( f"checkpointer_{checkpointer_name}" ) class AgentState(TypedDict): messages: Annotated[list[BaseMessage], add_messages] session: Annotated[httpx.Client, Context(httpx.Client)] @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] tools_by_name = {t.name: t for t in tools} model = FakeMessagesListChatModel( responses=[ AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ), AIMessage(id="ai3", content="answer"), ] ) def agent(data: AgentState) -> AgentState: assert isinstance(data["session"], httpx.Client) return { "messages": model.invoke(data["messages"]), "something_extra": "hi there", } # Define decision-making logic def should_continue(data: AgentState) -> str: assert isinstance(data["session"], httpx.Client) assert ( data["something_extra"] == "hi there" ), "nodes can pass extra data to their cond edges, which isn't saved in state" # Logic to decide whether to continue in the loop or exit if tool_calls := data["messages"][-1].tool_calls: return [Send("tools", tool_call) for tool_call in tool_calls] else: return END def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: time.sleep(input["args"].get("idx", 0) / 10) output = tools_by_name[input["name"]].invoke(input["args"], config) return { "messages": ToolMessage( content=output, name=input["name"], tool_call_id=input["id"] ) } # Define a new graph workflow = StateGraph(AgentState) # Define the two nodes we will cycle between workflow.add_node("agent", agent) workflow.add_node("tools", tools_node) # Set the entrypoint as `agent` # This means that this node is the first one called workflow.set_entry_point("agent") # We now add a conditional edge workflow.add_conditional_edges("agent", should_continue) # We now add a normal edge from `tools` to `agent`. # This means that after `tools` is called, `agent` node is called next. workflow.add_edge("tools", "agent") # Finally, we compile it! # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable app = workflow.compile() assert app.invoke({"messages": HumanMessage(content="what is weather in sf")}) == { "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ), AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ), _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call234", ), _AnyIdToolMessage( content="result for a third one", name="search_api", tool_call_id="tool_call567", ), AIMessage(content="answer", id="ai3"), ] } assert [ c for c in app.stream( {"messages": [HumanMessage(content="what is weather in sf")]} ) ] == [ { "agent": { "messages": AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ) }, }, { "tools": { "messages": _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ) } }, { "agent": { "messages": AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ) } }, { "tools": { "messages": _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call234", ) }, }, { "tools": { "messages": _AnyIdToolMessage( content="result for a third one", name="search_api", tool_call_id="tool_call567", ), }, }, {"agent": {"messages": AIMessage(content="answer", id="ai3")}}, ] # interrupt after agent app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["agent"], ) config = {"configurable": {"thread_id": "1"}} assert [ c for c in app_w_interrupt.stream( {"messages": HumanMessage(content="what is weather in sf")}, config ) ] == [ { "agent": { "messages": AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ) } }, {"__interrupt__": ()}, ] if not FF_SEND_V2: return assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), ] }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( content="", additional_kwargs={}, response_metadata={}, id="ai1", tool_calls=[ { "name": "search_api", "args": {"query": "query"}, "id": "tool_call123", "type": "tool_call", } ], ) }, ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr()) ), ), next=("tools",), config=(app_w_interrupt.checkpointer.get_tuple(config)).config, created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # modify ai message last_message = (app_w_interrupt.get_state(config)).values["messages"][-1] last_message.tool_calls[0]["args"]["query"] = "a different query" app_w_interrupt.update_state( config, {"messages": last_message, "something_extra": "hi there"} ) # message was replaced instead of appended assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), ] }, tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0, AnyStr())),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": { "agent": { "messages": AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), "something_extra": "hi there", } }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "tools": { "messages": _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ) } }, { "agent": { "messages": AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ) }, }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ), ] }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( "", id="ai2", tool_calls=[ { "name": "search_api", "args": {"query": "another", "idx": 0}, "id": "tool_call234", "type": "tool_call", }, { "name": "search_api", "args": {"query": "a third one", "idx": 1}, "id": "tool_call567", "type": "tool_call", }, ], ) }, ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr()) ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3, AnyStr()) ), ), next=("tools", "tools"), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 2, "writes": { "tools": { "messages": _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), }, }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) app_w_interrupt.update_state( config, { "messages": AIMessage(content="answer", id="ai2"), "something_extra": "hi there", }, ) # replaces message even if object identity is different, as long as id is the same assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage(content="answer", id="ai2"), ] }, tasks=(), next=(), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 3, "writes": { "agent": { "messages": AIMessage(content="answer", id="ai2"), "something_extra": "hi there", } }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # interrupt before tools app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before=["tools"], ) config = {"configurable": {"thread_id": "2"}} model.i = 0 assert [ c for c in app_w_interrupt.stream( {"messages": HumanMessage(content="what is weather in sf")}, config ) ] == [ { "agent": { "messages": AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ) } }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), ] }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( "", id="ai1", tool_calls=[ { "name": "search_api", "args": {"query": "query"}, "id": "tool_call123", "type": "tool_call", } ], ) }, ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr()) ), ), next=("tools",), config=(app_w_interrupt.checkpointer.get_tuple(config)).config, created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # modify ai message last_message = (app_w_interrupt.get_state(config)).values["messages"][-1] last_message.tool_calls[0]["args"]["query"] = "a different query" app_w_interrupt.update_state( config, {"messages": last_message, "something_extra": "hi there"} ) # message was replaced instead of appended assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), ] }, tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0, AnyStr())),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": { "agent": { "messages": AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), "something_extra": "hi there", } }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "tools": { "messages": _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ) } }, { "agent": { "messages": AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ) }, }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage( id="ai2", content="", tool_calls=[ { "id": "tool_call234", "name": "search_api", "args": {"query": "another", "idx": 0}, }, { "id": "tool_call567", "name": "search_api", "args": {"query": "a third one", "idx": 1}, }, ], ), ] }, tasks=( PregelTask( id=AnyStr(), name="agent", path=("__pregel_pull", "agent"), error=None, interrupts=(), state=None, result={ "messages": AIMessage( "", id="ai2", tool_calls=[ { "name": "search_api", "args": {"query": "another", "idx": 0}, "id": "tool_call234", "type": "tool_call", }, { "name": "search_api", "args": {"query": "a third one", "idx": 1}, "id": "tool_call567", "type": "tool_call", }, ], ) }, ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr()) ), PregelTask( AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3, AnyStr()) ), ), next=("tools", "tools"), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 2, "writes": { "tools": { "messages": _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), }, }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) app_w_interrupt.update_state( config, { "messages": AIMessage(content="answer", id="ai2"), "something_extra": "hi there", }, ) # replaces message even if object identity is different, as long as id is the same assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( id="ai1", content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, }, ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage(content="answer", id="ai2"), ] }, tasks=(), next=(), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 3, "writes": { "agent": { "messages": AIMessage(content="answer", id="ai2"), "something_extra": "hi there", } }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_message_graph( snapshot: SnapshotAssertion, deterministic_uuids: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str, ) -> None: from copy import deepcopy from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.tools import tool checkpointer: BaseCheckpointSaver = request.getfixturevalue( f"checkpointer_{checkpointer_name}" ) class FakeFuntionChatModel(FakeMessagesListChatModel): def bind_functions(self, functions: list): return self def _generate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: response = deepcopy(self.responses[self.i]) if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 generation = ChatGeneration(message=response) return ChatResult(generations=[generation]) @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] model = FakeFuntionChatModel( responses=[ AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ), AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ), AIMessage(content="answer", id="ai3"), ] ) # Define the function that determines whether to continue or not def should_continue(messages): last_message = messages[-1] # If there is no function call, then we finish if not last_message.tool_calls: return "end" # Otherwise if there is, we continue else: return "continue" # Define a new graph workflow = MessageGraph() # Define the two nodes we will cycle between workflow.add_node("agent", model) workflow.add_node("tools", ToolNode(tools)) # Set the entrypoint as `agent` # This means that this node is the first one called workflow.set_entry_point("agent") # We now add a conditional edge workflow.add_conditional_edges( # First, we define the start node. We use `agent`. # This means these are the edges taken after the `agent` node is called. "agent", # Next, we pass in the function that will determine which node is called next. should_continue, # Finally we pass in a mapping. # The keys are strings, and the values are other nodes. # END is a special node marking that the graph should finish. # What will happen is we will call `should_continue`, and then the output of that # will be matched against the keys in this mapping. # Based on which one it matches, that node will then be called. { # If `tools`, then we call the tool node. "continue": "tools", # Otherwise we finish. "end": END, }, ) # We now add a normal edge from `tools` to `agent`. # This means that after `tools` is called, `agent` node is called next. workflow.add_edge("tools", "agent") # Finally, we compile it! # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable app = workflow.compile() if SHOULD_CHECK_SNAPSHOTS: assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.invoke(HumanMessage(content="what is weather in sf")) == [ _AnyIdHumanMessage( content="what is weather in sf", ), AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", # respects ids passed in ), _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ), AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ), _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call456", ), AIMessage(content="answer", id="ai3"), ] assert [*app.stream([HumanMessage(content="what is weather in sf")])] == [ { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, { "tools": [ _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ) ] }, { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, { "tools": [ _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call456", ) ] }, {"agent": AIMessage(content="answer", id="ai3")}, ] app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["agent"], ) config = {"configurable": {"thread_id": "1"}} assert [ c for c in app_w_interrupt.stream(("human", "what is weather in sf"), config) ] == [ { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # modify ai message last_message = app_w_interrupt.get_state(config).values[-1] last_message.tool_calls[0]["args"] = {"query": "a different query"} next_config = app_w_interrupt.update_state(config, last_message) # message was replaced instead of appended assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=next_config, created_at=AnyStr(), metadata={ "parents": {}, "source": "update", "step": 2, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], id="ai1", ) }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "tools": [ _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ) ] }, { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 4, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) app_w_interrupt.update_state( config, AIMessage(content="answer", id="ai2"), # replace existing message ) # replaces message even if object identity is different, as long as id is the same assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage(content="answer", id="ai2"), ], tasks=(), next=(), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 5, "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before=["tools"], ) config = {"configurable": {"thread_id": "2"}} model.i = 0 # reset the llm assert [c for c in app_w_interrupt.stream("what is weather in sf", config)] == [ { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # modify ai message last_message = app_w_interrupt.get_state(config).values[-1] last_message.tool_calls[0]["args"] = {"query": "a different query"} app_w_interrupt.update_state(config, last_message) # message was replaced instead of appended assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 2, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], id="ai1", ) }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "tools": [ _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ) ] }, { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 4, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) app_w_interrupt.update_state( config, AIMessage(content="answer", id="ai2"), ) # replaces message even if object identity is different, as long as id is the same assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", id=AnyStr(), ), AIMessage(content="answer", id="ai2"), ], tasks=(), next=(), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 5, "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # add an extra message as if it came from "tools" node app_w_interrupt.update_state(config, ("ai", "an extra message"), as_node="tools") # extra message is coerced BaseMessge and appended # now the next node is "agent" per the graph edges assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", id=AnyStr(), ), AIMessage(content="answer", id="ai2"), _AnyIdAIMessage(content="an extra message"), ], tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 6, "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_root_graph( deterministic_uuids: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str, ) -> None: from copy import deepcopy from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import ( AIMessage, BaseMessage, HumanMessage, ToolMessage, ) from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.tools import tool checkpointer: BaseCheckpointSaver = request.getfixturevalue( f"checkpointer_{checkpointer_name}" ) class FakeFuntionChatModel(FakeMessagesListChatModel): def bind_functions(self, functions: list): return self def _generate( self, messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: response = deepcopy(self.responses[self.i]) if self.i < len(self.responses) - 1: self.i += 1 else: self.i = 0 generation = ChatGeneration(message=response) return ChatResult(generations=[generation]) @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] model = FakeFuntionChatModel( responses=[ AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ), AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ), AIMessage(content="answer", id="ai3"), ] ) # Define the function that determines whether to continue or not def should_continue(messages): last_message = messages[-1] # If there is no function call, then we finish if not last_message.tool_calls: return "end" # Otherwise if there is, we continue else: return "continue" class State(TypedDict): __root__: Annotated[list[BaseMessage], add_messages] # Define a new graph workflow = StateGraph(State) # Define the two nodes we will cycle between workflow.add_node("agent", model) workflow.add_node("tools", ToolNode(tools)) # Set the entrypoint as `agent` # This means that this node is the first one called workflow.set_entry_point("agent") # We now add a conditional edge workflow.add_conditional_edges( # First, we define the start node. We use `agent`. # This means these are the edges taken after the `agent` node is called. "agent", # Next, we pass in the function that will determine which node is called next. should_continue, # Finally we pass in a mapping. # The keys are strings, and the values are other nodes. # END is a special node marking that the graph should finish. # What will happen is we will call `should_continue`, and then the output of that # will be matched against the keys in this mapping. # Based on which one it matches, that node will then be called. { # If `tools`, then we call the tool node. "continue": "tools", # Otherwise we finish. "end": END, }, ) # We now add a normal edge from `tools` to `agent`. # This means that after `tools` is called, `agent` node is called next. workflow.add_edge("tools", "agent") # Finally, we compile it! # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable app = workflow.compile() assert app.invoke(HumanMessage(content="what is weather in sf")) == [ _AnyIdHumanMessage( content="what is weather in sf", ), AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", # respects ids passed in ), _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ), AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ), _AnyIdToolMessage( content="result for another", name="search_api", tool_call_id="tool_call456", ), AIMessage(content="answer", id="ai3"), ] assert [*app.stream([HumanMessage(content="what is weather in sf")])] == [ { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, { "tools": [ ToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", id="00000000-0000-4000-8000-000000000033", ) ] }, { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, { "tools": [ ToolMessage( content="result for another", name="search_api", tool_call_id="tool_call456", id="00000000-0000-4000-8000-000000000041", ) ] }, {"agent": AIMessage(content="answer", id="ai3")}, ] app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["agent"], ) config = {"configurable": {"thread_id": "1"}} assert [ c for c in app_w_interrupt.stream(("human", "what is weather in sf"), config) ] == [ { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # modify ai message last_message = app_w_interrupt.get_state(config).values[-1] last_message.tool_calls[0]["args"] = {"query": "a different query"} next_config = app_w_interrupt.update_state(config, last_message) # message was replaced instead of appended assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=next_config, created_at=AnyStr(), metadata={ "parents": {}, "source": "update", "step": 2, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], id="ai1", ) }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "tools": [ _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ) ] }, { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", id=AnyStr(), ), AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 4, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) app_w_interrupt.update_state( config, AIMessage(content="answer", id="ai2"), # replace existing message ) # replaces message even if object identity is different, as long as id is the same assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", id=AnyStr(), ), AIMessage(content="answer", id="ai2"), ], tasks=(), next=(), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 5, "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "1", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before=["tools"], ) config = {"configurable": {"thread_id": "2"}} model.i = 0 # reset the llm assert [c for c in app_w_interrupt.stream("what is weather in sf", config)] == [ { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, } ], id="ai1", ) }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # modify ai message last_message = app_w_interrupt.get_state(config).values[-1] last_message.tool_calls[0]["args"] = {"query": "a different query"} app_w_interrupt.update_state(config, last_message) # message was replaced instead of appended assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 2, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], id="ai1", ) }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config)] == [ { "tools": [ _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ) ] }, { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, {"__interrupt__": ()}, ] assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", id=AnyStr(), ), AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ), ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 4, "writes": { "agent": AIMessage( content="", tool_calls=[ { "id": "tool_call456", "name": "search_api", "args": {"query": "another"}, } ], id="ai2", ) }, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) app_w_interrupt.update_state( config, AIMessage(content="answer", id="ai2"), ) # replaces message even if object identity is different, as long as id is the same assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage(content="answer", id="ai2"), ], tasks=(), next=(), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 5, "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # add an extra message as if it came from "tools" node app_w_interrupt.update_state(config, ("ai", "an extra message"), as_node="tools") # extra message is coerced BaseMessge and appended # now the next node is "agent" per the graph edges assert app_w_interrupt.get_state(config) == StateSnapshot( values=[ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", id=AnyStr(), ), AIMessage(content="answer", id="ai2"), _AnyIdAIMessage(content="an extra message"), ], tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 6, "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # create new graph with one more state key, reuse previous thread history def simple_add(left, right): if not isinstance(right, list): right = [right] return left + right class MoreState(TypedDict): __root__: Annotated[list[BaseMessage], simple_add] something_else: str # Define a new graph new_workflow = StateGraph(MoreState) new_workflow.add_node( "agent", RunnableMap(__root__=RunnablePick("__root__") | model) ) new_workflow.add_node( "tools", RunnableMap(__root__=RunnablePick("__root__") | ToolNode(tools)) ) new_workflow.set_entry_point("agent") new_workflow.add_conditional_edges( "agent", RunnablePick("__root__") | should_continue, { # If `tools`, then we call the tool node. "continue": "tools", # Otherwise we finish. "end": END, }, ) new_workflow.add_edge("tools", "agent") new_app = new_workflow.compile(checkpointer=checkpointer) model.i = 0 # reset the llm # previous state is converted to new schema assert new_app.get_state(config) == StateSnapshot( values={ "__root__": [ _AnyIdHumanMessage(content="what is weather in sf"), AIMessage( content="", id="ai1", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "a different query"}, } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage(content="answer", id="ai2"), _AnyIdAIMessage(content="an extra message"), ] }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 6, "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) # new input is merged to old state assert new_app.invoke( { "__root__": [HumanMessage(content="what is weather in la")], "something_else": "value", }, config, interrupt_before=["agent"], ) == { "__root__": [ HumanMessage( content="what is weather in sf", id="00000000-0000-4000-8000-000000000070", ), AIMessage( content="", id="ai1", tool_calls=[ { "name": "search_api", "args": {"query": "a different query"}, "id": "tool_call123", } ], ), _AnyIdToolMessage( content="result for a different query", name="search_api", tool_call_id="tool_call123", ), AIMessage(content="answer", id="ai2"), AIMessage( content="an extra message", id="00000000-0000-4000-8000-000000000092" ), HumanMessage(content="what is weather in la"), ], "something_else": "value", } def test_in_one_fan_out_out_one_graph_state() -> None: def sorted_add(x: list[str], y: list[str]) -> list[str]: return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} def retriever_one(data: State) -> State: # timer ensures stream output order is stable # also, it confirms that the update order is not dependent on finishing order # instead being defined by the order of the nodes/edges in the graph definition # ie. stable between invocations time.sleep(0.1) return {"docs": ["doc1", "doc2"]} def retriever_two(data: State) -> State: return {"docs": ["doc3", "doc4"]} def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "retriever_one") workflow.add_edge("rewrite_query", "retriever_two") workflow.add_edge("retriever_one", "qa") workflow.add_edge("retriever_two", "qa") workflow.set_finish_point("qa") app = workflow.compile() assert app.invoke({"query": "what is weather in sf"}) == { "query": "query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", } assert [*app.stream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] assert [*app.stream({"query": "what is weather in sf"}, stream_mode="values")] == [ {"query": "what is weather in sf", "docs": []}, {"query": "query: what is weather in sf", "docs": []}, { "query": "query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], }, { "query": "query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", }, ] assert [ *app.stream( {"query": "what is weather in sf"}, stream_mode=["values", "updates", "debug"], ) ] == [ ("values", {"query": "what is weather in sf", "docs": []}), ( "debug", { "type": "task", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "rewrite_query", "input": {"query": "what is weather in sf", "docs": []}, "triggers": ["start:rewrite_query"], }, }, ), ("updates", {"rewrite_query": {"query": "query: what is weather in sf"}}), ( "debug", { "type": "task_result", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "rewrite_query", "result": [("query", "query: what is weather in sf")], "error": None, "interrupts": [], }, }, ), ("values", {"query": "query: what is weather in sf", "docs": []}), ( "debug", { "type": "task", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "retriever_one", "input": {"query": "query: what is weather in sf", "docs": []}, "triggers": ["rewrite_query"], }, }, ), ( "debug", { "type": "task", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "retriever_two", "input": {"query": "query: what is weather in sf", "docs": []}, "triggers": ["rewrite_query"], }, }, ), ( "updates", {"retriever_two": {"docs": ["doc3", "doc4"]}}, ), ( "debug", { "type": "task_result", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "retriever_two", "result": [("docs", ["doc3", "doc4"])], "error": None, "interrupts": [], }, }, ), ( "updates", {"retriever_one": {"docs": ["doc1", "doc2"]}}, ), ( "debug", { "type": "task_result", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "retriever_one", "result": [("docs", ["doc1", "doc2"])], "error": None, "interrupts": [], }, }, ), ( "values", { "query": "query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], }, ), ( "debug", { "type": "task", "timestamp": AnyStr(), "step": 3, "payload": { "id": AnyStr(), "name": "qa", "input": { "query": "query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], }, "triggers": ["retriever_one", "retriever_two"], }, }, ), ("updates", {"qa": {"answer": "doc1,doc2,doc3,doc4"}}), ( "debug", { "type": "task_result", "timestamp": AnyStr(), "step": 3, "payload": { "id": AnyStr(), "name": "qa", "result": [("answer", "doc1,doc2,doc3,doc4")], "error": None, "interrupts": [], }, }, ), ( "values", { "query": "query: what is weather in sf", "answer": "doc1,doc2,doc3,doc4", "docs": ["doc1", "doc2", "doc3", "doc4"], }, ), ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_dynamic_interrupt( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") class State(TypedDict): my_key: Annotated[str, operator.add] market: str tool_two_node_count = 0 def tool_two_node(s: State) -> State: nonlocal tool_two_node_count tool_two_node_count += 1 if s["market"] == "DE": answer = interrupt("Just because...") else: answer = " all good" return {"my_key": answer} tool_two_graph = StateGraph(State) tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy()) tool_two_graph.add_edge(START, "tool_two") tool_two = tool_two_graph.compile() tracer = FakeTracer() assert tool_two.invoke( {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} ) == { "my_key": "value", "market": "DE", } assert tool_two_node_count == 1, "interrupts aren't retried" assert len(tracer.runs) == 1 run = tracer.runs[0] assert run.end_time is not None assert run.error is None assert run.outputs == {"market": "DE", "my_key": "value"} assert tool_two.invoke({"my_key": "value", "market": "US"}) == { "my_key": "value all good", "market": "US", } tool_two = tool_two_graph.compile(checkpointer=checkpointer) # missing thread_id with pytest.raises(ValueError, match="thread_id"): tool_two.invoke({"my_key": "value", "market": "DE"}) # flow: interrupt -> resume with answer thread2 = {"configurable": {"thread_id": "2"}} # stop when about to enter node assert [ c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2) ] == [ { "__interrupt__": ( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:")], ), ) }, ] # resume with answer assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [ {"tool_two": {"my_key": " my answer"}}, ] # flow: interrupt -> clear tasks thread1 = {"configurable": {"thread_id": "1"}} # stop when about to enter node assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == { "my_key": "value ⛰️", "market": "DE", } assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ { "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, { "parents": {}, "source": "input", "step": -1, "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, "thread_id": "1", }, ] assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=("tool_two",), tasks=( PregelTask( AnyStr(), "tool_two", (PULL, "tool_two"), interrupts=( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:")], ), ), ), ), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) # clear the interrupt and next tasks tool_two.update_state(thread1, None, as_node=END) # interrupt and next tasks are cleared assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=(), tasks=(), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": {}, "thread_id": "1", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) @pytest.mark.skipif(not FF_SEND_V2, reason="send v2 is not enabled") @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_copy_checkpoint( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") class State(TypedDict): my_key: Annotated[str, operator.add] market: str def tool_one(s: State) -> State: return {"my_key": " one"} tool_two_node_count = 0 def tool_two_node(s: State) -> State: nonlocal tool_two_node_count tool_two_node_count += 1 if s["market"] == "DE": answer = interrupt("Just because...") else: answer = " all good" return {"my_key": answer} def start(state: State) -> list[Union[Send, str]]: return ["tool_two", Send("tool_one", state)] tool_two_graph = StateGraph(State) tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy()) tool_two_graph.add_node("tool_one", tool_one) tool_two_graph.set_conditional_entry_point(start) tool_two = tool_two_graph.compile() tracer = FakeTracer() assert tool_two.invoke( {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} ) == { "my_key": "value one", "market": "DE", } assert tool_two_node_count == 1, "interrupts aren't retried" assert len(tracer.runs) == 1 run = tracer.runs[0] assert run.end_time is not None assert run.error is None assert run.outputs == {"market": "DE", "my_key": "value one"} assert tool_two.invoke({"my_key": "value", "market": "US"}) == { "my_key": "value one all good", "market": "US", } tool_two = tool_two_graph.compile(checkpointer=checkpointer) # missing thread_id with pytest.raises(ValueError, match="thread_id"): tool_two.invoke({"my_key": "value", "market": "DE"}) # flow: interrupt -> resume with answer thread2 = {"configurable": {"thread_id": "2"}} # stop when about to enter node assert [ c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2) ] == [ { "tool_one": {"my_key": " one"}, }, { "__interrupt__": ( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:")], ), ) }, ] # resume with answer assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [ {"tool_two": {"my_key": " my answer"}}, ] # flow: interrupt -> clear tasks thread1 = {"configurable": {"thread_id": "1"}} # stop when about to enter node assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == { "my_key": "value ⛰️ one", "market": "DE", } assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ { "parents": {}, "source": "loop", "step": 0, "writes": {"tool_one": {"my_key": " one"}}, "thread_id": "1", }, { "parents": {}, "source": "input", "step": -1, "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, "thread_id": "1", }, ] assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️ one", "market": "DE"}, next=("tool_two",), tasks=( PregelTask( AnyStr(), "tool_two", (PULL, "tool_two"), interrupts=( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:")], ), ), ), ), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": {"tool_one": {"my_key": " one"}}, "thread_id": "1", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) # clear the interrupt and next tasks tool_two.update_state(thread1, None) # interrupt is cleared, next task is kept assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️ one", "market": "DE"}, next=("tool_two",), tasks=( PregelTask( AnyStr(), "tool_two", (PULL, "tool_two"), interrupts=(), ), ), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": {}, "thread_id": "1", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_dynamic_interrupt_subgraph( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") class SubgraphState(TypedDict): my_key: str market: str tool_two_node_count = 0 def tool_two_node(s: SubgraphState) -> SubgraphState: nonlocal tool_two_node_count tool_two_node_count += 1 if s["market"] == "DE": answer = interrupt("Just because...") else: answer = " all good" return {"my_key": answer} subgraph = StateGraph(SubgraphState) subgraph.add_node("do", tool_two_node, retry=RetryPolicy()) subgraph.add_edge(START, "do") class State(TypedDict): my_key: Annotated[str, operator.add] market: str tool_two_graph = StateGraph(State) tool_two_graph.add_node("tool_two", subgraph.compile()) tool_two_graph.add_edge(START, "tool_two") tool_two = tool_two_graph.compile() tracer = FakeTracer() assert tool_two.invoke( {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} ) == { "my_key": "value", "market": "DE", } assert tool_two_node_count == 1, "interrupts aren't retried" assert len(tracer.runs) == 1 run = tracer.runs[0] assert run.end_time is not None assert run.error is None assert run.outputs == {"market": "DE", "my_key": "value"} assert tool_two.invoke({"my_key": "value", "market": "US"}) == { "my_key": "value all good", "market": "US", } tool_two = tool_two_graph.compile(checkpointer=checkpointer) # missing thread_id with pytest.raises(ValueError, match="thread_id"): tool_two.invoke({"my_key": "value", "market": "DE"}) # flow: interrupt -> resume with answer thread2 = {"configurable": {"thread_id": "2"}} # stop when about to enter node assert [ c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2) ] == [ { "__interrupt__": ( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:"), AnyStr("do:")], ), ) }, ] # resume with answer assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [ {"tool_two": {"my_key": " my answer", "market": "DE"}}, ] # flow: interrupt -> clear tasks thread1 = {"configurable": {"thread_id": "1"}} # stop when about to enter node assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == { "my_key": "value ⛰️", "market": "DE", } assert [ c.metadata for c in tool_two.checkpointer.list( {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} ) ] == [ { "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, { "parents": {}, "source": "input", "step": -1, "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, "thread_id": "1", }, ] assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=("tool_two",), tasks=( PregelTask( AnyStr(), "tool_two", (PULL, "tool_two"), interrupts=( Interrupt( value="Just because...", resumable=True, ns=[AnyStr("tool_two:"), AnyStr("do:")], ), ), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("tool_two:"), } }, ), ), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "1", }, parent_config=[ *tool_two.checkpointer.list( {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 ) ][-1].config, ) # clear the interrupt and next tasks tool_two.update_state(thread1, None, as_node=END) # interrupt and next tasks are cleared assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=(), tasks=(), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": {}, "thread_id": "1", }, parent_config=[ *tool_two.checkpointer.list( {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 ) ][-1].config, ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_start_branch_then( snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") class State(TypedDict): my_key: Annotated[str, operator.add] market: str shared: Annotated[dict[str, dict[str, Any]], SharedValue.on("assistant_id")] def assert_shared_value(data: State, config: RunnableConfig) -> State: assert "shared" in data if thread_id := config["configurable"].get("thread_id"): if thread_id == "1": # this is the first thread, so should not see a value assert data["shared"] == {} return {"shared": {"1": {"hello": "world"}}} elif thread_id == "2": # this should get value saved by thread 1 assert data["shared"] == {"1": {"hello": "world"}} elif thread_id == "3": # this is a different assistant, so should not see previous value assert data["shared"] == {} return {} def tool_two_slow(data: State, config: RunnableConfig) -> State: return {"my_key": " slow", **assert_shared_value(data, config)} def tool_two_fast(data: State, config: RunnableConfig) -> State: return {"my_key": " fast", **assert_shared_value(data, config)} tool_two_graph = StateGraph(State) tool_two_graph.add_node("tool_two_slow", tool_two_slow) tool_two_graph.add_node("tool_two_fast", tool_two_fast) tool_two_graph.set_conditional_entry_point( lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast", then=END ) tool_two = tool_two_graph.compile() assert tool_two.get_graph().draw_mermaid() == snapshot assert tool_two.invoke({"my_key": "value", "market": "DE"}) == { "my_key": "value slow", "market": "DE", } assert tool_two.invoke({"my_key": "value", "market": "US"}) == { "my_key": "value fast", "market": "US", } tool_two = tool_two_graph.compile( store=InMemoryStore(), checkpointer=checkpointer, interrupt_before=["tool_two_fast", "tool_two_slow"], ) # missing thread_id with pytest.raises(ValueError, match="thread_id"): tool_two.invoke({"my_key": "value", "market": "DE"}) thread1 = {"configurable": {"thread_id": "1", "assistant_id": "a"}} # stop when about to enter node assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == { "my_key": "value ⛰️", "market": "DE", } assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ { "parents": {}, "source": "loop", "step": 0, "writes": None, "assistant_id": "a", "thread_id": "1", }, { "parents": {}, "source": "input", "step": -1, "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, "assistant_id": "a", "thread_id": "1", }, ] assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "assistant_id": "a", "thread_id": "1", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { "my_key": "value ⛰️ slow", "market": "DE", } assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️ slow", "market": "DE"}, tasks=(), next=(), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"tool_two_slow": {"my_key": " slow"}}, "assistant_id": "a", "thread_id": "1", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) thread2 = {"configurable": {"thread_id": "2", "assistant_id": "a"}} # stop when about to enter node assert tool_two.invoke({"my_key": "value", "market": "US"}, thread2) == { "my_key": "value", "market": "US", } assert tool_two.get_state(thread2) == StateSnapshot( values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), config=tool_two.checkpointer.get_tuple(thread2).config, created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "assistant_id": "a", "thread_id": "2", }, parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { "my_key": "value fast", "market": "US", } assert tool_two.get_state(thread2) == StateSnapshot( values={"my_key": "value fast", "market": "US"}, tasks=(), next=(), config=tool_two.checkpointer.get_tuple(thread2).config, created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"tool_two_fast": {"my_key": " fast"}}, "assistant_id": "a", "thread_id": "2", }, parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, ) thread3 = {"configurable": {"thread_id": "3", "assistant_id": "b"}} # stop when about to enter node assert tool_two.invoke({"my_key": "value", "market": "US"}, thread3) == { "my_key": "value", "market": "US", } assert tool_two.get_state(thread3) == StateSnapshot( values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), config=tool_two.checkpointer.get_tuple(thread3).config, created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 0, "writes": None, "assistant_id": "b", "thread_id": "3", }, parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, ) # update state tool_two.update_state(thread3, {"my_key": "key"}) # appends to my_key assert tool_two.get_state(thread3) == StateSnapshot( values={"my_key": "valuekey", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), config=tool_two.checkpointer.get_tuple(thread3).config, created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 1, "writes": {START: {"my_key": "key"}}, "assistant_id": "b", "thread_id": "3", }, parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, ) # resume, for same result as above assert tool_two.invoke(None, thread3, debug=1) == { "my_key": "valuekey fast", "market": "US", } assert tool_two.get_state(thread3) == StateSnapshot( values={"my_key": "valuekey fast", "market": "US"}, tasks=(), next=(), config=tool_two.checkpointer.get_tuple(thread3).config, created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 2, "writes": {"tool_two_fast": {"my_key": " fast"}}, "assistant_id": "b", "thread_id": "3", }, parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_branch_then( snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") class State(TypedDict): my_key: Annotated[str, operator.add] market: str tool_two_graph = StateGraph(State) tool_two_graph.set_entry_point("prepare") tool_two_graph.set_finish_point("finish") tool_two_graph.add_conditional_edges( source="prepare", path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast", then="finish", ) tool_two_graph.add_node("prepare", lambda s: {"my_key": " prepared"}) tool_two_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"}) tool_two_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"}) tool_two_graph.add_node("finish", lambda s: {"my_key": " finished"}) tool_two = tool_two_graph.compile() assert tool_two.get_graph().draw_mermaid(with_styles=False) == snapshot assert tool_two.get_graph().draw_mermaid() == snapshot assert tool_two.invoke({"my_key": "value", "market": "DE"}, debug=1) == { "my_key": "value prepared slow finished", "market": "DE", } assert tool_two.invoke({"my_key": "value", "market": "US"}) == { "my_key": "value prepared fast finished", "market": "US", } # test stream_mode=debug tool_two = tool_two_graph.compile(checkpointer=checkpointer) thread10 = {"configurable": {"thread_id": "10"}} res = [ *tool_two.stream( {"my_key": "value", "market": "DE"}, thread10, stream_mode="debug" ) ] assert res == [ { "type": "checkpoint", "timestamp": AnyStr(), "step": -1, "payload": { "config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": {"my_key": ""}, "metadata": { "parents": {}, "source": "input", "step": -1, "writes": {"__start__": {"my_key": "value", "market": "DE"}}, "thread_id": "10", }, "parent_config": None, "next": ["__start__"], "tasks": [ { "id": AnyStr(), "name": "__start__", "interrupts": (), "state": None, } ], }, }, { "type": "checkpoint", "timestamp": AnyStr(), "step": 0, "payload": { "config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": { "my_key": "value", "market": "DE", }, "metadata": { "parents": {}, "source": "loop", "step": 0, "writes": None, "thread_id": "10", }, "parent_config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "next": ["prepare"], "tasks": [ {"id": AnyStr(), "name": "prepare", "interrupts": (), "state": None} ], }, }, { "type": "task", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "prepare", "input": {"my_key": "value", "market": "DE"}, "triggers": ["start:prepare"], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 1, "payload": { "id": AnyStr(), "name": "prepare", "result": [("my_key", " prepared")], "error": None, "interrupts": [], }, }, { "type": "checkpoint", "timestamp": AnyStr(), "step": 1, "payload": { "config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": { "my_key": "value prepared", "market": "DE", }, "metadata": { "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "10", }, "parent_config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "next": ["tool_two_slow"], "tasks": [ { "id": AnyStr(), "name": "tool_two_slow", "interrupts": (), "state": None, } ], }, }, { "type": "task", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "tool_two_slow", "input": {"my_key": "value prepared", "market": "DE"}, "triggers": ["branch:prepare:condition:tool_two_slow"], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 2, "payload": { "id": AnyStr(), "name": "tool_two_slow", "result": [("my_key", " slow")], "error": None, "interrupts": [], }, }, { "type": "checkpoint", "timestamp": AnyStr(), "step": 2, "payload": { "config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": { "my_key": "value prepared slow", "market": "DE", }, "metadata": { "parents": {}, "source": "loop", "step": 2, "writes": {"tool_two_slow": {"my_key": " slow"}}, "thread_id": "10", }, "parent_config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "next": ["finish"], "tasks": [ {"id": AnyStr(), "name": "finish", "interrupts": (), "state": None} ], }, }, { "type": "task", "timestamp": AnyStr(), "step": 3, "payload": { "id": AnyStr(), "name": "finish", "input": {"my_key": "value prepared slow", "market": "DE"}, "triggers": ["branch:prepare:condition::then"], }, }, { "type": "task_result", "timestamp": AnyStr(), "step": 3, "payload": { "id": AnyStr(), "name": "finish", "result": [("my_key", " finished")], "error": None, "interrupts": [], }, }, { "type": "checkpoint", "timestamp": AnyStr(), "step": 3, "payload": { "config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "values": { "my_key": "value prepared slow finished", "market": "DE", }, "metadata": { "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "10", }, "parent_config": { "tags": [], "metadata": {"thread_id": "10"}, "callbacks": None, "recursion_limit": 25, "configurable": { "thread_id": "10", "checkpoint_ns": "", "checkpoint_id": AnyStr(), }, }, "next": [], "tasks": [], }, }, ] tool_two = tool_two_graph.compile( checkpointer=checkpointer, interrupt_before=["tool_two_fast", "tool_two_slow"] ) # missing thread_id with pytest.raises(ValueError, match="thread_id"): tool_two.invoke({"my_key": "value", "market": "DE"}) thread1 = {"configurable": {"thread_id": "1"}} # stop when about to enter node assert tool_two.invoke({"my_key": "value", "market": "DE"}, thread1) == { "my_key": "value prepared", "market": "DE", } assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "1", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { "my_key": "value prepared slow finished", "market": "DE", } assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "1", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) thread2 = {"configurable": {"thread_id": "2"}} # stop when about to enter node assert tool_two.invoke({"my_key": "value", "market": "US"}, thread2) == { "my_key": "value prepared", "market": "US", } assert tool_two.get_state(thread2) == StateSnapshot( values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), config=tool_two.checkpointer.get_tuple(thread2).config, created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "2", }, parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { "my_key": "value prepared fast finished", "market": "US", } assert tool_two.get_state(thread2) == StateSnapshot( values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), config=tool_two.checkpointer.get_tuple(thread2).config, created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "2", }, parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, ) tool_two = tool_two_graph.compile( checkpointer=checkpointer, interrupt_before=["finish"] ) thread1 = {"configurable": {"thread_id": "11"}} # stop when about to enter node assert tool_two.invoke({"my_key": "value", "market": "DE"}, thread1) == { "my_key": "value prepared slow", "market": "DE", } assert tool_two.get_state(thread1) == StateSnapshot( values={ "my_key": "value prepared slow", "market": "DE", }, tasks=(PregelTask(AnyStr(), "finish", (PULL, "finish")),), next=("finish",), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 2, "writes": {"tool_two_slow": {"my_key": " slow"}}, "thread_id": "11", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) # update state tool_two.update_state(thread1, {"my_key": "er"}) assert tool_two.get_state(thread1) == StateSnapshot( values={ "my_key": "value prepared slower", "market": "DE", }, tasks=(PregelTask(AnyStr(), "finish", (PULL, "finish")),), next=("finish",), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 3, "writes": {"tool_two_slow": {"my_key": "er"}}, "thread_id": "11", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) tool_two = tool_two_graph.compile( checkpointer=checkpointer, interrupt_after=["prepare"] ) # missing thread_id with pytest.raises(ValueError, match="thread_id"): tool_two.invoke({"my_key": "value", "market": "DE"}) thread1 = {"configurable": {"thread_id": "21"}} # stop when about to enter node assert tool_two.invoke({"my_key": "value", "market": "DE"}, thread1) == { "my_key": "value prepared", "market": "DE", } assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "21", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { "my_key": "value prepared slow finished", "market": "DE", } assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), config=tool_two.checkpointer.get_tuple(thread1).config, created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "21", }, parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) thread2 = {"configurable": {"thread_id": "22"}} # stop when about to enter node assert tool_two.invoke({"my_key": "value", "market": "US"}, thread2) == { "my_key": "value prepared", "market": "US", } assert tool_two.get_state(thread2) == StateSnapshot( values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), config=tool_two.checkpointer.get_tuple(thread2).config, created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "22", }, parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { "my_key": "value prepared fast finished", "market": "US", } assert tool_two.get_state(thread2) == StateSnapshot( values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), config=tool_two.checkpointer.get_tuple(thread2).config, created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "22", }, parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, ) thread3 = {"configurable": {"thread_id": "23"}} # update an empty thread before first run uconfig = tool_two.update_state(thread3, {"my_key": "key", "market": "DE"}) # check current state assert tool_two.get_state(thread3) == StateSnapshot( values={"my_key": "key", "market": "DE"}, tasks=(PregelTask(AnyStr(), "prepare", (PULL, "prepare")),), next=("prepare",), config=uconfig, created_at=AnyStr(), metadata={ "parents": {}, "source": "update", "step": 0, "writes": {START: {"my_key": "key", "market": "DE"}}, "thread_id": "23", }, parent_config=None, ) # run from this point assert tool_two.invoke(None, thread3) == { "my_key": "key prepared", "market": "DE", } # get state after first node assert tool_two.get_state(thread3) == StateSnapshot( values={"my_key": "key prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), config=tool_two.checkpointer.get_tuple(thread3).config, created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 1, "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "23", }, parent_config=uconfig, ) # resume, for same result as above assert tool_two.invoke(None, thread3, debug=1) == { "my_key": "key prepared slow finished", "market": "DE", } assert tool_two.get_state(thread3) == StateSnapshot( values={"my_key": "key prepared slow finished", "market": "DE"}, tasks=(), next=(), config=tool_two.checkpointer.get_tuple(thread3).config, created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], metadata={ "parents": {}, "source": "loop", "step": 3, "writes": {"finish": {"my_key": " finished"}}, "thread_id": "23", }, parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_in_one_fan_out_state_graph_waiting_edge( snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer: BaseCheckpointSaver = request.getfixturevalue( f"checkpointer_{checkpointer_name}" ) def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] workflow = StateGraph(State) @workflow.add_node def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} def analyzer_one(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} def retriever_two(data: State) -> State: time.sleep(0.1) # to ensure stream order return {"docs": ["doc3", "doc4"]} def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} workflow.add_node(analyzer_one) workflow.add_node(retriever_one) workflow.add_node(retriever_two) workflow.add_node(qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_edge("rewrite_query", "retriever_two") workflow.add_edge(["retriever_one", "retriever_two"], "qa") workflow.set_finish_point("qa") app = workflow.compile() assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.invoke({"query": "what is weather in sf"}) == { "query": "analyzed: query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", } assert [*app.stream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["retriever_one"], ) config = {"configurable": {"thread_id": "1"}} assert [ c for c in app_w_interrupt.stream({"query": "what is weather in sf"}, config) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"__interrupt__": ()}, ] assert [c for c in app_w_interrupt.stream(None, config)] == [ {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_before=["qa"], ) config = {"configurable": {"thread_id": "2"}} assert [ c for c in app_w_interrupt.stream({"query": "what is weather in sf"}, config) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"__interrupt__": ()}, ] app_w_interrupt.update_state(config, {"docs": ["doc5"]}) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "query": "analyzed: query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4", "doc5"], }, tasks=(PregelTask(AnyStr(), "qa", (PULL, "qa")),), next=("qa",), config=app_w_interrupt.checkpointer.get_tuple(config).config, created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], metadata={ "parents": {}, "source": "update", "step": 4, "writes": {"retriever_one": {"docs": ["doc5"]}}, "thread_id": "2", }, parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, ) assert [c for c in app_w_interrupt.stream(None, config, debug=1)] == [ {"qa": {"answer": "doc1,doc2,doc3,doc4,doc5"}}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_in_one_fan_out_state_graph_waiting_edge_via_branch( snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer: BaseCheckpointSaver = request.getfixturevalue( f"checkpointer_{checkpointer_name}" ) def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} def analyzer_one(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} def retriever_two(data: State) -> State: time.sleep(0.1) return {"docs": ["doc3", "doc4"]} def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} def rewrite_query_then(data: State) -> Literal["retriever_two"]: return "retriever_two" workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_conditional_edges("rewrite_query", rewrite_query_then) workflow.add_edge(["retriever_one", "retriever_two"], "qa") workflow.set_finish_point("qa") app = workflow.compile() assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.invoke({"query": "what is weather in sf"}, debug=True) == { "query": "analyzed: query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", } assert [*app.stream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["retriever_one"], ) config = {"configurable": {"thread_id": "1"}} assert [ c for c in app_w_interrupt.stream({"query": "what is weather in sf"}, config) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"__interrupt__": ()}, ] assert [c for c in app_w_interrupt.stream(None, config)] == [ {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1( snapshot: SnapshotAssertion, mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str, ) -> None: from pydantic.v1 import BaseModel, ValidationError checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") setup = mocker.Mock() teardown = mocker.Mock() @contextmanager def assert_ctx_once() -> Iterator[None]: assert setup.call_count == 0 assert teardown.call_count == 0 try: yield finally: assert setup.call_count == 1 assert teardown.call_count == 1 setup.reset_mock() teardown.reset_mock() @contextmanager def make_httpx_client() -> Iterator[httpx.Client]: setup() with httpx.Client() as client: try: yield client finally: teardown() def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class InnerObject(BaseModel): yo: int class State(BaseModel): class Config: arbitrary_types_allowed = True query: str inner: InnerObject answer: Optional[str] = None docs: Annotated[list[str], sorted_add] client: Annotated[httpx.Client, Context(make_httpx_client)] class Input(BaseModel): query: str inner: InnerObject class Output(BaseModel): answer: str docs: list[str] class StateUpdate(BaseModel): query: Optional[str] = None answer: Optional[str] = None docs: Optional[list[str]] = None def rewrite_query(data: State) -> State: return {"query": f"query: {data.query}"} def analyzer_one(data: State) -> State: return StateUpdate(query=f"analyzed: {data.query}") def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} def retriever_two(data: State) -> State: time.sleep(0.1) return {"docs": ["doc3", "doc4"]} def qa(data: State) -> State: return {"answer": ",".join(data.docs)} def decider(data: State) -> str: assert isinstance(data, State) return "retriever_two" workflow = StateGraph(State, input=Input, output=Output) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_conditional_edges( "rewrite_query", decider, {"retriever_two": "retriever_two"} ) workflow.add_edge(["retriever_one", "retriever_two"], "qa") workflow.set_finish_point("qa") app = workflow.compile() assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.get_input_jsonschema() == snapshot assert app.get_output_jsonschema() == snapshot with pytest.raises(ValidationError), assert_ctx_once(): app.invoke({"query": {}}) with assert_ctx_once(): assert app.invoke({"query": "what is weather in sf", "inner": {"yo": 1}}) == { "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", } with assert_ctx_once(): assert [ *app.stream({"query": "what is weather in sf", "inner": {"yo": 1}}) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["retriever_one"], ) config = {"configurable": {"thread_id": "1"}} with assert_ctx_once(): assert [ c for c in app_w_interrupt.stream( {"query": "what is weather in sf", "inner": {"yo": 1}}, config ) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"__interrupt__": ()}, ] with assert_ctx_once(): assert [c for c in app_w_interrupt.stream(None, config)] == [ {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] with assert_ctx_once(): assert app_w_interrupt.update_state( config, {"docs": ["doc5"]}, as_node="rewrite_query" ) == { "configurable": { "thread_id": "1", "checkpoint_id": AnyStr(), "checkpoint_ns": "", } } @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2( snapshot: SnapshotAssertion, mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str, ) -> None: from pydantic import BaseModel, ConfigDict, ValidationError checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") setup = mocker.Mock() teardown = mocker.Mock() @contextmanager def assert_ctx_once() -> Iterator[None]: assert setup.call_count == 0 assert teardown.call_count == 0 try: yield finally: assert setup.call_count == 1 assert teardown.call_count == 1 setup.reset_mock() teardown.reset_mock() @contextmanager def make_httpx_client() -> Iterator[httpx.Client]: setup() with httpx.Client() as client: try: yield client finally: teardown() def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class InnerObject(BaseModel): yo: int class State(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) query: str inner: InnerObject answer: Optional[str] = None docs: Annotated[list[str], sorted_add] client: Annotated[httpx.Client, Context(make_httpx_client)] class StateUpdate(BaseModel): query: Optional[str] = None answer: Optional[str] = None docs: Optional[list[str]] = None class Input(BaseModel): query: str inner: InnerObject class Output(BaseModel): answer: str docs: list[str] def rewrite_query(data: State) -> State: return {"query": f"query: {data.query}"} def analyzer_one(data: State) -> State: return StateUpdate(query=f"analyzed: {data.query}") def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} def retriever_two(data: State) -> State: time.sleep(0.1) return {"docs": ["doc3", "doc4"]} def qa(data: State) -> State: return {"answer": ",".join(data.docs)} def decider(data: State) -> str: assert isinstance(data, State) return "retriever_two" workflow = StateGraph(State, input=Input, output=Output) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_conditional_edges( "rewrite_query", decider, {"retriever_two": "retriever_two"} ) workflow.add_edge(["retriever_one", "retriever_two"], "qa") workflow.set_finish_point("qa") app = workflow.compile() if SHOULD_CHECK_SNAPSHOTS: assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.get_input_schema().model_json_schema() == snapshot assert app.get_output_schema().model_json_schema() == snapshot with pytest.raises(ValidationError), assert_ctx_once(): app.invoke({"query": {}}) with assert_ctx_once(): assert app.invoke({"query": "what is weather in sf", "inner": {"yo": 1}}) == { "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", } with assert_ctx_once(): assert [ *app.stream({"query": "what is weather in sf", "inner": {"yo": 1}}) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["retriever_one"], ) config = {"configurable": {"thread_id": "1"}} with assert_ctx_once(): assert [ c for c in app_w_interrupt.stream( {"query": "what is weather in sf", "inner": {"yo": 1}}, config ) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"__interrupt__": ()}, ] with assert_ctx_once(): assert [c for c in app_w_interrupt.stream(None, config)] == [ {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] with assert_ctx_once(): assert app_w_interrupt.update_state( config, {"docs": ["doc5"]}, as_node="rewrite_query" ) == { "configurable": { "thread_id": "1", "checkpoint_id": AnyStr(), "checkpoint_ns": "", } } @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_in_one_fan_out_state_graph_waiting_edge_plus_regular( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer: BaseCheckpointSaver = request.getfixturevalue( f"checkpointer_{checkpointer_name}" ) def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} def analyzer_one(data: State) -> State: time.sleep(0.1) return {"query": f'analyzed: {data["query"]}'} def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} def retriever_two(data: State) -> State: time.sleep(0.2) return {"docs": ["doc3", "doc4"]} def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_edge("rewrite_query", "retriever_two") workflow.add_edge(["retriever_one", "retriever_two"], "qa") workflow.set_finish_point("qa") # silly edge, to make sure having been triggered before doesn't break # semantics of named barrier (== waiting edges) workflow.add_edge("rewrite_query", "qa") app = workflow.compile() assert app.invoke({"query": "what is weather in sf"}) == { "query": "analyzed: query: what is weather in sf", "docs": ["doc1", "doc2", "doc3", "doc4"], "answer": "doc1,doc2,doc3,doc4", } assert [*app.stream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"qa": {"answer": ""}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] app_w_interrupt = workflow.compile( checkpointer=checkpointer, interrupt_after=["retriever_one"], ) config = {"configurable": {"thread_id": "1"}} assert [ c for c in app_w_interrupt.stream({"query": "what is weather in sf"}, config) ] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"qa": {"answer": ""}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"__interrupt__": ()}, ] assert [c for c in app_w_interrupt.stream(None, config)] == [ {"qa": {"answer": "doc1,doc2,doc3,doc4"}}, ] def test_in_one_fan_out_state_graph_waiting_edge_multiple() -> None: def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} def analyzer_one(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} def retriever_two(data: State) -> State: time.sleep(0.1) return {"docs": ["doc3", "doc4"]} def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} def decider(data: State) -> None: return None def decider_cond(data: State) -> str: if data["query"].count("analyzed") > 1: return "qa" else: return "rewrite_query" workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("decider", decider) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_edge("rewrite_query", "analyzer_one") workflow.add_edge("analyzer_one", "retriever_one") workflow.add_edge("rewrite_query", "retriever_two") workflow.add_edge(["retriever_one", "retriever_two"], "decider") workflow.add_conditional_edges("decider", decider_cond) workflow.set_finish_point("qa") app = workflow.compile() assert app.invoke({"query": "what is weather in sf"}) == { "query": "analyzed: query: analyzed: query: what is weather in sf", "answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4", "docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"], } assert [*app.stream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"decider": None}, {"rewrite_query": {"query": "query: analyzed: query: what is weather in sf"}}, { "analyzer_one": { "query": "analyzed: query: analyzed: query: what is weather in sf" } }, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"decider": None}, {"qa": {"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4"}}, ] def test_callable_in_conditional_edges_with_no_path_map() -> None: class State(TypedDict, total=False): query: str def rewrite(data: State) -> State: return {"query": f'query: {data["query"]}'} def analyze(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} class ChooseAnalyzer: def __call__(self, data: State) -> str: return "analyzer" workflow = StateGraph(State) workflow.add_node("rewriter", rewrite) workflow.add_node("analyzer", analyze) workflow.add_conditional_edges("rewriter", ChooseAnalyzer()) workflow.set_entry_point("rewriter") app = workflow.compile() assert app.invoke({"query": "what is weather in sf"}) == { "query": "analyzed: query: what is weather in sf", } def test_function_in_conditional_edges_with_no_path_map() -> None: class State(TypedDict, total=False): query: str def rewrite(data: State) -> State: return {"query": f'query: {data["query"]}'} def analyze(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} def choose_analyzer(data: State) -> str: return "analyzer" workflow = StateGraph(State) workflow.add_node("rewriter", rewrite) workflow.add_node("analyzer", analyze) workflow.add_conditional_edges("rewriter", choose_analyzer) workflow.set_entry_point("rewriter") app = workflow.compile() assert app.invoke({"query": "what is weather in sf"}) == { "query": "analyzed: query: what is weather in sf", } def test_in_one_fan_out_state_graph_waiting_edge_multiple_cond_edge() -> None: def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} def retriever_picker(data: State) -> list[str]: return ["analyzer_one", "retriever_two"] def analyzer_one(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} def retriever_two(data: State) -> State: time.sleep(0.1) return {"docs": ["doc3", "doc4"]} def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} def decider(data: State) -> None: return None def decider_cond(data: State) -> str: if data["query"].count("analyzed") > 1: return "qa" else: return "rewrite_query" workflow = StateGraph(State) workflow.add_node("rewrite_query", rewrite_query) workflow.add_node("analyzer_one", analyzer_one) workflow.add_node("retriever_one", retriever_one) workflow.add_node("retriever_two", retriever_two) workflow.add_node("decider", decider) workflow.add_node("qa", qa) workflow.set_entry_point("rewrite_query") workflow.add_conditional_edges("rewrite_query", retriever_picker) workflow.add_edge("analyzer_one", "retriever_one") workflow.add_edge(["retriever_one", "retriever_two"], "decider") workflow.add_conditional_edges("decider", decider_cond) workflow.set_finish_point("qa") app = workflow.compile() assert app.invoke({"query": "what is weather in sf"}) == { "query": "analyzed: query: analyzed: query: what is weather in sf", "answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4", "docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"], } assert [*app.stream({"query": "what is weather in sf"})] == [ {"rewrite_query": {"query": "query: what is weather in sf"}}, {"analyzer_one": {"query": "analyzed: query: what is weather in sf"}}, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"decider": None}, {"rewrite_query": {"query": "query: analyzed: query: what is weather in sf"}}, { "analyzer_one": { "query": "analyzed: query: analyzed: query: what is weather in sf" } }, {"retriever_two": {"docs": ["doc3", "doc4"]}}, {"retriever_one": {"docs": ["doc1", "doc2"]}}, {"decider": None}, {"qa": {"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4"}}, ] def test_simple_multi_edge(snapshot: SnapshotAssertion) -> None: class State(TypedDict): my_key: Annotated[str, operator.add] def up(state: State): pass def side(state: State): pass def other(state: State): return {"my_key": "_more"} def down(state: State): pass graph = StateGraph(State) graph.add_node("up", up) graph.add_node("side", side) graph.add_node("other", other) graph.add_node("down", down) graph.set_entry_point("up") graph.add_edge("up", "side") graph.add_edge("up", "other") graph.add_edge(["up", "side"], "down") graph.set_finish_point("down") app = graph.compile() assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.invoke({"my_key": "my_value"}) == {"my_key": "my_value_more"} assert [*app.stream({"my_key": "my_value"})] in ( [ {"up": None}, {"side": None}, {"other": {"my_key": "_more"}}, {"down": None}, ], [ {"up": None}, {"other": {"my_key": "_more"}}, {"side": None}, {"down": None}, ], ) def test_nested_graph_xray(snapshot: SnapshotAssertion) -> None: class State(TypedDict): my_key: Annotated[str, operator.add] market: str def logic(state: State): pass tool_two_graph = StateGraph(State) tool_two_graph.add_node("tool_two_slow", logic) tool_two_graph.add_node("tool_two_fast", logic) tool_two_graph.set_conditional_entry_point( lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast", then=END, ) tool_two = tool_two_graph.compile() graph = StateGraph(State) graph.add_node("tool_one", logic) graph.add_node("tool_two", tool_two) graph.add_node("tool_three", logic) graph.set_conditional_entry_point(lambda s: "tool_one", then=END) app = graph.compile() assert app.get_graph(xray=True).to_json() == snapshot assert app.get_graph(xray=True).draw_mermaid() == snapshot def test_nested_graph(snapshot: SnapshotAssertion) -> None: def never_called_fn(state: Any): assert 0, "This function should never be called" never_called = RunnableLambda(never_called_fn) class InnerState(TypedDict): my_key: str my_other_key: str def up(state: InnerState): return {"my_key": state["my_key"] + " there", "my_other_key": state["my_key"]} inner = StateGraph(InnerState) inner.add_node("up", up) inner.set_entry_point("up") inner.set_finish_point("up") class State(TypedDict): my_key: str never_called: Any def side(state: State): return {"my_key": state["my_key"] + " and back again"} graph = StateGraph(State) graph.add_node("inner", inner.compile()) graph.add_node("side", side) graph.set_entry_point("inner") graph.add_edge("inner", "side") graph.set_finish_point("side") app = graph.compile() assert app.get_graph().draw_mermaid(with_styles=False) == snapshot assert app.get_graph(xray=True).draw_mermaid() == snapshot assert app.invoke( {"my_key": "my value", "never_called": never_called}, debug=True ) == { "my_key": "my value there and back again", "never_called": never_called, } assert [*app.stream({"my_key": "my value", "never_called": never_called})] == [ {"inner": {"my_key": "my value there"}}, {"side": {"my_key": "my value there and back again"}}, ] assert [ *app.stream( {"my_key": "my value", "never_called": never_called}, stream_mode="values" ) ] == [ { "my_key": "my value", "never_called": never_called, }, { "my_key": "my value there", "never_called": never_called, }, { "my_key": "my value there and back again", "never_called": never_called, }, ] chain = app | RunnablePassthrough() assert chain.invoke({"my_key": "my value", "never_called": never_called}) == { "my_key": "my value there and back again", "never_called": never_called, } assert [*chain.stream({"my_key": "my value", "never_called": never_called})] == [ {"inner": {"my_key": "my value there"}}, {"side": {"my_key": "my value there and back again"}}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_stream_subgraphs_during_execution( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name) class InnerState(TypedDict): my_key: Annotated[str, operator.add] my_other_key: str def inner_1(state: InnerState): return {"my_key": "got here", "my_other_key": state["my_key"]} def inner_2(state: InnerState): time.sleep(0.5) return { "my_key": " and there", "my_other_key": state["my_key"], } inner = StateGraph(InnerState) inner.add_node("inner_1", inner_1) inner.add_node("inner_2", inner_2) inner.add_edge("inner_1", "inner_2") inner.set_entry_point("inner_1") inner.set_finish_point("inner_2") class State(TypedDict): my_key: Annotated[str, operator.add] def outer_1(state: State): time.sleep(0.2) return {"my_key": " and parallel"} def outer_2(state: State): return {"my_key": " and back again"} graph = StateGraph(State) graph.add_node("inner", inner.compile()) graph.add_node("outer_1", outer_1) graph.add_node("outer_2", outer_2) graph.add_edge(START, "inner") graph.add_edge(START, "outer_1") graph.add_edge(["inner", "outer_1"], "outer_2") graph.add_edge("outer_2", END) app = graph.compile(checkpointer=checkpointer) start = time.perf_counter() chunks: list[tuple[float, Any]] = [] config = {"configurable": {"thread_id": "2"}} for c in app.stream({"my_key": ""}, config, subgraphs=True): chunks.append((round(time.perf_counter() - start, 1), c)) for idx in range(len(chunks)): elapsed, c = chunks[idx] chunks[idx] = (round(elapsed - chunks[0][0], 1), c) assert chunks == [ # arrives before "inner" finishes ( FloatBetween(0.0, 0.1), ( (AnyStr("inner:"),), {"inner_1": {"my_key": "got here", "my_other_key": ""}}, ), ), (FloatBetween(0.2, 0.3), ((), {"outer_1": {"my_key": " and parallel"}})), ( FloatBetween(0.5, 0.8), ( (AnyStr("inner:"),), {"inner_2": {"my_key": " and there", "my_other_key": "got here"}}, ), ), (FloatBetween(0.5, 0.8), ((), {"inner": {"my_key": "got here and there"}})), (FloatBetween(0.5, 0.8), ((), {"outer_2": {"my_key": " and back again"}})), ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_stream_buffering_single_node( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name) class State(TypedDict): my_key: Annotated[str, operator.add] def node(state: State, writer: StreamWriter): writer("Before sleep") time.sleep(0.2) writer("After sleep") return {"my_key": "got here"} builder = StateGraph(State) builder.add_node("node", node) builder.add_edge(START, "node") builder.add_edge("node", END) graph = builder.compile(checkpointer=checkpointer) start = time.perf_counter() chunks: list[tuple[float, Any]] = [] config = {"configurable": {"thread_id": "2"}} for c in graph.stream({"my_key": ""}, config, stream_mode="custom"): chunks.append((round(time.perf_counter() - start, 1), c)) assert chunks == [ (FloatBetween(0.0, 0.1), "Before sleep"), (FloatBetween(0.2, 0.3), "After sleep"), ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_nested_graph_interrupts_parallel( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name) class InnerState(TypedDict): my_key: Annotated[str, operator.add] my_other_key: str def inner_1(state: InnerState): time.sleep(0.1) return {"my_key": "got here", "my_other_key": state["my_key"]} def inner_2(state: InnerState): return { "my_key": " and there", "my_other_key": state["my_key"], } inner = StateGraph(InnerState) inner.add_node("inner_1", inner_1) inner.add_node("inner_2", inner_2) inner.add_edge("inner_1", "inner_2") inner.set_entry_point("inner_1") inner.set_finish_point("inner_2") class State(TypedDict): my_key: Annotated[str, operator.add] def outer_1(state: State): return {"my_key": " and parallel"} def outer_2(state: State): return {"my_key": " and back again"} graph = StateGraph(State) graph.add_node("inner", inner.compile(interrupt_before=["inner_2"])) graph.add_node("outer_1", outer_1) graph.add_node("outer_2", outer_2) graph.add_edge(START, "inner") graph.add_edge(START, "outer_1") graph.add_edge(["inner", "outer_1"], "outer_2") graph.set_finish_point("outer_2") app = graph.compile(checkpointer=checkpointer) # test invoke w/ nested interrupt config = {"configurable": {"thread_id": "1"}} assert app.invoke({"my_key": ""}, config, debug=True) == { "my_key": " and parallel", } assert app.invoke(None, config, debug=True) == { "my_key": "got here and there and parallel and back again", } # below combo of assertions is asserting two things # - outer_1 finishes before inner interrupts (because we see its output in stream, which only happens after node finishes) # - the writes of outer are persisted in 1st call and used in 2nd call, ie outer isn't called again (because we dont see outer_1 output again in 2nd stream) # test stream updates w/ nested interrupt config = {"configurable": {"thread_id": "2"}} assert [*app.stream({"my_key": ""}, config, subgraphs=True)] == [ # we got to parallel node first ((), {"outer_1": {"my_key": " and parallel"}}), ((AnyStr("inner:"),), {"inner_1": {"my_key": "got here", "my_other_key": ""}}), ((), {"__interrupt__": ()}), ] assert [*app.stream(None, config)] == [ {"outer_1": {"my_key": " and parallel"}, "__metadata__": {"cached": True}}, {"inner": {"my_key": "got here and there"}}, {"outer_2": {"my_key": " and back again"}}, ] # test stream values w/ nested interrupt config = {"configurable": {"thread_id": "3"}} assert [*app.stream({"my_key": ""}, config, stream_mode="values")] == [ {"my_key": ""}, {"my_key": " and parallel"}, ] assert [*app.stream(None, config, stream_mode="values")] == [ {"my_key": ""}, {"my_key": "got here and there and parallel"}, {"my_key": "got here and there and parallel and back again"}, ] # test interrupts BEFORE the parallel node app = graph.compile(checkpointer=checkpointer, interrupt_before=["outer_1"]) config = {"configurable": {"thread_id": "4"}} assert [*app.stream({"my_key": ""}, config, stream_mode="values")] == [ {"my_key": ""} ] # while we're waiting for the node w/ interrupt inside to finish assert [*app.stream(None, config, stream_mode="values")] == [ {"my_key": ""}, {"my_key": " and parallel"}, ] assert [*app.stream(None, config, stream_mode="values")] == [ {"my_key": ""}, {"my_key": "got here and there and parallel"}, {"my_key": "got here and there and parallel and back again"}, ] # test interrupts AFTER the parallel node app = graph.compile(checkpointer=checkpointer, interrupt_after=["outer_1"]) config = {"configurable": {"thread_id": "5"}} assert [*app.stream({"my_key": ""}, config, stream_mode="values")] == [ {"my_key": ""}, {"my_key": " and parallel"}, ] assert [*app.stream(None, config, stream_mode="values")] == [ {"my_key": ""}, {"my_key": "got here and there and parallel"}, ] assert [*app.stream(None, config, stream_mode="values")] == [ {"my_key": "got here and there and parallel"}, {"my_key": "got here and there and parallel and back again"}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_doubly_nested_graph_interrupts( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name) class State(TypedDict): my_key: str class ChildState(TypedDict): my_key: str class GrandChildState(TypedDict): my_key: str def grandchild_1(state: ChildState): return {"my_key": state["my_key"] + " here"} def grandchild_2(state: ChildState): return { "my_key": state["my_key"] + " and there", } grandchild = StateGraph(GrandChildState) grandchild.add_node("grandchild_1", grandchild_1) grandchild.add_node("grandchild_2", grandchild_2) grandchild.add_edge("grandchild_1", "grandchild_2") grandchild.set_entry_point("grandchild_1") grandchild.set_finish_point("grandchild_2") child = StateGraph(ChildState) child.add_node( "child_1", grandchild.compile(interrupt_before=["grandchild_2"]), ) child.set_entry_point("child_1") child.set_finish_point("child_1") def parent_1(state: State): return {"my_key": "hi " + state["my_key"]} def parent_2(state: State): return {"my_key": state["my_key"] + " and back again"} graph = StateGraph(State) graph.add_node("parent_1", parent_1) graph.add_node("child", child.compile()) graph.add_node("parent_2", parent_2) graph.set_entry_point("parent_1") graph.add_edge("parent_1", "child") graph.add_edge("child", "parent_2") graph.set_finish_point("parent_2") app = graph.compile(checkpointer=checkpointer) # test invoke w/ nested interrupt config = {"configurable": {"thread_id": "1"}} assert app.invoke({"my_key": "my value"}, config, debug=True) == { "my_key": "hi my value", } assert app.invoke(None, config, debug=True) == { "my_key": "hi my value here and there and back again", } # test stream updates w/ nested interrupt nodes: list[str] = [] config = { "configurable": {"thread_id": "2", CONFIG_KEY_NODE_FINISHED: nodes.append} } assert [*app.stream({"my_key": "my value"}, config)] == [ {"parent_1": {"my_key": "hi my value"}}, {"__interrupt__": ()}, ] assert nodes == ["parent_1", "grandchild_1"] assert [*app.stream(None, config)] == [ {"child": {"my_key": "hi my value here and there"}}, {"parent_2": {"my_key": "hi my value here and there and back again"}}, ] assert nodes == [ "parent_1", "grandchild_1", "grandchild_2", "child_1", "child", "parent_2", ] # test stream values w/ nested interrupt config = {"configurable": {"thread_id": "3"}} assert [*app.stream({"my_key": "my value"}, config, stream_mode="values")] == [ {"my_key": "my value"}, {"my_key": "hi my value"}, ] assert [*app.stream(None, config, stream_mode="values")] == [ {"my_key": "hi my value"}, {"my_key": "hi my value here and there"}, {"my_key": "hi my value here and there and back again"}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_nested_graph_state( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name) class InnerState(TypedDict): my_key: str my_other_key: str def inner_1(state: InnerState): return { "my_key": state["my_key"] + " here", "my_other_key": state["my_key"], } def inner_2(state: InnerState): return { "my_key": state["my_key"] + " and there", "my_other_key": state["my_key"], } inner = StateGraph(InnerState) inner.add_node("inner_1", inner_1) inner.add_node("inner_2", inner_2) inner.add_edge("inner_1", "inner_2") inner.set_entry_point("inner_1") inner.set_finish_point("inner_2") class State(TypedDict): my_key: str other_parent_key: str def outer_1(state: State): return {"my_key": "hi " + state["my_key"]} def outer_2(state: State): return {"my_key": state["my_key"] + " and back again"} graph = StateGraph(State) graph.add_node("outer_1", outer_1) graph.add_node( "inner", inner.compile(interrupt_before=["inner_2"]), ) graph.add_node("outer_2", outer_2) graph.set_entry_point("outer_1") graph.add_edge("outer_1", "inner") graph.add_edge("inner", "outer_2") graph.set_finish_point("outer_2") app = graph.compile(checkpointer=checkpointer) config = {"configurable": {"thread_id": "1"}} app.invoke({"my_key": "my value"}, config, debug=True) # test state w/ nested subgraph state (right after interrupt) # first get_state without subgraph state assert app.get_state(config) == StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "inner", (PULL, "inner"), state={"configurable": {"thread_id": "1", "checkpoint_ns": AnyStr()}}, ), ), next=("inner",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"outer_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) # now, get_state with subgraphs state assert app.get_state(config, subgraphs=True) == StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "inner", (PULL, "inner"), state=StateSnapshot( values={ "my_key": "hi my value here", "my_other_key": "hi my value", }, tasks=( PregelTask( AnyStr(), "inner_2", (PULL, "inner_2"), ), ), next=("inner_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "parents": { "": AnyStr(), }, "source": "loop", "writes": { "inner_1": { "my_key": "hi my value here", "my_other_key": "hi my value", } }, "step": 1, "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "langgraph_node": "inner", "langgraph_path": [PULL, "inner"], "langgraph_step": 2, "langgraph_triggers": ["outer_1"], "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, ), ), ), next=("inner",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"outer_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) # get_state_history returns outer graph checkpoints history = list(app.get_state_history(config)) assert history == [ StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "inner", (PULL, "inner"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), } }, ), ), next=("inner",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"outer_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "my value"}, tasks=( PregelTask( AnyStr(), "outer_1", (PULL, "outer_1"), result={"my_key": "hi my value"}, ), ), next=("outer_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": None, "step": 0, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={}, tasks=( PregelTask( AnyStr(), "__start__", (PULL, "__start__"), result={"my_key": "my value"}, ), ), next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "writes": {"__start__": {"my_key": "my value"}}, "step": -1, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ), ] # get_state_history for a subgraph returns its checkpoints child_history = [*app.get_state_history(history[0].tasks[0].state)] assert child_history == [ StateSnapshot( values={"my_key": "hi my value here", "my_other_key": "hi my value"}, next=("inner_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "loop", "writes": { "inner_1": { "my_key": "hi my value here", "my_other_key": "hi my value", } }, "step": 1, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "langgraph_node": "inner", "langgraph_path": [PULL, "inner"], "langgraph_step": 2, "langgraph_triggers": ["outer_1"], "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, tasks=(PregelTask(AnyStr(), "inner_2", (PULL, "inner_2")),), ), StateSnapshot( values={"my_key": "hi my value"}, next=("inner_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "loop", "writes": None, "step": 0, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "langgraph_node": "inner", "langgraph_path": [PULL, "inner"], "langgraph_step": 2, "langgraph_triggers": ["outer_1"], "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, tasks=( PregelTask( AnyStr(), "inner_1", (PULL, "inner_1"), result={ "my_key": "hi my value here", "my_other_key": "hi my value", }, ), ), ), StateSnapshot( values={}, next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "input", "writes": {"__start__": {"my_key": "hi my value"}}, "step": -1, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "langgraph_node": "inner", "langgraph_path": [PULL, "inner"], "langgraph_step": 2, "langgraph_triggers": ["outer_1"], "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), parent_config=None, tasks=( PregelTask( AnyStr(), "__start__", (PULL, "__start__"), result={"my_key": "hi my value"}, ), ), ), ] # resume app.invoke(None, config, debug=True) # test state w/ nested subgraph state (after resuming from interrupt) assert app.get_state(config) == StateSnapshot( values={"my_key": "hi my value here and there and back again"}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "outer_2": {"my_key": "hi my value here and there and back again"} }, "step": 3, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) # test full history at the end actual_history = list(app.get_state_history(config)) expected_history = [ StateSnapshot( values={"my_key": "hi my value here and there and back again"}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "outer_2": {"my_key": "hi my value here and there and back again"} }, "step": 3, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "hi my value here and there"}, tasks=( PregelTask( AnyStr(), "outer_2", (PULL, "outer_2"), result={"my_key": "hi my value here and there and back again"}, ), ), next=("outer_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"inner": {"my_key": "hi my value here and there"}}, "step": 2, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "inner", (PULL, "inner"), state={ "configurable": {"thread_id": "1", "checkpoint_ns": AnyStr()} }, result={"my_key": "hi my value here and there"}, ), ), next=("inner",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"outer_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "my value"}, tasks=( PregelTask( AnyStr(), "outer_1", (PULL, "outer_1"), result={"my_key": "hi my value"}, ), ), next=("outer_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": None, "step": 0, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={}, tasks=( PregelTask( AnyStr(), "__start__", (PULL, "__start__"), result={"my_key": "my value"}, ), ), next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "writes": {"__start__": {"my_key": "my value"}}, "step": -1, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ), ] assert actual_history == expected_history # test looking up parent state by checkpoint ID for actual_snapshot, expected_snapshot in zip(actual_history, expected_history): assert app.get_state(actual_snapshot.config) == expected_snapshot @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_doubly_nested_graph_state( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name) class State(TypedDict): my_key: str class ChildState(TypedDict): my_key: str class GrandChildState(TypedDict): my_key: str def grandchild_1(state: ChildState): return {"my_key": state["my_key"] + " here"} def grandchild_2(state: ChildState): return { "my_key": state["my_key"] + " and there", } grandchild = StateGraph(GrandChildState) grandchild.add_node("grandchild_1", grandchild_1) grandchild.add_node("grandchild_2", grandchild_2) grandchild.add_edge("grandchild_1", "grandchild_2") grandchild.set_entry_point("grandchild_1") grandchild.set_finish_point("grandchild_2") child = StateGraph(ChildState) child.add_node( "child_1", grandchild.compile(interrupt_before=["grandchild_2"]), ) child.set_entry_point("child_1") child.set_finish_point("child_1") def parent_1(state: State): return {"my_key": "hi " + state["my_key"]} def parent_2(state: State): return {"my_key": state["my_key"] + " and back again"} graph = StateGraph(State) graph.add_node("parent_1", parent_1) graph.add_node("child", child.compile()) graph.add_node("parent_2", parent_2) graph.set_entry_point("parent_1") graph.add_edge("parent_1", "child") graph.add_edge("child", "parent_2") graph.set_finish_point("parent_2") app = graph.compile(checkpointer=checkpointer) # test invoke w/ nested interrupt config = {"configurable": {"thread_id": "1"}} assert [c for c in app.stream({"my_key": "my value"}, config, subgraphs=True)] == [ ((), {"parent_1": {"my_key": "hi my value"}}), ( (AnyStr("child:"), AnyStr("child_1:")), {"grandchild_1": {"my_key": "hi my value here"}}, ), ((), {"__interrupt__": ()}), ] # get state without subgraphs outer_state = app.get_state(config) assert outer_state == StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "child", (PULL, "child"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child"), } }, ), ), next=("child",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"parent_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) child_state = app.get_state(outer_state.tasks[0].state) assert ( child_state.tasks[0] == StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "child_1", (PULL, "child_1"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), } }, ), ), next=("child_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), } }, metadata={ "parents": {"": AnyStr()}, "source": "loop", "writes": None, "step": 0, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), } }, ).tasks[0] ) grandchild_state = app.get_state(child_state.tasks[0].state) assert grandchild_state == StateSnapshot( values={"my_key": "hi my value here"}, tasks=( PregelTask( AnyStr(), "grandchild_2", (PULL, "grandchild_2"), ), ), next=("grandchild_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, metadata={ "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "source": "loop", "writes": {"grandchild_1": {"my_key": "hi my value here"}}, "step": 1, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [PULL, AnyStr("child_1")], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, ) # get state with subgraphs assert app.get_state(config, subgraphs=True) == StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "child", (PULL, "child"), state=StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "child_1", (PULL, "child_1"), state=StateSnapshot( values={"my_key": "hi my value here"}, tasks=( PregelTask( AnyStr(), "grandchild_2", (PULL, "grandchild_2"), ), ), next=("grandchild_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr( re.compile(r"child:.+|child1:") ): AnyStr(), } ), } }, metadata={ "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "source": "loop", "writes": { "grandchild_1": {"my_key": "hi my value here"} }, "step": 1, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [ PULL, AnyStr("child_1"), ], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr( re.compile(r"child:.+|child1:") ): AnyStr(), } ), } }, ), ), ), next=("child_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "parents": {"": AnyStr()}, "source": "loop", "writes": None, "step": 0, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_node": "child", "langgraph_path": [PULL, AnyStr("child")], "langgraph_step": 2, "langgraph_triggers": [AnyStr("parent_1")], "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, ), ), ), next=("child",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"parent_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) # # resume assert [c for c in app.stream(None, config, subgraphs=True)] == [ ( (AnyStr("child:"), AnyStr("child_1:")), {"grandchild_2": {"my_key": "hi my value here and there"}}, ), ((AnyStr("child:"),), {"child_1": {"my_key": "hi my value here and there"}}), ((), {"child": {"my_key": "hi my value here and there"}}), ((), {"parent_2": {"my_key": "hi my value here and there and back again"}}), ] # get state with and without subgraphs assert ( app.get_state(config) == app.get_state(config, subgraphs=True) == StateSnapshot( values={"my_key": "hi my value here and there and back again"}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "parent_2": {"my_key": "hi my value here and there and back again"} }, "step": 3, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) ) # get outer graph history outer_history = list(app.get_state_history(config)) assert outer_history == [ StateSnapshot( values={"my_key": "hi my value here and there and back again"}, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "parent_2": {"my_key": "hi my value here and there and back again"} }, "step": 3, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "hi my value here and there"}, next=("parent_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": {"child": {"my_key": "hi my value here and there"}}, "step": 2, "parents": {}, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="parent_2", path=(PULL, "parent_2"), result={"my_key": "hi my value here and there and back again"}, ), ), ), StateSnapshot( values={"my_key": "hi my value"}, tasks=( PregelTask( AnyStr(), "child", (PULL, "child"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child"), } }, result={"my_key": "hi my value here and there"}, ), ), next=("child",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": {"parent_1": {"my_key": "hi my value"}}, "step": 1, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"my_key": "my value"}, next=("parent_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": None, "step": 0, "parents": {}, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="parent_1", path=(PULL, "parent_1"), result={"my_key": "hi my value"}, ), ), ), StateSnapshot( values={}, next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "input", "writes": {"__start__": {"my_key": "my value"}}, "step": -1, "parents": {}, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, tasks=( PregelTask( id=AnyStr(), name="__start__", path=(PULL, "__start__"), result={"my_key": "my value"}, ), ), ), ] # get child graph history child_history = list(app.get_state_history(outer_history[2].tasks[0].state)) assert child_history == [ StateSnapshot( values={"my_key": "hi my value here and there"}, next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "loop", "writes": {"child_1": {"my_key": "hi my value here and there"}}, "step": 1, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_node": "child", "langgraph_path": [PULL, AnyStr("child")], "langgraph_step": 2, "langgraph_triggers": [AnyStr("parent_1")], "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, tasks=(), ), StateSnapshot( values={"my_key": "hi my value"}, next=("child_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "loop", "writes": None, "step": 0, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_node": "child", "langgraph_path": [PULL, AnyStr("child")], "langgraph_step": 2, "langgraph_triggers": [AnyStr("parent_1")], "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, tasks=( PregelTask( id=AnyStr(), name="child_1", path=(PULL, "child_1"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), } }, result={"my_key": "hi my value here and there"}, ), ), ), StateSnapshot( values={}, next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( {"": AnyStr(), AnyStr("child:"): AnyStr()} ), } }, metadata={ "source": "input", "writes": {"__start__": {"my_key": "hi my value"}}, "step": -1, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_node": "child", "langgraph_path": [PULL, AnyStr("child")], "langgraph_step": 2, "langgraph_triggers": [AnyStr("parent_1")], "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), parent_config=None, tasks=( PregelTask( id=AnyStr(), name="__start__", path=(PULL, "__start__"), result={"my_key": "hi my value"}, ), ), ), ] # get grandchild graph history grandchild_history = list(app.get_state_history(child_history[1].tasks[0].state)) assert grandchild_history == [ StateSnapshot( values={"my_key": "hi my value here and there"}, next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, metadata={ "source": "loop", "writes": {"grandchild_2": {"my_key": "hi my value here and there"}}, "step": 2, "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [ PULL, AnyStr("child_1"), ], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, tasks=(), ), StateSnapshot( values={"my_key": "hi my value here"}, next=("grandchild_2",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, metadata={ "source": "loop", "writes": {"grandchild_1": {"my_key": "hi my value here"}}, "step": 1, "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [ PULL, AnyStr("child_1"), ], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, tasks=( PregelTask( id=AnyStr(), name="grandchild_2", path=(PULL, "grandchild_2"), result={"my_key": "hi my value here and there"}, ), ), ), StateSnapshot( values={"my_key": "hi my value"}, next=("grandchild_1",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, metadata={ "source": "loop", "writes": None, "step": 0, "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [ PULL, AnyStr("child_1"), ], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, tasks=( PregelTask( id=AnyStr(), name="grandchild_1", path=(PULL, "grandchild_1"), result={"my_key": "hi my value here"}, ), ), ), StateSnapshot( values={}, next=("__start__",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), } ), } }, metadata={ "source": "input", "writes": {"__start__": {"my_key": "hi my value"}}, "step": -1, "parents": AnyDict( { "": AnyStr(), AnyStr("child:"): AnyStr(), } ), "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "langgraph_checkpoint_ns": AnyStr("child:"), "langgraph_node": "child_1", "langgraph_path": [ PULL, AnyStr("child_1"), ], "langgraph_step": 1, "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), parent_config=None, tasks=( PregelTask( id=AnyStr(), name="__start__", path=(PULL, "__start__"), result={"my_key": "hi my value"}, ), ), ), ] # replay grandchild checkpoint assert [ c for c in app.stream(None, grandchild_history[2].config, subgraphs=True) ] == [ ( (AnyStr("child:"), AnyStr("child_1:")), {"grandchild_1": {"my_key": "hi my value here"}}, ), ((), {"__interrupt__": ()}), ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_send_to_nested_graphs( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name) class OverallState(TypedDict): subjects: list[str] jokes: Annotated[list[str], operator.add] def continue_to_jokes(state: OverallState): return [Send("generate_joke", {"subject": s}) for s in state["subjects"]] class JokeState(TypedDict): subject: str def edit(state: JokeState): subject = state["subject"] return {"subject": f"{subject} - hohoho"} # subgraph subgraph = StateGraph(JokeState, output=OverallState) subgraph.add_node("edit", edit) subgraph.add_node( "generate", lambda state: {"jokes": [f"Joke about {state['subject']}"]} ) subgraph.set_entry_point("edit") subgraph.add_edge("edit", "generate") subgraph.set_finish_point("generate") # parent graph builder = StateGraph(OverallState) builder.add_node( "generate_joke", subgraph.compile(interrupt_before=["generate"]), ) builder.add_conditional_edges(START, continue_to_jokes) builder.add_edge("generate_joke", END) graph = builder.compile(checkpointer=checkpointer) config = {"configurable": {"thread_id": "1"}} tracer = FakeTracer() # invoke and pause at nested interrupt assert graph.invoke( {"subjects": ["cats", "dogs"]}, config={**config, "callbacks": [tracer]} ) == { "subjects": ["cats", "dogs"], "jokes": [], } assert len(tracer.runs) == 1, "Should produce exactly 1 root run" # check state outer_state = graph.get_state(config) if not FF_SEND_V2: # update state of dogs joke graph graph.update_state(outer_state.tasks[1].state, {"subject": "turtles - hohoho"}) # continue past interrupt assert sorted( graph.stream(None, config=config), key=lambda d: d["generate_joke"]["jokes"][0], ) == [ {"generate_joke": {"jokes": ["Joke about cats - hohoho"]}}, {"generate_joke": {"jokes": ["Joke about turtles - hohoho"]}}, ] return assert outer_state == StateSnapshot( values={"subjects": ["cats", "dogs"], "jokes": []}, tasks=( PregelTask( id=AnyStr(), name="__start__", path=("__pregel_pull", "__start__"), error=None, interrupts=(), state=None, result={"subjects": ["cats", "dogs"]}, ), PregelTask( AnyStr(), "generate_joke", (PUSH, ("__pregel_pull", "__start__"), 1, AnyStr()), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), } }, ), PregelTask( AnyStr(), "generate_joke", (PUSH, ("__pregel_pull", "__start__"), 2, AnyStr()), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), } }, ), ), next=("generate_joke", "generate_joke"), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "writes": {"__start__": {"subjects": ["cats", "dogs"]}}, "step": -1, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ) # check state of each of the inner tasks assert graph.get_state(outer_state.tasks[1].state) == StateSnapshot( values={"subject": "cats - hohoho", "jokes": []}, next=("generate",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("generate_joke:"): AnyStr(), } ), } }, metadata={ "step": 1, "source": "loop", "writes": {"edit": None}, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), "langgraph_checkpoint_ns": AnyStr("generate_joke:"), "langgraph_node": "generate_joke", "langgraph_path": [PUSH, ["__pregel_pull", "__start__"], 1, AnyStr()], "langgraph_step": 0, "langgraph_triggers": [PUSH], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("generate_joke:"): AnyStr(), } ), } }, tasks=(PregelTask(id=AnyStr(""), name="generate", path=(PULL, "generate")),), ) assert graph.get_state(outer_state.tasks[2].state) == StateSnapshot( values={"subject": "dogs - hohoho", "jokes": []}, next=("generate",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("generate_joke:"): AnyStr(), } ), } }, metadata={ "step": 1, "source": "loop", "writes": {"edit": None}, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), "langgraph_checkpoint_ns": AnyStr("generate_joke:"), "langgraph_node": "generate_joke", "langgraph_path": [PUSH, ["__pregel_pull", "__start__"], 2, AnyStr()], "langgraph_step": 0, "langgraph_triggers": [PUSH], }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("generate_joke:"): AnyStr(), } ), } }, tasks=(PregelTask(id=AnyStr(""), name="generate", path=(PULL, "generate")),), ) # update state of dogs joke graph graph.update_state( outer_state.tasks[2 if FF_SEND_V2 else 1].state, {"subject": "turtles - hohoho"} ) # continue past interrupt assert sorted( graph.stream(None, config=config), key=lambda d: d["generate_joke"]["jokes"][0] ) == [ {"generate_joke": {"jokes": ["Joke about cats - hohoho"]}}, {"generate_joke": {"jokes": ["Joke about turtles - hohoho"]}}, ] actual_snapshot = graph.get_state(config) expected_snapshot = StateSnapshot( values={ "subjects": ["cats", "dogs"], "jokes": ["Joke about cats - hohoho", "Joke about turtles - hohoho"], }, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "generate_joke": [ {"jokes": ["Joke about cats - hohoho"]}, {"jokes": ["Joke about turtles - hohoho"]}, ] }, "step": 0, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ) assert actual_snapshot == expected_snapshot # test full history actual_history = list(graph.get_state_history(config)) # get subgraph node state for expected history expected_history = [ StateSnapshot( values={ "subjects": ["cats", "dogs"], "jokes": ["Joke about cats - hohoho", "Joke about turtles - hohoho"], }, tasks=(), next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "loop", "writes": { "generate_joke": [ {"jokes": ["Joke about cats - hohoho"]}, {"jokes": ["Joke about turtles - hohoho"]}, ] }, "step": 0, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, ), StateSnapshot( values={"jokes": []}, tasks=( PregelTask( id=AnyStr(), name="__start__", path=("__pregel_pull", "__start__"), error=None, interrupts=(), state=None, result={"subjects": ["cats", "dogs"]}, ), PregelTask( AnyStr(), "generate_joke", (PUSH, ("__pregel_pull", "__start__"), 1, AnyStr()), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), } }, result={"jokes": ["Joke about cats - hohoho"]}, ), PregelTask( AnyStr(), "generate_joke", (PUSH, ("__pregel_pull", "__start__"), 2, AnyStr()), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), } }, result={"jokes": ["Joke about turtles - hohoho"]}, ), ), next=("__start__", "generate_joke", "generate_joke"), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "parents": {}, "source": "input", "writes": {"__start__": {"subjects": ["cats", "dogs"]}}, "step": -1, "thread_id": "1", }, created_at=AnyStr(), parent_config=None, ), ] assert actual_history == expected_history @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_weather_subgraph( request: pytest.FixtureRequest, checkpointer_name: str, snapshot: SnapshotAssertion ) -> None: from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import AIMessage, ToolCall from langchain_core.tools import tool from langgraph.graph import MessagesState checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name) # setup subgraph @tool def get_weather(city: str): """Get the weather for a specific city""" return f"I'ts sunny in {city}!" weather_model = FakeMessagesListChatModel( responses=[ AIMessage( content="", tool_calls=[ ToolCall( id="tool_call123", name="get_weather", args={"city": "San Francisco"}, ) ], ) ] ) class SubGraphState(MessagesState): city: str def model_node(state: SubGraphState, writer: StreamWriter): writer(" very") result = weather_model.invoke(state["messages"]) return {"city": cast(AIMessage, result).tool_calls[0]["args"]["city"]} def weather_node(state: SubGraphState, writer: StreamWriter): writer(" good") result = get_weather.invoke({"city": state["city"]}) return {"messages": [{"role": "assistant", "content": result}]} subgraph = StateGraph(SubGraphState) subgraph.add_node(model_node) subgraph.add_node(weather_node) subgraph.add_edge(START, "model_node") subgraph.add_edge("model_node", "weather_node") subgraph.add_edge("weather_node", END) subgraph = subgraph.compile(interrupt_before=["weather_node"]) # setup main graph class RouterState(MessagesState): route: Literal["weather", "other"] router_model = FakeMessagesListChatModel( responses=[ AIMessage( content="", tool_calls=[ ToolCall( id="tool_call123", name="router", args={"dest": "weather"}, ) ], ) ] ) def router_node(state: RouterState, writer: StreamWriter): writer("I'm") system_message = "Classify the incoming query as either about weather or not." messages = [{"role": "system", "content": system_message}] + state["messages"] route = router_model.invoke(messages) return {"route": cast(AIMessage, route).tool_calls[0]["args"]["dest"]} def normal_llm_node(state: RouterState): return {"messages": [AIMessage("Hello!")]} def route_after_prediction(state: RouterState): if state["route"] == "weather": return "weather_graph" else: return "normal_llm_node" def weather_graph(state: RouterState): return subgraph.invoke(state) graph = StateGraph(RouterState) graph.add_node(router_node) graph.add_node(normal_llm_node) graph.add_node("weather_graph", weather_graph) graph.add_edge(START, "router_node") graph.add_conditional_edges("router_node", route_after_prediction) graph.add_edge("normal_llm_node", END) graph.add_edge("weather_graph", END) graph = graph.compile(checkpointer=checkpointer) assert graph.get_graph(xray=1).draw_mermaid() == snapshot config = {"configurable": {"thread_id": "1"}} thread2 = {"configurable": {"thread_id": "2"}} inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]} # run with custom output assert [c for c in graph.stream(inputs, thread2, stream_mode="custom")] == [ "I'm", " very", ] assert [c for c in graph.stream(None, thread2, stream_mode="custom")] == [ " good", ] # run until interrupt assert [ c for c in graph.stream( inputs, config=config, stream_mode="updates", subgraphs=True ) ] == [ ((), {"router_node": {"route": "weather"}}), ((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}), ((), {"__interrupt__": ()}), ] # check current state state = graph.get_state(config) assert state == StateSnapshot( values={ "messages": [_AnyIdHumanMessage(content="what's the weather in sf")], "route": "weather", }, next=("weather_graph",), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": {"router_node": {"route": "weather"}}, "step": 1, "parents": {}, "thread_id": "1", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="weather_graph", path=(PULL, "weather_graph"), state={ "configurable": { "thread_id": "1", "checkpoint_ns": AnyStr("weather_graph:"), } }, ), ), ) # update graph.update_state(state.tasks[0].state, {"city": "la"}) # run after update assert [ c for c in graph.stream( None, config=config, stream_mode="updates", subgraphs=True ) ] == [ ( (AnyStr("weather_graph:"),), { "weather_node": { "messages": [{"role": "assistant", "content": "I'ts sunny in la!"}] } }, ), ( (), { "weather_graph": { "messages": [ _AnyIdHumanMessage(content="what's the weather in sf"), _AnyIdAIMessage(content="I'ts sunny in la!"), ] } }, ), ] # try updating acting as weather node config = {"configurable": {"thread_id": "14"}} inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]} assert [ c for c in graph.stream( inputs, config=config, stream_mode="updates", subgraphs=True ) ] == [ ((), {"router_node": {"route": "weather"}}), ((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}), ((), {"__interrupt__": ()}), ] state = graph.get_state(config, subgraphs=True) assert state == StateSnapshot( values={ "messages": [_AnyIdHumanMessage(content="what's the weather in sf")], "route": "weather", }, next=("weather_graph",), config={ "configurable": { "thread_id": "14", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": {"router_node": {"route": "weather"}}, "step": 1, "parents": {}, "thread_id": "14", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "14", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="weather_graph", path=(PULL, "weather_graph"), state=StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what's the weather in sf") ], "city": "San Francisco", }, next=("weather_node",), config={ "configurable": { "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("weather_graph:"): AnyStr(), } ), } }, metadata={ "source": "loop", "writes": {"model_node": {"city": "San Francisco"}}, "step": 1, "parents": {"": AnyStr()}, "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "langgraph_node": "weather_graph", "langgraph_path": [PULL, "weather_graph"], "langgraph_step": 2, "langgraph_triggers": [ "branch:router_node:route_after_prediction:weather_graph" ], "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("weather_graph:"): AnyStr(), } ), } }, tasks=( PregelTask( id=AnyStr(), name="weather_node", path=(PULL, "weather_node"), ), ), ), ), ), ) graph.update_state( state.tasks[0].state.config, {"messages": [{"role": "assistant", "content": "rainy"}]}, as_node="weather_node", ) state = graph.get_state(config, subgraphs=True) assert state == StateSnapshot( values={ "messages": [_AnyIdHumanMessage(content="what's the weather in sf")], "route": "weather", }, next=("weather_graph",), config={ "configurable": { "thread_id": "14", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": {"router_node": {"route": "weather"}}, "step": 1, "parents": {}, "thread_id": "14", }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "14", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=( PregelTask( id=AnyStr(), name="weather_graph", path=(PULL, "weather_graph"), state=StateSnapshot( values={ "messages": [ _AnyIdHumanMessage(content="what's the weather in sf"), _AnyIdAIMessage(content="rainy"), ], "city": "San Francisco", }, next=(), config={ "configurable": { "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("weather_graph:"): AnyStr(), } ), } }, metadata={ "step": 2, "source": "update", "writes": { "weather_node": { "messages": [{"role": "assistant", "content": "rainy"}] } }, "parents": {"": AnyStr()}, "thread_id": "14", "checkpoint_id": AnyStr(), "checkpoint_ns": AnyStr("weather_graph:"), "langgraph_node": "weather_graph", "langgraph_path": [PULL, "weather_graph"], "langgraph_step": 2, "langgraph_triggers": [ "branch:router_node:route_after_prediction:weather_graph" ], "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), "checkpoint_map": AnyDict( { "": AnyStr(), AnyStr("weather_graph:"): AnyStr(), } ), } }, tasks=(), ), ), ), ) assert [ c for c in graph.stream( None, config=config, stream_mode="updates", subgraphs=True ) ] == [ ( (), { "weather_graph": { "messages": [ _AnyIdHumanMessage(content="what's the weather in sf"), _AnyIdAIMessage(content="rainy"), ] } }, ), ] def test_repeat_condition(snapshot: SnapshotAssertion) -> None: class AgentState(TypedDict): hello: str def router(state: AgentState) -> str: return "hmm" workflow = StateGraph(AgentState) workflow.add_node("Researcher", lambda x: x) workflow.add_node("Chart Generator", lambda x: x) workflow.add_node("Call Tool", lambda x: x) workflow.add_conditional_edges( "Researcher", router, { "redo": "Researcher", "continue": "Chart Generator", "call_tool": "Call Tool", "end": END, }, ) workflow.add_conditional_edges( "Chart Generator", router, {"continue": "Researcher", "call_tool": "Call Tool", "end": END}, ) workflow.add_conditional_edges( "Call Tool", # Each agent node updates the 'sender' field # the tool calling node does not, meaning # this edge will route back to the original agent # who invoked the tool lambda x: x["sender"], { "Researcher": "Researcher", "Chart Generator": "Chart Generator", }, ) workflow.set_entry_point("Researcher") app = workflow.compile() assert app.get_graph().draw_mermaid(with_styles=False) == snapshot def test_checkpoint_metadata() -> None: """This test verifies that a run's configurable fields are merged with the previous checkpoint config for each step in the run. """ # set up test from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import AIMessage, AnyMessage from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import tool # graph state class BaseState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] # initialize graph nodes @tool() def search_api(query: str) -> str: """Searches the API for the query.""" return f"result for {query}" tools = [search_api] prompt = ChatPromptTemplate.from_messages( [ ("system", "You are a nice assistant."), ("placeholder", "{messages}"), ] ) model = FakeMessagesListChatModel( responses=[ AIMessage( content="", tool_calls=[ { "id": "tool_call123", "name": "search_api", "args": {"query": "query"}, }, ], ), AIMessage(content="answer"), ] ) @traceable(run_type="llm") def agent(state: BaseState) -> BaseState: formatted = prompt.invoke(state) response = model.invoke(formatted) return {"messages": response, "usage_metadata": {"total_tokens": 123}} def should_continue(data: BaseState) -> str: # Logic to decide whether to continue in the loop or exit if not data["messages"][-1].tool_calls: return "exit" else: return "continue" # define graphs w/ and w/o interrupt workflow = StateGraph(BaseState) workflow.add_node("agent", agent) workflow.add_node("tools", ToolNode(tools)) workflow.set_entry_point("agent") workflow.add_conditional_edges( "agent", should_continue, {"continue": "tools", "exit": END} ) workflow.add_edge("tools", "agent") # graph w/o interrupt checkpointer_1 = MemorySaverAssertCheckpointMetadata() app = workflow.compile(checkpointer=checkpointer_1) # graph w/ interrupt checkpointer_2 = MemorySaverAssertCheckpointMetadata() app_w_interrupt = workflow.compile( checkpointer=checkpointer_2, interrupt_before=["tools"] ) # assertions # invoke graph w/o interrupt assert app.invoke( {"messages": ["what is weather in sf"]}, { "configurable": { "thread_id": "1", "test_config_1": "foo", "test_config_2": "bar", }, }, ) == { "messages": [ _AnyIdHumanMessage(content="what is weather in sf"), _AnyIdAIMessage( content="", tool_calls=[ { "name": "search_api", "args": {"query": "query"}, "id": "tool_call123", "type": "tool_call", } ], ), _AnyIdToolMessage( content="result for query", name="search_api", tool_call_id="tool_call123", ), _AnyIdAIMessage(content="answer"), ] } config = {"configurable": {"thread_id": "1"}} # assert that checkpoint metadata contains the run's configurable fields chkpnt_metadata_1 = checkpointer_1.get_tuple(config).metadata assert chkpnt_metadata_1["thread_id"] == "1" assert chkpnt_metadata_1["test_config_1"] == "foo" assert chkpnt_metadata_1["test_config_2"] == "bar" # Verify that all checkpoint metadata have the expected keys. This check # is needed because a run may have an arbitrary number of steps depending # on how the graph is constructed. chkpnt_tuples_1 = checkpointer_1.list(config) for chkpnt_tuple in chkpnt_tuples_1: assert chkpnt_tuple.metadata["thread_id"] == "1" assert chkpnt_tuple.metadata["test_config_1"] == "foo" assert chkpnt_tuple.metadata["test_config_2"] == "bar" # invoke graph, but interrupt before tool call app_w_interrupt.invoke( {"messages": ["what is weather in sf"]}, { "configurable": { "thread_id": "2", "test_config_3": "foo", "test_config_4": "bar", }, }, ) config = {"configurable": {"thread_id": "2"}} # assert that checkpoint metadata contains the run's configurable fields chkpnt_metadata_2 = checkpointer_2.get_tuple(config).metadata assert chkpnt_metadata_2["thread_id"] == "2" assert chkpnt_metadata_2["test_config_3"] == "foo" assert chkpnt_metadata_2["test_config_4"] == "bar" # resume graph execution app_w_interrupt.invoke( input=None, config={ "configurable": { "thread_id": "2", "test_config_3": "foo", "test_config_4": "bar", } }, ) # assert that checkpoint metadata contains the run's configurable fields chkpnt_metadata_3 = checkpointer_2.get_tuple(config).metadata assert chkpnt_metadata_3["thread_id"] == "2" assert chkpnt_metadata_3["test_config_3"] == "foo" assert chkpnt_metadata_3["test_config_4"] == "bar" # Verify that all checkpoint metadata have the expected keys. This check # is needed because a run may have an arbitrary number of steps depending # on how the graph is constructed. chkpnt_tuples_2 = checkpointer_2.list(config) for chkpnt_tuple in chkpnt_tuples_2: assert chkpnt_tuple.metadata["thread_id"] == "2" assert chkpnt_tuple.metadata["test_config_3"] == "foo" assert chkpnt_tuple.metadata["test_config_4"] == "bar" @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_remove_message_via_state_update( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage workflow = MessageGraph() workflow.add_node( "chatbot", lambda state: [ AIMessage( content="Hello! How can I help you", ) ], ) workflow.set_entry_point("chatbot") workflow.add_edge("chatbot", END) checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name) app = workflow.compile(checkpointer=checkpointer) config = {"configurable": {"thread_id": "1"}} output = app.invoke([HumanMessage(content="Hi")], config=config) app.update_state(config, values=[RemoveMessage(id=output[-1].id)]) updated_state = app.get_state(config) assert len(updated_state.values) == 1 assert updated_state.values[-1].content == "Hi" def test_remove_message_from_node(): from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage workflow = MessageGraph() workflow.add_node( "chatbot", lambda state: [ AIMessage( content="Hello!", ), AIMessage( content="How can I help you?", ), ], ) workflow.add_node("delete_messages", lambda state: [RemoveMessage(id=state[-2].id)]) workflow.set_entry_point("chatbot") workflow.add_edge("chatbot", "delete_messages") workflow.add_edge("delete_messages", END) app = workflow.compile() output = app.invoke([HumanMessage(content="Hi")]) assert len(output) == 2 assert output[-1].content == "How can I help you?" def test_xray_lance(snapshot: SnapshotAssertion): from langchain_core.messages import AnyMessage, HumanMessage from pydantic import BaseModel, Field class Analyst(BaseModel): affiliation: str = Field( description="Primary affiliation of the investment analyst.", ) name: str = Field( description="Name of the investment analyst.", pattern=r"^[a-zA-Z0-9_-]{1,64}$", ) role: str = Field( description="Role of the investment analyst in the context of the topic.", ) description: str = Field( description="Description of the investment analyst focus, concerns, and motives.", ) @property def persona(self) -> str: return f"Name: {self.name}\nRole: {self.role}\nAffiliation: {self.affiliation}\nDescription: {self.description}\n" class Perspectives(BaseModel): analysts: List[Analyst] = Field( description="Comprehensive list of investment analysts with their roles and affiliations.", ) class Section(BaseModel): section_title: str = Field(..., title="Title of the section") context: str = Field( ..., title="Provide a clear summary of the focus area that you researched." ) findings: str = Field( ..., title="Give a clear and detailed overview of your findings based upon the expert interview.", ) thesis: str = Field( ..., title="Give a clear and specific investment thesis based upon these findings.", ) class InterviewState(TypedDict): messages: Annotated[List[AnyMessage], add_messages] analyst: Analyst section: Section class ResearchGraphState(TypedDict): analysts: List[Analyst] topic: str max_analysts: int sections: List[Section] interviews: Annotated[list, operator.add] # Conditional edge def route_messages(state): return "ask_question" def generate_question(state): return ... def generate_answer(state): return ... # Add nodes and edges interview_builder = StateGraph(InterviewState) interview_builder.add_node("ask_question", generate_question) interview_builder.add_node("answer_question", generate_answer) # Flow interview_builder.add_edge(START, "ask_question") interview_builder.add_edge("ask_question", "answer_question") interview_builder.add_conditional_edges("answer_question", route_messages) # Set up memory memory = MemorySaver() # Interview interview_graph = interview_builder.compile(checkpointer=memory).with_config( run_name="Conduct Interviews" ) # View assert interview_graph.get_graph().to_json() == snapshot def run_all_interviews(state: ResearchGraphState): """Edge to run the interview sub-graph using Send""" return [ Send( "conduct_interview", { "analyst": Analyst(), "messages": [ HumanMessage( content="So you said you were writing an article on ...?" ) ], }, ) for s in state["analysts"] ] def generate_sections(state: ResearchGraphState): return ... def generate_analysts(state: ResearchGraphState): return ... builder = StateGraph(ResearchGraphState) builder.add_node("generate_analysts", generate_analysts) builder.add_node("conduct_interview", interview_builder.compile()) builder.add_node("generate_sections", generate_sections) builder.add_edge(START, "generate_analysts") builder.add_conditional_edges( "generate_analysts", run_all_interviews, ["conduct_interview"] ) builder.add_edge("conduct_interview", "generate_sections") builder.add_edge("generate_sections", END) graph = builder.compile() # View assert graph.get_graph().to_json() == snapshot assert graph.get_graph(xray=1).to_json() == snapshot @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_channel_values(request: pytest.FixtureRequest, checkpointer_name: str) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") config = {"configurable": {"thread_id": "1"}} chain = Channel.subscribe_to("input") | Channel.write_to("output") app = Pregel( nodes={ "one": chain, }, channels={ "ephemeral": EphemeralValue(Any), "input": LastValue(int), "output": LastValue(int), }, input_channels=["input", "ephemeral"], output_channels="output", checkpointer=checkpointer, ) app.invoke({"input": 1, "ephemeral": "meow"}, config) assert checkpointer.get(config)["channel_values"] == {"input": 1, "output": 1} def test_xray_issue(snapshot: SnapshotAssertion) -> None: class State(TypedDict): messages: Annotated[list, add_messages] def node(name): def _node(state: State): return {"messages": [("human", f"entered {name} node")]} return _node parent = StateGraph(State) child = StateGraph(State) child.add_node("c_one", node("c_one")) child.add_node("c_two", node("c_two")) child.add_edge("__start__", "c_one") child.add_edge("c_two", "c_one") child.add_conditional_edges( "c_one", lambda x: str(randrange(0, 2)), {"0": "c_two", "1": "__end__"} ) parent.add_node("p_one", node("p_one")) parent.add_node("p_two", child.compile()) parent.add_edge("__start__", "p_one") parent.add_edge("p_two", "p_one") parent.add_conditional_edges( "p_one", lambda x: str(randrange(0, 2)), {"0": "p_two", "1": "__end__"} ) app = parent.compile() assert app.get_graph(xray=True).draw_mermaid() == snapshot def test_xray_bool(snapshot: SnapshotAssertion) -> None: class State(TypedDict): messages: Annotated[list, add_messages] def node(name): def _node(state: State): return {"messages": [("human", f"entered {name} node")]} return _node grand_parent = StateGraph(State) child = StateGraph(State) child.add_node("c_one", node("c_one")) child.add_node("c_two", node("c_two")) child.add_edge("__start__", "c_one") child.add_edge("c_two", "c_one") child.add_conditional_edges( "c_one", lambda x: str(randrange(0, 2)), {"0": "c_two", "1": "__end__"} ) parent = StateGraph(State) parent.add_node("p_one", node("p_one")) parent.add_node("p_two", child.compile()) parent.add_edge("__start__", "p_one") parent.add_edge("p_two", "p_one") parent.add_conditional_edges( "p_one", lambda x: str(randrange(0, 2)), {"0": "p_two", "1": "__end__"} ) grand_parent.add_node("gp_one", node("gp_one")) grand_parent.add_node("gp_two", parent.compile()) grand_parent.add_edge("__start__", "gp_one") grand_parent.add_edge("gp_two", "gp_one") grand_parent.add_conditional_edges( "gp_one", lambda x: str(randrange(0, 2)), {"0": "gp_two", "1": "__end__"} ) app = grand_parent.compile() assert app.get_graph(xray=True).draw_mermaid() == snapshot def test_multiple_sinks_subgraphs(snapshot: SnapshotAssertion) -> None: class State(TypedDict): messages: Annotated[list, add_messages] subgraph_builder = StateGraph(State) subgraph_builder.add_node("one", lambda x: x) subgraph_builder.add_node("two", lambda x: x) subgraph_builder.add_node("three", lambda x: x) subgraph_builder.add_edge("__start__", "one") subgraph_builder.add_conditional_edges("one", lambda x: "two", ["two", "three"]) subgraph = subgraph_builder.compile() builder = StateGraph(State) builder.add_node("uno", lambda x: x) builder.add_node("dos", lambda x: x) builder.add_node("subgraph", subgraph) builder.add_edge("__start__", "uno") builder.add_conditional_edges("uno", lambda x: "dos", ["dos", "subgraph"]) app = builder.compile() assert app.get_graph(xray=True).draw_mermaid() == snapshot def test_subgraph_retries(): class State(TypedDict): count: int class ChildState(State): some_list: Annotated[list, operator.add] called_times = 0 class RandomError(ValueError): """This will be retried on.""" def parent_node(state: State): return {"count": state["count"] + 1} def child_node_a(state: ChildState): nonlocal called_times # We want it to retry only on node_b # NOT re-compute the whole graph. assert not called_times called_times += 1 return {"some_list": ["val"]} def child_node_b(state: ChildState): raise RandomError("First attempt fails") child = StateGraph(ChildState) child.add_node(child_node_a) child.add_node(child_node_b) child.add_edge("__start__", "child_node_a") child.add_edge("child_node_a", "child_node_b") parent = StateGraph(State) parent.add_node("parent_node", parent_node) parent.add_node( "child_graph", child.compile(), retry=RetryPolicy( max_attempts=3, retry_on=(RandomError,), backoff_factor=0.0001, initial_interval=0.0001, ), ) parent.add_edge("parent_node", "child_graph") parent.set_entry_point("parent_node") checkpointer = MemorySaver() app = parent.compile(checkpointer=checkpointer) with pytest.raises(RandomError): app.invoke({"count": 0}, {"configurable": {"thread_id": "foo"}}) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) @pytest.mark.parametrize("store_name", ALL_STORES_SYNC) def test_store_injected( request: pytest.FixtureRequest, checkpointer_name: str, store_name: str ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") the_store = request.getfixturevalue(f"store_{store_name}") class State(TypedDict): count: Annotated[int, operator.add] doc_id = str(uuid.uuid4()) doc = {"some-key": "this-is-a-val"} uid = uuid.uuid4().hex namespace = (f"foo-{uid}", "bar") thread_1 = str(uuid.uuid4()) thread_2 = str(uuid.uuid4()) class Node: def __init__(self, i: Optional[int] = None): self.i = i def __call__(self, inputs: State, config: RunnableConfig, store: BaseStore): assert isinstance(store, BaseStore) store.put( namespace if self.i is not None and config["configurable"]["thread_id"] in (thread_1, thread_2) else (f"foo_{self.i}", "bar"), doc_id, { **doc, "from_thread": config["configurable"]["thread_id"], "some_val": inputs["count"], }, ) return {"count": 1} builder = StateGraph(State) builder.add_node("node", Node()) builder.add_edge("__start__", "node") N = 500 M = 1 if "duckdb" in store_name: logger.warning( "DuckDB store implementation has a known issue that does not" " support concurrent writes, so we're reducing the test scope" ) N = M = 1 for i in range(N): builder.add_node(f"node_{i}", Node(i)) builder.add_edge("__start__", f"node_{i}") graph = builder.compile(store=the_store, checkpointer=checkpointer) results = graph.batch( [{"count": 0}] * M, ([{"configurable": {"thread_id": str(uuid.uuid4())}}] * (M - 1)) + [{"configurable": {"thread_id": thread_1}}], ) result = results[-1] assert result == {"count": N + 1} returned_doc = the_store.get(namespace, doc_id).value assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 0} assert len(the_store.search(namespace)) == 1 # Check results after another turn of the same thread result = graph.invoke({"count": 0}, {"configurable": {"thread_id": thread_1}}) assert result == {"count": (N + 1) * 2} returned_doc = the_store.get(namespace, doc_id).value assert returned_doc == {**doc, "from_thread": thread_1, "some_val": N + 1} assert len(the_store.search(namespace)) == 1 result = graph.invoke({"count": 0}, {"configurable": {"thread_id": thread_2}}) assert result == {"count": N + 1} returned_doc = the_store.get(namespace, doc_id).value assert returned_doc == { **doc, "from_thread": thread_2, "some_val": 0, } # Overwrites the whole doc assert len(the_store.search(namespace)) == 1 # still overwriting the same one def test_enum_node_names(): class NodeName(str, enum.Enum): BAZ = "baz" class State(TypedDict): foo: str bar: str def baz(state: State): return {"bar": state["foo"] + "!"} graph = StateGraph(State) graph.add_node(NodeName.BAZ, baz) graph.add_edge(START, NodeName.BAZ) graph.add_edge(NodeName.BAZ, END) graph = graph.compile() assert graph.invoke({"foo": "hello"}) == {"foo": "hello", "bar": "hello!"} def test_debug_retry(): class State(TypedDict): messages: Annotated[list[str], operator.add] def node(name): def _node(state: State): return {"messages": [f"entered {name} node"]} return _node builder = StateGraph(State) builder.add_node("one", node("one")) builder.add_node("two", node("two")) builder.add_edge(START, "one") builder.add_edge("one", "two") builder.add_edge("two", END) saver = MemorySaver() graph = builder.compile(checkpointer=saver) config = {"configurable": {"thread_id": "1"}} graph.invoke({"messages": []}, config=config) # re-run step: 1 target_config = next( c.parent_config for c in saver.list(config) if c.metadata["step"] == 1 ) update_config = graph.update_state(target_config, values=None) events = [*graph.stream(None, config=update_config, stream_mode="debug")] checkpoint_events = list( reversed([e["payload"] for e in events if e["type"] == "checkpoint"]) ) checkpoint_history = { c.config["configurable"]["checkpoint_id"]: c for c in graph.get_state_history(config) } def lax_normalize_config(config: Optional[dict]) -> Optional[dict]: if config is None: return None return config["configurable"] for stream in checkpoint_events: stream_conf = lax_normalize_config(stream["config"]) stream_parent_conf = lax_normalize_config(stream["parent_config"]) assert stream_conf != stream_parent_conf # ensure the streamed checkpoint == checkpoint from checkpointer.list() history = checkpoint_history[stream["config"]["configurable"]["checkpoint_id"]] history_conf = lax_normalize_config(history.config) assert stream_conf == history_conf history_parent_conf = lax_normalize_config(history.parent_config) assert stream_parent_conf == history_parent_conf def test_debug_subgraphs(): class State(TypedDict): messages: Annotated[list[str], operator.add] def node(name): def _node(state: State): return {"messages": [f"entered {name} node"]} return _node parent = StateGraph(State) child = StateGraph(State) child.add_node("c_one", node("c_one")) child.add_node("c_two", node("c_two")) child.add_edge(START, "c_one") child.add_edge("c_one", "c_two") child.add_edge("c_two", END) parent.add_node("p_one", node("p_one")) parent.add_node("p_two", child.compile()) parent.add_edge(START, "p_one") parent.add_edge("p_one", "p_two") parent.add_edge("p_two", END) graph = parent.compile(checkpointer=MemorySaver()) config = {"configurable": {"thread_id": "1"}} events = [ *graph.stream( {"messages": []}, config=config, stream_mode="debug", ) ] checkpoint_events = list( reversed([e["payload"] for e in events if e["type"] == "checkpoint"]) ) checkpoint_history = list(graph.get_state_history(config)) assert len(checkpoint_events) == len(checkpoint_history) def lax_normalize_config(config: Optional[dict]) -> Optional[dict]: if config is None: return None return config["configurable"] for stream, history in zip(checkpoint_events, checkpoint_history): assert stream["values"] == history.values assert stream["next"] == list(history.next) assert lax_normalize_config(stream["config"]) == lax_normalize_config( history.config ) assert lax_normalize_config(stream["parent_config"]) == lax_normalize_config( history.parent_config ) assert len(stream["tasks"]) == len(history.tasks) for stream_task, history_task in zip(stream["tasks"], history.tasks): assert stream_task["id"] == history_task.id assert stream_task["name"] == history_task.name assert stream_task["interrupts"] == history_task.interrupts assert stream_task.get("error") == history_task.error assert stream_task.get("state") == history_task.state def test_debug_nested_subgraphs(): from collections import defaultdict class State(TypedDict): messages: Annotated[list[str], operator.add] def node(name): def _node(state: State): return {"messages": [f"entered {name} node"]} return _node grand_parent = StateGraph(State) parent = StateGraph(State) child = StateGraph(State) child.add_node("c_one", node("c_one")) child.add_node("c_two", node("c_two")) child.add_edge(START, "c_one") child.add_edge("c_one", "c_two") child.add_edge("c_two", END) parent.add_node("p_one", node("p_one")) parent.add_node("p_two", child.compile()) parent.add_edge(START, "p_one") parent.add_edge("p_one", "p_two") parent.add_edge("p_two", END) grand_parent.add_node("gp_one", node("gp_one")) grand_parent.add_node("gp_two", parent.compile()) grand_parent.add_edge(START, "gp_one") grand_parent.add_edge("gp_one", "gp_two") grand_parent.add_edge("gp_two", END) graph = grand_parent.compile(checkpointer=MemorySaver()) config = {"configurable": {"thread_id": "1"}} events = [ *graph.stream( {"messages": []}, config=config, stream_mode="debug", subgraphs=True, ) ] stream_ns: dict[tuple, dict] = defaultdict(list) for ns, e in events: if e["type"] == "checkpoint": stream_ns[ns].append(e["payload"]) assert list(stream_ns.keys()) == [ (), (AnyStr("gp_two:"),), (AnyStr("gp_two:"), AnyStr("p_two:")), ] history_ns = { ns: list( graph.get_state_history( {"configurable": {"thread_id": "1", "checkpoint_ns": "|".join(ns)}} ) )[::-1] for ns in stream_ns.keys() } def normalize_config(config: Optional[dict]) -> Optional[dict]: if config is None: return None clean_config = {} clean_config["thread_id"] = config["configurable"]["thread_id"] clean_config["checkpoint_id"] = config["configurable"]["checkpoint_id"] clean_config["checkpoint_ns"] = config["configurable"]["checkpoint_ns"] if "checkpoint_map" in config["configurable"]: clean_config["checkpoint_map"] = config["configurable"]["checkpoint_map"] return clean_config for checkpoint_events, checkpoint_history in zip( stream_ns.values(), history_ns.values() ): for stream, history in zip(checkpoint_events, checkpoint_history): assert stream["values"] == history.values assert stream["next"] == list(history.next) assert normalize_config(stream["config"]) == normalize_config( history.config ) assert normalize_config(stream["parent_config"]) == normalize_config( history.parent_config ) assert len(stream["tasks"]) == len(history.tasks) for stream_task, history_task in zip(stream["tasks"], history.tasks): assert stream_task["id"] == history_task.id assert stream_task["name"] == history_task.name assert stream_task["interrupts"] == history_task.interrupts assert stream_task.get("error") == history_task.error assert stream_task.get("state") == history_task.state def test_add_sequence(): class State(TypedDict): foo: Annotated[list[str], operator.add] bar: str def step1(state: State): return {"foo": ["step1"], "bar": "baz"} def step2(state: State): return {"foo": ["step2"]} # test raising if less than 1 steps with pytest.raises(ValueError): StateGraph(State).add_sequence([]) # test raising if duplicate step names with pytest.raises(ValueError): StateGraph(State).add_sequence([step1, step1]) with pytest.raises(ValueError): StateGraph(State).add_sequence([("foo", step1), ("foo", step1)]) # test unnamed steps builder = StateGraph(State) builder.add_sequence([step1, step2]) builder.add_edge(START, "step1") graph = builder.compile() result = graph.invoke({"foo": []}) assert result == {"foo": ["step1", "step2"], "bar": "baz"} stream_chunks = list(graph.stream({"foo": []})) assert stream_chunks == [ {"step1": {"foo": ["step1"], "bar": "baz"}}, {"step2": {"foo": ["step2"]}}, ] # test named steps builder_named_steps = StateGraph(State) builder_named_steps.add_sequence([("meow1", step1), ("meow2", step2)]) builder_named_steps.add_edge(START, "meow1") graph_named_steps = builder_named_steps.compile() result = graph_named_steps.invoke({"foo": []}) stream_chunks = list(graph_named_steps.stream({"foo": []})) assert result == {"foo": ["step1", "step2"], "bar": "baz"} assert stream_chunks == [ {"meow1": {"foo": ["step1"], "bar": "baz"}}, {"meow2": {"foo": ["step2"]}}, ] builder_named_steps = StateGraph(State) builder_named_steps.add_sequence( [ ("meow1", lambda state: {"foo": ["foo"]}), ("meow2", lambda state: {"bar": state["foo"][0] + "bar"}), ], ) builder_named_steps.add_edge(START, "meow1") graph_named_steps = builder_named_steps.compile() result = graph_named_steps.invoke({"foo": []}) stream_chunks = list(graph_named_steps.stream({"foo": []})) # filtered by output schema assert result == {"bar": "foobar", "foo": ["foo"]} assert stream_chunks == [ {"meow1": {"foo": ["foo"]}}, {"meow2": {"bar": "foobar"}}, ] # test two sequences def a(state: State): return {"foo": ["a"]} def b(state: State): return {"foo": ["b"]} builder_two_sequences = StateGraph(State) builder_two_sequences.add_sequence([a]) builder_two_sequences.add_sequence([b]) builder_two_sequences.add_edge(START, "a") builder_two_sequences.add_edge("a", "b") graph_two_sequences = builder_two_sequences.compile() result = graph_two_sequences.invoke({"foo": []}) assert result == {"foo": ["a", "b"]} stream_chunks = list(graph_two_sequences.stream({"foo": []})) assert stream_chunks == [ {"a": {"foo": ["a"]}}, {"b": {"foo": ["b"]}}, ] # test mixed nodes and sequences def c(state: State): return {"foo": ["c"]} def d(state: State): return {"foo": ["d"]} def e(state: State): return {"foo": ["e"]} def foo(state: State): if state["foo"][0] == "a": return "d" else: return "c" builder_complex = StateGraph(State) builder_complex.add_sequence([a, b]) builder_complex.add_conditional_edges("b", foo) builder_complex.add_node(c) builder_complex.add_sequence([d, e]) builder_complex.add_edge(START, "a") graph_complex = builder_complex.compile() result = graph_complex.invoke({"foo": []}) assert result == {"foo": ["a", "b", "d", "e"]} result = graph_complex.invoke({"foo": ["start"]}) assert result == {"foo": ["start", "a", "b", "c"]} stream_chunks = list(graph_complex.stream({"foo": []})) assert stream_chunks == [ {"a": {"foo": ["a"]}}, {"b": {"foo": ["b"]}}, {"d": {"foo": ["d"]}}, {"e": {"foo": ["e"]}}, ] def test_runnable_passthrough_node_graph() -> None: class State(TypedDict): changeme: str async def dummy(state): return state agent = dummy | RunnablePassthrough.assign(prediction=RunnableLambda(lambda x: x)) graph_builder = StateGraph(State) graph_builder.add_node("agent", agent) graph_builder.add_edge(START, "agent") graph = graph_builder.compile() assert graph.get_graph(xray=True).to_json() == graph.get_graph(xray=False).to_json() @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_parent_command(request: pytest.FixtureRequest, checkpointer_name: str) -> None: from langchain_core.messages import BaseMessage from langchain_core.tools import tool @tool(return_direct=True) def get_user_name() -> Command: """Retrieve user name""" return Command(update={"user_name": "Meow"}, graph=Command.PARENT) subgraph_builder = StateGraph(MessagesState) subgraph_builder.add_node("tool", get_user_name) subgraph_builder.add_edge(START, "tool") subgraph = subgraph_builder.compile() class CustomParentState(TypedDict): messages: Annotated[list[BaseMessage], add_messages] # this key is not available to the child graph user_name: str builder = StateGraph(CustomParentState) builder.add_node("alice", subgraph) builder.add_edge(START, "alice") checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") graph = builder.compile(checkpointer=checkpointer) config = {"configurable": {"thread_id": "1"}} assert graph.invoke({"messages": [("user", "get user name")]}, config) == { "messages": [ _AnyIdHumanMessage( content="get user name", additional_kwargs={}, response_metadata={} ), ], "user_name": "Meow", } assert graph.get_state(config) == StateSnapshot( values={ "messages": [ _AnyIdHumanMessage( content="get user name", additional_kwargs={}, response_metadata={} ), ], "user_name": "Meow", }, next=(), config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, metadata={ "source": "loop", "writes": { "alice": { "user_name": "Meow", } }, "thread_id": "1", "step": 1, "parents": {}, }, created_at=AnyStr(), parent_config={ "configurable": { "thread_id": "1", "checkpoint_ns": "", "checkpoint_id": AnyStr(), } }, tasks=(), ) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_interrupt_subgraph(request: pytest.FixtureRequest, checkpointer_name: str): checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") class State(TypedDict): baz: str def foo(state): return {"baz": "foo"} def bar(state): value = interrupt("Please provide baz value:") return {"baz": value} child_builder = StateGraph(State) child_builder.add_node(bar) child_builder.add_edge(START, "bar") builder = StateGraph(State) builder.add_node(foo) builder.add_node("bar", child_builder.compile()) builder.add_edge(START, "foo") builder.add_edge("foo", "bar") graph = builder.compile(checkpointer=checkpointer) thread1 = {"configurable": {"thread_id": "1"}} # First run, interrupted at bar assert graph.invoke({"baz": ""}, thread1) # Resume with answer assert graph.invoke(Command(resume="bar"), thread1) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_interrupt_multiple(request: pytest.FixtureRequest, checkpointer_name: str): checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") class State(TypedDict): my_key: Annotated[str, operator.add] def node(s: State) -> State: answer = interrupt({"value": 1}) answer2 = interrupt({"value": 2}) return {"my_key": answer + " " + answer2} builder = StateGraph(State) builder.add_node("node", node) builder.add_edge(START, "node") graph = builder.compile(checkpointer=checkpointer) thread1 = {"configurable": {"thread_id": "1"}} assert [e for e in graph.stream({"my_key": "DE", "market": "DE"}, thread1)] == [ { "__interrupt__": ( Interrupt( value={"value": 1}, resumable=True, ns=[AnyStr("node:")], when="during", ), ) } ] assert [ event for event in graph.stream( Command(resume="answer 1", update={"my_key": "foofoo"}), thread1 ) ] == [ { "__interrupt__": ( Interrupt( value={"value": 2}, resumable=True, ns=[AnyStr("node:")], when="during", ), ) } ] assert [event for event in graph.stream(Command(resume="answer 2"), thread1)] == [ {"node": {"my_key": "answer 1 answer 2"}}, ] @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_interrupt_loop(request: pytest.FixtureRequest, checkpointer_name: str): checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") class State(TypedDict): age: int other: str def ask_age(s: State): """Ask an expert for help.""" question = "How old are you?" value = None for _ in range(10): value: str = interrupt(question) if not value.isdigit() or int(value) < 18: question = "invalid response" value = None else: break return {"age": int(value)} builder = StateGraph(State) builder.add_node("node", ask_age) builder.add_edge(START, "node") graph = builder.compile(checkpointer=checkpointer) thread1 = {"configurable": {"thread_id": "1"}} assert [e for e in graph.stream({"other": ""}, thread1)] == [ { "__interrupt__": ( Interrupt( value="How old are you?", resumable=True, ns=[AnyStr("node:")], when="during", ), ) } ] assert [ event for event in graph.stream( Command(resume="13"), thread1, ) ] == [ { "__interrupt__": ( Interrupt( value="invalid response", resumable=True, ns=[AnyStr("node:")], when="during", ), ) } ] assert [ event for event in graph.stream( Command(resume="15"), thread1, ) ] == [ { "__interrupt__": ( Interrupt( value="invalid response", resumable=True, ns=[AnyStr("node:")], when="during", ), ) } ] assert [event for event in graph.stream(Command(resume="19"), thread1)] == [ {"node": {"age": 19}}, ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/any_str.py`: ```py import re from typing import Any, Sequence, Union from typing_extensions import Self class FloatBetween(float): def __new__(cls, min_value: float, max_value: float) -> Self: return super().__new__(cls, min_value) def __init__(self, min_value: float, max_value: float) -> None: super().__init__() self.min_value = min_value self.max_value = max_value def __eq__(self, other: object) -> bool: return ( isinstance(other, float) and other >= self.min_value and other <= self.max_value ) def __hash__(self) -> int: return hash((float(self), self.min_value, self.max_value)) class AnyStr(str): def __init__(self, prefix: Union[str, re.Pattern] = "") -> None: super().__init__() self.prefix = prefix def __eq__(self, other: object) -> bool: return isinstance(other, str) and ( other.startswith(self.prefix) if isinstance(self.prefix, str) else self.prefix.match(other) ) def __hash__(self) -> int: return hash((str(self), self.prefix)) class AnyDict(dict): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def __eq__(self, other: object) -> bool: if not isinstance(other, dict) or len(self) != len(other): return False for k, v in self.items(): if kk := next((kk for kk in other if kk == k), None): if v == other[kk]: continue else: return False else: return True class AnyVersion: def __init__(self) -> None: super().__init__() def __eq__(self, other: object) -> bool: return isinstance(other, (str, int, float)) def __hash__(self) -> int: return hash(str(self)) class UnsortedSequence: def __init__(self, *values: Any) -> None: self.seq = values def __eq__(self, value: object) -> bool: return ( isinstance(value, Sequence) and len(self.seq) == len(value) and all(a in value for a in self.seq) ) def __hash__(self) -> int: return hash(frozenset(self.seq)) def __repr__(self) -> str: return repr(self.seq) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_algo.py`: ```py from langgraph.checkpoint.base import empty_checkpoint from langgraph.pregel.algo import prepare_next_tasks from langgraph.pregel.manager import ChannelsManager def test_prepare_next_tasks() -> None: config = {} processes = {} checkpoint = empty_checkpoint() with ChannelsManager({}, checkpoint, config) as (channels, managed): assert ( prepare_next_tasks( checkpoint, {}, processes, channels, managed, config, 0, for_execution=False, ) == {} ) assert ( prepare_next_tasks( checkpoint, {}, processes, channels, managed, config, 0, for_execution=True, checkpointer=None, store=None, manager=None, ) == {} ) # TODO: add more tests ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/fake_tracer.py`: ```py from typing import Any, Optional from uuid import UUID from langchain_core.messages.base import BaseMessage from langchain_core.outputs.chat_generation import ChatGeneration from langchain_core.outputs.llm_result import LLMResult from langchain_core.tracers import BaseTracer, Run class FakeTracer(BaseTracer): """Fake tracer that records LangChain execution. It replaces run ids with deterministic UUIDs for snapshotting.""" def __init__(self) -> None: """Initialize the tracer.""" super().__init__() self.runs: list[Run] = [] self.uuids_map: dict[UUID, UUID] = {} self.uuids_generator = ( UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000) ) def _replace_uuid(self, uuid: UUID) -> UUID: if uuid not in self.uuids_map: self.uuids_map[uuid] = next(self.uuids_generator) return self.uuids_map[uuid] def _replace_message_id(self, maybe_message: Any) -> Any: if isinstance(maybe_message, BaseMessage): maybe_message.id = str(next(self.uuids_generator)) if isinstance(maybe_message, ChatGeneration): maybe_message.message.id = str(next(self.uuids_generator)) if isinstance(maybe_message, LLMResult): for i, gen_list in enumerate(maybe_message.generations): for j, gen in enumerate(gen_list): maybe_message.generations[i][j] = self._replace_message_id(gen) if isinstance(maybe_message, dict): for k, v in maybe_message.items(): maybe_message[k] = self._replace_message_id(v) if isinstance(maybe_message, list): for i, v in enumerate(maybe_message): maybe_message[i] = self._replace_message_id(v) return maybe_message def _copy_run(self, run: Run) -> Run: if run.dotted_order: levels = run.dotted_order.split(".") processed_levels = [] for level in levels: timestamp, run_id = level.split("Z") new_run_id = self._replace_uuid(UUID(run_id)) processed_level = f"{timestamp}Z{new_run_id}" processed_levels.append(processed_level) new_dotted_order = ".".join(processed_levels) else: new_dotted_order = None return run.copy( update={ "id": self._replace_uuid(run.id), "parent_run_id": ( self.uuids_map[run.parent_run_id] if run.parent_run_id else None ), "child_runs": [self._copy_run(child) for child in run.child_runs], "trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None, "dotted_order": new_dotted_order, "inputs": self._replace_message_id(run.inputs), "outputs": self._replace_message_id(run.outputs), } ) def _persist_run(self, run: Run) -> None: """Persist a run.""" self.runs.append(self._copy_run(run)) def flattened_runs(self) -> list[Run]: q = [] + self.runs result = [] while q: parent = q.pop() result.append(parent) if parent.child_runs: q.extend(parent.child_runs) return result @property def run_ids(self) -> list[Optional[UUID]]: runs = self.flattened_runs() uuids_map = {v: k for k, v in self.uuids_map.items()} return [uuids_map.get(r.id) for r in runs] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/poetry.toml`: ```toml [virtualenvs] in-project = true [installer] modern-installation = false ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/memory/__init__.py`: ```py import asyncio import logging import os import pickle import random import shutil from collections import defaultdict from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack from functools import partial from types import TracebackType from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple, Type from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( WRITES_IDX_MAP, BaseCheckpointSaver, ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, SerializerProtocol, get_checkpoint_id, ) from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol logger = logging.getLogger(__name__) class MemorySaver( BaseCheckpointSaver[str], AbstractContextManager, AbstractAsyncContextManager ): """An in-memory checkpoint saver. This checkpoint saver stores checkpoints in memory using a defaultdict. Note: Only use `MemorySaver` for debugging or testing purposes. For production use cases we recommend installing [langgraph-checkpoint-postgres](https://pypi.org/project/langgraph-checkpoint-postgres/) and using `PostgresSaver` / `AsyncPostgresSaver`. Args: serde (Optional[SerializerProtocol]): The serializer to use for serializing and deserializing checkpoints. Defaults to None. Examples: import asyncio from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import StateGraph builder = StateGraph(int) builder.add_node("add_one", lambda x: x + 1) builder.set_entry_point("add_one") builder.set_finish_point("add_one") memory = MemorySaver() graph = builder.compile(checkpointer=memory) coro = graph.ainvoke(1, {"configurable": {"thread_id": "thread-1"}}) asyncio.run(coro) # Output: 2 """ # thread ID -> checkpoint NS -> checkpoint ID -> checkpoint mapping storage: defaultdict[ str, dict[ str, dict[str, tuple[tuple[str, bytes], tuple[str, bytes], Optional[str]]] ], ] writes: defaultdict[ tuple[str, str, str], dict[tuple[str, int], tuple[str, str, tuple[str, bytes]]] ] def __init__( self, *, serde: Optional[SerializerProtocol] = None, factory: Type[defaultdict] = defaultdict, ) -> None: super().__init__(serde=serde) self.storage = factory(lambda: defaultdict(dict)) self.writes = factory(dict) self.stack = ExitStack() if factory is not defaultdict: self.stack.enter_context(self.storage) # type: ignore[arg-type] self.stack.enter_context(self.writes) # type: ignore[arg-type] def __enter__(self) -> "MemorySaver": return self.stack.__enter__() def __exit__( self, exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> Optional[bool]: return self.stack.__exit__(exc_type, exc_value, traceback) async def __aenter__(self) -> "MemorySaver": return self.stack.__enter__() async def __aexit__( self, __exc_type: Optional[type[BaseException]], __exc_value: Optional[BaseException], __traceback: Optional[TracebackType], ) -> Optional[bool]: return self.stack.__exit__(__exc_type, __exc_value, __traceback) def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the in-memory storage. This method retrieves a checkpoint tuple from the in-memory storage based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. """ thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"].get("checkpoint_ns", "") if checkpoint_id := get_checkpoint_id(config): if saved := self.storage[thread_id][checkpoint_ns].get(checkpoint_id): checkpoint, metadata, parent_checkpoint_id = saved writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values() if parent_checkpoint_id: sends = [ w[2] for w in self.writes[ (thread_id, checkpoint_ns, parent_checkpoint_id) ].values() if w[1] == TASKS ] else: sends = [] return CheckpointTuple( config=config, checkpoint={ **self.serde.loads_typed(checkpoint), "pending_sends": [self.serde.loads_typed(s) for s in sends], }, metadata=self.serde.loads_typed(metadata), pending_writes=[ (id, c, self.serde.loads_typed(v)) for id, c, v in writes ], parent_config={ "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } if parent_checkpoint_id else None, ) else: if checkpoints := self.storage[thread_id][checkpoint_ns]: checkpoint_id = max(checkpoints.keys()) checkpoint, metadata, parent_checkpoint_id = checkpoints[checkpoint_id] writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values() if parent_checkpoint_id: sends = [ w[2] for w in self.writes[ (thread_id, checkpoint_ns, parent_checkpoint_id) ].values() if w[1] == TASKS ] else: sends = [] return CheckpointTuple( config={ "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } }, checkpoint={ **self.serde.loads_typed(checkpoint), "pending_sends": [self.serde.loads_typed(s) for s in sends], }, metadata=self.serde.loads_typed(metadata), pending_writes=[ (id, c, self.serde.loads_typed(v)) for id, c, v in writes ], parent_config={ "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } if parent_checkpoint_id else None, ) def list( self, config: Optional[RunnableConfig], *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[CheckpointTuple]: """List checkpoints from the in-memory storage. This method retrieves a list of checkpoint tuples from the in-memory storage based on the provided criteria. Args: config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. before (Optional[RunnableConfig]): List checkpoints created before this configuration. limit (Optional[int]): Maximum number of checkpoints to return. Yields: Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples. """ thread_ids = (config["configurable"]["thread_id"],) if config else self.storage config_checkpoint_ns = ( config["configurable"].get("checkpoint_ns") if config else None ) config_checkpoint_id = get_checkpoint_id(config) if config else None for thread_id in thread_ids: for checkpoint_ns in self.storage[thread_id].keys(): if ( config_checkpoint_ns is not None and checkpoint_ns != config_checkpoint_ns ): continue for checkpoint_id, ( checkpoint, metadata_b, parent_checkpoint_id, ) in sorted( self.storage[thread_id][checkpoint_ns].items(), key=lambda x: x[0], reverse=True, ): # filter by checkpoint ID from config if config_checkpoint_id and checkpoint_id != config_checkpoint_id: continue # filter by checkpoint ID from `before` config if ( before and (before_checkpoint_id := get_checkpoint_id(before)) and checkpoint_id >= before_checkpoint_id ): continue # filter by metadata metadata = self.serde.loads_typed(metadata_b) if filter and not all( query_value == metadata.get(query_key) for query_key, query_value in filter.items() ): continue # limit search results if limit is not None and limit <= 0: break elif limit is not None: limit -= 1 writes = self.writes[ (thread_id, checkpoint_ns, checkpoint_id) ].values() if parent_checkpoint_id: sends = [ w[2] for w in self.writes[ (thread_id, checkpoint_ns, parent_checkpoint_id) ].values() if w[1] == TASKS ] else: sends = [] yield CheckpointTuple( config={ "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } }, checkpoint={ **self.serde.loads_typed(checkpoint), "pending_sends": [self.serde.loads_typed(s) for s in sends], }, metadata=metadata, parent_config={ "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } if parent_checkpoint_id else None, pending_writes=[ (id, c, self.serde.loads_typed(v)) for id, c, v in writes ], ) def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Save a checkpoint to the in-memory storage. This method saves a checkpoint to the in-memory storage. The checkpoint is associated with the provided config. Args: config (RunnableConfig): The config to associate with the checkpoint. checkpoint (Checkpoint): The checkpoint to save. metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. new_versions (dict): New versions as of this write Returns: RunnableConfig: The updated config containing the saved checkpoint's timestamp. """ c = checkpoint.copy() c.pop("pending_sends") # type: ignore[misc] thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"]["checkpoint_ns"] self.storage[thread_id][checkpoint_ns].update( { checkpoint["id"]: ( self.serde.dumps_typed(c), self.serde.dumps_typed(metadata), config["configurable"].get("checkpoint_id"), # parent ) } ) return { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint["id"], } } def put_writes( self, config: RunnableConfig, writes: Sequence[Tuple[str, Any]], task_id: str, ) -> None: """Save a list of writes to the in-memory storage. This method saves a list of writes to the in-memory storage. The writes are associated with the provided config. Args: config (RunnableConfig): The config to associate with the writes. writes (list[tuple[str, Any]]): The writes to save. task_id (str): Identifier for the task creating the writes. Returns: RunnableConfig: The updated config containing the saved writes' timestamp. """ thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"].get("checkpoint_ns", "") checkpoint_id = config["configurable"]["checkpoint_id"] outer_key = (thread_id, checkpoint_ns, checkpoint_id) outer_writes_ = self.writes.get(outer_key) for idx, (c, v) in enumerate(writes): inner_key = (task_id, WRITES_IDX_MAP.get(c, idx)) if inner_key[1] >= 0 and outer_writes_ and inner_key in outer_writes_: continue self.writes[outer_key][inner_key] = (task_id, c, self.serde.dumps_typed(v)) async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Asynchronous version of get_tuple. This method is an asynchronous wrapper around get_tuple that runs the synchronous method in a separate thread using asyncio. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. """ return await asyncio.get_running_loop().run_in_executor( None, self.get_tuple, config ) async def alist( self, config: Optional[RunnableConfig], *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> AsyncIterator[CheckpointTuple]: """Asynchronous version of list. This method is an asynchronous wrapper around list that runs the synchronous method in a separate thread using asyncio. Args: config (RunnableConfig): The config to use for listing the checkpoints. Yields: AsyncIterator[CheckpointTuple]: An asynchronous iterator of checkpoint tuples. """ loop = asyncio.get_running_loop() iter = await loop.run_in_executor( None, partial( self.list, before=before, limit=limit, filter=filter, ), config, ) while True: # handling StopIteration exception inside coroutine won't work # as expected, so using next() with default value to break the loop if item := await loop.run_in_executor(None, next, iter, None): yield item else: break async def aput( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Asynchronous version of put. Args: config (RunnableConfig): The config to associate with the checkpoint. checkpoint (Checkpoint): The checkpoint to save. metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. new_versions (dict): New versions as of this write Returns: RunnableConfig: The updated config containing the saved checkpoint's timestamp. """ return await asyncio.get_running_loop().run_in_executor( None, self.put, config, checkpoint, metadata, new_versions ) async def aput_writes( self, config: RunnableConfig, writes: Sequence[Tuple[str, Any]], task_id: str, ) -> None: """Asynchronous version of put_writes. This method is an asynchronous wrapper around put_writes that runs the synchronous method in a separate thread using asyncio. Args: config (RunnableConfig): The config to associate with the writes. writes (List[Tuple[str, Any]]): The writes to save, each as a (channel, value) pair. task_id (str): Identifier for the task creating the writes. """ return await asyncio.get_running_loop().run_in_executor( None, self.put_writes, config, writes, task_id ) def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: if current is None: current_v = 0 elif isinstance(current, int): current_v = current else: current_v = int(current.split(".")[0]) next_v = current_v + 1 next_h = random.random() return f"{next_v:032}.{next_h:016}" class PersistentDict(defaultdict): """Persistent dictionary with an API compatible with shelve and anydbm. The dict is kept in memory, so the dictionary operations run as fast as a regular dictionary. Write to disk is delayed until close or sync (similar to gdbm's fast mode). Input file format is automatically discovered. Output file format is selectable between pickle, json, and csv. All three serialization formats are backed by fast C implementations. Adapted from https://code.activestate.com/recipes/576642-persistent-dict-with-multiple-standard-file-format/ """ def __init__(self, *args: Any, filename: str, **kwds: Any) -> None: self.flag = "c" # r=readonly, c=create, or n=new self.mode = None # None or an octal triple like 0644 self.format = "pickle" # 'csv', 'json', or 'pickle' self.filename = filename super().__init__(*args, **kwds) def sync(self) -> None: "Write dict to disk" if self.flag == "r": return tempname = self.filename + ".tmp" fileobj = open(tempname, "wb" if self.format == "pickle" else "w") try: self.dump(fileobj) except Exception: os.remove(tempname) raise finally: fileobj.close() shutil.move(tempname, self.filename) # atomic commit if self.mode is not None: os.chmod(self.filename, self.mode) def close(self) -> None: self.sync() self.clear() def __enter__(self) -> "PersistentDict": return self def __exit__(self, *exc_info: Any) -> None: self.close() def dump(self, fileobj: Any) -> None: if self.format == "pickle": pickle.dump(dict(self), fileobj, 2) else: raise NotImplementedError("Unknown format: " + repr(self.format)) def load(self) -> None: # try formats from most restrictive to least restrictive if self.flag == "n": return with open(self.filename, "rb" if self.format == "pickle" else "r") as fileobj: for loader in (pickle.load,): fileobj.seek(0) try: return self.update(loader(fileobj)) except EOFError: return except Exception: logging.error(f"Failed to load file: {fileobj.name}") raise raise ValueError("File not in a supported f ormat") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py`: ```py import dataclasses import decimal import importlib import json import pathlib import re from collections import deque from datetime import date, datetime, time, timedelta, timezone from enum import Enum from inspect import isclass from ipaddress import ( IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network, ) from typing import Any, Callable, Optional, Sequence, Union, cast from uuid import UUID import msgpack # type: ignore[import-untyped] from langchain_core.load.load import Reviver from langchain_core.load.serializable import Serializable from zoneinfo import ZoneInfo from langgraph.checkpoint.serde.base import SerializerProtocol from langgraph.checkpoint.serde.types import SendProtocol from langgraph.store.base import Item LC_REVIVER = Reviver() class JsonPlusSerializer(SerializerProtocol): def _encode_constructor_args( self, constructor: Union[Callable, type[Any]], *, method: Union[None, str, Sequence[Union[None, str]]] = None, args: Optional[Sequence[Any]] = None, kwargs: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: out = { "lc": 2, "type": "constructor", "id": (*constructor.__module__.split("."), constructor.__name__), } if method is not None: out["method"] = method if args is not None: out["args"] = args if kwargs is not None: out["kwargs"] = kwargs return out def _default(self, obj: Any) -> Union[str, dict[str, Any]]: if isinstance(obj, Serializable): return cast(dict[str, Any], obj.to_json()) elif hasattr(obj, "model_dump") and callable(obj.model_dump): return self._encode_constructor_args( obj.__class__, method=(None, "model_construct"), kwargs=obj.model_dump() ) elif hasattr(obj, "dict") and callable(obj.dict): return self._encode_constructor_args( obj.__class__, method=(None, "construct"), kwargs=obj.dict() ) elif hasattr(obj, "_asdict") and callable(obj._asdict): return self._encode_constructor_args(obj.__class__, kwargs=obj._asdict()) elif isinstance(obj, pathlib.Path): return self._encode_constructor_args(pathlib.Path, args=obj.parts) elif isinstance(obj, re.Pattern): return self._encode_constructor_args( re.compile, args=(obj.pattern, obj.flags) ) elif isinstance(obj, UUID): return self._encode_constructor_args(UUID, args=(obj.hex,)) elif isinstance(obj, decimal.Decimal): return self._encode_constructor_args(decimal.Decimal, args=(str(obj),)) elif isinstance(obj, (set, frozenset, deque)): return self._encode_constructor_args(type(obj), args=(tuple(obj),)) elif isinstance(obj, (IPv4Address, IPv4Interface, IPv4Network)): return self._encode_constructor_args(obj.__class__, args=(str(obj),)) elif isinstance(obj, (IPv6Address, IPv6Interface, IPv6Network)): return self._encode_constructor_args(obj.__class__, args=(str(obj),)) elif isinstance(obj, datetime): return self._encode_constructor_args( datetime, method="fromisoformat", args=(obj.isoformat(),) ) elif isinstance(obj, timezone): return self._encode_constructor_args( timezone, args=obj.__getinitargs__(), # type: ignore[attr-defined] ) elif isinstance(obj, ZoneInfo): return self._encode_constructor_args(ZoneInfo, args=(obj.key,)) elif isinstance(obj, timedelta): return self._encode_constructor_args( timedelta, args=(obj.days, obj.seconds, obj.microseconds) ) elif isinstance(obj, date): return self._encode_constructor_args( date, args=(obj.year, obj.month, obj.day) ) elif isinstance(obj, time): return self._encode_constructor_args( time, args=(obj.hour, obj.minute, obj.second, obj.microsecond, obj.tzinfo), kwargs={"fold": obj.fold}, ) elif dataclasses.is_dataclass(obj): return self._encode_constructor_args( obj.__class__, kwargs={ field.name: getattr(obj, field.name) for field in dataclasses.fields(obj) }, ) elif isinstance(obj, Enum): return self._encode_constructor_args(obj.__class__, args=(obj.value,)) elif isinstance(obj, SendProtocol): return self._encode_constructor_args( obj.__class__, kwargs={"node": obj.node, "arg": obj.arg} ) elif isinstance(obj, (bytes, bytearray)): return self._encode_constructor_args( obj.__class__, method="fromhex", args=(obj.hex(),) ) elif isinstance(obj, BaseException): return repr(obj) else: raise TypeError( f"Object of type {obj.__class__.__name__} is not JSON serializable" ) def _reviver(self, value: dict[str, Any]) -> Any: if ( value.get("lc", None) == 2 and value.get("type", None) == "constructor" and value.get("id", None) is not None ): try: # Get module and class name [*module, name] = value["id"] # Import module mod = importlib.import_module(".".join(module)) # Import class cls = getattr(mod, name) # Instantiate class method = value.get("method") if isinstance(method, str): methods = [getattr(cls, method)] elif isinstance(method, list): methods = [ cls if method is None else getattr(cls, method) for method in method ] else: methods = [cls] args = value.get("args") kwargs = value.get("kwargs") for method in methods: try: if isclass(method) and issubclass(method, BaseException): return None if args and kwargs: return method(*args, **kwargs) elif args: return method(*args) elif kwargs: return method(**kwargs) else: return method() except Exception: continue except Exception: return None return LC_REVIVER(value) def dumps(self, obj: Any) -> bytes: return json.dumps(obj, default=self._default, ensure_ascii=False).encode( "utf-8", "ignore" ) def dumps_typed(self, obj: Any) -> tuple[str, bytes]: if isinstance(obj, bytes): return "bytes", obj elif isinstance(obj, bytearray): return "bytearray", obj else: try: return "msgpack", _msgpack_enc(obj) except UnicodeEncodeError: return "json", self.dumps(obj) def loads(self, data: bytes) -> Any: return json.loads(data, object_hook=self._reviver) def loads_typed(self, data: tuple[str, bytes]) -> Any: type_, data_ = data if type_ == "bytes": return data_ elif type_ == "bytearray": return bytearray(data_) elif type_ == "json": return self.loads(data_) elif type_ == "msgpack": return msgpack.unpackb( data_, ext_hook=_msgpack_ext_hook, strict_map_key=False ) else: raise NotImplementedError(f"Unknown serialization type: {type_}") # --- msgpack --- EXT_CONSTRUCTOR_SINGLE_ARG = 0 EXT_CONSTRUCTOR_POS_ARGS = 1 EXT_CONSTRUCTOR_KW_ARGS = 2 EXT_METHOD_SINGLE_ARG = 3 EXT_PYDANTIC_V1 = 4 EXT_PYDANTIC_V2 = 5 def _msgpack_default(obj: Any) -> Union[str, msgpack.ExtType]: if hasattr(obj, "model_dump") and callable(obj.model_dump): # pydantic v2 return msgpack.ExtType( EXT_PYDANTIC_V2, _msgpack_enc( ( obj.__class__.__module__, obj.__class__.__name__, obj.model_dump(), "model_validate_json", ), ), ) elif hasattr(obj, "get_secret_value") and callable(obj.get_secret_value): return msgpack.ExtType( EXT_CONSTRUCTOR_SINGLE_ARG, _msgpack_enc( ( obj.__class__.__module__, obj.__class__.__name__, obj.get_secret_value(), ), ), ) elif hasattr(obj, "dict") and callable(obj.dict): # pydantic v1 return msgpack.ExtType( EXT_PYDANTIC_V1, _msgpack_enc( ( obj.__class__.__module__, obj.__class__.__name__, obj.dict(), ), ), ) elif hasattr(obj, "_asdict") and callable(obj._asdict): # namedtuple return msgpack.ExtType( EXT_CONSTRUCTOR_KW_ARGS, _msgpack_enc( ( obj.__class__.__module__, obj.__class__.__name__, obj._asdict(), ), ), ) elif isinstance(obj, pathlib.Path): return msgpack.ExtType( EXT_CONSTRUCTOR_POS_ARGS, _msgpack_enc( (obj.__class__.__module__, obj.__class__.__name__, obj.parts), ), ) elif isinstance(obj, re.Pattern): return msgpack.ExtType( EXT_CONSTRUCTOR_POS_ARGS, _msgpack_enc( ("re", "compile", (obj.pattern, obj.flags)), ), ) elif isinstance(obj, UUID): return msgpack.ExtType( EXT_CONSTRUCTOR_SINGLE_ARG, _msgpack_enc( (obj.__class__.__module__, obj.__class__.__name__, obj.hex), ), ) elif isinstance(obj, decimal.Decimal): return msgpack.ExtType( EXT_CONSTRUCTOR_SINGLE_ARG, _msgpack_enc( (obj.__class__.__module__, obj.__class__.__name__, str(obj)), ), ) elif isinstance(obj, (set, frozenset, deque)): return msgpack.ExtType( EXT_CONSTRUCTOR_SINGLE_ARG, _msgpack_enc( (obj.__class__.__module__, obj.__class__.__name__, tuple(obj)), ), ) elif isinstance(obj, (IPv4Address, IPv4Interface, IPv4Network)): return msgpack.ExtType( EXT_CONSTRUCTOR_SINGLE_ARG, _msgpack_enc( (obj.__class__.__module__, obj.__class__.__name__, str(obj)), ), ) elif isinstance(obj, (IPv6Address, IPv6Interface, IPv6Network)): return msgpack.ExtType( EXT_CONSTRUCTOR_SINGLE_ARG, _msgpack_enc( (obj.__class__.__module__, obj.__class__.__name__, str(obj)), ), ) elif isinstance(obj, datetime): return msgpack.ExtType( EXT_METHOD_SINGLE_ARG, _msgpack_enc( ( obj.__class__.__module__, obj.__class__.__name__, obj.isoformat(), "fromisoformat", ), ), ) elif isinstance(obj, timedelta): return msgpack.ExtType( EXT_CONSTRUCTOR_POS_ARGS, _msgpack_enc( ( obj.__class__.__module__, obj.__class__.__name__, (obj.days, obj.seconds, obj.microseconds), ), ), ) elif isinstance(obj, date): return msgpack.ExtType( EXT_CONSTRUCTOR_POS_ARGS, _msgpack_enc( ( obj.__class__.__module__, obj.__class__.__name__, (obj.year, obj.month, obj.day), ), ), ) elif isinstance(obj, time): return msgpack.ExtType( EXT_CONSTRUCTOR_KW_ARGS, _msgpack_enc( ( obj.__class__.__module__, obj.__class__.__name__, { "hour": obj.hour, "minute": obj.minute, "second": obj.second, "microsecond": obj.microsecond, "tzinfo": obj.tzinfo, "fold": obj.fold, }, ), ), ) elif isinstance(obj, timezone): return msgpack.ExtType( EXT_CONSTRUCTOR_POS_ARGS, _msgpack_enc( ( obj.__class__.__module__, obj.__class__.__name__, obj.__getinitargs__(), # type: ignore[attr-defined] ), ), ) elif isinstance(obj, ZoneInfo): return msgpack.ExtType( EXT_CONSTRUCTOR_SINGLE_ARG, _msgpack_enc( (obj.__class__.__module__, obj.__class__.__name__, obj.key), ), ) elif isinstance(obj, Enum): return msgpack.ExtType( EXT_CONSTRUCTOR_SINGLE_ARG, _msgpack_enc( (obj.__class__.__module__, obj.__class__.__name__, obj.value), ), ) elif isinstance(obj, SendProtocol): return msgpack.ExtType( EXT_CONSTRUCTOR_POS_ARGS, _msgpack_enc( (obj.__class__.__module__, obj.__class__.__name__, (obj.node, obj.arg)), ), ) elif dataclasses.is_dataclass(obj): # doesn't use dataclasses.asdict to avoid deepcopy and recursion return msgpack.ExtType( EXT_CONSTRUCTOR_KW_ARGS, _msgpack_enc( ( obj.__class__.__module__, obj.__class__.__name__, { field.name: getattr(obj, field.name) for field in dataclasses.fields(obj) }, ), ), ) elif isinstance(obj, Item): return msgpack.ExtType( EXT_CONSTRUCTOR_KW_ARGS, _msgpack_enc( ( obj.__class__.__module__, obj.__class__.__name__, {k: getattr(obj, k) for k in obj.__slots__}, ), ), ) elif isinstance(obj, BaseException): return repr(obj) else: raise TypeError(f"Object of type {obj.__class__.__name__} is not serializable") def _msgpack_ext_hook(code: int, data: bytes) -> Any: if code == EXT_CONSTRUCTOR_SINGLE_ARG: try: tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook) # module, name, arg return getattr(importlib.import_module(tup[0]), tup[1])(tup[2]) except Exception: return elif code == EXT_CONSTRUCTOR_POS_ARGS: try: tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook) # module, name, args return getattr(importlib.import_module(tup[0]), tup[1])(*tup[2]) except Exception: return elif code == EXT_CONSTRUCTOR_KW_ARGS: try: tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook) # module, name, args return getattr(importlib.import_module(tup[0]), tup[1])(**tup[2]) except Exception: return elif code == EXT_METHOD_SINGLE_ARG: try: tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook) # module, name, arg, method return getattr(getattr(importlib.import_module(tup[0]), tup[1]), tup[3])( tup[2] ) except Exception: return elif code == EXT_PYDANTIC_V1: try: tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook) # module, name, kwargs cls = getattr(importlib.import_module(tup[0]), tup[1]) try: return cls(**tup[2]) except Exception: return cls.construct(**tup[2]) except Exception: return elif code == EXT_PYDANTIC_V2: try: tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook) # module, name, kwargs, method cls = getattr(importlib.import_module(tup[0]), tup[1]) try: return cls(**tup[2]) except Exception: return cls.model_construct(**tup[2]) except Exception: return ENC_POOL: deque[msgpack.Packer] = deque(maxlen=32) def _msgpack_enc(data: Any) -> bytes: try: enc = ENC_POOL.popleft() except IndexError: enc = msgpack.Packer(default=_msgpack_default) try: return enc.pack(data) finally: ENC_POOL.append(enc) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/serde/types.py`: ```py from typing import ( Any, Optional, Protocol, Sequence, TypeVar, runtime_checkable, ) from typing_extensions import Self ERROR = "__error__" SCHEDULED = "__scheduled__" INTERRUPT = "__interrupt__" RESUME = "__resume__" TASKS = "__pregel_tasks" Value = TypeVar("Value", covariant=True) Update = TypeVar("Update", contravariant=True) C = TypeVar("C") class ChannelProtocol(Protocol[Value, Update, C]): # Mirrors langgraph.channels.base.BaseChannel @property def ValueType(self) -> Any: ... @property def UpdateType(self) -> Any: ... def checkpoint(self) -> Optional[C]: ... def from_checkpoint(self, checkpoint: Optional[C]) -> Self: ... def update(self, values: Sequence[Update]) -> bool: ... def get(self) -> Value: ... def consume(self) -> bool: ... @runtime_checkable class SendProtocol(Protocol): # Mirrors langgraph.constants.Send node: str arg: Any def __hash__(self) -> int: ... def __repr__(self) -> str: ... def __eq__(self, value: object) -> bool: ... ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/serde/base.py`: ```py from typing import Any, Protocol class SerializerProtocol(Protocol): """Protocol for serialization and deserialization of objects. - `dumps`: Serialize an object to bytes. - `dumps_typed`: Serialize an object to a tuple (type, bytes). - `loads`: Deserialize an object from bytes. - `loads_typed`: Deserialize an object from a tuple (type, bytes). Valid implementations include the `pickle`, `json` and `orjson` modules. """ def dumps(self, obj: Any) -> bytes: ... def dumps_typed(self, obj: Any) -> tuple[str, bytes]: ... def loads(self, data: bytes) -> Any: ... def loads_typed(self, data: tuple[str, bytes]) -> Any: ... class SerializerCompat(SerializerProtocol): def __init__(self, serde: SerializerProtocol) -> None: self.serde = serde def dumps(self, obj: Any) -> bytes: return self.serde.dumps(obj) def loads(self, data: bytes) -> Any: return self.serde.loads(data) def dumps_typed(self, obj: Any) -> tuple[str, bytes]: return type(obj).__name__, self.serde.dumps(obj) def loads_typed(self, data: tuple[str, bytes]) -> Any: return self.serde.loads(data[1]) def maybe_add_typed_methods(serde: SerializerProtocol) -> SerializerProtocol: """Wrap serde old serde implementations in a class with loads_typed and dumps_typed for backwards compatibility.""" if not hasattr(serde, "loads_typed") or not hasattr(serde, "dumps_typed"): return SerializerCompat(serde) return serde ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/base/id.py`: ```py """Adapted from https://github.com/oittaa/uuid6-python/blob/main/src/uuid6/__init__.py#L95 Bundled in to avoid install issues with uuid6 package """ import random import time import uuid from typing import Optional, Tuple _last_v6_timestamp = None class UUID(uuid.UUID): r"""UUID draft version objects""" __slots__ = () def __init__( self, hex: Optional[str] = None, bytes: Optional[bytes] = None, bytes_le: Optional[bytes] = None, fields: Optional[Tuple[int, int, int, int, int, int]] = None, int: Optional[int] = None, version: Optional[int] = None, *, is_safe: uuid.SafeUUID = uuid.SafeUUID.unknown, ) -> None: r"""Create a UUID.""" if int is None or [hex, bytes, bytes_le, fields].count(None) != 4: return super().__init__( hex=hex, bytes=bytes, bytes_le=bytes_le, fields=fields, int=int, version=version, is_safe=is_safe, ) if not 0 <= int < 1 << 128: raise ValueError("int is out of range (need a 128-bit value)") if version is not None: if not 6 <= version <= 8: raise ValueError("illegal version number") # Set the variant to RFC 4122. int &= ~(0xC000 << 48) int |= 0x8000 << 48 # Set the version number. int &= ~(0xF000 << 64) int |= version << 76 super().__init__(int=int, is_safe=is_safe) @property def subsec(self) -> int: return ((self.int >> 64) & 0x0FFF) << 8 | ((self.int >> 54) & 0xFF) @property def time(self) -> int: if self.version == 6: return ( (self.time_low << 28) | (self.time_mid << 12) | (self.time_hi_version & 0x0FFF) ) if self.version == 7: return self.int >> 80 if self.version == 8: return (self.int >> 80) * 10**6 + _subsec_decode(self.subsec) return super().time def _subsec_decode(value: int) -> int: return -(-value * 10**6 // 2**20) def uuid6(node: Optional[int] = None, clock_seq: Optional[int] = None) -> UUID: r"""UUID version 6 is a field-compatible version of UUIDv1, reordered for improved DB locality. It is expected that UUIDv6 will primarily be used in contexts where there are existing v1 UUIDs. Systems that do not involve legacy UUIDv1 SHOULD consider using UUIDv7 instead. If 'node' is not given, a random 48-bit number is chosen. If 'clock_seq' is given, it is used as the sequence number; otherwise a random 14-bit sequence number is chosen.""" global _last_v6_timestamp nanoseconds = time.time_ns() # 0x01b21dd213814000 is the number of 100-ns intervals between the # UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00. timestamp = nanoseconds // 100 + 0x01B21DD213814000 if _last_v6_timestamp is not None and timestamp <= _last_v6_timestamp: timestamp = _last_v6_timestamp + 1 _last_v6_timestamp = timestamp if clock_seq is None: clock_seq = random.getrandbits(14) # instead of stable storage if node is None: node = random.getrandbits(48) time_high_and_time_mid = (timestamp >> 12) & 0xFFFFFFFFFFFF time_low_and_version = timestamp & 0x0FFF uuid_int = time_high_and_time_mid << 80 uuid_int |= time_low_and_version << 64 uuid_int |= (clock_seq & 0x3FFF) << 48 uuid_int |= node & 0xFFFFFFFFFFFF return UUID(int=uuid_int, version=6) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/base/__init__.py`: ```py from datetime import datetime, timezone from typing import ( Any, AsyncIterator, Dict, Generic, Iterator, List, Literal, Mapping, NamedTuple, Optional, Sequence, Tuple, TypedDict, TypeVar, Union, ) from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig from langgraph.checkpoint.base.id import uuid6 from langgraph.checkpoint.serde.base import SerializerProtocol, maybe_add_typed_methods from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer from langgraph.checkpoint.serde.types import ( ERROR, INTERRUPT, RESUME, SCHEDULED, ChannelProtocol, SendProtocol, ) V = TypeVar("V", int, float, str) PendingWrite = Tuple[str, str, Any] # Marked as total=False to allow for future expansion. class CheckpointMetadata(TypedDict, total=False): """Metadata associated with a checkpoint.""" source: Literal["input", "loop", "update", "fork"] """The source of the checkpoint. - "input": The checkpoint was created from an input to invoke/stream/batch. - "loop": The checkpoint was created from inside the pregel loop. - "update": The checkpoint was created from a manual state update. - "fork": The checkpoint was created as a copy of another checkpoint. """ step: int """The step number of the checkpoint. -1 for the first "input" checkpoint. 0 for the first "loop" checkpoint. ... for the nth checkpoint afterwards. """ writes: dict[str, Any] """The writes that were made between the previous checkpoint and this one. Mapping from node name to writes emitted by that node. """ parents: dict[str, str] """The IDs of the parent checkpoints. Mapping from checkpoint namespace to checkpoint ID. """ class TaskInfo(TypedDict): status: Literal["scheduled", "success", "error"] ChannelVersions = dict[str, Union[str, int, float]] class Checkpoint(TypedDict): """State snapshot at a given point in time.""" v: int """The version of the checkpoint format. Currently 1.""" id: str """The ID of the checkpoint. This is both unique and monotonically increasing, so can be used for sorting checkpoints from first to last.""" ts: str """The timestamp of the checkpoint in ISO 8601 format.""" channel_values: dict[str, Any] """The values of the channels at the time of the checkpoint. Mapping from channel name to deserialized channel snapshot value. """ channel_versions: ChannelVersions """The versions of the channels at the time of the checkpoint. The keys are channel names and the values are monotonically increasing version strings for each channel. """ versions_seen: dict[str, ChannelVersions] """Map from node ID to map from channel name to version seen. This keeps track of the versions of the channels that each node has seen. Used to determine which nodes to execute next. """ pending_sends: List[SendProtocol] """List of inputs pushed to nodes but not yet processed. Cleared by the next checkpoint.""" def empty_checkpoint() -> Checkpoint: return Checkpoint( v=1, id=str(uuid6(clock_seq=-2)), ts=datetime.now(timezone.utc).isoformat(), channel_values={}, channel_versions={}, versions_seen={}, pending_sends=[], ) def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint: return Checkpoint( v=checkpoint["v"], ts=checkpoint["ts"], id=checkpoint["id"], channel_values=checkpoint["channel_values"].copy(), channel_versions=checkpoint["channel_versions"].copy(), versions_seen={k: v.copy() for k, v in checkpoint["versions_seen"].items()}, pending_sends=checkpoint.get("pending_sends", []).copy(), ) def create_checkpoint( checkpoint: Checkpoint, channels: Optional[Mapping[str, ChannelProtocol]], step: int, *, id: Optional[str] = None, ) -> Checkpoint: """Create a checkpoint for the given channels.""" ts = datetime.now(timezone.utc).isoformat() if channels is None: values = checkpoint["channel_values"] else: values = {} for k, v in channels.items(): if k not in checkpoint["channel_versions"]: continue try: values[k] = v.checkpoint() except EmptyChannelError: pass return Checkpoint( v=1, ts=ts, id=id or str(uuid6(clock_seq=step)), channel_values=values, channel_versions=checkpoint["channel_versions"], versions_seen=checkpoint["versions_seen"], pending_sends=checkpoint.get("pending_sends", []), ) class CheckpointTuple(NamedTuple): """A tuple containing a checkpoint and its associated data.""" config: RunnableConfig checkpoint: Checkpoint metadata: CheckpointMetadata parent_config: Optional[RunnableConfig] = None pending_writes: Optional[List[PendingWrite]] = None CheckpointThreadId = ConfigurableFieldSpec( id="thread_id", annotation=str, name="Thread ID", description=None, default="", is_shared=True, ) CheckpointNS = ConfigurableFieldSpec( id="checkpoint_ns", annotation=str, name="Checkpoint NS", description='Checkpoint namespace. Denotes the path to the subgraph node the checkpoint originates from, separated by `|` character, e.g. `"child|grandchild"`. Defaults to "" (root graph).', default="", is_shared=True, ) CheckpointId = ConfigurableFieldSpec( id="checkpoint_id", annotation=Optional[str], name="Checkpoint ID", description="Pass to fetch a past checkpoint. If None, fetches the latest checkpoint.", default=None, is_shared=True, ) class BaseCheckpointSaver(Generic[V]): """Base class for creating a graph checkpointer. Checkpointers allow LangGraph agents to persist their state within and across multiple interactions. Attributes: serde (SerializerProtocol): Serializer for encoding/decoding checkpoints. Note: When creating a custom checkpoint saver, consider implementing async versions to avoid blocking the main thread. """ serde: SerializerProtocol = JsonPlusSerializer() def __init__( self, *, serde: Optional[SerializerProtocol] = None, ) -> None: self.serde = maybe_add_typed_methods(serde or self.serde) @property def config_specs(self) -> list[ConfigurableFieldSpec]: """Define the configuration options for the checkpoint saver. Returns: list[ConfigurableFieldSpec]: List of configuration field specs. """ return [CheckpointThreadId, CheckpointNS, CheckpointId] def get(self, config: RunnableConfig) -> Optional[Checkpoint]: """Fetch a checkpoint using the given configuration. Args: config (RunnableConfig): Configuration specifying which checkpoint to retrieve. Returns: Optional[Checkpoint]: The requested checkpoint, or None if not found. """ if value := self.get_tuple(config): return value.checkpoint def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Fetch a checkpoint tuple using the given configuration. Args: config (RunnableConfig): Configuration specifying which checkpoint to retrieve. Returns: Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. Raises: NotImplementedError: Implement this method in your custom checkpoint saver. """ raise NotImplementedError def list( self, config: Optional[RunnableConfig], *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[CheckpointTuple]: """List checkpoints that match the given criteria. Args: config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria. before (Optional[RunnableConfig]): List checkpoints created before this configuration. limit (Optional[int]): Maximum number of checkpoints to return. Returns: Iterator[CheckpointTuple]: Iterator of matching checkpoint tuples. Raises: NotImplementedError: Implement this method in your custom checkpoint saver. """ raise NotImplementedError def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Store a checkpoint with its configuration and metadata. Args: config (RunnableConfig): Configuration for the checkpoint. checkpoint (Checkpoint): The checkpoint to store. metadata (CheckpointMetadata): Additional metadata for the checkpoint. new_versions (ChannelVersions): New channel versions as of this write. Returns: RunnableConfig: Updated configuration after storing the checkpoint. Raises: NotImplementedError: Implement this method in your custom checkpoint saver. """ raise NotImplementedError def put_writes( self, config: RunnableConfig, writes: Sequence[Tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint. Args: config (RunnableConfig): Configuration of the related checkpoint. writes (List[Tuple[str, Any]]): List of writes to store. task_id (str): Identifier for the task creating the writes. Raises: NotImplementedError: Implement this method in your custom checkpoint saver. """ raise NotImplementedError async def aget(self, config: RunnableConfig) -> Optional[Checkpoint]: """Asynchronously fetch a checkpoint using the given configuration. Args: config (RunnableConfig): Configuration specifying which checkpoint to retrieve. Returns: Optional[Checkpoint]: The requested checkpoint, or None if not found. """ if value := await self.aget_tuple(config): return value.checkpoint async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Asynchronously fetch a checkpoint tuple using the given configuration. Args: config (RunnableConfig): Configuration specifying which checkpoint to retrieve. Returns: Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. Raises: NotImplementedError: Implement this method in your custom checkpoint saver. """ raise NotImplementedError async def alist( self, config: Optional[RunnableConfig], *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> AsyncIterator[CheckpointTuple]: """Asynchronously list checkpoints that match the given criteria. Args: config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. before (Optional[RunnableConfig]): List checkpoints created before this configuration. limit (Optional[int]): Maximum number of checkpoints to return. Returns: AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples. Raises: NotImplementedError: Implement this method in your custom checkpoint saver. """ raise NotImplementedError yield async def aput( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Asynchronously store a checkpoint with its configuration and metadata. Args: config (RunnableConfig): Configuration for the checkpoint. checkpoint (Checkpoint): The checkpoint to store. metadata (CheckpointMetadata): Additional metadata for the checkpoint. new_versions (ChannelVersions): New channel versions as of this write. Returns: RunnableConfig: Updated configuration after storing the checkpoint. Raises: NotImplementedError: Implement this method in your custom checkpoint saver. """ raise NotImplementedError async def aput_writes( self, config: RunnableConfig, writes: Sequence[Tuple[str, Any]], task_id: str, ) -> None: """Asynchronously store intermediate writes linked to a checkpoint. Args: config (RunnableConfig): Configuration of the related checkpoint. writes (List[Tuple[str, Any]]): List of writes to store. task_id (str): Identifier for the task creating the writes. Raises: NotImplementedError: Implement this method in your custom checkpoint saver. """ raise NotImplementedError def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V: """Generate the next version ID for a channel. Default is to use integer versions, incrementing by 1. If you override, you can use str/int/float versions, as long as they are monotonically increasing. Args: current (Optional[V]): The current version identifier (int, float, or str). channel (BaseChannel): The channel being versioned. Returns: V: The next version identifier, which must be increasing. """ if isinstance(current, str): raise NotImplementedError elif current is None: return 1 else: return current + 1 class EmptyChannelError(Exception): """Raised when attempting to get the value of a channel that hasn't been updated for the first time yet.""" pass def get_checkpoint_id(config: RunnableConfig) -> Optional[str]: """Get checkpoint ID in a backwards-compatible manner (fallback on thread_ts).""" return config["configurable"].get( "checkpoint_id", config["configurable"].get("thread_ts") ) """ Mapping from error type to error index. Regular writes just map to their index in the list of writes being saved. Special writes (e.g. errors) map to negative indices, to avoid those writes from conflicting with regular writes. Each Checkpointer implementation should use this mapping in put_writes. """ WRITES_IDX_MAP = {ERROR: -1, SCHEDULED: -2, INTERRUPT: -3, RESUME: -4} ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/store/memory/__init__.py`: ```py """In-memory dictionary-backed store with optional vector search. !!! example "Examples" Basic key-value storage: ```python from langgraph.store.memory import InMemoryStore store = InMemoryStore() store.put(("users", "123"), "prefs", {"theme": "dark"}) item = store.get(("users", "123"), "prefs") ``` Vector search using LangChain embeddings: ```python from langchain.embeddings import init_embeddings from langgraph.store.memory import InMemoryStore store = InMemoryStore( index={ "dims": 1536, "embed": init_embeddings("openai:text-embedding-3-small") } ) # Store documents store.put(("docs",), "doc1", {"text": "Python tutorial"}) store.put(("docs",), "doc2", {"text": "TypeScript guide"}) # Search by similarity results = store.search(("docs",), query="python programming") ``` Vector search using OpenAI SDK directly: ```python from openai import OpenAI from langgraph.store.memory import InMemoryStore client = OpenAI() def embed_texts(texts: list[str]) -> list[list[float]]: response = client.embeddings.create( model="text-embedding-3-small", input=texts ) return [e.embedding for e in response.data] store = InMemoryStore( index={ "dims": 1536, "embed": embed_texts } ) # Store documents store.put(("docs",), "doc1", {"text": "Python tutorial"}) store.put(("docs",), "doc2", {"text": "TypeScript guide"}) # Search by similarity results = store.search(("docs",), query="python programming") ``` Async vector search using OpenAI SDK: ```python from openai import AsyncOpenAI from langgraph.store.memory import InMemoryStore client = AsyncOpenAI() async def aembed_texts(texts: list[str]) -> list[list[float]]: response = await client.embeddings.create( model="text-embedding-3-small", input=texts ) return [e.embedding for e in response.data] store = InMemoryStore( index={ "dims": 1536, "embed": aembed_texts } ) # Store documents await store.aput(("docs",), "doc1", {"text": "Python tutorial"}) await store.aput(("docs",), "doc2", {"text": "TypeScript guide"}) # Search by similarity results = await store.asearch(("docs",), query="python programming") ``` Warning: This store keeps all data in memory. Data is lost when the process exits. For persistence, use a database-backed store like PostgresStore. Tip: For vector search, install numpy for better performance: ```bash pip install numpy ``` """ import asyncio import concurrent.futures as cf import functools import logging from collections import defaultdict from datetime import datetime, timezone from importlib import util from typing import Any, Iterable, Optional from langchain_core.embeddings import Embeddings from langgraph.store.base import ( BaseStore, GetOp, IndexConfig, Item, ListNamespacesOp, MatchCondition, Op, PutOp, Result, SearchItem, SearchOp, ensure_embeddings, get_text_at_path, tokenize_path, ) logger = logging.getLogger(__name__) class InMemoryStore(BaseStore): """In-memory dictionary-backed store with optional vector search. !!! example "Examples" Basic key-value storage: store = InMemoryStore() store.put(("users", "123"), "prefs", {"theme": "dark"}) item = store.get(("users", "123"), "prefs") Vector search with embeddings: from langchain.embeddings import init_embeddings store = InMemoryStore(index={ "dims": 1536, "embed": init_embeddings("openai:text-embedding-3-small"), "fields": ["text"], }) # Store documents store.put(("docs",), "doc1", {"text": "Python tutorial"}) store.put(("docs",), "doc2", {"text": "TypeScript guide"}) # Search by similarity results = store.search(("docs",), query="python programming") Note: Semantic search is disabled by default. You can enable it by providing an `index` configuration when creating the store. Without this configuration, all `index` arguments passed to `put` or `aput`will have no effect. Warning: This store keeps all data in memory. Data is lost when the process exits. For persistence, use a database-backed store like PostgresStore. Tip: For vector search, install numpy for better performance: ```bash pip install numpy ``` """ __slots__ = ( "_data", "_vectors", "index_config", "embeddings", ) def __init__(self, *, index: Optional[IndexConfig] = None) -> None: # Both _data and _vectors are wrapped in the In-memory API # Do not change their names self._data: dict[tuple[str, ...], dict[str, Item]] = defaultdict(dict) # [ns][key][path] self._vectors: dict[tuple[str, ...], dict[str, dict[str, list[float]]]] = ( defaultdict(lambda: defaultdict(dict)) ) self.index_config = index if self.index_config: self.index_config = self.index_config.copy() self.embeddings: Optional[Embeddings] = ensure_embeddings( self.index_config.get("embed"), ) self.index_config["__tokenized_fields"] = [ (p, tokenize_path(p)) if p != "$" else (p, p) for p in (self.index_config.get("fields") or ["$"]) ] else: self.index_config = None self.embeddings = None def batch(self, ops: Iterable[Op]) -> list[Result]: # The batch/abatch methods are treated as internal. # Users should access via put/search/get/list_namespaces/etc. results, put_ops, search_ops = self._prepare_ops(ops) if search_ops: queryinmem_store = self._embed_search_queries(search_ops) self._batch_search(search_ops, queryinmem_store, results) to_embed = self._extract_texts(put_ops) if to_embed and self.index_config and self.embeddings: embeddings = self.embeddings.embed_documents(list(to_embed)) self._insertinmem_store(to_embed, embeddings) self._apply_put_ops(put_ops) return results async def abatch(self, ops: Iterable[Op]) -> list[Result]: # The batch/abatch methods are treated as internal. # Users should access via put/search/get/list_namespaces/etc. results, put_ops, search_ops = self._prepare_ops(ops) if search_ops: queryinmem_store = await self._aembed_search_queries(search_ops) self._batch_search(search_ops, queryinmem_store, results) to_embed = self._extract_texts(put_ops) if to_embed and self.index_config and self.embeddings: embeddings = await self.embeddings.aembed_documents(list(to_embed)) self._insertinmem_store(to_embed, embeddings) self._apply_put_ops(put_ops) return results # Helpers def _filter_items(self, op: SearchOp) -> list[tuple[Item, list[list[float]]]]: """Filter items by namespace and filter function, return items with their embeddings.""" namespace_prefix = op.namespace_prefix def filter_func(item: Item) -> bool: if not op.filter: return True return all( _compare_values(item.value.get(key), filter_value) for key, filter_value in op.filter.items() ) filtered = [] for namespace in self._data: if not ( namespace[: len(namespace_prefix)] == namespace_prefix if len(namespace) >= len(namespace_prefix) else False ): continue for key, item in self._data[namespace].items(): if filter_func(item): if op.query and (embeddings := self._vectors[namespace].get(key)): filtered.append((item, list(embeddings.values()))) else: filtered.append((item, [])) return filtered def _embed_search_queries( self, search_ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]], ) -> dict[str, list[float]]: queryinmem_store = {} if self.index_config and self.embeddings and search_ops: queries = {op.query for (op, _) in search_ops.values() if op.query} if queries: with cf.ThreadPoolExecutor() as executor: futures = { q: executor.submit(self.embeddings.embed_query, q) for q in list(queries) } for query, future in futures.items(): queryinmem_store[query] = future.result() return queryinmem_store async def _aembed_search_queries( self, search_ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]], ) -> dict[str, list[float]]: queryinmem_store = {} if self.index_config and self.embeddings and search_ops: queries = {op.query for (op, _) in search_ops.values() if op.query} if queries: coros = [self.embeddings.aembed_query(q) for q in list(queries)] results = await asyncio.gather(*coros) queryinmem_store = dict(zip(queries, results)) return queryinmem_store def _batch_search( self, ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]], queryinmem_store: dict[str, list[float]], results: list[Result], ) -> None: """Perform batch similarity search for multiple queries.""" for i, (op, candidates) in ops.items(): if not candidates: results[i] = [] continue if op.query and queryinmem_store: query_embedding = queryinmem_store[op.query] flat_items, flat_vectors = [], [] scoreless = [] for item, vectors in candidates: for vector in vectors: flat_items.append(item) flat_vectors.append(vector) if not vectors: scoreless.append(item) scores = _cosine_similarity(query_embedding, flat_vectors) sorted_results = sorted( zip(scores, flat_items), key=lambda x: x[0], reverse=True ) # max pooling seen: set[tuple[tuple[str, ...], str]] = set() kept: list[tuple[Optional[float], Item]] = [] for score, item in sorted_results: key = (item.namespace, item.key) if key in seen: continue ix = len(seen) seen.add(key) if ix >= op.offset + op.limit: break if ix < op.offset: continue kept.append((score, item)) if scoreless and len(kept) < op.limit: # Corner case: if we request more items than what we have embedded, # fill the rest with non-scored items kept.extend( (None, item) for item in scoreless[: op.limit - len(kept)] ) results[i] = [ SearchItem( namespace=item.namespace, key=item.key, value=item.value, created_at=item.created_at, updated_at=item.updated_at, score=float(score) if score is not None else None, ) for score, item in kept ] else: results[i] = [ SearchItem( namespace=item.namespace, key=item.key, value=item.value, created_at=item.created_at, updated_at=item.updated_at, ) for (item, _) in candidates[op.offset : op.offset + op.limit] ] def _prepare_ops( self, ops: Iterable[Op] ) -> tuple[ list[Result], dict[tuple[tuple[str, ...], str], PutOp], dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]], ]: results: list[Result] = [] put_ops: dict[tuple[tuple[str, ...], str], PutOp] = {} search_ops: dict[ int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]] ] = {} for i, op in enumerate(ops): if isinstance(op, GetOp): item = self._data[op.namespace].get(op.key) results.append(item) elif isinstance(op, SearchOp): search_ops[i] = (op, self._filter_items(op)) results.append(None) elif isinstance(op, ListNamespacesOp): results.append(self._handle_list_namespaces(op)) elif isinstance(op, PutOp): put_ops[(op.namespace, op.key)] = op results.append(None) else: raise ValueError(f"Unknown operation type: {type(op)}") return results, put_ops, search_ops def _apply_put_ops(self, put_ops: dict[tuple[tuple[str, ...], str], PutOp]) -> None: for (namespace, key), op in put_ops.items(): if op.value is None: self._data[namespace].pop(key, None) self._vectors[namespace].pop(key, None) else: self._data[namespace][key] = Item( value=op.value, key=key, namespace=namespace, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) def _extract_texts( self, put_ops: dict[tuple[tuple[str, ...], str], PutOp] ) -> dict[str, list[tuple[tuple[str, ...], str, str]]]: if put_ops and self.index_config and self.embeddings: to_embed = defaultdict(list) for op in put_ops.values(): if op.value is not None and op.index is not False: if op.index is None: paths = self.index_config["__tokenized_fields"] else: paths = [(ix, tokenize_path(ix)) for ix in op.index] for path, field in paths: texts = get_text_at_path(op.value, field) if texts: if len(texts) > 1: for i, text in enumerate(texts): to_embed[text].append( (op.namespace, op.key, f"{path}.{i}") ) else: to_embed[texts[0]].append((op.namespace, op.key, path)) return to_embed return {} def _insertinmem_store( self, to_embed: dict[str, list[tuple[tuple[str, ...], str, str]]], embeddings: list[list[float]], ) -> None: indices = [index for indices in to_embed.values() for index in indices] if len(indices) != len(embeddings): raise ValueError( f"Number of embeddings ({len(embeddings)}) does not" f" match number of indices ({len(indices)})" ) for embedding, (ns, key, path) in zip(embeddings, indices): self._vectors[ns][key][path] = embedding def _handle_list_namespaces(self, op: ListNamespacesOp) -> list[tuple[str, ...]]: all_namespaces = list( self._data.keys() ) # Avoid collection size changing while iterating namespaces = all_namespaces if op.match_conditions: namespaces = [ ns for ns in namespaces if all(_does_match(condition, ns) for condition in op.match_conditions) ] if op.max_depth is not None: namespaces = sorted({ns[: op.max_depth] for ns in namespaces}) else: namespaces = sorted(namespaces) return namespaces[op.offset : op.offset + op.limit] @functools.lru_cache(maxsize=1) def _check_numpy() -> bool: if bool(util.find_spec("numpy")): return True logger.warning( "NumPy not found in the current Python environment. " "The InMemoryStore will use a pure Python implementation for vector operations, " "which may significantly impact performance, especially for large datasets or frequent searches. " "For optimal speed and efficiency, consider installing NumPy: " "pip install numpy" ) return False def _cosine_similarity(X: list[float], Y: list[list[float]]) -> list[float]: """ Compute cosine similarity between a vector X and a matrix Y. Lazy import numpy for efficiency. """ if not Y: return [] if _check_numpy(): import numpy as np # type: ignore X_arr = np.array(X) if not isinstance(X, np.ndarray) else X Y_arr = np.array(Y) if not isinstance(Y, np.ndarray) else Y X_norm = np.linalg.norm(X_arr) Y_norm = np.linalg.norm(Y_arr, axis=1) # Avoid division by zero mask = Y_norm != 0 similarities = np.zeros_like(Y_norm) similarities[mask] = np.dot(Y_arr[mask], X_arr) / (Y_norm[mask] * X_norm) return similarities.tolist() similarities = [] for y in Y: dot_product = sum(a * b for a, b in zip(X, y)) norm1 = sum(a * a for a in X) ** 0.5 norm2 = sum(a * a for a in y) ** 0.5 similarity = dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0.0 similarities.append(similarity) return similarities def _does_match(match_condition: MatchCondition, key: tuple[str, ...]) -> bool: """Whether a namespace key matches a match condition.""" match_type = match_condition.match_type path = match_condition.path if len(key) < len(path): return False if match_type == "prefix": for k_elem, p_elem in zip(key, path): if p_elem == "*": continue # Wildcard matches any element if k_elem != p_elem: return False return True elif match_type == "suffix": for k_elem, p_elem in zip(reversed(key), reversed(path)): if p_elem == "*": continue # Wildcard matches any element if k_elem != p_elem: return False return True else: raise ValueError(f"Unsupported match type: {match_type}") def _compare_values(item_value: Any, filter_value: Any) -> bool: """Compare values in a JSONB-like way, handling nested objects.""" if isinstance(filter_value, dict): if any(k.startswith("$") for k in filter_value): return all( _apply_operator(item_value, op_key, op_value) for op_key, op_value in filter_value.items() ) if not isinstance(item_value, dict): return False return all( _compare_values(item_value.get(k), v) for k, v in filter_value.items() ) elif isinstance(filter_value, (list, tuple)): return ( isinstance(item_value, (list, tuple)) and len(item_value) == len(filter_value) and all(_compare_values(iv, fv) for iv, fv in zip(item_value, filter_value)) ) else: return item_value == filter_value def _apply_operator(value: Any, operator: str, op_value: Any) -> bool: """Apply a comparison operator, matching PostgreSQL's JSONB behavior.""" if operator == "$eq": return value == op_value elif operator == "$gt": return float(value) > float(op_value) elif operator == "$gte": return float(value) >= float(op_value) elif operator == "$lt": return float(value) < float(op_value) elif operator == "$lte": return float(value) <= float(op_value) elif operator == "$ne": return value != op_value else: raise ValueError(f"Unsupported operator: {operator}") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/store/base/embed.py`: ```py """Utilities for working with embedding functions and LangChain's Embeddings interface. This module provides tools to wrap arbitrary embedding functions (both sync and async) into LangChain's Embeddings interface. This enables using custom embedding functions with LangChain-compatible tools while maintaining support for both synchronous and asynchronous operations. """ import asyncio import json from typing import Any, Awaitable, Callable, Optional, Sequence, Union from langchain_core.embeddings import Embeddings EmbeddingsFunc = Callable[[Sequence[str]], list[list[float]]] """Type for synchronous embedding functions. The function should take a sequence of strings and return a list of embeddings, where each embedding is a list of floats. The dimensionality of the embeddings should be consistent for all inputs. """ AEmbeddingsFunc = Callable[[Sequence[str]], Awaitable[list[list[float]]]] """Type for asynchronous embedding functions. Similar to EmbeddingsFunc, but returns an awaitable that resolves to the embeddings. """ def ensure_embeddings( embed: Union[Embeddings, EmbeddingsFunc, AEmbeddingsFunc, None], ) -> Embeddings: """Ensure that an embedding function conforms to LangChain's Embeddings interface. This function wraps arbitrary embedding functions to make them compatible with LangChain's Embeddings interface. It handles both synchronous and asynchronous functions. Args: embed: Either an existing Embeddings instance, or a function that converts text to embeddings. If the function is async, it will be used for both sync and async operations. Returns: An Embeddings instance that wraps the provided function(s). ??? example "Examples" Wrap a synchronous embedding function: ```python def my_embed_fn(texts): return [[0.1, 0.2] for _ in texts] embeddings = ensure_embeddings(my_embed_fn) result = embeddings.embed_query("hello") # Returns [0.1, 0.2] ``` Wrap an asynchronous embedding function: ```python async def my_async_fn(texts): return [[0.1, 0.2] for _ in texts] embeddings = ensure_embeddings(my_async_fn) result = await embeddings.aembed_query("hello") # Returns [0.1, 0.2] ``` """ if embed is None: raise ValueError("embed must be provided") if isinstance(embed, Embeddings): return embed return EmbeddingsLambda(embed) class EmbeddingsLambda(Embeddings): """Wrapper to convert embedding functions into LangChain's Embeddings interface. This class allows arbitrary embedding functions to be used with LangChain-compatible tools. It supports both synchronous and asynchronous operations, and can handle: 1. A synchronous function for sync operations (async operations will use sync function) 2. An async function for both sync/async operations (sync operations will raise an error) The embedding functions should convert text into fixed-dimensional vectors that capture the semantic meaning of the text. Args: func: Function that converts text to embeddings. Can be sync or async. If async, it will be used for async operations, but sync operations will raise an error. If sync, it will be used for both sync and async operations. ??? example "Examples" With a sync function: ```python def my_embed_fn(texts): # Return 2D embeddings for each text return [[0.1, 0.2] for _ in texts] embeddings = EmbeddingsLambda(my_embed_fn) result = embeddings.embed_query("hello") # Returns [0.1, 0.2] await embeddings.aembed_query("hello") # Also returns [0.1, 0.2] ``` With an async function: ```python async def my_async_fn(texts): return [[0.1, 0.2] for _ in texts] embeddings = EmbeddingsLambda(my_async_fn) await embeddings.aembed_query("hello") # Returns [0.1, 0.2] # Note: embed_query() would raise an error ``` """ def __init__( self, func: Union[EmbeddingsFunc, AEmbeddingsFunc], ) -> None: if func is None: raise ValueError("func must be provided") if _is_async_callable(func): self.afunc = func else: self.func = func def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed a list of texts into vectors. Args: texts: list of texts to convert to embeddings. Returns: list of embeddings, one per input text. Each embedding is a list of floats. Raises: ValueError: If the instance was initialized with only an async function. """ func = getattr(self, "func", None) if func is None: raise ValueError( "EmbeddingsLambda was initialized with an async function but no sync function. " "Use aembed_documents for async operation or provide a sync function." ) return func(texts) def embed_query(self, text: str) -> list[float]: """Embed a single piece of text. Args: text: Text to convert to an embedding. Returns: Embedding vector as a list of floats. Note: This is equivalent to calling embed_documents with a single text and taking the first result. """ return self.embed_documents([text])[0] async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronously embed a list of texts into vectors. Args: texts: list of texts to convert to embeddings. Returns: list of embeddings, one per input text. Each embedding is a list of floats. Note: If no async function was provided, this falls back to the sync implementation. """ afunc = getattr(self, "afunc", None) if afunc is None: return await super().aembed_documents(texts) return await afunc(texts) async def aembed_query(self, text: str) -> list[float]: """Asynchronously embed a single piece of text. Args: text: Text to convert to an embedding. Returns: Embedding vector as a list of floats. Note: This is equivalent to calling aembed_documents with a single text and taking the first result. """ afunc = getattr(self, "afunc", None) if afunc is None: return await super().aembed_query(text) return (await afunc([text]))[0] def get_text_at_path(obj: Any, path: Union[str, list[str]]) -> list[str]: """Extract text from an object using a path expression or pre-tokenized path. Args: obj: The object to extract text from path: Either a path string or pre-tokenized path list. !!! info "Path types handled" - Simple paths: "field1.field2" - Array indexing: "[0]", "[*]", "[-1]" - Wildcards: "*" - Multi-field selection: "{field1,field2}" - Nested paths in multi-field: "{field1,nested.field2}" """ if not path or path == "$": return [json.dumps(obj, sort_keys=True)] tokens = tokenize_path(path) if isinstance(path, str) else path def _extract_from_obj(obj: Any, tokens: list[str], pos: int) -> list[str]: if pos >= len(tokens): if isinstance(obj, (str, int, float, bool)): return [str(obj)] elif obj is None: return [] elif isinstance(obj, (list, dict)): return [json.dumps(obj, sort_keys=True)] return [] token = tokens[pos] results = [] if token.startswith("[") and token.endswith("]"): if not isinstance(obj, list): return [] index = token[1:-1] if index == "*": for item in obj: results.extend(_extract_from_obj(item, tokens, pos + 1)) else: try: idx = int(index) if idx < 0: idx = len(obj) + idx if 0 <= idx < len(obj): results.extend(_extract_from_obj(obj[idx], tokens, pos + 1)) except (ValueError, IndexError): return [] elif token.startswith("{") and token.endswith("}"): if not isinstance(obj, dict): return [] fields = [f.strip() for f in token[1:-1].split(",")] for field in fields: nested_tokens = tokenize_path(field) if nested_tokens: current_obj: Optional[dict] = obj for nested_token in nested_tokens: if ( isinstance(current_obj, dict) and nested_token in current_obj ): current_obj = current_obj[nested_token] else: current_obj = None break if current_obj is not None: if isinstance(current_obj, (str, int, float, bool)): results.append(str(current_obj)) elif isinstance(current_obj, (list, dict)): results.append(json.dumps(current_obj, sort_keys=True)) # Handle wildcard elif token == "*": if isinstance(obj, dict): for value in obj.values(): results.extend(_extract_from_obj(value, tokens, pos + 1)) elif isinstance(obj, list): for item in obj: results.extend(_extract_from_obj(item, tokens, pos + 1)) # Handle regular field else: if isinstance(obj, dict) and token in obj: results.extend(_extract_from_obj(obj[token], tokens, pos + 1)) return results return _extract_from_obj(obj, tokens, 0) # Private utility functions def tokenize_path(path: str) -> list[str]: """Tokenize a path into components. !!! info "Types handled" - Simple paths: "field1.field2" - Array indexing: "[0]", "[*]", "[-1]" - Wildcards: "*" - Multi-field selection: "{field1,field2}" """ if not path: return [] tokens = [] current: list[str] = [] i = 0 while i < len(path): char = path[i] if char == "[": # Handle array index if current: tokens.append("".join(current)) current = [] bracket_count = 1 index_chars = ["["] i += 1 while i < len(path) and bracket_count > 0: if path[i] == "[": bracket_count += 1 elif path[i] == "]": bracket_count -= 1 index_chars.append(path[i]) i += 1 tokens.append("".join(index_chars)) continue elif char == "{": # Handle multi-field selection if current: tokens.append("".join(current)) current = [] brace_count = 1 field_chars = ["{"] i += 1 while i < len(path) and brace_count > 0: if path[i] == "{": brace_count += 1 elif path[i] == "}": brace_count -= 1 field_chars.append(path[i]) i += 1 tokens.append("".join(field_chars)) continue elif char == ".": # Handle regular field if current: tokens.append("".join(current)) current = [] else: current.append(char) i += 1 if current: tokens.append("".join(current)) return tokens def _is_async_callable( func: Any, ) -> bool: """Check if a function is async. This includes both async def functions and classes with async __call__ methods. Args: func: Function or callable object to check. Returns: True if the function is async, False otherwise. """ return ( asyncio.iscoroutinefunction(func) or hasattr(func, "__call__") # noqa: B004 and asyncio.iscoroutinefunction(func.__call__) ) __all__ = [ "ensure_embeddings", "EmbeddingsFunc", "AEmbeddingsFunc", ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/store/base/batch.py`: ```py import asyncio import weakref from typing import Any, Literal, Optional, Union from langgraph.store.base import ( BaseStore, GetOp, Item, ListNamespacesOp, MatchCondition, NamespacePath, Op, PutOp, SearchItem, SearchOp, _validate_namespace, ) class AsyncBatchedBaseStore(BaseStore): """Efficiently batch operations in a background task.""" __slots__ = ("_loop", "_aqueue", "_task") def __init__(self) -> None: self._loop = asyncio.get_running_loop() self._aqueue: dict[asyncio.Future, Op] = {} self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self))) def __del__(self) -> None: self._task.cancel() async def aget( self, namespace: tuple[str, ...], key: str, ) -> Optional[Item]: fut = self._loop.create_future() self._aqueue[fut] = GetOp(namespace, key) return await fut async def asearch( self, namespace_prefix: tuple[str, ...], /, *, query: Optional[str] = None, filter: Optional[dict[str, Any]] = None, limit: int = 10, offset: int = 0, ) -> list[SearchItem]: fut = self._loop.create_future() self._aqueue[fut] = SearchOp(namespace_prefix, filter, limit, offset, query) return await fut async def aput( self, namespace: tuple[str, ...], key: str, value: dict[str, Any], index: Optional[Union[Literal[False], list[str]]] = None, ) -> None: _validate_namespace(namespace) fut = self._loop.create_future() self._aqueue[fut] = PutOp(namespace, key, value, index) return await fut async def adelete( self, namespace: tuple[str, ...], key: str, ) -> None: fut = self._loop.create_future() self._aqueue[fut] = PutOp(namespace, key, None) return await fut async def alist_namespaces( self, *, prefix: Optional[NamespacePath] = None, suffix: Optional[NamespacePath] = None, max_depth: Optional[int] = None, limit: int = 100, offset: int = 0, ) -> list[tuple[str, ...]]: fut = self._loop.create_future() match_conditions = [] if prefix: match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) if suffix: match_conditions.append(MatchCondition(match_type="suffix", path=suffix)) op = ListNamespacesOp( match_conditions=tuple(match_conditions), max_depth=max_depth, limit=limit, offset=offset, ) self._aqueue[fut] = op return await fut def _dedupe_ops(values: list[Op]) -> tuple[Optional[list[int]], list[Op]]: """Dedupe operations while preserving order for results. Args: values: List of operations to dedupe Returns: Tuple of (listen indices, deduped operations) where listen indices map deduped operation results back to original positions """ if len(values) <= 1: return None, list(values) dedupped: list[Op] = [] listen: list[int] = [] puts: dict[tuple[tuple[str, ...], str], int] = {} for op in values: if isinstance(op, (GetOp, SearchOp, ListNamespacesOp)): try: listen.append(dedupped.index(op)) except ValueError: listen.append(len(dedupped)) dedupped.append(op) elif isinstance(op, PutOp): putkey = (op.namespace, op.key) if putkey in puts: # Overwrite previous put ix = puts[putkey] dedupped[ix] = op listen.append(ix) else: puts[putkey] = len(dedupped) listen.append(len(dedupped)) dedupped.append(op) else: # Any new ops will be treated regularly listen.append(len(dedupped)) dedupped.append(op) return listen, dedupped async def _run( aqueue: dict[asyncio.Future, Op], store: weakref.ReferenceType[BaseStore] ) -> None: while True: await asyncio.sleep(0) if not aqueue: continue if s := store(): # get the operations to run taken = aqueue.copy() # action each operation try: values = list(taken.values()) listen, dedupped = _dedupe_ops(values) results = await s.abatch(dedupped) if listen is not None: results = [results[ix] for ix in listen] # set the results of each operation for fut, result in zip(taken, results): fut.set_result(result) except Exception as e: for fut in taken: fut.set_exception(e) # remove the operations from the queue for fut in taken: del aqueue[fut] else: break # remove strong ref to store del s ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/store/base/__init__.py`: ```py """Base classes and types for persistent key-value stores. Stores provide long-term memory that persists across threads and conversations. Supports hierarchical namespaces, key-value storage, and optional vector search. Core types: - BaseStore: Store interface with sync/async operations - Item: Stored key-value pairs with metadata - Op: Get/Put/Search/List operations """ from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Iterable, Literal, NamedTuple, Optional, TypedDict, Union, cast from langchain_core.embeddings import Embeddings from langgraph.store.base.embed import ( AEmbeddingsFunc, EmbeddingsFunc, ensure_embeddings, get_text_at_path, tokenize_path, ) class Item: """Represents a stored item with metadata. Args: value (dict[str, Any]): The stored data as a dictionary. Keys are filterable. key (str): Unique identifier within the namespace. namespace (tuple[str, ...]): Hierarchical path defining the collection in which this document resides. Represented as a tuple of strings, allowing for nested categorization. For example: ("documents", 'user123') created_at (datetime): Timestamp of item creation. updated_at (datetime): Timestamp of last update. """ __slots__ = ("value", "key", "namespace", "created_at", "updated_at") def __init__( self, *, value: dict[str, Any], key: str, namespace: tuple[str, ...], created_at: datetime, updated_at: datetime, ): self.value = value self.key = key # The casting from json-like types is for if this object is # deserialized. self.namespace = tuple(namespace) self.created_at = ( datetime.fromisoformat(cast(str, created_at)) if isinstance(created_at, str) else created_at ) self.updated_at = ( datetime.fromisoformat(cast(str, created_at)) if isinstance(updated_at, str) else updated_at ) def __eq__(self, other: object) -> bool: if not isinstance(other, Item): return False return ( self.value == other.value and self.key == other.key and self.namespace == other.namespace and self.created_at == other.created_at and self.updated_at == other.updated_at ) def __hash__(self) -> int: return hash((self.namespace, self.key)) def dict(self) -> dict: return { "value": self.value, "key": self.key, "namespace": list(self.namespace), "created_at": self.created_at.isoformat(), "updated_at": self.updated_at.isoformat(), } class SearchItem(Item): """Represents an item returned from a search operation with additional metadata.""" __slots__ = ("score",) def __init__( self, namespace: tuple[str, ...], key: str, value: dict[str, Any], created_at: datetime, updated_at: datetime, score: Optional[float] = None, ) -> None: """Initialize a result item. Args: namespace: Hierarchical path to the item. key: Unique identifier within the namespace. value: The stored value. created_at: When the item was first created. updated_at: When the item was last updated. score: Relevance/similarity score if from a ranked operation. """ super().__init__( value=value, key=key, namespace=namespace, created_at=created_at, updated_at=updated_at, ) self.score = score def dict(self) -> dict: result = super().dict() result["score"] = self.score return result class GetOp(NamedTuple): """Operation to retrieve a specific item by its namespace and key. This operation allows precise retrieval of stored items using their full path (namespace) and unique identifier (key) combination. ???+ example "Examples" Basic item retrieval: ```python GetOp(namespace=("users", "profiles"), key="user123") GetOp(namespace=("cache", "embeddings"), key="doc456") ``` """ namespace: tuple[str, ...] """Hierarchical path that uniquely identifies the item's location. ???+ example "Examples" ```python ("users",) # Root level users namespace ("users", "profiles") # Profiles within users namespace ``` """ key: str """Unique identifier for the item within its specific namespace. ???+ example "Examples" ```python "user123" # For a user profile "doc456" # For a document ``` """ class SearchOp(NamedTuple): """Operation to search for items within a specified namespace hierarchy. This operation supports both structured filtering and natural language search within a given namespace prefix. It provides pagination through limit and offset parameters. Note: Natural language search support depends on your store implementation. ???+ example "Examples" Search with filters and pagination: ```python SearchOp( namespace_prefix=("documents",), filter={"type": "report", "status": "active"}, limit=5, offset=10 ) ``` Natural language search: ```python SearchOp( namespace_prefix=("users", "content"), query="technical documentation about APIs", limit=20 ) ``` """ namespace_prefix: tuple[str, ...] """Hierarchical path prefix defining the search scope. ???+ example "Examples" ```python () # Search entire store ("documents",) # Search all documents ("users", "content") # Search within user content ``` """ filter: Optional[dict[str, Any]] = None """Key-value pairs for filtering results based on exact matches or comparison operators. The filter supports both exact matches and operator-based comparisons. Supported Operators: - $eq: Equal to (same as direct value comparison) - $ne: Not equal to - $gt: Greater than - $gte: Greater than or equal to - $lt: Less than - $lte: Less than or equal to ???+ example "Examples" Simple exact match: ```python {"status": "active"} ``` Comparison operators: ```python {"score": {"$gt": 4.99}} # Score greater than 4.99 ``` Multiple conditions: ```python { "score": {"$gte": 3.0}, "color": "red" } ``` """ limit: int = 10 """Maximum number of items to return in the search results.""" offset: int = 0 """Number of matching items to skip for pagination.""" query: Optional[str] = None """Natural language search query for semantic search capabilities. ???+ example "Examples" - "technical documentation about REST APIs" - "machine learning papers from 2023" """ # Type representing a namespace path that can include wildcards NamespacePath = tuple[Union[str, Literal["*"]], ...] """A tuple representing a namespace path that can include wildcards. ???+ example "Examples" ```python ("users",) # Exact users namespace ("documents", "*") # Any sub-namespace under documents ("cache", "*", "v1") # Any cache category with v1 version ``` """ # Type for specifying how to match namespaces NamespaceMatchType = Literal["prefix", "suffix"] """Specifies how to match namespace paths. Values: "prefix": Match from the start of the namespace "suffix": Match from the end of the namespace """ class MatchCondition(NamedTuple): """Represents a pattern for matching namespaces in the store. This class combines a match type (prefix or suffix) with a namespace path pattern that can include wildcards to flexibly match different namespace hierarchies. ???+ example "Examples" Prefix matching: ```python MatchCondition(match_type="prefix", path=("users", "profiles")) ``` Suffix matching with wildcard: ```python MatchCondition(match_type="suffix", path=("cache", "*")) ``` Simple suffix matching: ```python MatchCondition(match_type="suffix", path=("v1",)) ``` """ match_type: NamespaceMatchType """Type of namespace matching to perform.""" path: NamespacePath """Namespace path pattern that can include wildcards.""" class ListNamespacesOp(NamedTuple): """Operation to list and filter namespaces in the store. This operation allows exploring the organization of data, finding specific collections, and navigating the namespace hierarchy. ???+ example "Examples" List all namespaces under the "documents" path: ```python ListNamespacesOp( match_conditions=(MatchCondition(match_type="prefix", path=("documents",)),), max_depth=2 ) ``` List all namespaces that end with "v1": ```python ListNamespacesOp( match_conditions=(MatchCondition(match_type="suffix", path=("v1",)),), limit=50 ) ``` """ match_conditions: Optional[tuple[MatchCondition, ...]] = None """Optional conditions for filtering namespaces. ???+ example "Examples" All user namespaces: ```python (MatchCondition(match_type="prefix", path=("users",)),) ``` All namespaces that start with "docs" and end with "draft": ```python ( MatchCondition(match_type="prefix", path=("docs",)), MatchCondition(match_type="suffix", path=("draft",)) ) ``` """ max_depth: Optional[int] = None """Maximum depth of namespace hierarchy to return. Note: Namespaces deeper than this level will be truncated. """ limit: int = 100 """Maximum number of namespaces to return.""" offset: int = 0 """Number of namespaces to skip for pagination.""" class PutOp(NamedTuple): """Operation to store, update, or delete an item in the store. This class represents a single operation to modify the store's contents, whether adding new items, updating existing ones, or removing them. """ namespace: tuple[str, ...] """Hierarchical path that identifies the location of the item. The namespace acts as a folder-like structure to organize items. Each element in the tuple represents one level in the hierarchy. ???+ example "Examples" Root level documents ```python ("documents",) ``` User-specific documents ```python ("documents", "user123") ``` Nested cache structure ```python ("cache", "embeddings", "v1") ``` """ key: str """Unique identifier for the item within its namespace. The key must be unique within the specific namespace to avoid conflicts. Together with the namespace, it forms a complete path to the item. Example: If namespace is ("documents", "user123") and key is "report1", the full path would effectively be "documents/user123/report1" """ value: Optional[dict[str, Any]] """The data to store, or None to mark the item for deletion. The value must be a dictionary with string keys and JSON-serializable values. Setting this to None signals that the item should be deleted. Example: { "field1": "string value", "field2": 123, "nested": {"can": "contain", "any": "serializable data"} } """ index: Optional[Union[Literal[False], list[str]]] = None # type: ignore[assignment] """Controls how the item's fields are indexed for search operations. Indexing configuration determines how the item can be found through search: - None (default): Uses the store's default indexing configuration (if provided) - False: Disables indexing for this item - list[str]: Specifies which json path fields to index for search The item remains accessible through direct get() operations regardless of indexing. When indexed, fields can be searched using natural language queries through vector similarity search (if supported by the store implementation). Path Syntax: - Simple field access: "field" - Nested fields: "parent.child.grandchild" - Array indexing: - Specific index: "array[0]" - Last element: "array[-1]" - All elements (each individually): "array[*]" ???+ example "Examples" - None - Use store defaults (whole item) - list[str] - List of fields to index ```python [ "metadata.title", # Nested field access "context[*].content", # Index content from all context as separate vectors "authors[0].name", # First author's name "revisions[-1].changes", # Most recent revision's changes "sections[*].paragraphs[*].text", # All text from all paragraphs in all sections "metadata.tags[*]", # All tags in metadata ] ``` """ Op = Union[GetOp, SearchOp, PutOp, ListNamespacesOp] Result = Union[Item, list[Item], list[SearchItem], list[tuple[str, ...]], None] class InvalidNamespaceError(ValueError): """Provided namespace is invalid.""" class IndexConfig(TypedDict, total=False): """Configuration for indexing documents for semantic search in the store. If not provided to the store, the store will not support vector search. In that case, all `index` arguments to put() and `aput()` operations will be ignored. """ dims: int """Number of dimensions in the embedding vectors. Common embedding models have the following dimensions: - openai:text-embedding-3-large: 3072 - openai:text-embedding-3-small: 1536 - openai:text-embedding-ada-002: 1536 - cohere:embed-english-v3.0: 1024 - cohere:embed-english-light-v3.0: 384 - cohere:embed-multilingual-v3.0: 1024 - cohere:embed-multilingual-light-v3.0: 384 """ embed: Union[Embeddings, EmbeddingsFunc, AEmbeddingsFunc] """Optional function to generate embeddings from text. Can be specified in three ways: 1. A LangChain Embeddings instance 2. A synchronous embedding function (EmbeddingsFunc) 3. An asynchronous embedding function (AEmbeddingsFunc) ???+ example "Examples" Using LangChain's initialization with InMemoryStore: ```python from langchain.embeddings import init_embeddings from langgraph.store.memory import InMemoryStore store = InMemoryStore( index={ "dims": 1536, "embed": init_embeddings("openai:text-embedding-3-small") } ) ``` Using a custom embedding function with InMemoryStore: ```python from openai import OpenAI from langgraph.store.memory import InMemoryStore client = OpenAI() def embed_texts(texts: list[str]) -> list[list[float]]: response = client.embeddings.create( model="text-embedding-3-small", input=texts ) return [e.embedding for e in response.data] store = InMemoryStore( index={ "dims": 1536, "embed": embed_texts } ) ``` Using an asynchronous embedding function with InMemoryStore: ```python from openai import AsyncOpenAI from langgraph.store.memory import InMemoryStore client = AsyncOpenAI() async def aembed_texts(texts: list[str]) -> list[list[float]]: response = await client.embeddings.create( model="text-embedding-3-small", input=texts ) return [e.embedding for e in response.data] store = InMemoryStore( index={ "dims": 1536, "embed": aembed_texts } ) ``` """ fields: Optional[list[str]] """Fields to extract text from for embedding generation. Controls which parts of stored items are embedded for semantic search. Follows JSON path syntax: - ["$"]: Embeds the entire JSON object as one vector (default) - ["field1", "field2"]: Embeds specific top-level fields - ["parent.child"]: Embeds nested fields using dot notation - ["array[*].field"]: Embeds field from each array element separately Note: You can always override this behavior when storing an item using the `index` parameter in the `put` or `aput` operations. ???+ example "Examples" ```python # Embed entire document (default) fields=["$"] # Embed specific fields fields=["text", "summary"] # Embed nested fields fields=["metadata.title", "content.body"] # Embed from arrays fields=["messages[*].content"] # Each message content separately fields=["context[0].text"] # First context item's text ``` Note: - Fields missing from a document are skipped - Array notation creates separate embeddings for each element - Complex nested paths are supported (e.g., "a.b[*].c.d") """ class BaseStore(ABC): """Abstract base class for persistent key-value stores. Stores enable persistence and memory that can be shared across threads, scoped to user IDs, assistant IDs, or other arbitrary namespaces. Some implementations may support semantic search capabilities through an optional `index` configuration. Note: Semantic search capabilities vary by implementation and are typically disabled by default. Stores that support this feature can be configured by providing an `index` configuration at creation time. Without this configuration, semantic search is disabled and any `index` arguments to storage operations will have no effect. """ __slots__ = ("__weakref__",) @abstractmethod def batch(self, ops: Iterable[Op]) -> list[Result]: """Execute multiple operations synchronously in a single batch. Args: ops: An iterable of operations to execute. Returns: A list of results, where each result corresponds to an operation in the input. The order of results matches the order of input operations. """ @abstractmethod async def abatch(self, ops: Iterable[Op]) -> list[Result]: """Execute multiple operations asynchronously in a single batch. Args: ops: An iterable of operations to execute. Returns: A list of results, where each result corresponds to an operation in the input. The order of results matches the order of input operations. """ def get(self, namespace: tuple[str, ...], key: str) -> Optional[Item]: """Retrieve a single item. Args: namespace: Hierarchical path for the item. key: Unique identifier within the namespace. Returns: The retrieved item or None if not found. """ return self.batch([GetOp(namespace, key)])[0] def search( self, namespace_prefix: tuple[str, ...], /, *, query: Optional[str] = None, filter: Optional[dict[str, Any]] = None, limit: int = 10, offset: int = 0, ) -> list[SearchItem]: """Search for items within a namespace prefix. Args: namespace_prefix: Hierarchical path prefix to search within. query: Optional query for natural language search. filter: Key-value pairs to filter results. limit: Maximum number of items to return. offset: Number of items to skip before returning results. Returns: List of items matching the search criteria. ???+ example "Examples" Basic filtering: ```python # Search for documents with specific metadata results = store.search( ("docs",), filter={"type": "article", "status": "published"} ) ``` Natural language search (requires vector store implementation): ```python # Initialize store with embedding configuration store = YourStore( # e.g., InMemoryStore, AsyncPostgresStore index={ "dims": 1536, # embedding dimensions "embed": your_embedding_function, # function to create embeddings "fields": ["text"] # fields to embed. Defaults to ["$"] } ) # Search for semantically similar documents results = store.search( ("docs",), query="machine learning applications in healthcare", filter={"type": "research_paper"}, limit=5 ) ``` Note: Natural language search support depends on your store implementation and requires proper embedding configuration. """ return self.batch([SearchOp(namespace_prefix, filter, limit, offset, query)])[0] def put( self, namespace: tuple[str, ...], key: str, value: dict[str, Any], index: Optional[Union[Literal[False], list[str]]] = None, ) -> None: """Store or update an item in the store. Args: namespace: Hierarchical path for the item, represented as a tuple of strings. Example: ("documents", "user123") key: Unique identifier within the namespace. Together with namespace forms the complete path to the item. value: Dictionary containing the item's data. Must contain string keys and JSON-serializable values. index: Controls how the item's fields are indexed for search: - None (default): Use `fields` you configured when creating the store (if any) If you do not initialize the store with indexing capabilities, the `index` parameter will be ignored - False: Disable indexing for this item - list[str]: List of field paths to index, supporting: - Nested fields: "metadata.title" - Array access: "chapters[*].content" (each indexed separately) - Specific indices: "authors[0].name" Note: Indexing support depends on your store implementation. If you do not initialize the store with indexing capabilities, the `index` parameter will be ignored. ???+ example "Examples" Store item. Indexing depends on how you configure the store. ```python store.put(("docs",), "report", {"memory": "Will likes ai"}) ``` Do not index item for semantic search. Still accessible through get() and search() operations but won't have a vector representation. ```python store.put(("docs",), "report", {"memory": "Will likes ai"}, index=False) ``` Index specific fields for search. ```python store.put(("docs",), "report", {"memory": "Will likes ai"}, index=["memory"]) ``` """ _validate_namespace(namespace) self.batch([PutOp(namespace, key, value, index=index)]) def delete(self, namespace: tuple[str, ...], key: str) -> None: """Delete an item. Args: namespace: Hierarchical path for the item. key: Unique identifier within the namespace. """ self.batch([PutOp(namespace, key, None)]) def list_namespaces( self, *, prefix: Optional[NamespacePath] = None, suffix: Optional[NamespacePath] = None, max_depth: Optional[int] = None, limit: int = 100, offset: int = 0, ) -> list[tuple[str, ...]]: """List and filter namespaces in the store. Used to explore the organization of data, find specific collections, or navigate the namespace hierarchy. Args: prefix (Optional[Tuple[str, ...]]): Filter namespaces that start with this path. suffix (Optional[Tuple[str, ...]]): Filter namespaces that end with this path. max_depth (Optional[int]): Return namespaces up to this depth in the hierarchy. Namespaces deeper than this level will be truncated. limit (int): Maximum number of namespaces to return (default 100). offset (int): Number of namespaces to skip for pagination (default 0). Returns: List[Tuple[str, ...]]: A list of namespace tuples that match the criteria. Each tuple represents a full namespace path up to `max_depth`. ???+ example "Examples": Setting max_depth=3. Given the namespaces: ```python # Example if you have the following namespaces: # ("a", "b", "c") # ("a", "b", "d", "e") # ("a", "b", "d", "i") # ("a", "b", "f") # ("a", "c", "f") store.list_namespaces(prefix=("a", "b"), max_depth=3) # [("a", "b", "c"), ("a", "b", "d"), ("a", "b", "f")] ``` """ match_conditions = [] if prefix: match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) if suffix: match_conditions.append(MatchCondition(match_type="suffix", path=suffix)) op = ListNamespacesOp( match_conditions=tuple(match_conditions), max_depth=max_depth, limit=limit, offset=offset, ) return self.batch([op])[0] async def aget(self, namespace: tuple[str, ...], key: str) -> Optional[Item]: """Asynchronously retrieve a single item. Args: namespace: Hierarchical path for the item. key: Unique identifier within the namespace. Returns: The retrieved item or None if not found. """ return (await self.abatch([GetOp(namespace, key)]))[0] async def asearch( self, namespace_prefix: tuple[str, ...], /, *, query: Optional[str] = None, filter: Optional[dict[str, Any]] = None, limit: int = 10, offset: int = 0, ) -> list[SearchItem]: """Asynchronously search for items within a namespace prefix. Args: namespace_prefix: Hierarchical path prefix to search within. query: Optional query for natural language search. filter: Key-value pairs to filter results. limit: Maximum number of items to return. offset: Number of items to skip before returning results. Returns: List of items matching the search criteria. ???+ example "Examples" Basic filtering: ```python # Search for documents with specific metadata results = await store.asearch( ("docs",), filter={"type": "article", "status": "published"} ) ``` Natural language search (requires vector store implementation): ```python # Initialize store with embedding configuration store = YourStore( # e.g., InMemoryStore, AsyncPostgresStore index={ "dims": 1536, # embedding dimensions "embed": your_embedding_function, # function to create embeddings "fields": ["text"] # fields to embed } ) # Search for semantically similar documents results = await store.asearch( ("docs",), query="machine learning applications in healthcare", filter={"type": "research_paper"}, limit=5 ) ``` Note: Natural language search support depends on your store implementation and requires proper embedding configuration. """ return ( await self.abatch( [SearchOp(namespace_prefix, filter, limit, offset, query)] ) )[0] async def aput( self, namespace: tuple[str, ...], key: str, value: dict[str, Any], index: Optional[Union[Literal[False], list[str]]] = None, ) -> None: """Asynchronously store or update an item in the store. Args: namespace: Hierarchical path for the item, represented as a tuple of strings. Example: ("documents", "user123") key: Unique identifier within the namespace. Together with namespace forms the complete path to the item. value: Dictionary containing the item's data. Must contain string keys and JSON-serializable values. index: Controls how the item's fields are indexed for search: - None (default): Use `fields` you configured when creating the store (if any) If you do not initialize the store with indexing capabilities, the `index` parameter will be ignored - False: Disable indexing for this item - list[str]: List of field paths to index, supporting: - Nested fields: "metadata.title" - Array access: "chapters[*].content" (each indexed separately) - Specific indices: "authors[0].name" Note: Indexing support depends on your store implementation. If you do not initialize the store with indexing capabilities, the `index` parameter will be ignored. ???+ example "Examples" Store item. Indexing depends on how you configure the store. ```python await store.aput(("docs",), "report", {"memory": "Will likes ai"}) ``` Do not index item for semantic search. Still accessible through get() and search() operations but won't have a vector representation. ```python await store.aput(("docs",), "report", {"memory": "Will likes ai"}, index=False) ``` Index specific fields for search (if store configured to index items): ```python await store.aput( ("docs",), "report", { "memory": "Will likes ai", "context": [{"content": "..."}, {"content": "..."}] }, index=["memory", "context[*].content"] ) ``` """ _validate_namespace(namespace) await self.abatch([PutOp(namespace, key, value, index=index)]) async def adelete(self, namespace: tuple[str, ...], key: str) -> None: """Asynchronously delete an item. Args: namespace: Hierarchical path for the item. key: Unique identifier within the namespace. """ await self.abatch([PutOp(namespace, key, None)]) async def alist_namespaces( self, *, prefix: Optional[NamespacePath] = None, suffix: Optional[NamespacePath] = None, max_depth: Optional[int] = None, limit: int = 100, offset: int = 0, ) -> list[tuple[str, ...]]: """List and filter namespaces in the store asynchronously. Used to explore the organization of data, find specific collections, or navigate the namespace hierarchy. Args: prefix (Optional[Tuple[str, ...]]): Filter namespaces that start with this path. suffix (Optional[Tuple[str, ...]]): Filter namespaces that end with this path. max_depth (Optional[int]): Return namespaces up to this depth in the hierarchy. Namespaces deeper than this level will be truncated to this depth. limit (int): Maximum number of namespaces to return (default 100). offset (int): Number of namespaces to skip for pagination (default 0). Returns: List[Tuple[str, ...]]: A list of namespace tuples that match the criteria. Each tuple represents a full namespace path up to `max_depth`. ???+ example "Examples" Setting max_depth=3 with existing namespaces: ```python # Given the following namespaces: # ("a", "b", "c") # ("a", "b", "d", "e") # ("a", "b", "d", "i") # ("a", "b", "f") # ("a", "c", "f") await store.alist_namespaces(prefix=("a", "b"), max_depth=3) # Returns: [("a", "b", "c"), ("a", "b", "d"), ("a", "b", "f")] ``` """ match_conditions = [] if prefix: match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) if suffix: match_conditions.append(MatchCondition(match_type="suffix", path=suffix)) op = ListNamespacesOp( match_conditions=tuple(match_conditions), max_depth=max_depth, limit=limit, offset=offset, ) return (await self.abatch([op]))[0] def _validate_namespace(namespace: tuple[str, ...]) -> None: if not namespace: raise InvalidNamespaceError("Namespace cannot be empty.") for label in namespace: if not isinstance(label, str): raise InvalidNamespaceError( f"Invalid namespace label '{label}' found in {namespace}. Namespace labels" f" must be strings, but got {type(label).__name__}." ) if "." in label: raise InvalidNamespaceError( f"Invalid namespace label '{label}' found in {namespace}. Namespace labels cannot contain periods ('.')." ) elif not label: raise InvalidNamespaceError( f"Namespace labels cannot be empty strings. Got {label} in {namespace}" ) if namespace[0] == "langgraph": raise InvalidNamespaceError( f'Root label for namespace cannot be "langgraph". Got: {namespace}' ) __all__ = [ "BaseStore", "Item", "Op", "PutOp", "GetOp", "SearchOp", "ListNamespacesOp", "MatchCondition", "NamespacePath", "NamespaceMatchType", "Embeddings", "ensure_embeddings", "tokenize_path", "get_text_at_path", ] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/pyproject.toml`: ```toml [tool.poetry] name = "langgraph-checkpoint" version = "2.0.8" description = "Library with base interfaces for LangGraph checkpoint savers." authors = [] license = "MIT" readme = "README.md" repository = "https://www.github.com/langchain-ai/langgraph" packages = [{ include = "langgraph" }] [tool.poetry.dependencies] python = "^3.9.0,<4.0" langchain-core = ">=0.2.38,<0.4" msgpack = "^1.1.0" [tool.poetry.group.dev.dependencies] ruff = "^0.6.2" codespell = "^2.2.0" pytest = "^7.2.1" pytest-asyncio = "^0.21.1" pytest-mock = "^3.11.1" pytest-watcher = "^0.4.1" mypy = "^1.10.0" dataclasses-json = "^0.6.7" [tool.pytest.ini_options] # --strict-markers will raise errors on unknown marks. # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks # # https://docs.pytest.org/en/7.1.x/reference/reference.html # --strict-config any warnings encountered while parsing the `pytest` # section of the configuration file raise errors. addopts = "--strict-markers --strict-config --durations=5 -vv" asyncio_mode = "auto" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] lint.select = [ "E", # pycodestyle "F", # Pyflakes "UP", # pyupgrade "B", # flake8-bugbear "I", # isort ] lint.ignore = ["E501", "B008", "UP007", "UP006"] [tool.pytest-watcher] now = true delay = 0.1 runner_args = ["--ff", "-v", "--tb", "short"] patterns = ["*.py"] [tool.mypy] # https://mypy.readthedocs.io/en/stable/config_file.html disallow_untyped_defs = "True" explicit_package_bases = "True" warn_no_return = "False" warn_unused_ignores = "True" warn_redundant_casts = "True" allow_redefinition = "True" disable_error_code = "typeddict-item, return-value" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/tests/test_jsonplus.py`: ```py import dataclasses import pathlib import re import sys import uuid from collections import deque from datetime import date, datetime, time, timezone from decimal import Decimal from enum import Enum from ipaddress import IPv4Address import dataclasses_json from pydantic import BaseModel, SecretStr from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import SecretStr as SecretStrV1 from zoneinfo import ZoneInfo from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer from langgraph.store.base import Item class InnerPydantic(BaseModel): hello: str class MyPydantic(BaseModel): foo: str bar: int inner: InnerPydantic class InnerPydanticV1(BaseModelV1): hello: str class MyPydanticV1(BaseModelV1): foo: str bar: int inner: InnerPydanticV1 @dataclasses.dataclass class InnerDataclass: hello: str @dataclasses.dataclass class MyDataclass: foo: str bar: int inner: InnerDataclass def something(self) -> None: pass if sys.version_info < (3, 10): class MyDataclassWSlots(MyDataclass): pass else: @dataclasses.dataclass(slots=True) class MyDataclassWSlots: foo: str bar: int inner: InnerDataclass def something(self) -> None: pass class MyEnum(Enum): FOO = "foo" BAR = "bar" @dataclasses_json.dataclass_json @dataclasses.dataclass class Person: name: str def test_serde_jsonplus() -> None: uid = uuid.UUID(int=1) deque_instance = deque([1, 2, 3]) tzn = ZoneInfo("America/New_York") ip4 = IPv4Address("192.168.0.1") current_date = date(2024, 4, 19) current_time = time(23, 4, 57, 51022, timezone.max) current_timestamp = datetime(2024, 4, 19, 23, 4, 57, 51022, timezone.max) to_serialize = { "path": pathlib.Path("foo", "bar"), "re": re.compile(r"foo", re.DOTALL), "decimal": Decimal("1.10101"), "set": {1, 2, frozenset({1, 2})}, "frozen_set": frozenset({1, 2, 3}), "ip4": ip4, "deque": deque_instance, "tzn": tzn, "date": current_date, "time": current_time, "uid": uid, "timestamp": current_timestamp, "my_slotted_class": MyDataclassWSlots("bar", 2, InnerDataclass("hello")), "my_dataclass": MyDataclass("foo", 1, InnerDataclass("hello")), "my_enum": MyEnum.FOO, "my_pydantic": MyPydantic(foo="foo", bar=1, inner=InnerPydantic(hello="hello")), "my_pydantic_v1": MyPydanticV1( foo="foo", bar=1, inner=InnerPydanticV1(hello="hello") ), "my_secret_str": SecretStr("meow"), "my_secret_str_v1": SecretStrV1("meow"), "person": Person(name="foo"), "a_bool": True, "a_none": None, "a_str": "foo", "a_str_nuc": "foo\u0000", "a_str_uc": "foo ⛰️", "a_str_ucuc": "foo \u26f0\ufe0f\u0000", "a_str_ucucuc": "foo \\u26f0\\ufe0f", "an_int": 1, "a_float": 1.1, "a_bytes": b"my bytes", "a_bytearray": bytearray([42]), "my_item": Item( value={}, key="my-key", namespace=("a", "name", " "), created_at=datetime(2024, 9, 24, 17, 29, 10, 128397), updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397), ), } serde = JsonPlusSerializer() dumped = serde.dumps_typed(to_serialize) assert dumped[0] == "msgpack" assert serde.loads_typed(dumped) == to_serialize for value in to_serialize.values(): assert serde.loads_typed(serde.dumps_typed(value)) == value surrogates = [ "Hello\ud83d\ude00", "Python\ud83d\udc0d", "Surrogate\ud834\udd1e", "Example\ud83c\udf89", "String\ud83c\udfa7", "With\ud83c\udf08", "Surrogates\ud83d\ude0e", "Embedded\ud83d\udcbb", "In\ud83c\udf0e", "The\ud83d\udcd6", "Text\ud83d\udcac", "收花🙄·到", ] assert serde.loads_typed(serde.dumps_typed(surrogates)) == [ v.encode("utf-8", "ignore").decode() for v in surrogates ] def test_serde_jsonplus_bytes() -> None: serde = JsonPlusSerializer() some_bytes = b"my bytes" dumped = serde.dumps_typed(some_bytes) assert dumped == ("bytes", some_bytes) assert serde.loads_typed(dumped) == some_bytes def test_serde_jsonplus_bytearray() -> None: serde = JsonPlusSerializer() some_bytearray = bytearray([42]) dumped = serde.dumps_typed(some_bytearray) assert dumped == ("bytearray", some_bytearray) assert serde.loads_typed(dumped) == some_bytearray def test_loads_cannot_find() -> None: serde = JsonPlusSerializer() dumped = ( "json", b'{"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyPydanticccc"], "method": null, "args": [], "kwargs": {"foo": "foo", "bar": 1}}', ) assert serde.loads_typed(dumped) is None, "Should return None if cannot find class" dumped = ( "json", b'{"lc": 2, "type": "constructor", "id": ["tests", "test_jsonpluss", "MyPydantic"], "method": null, "args": [], "kwargs": {"foo": "foo", "bar": 1}}', ) assert serde.loads_typed(dumped) is None, "Should return None if cannot find module" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/tests/test_store.py`: ```py # mypy: disable-error-code="operator" import asyncio import json from datetime import datetime from typing import Any, Iterable import pytest from pytest_mock import MockerFixture from langgraph.store.base import ( GetOp, InvalidNamespaceError, Item, Op, PutOp, Result, get_text_at_path, ) from langgraph.store.base.batch import AsyncBatchedBaseStore from langgraph.store.memory import InMemoryStore from tests.embed_test_utils import CharacterEmbeddings class MockAsyncBatchedStore(AsyncBatchedBaseStore): def __init__(self, **kwargs: Any) -> None: super().__init__() self._store = InMemoryStore(**kwargs) def batch(self, ops: Iterable[Op]) -> list[Result]: return self._store.batch(ops) async def abatch(self, ops: Iterable[Op]) -> list[Result]: return self._store.batch(ops) def test_get_text_at_path() -> None: nested_data = { "name": "test", "info": { "age": 25, "tags": ["a", "b", "c"], "metadata": {"created": "2024-01-01", "updated": "2024-01-02"}, }, "items": [ {"id": 1, "value": "first", "tags": ["x", "y"]}, {"id": 2, "value": "second", "tags": ["y", "z"]}, {"id": 3, "value": "third", "tags": ["z", "w"]}, ], "empty": None, "zeros": [0, 0.0, "0"], "empty_list": [], "empty_dict": {}, } assert get_text_at_path(nested_data, "$") == [ json.dumps(nested_data, sort_keys=True) ] assert get_text_at_path(nested_data, "name") == ["test"] assert get_text_at_path(nested_data, "info.age") == ["25"] assert get_text_at_path(nested_data, "info.metadata.created") == ["2024-01-01"] assert get_text_at_path(nested_data, "items[0].value") == ["first"] assert get_text_at_path(nested_data, "items[-1].value") == ["third"] assert get_text_at_path(nested_data, "items[1].tags[0]") == ["y"] values = get_text_at_path(nested_data, "items[*].value") assert set(values) == {"first", "second", "third"} metadata_dates = get_text_at_path(nested_data, "info.metadata.*") assert set(metadata_dates) == {"2024-01-01", "2024-01-02"} name_and_age = get_text_at_path(nested_data, "{name,info.age}") assert set(name_and_age) == {"test", "25"} item_fields = get_text_at_path(nested_data, "items[*].{id,value}") assert set(item_fields) == {"1", "2", "3", "first", "second", "third"} all_tags = get_text_at_path(nested_data, "items[*].tags[*]") assert set(all_tags) == {"x", "y", "z", "w"} assert get_text_at_path(None, "any.path") == [] assert get_text_at_path({}, "any.path") == [] assert get_text_at_path(nested_data, "") == [ json.dumps(nested_data, sort_keys=True) ] assert get_text_at_path(nested_data, "nonexistent") == [] assert get_text_at_path(nested_data, "items[99].value") == [] assert get_text_at_path(nested_data, "items[*].nonexistent") == [] assert get_text_at_path(nested_data, "empty") == [] assert get_text_at_path(nested_data, "empty_list") == ["[]"] assert get_text_at_path(nested_data, "empty_dict") == ["{}"] zeros = get_text_at_path(nested_data, "zeros[*]") assert set(zeros) == {"0", "0.0"} assert get_text_at_path(nested_data, "items[].value") == [] assert get_text_at_path(nested_data, "items[abc].value") == [] assert get_text_at_path(nested_data, "{unclosed") == [] assert get_text_at_path(nested_data, "nested[{invalid}]") == [] async def test_async_batch_store(mocker: MockerFixture) -> None: abatch = mocker.stub() class MockStore(AsyncBatchedBaseStore): def batch(self, ops: Iterable[Op]) -> list[Result]: raise NotImplementedError async def abatch(self, ops: Iterable[Op]) -> list[Result]: assert all(isinstance(op, GetOp) for op in ops) abatch(ops) return [ Item( value={}, key=getattr(op, "key", ""), namespace=getattr(op, "namespace", ()), created_at=datetime(2024, 9, 24, 17, 29, 10, 128397), updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397), ) for op in ops ] store = MockStore() # concurrent calls are batched results = await asyncio.gather( store.aget(namespace=("a",), key="b"), store.aget(namespace=("c",), key="d"), ) assert results == [ Item( value={}, key="b", namespace=("a",), created_at=datetime(2024, 9, 24, 17, 29, 10, 128397), updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397), ), Item( value={}, key="d", namespace=("c",), created_at=datetime(2024, 9, 24, 17, 29, 10, 128397), updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397), ), ] assert abatch.call_count == 1 assert [tuple(c.args[0]) for c in abatch.call_args_list] == [ ( GetOp(("a",), "b"), GetOp(("c",), "d"), ), ] def test_list_namespaces_basic() -> None: store = InMemoryStore() namespaces = [ ("a", "b", "c"), ("a", "b", "d", "e"), ("a", "b", "d", "i"), ("a", "b", "f"), ("a", "c", "f"), ("b", "a", "f"), ("users", "123"), ("users", "456", "settings"), ("admin", "users", "789"), ] for i, ns in enumerate(namespaces): store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"}) result = store.list_namespaces(prefix=("a", "b")) expected = [ ("a", "b", "c"), ("a", "b", "d", "e"), ("a", "b", "d", "i"), ("a", "b", "f"), ] assert sorted(result) == sorted(expected) result = store.list_namespaces(suffix=("f",)) expected = [ ("a", "b", "f"), ("a", "c", "f"), ("b", "a", "f"), ] assert sorted(result) == sorted(expected) result = store.list_namespaces(prefix=("a",), suffix=("f",)) expected = [ ("a", "b", "f"), ("a", "c", "f"), ] assert sorted(result) == sorted(expected) # Test max_depth result = store.list_namespaces(prefix=("a", "b"), max_depth=3) expected = [ ("a", "b", "c"), ("a", "b", "d"), ("a", "b", "f"), ] assert sorted(result) == sorted(expected) # Test limit and offset result = store.list_namespaces(prefix=("a", "b"), limit=2) expected = [ ("a", "b", "c"), ("a", "b", "d", "e"), ] assert result == expected result = store.list_namespaces(prefix=("a", "b"), offset=2) expected = [ ("a", "b", "d", "i"), ("a", "b", "f"), ] assert result == expected result = store.list_namespaces(prefix=("a", "*", "f")) expected = [ ("a", "b", "f"), ("a", "c", "f"), ] assert sorted(result) == sorted(expected) result = store.list_namespaces(suffix=("*", "f")) expected = [ ("a", "b", "f"), ("a", "c", "f"), ("b", "a", "f"), ] assert sorted(result) == sorted(expected) result = store.list_namespaces(prefix=("nonexistent",)) assert result == [] result = store.list_namespaces(prefix=("users", "123")) expected = [("users", "123")] assert result == expected def test_list_namespaces_with_wildcards() -> None: store = InMemoryStore() namespaces = [ ("users", "123"), ("users", "456"), ("users", "789", "settings"), ("admin", "users", "789"), ("guests", "123"), ("guests", "456", "preferences"), ] for i, ns in enumerate(namespaces): store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"}) result = store.list_namespaces(prefix=("users", "*")) expected = [ ("users", "123"), ("users", "456"), ("users", "789", "settings"), ] assert sorted(result) == sorted(expected) result = store.list_namespaces(suffix=("*", "preferences")) expected = [ ("guests", "456", "preferences"), ] assert result == expected result = store.list_namespaces(prefix=("*", "users"), suffix=("*", "settings")) assert result == [] store.put( namespace=("admin", "users", "settings", "789"), key="foo", value={"data": "some_val"}, ) expected = [ ("admin", "users", "settings", "789"), ] def test_list_namespaces_pagination() -> None: store = InMemoryStore() for i in range(20): ns = ("namespace", f"sub_{i:02d}") store.put(namespace=ns, key=f"id_{i:02d}", value={"data": f"value_{i:02d}"}) result = store.list_namespaces(prefix=("namespace",), limit=5, offset=0) expected = [("namespace", f"sub_{i:02d}") for i in range(5)] assert result == expected result = store.list_namespaces(prefix=("namespace",), limit=5, offset=5) expected = [("namespace", f"sub_{i:02d}") for i in range(5, 10)] assert result == expected result = store.list_namespaces(prefix=("namespace",), limit=5, offset=15) expected = [("namespace", f"sub_{i:02d}") for i in range(15, 20)] assert result == expected def test_list_namespaces_max_depth() -> None: store = InMemoryStore() namespaces = [ ("a", "b", "c", "d"), ("a", "b", "c", "e"), ("a", "b", "f"), ("a", "g"), ("h", "i", "j", "k"), ] for i, ns in enumerate(namespaces): store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"}) result = store.list_namespaces(max_depth=2) expected = [ ("a", "b"), ("a", "g"), ("h", "i"), ] assert sorted(result) == sorted(expected) def test_list_namespaces_no_conditions() -> None: store = InMemoryStore() namespaces = [ ("a", "b"), ("c", "d"), ("e", "f", "g"), ] for i, ns in enumerate(namespaces): store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"}) result = store.list_namespaces() expected = namespaces assert sorted(result) == sorted(expected) def test_list_namespaces_empty_store() -> None: store = InMemoryStore() result = store.list_namespaces() assert result == [] async def test_cannot_put_empty_namespace() -> None: store = InMemoryStore() doc = {"foo": "bar"} with pytest.raises(InvalidNamespaceError): store.put((), "foo", doc) with pytest.raises(InvalidNamespaceError): await store.aput((), "foo", doc) with pytest.raises(InvalidNamespaceError): store.put(("the", "thing.about"), "foo", doc) with pytest.raises(InvalidNamespaceError): await store.aput(("the", "thing.about"), "foo", doc) with pytest.raises(InvalidNamespaceError): store.put(("some", "fun", ""), "foo", doc) with pytest.raises(InvalidNamespaceError): await store.aput(("some", "fun", ""), "foo", doc) with pytest.raises(InvalidNamespaceError): await store.aput(("langgraph", "foo"), "bar", doc) with pytest.raises(InvalidNamespaceError): store.put(("langgraph", "foo"), "bar", doc) await store.aput(("foo", "langgraph", "foo"), "bar", doc) assert (await store.aget(("foo", "langgraph", "foo"), "bar")).value == doc # type: ignore[union-attr] assert (await store.asearch(("foo", "langgraph", "foo"), query="bar"))[ 0 ].value == doc await store.adelete(("foo", "langgraph", "foo"), "bar") assert (await store.aget(("foo", "langgraph", "foo"), "bar")) is None store.put(("foo", "langgraph", "foo"), "bar", doc) assert store.get(("foo", "langgraph", "foo"), "bar").value == doc # type: ignore[union-attr] assert store.search(("foo", "langgraph", "foo"), query="bar")[0].value == doc store.delete(("foo", "langgraph", "foo"), "bar") assert store.get(("foo", "langgraph", "foo"), "bar") is None # Do the same but go past the public put api await store.abatch([PutOp(("langgraph", "foo"), "bar", doc)]) assert (await store.aget(("langgraph", "foo"), "bar")).value == doc # type: ignore[union-attr] assert (await store.asearch(("langgraph", "foo")))[0].value == doc await store.adelete(("langgraph", "foo"), "bar") assert (await store.aget(("langgraph", "foo"), "bar")) is None store.batch([PutOp(("langgraph", "foo"), "bar", doc)]) assert store.get(("langgraph", "foo"), "bar").value == doc # type: ignore[union-attr] assert store.search(("langgraph", "foo"))[0].value == doc store.delete(("langgraph", "foo"), "bar") assert store.get(("langgraph", "foo"), "bar") is None async_store = MockAsyncBatchedStore() doc = {"foo": "bar"} with pytest.raises(InvalidNamespaceError): await async_store.aput((), "foo", doc) with pytest.raises(InvalidNamespaceError): await async_store.aput(("the", "thing.about"), "foo", doc) with pytest.raises(InvalidNamespaceError): await async_store.aput(("some", "fun", ""), "foo", doc) with pytest.raises(InvalidNamespaceError): await async_store.aput(("langgraph", "foo"), "bar", doc) await async_store.aput(("foo", "langgraph", "foo"), "bar", doc) val = await async_store.aget(("foo", "langgraph", "foo"), "bar") assert val is not None assert val.value == doc assert (await async_store.asearch(("foo", "langgraph", "foo")))[0].value == doc assert (await async_store.asearch(("foo", "langgraph", "foo"), query="bar"))[ 0 ].value == doc await async_store.adelete(("foo", "langgraph", "foo"), "bar") assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")) is None await async_store.abatch([PutOp(("valid", "namespace"), "key", doc)]) val = await async_store.aget(("valid", "namespace"), "key") assert val is not None assert val.value == doc assert (await async_store.asearch(("valid", "namespace")))[0].value == doc await async_store.adelete(("valid", "namespace"), "key") assert (await async_store.aget(("valid", "namespace"), "key")) is None async def test_async_batch_store_deduplication(mocker: MockerFixture) -> None: abatch = mocker.spy(InMemoryStore, "batch") store = MockAsyncBatchedStore() same_doc = {"value": "same"} diff_doc = {"value": "different"} await asyncio.gather( store.aput(namespace=("test",), key="same", value=same_doc), store.aput(namespace=("test",), key="different", value=diff_doc), ) abatch.reset_mock() results = await asyncio.gather( store.aget(namespace=("test",), key="same"), store.aget(namespace=("test",), key="same"), store.aget(namespace=("test",), key="different"), ) assert len(results) == 3 assert results[0] == results[1] assert results[0] != results[2] assert results[0].value == same_doc # type: ignore assert results[2].value == diff_doc # type: ignore assert len(abatch.call_args_list) == 1 ops = list(abatch.call_args_list[0].args[1]) assert len(ops) == 2 assert GetOp(("test",), "same") in ops assert GetOp(("test",), "different") in ops abatch.reset_mock() doc1 = {"value": 1} doc2 = {"value": 2} results = await asyncio.gather( store.aput(namespace=("test",), key="key", value=doc1), store.aput(namespace=("test",), key="key", value=doc2), ) assert len(abatch.call_args_list) == 1 ops = list(abatch.call_args_list[0].args[1]) assert len(ops) == 1 assert ops[0] == PutOp(("test",), "key", doc2) assert len(results) == 2 assert all(result is None for result in results) result = await store.aget(namespace=("test",), key="key") assert result is not None assert result.value == doc2 abatch.reset_mock() results = await asyncio.gather( store.asearch(("test",), filter={"value": 2}), store.asearch(("test",), filter={"value": 2}), ) assert len(abatch.call_args_list) == 1 ops = list(abatch.call_args_list[0].args[1]) assert len(ops) == 1 assert len(results) == 2 assert results[0] == results[1] assert len(results[0]) == 1 assert results[0][0].value == doc2 abatch.reset_mock() @pytest.fixture def fake_embeddings() -> CharacterEmbeddings: return CharacterEmbeddings(dims=500) def test_vector_store_initialization(fake_embeddings: CharacterEmbeddings) -> None: """Test store initialization with embedding config.""" store = InMemoryStore( index={"dims": fake_embeddings.dims, "embed": fake_embeddings} ) assert store.index_config is not None assert store.index_config["dims"] == fake_embeddings.dims assert store.index_config["embed"] == fake_embeddings def test_vector_insert_with_auto_embedding( fake_embeddings: CharacterEmbeddings, ) -> None: """Test inserting items that get auto-embedded.""" store = InMemoryStore( index={"dims": fake_embeddings.dims, "embed": fake_embeddings} ) docs = [ ("doc1", {"text": "short text"}), ("doc2", {"text": "longer text document"}), ("doc3", {"text": "longest text document here"}), ("doc4", {"description": "text in description field"}), ("doc5", {"content": "text in content field"}), ("doc6", {"body": "text in body field"}), ] for key, value in docs: store.put(("test",), key, value) results = store.search(("test",), query="long text") assert len(results) > 0 doc_order = [r.key for r in results] assert "doc2" in doc_order assert "doc3" in doc_order async def test_async_vector_insert_with_auto_embedding( fake_embeddings: CharacterEmbeddings, ) -> None: """Test inserting items that get auto-embedded using async methods.""" store = InMemoryStore( index={"dims": fake_embeddings.dims, "embed": fake_embeddings} ) docs = [ ("doc1", {"text": "short text"}), ("doc2", {"text": "longer text document"}), ("doc3", {"text": "longest text document here"}), ("doc4", {"description": "text in description field"}), ("doc5", {"content": "text in content field"}), ("doc6", {"body": "text in body field"}), ] for key, value in docs: await store.aput(("test",), key, value) results = await store.asearch(("test",), query="long text") assert len(results) > 0 doc_order = [r.key for r in results] assert "doc2" in doc_order assert "doc3" in doc_order def test_vector_update_with_embedding(fake_embeddings: CharacterEmbeddings) -> None: """Test that updating items properly updates their embeddings.""" store = InMemoryStore( index={"dims": fake_embeddings.dims, "embed": fake_embeddings} ) store.put(("test",), "doc1", {"text": "zany zebra Xerxes"}) store.put(("test",), "doc2", {"text": "something about dogs"}) store.put(("test",), "doc3", {"text": "text about birds"}) results_initial = store.search(("test",), query="Zany Xerxes") assert len(results_initial) > 0 assert results_initial[0].key == "doc1" initial_score = results_initial[0].score assert initial_score is not None store.put(("test",), "doc1", {"text": "new text about dogs"}) results_after = store.search(("test",), query="Zany Xerxes") after_score = next((r.score for r in results_after if r.key == "doc1"), 0.0) assert after_score is not None assert after_score < initial_score results_new = store.search(("test",), query="new text about dogs") for r in results_new: if r.key == "doc1": assert r.score > after_score # Don't index this one store.put(("test",), "doc4", {"text": "new text about dogs"}, index=False) results_new = store.search(("test",), query="new text about dogs", limit=3) assert not any(r.key == "doc4" for r in results_new) async def test_async_vector_update_with_embedding( fake_embeddings: CharacterEmbeddings, ) -> None: """Test that updating items properly updates their embeddings using async methods.""" store = InMemoryStore( index={"dims": fake_embeddings.dims, "embed": fake_embeddings} ) await store.aput(("test",), "doc1", {"text": "zany zebra Xerxes"}) await store.aput(("test",), "doc2", {"text": "something about dogs"}) await store.aput(("test",), "doc3", {"text": "text about birds"}) results_initial = await store.asearch(("test",), query="Zany Xerxes") assert len(results_initial) > 0 assert results_initial[0].key == "doc1" initial_score = results_initial[0].score await store.aput(("test",), "doc1", {"text": "new text about dogs"}) results_after = await store.asearch(("test",), query="Zany Xerxes") after_score = next((r.score for r in results_after if r.key == "doc1"), 0.0) assert after_score is not None assert after_score < initial_score results_new = await store.asearch(("test",), query="new text about dogs") for r in results_new: if r.key == "doc1": assert r.score is not None assert r.score > after_score # Don't index this one await store.aput(("test",), "doc4", {"text": "new text about dogs"}, index=False) results_new = await store.asearch(("test",), query="new text about dogs", limit=3) assert not any(r.key == "doc4" for r in results_new) def test_vector_search_with_filters(fake_embeddings: CharacterEmbeddings) -> None: """Test combining vector search with filters.""" inmem_store = InMemoryStore( index={"dims": fake_embeddings.dims, "embed": fake_embeddings} ) # Insert test documents docs = [ ("doc1", {"text": "red apple", "color": "red", "score": 4.5}), ("doc2", {"text": "red car", "color": "red", "score": 3.0}), ("doc3", {"text": "green apple", "color": "green", "score": 4.0}), ("doc4", {"text": "blue car", "color": "blue", "score": 3.5}), ] for key, value in docs: inmem_store.put(("test",), key, value) results = inmem_store.search(("test",), query="apple", filter={"color": "red"}) assert len(results) == 2 assert results[0].key == "doc1" results = inmem_store.search(("test",), query="car", filter={"color": "red"}) assert len(results) == 2 assert results[0].key == "doc2" results = inmem_store.search( ("test",), query="bbbbluuu", filter={"score": {"$gt": 3.2}} ) assert len(results) == 3 assert results[0].key == "doc4" # Multiple filters results = inmem_store.search( ("test",), query="apple", filter={"score": {"$gte": 4.0}, "color": "green"} ) assert len(results) == 1 assert results[0].key == "doc3" async def test_async_vector_search_with_filters( fake_embeddings: CharacterEmbeddings, ) -> None: """Test combining vector search with filters using async methods.""" store = InMemoryStore( index={"dims": fake_embeddings.dims, "embed": fake_embeddings} ) # Insert test documents docs = [ ("doc1", {"text": "red apple", "color": "red", "score": 4.5}), ("doc2", {"text": "red car", "color": "red", "score": 3.0}), ("doc3", {"text": "green apple", "color": "green", "score": 4.0}), ("doc4", {"text": "blue car", "color": "blue", "score": 3.5}), ] for key, value in docs: await store.aput(("test",), key, value) results = await store.asearch(("test",), query="apple", filter={"color": "red"}) assert len(results) == 2 assert results[0].key == "doc1" results = await store.asearch(("test",), query="car", filter={"color": "red"}) assert len(results) == 2 assert results[0].key == "doc2" results = await store.asearch( ("test",), query="bbbbluuu", filter={"score": {"$gt": 3.2}} ) assert len(results) == 3 assert results[0].key == "doc4" # Multiple filters results = await store.asearch( ("test",), query="apple", filter={"score": {"$gte": 4.0}, "color": "green"} ) assert len(results) == 1 assert results[0].key == "doc3" async def test_async_batched_vector_search_concurrent( fake_embeddings: CharacterEmbeddings, ) -> None: """Test concurrent vector search operations using async batched store.""" store = MockAsyncBatchedStore( index={"dims": fake_embeddings.dims, "embed": fake_embeddings} ) colors = ["red", "blue", "green", "yellow", "purple"] items = ["apple", "car", "house", "book", "phone"] scores = [3.0, 3.5, 4.0, 4.5, 5.0] docs = [] for i in range(50): color = colors[i % len(colors)] item = items[i % len(items)] score = scores[i % len(scores)] docs.append( ( f"doc{i}", {"text": f"{color} {item}", "color": color, "score": score, "index": i}, ) ) coros = [ *[store.aput(("test",), key, value) for key, value in docs], *[store.adelete(("test",), key) for key, value in docs], *[store.aput(("test",), key, value) for key, value in docs], ] await asyncio.gather(*coros) # Prepare multiple search queries with different filters search_queries: list[tuple[str, dict[str, Any]]] = [ ("apple", {"color": "red"}), ("car", {"color": "blue"}), ("house", {"color": "green"}), ("phone", {"score": {"$gt": 4.99}}), ("book", {"score": {"$lte": 3.5}}), ("apple", {"score": {"$gte": 3.0}, "color": "red"}), ("car", {"score": {"$lt": 5.1}, "color": "blue"}), ("house", {"index": {"$gt": 25}}), ("phone", {"index": {"$lte": 10}}), ] all_results = await asyncio.gather( *[ store.asearch(("test",), query=query, filter=filter_) for query, filter_ in search_queries ] ) for results, (query, filter_) in zip(all_results, search_queries): assert len(results) > 0, f"No results for query '{query}' with filter {filter_}" for result in results: if "color" in filter_: assert result.value["color"] == filter_["color"] if "score" in filter_: score = result.value["score"] for op, value in filter_["score"].items(): if op == "$gt": assert score > value elif op == "$gte": assert score >= value elif op == "$lt": assert score < value elif op == "$lte": assert score <= value if "index" in filter_: index = result.value["index"] for op, value in filter_["index"].items(): if op == "$gt": assert index > value elif op == "$gte": assert index >= value elif op == "$lt": assert index < value elif op == "$lte": assert index <= value def test_vector_search_pagination(fake_embeddings: CharacterEmbeddings) -> None: """Test pagination with vector search.""" store = InMemoryStore( index={"dims": fake_embeddings.dims, "embed": fake_embeddings} ) for i in range(5): store.put(("test",), f"doc{i}", {"text": f"test document number {i}"}) results_page1 = store.search(("test",), query="test", limit=2) results_page2 = store.search(("test",), query="test", limit=2, offset=2) assert len(results_page1) == 2 assert len(results_page2) == 2 assert results_page1[0].key != results_page2[0].key all_results = store.search(("test",), query="test", limit=10) assert len(all_results) == 5 async def test_async_vector_search_pagination( fake_embeddings: CharacterEmbeddings, ) -> None: """Test pagination with vector search using async methods.""" store = InMemoryStore( index={"dims": fake_embeddings.dims, "embed": fake_embeddings} ) for i in range(5): await store.aput(("test",), f"doc{i}", {"text": f"test document number {i}"}) results_page1 = await store.asearch(("test",), query="test", limit=2) results_page2 = await store.asearch(("test",), query="test", limit=2, offset=2) assert len(results_page1) == 2 assert len(results_page2) == 2 assert results_page1[0].key != results_page2[0].key all_results = await store.asearch(("test",), query="test", limit=10) assert len(all_results) == 5 async def test_embed_with_path(fake_embeddings: CharacterEmbeddings) -> None: # Test store-level field configuration store = InMemoryStore( index={ "dims": fake_embeddings.dims, "embed": fake_embeddings, # Key 2 isn't included. Don't index it. "fields": ["key0", "key1", "key3"], } ) # This will have 2 vectors representing it doc1 = { # Omit key0 - check it doesn't raise an error "key1": "xxx", "key2": "yyy", "key3": "zzz", } # This will have 3 vectors representing it doc2 = { "key0": "uuu", "key1": "vvv", "key2": "www", "key3": "xxx", } await store.aput(("test",), "doc1", doc1) await store.aput(("test",), "doc2", doc2) # doc2.key3 and doc1.key1 both would have the highest score results = await store.asearch(("test",), query="xxx") assert len(results) == 2 assert results[0].key != results[1].key ascore = results[0].score bscore = results[1].score assert ascore == bscore assert ascore is not None and bscore is not None results = await store.asearch(("test",), query="uuu") assert len(results) == 2 assert results[0].key != results[1].key assert results[0].key == "doc2" assert results[0].score is not None and results[0].score > results[1].score assert ascore == pytest.approx(results[0].score, abs=1e-5) # Un-indexed - will have low results for both. Not zero (because we're projecting) # but less than the above. results = await store.asearch(("test",), query="www") assert len(results) == 2 assert results[0].score < ascore assert results[1].score < ascore # Test operation-level field configuration store_no_defaults = InMemoryStore( index={ "dims": fake_embeddings.dims, "embed": fake_embeddings, "fields": ["key17"], } ) doc3 = { "key0": "aaa", "key1": "bbb", "key2": "ccc", "key3": "ddd", } doc4 = { "key0": "eee", "key1": "bbb", # Same as doc3.key1 "key2": "fff", "key3": "ggg", } await store_no_defaults.aput(("test",), "doc3", doc3, index=["key0", "key1"]) await store_no_defaults.aput(("test",), "doc4", doc4, index=["key1", "key3"]) results = await store_no_defaults.asearch(("test",), query="aaa") assert len(results) == 2 assert results[0].key == "doc3" assert results[0].score is not None and results[0].score > results[1].score results = await store_no_defaults.asearch(("test",), query="ggg") assert len(results) == 2 assert results[0].key == "doc4" assert results[0].score is not None and results[0].score > results[1].score results = await store_no_defaults.asearch(("test",), query="bbb") assert len(results) == 2 assert results[0].key != results[1].key assert results[0].score == results[1].score results = await store_no_defaults.asearch(("test",), query="ccc") assert len(results) == 2 assert all(r.score < ascore for r in results) doc5 = { "key0": "hhh", "key1": "iii", } await store_no_defaults.aput(("test",), "doc5", doc5, index=False) results = await store_no_defaults.asearch(("test",), query="hhh") assert len(results) == 3 doc5_result = next(r for r in results if r.key == "doc5") assert doc5_result.score is None ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/tests/test_memory.py`: ```py from typing import Any import pytest from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( Checkpoint, CheckpointMetadata, create_checkpoint, empty_checkpoint, ) from langgraph.checkpoint.memory import MemorySaver class TestMemorySaver: @pytest.fixture(autouse=True) def setup(self) -> None: self.memory_saver = MemorySaver() # objects for test setup self.config_1: RunnableConfig = { "configurable": { "thread_id": "thread-1", "checkpoint_ns": "", # for backwards compatibility testing "thread_ts": "1", } } self.config_2: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_ns": "", "checkpoint_id": "2", } } self.config_3: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2-inner", "checkpoint_ns": "inner", } } self.chkpnt_1: Checkpoint = empty_checkpoint() self.chkpnt_2: Checkpoint = create_checkpoint(self.chkpnt_1, {}, 1) self.chkpnt_3: Checkpoint = empty_checkpoint() self.metadata_1: CheckpointMetadata = { "source": "input", "step": 2, "writes": {}, "score": 1, } self.metadata_2: CheckpointMetadata = { "source": "loop", "step": 1, "writes": {"foo": "bar"}, "score": None, } self.metadata_3: CheckpointMetadata = {} async def test_search(self) -> None: # set up test # save checkpoints self.memory_saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {}) self.memory_saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {}) self.memory_saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {}) # call method / assertions query_1 = {"source": "input"} # search by 1 key query_2 = { "step": 1, "writes": {"foo": "bar"}, } # search by multiple keys query_3: dict[str, Any] = {} # search by no keys, return all checkpoints query_4 = {"source": "update", "step": 1} # no match search_results_1 = list(self.memory_saver.list(None, filter=query_1)) assert len(search_results_1) == 1 assert search_results_1[0].metadata == self.metadata_1 search_results_2 = list(self.memory_saver.list(None, filter=query_2)) assert len(search_results_2) == 1 assert search_results_2[0].metadata == self.metadata_2 search_results_3 = list(self.memory_saver.list(None, filter=query_3)) assert len(search_results_3) == 3 search_results_4 = list(self.memory_saver.list(None, filter=query_4)) assert len(search_results_4) == 0 # search by config (defaults to checkpoints across all namespaces) search_results_5 = list( self.memory_saver.list({"configurable": {"thread_id": "thread-2"}}) ) assert len(search_results_5) == 2 assert { search_results_5[0].config["configurable"]["checkpoint_ns"], search_results_5[1].config["configurable"]["checkpoint_ns"], } == {"", "inner"} # TODO: test before and limit params async def test_asearch(self) -> None: # set up test # save checkpoints self.memory_saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {}) self.memory_saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {}) self.memory_saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {}) # call method / assertions query_1 = {"source": "input"} # search by 1 key query_2 = { "step": 1, "writes": {"foo": "bar"}, } # search by multiple keys query_3: dict[str, Any] = {} # search by no keys, return all checkpoints query_4 = {"source": "update", "step": 1} # no match search_results_1 = [ c async for c in self.memory_saver.alist(None, filter=query_1) ] assert len(search_results_1) == 1 assert search_results_1[0].metadata == self.metadata_1 search_results_2 = [ c async for c in self.memory_saver.alist(None, filter=query_2) ] assert len(search_results_2) == 1 assert search_results_2[0].metadata == self.metadata_2 search_results_3 = [ c async for c in self.memory_saver.alist(None, filter=query_3) ] assert len(search_results_3) == 3 search_results_4 = [ c async for c in self.memory_saver.alist(None, filter=query_4) ] assert len(search_results_4) == 0 ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/tests/embed_test_utils.py`: ```py """Embedding utilities for testing.""" import math import random from collections import Counter, defaultdict from typing import Any from langchain_core.embeddings import Embeddings class CharacterEmbeddings(Embeddings): """Simple character-frequency based embeddings using random projections.""" def __init__(self, dims: int = 50, seed: int = 42): """Initialize with embedding dimensions and random seed.""" self._rng = random.Random(seed) self.dims = dims # Create projection vector for each character lazily self._char_projections: defaultdict[str, list[float]] = defaultdict( lambda: [ self._rng.gauss(0, 1 / math.sqrt(self.dims)) for _ in range(self.dims) ] ) def _embed_one(self, text: str) -> list[float]: """Embed a single text.""" counts = Counter(text) total = sum(counts.values()) if total == 0: return [0.0] * self.dims embedding = [0.0] * self.dims for char, count in counts.items(): weight = count / total char_proj = self._char_projections[char] for i, proj in enumerate(char_proj): embedding[i] += weight * proj norm = math.sqrt(sum(x * x for x in embedding)) if norm > 0: embedding = [x / norm for x in embedding] return embedding def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed a list of documents.""" return [self._embed_one(text) for text in texts] def embed_query(self, text: str) -> list[float]: """Embed a query string.""" return self._embed_one(text) def __eq__(self, other: Any) -> bool: return isinstance(other, CharacterEmbeddings) and self.dims == other.dims ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/checkpoint/duckdb/__init__.py`: ```py import threading from contextlib import contextmanager from typing import Any, Iterator, Optional, Sequence from langchain_core.runnables import RunnableConfig import duckdb from langgraph.checkpoint.base import ( WRITES_IDX_MAP, ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, get_checkpoint_id, ) from langgraph.checkpoint.duckdb.base import BaseDuckDBSaver from langgraph.checkpoint.serde.base import SerializerProtocol class DuckDBSaver(BaseDuckDBSaver): lock: threading.Lock def __init__( self, conn: duckdb.DuckDBPyConnection, serde: Optional[SerializerProtocol] = None, ) -> None: super().__init__(serde=serde) self.conn = conn self.lock = threading.Lock() @classmethod @contextmanager def from_conn_string(cls, conn_string: str) -> Iterator["DuckDBSaver"]: """Create a new DuckDBSaver instance from a connection string. Args: conn_string (str): The DuckDB connection info string. Returns: DuckDBSaver: A new DuckDBSaver instance. """ with duckdb.connect(conn_string) as conn: yield cls(conn) def setup(self) -> None: """Set up the checkpoint database asynchronously. This method creates the necessary tables in the DuckDB database if they don't already exist and runs database migrations. It MUST be called directly by the user the first time checkpointer is used. """ with self.lock, self.conn.cursor() as cur: try: row = cur.execute( "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" ).fetchone() if row is None: version = -1 else: version = row[0] except duckdb.CatalogException: version = -1 for v, migration in zip( range(version + 1, len(self.MIGRATIONS)), self.MIGRATIONS[version + 1 :], ): cur.execute(migration) cur.execute("INSERT INTO checkpoint_migrations (v) VALUES (?)", [v]) def list( self, config: Optional[RunnableConfig], *, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[CheckpointTuple]: """List checkpoints from the database. This method retrieves a list of checkpoint tuples from the DuckDB database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: config (RunnableConfig): The config to use for listing the checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. Defaults to None. before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None. Yields: Iterator[CheckpointTuple]: An iterator of checkpoint tuples. Examples: >>> from langgraph.checkpoint.duckdb import DuckDBSaver >>> with DuckDBSaver.from_conn_string(":memory:") as memory: ... # Run a graph, then list the checkpoints >>> config = {"configurable": {"thread_id": "1"}} >>> checkpoints = list(memory.list(config, limit=2)) >>> print(checkpoints) [CheckpointTuple(...), CheckpointTuple(...)] >>> config = {"configurable": {"thread_id": "1"}} >>> before = {"configurable": {"checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875"}} >>> with DuckDBSaver.from_conn_string(":memory:") as memory: ... # Run a graph, then list the checkpoints >>> checkpoints = list(memory.list(config, before=before)) >>> print(checkpoints) [CheckpointTuple(...), ...] """ where, args = self._search_where(config, filter, before) query = self.SELECT_SQL + where + " ORDER BY checkpoint_id DESC" if limit: query += f" LIMIT {limit}" # if we change this to use .stream() we need to make sure to close the cursor with self._cursor() as cur: cur.execute(query, args) for value in cur.fetchall(): ( thread_id, checkpoint, checkpoint_ns, checkpoint_id, parent_checkpoint_id, metadata, channel_values, pending_writes, pending_sends, ) = value yield CheckpointTuple( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } }, self._load_checkpoint( checkpoint, channel_values, pending_sends, ), self._load_metadata(metadata), ( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } if parent_checkpoint_id else None ), self._load_writes(pending_writes), ) def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database. This method retrieves a checkpoint tuple from the DuckDB database based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. Examples: Basic: >>> config = {"configurable": {"thread_id": "1"}} >>> checkpoint_tuple = memory.get_tuple(config) >>> print(checkpoint_tuple) CheckpointTuple(...) With timestamp: >>> config = { ... "configurable": { ... "thread_id": "1", ... "checkpoint_ns": "", ... "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", ... } ... } >>> checkpoint_tuple = memory.get_tuple(config) >>> print(checkpoint_tuple) CheckpointTuple(...) """ # noqa thread_id = config["configurable"]["thread_id"] checkpoint_id = get_checkpoint_id(config) checkpoint_ns = config["configurable"].get("checkpoint_ns", "") if checkpoint_id: args: tuple[Any, ...] = (thread_id, checkpoint_ns, checkpoint_id) where = "WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?" else: args = (thread_id, checkpoint_ns) where = "WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1" with self._cursor() as cur: cur.execute( self.SELECT_SQL + where, args, ) value = cur.fetchone() if value: ( thread_id, checkpoint, checkpoint_ns, checkpoint_id, parent_checkpoint_id, metadata, channel_values, pending_writes, pending_sends, ) = value return CheckpointTuple( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } }, self._load_checkpoint( checkpoint, channel_values, pending_sends, ), self._load_metadata(metadata), ( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } if parent_checkpoint_id else None ), self._load_writes(pending_writes), ) def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Save a checkpoint to the database. This method saves a checkpoint to the DuckDB database. The checkpoint is associated with the provided config and its parent config (if any). Args: config (RunnableConfig): The config to associate with the checkpoint. checkpoint (Checkpoint): The checkpoint to save. metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. new_versions (ChannelVersions): New channel versions as of this write. Returns: RunnableConfig: Updated configuration after storing the checkpoint. Examples: >>> from langgraph.checkpoint.duckdb import DuckDBSaver >>> with DuckDBSaver.from_conn_string(":memory:") as memory: >>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} >>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "channel_values": {"key": "value"}} >>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {}) >>> print(saved_config) {'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef4f797-8335-6428-8001-8a1503f9b875'}} """ configurable = config["configurable"].copy() thread_id = configurable.pop("thread_id") checkpoint_ns = configurable.pop("checkpoint_ns") checkpoint_id = configurable.pop( "checkpoint_id", configurable.pop("thread_ts", None) ) copy = checkpoint.copy() next_config = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint["id"], } } checkpoint_blobs = self._dump_blobs( thread_id, checkpoint_ns, copy.pop("channel_values"), # type: ignore[misc] new_versions, ) with self._cursor() as cur: if checkpoint_blobs: cur.executemany(self.UPSERT_CHECKPOINT_BLOBS_SQL, checkpoint_blobs) cur.execute( self.UPSERT_CHECKPOINTS_SQL, ( thread_id, checkpoint_ns, checkpoint["id"], checkpoint_id, self._dump_checkpoint(copy), self._dump_metadata(metadata), ), ) return next_config def put_writes( self, config: RunnableConfig, writes: Sequence[tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint. This method saves intermediate writes associated with a checkpoint to the DuckDB database. Args: config (RunnableConfig): Configuration of the related checkpoint. writes (List[Tuple[str, Any]]): List of writes to store. task_id (str): Identifier for the task creating the writes. """ query = ( self.UPSERT_CHECKPOINT_WRITES_SQL if all(w[0] in WRITES_IDX_MAP for w in writes) else self.INSERT_CHECKPOINT_WRITES_SQL ) with self._cursor() as cur: cur.executemany( query, self._dump_writes( config["configurable"]["thread_id"], config["configurable"]["checkpoint_ns"], config["configurable"]["checkpoint_id"], task_id, writes, ), ) @contextmanager def _cursor(self) -> Iterator[duckdb.DuckDBPyConnection]: with self.lock, self.conn.cursor() as cur: yield cur __all__ = ["DuckDBSaver", "Conn"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/checkpoint/duckdb/aio.py`: ```py import asyncio from contextlib import asynccontextmanager from typing import Any, AsyncIterator, Iterator, Optional, Sequence from langchain_core.runnables import RunnableConfig import duckdb from langgraph.checkpoint.base import ( WRITES_IDX_MAP, ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, get_checkpoint_id, ) from langgraph.checkpoint.duckdb.base import BaseDuckDBSaver from langgraph.checkpoint.serde.base import SerializerProtocol class AsyncDuckDBSaver(BaseDuckDBSaver): lock: asyncio.Lock def __init__( self, conn: duckdb.DuckDBPyConnection, serde: Optional[SerializerProtocol] = None, ) -> None: super().__init__(serde=serde) self.conn = conn self.lock = asyncio.Lock() self.loop = asyncio.get_running_loop() @classmethod @asynccontextmanager async def from_conn_string( cls, conn_string: str, ) -> AsyncIterator["AsyncDuckDBSaver"]: """Create a new AsyncDuckDBSaver instance from a connection string. Args: conn_string (str): The DuckDB connection info string. Returns: AsyncDuckDBSaver: A new AsyncDuckDBSaver instance. """ with duckdb.connect(conn_string) as conn: yield cls(conn) async def setup(self) -> None: """Set up the checkpoint database asynchronously. This method creates the necessary tables in the DuckDB database if they don't already exist and runs database migrations. It MUST be called directly by the user the first time checkpointer is used. """ async with self.lock: with self.conn.cursor() as cur: try: await asyncio.to_thread( cur.execute, "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1", ) row = await asyncio.to_thread(cur.fetchone) if row is None: version = -1 else: version = row[0] except duckdb.CatalogException: version = -1 for v, migration in zip( range(version + 1, len(self.MIGRATIONS)), self.MIGRATIONS[version + 1 :], ): await asyncio.to_thread(cur.execute, migration) await asyncio.to_thread( cur.execute, "INSERT INTO checkpoint_migrations (v) VALUES (?)", [v], ) async def alist( self, config: Optional[RunnableConfig], *, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> AsyncIterator[CheckpointTuple]: """List checkpoints from the database asynchronously. This method retrieves a list of checkpoint tuples from the DuckDB database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. limit (Optional[int]): Maximum number of checkpoints to return. Yields: AsyncIterator[CheckpointTuple]: An asynchronous iterator of matching checkpoint tuples. """ where, args = self._search_where(config, filter, before) query = self.SELECT_SQL + where + " ORDER BY checkpoint_id DESC" if limit: query += f" LIMIT {limit}" # if we change this to use .stream() we need to make sure to close the cursor async with self._cursor() as cur: await asyncio.to_thread(cur.execute, query, args) results = await asyncio.to_thread(cur.fetchall) for value in results: ( thread_id, checkpoint, checkpoint_ns, checkpoint_id, parent_checkpoint_id, metadata, channel_values, pending_writes, pending_sends, ) = value yield CheckpointTuple( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } }, await asyncio.to_thread( self._load_checkpoint, checkpoint, channel_values, pending_sends, ), self._load_metadata(metadata), ( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } if parent_checkpoint_id else None ), await asyncio.to_thread(self._load_writes, pending_writes), ) async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database asynchronously. This method retrieves a checkpoint tuple from the DuckDBdatabase based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. """ thread_id = config["configurable"]["thread_id"] checkpoint_id = get_checkpoint_id(config) checkpoint_ns = config["configurable"].get("checkpoint_ns", "") if checkpoint_id: args: tuple[Any, ...] = (thread_id, checkpoint_ns, checkpoint_id) where = "WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?" else: args = (thread_id, checkpoint_ns) where = "WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1" async with self._cursor() as cur: await asyncio.to_thread( cur.execute, self.SELECT_SQL + where, args, ) value = await asyncio.to_thread(cur.fetchone) if value: ( thread_id, checkpoint, checkpoint_ns, checkpoint_id, parent_checkpoint_id, metadata, channel_values, pending_writes, pending_sends, ) = value return CheckpointTuple( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } }, await asyncio.to_thread( self._load_checkpoint, checkpoint, channel_values, pending_sends, ), self._load_metadata(metadata), ( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } if parent_checkpoint_id else None ), await asyncio.to_thread(self._load_writes, pending_writes), ) async def aput( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Save a checkpoint to the database asynchronously. This method saves a checkpoint to the DuckDB database. The checkpoint is associated with the provided config and its parent config (if any). Args: config (RunnableConfig): The config to associate with the checkpoint. checkpoint (Checkpoint): The checkpoint to save. metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. new_versions (ChannelVersions): New channel versions as of this write. Returns: RunnableConfig: Updated configuration after storing the checkpoint. """ configurable = config["configurable"].copy() thread_id = configurable.pop("thread_id") checkpoint_ns = configurable.pop("checkpoint_ns") checkpoint_id = configurable.pop( "checkpoint_id", configurable.pop("thread_ts", None) ) copy = checkpoint.copy() next_config = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint["id"], } } checkpoint_blobs = await asyncio.to_thread( self._dump_blobs, thread_id, checkpoint_ns, copy.pop("channel_values"), # type: ignore[misc] new_versions, ) async with self._cursor() as cur: if checkpoint_blobs: await asyncio.to_thread( cur.executemany, self.UPSERT_CHECKPOINT_BLOBS_SQL, checkpoint_blobs ) await asyncio.to_thread( cur.execute, self.UPSERT_CHECKPOINTS_SQL, ( thread_id, checkpoint_ns, checkpoint["id"], checkpoint_id, self._dump_checkpoint(copy), self._dump_metadata(metadata), ), ) return next_config async def aput_writes( self, config: RunnableConfig, writes: Sequence[tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint asynchronously. This method saves intermediate writes associated with a checkpoint to the database. Args: config (RunnableConfig): Configuration of the related checkpoint. writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. task_id (str): Identifier for the task creating the writes. """ query = ( self.UPSERT_CHECKPOINT_WRITES_SQL if all(w[0] in WRITES_IDX_MAP for w in writes) else self.INSERT_CHECKPOINT_WRITES_SQL ) params = await asyncio.to_thread( self._dump_writes, config["configurable"]["thread_id"], config["configurable"]["checkpoint_ns"], config["configurable"]["checkpoint_id"], task_id, writes, ) async with self._cursor() as cur: await asyncio.to_thread(cur.executemany, query, params) @asynccontextmanager async def _cursor(self) -> AsyncIterator[duckdb.DuckDBPyConnection]: async with self.lock: with self.conn.cursor() as cur: yield cur def list( self, config: Optional[RunnableConfig], *, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[CheckpointTuple]: """List checkpoints from the database. This method retrieves a list of checkpoint tuples from the DuckDB database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. limit (Optional[int]): Maximum number of checkpoints to return. Yields: Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples. """ aiter_ = self.alist(config, filter=filter, before=before, limit=limit) while True: try: yield asyncio.run_coroutine_threadsafe( anext(aiter_), self.loop, ).result() except StopAsyncIteration: break def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database. This method retrieves a checkpoint tuple from the DuckDB database based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. """ try: # check if we are in the main thread, only bg threads can block # we don't check in other methods to avoid the overhead if asyncio.get_running_loop() is self.loop: raise asyncio.InvalidStateError( "Synchronous calls to AsyncDuckDBSaver are only allowed from a " "different thread. From the main thread, use the async interface." "For example, use `await checkpointer.aget_tuple(...)` or `await " "graph.ainvoke(...)`." ) except RuntimeError: pass return asyncio.run_coroutine_threadsafe( self.aget_tuple(config), self.loop ).result() def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Save a checkpoint to the database. This method saves a checkpoint to the DuckDB database. The checkpoint is associated with the provided config and its parent config (if any). Args: config (RunnableConfig): The config to associate with the checkpoint. checkpoint (Checkpoint): The checkpoint to save. metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. new_versions (ChannelVersions): New channel versions as of this write. Returns: RunnableConfig: Updated configuration after storing the checkpoint. """ return asyncio.run_coroutine_threadsafe( self.aput(config, checkpoint, metadata, new_versions), self.loop ).result() def put_writes( self, config: RunnableConfig, writes: Sequence[tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint. This method saves intermediate writes associated with a checkpoint to the database. Args: config (RunnableConfig): Configuration of the related checkpoint. writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. task_id (str): Identifier for the task creating the writes. """ return asyncio.run_coroutine_threadsafe( self.aput_writes(config, writes, task_id), self.loop ).result() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/checkpoint/duckdb/base.py`: ```py import json import random from typing import Any, List, Optional, Sequence, Tuple, cast from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( WRITES_IDX_MAP, BaseCheckpointSaver, ChannelVersions, Checkpoint, CheckpointMetadata, get_checkpoint_id, ) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol MetadataInput = Optional[dict[str, Any]] """ To add a new migration, add a new string to the MIGRATIONS list. The position of the migration in the list is the version number. """ MIGRATIONS = [ """CREATE TABLE IF NOT EXISTS checkpoint_migrations ( v INTEGER PRIMARY KEY );""", """CREATE TABLE IF NOT EXISTS checkpoints ( thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', checkpoint_id TEXT NOT NULL, parent_checkpoint_id TEXT, type TEXT, checkpoint JSON NOT NULL, metadata JSON NOT NULL DEFAULT '{}', PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) );""", """CREATE TABLE IF NOT EXISTS checkpoint_blobs ( thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', channel TEXT NOT NULL, version TEXT NOT NULL, type TEXT NOT NULL, blob BLOB, PRIMARY KEY (thread_id, checkpoint_ns, channel, version) );""", """CREATE TABLE IF NOT EXISTS checkpoint_writes ( thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', checkpoint_id TEXT NOT NULL, task_id TEXT NOT NULL, idx INTEGER NOT NULL, channel TEXT NOT NULL, type TEXT, blob BLOB NOT NULL, PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) );""", ] SELECT_SQL = f""" select thread_id, checkpoint, checkpoint_ns, checkpoint_id, parent_checkpoint_id, metadata, ( select array_agg(array[bl.channel::bytea, bl.type::bytea, bl.blob]) from ( SELECT unnest(json_keys(json_extract(checkpoint, '$.channel_versions'))) as key ) cv inner join checkpoint_blobs bl on bl.thread_id = checkpoints.thread_id and bl.checkpoint_ns = checkpoints.checkpoint_ns and bl.channel = cv.key and bl.version = json_extract_string(checkpoint, '$.channel_versions.' || cv.key) ) as channel_values, ( select array_agg(array[cw.task_id::blob, cw.channel::blob, cw.type::blob, cw.blob]) from checkpoint_writes cw where cw.thread_id = checkpoints.thread_id and cw.checkpoint_ns = checkpoints.checkpoint_ns and cw.checkpoint_id = checkpoints.checkpoint_id ) as pending_writes, ( select array_agg(array[cw.type::blob, cw.blob]) from checkpoint_writes cw where cw.thread_id = checkpoints.thread_id and cw.checkpoint_ns = checkpoints.checkpoint_ns and cw.checkpoint_id = checkpoints.parent_checkpoint_id and cw.channel = '{TASKS}' ) as pending_sends from checkpoints """ UPSERT_CHECKPOINT_BLOBS_SQL = """ INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, version, type, blob) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT (thread_id, checkpoint_ns, channel, version) DO NOTHING """ UPSERT_CHECKPOINTS_SQL = """ INSERT INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id) DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata; """ UPSERT_CHECKPOINT_WRITES_SQL = """ INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET channel = EXCLUDED.channel, type = EXCLUDED.type, blob = EXCLUDED.blob; """ INSERT_CHECKPOINT_WRITES_SQL = """ INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING """ class BaseDuckDBSaver(BaseCheckpointSaver[str]): SELECT_SQL = SELECT_SQL MIGRATIONS = MIGRATIONS UPSERT_CHECKPOINT_BLOBS_SQL = UPSERT_CHECKPOINT_BLOBS_SQL UPSERT_CHECKPOINTS_SQL = UPSERT_CHECKPOINTS_SQL UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL jsonplus_serde = JsonPlusSerializer() def _load_checkpoint( self, checkpoint_json_str: str, channel_values: list[tuple[bytes, bytes, bytes]], pending_sends: list[tuple[bytes, bytes]], ) -> Checkpoint: checkpoint = json.loads(checkpoint_json_str) return { **checkpoint, "pending_sends": [ self.serde.loads_typed((c.decode(), b)) for c, b in pending_sends or [] ], "channel_values": self._load_blobs(channel_values), } def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]: return {**checkpoint, "pending_sends": []} def _load_blobs( self, blob_values: list[tuple[bytes, bytes, bytes]] ) -> dict[str, Any]: if not blob_values: return {} return { k.decode(): self.serde.loads_typed((t.decode(), v)) for k, t, v in blob_values if t.decode() != "empty" } def _dump_blobs( self, thread_id: str, checkpoint_ns: str, values: dict[str, Any], versions: ChannelVersions, ) -> list[tuple[str, str, str, str, str, Optional[bytes]]]: if not versions: return [] return [ ( thread_id, checkpoint_ns, k, cast(str, ver), *( self.serde.dumps_typed(values[k]) if k in values else ("empty", None) ), ) for k, ver in versions.items() ] def _load_writes( self, writes: list[tuple[bytes, bytes, bytes, bytes]] ) -> list[tuple[str, str, Any]]: return ( [ ( tid.decode(), channel.decode(), self.serde.loads_typed((t.decode(), v)), ) for tid, channel, t, v in writes ] if writes else [] ) def _dump_writes( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str, task_id: str, writes: Sequence[tuple[str, Any]], ) -> list[tuple[str, str, str, str, int, str, str, bytes]]: return [ ( thread_id, checkpoint_ns, checkpoint_id, task_id, WRITES_IDX_MAP.get(channel, idx), channel, *self.serde.dumps_typed(value), ) for idx, (channel, value) in enumerate(writes) ] def _load_metadata(self, metadata_json_str: str) -> CheckpointMetadata: return self.jsonplus_serde.loads(metadata_json_str.encode()) def _dump_metadata(self, metadata: CheckpointMetadata) -> str: serialized_metadata = self.jsonplus_serde.dumps(metadata) # NOTE: we're using JSON serializer (not msgpack), so we need to remove null characters before writing return serialized_metadata.decode().replace("\\u0000", "") def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: if current is None: current_v = 0 elif isinstance(current, int): current_v = current else: current_v = int(current.split(".")[0]) next_v = current_v + 1 next_h = random.random() return f"{next_v:032}.{next_h:016}" def _search_where( self, config: Optional[RunnableConfig], filter: MetadataInput, before: Optional[RunnableConfig] = None, ) -> Tuple[str, List[Any]]: """Return WHERE clause predicates for alist() given config, filter, before. This method returns a tuple of a string and a tuple of values. The string is the parametered WHERE clause predicate (including the WHERE keyword): "WHERE column1 = $1 AND column2 IS $2". The list of values contains the values for each of the corresponding parameters. """ wheres = [] param_values = [] # construct predicate for config filter if config: wheres.append("thread_id = ?") param_values.append(config["configurable"]["thread_id"]) checkpoint_ns = config["configurable"].get("checkpoint_ns") if checkpoint_ns is not None: wheres.append("checkpoint_ns = ?") param_values.append(checkpoint_ns) if checkpoint_id := get_checkpoint_id(config): wheres.append("checkpoint_id = ?") param_values.append(checkpoint_id) # construct predicate for metadata filter if filter: wheres.append("json_contains(metadata, ?)") param_values.append(json.dumps(filter)) # construct predicate for `before` if before is not None: wheres.append("checkpoint_id < ?") param_values.append(get_checkpoint_id(before)) return ( "WHERE " + " AND ".join(wheres) if wheres else "", param_values, ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/store/duckdb/__init__.py`: ```py from langgraph.store.duckdb.aio import AsyncDuckDBStore from langgraph.store.duckdb.base import DuckDBStore __all__ = ["AsyncDuckDBStore", "DuckDBStore"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/store/duckdb/aio.py`: ```py import asyncio import logging from contextlib import asynccontextmanager from typing import ( AsyncIterator, Iterable, Sequence, cast, ) import duckdb from langgraph.store.base import GetOp, ListNamespacesOp, Op, PutOp, Result, SearchOp from langgraph.store.base.batch import AsyncBatchedBaseStore from langgraph.store.duckdb.base import ( BaseDuckDBStore, _convert_ns, _group_ops, _row_to_item, ) logger = logging.getLogger(__name__) class AsyncDuckDBStore(AsyncBatchedBaseStore, BaseDuckDBStore): def __init__( self, conn: duckdb.DuckDBPyConnection, ) -> None: super().__init__() self.conn = conn self.loop = asyncio.get_running_loop() async def abatch(self, ops: Iterable[Op]) -> list[Result]: grouped_ops, num_ops = _group_ops(ops) results: list[Result] = [None] * num_ops tasks = [] if GetOp in grouped_ops: tasks.append( self._batch_get_ops( cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), results ) ) if PutOp in grouped_ops: tasks.append( self._batch_put_ops( cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]) ) ) if SearchOp in grouped_ops: tasks.append( self._batch_search_ops( cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]), results, ) ) if ListNamespacesOp in grouped_ops: tasks.append( self._batch_list_namespaces_ops( cast( Sequence[tuple[int, ListNamespacesOp]], grouped_ops[ListNamespacesOp], ), results, ) ) await asyncio.gather(*tasks) return results def batch(self, ops: Iterable[Op]) -> list[Result]: return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result() async def _batch_get_ops( self, get_ops: Sequence[tuple[int, GetOp]], results: list[Result], ) -> None: cursors = [] for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops): cur = self.conn.cursor() await asyncio.to_thread(cur.execute, query, params) cursors.append((cur, namespace, items)) for cur, namespace, items in cursors: rows = await asyncio.to_thread(cur.fetchall) key_to_row = {row[1]: row for row in rows} for idx, key in items: row = key_to_row.get(key) if row: results[idx] = _row_to_item(namespace, row) else: results[idx] = None async def _batch_put_ops( self, put_ops: Sequence[tuple[int, PutOp]], ) -> None: queries = self._get_batch_PUT_queries(put_ops) for query, params in queries: cur = self.conn.cursor() await asyncio.to_thread(cur.execute, query, params) async def _batch_search_ops( self, search_ops: Sequence[tuple[int, SearchOp]], results: list[Result], ) -> None: queries = self._get_batch_search_queries(search_ops) cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = [] for (query, params), (idx, _) in zip(queries, search_ops): cur = self.conn.cursor() await asyncio.to_thread(cur.execute, query, params) cursors.append((cur, idx)) for cur, idx in cursors: rows = await asyncio.to_thread(cur.fetchall) items = [_row_to_item(_convert_ns(row[0]), row) for row in rows] results[idx] = items async def _batch_list_namespaces_ops( self, list_ops: Sequence[tuple[int, ListNamespacesOp]], results: list[Result], ) -> None: queries = self._get_batch_list_namespaces_queries(list_ops) cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = [] for (query, params), (idx, _) in zip(queries, list_ops): cur = self.conn.cursor() await asyncio.to_thread(cur.execute, query, params) cursors.append((cur, idx)) for cur, idx in cursors: rows = cast(list[tuple], await asyncio.to_thread(cur.fetchall)) namespaces = [_convert_ns(row[0]) for row in rows] results[idx] = namespaces @classmethod @asynccontextmanager async def from_conn_string( cls, conn_string: str, ) -> AsyncIterator["AsyncDuckDBStore"]: """Create a new AsyncDuckDBStore instance from a connection string. Args: conn_string (str): The DuckDB connection info string. Returns: AsyncDuckDBStore: A new AsyncDuckDBStore instance. """ with duckdb.connect(conn_string) as conn: yield cls(conn) async def setup(self) -> None: """Set up the store database asynchronously. This method creates the necessary tables in the DuckDB database if they don't already exist and runs database migrations. It is called automatically when needed and should not be called directly by the user. """ cur = self.conn.cursor() try: await asyncio.to_thread( cur.execute, "SELECT v FROM store_migrations ORDER BY v DESC LIMIT 1" ) row = await asyncio.to_thread(cur.fetchone) if row is None: version = -1 else: version = row[0] except duckdb.CatalogException: version = -1 # Create store_migrations table if it doesn't exist await asyncio.to_thread( cur.execute, """ CREATE TABLE IF NOT EXISTS store_migrations ( v INTEGER PRIMARY KEY ) """, ) for v, migration in enumerate( self.MIGRATIONS[version + 1 :], start=version + 1 ): await asyncio.to_thread(cur.execute, migration) await asyncio.to_thread( cur.execute, "INSERT INTO store_migrations (v) VALUES (?)", (v,) ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/store/duckdb/base.py`: ```py import asyncio import json import logging from collections import defaultdict from contextlib import contextmanager from typing import ( Any, Generic, Iterable, Iterator, Sequence, TypeVar, Union, cast, ) import duckdb from langgraph.store.base import ( BaseStore, GetOp, Item, ListNamespacesOp, Op, PutOp, Result, SearchItem, SearchOp, ) logger = logging.getLogger(__name__) MIGRATIONS = [ """ CREATE TABLE IF NOT EXISTS store ( prefix TEXT NOT NULL, key TEXT NOT NULL, value JSON NOT NULL, created_at TIMESTAMP DEFAULT now(), updated_at TIMESTAMP DEFAULT now(), PRIMARY KEY (prefix, key) ); """, """ CREATE INDEX IF NOT EXISTS store_prefix_idx ON store (prefix); """, ] C = TypeVar("C", bound=duckdb.DuckDBPyConnection) class BaseDuckDBStore(Generic[C]): MIGRATIONS = MIGRATIONS conn: C def _get_batch_GET_ops_queries( self, get_ops: Sequence[tuple[int, GetOp]], ) -> list[tuple[str, tuple, tuple[str, ...], list]]: namespace_groups = defaultdict(list) for idx, op in get_ops: namespace_groups[op.namespace].append((idx, op.key)) results = [] for namespace, items in namespace_groups.items(): _, keys = zip(*items) keys_to_query = ",".join(["?"] * len(keys)) query = f""" SELECT prefix, key, value, created_at, updated_at FROM store WHERE prefix = ? AND key IN ({keys_to_query}) """ params = (_namespace_to_text(namespace), *keys) results.append((query, params, namespace, items)) return results def _get_batch_PUT_queries( self, put_ops: Sequence[tuple[int, PutOp]], ) -> list[tuple[str, Sequence]]: inserts: list[PutOp] = [] deletes: list[PutOp] = [] for _, op in put_ops: if op.value is None: deletes.append(op) else: inserts.append(op) queries: list[tuple[str, Sequence]] = [] if deletes: namespace_groups: dict[tuple[str, ...], list[str]] = defaultdict(list) for op in deletes: namespace_groups[op.namespace].append(op.key) for namespace, keys in namespace_groups.items(): placeholders = ",".join(["?"] * len(keys)) query = ( f"DELETE FROM store WHERE prefix = ? AND key IN ({placeholders})" ) params = (_namespace_to_text(namespace), *keys) queries.append((query, params)) if inserts: values = [] insertion_params = [] for op in inserts: values.append("(?, ?, ?, now(), now())") insertion_params.extend( [ _namespace_to_text(op.namespace), op.key, json.dumps(op.value), ] ) values_str = ",".join(values) query = f""" INSERT INTO store (prefix, key, value, created_at, updated_at) VALUES {values_str} ON CONFLICT (prefix, key) DO UPDATE SET value = EXCLUDED.value, updated_at = now() """ queries.append((query, insertion_params)) return queries def _get_batch_search_queries( self, search_ops: Sequence[tuple[int, SearchOp]], ) -> list[tuple[str, Sequence]]: queries: list[tuple[str, Sequence]] = [] for _, op in search_ops: query = """ SELECT prefix, key, value, created_at, updated_at FROM store WHERE prefix LIKE ? """ params: list = [f"{_namespace_to_text(op.namespace_prefix)}%"] if op.filter: filter_conditions = [] for key, value in op.filter.items(): filter_conditions.append(f"json_extract(value, '$.{key}') = ?") params.append(json.dumps(value)) query += " AND " + " AND ".join(filter_conditions) query += " ORDER BY updated_at DESC LIMIT ? OFFSET ?" params.extend([op.limit, op.offset]) queries.append((query, params)) return queries def _get_batch_list_namespaces_queries( self, list_ops: Sequence[tuple[int, ListNamespacesOp]], ) -> list[tuple[str, Sequence]]: queries: list[tuple[str, Sequence]] = [] for _, op in list_ops: query = """ WITH split_prefix AS ( SELECT prefix, string_split(prefix, '.') AS parts FROM store ) SELECT DISTINCT ON (truncated_prefix) CASE WHEN ? IS NOT NULL THEN array_to_string(array_slice(parts, 1, ?), '.') ELSE prefix END AS truncated_prefix, prefix FROM split_prefix """ params: list[Any] = [op.max_depth, op.max_depth] conditions = [] if op.match_conditions: for condition in op.match_conditions: if condition.match_type == "prefix": conditions.append("prefix LIKE ?") params.append( f"{_namespace_to_text(condition.path, handle_wildcards=True)}%" ) elif condition.match_type == "suffix": conditions.append("prefix LIKE ?") params.append( f"%{_namespace_to_text(condition.path, handle_wildcards=True)}" ) else: logger.warning( f"Unknown match_type in list_namespaces: {condition.match_type}" ) if conditions: query += " WHERE " + " AND ".join(conditions) query += " ORDER BY prefix LIMIT ? OFFSET ?" params.extend([op.limit, op.offset]) queries.append((query, params)) return queries class DuckDBStore(BaseStore, BaseDuckDBStore[duckdb.DuckDBPyConnection]): def __init__( self, conn: duckdb.DuckDBPyConnection, ) -> None: super().__init__() self.conn = conn def batch(self, ops: Iterable[Op]) -> list[Result]: grouped_ops, num_ops = _group_ops(ops) results: list[Result] = [None] * num_ops if GetOp in grouped_ops: self._batch_get_ops( cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), results ) if PutOp in grouped_ops: self._batch_put_ops(cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp])) if SearchOp in grouped_ops: self._batch_search_ops( cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]), results, ) if ListNamespacesOp in grouped_ops: self._batch_list_namespaces_ops( cast( Sequence[tuple[int, ListNamespacesOp]], grouped_ops[ListNamespacesOp], ), results, ) return results async def abatch(self, ops: Iterable[Op]) -> list[Result]: return await asyncio.get_running_loop().run_in_executor(None, self.batch, ops) def _batch_get_ops( self, get_ops: Sequence[tuple[int, GetOp]], results: list[Result], ) -> None: cursors = [] for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops): cur = self.conn.cursor() cur.execute(query, params) cursors.append((cur, namespace, items)) for cur, namespace, items in cursors: rows = cur.fetchall() key_to_row = {row[1]: row for row in rows} for idx, key in items: row = key_to_row.get(key) if row: results[idx] = _row_to_item(namespace, row) else: results[idx] = None def _batch_put_ops( self, put_ops: Sequence[tuple[int, PutOp]], ) -> None: queries = self._get_batch_PUT_queries(put_ops) for query, params in queries: cur = self.conn.cursor() cur.execute(query, params) def _batch_search_ops( self, search_ops: Sequence[tuple[int, SearchOp]], results: list[Result], ) -> None: queries = self._get_batch_search_queries(search_ops) cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = [] for (query, params), (idx, _) in zip(queries, search_ops): cur = self.conn.cursor() cur.execute(query, params) cursors.append((cur, idx)) for cur, idx in cursors: rows = cur.fetchall() items = [_row_to_search_item(_convert_ns(row[0]), row) for row in rows] results[idx] = items def _batch_list_namespaces_ops( self, list_ops: Sequence[tuple[int, ListNamespacesOp]], results: list[Result], ) -> None: queries = self._get_batch_list_namespaces_queries(list_ops) cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = [] for (query, params), (idx, _) in zip(queries, list_ops): cur = self.conn.cursor() cur.execute(query, params) cursors.append((cur, idx)) for cur, idx in cursors: rows = cast(list[dict], cur.fetchall()) namespaces = [_convert_ns(row[0]) for row in rows] results[idx] = namespaces @classmethod @contextmanager def from_conn_string( cls, conn_string: str, ) -> Iterator["DuckDBStore"]: """Create a new BaseDuckDBStore instance from a connection string. Args: conn_string (str): The DuckDB connection info string. Returns: DuckDBStore: A new DuckDBStore instance. """ with duckdb.connect(conn_string) as conn: yield cls(conn=conn) def setup(self) -> None: """Set up the store database. This method creates the necessary tables in the DuckDB database if they don't already exist and runs database migrations. It is called automatically when needed and should not be called directly by the user. """ with self.conn.cursor() as cur: try: cur.execute("SELECT v FROM store_migrations ORDER BY v DESC LIMIT 1") row = cast(dict, cur.fetchone()) if row is None: version = -1 else: version = row["v"] except duckdb.CatalogException: version = -1 # Create store_migrations table if it doesn't exist cur.execute( """ CREATE TABLE IF NOT EXISTS store_migrations ( v INTEGER PRIMARY KEY ) """ ) for v, migration in enumerate( self.MIGRATIONS[version + 1 :], start=version + 1 ): cur.execute(migration) cur.execute("INSERT INTO store_migrations (v) VALUES (?)", (v,)) def _namespace_to_text( namespace: tuple[str, ...], handle_wildcards: bool = False ) -> str: """Convert namespace tuple to text string.""" if handle_wildcards: namespace = tuple("%" if val == "*" else val for val in namespace) return ".".join(namespace) def _row_to_item( namespace: tuple[str, ...], row: tuple, ) -> Item: """Convert a row from the database into an Item.""" _, key, val, created_at, updated_at = row return Item( value=val if isinstance(val, dict) else json.loads(val), key=key, namespace=namespace, created_at=created_at, updated_at=updated_at, ) def _row_to_search_item( namespace: tuple[str, ...], row: tuple, ) -> SearchItem: """Convert a row from the database into an SearchItem.""" # TODO: Add support for search _, key, val, created_at, updated_at = row return SearchItem( value=val if isinstance(val, dict) else json.loads(val), key=key, namespace=namespace, created_at=created_at, updated_at=updated_at, ) def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]: grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list) tot = 0 for idx, op in enumerate(ops): grouped_ops[type(op)].append((idx, op)) tot += 1 return grouped_ops, tot def _convert_ns(namespace: Union[str, list]) -> tuple[str, ...]: if isinstance(namespace, list): return tuple(namespace) return tuple(namespace.split(".")) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/pyproject.toml`: ```toml [tool.poetry] name = "langgraph-checkpoint-duckdb" version = "2.0.1" description = "Library with a DuckDB implementation of LangGraph checkpoint saver." authors = [] license = "MIT" readme = "README.md" repository = "https://www.github.com/langchain-ai/langgraph" packages = [{ include = "langgraph" }] [tool.poetry.dependencies] python = "^3.9.0,<4.0" langgraph-checkpoint = "^2.0.2" duckdb = ">=1.1.2" [tool.poetry.group.dev.dependencies] ruff = "^0.6.2" codespell = "^2.2.0" pytest = "^7.2.1" anyio = "^4.4.0" pytest-asyncio = "^0.21.1" pytest-mock = "^3.11.1" pytest-watch = "^4.2.0" mypy = "^1.10.0" langgraph-checkpoint = {path = "../checkpoint", develop = true} [tool.pytest.ini_options] # --strict-markers will raise errors on unknown marks. # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks # # https://docs.pytest.org/en/7.1.x/reference/reference.html # --strict-config any warnings encountered while parsing the `pytest` # section of the configuration file raise errors. addopts = "--strict-markers --strict-config --durations=5 -vv" asyncio_mode = "auto" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] lint.select = [ "E", # pycodestyle "F", # Pyflakes "UP", # pyupgrade "B", # flake8-bugbear "I", # isort ] lint.ignore = ["E501", "B008", "UP007", "UP006"] [tool.mypy] # https://mypy.readthedocs.io/en/stable/config_file.html disallow_untyped_defs = "True" explicit_package_bases = "True" warn_no_return = "False" warn_unused_ignores = "True" warn_redundant_casts = "True" allow_redefinition = "True" disable_error_code = "typeddict-item, return-value" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/tests/test_sync.py`: ```py from typing import Any import pytest from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( Checkpoint, CheckpointMetadata, create_checkpoint, empty_checkpoint, ) from langgraph.checkpoint.duckdb import DuckDBSaver class TestDuckDBSaver: @pytest.fixture(autouse=True) def setup(self) -> None: # objects for test setup self.config_1: RunnableConfig = { "configurable": { "thread_id": "thread-1", # for backwards compatibility testing "thread_ts": "1", "checkpoint_ns": "", } } self.config_2: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2", "checkpoint_ns": "", } } self.config_3: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2-inner", "checkpoint_ns": "inner", } } self.chkpnt_1: Checkpoint = empty_checkpoint() self.chkpnt_2: Checkpoint = create_checkpoint(self.chkpnt_1, {}, 1) self.chkpnt_3: Checkpoint = empty_checkpoint() self.metadata_1: CheckpointMetadata = { "source": "input", "step": 2, "writes": {}, "score": 1, } self.metadata_2: CheckpointMetadata = { "source": "loop", "step": 1, "writes": {"foo": "bar"}, "score": None, } self.metadata_3: CheckpointMetadata = {} def test_search(self) -> None: with DuckDBSaver.from_conn_string(":memory:") as saver: saver.setup() # save checkpoints saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {}) saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {}) saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {}) # call method / assertions query_1 = {"source": "input"} # search by 1 key query_2 = { "step": 1, "writes": {"foo": "bar"}, } # search by multiple keys query_3: dict[str, Any] = {} # search by no keys, return all checkpoints query_4 = {"source": "update", "step": 1} # no match search_results_1 = list(saver.list(None, filter=query_1)) assert len(search_results_1) == 1 assert search_results_1[0].metadata == self.metadata_1 search_results_2 = list(saver.list(None, filter=query_2)) assert len(search_results_2) == 1 assert search_results_2[0].metadata == self.metadata_2 search_results_3 = list(saver.list(None, filter=query_3)) assert len(search_results_3) == 3 search_results_4 = list(saver.list(None, filter=query_4)) assert len(search_results_4) == 0 # search by config (defaults to checkpoints across all namespaces) search_results_5 = list( saver.list({"configurable": {"thread_id": "thread-2"}}) ) assert len(search_results_5) == 2 assert { search_results_5[0].config["configurable"]["checkpoint_ns"], search_results_5[1].config["configurable"]["checkpoint_ns"], } == {"", "inner"} # TODO: test before and limit params def test_null_chars(self) -> None: with DuckDBSaver.from_conn_string(":memory:") as saver: saver.setup() config = saver.put(self.config_1, self.chkpnt_1, {"my_key": "\x00abc"}, {}) assert saver.get_tuple(config).metadata["my_key"] == "abc" # type: ignore assert ( list(saver.list(None, filter={"my_key": "abc"}))[0].metadata["my_key"] # type: ignore == "abc" ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/tests/test_async.py`: ```py from typing import Any import pytest from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( Checkpoint, CheckpointMetadata, create_checkpoint, empty_checkpoint, ) from langgraph.checkpoint.duckdb.aio import AsyncDuckDBSaver class TestAsyncDuckDBSaver: @pytest.fixture(autouse=True) async def setup(self) -> None: # objects for test setup self.config_1: RunnableConfig = { "configurable": { "thread_id": "thread-1", # for backwards compatibility testing "thread_ts": "1", "checkpoint_ns": "", } } self.config_2: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2", "checkpoint_ns": "", } } self.config_3: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2-inner", "checkpoint_ns": "inner", } } self.chkpnt_1: Checkpoint = empty_checkpoint() self.chkpnt_2: Checkpoint = create_checkpoint(self.chkpnt_1, {}, 1) self.chkpnt_3: Checkpoint = empty_checkpoint() self.metadata_1: CheckpointMetadata = { "source": "input", "step": 2, "writes": {}, "score": 1, } self.metadata_2: CheckpointMetadata = { "source": "loop", "step": 1, "writes": {"foo": "bar"}, "score": None, } self.metadata_3: CheckpointMetadata = {} async def test_asearch(self) -> None: async with AsyncDuckDBSaver.from_conn_string(":memory:") as saver: await saver.setup() await saver.aput(self.config_1, self.chkpnt_1, self.metadata_1, {}) await saver.aput(self.config_2, self.chkpnt_2, self.metadata_2, {}) await saver.aput(self.config_3, self.chkpnt_3, self.metadata_3, {}) # call method / assertions query_1 = {"source": "input"} # search by 1 key query_2 = { "step": 1, "writes": {"foo": "bar"}, } # search by multiple keys query_3: dict[str, Any] = {} # search by no keys, return all checkpoints query_4 = {"source": "update", "step": 1} # no match search_results_1 = [c async for c in saver.alist(None, filter=query_1)] assert len(search_results_1) == 1 assert search_results_1[0].metadata == self.metadata_1 search_results_2 = [c async for c in saver.alist(None, filter=query_2)] assert len(search_results_2) == 1 assert search_results_2[0].metadata == self.metadata_2 search_results_3 = [c async for c in saver.alist(None, filter=query_3)] assert len(search_results_3) == 3 search_results_4 = [c async for c in saver.alist(None, filter=query_4)] assert len(search_results_4) == 0 # search by config (defaults to checkpoints across all namespaces) search_results_5 = [ c async for c in saver.alist({"configurable": {"thread_id": "thread-2"}}) ] assert len(search_results_5) == 2 assert { search_results_5[0].config["configurable"]["checkpoint_ns"], search_results_5[1].config["configurable"]["checkpoint_ns"], } == {"", "inner"} # TODO: test before and limit params async def test_null_chars(self) -> None: async with AsyncDuckDBSaver.from_conn_string(":memory:") as saver: await saver.setup() config = await saver.aput( self.config_1, self.chkpnt_1, {"my_key": "\x00abc"}, {} ) assert (await saver.aget_tuple(config)).metadata["my_key"] == "abc" # type: ignore assert [c async for c in saver.alist(None, filter={"my_key": "abc"})][ 0 ].metadata["my_key"] == "abc" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/tests/test_store.py`: ```py # type: ignore import uuid from datetime import datetime from typing import Any from unittest.mock import MagicMock import pytest from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp from langgraph.store.duckdb import DuckDBStore class MockCursor: def __init__(self, fetch_result: Any) -> None: self.fetch_result = fetch_result self.execute = MagicMock() self.fetchall = MagicMock(return_value=self.fetch_result) class MockConnection: def __init__(self) -> None: self.cursor = MagicMock() @pytest.fixture def mock_connection() -> MockConnection: return MockConnection() @pytest.fixture def store(mock_connection: MockConnection) -> DuckDBStore: duck_db_store = DuckDBStore(mock_connection) duck_db_store.setup() return duck_db_store def test_batch_order(store: DuckDBStore) -> None: mock_connection = store.conn mock_get_cursor = MockCursor( [ ( "test.foo", "key1", '{"data": "value1"}', datetime.now(), datetime.now(), ), ( "test.bar", "key2", '{"data": "value2"}', datetime.now(), datetime.now(), ), ] ) mock_search_cursor = MockCursor( [ ( "test.foo", "key1", '{"data": "value1"}', datetime.now(), datetime.now(), ), ] ) mock_list_namespaces_cursor = MockCursor( [ ("test",), ] ) failures = [] def cursor_side_effect() -> Any: cursor = MagicMock() def execute_side_effect(query: str, *params: Any) -> None: # My super sophisticated database. if "WHERE prefix = ? AND key" in query: cursor.fetchall = mock_get_cursor.fetchall elif "SELECT prefix, key, value" in query: cursor.fetchall = mock_search_cursor.fetchall elif "SELECT DISTINCT ON (truncated_prefix)" in query: cursor.fetchall = mock_list_namespaces_cursor.fetchall elif "INSERT INTO " in query: pass else: e = ValueError(f"Unmatched query: {query}") failures.append(e) raise e cursor.execute = MagicMock(side_effect=execute_side_effect) return cursor mock_connection.cursor.side_effect = cursor_side_effect ops = [ GetOp(namespace=("test",), key="key1"), PutOp(namespace=("test",), key="key2", value={"data": "value2"}), SearchOp( namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0 ), ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0), GetOp(namespace=("test",), key="key3"), ] results = store.batch(ops) assert not failures assert len(results) == 5 assert isinstance(results[0], Item) assert isinstance(results[0].value, dict) assert results[0].value == {"data": "value1"} assert results[0].key == "key1" assert results[1] is None assert isinstance(results[2], list) assert len(results[2]) == 1 assert isinstance(results[3], list) assert results[3] == [("test",)] assert results[4] is None ops_reordered = [ SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0), GetOp(namespace=("test",), key="key2"), ListNamespacesOp(match_conditions=None, max_depth=None, limit=5, offset=0), PutOp(namespace=("test",), key="key3", value={"data": "value3"}), GetOp(namespace=("test",), key="key1"), ] results_reordered = store.batch(ops_reordered) assert not failures assert len(results_reordered) == 5 assert isinstance(results_reordered[0], list) assert len(results_reordered[0]) == 1 assert isinstance(results_reordered[1], Item) assert results_reordered[1].value == {"data": "value2"} assert results_reordered[1].key == "key2" assert isinstance(results_reordered[2], list) assert results_reordered[2] == [("test",)] assert results_reordered[3] is None assert isinstance(results_reordered[4], Item) assert results_reordered[4].value == {"data": "value1"} assert results_reordered[4].key == "key1" def test_batch_get_ops(store: DuckDBStore) -> None: mock_connection = store.conn mock_cursor = MockCursor( [ ( "test.foo", "key1", '{"data": "value1"}', datetime.now(), datetime.now(), ), ( "test.bar", "key2", '{"data": "value2"}', datetime.now(), datetime.now(), ), ] ) mock_connection.cursor.return_value = mock_cursor ops = [ GetOp(namespace=("test",), key="key1"), GetOp(namespace=("test",), key="key2"), GetOp(namespace=("test",), key="key3"), ] results = store.batch(ops) assert len(results) == 3 assert results[0] is not None assert results[1] is not None assert results[2] is None assert results[0].key == "key1" assert results[1].key == "key2" def test_batch_put_ops(store: DuckDBStore) -> None: mock_connection = store.conn mock_cursor = MockCursor([]) mock_connection.cursor.return_value = mock_cursor ops = [ PutOp(namespace=("test",), key="key1", value={"data": "value1"}), PutOp(namespace=("test",), key="key2", value={"data": "value2"}), PutOp(namespace=("test",), key="key3", value=None), ] results = store.batch(ops) assert len(results) == 3 assert all(result is None for result in results) assert mock_cursor.execute.call_count == 2 def test_batch_search_ops(store: DuckDBStore) -> None: mock_connection = store.conn mock_cursor = MockCursor( [ ( "test.foo", "key1", '{"data": "value1"}', datetime.now(), datetime.now(), ), ( "test.bar", "key2", '{"data": "value2"}', datetime.now(), datetime.now(), ), ] ) mock_connection.cursor.return_value = mock_cursor ops = [ SearchOp( namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0 ), SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0), ] results = store.batch(ops) assert len(results) == 2 assert len(results[0]) == 2 assert len(results[1]) == 2 def test_batch_list_namespaces_ops(store: DuckDBStore) -> None: mock_connection = store.conn mock_cursor = MockCursor([("test.namespace1",), ("test.namespace2",)]) mock_connection.cursor.return_value = mock_cursor ops = [ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0)] results = store.batch(ops) assert len(results) == 1 assert results[0] == [("test", "namespace1"), ("test", "namespace2")] def test_basic_store_ops() -> None: with DuckDBStore.from_conn_string(":memory:") as store: store.setup() namespace = ("test", "documents") item_id = "doc1" item_value = {"title": "Test Document", "content": "Hello, World!"} store.put(namespace, item_id, item_value) item = store.get(namespace, item_id) assert item assert item.namespace == namespace assert item.key == item_id assert item.value == item_value updated_value = { "title": "Updated Test Document", "content": "Hello, LangGraph!", } store.put(namespace, item_id, updated_value) updated_item = store.get(namespace, item_id) assert updated_item.value == updated_value assert updated_item.updated_at > item.updated_at different_namespace = ("test", "other_documents") item_in_different_namespace = store.get(different_namespace, item_id) assert item_in_different_namespace is None new_item_id = "doc2" new_item_value = {"title": "Another Document", "content": "Greetings!"} store.put(namespace, new_item_id, new_item_value) search_results = store.search(["test"], limit=10) items = search_results assert len(items) == 2 assert any(item.key == item_id for item in items) assert any(item.key == new_item_id for item in items) namespaces = store.list_namespaces(prefix=["test"]) assert ("test", "documents") in namespaces store.delete(namespace, item_id) store.delete(namespace, new_item_id) deleted_item = store.get(namespace, item_id) assert deleted_item is None deleted_item = store.get(namespace, new_item_id) assert deleted_item is None empty_search_results = store.search(["test"], limit=10) assert len(empty_search_results) == 0 def test_list_namespaces() -> None: with DuckDBStore.from_conn_string(":memory:") as store: store.setup() test_pref = str(uuid.uuid4()) test_namespaces = [ (test_pref, "test", "documents", "public", test_pref), (test_pref, "test", "documents", "private", test_pref), (test_pref, "test", "images", "public", test_pref), (test_pref, "test", "images", "private", test_pref), (test_pref, "prod", "documents", "public", test_pref), ( test_pref, "prod", "documents", "some", "nesting", "public", test_pref, ), (test_pref, "prod", "documents", "private", test_pref), ] for namespace in test_namespaces: store.put(namespace, "dummy", {"content": "dummy"}) prefix_result = store.list_namespaces(prefix=[test_pref, "test"]) assert len(prefix_result) == 4 assert all([ns[1] == "test" for ns in prefix_result]) specific_prefix_result = store.list_namespaces( prefix=[test_pref, "test", "documents"] ) assert len(specific_prefix_result) == 2 assert all([ns[1:3] == ("test", "documents") for ns in specific_prefix_result]) suffix_result = store.list_namespaces(suffix=["public", test_pref]) assert len(suffix_result) == 4 assert all(ns[-2] == "public" for ns in suffix_result) prefix_suffix_result = store.list_namespaces( prefix=[test_pref, "test"], suffix=["public", test_pref] ) assert len(prefix_suffix_result) == 2 assert all( ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result ) wildcard_prefix_result = store.list_namespaces( prefix=[test_pref, "*", "documents"] ) assert len(wildcard_prefix_result) == 5 assert all(ns[2] == "documents" for ns in wildcard_prefix_result) wildcard_suffix_result = store.list_namespaces( suffix=["*", "public", test_pref] ) assert len(wildcard_suffix_result) == 4 assert all(ns[-2] == "public" for ns in wildcard_suffix_result) wildcard_single = store.list_namespaces( suffix=["some", "*", "public", test_pref] ) assert len(wildcard_single) == 1 assert wildcard_single[0] == ( test_pref, "prod", "documents", "some", "nesting", "public", test_pref, ) max_depth_result = store.list_namespaces(max_depth=3) assert all([len(ns) <= 3 for ns in max_depth_result]) max_depth_result = store.list_namespaces( max_depth=4, prefix=[test_pref, "*", "documents"] ) assert ( len(set(tuple(res) for res in max_depth_result)) == len(max_depth_result) == 5 ) limit_result = store.list_namespaces(prefix=[test_pref], limit=3) assert len(limit_result) == 3 offset_result = store.list_namespaces(prefix=[test_pref], offset=3) assert len(offset_result) == len(test_namespaces) - 3 empty_prefix_result = store.list_namespaces(prefix=[test_pref]) assert len(empty_prefix_result) == len(test_namespaces) assert set(tuple(ns) for ns in empty_prefix_result) == set( tuple(ns) for ns in test_namespaces ) for namespace in test_namespaces: store.delete(namespace, "dummy") def test_search(): with DuckDBStore.from_conn_string(":memory:") as store: store.setup() test_namespaces = [ ("test_search", "documents", "user1"), ("test_search", "documents", "user2"), ("test_search", "reports", "department1"), ("test_search", "reports", "department2"), ] test_items = [ {"title": "Doc 1", "author": "John Doe", "tags": ["important"]}, {"title": "Doc 2", "author": "Jane Smith", "tags": ["draft"]}, {"title": "Report A", "author": "John Doe", "tags": ["final"]}, {"title": "Report B", "author": "Alice Johnson", "tags": ["draft"]}, ] for namespace, item in zip(test_namespaces, test_items): store.put(namespace, f"item_{namespace[-1]}", item) docs_result = store.search(["test_search", "documents"]) assert len(docs_result) == 2 assert all( [item.namespace[1] == "documents" for item in docs_result] ), docs_result reports_result = store.search(["test_search", "reports"]) assert len(reports_result) == 2 assert all(item.namespace[1] == "reports" for item in reports_result) limited_result = store.search(["test_search"], limit=2) assert len(limited_result) == 2 offset_result = store.search(["test_search"]) assert len(offset_result) == 4 offset_result = store.search(["test_search"], offset=2) assert len(offset_result) == 2 assert all(item not in limited_result for item in offset_result) john_doe_result = store.search(["test_search"], filter={"author": "John Doe"}) assert len(john_doe_result) == 2 assert all(item.value["author"] == "John Doe" for item in john_doe_result) draft_result = store.search(["test_search"], filter={"tags": ["draft"]}) assert len(draft_result) == 2 assert all("draft" in item.value["tags"] for item in draft_result) page1 = store.search(["test_search"], limit=2, offset=0) page2 = store.search(["test_search"], limit=2, offset=2) all_items = page1 + page2 assert len(all_items) == 4 assert len(set(item.key for item in all_items)) == 4 for namespace in test_namespaces: store.delete(namespace, f"item_{namespace[-1]}") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/tests/test_async_store.py`: ```py # type: ignore import uuid from datetime import datetime from typing import Any from unittest.mock import MagicMock import pytest from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp from langgraph.store.duckdb import AsyncDuckDBStore class MockCursor: def __init__(self, fetch_result: Any) -> None: self.fetch_result = fetch_result self.execute = MagicMock() self.fetchall = MagicMock(return_value=self.fetch_result) class MockConnection: def __init__(self) -> None: self.cursor = MagicMock() @pytest.fixture def mock_connection() -> MockConnection: return MockConnection() @pytest.fixture async def store(mock_connection: MockConnection) -> AsyncDuckDBStore: duck_db_store = AsyncDuckDBStore(mock_connection) await duck_db_store.setup() return duck_db_store async def test_abatch_order(store: AsyncDuckDBStore) -> None: mock_connection = store.conn mock_get_cursor = MockCursor( [ ( "test.foo", "key1", '{"data": "value1"}', datetime.now(), datetime.now(), ), ( "test.bar", "key2", '{"data": "value2"}', datetime.now(), datetime.now(), ), ] ) mock_search_cursor = MockCursor( [ ( "test.foo", "key1", '{"data": "value1"}', datetime.now(), datetime.now(), ), ] ) mock_list_namespaces_cursor = MockCursor( [ ("test",), ] ) failures = [] def cursor_side_effect() -> Any: cursor = MagicMock() def execute_side_effect(query: str, *params: Any) -> None: # My super sophisticated database. if "WHERE prefix = ? AND key" in query: cursor.fetchall = mock_get_cursor.fetchall elif "SELECT prefix, key, value" in query: cursor.fetchall = mock_search_cursor.fetchall elif "SELECT DISTINCT ON (truncated_prefix)" in query: cursor.fetchall = mock_list_namespaces_cursor.fetchall elif "INSERT INTO " in query: pass else: e = ValueError(f"Unmatched query: {query}") failures.append(e) raise e cursor.execute = MagicMock(side_effect=execute_side_effect) return cursor mock_connection.cursor.side_effect = cursor_side_effect # type: ignore ops = [ GetOp(namespace=("test",), key="key1"), PutOp(namespace=("test",), key="key2", value={"data": "value2"}), SearchOp( namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0 ), ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0), GetOp(namespace=("test",), key="key3"), ] results = await store.abatch(ops) assert not failures assert len(results) == 5 assert isinstance(results[0], Item) assert isinstance(results[0].value, dict) assert results[0].value == {"data": "value1"} assert results[0].key == "key1" assert results[1] is None assert isinstance(results[2], list) assert len(results[2]) == 1 assert isinstance(results[3], list) assert results[3] == [("test",)] assert results[4] is None ops_reordered = [ SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0), GetOp(namespace=("test",), key="key2"), ListNamespacesOp(match_conditions=None, max_depth=None, limit=5, offset=0), PutOp(namespace=("test",), key="key3", value={"data": "value3"}), GetOp(namespace=("test",), key="key1"), ] results_reordered = await store.abatch(ops_reordered) assert not failures assert len(results_reordered) == 5 assert isinstance(results_reordered[0], list) assert len(results_reordered[0]) == 1 assert isinstance(results_reordered[1], Item) assert results_reordered[1].value == {"data": "value2"} assert results_reordered[1].key == "key2" assert isinstance(results_reordered[2], list) assert results_reordered[2] == [("test",)] assert results_reordered[3] is None assert isinstance(results_reordered[4], Item) assert results_reordered[4].value == {"data": "value1"} assert results_reordered[4].key == "key1" async def test_batch_get_ops(store: AsyncDuckDBStore) -> None: mock_connection = store.conn mock_cursor = MockCursor( [ ( "test.foo", "key1", '{"data": "value1"}', datetime.now(), datetime.now(), ), ( "test.bar", "key2", '{"data": "value2"}', datetime.now(), datetime.now(), ), ] ) mock_connection.cursor.return_value = mock_cursor ops = [ GetOp(namespace=("test",), key="key1"), GetOp(namespace=("test",), key="key2"), GetOp(namespace=("test",), key="key3"), ] results = await store.abatch(ops) assert len(results) == 3 assert results[0] is not None assert results[1] is not None assert results[2] is None assert results[0].key == "key1" assert results[1].key == "key2" async def test_batch_put_ops(store: AsyncDuckDBStore) -> None: mock_connection = store.conn mock_cursor = MockCursor([]) mock_connection.cursor.return_value = mock_cursor ops = [ PutOp(namespace=("test",), key="key1", value={"data": "value1"}), PutOp(namespace=("test",), key="key2", value={"data": "value2"}), PutOp(namespace=("test",), key="key3", value=None), ] results = await store.abatch(ops) assert len(results) == 3 assert all(result is None for result in results) assert mock_cursor.execute.call_count == 2 async def test_batch_search_ops(store: AsyncDuckDBStore) -> None: mock_connection = store.conn mock_cursor = MockCursor( [ ( "test.foo", "key1", '{"data": "value1"}', datetime.now(), datetime.now(), ), ( "test.bar", "key2", '{"data": "value2"}', datetime.now(), datetime.now(), ), ] ) mock_connection.cursor.return_value = mock_cursor ops = [ SearchOp( namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0 ), SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0), ] results = await store.abatch(ops) assert len(results) == 2 assert len(results[0]) == 2 assert len(results[1]) == 2 async def test_batch_list_namespaces_ops(store: AsyncDuckDBStore) -> None: mock_connection = store.conn mock_cursor = MockCursor([("test.namespace1",), ("test.namespace2",)]) mock_connection.cursor.return_value = mock_cursor ops = [ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0)] results = await store.abatch(ops) assert len(results) == 1 assert results[0] == [("test", "namespace1"), ("test", "namespace2")] # The following use the actual DB connection async def test_basic_store_ops() -> None: async with AsyncDuckDBStore.from_conn_string(":memory:") as store: await store.setup() namespace = ("test", "documents") item_id = "doc1" item_value = {"title": "Test Document", "content": "Hello, World!"} await store.aput(namespace, item_id, item_value) item = await store.aget(namespace, item_id) assert item assert item.namespace == namespace assert item.key == item_id assert item.value == item_value updated_value = { "title": "Updated Test Document", "content": "Hello, LangGraph!", } await store.aput(namespace, item_id, updated_value) updated_item = await store.aget(namespace, item_id) assert updated_item.value == updated_value assert updated_item.updated_at > item.updated_at different_namespace = ("test", "other_documents") item_in_different_namespace = await store.aget(different_namespace, item_id) assert item_in_different_namespace is None new_item_id = "doc2" new_item_value = {"title": "Another Document", "content": "Greetings!"} await store.aput(namespace, new_item_id, new_item_value) search_results = await store.asearch(["test"], limit=10) items = search_results assert len(items) == 2 assert any(item.key == item_id for item in items) assert any(item.key == new_item_id for item in items) namespaces = await store.alist_namespaces(prefix=["test"]) assert ("test", "documents") in namespaces await store.adelete(namespace, item_id) await store.adelete(namespace, new_item_id) deleted_item = await store.aget(namespace, item_id) assert deleted_item is None deleted_item = await store.aget(namespace, new_item_id) assert deleted_item is None empty_search_results = await store.asearch(["test"], limit=10) assert len(empty_search_results) == 0 async def test_list_namespaces() -> None: async with AsyncDuckDBStore.from_conn_string(":memory:") as store: await store.setup() test_pref = str(uuid.uuid4()) test_namespaces = [ (test_pref, "test", "documents", "public", test_pref), (test_pref, "test", "documents", "private", test_pref), (test_pref, "test", "images", "public", test_pref), (test_pref, "test", "images", "private", test_pref), (test_pref, "prod", "documents", "public", test_pref), ( test_pref, "prod", "documents", "some", "nesting", "public", test_pref, ), (test_pref, "prod", "documents", "private", test_pref), ] for namespace in test_namespaces: await store.aput(namespace, "dummy", {"content": "dummy"}) prefix_result = await store.alist_namespaces(prefix=[test_pref, "test"]) assert len(prefix_result) == 4 assert all([ns[1] == "test" for ns in prefix_result]) specific_prefix_result = await store.alist_namespaces( prefix=[test_pref, "test", "documents"] ) assert len(specific_prefix_result) == 2 assert all([ns[1:3] == ("test", "documents") for ns in specific_prefix_result]) suffix_result = await store.alist_namespaces(suffix=["public", test_pref]) assert len(suffix_result) == 4 assert all(ns[-2] == "public" for ns in suffix_result) prefix_suffix_result = await store.alist_namespaces( prefix=[test_pref, "test"], suffix=["public", test_pref] ) assert len(prefix_suffix_result) == 2 assert all( ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result ) wildcard_prefix_result = await store.alist_namespaces( prefix=[test_pref, "*", "documents"] ) assert len(wildcard_prefix_result) == 5 assert all(ns[2] == "documents" for ns in wildcard_prefix_result) wildcard_suffix_result = await store.alist_namespaces( suffix=["*", "public", test_pref] ) assert len(wildcard_suffix_result) == 4 assert all(ns[-2] == "public" for ns in wildcard_suffix_result) wildcard_single = await store.alist_namespaces( suffix=["some", "*", "public", test_pref] ) assert len(wildcard_single) == 1 assert wildcard_single[0] == ( test_pref, "prod", "documents", "some", "nesting", "public", test_pref, ) max_depth_result = await store.alist_namespaces(max_depth=3) assert all([len(ns) <= 3 for ns in max_depth_result]) max_depth_result = await store.alist_namespaces( max_depth=4, prefix=[test_pref, "*", "documents"] ) assert ( len(set(tuple(res) for res in max_depth_result)) == len(max_depth_result) == 5 ) limit_result = await store.alist_namespaces(prefix=[test_pref], limit=3) assert len(limit_result) == 3 offset_result = await store.alist_namespaces(prefix=[test_pref], offset=3) assert len(offset_result) == len(test_namespaces) - 3 empty_prefix_result = await store.alist_namespaces(prefix=[test_pref]) assert len(empty_prefix_result) == len(test_namespaces) assert set(tuple(ns) for ns in empty_prefix_result) == set( tuple(ns) for ns in test_namespaces ) for namespace in test_namespaces: await store.adelete(namespace, "dummy") async def test_search(): async with AsyncDuckDBStore.from_conn_string(":memory:") as store: await store.setup() test_namespaces = [ ("test_search", "documents", "user1"), ("test_search", "documents", "user2"), ("test_search", "reports", "department1"), ("test_search", "reports", "department2"), ] test_items = [ {"title": "Doc 1", "author": "John Doe", "tags": ["important"]}, {"title": "Doc 2", "author": "Jane Smith", "tags": ["draft"]}, {"title": "Report A", "author": "John Doe", "tags": ["final"]}, {"title": "Report B", "author": "Alice Johnson", "tags": ["draft"]}, ] empty = await store.asearch( ( "scoped", "assistant_id", "shared", "6c5356f6-63ab-4158-868d-cd9fd14c736e", ), limit=10, offset=0, ) assert len(empty) == 0 for namespace, item in zip(test_namespaces, test_items): await store.aput(namespace, f"item_{namespace[-1]}", item) docs_result = await store.asearch(["test_search", "documents"]) assert len(docs_result) == 2 assert all([item.namespace[1] == "documents" for item in docs_result]), [ item.namespace for item in docs_result ] reports_result = await store.asearch(["test_search", "reports"]) assert len(reports_result) == 2 assert all(item.namespace[1] == "reports" for item in reports_result) limited_result = await store.asearch(["test_search"], limit=2) assert len(limited_result) == 2 offset_result = await store.asearch(["test_search"]) assert len(offset_result) == 4 offset_result = await store.asearch(["test_search"], offset=2) assert len(offset_result) == 2 assert all(item not in limited_result for item in offset_result) john_doe_result = await store.asearch( ["test_search"], filter={"author": "John Doe"} ) assert len(john_doe_result) == 2 assert all(item.value["author"] == "John Doe" for item in john_doe_result) draft_result = await store.asearch(["test_search"], filter={"tags": ["draft"]}) assert len(draft_result) == 2 assert all("draft" in item.value["tags"] for item in draft_result) page1 = await store.asearch(["test_search"], limit=2, offset=0) page2 = await store.asearch(["test_search"], limit=2, offset=2) all_items = page1 + page2 assert len(all_items) == 4 assert len(set(item.key for item in all_items)) == 4 empty = await store.asearch( ( "scoped", "assistant_id", "shared", "again", "maybe", "some-long", "6be5cb0e-2eb4-42e6-bb6b-fba3c269db25", ), limit=10, offset=0, ) assert len(empty) == 0 # Test with a namespace beginning with a number (like a UUID) uuid_namespace = (str(uuid.uuid4()), "documents") uuid_item_id = "uuid_doc" uuid_item_value = { "title": "UUID Document", "content": "This document has a UUID namespace.", } # Insert the item with the UUID namespace await store.aput(uuid_namespace, uuid_item_id, uuid_item_value) # Retrieve the item to verify it was stored correctly retrieved_item = await store.aget(uuid_namespace, uuid_item_id) assert retrieved_item is not None assert retrieved_item.namespace == uuid_namespace assert retrieved_item.key == uuid_item_id assert retrieved_item.value == uuid_item_value # Search for the item using the UUID namespace search_result = await store.asearch([uuid_namespace[0]]) assert len(search_result) == 1 assert search_result[0].key == uuid_item_id assert search_result[0].value == uuid_item_value # Clean up: delete the item with the UUID namespace await store.adelete(uuid_namespace, uuid_item_id) # Verify the item was deleted deleted_item = await store.aget(uuid_namespace, uuid_item_id) assert deleted_item is None for namespace in test_namespaces: await store.adelete(namespace, f"item_{namespace[-1]}") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/default_sync.py`: ```py import concurrent.futures from typing import Optional, Sequence from kafka import KafkaConsumer, KafkaProducer from langgraph.scheduler.kafka.types import ConsumerRecord, TopicPartition class DefaultConsumer(KafkaConsumer): def getmany( self, timeout_ms: int, max_records: int ) -> dict[TopicPartition, Sequence[ConsumerRecord]]: return self.poll(timeout_ms=timeout_ms, max_records=max_records) def __enter__(self): return self def __exit__(self, *args): self.close() class DefaultProducer(KafkaProducer): def send( self, topic: str, *, key: Optional[bytes] = None, value: Optional[bytes] = None, ) -> concurrent.futures.Future: fut = concurrent.futures.Future() kfut = super().send(topic, key=key, value=value) kfut.add_callback(fut.set_result) kfut.add_errback(fut.set_exception) return fut def __enter__(self): return self def __exit__(self, *args): self.close() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py`: ```py import aiokafka class DefaultAsyncConsumer(aiokafka.AIOKafkaConsumer): pass class DefaultAsyncProducer(aiokafka.AIOKafkaProducer): pass ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/serde.py`: ```py from typing import Any import orjson from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer SERIALIZER = JsonPlusSerializer() def loads(v: bytes) -> Any: return SERIALIZER.loads(v) def dumps(v: Any) -> bytes: return orjson.dumps(v, default=_default) def _default(v: Any) -> Any: # things we don't know how to serialize (eg. functions) ignore return None ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/types.py`: ```py import asyncio import concurrent.futures from typing import Any, NamedTuple, Optional, Protocol, Sequence, TypedDict, Union from langchain_core.runnables import RunnableConfig class Topics(NamedTuple): orchestrator: str executor: str error: str class Sendable(TypedDict): topic: str value: Optional[Any] key: Optional[Any] class MessageToOrchestrator(TypedDict): input: Optional[dict[str, Any]] config: RunnableConfig finally_send: Optional[Sequence[Sendable]] class ExecutorTask(TypedDict): id: Optional[str] path: tuple[Union[str, int], ...] class MessageToExecutor(TypedDict): config: RunnableConfig task: ExecutorTask finally_send: Optional[Sequence[Sendable]] class ErrorMessage(TypedDict): topic: str error: str msg: Union[MessageToExecutor, MessageToOrchestrator] class TopicPartition(Protocol): topic: str partition: int class ConsumerRecord(Protocol): topic: str "The topic this record is received from" partition: int "The partition from which this record is received" offset: int "The position of this record in the corresponding Kafka partition." timestamp: int "The timestamp of this record" timestamp_type: int "The timestamp type of this record" key: Optional[bytes] "The key (or `None` if no key is specified)" value: Optional[bytes] "The value" class Consumer(Protocol): def getmany( self, timeout_ms: int, max_records: int ) -> dict[TopicPartition, Sequence[ConsumerRecord]]: ... def commit(self) -> None: ... class AsyncConsumer(Protocol): async def getmany( self, timeout_ms: int, max_records: int ) -> dict[TopicPartition, Sequence[ConsumerRecord]]: ... async def commit(self) -> None: ... class Producer(Protocol): def send( self, topic: str, *, key: Optional[bytes] = None, value: Optional[bytes] = None, ) -> concurrent.futures.Future: ... class AsyncProducer(Protocol): async def send( self, topic: str, *, key: Optional[bytes] = None, value: Optional[bytes] = None, ) -> asyncio.Future: ... ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/retry.py`: ```py import asyncio import logging import random import time from typing import Awaitable, Callable, Optional from typing_extensions import ParamSpec from langgraph.types import RetryPolicy logger = logging.getLogger(__name__) P = ParamSpec("P") def retry( retry_policy: Optional[RetryPolicy], func: Callable[P, None], *args: P.args, **kwargs: P.kwargs, ) -> None: """Run a task asynchronously with retries.""" interval = retry_policy.initial_interval if retry_policy else 0 attempts = 0 while True: try: func(*args, **kwargs) # if successful, end break except Exception as exc: if retry_policy is None: raise # increment attempts attempts += 1 # check if we should retry if callable(retry_policy.retry_on): if not retry_policy.retry_on(exc): raise elif not isinstance(exc, retry_policy.retry_on): raise # check if we should give up if attempts >= retry_policy.max_attempts: raise # sleep before retrying interval = min( retry_policy.max_interval, interval * retry_policy.backoff_factor, ) time.sleep( interval + random.uniform(0, 1) if retry_policy.jitter else interval ) # log the retry logger.info( f"Retrying function {func} with {args} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}", exc_info=exc, ) async def aretry( retry_policy: Optional[RetryPolicy], func: Callable[P, Awaitable[None]], *args: P.args, **kwargs: P.kwargs, ) -> None: """Run a task asynchronously with retries.""" interval = retry_policy.initial_interval if retry_policy else 0 attempts = 0 while True: try: await func(*args, **kwargs) # if successful, end break except Exception as exc: if retry_policy is None: raise # increment attempts attempts += 1 # check if we should retry if callable(retry_policy.retry_on): if not retry_policy.retry_on(exc): raise elif not isinstance(exc, retry_policy.retry_on): raise # check if we should give up if attempts >= retry_policy.max_attempts: raise # sleep before retrying interval = min( retry_policy.max_interval, interval * retry_policy.backoff_factor, ) await asyncio.sleep( interval + random.uniform(0, 1) if retry_policy.jitter else interval ) # log the retry logger.info( f"Retrying function {func} with {args} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}", exc_info=exc, ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py`: ```py import asyncio import concurrent.futures from contextlib import ( AbstractAsyncContextManager, AbstractContextManager, AsyncExitStack, ExitStack, ) from typing import Any, Optional from langchain_core.runnables import ensure_config from typing_extensions import Self import langgraph.scheduler.kafka.serde as serde from langgraph.constants import ( CONFIG_KEY_DEDUPE_TASKS, CONFIG_KEY_ENSURE_LATEST, INTERRUPT, NS_END, NS_SEP, SCHEDULED, ) from langgraph.errors import CheckpointNotLatest, GraphInterrupt from langgraph.pregel import Pregel from langgraph.pregel.executor import BackgroundExecutor, Submit from langgraph.pregel.loop import AsyncPregelLoop, SyncPregelLoop from langgraph.scheduler.kafka.retry import aretry, retry from langgraph.scheduler.kafka.types import ( AsyncConsumer, AsyncProducer, Consumer, ErrorMessage, ExecutorTask, MessageToExecutor, MessageToOrchestrator, Producer, Topics, ) from langgraph.types import RetryPolicy from langgraph.utils.config import patch_configurable class AsyncKafkaOrchestrator(AbstractAsyncContextManager): consumer: AsyncConsumer producer: AsyncProducer def __init__( self, graph: Pregel, topics: Topics, batch_max_n: int = 10, batch_max_ms: int = 1000, retry_policy: Optional[RetryPolicy] = None, consumer: Optional[AsyncConsumer] = None, producer: Optional[AsyncProducer] = None, **kwargs: Any, ) -> None: self.graph = graph self.topics = topics self.stack = AsyncExitStack() self.kwargs = kwargs self.consumer = consumer self.producer = producer self.batch_max_n = batch_max_n self.batch_max_ms = batch_max_ms self.retry_policy = retry_policy async def __aenter__(self) -> Self: loop = asyncio.get_running_loop() self.subgraphs = { k: v async for k, v in self.graph.aget_subgraphs(recurse=True) } if self.consumer is None: from langgraph.scheduler.kafka.default_async import DefaultAsyncConsumer self.consumer = await self.stack.enter_async_context( DefaultAsyncConsumer( self.topics.orchestrator, auto_offset_reset="earliest", group_id="orchestrator", enable_auto_commit=False, loop=loop, **self.kwargs, ) ) if self.producer is None: from langgraph.scheduler.kafka.default_async import DefaultAsyncProducer self.producer = await self.stack.enter_async_context( DefaultAsyncProducer( loop=loop, **self.kwargs, ) ) return self async def __aexit__(self, *args: Any) -> None: return await self.stack.__aexit__(*args) def __aiter__(self) -> Self: return self async def __anext__(self) -> list[MessageToOrchestrator]: # wait for next batch recs = await self.consumer.getmany( timeout_ms=self.batch_max_ms, max_records=self.batch_max_n ) # dedupe messages, eg. if multiple nodes finish around same time uniq = set(msg.value for msgs in recs.values() for msg in msgs) msgs: list[MessageToOrchestrator] = [serde.loads(msg) for msg in uniq] # process batch await asyncio.gather(*(self.each(msg) for msg in msgs)) # commit offsets await self.consumer.commit() # return message return msgs async def each(self, msg: MessageToOrchestrator) -> None: try: await aretry(self.retry_policy, self.attempt, msg) except CheckpointNotLatest: pass except GraphInterrupt: pass except Exception as exc: fut = await self.producer.send( self.topics.error, value=serde.dumps( ErrorMessage( topic=self.topics.orchestrator, msg=msg, error=repr(exc), ) ), ) await fut async def attempt(self, msg: MessageToOrchestrator) -> None: # find graph if checkpoint_ns := msg["config"]["configurable"].get("checkpoint_ns"): # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) ) # find the subgraph with the matching name if recast_checkpoint_ns in self.subgraphs: graph = self.subgraphs[recast_checkpoint_ns] else: raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") else: graph = self.graph # process message async with AsyncPregelLoop( msg["input"], config=ensure_config(msg["config"]), stream=None, store=self.graph.store, checkpointer=self.graph.checkpointer, nodes=graph.nodes, specs=graph.channels, output_keys=graph.output_channels, stream_keys=graph.stream_channels, interrupt_after=graph.interrupt_after_nodes, interrupt_before=graph.interrupt_before_nodes, check_subgraphs=False, ) as loop: if loop.tick(input_keys=graph.input_channels): # wait for checkpoint to be saved if hasattr(loop, "_put_checkpoint_fut"): await loop._put_checkpoint_fut # schedule any new tasks if new_tasks := [ t for t in loop.tasks.values() if not t.scheduled and not t.writes ]: # send messages to executor futures = await asyncio.gather( *( self.producer.send( self.topics.executor, value=serde.dumps( MessageToExecutor( config=patch_configurable( loop.config, { **loop.checkpoint_config[ "configurable" ], CONFIG_KEY_DEDUPE_TASKS: True, CONFIG_KEY_ENSURE_LATEST: True, }, ), task=ExecutorTask(id=task.id, path=task.path), finally_send=msg.get("finally_send"), ) ), ) for task in new_tasks ) ) # wait for messages to be sent await asyncio.gather(*futures) # mark as scheduled for task in new_tasks: loop.put_writes( task.id, [ ( SCHEDULED, max( loop.checkpoint["versions_seen"] .get(INTERRUPT, {}) .values(), default=None, ), ) ], ) elif loop.status == "done" and msg.get("finally_send"): # send any finally_send messages futs = await asyncio.gather( *( self.producer.send( m["topic"], value=serde.dumps(m["value"]) if m.get("value") else None, key=serde.dumps(m["key"]) if m.get("key") else None, ) for m in msg["finally_send"] ) ) # wait for messages to be sent await asyncio.gather(*futs) class KafkaOrchestrator(AbstractContextManager): consumer: Consumer producer: Producer submit: Submit def __init__( self, graph: Pregel, topics: Topics, batch_max_n: int = 10, batch_max_ms: int = 1000, retry_policy: Optional[RetryPolicy] = None, consumer: Optional[Consumer] = None, producer: Optional[Producer] = None, **kwargs: Any, ) -> None: self.graph = graph self.topics = topics self.stack = ExitStack() self.kwargs = kwargs self.consumer = consumer self.producer = producer self.batch_max_n = batch_max_n self.batch_max_ms = batch_max_ms self.retry_policy = retry_policy def __enter__(self) -> Self: self.subgraphs = dict(self.graph.get_subgraphs(recurse=True)) self.submit = self.stack.enter_context(BackgroundExecutor({})) if self.consumer is None: from langgraph.scheduler.kafka.default_sync import DefaultConsumer self.consumer = self.stack.enter_context( DefaultConsumer( self.topics.orchestrator, auto_offset_reset="earliest", group_id="orchestrator", enable_auto_commit=False, **self.kwargs, ) ) if self.producer is None: from langgraph.scheduler.kafka.default_sync import DefaultProducer self.producer = self.stack.enter_context( DefaultProducer( **self.kwargs, ) ) return self def __exit__(self, *args: Any) -> None: return self.stack.__exit__(*args) def __iter__(self) -> Self: return self def __next__(self) -> list[MessageToOrchestrator]: # wait for next batch recs = self.consumer.getmany( timeout_ms=self.batch_max_ms, max_records=self.batch_max_n ) # dedupe messages, eg. if multiple nodes finish around same time uniq = set(msg.value for msgs in recs.values() for msg in msgs) msgs: list[MessageToOrchestrator] = [serde.loads(msg) for msg in uniq] # process batch concurrent.futures.wait(self.submit(self.each, msg) for msg in msgs) # commit offsets self.consumer.commit() # return message return msgs def each(self, msg: MessageToOrchestrator) -> None: try: retry(self.retry_policy, self.attempt, msg) except CheckpointNotLatest: pass except GraphInterrupt: pass except Exception as exc: fut = self.producer.send( self.topics.error, value=serde.dumps( ErrorMessage( topic=self.topics.orchestrator, msg=msg, error=repr(exc), ) ), ) fut.result() def attempt(self, msg: MessageToOrchestrator) -> None: # find graph if checkpoint_ns := msg["config"]["configurable"].get("checkpoint_ns"): # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) ) # find the subgraph with the matching name if recast_checkpoint_ns in self.subgraphs: graph = self.subgraphs[recast_checkpoint_ns] else: raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") else: graph = self.graph # process message with SyncPregelLoop( msg["input"], config=ensure_config(msg["config"]), stream=None, store=self.graph.store, checkpointer=self.graph.checkpointer, nodes=graph.nodes, specs=graph.channels, output_keys=graph.output_channels, stream_keys=graph.stream_channels, interrupt_after=graph.interrupt_after_nodes, interrupt_before=graph.interrupt_before_nodes, check_subgraphs=False, ) as loop: if loop.tick(input_keys=graph.input_channels): # wait for checkpoint to be saved if hasattr(loop, "_put_checkpoint_fut"): loop._put_checkpoint_fut.result() # schedule any new tasks if new_tasks := [ t for t in loop.tasks.values() if not t.scheduled and not t.writes ]: # send messages to executor futures = [ self.producer.send( self.topics.executor, value=serde.dumps( MessageToExecutor( config=patch_configurable( loop.config, { **loop.checkpoint_config["configurable"], CONFIG_KEY_DEDUPE_TASKS: True, CONFIG_KEY_ENSURE_LATEST: True, }, ), task=ExecutorTask(id=task.id, path=task.path), finally_send=msg.get("finally_send"), ) ), ) for task in new_tasks ] # wait for messages to be sent concurrent.futures.wait(futures) # mark as scheduled for task in new_tasks: loop.put_writes( task.id, [ ( SCHEDULED, max( loop.checkpoint["versions_seen"] .get(INTERRUPT, {}) .values(), default=None, ), ) ], ) elif loop.status == "done" and msg.get("finally_send"): # schedule any finally_send msgs futs = [ self.producer.send( m["topic"], value=serde.dumps(m["value"]) if m.get("value") else None, key=serde.dumps(m["key"]) if m.get("key") else None, ) for m in msg["finally_send"] ] # wait for messages to be sent concurrent.futures.wait(futs) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py`: ```py import asyncio import concurrent.futures from contextlib import ( AbstractAsyncContextManager, AbstractContextManager, AsyncExitStack, ExitStack, ) from functools import partial from typing import Any, Optional, Sequence from uuid import UUID import orjson from langchain_core.runnables import RunnableConfig from typing_extensions import Self import langgraph.scheduler.kafka.serde as serde from langgraph.constants import CONFIG_KEY_DELEGATE, ERROR, NS_END, NS_SEP from langgraph.errors import CheckpointNotLatest, GraphDelegate, TaskNotFound from langgraph.pregel import Pregel from langgraph.pregel.algo import prepare_single_task from langgraph.pregel.executor import ( AsyncBackgroundExecutor, BackgroundExecutor, Submit, ) from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager from langgraph.pregel.runner import PregelRunner from langgraph.scheduler.kafka.retry import aretry, retry from langgraph.scheduler.kafka.types import ( AsyncConsumer, AsyncProducer, Consumer, ErrorMessage, MessageToExecutor, MessageToOrchestrator, Producer, Sendable, Topics, ) from langgraph.types import LoopProtocol, PregelExecutableTask, RetryPolicy from langgraph.utils.config import patch_configurable class AsyncKafkaExecutor(AbstractAsyncContextManager): consumer: AsyncConsumer producer: AsyncProducer def __init__( self, graph: Pregel, topics: Topics, *, batch_max_n: int = 10, batch_max_ms: int = 1000, retry_policy: Optional[RetryPolicy] = None, consumer: Optional[AsyncConsumer] = None, producer: Optional[AsyncProducer] = None, **kwargs: Any, ) -> None: self.graph = graph self.topics = topics self.stack = AsyncExitStack() self.kwargs = kwargs self.consumer = consumer self.producer = producer self.batch_max_n = batch_max_n self.batch_max_ms = batch_max_ms self.retry_policy = retry_policy async def __aenter__(self) -> Self: loop = asyncio.get_running_loop() self.subgraphs = { k: v async for k, v in self.graph.aget_subgraphs(recurse=True) } if self.consumer is None: from langgraph.scheduler.kafka.default_async import DefaultAsyncConsumer self.consumer = await self.stack.enter_async_context( DefaultAsyncConsumer( self.topics.executor, auto_offset_reset="earliest", group_id="executor", enable_auto_commit=False, loop=loop, **self.kwargs, ) ) if self.producer is None: from langgraph.scheduler.kafka.default_async import DefaultAsyncProducer self.producer = await self.stack.enter_async_context( DefaultAsyncProducer( loop=loop, **self.kwargs, ) ) return self async def __aexit__(self, *args: Any) -> None: return await self.stack.__aexit__(*args) def __aiter__(self) -> Self: return self async def __anext__(self) -> Sequence[MessageToExecutor]: # wait for next batch recs = await self.consumer.getmany( timeout_ms=self.batch_max_ms, max_records=self.batch_max_n ) msgs: list[MessageToExecutor] = [ serde.loads(msg.value) for msgs in recs.values() for msg in msgs ] # process batch await asyncio.gather(*(self.each(msg) for msg in msgs)) # commit offsets await self.consumer.commit() # return message return msgs async def each(self, msg: MessageToExecutor) -> None: try: await aretry(self.retry_policy, self.attempt, msg) except CheckpointNotLatest: pass except GraphDelegate as exc: for arg in exc.args: fut = await self.producer.send( self.topics.orchestrator, value=serde.dumps( MessageToOrchestrator( config=arg["config"], input=orjson.Fragment( self.graph.checkpointer.serde.dumps(arg["input"]) ), finally_send=[ Sendable(topic=self.topics.executor, value=msg) ], ) ), # use thread_id, checkpoint_ns as partition key key=serde.dumps( ( arg["config"]["configurable"]["thread_id"], arg["config"]["configurable"].get("checkpoint_ns"), ) ), ) await fut except Exception as exc: fut = await self.producer.send( self.topics.error, value=serde.dumps( ErrorMessage( topic=self.topics.executor, msg=msg, error=repr(exc), ) ), ) await fut async def attempt(self, msg: MessageToExecutor) -> None: # find graph if checkpoint_ns := msg["config"]["configurable"].get("checkpoint_ns"): # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) ) # find the subgraph with the matching name if recast_checkpoint_ns in self.subgraphs: graph = self.subgraphs[recast_checkpoint_ns] else: raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") else: graph = self.graph # process message saved = await self.graph.checkpointer.aget_tuple( patch_configurable(msg["config"], {"checkpoint_id": None}) ) if saved is None: raise RuntimeError("Checkpoint not found") if saved.checkpoint["id"] != msg["config"]["configurable"]["checkpoint_id"]: raise CheckpointNotLatest() async with AsyncChannelsManager( graph.channels, saved.checkpoint, LoopProtocol( config=msg["config"], store=self.graph.store, step=saved.metadata["step"] + 1, stop=saved.metadata["step"] + 2, ), ) as (channels, managed), AsyncBackgroundExecutor(msg["config"]) as submit: if task := await asyncio.to_thread( prepare_single_task, msg["task"]["path"], msg["task"]["id"], checkpoint=saved.checkpoint, pending_writes=saved.pending_writes or [], processes=graph.nodes, channels=channels, managed=managed, config=patch_configurable(msg["config"], {CONFIG_KEY_DELEGATE: True}), step=saved.metadata["step"] + 1, for_execution=True, checkpointer=self.graph.checkpointer, store=self.graph.store, ): # execute task, saving writes runner = PregelRunner( submit=submit, put_writes=partial(self._put_writes, submit, msg["config"]), schedule_task=self._schedule_task, ) async for _ in runner.atick([task], reraise=False): pass else: # task was not found await self.graph.checkpointer.aput_writes( msg["config"], [(ERROR, TaskNotFound())], str(UUID(int=0)) ) # notify orchestrator fut = await self.producer.send( self.topics.orchestrator, value=serde.dumps( MessageToOrchestrator( input=None, config=msg["config"], finally_send=msg.get("finally_send"), ) ), # use thread_id, checkpoint_ns as partition key key=serde.dumps( ( msg["config"]["configurable"]["thread_id"], msg["config"]["configurable"].get("checkpoint_ns"), ) ), ) await fut def _schedule_task( self, task: PregelExecutableTask, idx: int, ) -> None: # will be scheduled by orchestrator when executor finishes pass def _put_writes( self, submit: Submit, config: RunnableConfig, task_id: str, writes: list[tuple[str, Any]], ) -> None: return submit(self.graph.checkpointer.aput_writes, config, writes, task_id) class KafkaExecutor(AbstractContextManager): consumer: Consumer producer: Producer def __init__( self, graph: Pregel, topics: Topics, *, batch_max_n: int = 10, batch_max_ms: int = 1000, retry_policy: Optional[RetryPolicy] = None, consumer: Optional[Consumer] = None, producer: Optional[Producer] = None, **kwargs: Any, ) -> None: self.graph = graph self.topics = topics self.stack = ExitStack() self.kwargs = kwargs self.consumer = consumer self.producer = producer self.batch_max_n = batch_max_n self.batch_max_ms = batch_max_ms self.retry_policy = retry_policy def __enter__(self) -> Self: self.subgraphs = dict(self.graph.get_subgraphs(recurse=True)) self.submit = self.stack.enter_context(BackgroundExecutor({})) if self.consumer is None: from langgraph.scheduler.kafka.default_sync import DefaultConsumer self.consumer = self.stack.enter_context( DefaultConsumer( self.topics.executor, auto_offset_reset="earliest", group_id="executor", enable_auto_commit=False, **self.kwargs, ) ) if self.producer is None: from langgraph.scheduler.kafka.default_sync import DefaultProducer self.producer = self.stack.enter_context( DefaultProducer( **self.kwargs, ) ) return self def __exit__(self, *args: Any) -> None: return self.stack.__exit__(*args) def __iter__(self) -> Self: return self def __next__(self) -> Sequence[MessageToExecutor]: # wait for next batch recs = self.consumer.getmany( timeout_ms=self.batch_max_ms, max_records=self.batch_max_n ) msgs: list[MessageToExecutor] = [ serde.loads(msg.value) for msgs in recs.values() for msg in msgs ] # process batch concurrent.futures.wait(self.submit(self.each, msg) for msg in msgs) # commit offsets self.consumer.commit() # return message return msgs def each(self, msg: MessageToExecutor) -> None: try: retry(self.retry_policy, self.attempt, msg) except CheckpointNotLatest: pass except GraphDelegate as exc: for arg in exc.args: fut = self.producer.send( self.topics.orchestrator, value=serde.dumps( MessageToOrchestrator( config=arg["config"], input=orjson.Fragment( self.graph.checkpointer.serde.dumps(arg["input"]) ), finally_send=[ Sendable(topic=self.topics.executor, value=msg) ], ) ), # use thread_id, checkpoint_ns as partition key key=serde.dumps( ( arg["config"]["configurable"]["thread_id"], arg["config"]["configurable"].get("checkpoint_ns"), ) ), ) fut.result() except Exception as exc: fut = self.producer.send( self.topics.error, value=serde.dumps( ErrorMessage( topic=self.topics.executor, msg=msg, error=repr(exc), ) ), ) fut.result() def attempt(self, msg: MessageToExecutor) -> None: # find graph if checkpoint_ns := msg["config"]["configurable"].get("checkpoint_ns"): # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) ) # find the subgraph with the matching name if recast_checkpoint_ns in self.subgraphs: graph = self.subgraphs[recast_checkpoint_ns] else: raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") else: graph = self.graph # process message saved = self.graph.checkpointer.get_tuple( patch_configurable(msg["config"], {"checkpoint_id": None}) ) if saved is None: raise RuntimeError("Checkpoint not found") if saved.checkpoint["id"] != msg["config"]["configurable"]["checkpoint_id"]: raise CheckpointNotLatest() with ChannelsManager( graph.channels, saved.checkpoint, LoopProtocol( config=msg["config"], store=self.graph.store, step=saved.metadata["step"] + 1, stop=saved.metadata["step"] + 2, ), ) as (channels, managed), BackgroundExecutor({}) as submit: if task := prepare_single_task( msg["task"]["path"], msg["task"]["id"], checkpoint=saved.checkpoint, pending_writes=saved.pending_writes or [], processes=graph.nodes, channels=channels, managed=managed, config=patch_configurable(msg["config"], {CONFIG_KEY_DELEGATE: True}), step=saved.metadata["step"] + 1, for_execution=True, checkpointer=self.graph.checkpointer, ): # execute task, saving writes runner = PregelRunner( submit=submit, put_writes=partial(self._put_writes, submit, msg["config"]), schedule_task=self._schedule_task, ) for _ in runner.tick([task], reraise=False): pass else: # task was not found self.graph.checkpointer.put_writes( msg["config"], [(ERROR, TaskNotFound())], str(UUID(int=0)) ) # notify orchestrator fut = self.producer.send( self.topics.orchestrator, value=serde.dumps( MessageToOrchestrator( input=None, config=msg["config"], finally_send=msg.get("finally_send"), ) ), # use thread_id, checkpoint_ns as partition key key=serde.dumps( ( msg["config"]["configurable"]["thread_id"], msg["config"]["configurable"].get("checkpoint_ns"), ) ), ) fut.result() def _schedule_task( self, task: PregelExecutableTask, idx: int, ) -> None: # will be scheduled by orchestrator when executor finishes pass def _put_writes( self, submit: Submit, config: RunnableConfig, task_id: str, writes: list[tuple[str, Any]], ) -> None: return submit(self.graph.checkpointer.put_writes, config, writes, task_id) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/pyproject.toml`: ```toml [tool.poetry] name = "langgraph-scheduler-kafka" version = "1.0.0" description = "Library with Kafka-based work scheduler." authors = [] license = "MIT" readme = "README.md" repository = "https://www.github.com/langchain-ai/langgraph" packages = [{ include = "langgraph" }] [tool.poetry.dependencies] python = "^3.9.0,<4.0" orjson = "^3.10.7" crc32c = "^2.7.post1" aiokafka = "^0.11.0" langgraph = "^0.2.19" [tool.poetry.group.dev.dependencies] ruff = "^0.6.2" codespell = "^2.2.0" pytest = "^7.2.1" pytest-mock = "^3.11.1" pytest-watcher = "^0.4.1" mypy = "^1.10.0" langgraph = {path = "../langgraph", develop = true} langgraph-checkpoint-postgres = {path = "../checkpoint-postgres", develop = true} langgraph-checkpoint = {path = "../checkpoint", develop = true} kafka-python-ng = "^2.2.2" [tool.pytest.ini_options] # --strict-markers will raise errors on unknown marks. # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks # # https://docs.pytest.org/en/7.1.x/reference/reference.html # --strict-config any warnings encountered while parsing the `pytest` # section of the configuration file raise errors. addopts = "--strict-markers --strict-config --durations=5 -vv" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] lint.select = [ "E", # pycodestyle "F", # Pyflakes "UP", # pyupgrade "B", # flake8-bugbear "I", # isort ] lint.ignore = ["E501", "B008", "UP007", "UP006"] [tool.pytest-watcher] now = true delay = 0.1 runner_args = ["--ff", "-v", "--tb", "short", "-s"] patterns = ["*.py"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/conftest.py`: ```py from typing import AsyncIterator, Iterator from uuid import uuid4 import kafka.admin import pytest from psycopg import AsyncConnection, Connection from psycopg_pool import AsyncConnectionPool, ConnectionPool from langgraph.checkpoint.postgres import PostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.scheduler.kafka.types import Topics DEFAULT_POSTGRES_URI = "postgres://postgres:postgres@localhost:5443/" @pytest.fixture def anyio_backend(): return "asyncio" @pytest.fixture def topics() -> Iterator[Topics]: o = f"test_o_{uuid4().hex[:16]}" e = f"test_e_{uuid4().hex[:16]}" z = f"test_z_{uuid4().hex[:16]}" admin = kafka.admin.KafkaAdminClient() # create topics admin.create_topics( [ kafka.admin.NewTopic(name=o, num_partitions=1, replication_factor=1), kafka.admin.NewTopic(name=e, num_partitions=1, replication_factor=1), kafka.admin.NewTopic(name=z, num_partitions=1, replication_factor=1), ] ) # yield topics yield Topics(orchestrator=o, executor=e, error=z) # delete topics admin.delete_topics([o, e, z]) admin.close() @pytest.fixture async def acheckpointer() -> AsyncIterator[AsyncPostgresSaver]: database = f"test_{uuid4().hex[:16]}" # create unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"CREATE DATABASE {database}") try: # yield checkpointer async with AsyncConnectionPool( DEFAULT_POSTGRES_URI + database, max_size=10, kwargs={"autocommit": True} ) as pool: checkpointer = AsyncPostgresSaver(pool) await checkpointer.setup() yield checkpointer finally: # drop unique db async with await AsyncConnection.connect( DEFAULT_POSTGRES_URI, autocommit=True ) as conn: await conn.execute(f"DROP DATABASE {database}") @pytest.fixture def checkpointer() -> Iterator[PostgresSaver]: database = f"test_{uuid4().hex[:16]}" # create unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"CREATE DATABASE {database}") try: # yield checkpointer with ConnectionPool( DEFAULT_POSTGRES_URI + database, max_size=10, kwargs={"autocommit": True} ) as pool: checkpointer = PostgresSaver(pool) checkpointer.setup() yield checkpointer finally: # drop unique db with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: conn.execute(f"DROP DATABASE {database}") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_fanout_sync.py`: ```py import operator import time from typing import ( Annotated, Sequence, TypedDict, Union, ) from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.graph.state import StateGraph from langgraph.pregel import Pregel from langgraph.scheduler.kafka import serde from langgraph.scheduler.kafka.default_sync import DefaultProducer from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics from tests.any import AnyDict from tests.drain import drain_topics def mk_fanout_graph( checkpointer: BaseCheckpointSaver, interrupt_before: Sequence[str] = () ) -> Pregel: # copied from test_in_one_fan_out_state_graph_waiting_edge_multiple_cond_edge def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} def retriever_picker(data: State) -> list[str]: return ["analyzer_one", "retriever_two"] def analyzer_one(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} def retriever_two(data: State) -> State: time.sleep(0.1) return {"docs": ["doc3", "doc4"]} def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} def decider(data: State) -> None: return None def decider_cond(data: State) -> str: if data["query"].count("analyzed") > 1: return "qa" else: return "rewrite_query" builder = StateGraph(State) builder.add_node("rewrite_query", rewrite_query) builder.add_node("analyzer_one", analyzer_one) builder.add_node("retriever_one", retriever_one) builder.add_node("retriever_two", retriever_two) builder.add_node("decider", decider) builder.add_node("qa", qa) builder.set_entry_point("rewrite_query") builder.add_conditional_edges("rewrite_query", retriever_picker) builder.add_edge("analyzer_one", "retriever_one") builder.add_edge(["retriever_one", "retriever_two"], "decider") builder.add_conditional_edges("decider", decider_cond) builder.set_finish_point("qa") return builder.compile(checkpointer, interrupt_before=interrupt_before) def test_fanout_graph(topics: Topics, checkpointer: BaseCheckpointSaver) -> None: input = {"query": "what is weather in sf"} config = {"configurable": {"thread_id": "1"}} graph = mk_fanout_graph(checkpointer) # start a new run with DefaultProducer() as producer: producer.send( topics.orchestrator, value=serde.dumps(MessageToOrchestrator(input=input, config=config)), ) producer.flush() # drain topics orch_msgs, exec_msgs = drain_topics(topics, graph, debug=1) # check state state = graph.get_state(config) assert state.next == () assert ( state.values == graph.invoke(input, {"configurable": {"thread_id": "2"}}) == { "docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"], "query": "analyzed: query: analyzed: query: what is weather in sf", "answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4", } ) # check history history = [c for c in graph.get_state_history(config)] assert len(history) == 11 # check messages assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": None, } for c in reversed(history) for _ in c.tasks ] assert exec_msgs == [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": None, } for c in reversed(history) for t in c.tasks ] def test_fanout_graph_w_interrupt( topics: Topics, checkpointer: BaseCheckpointSaver ) -> None: input = {"query": "what is weather in sf"} config = {"configurable": {"thread_id": "1"}} graph = mk_fanout_graph(checkpointer, interrupt_before=["qa"]) # start a new run with DefaultProducer() as producer: producer.send( topics.orchestrator, value=serde.dumps(MessageToOrchestrator(input=input, config=config)), ) producer.flush() orch_msgs, exec_msgs = drain_topics(topics, graph, debug=1) # check interrupted state state = graph.get_state(config) assert state.next == ("qa",) assert ( state.values == graph.invoke(input, {"configurable": {"thread_id": "2"}}) == { "docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"], "query": "analyzed: query: analyzed: query: what is weather in sf", } ) # check history history = [c for c in graph.get_state_history(config)] assert len(history) == 10 # check messages assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": None, } for c in reversed(history[1:]) # the last one wasn't executed # orchestrator messages appear only after tasks for that checkpoint # finish executing, ie. after executor sends message to resume checkpoint for _ in c.tasks ] assert exec_msgs == [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": None, } for c in reversed(history[1:]) # the last one wasn't executed for t in c.tasks ] # resume the thread with DefaultProducer() as producer: producer.send( topics.orchestrator, value=serde.dumps(MessageToOrchestrator(input=None, config=config)), ) producer.flush() orch_msgs, exec_msgs = drain_topics(topics, graph) # check final state state = graph.get_state(config) assert state.next == () assert ( state.values == graph.invoke(None, {"configurable": {"thread_id": "2"}}) == { "answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4", "docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"], "query": "analyzed: query: analyzed: query: what is weather in sf", } ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_fanout.py`: ```py import asyncio import operator from typing import ( Annotated, Sequence, TypedDict, Union, ) import pytest from aiokafka import AIOKafkaProducer from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.graph.state import StateGraph from langgraph.pregel import Pregel from langgraph.scheduler.kafka import serde from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics from tests.any import AnyDict from tests.drain import drain_topics_async pytestmark = pytest.mark.anyio def mk_fanout_graph( checkpointer: BaseCheckpointSaver, interrupt_before: Sequence[str] = () ) -> Pregel: # copied from test_in_one_fan_out_state_graph_waiting_edge_multiple_cond_edge def sorted_add( x: list[str], y: Union[list[str], list[tuple[str, str]]] ) -> list[str]: if isinstance(y[0], tuple): for rem, _ in y: x.remove(rem) y = [t[1] for t in y] return sorted(operator.add(x, y)) class State(TypedDict, total=False): query: str answer: str docs: Annotated[list[str], sorted_add] async def rewrite_query(data: State) -> State: return {"query": f'query: {data["query"]}'} async def retriever_picker(data: State) -> list[str]: return ["analyzer_one", "retriever_two"] async def analyzer_one(data: State) -> State: return {"query": f'analyzed: {data["query"]}'} async def retriever_one(data: State) -> State: return {"docs": ["doc1", "doc2"]} async def retriever_two(data: State) -> State: await asyncio.sleep(0.1) return {"docs": ["doc3", "doc4"]} async def qa(data: State) -> State: return {"answer": ",".join(data["docs"])} async def decider(data: State) -> None: return None def decider_cond(data: State) -> str: if data["query"].count("analyzed") > 1: return "qa" else: return "rewrite_query" builder = StateGraph(State) builder.add_node("rewrite_query", rewrite_query) builder.add_node("analyzer_one", analyzer_one) builder.add_node("retriever_one", retriever_one) builder.add_node("retriever_two", retriever_two) builder.add_node("decider", decider) builder.add_node("qa", qa) builder.set_entry_point("rewrite_query") builder.add_conditional_edges("rewrite_query", retriever_picker) builder.add_edge("analyzer_one", "retriever_one") builder.add_edge(["retriever_one", "retriever_two"], "decider") builder.add_conditional_edges("decider", decider_cond) builder.set_finish_point("qa") return builder.compile(checkpointer, interrupt_before=interrupt_before) async def test_fanout_graph(topics: Topics, acheckpointer: BaseCheckpointSaver) -> None: input = {"query": "what is weather in sf"} config = {"configurable": {"thread_id": "1"}} graph = mk_fanout_graph(acheckpointer) # start a new run async with AIOKafkaProducer(value_serializer=serde.dumps) as producer: await producer.send_and_wait( topics.orchestrator, MessageToOrchestrator(input=input, config=config), ) # drain topics orch_msgs, exec_msgs = await drain_topics_async(topics, graph) # check state state = await graph.aget_state(config) assert state.next == () assert ( state.values == await graph.ainvoke(input, {"configurable": {"thread_id": "2"}}) == { "docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"], "query": "analyzed: query: analyzed: query: what is weather in sf", "answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4", } ) # check history history = [c async for c in graph.aget_state_history(config)] assert len(history) == 11 # check messages assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": None, } for c in reversed(history) for _ in c.tasks ] assert exec_msgs == [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": None, } for c in reversed(history) for t in c.tasks ] async def test_fanout_graph_w_interrupt( topics: Topics, acheckpointer: BaseCheckpointSaver ) -> None: input = {"query": "what is weather in sf"} config = {"configurable": {"thread_id": "1"}} graph = mk_fanout_graph(acheckpointer, interrupt_before=["qa"]) # start a new run async with AIOKafkaProducer(value_serializer=serde.dumps) as producer: await producer.send_and_wait( topics.orchestrator, MessageToOrchestrator(input=input, config=config), ) orch_msgs, exec_msgs = await drain_topics_async(topics, graph) # check interrupted state state = await graph.aget_state(config) assert state.next == ("qa",) assert ( state.values == await graph.ainvoke(input, {"configurable": {"thread_id": "2"}}) == { "docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"], "query": "analyzed: query: analyzed: query: what is weather in sf", } ) # check history history = [c async for c in graph.aget_state_history(config)] assert len(history) == 10 # check messages assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": None, } for c in reversed(history[1:]) # the last one wasn't executed # orchestrator messages appear only after tasks for that checkpoint # finish executing, ie. after executor sends message to resume checkpoint for _ in c.tasks ] assert exec_msgs == [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": None, } for c in reversed(history[1:]) # the last one wasn't executed for t in c.tasks ] # resume the thread async with AIOKafkaProducer(value_serializer=serde.dumps) as producer: await producer.send_and_wait( topics.orchestrator, MessageToOrchestrator(input=None, config=config), ) orch_msgs, exec_msgs = await drain_topics_async(topics, graph) # check final state state = await graph.aget_state(config) assert state.next == () assert ( state.values == await graph.ainvoke(None, {"configurable": {"thread_id": "2"}}) == { "answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4", "docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"], "query": "analyzed: query: analyzed: query: what is weather in sf", } ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_subgraph.py`: ```py from typing import Literal, cast import pytest from aiokafka import AIOKafkaProducer from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import AIMessage, ToolCall from langchain_core.tools import tool from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.constants import END, START from langgraph.graph import MessagesState from langgraph.graph.state import StateGraph from langgraph.pregel import Pregel from langgraph.scheduler.kafka import serde from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics from tests.any import AnyDict, AnyList from tests.drain import drain_topics_async from tests.messages import _AnyIdAIMessage, _AnyIdHumanMessage pytestmark = pytest.mark.anyio def mk_weather_graph(checkpointer: BaseCheckpointSaver) -> Pregel: # copied from test_weather_subgraph # setup subgraph @tool def get_weather(city: str): """Get the weather for a specific city""" return f"I'ts sunny in {city}!" weather_model = FakeMessagesListChatModel( responses=[ AIMessage( content="", tool_calls=[ ToolCall( id="tool_call123", name="get_weather", args={"city": "San Francisco"}, ) ], ) ] ) class SubGraphState(MessagesState): city: str def model_node(state: SubGraphState): result = weather_model.invoke(state["messages"]) return {"city": cast(AIMessage, result).tool_calls[0]["args"]["city"]} def weather_node(state: SubGraphState): result = get_weather.invoke({"city": state["city"]}) return {"messages": [{"role": "assistant", "content": result}]} subgraph = StateGraph(SubGraphState) subgraph.add_node(model_node) subgraph.add_node(weather_node) subgraph.add_edge(START, "model_node") subgraph.add_edge("model_node", "weather_node") subgraph.add_edge("weather_node", END) subgraph = subgraph.compile(interrupt_before=["weather_node"]) # setup main graph class RouterState(MessagesState): route: Literal["weather", "other"] router_model = FakeMessagesListChatModel( responses=[ AIMessage( content="", tool_calls=[ ToolCall( id="tool_call123", name="router", args={"dest": "weather"}, ) ], ) ] ) def router_node(state: RouterState): system_message = "Classify the incoming query as either about weather or not." messages = [{"role": "system", "content": system_message}] + state["messages"] route = router_model.invoke(messages) return {"route": cast(AIMessage, route).tool_calls[0]["args"]["dest"]} def normal_llm_node(state: RouterState): return {"messages": [AIMessage("Hello!")]} def route_after_prediction(state: RouterState): if state["route"] == "weather": return "weather_graph" else: return "normal_llm_node" async def weather_graph(state: RouterState): return await subgraph.ainvoke(state) graph = StateGraph(RouterState) graph.add_node(router_node) graph.add_node(normal_llm_node) graph.add_node("weather_graph", weather_graph) graph.add_edge(START, "router_node") graph.add_conditional_edges("router_node", route_after_prediction) graph.add_edge("normal_llm_node", END) graph.add_edge("weather_graph", END) return graph.compile(checkpointer=checkpointer) async def test_subgraph_w_interrupt( topics: Topics, acheckpointer: BaseCheckpointSaver ) -> None: input = {"messages": [{"role": "user", "content": "what's the weather in sf"}]} config = {"configurable": {"thread_id": "1"}} graph = mk_weather_graph(acheckpointer) # start a new run async with AIOKafkaProducer(value_serializer=serde.dumps) as producer: await producer.send_and_wait( topics.orchestrator, MessageToOrchestrator(input=input, config=config), ) orch_msgs, exec_msgs = await drain_topics_async(topics, graph) # check interrupted state state = await graph.aget_state(config) assert state.next == ("weather_graph",) assert state.values == { "messages": [_AnyIdHumanMessage(content="what's the weather in sf")], "route": "weather", } # check outer history history = [c async for c in graph.aget_state_history(config)] assert len(history) == 3 # check child history child_history = [ c async for c in graph.aget_state_history(history[0].tasks[0].state) ] assert len(child_history) == 3 # check messages assert ( orch_msgs == ( # initial message to outer graph [MessageToOrchestrator(input=input, config=config)] # outer graph messages, until interrupted + [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": None, } for c in reversed(history[1:]) # the last one wasn't executed # orchestrator messages appear only after tasks for that checkpoint # finish executing, ie. after executor sends message to resume checkpoint for _ in c.tasks ] # initial message to child graph + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "__pregel_store": None, "__pregel_task_id": history[0].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": None, "checkpoint_map": { "": history[0].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[0] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": { "messages": [ _AnyIdHumanMessage(content="what's the weather in sf") ], "route": "weather", }, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": False, "checkpoint_id": history[0].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[0].tasks[0].id, "path": list(history[0].tasks[0].path), }, }, } ], } ] # child graph messages, until interrupted + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "__pregel_store": None, "__pregel_task_id": history[0].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { "": history[0].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[0] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": False, "checkpoint_id": history[0].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[0].tasks[0].id, "path": list(history[0].tasks[0].path), }, }, } ], } for c in reversed(child_history[1:]) # the last one wasn't executed # orchestrator messages appear only after tasks for that checkpoint # finish executing, ie. after executor sends message to resume checkpoint for _ in c.tasks ] ) ) assert ( exec_msgs == ( # outer graph tasks [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": None, } for c in reversed(history) for t in c.tasks ] # child graph tasks + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "__pregel_store": None, "__pregel_task_id": history[0].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { "": history[0].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[0] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": False, "checkpoint_id": history[0].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[0].tasks[0].id, "path": list(history[0].tasks[0].path), }, }, } ], } for c in reversed(child_history[1:]) # the last one wasn't executed for t in c.tasks ] ) ) # resume the thread async with AIOKafkaProducer(value_serializer=serde.dumps) as producer: await producer.send_and_wait( topics.orchestrator, MessageToOrchestrator(input=None, config=config), ) orch_msgs, exec_msgs = await drain_topics_async(topics, graph) # check final state state = await graph.aget_state(config) assert state.next == () assert state.values == { "messages": [ _AnyIdHumanMessage(content="what's the weather in sf"), _AnyIdAIMessage(content="I'ts sunny in San Francisco!"), ], "route": "weather", } # check outer history history = [c async for c in graph.aget_state_history(config)] assert len(history) == 4 # check child history # accessing second to last checkpoint, since that's the one w/ subgraph task child_history = [ c async for c in graph.aget_state_history(history[1].tasks[0].state) ] assert len(child_history) == 4 # check messages assert ( orch_msgs == ( # initial message to outer graph [MessageToOrchestrator(input=None, config=config)] # initial message to child graph + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, "__pregel_store": None, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": None, "checkpoint_map": { "": history[1].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[1] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": True, "checkpoint_id": history[1].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[1].tasks[0].id, "path": list(history[1].tasks[0].path), }, }, } ], } ] # child graph messages, from previous last checkpoint onwards + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, "__pregel_store": None, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { "": history[1].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[1] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": True, "checkpoint_id": history[1].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[1].tasks[0].id, "path": list(history[1].tasks[0].path), }, }, } ], } for c in reversed(child_history[:2]) for _ in c.tasks ] # outer graph messages, from previous last checkpoint onwards + [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": None, } for c in reversed(history[:2]) for _ in c.tasks ] ) ) assert ( exec_msgs == ( # outer graph tasks [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": None, } for c in reversed(history[:2]) for t in c.tasks ] # child graph tasks + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, "__pregel_store": None, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { "": history[1].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[1] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": True, "checkpoint_id": history[1].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[1].tasks[0].id, "path": list(history[1].tasks[0].path), }, }, } ], } for c in reversed(child_history[:2]) for t in c.tasks ] # "finally" tasks + [ { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": True, "checkpoint_id": history[1].config["configurable"][ "checkpoint_id" ], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[1].tasks[0].id, "path": list(history[1].tasks[0].path), }, } ] ) ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/messages.py`: ```py """Redefined messages as a work-around for pydantic issue with AnyStr. The code below creates version of pydantic models that will work in unit tests with AnyStr as id field Please note that the `id` field is assigned AFTER the model is created to workaround an issue with pydantic ignoring the __eq__ method on subclassed strings. """ from typing import Any from langchain_core.messages import AIMessage, HumanMessage from tests.any import AnyStr def _AnyIdAIMessage(**kwargs: Any) -> AIMessage: """Create ai message with an any id field.""" message = AIMessage(**kwargs) message.id = AnyStr() return message def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: """Create a human message with an any id field.""" message = HumanMessage(**kwargs) message.id = AnyStr() return message ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_push.py`: ```py import operator from typing import ( Annotated, Literal, Union, ) import pytest from aiokafka import AIOKafkaProducer from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.constants import FF_SEND_V2, START from langgraph.errors import NodeInterrupt from langgraph.graph.state import CompiledStateGraph, StateGraph from langgraph.scheduler.kafka import serde from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics from langgraph.types import Command, Send from tests.any import AnyDict from tests.drain import drain_topics_async pytestmark = pytest.mark.anyio def mk_push_graph( checkpointer: BaseCheckpointSaver, ) -> CompiledStateGraph: # copied from test_send_dedupe_on_resume class InterruptOnce: ticks: int = 0 def __call__(self, state): self.ticks += 1 if self.ticks == 1: raise NodeInterrupt("Bahh") return ["|".join(("flaky", str(state)))] class Node: def __init__(self, name: str): self.name = name self.ticks = 0 self.__name__ = name def __call__(self, state): self.ticks += 1 update = ( [self.name] if isinstance(state, list) else ["|".join((self.name, str(state)))] ) if isinstance(state, Command): return state.copy(update=update) else: return update def send_for_fun(state): return [ Send("2", Command(goto=Send("2", 3))), Send("2", Command(goto=Send("flaky", 4))), "3.1", ] def route_to_three(state) -> Literal["3"]: return "3" builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("2")) builder.add_node(Node("3")) builder.add_node(Node("3.1")) builder.add_node("flaky", InterruptOnce()) builder.add_edge(START, "1") builder.add_conditional_edges("1", send_for_fun) builder.add_conditional_edges("2", route_to_three) return builder.compile(checkpointer=checkpointer) async def test_push_graph(topics: Topics, acheckpointer: BaseCheckpointSaver) -> None: if not FF_SEND_V2: pytest.skip("Test requires FF_SEND_V2") input = ["0"] config = {"configurable": {"thread_id": "1"}} graph = mk_push_graph(acheckpointer) graph_compare = mk_push_graph(acheckpointer) # start a new run async with AIOKafkaProducer(value_serializer=serde.dumps) as producer: await producer.send_and_wait( topics.orchestrator, MessageToOrchestrator(input=input, config=config), ) # drain topics orch_msgs, exec_msgs = await drain_topics_async(topics, graph) # check state state = await graph.aget_state(config) assert all(not t.error for t in state.tasks) assert state.next == ("flaky",) assert ( state.values == await graph_compare.ainvoke(input, {"configurable": {"thread_id": "2"}}) == [ "0", "1", "2|Control(goto=Send(node='2', arg=3))", "2|Control(goto=Send(node='flaky', arg=4))", "2|3", ] ) # check history history = [c async for c in graph.aget_state_history(config)] assert len(history) == 2 # check messages assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": None, } for c in reversed(history) for _ in c.tasks ] assert exec_msgs == [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": _convert_path(t.path), }, "finally_send": None, } for c in reversed(history) for t in c.tasks ] # resume the thread async with AIOKafkaProducer(value_serializer=serde.dumps) as producer: await producer.send_and_wait( topics.orchestrator, MessageToOrchestrator(input=None, config=config), ) orch_msgs, exec_msgs = await drain_topics_async(topics, graph) # check final state state = await graph.aget_state(config) assert state.next == () assert ( state.values == await graph_compare.ainvoke(None, {"configurable": {"thread_id": "2"}}) == [ "0", "1", "2|Control(goto=Send(node='2', arg=3))", "2|Control(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", "3.1", ] ) # check history history = [c async for c in graph.aget_state_history(config)] assert len(history) == 4 # check executions # node "2" doesn't get called again, as we recover writes saved before assert graph.builder.nodes["2"].runnable.func.ticks == 3 # node "flaky" gets called again, as it was interrupted assert graph.builder.nodes["flaky"].runnable.func.ticks == 2 def _convert_path( path: tuple[Union[str, int, tuple], ...], ) -> list[Union[str, int, list]]: return list(_convert_path(p) if isinstance(p, tuple) else p for p in path) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_subgraph_sync.py`: ```py from typing import Literal, cast import pytest from langchain_core.language_models.fake_chat_models import ( FakeMessagesListChatModel, ) from langchain_core.messages import AIMessage, ToolCall from langchain_core.tools import tool from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.constants import END, START from langgraph.graph import MessagesState from langgraph.graph.state import StateGraph from langgraph.pregel import Pregel from langgraph.scheduler.kafka import serde from langgraph.scheduler.kafka.default_sync import DefaultProducer from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics from tests.any import AnyDict, AnyList from tests.drain import drain_topics from tests.messages import _AnyIdAIMessage, _AnyIdHumanMessage pytestmark = pytest.mark.anyio def mk_weather_graph(checkpointer: BaseCheckpointSaver) -> Pregel: # copied from test_weather_subgraph # setup subgraph @tool def get_weather(city: str): """Get the weather for a specific city""" return f"I'ts sunny in {city}!" weather_model = FakeMessagesListChatModel( responses=[ AIMessage( content="", tool_calls=[ ToolCall( id="tool_call123", name="get_weather", args={"city": "San Francisco"}, ) ], ) ] ) class SubGraphState(MessagesState): city: str def model_node(state: SubGraphState): result = weather_model.invoke(state["messages"]) return {"city": cast(AIMessage, result).tool_calls[0]["args"]["city"]} def weather_node(state: SubGraphState): result = get_weather.invoke({"city": state["city"]}) return {"messages": [{"role": "assistant", "content": result}]} subgraph = StateGraph(SubGraphState) subgraph.add_node(model_node) subgraph.add_node(weather_node) subgraph.add_edge(START, "model_node") subgraph.add_edge("model_node", "weather_node") subgraph.add_edge("weather_node", END) subgraph = subgraph.compile(interrupt_before=["weather_node"]) # setup main graph class RouterState(MessagesState): route: Literal["weather", "other"] router_model = FakeMessagesListChatModel( responses=[ AIMessage( content="", tool_calls=[ ToolCall( id="tool_call123", name="router", args={"dest": "weather"}, ) ], ) ] ) def router_node(state: RouterState): system_message = "Classify the incoming query as either about weather or not." messages = [{"role": "system", "content": system_message}] + state["messages"] route = router_model.invoke(messages) return {"route": cast(AIMessage, route).tool_calls[0]["args"]["dest"]} def normal_llm_node(state: RouterState): return {"messages": [AIMessage("Hello!")]} def route_after_prediction(state: RouterState): if state["route"] == "weather": return "weather_graph" else: return "normal_llm_node" def weather_graph(state: RouterState): return subgraph.invoke(state) graph = StateGraph(RouterState) graph.add_node(router_node) graph.add_node(normal_llm_node) graph.add_node("weather_graph", weather_graph) graph.add_edge(START, "router_node") graph.add_conditional_edges("router_node", route_after_prediction) graph.add_edge("normal_llm_node", END) graph.add_edge("weather_graph", END) return graph.compile(checkpointer=checkpointer) def test_subgraph_w_interrupt( topics: Topics, checkpointer: BaseCheckpointSaver ) -> None: input = {"messages": [{"role": "user", "content": "what's the weather in sf"}]} config = {"configurable": {"thread_id": "1"}} graph = mk_weather_graph(checkpointer) # start a new run with DefaultProducer() as producer: producer.send( topics.orchestrator, value=serde.dumps(MessageToOrchestrator(input=input, config=config)), ) producer.flush() orch_msgs, exec_msgs = drain_topics(topics, graph) # check interrupted state state = graph.get_state(config) assert state.next == ("weather_graph",) assert state.values == { "messages": [_AnyIdHumanMessage(content="what's the weather in sf")], "route": "weather", } # check outer history history = [c for c in graph.get_state_history(config)] assert len(history) == 3 # check child history child_history = [c for c in graph.get_state_history(history[0].tasks[0].state)] assert len(child_history) == 3 # check messages assert ( orch_msgs == ( # initial message to outer graph [MessageToOrchestrator(input=input, config=config)] # outer graph messages, until interrupted + [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": None, } for c in reversed(history[1:]) # the last one wasn't executed # orchestrator messages appear only after tasks for that checkpoint # finish executing, ie. after executor sends message to resume checkpoint for _ in c.tasks ] # initial message to child graph + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "__pregel_store": None, "__pregel_task_id": history[0].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": None, "checkpoint_map": { "": history[0].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[0] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": { "messages": [ _AnyIdHumanMessage(content="what's the weather in sf") ], "route": "weather", }, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": False, "checkpoint_id": history[0].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[0].tasks[0].id, "path": list(history[0].tasks[0].path), }, }, } ], } ] # child graph messages, until interrupted + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_store": None, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "__pregel_task_id": history[0].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { "": history[0].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[0] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": False, "checkpoint_id": history[0].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[0].tasks[0].id, "path": list(history[0].tasks[0].path), }, }, } ], } for c in reversed(child_history[1:]) # the last one wasn't executed # orchestrator messages appear only after tasks for that checkpoint # finish executing, ie. after executor sends message to resume checkpoint for _ in c.tasks ] ) ) assert ( exec_msgs == ( # outer graph tasks [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": None, } for c in reversed(history) for t in c.tasks ] # child graph tasks + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_store": None, "__pregel_resuming": False, "__pregel_task_id": history[0].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { "": history[0].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[0] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": False, "checkpoint_id": history[0].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[0].tasks[0].id, "path": list(history[0].tasks[0].path), }, }, } ], } for c in reversed(child_history[1:]) # the last one wasn't executed for t in c.tasks ] ) ) # resume the thread with DefaultProducer() as producer: producer.send( topics.orchestrator, value=serde.dumps(MessageToOrchestrator(input=None, config=config)), ) producer.flush() orch_msgs, exec_msgs = drain_topics(topics, graph) # check final state state = graph.get_state(config) assert state.next == () assert state.values == { "messages": [ _AnyIdHumanMessage(content="what's the weather in sf"), _AnyIdAIMessage(content="I'ts sunny in San Francisco!"), ], "route": "weather", } # check outer history history = [c for c in graph.get_state_history(config)] assert len(history) == 4 # check child history # accessing second to last checkpoint, since that's the one w/ subgraph task child_history = [c for c in graph.get_state_history(history[1].tasks[0].state)] assert len(child_history) == 4 # check messages assert ( orch_msgs == ( # initial message to outer graph [MessageToOrchestrator(input=None, config=config)] # initial message to child graph + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_store": None, "__pregel_resuming": True, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": None, "checkpoint_map": { "": history[1].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[1] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": True, "checkpoint_id": history[1].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[1].tasks[0].id, "path": list(history[1].tasks[0].path), }, }, } ], } ] # child graph messages, from previous last checkpoint onwards + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_store": None, "__pregel_resuming": True, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { "": history[1].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[1] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": True, "checkpoint_id": history[1].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[1].tasks[0].id, "path": list(history[1].tasks[0].path), }, }, } ], } for c in reversed(child_history[:2]) for _ in c.tasks ] # outer graph messages, from previous last checkpoint onwards + [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": None, } for c in reversed(history[:2]) for _ in c.tasks ] ) ) assert ( exec_msgs == ( # outer graph tasks [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": None, } for c in reversed(history[:2]) for t in c.tasks ] # child graph tasks + [ { "config": { "callbacks": None, "configurable": { "__pregel_checkpointer": None, "__pregel_delegate": False, "__pregel_read": None, "__pregel_send": None, "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, "__pregel_store": None, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": {}, "__pregel_writes": AnyList(), "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_map": { "": history[1].config["configurable"]["checkpoint_id"] }, "checkpoint_ns": history[1] .tasks[0] .state["configurable"]["checkpoint_ns"], "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": list(t.path), }, "finally_send": [ { "topic": topics.executor, "value": { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": True, "checkpoint_id": history[1].config[ "configurable" ]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[1].tasks[0].id, "path": list(history[1].tasks[0].path), }, }, } ], } for c in reversed(child_history[:2]) for t in c.tasks ] # "finally" tasks + [ { "config": { "callbacks": None, "configurable": { "__pregel_dedupe_tasks": True, "__pregel_ensure_latest": True, "__pregel_resuming": True, "checkpoint_id": history[1].config["configurable"][ "checkpoint_id" ], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "finally_send": None, "task": { "id": history[1].tasks[0].id, "path": list(history[1].tasks[0].path), }, } ] ) ) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/any.py`: ```py import re from typing import Union class AnyStr(str): def __init__(self, prefix: Union[str, re.Pattern] = "") -> None: super().__init__() self.prefix = prefix def __eq__(self, other: object) -> bool: return isinstance(other, str) and ( other.startswith(self.prefix) if isinstance(self.prefix, str) else self.prefix.match(other) ) def __hash__(self) -> int: return hash((str(self), self.prefix)) class AnyDict(dict): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def __eq__(self, other: object) -> bool: if not self and isinstance(other, dict): return True if not isinstance(other, dict) or len(self) != len(other): return False for k, v in self.items(): if kk := next((kk for kk in other if kk == k), None): if v == other[kk]: continue else: return False else: return True class AnyList(list): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def __eq__(self, other: object) -> bool: if not self and isinstance(other, list): return True if not isinstance(other, list) or len(self) != len(other): return False for i, v in enumerate(self): if v == other[i]: continue else: return False else: return True ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/drain.py`: ```py import asyncio import threading import time from concurrent.futures import ThreadPoolExecutor from typing import Optional, TypeVar import anyio from aiokafka import AIOKafkaConsumer from typing_extensions import ParamSpec from langgraph.pregel import Pregel from langgraph.scheduler.kafka.default_sync import DefaultConsumer from langgraph.scheduler.kafka.executor import AsyncKafkaExecutor, KafkaExecutor from langgraph.scheduler.kafka.orchestrator import ( AsyncKafkaOrchestrator, KafkaOrchestrator, ) from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics C = ParamSpec("C") R = TypeVar("R") async def drain_topics_async( topics: Topics, graph: Pregel, *, debug: bool = False ) -> tuple[list[MessageToOrchestrator], list[MessageToOrchestrator]]: scope: Optional[anyio.CancelScope] = None orch_msgs = [] exec_msgs = [] errors = [] def done() -> bool: return ( len(orch_msgs) > 0 and any(orch_msgs) and len(exec_msgs) > 0 and any(exec_msgs) and not orch_msgs[-1] and not exec_msgs[-1] ) async def orchestrator() -> None: async with AsyncKafkaOrchestrator(graph, topics) as orch: async for msgs in orch: orch_msgs.append(msgs) if debug: print("\n---\norch", len(msgs), msgs) if done(): scope.cancel() async def executor() -> None: async with AsyncKafkaExecutor(graph, topics) as exec: async for msgs in exec: exec_msgs.append(msgs) if debug: print("\n---\nexec", len(msgs), msgs) if done(): scope.cancel() async def error_consumer() -> None: async with AIOKafkaConsumer(topics.error) as consumer: async for msg in consumer: errors.append(msg) if scope: scope.cancel() # start error consumer error_task = asyncio.create_task(error_consumer(), name="error_consumer") # run the orchestrator and executor until break_when async with anyio.create_task_group() as tg: tg.cancel_scope.deadline = anyio.current_time() + 20 scope = tg.cancel_scope tg.start_soon(orchestrator, name="orchestrator") tg.start_soon(executor, name="executor") # cancel error consumer error_task.cancel() try: await error_task except asyncio.CancelledError: pass # check no errors assert not errors, errors return [m for mm in orch_msgs for m in mm], [m for mm in exec_msgs for m in mm] def drain_topics( topics: Topics, graph: Pregel, *, debug: bool = False ) -> tuple[list[MessageToOrchestrator], list[MessageToOrchestrator]]: orch_msgs = [] exec_msgs = [] errors = [] event = threading.Event() def done() -> bool: return ( len(orch_msgs) > 0 and any(orch_msgs) and len(exec_msgs) > 0 and any(exec_msgs) and not orch_msgs[-1] and not exec_msgs[-1] ) def orchestrator() -> None: try: with KafkaOrchestrator(graph, topics) as orch: for msgs in orch: orch_msgs.append(msgs) if debug: print("\n---\norch", len(msgs), msgs) if done(): event.set() if event.is_set(): break except Exception as e: errors.append(e) event.set() def executor() -> None: try: with KafkaExecutor(graph, topics) as exec: for msgs in exec: exec_msgs.append(msgs) if debug: print("\n---\nexec", len(msgs), msgs) if done(): event.set() if event.is_set(): break except Exception as e: errors.append(e) event.set() def error_consumer() -> None: try: with DefaultConsumer(topics.error) as consumer: while not event.is_set(): if msg := consumer.poll(timeout_ms=100): errors.append(msg) event.set() except Exception as e: errors.append(e) event.set() with ThreadPoolExecutor() as pool: # start error consumer pool.submit(error_consumer) # run the orchestrator and executor until break_when pool.submit(orchestrator) pool.submit(executor) # timeout start = time.time() while not event.is_set(): time.sleep(0.1) if time.time() - start > 20: event.set() # check no errors assert not errors, errors return [m for mm in orch_msgs for m in mm], [m for mm in exec_msgs for m in mm] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_push_sync.py`: ```py import operator from typing import ( Annotated, Literal, Union, ) import pytest from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.constants import FF_SEND_V2, START from langgraph.errors import NodeInterrupt from langgraph.graph.state import CompiledStateGraph, StateGraph from langgraph.scheduler.kafka import serde from langgraph.scheduler.kafka.default_sync import DefaultProducer from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics from langgraph.types import Command, Send from tests.any import AnyDict from tests.drain import drain_topics pytestmark = pytest.mark.anyio def mk_push_graph( checkpointer: BaseCheckpointSaver, ) -> CompiledStateGraph: # copied from test_send_dedupe_on_resume class InterruptOnce: ticks: int = 0 def __call__(self, state): self.ticks += 1 if self.ticks == 1: raise NodeInterrupt("Bahh") return ["|".join(("flaky", str(state)))] class Node: def __init__(self, name: str): self.name = name self.ticks = 0 self.__name__ = name def __call__(self, state): self.ticks += 1 update = ( [self.name] if isinstance(state, list) else ["|".join((self.name, str(state)))] ) if isinstance(state, Command): return state.copy(update=update) else: return update def send_for_fun(state): return [ Send("2", Command(goto=Send("2", 3))), Send("2", Command(goto=Send("flaky", 4))), "3.1", ] def route_to_three(state) -> Literal["3"]: return "3" builder = StateGraph(Annotated[list, operator.add]) builder.add_node(Node("1")) builder.add_node(Node("2")) builder.add_node(Node("3")) builder.add_node(Node("3.1")) builder.add_node("flaky", InterruptOnce()) builder.add_edge(START, "1") builder.add_conditional_edges("1", send_for_fun) builder.add_conditional_edges("2", route_to_three) return builder.compile(checkpointer=checkpointer) def test_push_graph(topics: Topics, acheckpointer: BaseCheckpointSaver) -> None: if not FF_SEND_V2: pytest.skip("Test requires FF_SEND_V2") input = ["0"] config = {"configurable": {"thread_id": "1"}} graph = mk_push_graph(acheckpointer) graph_compare = mk_push_graph(acheckpointer) # start a new run with DefaultProducer() as producer: producer.send( topics.orchestrator, value=serde.dumps(MessageToOrchestrator(input=input, config=config)), ) producer.flush() # drain topics orch_msgs, exec_msgs = drain_topics(topics, graph) # check state state = graph.get_state(config) assert all(not t.error for t in state.tasks) assert state.next == ("flaky",) assert ( state.values == graph_compare.invoke(input, {"configurable": {"thread_id": "2"}}) == [ "0", "1", "2|Control(goto=Send(node='2', arg=3))", "2|Control(goto=Send(node='flaky', arg=4))", "2|3", ] ) # check history history = [c for c in graph.get_state_history(config)] assert len(history) == 2 # check messages assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "input": None, "finally_send": None, } for c in reversed(history) for _ in c.tasks ] assert exec_msgs == [ { "config": { "callbacks": None, "configurable": { "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], "checkpoint_ns": "", "thread_id": "1", }, "metadata": AnyDict(), "recursion_limit": 25, "tags": [], }, "task": { "id": t.id, "path": _convert_path(t.path), }, "finally_send": None, } for c in reversed(history) for t in c.tasks ] # resume the thread with DefaultProducer() as producer: producer.send( topics.orchestrator, value=serde.dumps(MessageToOrchestrator(input=None, config=config)), ) producer.flush() orch_msgs, exec_msgs = drain_topics(topics, graph) # check final state state = graph.get_state(config) assert state.next == () assert ( state.values == graph_compare.invoke(None, {"configurable": {"thread_id": "2"}}) == [ "0", "1", "2|Control(goto=Send(node='2', arg=3))", "2|Control(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", "3.1", ] ) # check history history = [c for c in graph.get_state_history(config)] assert len(history) == 4 # check executions # node "2" doesn't get called again, as we recover writes saved before assert graph.builder.nodes["2"].runnable.func.ticks == 3 # node "flaky" gets called again, as it was interrupted assert graph.builder.nodes["flaky"].runnable.func.ticks == 2 def _convert_path( path: tuple[Union[str, int, tuple], ...], ) -> list[Union[str, int, list]]: return list(_convert_path(p) if isinstance(p, tuple) else p for p in path) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/sdk-py/pyproject.toml`: ```toml [tool.poetry] name = "langgraph-sdk" version = "0.1.42" description = "SDK for interacting with LangGraph API" authors = [] license = "MIT" readme = "README.md" repository = "https://www.github.com/langchain-ai/langgraph" packages = [{ include = "langgraph_sdk" }] [tool.poetry.dependencies] python = "^3.9.0,<4.0" httpx = ">=0.25.2" orjson = ">=3.10.1" [tool.poetry.group.dev.dependencies] ruff = "^0.6.2" codespell = "^2.2.0" pytest = "^7.2.1" pytest-asyncio = "^0.21.1" pytest-mock = "^3.11.1" pytest-watch = "^4.2.0" mypy = "^1.10.0" [tool.pytest.ini_options] # --strict-markers will raise errors on unknown marks. # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks # # https://docs.pytest.org/en/7.1.x/reference/reference.html # --strict-config any warnings encountered while parsing the `pytest` # section of the configuration file raise errors. addopts = "--strict-markers --strict-config --durations=5 -vv" asyncio_mode = "auto" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] lint.select = [ "E", # pycodestyle "F", # Pyflakes "UP", # pyupgrade "B", # flake8-bugbear "I", # isort ] lint.ignore = ["E501", "B008", "UP007", "UP006"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/sdk-py/langgraph_sdk/client.py`: ```py """The LangGraph client implementations connect to the LangGraph API. This module provides both asynchronous (LangGraphClient) and synchronous (SyncLanggraphClient) clients to interacting with the LangGraph API's core resources such as Assistants, Threads, Runs, and Cron jobs, as well as its persistent document Store. """ from __future__ import annotations import asyncio import logging import os import sys from typing import ( Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Sequence, Union, overload, ) import httpx import orjson from httpx._types import QueryParamTypes import langgraph_sdk from langgraph_sdk.schema import ( All, Assistant, AssistantVersion, CancelAction, Checkpoint, Command, Config, Cron, DisconnectMode, GraphSchema, IfNotExists, Item, Json, ListNamespaceResponse, MultitaskStrategy, OnCompletionBehavior, OnConflictBehavior, Run, RunCreate, RunStatus, SearchItemsResponse, StreamMode, StreamPart, Subgraphs, Thread, ThreadState, ThreadStatus, ThreadUpdateStateResponse, ) from langgraph_sdk.sse import SSEDecoder, aiter_lines_raw, iter_lines_raw logger = logging.getLogger(__name__) RESERVED_HEADERS = ("x-api-key",) def _get_api_key(api_key: Optional[str] = None) -> Optional[str]: """Get the API key from the environment. Precedence: 1. explicit argument 2. LANGGRAPH_API_KEY 3. LANGSMITH_API_KEY 4. LANGCHAIN_API_KEY """ if api_key: return api_key for prefix in ["LANGGRAPH", "LANGSMITH", "LANGCHAIN"]: if env := os.getenv(f"{prefix}_API_KEY"): return env.strip().strip('"').strip("'") return None # type: ignore def get_headers( api_key: Optional[str], custom_headers: Optional[dict[str, str]] ) -> dict[str, str]: """Combine api_key and custom user-provided headers.""" custom_headers = custom_headers or {} for header in RESERVED_HEADERS: if header in custom_headers: raise ValueError(f"Cannot set reserved header '{header}'") headers = { "User-Agent": f"langgraph-sdk-py/{langgraph_sdk.__version__}", **custom_headers, } api_key = _get_api_key(api_key) if api_key: headers["x-api-key"] = api_key return headers def orjson_default(obj: Any) -> Any: if hasattr(obj, "model_dump") and callable(obj.model_dump): return obj.model_dump() elif hasattr(obj, "dict") and callable(obj.dict): return obj.dict() elif isinstance(obj, (set, frozenset)): return list(obj) else: raise TypeError(f"Object of type {type(obj)} is not JSON serializable") def get_client( *, url: Optional[str] = None, api_key: Optional[str] = None, headers: Optional[dict[str, str]] = None, ) -> LangGraphClient: """Get a LangGraphClient instance. Args: url: The URL of the LangGraph API. api_key: The API key. If not provided, it will be read from the environment. Precedence: 1. explicit argument 2. LANGGRAPH_API_KEY 3. LANGSMITH_API_KEY 4. LANGCHAIN_API_KEY headers: Optional custom headers Returns: LangGraphClient: The top-level client for accessing AssistantsClient, ThreadsClient, RunsClient, and CronClient. Example: from langgraph_sdk import get_client # get top-level LangGraphClient client = get_client(url="http://localhost:8123") # example usage: client.<model>.<method_name>() assistants = await client.assistants.get(assistant_id="some_uuid") """ transport: Optional[httpx.AsyncBaseTransport] = None if url is None: try: from langgraph_api.server import app # type: ignore url = "http://api" transport = httpx.ASGITransport(app, root_path="/noauth") except Exception: url = "http://localhost:8123" if transport is None: transport = httpx.AsyncHTTPTransport(retries=5) client = httpx.AsyncClient( base_url=url, transport=transport, timeout=httpx.Timeout(connect=5, read=300, write=300, pool=5), headers=get_headers(api_key, headers), ) return LangGraphClient(client) class LangGraphClient: """Top-level client for LangGraph API. Attributes: assistants: Manages versioned configuration for your graphs. threads: Handles (potentially) multi-turn interactions, such as conversational threads. runs: Controls individual invocations of the graph. crons: Manages scheduled operations. store: Interfaces with persistent, shared data storage. """ def __init__(self, client: httpx.AsyncClient) -> None: self.http = HttpClient(client) self.assistants = AssistantsClient(self.http) self.threads = ThreadsClient(self.http) self.runs = RunsClient(self.http) self.crons = CronClient(self.http) self.store = StoreClient(self.http) class HttpClient: """Handle async requests to the LangGraph API. Adds additional error messaging & content handling above the provided httpx client. Attributes: client (httpx.AsyncClient): Underlying HTTPX async client. """ def __init__(self, client: httpx.AsyncClient) -> None: self.client = client async def get(self, path: str, *, params: Optional[QueryParamTypes] = None) -> Any: """Send a GET request.""" r = await self.client.get(path, params=params) try: r.raise_for_status() except httpx.HTTPStatusError as e: body = (await r.aread()).decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e return await adecode_json(r) async def post(self, path: str, *, json: Optional[dict]) -> Any: """Send a POST request.""" if json is not None: headers, content = await aencode_json(json) else: headers, content = {}, b"" r = await self.client.post(path, headers=headers, content=content) try: r.raise_for_status() except httpx.HTTPStatusError as e: body = (await r.aread()).decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e return await adecode_json(r) async def put(self, path: str, *, json: dict) -> Any: """Send a PUT request.""" headers, content = await aencode_json(json) r = await self.client.put(path, headers=headers, content=content) try: r.raise_for_status() except httpx.HTTPStatusError as e: body = (await r.aread()).decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e return await adecode_json(r) async def patch(self, path: str, *, json: dict) -> Any: """Send a PATCH request.""" headers, content = await aencode_json(json) r = await self.client.patch(path, headers=headers, content=content) try: r.raise_for_status() except httpx.HTTPStatusError as e: body = (await r.aread()).decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e return await adecode_json(r) async def delete(self, path: str, *, json: Optional[Any] = None) -> None: """Send a DELETE request.""" r = await self.client.request("DELETE", path, json=json) try: r.raise_for_status() except httpx.HTTPStatusError as e: body = (await r.aread()).decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e async def stream( self, path: str, method: str, *, json: Optional[dict] = None ) -> AsyncIterator[StreamPart]: """Stream results using SSE.""" headers, content = await aencode_json(json) headers["Accept"] = "text/event-stream" headers["Cache-Control"] = "no-store" async with self.client.stream( method, path, headers=headers, content=content ) as res: # check status try: res.raise_for_status() except httpx.HTTPStatusError as e: body = (await res.aread()).decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e # check content type content_type = res.headers.get("content-type", "").partition(";")[0] if "text/event-stream" not in content_type: raise httpx.TransportError( "Expected response header Content-Type to contain 'text/event-stream', " f"got {content_type!r}" ) # parse SSE decoder = SSEDecoder() async for line in aiter_lines_raw(res): sse = decoder.decode(line=line.rstrip(b"\n")) if sse is not None: yield sse async def aencode_json(json: Any) -> tuple[dict[str, str], bytes]: body = await asyncio.get_running_loop().run_in_executor( None, orjson.dumps, json, orjson_default, orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS, ) content_length = str(len(body)) content_type = "application/json" headers = {"Content-Length": content_length, "Content-Type": content_type} return headers, body async def adecode_json(r: httpx.Response) -> Any: body = await r.aread() return ( await asyncio.get_running_loop().run_in_executor(None, orjson.loads, body) if body else None ) class AssistantsClient: """Client for managing assistants in LangGraph. This class provides methods to interact with assistants, which are versioned configurations of your graph. Example: client = get_client() assistant = await client.assistants.get("assistant_id_123") """ def __init__(self, http: HttpClient) -> None: self.http = http async def get(self, assistant_id: str) -> Assistant: """Get an assistant by ID. Args: assistant_id: The ID of the assistant to get. Returns: Assistant: Assistant Object. Example Usage: assistant = await client.assistants.get( assistant_id="my_assistant_id" ) print(assistant) ---------------------------------------------------- { 'assistant_id': 'my_assistant_id', 'graph_id': 'agent', 'created_at': '2024-06-25T17:10:33.109781+00:00', 'updated_at': '2024-06-25T17:10:33.109781+00:00', 'config': {}, 'metadata': {'created_by': 'system'} } """ # noqa: E501 return await self.http.get(f"/assistants/{assistant_id}") async def get_graph( self, assistant_id: str, *, xray: Union[int, bool] = False ) -> dict[str, list[dict[str, Any]]]: """Get the graph of an assistant by ID. Args: assistant_id: The ID of the assistant to get the graph of. xray: Include graph representation of subgraphs. If an integer value is provided, only subgraphs with a depth less than or equal to the value will be included. Returns: Graph: The graph information for the assistant in JSON format. Example Usage: graph_info = await client.assistants.get_graph( assistant_id="my_assistant_id" ) print(graph_info) -------------------------------------------------------------------------------------------------------------------------- { 'nodes': [ {'id': '__start__', 'type': 'schema', 'data': '__start__'}, {'id': '__end__', 'type': 'schema', 'data': '__end__'}, {'id': 'agent','type': 'runnable','data': {'id': ['langgraph', 'utils', 'RunnableCallable'],'name': 'agent'}}, ], 'edges': [ {'source': '__start__', 'target': 'agent'}, {'source': 'agent','target': '__end__'} ] } """ # noqa: E501 return await self.http.get( f"/assistants/{assistant_id}/graph", params={"xray": xray} ) async def get_schemas(self, assistant_id: str) -> GraphSchema: """Get the schemas of an assistant by ID. Args: assistant_id: The ID of the assistant to get the schema of. Returns: GraphSchema: The graph schema for the assistant. Example Usage: schema = await client.assistants.get_schemas( assistant_id="my_assistant_id" ) print(schema) ---------------------------------------------------------------------------------------------------------------------------- { 'graph_id': 'agent', 'state_schema': { 'title': 'LangGraphInput', '$ref': '#/definitions/AgentState', 'definitions': { 'BaseMessage': { 'title': 'BaseMessage', 'description': 'Base abstract Message class. Messages are the inputs and outputs of ChatModels.', 'type': 'object', 'properties': { 'content': { 'title': 'Content', 'anyOf': [ {'type': 'string'}, {'type': 'array','items': {'anyOf': [{'type': 'string'}, {'type': 'object'}]}} ] }, 'additional_kwargs': { 'title': 'Additional Kwargs', 'type': 'object' }, 'response_metadata': { 'title': 'Response Metadata', 'type': 'object' }, 'type': { 'title': 'Type', 'type': 'string' }, 'name': { 'title': 'Name', 'type': 'string' }, 'id': { 'title': 'Id', 'type': 'string' } }, 'required': ['content', 'type'] }, 'AgentState': { 'title': 'AgentState', 'type': 'object', 'properties': { 'messages': { 'title': 'Messages', 'type': 'array', 'items': {'$ref': '#/definitions/BaseMessage'} } }, 'required': ['messages'] } } }, 'config_schema': { 'title': 'Configurable', 'type': 'object', 'properties': { 'model_name': { 'title': 'Model Name', 'enum': ['anthropic', 'openai'], 'type': 'string' } } } } """ # noqa: E501 return await self.http.get(f"/assistants/{assistant_id}/schemas") async def get_subgraphs( self, assistant_id: str, namespace: Optional[str] = None, recurse: bool = False ) -> Subgraphs: """Get the schemas of an assistant by ID. Args: assistant_id: The ID of the assistant to get the schema of. Returns: Subgraphs: The graph schema for the assistant. """ # noqa: E501 if namespace is not None: return await self.http.get( f"/assistants/{assistant_id}/subgraphs/{namespace}", params={"recurse": recurse}, ) else: return await self.http.get( f"/assistants/{assistant_id}/subgraphs", params={"recurse": recurse}, ) async def create( self, graph_id: Optional[str], config: Optional[Config] = None, *, metadata: Json = None, assistant_id: Optional[str] = None, if_exists: Optional[OnConflictBehavior] = None, name: Optional[str] = None, ) -> Assistant: """Create a new assistant. Useful when graph is configurable and you want to create different assistants based on different configurations. Args: graph_id: The ID of the graph the assistant should use. The graph ID is normally set in your langgraph.json configuration. config: Configuration to use for the graph. metadata: Metadata to add to assistant. assistant_id: Assistant ID to use, will default to a random UUID if not provided. if_exists: How to handle duplicate creation. Defaults to 'raise' under the hood. Must be either 'raise' (raise error if duplicate), or 'do_nothing' (return existing assistant). name: The name of the assistant. Defaults to 'Untitled' under the hood. Returns: Assistant: The created assistant. Example Usage: assistant = await client.assistants.create( graph_id="agent", config={"configurable": {"model_name": "openai"}}, metadata={"number":1}, assistant_id="my-assistant-id", if_exists="do_nothing", name="my_name" ) """ # noqa: E501 payload: Dict[str, Any] = { "graph_id": graph_id, } if config: payload["config"] = config if metadata: payload["metadata"] = metadata if assistant_id: payload["assistant_id"] = assistant_id if if_exists: payload["if_exists"] = if_exists if name: payload["name"] = name return await self.http.post("/assistants", json=payload) async def update( self, assistant_id: str, *, graph_id: Optional[str] = None, config: Optional[Config] = None, metadata: Json = None, name: Optional[str] = None, ) -> Assistant: """Update an assistant. Use this to point to a different graph, update the configuration, or change the metadata of an assistant. Args: assistant_id: Assistant to update. graph_id: The ID of the graph the assistant should use. The graph ID is normally set in your langgraph.json configuration. If None, assistant will keep pointing to same graph. config: Configuration to use for the graph. metadata: Metadata to merge with existing assistant metadata. name: The new name for the assistant. Returns: Assistant: The updated assistant. Example Usage: assistant = await client.assistants.update( assistant_id='e280dad7-8618-443f-87f1-8e41841c180f', graph_id="other-graph", config={"configurable": {"model_name": "anthropic"}}, metadata={"number":2} ) """ # noqa: E501 payload: Dict[str, Any] = {} if graph_id: payload["graph_id"] = graph_id if config: payload["config"] = config if metadata: payload["metadata"] = metadata if name: payload["name"] = name return await self.http.patch( f"/assistants/{assistant_id}", json=payload, ) async def delete( self, assistant_id: str, ) -> None: """Delete an assistant. Args: assistant_id: The assistant ID to delete. Returns: None Example Usage: await client.assistants.delete( assistant_id="my_assistant_id" ) """ # noqa: E501 await self.http.delete(f"/assistants/{assistant_id}") async def search( self, *, metadata: Json = None, graph_id: Optional[str] = None, limit: int = 10, offset: int = 0, ) -> list[Assistant]: """Search for assistants. Args: metadata: Metadata to filter by. Exact match filter for each KV pair. graph_id: The ID of the graph to filter by. The graph ID is normally set in your langgraph.json configuration. limit: The maximum number of results to return. offset: The number of results to skip. Returns: list[Assistant]: A list of assistants. Example Usage: assistants = await client.assistants.search( metadata = {"name":"my_name"}, graph_id="my_graph_id", limit=5, offset=5 ) """ payload: Dict[str, Any] = { "limit": limit, "offset": offset, } if metadata: payload["metadata"] = metadata if graph_id: payload["graph_id"] = graph_id return await self.http.post( "/assistants/search", json=payload, ) async def get_versions( self, assistant_id: str, metadata: Json = None, limit: int = 10, offset: int = 0, ) -> list[AssistantVersion]: """List all versions of an assistant. Args: assistant_id: The assistant ID to get versions for. metadata: Metadata to filter versions by. Exact match filter for each KV pair. limit: The maximum number of versions to return. offset: The number of versions to skip. Returns: list[Assistant]: A list of assistants. Example Usage: assistant_versions = await client.assistants.get_versions( assistant_id="my_assistant_id" ) """ # noqa: E501 payload: Dict[str, Any] = { "limit": limit, "offset": offset, } if metadata: payload["metadata"] = metadata return await self.http.post( f"/assistants/{assistant_id}/versions", json=payload ) async def set_latest(self, assistant_id: str, version: int) -> Assistant: """Change the version of an assistant. Args: assistant_id: The assistant ID to delete. version: The version to change to. Returns: Assistant: Assistant Object. Example Usage: new_version_assistant = await client.assistants.set_latest( assistant_id="my_assistant_id", version=3 ) """ # noqa: E501 payload: Dict[str, Any] = {"version": version} return await self.http.post(f"/assistants/{assistant_id}/latest", json=payload) class ThreadsClient: """Client for managing threads in LangGraph. A thread maintains the state of a graph across multiple interactions/invocations (aka runs). It accumulates and persists the graph's state, allowing for continuity between separate invocations of the graph. Example: client = get_client() new_thread = await client.threads.create(metadata={"user_id": "123"}) """ def __init__(self, http: HttpClient) -> None: self.http = http async def get(self, thread_id: str) -> Thread: """Get a thread by ID. Args: thread_id: The ID of the thread to get. Returns: Thread: Thread object. Example Usage: thread = await client.threads.get( thread_id="my_thread_id" ) print(thread) ----------------------------------------------------- { 'thread_id': 'my_thread_id', 'created_at': '2024-07-18T18:35:15.540834+00:00', 'updated_at': '2024-07-18T18:35:15.540834+00:00', 'metadata': {'graph_id': 'agent'} } """ # noqa: E501 return await self.http.get(f"/threads/{thread_id}") async def create( self, *, metadata: Json = None, thread_id: Optional[str] = None, if_exists: Optional[OnConflictBehavior] = None, ) -> Thread: """Create a new thread. Args: metadata: Metadata to add to thread. thread_id: ID of thread. If None, ID will be a randomly generated UUID. if_exists: How to handle duplicate creation. Defaults to 'raise' under the hood. Must be either 'raise' (raise error if duplicate), or 'do_nothing' (return existing thread). Returns: Thread: The created thread. Example Usage: thread = await client.threads.create( metadata={"number":1}, thread_id="my-thread-id", if_exists="raise" ) """ # noqa: E501 payload: Dict[str, Any] = {} if thread_id: payload["thread_id"] = thread_id if metadata: payload["metadata"] = metadata if if_exists: payload["if_exists"] = if_exists return await self.http.post("/threads", json=payload) async def update(self, thread_id: str, *, metadata: dict[str, Any]) -> Thread: """Update a thread. Args: thread_id: ID of thread to update. metadata: Metadata to merge with existing thread metadata. Returns: Thread: The created thread. Example Usage: thread = await client.threads.update( thread_id="my-thread-id", metadata={"number":1}, ) """ # noqa: E501 return await self.http.patch( f"/threads/{thread_id}", json={"metadata": metadata} ) async def delete(self, thread_id: str) -> None: """Delete a thread. Args: thread_id: The ID of the thread to delete. Returns: None Example Usage: await client.threads.delete( thread_id="my_thread_id" ) """ # noqa: E501 await self.http.delete(f"/threads/{thread_id}") async def search( self, *, metadata: Json = None, values: Json = None, status: Optional[ThreadStatus] = None, limit: int = 10, offset: int = 0, ) -> list[Thread]: """Search for threads. Args: metadata: Thread metadata to filter on. values: State values to filter on. status: Thread status to filter on. Must be one of 'idle', 'busy', 'interrupted' or 'error'. limit: Limit on number of threads to return. offset: Offset in threads table to start search from. Returns: list[Thread]: List of the threads matching the search parameters. Example Usage: threads = await client.threads.search( metadata={"number":1}, status="interrupted", limit=15, offset=5 ) """ # noqa: E501 payload: Dict[str, Any] = { "limit": limit, "offset": offset, } if metadata: payload["metadata"] = metadata if values: payload["values"] = values if status: payload["status"] = status return await self.http.post( "/threads/search", json=payload, ) async def copy(self, thread_id: str) -> None: """Copy a thread. Args: thread_id: The ID of the thread to copy. Returns: None Example Usage: await client.threads.copy( thread_id="my_thread_id" ) """ # noqa: E501 return await self.http.post(f"/threads/{thread_id}/copy", json=None) async def get_state( self, thread_id: str, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, # deprecated *, subgraphs: bool = False, ) -> ThreadState: """Get the state of a thread. Args: thread_id: The ID of the thread to get the state of. checkpoint: The checkpoint to get the state of. subgraphs: Include subgraphs states. Returns: ThreadState: the thread of the state. Example Usage: thread_state = await client.threads.get_state( thread_id="my_thread_id", checkpoint_id="my_checkpoint_id" ) print(thread_state) ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- { 'values': { 'messages': [ { 'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10', 'example': False }, { 'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.", 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'ai', 'name': None, 'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b', 'example': False, 'tool_calls': [], 'invalid_tool_calls': [], 'usage_metadata': None } ] }, 'next': [], 'checkpoint': { 'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2', 'checkpoint_ns': '', 'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1' } 'metadata': { 'step': 1, 'run_id': '1ef4a9b8-d7da-679a-a45a-872054341df2', 'source': 'loop', 'writes': { 'agent': { 'messages': [ { 'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b', 'name': None, 'type': 'ai', 'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.", 'example': False, 'tool_calls': [], 'usage_metadata': None, 'additional_kwargs': {}, 'response_metadata': {}, 'invalid_tool_calls': [] } ] } }, 'user_id': None, 'graph_id': 'agent', 'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2', 'created_by': 'system', 'assistant_id': 'fe096781-5601-53d2-b2f6-0d3403f7e9ca'}, 'created_at': '2024-07-25T15:35:44.184703+00:00', 'parent_config': { 'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2', 'checkpoint_ns': '', 'checkpoint_id': '1ef4a9b8-d80d-6fa7-8000-9300467fad0f' } } """ # noqa: E501 if checkpoint: return await self.http.post( f"/threads/{thread_id}/state/checkpoint", json={"checkpoint": checkpoint, "subgraphs": subgraphs}, ) elif checkpoint_id: return await self.http.get( f"/threads/{thread_id}/state/{checkpoint_id}", params={"subgraphs": subgraphs}, ) else: return await self.http.get( f"/threads/{thread_id}/state", params={"subgraphs": subgraphs}, ) async def update_state( self, thread_id: str, values: Optional[Union[dict, Sequence[dict]]], *, as_node: Optional[str] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, # deprecated ) -> ThreadUpdateStateResponse: """Update the state of a thread. Args: thread_id: The ID of the thread to update. values: The values to update the state with. as_node: Update the state as if this node had just executed. checkpoint: The checkpoint to update the state of. Returns: ThreadUpdateStateResponse: Response after updating a thread's state. Example Usage: response = await client.threads.update_state( thread_id="my_thread_id", values={"messages":[{"role": "user", "content": "hello!"}]}, as_node="my_node", ) print(response) ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- { 'checkpoint': { 'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2', 'checkpoint_ns': '', 'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1', 'checkpoint_map': {} } } """ # noqa: E501 payload: Dict[str, Any] = { "values": values, } if checkpoint_id: payload["checkpoint_id"] = checkpoint_id if checkpoint: payload["checkpoint"] = checkpoint if as_node: payload["as_node"] = as_node return await self.http.post(f"/threads/{thread_id}/state", json=payload) async def get_history( self, thread_id: str, *, limit: int = 10, before: Optional[str | Checkpoint] = None, metadata: Optional[dict] = None, checkpoint: Optional[Checkpoint] = None, ) -> list[ThreadState]: """Get the state history of a thread. Args: thread_id: The ID of the thread to get the state history for. checkpoint: Return states for this subgraph. If empty defaults to root. limit: The maximum number of states to return. before: Return states before this checkpoint. metadata: Filter states by metadata key-value pairs. Returns: list[ThreadState]: the state history of the thread. Example Usage: thread_state = await client.threads.get_history( thread_id="my_thread_id", limit=5, ) """ # noqa: E501 payload: Dict[str, Any] = { "limit": limit, } if before: payload["before"] = before if metadata: payload["metadata"] = metadata if checkpoint: payload["checkpoint"] = checkpoint return await self.http.post(f"/threads/{thread_id}/history", json=payload) class RunsClient: """Client for managing runs in LangGraph. A run is a single assistant invocation with optional input, config, and metadata. This client manages runs, which can be stateful (on threads) or stateless. Example: client = get_client() run = await client.runs.create(assistant_id="asst_123", thread_id="thread_456", input={"query": "Hello"}) """ def __init__(self, http: HttpClient) -> None: self.http = http @overload def stream( self, thread_id: str, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> AsyncIterator[StreamPart]: ... @overload def stream( self, thread_id: None, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, if_not_exists: Optional[IfNotExists] = None, webhook: Optional[str] = None, after_seconds: Optional[int] = None, ) -> AsyncIterator[StreamPart]: ... def stream( self, thread_id: Optional[str], assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> AsyncIterator[StreamPart]: """Create a run and stream the results. Args: thread_id: the thread ID to assign to the thread. If None will create a stateless run. assistant_id: The assistant ID or graph name to stream from. If using graph name, will default to first assistant created from that graph. input: The input to the graph. command: A command to execute. Cannot be combined with input. stream_mode: The stream mode(s) to use. stream_subgraphs: Whether to stream output from subgraphs. metadata: Metadata to assign to the run. config: The configuration for the assistant. checkpoint: The checkpoint to resume from. interrupt_before: Nodes to interrupt immediately before they get executed. interrupt_after: Nodes to Nodes to interrupt immediately after they get executed. feedback_keys: Feedback keys to assign to run. on_disconnect: The disconnect mode to use. Must be one of 'cancel' or 'continue'. on_completion: Whether to delete or keep the thread created for a stateless run. Must be one of 'delete' or 'keep'. webhook: Webhook to call after LangGraph API call is done. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. if_not_exists: How to handle missing thread. Defaults to 'reject'. Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. Returns: AsyncIterator[StreamPart]: Asynchronous iterator of stream results. Example Usage: async for chunk in client.runs.stream( thread_id=None, assistant_id="agent", input={"messages": [{"role": "user", "content": "how are you?"}]}, stream_mode=["values","debug"], metadata={"name":"my_run"}, config={"configurable": {"model_name": "anthropic"}}, interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"], interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"], feedback_keys=["my_feedback_key_1","my_feedback_key_2"], webhook="https://my.fake.webhook.com", multitask_strategy="interrupt" ): print(chunk) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ StreamPart(event='metadata', data={'run_id': '1ef4a9b8-d7da-679a-a45a-872054341df2'}) StreamPart(event='values', data={'messages': [{'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10', 'example': False}]}) StreamPart(event='values', data={'messages': [{'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10', 'example': False}, {'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.", 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'ai', 'name': None, 'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b', 'example': False, 'tool_calls': [], 'invalid_tool_calls': [], 'usage_metadata': None}]}) StreamPart(event='end', data=None) """ # noqa: E501 payload = { "input": input, "command": command, "config": config, "metadata": metadata, "stream_mode": stream_mode, "stream_subgraphs": stream_subgraphs, "assistant_id": assistant_id, "interrupt_before": interrupt_before, "interrupt_after": interrupt_after, "feedback_keys": feedback_keys, "webhook": webhook, "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, "if_not_exists": if_not_exists, "on_disconnect": on_disconnect, "on_completion": on_completion, "after_seconds": after_seconds, } endpoint = ( f"/threads/{thread_id}/runs/stream" if thread_id is not None else "/runs/stream" ) return self.http.stream( endpoint, "POST", json={k: v for k, v in payload.items() if v is not None} ) @overload async def create( self, thread_id: None, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_completion: Optional[OnCompletionBehavior] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: ... @overload async def create( self, thread_id: str, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: ... async def create( self, thread_id: Optional[str], assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, if_not_exists: Optional[IfNotExists] = None, on_completion: Optional[OnCompletionBehavior] = None, after_seconds: Optional[int] = None, ) -> Run: """Create a background run. Args: thread_id: the thread ID to assign to the thread. If None will create a stateless run. assistant_id: The assistant ID or graph name to stream from. If using graph name, will default to first assistant created from that graph. input: The input to the graph. command: A command to execute. Cannot be combined with input. stream_mode: The stream mode(s) to use. stream_subgraphs: Whether to stream output from subgraphs. metadata: Metadata to assign to the run. config: The configuration for the assistant. checkpoint: The checkpoint to resume from. interrupt_before: Nodes to interrupt immediately before they get executed. interrupt_after: Nodes to Nodes to interrupt immediately after they get executed. webhook: Webhook to call after LangGraph API call is done. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. on_completion: Whether to delete or keep the thread created for a stateless run. Must be one of 'delete' or 'keep'. if_not_exists: How to handle missing thread. Defaults to 'reject'. Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. Returns: Run: The created background run. Example Usage: background_run = await client.runs.create( thread_id="my_thread_id", assistant_id="my_assistant_id", input={"messages": [{"role": "user", "content": "hello!"}]}, metadata={"name":"my_run"}, config={"configurable": {"model_name": "openai"}}, interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"], interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"], webhook="https://my.fake.webhook.com", multitask_strategy="interrupt" ) print(background_run) -------------------------------------------------------------------------------- { 'run_id': 'my_run_id', 'thread_id': 'my_thread_id', 'assistant_id': 'my_assistant_id', 'created_at': '2024-07-25T15:35:42.598503+00:00', 'updated_at': '2024-07-25T15:35:42.598503+00:00', 'metadata': {}, 'status': 'pending', 'kwargs': { 'input': { 'messages': [ { 'role': 'user', 'content': 'how are you?' } ] }, 'config': { 'metadata': { 'created_by': 'system' }, 'configurable': { 'run_id': 'my_run_id', 'user_id': None, 'graph_id': 'agent', 'thread_id': 'my_thread_id', 'checkpoint_id': None, 'model_name': "openai", 'assistant_id': 'my_assistant_id' } }, 'webhook': "https://my.fake.webhook.com", 'temporary': False, 'stream_mode': ['values'], 'feedback_keys': None, 'interrupt_after': ["node_to_stop_after_1","node_to_stop_after_2"], 'interrupt_before': ["node_to_stop_before_1","node_to_stop_before_2"] }, 'multitask_strategy': 'interrupt' } """ # noqa: E501 payload = { "input": input, "command": command, "stream_mode": stream_mode, "stream_subgraphs": stream_subgraphs, "config": config, "metadata": metadata, "assistant_id": assistant_id, "interrupt_before": interrupt_before, "interrupt_after": interrupt_after, "webhook": webhook, "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, "if_not_exists": if_not_exists, "on_completion": on_completion, "after_seconds": after_seconds, } payload = {k: v for k, v in payload.items() if v is not None} if thread_id: return await self.http.post(f"/threads/{thread_id}/runs", json=payload) else: return await self.http.post("/runs", json=payload) async def create_batch(self, payloads: list[RunCreate]) -> list[Run]: """Create a batch of stateless background runs.""" def filter_payload(payload: RunCreate): return {k: v for k, v in payload.items() if v is not None} payloads = [filter_payload(payload) for payload in payloads] return await self.http.post("/runs/batch", json=payloads) @overload async def wait( self, thread_id: str, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, multitask_strategy: Optional[MultitaskStrategy] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, raise_error: bool = True, ) -> Union[list[dict], dict[str, Any]]: ... @overload async def wait( self, thread_id: None, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, raise_error: bool = True, ) -> Union[list[dict], dict[str, Any]]: ... async def wait( self, thread_id: Optional[str], assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, multitask_strategy: Optional[MultitaskStrategy] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, raise_error: bool = True, ) -> Union[list[dict], dict[str, Any]]: """Create a run, wait until it finishes and return the final state. Args: thread_id: the thread ID to create the run on. If None will create a stateless run. assistant_id: The assistant ID or graph name to run. If using graph name, will default to first assistant created from that graph. input: The input to the graph. command: A command to execute. Cannot be combined with input. metadata: Metadata to assign to the run. config: The configuration for the assistant. checkpoint: The checkpoint to resume from. interrupt_before: Nodes to interrupt immediately before they get executed. interrupt_after: Nodes to Nodes to interrupt immediately after they get executed. webhook: Webhook to call after LangGraph API call is done. on_disconnect: The disconnect mode to use. Must be one of 'cancel' or 'continue'. on_completion: Whether to delete or keep the thread created for a stateless run. Must be one of 'delete' or 'keep'. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. if_not_exists: How to handle missing thread. Defaults to 'reject'. Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. Returns: Union[list[dict], dict[str, Any]]: The output of the run. Example Usage: final_state_of_run = await client.runs.wait( thread_id=None, assistant_id="agent", input={"messages": [{"role": "user", "content": "how are you?"}]}, metadata={"name":"my_run"}, config={"configurable": {"model_name": "anthropic"}}, interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"], interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"], webhook="https://my.fake.webhook.com", multitask_strategy="interrupt" ) print(final_state_of_run) ------------------------------------------------------------------------------------------------------------------------------------------- { 'messages': [ { 'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'f51a862c-62fe-4866-863b-b0863e8ad78a', 'example': False }, { 'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.", 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'ai', 'name': None, 'id': 'run-bf1cd3c6-768f-4c16-b62d-ba6f17ad8b36', 'example': False, 'tool_calls': [], 'invalid_tool_calls': [], 'usage_metadata': None } ] } """ # noqa: E501 payload = { "input": input, "command": command, "config": config, "metadata": metadata, "assistant_id": assistant_id, "interrupt_before": interrupt_before, "interrupt_after": interrupt_after, "webhook": webhook, "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, "if_not_exists": if_not_exists, "on_disconnect": on_disconnect, "on_completion": on_completion, "after_seconds": after_seconds, } endpoint = ( f"/threads/{thread_id}/runs/wait" if thread_id is not None else "/runs/wait" ) response = await self.http.post( endpoint, json={k: v for k, v in payload.items() if v is not None} ) if ( raise_error and isinstance(response, dict) and "__error__" in response and isinstance(response["__error__"], dict) ): raise Exception( f"{response['__error__'].get('error')}: {response['__error__'].get('message')}" ) return response async def list( self, thread_id: str, *, limit: int = 10, offset: int = 0, status: Optional[RunStatus] = None, ) -> List[Run]: """List runs. Args: thread_id: The thread ID to list runs for. limit: The maximum number of results to return. offset: The number of results to skip. status: The status of the run to filter by. Returns: List[Run]: The runs for the thread. Example Usage: await client.runs.delete( thread_id="thread_id_to_delete", limit=5, offset=5, ) """ # noqa: E501 params = { "limit": limit, "offset": offset, } if status is not None: params["status"] = status return await self.http.get(f"/threads/{thread_id}/runs", params=params) async def get(self, thread_id: str, run_id: str) -> Run: """Get a run. Args: thread_id: The thread ID to get. run_id: The run ID to get. Returns: Run: Run object. Example Usage: run = await client.runs.get( thread_id="thread_id_to_delete", run_id="run_id_to_delete", ) """ # noqa: E501 return await self.http.get(f"/threads/{thread_id}/runs/{run_id}") async def cancel( self, thread_id: str, run_id: str, *, wait: bool = False, action: CancelAction = "interrupt", ) -> None: """Get a run. Args: thread_id: The thread ID to cancel. run_id: The run ID to cancek. wait: Whether to wait until run has completed. action: Action to take when cancelling the run. Possible values are `interrupt` or `rollback`. Default is `interrupt`. Returns: None Example Usage: await client.runs.cancel( thread_id="thread_id_to_cancel", run_id="run_id_to_cancel", wait=True, action="interrupt" ) """ # noqa: E501 return await self.http.post( f"/threads/{thread_id}/runs/{run_id}/cancel?wait={1 if wait else 0}&action={action}", json=None, ) async def join(self, thread_id: str, run_id: str) -> dict: """Block until a run is done. Returns the final state of the thread. Args: thread_id: The thread ID to join. run_id: The run ID to join. Returns: None Example Usage: result =await client.runs.join( thread_id="thread_id_to_join", run_id="run_id_to_join" ) """ # noqa: E501 return await self.http.get(f"/threads/{thread_id}/runs/{run_id}/join") def join_stream( self, thread_id: str, run_id: str, *, cancel_on_disconnect: bool = False ) -> AsyncIterator[StreamPart]: """Stream output from a run in real-time, until the run is done. Output is not buffered, so any output produced before this call will not be received here. Args: thread_id: The thread ID to join. run_id: The run ID to join. cancel_on_disconnect: Whether to cancel the run when the stream is disconnected. Returns: None Example Usage: await client.runs.join_stream( thread_id="thread_id_to_join", run_id="run_id_to_join" ) """ # noqa: E501 return self.http.stream( f"/threads/{thread_id}/runs/{run_id}/stream", "GET", params={"cancel_on_disconnect": cancel_on_disconnect}, ) async def delete(self, thread_id: str, run_id: str) -> None: """Delete a run. Args: thread_id: The thread ID to delete. run_id: The run ID to delete. Returns: None Example Usage: await client.runs.delete( thread_id="thread_id_to_delete", run_id="run_id_to_delete" ) """ # noqa: E501 await self.http.delete(f"/threads/{thread_id}/runs/{run_id}") class CronClient: """Client for managing recurrent runs (cron jobs) in LangGraph. A run is a single invocation of an assistant with optional input and config. This client allows scheduling recurring runs to occur automatically. Example: client = get_client() cron_job = await client.crons.create_for_thread( thread_id="thread_123", assistant_id="asst_456", schedule="0 9 * * *", input={"message": "Daily update"} ) """ def __init__(self, http_client: HttpClient) -> None: self.http = http_client async def create_for_thread( self, thread_id: str, assistant_id: str, *, schedule: str, input: Optional[dict] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, list[str]]] = None, interrupt_after: Optional[Union[All, list[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[str] = None, ) -> Run: """Create a cron job for a thread. Args: thread_id: the thread ID to run the cron job on. assistant_id: The assistant ID or graph name to use for the cron job. If using graph name, will default to first assistant created from that graph. schedule: The cron schedule to execute this job on. input: The input to the graph. metadata: Metadata to assign to the cron job runs. config: The configuration for the assistant. interrupt_before: Nodes to interrupt immediately before they get executed. interrupt_after: Nodes to Nodes to interrupt immediately after they get executed. webhook: Webhook to call after LangGraph API call is done. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. Returns: Run: The cron run. Example Usage: cron_run = await client.crons.create_for_thread( thread_id="my-thread-id", assistant_id="agent", schedule="27 15 * * *", input={"messages": [{"role": "user", "content": "hello!"}]}, metadata={"name":"my_run"}, config={"configurable": {"model_name": "openai"}}, interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"], interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"], webhook="https://my.fake.webhook.com", multitask_strategy="interrupt" ) """ # noqa: E501 payload = { "schedule": schedule, "input": input, "config": config, "metadata": metadata, "assistant_id": assistant_id, "interrupt_before": interrupt_before, "interrupt_after": interrupt_after, "webhook": webhook, } if multitask_strategy: payload["multitask_strategy"] = multitask_strategy payload = {k: v for k, v in payload.items() if v is not None} return await self.http.post(f"/threads/{thread_id}/runs/crons", json=payload) async def create( self, assistant_id: str, *, schedule: str, input: Optional[dict] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, list[str]]] = None, interrupt_after: Optional[Union[All, list[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[str] = None, ) -> Run: """Create a cron run. Args: assistant_id: The assistant ID or graph name to use for the cron job. If using graph name, will default to first assistant created from that graph. schedule: The cron schedule to execute this job on. input: The input to the graph. metadata: Metadata to assign to the cron job runs. config: The configuration for the assistant. interrupt_before: Nodes to interrupt immediately before they get executed. interrupt_after: Nodes to Nodes to interrupt immediately after they get executed. webhook: Webhook to call after LangGraph API call is done. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. Returns: Run: The cron run. Example Usage: cron_run = client.crons.create( assistant_id="agent", schedule="27 15 * * *", input={"messages": [{"role": "user", "content": "hello!"}]}, metadata={"name":"my_run"}, config={"configurable": {"model_name": "openai"}}, interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"], interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"], webhook="https://my.fake.webhook.com", multitask_strategy="interrupt" ) """ # noqa: E501 payload = { "schedule": schedule, "input": input, "config": config, "metadata": metadata, "assistant_id": assistant_id, "interrupt_before": interrupt_before, "interrupt_after": interrupt_after, "webhook": webhook, } if multitask_strategy: payload["multitask_strategy"] = multitask_strategy payload = {k: v for k, v in payload.items() if v is not None} return await self.http.post("/runs/crons", json=payload) async def delete(self, cron_id: str) -> None: """Delete a cron. Args: cron_id: The cron ID to delete. Returns: None Example Usage: await client.crons.delete( cron_id="cron_to_delete" ) """ # noqa: E501 await self.http.delete(f"/runs/crons/{cron_id}") async def search( self, *, assistant_id: Optional[str] = None, thread_id: Optional[str] = None, limit: int = 10, offset: int = 0, ) -> list[Cron]: """Get a list of cron jobs. Args: assistant_id: The assistant ID or graph name to search for. thread_id: the thread ID to search for. limit: The maximum number of results to return. offset: The number of results to skip. Returns: list[Cron]: The list of cron jobs returned by the search, Example Usage: cron_jobs = await client.crons.search( assistant_id="my_assistant_id", thread_id="my_thread_id", limit=5, offset=5, ) print(cron_jobs) ---------------------------------------------------------- [ { 'cron_id': '1ef3cefa-4c09-6926-96d0-3dc97fd5e39b', 'assistant_id': 'my_assistant_id', 'thread_id': 'my_thread_id', 'user_id': None, 'payload': { 'input': {'start_time': ''}, 'schedule': '4 * * * *', 'assistant_id': 'my_assistant_id' }, 'schedule': '4 * * * *', 'next_run_date': '2024-07-25T17:04:00+00:00', 'end_time': None, 'created_at': '2024-07-08T06:02:23.073257+00:00', 'updated_at': '2024-07-08T06:02:23.073257+00:00' } ] """ # noqa: E501 payload = { "assistant_id": assistant_id, "thread_id": thread_id, "limit": limit, "offset": offset, } payload = {k: v for k, v in payload.items() if v is not None} return await self.http.post("/runs/crons/search", json=payload) class StoreClient: """Client for interacting with the graph's shared storage. The Store provides a key-value storage system for persisting data across graph executions, allowing for stateful operations and data sharing across threads. Example: client = get_client() await client.store.put_item(["users", "user123"], "mem-123451342", {"name": "Alice", "score": 100}) """ def __init__(self, http: HttpClient) -> None: self.http = http async def put_item( self, namespace: Sequence[str], /, key: str, value: dict[str, Any], index: Optional[Union[Literal[False], list[str]]] = None, ) -> None: """Store or update an item. Args: namespace: A list of strings representing the namespace path. key: The unique identifier for the item within the namespace. value: A dictionary containing the item's data. index: Controls search indexing - None (use defaults), False (disable), or list of field paths to index. Returns: None Example Usage: await client.store.put_item( ["documents", "user123"], key="item456", value={"title": "My Document", "content": "Hello World"} ) """ for label in namespace: if "." in label: raise ValueError( f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')." ) payload = {"namespace": namespace, "key": key, "value": value, "index": index} await self.http.put("/store/items", json=payload) async def get_item(self, namespace: Sequence[str], /, key: str) -> Item: """Retrieve a single item. Args: key: The unique identifier for the item. namespace: Optional list of strings representing the namespace path. Returns: Item: The retrieved item. Example Usage: item = await client.store.get_item( ["documents", "user123"], key="item456", ) print(item) ---------------------------------------------------------------- { 'namespace': ['documents', 'user123'], 'key': 'item456', 'value': {'title': 'My Document', 'content': 'Hello World'}, 'created_at': '2024-07-30T12:00:00Z', 'updated_at': '2024-07-30T12:00:00Z' } """ for label in namespace: if "." in label: raise ValueError( f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')." ) return await self.http.get( "/store/items", params={"namespace": ".".join(namespace), "key": key} ) async def delete_item(self, namespace: Sequence[str], /, key: str) -> None: """Delete an item. Args: key: The unique identifier for the item. namespace: Optional list of strings representing the namespace path. Returns: None Example Usage: await client.store.delete_item( ["documents", "user123"], key="item456", ) """ await self.http.delete( "/store/items", json={"namespace": namespace, "key": key} ) async def search_items( self, namespace_prefix: Sequence[str], /, filter: Optional[dict[str, Any]] = None, limit: int = 10, offset: int = 0, query: Optional[str] = None, ) -> SearchItemsResponse: """Search for items within a namespace prefix. Args: namespace_prefix: List of strings representing the namespace prefix. filter: Optional dictionary of key-value pairs to filter results. limit: Maximum number of items to return (default is 10). offset: Number of items to skip before returning results (default is 0). query: Optional query for natural language search. Returns: List[Item]: A list of items matching the search criteria. Example Usage: items = await client.store.search_items( ["documents"], filter={"author": "John Doe"}, limit=5, offset=0 ) print(items) ---------------------------------------------------------------- { "items": [ { "namespace": ["documents", "user123"], "key": "item789", "value": { "title": "Another Document", "author": "John Doe" }, "created_at": "2024-07-30T12:00:00Z", "updated_at": "2024-07-30T12:00:00Z" }, # ... additional items ... ] } """ payload = { "namespace_prefix": namespace_prefix, "filter": filter, "limit": limit, "offset": offset, "query": query, } return await self.http.post("/store/items/search", json=_provided_vals(payload)) async def list_namespaces( self, prefix: Optional[List[str]] = None, suffix: Optional[List[str]] = None, max_depth: Optional[int] = None, limit: int = 100, offset: int = 0, ) -> ListNamespaceResponse: """List namespaces with optional match conditions. Args: prefix: Optional list of strings representing the prefix to filter namespaces. suffix: Optional list of strings representing the suffix to filter namespaces. max_depth: Optional integer specifying the maximum depth of namespaces to return. limit: Maximum number of namespaces to return (default is 100). offset: Number of namespaces to skip before returning results (default is 0). Returns: List[List[str]]: A list of namespaces matching the criteria. Example Usage: namespaces = await client.store.list_namespaces( prefix=["documents"], max_depth=3, limit=10, offset=0 ) print(namespaces) ---------------------------------------------------------------- [ ["documents", "user123", "reports"], ["documents", "user456", "invoices"], ... ] """ payload = { "prefix": prefix, "suffix": suffix, "max_depth": max_depth, "limit": limit, "offset": offset, } return await self.http.post("/store/namespaces", json=_provided_vals(payload)) def get_sync_client( *, url: Optional[str] = None, api_key: Optional[str] = None, headers: Optional[dict[str, str]] = None, ) -> SyncLangGraphClient: """Get a synchronous LangGraphClient instance. Args: url: The URL of the LangGraph API. api_key: The API key. If not provided, it will be read from the environment. Precedence: 1. explicit argument 2. LANGGRAPH_API_KEY 3. LANGSMITH_API_KEY 4. LANGCHAIN_API_KEY headers: Optional custom headers Returns: SyncLangGraphClient: The top-level synchronous client for accessing AssistantsClient, ThreadsClient, RunsClient, and CronClient. Example: from langgraph_sdk import get_sync_client # get top-level synchronous LangGraphClient client = get_sync_client(url="http://localhost:8123") # example usage: client.<model>.<method_name>() assistant = client.assistants.get(assistant_id="some_uuid") """ if url is None: url = "http://localhost:8123" transport = httpx.HTTPTransport(retries=5) client = httpx.Client( base_url=url, transport=transport, timeout=httpx.Timeout(connect=5, read=300, write=300, pool=5), headers=get_headers(api_key, headers), ) return SyncLangGraphClient(client) class SyncLangGraphClient: """Synchronous client for interacting with the LangGraph API. This class provides synchronous access to LangGraph API endpoints for managing assistants, threads, runs, cron jobs, and data storage. Example: client = get_sync_client() assistant = client.assistants.get("asst_123") """ def __init__(self, client: httpx.Client) -> None: self.http = SyncHttpClient(client) self.assistants = SyncAssistantsClient(self.http) self.threads = SyncThreadsClient(self.http) self.runs = SyncRunsClient(self.http) self.crons = SyncCronClient(self.http) self.store = SyncStoreClient(self.http) class SyncHttpClient: def __init__(self, client: httpx.Client) -> None: self.client = client def get(self, path: str, *, params: Optional[QueryParamTypes] = None) -> Any: """Send a GET request.""" r = self.client.get(path, params=params) try: r.raise_for_status() except httpx.HTTPStatusError as e: body = r.read().decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e return decode_json(r) def post(self, path: str, *, json: Optional[dict]) -> Any: """Send a POST request.""" if json is not None: headers, content = encode_json(json) else: headers, content = {}, b"" r = self.client.post(path, headers=headers, content=content) try: r.raise_for_status() except httpx.HTTPStatusError as e: body = r.read().decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e return decode_json(r) def put(self, path: str, *, json: dict) -> Any: """Send a PUT request.""" headers, content = encode_json(json) r = self.client.put(path, headers=headers, content=content) try: r.raise_for_status() except httpx.HTTPStatusError as e: body = r.read().decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e return decode_json(r) def patch(self, path: str, *, json: dict) -> Any: """Send a PATCH request.""" headers, content = encode_json(json) r = self.client.patch(path, headers=headers, content=content) try: r.raise_for_status() except httpx.HTTPStatusError as e: body = r.read().decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e return decode_json(r) def delete(self, path: str, *, json: Optional[Any] = None) -> None: """Send a DELETE request.""" r = self.client.request("DELETE", path, json=json) try: r.raise_for_status() except httpx.HTTPStatusError as e: body = r.read().decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e def stream( self, path: str, method: str, *, json: Optional[dict] = None ) -> Iterator[StreamPart]: """Stream the results of a request using SSE.""" headers, content = encode_json(json) with self.client.stream(method, path, headers=headers, content=content) as res: # check status try: res.raise_for_status() except httpx.HTTPStatusError as e: body = (res.read()).decode() if sys.version_info >= (3, 11): e.add_note(body) else: logger.error(f"Error from langgraph-api: {body}", exc_info=e) raise e # check content type content_type = res.headers.get("content-type", "").partition(";")[0] if "text/event-stream" not in content_type: raise httpx.TransportError( "Expected response header Content-Type to contain 'text/event-stream', " f"got {content_type!r}" ) # parse SSE decoder = SSEDecoder() for line in iter_lines_raw(res): sse = decoder.decode(line.rstrip(b"\n")) if sse is not None: yield sse def encode_json(json: Any) -> tuple[dict[str, str], bytes]: body = orjson.dumps( json, orjson_default, orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS, ) content_length = str(len(body)) content_type = "application/json" headers = {"Content-Length": content_length, "Content-Type": content_type} return headers, body def decode_json(r: httpx.Response) -> Any: body = r.read() return orjson.loads(body if body else None) class SyncAssistantsClient: """Client for managing assistants in LangGraph synchronously. This class provides methods to interact with assistants, which are versioned configurations of your graph. Example: client = get_client() assistant = client.assistants.get("assistant_id_123") """ def __init__(self, http: SyncHttpClient) -> None: self.http = http def get(self, assistant_id: str) -> Assistant: """Get an assistant by ID. Args: assistant_id: The ID of the assistant to get. Returns: Assistant: Assistant Object. Example Usage: assistant = client.assistants.get( assistant_id="my_assistant_id" ) print(assistant) ---------------------------------------------------- { 'assistant_id': 'my_assistant_id', 'graph_id': 'agent', 'created_at': '2024-06-25T17:10:33.109781+00:00', 'updated_at': '2024-06-25T17:10:33.109781+00:00', 'config': {}, 'metadata': {'created_by': 'system'} } """ # noqa: E501 return self.http.get(f"/assistants/{assistant_id}") def get_graph( self, assistant_id: str, *, xray: Union[int, bool] = False ) -> dict[str, list[dict[str, Any]]]: """Get the graph of an assistant by ID. Args: assistant_id: The ID of the assistant to get the graph of. xray: Include graph representation of subgraphs. If an integer value is provided, only subgraphs with a depth less than or equal to the value will be included. Returns: Graph: The graph information for the assistant in JSON format. Example Usage: graph_info = client.assistants.get_graph( assistant_id="my_assistant_id" ) print(graph_info) -------------------------------------------------------------------------------------------------------------------------- { 'nodes': [ {'id': '__start__', 'type': 'schema', 'data': '__start__'}, {'id': '__end__', 'type': 'schema', 'data': '__end__'}, {'id': 'agent','type': 'runnable','data': {'id': ['langgraph', 'utils', 'RunnableCallable'],'name': 'agent'}}, ], 'edges': [ {'source': '__start__', 'target': 'agent'}, {'source': 'agent','target': '__end__'} ] } """ # noqa: E501 return self.http.get(f"/assistants/{assistant_id}/graph", params={"xray": xray}) def get_schemas(self, assistant_id: str) -> GraphSchema: """Get the schemas of an assistant by ID. Args: assistant_id: The ID of the assistant to get the schema of. Returns: GraphSchema: The graph schema for the assistant. Example Usage: schema = client.assistants.get_schemas( assistant_id="my_assistant_id" ) print(schema) ---------------------------------------------------------------------------------------------------------------------------- { 'graph_id': 'agent', 'state_schema': { 'title': 'LangGraphInput', '$ref': '#/definitions/AgentState', 'definitions': { 'BaseMessage': { 'title': 'BaseMessage', 'description': 'Base abstract Message class. Messages are the inputs and outputs of ChatModels.', 'type': 'object', 'properties': { 'content': { 'title': 'Content', 'anyOf': [ {'type': 'string'}, {'type': 'array','items': {'anyOf': [{'type': 'string'}, {'type': 'object'}]}} ] }, 'additional_kwargs': { 'title': 'Additional Kwargs', 'type': 'object' }, 'response_metadata': { 'title': 'Response Metadata', 'type': 'object' }, 'type': { 'title': 'Type', 'type': 'string' }, 'name': { 'title': 'Name', 'type': 'string' }, 'id': { 'title': 'Id', 'type': 'string' } }, 'required': ['content', 'type'] }, 'AgentState': { 'title': 'AgentState', 'type': 'object', 'properties': { 'messages': { 'title': 'Messages', 'type': 'array', 'items': {'$ref': '#/definitions/BaseMessage'} } }, 'required': ['messages'] } } }, 'config_schema': { 'title': 'Configurable', 'type': 'object', 'properties': { 'model_name': { 'title': 'Model Name', 'enum': ['anthropic', 'openai'], 'type': 'string' } } } } """ # noqa: E501 return self.http.get(f"/assistants/{assistant_id}/schemas") def get_subgraphs( self, assistant_id: str, namespace: Optional[str] = None, recurse: bool = False ) -> Subgraphs: """Get the schemas of an assistant by ID. Args: assistant_id: The ID of the assistant to get the schema of. Returns: Subgraphs: The graph schema for the assistant. """ # noqa: E501 if namespace is not None: return self.http.get( f"/assistants/{assistant_id}/subgraphs/{namespace}", params={"recurse": recurse}, ) else: return self.http.get( f"/assistants/{assistant_id}/subgraphs", params={"recurse": recurse}, ) def create( self, graph_id: Optional[str], config: Optional[Config] = None, *, metadata: Json = None, assistant_id: Optional[str] = None, if_exists: Optional[OnConflictBehavior] = None, name: Optional[str] = None, ) -> Assistant: """Create a new assistant. Useful when graph is configurable and you want to create different assistants based on different configurations. Args: graph_id: The ID of the graph the assistant should use. The graph ID is normally set in your langgraph.json configuration. config: Configuration to use for the graph. metadata: Metadata to add to assistant. assistant_id: Assistant ID to use, will default to a random UUID if not provided. if_exists: How to handle duplicate creation. Defaults to 'raise' under the hood. Must be either 'raise' (raise error if duplicate), or 'do_nothing' (return existing assistant). name: The name of the assistant. Defaults to 'Untitled' under the hood. Returns: Assistant: The created assistant. Example Usage: assistant = client.assistants.create( graph_id="agent", config={"configurable": {"model_name": "openai"}}, metadata={"number":1}, assistant_id="my-assistant-id", if_exists="do_nothing", name="my_name" ) """ # noqa: E501 payload: Dict[str, Any] = { "graph_id": graph_id, } if config: payload["config"] = config if metadata: payload["metadata"] = metadata if assistant_id: payload["assistant_id"] = assistant_id if if_exists: payload["if_exists"] = if_exists if name: payload["name"] = name return self.http.post("/assistants", json=payload) def update( self, assistant_id: str, *, graph_id: Optional[str] = None, config: Optional[Config] = None, metadata: Json = None, name: Optional[str] = None, ) -> Assistant: """Update an assistant. Use this to point to a different graph, update the configuration, or change the metadata of an assistant. Args: assistant_id: Assistant to update. graph_id: The ID of the graph the assistant should use. The graph ID is normally set in your langgraph.json configuration. If None, assistant will keep pointing to same graph. config: Configuration to use for the graph. metadata: Metadata to merge with existing assistant metadata. name: The new name for the assistant. Returns: Assistant: The updated assistant. Example Usage: assistant = client.assistants.update( assistant_id='e280dad7-8618-443f-87f1-8e41841c180f', graph_id="other-graph", config={"configurable": {"model_name": "anthropic"}}, metadata={"number":2} ) """ # noqa: E501 payload: Dict[str, Any] = {} if graph_id: payload["graph_id"] = graph_id if config: payload["config"] = config if metadata: payload["metadata"] = metadata if name: payload["name"] = name return self.http.patch( f"/assistants/{assistant_id}", json=payload, ) def delete( self, assistant_id: str, ) -> None: """Delete an assistant. Args: assistant_id: The assistant ID to delete. Returns: None Example Usage: client.assistants.delete( assistant_id="my_assistant_id" ) """ # noqa: E501 self.http.delete(f"/assistants/{assistant_id}") def search( self, *, metadata: Json = None, graph_id: Optional[str] = None, limit: int = 10, offset: int = 0, ) -> list[Assistant]: """Search for assistants. Args: metadata: Metadata to filter by. Exact match filter for each KV pair. graph_id: The ID of the graph to filter by. The graph ID is normally set in your langgraph.json configuration. limit: The maximum number of results to return. offset: The number of results to skip. Returns: list[Assistant]: A list of assistants. Example Usage: assistants = client.assistants.search( metadata = {"name":"my_name"}, graph_id="my_graph_id", limit=5, offset=5 ) """ payload: Dict[str, Any] = { "limit": limit, "offset": offset, } if metadata: payload["metadata"] = metadata if graph_id: payload["graph_id"] = graph_id return self.http.post( "/assistants/search", json=payload, ) def get_versions( self, assistant_id: str, metadata: Json = None, limit: int = 10, offset: int = 0, ) -> list[AssistantVersion]: """List all versions of an assistant. Args: assistant_id: The assistant ID to get versions for. metadata: Metadata to filter versions by. Exact match filter for each KV pair. limit: The maximum number of versions to return. offset: The number of versions to skip. Returns: list[Assistant]: A list of assistants. Example Usage: assistant_versions = await client.assistants.get_versions( assistant_id="my_assistant_id" ) """ # noqa: E501 payload: Dict[str, Any] = { "limit": limit, "offset": offset, } if metadata: payload["metadata"] = metadata return self.http.post(f"/assistants/{assistant_id}/versions", json=payload) def set_latest(self, assistant_id: str, version: int) -> Assistant: """Change the version of an assistant. Args: assistant_id: The assistant ID to delete. version: The version to change to. Returns: Assistant: Assistant Object. Example Usage: new_version_assistant = await client.assistants.set_latest( assistant_id="my_assistant_id", version=3 ) """ # noqa: E501 payload: Dict[str, Any] = {"version": version} return self.http.post(f"/assistants/{assistant_id}/latest", json=payload) class SyncThreadsClient: """Synchronous client for managing threads in LangGraph. This class provides methods to create, retrieve, and manage threads, which represent conversations or stateful interactions. Example: client = get_sync_client() thread = client.threads.create(metadata={"user_id": "123"}) """ def __init__(self, http: SyncHttpClient) -> None: self.http = http def get(self, thread_id: str) -> Thread: """Get a thread by ID. Args: thread_id: The ID of the thread to get. Returns: Thread: Thread object. Example Usage: thread = client.threads.get( thread_id="my_thread_id" ) print(thread) ----------------------------------------------------- { 'thread_id': 'my_thread_id', 'created_at': '2024-07-18T18:35:15.540834+00:00', 'updated_at': '2024-07-18T18:35:15.540834+00:00', 'metadata': {'graph_id': 'agent'} } """ # noqa: E501 return self.http.get(f"/threads/{thread_id}") def create( self, *, metadata: Json = None, thread_id: Optional[str] = None, if_exists: Optional[OnConflictBehavior] = None, ) -> Thread: """Create a new thread. Args: metadata: Metadata to add to thread. thread_id: ID of thread. If None, ID will be a randomly generated UUID. if_exists: How to handle duplicate creation. Defaults to 'raise' under the hood. Must be either 'raise' (raise error if duplicate), or 'do_nothing' (return existing thread). Returns: Thread: The created thread. Example Usage: thread = client.threads.create( metadata={"number":1}, thread_id="my-thread-id", if_exists="raise" ) """ # noqa: E501 payload: Dict[str, Any] = {} if thread_id: payload["thread_id"] = thread_id if metadata: payload["metadata"] = metadata if if_exists: payload["if_exists"] = if_exists return self.http.post("/threads", json=payload) def update(self, thread_id: str, *, metadata: dict[str, Any]) -> Thread: """Update a thread. Args: thread_id: ID of thread to update. metadata: Metadata to merge with existing thread metadata. Returns: Thread: The created thread. Example Usage: thread = client.threads.update( thread_id="my-thread-id", metadata={"number":1}, ) """ # noqa: E501 return self.http.patch(f"/threads/{thread_id}", json={"metadata": metadata}) def delete(self, thread_id: str) -> None: """Delete a thread. Args: thread_id: The ID of the thread to delete. Returns: None Example Usage: client.threads.delete( thread_id="my_thread_id" ) """ # noqa: E501 self.http.delete(f"/threads/{thread_id}") def search( self, *, metadata: Json = None, values: Json = None, status: Optional[ThreadStatus] = None, limit: int = 10, offset: int = 0, ) -> list[Thread]: """Search for threads. Args: metadata: Thread metadata to filter on. values: State values to filter on. status: Thread status to filter on. Must be one of 'idle', 'busy', 'interrupted' or 'error'. limit: Limit on number of threads to return. offset: Offset in threads table to start search from. Returns: list[Thread]: List of the threads matching the search parameters. Example Usage: threads = client.threads.search( metadata={"number":1}, status="interrupted", limit=15, offset=5 ) """ # noqa: E501 payload: Dict[str, Any] = { "limit": limit, "offset": offset, } if metadata: payload["metadata"] = metadata if values: payload["values"] = values if status: payload["status"] = status return self.http.post( "/threads/search", json=payload, ) def copy(self, thread_id: str) -> None: """Copy a thread. Args: thread_id: The ID of the thread to copy. Returns: None Example Usage: client.threads.copy( thread_id="my_thread_id" ) """ # noqa: E501 return self.http.post(f"/threads/{thread_id}/copy", json=None) def get_state( self, thread_id: str, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, # deprecated *, subgraphs: bool = False, ) -> ThreadState: """Get the state of a thread. Args: thread_id: The ID of the thread to get the state of. checkpoint: The checkpoint to get the state of. subgraphs: Include subgraphs states. Returns: ThreadState: the thread of the state. Example Usage: thread_state = client.threads.get_state( thread_id="my_thread_id", checkpoint_id="my_checkpoint_id" ) print(thread_state) ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- { 'values': { 'messages': [ { 'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10', 'example': False }, { 'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.", 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'ai', 'name': None, 'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b', 'example': False, 'tool_calls': [], 'invalid_tool_calls': [], 'usage_metadata': None } ] }, 'next': [], 'checkpoint': { 'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2', 'checkpoint_ns': '', 'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1' } 'metadata': { 'step': 1, 'run_id': '1ef4a9b8-d7da-679a-a45a-872054341df2', 'source': 'loop', 'writes': { 'agent': { 'messages': [ { 'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b', 'name': None, 'type': 'ai', 'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.", 'example': False, 'tool_calls': [], 'usage_metadata': None, 'additional_kwargs': {}, 'response_metadata': {}, 'invalid_tool_calls': [] } ] } }, 'user_id': None, 'graph_id': 'agent', 'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2', 'created_by': 'system', 'assistant_id': 'fe096781-5601-53d2-b2f6-0d3403f7e9ca'}, 'created_at': '2024-07-25T15:35:44.184703+00:00', 'parent_config': { 'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2', 'checkpoint_ns': '', 'checkpoint_id': '1ef4a9b8-d80d-6fa7-8000-9300467fad0f' } } """ # noqa: E501 if checkpoint: return self.http.post( f"/threads/{thread_id}/state/checkpoint", json={"checkpoint": checkpoint, "subgraphs": subgraphs}, ) elif checkpoint_id: return self.http.get( f"/threads/{thread_id}/state/{checkpoint_id}", params={"subgraphs": subgraphs}, ) else: return self.http.get( f"/threads/{thread_id}/state", params={"subgraphs": subgraphs}, ) def update_state( self, thread_id: str, values: Optional[Union[dict, Sequence[dict]]], *, as_node: Optional[str] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, # deprecated ) -> ThreadUpdateStateResponse: """Update the state of a thread. Args: thread_id: The ID of the thread to update. values: The values to update the state with. as_node: Update the state as if this node had just executed. checkpoint: The checkpoint to update the state of. Returns: ThreadUpdateStateResponse: Response after updating a thread's state. Example Usage: response = client.threads.update_state( thread_id="my_thread_id", values={"messages":[{"role": "user", "content": "hello!"}]}, as_node="my_node", ) print(response) ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- { 'checkpoint': { 'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2', 'checkpoint_ns': '', 'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1', 'checkpoint_map': {} } } """ # noqa: E501 payload: Dict[str, Any] = { "values": values, } if checkpoint_id: payload["checkpoint_id"] = checkpoint_id if checkpoint: payload["checkpoint"] = checkpoint if as_node: payload["as_node"] = as_node return self.http.post(f"/threads/{thread_id}/state", json=payload) def get_history( self, thread_id: str, *, limit: int = 10, before: Optional[str | Checkpoint] = None, metadata: Optional[dict] = None, checkpoint: Optional[Checkpoint] = None, ) -> list[ThreadState]: """Get the state history of a thread. Args: thread_id: The ID of the thread to get the state history for. checkpoint: Return states for this subgraph. If empty defaults to root. limit: The maximum number of states to return. before: Return states before this checkpoint. metadata: Filter states by metadata key-value pairs. Returns: list[ThreadState]: the state history of the thread. Example Usage: thread_state = client.threads.get_history( thread_id="my_thread_id", limit=5, before="my_timestamp", metadata={"name":"my_name"} ) """ # noqa: E501 payload: Dict[str, Any] = { "limit": limit, } if before: payload["before"] = before if metadata: payload["metadata"] = metadata if checkpoint: payload["checkpoint"] = checkpoint return self.http.post(f"/threads/{thread_id}/history", json=payload) class SyncRunsClient: """Synchronous client for managing runs in LangGraph. This class provides methods to create, retrieve, and manage runs, which represent individual executions of graphs. Example: client = get_sync_client() run = client.runs.create(thread_id="thread_123", assistant_id="asst_456") """ def __init__(self, http: SyncHttpClient) -> None: self.http = http @overload def stream( self, thread_id: str, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Iterator[StreamPart]: ... @overload def stream( self, thread_id: None, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, if_not_exists: Optional[IfNotExists] = None, webhook: Optional[str] = None, after_seconds: Optional[int] = None, ) -> Iterator[StreamPart]: ... def stream( self, thread_id: Optional[str], assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, feedback_keys: Optional[Sequence[str]] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Iterator[StreamPart]: """Create a run and stream the results. Args: thread_id: the thread ID to assign to the thread. If None will create a stateless run. assistant_id: The assistant ID or graph name to stream from. If using graph name, will default to first assistant created from that graph. input: The input to the graph. command: The command to execute. stream_mode: The stream mode(s) to use. stream_subgraphs: Whether to stream output from subgraphs. metadata: Metadata to assign to the run. config: The configuration for the assistant. checkpoint: The checkpoint to resume from. interrupt_before: Nodes to interrupt immediately before they get executed. interrupt_after: Nodes to Nodes to interrupt immediately after they get executed. feedback_keys: Feedback keys to assign to run. on_disconnect: The disconnect mode to use. Must be one of 'cancel' or 'continue'. on_completion: Whether to delete or keep the thread created for a stateless run. Must be one of 'delete' or 'keep'. webhook: Webhook to call after LangGraph API call is done. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. if_not_exists: How to handle missing thread. Defaults to 'reject'. Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. Returns: Iterator[StreamPart]: Iterator of stream results. Example Usage: async for chunk in client.runs.stream( thread_id=None, assistant_id="agent", input={"messages": [{"role": "user", "content": "how are you?"}]}, stream_mode=["values","debug"], metadata={"name":"my_run"}, config={"configurable": {"model_name": "anthropic"}}, interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"], interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"], feedback_keys=["my_feedback_key_1","my_feedback_key_2"], webhook="https://my.fake.webhook.com", multitask_strategy="interrupt" ): print(chunk) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ StreamPart(event='metadata', data={'run_id': '1ef4a9b8-d7da-679a-a45a-872054341df2'}) StreamPart(event='values', data={'messages': [{'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10', 'example': False}]}) StreamPart(event='values', data={'messages': [{'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10', 'example': False}, {'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.", 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'ai', 'name': None, 'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b', 'example': False, 'tool_calls': [], 'invalid_tool_calls': [], 'usage_metadata': None}]}) StreamPart(event='end', data=None) """ # noqa: E501 payload = { "input": input, "command": command, "config": config, "metadata": metadata, "stream_mode": stream_mode, "stream_subgraphs": stream_subgraphs, "assistant_id": assistant_id, "interrupt_before": interrupt_before, "interrupt_after": interrupt_after, "feedback_keys": feedback_keys, "webhook": webhook, "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, "if_not_exists": if_not_exists, "on_disconnect": on_disconnect, "on_completion": on_completion, "after_seconds": after_seconds, } endpoint = ( f"/threads/{thread_id}/runs/stream" if thread_id is not None else "/runs/stream" ) return self.http.stream( endpoint, "POST", json={k: v for k, v in payload.items() if v is not None} ) @overload def create( self, thread_id: None, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_completion: Optional[OnCompletionBehavior] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: ... @overload def create( self, thread_id: str, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: ... def create( self, thread_id: Optional[str], assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[MultitaskStrategy] = None, on_completion: Optional[OnCompletionBehavior] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Run: """Create a background run. Args: thread_id: the thread ID to assign to the thread. If None will create a stateless run. assistant_id: The assistant ID or graph name to stream from. If using graph name, will default to first assistant created from that graph. input: The input to the graph. command: The command to execute. stream_mode: The stream mode(s) to use. stream_subgraphs: Whether to stream output from subgraphs. metadata: Metadata to assign to the run. config: The configuration for the assistant. checkpoint: The checkpoint to resume from. interrupt_before: Nodes to interrupt immediately before they get executed. interrupt_after: Nodes to Nodes to interrupt immediately after they get executed. webhook: Webhook to call after LangGraph API call is done. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. on_completion: Whether to delete or keep the thread created for a stateless run. Must be one of 'delete' or 'keep'. if_not_exists: How to handle missing thread. Defaults to 'reject'. Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. Returns: Run: The created background run. Example Usage: background_run = client.runs.create( thread_id="my_thread_id", assistant_id="my_assistant_id", input={"messages": [{"role": "user", "content": "hello!"}]}, metadata={"name":"my_run"}, config={"configurable": {"model_name": "openai"}}, interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"], interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"], webhook="https://my.fake.webhook.com", multitask_strategy="interrupt" ) print(background_run) -------------------------------------------------------------------------------- { 'run_id': 'my_run_id', 'thread_id': 'my_thread_id', 'assistant_id': 'my_assistant_id', 'created_at': '2024-07-25T15:35:42.598503+00:00', 'updated_at': '2024-07-25T15:35:42.598503+00:00', 'metadata': {}, 'status': 'pending', 'kwargs': { 'input': { 'messages': [ { 'role': 'user', 'content': 'how are you?' } ] }, 'config': { 'metadata': { 'created_by': 'system' }, 'configurable': { 'run_id': 'my_run_id', 'user_id': None, 'graph_id': 'agent', 'thread_id': 'my_thread_id', 'checkpoint_id': None, 'model_name': "openai", 'assistant_id': 'my_assistant_id' } }, 'webhook': "https://my.fake.webhook.com", 'temporary': False, 'stream_mode': ['values'], 'feedback_keys': None, 'interrupt_after': ["node_to_stop_after_1","node_to_stop_after_2"], 'interrupt_before': ["node_to_stop_before_1","node_to_stop_before_2"] }, 'multitask_strategy': 'interrupt' } """ # noqa: E501 payload = { "input": input, "command": command, "stream_mode": stream_mode, "stream_subgraphs": stream_subgraphs, "config": config, "metadata": metadata, "assistant_id": assistant_id, "interrupt_before": interrupt_before, "interrupt_after": interrupt_after, "webhook": webhook, "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, "if_not_exists": if_not_exists, "on_completion": on_completion, "after_seconds": after_seconds, } payload = {k: v for k, v in payload.items() if v is not None} if thread_id: return self.http.post(f"/threads/{thread_id}/runs", json=payload) else: return self.http.post("/runs", json=payload) def create_batch(self, payloads: list[RunCreate]) -> list[Run]: """Create a batch of stateless background runs.""" def filter_payload(payload: RunCreate): return {k: v for k, v in payload.items() if v is not None} payloads = [filter_payload(payload) for payload in payloads] return self.http.post("/runs/batch", json=payloads) @overload def wait( self, thread_id: str, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, multitask_strategy: Optional[MultitaskStrategy] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Union[list[dict], dict[str, Any]]: ... @overload def wait( self, thread_id: None, assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Union[list[dict], dict[str, Any]]: ... def wait( self, thread_id: Optional[str], assistant_id: str, *, input: Optional[dict] = None, command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, checkpoint_id: Optional[str] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, webhook: Optional[str] = None, on_disconnect: Optional[DisconnectMode] = None, on_completion: Optional[OnCompletionBehavior] = None, multitask_strategy: Optional[MultitaskStrategy] = None, if_not_exists: Optional[IfNotExists] = None, after_seconds: Optional[int] = None, ) -> Union[list[dict], dict[str, Any]]: """Create a run, wait until it finishes and return the final state. Args: thread_id: the thread ID to create the run on. If None will create a stateless run. assistant_id: The assistant ID or graph name to run. If using graph name, will default to first assistant created from that graph. input: The input to the graph. command: The command to execute. metadata: Metadata to assign to the run. config: The configuration for the assistant. checkpoint: The checkpoint to resume from. interrupt_before: Nodes to interrupt immediately before they get executed. interrupt_after: Nodes to Nodes to interrupt immediately after they get executed. webhook: Webhook to call after LangGraph API call is done. on_disconnect: The disconnect mode to use. Must be one of 'cancel' or 'continue'. on_completion: Whether to delete or keep the thread created for a stateless run. Must be one of 'delete' or 'keep'. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. if_not_exists: How to handle missing thread. Defaults to 'reject'. Must be either 'reject' (raise error if missing), or 'create' (create new thread). after_seconds: The number of seconds to wait before starting the run. Use to schedule future runs. Returns: Union[list[dict], dict[str, Any]]: The output of the run. Example Usage: final_state_of_run = client.runs.wait( thread_id=None, assistant_id="agent", input={"messages": [{"role": "user", "content": "how are you?"}]}, metadata={"name":"my_run"}, config={"configurable": {"model_name": "anthropic"}}, interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"], interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"], webhook="https://my.fake.webhook.com", multitask_strategy="interrupt" ) print(final_state_of_run) ------------------------------------------------------------------------------------------------------------------------------------------- { 'messages': [ { 'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'f51a862c-62fe-4866-863b-b0863e8ad78a', 'example': False }, { 'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.", 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'ai', 'name': None, 'id': 'run-bf1cd3c6-768f-4c16-b62d-ba6f17ad8b36', 'example': False, 'tool_calls': [], 'invalid_tool_calls': [], 'usage_metadata': None } ] } """ # noqa: E501 payload = { "input": input, "command": command, "config": config, "metadata": metadata, "assistant_id": assistant_id, "interrupt_before": interrupt_before, "interrupt_after": interrupt_after, "webhook": webhook, "checkpoint": checkpoint, "checkpoint_id": checkpoint_id, "multitask_strategy": multitask_strategy, "if_not_exists": if_not_exists, "on_disconnect": on_disconnect, "on_completion": on_completion, "after_seconds": after_seconds, } endpoint = ( f"/threads/{thread_id}/runs/wait" if thread_id is not None else "/runs/wait" ) return self.http.post( endpoint, json={k: v for k, v in payload.items() if v is not None} ) def list(self, thread_id: str, *, limit: int = 10, offset: int = 0) -> List[Run]: """List runs. Args: thread_id: The thread ID to list runs for. limit: The maximum number of results to return. offset: The number of results to skip. Returns: List[Run]: The runs for the thread. Example Usage: client.runs.delete( thread_id="thread_id_to_delete", limit=5, offset=5, ) """ # noqa: E501 return self.http.get(f"/threads/{thread_id}/runs?limit={limit}&offset={offset}") def get(self, thread_id: str, run_id: str) -> Run: """Get a run. Args: thread_id: The thread ID to get. run_id: The run ID to get. Returns: Run: Run object. Example Usage: run = client.runs.get( thread_id="thread_id_to_delete", run_id="run_id_to_delete", ) """ # noqa: E501 return self.http.get(f"/threads/{thread_id}/runs/{run_id}") def cancel( self, thread_id: str, run_id: str, *, wait: bool = False, action: CancelAction = "interrupt", ) -> None: """Get a run. Args: thread_id: The thread ID to cancel. run_id: The run ID to cancek. wait: Whether to wait until run has completed. action: Action to take when cancelling the run. Possible values are `interrupt` or `rollback`. Default is `interrupt`. Returns: None Example Usage: client.runs.cancel( thread_id="thread_id_to_cancel", run_id="run_id_to_cancel", wait=True, action="interrupt" ) """ # noqa: E501 return self.http.post( f"/threads/{thread_id}/runs/{run_id}/cancel?wait={1 if wait else 0}&action={action}", json=None, ) def join(self, thread_id: str, run_id: str) -> dict: """Block until a run is done. Returns the final state of the thread. Args: thread_id: The thread ID to join. run_id: The run ID to join. Returns: None Example Usage: client.runs.join( thread_id="thread_id_to_join", run_id="run_id_to_join" ) """ # noqa: E501 return self.http.get(f"/threads/{thread_id}/runs/{run_id}/join") def join_stream(self, thread_id: str, run_id: str) -> Iterator[StreamPart]: """Stream output from a run in real-time, until the run is done. Output is not buffered, so any output produced before this call will not be received here. Args: thread_id: The thread ID to join. run_id: The run ID to join. Returns: None Example Usage: client.runs.join_stream( thread_id="thread_id_to_join", run_id="run_id_to_join" ) """ # noqa: E501 return self.http.stream(f"/threads/{thread_id}/runs/{run_id}/stream", "GET") def delete(self, thread_id: str, run_id: str) -> None: """Delete a run. Args: thread_id: The thread ID to delete. run_id: The run ID to delete. Returns: None Example Usage: client.runs.delete( thread_id="thread_id_to_delete", run_id="run_id_to_delete" ) """ # noqa: E501 self.http.delete(f"/threads/{thread_id}/runs/{run_id}") class SyncCronClient: """Synchronous client for managing cron jobs in LangGraph. This class provides methods to create and manage scheduled tasks (cron jobs) for automated graph executions. Example: client = get_sync_client() cron_job = client.crons.create_for_thread(thread_id="thread_123", assistant_id="asst_456", schedule="0 * * * *") """ def __init__(self, http_client: SyncHttpClient) -> None: self.http = http_client def create_for_thread( self, thread_id: str, assistant_id: str, *, schedule: str, input: Optional[dict] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, list[str]]] = None, interrupt_after: Optional[Union[All, list[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[str] = None, ) -> Run: """Create a cron job for a thread. Args: thread_id: the thread ID to run the cron job on. assistant_id: The assistant ID or graph name to use for the cron job. If using graph name, will default to first assistant created from that graph. schedule: The cron schedule to execute this job on. input: The input to the graph. metadata: Metadata to assign to the cron job runs. config: The configuration for the assistant. interrupt_before: Nodes to interrupt immediately before they get executed. interrupt_after: Nodes to Nodes to interrupt immediately after they get executed. webhook: Webhook to call after LangGraph API call is done. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. Returns: Run: The cron run. Example Usage: cron_run = client.crons.create_for_thread( thread_id="my-thread-id", assistant_id="agent", schedule="27 15 * * *", input={"messages": [{"role": "user", "content": "hello!"}]}, metadata={"name":"my_run"}, config={"configurable": {"model_name": "openai"}}, interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"], interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"], webhook="https://my.fake.webhook.com", multitask_strategy="interrupt" ) """ # noqa: E501 payload = { "schedule": schedule, "input": input, "config": config, "metadata": metadata, "assistant_id": assistant_id, "interrupt_before": interrupt_before, "interrupt_after": interrupt_after, "webhook": webhook, } if multitask_strategy: payload["multitask_strategy"] = multitask_strategy payload = {k: v for k, v in payload.items() if v is not None} return self.http.post(f"/threads/{thread_id}/runs/crons", json=payload) def create( self, assistant_id: str, *, schedule: str, input: Optional[dict] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, list[str]]] = None, interrupt_after: Optional[Union[All, list[str]]] = None, webhook: Optional[str] = None, multitask_strategy: Optional[str] = None, ) -> Run: """Create a cron run. Args: assistant_id: The assistant ID or graph name to use for the cron job. If using graph name, will default to first assistant created from that graph. schedule: The cron schedule to execute this job on. input: The input to the graph. metadata: Metadata to assign to the cron job runs. config: The configuration for the assistant. interrupt_before: Nodes to interrupt immediately before they get executed. interrupt_after: Nodes to Nodes to interrupt immediately after they get executed. webhook: Webhook to call after LangGraph API call is done. multitask_strategy: Multitask strategy to use. Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'. Returns: Run: The cron run. Example Usage: cron_run = client.crons.create( assistant_id="agent", schedule="27 15 * * *", input={"messages": [{"role": "user", "content": "hello!"}]}, metadata={"name":"my_run"}, config={"configurable": {"model_name": "openai"}}, interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"], interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"], webhook="https://my.fake.webhook.com", multitask_strategy="interrupt" ) """ # noqa: E501 payload = { "schedule": schedule, "input": input, "config": config, "metadata": metadata, "assistant_id": assistant_id, "interrupt_before": interrupt_before, "interrupt_after": interrupt_after, "webhook": webhook, } if multitask_strategy: payload["multitask_strategy"] = multitask_strategy payload = {k: v for k, v in payload.items() if v is not None} return self.http.post("/runs/crons", json=payload) def delete(self, cron_id: str) -> None: """Delete a cron. Args: cron_id: The cron ID to delete. Returns: None Example Usage: client.crons.delete( cron_id="cron_to_delete" ) """ # noqa: E501 self.http.delete(f"/runs/crons/{cron_id}") def search( self, *, assistant_id: Optional[str] = None, thread_id: Optional[str] = None, limit: int = 10, offset: int = 0, ) -> list[Cron]: """Get a list of cron jobs. Args: assistant_id: The assistant ID or graph name to search for. thread_id: the thread ID to search for. limit: The maximum number of results to return. offset: The number of results to skip. Returns: list[Cron]: The list of cron jobs returned by the search, Example Usage: cron_jobs = client.crons.search( assistant_id="my_assistant_id", thread_id="my_thread_id", limit=5, offset=5, ) print(cron_jobs) ---------------------------------------------------------- [ { 'cron_id': '1ef3cefa-4c09-6926-96d0-3dc97fd5e39b', 'assistant_id': 'my_assistant_id', 'thread_id': 'my_thread_id', 'user_id': None, 'payload': { 'input': {'start_time': ''}, 'schedule': '4 * * * *', 'assistant_id': 'my_assistant_id' }, 'schedule': '4 * * * *', 'next_run_date': '2024-07-25T17:04:00+00:00', 'end_time': None, 'created_at': '2024-07-08T06:02:23.073257+00:00', 'updated_at': '2024-07-08T06:02:23.073257+00:00' } ] """ # noqa: E501 payload = { "assistant_id": assistant_id, "thread_id": thread_id, "limit": limit, "offset": offset, } payload = {k: v for k, v in payload.items() if v is not None} return self.http.post("/runs/crons/search", json=payload) class SyncStoreClient: """A client for synchronous operations on a key-value store. Provides methods to interact with a remote key-value store, allowing storage and retrieval of items within namespaced hierarchies. Example: client = get_sync_client() client.store.put_item(["users", "profiles"], "user123", {"name": "Alice", "age": 30}) """ def __init__(self, http: SyncHttpClient) -> None: self.http = http def put_item( self, namespace: Sequence[str], /, key: str, value: dict[str, Any], index: Optional[Union[Literal[False], list[str]]] = None, ) -> None: """Store or update an item. Args: namespace: A list of strings representing the namespace path. key: The unique identifier for the item within the namespace. value: A dictionary containing the item's data. index: Controls search indexing - None (use defaults), False (disable), or list of field paths to index. Returns: None Example Usage: client.store.put_item( ["documents", "user123"], key="item456", value={"title": "My Document", "content": "Hello World"} ) """ for label in namespace: if "." in label: raise ValueError( f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')." ) payload = { "namespace": namespace, "key": key, "value": value, "index": index, } self.http.put("/store/items", json=payload) def get_item(self, namespace: Sequence[str], /, key: str) -> Item: """Retrieve a single item. Args: key: The unique identifier for the item. namespace: Optional list of strings representing the namespace path. Returns: Item: The retrieved item. Example Usage: item = client.store.get_item( ["documents", "user123"], key="item456", ) print(item) ---------------------------------------------------------------- { 'namespace': ['documents', 'user123'], 'key': 'item456', 'value': {'title': 'My Document', 'content': 'Hello World'}, 'created_at': '2024-07-30T12:00:00Z', 'updated_at': '2024-07-30T12:00:00Z' } """ for label in namespace: if "." in label: raise ValueError( f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')." ) return self.http.get( "/store/items", params={"key": key, "namespace": ".".join(namespace)} ) def delete_item(self, namespace: Sequence[str], /, key: str) -> None: """Delete an item. Args: key: The unique identifier for the item. namespace: Optional list of strings representing the namespace path. Returns: None Example Usage: client.store.delete_item( ["documents", "user123"], key="item456", ) """ self.http.delete("/store/items", json={"key": key, "namespace": namespace}) def search_items( self, namespace_prefix: Sequence[str], /, filter: Optional[dict[str, Any]] = None, limit: int = 10, offset: int = 0, query: Optional[str] = None, ) -> SearchItemsResponse: """Search for items within a namespace prefix. Args: namespace_prefix: List of strings representing the namespace prefix. filter: Optional dictionary of key-value pairs to filter results. limit: Maximum number of items to return (default is 10). offset: Number of items to skip before returning results (default is 0). query: Optional query for natural language search. Returns: List[Item]: A list of items matching the search criteria. Example Usage: items = client.store.search_items( ["documents"], filter={"author": "John Doe"}, limit=5, offset=0 ) print(items) ---------------------------------------------------------------- { "items": [ { "namespace": ["documents", "user123"], "key": "item789", "value": { "title": "Another Document", "author": "John Doe" }, "created_at": "2024-07-30T12:00:00Z", "updated_at": "2024-07-30T12:00:00Z" }, # ... additional items ... ] } """ payload = { "namespace_prefix": namespace_prefix, "filter": filter, "limit": limit, "offset": offset, "query": query, } return self.http.post("/store/items/search", json=_provided_vals(payload)) def list_namespaces( self, prefix: Optional[List[str]] = None, suffix: Optional[List[str]] = None, max_depth: Optional[int] = None, limit: int = 100, offset: int = 0, ) -> ListNamespaceResponse: """List namespaces with optional match conditions. Args: prefix: Optional list of strings representing the prefix to filter namespaces. suffix: Optional list of strings representing the suffix to filter namespaces. max_depth: Optional integer specifying the maximum depth of namespaces to return. limit: Maximum number of namespaces to return (default is 100). offset: Number of namespaces to skip before returning results (default is 0). Returns: List[List[str]]: A list of namespaces matching the criteria. Example Usage: namespaces = client.store.list_namespaces( prefix=["documents"], max_depth=3, limit=10, offset=0 ) print(namespaces) ---------------------------------------------------------------- [ ["documents", "user123", "reports"], ["documents", "user456", "invoices"], ... ] """ payload = { "prefix": prefix, "suffix": suffix, "max_depth": max_depth, "limit": limit, "offset": offset, } return self.http.post("/store/namespaces", json=_provided_vals(payload)) def _provided_vals(d: dict): return {k: v for k, v in d.items() if v is not None} ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/sdk-py/langgraph_sdk/__init__.py`: ```py from langgraph_sdk.client import get_client, get_sync_client try: from importlib import metadata __version__ = metadata.version(__package__) except metadata.PackageNotFoundError: __version__ = "unknown" __all__ = ["get_client", "get_sync_client"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/sdk-py/langgraph_sdk/sse.py`: ```py """Adapted from httpx_sse to split lines on \n, \r, \r\n per the SSE spec.""" from typing import AsyncIterator, Iterator, Optional, Union import httpx import orjson from langgraph_sdk.schema import StreamPart BytesLike = Union[bytes, bytearray, memoryview] class BytesLineDecoder: """ Handles incrementally reading lines from text. Has the same behaviour as the stdllib bytes splitlines, but handling the input iteratively. """ def __init__(self) -> None: self.buffer = bytearray() self.trailing_cr: bool = False def decode(self, text: bytes) -> list[BytesLike]: # See https://docs.python.org/3/glossary.html#term-universal-newlines NEWLINE_CHARS = b"\n\r" # We always push a trailing `\r` into the next decode iteration. if self.trailing_cr: text = b"\r" + text self.trailing_cr = False if text.endswith(b"\r"): self.trailing_cr = True text = text[:-1] if not text: # NOTE: the edge case input of empty text doesn't occur in practice, # because other httpx internals filter out this value return [] # pragma: no cover trailing_newline = text[-1] in NEWLINE_CHARS lines = text.splitlines() if len(lines) == 1 and not trailing_newline: # No new lines, buffer the input and continue. self.buffer.extend(lines[0]) return [] if self.buffer: # Include any existing buffer in the first portion of the # splitlines result. self.buffer.extend(lines[0]) lines = [self.buffer] + lines[1:] self.buffer = bytearray() if not trailing_newline: # If the last segment of splitlines is not newline terminated, # then drop it from our output and start a new buffer. self.buffer.extend(lines.pop()) return lines def flush(self) -> list[BytesLike]: if not self.buffer and not self.trailing_cr: return [] lines = [self.buffer] self.buffer = bytearray() self.trailing_cr = False return lines class SSEDecoder: def __init__(self) -> None: self._event = "" self._data = bytearray() self._last_event_id = "" self._retry: Optional[int] = None def decode(self, line: bytes) -> Optional[StreamPart]: # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 if not line: if ( not self._event and not self._data and not self._last_event_id and self._retry is None ): return None sse = StreamPart( event=self._event, data=orjson.loads(self._data) if self._data else None, ) # NOTE: as per the SSE spec, do not reset last_event_id. self._event = "" self._data = bytearray() self._retry = None return sse if line.startswith(b":"): return None fieldname, _, value = line.partition(b":") if value.startswith(b" "): value = value[1:] if fieldname == b"event": self._event = value.decode() elif fieldname == b"data": self._data.extend(value) elif fieldname == b"id": if b"\0" in value: pass else: self._last_event_id = value.decode() elif fieldname == b"retry": try: self._retry = int(value) except (TypeError, ValueError): pass else: pass # Field is ignored. return None async def aiter_lines_raw(response: httpx.Response) -> AsyncIterator[BytesLike]: decoder = BytesLineDecoder() async for chunk in response.aiter_bytes(): for line in decoder.decode(chunk): yield line for line in decoder.flush(): yield line def iter_lines_raw(response: httpx.Response) -> Iterator[BytesLike]: decoder = BytesLineDecoder() for chunk in response.iter_bytes(): for line in decoder.decode(chunk): yield line for line in decoder.flush(): yield line ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/sdk-py/langgraph_sdk/schema.py`: ```py """Data models for interacting with the LangGraph API.""" from datetime import datetime from typing import Any, Dict, Literal, NamedTuple, Optional, Sequence, TypedDict, Union Json = Optional[dict[str, Any]] """Represents a JSON-like structure, which can be None or a dictionary with string keys and any values.""" RunStatus = Literal["pending", "error", "success", "timeout", "interrupted"] """ Represents the status of a run: - "pending": The run is waiting to start. - "error": The run encountered an error and stopped. - "success": The run completed successfully. - "timeout": The run exceeded its time limit. - "interrupted": The run was manually stopped or interrupted. """ ThreadStatus = Literal["idle", "busy", "interrupted", "error"] """ Represents the status of a thread: - "idle": The thread is not currently processing any task. - "busy": The thread is actively processing a task. - "interrupted": The thread's execution was interrupted. - "error": An exception occurred during task processing. """ StreamMode = Literal[ "values", "messages", "updates", "events", "debug", "custom", "messages-tuple" ] """ Defines the mode of streaming: - "values": Stream only the values. - "messages": Stream complete messages. - "updates": Stream updates to the state. - "events": Stream events occurring during execution. - "debug": Stream detailed debug information. - "custom": Stream custom events. """ DisconnectMode = Literal["cancel", "continue"] """ Specifies behavior on disconnection: - "cancel": Cancel the operation on disconnection. - "continue": Continue the operation even if disconnected. """ MultitaskStrategy = Literal["reject", "interrupt", "rollback", "enqueue"] """ Defines how to handle multiple tasks: - "reject": Reject new tasks when busy. - "interrupt": Interrupt current task for new ones. - "rollback": Roll back current task and start new one. - "enqueue": Queue new tasks for later execution. """ OnConflictBehavior = Literal["raise", "do_nothing"] """ Specifies behavior on conflict: - "raise": Raise an exception when a conflict occurs. - "do_nothing": Ignore conflicts and proceed. """ OnCompletionBehavior = Literal["delete", "keep"] """ Defines action after completion: - "delete": Delete resources after completion. - "keep": Retain resources after completion. """ All = Literal["*"] """Represents a wildcard or 'all' selector.""" IfNotExists = Literal["create", "reject"] """ Specifies behavior if the thread doesn't exist: - "create": Create a new thread if it doesn't exist. - "reject": Reject the operation if the thread doesn't exist. """ CancelAction = Literal["interrupt", "rollback"] """ Action to take when cancelling the run. - "interrupt": Simply cancel the run. - "rollback": Cancel the run. Then delete the run and associated checkpoints. """ class Config(TypedDict, total=False): """Configuration options for a call.""" tags: list[str] """ Tags for this call and any sub-calls (eg. a Chain calling an LLM). You can use these to filter calls. """ recursion_limit: int """ Maximum number of times a call can recurse. If not provided, defaults to 25. """ configurable: dict[str, Any] """ Runtime values for attributes previously made configurable on this Runnable, or sub-Runnables, through .configurable_fields() or .configurable_alternatives(). Check .output_schema() for a description of the attributes that have been made configurable. """ class Checkpoint(TypedDict): """Represents a checkpoint in the execution process.""" thread_id: str """Unique identifier for the thread associated with this checkpoint.""" checkpoint_ns: str """Namespace for the checkpoint, used for organization and retrieval.""" checkpoint_id: Optional[str] """Optional unique identifier for the checkpoint itself.""" checkpoint_map: Optional[dict[str, Any]] """Optional dictionary containing checkpoint-specific data.""" class GraphSchema(TypedDict): """Defines the structure and properties of a graph.""" graph_id: str """The ID of the graph.""" input_schema: Optional[dict] """The schema for the graph input. Missing if unable to generate JSON schema from graph.""" output_schema: Optional[dict] """The schema for the graph output. Missing if unable to generate JSON schema from graph.""" state_schema: Optional[dict] """The schema for the graph state. Missing if unable to generate JSON schema from graph.""" config_schema: Optional[dict] """The schema for the graph config. Missing if unable to generate JSON schema from graph.""" Subgraphs = dict[str, GraphSchema] class AssistantBase(TypedDict): """Base model for an assistant.""" assistant_id: str """The ID of the assistant.""" graph_id: str """The ID of the graph.""" config: Config """The assistant config.""" created_at: datetime """The time the assistant was created.""" metadata: Json """The assistant metadata.""" version: int """The version of the assistant""" class AssistantVersion(AssistantBase): """Represents a specific version of an assistant.""" pass class Assistant(AssistantBase): """Represents an assistant with additional properties.""" updated_at: datetime """The last time the assistant was updated.""" name: str """The name of the assistant""" class Interrupt(TypedDict, total=False): """Represents an interruption in the execution flow.""" value: Any """The value associated with the interrupt.""" when: Literal["during"] """When the interrupt occurred.""" resumable: bool """Whether the interrupt can be resumed.""" ns: Optional[list[str]] """Optional namespace for the interrupt.""" class Thread(TypedDict): """Represents a conversation thread.""" thread_id: str """The ID of the thread.""" created_at: datetime """The time the thread was created.""" updated_at: datetime """The last time the thread was updated.""" metadata: Json """The thread metadata.""" status: ThreadStatus """The status of the thread, one of 'idle', 'busy', 'interrupted'.""" values: Json """The current state of the thread.""" interrupts: Dict[str, list[Interrupt]] """Interrupts which were thrown in this thread""" class ThreadTask(TypedDict): """Represents a task within a thread.""" id: str name: str error: Optional[str] interrupts: list[Interrupt] checkpoint: Optional[Checkpoint] state: Optional["ThreadState"] result: Optional[dict[str, Any]] class ThreadState(TypedDict): """Represents the state of a thread.""" values: Union[list[dict], dict[str, Any]] """The state values.""" next: Sequence[str] """The next nodes to execute. If empty, the thread is done until new input is received.""" checkpoint: Checkpoint """The ID of the checkpoint.""" metadata: Json """Metadata for this state""" created_at: Optional[str] """Timestamp of state creation""" parent_checkpoint: Optional[Checkpoint] """The ID of the parent checkpoint. If missing, this is the root checkpoint.""" tasks: Sequence[ThreadTask] """Tasks to execute in this step. If already attempted, may contain an error.""" class ThreadUpdateStateResponse(TypedDict): """Represents the response from updating a thread's state.""" checkpoint: Checkpoint """Checkpoint of the latest state.""" class Run(TypedDict): """Represents a single execution run.""" run_id: str """The ID of the run.""" thread_id: str """The ID of the thread.""" assistant_id: str """The assistant that was used for this run.""" created_at: datetime """The time the run was created.""" updated_at: datetime """The last time the run was updated.""" status: RunStatus """The status of the run. One of 'pending', 'running', "error", 'success', "timeout", "interrupted".""" metadata: Json """The run metadata.""" multitask_strategy: MultitaskStrategy """Strategy to handle concurrent runs on the same thread.""" class Cron(TypedDict): """Represents a scheduled task.""" cron_id: str """The ID of the cron.""" thread_id: Optional[str] """The ID of the thread.""" end_time: Optional[datetime] """The end date to stop running the cron.""" schedule: str """The schedule to run, cron format.""" created_at: datetime """The time the cron was created.""" updated_at: datetime """The last time the cron was updated.""" payload: dict """The run payload to use for creating new run.""" class RunCreate(TypedDict): """Defines the parameters for initiating a background run.""" thread_id: Optional[str] """The identifier of the thread to run. If not provided, the run is stateless.""" assistant_id: str """The identifier of the assistant to use for this run.""" input: Optional[dict] """Initial input data for the run.""" metadata: Optional[dict] """Additional metadata to associate with the run.""" config: Optional[Config] """Configuration options for the run.""" checkpoint_id: Optional[str] """The identifier of a checkpoint to resume from.""" interrupt_before: Optional[list[str]] """List of node names to interrupt execution before.""" interrupt_after: Optional[list[str]] """List of node names to interrupt execution after.""" webhook: Optional[str] """URL to send webhook notifications about the run's progress.""" multitask_strategy: Optional[MultitaskStrategy] """Strategy for handling concurrent runs on the same thread.""" class Item(TypedDict): """Represents a single document or data entry in the graph's Store. Items are used to store cross-thread memories. """ namespace: list[str] """The namespace of the item. A namespace is analogous to a document's directory.""" key: str """The unique identifier of the item within its namespace. In general, keys needn't be globally unique. """ value: dict[str, Any] """The value stored in the item. This is the document itself.""" created_at: datetime """The timestamp when the item was created.""" updated_at: datetime """The timestamp when the item was last updated.""" class ListNamespaceResponse(TypedDict): """Response structure for listing namespaces.""" namespaces: list[list[str]] """A list of namespace paths, where each path is a list of strings.""" class SearchItem(Item, total=False): """Item with an optional relevance score from search operations. Attributes: score (Optional[float]): Relevance/similarity score. Included when searching a compatible store with a natural language query. """ score: Optional[float] class SearchItemsResponse(TypedDict): """Response structure for searching items.""" items: list[SearchItem] """A list of items matching the search criteria.""" class StreamPart(NamedTuple): """Represents a part of a stream response.""" event: str """The type of event for this stream part.""" data: dict """The data payload associated with the event.""" class Send(TypedDict): node: str input: Optional[dict[str, Any]] class Command(TypedDict, total=False): goto: Union[Send, str, Sequence[Union[Send, str]]] update: dict[str, Any] resume: Any ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/config.py`: ```py import json import os import pathlib import textwrap from typing import NamedTuple, Optional, TypedDict, Union import click MIN_NODE_VERSION = "20" MIN_PYTHON_VERSION = "3.11" class IndexConfig(TypedDict, total=False): """Configuration for indexing documents for semantic search in the store.""" dims: int """Number of dimensions in the embedding vectors. Common embedding models have the following dimensions: - openai:text-embedding-3-large: 3072 - openai:text-embedding-3-small: 1536 - openai:text-embedding-ada-002: 1536 - cohere:embed-english-v3.0: 1024 - cohere:embed-english-light-v3.0: 384 - cohere:embed-multilingual-v3.0: 1024 - cohere:embed-multilingual-light-v3.0: 384 """ embed: str """Optional model (string) to generate embeddings from text or path to model or function. Examples: - "openai:text-embedding-3-large" - "cohere:embed-multilingual-v3.0" - "src/app.py:embeddings """ fields: Optional[list[str]] """Fields to extract text from for embedding generation. Defaults to the root ["$"], which embeds the json object as a whole. """ class StoreConfig(TypedDict, total=False): embed: Optional[IndexConfig] """Configuration for vector embeddings in store.""" class Config(TypedDict, total=False): python_version: str node_version: Optional[str] pip_config_file: Optional[str] dockerfile_lines: list[str] dependencies: list[str] graphs: dict[str, str] env: Union[dict[str, str], str] store: Optional[StoreConfig] def _parse_version(version_str: str) -> tuple[int, int]: """Parse a version string into a tuple of (major, minor).""" try: major, minor = map(int, version_str.split(".")) return (major, minor) except ValueError: raise click.UsageError(f"Invalid version format: {version_str}") from None def _parse_node_version(version_str: str) -> int: """Parse a Node.js version string into a major version number.""" try: if "." in version_str: raise ValueError("Node.js version must be major version only") return int(version_str) except ValueError: raise click.UsageError( f"Invalid Node.js version format: {version_str}. " "Use major version only (e.g., '20')." ) from None def validate_config(config: Config) -> Config: config = ( { "node_version": config.get("node_version"), "dockerfile_lines": config.get("dockerfile_lines", []), "graphs": config.get("graphs", {}), "env": config.get("env", {}), "store": config.get("store"), } if config.get("node_version") else { "python_version": config.get("python_version", "3.11"), "pip_config_file": config.get("pip_config_file"), "dockerfile_lines": config.get("dockerfile_lines", []), "dependencies": config.get("dependencies", []), "graphs": config.get("graphs", {}), "env": config.get("env", {}), "store": config.get("store"), } ) if config.get("node_version"): node_version = config["node_version"] try: major = _parse_node_version(node_version) min_major = _parse_node_version(MIN_NODE_VERSION) if major < min_major: raise click.UsageError( f"Node.js version {node_version} is not supported. " f"Minimum required version is {MIN_NODE_VERSION}." ) except ValueError as e: raise click.UsageError(str(e)) from None if config.get("python_version"): pyversion = config["python_version"] if not pyversion.count(".") == 1 or not all( part.isdigit() for part in pyversion.split(".") ): raise click.UsageError( f"Invalid Python version format: {pyversion}. " "Use 'major.minor' format (e.g., '3.11'). " "Patch version cannot be specified." ) if _parse_version(pyversion) < _parse_version(MIN_PYTHON_VERSION): raise click.UsageError( f"Python version {pyversion} is not supported. " f"Minimum required version is {MIN_PYTHON_VERSION}." ) if not config["dependencies"]: raise click.UsageError( "No dependencies found in config. " "Add at least one dependency to 'dependencies' list." ) if not config["graphs"]: raise click.UsageError( "No graphs found in config. " "Add at least one graph to 'graphs' dictionary." ) return config def validate_config_file(config_path: pathlib.Path) -> Config: with open(config_path) as f: config = json.load(f) validated = validate_config(config) # Enforce the package.json doesn't enforce an # incompatible Node.js version if validated.get("node_version"): package_json_path = config_path.parent / "package.json" if package_json_path.is_file(): try: with open(package_json_path) as f: package_json = json.load(f) if "engines" in package_json: engines = package_json["engines"] if any(engine != "node" for engine in engines.keys()): raise click.UsageError( "Only 'node' engine is supported in package.json engines." f" Got engines: {list(engines.keys())}" ) if engines: node_version = engines["node"] try: major = _parse_node_version(node_version) min_major = _parse_node_version(MIN_NODE_VERSION) if major < min_major: raise click.UsageError( f"Node.js version in package.json engines must be >= {MIN_NODE_VERSION} " f"(major version only), got '{node_version}'. Minor/patch versions " "(like '20.x.y') are not supported to prevent deployment issues " "when new Node.js versions are released." ) except ValueError as e: raise click.UsageError(str(e)) from None except json.JSONDecodeError: raise click.UsageError( "Invalid package.json found in langgraph " f"config directory {package_json_path}: file is not valid JSON" ) from None return validated class LocalDeps(NamedTuple): pip_reqs: list[tuple[pathlib.Path, str]] real_pkgs: dict[pathlib.Path, str] faux_pkgs: dict[pathlib.Path, tuple[str, str]] # if . is in dependencies, use it as working_dir working_dir: Optional[str] = None def _assemble_local_deps(config_path: pathlib.Path, config: Config) -> LocalDeps: # ensure reserved package names are not used reserved = { "src", "langgraph-api", "langgraph_api", "langgraph", "langchain-core", "langchain_core", "pydantic", "orjson", "fastapi", "uvicorn", "psycopg", "httpx", "langsmith", } def check_reserved(name: str, ref: str): if name in reserved: raise ValueError( f"Package name '{name}' used in local dep '{ref}' is reserved. " "Rename the directory." ) reserved.add(name) pip_reqs = [] real_pkgs = {} faux_pkgs = {} working_dir = None for local_dep in config["dependencies"]: if not local_dep.startswith("."): continue resolved = config_path.parent / local_dep # validate local dependency if not resolved.exists(): raise FileNotFoundError(f"Could not find local dependency: {resolved}") elif not resolved.is_dir(): raise NotADirectoryError( f"Local dependency must be a directory: {resolved}" ) elif not resolved.is_relative_to(config_path.parent): raise ValueError( f"Local dependency '{resolved}' must be a subdirectory of '{config_path.parent}'" ) # if it's installable, add it to local_pkgs # otherwise, add it to faux_pkgs, and create a pyproject.toml files = os.listdir(resolved) if "pyproject.toml" in files: real_pkgs[resolved] = local_dep if local_dep == ".": working_dir = f"/deps/{resolved.name}" elif "setup.py" in files: real_pkgs[resolved] = local_dep if local_dep == ".": working_dir = f"/deps/{resolved.name}" else: if any(file == "__init__.py" for file in files): # flat layout if "-" in resolved.name: raise ValueError( f"Package name '{resolved.name}' contains a hyphen. " "Rename the directory to use it as flat-layout package." ) check_reserved(resolved.name, local_dep) container_path = f"/deps/__outer_{resolved.name}/{resolved.name}" else: # src layout container_path = f"/deps/__outer_{resolved.name}/src" for file in files: rfile = resolved / file if ( rfile.is_dir() and file != "__pycache__" and not file.startswith(".") ): try: for subfile in os.listdir(rfile): if subfile.endswith(".py"): check_reserved(file, local_dep) break except PermissionError: pass faux_pkgs[resolved] = (local_dep, container_path) if local_dep == ".": working_dir = container_path if "requirements.txt" in files: rfile = resolved / "requirements.txt" pip_reqs.append( ( rfile.relative_to(config_path.parent), f"{container_path}/requirements.txt", ) ) return LocalDeps(pip_reqs, real_pkgs, faux_pkgs, working_dir) def _update_graph_paths( config_path: pathlib.Path, config: Config, local_deps: LocalDeps ) -> None: for graph_id, import_str in config["graphs"].items(): module_str, _, attr_str = import_str.partition(":") if not module_str or not attr_str: message = ( 'Import string "{import_str}" must be in format "<module>:<attribute>".' ) raise ValueError(message.format(import_str=import_str)) if "/" in module_str: resolved = config_path.parent / module_str if not resolved.exists(): raise FileNotFoundError(f"Could not find local module: {resolved}") elif not resolved.is_file(): raise IsADirectoryError(f"Local module must be a file: {resolved}") else: for path in local_deps.real_pkgs: if resolved.is_relative_to(path): module_str = f"/deps/{path.name}/{resolved.relative_to(path)}" break else: for faux_pkg, (_, destpath) in local_deps.faux_pkgs.items(): if resolved.is_relative_to(faux_pkg): module_str = f"{destpath}/{resolved.relative_to(faux_pkg)}" break else: raise ValueError( f"Module '{import_str}' not found in 'dependencies' list. " "Add its containing package to 'dependencies' list." ) # update the config config["graphs"][graph_id] = f"{module_str}:{attr_str}" def python_config_to_docker(config_path: pathlib.Path, config: Config, base_image: str): # configure pip pip_install = ( "PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt" ) if config.get("pip_config_file"): pip_install = f"PIP_CONFIG_FILE=/pipconfig.txt {pip_install}" pip_config_file_str = ( f"ADD {config['pip_config_file']} /pipconfig.txt" if config.get("pip_config_file") else "" ) # collect dependencies pypi_deps = [dep for dep in config["dependencies"] if not dep.startswith(".")] local_deps = _assemble_local_deps(config_path, config) # rewrite graph paths _update_graph_paths(config_path, config, local_deps) pip_pkgs_str = f"RUN {pip_install} {' '.join(pypi_deps)}" if pypi_deps else "" if local_deps.pip_reqs: pip_reqs_str = os.linesep.join( f"ADD {reqpath} {destpath}" for reqpath, destpath in local_deps.pip_reqs ) pip_reqs_str += f'{os.linesep}RUN {pip_install} {" ".join("-r " + r for _,r in local_deps.pip_reqs)}' else: pip_reqs_str = "" # https://setuptools.pypa.io/en/latest/userguide/datafiles.html#package-data # https://til.simonwillison.net/python/pyproject faux_pkgs_str = f"{os.linesep}{os.linesep}".join( f"""ADD {relpath} {destpath} RUN set -ex && \\ for line in '[project]' \\ 'name = "{fullpath.name}"' \\ 'version = "0.1"' \\ '[tool.setuptools.package-data]' \\ '"*" = ["**/*"]'; do \\ echo "$line" >> /deps/__outer_{fullpath.name}/pyproject.toml; \\ done""" for fullpath, (relpath, destpath) in local_deps.faux_pkgs.items() ) local_pkgs_str = os.linesep.join( f"ADD {relpath} /deps/{fullpath.name}" for fullpath, relpath in local_deps.real_pkgs.items() ) installs = f"{os.linesep}{os.linesep}".join( filter( None, [ pip_config_file_str, pip_pkgs_str, pip_reqs_str, local_pkgs_str, faux_pkgs_str, ], ) ) store_config = config.get("store") env_additional_config = ( "" if not store_config else f""" ENV LANGGRAPH_STORE='{json.dumps(store_config)}' """ ) return f"""FROM {base_image}:{config['python_version']} {os.linesep.join(config["dockerfile_lines"])} {installs} RUN {pip_install} -e /deps/* {env_additional_config} ENV LANGSERVE_GRAPHS='{json.dumps(config["graphs"])}' {f"WORKDIR {local_deps.working_dir}" if local_deps.working_dir else ""}""" def node_config_to_docker(config_path: pathlib.Path, config: Config, base_image: str): faux_path = f"/deps/{config_path.parent.name}" def test_file(file_name): full_path = config_path.parent / file_name try: return full_path.is_file() except OSError: return False npm, yarn, pnpm = [ test_file("package-lock.json"), test_file("yarn.lock"), test_file("pnpm-lock.yaml"), ] if yarn: install_cmd = "yarn install --frozen-lockfile" elif pnpm: install_cmd = "pnpm i --frozen-lockfile" elif npm: install_cmd = "npm ci" else: install_cmd = "npm i" store_config = config.get("store") env_additional_config = ( "" if not store_config else f""" ENV LANGGRAPH_STORE='{json.dumps(store_config)}' """ ) return f"""FROM {base_image}:{config['node_version']} {os.linesep.join(config["dockerfile_lines"])} ADD . {faux_path} RUN cd {faux_path} && {install_cmd} {env_additional_config} ENV LANGSERVE_GRAPHS='{json.dumps(config["graphs"])}' WORKDIR {faux_path} RUN (test ! -f /api/langgraph_api/js/build.mts && echo "Prebuild script not found, skipping") || tsx /api/langgraph_api/js/build.mts""" def config_to_docker(config_path: pathlib.Path, config: Config, base_image: str): if config.get("node_version"): return node_config_to_docker(config_path, config, base_image) return python_config_to_docker(config_path, config, base_image) def config_to_compose( config_path: pathlib.Path, config: Config, base_image: str, watch: bool = False, ): env_vars = config["env"].items() if isinstance(config["env"], dict) else {} env_vars_str = "\n".join(f' {k}: "{v}"' for k, v in env_vars) env_file_str = ( f"env_file: {config['env']}" if isinstance(config["env"], str) else "" ) if watch: watch_paths = [config_path.name] + [ dep for dep in config["dependencies"] if dep.startswith(".") ] watch_actions = "\n".join( f"""- path: {path} action: rebuild""" for path in watch_paths ) watch_str = f""" develop: watch: {textwrap.indent(watch_actions, " ")} """ else: watch_str = "" return f""" {textwrap.indent(env_vars_str, " ")} {env_file_str} pull_policy: build build: context: . dockerfile_inline: | {textwrap.indent(config_to_docker(config_path, config, base_image), " ")} {watch_str} """ ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/version.py`: ```py """Main entrypoint into package.""" from importlib import metadata try: __version__ = metadata.version(__package__) except metadata.PackageNotFoundError: # Case where package metadata is not available. __version__ = "" del metadata # optional, avoids polluting the results of dir(__package__) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/util.py`: ```py def clean_empty_lines(input_str: str): return "\n".join(filter(None, input_str.splitlines())) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/exec.py`: ```py import asyncio import signal import sys from contextlib import contextmanager from typing import Callable, Optional, cast import click.exceptions @contextmanager def Runner(): if hasattr(asyncio, "Runner"): with asyncio.Runner() as runner: yield runner else: class _Runner: def __enter__(self): return self def __exit__(self, *args): pass def run(self, coro): return asyncio.run(coro) yield _Runner() async def subp_exec( cmd: str, *args: str, input: Optional[str] = None, wait: Optional[float] = None, verbose: bool = False, collect: bool = False, on_stdout: Optional[Callable[[str], Optional[bool]]] = None, ) -> tuple[Optional[str], Optional[str]]: if verbose: cmd_str = f"+ {cmd} {' '.join(map(str, args))}" if input: print(cmd_str, " <\n", "\n".join(filter(None, input.splitlines())), sep="") else: print(cmd_str) if wait: await asyncio.sleep(wait) try: proc = await asyncio.create_subprocess_exec( cmd, *args, stdin=asyncio.subprocess.PIPE if input else None, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) def signal_handler(): # make sure process exists, then terminate it if proc.returncode is None: proc.terminate() original_sigint_handler = signal.getsignal(signal.SIGINT) if sys.platform == "win32": def handle_windows_signal(signum, frame): signal_handler() original_sigint_handler(signum, frame) signal.signal(signal.SIGINT, handle_windows_signal) # NOTE: we're not adding a handler for SIGTERM since it's ignored on Windows else: loop = asyncio.get_event_loop() loop.add_signal_handler(signal.SIGINT, signal_handler) loop.add_signal_handler(signal.SIGTERM, signal_handler) empty_fut: asyncio.Future = asyncio.Future() empty_fut.set_result(None) stdout, stderr, _ = await asyncio.gather( monitor_stream( cast(asyncio.StreamReader, proc.stdout), collect=True, display=verbose, on_line=on_stdout, ), monitor_stream( cast(asyncio.StreamReader, proc.stderr), collect=True, display=verbose, ), proc._feed_stdin(input.encode()) if input else empty_fut, # type: ignore[attr-defined] ) returncode = await proc.wait() if ( returncode is not None and returncode != 0 # success and returncode != 130 # user interrupt ): sys.stdout.write(stdout.decode() if stdout else "") sys.stderr.write(stderr.decode() if stderr else "") raise click.exceptions.Exit(returncode) if collect: return ( stdout.decode() if stdout else None, stderr.decode() if stderr else None, ) else: return None, None finally: try: if proc.returncode is None: try: proc.terminate() except (ProcessLookupError, KeyboardInterrupt): pass if sys.platform == "win32": signal.signal(signal.SIGINT, original_sigint_handler) else: loop.remove_signal_handler(signal.SIGINT) loop.remove_signal_handler(signal.SIGTERM) except UnboundLocalError: pass async def monitor_stream( stream: asyncio.StreamReader, collect: bool = False, display: bool = False, on_line: Optional[Callable[[str], Optional[bool]]] = None, ) -> Optional[bytearray]: if collect: ba = bytearray() def handle(line: bytes, overrun: bool): nonlocal on_line nonlocal display if display: sys.stdout.buffer.write(line) if overrun: return if collect: ba.extend(line) if on_line: if on_line(line.decode()): on_line = None display = True """Adapted from asyncio.StreamReader.readline() to handle LimitOverrunError.""" sep = b"\n" seplen = len(sep) while True: try: line = await stream.readuntil(sep) overrun = False except asyncio.IncompleteReadError as e: line = e.partial overrun = False except asyncio.LimitOverrunError as e: if stream._buffer.startswith(sep, e.consumed): line = stream._buffer[: e.consumed + seplen] else: line = stream._buffer.clear() overrun = True stream._maybe_resume_transport() await asyncio.to_thread(handle, line, overrun) if line == b"": break if collect: return ba else: return None ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/constants.py`: ```py DEFAULT_CONFIG = "langgraph.json" DEFAULT_PORT = 8123 # analytics SUPABASE_PUBLIC_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Imt6cmxwcG9qaW5wY3l5YWlweG5iIiwicm9sZSI6ImFub24iLCJpYXQiOjE3MTkyNTc1NzksImV4cCI6MjAzNDgzMzU3OX0.kkVOlLz3BxemA5nP-vat3K4qRtrDuO4SwZSR_htcX9c" SUPABASE_URL = "https://kzrlppojinpcyyaipxnb.supabase.co" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/docker.py`: ```py import json import pathlib import shutil from typing import Literal, NamedTuple, Optional import click.exceptions from langgraph_cli.exec import subp_exec ROOT = pathlib.Path(__file__).parent.resolve() DEFAULT_POSTGRES_URI = ( "postgres://postgres:postgres@langgraph-postgres:5432/postgres?sslmode=disable" ) class Version(NamedTuple): major: int minor: int patch: int DockerComposeType = Literal["plugin", "standalone"] class DockerCapabilities(NamedTuple): version_docker: Version version_compose: Version healthcheck_start_interval: bool compose_type: DockerComposeType = "plugin" def _parse_version(version: str) -> Version: parts = version.split(".", 2) if len(parts) == 1: major = parts[0] minor = "0" patch = "0" elif len(parts) == 2: major, minor = parts patch = "0" else: major, minor, patch = parts return Version(int(major.lstrip("v")), int(minor), int(patch.split("-")[0])) def check_capabilities(runner) -> DockerCapabilities: # check docker available if shutil.which("docker") is None: raise click.UsageError("Docker not installed") from None try: stdout, _ = runner.run(subp_exec("docker", "info", "-f", "json", collect=True)) info = json.loads(stdout) except (click.exceptions.Exit, json.JSONDecodeError): raise click.UsageError("Docker not installed or not running") from None if not info["ServerVersion"]: raise click.UsageError("Docker not running") from None compose_type: DockerComposeType try: compose = next( p for p in info["ClientInfo"]["Plugins"] if p["Name"] == "compose" ) compose_version_str = compose["Version"] compose_type = "plugin" except (KeyError, StopIteration): if shutil.which("docker-compose") is None: raise click.UsageError("Docker Compose not installed") from None compose_version_str, _ = runner.run( subp_exec("docker-compose", "--version", "--short", collect=True) ) compose_type = "standalone" # parse versions docker_version = _parse_version(info["ServerVersion"]) compose_version = _parse_version(compose_version_str) # check capabilities return DockerCapabilities( version_docker=docker_version, version_compose=compose_version, healthcheck_start_interval=docker_version >= Version(25, 0, 0), compose_type=compose_type, ) def debugger_compose( *, port: Optional[int] = None, base_url: Optional[str] = None ) -> dict: if port is None: return "" config = { "langgraph-debugger": { "image": "langchain/langgraph-debugger", "restart": "on-failure", "depends_on": { "langgraph-postgres": {"condition": "service_healthy"}, }, "ports": [f'"{port}:3968"'], } } if base_url: config["langgraph-debugger"]["environment"] = { "VITE_STUDIO_LOCAL_GRAPH_URL": base_url } return config # Function to convert dictionary to YAML def dict_to_yaml(d: dict, *, indent: int = 0) -> str: """Convert a dictionary to a YAML string.""" yaml_str = "" for idx, (key, value) in enumerate(d.items()): # Format things in a visually appealing way # Use an extra newline for top-level keys only if idx >= 1 and indent < 2: yaml_str += "\n" space = " " * indent if isinstance(value, dict): yaml_str += f"{space}{key}:\n" + dict_to_yaml(value, indent=indent + 1) elif isinstance(value, list): yaml_str += f"{space}{key}:\n" for item in value: yaml_str += f"{space} - {item}\n" else: yaml_str += f"{space}{key}: {value}\n" return yaml_str def compose_as_dict( capabilities: DockerCapabilities, *, port: int, debugger_port: Optional[int] = None, debugger_base_url: Optional[str] = None, # postgres://user:password@host:port/database?option=value postgres_uri: Optional[str] = None, ) -> dict: """Create a docker compose file as a dictionary in YML style.""" if postgres_uri is None: include_db = True postgres_uri = DEFAULT_POSTGRES_URI else: include_db = False # The services below are defined in a non-intuitive order to match # the existing unit tests for this function. # It's fine to re-order just requires updating the unit tests, so it should # be done with caution. # Define the Redis service first as per the test order services = { "langgraph-redis": { "image": "redis:6", "healthcheck": { "test": "redis-cli ping", "interval": "5s", "timeout": "1s", "retries": 5, }, } } # Add Postgres service before langgraph-api if it is needed if include_db: services["langgraph-postgres"] = { "image": "postgres:16", "ports": ['"5433:5432"'], "environment": { "POSTGRES_DB": "postgres", "POSTGRES_USER": "postgres", "POSTGRES_PASSWORD": "postgres", }, "volumes": ["langgraph-data:/var/lib/postgresql/data"], "healthcheck": { "test": "pg_isready -U postgres", "start_period": "10s", "timeout": "1s", "retries": 5, }, } if capabilities.healthcheck_start_interval: services["langgraph-postgres"]["healthcheck"]["interval"] = "60s" services["langgraph-postgres"]["healthcheck"]["start_interval"] = "1s" else: services["langgraph-postgres"]["healthcheck"]["interval"] = "5s" # Add optional debugger service if debugger_port is specified if debugger_port: services["langgraph-debugger"] = debugger_compose( port=debugger_port, base_url=debugger_base_url )["langgraph-debugger"] # Add langgraph-api service services["langgraph-api"] = { "ports": [f'"{port}:8000"'], "depends_on": { "langgraph-redis": {"condition": "service_healthy"}, }, "environment": { "REDIS_URI": "redis://langgraph-redis:6379", "POSTGRES_URI": postgres_uri, }, } # If Postgres is included, add it to the dependencies of langgraph-api if include_db: services["langgraph-api"]["depends_on"]["langgraph-postgres"] = { "condition": "service_healthy" } # Additional healthcheck for langgraph-api if required if capabilities.healthcheck_start_interval: services["langgraph-api"]["healthcheck"] = { "test": "python /api/healthcheck.py", "interval": "60s", "start_interval": "1s", "start_period": "10s", } # Final compose dictionary with volumes included if needed compose_dict = {} if include_db: compose_dict["volumes"] = {"langgraph-data": {"driver": "local"}} compose_dict["services"] = services return compose_dict def compose( capabilities: DockerCapabilities, *, port: int, debugger_port: Optional[int] = None, debugger_base_url: Optional[str] = None, # postgres://user:password@host:port/database?option=value postgres_uri: Optional[str] = None, ) -> str: """Create a docker compose file as a string.""" compose_content = compose_as_dict( capabilities, port=port, debugger_port=debugger_port, debugger_base_url=debugger_base_url, postgres_uri=postgres_uri, ) compose_str = dict_to_yaml(compose_content) return compose_str ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/templates.py`: ```py import os import shutil import sys from io import BytesIO from typing import Dict, Optional from urllib import error, request from zipfile import ZipFile import click TEMPLATES: Dict[str, Dict[str, str]] = { "New LangGraph Project": { "description": "A simple, minimal chatbot with memory.", "python": "https://github.com/langchain-ai/new-langgraph-project/archive/refs/heads/main.zip", "js": "https://github.com/langchain-ai/new-langgraphjs-project/archive/refs/heads/main.zip", }, "ReAct Agent": { "description": "A simple agent that can be flexibly extended to many tools.", "python": "https://github.com/langchain-ai/react-agent/archive/refs/heads/main.zip", "js": "https://github.com/langchain-ai/react-agent-js/archive/refs/heads/main.zip", }, "Memory Agent": { "description": "A ReAct-style agent with an additional tool to store memories for use across conversational threads.", "python": "https://github.com/langchain-ai/memory-agent/archive/refs/heads/main.zip", "js": "https://github.com/langchain-ai/memory-agent-js/archive/refs/heads/main.zip", }, "Retrieval Agent": { "description": "An agent that includes a retrieval-based question-answering system.", "python": "https://github.com/langchain-ai/retrieval-agent-template/archive/refs/heads/main.zip", "js": "https://github.com/langchain-ai/retrieval-agent-template-js/archive/refs/heads/main.zip", }, "Data-enrichment Agent": { "description": "An agent that performs web searches and organizes its findings into a structured format.", "python": "https://github.com/langchain-ai/data-enrichment/archive/refs/heads/main.zip", "js": "https://github.com/langchain-ai/data-enrichment-js/archive/refs/heads/main.zip", }, } # Generate TEMPLATE_IDS programmatically TEMPLATE_ID_TO_CONFIG = { f"{name.lower().replace(' ', '-')}-{lang}": (name, lang, url) for name, versions in TEMPLATES.items() for lang, url in versions.items() if lang in {"python", "js"} } TEMPLATE_IDS = list(TEMPLATE_ID_TO_CONFIG.keys()) TEMPLATE_HELP_STRING = ( "The name of the template to use. Available options:\n" + "\n".join(f"{id_}" for id_ in TEMPLATE_ID_TO_CONFIG) ) def _choose_template() -> str: """Presents a list of templates to the user and prompts them to select one. Returns: str: The URL of the selected template. """ click.secho("🌟 Please select a template:", bold=True, fg="yellow") for idx, (template_name, template_info) in enumerate(TEMPLATES.items(), 1): click.secho(f"{idx}. ", nl=False, fg="cyan") click.secho(template_name, fg="cyan", nl=False) click.secho(f" - {template_info['description']}", fg="white") # Get the template choice from the user, defaulting to the first template if blank template_choice: Optional[int] = click.prompt( "Enter the number of your template choice (default is 1)", type=int, default=1, show_default=False, ) template_keys = list(TEMPLATES.keys()) if 1 <= template_choice <= len(template_keys): selected_template: str = template_keys[template_choice - 1] else: click.secho("❌ Invalid choice. Please try again.", fg="red") return _choose_template() # Prompt the user to choose between Python or JS/TS version click.secho( f"\nYou selected: {selected_template} - {TEMPLATES[selected_template]['description']}", fg="green", ) version_choice: int = click.prompt( "Choose language (1 for Python 🐍, 2 for JS/TS 🌐)", type=int ) if version_choice == 1: return TEMPLATES[selected_template]["python"] elif version_choice == 2: return TEMPLATES[selected_template]["js"] else: click.secho("❌ Invalid choice. Please try again.", fg="red") return _choose_template() def _download_repo_with_requests(repo_url: str, path: str) -> None: """Download a ZIP archive from the given URL and extracts it to the specified path. Args: repo_url (str): The URL of the repository to download. path (str): The path where the repository should be extracted. """ click.secho("📥 Attempting to download repository as a ZIP archive...", fg="yellow") click.secho(f"URL: {repo_url}", fg="yellow") try: with request.urlopen(repo_url) as response: if response.status == 200: with ZipFile(BytesIO(response.read())) as zip_file: zip_file.extractall(path) # Move extracted contents to path for item in os.listdir(path): if item.endswith("-main"): extracted_dir = os.path.join(path, item) for filename in os.listdir(extracted_dir): shutil.move(os.path.join(extracted_dir, filename), path) shutil.rmtree(extracted_dir) click.secho( f"✅ Downloaded and extracted repository to {path}", fg="green" ) except error.HTTPError as e: click.secho( f"❌ Error: Failed to download repository.\n" f"Details: {e}\n", fg="red", bold=True, err=True, ) sys.exit(1) def _get_template_url(template_name: str) -> Optional[str]: """ Retrieves the template URL based on the provided template name. Args: template_name (str): The name of the template. Returns: Optional[str]: The URL of the template if found, else None. """ if template_name in TEMPLATES: click.secho(f"Template selected: {template_name}", fg="green") version_choice: int = click.prompt( "Choose version (1 for Python 🐍, 2 for JS/TS 🌐)", type=int ) if version_choice == 1: return TEMPLATES[template_name]["python"] elif version_choice == 2: return TEMPLATES[template_name]["js"] else: click.secho("❌ Invalid choice. Please try again.", fg="red") return None else: click.secho( f"Template '{template_name}' not found. Please select from the available options.", fg="red", ) return None def create_new(path: Optional[str], template: Optional[str]) -> None: """Create a new LangGraph project at the specified PATH using the chosen TEMPLATE. Args: path (Optional[str]): The path where the new project will be created. template (Optional[str]): The name of the template to use. """ # Prompt for path if not provided if not path: path = click.prompt( "📂 Please specify the path to create the application", default="." ) path = os.path.abspath(path) # Ensure path is absolute # Check if path exists and is not empty if os.path.exists(path) and os.listdir(path): click.secho( "❌ The specified directory already exists and is not empty. " "Aborting to prevent overwriting files.", fg="red", bold=True, ) sys.exit(1) # Get template URL either from command-line argument or # through interactive selection if template: if template not in TEMPLATE_ID_TO_CONFIG: # Format available options in a readable way with descriptions template_options = "" for id_ in TEMPLATE_IDS: name, lang, _ = TEMPLATE_ID_TO_CONFIG[id_] description = TEMPLATES[name]["description"] # Add each template option with color formatting template_options += ( click.style("- ", fg="yellow", bold=True) + click.style(f"{id_}", fg="cyan") + click.style(f": {description}", fg="white") + "\n" ) # Display error message with colors and formatting click.secho("❌ Error:", fg="red", bold=True, nl=False) click.secho(f" Template '{template}' not found.", fg="red") click.secho( "Please select from the available options:\n", fg="yellow", bold=True ) click.secho(template_options, fg="cyan") sys.exit(1) _, _, template_url = TEMPLATE_ID_TO_CONFIG[template] else: template_url = _choose_template() # Download and extract the template _download_repo_with_requests(template_url, path) click.secho(f"🎉 New project created at {path}", fg="green", bold=True) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/cli.py`: ```py import os import pathlib import shutil import sys from typing import Callable, Optional, Sequence import click import click.exceptions from click import secho import langgraph_cli.config import langgraph_cli.docker from langgraph_cli.analytics import log_command from langgraph_cli.config import Config from langgraph_cli.constants import DEFAULT_CONFIG, DEFAULT_PORT from langgraph_cli.docker import DockerCapabilities from langgraph_cli.exec import Runner, subp_exec from langgraph_cli.progress import Progress from langgraph_cli.templates import TEMPLATE_HELP_STRING, create_new from langgraph_cli.version import __version__ OPT_DOCKER_COMPOSE = click.option( "--docker-compose", "-d", help="Advanced: Path to docker-compose.yml file with additional services to launch.", type=click.Path( exists=True, file_okay=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path, ), ) OPT_CONFIG = click.option( "--config", "-c", help="""Path to configuration file declaring dependencies, graphs and environment variables. \b Config file must be a JSON file that has the following keys: - "dependencies": array of dependencies for langgraph API server. Dependencies can be one of the following: - ".", which would look for local python packages, as well as pyproject.toml, setup.py or requirements.txt in the app directory - "./local_package" - "<package_name> - "graphs": mapping from graph ID to path where the compiled graph is defined, i.e. ./your_package/your_file.py:variable, where "variable" is an instance of langgraph.graph.graph.CompiledGraph - "env": (optional) path to .env file or a mapping from environment variable to its value - "python_version": (optional) 3.11, 3.12, or 3.13. Defaults to 3.11 - "pip_config_file": (optional) path to pip config file - "dockerfile_lines": (optional) array of additional lines to add to Dockerfile following the import from parent image \b Example: langgraph up -c langgraph.json \b Example: { "dependencies": [ "langchain_openai", "./your_package" ], "graphs": { "my_graph_id": "./your_package/your_file.py:variable" }, "env": "./.env" } \b Example: { "python_version": "3.11", "dependencies": [ "langchain_openai", "." ], "graphs": { "my_graph_id": "./your_package/your_file.py:variable" }, "env": { "OPENAI_API_KEY": "secret-key" } } Defaults to looking for langgraph.json in the current directory.""", default=DEFAULT_CONFIG, type=click.Path( exists=True, file_okay=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path, ), ) OPT_PORT = click.option( "--port", "-p", type=int, default=DEFAULT_PORT, show_default=True, help=""" Port to expose. \b Example: langgraph up --port 8000 \b """, ) OPT_RECREATE = click.option( "--recreate/--no-recreate", default=False, show_default=True, help="Recreate containers even if their configuration and image haven't changed", ) OPT_PULL = click.option( "--pull/--no-pull", default=True, show_default=True, help=""" Pull latest images. Use --no-pull for running the server with locally-built images. \b Example: langgraph up --no-pull \b """, ) OPT_VERBOSE = click.option( "--verbose", is_flag=True, default=False, help="Show more output from the server logs", ) OPT_WATCH = click.option("--watch", is_flag=True, help="Restart on file changes") OPT_DEBUGGER_PORT = click.option( "--debugger-port", type=int, help="Pull the debugger image locally and serve the UI on specified port", ) OPT_DEBUGGER_BASE_URL = click.option( "--debugger-base-url", type=str, help="URL used by the debugger to access LangGraph API. Defaults to http://127.0.0.1:[PORT]", ) OPT_POSTGRES_URI = click.option( "--postgres-uri", help="Postgres URI to use for the database. Defaults to launching a local database", ) @click.group() @click.version_option(version=__version__, prog_name="LangGraph CLI") def cli(): pass @OPT_RECREATE @OPT_PULL @OPT_PORT @OPT_DOCKER_COMPOSE @OPT_CONFIG @OPT_VERBOSE @OPT_DEBUGGER_PORT @OPT_DEBUGGER_BASE_URL @OPT_WATCH @OPT_POSTGRES_URI @click.option( "--wait", is_flag=True, help="Wait for services to start before returning. Implies --detach", ) @cli.command(help="🚀 Launch LangGraph API server.") @log_command def up( config: pathlib.Path, docker_compose: Optional[pathlib.Path], port: int, recreate: bool, pull: bool, watch: bool, wait: bool, verbose: bool, debugger_port: Optional[int], debugger_base_url: Optional[str], postgres_uri: Optional[str], ): click.secho("Starting LangGraph API server...", fg="green") click.secho( """For local dev, requires env var LANGSMITH_API_KEY with access to LangGraph Cloud closed beta. For production use, requires a license key in env var LANGGRAPH_CLOUD_LICENSE_KEY.""", ) with Runner() as runner, Progress(message="Pulling...") as set: capabilities = langgraph_cli.docker.check_capabilities(runner) args, stdin = prepare( runner, capabilities=capabilities, config_path=config, docker_compose=docker_compose, port=port, pull=pull, watch=watch, verbose=verbose, debugger_port=debugger_port, debugger_base_url=debugger_base_url, postgres_uri=postgres_uri, ) # add up + options args.extend(["up", "--remove-orphans"]) if recreate: args.extend(["--force-recreate", "--renew-anon-volumes"]) try: runner.run(subp_exec("docker", "volume", "rm", "langgraph-data")) except click.exceptions.Exit: pass if watch: args.append("--watch") if wait: args.append("--wait") else: args.append("--abort-on-container-exit") # run docker compose set("Building...") def on_stdout(line: str): if "unpacking to docker.io" in line: set("Starting...") elif "Application startup complete" in line: debugger_origin = ( f"http://localhost:{debugger_port}" if debugger_port else "https://smith.langchain.com" ) debugger_base_url_query = ( debugger_base_url or f"http://127.0.0.1:{port}" ) set("") sys.stdout.write( f"""Ready! - API: http://localhost:{port} - Docs: http://localhost:{port}/docs - LangGraph Studio: {debugger_origin}/studio/?baseUrl={debugger_base_url_query} """ ) sys.stdout.flush() return True if capabilities.compose_type == "plugin": compose_cmd = ["docker", "compose"] elif capabilities.compose_type == "standalone": compose_cmd = ["docker-compose"] runner.run( subp_exec( *compose_cmd, *args, input=stdin, verbose=verbose, on_stdout=on_stdout, ) ) def _build( runner, set: Callable[[str], None], config: pathlib.Path, config_json: dict, base_image: Optional[str], pull: bool, tag: str, passthrough: Sequence[str] = (), ): base_image = base_image or ( "langchain/langgraphjs-api" if config_json.get("node_version") else "langchain/langgraph-api" ) # pull latest images if pull: runner.run( subp_exec( "docker", "pull", ( f"{base_image}:{config_json['node_version']}" if config_json.get("node_version") else f"{base_image}:{config_json['python_version']}" ), verbose=True, ) ) set("Building...") # apply options args = [ "-f", "-", # stdin "-t", tag, ] # apply config stdin = langgraph_cli.config.config_to_docker(config, config_json, base_image) # run docker build runner.run( subp_exec( "docker", "build", *args, *passthrough, str(config.parent), input=stdin, verbose=True, ) ) @OPT_CONFIG @OPT_PULL @click.option( "--tag", "-t", help="""Tag for the docker image. \b Example: langgraph build -t my-image \b """, required=True, ) @click.option( "--base-image", hidden=True, ) @click.argument("docker_build_args", nargs=-1, type=click.UNPROCESSED) @cli.command( help="📦 Build LangGraph API server Docker image.", context_settings=dict( ignore_unknown_options=True, ), ) @log_command def build( config: pathlib.Path, docker_build_args: Sequence[str], base_image: Optional[str], pull: bool, tag: str, ): with Runner() as runner, Progress(message="Pulling...") as set: if shutil.which("docker") is None: raise click.UsageError("Docker not installed") from None config_json = langgraph_cli.config.validate_config_file(config) _build( runner, set, config, config_json, base_image, pull, tag, docker_build_args ) def _get_docker_ignore_content() -> str: """Return the content of a .dockerignore file. This file is used to exclude files and directories from the Docker build context. It may be overly broad, but it's better to be safe than sorry. The main goal is to exclude .env files by default. """ return """\ # Ignore node_modules and other dependency directories node_modules bower_components vendor # Ignore logs and temporary files *.log *.tmp *.swp # Ignore .env files and other environment files .env .env.* *.local # Ignore git-related files .git .gitignore # Ignore Docker-related files and configs .dockerignore docker-compose.yml # Ignore build and cache directories dist build .cache __pycache__ # Ignore IDE and editor configurations .vscode .idea *.sublime-project *.sublime-workspace .DS_Store # macOS-specific # Ignore test and coverage files coverage *.coverage *.test.js *.spec.js tests """ @OPT_CONFIG @click.argument("save_path", type=click.Path(resolve_path=True)) @cli.command( help="🐳 Generate a Dockerfile for the LangGraph API server, with Docker Compose options." ) @click.option( # Add a flag for adding a docker-compose.yml file as part of the output "--add-docker-compose", help=( "Add additional files for running the LangGraph API server with " "docker-compose. These files include a docker-compose.yml, .env file, " "and a .dockerignore file." ), is_flag=True, ) @log_command def dockerfile(save_path: str, config: pathlib.Path, add_docker_compose: bool) -> None: save_path = pathlib.Path(save_path).absolute() secho(f"🔍 Validating configuration at path: {config}", fg="yellow") config_json = langgraph_cli.config.validate_config_file(config) secho("✅ Configuration validated!", fg="green") secho(f"📝 Generating Dockerfile at {save_path}", fg="yellow") with open(str(save_path), "w", encoding="utf-8") as f: f.write( langgraph_cli.config.config_to_docker( config, config_json, ( "langchain/langgraphjs-api" if config_json.get("node_version") else "langchain/langgraph-api" ), ) ) secho("✅ Created: Dockerfile", fg="green") if add_docker_compose: # Add docker compose and related files # Add .dockerignore file in the same directory as the Dockerfile with open(str(save_path.parent / ".dockerignore"), "w", encoding="utf-8") as f: f.write(_get_docker_ignore_content()) secho("✅ Created: .dockerignore", fg="green") # Generate a docker-compose.yml file path = str(save_path.parent / "docker-compose.yml") with open(path, "w", encoding="utf-8") as f: with Runner() as runner: capabilities = langgraph_cli.docker.check_capabilities(runner) compose_dict = langgraph_cli.docker.compose_as_dict( capabilities, port=8123, ) # Add .env file to the docker-compose.yml for the langgraph-api service compose_dict["services"]["langgraph-api"]["env_file"] = [".env"] # Add the Dockerfile to the build context compose_dict["services"]["langgraph-api"]["build"] = { "context": ".", "dockerfile": save_path.name, } f.write(langgraph_cli.docker.dict_to_yaml(compose_dict)) secho("✅ Created: docker-compose.yml", fg="green") # Check if the .env file exists in the same directory as the Dockerfile if not (save_path.parent / ".env").exists(): # Also add an empty .env file with open(str(save_path.parent / ".env"), "w", encoding="utf-8") as f: f.writelines( [ "# Uncomment the following line to add your LangSmith API key", "\n", "# LANGSMITH_API_KEY=your-api-key", "\n", "# Or if you have a LangGraph Cloud license key, " "then uncomment the following line: ", "\n", "# LANGGRAPH_CLOUD_LICENSE_KEY=your-license-key", "\n", "# Add any other environment variables go below...", ] ) secho("✅ Created: .env", fg="green") else: # Do nothing since the .env file already exists. Not a great # idea to overwrite in case the user has added custom env vars set # in the .env file already. secho("➖ Skipped: .env. It already exists!", fg="yellow") secho( f"🎉 Files generated successfully at path {save_path.parent}!", fg="cyan", bold=True, ) @click.option( "--host", default="127.0.0.1", help="Network interface to bind the development server to. Default 127.0.0.1 is recommended for security. Only use 0.0.0.0 in trusted networks", ) @click.option( "--port", default=2024, type=int, help="Port number to bind the development server to. Example: langgraph dev --port 8000", ) @click.option( "--no-reload", is_flag=True, help="Disable automatic reloading when code changes are detected", ) @click.option( "--config", type=click.Path(exists=True), default="langgraph.json", help="Path to configuration file declaring dependencies, graphs and environment variables", ) @click.option( "--n-jobs-per-worker", default=None, type=int, help="Maximum number of concurrent jobs each worker process can handle. Default: 10", ) @click.option( "--no-browser", is_flag=True, help="Skip automatically opening the browser when the server starts", ) @click.option( "--debug-port", default=None, type=int, help="Enable remote debugging by listening on specified port. Requires debugpy to be installed", ) @click.option( "--wait-for-client", is_flag=True, help="Wait for a debugger client to connect to the debug port before starting the server", default=False, ) @cli.command( "dev", help="🏃‍♀️‍➡️ Run LangGraph API server in development mode with hot reloading and debugging support", ) @log_command def dev( host: str, port: int, no_reload: bool, config: pathlib.Path, n_jobs_per_worker: Optional[int], no_browser: bool, debug_port: Optional[int], wait_for_client: bool, ): """CLI entrypoint for running the LangGraph API server.""" try: from langgraph_api.cli import run_server except ImportError: try: import pkg_resources pkg_resources.require("langgraph-api-inmem") except (ImportError, pkg_resources.DistributionNotFound): raise click.UsageError( "Required package 'langgraph-api-inmem' is not installed.\n" "Please install it with:\n\n" ' pip install -U "langgraph-cli[inmem]"\n\n' "If you're developing the langgraph-cli package locally, you can install in development mode:\n" " pip install -e ." ) from None raise click.UsageError( "Could not import run_server. This likely means your installation is incomplete.\n" "Please ensure langgraph-cli is installed with the 'inmem' extra: pip install -U \"langgraph-cli[inmem]\"" ) from None config_json = langgraph_cli.config.validate_config_file(config) cwd = os.getcwd() sys.path.append(cwd) dependencies = config_json.get("dependencies", []) for dep in dependencies: dep_path = pathlib.Path(cwd) / dep if dep_path.is_dir() and dep_path.exists(): sys.path.append(str(dep_path)) graphs = config_json.get("graphs", {}) run_server( host, port, not no_reload, graphs, n_jobs_per_worker=n_jobs_per_worker, open_browser=not no_browser, debug_port=debug_port, env=config_json.get("env"), store=config_json.get("store"), wait_for_client=wait_for_client, ) @click.argument("path", required=False) @click.option( "--template", type=str, help=TEMPLATE_HELP_STRING, ) @cli.command("new", help="🌱 Create a new LangGraph project from a template.") @log_command def new(path: Optional[str], template: Optional[str]) -> None: """Create a new LangGraph project from a template.""" return create_new(path, template) def prepare_args_and_stdin( *, capabilities: DockerCapabilities, config_path: pathlib.Path, config: Config, docker_compose: Optional[pathlib.Path], port: int, watch: bool, debugger_port: Optional[int] = None, debugger_base_url: Optional[str] = None, postgres_uri: Optional[str] = None, ): # prepare args stdin = langgraph_cli.docker.compose( capabilities, port=port, debugger_port=debugger_port, debugger_base_url=debugger_base_url, postgres_uri=postgres_uri, ) args = [ "--project-directory", str(config_path.parent), ] # apply options if docker_compose: args.extend(["-f", str(docker_compose)]) args.extend(["-f", "-"]) # stdin # apply config stdin += langgraph_cli.config.config_to_compose( config_path, config, watch=watch, base_image=( "langchain/langgraphjs-api" if config.get("node_version") else "langchain/langgraph-api" ), ) return args, stdin def prepare( runner, *, capabilities: DockerCapabilities, config_path: pathlib.Path, docker_compose: Optional[pathlib.Path], port: int, pull: bool, watch: bool, verbose: bool, debugger_port: Optional[int] = None, debugger_base_url: Optional[str] = None, postgres_uri: Optional[str] = None, ): config_json = langgraph_cli.config.validate_config_file(config_path) # pull latest images if pull: runner.run( subp_exec( "docker", "pull", ( f"langchain/langgraphjs-api:{config_json['node_version']}" if config_json.get("node_version") else f"langchain/langgraph-api:{config_json['python_version']}" ), verbose=verbose, ) ) args, stdin = prepare_args_and_stdin( capabilities=capabilities, config_path=config_path, config=config_json, docker_compose=docker_compose, port=port, watch=watch, debugger_port=debugger_port, debugger_base_url=debugger_base_url or f"http://127.0.0.1:{port}", postgres_uri=postgres_uri, ) return args, stdin ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/analytics.py`: ```py import functools import json import os import pathlib import platform import threading import urllib.error import urllib.request from typing import Any, TypedDict from langgraph_cli.constants import ( DEFAULT_CONFIG, DEFAULT_PORT, SUPABASE_PUBLIC_API_KEY, SUPABASE_URL, ) from langgraph_cli.version import __version__ class LogData(TypedDict): os: str os_version: str python_version: str cli_version: str cli_command: str params: dict[str, Any] def get_anonymized_params(kwargs: dict[str, Any]) -> dict[str, bool]: params = {} # anonymize params with values if config := kwargs.get("config"): if config != pathlib.Path(DEFAULT_CONFIG).resolve(): params["config"] = True if port := kwargs.get("port"): if port != DEFAULT_PORT: params["port"] = True if kwargs.get("docker_compose"): params["docker_compose"] = True if kwargs.get("debugger_port"): params["debugger_port"] = True if kwargs.get("postgres_uri"): params["postgres_uri"] = True # pick up exact values for boolean flags for boolean_param in ["recreate", "pull", "watch", "wait", "verbose"]: if kwargs.get(boolean_param): params[boolean_param] = kwargs[boolean_param] return params def log_data(data: LogData) -> None: headers = { "Content-Type": "application/json", "apikey": SUPABASE_PUBLIC_API_KEY, "User-Agent": "Mozilla/5.0", } supabase_url = SUPABASE_URL req = urllib.request.Request( f"{supabase_url}/rest/v1/logs", data=json.dumps(data).encode("utf-8"), headers=headers, method="POST", ) try: urllib.request.urlopen(req) except urllib.error.URLError: pass def log_command(func): @functools.wraps(func) def decorator(*args, **kwargs): if os.getenv("LANGGRAPH_CLI_NO_ANALYTICS") == "1": return func(*args, **kwargs) data = { "os": platform.system(), "os_version": platform.version(), "python_version": platform.python_version(), "cli_version": __version__, "cli_command": func.__name__, "params": get_anonymized_params(kwargs), } background_thread = threading.Thread(target=log_data, args=(data,)) background_thread.start() return func(*args, **kwargs) return decorator ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/progress.py`: ```py import sys import threading import time from typing import Callable class Progress: delay: float = 0.1 @staticmethod def spinning_cursor(): while True: yield from "|/-\\" def __init__(self, *, message=""): self.message = message self.spinner_generator = self.spinning_cursor() def spinner_iteration(self): message = self.message sys.stdout.write(next(self.spinner_generator) + " " + message) sys.stdout.flush() time.sleep(self.delay) # clear the spinner and message sys.stdout.write( "\b" * (len(message) + 2) + " " * (len(message) + 2) + "\b" * (len(message) + 2) ) sys.stdout.flush() def spinner_task(self): while self.message: message = self.message sys.stdout.write(next(self.spinner_generator) + " " + message) sys.stdout.flush() time.sleep(self.delay) # clear the spinner and message sys.stdout.write( "\b" * (len(message) + 2) + " " * (len(message) + 2) + "\b" * (len(message) + 2) ) sys.stdout.flush() def __enter__(self) -> Callable[[str], None]: self.thread = threading.Thread(target=self.spinner_task) self.thread.start() def set_message(message): self.message = message if not message: self.thread.join() return set_message def __exit__(self, exception, value, tb): self.message = "" try: self.thread.join() finally: del self.thread if exception is not None: return False ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/pyproject.toml`: ```toml [tool.poetry] name = "langgraph-cli" version = "0.1.61" description = "CLI for interacting with LangGraph API" authors = [] license = "MIT" readme = "README.md" repository = "https://www.github.com/langchain-ai/langgraph" packages = [{ include = "langgraph_cli" }] [tool.poetry.scripts] langgraph = "langgraph_cli.cli:cli" [tool.poetry.dependencies] python = "^3.9.0,<4.0" click = "^8.1.7" langgraph-api = { version = ">=0.0.6,<0.1.0", optional = true, python = ">=3.11,<4.0" } python-dotenv = { version = ">=0.8.0", optional = true } [tool.poetry.group.dev.dependencies] ruff = "^0.6.2" codespell = "^2.2.0" pytest = "^7.2.1" pytest-asyncio = "^0.21.1" pytest-mock = "^3.11.1" pytest-watch = "^4.2.0" mypy = "^1.10.0" [tool.poetry.extras] inmem = ["langgraph-api", "python-dotenv"] [tool.pytest.ini_options] # --strict-markers will raise errors on unknown marks. # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks # # https://docs.pytest.org/en/7.1.x/reference/reference.html # --strict-config any warnings encountered while parsing the `pytest` # section of the configuration file raise errors. addopts = "--strict-markers --strict-config --durations=5 -vv" asyncio_mode = "auto" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] lint.select = [ # pycodestyle "E", # Pyflakes "F", # pyupgrade "UP", # flake8-bugbear "B", # isort "I", ] lint.ignore = ["E501", "B008"] ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/conftest.py`: ```py import os from unittest.mock import patch import pytest @pytest.fixture(autouse=True) def disable_analytics_env() -> None: """Disable analytics for unit tests LANGGRAPH_CLI_NO_ANALYTICS.""" # First check if the environment variable is already set, if so, log a warning prior # to overriding it. if "LANGGRAPH_CLI_NO_ANALYTICS" in os.environ: print("⚠️ LANGGRAPH_CLI_NO_ANALYTICS is set. Overriding it for the test.") with patch.dict(os.environ, {"LANGGRAPH_CLI_NO_ANALYTICS": "0"}): yield ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/cli/test_templates.py`: ```py """Unit tests for the 'new' CLI command. This command creates a new LangGraph project using a specified template. """ import os from io import BytesIO from pathlib import Path from tempfile import TemporaryDirectory from unittest.mock import MagicMock, patch from urllib import request from zipfile import ZipFile from click.testing import CliRunner from langgraph_cli.cli import cli from langgraph_cli.templates import TEMPLATE_ID_TO_CONFIG @patch.object(request, "urlopen") def test_create_new_with_mocked_download(mock_urlopen: MagicMock) -> None: """Test the 'new' CLI command with a mocked download response using urllib.""" # Mock the response content to simulate a ZIP file mock_zip_content = BytesIO() with ZipFile(mock_zip_content, "w") as mock_zip: mock_zip.writestr("test-file.txt", "Test content.") # Create a mock response that behaves like a context manager mock_response = MagicMock() mock_response.read.return_value = mock_zip_content.getvalue() mock_response.__enter__.return_value = mock_response # Setup enter context mock_response.status = 200 mock_urlopen.return_value = mock_response with TemporaryDirectory() as temp_dir: runner = CliRunner() template = next( iter(TEMPLATE_ID_TO_CONFIG) ) # Select the first template for the test result = runner.invoke(cli, ["new", temp_dir, "--template", template]) # Verify CLI command execution and success assert result.exit_code == 0, result.output assert ( "New project created" in result.output ), "Expected success message in output." # Verify that the directory is not empty assert os.listdir(temp_dir), "Expected files to be created in temp directory." # Check for a known file in the extracted content extracted_files = [f.name for f in Path(temp_dir).glob("*")] assert ( "test-file.txt" in extracted_files ), "Expected 'test-file.txt' in the extracted content." def test_invalid_template_id() -> None: """Test that an invalid template ID passed via CLI results in a graceful error.""" runner = CliRunner() result = runner.invoke( cli, ["new", "dummy_path", "--template", "invalid-template-id"] ) # Verify the command failed and proper message is displayed assert result.exit_code != 0, "Expected non-zero exit code for invalid template." assert ( "Template 'invalid-template-id' not found" in result.output ), "Expected error message in output." ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/cli/test_cli.py`: ```py import json import pathlib import shutil import tempfile from contextlib import contextmanager from pathlib import Path from click.testing import CliRunner from langgraph_cli.cli import cli, prepare_args_and_stdin from langgraph_cli.config import Config, validate_config from langgraph_cli.docker import DEFAULT_POSTGRES_URI, DockerCapabilities, Version from langgraph_cli.util import clean_empty_lines DEFAULT_DOCKER_CAPABILITIES = DockerCapabilities( version_docker=Version(26, 1, 1), version_compose=Version(2, 27, 0), healthcheck_start_interval=True, ) @contextmanager def temporary_config_folder(config_content: dict): # Create a temporary directory temp_dir = tempfile.mkdtemp() try: # Define the path for the config.json file config_path = Path(temp_dir) / "config.json" # Write the provided dictionary content to config.json with open(config_path, "w", encoding="utf-8") as config_file: json.dump(config_content, config_file) # Yield the temporary directory path for use within the context yield config_path.parent finally: # Cleanup the temporary directory and its contents shutil.rmtree(temp_dir) def test_prepare_args_and_stdin() -> None: # this basically serves as an end-to-end test for using config and docker helpers config_path = pathlib.Path("./langgraph.json") config = validate_config( Config(dependencies=["."], graphs={"agent": "agent.py:graph"}) ) port = 8000 debugger_port = 8001 debugger_graph_url = f"http://127.0.0.1:{port}" actual_args, actual_stdin = prepare_args_and_stdin( capabilities=DEFAULT_DOCKER_CAPABILITIES, config_path=config_path, config=config, docker_compose=pathlib.Path("custom-docker-compose.yml"), port=port, debugger_port=debugger_port, debugger_base_url=debugger_graph_url, watch=True, ) expected_args = [ "--project-directory", ".", "-f", "custom-docker-compose.yml", "-f", "-", ] expected_stdin = f"""volumes: langgraph-data: driver: local services: langgraph-redis: image: redis:6 healthcheck: test: redis-cli ping interval: 5s timeout: 1s retries: 5 langgraph-postgres: image: postgres:16 ports: - "5433:5432" environment: POSTGRES_DB: postgres POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres volumes: - langgraph-data:/var/lib/postgresql/data healthcheck: test: pg_isready -U postgres start_period: 10s timeout: 1s retries: 5 interval: 60s start_interval: 1s langgraph-debugger: image: langchain/langgraph-debugger restart: on-failure depends_on: langgraph-postgres: condition: service_healthy ports: - "{debugger_port}:3968" environment: VITE_STUDIO_LOCAL_GRAPH_URL: {debugger_graph_url} langgraph-api: ports: - "8000:8000" depends_on: langgraph-redis: condition: service_healthy langgraph-postgres: condition: service_healthy environment: REDIS_URI: redis://langgraph-redis:6379 POSTGRES_URI: {DEFAULT_POSTGRES_URI} healthcheck: test: python /api/healthcheck.py interval: 60s start_interval: 1s start_period: 10s pull_policy: build build: context: . dockerfile_inline: | FROM langchain/langgraph-api:3.11 ADD . /deps/ RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* ENV LANGSERVE_GRAPHS='{{"agent": "agent.py:graph"}}' WORKDIR /deps/ develop: watch: - path: langgraph.json action: rebuild - path: . action: rebuild\ """ assert actual_args == expected_args assert clean_empty_lines(actual_stdin) == expected_stdin def test_version_option() -> None: """Test the --version option of the CLI.""" runner = CliRunner() result = runner.invoke(cli, ["--version"]) # Verify that the command executed successfully assert result.exit_code == 0, "Expected exit code 0 for --version option" # Check that the output contains the correct version information assert ( "LangGraph CLI, version" in result.output ), "Expected version information in output" def test_dockerfile_command_basic() -> None: """Test the 'dockerfile' command with basic configuration.""" runner = CliRunner() config_content = { "node_version": "20", # Add any other necessary configuration fields "graphs": {"agent": "agent.py:graph"}, } with temporary_config_folder(config_content) as temp_dir: save_path = temp_dir / "Dockerfile" result = runner.invoke( cli, ["dockerfile", str(save_path), "--config", str(temp_dir / "config.json")], ) # Assert command was successful assert result.exit_code == 0, result.output assert "✅ Created: Dockerfile" in result.output # Check if Dockerfile was created assert save_path.exists() def test_dockerfile_command_with_docker_compose() -> None: """Test the 'dockerfile' command with Docker Compose configuration.""" runner = CliRunner() config_content = { "dependencies": ["./my_agent"], "graphs": {"agent": "./my_agent/agent.py:graph"}, "env": ".env", } with temporary_config_folder(config_content) as temp_dir: save_path = temp_dir / "Dockerfile" # Add agent.py file agent_path = temp_dir / "my_agent" / "agent.py" agent_path.parent.mkdir(parents=True, exist_ok=True) agent_path.touch() result = runner.invoke( cli, [ "dockerfile", str(save_path), "--config", str(temp_dir / "config.json"), "--add-docker-compose", ], ) # Assert command was successful assert result.exit_code == 0 assert "✅ Created: Dockerfile" in result.output assert "✅ Created: .dockerignore" in result.output assert "✅ Created: docker-compose.yml" in result.output assert ( "✅ Created: .env" in result.output or "➖ Skipped: .env" in result.output ) assert "🎉 Files generated successfully" in result.output # Check if Dockerfile, .dockerignore, docker-compose.yml, and .env were created assert save_path.exists() assert (temp_dir / ".dockerignore").exists() assert (temp_dir / "docker-compose.yml").exists() assert (temp_dir / ".env").exists() or "➖ Skipped: .env" in result.output def test_dockerfile_command_with_bad_config() -> None: """Test the 'dockerfile' command with basic configuration.""" runner = CliRunner() config_content = { "node_version": "20" # Add any other necessary configuration fields } with temporary_config_folder(config_content) as temp_dir: save_path = temp_dir / "Dockerfile" result = runner.invoke( cli, ["dockerfile", str(save_path), "--config", str(temp_dir / "conf.json")], ) # Assert command was successful assert result.exit_code == 2 assert "conf.json' does not exist" in result.output ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/agent.py`: ```py import asyncio import os from typing import Annotated, Sequence, TypedDict from langchain_core.language_models.fake_chat_models import FakeListChatModel from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage from langgraph.graph import END, StateGraph, add_messages # check that env var is present os.environ["SOME_ENV_VAR"] class AgentState(TypedDict): some_bytes: bytes some_byte_array: bytearray dict_with_bytes: dict[str, bytes] messages: Annotated[Sequence[BaseMessage], add_messages] sleep: int async def call_model(state, config): if sleep := state.get("sleep"): await asyncio.sleep(sleep) messages = state["messages"] if len(messages) > 1: assert state["some_bytes"] == b"some_bytes" assert state["some_byte_array"] == bytearray(b"some_byte_array") assert state["dict_with_bytes"] == {"more_bytes": b"more_bytes"} # hacky way to reset model to the "first" response if isinstance(messages[-1], HumanMessage): model.i = 0 response = await model.ainvoke(messages) return { "messages": [response], "some_bytes": b"some_bytes", "some_byte_array": bytearray(b"some_byte_array"), "dict_with_bytes": {"more_bytes": b"more_bytes"}, } def call_tool(state): last_message_content = state["messages"][-1].content return { "messages": [ ToolMessage( f"tool_call__{last_message_content}", tool_call_id="tool_call_id" ) ] } def should_continue(state): messages = state["messages"] last_message = messages[-1] if last_message.content == "end": return END else: return "tool" # NOTE: the model cycles through responses infinitely here model = FakeListChatModel(responses=["begin", "end"]) workflow = StateGraph(AgentState) workflow.add_node("agent", call_model) workflow.add_node("tool", call_tool) workflow.set_entry_point("agent") workflow.add_conditional_edges( "agent", should_continue, ) workflow.add_edge("tool", "agent") graph = workflow.compile() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/test_config.py`: ```py import json import os import pathlib import tempfile import click import pytest from langgraph_cli.config import ( config_to_compose, config_to_docker, validate_config, validate_config_file, ) from langgraph_cli.util import clean_empty_lines PATH_TO_CONFIG = pathlib.Path(__file__).parent / "test_config.json" def test_validate_config(): # minimal config expected_config = { "dependencies": ["."], "graphs": { "agent": "./agent.py:graph", }, } expected_config = { "python_version": "3.11", "pip_config_file": None, "dockerfile_lines": [], "env": {}, "store": None, **expected_config, } actual_config = validate_config(expected_config) assert actual_config == expected_config # full config env = ".env" expected_config = { "python_version": "3.12", "pip_config_file": "pipconfig.txt", "dockerfile_lines": ["ARG meow"], "dependencies": [".", "langchain"], "graphs": { "agent": "./agent.py:graph", }, "env": env, "store": None, } actual_config = validate_config(expected_config) assert actual_config == expected_config expected_config["python_version"] = "3.13" actual_config = validate_config(expected_config) assert actual_config == expected_config # check wrong python version raises with pytest.raises(click.UsageError): validate_config( { "python_version": "3.9", } ) # check missing dependencies key raises with pytest.raises(click.UsageError): validate_config( {"python_version": "3.9", "graphs": {"agent": "./agent.py:graph"}}, ) # check missing graphs key raises with pytest.raises(click.UsageError): validate_config({"python_version": "3.9", "dependencies": ["."]}) with pytest.raises(click.UsageError) as exc_info: validate_config({"python_version": "3.11.0"}) assert "Invalid Python version format" in str(exc_info.value) with pytest.raises(click.UsageError) as exc_info: validate_config({"python_version": "3"}) assert "Invalid Python version format" in str(exc_info.value) with pytest.raises(click.UsageError) as exc_info: validate_config({"python_version": "abc.def"}) assert "Invalid Python version format" in str(exc_info.value) with pytest.raises(click.UsageError) as exc_info: validate_config({"python_version": "3.10"}) assert "Minimum required version" in str(exc_info.value) def test_validate_config_file(): with tempfile.TemporaryDirectory() as tmpdir: tmpdir_path = pathlib.Path(tmpdir) config_path = tmpdir_path / "langgraph.json" node_config = {"node_version": "20", "graphs": {"agent": "./agent.js:graph"}} with open(config_path, "w") as f: json.dump(node_config, f) validate_config_file(config_path) package_json = {"name": "test", "engines": {"node": "20"}} with open(tmpdir_path / "package.json", "w") as f: json.dump(package_json, f) validate_config_file(config_path) package_json["engines"]["node"] = "20.18" with open(tmpdir_path / "package.json", "w") as f: json.dump(package_json, f) with pytest.raises(click.UsageError, match="Use major version only"): validate_config_file(config_path) package_json["engines"] = {"node": "18"} with open(tmpdir_path / "package.json", "w") as f: json.dump(package_json, f) with pytest.raises(click.UsageError, match="must be >= 20"): validate_config_file(config_path) package_json["engines"] = {"node": "20", "deno": "1.0"} with open(tmpdir_path / "package.json", "w") as f: json.dump(package_json, f) with pytest.raises(click.UsageError, match="Only 'node' engine is supported"): validate_config_file(config_path) with open(tmpdir_path / "package.json", "w") as f: f.write("{invalid json") with pytest.raises(click.UsageError, match="Invalid package.json"): validate_config_file(config_path) python_config = { "python_version": "3.11", "dependencies": ["."], "graphs": {"agent": "./agent.py:graph"}, } with open(config_path, "w") as f: json.dump(python_config, f) validate_config_file(config_path) for package_content in [ {"name": "test"}, {"engines": {"node": "18"}}, {"engines": {"node": "20", "deno": "1.0"}}, "{invalid json", ]: with open(tmpdir_path / "package.json", "w") as f: if isinstance(package_content, dict): json.dump(package_content, f) else: f.write(package_content) validate_config_file(config_path) # config_to_docker def test_config_to_docker_simple(): graphs = {"agent": "./agent.py:graph"} actual_docker_stdin = config_to_docker( PATH_TO_CONFIG, validate_config({"dependencies": ["."], "graphs": graphs}), "langchain/langgraph-api", ) expected_docker_stdin = """\ FROM langchain/langgraph-api:3.11 ADD . /deps/__outer_unit_tests/unit_tests RUN set -ex && \\ for line in '[project]' \\ 'name = "unit_tests"' \\ 'version = "0.1"' \\ '[tool.setuptools.package-data]' \\ '"*" = ["**/*"]'; do \\ echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\ done RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}' WORKDIR /deps/__outer_unit_tests/unit_tests\ """ assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin def test_config_to_docker_pipconfig(): graphs = {"agent": "./agent.py:graph"} actual_docker_stdin = config_to_docker( PATH_TO_CONFIG, validate_config( { "dependencies": ["."], "graphs": graphs, "pip_config_file": "pipconfig.txt", } ), "langchain/langgraph-api", ) expected_docker_stdin = """\ FROM langchain/langgraph-api:3.11 ADD pipconfig.txt /pipconfig.txt ADD . /deps/__outer_unit_tests/unit_tests RUN set -ex && \\ for line in '[project]' \\ 'name = "unit_tests"' \\ 'version = "0.1"' \\ '[tool.setuptools.package-data]' \\ '"*" = ["**/*"]'; do \\ echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\ done RUN PIP_CONFIG_FILE=/pipconfig.txt PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}' WORKDIR /deps/__outer_unit_tests/unit_tests\ """ assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin def test_config_to_docker_invalid_inputs(): # test missing local dependencies with pytest.raises(FileNotFoundError): graphs = {"agent": "tests/unit_tests/agent.py:graph"} config_to_docker( PATH_TO_CONFIG, validate_config({"dependencies": ["./missing"], "graphs": graphs}), "langchain/langgraph-api", ) # test missing local module with pytest.raises(FileNotFoundError): graphs = {"agent": "./missing_agent.py:graph"} config_to_docker( PATH_TO_CONFIG, validate_config({"dependencies": ["."], "graphs": graphs}), "langchain/langgraph-api", ) def test_config_to_docker_local_deps(): graphs = {"agent": "./graphs/agent.py:graph"} actual_docker_stdin = config_to_docker( PATH_TO_CONFIG, validate_config( { "dependencies": ["./graphs"], "graphs": graphs, } ), "langchain/langgraph-api-custom", ) expected_docker_stdin = """\ FROM langchain/langgraph-api-custom:3.11 ADD ./graphs /deps/__outer_graphs/src RUN set -ex && \\ for line in '[project]' \\ 'name = "graphs"' \\ 'version = "0.1"' \\ '[tool.setuptools.package-data]' \\ '"*" = ["**/*"]'; do \\ echo "$line" >> /deps/__outer_graphs/pyproject.toml; \\ done RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_graphs/src/agent.py:graph"}'\ """ assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin def test_config_to_docker_pyproject(): pyproject_str = """[project] name = "custom" version = "0.1" dependencies = ["langchain"]""" pyproject_path = "tests/unit_tests/pyproject.toml" with open(pyproject_path, "w") as f: f.write(pyproject_str) graphs = {"agent": "./graphs/agent.py:graph"} actual_docker_stdin = config_to_docker( PATH_TO_CONFIG, validate_config( { "dependencies": ["."], "graphs": graphs, } ), "langchain/langgraph-api", ) os.remove(pyproject_path) expected_docker_stdin = """FROM langchain/langgraph-api:3.11 ADD . /deps/unit_tests RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* ENV LANGSERVE_GRAPHS='{"agent": "/deps/unit_tests/graphs/agent.py:graph"}' WORKDIR /deps/unit_tests""" assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin def test_config_to_docker_end_to_end(): graphs = {"agent": "./graphs/agent.py:graph"} actual_docker_stdin = config_to_docker( PATH_TO_CONFIG, validate_config( { "python_version": "3.12", "dependencies": ["./graphs/", "langchain", "langchain_openai"], "graphs": graphs, "pip_config_file": "pipconfig.txt", "dockerfile_lines": ["ARG meow", "ARG foo"], } ), "langchain/langgraph-api", ) expected_docker_stdin = """FROM langchain/langgraph-api:3.12 ARG meow ARG foo ADD pipconfig.txt /pipconfig.txt RUN PIP_CONFIG_FILE=/pipconfig.txt PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt langchain langchain_openai ADD ./graphs/ /deps/__outer_graphs/src RUN set -ex && \\ for line in '[project]' \\ 'name = "graphs"' \\ 'version = "0.1"' \\ '[tool.setuptools.package-data]' \\ '"*" = ["**/*"]'; do \\ echo "$line" >> /deps/__outer_graphs/pyproject.toml; \\ done RUN PIP_CONFIG_FILE=/pipconfig.txt PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_graphs/src/agent.py:graph"}'""" assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin # node.js build used for LangGraph Cloud def test_config_to_docker_nodejs(): graphs = {"agent": "./graphs/agent.js:graph"} actual_docker_stdin = config_to_docker( PATH_TO_CONFIG, validate_config( { "node_version": "20", "graphs": graphs, "dockerfile_lines": ["ARG meow", "ARG foo"], } ), "langchain/langgraphjs-api", ) expected_docker_stdin = """FROM langchain/langgraphjs-api:20 ARG meow ARG foo ADD . /deps/unit_tests RUN cd /deps/unit_tests && npm i ENV LANGSERVE_GRAPHS='{"agent": "./graphs/agent.js:graph"}' WORKDIR /deps/unit_tests RUN (test ! -f /api/langgraph_api/js/build.mts && echo "Prebuild script not found, skipping") || tsx /api/langgraph_api/js/build.mts""" assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin # config_to_compose def test_config_to_compose_simple_config(): graphs = {"agent": "./agent.py:graph"} expected_compose_stdin = """\ pull_policy: build build: context: . dockerfile_inline: | FROM langchain/langgraph-api:3.11 ADD . /deps/__outer_unit_tests/unit_tests RUN set -ex && \\ for line in '[project]' \\ 'name = "unit_tests"' \\ 'version = "0.1"' \\ '[tool.setuptools.package-data]' \\ '"*" = ["**/*"]'; do \\ echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\ done RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}' WORKDIR /deps/__outer_unit_tests/unit_tests """ actual_compose_stdin = config_to_compose( PATH_TO_CONFIG, validate_config({"dependencies": ["."], "graphs": graphs}), "langchain/langgraph-api", ) assert clean_empty_lines(actual_compose_stdin) == expected_compose_stdin def test_config_to_compose_env_vars(): graphs = {"agent": "./agent.py:graph"} expected_compose_stdin = """ OPENAI_API_KEY: "key" pull_policy: build build: context: . dockerfile_inline: | FROM langchain/langgraph-api-custom:3.11 ADD . /deps/__outer_unit_tests/unit_tests RUN set -ex && \\ for line in '[project]' \\ 'name = "unit_tests"' \\ 'version = "0.1"' \\ '[tool.setuptools.package-data]' \\ '"*" = ["**/*"]'; do \\ echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\ done RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}' WORKDIR /deps/__outer_unit_tests/unit_tests """ openai_api_key = "key" actual_compose_stdin = config_to_compose( PATH_TO_CONFIG, validate_config( { "dependencies": ["."], "graphs": graphs, "env": {"OPENAI_API_KEY": openai_api_key}, } ), "langchain/langgraph-api-custom", ) assert clean_empty_lines(actual_compose_stdin) == expected_compose_stdin def test_config_to_compose_env_file(): graphs = {"agent": "./agent.py:graph"} expected_compose_stdin = """\ env_file: .env pull_policy: build build: context: . dockerfile_inline: | FROM langchain/langgraph-api:3.11 ADD . /deps/__outer_unit_tests/unit_tests RUN set -ex && \\ for line in '[project]' \\ 'name = "unit_tests"' \\ 'version = "0.1"' \\ '[tool.setuptools.package-data]' \\ '"*" = ["**/*"]'; do \\ echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\ done RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}' WORKDIR /deps/__outer_unit_tests/unit_tests """ actual_compose_stdin = config_to_compose( PATH_TO_CONFIG, validate_config({"dependencies": ["."], "graphs": graphs, "env": ".env"}), "langchain/langgraph-api", ) assert clean_empty_lines(actual_compose_stdin) == expected_compose_stdin def test_config_to_compose_watch(): graphs = {"agent": "./agent.py:graph"} expected_compose_stdin = """\ pull_policy: build build: context: . dockerfile_inline: | FROM langchain/langgraph-api:3.11 ADD . /deps/__outer_unit_tests/unit_tests RUN set -ex && \\ for line in '[project]' \\ 'name = "unit_tests"' \\ 'version = "0.1"' \\ '[tool.setuptools.package-data]' \\ '"*" = ["**/*"]'; do \\ echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\ done RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}' WORKDIR /deps/__outer_unit_tests/unit_tests develop: watch: - path: test_config.json action: rebuild - path: . action: rebuild\ """ actual_compose_stdin = config_to_compose( PATH_TO_CONFIG, validate_config({"dependencies": ["."], "graphs": graphs}), "langchain/langgraph-api", watch=True, ) assert clean_empty_lines(actual_compose_stdin) == expected_compose_stdin def test_config_to_compose_end_to_end(): # test all of the above + langgraph API path graphs = {"agent": "./agent.py:graph"} expected_compose_stdin = """\ env_file: .env pull_policy: build build: context: . dockerfile_inline: | FROM langchain/langgraph-api:3.11 ADD . /deps/__outer_unit_tests/unit_tests RUN set -ex && \\ for line in '[project]' \\ 'name = "unit_tests"' \\ 'version = "0.1"' \\ '[tool.setuptools.package-data]' \\ '"*" = ["**/*"]'; do \\ echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\ done RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}' WORKDIR /deps/__outer_unit_tests/unit_tests develop: watch: - path: test_config.json action: rebuild - path: . action: rebuild\ """ actual_compose_stdin = config_to_compose( PATH_TO_CONFIG, validate_config({"dependencies": ["."], "graphs": graphs, "env": ".env"}), "langchain/langgraph-api", watch=True, ) assert clean_empty_lines(actual_compose_stdin) == expected_compose_stdin ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/helpers.py`: ```py def clean_empty_lines(input_str: str): return "\n".join(filter(None, input_str.splitlines())) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/test_docker.py`: ```py from langgraph_cli.docker import ( DEFAULT_POSTGRES_URI, DockerCapabilities, Version, compose, ) from langgraph_cli.util import clean_empty_lines DEFAULT_DOCKER_CAPABILITIES = DockerCapabilities( version_docker=Version(26, 1, 1), version_compose=Version(2, 27, 0), healthcheck_start_interval=False, ) def test_compose_with_no_debugger_and_custom_db(): port = 8123 custom_postgres_uri = "custom_postgres_uri" actual_compose_str = compose( DEFAULT_DOCKER_CAPABILITIES, port=port, postgres_uri=custom_postgres_uri ) expected_compose_str = f"""services: langgraph-redis: image: redis:6 healthcheck: test: redis-cli ping interval: 5s timeout: 1s retries: 5 langgraph-api: ports: - "{port}:8000" depends_on: langgraph-redis: condition: service_healthy environment: REDIS_URI: redis://langgraph-redis:6379 POSTGRES_URI: {custom_postgres_uri}""" assert clean_empty_lines(actual_compose_str) == expected_compose_str def test_compose_with_no_debugger_and_custom_db_with_healthcheck(): port = 8123 custom_postgres_uri = "custom_postgres_uri" actual_compose_str = compose( DEFAULT_DOCKER_CAPABILITIES._replace(healthcheck_start_interval=True), port=port, postgres_uri=custom_postgres_uri, ) expected_compose_str = f"""services: langgraph-redis: image: redis:6 healthcheck: test: redis-cli ping interval: 5s timeout: 1s retries: 5 langgraph-api: ports: - "{port}:8000" depends_on: langgraph-redis: condition: service_healthy environment: REDIS_URI: redis://langgraph-redis:6379 POSTGRES_URI: {custom_postgres_uri} healthcheck: test: python /api/healthcheck.py interval: 60s start_interval: 1s start_period: 10s""" assert clean_empty_lines(actual_compose_str) == expected_compose_str def test_compose_with_debugger_and_custom_db(): port = 8123 custom_postgres_uri = "custom_postgres_uri" actual_compose_str = compose( DEFAULT_DOCKER_CAPABILITIES, port=port, postgres_uri=custom_postgres_uri, ) expected_compose_str = f"""services: langgraph-redis: image: redis:6 healthcheck: test: redis-cli ping interval: 5s timeout: 1s retries: 5 langgraph-api: ports: - "{port}:8000" depends_on: langgraph-redis: condition: service_healthy environment: REDIS_URI: redis://langgraph-redis:6379 POSTGRES_URI: {custom_postgres_uri}""" assert clean_empty_lines(actual_compose_str) == expected_compose_str def test_compose_with_debugger_and_default_db(): port = 8123 actual_compose_str = compose(DEFAULT_DOCKER_CAPABILITIES, port=port) expected_compose_str = f"""volumes: langgraph-data: driver: local services: langgraph-redis: image: redis:6 healthcheck: test: redis-cli ping interval: 5s timeout: 1s retries: 5 langgraph-postgres: image: postgres:16 ports: - "5433:5432" environment: POSTGRES_DB: postgres POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres volumes: - langgraph-data:/var/lib/postgresql/data healthcheck: test: pg_isready -U postgres start_period: 10s timeout: 1s retries: 5 interval: 5s langgraph-api: ports: - "{port}:8000" depends_on: langgraph-redis: condition: service_healthy langgraph-postgres: condition: service_healthy environment: REDIS_URI: redis://langgraph-redis:6379 POSTGRES_URI: {DEFAULT_POSTGRES_URI}""" assert clean_empty_lines(actual_compose_str) == expected_compose_str ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/integration_tests/test_cli.py`: ```py import pytest import requests from langgraph_cli.templates import TEMPLATE_ID_TO_CONFIG @pytest.mark.parametrize("template_key", TEMPLATE_ID_TO_CONFIG.keys()) def test_template_urls_work(template_key: str) -> None: """Integration test to verify that all template URLs are reachable.""" _, _, template_url = TEMPLATE_ID_TO_CONFIG[template_key] response = requests.head(template_url) # Returns 302 on a successful HEAD request assert response.status_code == 302, f"URL {template_url} is not reachable." ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/pyproject.toml`: ```toml [tool.poetry] name = "langgraph-examples" version = "0.1.0" description = "" authors = [] readme = "README.md" packages = [] package-mode = false [tool.poetry.dependencies] python = "^3.9.0,<4.0" langgraph-cli = {path = "../../cli", develop = true} langgraph-sdk = {path = "../../sdk-py", develop = true} [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs_reqs_b/hello.py`: ```py from graphs_submod.agent import graph # noqa from utils.greeter import greet greet() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs_reqs_b/graphs_submod/agent.py`: ```py from pathlib import Path from typing import Annotated, Sequence, TypedDict from langchain_anthropic import ChatAnthropic from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.messages import BaseMessage from langchain_openai import ChatOpenAI from langgraph.graph import END, StateGraph, add_messages from langgraph.prebuilt import ToolNode tools = [TavilySearchResults(max_results=1)] model_anth = ChatAnthropic(temperature=0, model_name="claude-3-sonnet-20240229") model_oai = ChatOpenAI(temperature=0) model_anth = model_anth.bind_tools(tools) model_oai = model_oai.bind_tools(tools) prompt = open("prompt.txt").read() subprompt = open(Path(__file__).parent / "subprompt.txt").read() class AgentState(TypedDict): messages: Annotated[Sequence[BaseMessage], add_messages] # Define the function that determines whether to continue or not def should_continue(state): messages = state["messages"] last_message = messages[-1] # If there are no tool calls, then we finish if not last_message.tool_calls: return "end" # Otherwise if there is, we continue else: return "continue" # Define the function that calls the model def call_model(state, config): if config["configurable"].get("model", "anthropic") == "anthropic": model = model_anth else: model = model_oai messages = state["messages"] response = model.invoke(messages) # We return a list, because this will get added to the existing list return {"messages": [response]} # Define the function to execute tools tool_node = ToolNode(tools) # Define a new graph workflow = StateGraph(AgentState) # Define the two nodes we will cycle between workflow.add_node("agent", call_model) workflow.add_node("action", tool_node) # Set the entrypoint as `agent` # This means that this node is the first one called workflow.set_entry_point("agent") # We now add a conditional edge workflow.add_conditional_edges( # First, we define the start node. We use `agent`. # This means these are the edges taken after the `agent` node is called. "agent", # Next, we pass in the function that will determine which node is called next. should_continue, # Finally we pass in a mapping. # The keys are strings, and the values are other nodes. # END is a special node marking that the graph should finish. # What will happen is we will call `should_continue`, and then the output of that # will be matched against the keys in this mapping. # Based on which one it matches, that node will then be called. { # If `tools`, then we call the tool node. "continue": "action", # Otherwise we finish. "end": END, }, ) # We now add a normal edge from `tools` to `agent`. # This means that after `tools` is called, `agent` node is called next. workflow.add_edge("action", "agent") # Finally, we compile it! # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable graph = workflow.compile() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs_reqs_b/utils/greeter.py`: ```py def greet(): print("Hello, world!") ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs/storm.py`: ```py import asyncio import json from typing import Annotated, List, Optional from langchain_community.retrievers import WikipediaRetriever from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.vectorstores import SKLearnVectorStore from langchain_core.documents import Document from langchain_core.messages import ( AIMessage, AnyMessage, HumanMessage, ToolMessage, ) from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableConfig, RunnableLambda from langchain_core.runnables import chain as as_runnable from langchain_core.tools import tool from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langgraph.graph import END, StateGraph from pydantic import BaseModel, Field from typing_extensions import TypedDict fast_llm = ChatOpenAI(model="gpt-3.5-turbo") # Uncomment for a Fireworks model # fast_llm = ChatFireworks(model="accounts/fireworks/models/firefunction-v1", max_tokens=32_000) long_context_llm = ChatOpenAI(model="gpt-4-turbo-preview") direct_gen_outline_prompt = ChatPromptTemplate.from_messages( [ ( "system", "You are a Wikipedia writer. Write an outline for a Wikipedia page about a user-provided topic. Be comprehensive and specific.", ), ("user", "{topic}"), ] ) class Subsection(BaseModel): subsection_title: str = Field(..., title="Title of the subsection") description: str = Field(..., title="Content of the subsection") @property def as_str(self) -> str: return f"### {self.subsection_title}\n\n{self.description}".strip() class Section(BaseModel): section_title: str = Field(..., title="Title of the section") description: str = Field(..., title="Content of the section") subsections: Optional[List[Subsection]] = Field( default=None, title="Titles and descriptions for each subsection of the Wikipedia page.", ) @property def as_str(self) -> str: subsections = "\n\n".join( f"### {subsection.subsection_title}\n\n{subsection.description}" for subsection in self.subsections or [] ) return f"## {self.section_title}\n\n{self.description}\n\n{subsections}".strip() class Outline(BaseModel): page_title: str = Field(..., title="Title of the Wikipedia page") sections: List[Section] = Field( default_factory=list, title="Titles and descriptions for each section of the Wikipedia page.", ) @property def as_str(self) -> str: sections = "\n\n".join(section.as_str for section in self.sections) return f"# {self.page_title}\n\n{sections}".strip() generate_outline_direct = direct_gen_outline_prompt | fast_llm.with_structured_output( Outline ) gen_related_topics_prompt = ChatPromptTemplate.from_template( """I'm writing a Wikipedia page for a topic mentioned below. Please identify and recommend some Wikipedia pages on closely related subjects. I'm looking for examples that provide insights into interesting aspects commonly associated with this topic, or examples that help me understand the typical content and structure included in Wikipedia pages for similar topics. Please list the as many subjects and urls as you can. Topic of interest: {topic} """ ) class RelatedSubjects(BaseModel): topics: List[str] = Field( description="Comprehensive list of related subjects as background research.", ) expand_chain = gen_related_topics_prompt | fast_llm.with_structured_output( RelatedSubjects ) class Editor(BaseModel): affiliation: str = Field( description="Primary affiliation of the editor.", ) name: str = Field( description="Name of the editor.", ) role: str = Field( description="Role of the editor in the context of the topic.", ) description: str = Field( description="Description of the editor's focus, concerns, and motives.", ) @property def persona(self) -> str: return f"Name: {self.name}\nRole: {self.role}\nAffiliation: {self.affiliation}\nDescription: {self.description}\n" class Perspectives(BaseModel): editors: List[Editor] = Field( description="Comprehensive list of editors with their roles and affiliations.", # Add a pydantic validation/restriction to be at most M editors ) gen_perspectives_prompt = ChatPromptTemplate.from_messages( [ ( "system", """You need to select a diverse (and distinct) group of Wikipedia editors who will work together to create a comprehensive article on the topic. Each of them represents a different perspective, role, or affiliation related to this topic.\ You can use other Wikipedia pages of related topics for inspiration. For each editor, add a description of what they will focus on. Wiki page outlines of related topics for inspiration: {examples}""", ), ("user", "Topic of interest: {topic}"), ] ) gen_perspectives_chain = gen_perspectives_prompt | ChatOpenAI( model="gpt-3.5-turbo" ).with_structured_output(Perspectives) wikipedia_retriever = WikipediaRetriever(load_all_available_meta=True, top_k_results=1) def format_doc(doc, max_length=1000): related = "- ".join(doc.metadata["categories"]) return f"### {doc.metadata['title']}\n\nSummary: {doc.page_content}\n\nRelated\n{related}"[ :max_length ] def format_docs(docs): return "\n\n".join(format_doc(doc) for doc in docs) @as_runnable async def survey_subjects(topic: str): related_subjects = await expand_chain.ainvoke({"topic": topic}) retrieved_docs = await wikipedia_retriever.abatch( related_subjects.topics, return_exceptions=True ) all_docs = [] for docs in retrieved_docs: if isinstance(docs, BaseException): continue all_docs.extend(docs) formatted = format_docs(all_docs) return await gen_perspectives_chain.ainvoke({"examples": formatted, "topic": topic}) def add_messages(left, right): if not isinstance(left, list): left = [left] if not isinstance(right, list): right = [right] return left + right def update_references(references, new_references): if not references: references = {} references.update(new_references) return references def update_editor(editor, new_editor): # Can only set at the outset if not editor: return new_editor return editor class InterviewState(TypedDict): messages: Annotated[List[AnyMessage], add_messages] references: Annotated[Optional[dict], update_references] editor: Annotated[Optional[Editor], update_editor] gen_qn_prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are an experienced Wikipedia writer and want to edit a specific page. \ Besides your identity as a Wikipedia writer, you have a specific focus when researching the topic. \ Now, you are chatting with an expert to get information. Ask good questions to get more useful information. When you have no more questions to ask, say "Thank you so much for your help!" to end the conversation.\ Please only ask one question at a time and don't ask what you have asked before.\ Your questions should be related to the topic you want to write. Be comprehensive and curious, gaining as much unique insight from the expert as possible.\ Stay true to your specific perspective: {persona}""", ), MessagesPlaceholder(variable_name="messages", optional=True), ] ) def tag_with_name(ai_message: AIMessage, name: str): ai_message.name = name.replace(" ", "_").replace(".", "_") return ai_message def swap_roles(state: InterviewState, name: str): converted = [] for message in state["messages"]: if isinstance(message, AIMessage) and message.name != name: message = HumanMessage(**message.dict(exclude={"type"})) converted.append(message) return {"messages": converted} @as_runnable async def generate_question(state: InterviewState): editor = state["editor"] gn_chain = ( RunnableLambda(swap_roles).bind(name=editor.name) | gen_qn_prompt.partial(persona=editor.persona) | fast_llm | RunnableLambda(tag_with_name).bind(name=editor.name) ) result = await gn_chain.ainvoke(state) return {"messages": [result]} class Queries(BaseModel): queries: List[str] = Field( description="Comprehensive list of search engine queries to answer the user's questions.", ) gen_queries_prompt = ChatPromptTemplate.from_messages( [ ( "system", "You are a helpful research assistant. Query the search engine to answer the user's questions.", ), MessagesPlaceholder(variable_name="messages", optional=True), ] ) gen_queries_chain = gen_queries_prompt | ChatOpenAI( model="gpt-3.5-turbo" ).with_structured_output(Queries, include_raw=True) class AnswerWithCitations(BaseModel): answer: str = Field( description="Comprehensive answer to the user's question with citations.", ) cited_urls: List[str] = Field( description="List of urls cited in the answer.", ) @property def as_str(self) -> str: return f"{self.answer}\n\nCitations:\n\n" + "\n".join( f"[{i+1}]: {url}" for i, url in enumerate(self.cited_urls) ) gen_answer_prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are an expert who can use information effectively. You are chatting with a Wikipedia writer who wants\ to write a Wikipedia page on the topic you know. You have gathered the related information and will now use the information to form a response. Make your response as informative as possible and make sure every sentence is supported by the gathered information. Each response must be backed up by a citation from a reliable source, formatted as a footnote, reproducing the URLS after your response.""", ), MessagesPlaceholder(variable_name="messages", optional=True), ] ) gen_answer_chain = gen_answer_prompt | fast_llm.with_structured_output( AnswerWithCitations, include_raw=True ).with_config(run_name="GenerateAnswer") # Tavily is typically a better search engine, but your free queries are limited tavily_search = TavilySearchResults(max_results=4) @tool async def search_engine(query: str): """Search engine to the internet.""" results = tavily_search.invoke(query) return [{"content": r["content"], "url": r["url"]} for r in results] async def gen_answer( state: InterviewState, config: Optional[RunnableConfig] = None, name: str = "Subject_Matter_Expert", max_str_len: int = 15000, ): swapped_state = swap_roles(state, name) # Convert all other AI messages queries = await gen_queries_chain.ainvoke(swapped_state) query_results = await search_engine.abatch( queries["parsed"].queries, config, return_exceptions=True ) successful_results = [ res for res in query_results if not isinstance(res, Exception) ] all_query_results = { res["url"]: res["content"] for results in successful_results for res in results } # We could be more precise about handling max token length if we wanted to here dumped = json.dumps(all_query_results)[:max_str_len] ai_message: AIMessage = queries["raw"] tool_call = queries["raw"].tool_calls[0] tool_id = tool_call["id"] tool_message = ToolMessage(tool_call_id=tool_id, content=dumped) swapped_state["messages"].extend([ai_message, tool_message]) # Only update the shared state with the final answer to avoid # polluting the dialogue history with intermediate messages generated = await gen_answer_chain.ainvoke(swapped_state) cited_urls = set(generated["parsed"].cited_urls) # Save the retrieved information to a the shared state for future reference cited_references = {k: v for k, v in all_query_results.items() if k in cited_urls} formatted_message = AIMessage(name=name, content=generated["parsed"].as_str) return {"messages": [formatted_message], "references": cited_references} max_num_turns = 5 def route_messages(state: InterviewState, name: str = "Subject_Matter_Expert"): messages = state["messages"] num_responses = len( [m for m in messages if isinstance(m, AIMessage) and m.name == name] ) if num_responses >= max_num_turns: return END last_question = messages[-2] if last_question.content.endswith("Thank you so much for your help!"): return END return "ask_question" builder = StateGraph(InterviewState) builder.add_node("ask_question", generate_question) builder.add_node("answer_question", gen_answer) builder.add_conditional_edges("answer_question", route_messages) builder.add_edge("ask_question", "answer_question") builder.set_entry_point("ask_question") interview_graph = builder.compile().with_config(run_name="Conduct Interviews") refine_outline_prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are a Wikipedia writer. You have gathered information from experts and search engines. Now, you are refining the outline of the Wikipedia page. \ You need to make sure that the outline is comprehensive and specific. \ Topic you are writing about: {topic} Old outline: {old_outline}""", ), ( "user", "Refine the outline based on your conversations with subject-matter experts:\n\nConversations:\n\n{conversations}\n\nWrite the refined Wikipedia outline:", ), ] ) # Using turbo preview since the context can get quite long refine_outline_chain = refine_outline_prompt | long_context_llm.with_structured_output( Outline ) embeddings = OpenAIEmbeddings(model="text-embedding-3-small") # reference_docs = [ # Document(page_content=v, metadata={"source": k}) # for k, v in final_state["references"].items() # ] # # This really doesn't need to be a vectorstore for this size of data. # # It could just be a numpy matrix. Or you could store documents # # across requests if you want. # vectorstore = SKLearnVectorStore.from_documents( # reference_docs, # embedding=embeddings, # ) # retriever = vectorstore.as_retriever(k=10) vectorstore = SKLearnVectorStore(embedding=embeddings) retriever = vectorstore.as_retriever(k=10) class SubSection(BaseModel): subsection_title: str = Field(..., title="Title of the subsection") content: str = Field( ..., title="Full content of the subsection. Include [#] citations to the cited sources where relevant.", ) @property def as_str(self) -> str: return f"### {self.subsection_title}\n\n{self.content}".strip() class WikiSection(BaseModel): section_title: str = Field(..., title="Title of the section") content: str = Field(..., title="Full content of the section") subsections: Optional[List[Subsection]] = Field( default=None, title="Titles and descriptions for each subsection of the Wikipedia page.", ) citations: List[str] = Field(default_factory=list) @property def as_str(self) -> str: subsections = "\n\n".join( subsection.as_str for subsection in self.subsections or [] ) citations = "\n".join([f" [{i}] {cit}" for i, cit in enumerate(self.citations)]) return ( f"## {self.section_title}\n\n{self.content}\n\n{subsections}".strip() + f"\n\n{citations}".strip() ) section_writer_prompt = ChatPromptTemplate.from_messages( [ ( "system", "You are an expert Wikipedia writer. Complete your assigned WikiSection from the following outline:\n\n" "{outline}\n\nCite your sources, using the following references:\n\n<Documents>\n{docs}\n<Documents>", ), ("user", "Write the full WikiSection for the {section} section."), ] ) async def retrieve(inputs: dict): docs = await retriever.ainvoke(inputs["topic"] + ": " + inputs["section"]) formatted = "\n".join( [ f'<Document href="{doc.metadata["source"]}"/>\n{doc.page_content}\n</Document>' for doc in docs ] ) return {"docs": formatted, **inputs} section_writer = ( retrieve | section_writer_prompt | long_context_llm.with_structured_output(WikiSection) ) writer_prompt = ChatPromptTemplate.from_messages( [ ( "system", "You are an expert Wikipedia author. Write the complete wiki article on {topic} using the following section drafts:\n\n" "{draft}\n\nStrictly follow Wikipedia format guidelines.", ), ( "user", 'Write the complete Wiki article using markdown format. Organize citations using footnotes like "[1]",' " avoiding duplicates in the footer. Include URLs in the footer.", ), ] ) writer = writer_prompt | long_context_llm | StrOutputParser() class ResearchState(TypedDict): topic: str outline: Outline editors: List[Editor] interview_results: List[InterviewState] # The final sections output sections: List[WikiSection] article: str async def initialize_research(state: ResearchState): topic = state["topic"] coros = ( generate_outline_direct.ainvoke({"topic": topic}), survey_subjects.ainvoke(topic), ) results = await asyncio.gather(*coros) return { **state, "outline": results[0], "editors": results[1].editors, } async def conduct_interviews(state: ResearchState): topic = state["topic"] initial_states = [ { "editor": editor, "messages": [ AIMessage( content=f"So you said you were writing an article on {topic}?", name="Subject_Matter_Expert", ) ], } for editor in state["editors"] ] # We call in to the sub-graph here to parallelize the interviews interview_results = await interview_graph.abatch(initial_states) return { **state, "interview_results": interview_results, } def format_conversation(interview_state): messages = interview_state["messages"] convo = "\n".join(f"{m.name}: {m.content}" for m in messages) return f'Conversation with {interview_state["editor"].name}\n\n' + convo async def refine_outline(state: ResearchState): convos = "\n\n".join( [ format_conversation(interview_state) for interview_state in state["interview_results"] ] ) updated_outline = await refine_outline_chain.ainvoke( { "topic": state["topic"], "old_outline": state["outline"].as_str, "conversations": convos, } ) return {**state, "outline": updated_outline} async def index_references(state: ResearchState): all_docs = [] for interview_state in state["interview_results"]: reference_docs = [ Document(page_content=v, metadata={"source": k}) for k, v in interview_state["references"].items() ] all_docs.extend(reference_docs) await vectorstore.aadd_documents(all_docs) return state async def write_sections(state: ResearchState): outline = state["outline"] sections = await section_writer.abatch( [ { "outline": outline.as_str, "section": section.section_title, "topic": state["topic"], } for section in outline.sections ] ) return { **state, "sections": sections, } async def write_article(state: ResearchState): topic = state["topic"] sections = state["sections"] draft = "\n\n".join([section.as_str for section in sections]) article = await writer.ainvoke({"topic": topic, "draft": draft}) return { **state, "article": article, } builder_of_storm = StateGraph(ResearchState) nodes = [ ("init_research", initialize_research), ("conduct_interviews", conduct_interviews), ("refine_outline", refine_outline), ("index_references", index_references), ("write_sections", write_sections), ("write_article", write_article), ] for i in range(len(nodes)): name, node = nodes[i] builder_of_storm.add_node(name, node) if i > 0: builder_of_storm.add_edge(nodes[i - 1][0], name) builder_of_storm.set_entry_point(nodes[0][0]) builder_of_storm.set_finish_point(nodes[-1][0]) graph = builder_of_storm.compile() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs/agent.py`: ```py from typing import Annotated, Literal, Sequence, TypedDict from langchain_anthropic import ChatAnthropic from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.messages import BaseMessage from langchain_openai import ChatOpenAI from langgraph.graph import END, StateGraph, add_messages from langgraph.prebuilt import ToolNode tools = [TavilySearchResults(max_results=1)] model_anth = ChatAnthropic(temperature=0, model_name="claude-3-sonnet-20240229") model_oai = ChatOpenAI(temperature=0) model_anth = model_anth.bind_tools(tools) model_oai = model_oai.bind_tools(tools) class AgentState(TypedDict): messages: Annotated[Sequence[BaseMessage], add_messages] # Define the function that determines whether to continue or not def should_continue(state): messages = state["messages"] last_message = messages[-1] # If there are no tool calls, then we finish if not last_message.tool_calls: return "end" # Otherwise if there is, we continue else: return "continue" # Define the function that calls the model def call_model(state, config): if config["configurable"].get("model", "anthropic") == "anthropic": model = model_anth else: model = model_oai messages = state["messages"] response = model.invoke(messages) # We return a list, because this will get added to the existing list return {"messages": [response]} # Define the function to execute tools tool_node = ToolNode(tools) class ConfigSchema(TypedDict): model: Literal["anthropic", "openai"] # Define a new graph workflow = StateGraph(AgentState, config_schema=ConfigSchema) # Define the two nodes we will cycle between workflow.add_node("agent", call_model) workflow.add_node("action", tool_node) # Set the entrypoint as `agent` # This means that this node is the first one called workflow.set_entry_point("agent") # We now add a conditional edge workflow.add_conditional_edges( # First, we define the start node. We use `agent`. # This means these are the edges taken after the `agent` node is called. "agent", # Next, we pass in the function that will determine which node is called next. should_continue, # Finally we pass in a mapping. # The keys are strings, and the values are other nodes. # END is a special node marking that the graph should finish. # What will happen is we will call `should_continue`, and then the output of that # will be matched against the keys in this mapping. # Based on which one it matches, that node will then be called. { # If `tools`, then we call the tool node. "continue": "action", # Otherwise we finish. "end": END, }, ) # We now add a normal edge from `tools` to `agent`. # This means that after `tools` is called, `agent` node is called next. workflow.add_edge("action", "agent") # Finally, we compile it! # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable graph = workflow.compile() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs_reqs_a/hello.py`: ```py from graphs_reqs_a.graphs_submod.agent import graph # noqa ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs_reqs_a/graphs_submod/agent.py`: ```py from pathlib import Path from typing import Annotated, Sequence, TypedDict from langchain_anthropic import ChatAnthropic from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.messages import BaseMessage from langchain_openai import ChatOpenAI from langgraph.graph import END, StateGraph, add_messages from langgraph.prebuilt import ToolNode tools = [TavilySearchResults(max_results=1)] model_anth = ChatAnthropic(temperature=0, model_name="claude-3-sonnet-20240229") model_oai = ChatOpenAI(temperature=0) model_anth = model_anth.bind_tools(tools) model_oai = model_oai.bind_tools(tools) prompt = open("prompt.txt").read() subprompt = open(Path(__file__).parent / "subprompt.txt").read() class AgentState(TypedDict): messages: Annotated[Sequence[BaseMessage], add_messages] # Define the function that determines whether to continue or not def should_continue(state): messages = state["messages"] last_message = messages[-1] # If there are no tool calls, then we finish if not last_message.tool_calls: return "end" # Otherwise if there is, we continue else: return "continue" # Define the function that calls the model def call_model(state, config): if config["configurable"].get("model", "anthropic") == "anthropic": model = model_anth else: model = model_oai messages = state["messages"] response = model.invoke(messages) # We return a list, because this will get added to the existing list return {"messages": [response]} # Define the function to execute tools tool_node = ToolNode(tools) # Define a new graph workflow = StateGraph(AgentState) # Define the two nodes we will cycle between workflow.add_node("agent", call_model) workflow.add_node("action", tool_node) # Set the entrypoint as `agent` # This means that this node is the first one called workflow.set_entry_point("agent") # We now add a conditional edge workflow.add_conditional_edges( # First, we define the start node. We use `agent`. # This means these are the edges taken after the `agent` node is called. "agent", # Next, we pass in the function that will determine which node is called next. should_continue, # Finally we pass in a mapping. # The keys are strings, and the values are other nodes. # END is a special node marking that the graph should finish. # What will happen is we will call `should_continue`, and then the output of that # will be matched against the keys in this mapping. # Based on which one it matches, that node will then be called. { # If `tools`, then we call the tool node. "continue": "action", # Otherwise we finish. "end": END, }, ) # We now add a normal edge from `tools` to `agent`. # This means that after `tools` is called, `agent` node is called next. workflow.add_edge("action", "agent") # Finally, we compile it! # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable graph = workflow.compile() ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py`: ```py import random import sqlite3 import threading from contextlib import closing, contextmanager from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( WRITES_IDX_MAP, BaseCheckpointSaver, ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, SerializerProtocol, get_checkpoint_id, ) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer from langgraph.checkpoint.serde.types import ChannelProtocol from langgraph.checkpoint.sqlite.utils import search_where _AIO_ERROR_MSG = ( "The SqliteSaver does not support async methods. " "Consider using AsyncSqliteSaver instead.\n" "from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver\n" "Note: AsyncSqliteSaver requires the aiosqlite package to use.\n" "Install with:\n`pip install aiosqlite`\n" "See https://langchain-ai.github.io/langgraph/reference/checkpoints/asyncsqlitesaver" "for more information." ) class SqliteSaver(BaseCheckpointSaver[str]): """A checkpoint saver that stores checkpoints in a SQLite database. Note: This class is meant for lightweight, synchronous use cases (demos and small projects) and does not scale to multiple threads. For a similar sqlite saver with `async` support, consider using [AsyncSqliteSaver][langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver]. Args: conn (sqlite3.Connection): The SQLite database connection. serde (Optional[SerializerProtocol]): The serializer to use for serializing and deserializing checkpoints. Defaults to JsonPlusSerializerCompat. Examples: >>> import sqlite3 >>> from langgraph.checkpoint.sqlite import SqliteSaver >>> from langgraph.graph import StateGraph >>> >>> builder = StateGraph(int) >>> builder.add_node("add_one", lambda x: x + 1) >>> builder.set_entry_point("add_one") >>> builder.set_finish_point("add_one") >>> conn = sqlite3.connect("checkpoints.sqlite") >>> memory = SqliteSaver(conn) >>> graph = builder.compile(checkpointer=memory) >>> config = {"configurable": {"thread_id": "1"}} >>> graph.get_state(config) >>> result = graph.invoke(3, config) >>> graph.get_state(config) StateSnapshot(values=4, next=(), config={'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '0c62ca34-ac19-445d-bbb0-5b4984975b2a'}}, parent_config=None) """ # noqa conn: sqlite3.Connection is_setup: bool def __init__( self, conn: sqlite3.Connection, *, serde: Optional[SerializerProtocol] = None, ) -> None: super().__init__(serde=serde) self.jsonplus_serde = JsonPlusSerializer() self.conn = conn self.is_setup = False self.lock = threading.Lock() @classmethod @contextmanager def from_conn_string(cls, conn_string: str) -> Iterator["SqliteSaver"]: """Create a new SqliteSaver instance from a connection string. Args: conn_string (str): The SQLite connection string. Yields: SqliteSaver: A new SqliteSaver instance. Examples: In memory: with SqliteSaver.from_conn_string(":memory:") as memory: ... To disk: with SqliteSaver.from_conn_string("checkpoints.sqlite") as memory: ... """ with closing( sqlite3.connect( conn_string, # https://ricardoanderegg.com/posts/python-sqlite-thread-safety/ check_same_thread=False, ) ) as conn: yield cls(conn) def setup(self) -> None: """Set up the checkpoint database. This method creates the necessary tables in the SQLite database if they don't already exist. It is called automatically when needed and should not be called directly by the user. """ if self.is_setup: return self.conn.executescript( """ PRAGMA journal_mode=WAL; CREATE TABLE IF NOT EXISTS checkpoints ( thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', checkpoint_id TEXT NOT NULL, parent_checkpoint_id TEXT, type TEXT, checkpoint BLOB, metadata BLOB, PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) ); CREATE TABLE IF NOT EXISTS writes ( thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', checkpoint_id TEXT NOT NULL, task_id TEXT NOT NULL, idx INTEGER NOT NULL, channel TEXT NOT NULL, type TEXT, value BLOB, PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) ); """ ) self.is_setup = True @contextmanager def cursor(self, transaction: bool = True) -> Iterator[sqlite3.Cursor]: """Get a cursor for the SQLite database. This method returns a cursor for the SQLite database. It is used internally by the SqliteSaver and should not be called directly by the user. Args: transaction (bool): Whether to commit the transaction when the cursor is closed. Defaults to True. Yields: sqlite3.Cursor: A cursor for the SQLite database. """ with self.lock: self.setup() cur = self.conn.cursor() try: yield cur finally: if transaction: self.conn.commit() cur.close() def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database. This method retrieves a checkpoint tuple from the SQLite database based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and checkpoint ID is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. Examples: Basic: >>> config = {"configurable": {"thread_id": "1"}} >>> checkpoint_tuple = memory.get_tuple(config) >>> print(checkpoint_tuple) CheckpointTuple(...) With checkpoint ID: >>> config = { ... "configurable": { ... "thread_id": "1", ... "checkpoint_ns": "", ... "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", ... } ... } >>> checkpoint_tuple = memory.get_tuple(config) >>> print(checkpoint_tuple) CheckpointTuple(...) """ # noqa checkpoint_ns = config["configurable"].get("checkpoint_ns", "") with self.cursor(transaction=False) as cur: # find the latest checkpoint for the thread_id if checkpoint_id := get_checkpoint_id(config): cur.execute( "SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?", ( str(config["configurable"]["thread_id"]), checkpoint_ns, checkpoint_id, ), ) else: cur.execute( "SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1", (str(config["configurable"]["thread_id"]), checkpoint_ns), ) # if a checkpoint is found, return it if value := cur.fetchone(): ( thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, ) = value if not get_checkpoint_id(config): config = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } } # find any pending writes cur.execute( "SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? ORDER BY task_id, idx", ( str(config["configurable"]["thread_id"]), checkpoint_ns, str(config["configurable"]["checkpoint_id"]), ), ) # deserialize the checkpoint and metadata return CheckpointTuple( config, self.serde.loads_typed((type, checkpoint)), self.jsonplus_serde.loads(metadata) if metadata is not None else {}, ( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } if parent_checkpoint_id else None ), [ (task_id, channel, self.serde.loads_typed((type, value))) for task_id, channel, type, value in cur ], ) def list( self, config: Optional[RunnableConfig], *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[CheckpointTuple]: """List checkpoints from the database. This method retrieves a list of checkpoint tuples from the SQLite database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: config (RunnableConfig): The config to use for listing the checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. Defaults to None. before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None. Yields: Iterator[CheckpointTuple]: An iterator of checkpoint tuples. Examples: >>> from langgraph.checkpoint.sqlite import SqliteSaver >>> with SqliteSaver.from_conn_string(":memory:") as memory: ... # Run a graph, then list the checkpoints >>> config = {"configurable": {"thread_id": "1"}} >>> checkpoints = list(memory.list(config, limit=2)) >>> print(checkpoints) [CheckpointTuple(...), CheckpointTuple(...)] >>> config = {"configurable": {"thread_id": "1"}} >>> before = {"configurable": {"checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875"}} >>> with SqliteSaver.from_conn_string(":memory:") as memory: ... # Run a graph, then list the checkpoints >>> checkpoints = list(memory.list(config, before=before)) >>> print(checkpoints) [CheckpointTuple(...), ...] """ where, param_values = search_where(config, filter, before) query = f"""SELECT thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints {where} ORDER BY checkpoint_id DESC""" if limit: query += f" LIMIT {limit}" with self.cursor(transaction=False) as cur, closing(self.conn.cursor()) as wcur: cur.execute(query, param_values) for ( thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, ) in cur: wcur.execute( "SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? ORDER BY task_id, idx", (thread_id, checkpoint_ns, checkpoint_id), ) yield CheckpointTuple( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } }, self.serde.loads_typed((type, checkpoint)), self.jsonplus_serde.loads(metadata) if metadata is not None else {}, ( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } if parent_checkpoint_id else None ), [ (task_id, channel, self.serde.loads_typed((type, value))) for task_id, channel, type, value in wcur ], ) def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Save a checkpoint to the database. This method saves a checkpoint to the SQLite database. The checkpoint is associated with the provided config and its parent config (if any). Args: config (RunnableConfig): The config to associate with the checkpoint. checkpoint (Checkpoint): The checkpoint to save. metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. new_versions (ChannelVersions): New channel versions as of this write. Returns: RunnableConfig: Updated configuration after storing the checkpoint. Examples: >>> from langgraph.checkpoint.sqlite import SqliteSaver >>> with SqliteSaver.from_conn_string(":memory:") as memory: >>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} >>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "channel_values": {"key": "value"}} >>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {}) >>> print(saved_config) {'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef4f797-8335-6428-8001-8a1503f9b875'}} """ thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"]["checkpoint_ns"] type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint) serialized_metadata = self.jsonplus_serde.dumps(metadata) with self.cursor() as cur: cur.execute( "INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)", ( str(config["configurable"]["thread_id"]), checkpoint_ns, checkpoint["id"], config["configurable"].get("checkpoint_id"), type_, serialized_checkpoint, serialized_metadata, ), ) return { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint["id"], } } def put_writes( self, config: RunnableConfig, writes: Sequence[Tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint. This method saves intermediate writes associated with a checkpoint to the SQLite database. Args: config (RunnableConfig): Configuration of the related checkpoint. writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. task_id (str): Identifier for the task creating the writes. """ query = ( "INSERT OR REPLACE INTO writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) VALUES (?, ?, ?, ?, ?, ?, ?, ?)" if all(w[0] in WRITES_IDX_MAP for w in writes) else "INSERT OR IGNORE INTO writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) VALUES (?, ?, ?, ?, ?, ?, ?, ?)" ) with self.cursor() as cur: cur.executemany( query, [ ( str(config["configurable"]["thread_id"]), str(config["configurable"]["checkpoint_ns"]), str(config["configurable"]["checkpoint_id"]), task_id, WRITES_IDX_MAP.get(channel, idx), channel, *self.serde.dumps_typed(value), ) for idx, (channel, value) in enumerate(writes) ], ) async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database asynchronously. Note: This async method is not supported by the SqliteSaver class. Use get_tuple() instead, or consider using [AsyncSqliteSaver][langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver]. """ raise NotImplementedError(_AIO_ERROR_MSG) async def alist( self, config: Optional[RunnableConfig], *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> AsyncIterator[CheckpointTuple]: """List checkpoints from the database asynchronously. Note: This async method is not supported by the SqliteSaver class. Use list() instead, or consider using [AsyncSqliteSaver][langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver]. """ raise NotImplementedError(_AIO_ERROR_MSG) yield async def aput( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Save a checkpoint to the database asynchronously. Note: This async method is not supported by the SqliteSaver class. Use put() instead, or consider using [AsyncSqliteSaver][langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver]. """ raise NotImplementedError(_AIO_ERROR_MSG) def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: """Generate the next version ID for a channel. This method creates a new version identifier for a channel based on its current version. Args: current (Optional[str]): The current version identifier of the channel. channel (BaseChannel): The channel being versioned. Returns: str: The next version identifier, which is guaranteed to be monotonically increasing. """ if current is None: current_v = 0 elif isinstance(current, int): current_v = current else: current_v = int(current.split(".")[0]) next_v = current_v + 1 next_h = random.random() return f"{next_v:032}.{next_h:016}" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/utils.py`: ```py import json from typing import Any, Dict, Optional, Sequence, Tuple from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import get_checkpoint_id def _metadata_predicate( metadata_filter: Dict[str, Any], ) -> Tuple[Sequence[str], Sequence[Any]]: """Return WHERE clause predicates for (a)search() given metadata filter. This method returns a tuple of a string and a tuple of values. The string is the parametered WHERE clause predicate (excluding the WHERE keyword): "column1 = ? AND column2 IS ?". The tuple of values contains the values for each of the corresponding parameters. """ def _where_value(query_value: Any) -> Tuple[str, Any]: """Return tuple of operator and value for WHERE clause predicate.""" if query_value is None: return ("IS ?", None) elif ( isinstance(query_value, str) or isinstance(query_value, int) or isinstance(query_value, float) ): return ("= ?", query_value) elif isinstance(query_value, bool): return ("= ?", 1 if query_value else 0) elif isinstance(query_value, dict) or isinstance(query_value, list): # query value for JSON object cannot have trailing space after separators (, :) # SQLite json_extract() returns JSON string without whitespace return ("= ?", json.dumps(query_value, separators=(",", ":"))) else: return ("= ?", str(query_value)) predicates = [] param_values = [] # process metadata query for query_key, query_value in metadata_filter.items(): operator, param_value = _where_value(query_value) predicates.append( f"json_extract(CAST(metadata AS TEXT), '$.{query_key}') {operator}" ) param_values.append(param_value) return (predicates, param_values) def search_where( config: Optional[RunnableConfig], filter: Optional[Dict[str, Any]], before: Optional[RunnableConfig] = None, ) -> Tuple[str, Sequence[Any]]: """Return WHERE clause predicates for (a)search() given metadata filter and `before` config. This method returns a tuple of a string and a tuple of values. The string is the parametered WHERE clause predicate (including the WHERE keyword): "WHERE column1 = ? AND column2 IS ?". The tuple of values contains the values for each of the corresponding parameters. """ wheres = [] param_values = [] # construct predicate for config filter if config is not None: wheres.append("thread_id = ?") param_values.append(config["configurable"]["thread_id"]) checkpoint_ns = config["configurable"].get("checkpoint_ns") if checkpoint_ns is not None: wheres.append("checkpoint_ns = ?") param_values.append(checkpoint_ns) if checkpoint_id := get_checkpoint_id(config): wheres.append("checkpoint_id = ?") param_values.append(checkpoint_id) # construct predicate for metadata filter if filter: metadata_predicates, metadata_values = _metadata_predicate(filter) wheres.extend(metadata_predicates) param_values.extend(metadata_values) # construct predicate for `before` if before is not None: wheres.append("checkpoint_id < ?") param_values.append(get_checkpoint_id(before)) return ("WHERE " + " AND ".join(wheres) if wheres else "", param_values) ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py`: ```py import asyncio import random from contextlib import asynccontextmanager from typing import ( Any, AsyncIterator, Callable, Dict, Iterator, Optional, Sequence, Tuple, TypeVar, ) import aiosqlite from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( WRITES_IDX_MAP, BaseCheckpointSaver, ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, SerializerProtocol, get_checkpoint_id, ) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer from langgraph.checkpoint.serde.types import ChannelProtocol from langgraph.checkpoint.sqlite.utils import search_where T = TypeVar("T", bound=Callable) class AsyncSqliteSaver(BaseCheckpointSaver[str]): """An asynchronous checkpoint saver that stores checkpoints in a SQLite database. This class provides an asynchronous interface for saving and retrieving checkpoints using a SQLite database. It's designed for use in asynchronous environments and offers better performance for I/O-bound operations compared to synchronous alternatives. Attributes: conn (aiosqlite.Connection): The asynchronous SQLite database connection. serde (SerializerProtocol): The serializer used for encoding/decoding checkpoints. Tip: Requires the [aiosqlite](https://pypi.org/project/aiosqlite/) package. Install it with `pip install aiosqlite`. Warning: While this class supports asynchronous checkpointing, it is not recommended for production workloads due to limitations in SQLite's write performance. For production use, consider a more robust database like PostgreSQL. Tip: Remember to **close the database connection** after executing your code, otherwise, you may see the graph "hang" after execution (since the program will not exit until the connection is closed). The easiest way is to use the `async with` statement as shown in the examples. ```python async with AsyncSqliteSaver.from_conn_string("checkpoints.sqlite") as saver: # Your code here graph = builder.compile(checkpointer=saver) config = {"configurable": {"thread_id": "thread-1"}} async for event in graph.astream_events(..., config, version="v1"): print(event) ``` Examples: Usage within StateGraph: ```pycon >>> import asyncio >>> >>> from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver >>> from langgraph.graph import StateGraph >>> >>> builder = StateGraph(int) >>> builder.add_node("add_one", lambda x: x + 1) >>> builder.set_entry_point("add_one") >>> builder.set_finish_point("add_one") >>> async with AsyncSqliteSaver.from_conn_string("checkpoints.db") as memory: >>> graph = builder.compile(checkpointer=memory) >>> coro = graph.ainvoke(1, {"configurable": {"thread_id": "thread-1"}}) >>> print(asyncio.run(coro)) Output: 2 ``` Raw usage: ```pycon >>> import asyncio >>> import aiosqlite >>> from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver >>> >>> async def main(): >>> async with aiosqlite.connect("checkpoints.db") as conn: ... saver = AsyncSqliteSaver(conn) ... config = {"configurable": {"thread_id": "1"}} ... checkpoint = {"ts": "2023-05-03T10:00:00Z", "data": {"key": "value"}} ... saved_config = await saver.aput(config, checkpoint, {}, {}) ... print(saved_config) >>> asyncio.run(main()) {"configurable": {"thread_id": "1", "checkpoint_id": "0c62ca34-ac19-445d-bbb0-5b4984975b2a"}} ``` """ lock: asyncio.Lock is_setup: bool def __init__( self, conn: aiosqlite.Connection, *, serde: Optional[SerializerProtocol] = None, ): super().__init__(serde=serde) self.jsonplus_serde = JsonPlusSerializer() self.conn = conn self.lock = asyncio.Lock() self.loop = asyncio.get_running_loop() self.is_setup = False @classmethod @asynccontextmanager async def from_conn_string( cls, conn_string: str ) -> AsyncIterator["AsyncSqliteSaver"]: """Create a new AsyncSqliteSaver instance from a connection string. Args: conn_string (str): The SQLite connection string. Yields: AsyncSqliteSaver: A new AsyncSqliteSaver instance. """ async with aiosqlite.connect(conn_string) as conn: yield cls(conn) def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database. This method retrieves a checkpoint tuple from the SQLite database based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and checkpoint ID is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. """ try: # check if we are in the main thread, only bg threads can block # we don't check in other methods to avoid the overhead if asyncio.get_running_loop() is self.loop: raise asyncio.InvalidStateError( "Synchronous calls to AsyncSqliteSaver are only allowed from a " "different thread. From the main thread, use the async interface." "For example, use `await checkpointer.aget_tuple(...)` or `await " "graph.ainvoke(...)`." ) except RuntimeError: pass return asyncio.run_coroutine_threadsafe( self.aget_tuple(config), self.loop ).result() def list( self, config: Optional[RunnableConfig], *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[CheckpointTuple]: """List checkpoints from the database asynchronously. This method retrieves a list of checkpoint tuples from the SQLite database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. limit (Optional[int]): Maximum number of checkpoints to return. Yields: Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples. """ aiter_ = self.alist(config, filter=filter, before=before, limit=limit) while True: try: yield asyncio.run_coroutine_threadsafe( anext(aiter_), self.loop, ).result() except StopAsyncIteration: break def put( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Save a checkpoint to the database. This method saves a checkpoint to the SQLite database. The checkpoint is associated with the provided config and its parent config (if any). Args: config (RunnableConfig): The config to associate with the checkpoint. checkpoint (Checkpoint): The checkpoint to save. metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. new_versions (ChannelVersions): New channel versions as of this write. Returns: RunnableConfig: Updated configuration after storing the checkpoint. """ return asyncio.run_coroutine_threadsafe( self.aput(config, checkpoint, metadata, new_versions), self.loop ).result() def put_writes( self, config: RunnableConfig, writes: Sequence[Tuple[str, Any]], task_id: str ) -> None: return asyncio.run_coroutine_threadsafe( self.aput_writes(config, writes, task_id), self.loop ).result() async def setup(self) -> None: """Set up the checkpoint database asynchronously. This method creates the necessary tables in the SQLite database if they don't already exist. It is called automatically when needed and should not be called directly by the user. """ async with self.lock: if self.is_setup: return if not self.conn.is_alive(): await self.conn async with self.conn.executescript( """ PRAGMA journal_mode=WAL; CREATE TABLE IF NOT EXISTS checkpoints ( thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', checkpoint_id TEXT NOT NULL, parent_checkpoint_id TEXT, type TEXT, checkpoint BLOB, metadata BLOB, PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) ); CREATE TABLE IF NOT EXISTS writes ( thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', checkpoint_id TEXT NOT NULL, task_id TEXT NOT NULL, idx INTEGER NOT NULL, channel TEXT NOT NULL, type TEXT, value BLOB, PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) ); """ ): await self.conn.commit() self.is_setup = True async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database asynchronously. This method retrieves a checkpoint tuple from the SQLite database based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and checkpoint ID is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. Args: config (RunnableConfig): The config to use for retrieving the checkpoint. Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. """ await self.setup() checkpoint_ns = config["configurable"].get("checkpoint_ns", "") async with self.lock, self.conn.cursor() as cur: # find the latest checkpoint for the thread_id if checkpoint_id := get_checkpoint_id(config): await cur.execute( "SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?", ( str(config["configurable"]["thread_id"]), checkpoint_ns, checkpoint_id, ), ) else: await cur.execute( "SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1", (str(config["configurable"]["thread_id"]), checkpoint_ns), ) # if a checkpoint is found, return it if value := await cur.fetchone(): ( thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, ) = value if not get_checkpoint_id(config): config = { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } } # find any pending writes await cur.execute( "SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? ORDER BY task_id, idx", ( str(config["configurable"]["thread_id"]), checkpoint_ns, str(config["configurable"]["checkpoint_id"]), ), ) # deserialize the checkpoint and metadata return CheckpointTuple( config, self.serde.loads_typed((type, checkpoint)), self.jsonplus_serde.loads(metadata) if metadata is not None else {}, ( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } if parent_checkpoint_id else None ), [ (task_id, channel, self.serde.loads_typed((type, value))) async for task_id, channel, type, value in cur ], ) async def alist( self, config: Optional[RunnableConfig], *, filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> AsyncIterator[CheckpointTuple]: """List checkpoints from the database asynchronously. This method retrieves a list of checkpoint tuples from the SQLite database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. limit (Optional[int]): Maximum number of checkpoints to return. Yields: AsyncIterator[CheckpointTuple]: An asynchronous iterator of matching checkpoint tuples. """ await self.setup() where, params = search_where(config, filter, before) query = f"""SELECT thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints {where} ORDER BY checkpoint_id DESC""" if limit: query += f" LIMIT {limit}" async with self.lock, self.conn.execute( query, params ) as cur, self.conn.cursor() as wcur: async for ( thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, ) in cur: await wcur.execute( "SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? ORDER BY task_id, idx", (thread_id, checkpoint_ns, checkpoint_id), ) yield CheckpointTuple( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id, } }, self.serde.loads_typed((type, checkpoint)), self.jsonplus_serde.loads(metadata) if metadata is not None else {}, ( { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": parent_checkpoint_id, } } if parent_checkpoint_id else None ), [ (task_id, channel, self.serde.loads_typed((type, value))) async for task_id, channel, type, value in wcur ], ) async def aput( self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: """Save a checkpoint to the database asynchronously. This method saves a checkpoint to the SQLite database. The checkpoint is associated with the provided config and its parent config (if any). Args: config (RunnableConfig): The config to associate with the checkpoint. checkpoint (Checkpoint): The checkpoint to save. metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. new_versions (ChannelVersions): New channel versions as of this write. Returns: RunnableConfig: Updated configuration after storing the checkpoint. """ await self.setup() thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"]["checkpoint_ns"] type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint) serialized_metadata = self.jsonplus_serde.dumps(metadata) async with self.lock, self.conn.execute( "INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)", ( str(config["configurable"]["thread_id"]), checkpoint_ns, checkpoint["id"], config["configurable"].get("checkpoint_id"), type_, serialized_checkpoint, serialized_metadata, ), ): await self.conn.commit() return { "configurable": { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint["id"], } } async def aput_writes( self, config: RunnableConfig, writes: Sequence[Tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint asynchronously. This method saves intermediate writes associated with a checkpoint to the database. Args: config (RunnableConfig): Configuration of the related checkpoint. writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. task_id (str): Identifier for the task creating the writes. """ query = ( "INSERT OR REPLACE INTO writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) VALUES (?, ?, ?, ?, ?, ?, ?, ?)" if all(w[0] in WRITES_IDX_MAP for w in writes) else "INSERT OR IGNORE INTO writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) VALUES (?, ?, ?, ?, ?, ?, ?, ?)" ) await self.setup() async with self.lock, self.conn.cursor() as cur: await cur.executemany( query, [ ( str(config["configurable"]["thread_id"]), str(config["configurable"]["checkpoint_ns"]), str(config["configurable"]["checkpoint_id"]), task_id, WRITES_IDX_MAP.get(channel, idx), channel, *self.serde.dumps_typed(value), ) for idx, (channel, value) in enumerate(writes) ], ) def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: """Generate the next version ID for a channel. This method creates a new version identifier for a channel based on its current version. Args: current (Optional[str]): The current version identifier of the channel. channel (BaseChannel): The channel being versioned. Returns: str: The next version identifier, which is guaranteed to be monotonically increasing. """ if current is None: current_v = 0 elif isinstance(current, int): current_v = current else: current_v = int(current.split(".")[0]) next_v = current_v + 1 next_h = random.random() return f"{next_v:032}.{next_h:016}" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/pyproject.toml`: ```toml [tool.poetry] name = "langgraph-checkpoint-sqlite" version = "2.0.1" description = "Library with a SQLite implementation of LangGraph checkpoint saver." authors = [] license = "MIT" readme = "README.md" repository = "https://www.github.com/langchain-ai/langgraph" packages = [{ include = "langgraph" }] [tool.poetry.dependencies] python = "^3.9.0" langgraph-checkpoint = "^2.0.2" aiosqlite = "^0.20.0" [tool.poetry.group.dev.dependencies] ruff = "^0.6.2" codespell = "^2.2.0" pytest = "^7.2.1" pytest-asyncio = "^0.21.1" pytest-mock = "^3.11.1" pytest-watcher = "^0.4.1" mypy = "^1.10.0" langgraph-checkpoint = {path = "../checkpoint", develop = true} [tool.pytest.ini_options] # --strict-markers will raise errors on unknown marks. # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks # # https://docs.pytest.org/en/7.1.x/reference/reference.html # --strict-config any warnings encountered while parsing the `pytest` # section of the configuration file raise errors. addopts = "--strict-markers --strict-config --durations=5 -vv" asyncio_mode = "auto" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] lint.select = [ "E", # pycodestyle "F", # Pyflakes "UP", # pyupgrade "B", # flake8-bugbear "I", # isort ] lint.ignore = ["E501", "B008", "UP007", "UP006"] [tool.pytest-watcher] now = true delay = 0.1 runner_args = ["--ff", "-v", "--tb", "short"] patterns = ["*.py"] [tool.mypy] # https://mypy.readthedocs.io/en/stable/config_file.html disallow_untyped_defs = "True" explicit_package_bases = "True" warn_no_return = "False" warn_unused_ignores = "True" warn_redundant_casts = "True" allow_redefinition = "True" disable_error_code = "typeddict-item, return-value" ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/tests/test_aiosqlite.py`: ```py from typing import Any import pytest from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( Checkpoint, CheckpointMetadata, create_checkpoint, empty_checkpoint, ) from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver class TestAsyncSqliteSaver: @pytest.fixture(autouse=True) def setup(self) -> None: # objects for test setup self.config_1: RunnableConfig = { "configurable": { "thread_id": "thread-1", # for backwards compatibility testing "thread_ts": "1", "checkpoint_ns": "", } } self.config_2: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2", "checkpoint_ns": "", } } self.config_3: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2-inner", "checkpoint_ns": "inner", } } self.chkpnt_1: Checkpoint = empty_checkpoint() self.chkpnt_2: Checkpoint = create_checkpoint(self.chkpnt_1, {}, 1) self.chkpnt_3: Checkpoint = empty_checkpoint() self.metadata_1: CheckpointMetadata = { "source": "input", "step": 2, "writes": {}, "score": 1, } self.metadata_2: CheckpointMetadata = { "source": "loop", "step": 1, "writes": {"foo": "bar"}, "score": None, } self.metadata_3: CheckpointMetadata = {} async def test_asearch(self) -> None: async with AsyncSqliteSaver.from_conn_string(":memory:") as saver: await saver.aput(self.config_1, self.chkpnt_1, self.metadata_1, {}) await saver.aput(self.config_2, self.chkpnt_2, self.metadata_2, {}) await saver.aput(self.config_3, self.chkpnt_3, self.metadata_3, {}) # call method / assertions query_1 = {"source": "input"} # search by 1 key query_2 = { "step": 1, "writes": {"foo": "bar"}, } # search by multiple keys query_3: dict[str, Any] = {} # search by no keys, return all checkpoints query_4 = {"source": "update", "step": 1} # no match search_results_1 = [c async for c in saver.alist(None, filter=query_1)] assert len(search_results_1) == 1 assert search_results_1[0].metadata == self.metadata_1 search_results_2 = [c async for c in saver.alist(None, filter=query_2)] assert len(search_results_2) == 1 assert search_results_2[0].metadata == self.metadata_2 search_results_3 = [c async for c in saver.alist(None, filter=query_3)] assert len(search_results_3) == 3 search_results_4 = [c async for c in saver.alist(None, filter=query_4)] assert len(search_results_4) == 0 # search by config (defaults to checkpoints across all namespaces) search_results_5 = [ c async for c in saver.alist({"configurable": {"thread_id": "thread-2"}}) ] assert len(search_results_5) == 2 assert { search_results_5[0].config["configurable"]["checkpoint_ns"], search_results_5[1].config["configurable"]["checkpoint_ns"], } == {"", "inner"} # TODO: test before and limit params ``` `/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/tests/test_sqlite.py`: ```py from typing import Any, cast import pytest from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( Checkpoint, CheckpointMetadata, create_checkpoint, empty_checkpoint, ) from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.checkpoint.sqlite.utils import _metadata_predicate, search_where class TestSqliteSaver: @pytest.fixture(autouse=True) def setup(self) -> None: # objects for test setup self.config_1: RunnableConfig = { "configurable": { "thread_id": "thread-1", # for backwards compatibility testing "thread_ts": "1", "checkpoint_ns": "", } } self.config_2: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2", "checkpoint_ns": "", } } self.config_3: RunnableConfig = { "configurable": { "thread_id": "thread-2", "checkpoint_id": "2-inner", "checkpoint_ns": "inner", } } self.chkpnt_1: Checkpoint = empty_checkpoint() self.chkpnt_2: Checkpoint = create_checkpoint(self.chkpnt_1, {}, 1) self.chkpnt_3: Checkpoint = empty_checkpoint() self.metadata_1: CheckpointMetadata = { "source": "input", "step": 2, "writes": {}, "score": 1, } self.metadata_2: CheckpointMetadata = { "source": "loop", "step": 1, "writes": {"foo": "bar"}, "score": None, } self.metadata_3: CheckpointMetadata = {} def test_search(self) -> None: with SqliteSaver.from_conn_string(":memory:") as saver: # set up test # save checkpoints saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {}) saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {}) saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {}) # call method / assertions query_1 = {"source": "input"} # search by 1 key query_2 = { "step": 1, "writes": {"foo": "bar"}, } # search by multiple keys query_3: dict[str, Any] = {} # search by no keys, return all checkpoints query_4 = {"source": "update", "step": 1} # no match search_results_1 = list(saver.list(None, filter=query_1)) assert len(search_results_1) == 1 assert search_results_1[0].metadata == self.metadata_1 search_results_2 = list(saver.list(None, filter=query_2)) assert len(search_results_2) == 1 assert search_results_2[0].metadata == self.metadata_2 search_results_3 = list(saver.list(None, filter=query_3)) assert len(search_results_3) == 3 search_results_4 = list(saver.list(None, filter=query_4)) assert len(search_results_4) == 0 # search by config (defaults to checkpoints across all namespaces) search_results_5 = list( saver.list({"configurable": {"thread_id": "thread-2"}}) ) assert len(search_results_5) == 2 assert { search_results_5[0].config["configurable"]["checkpoint_ns"], search_results_5[1].config["configurable"]["checkpoint_ns"], } == {"", "inner"} # TODO: test before and limit params def test_search_where(self) -> None: # call method / assertions expected_predicate_1 = "WHERE json_extract(CAST(metadata AS TEXT), '$.source') = ? AND json_extract(CAST(metadata AS TEXT), '$.step') = ? AND json_extract(CAST(metadata AS TEXT), '$.writes') = ? AND json_extract(CAST(metadata AS TEXT), '$.score') = ? AND checkpoint_id < ?" expected_param_values_1 = ["input", 2, "{}", 1, "1"] assert search_where( None, cast(dict[str, Any], self.metadata_1), self.config_1 ) == ( expected_predicate_1, expected_param_values_1, ) def test_metadata_predicate(self) -> None: # call method / assertions expected_predicate_1 = [ "json_extract(CAST(metadata AS TEXT), '$.source') = ?", "json_extract(CAST(metadata AS TEXT), '$.step') = ?", "json_extract(CAST(metadata AS TEXT), '$.writes') = ?", "json_extract(CAST(metadata AS TEXT), '$.score') = ?", ] expected_predicate_2 = [ "json_extract(CAST(metadata AS TEXT), '$.source') = ?", "json_extract(CAST(metadata AS TEXT), '$.step') = ?", "json_extract(CAST(metadata AS TEXT), '$.writes') = ?", "json_extract(CAST(metadata AS TEXT), '$.score') IS ?", ] expected_predicate_3: list[str] = [] expected_param_values_1 = ["input", 2, "{}", 1] expected_param_values_2 = ["loop", 1, '{"foo":"bar"}', None] expected_param_values_3: list[Any] = [] assert _metadata_predicate(cast(dict[str, Any], self.metadata_1)) == ( expected_predicate_1, expected_param_values_1, ) assert _metadata_predicate(cast(dict[str, Any], self.metadata_2)) == ( expected_predicate_2, expected_param_values_2, ) assert _metadata_predicate(cast(dict[str, Any], self.metadata_3)) == ( expected_predicate_3, expected_param_values_3, ) async def test_informative_async_errors(self) -> None: with SqliteSaver.from_conn_string(":memory:") as saver: # call method / assertions with pytest.raises(NotImplementedError, match="AsyncSqliteSaver"): await saver.aget(self.config_1) with pytest.raises(NotImplementedError, match="AsyncSqliteSaver"): await saver.aget_tuple(self.config_1) with pytest.raises(NotImplementedError, match="AsyncSqliteSaver"): async for _ in saver.alist(self.config_1): pass ``` `/Users/malcolm/dev/langchain-ai/langgraph/examples/chatbot-simulation-evaluation/simulation_utils.py`: ```py import functools from typing import Annotated, Any, Callable, Dict, List, Optional, Union from langchain_community.adapters.openai import convert_message_to_dict from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import Runnable, RunnableLambda from langchain_core.runnables import chain as as_runnable from langchain_openai import ChatOpenAI from typing_extensions import TypedDict from langgraph.graph import END, StateGraph, START def langchain_to_openai_messages(messages: List[BaseMessage]): """ Convert a list of langchain base messages to a list of openai messages. Parameters: messages (List[BaseMessage]): A list of langchain base messages. Returns: List[dict]: A list of openai messages. """ return [ convert_message_to_dict(m) if isinstance(m, BaseMessage) else m for m in messages ] def create_simulated_user( system_prompt: str, llm: Runnable | None = None ) -> Runnable[Dict, AIMessage]: """ Creates a simulated user for chatbot simulation. Args: system_prompt (str): The system prompt to be used by the simulated user. llm (Runnable | None, optional): The language model to be used for the simulation. Defaults to gpt-3.5-turbo. Returns: Runnable[Dict, AIMessage]: The simulated user for chatbot simulation. """ return ChatPromptTemplate.from_messages( [ ("system", system_prompt), MessagesPlaceholder(variable_name="messages"), ] ) | (llm or ChatOpenAI(model="gpt-3.5-turbo")).with_config( run_name="simulated_user" ) Messages = Union[list[AnyMessage], AnyMessage] def add_messages(left: Messages, right: Messages) -> Messages: if not isinstance(left, list): left = [left] if not isinstance(right, list): right = [right] return left + right class SimulationState(TypedDict): """ Represents the state of a simulation. Attributes: messages (List[AnyMessage]): A list of messages in the simulation. inputs (Optional[dict[str, Any]]): Optional inputs for the simulation. """ messages: Annotated[List[AnyMessage], add_messages] inputs: Optional[dict[str, Any]] def create_chat_simulator( assistant: ( Callable[[List[AnyMessage]], str | AIMessage] | Runnable[List[AnyMessage], str | AIMessage] ), simulated_user: Runnable[Dict, AIMessage], *, input_key: str, max_turns: int = 6, should_continue: Optional[Callable[[SimulationState], str]] = None, ): """Creates a chat simulator for evaluating a chatbot. Args: assistant: The chatbot assistant function or runnable object. simulated_user: The simulated user object. input_key: The key for the input to the chat simulation. max_turns: The maximum number of turns in the chat simulation. Default is 6. should_continue: Optional function to determine if the simulation should continue. If not provided, a default function will be used. Returns: The compiled chat simulation graph. """ graph_builder = StateGraph(SimulationState) graph_builder.add_node( "user", _create_simulated_user_node(simulated_user), ) graph_builder.add_node( "assistant", _fetch_messages | assistant | _coerce_to_message ) graph_builder.add_edge("assistant", "user") graph_builder.add_conditional_edges( "user", should_continue or functools.partial(_should_continue, max_turns=max_turns), ) # If your dataset has a 'leading question/input', then we route first to the assistant, otherwise, we let the user take the lead. graph_builder.add_edge(START, "assistant" if input_key is not None else "user") return ( RunnableLambda(_prepare_example).bind(input_key=input_key) | graph_builder.compile() ) ## Private methods def _prepare_example(inputs: dict[str, Any], input_key: Optional[str] = None): if input_key is not None: if input_key not in inputs: raise ValueError( f"Dataset's example input must contain the provided input key: '{input_key}'.\nFound: {list(inputs.keys())}" ) messages = [HumanMessage(content=inputs[input_key])] return { "inputs": {k: v for k, v in inputs.items() if k != input_key}, "messages": messages, } return {"inputs": inputs, "messages": []} def _invoke_simulated_user(state: SimulationState, simulated_user: Runnable): """Invoke the simulated user node.""" runnable = ( simulated_user if isinstance(simulated_user, Runnable) else RunnableLambda(simulated_user) ) inputs = state.get("inputs", {}) inputs["messages"] = state["messages"] return runnable.invoke(inputs) def _swap_roles(state: SimulationState): new_messages = [] for m in state["messages"]: if isinstance(m, AIMessage): new_messages.append(HumanMessage(content=m.content)) else: new_messages.append(AIMessage(content=m.content)) return { "inputs": state.get("inputs", {}), "messages": new_messages, } @as_runnable def _fetch_messages(state: SimulationState): """Invoke the simulated user node.""" return state["messages"] def _convert_to_human_message(message: BaseMessage): return {"messages": [HumanMessage(content=message.content)]} def _create_simulated_user_node(simulated_user: Runnable): """Simulated user accepts a {"messages": [...]} argument and returns a single message.""" return ( _swap_roles | RunnableLambda(_invoke_simulated_user).bind(simulated_user=simulated_user) | _convert_to_human_message ) def _coerce_to_message(assistant_output: str | BaseMessage): if isinstance(assistant_output, str): return {"messages": [AIMessage(content=assistant_output)]} else: return {"messages": [assistant_output]} def _should_continue(state: SimulationState, max_turns: int = 6): messages = state["messages"] # TODO support other stop criteria if len(messages) > max_turns: return END elif messages[-1].content.strip() == "FINISHED": return END else: return "assistant" ```

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/bossjones/datetime-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server