langgraph.txt•2.76 MB
Project Path: langgraph
Source Tree:
```
langgraph
├── LICENSE
├── requirements.txt
├── Makefile
├── pyproject.toml
├── docs
│ ├── overrides
│ │ ├── main.html
│ │ └── partials
│ │ ├── logo.html
│ │ └── comments.html
│ ├── mkdocs.yml
│ ├── _scripts
│ │ ├── notebook_convert_templates
│ │ │ └── mdoutput
│ │ │ ├── conf.json
│ │ │ └── index.md.j2
│ │ ├── generate_api_reference_links.py
│ │ ├── prepare_notebooks_for_ci.py
│ │ ├── notebook_convert.py
│ │ ├── notebook_hooks.py
│ │ ├── download_tiktoken.py
│ │ └── execute_notebooks.sh
│ ├── test-compose.yml
│ ├── cassettes
│ │ ├── LLMCompiler_31854dfd-b82f-4c24-9b58-6bae66777909.msgpack.zlib
│ │ ├── introduction_6a385b06-8d34-4a2d-aded-1cf4bb0ca590.msgpack.zlib
│ │ ├── wait-user-input_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib
│ │ ├── state-model_8edb04b9-40b6-46f1-a7a8-4b2d8aba7752.msgpack.zlib
│ │ ├── create-react-agent-hitl_83148e08-63e8-49e5-a08b-02dc907bed1d.msgpack.zlib
│ │ ├── add-summary-conversation-history_048805a4-3d97-4e76-ac45-8d80d4364c46.msgpack.zlib
│ │ ├── introduction_32e4f36e-72ce-4ade-bd7e-94880e0d456b.msgpack.zlib
│ │ ├── create-react-agent_fa16de4c-aac0-4ff4-ab69-60d399f75423.msgpack.zlib
│ │ ├── customer-support_b7443751.msgpack.zlib
│ │ ├── introduction_761d15fb-d5e2-4d50-a630-126d77e77294.msgpack.zlib
│ │ ├── delete-messages_57b27553-21be-43e5-ac48-d1d0a3aa0dca.msgpack.zlib
│ │ ├── persistence_postgres_386b78bc-2f73-49ba-a2a4-47bce6fc49b7.msgpack.zlib
│ │ ├── LLMCompiler_38d3ea91-59ba-4267-8060-ed75bbc840c6.msgpack.zlib
│ │ ├── introduction_35c8978e-c07d-4dd0-a97b-0ce3a723eea5.msgpack.zlib
│ │ ├── disable-streaming_8.msgpack.zlib
│ │ ├── introduction_9f318020-ab7e-415b-a5e2-eddec6d9f3a6.msgpack.zlib
│ │ ├── create-react-agent-memory_187479f9-32fa-4611-9487-cf816ba2e147.msgpack.zlib
│ │ ├── streaming-from-final-node_55d60dfa-96e3-442f-9924-0c99f46baed8.msgpack.zlib
│ │ ├── pass-config-to-tools_14.msgpack.zlib
│ │ ├── introduction_dba1b168-f8e0-496d-9bd6-37198fb4776e.msgpack.zlib
│ │ ├── tool-calling_19.msgpack.zlib
│ │ ├── reflexion_6fd51f17-c0b0-44b6-90e2-55a66cb8f5a7.msgpack.zlib
│ │ ├── persistence_08ae8246-11d5-40e1-8567-361e5bef8917.msgpack.zlib
│ │ ├── time-travel_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib
│ │ ├── customer-support_783a548d-029a-47b7-9ac0-9c5203ec92c7.msgpack.zlib
│ │ ├── sql-agent_cf02e843-438d-4168-a27a-8f1e0266f8d7.msgpack.zlib
│ │ ├── edit-graph-state_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib
│ │ ├── reflection_dfbf99a8-3aa0-4e09-936e-8452c35fa84d.msgpack.zlib
│ │ ├── retries_746b409c-693d-49af-8c2b-bea0a4b0028d.msgpack.zlib
│ │ ├── langgraph_code_assistant_2dacccf0-d73f-4017-aaf0-9806ffe5bd2c.msgpack.zlib
│ │ ├── semantic-search_8.msgpack.zlib
│ │ ├── langgraph_self_rag_fb69dbb9-91ee-4868-8c3c-93af3cd885be.msgpack.zlib
│ │ ├── recursion-limit_5.msgpack.zlib
│ │ ├── LLMCompiler_730490c6-6e3a-4173-82a1-9eb9d5eeff20.msgpack.zlib
│ │ ├── add-summary-conversation-history_57b27553-21be-43e5-ac48-d1d0a3aa0dca.msgpack.zlib
│ │ ├── sql-agent_3bf7709f-500c-4f28-bb85-dda317286c63.msgpack.zlib
│ │ ├── agent_supervisor_45a92dfd-0e11-47f5-aad4-b68d24990e34.msgpack.zlib
│ │ ├── persistence_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib
│ │ ├── introduction_8b49509c-9d97-457c-a76a-c495fb30ccbc.msgpack.zlib
│ │ ├── langgraph_agentic_rag_278d1d83-dda6-4de4-bf8b-be9965c227fa.msgpack.zlib
│ │ ├── introduction_9fc99c7e-b61d-4aec-9c62-042798185ec3.msgpack.zlib
│ │ ├── introduction_f4009ba6-dc0b-4216-ab0c-fbb104616f73.msgpack.zlib
│ │ ├── subgraph_13.msgpack.zlib
│ │ ├── cross-thread-persistence_d362350b-d730-48bd-9652-983812fd7811.msgpack.zlib
│ │ ├── customer-support_72fceb01-b0ab-4bef-a22f-a2fce6ee33ef.msgpack.zlib
│ │ ├── branching_1d0e6c56.msgpack.zlib
│ │ ├── LLMCompiler_942dab42-ad42-4ba2-90d5-49edbe4fae68.msgpack.zlib
│ │ ├── tool-calling-errors_19.msgpack.zlib
│ │ ├── run-id-langsmith_6.msgpack.zlib
│ │ ├── introduction_f5447778-53d7-47f3-801b-f47bcf2185a0.msgpack.zlib
│ │ ├── semantic-search_19.msgpack.zlib
│ │ ├── branching_83320227-8ab3-44c0-b6cf-064a7a425b9f.msgpack.zlib
│ │ ├── shared-state_d362350b-d730-48bd-9652-983812fd7811.msgpack.zlib
│ │ ├── introduction_85f17be3-eaf6-495e-a846-49436916b4ab.msgpack.zlib
│ │ ├── introduction_051dc374-67cc-4371-9dd1-221e07593148.msgpack.zlib
│ │ ├── configuration_e043a719-f197-46ef-9d45-84740a39aeb0.msgpack.zlib
│ │ ├── review-tool-calls_d57d5131-7912-4216-aa87-b7272507fa51.msgpack.zlib
│ │ ├── edit-graph-state_85e452f8-f33a-4ead-bb4d-7386cdba8edc.msgpack.zlib
│ │ ├── introduction_effb95d9-b7d5-40c5-9253-253d193b23b2.msgpack.zlib
│ │ ├── retries_f4752239-2aa3-4367-b777-8478c16b9471.msgpack.zlib
│ │ ├── review-tool-calls_3f05f8b6-6128-4de5-8884-862fc93f1227.msgpack.zlib
│ │ ├── langgraph_self_rag_bd62276f-bf26-40d0-8cff-e07b10e00321.msgpack.zlib
│ │ ├── langgraph_code_assistant_3ba3df70-f6b4-4ea5-a210-e10944960bc6.msgpack.zlib
│ │ ├── langgraph_code_assistant_ef7cf662-7a6f-4dee-965c-6309d4045feb.msgpack.zlib
│ │ ├── create-react-agent-hitl_740bbaeb.msgpack.zlib
│ │ ├── disable-streaming_4.msgpack.zlib
│ │ ├── plan-and-execute_746e697a-dec4-4342-a814-9b3456828169.msgpack.zlib
│ │ ├── LLMCompiler_55142257-2674-4a47-988e-0d2810917329.msgpack.zlib
│ │ ├── create-react-agent-memory_9ffff6c3-a4f5-47c9-b51d-97caaee85cd6.msgpack.zlib
│ │ ├── pass-config-to-tools_18.msgpack.zlib
│ │ ├── semantic-search_13.msgpack.zlib
│ │ ├── subgraphs-manage-state_13.msgpack.zlib
│ │ ├── subgraphs-manage-state_46.msgpack.zlib
│ │ ├── LLMCompiler_391d6931.msgpack.zlib
│ │ ├── stream-updates_e9e9ffb0-2cd5-466f-b70b-b6ed51b852d1.msgpack.zlib
│ │ ├── streaming-from-final-node_68ac2c7f.msgpack.zlib
│ │ ├── review-tool-calls_df4a9900-d953-4465-b8af-bd2858cb63ea.msgpack.zlib
│ │ ├── breakpoints_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib
│ │ ├── configuration_070f11a6-2441-4db5-9df6-e318f110e281.msgpack.zlib
│ │ ├── langgraph_agentic_rag_7b5a1d35.msgpack.zlib
│ │ ├── create-react-agent-hitl_9ffff6c3-a4f5-47c9-b51d-97caaee85cd6.msgpack.zlib
│ │ ├── tool-calling_23.msgpack.zlib
│ │ ├── reflexion_5922e1fe-7533-4f41-8b1d-d812707c1968.msgpack.zlib
│ │ ├── tool-calling-errors_13.msgpack.zlib
│ │ ├── langgraph_self_rag_e78931ec-940c-46ad-a0b2-f43f953f1fd7.msgpack.zlib
│ │ ├── configuration_f2f7c74b-9fb0-41c6-9728-dcf9d8a3c397.msgpack.zlib
│ │ ├── LLMCompiler_0b3a0916-d8ca-4092-b91c-d9e2b05259d8.msgpack.zlib
│ │ ├── react-agent-from-scratch_11.msgpack.zlib
│ │ ├── persistence_postgres_180d6daf-8fa7-4608-bd2e-bfbf44ed5836.msgpack.zlib
│ │ ├── subgraphs-manage-state_30.msgpack.zlib
│ │ ├── persistence_redis_6a39d1ff-ca37-4457-8b52-07d33b59c36e.msgpack.zlib
│ │ ├── sql-agent_4f200d1813897000.msgpack.zlib
│ │ ├── delete-messages_3975f34c-c243-40ea-b9d2-424d50a48dc9.msgpack.zlib
│ │ ├── semantic-search_15.msgpack.zlib
│ │ ├── review-tool-calls_f9a0d5d4-52ff-49e0-a6f4-41f9a0e844d8.msgpack.zlib
│ │ ├── breakpoints_6098e5cb.msgpack.zlib
│ │ ├── rewoo_badaca52-5d55-433f-8770-1bd50c10bf7f.msgpack.zlib
│ │ ├── shared-state_d862be40-1f8a-4057-81c4-b7bf073dc4c1.msgpack.zlib
│ │ ├── add-summary-conversation-history_40e5db8e-9db9-4ac7-9d76-a99fd4034bf3.msgpack.zlib
│ │ ├── subgraphs-manage-state_40.msgpack.zlib
│ │ ├── subgraphs-manage-state_15.msgpack.zlib
│ │ ├── persistence_redis_5fe54e79-9eaf-44e2-b2d9-1e0284b984d0.msgpack.zlib
│ │ ├── configuration_ef50f048-fc43-40c0-b713-346408fcf052.msgpack.zlib
│ │ ├── plan-and-execute_b8ac1f67-e87a-427c-b4f7-44351295b788.msgpack.zlib
│ │ ├── langgraph_self_rag_1fafad21-60cc-483e-92a3-6a7edb1838e3.msgpack.zlib
│ │ ├── disable-streaming_2.msgpack.zlib
│ │ ├── async_8edb04b9-40b6-46f1-a7a8-4b2d8aba7752.msgpack.zlib
│ │ ├── agent-simulation-evaluation_6f80669e-aa78-4666-b67c-a539366d5aab.msgpack.zlib
│ │ ├── react-agent-structured-output_15.msgpack.zlib
│ │ ├── tool-calling-errors_9.msgpack.zlib
│ │ ├── subgraphs-manage-state_9.msgpack.zlib
│ │ ├── retries_04e93401-50e2-42d0-8373-326006badebb.msgpack.zlib
│ │ ├── state-model_e09aaa63.msgpack.zlib
│ │ ├── persistence_273d56a8-f40f-4a51-a27f-7c6bb2bda0ba.msgpack.zlib
│ │ ├── subgraph_9.msgpack.zlib
│ │ ├── langgraph_code_assistant_9bcaafe4-ddcf-4fab-8620-2d9b6c508f98.msgpack.zlib
│ │ ├── tool-calling_25.msgpack.zlib
│ │ ├── langgraph_code_assistant_9f14750f-dddc-485b-ba29-5392cdf4ba43.msgpack.zlib
│ │ ├── cross-thread-persistence_d862be40-1f8a-4057-81c4-b7bf073dc4c1.msgpack.zlib
│ │ ├── subgraphs-manage-state_42.msgpack.zlib
│ │ ├── persistence_postgres_6a39d1ff-ca37-4457-8b52-07d33b59c36e.msgpack.zlib
│ │ ├── sql-agent_64b0bf1b14c2e902.msgpack.zlib
│ │ ├── stream-values_e9e9ffb0-2cd5-466f-b70b-b6ed51b852d1.msgpack.zlib
│ │ ├── introduction_faa345c6-38a2-42e8-9035-9cf56f7bb5b1.msgpack.zlib
│ │ ├── sql-agent_eb814b85-70ba-4699-9038-20266b53efbd.msgpack.zlib
│ │ ├── add-summary-conversation-history_7094c5ab-66f8-42ff-b1c3-90c8a9468e62.msgpack.zlib
│ │ ├── pass-run-time-values-to-tools_2e3fd1e2-cc19-4023-8ffa-b0fc13da9e09.msgpack.zlib
│ │ ├── introduction_acbec099-e5d2-497f-929e-c548d7bcbf77.msgpack.zlib
│ │ ├── customer-support_96469e95-5070-4169-bedd-45db94b43d97.msgpack.zlib
│ │ ├── semantic-search_17.msgpack.zlib
│ │ ├── semantic-search_10.msgpack.zlib
│ │ ├── streaming-from-final-node_2ab6d079-ba06-48ba-abe5-e72df24407af.msgpack.zlib
│ │ ├── langgraph_self_rag_dcd77cc1-4587-40ec-b633-5364eab9e1ec.msgpack.zlib
│ │ ├── breakpoints_51923913-20f7-4ee1-b9ba-d01f5fb2869b.msgpack.zlib
│ │ ├── wait-user-input_a9f599b5-1a55-406b-a76b-f52b3ca06975.msgpack.zlib
│ │ ├── self-discover_6cbfbe81-f751-42da-843a-f9003ace663d.msgpack.zlib
│ │ ├── pass-run-time-values-to-tools_d683986f-faf4-4724-a13f-bac39ea9bafe.msgpack.zlib
│ │ ├── agent_supervisor_56ba78e9-d9c1-457c-a073-d606d5d3e013.msgpack.zlib
│ │ ├── map-reduce_fd90cace.msgpack.zlib
│ │ ├── reflexion_2634a3ea-7423-4579-9f4e-390e439c3209.msgpack.zlib
│ │ ├── stream-values_c122bf15-a489-47bf-b482-a744a54e2cc4.msgpack.zlib
│ │ ├── reflection_06263a07-8a15-4ec3-b692-1c6cef3b1c1f.msgpack.zlib
│ │ ├── persistence_postgres_bd235fc7-1e5c-4db6-a90b-ea75462ccf7d.msgpack.zlib
│ │ ├── async_4b369a6f.msgpack.zlib
│ │ ├── subgraphs-manage-state_33.msgpack.zlib
│ │ ├── streaming-subgraphs_7.msgpack.zlib
│ │ ├── semantic-search_6.msgpack.zlib
│ │ ├── persistence_mongodb_5fe54e79-9eaf-44e2-b2d9-1e0284b984d0.msgpack.zlib
│ │ ├── sql-agent_293017e8f05ac2b3.msgpack.zlib
│ │ ├── tool-calling-errors_17.msgpack.zlib
│ │ ├── langgraph_code_assistant_c2eb35d1-4990-47dc-a5c4-208bae588a82.msgpack.zlib
│ │ ├── introduction_03a09bfc-3d90-4e54-878f-22e3cb28a418.msgpack.zlib
│ │ ├── introduction_6b32914d-4d60-491f-8e11-1e6867e38ffd.msgpack.zlib
│ │ ├── streaming-tokens_96050fba.msgpack.zlib
│ │ ├── reflexion_7541f82c.msgpack.zlib
│ │ ├── hierarchical_agent_teams_6b8badbf-d728-44bd-a2a7-5b4e587c92fe.msgpack.zlib
│ │ ├── add-summary-conversation-history_0a1a0fda-5309-45f0-9465-9f3dff604d74.msgpack.zlib
│ │ ├── subgraphs-manage-state_11.msgpack.zlib
│ │ ├── reflection_16c5eb2a-8bce-48ab-b87d-9dacb9b64ac6.msgpack.zlib
│ │ ├── persistence_mongodb_6a39d1ff-ca37-4457-8b52-07d33b59c36e.msgpack.zlib
│ │ ├── create-react-agent-system-prompt_9ffff6c3-a4f5-47c9-b51d-97caaee85cd6.msgpack.zlib
│ │ ├── dynamic_breakpoints_9a14c8b2-5c25-4201-93ea-e5358ee99bcb.msgpack.zlib
│ │ ├── branching_66f52a20.msgpack.zlib
│ │ ├── tool-calling_17.msgpack.zlib
│ │ ├── LLMCompiler_15dd9639-691f-4906-9012-83fd6e9ac126.msgpack.zlib
│ │ ├── LLMCompiler_152eecf3-6bef-4718-af71-a0b3c5a3b009.msgpack.zlib
│ │ ├── introduction_69071b02-c011-4b7f-90b1-8e89e032322d.msgpack.zlib
│ │ ├── tool-calling-errors_11.msgpack.zlib
│ │ ├── return-when-recursion-limit-hits_6.msgpack.zlib
│ │ ├── sql-agent_85958809-03c5-4e52-97cc-e7c0ae986f60.msgpack.zlib
│ │ ├── persistence_postgres_5fe54e79-9eaf-44e2-b2d9-1e0284b984d0.msgpack.zlib
│ │ ├── tool-calling_26.msgpack.zlib
│ │ ├── streaming-events-from-within-tools-without-langchain_45c96a79-4147-42e3-89fd-d942b2b49f6c.msgpack.zlib
│ │ ├── react-agent-from-scratch_13.msgpack.zlib
│ │ ├── create-react-agent_187479f9-32fa-4611-9487-cf816ba2e147.msgpack.zlib
│ │ ├── langgraph_code_assistant_71d90f9e-9dad-410c-a709-093d275029ae.msgpack.zlib
│ │ ├── map-reduce_37ed1f71-63db-416f-b715-4617b33d4b7f.msgpack.zlib
│ │ ├── hierarchical_agent_teams_912b0604-a178-4246-a36f-2dedae606680.msgpack.zlib
│ │ ├── introduction_b3220ae2-cba0-4447-96d1-eb0be4684e59.msgpack.zlib
│ │ ├── stream-multiple_e9e9ffb0-2cd5-466f-b70b-b6ed51b852d1.msgpack.zlib
│ │ ├── wait-user-input_f5319e01.msgpack.zlib
│ │ ├── persistence_6fa9a5e3-7101-43ab-a811-592e222b9580.msgpack.zlib
│ │ ├── review-tool-calls_ec77831c-e6b8-4903-9146-e098a4b2fda1.msgpack.zlib
│ │ ├── review-tool-calls_85e452f8-f33a-4ead-bb4d-7386cdba8edc.msgpack.zlib
│ │ ├── review-tool-calls_1b3aa6fc-c7fb-4819-8d7f-ba6057cc4edf.msgpack.zlib
│ │ ├── streaming-events-from-within-tools_ec461f66.msgpack.zlib
│ │ ├── cross-thread-persistence_c871a073-a466-46ad-aafe-2b870831057e.msgpack.zlib
│ │ ├── plan-and-execute_72d233ca-1dbf-4b43-b680-b3bf39e3691f.msgpack.zlib
│ │ ├── introduction_11d5b934-6d8b-4f52-a3bc-b3daa7207e00.msgpack.zlib
│ │ ├── self-discover_a18d8f24-5d9a-45c5-9739-6f3c4ed6c9c9.msgpack.zlib
│ │ ├── plan-and-execute_67ce37b7-e089-479b-bcb8-c3f5d9874613.msgpack.zlib
│ │ ├── langgraph_self_rag_c6f4c70e-1660-4149-82c0-837f19fc9fb5.msgpack.zlib
│ │ ├── agent-simulation-evaluation_32848c2e-be82-46f3-81db-b23fea45461c.msgpack.zlib
│ │ ├── manage-conversation-history_57b27553-21be-43e5-ac48-d1d0a3aa0dca.msgpack.zlib
│ │ ├── time-travel_e986f94f-706f-4b6f-b3c4-f95483b9e9b8.msgpack.zlib
│ │ ├── shared-state_c871a073-a466-46ad-aafe-2b870831057e.msgpack.zlib
│ │ ├── pass-config-to-tools_17.msgpack.zlib
│ │ ├── introduction_c1955d79-a1e4-47d0-ba79-b45bd5752a23.msgpack.zlib
│ │ ├── breakpoints_9b53f191-1e86-4881-a667-d46a3d66958b.msgpack.zlib
│ │ ├── langgraph_self_rag_4138bc51-8c84-4b8a-8d24-f7f470721f6f.msgpack.zlib
│ │ ├── pass-run-time-values-to-tools_c0858273-10f2-45c4-a922-b11321ac3fae.msgpack.zlib
│ │ ├── streaming-tokens-without-langchain_d6ed3df5.msgpack.zlib
│ │ ├── langgraph_agentic_rag_7649f05a-cb67-490d-b24a-74d41895139a.msgpack.zlib
│ │ ├── review-tool-calls_a30d40ad-611d-4ec3-84be-869ea05acb89.msgpack.zlib
│ │ ├── information-gather-prompting_1b1613e0.msgpack.zlib
│ │ ├── retries_5d072c9c-9404-4338-88c6-b3e136969aca.msgpack.zlib
│ │ ├── streaming-tokens_72785b66.msgpack.zlib
│ │ ├── async_f544977e-31f7-41f0-88c4-ec9c27b8cecb.msgpack.zlib
│ │ ├── branching_932c497e.msgpack.zlib
│ │ ├── manage-conversation-history_52468ebb-4b23-45ac-a98e-b4439f37740a.msgpack.zlib
│ │ ├── plan-and-execute_7363e528.msgpack.zlib
│ │ ├── configuration_718685f7-4cdd-4181-9fc8-e7762d584727.msgpack.zlib
│ │ ├── time-travel_9a92d3da-62e2-45a2-8545-e4f6a64e0ffe.msgpack.zlib
│ │ ├── LLMCompiler_5bc4584a-e31c-4065-805e-76a6db30676a.msgpack.zlib
│ │ ├── create-react-agent_9ffff6c3-a4f5-47c9-b51d-97caaee85cd6.msgpack.zlib
│ │ ├── wait-user-input_58eae42d-be32-48da-8d0a-ab64471657d9.msgpack.zlib
│ │ ├── sql-agent_1040233f-3751-4bd3-902f-709fc2e1ecf5.msgpack.zlib
│ │ ├── persistence_postgres_4faf6087-73cc-4957-9a4f-f3509a32a740.msgpack.zlib
│ │ ├── review-tool-calls_2561a38f-edb5-4b44-b2d7-6a7b70d2e6b7.msgpack.zlib
│ │ ├── hierarchical_agent_teams_9860fd46-c24d-40a5-a6ba-e8fddcd43369.msgpack.zlib
│ │ ├── agent-simulation-evaluation_f58959bf-2ab5-4330-9ac2-c00f45237e24.msgpack.zlib
│ │ ├── react-agent-structured-output_9.msgpack.zlib
│ │ ├── multi-agent-collaboration_9f478b05-3f09-447f-a9f4-1b2eae73f5ef.msgpack.zlib
│ │ ├── introduction_a7debb4a-2a3a-40b9-a48c-7052ec2c2726.msgpack.zlib
│ │ ├── edit-graph-state_51923913-20f7-4ee1-b9ba-d01f5fb2869b.msgpack.zlib
│ │ ├── information-gather-prompting_25793988-45a2-4e65-b33c-64e72aadb10e.msgpack.zlib
│ │ ├── rewoo_56ecb45b-ea76-4303-a4f3-51406fe8312a.msgpack.zlib
│ │ ├── async_cfd140f0-a5a6-4697-8115-322242f197b5.msgpack.zlib
│ │ ├── subgraphs-manage-state_48.msgpack.zlib
│ │ ├── reflection_9bbe25dc-fd1e-4ed5-a3c8-fed830b46d12.msgpack.zlib
│ │ ├── introduction_4527cf9a-b191-4bde-858a-e33a74a48c55.msgpack.zlib
│ │ └── pass-config-to-tools_16.msgpack.zlib
│ ├── docs
│ │ ├── troubleshooting
│ │ │ └── errors
│ │ │ ├── INVALID_CHAT_HISTORY.md
│ │ │ ├── INVALID_GRAPH_NODE_RETURN_VALUE.md
│ │ │ ├── INVALID_CONCURRENT_GRAPH_UPDATE.md
│ │ │ ├── index.md
│ │ │ ├── MULTIPLE_SUBGRAPHS.md
│ │ │ └── GRAPH_RECURSION_LIMIT.md
│ │ ├── index.md
│ │ ├── static
│ │ │ ├── wordmark_dark.svg
│ │ │ ├── values_vs_updates.png
│ │ │ ├── wordmark_light.svg
│ │ │ └── favicon.png
│ │ ├── cloud
│ │ │ ├── sdk
│ │ │ │ └── img
│ │ │ │ ├── thread_diagram.png
│ │ │ │ └── graph_diagram.png
│ │ │ ├── deployment
│ │ │ │ ├── graph_rebuild.md
│ │ │ │ ├── setup_pyproject.md
│ │ │ │ ├── custom_docker.md
│ │ │ │ ├── setup.md
│ │ │ │ ├── semantic_search.md
│ │ │ │ ├── setup_javascript.md
│ │ │ │ ├── img
│ │ │ │ │ ├── graph_run.png
│ │ │ │ │ ├── deployed_page.png
│ │ │ │ │ ├── quick_start_studio.png
│ │ │ │ │ ├── cloud_deployment.png
│ │ │ │ │ └── deployment_page.png
│ │ │ │ ├── cloud.md
│ │ │ │ └── test_locally.md
│ │ │ ├── how-tos
│ │ │ │ ├── stream_values.md
│ │ │ │ ├── same-thread.md
│ │ │ │ ├── configuration_cloud.md
│ │ │ │ ├── background_run.md
│ │ │ │ ├── stream_debug.md
│ │ │ │ ├── human_in_the_loop_edit_state.md
│ │ │ │ ├── cron_jobs.md
│ │ │ │ ├── human_in_the_loop_breakpoint.md
│ │ │ │ ├── enqueue_concurrent.md
│ │ │ │ ├── threads_studio.md
│ │ │ │ ├── check_thread_status.md
│ │ │ │ ├── test_deployment.md
│ │ │ │ ├── human_in_the_loop_time_travel.md
│ │ │ │ ├── test_local_deployment.md
│ │ │ │ ├── stream_messages.md
│ │ │ │ ├── assistant_versioning.md
│ │ │ │ ├── webhooks.md
│ │ │ │ ├── invoke_studio.md
│ │ │ │ ├── stream_updates.md
│ │ │ │ ├── stream_events.md
│ │ │ │ ├── img
│ │ │ │ │ ├── studio_threads.mp4
│ │ │ │ │ ├── select_different_version.png
│ │ │ │ │ ├── studio_input_poster.png
│ │ │ │ │ ├── edit_created_assistant.png
│ │ │ │ │ ├── click_create_assistant.png
│ │ │ │ │ ├── studio_usage_poster.png
│ │ │ │ │ ├── create_assistant_view.png
│ │ │ │ │ ├── studio_input.mp4
│ │ │ │ │ ├── see_new_version.png
│ │ │ │ │ ├── studio_forks_poster.png
│ │ │ │ │ ├── create_new_version.png
│ │ │ │ │ ├── studio_threads_poster.png
│ │ │ │ │ ├── see_version_history.png
│ │ │ │ │ ├── create_assistant.png
│ │ │ │ │ ├── studio_screenshot.png
│ │ │ │ │ ├── studio_forks.mp4
│ │ │ │ │ └── studio_usage.mp4
│ │ │ │ ├── human_in_the_loop_review_tool_calls.md
│ │ │ │ ├── reject_concurrent.md
│ │ │ │ ├── human_in_the_loop_user_input.md
│ │ │ │ ├── langgraph_to_langgraph_cloud.ipynb
│ │ │ │ ├── stateless_runs.md
│ │ │ │ ├── interrupt_concurrent.md
│ │ │ │ ├── stream_multiple.md
│ │ │ │ ├── rollback_concurrent.md
│ │ │ │ └── copy_threads.md
│ │ │ ├── reference
│ │ │ │ ├── cli.md
│ │ │ │ ├── env_var.md
│ │ │ │ ├── sdk
│ │ │ │ │ └── python_sdk_ref.md
│ │ │ │ └── api
│ │ │ │ ├── api_ref.html
│ │ │ │ ├── api_ref.md
│ │ │ │ └── openapi.json
│ │ │ └── quick_start.md
│ │ ├── concepts
│ │ │ ├── langgraph_cli.md
│ │ │ ├── plans.md
│ │ │ ├── template_applications.md
│ │ │ ├── langgraph_platform.md
│ │ │ ├── deployment_options.md
│ │ │ ├── low_level.md
│ │ │ ├── faq.md
│ │ │ ├── streaming.md
│ │ │ ├── img
│ │ │ │ ├── multi_agent
│ │ │ │ │ ├── response.png
│ │ │ │ │ ├── architectures.png
│ │ │ │ │ └── request.png
│ │ │ │ ├── assistants.png
│ │ │ │ ├── lg_studio.png
│ │ │ │ ├── byoc_architecture.png
│ │ │ │ ├── human_in_the_loop
│ │ │ │ │ ├── replay.png
│ │ │ │ │ ├── approval.png
│ │ │ │ │ ├── wait_for_input.png
│ │ │ │ │ ├── forking.png
│ │ │ │ │ └── edit_graph_state.png
│ │ │ │ ├── memory
│ │ │ │ │ ├── update-profile.png
│ │ │ │ │ ├── hot_path_vs_background.png
│ │ │ │ │ ├── short-vs-long.png
│ │ │ │ │ ├── filter.png
│ │ │ │ │ ├── summary.png
│ │ │ │ │ ├── update-instructions.png
│ │ │ │ │ └── update-list.png
│ │ │ │ ├── double_texting.png
│ │ │ │ ├── agent_types.png
│ │ │ │ ├── langgraph_cloud_architecture.png
│ │ │ │ ├── persistence
│ │ │ │ │ ├── re_play.jpg
│ │ │ │ │ ├── checkpoints.jpg
│ │ │ │ │ ├── get_state.jpg
│ │ │ │ │ ├── checkpoints_full_story.jpg
│ │ │ │ │ └── shared_state.png
│ │ │ │ ├── langgraph.png
│ │ │ │ ├── tool_call.png
│ │ │ │ ├── challenge.png
│ │ │ │ └── lg_platform.png
│ │ │ ├── sdk.md
│ │ │ ├── high_level.md
│ │ │ ├── bring_your_own_cloud.md
│ │ │ ├── langgraph_cloud.md
│ │ │ ├── persistence.md
│ │ │ ├── double_texting.md
│ │ │ ├── memory.md
│ │ │ ├── assistants.md
│ │ │ ├── index.md
│ │ │ ├── langgraph_studio.md
│ │ │ ├── langgraph_server.md
│ │ │ ├── human_in_the_loop.md
│ │ │ ├── application_structure.md
│ │ │ ├── self_hosted.md
│ │ │ ├── agentic_concepts.md
│ │ │ └── multi_agent.md
│ │ ├── tutorials
│ │ │ ├── multi_agent
│ │ │ │ ├── agent_supervisor_docs.md
│ │ │ │ ├── agent_supervisor.py
│ │ │ │ ├── hierarchical_agent_teams.ipynb
│ │ │ │ ├── hierarchical_agent_teams_docs.md
│ │ │ │ ├── agent_supervisor.ipynb
│ │ │ │ ├── multi-agent-collaboration_docs.md
│ │ │ │ ├── multi-agent-collaboration.py
│ │ │ │ ├── multi-agent-collaboration.ipynb
│ │ │ │ └── hierarchical_agent_teams.py
│ │ │ ├── introduction.ipynb
│ │ │ ├── customer-support
│ │ │ │ ├── img
│ │ │ │ │ ├── customer-support-bot-4.png
│ │ │ │ │ ├── part-3-diagram.png
│ │ │ │ │ ├── part-4-diagram.png
│ │ │ │ │ ├── part-1-diagram.png
│ │ │ │ │ └── part-2-diagram.png
│ │ │ │ └── customer-support.ipynb
│ │ │ ├── plan-and-execute
│ │ │ │ ├── plan-and-execute.py
│ │ │ │ ├── plan-and-execute_docs.md
│ │ │ │ └── plan-and-execute.ipynb
│ │ │ ├── tnt-llm
│ │ │ │ ├── img
│ │ │ │ │ └── tnt_llm.png
│ │ │ │ └── tnt-llm.ipynb
│ │ │ ├── reflexion
│ │ │ │ ├── reflexion.ipynb
│ │ │ │ ├── reflexion_docs.md
│ │ │ │ └── reflexion.py
│ │ │ ├── code_assistant
│ │ │ │ ├── langgraph_code_assistant_docs.md
│ │ │ │ ├── langgraph_code_assistant.ipynb
│ │ │ │ └── langgraph_code_assistant.py
│ │ │ ├── lats
│ │ │ │ └── lats.ipynb
│ │ │ ├── rewoo
│ │ │ │ └── rewoo.ipynb
│ │ │ ├── web-navigation
│ │ │ │ ├── img
│ │ │ │ │ └── web-voyager.excalidraw.jpg
│ │ │ │ ├── web_voyager.ipynb
│ │ │ │ └── mark_page.js
│ │ │ ├── rag
│ │ │ │ ├── langgraph_self_rag_local.ipynb
│ │ │ │ ├── langgraph_adaptive_rag_local.ipynb
│ │ │ │ ├── langgraph_self_rag.ipynb
│ │ │ │ ├── langgraph_crag_local.ipynb
│ │ │ │ ├── langgraph_crag.ipynb
│ │ │ │ ├── langgraph_adaptive_rag.ipynb
│ │ │ │ └── langgraph_agentic_rag.ipynb
│ │ │ ├── llm-compiler
│ │ │ │ ├── LLMCompiler.ipynb
│ │ │ │ ├── output_parser.py
│ │ │ │ └── math_tools.py
│ │ │ ├── tot
│ │ │ │ ├── img
│ │ │ │ │ └── tot.png
│ │ │ │ └── tot.ipynb
│ │ │ ├── sql-agent.ipynb
│ │ │ ├── usaco
│ │ │ │ ├── img
│ │ │ │ │ ├── diagram-part-2.png
│ │ │ │ │ ├── diagram-part-1.png
│ │ │ │ │ ├── benchmark.png
│ │ │ │ │ ├── diagram.png
│ │ │ │ │ └── usaco.png
│ │ │ │ └── usaco.ipynb
│ │ │ ├── index.md
│ │ │ ├── chatbots
│ │ │ │ └── information-gather-prompting.ipynb
│ │ │ ├── chatbot-simulation-evaluation
│ │ │ │ ├── agent-simulation-evaluation.ipynb
│ │ │ │ └── langsmith-agent-simulation-evaluation.ipynb
│ │ │ ├── langgraph-platform
│ │ │ │ └── local-server.md
│ │ │ ├── extraction
│ │ │ │ └── retries.ipynb
│ │ │ ├── storm
│ │ │ │ └── storm.ipynb
│ │ │ ├── self-discover
│ │ │ │ └── self-discover.ipynb
│ │ │ └── reflection
│ │ │ ├── reflection.py
│ │ │ ├── reflection.ipynb
│ │ │ └── reflection_docs.md
│ │ ├── how-tos
│ │ │ ├── run-id-langsmith.ipynb
│ │ │ ├── react-agent-from-scratch.ipynb
│ │ │ ├── streaming-from-final-node.ipynb
│ │ │ ├── recursion-limit.ipynb
│ │ │ ├── human_in_the_loop
│ │ │ │ ├── breakpoints.ipynb
│ │ │ │ ├── time-travel.ipynb
│ │ │ │ ├── edit-graph-state.ipynb
│ │ │ │ ├── dynamic_breakpoints.ipynb
│ │ │ │ ├── review-tool-calls.ipynb
│ │ │ │ └── wait-user-input.ipynb
│ │ │ ├── use-remote-graph.md
│ │ │ ├── memory
│ │ │ │ ├── semantic-search.ipynb
│ │ │ │ ├── add-summary-conversation-history.ipynb
│ │ │ │ ├── manage-conversation-history.ipynb
│ │ │ │ └── delete-messages.ipynb
│ │ │ ├── streaming-tokens.ipynb
│ │ │ ├── local-studio.md
│ │ │ ├── create-react-agent.ipynb
│ │ │ ├── configuration.ipynb
│ │ │ ├── visualization.ipynb
│ │ │ ├── branching.ipynb
│ │ │ ├── create-react-agent-memory.ipynb
│ │ │ ├── many-tools.ipynb
│ │ │ ├── streaming-events-from-within-tools.ipynb
│ │ │ ├── async.ipynb
│ │ │ ├── streaming-content.ipynb
│ │ │ ├── subgraphs-manage-state.ipynb
│ │ │ ├── persistence.ipynb
│ │ │ ├── node-retries.ipynb
│ │ │ ├── persistence_mongodb.ipynb
│ │ │ ├── create-react-agent-hitl.ipynb
│ │ │ ├── state-model.ipynb
│ │ │ ├── persistence_postgres.ipynb
│ │ │ ├── deploy-self-hosted.md
│ │ │ ├── stream-multiple.ipynb
│ │ │ ├── index.md
│ │ │ ├── pass-run-time-values-to-tools.ipynb
│ │ │ ├── stream-updates.ipynb
│ │ │ ├── cross-thread-persistence.ipynb
│ │ │ ├── return-when-recursion-limit-hits.ipynb
│ │ │ ├── tool-calling-errors.ipynb
│ │ │ ├── create-react-agent-system-prompt.ipynb
│ │ │ ├── streaming-subgraphs.ipynb
│ │ │ ├── command.ipynb
│ │ │ ├── tool-calling.ipynb
│ │ │ ├── persistence_redis.ipynb
│ │ │ ├── pass-config-to-tools.ipynb
│ │ │ ├── stream-values.ipynb
│ │ │ ├── react-agent-structured-output.ipynb
│ │ │ ├── autogen-langgraph-platform.ipynb
│ │ │ ├── subgraph.ipynb
│ │ │ ├── autogen-integration.ipynb
│ │ │ ├── react_diagrams.png
│ │ │ ├── streaming-tokens-without-langchain.ipynb
│ │ │ ├── map-reduce.ipynb
│ │ │ ├── subgraph-persistence.ipynb
│ │ │ ├── disable-streaming.ipynb
│ │ │ ├── pass_private_state.ipynb
│ │ │ ├── subgraph-transform-state.ipynb
│ │ │ ├── input_output_schema.ipynb
│ │ │ └── streaming-events-from-within-tools-without-langchain.ipynb
│ │ └── reference
│ │ ├── remote_graph.md
│ │ ├── errors.md
│ │ ├── channels.md
│ │ ├── graphs.md
│ │ ├── prebuilt.md
│ │ ├── index.md
│ │ ├── store.md
│ │ ├── checkpoints.md
│ │ ├── types.md
│ │ └── constants.md
│ ├── README.md
│ └── codespell_notebooks.sh
├── libs
│ ├── checkpoint-postgres
│ │ ├── langgraph
│ │ │ ├── checkpoint
│ │ │ │ └── postgres
│ │ │ │ ├── _internal.py
│ │ │ │ ├── _ainternal.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── aio.py
│ │ │ │ ├── py.typed
│ │ │ │ └── base.py
│ │ │ └── store
│ │ │ └── postgres
│ │ │ ├── __init__.py
│ │ │ ├── aio.py
│ │ │ ├── py.typed
│ │ │ └── base.py
│ │ ├── Makefile
│ │ ├── pyproject.toml
│ │ ├── tests
│ │ │ ├── compose-postgres.yml
│ │ │ ├── conftest.py
│ │ │ ├── test_sync.py
│ │ │ ├── test_async.py
│ │ │ ├── __init__.py
│ │ │ ├── test_store.py
│ │ │ ├── embed_test_utils.py
│ │ │ └── test_async_store.py
│ │ ├── README.md
│ │ └── poetry.lock
│ ├── langgraph
│ │ ├── langgraph
│ │ │ ├── pregel
│ │ │ │ ├── loop_docs.md
│ │ │ │ ├── runner.py
│ │ │ │ ├── remote_docs.md
│ │ │ │ ├── algo_docs.md
│ │ │ │ ├── write.py
│ │ │ │ ├── call_docs.md
│ │ │ │ ├── log.py
│ │ │ │ ├── remote.py
│ │ │ │ ├── runner_docs.md
│ │ │ │ ├── protocol.py
│ │ │ │ ├── io.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── io_docs.md
│ │ │ │ ├── types.py
│ │ │ │ ├── debug_docs.md
│ │ │ │ ├── retry.py
│ │ │ │ ├── write_docs.md
│ │ │ │ ├── protocol_docs.md
│ │ │ │ ├── executor_docs.md
│ │ │ │ ├── validate.py
│ │ │ │ ├── utils_docs.md
│ │ │ │ ├── read_docs.md
│ │ │ │ ├── utils.py
│ │ │ │ ├── debug.py
│ │ │ │ ├── manager_docs.md
│ │ │ │ ├── messages_docs.md
│ │ │ │ ├── messages.py
│ │ │ │ ├── loop.py
│ │ │ │ ├── retry_docs.md
│ │ │ │ ├── log_docs.md
│ │ │ │ ├── types_docs.md
│ │ │ │ ├── validate_docs.md
│ │ │ │ ├── algo.py
│ │ │ │ ├── __init___docs.md
│ │ │ │ ├── manager.py
│ │ │ │ ├── executor.py
│ │ │ │ └── read.py
│ │ │ ├── managed
│ │ │ │ ├── context_docs.md
│ │ │ │ ├── is_last_step_docs.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── shared_value.py
│ │ │ │ ├── context.py
│ │ │ │ ├── base_docs.md
│ │ │ │ ├── is_last_step.py
│ │ │ │ ├── shared_value_docs.md
│ │ │ │ ├── __init___docs.md
│ │ │ │ └── base.py
│ │ │ ├── version_docs.md
│ │ │ ├── version.py
│ │ │ ├── constants_docs.md
│ │ │ ├── constants.py
│ │ │ ├── graph
│ │ │ │ ├── state_docs.md
│ │ │ │ ├── graph_docs.md
│ │ │ │ ├── graph.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── message.py
│ │ │ │ ├── __init___docs.md
│ │ │ │ ├── message_docs.md
│ │ │ │ └── state.py
│ │ │ ├── _api
│ │ │ │ ├── deprecation.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── deprecation_docs.md
│ │ │ │ └── __init___docs.md
│ │ │ ├── utils
│ │ │ │ ├── queue.py
│ │ │ │ ├── fields_docs.md
│ │ │ │ ├── config.py
│ │ │ │ ├── fields.py
│ │ │ │ ├── config_docs.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── pydantic.py
│ │ │ │ ├── future_docs.md
│ │ │ │ ├── runnable_docs.md
│ │ │ │ ├── pydantic_docs.md
│ │ │ │ ├── runnable.py
│ │ │ │ ├── __init___docs.md
│ │ │ │ └── queue_docs.md
│ │ │ ├── types.py
│ │ │ ├── prebuilt
│ │ │ │ ├── chat_agent_executor_docs.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── tool_validator.py
│ │ │ │ ├── tool_executor_docs.md
│ │ │ │ ├── chat_agent_executor.py
│ │ │ │ ├── tool_validator_docs.md
│ │ │ │ ├── tool_node_docs.md
│ │ │ │ ├── tool_node.py
│ │ │ │ ├── __init___docs.md
│ │ │ │ └── tool_executor.py
│ │ │ ├── errors_docs.md
│ │ │ ├── py.typed
│ │ │ ├── types_docs.md
│ │ │ ├── errors.py
│ │ │ ├── channels
│ │ │ │ ├── context_docs.md
│ │ │ │ ├── untracked_value_docs.md
│ │ │ │ ├── dynamic_barrier_value_docs.md
│ │ │ │ ├── last_value.py
│ │ │ │ ├── binop_docs.md
│ │ │ │ ├── named_barrier_value_docs.md
│ │ │ │ ├── topic_docs.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── untracked_value.py
│ │ │ │ ├── last_value_docs.md
│ │ │ │ ├── any_value.py
│ │ │ │ ├── named_barrier_value.py
│ │ │ │ ├── binop.py
│ │ │ │ ├── ephemeral_value.py
│ │ │ │ ├── dynamic_barrier_value.py
│ │ │ │ ├── context.py
│ │ │ │ ├── topic.py
│ │ │ │ ├── base_docs.md
│ │ │ │ ├── ephemeral_value_docs.md
│ │ │ │ ├── __init___docs.md
│ │ │ │ ├── any_value_docs.md
│ │ │ │ └── base.py
│ │ │ └── func
│ │ │ └── __init___docs.md
│ │ ├── bench
│ │ │ ├── fanout_to_subgraph.py
│ │ │ ├── react_agent.py
│ │ │ ├── __init__.py
│ │ │ ├── wide_state.py
│ │ │ └── __main__.py
│ │ ├── LICENSE
│ │ ├── Makefile
│ │ ├── pyproject.toml
│ │ ├── tests
│ │ │ ├── fake_tracer_docs.md
│ │ │ ├── compose-postgres.yml
│ │ │ ├── fake_chat_docs.md
│ │ │ ├── any_int_docs.md
│ │ │ ├── test_pregel_docs.md
│ │ │ ├── test_utils.py
│ │ │ ├── conftest.py
│ │ │ ├── test_interruption_docs.md
│ │ │ ├── test_remote_graph.py
│ │ │ ├── fake_chat.py
│ │ │ ├── test_channels_docs.md
│ │ │ ├── test_large_cases_docs.md
│ │ │ ├── conftest_docs.md
│ │ │ ├── any_int.py
│ │ │ ├── test_remote_graph_docs.md
│ │ │ ├── test_interruption.py
│ │ │ ├── test_utils_docs.md
│ │ │ ├── test_io_docs.md
│ │ │ ├── memory_assert.py
│ │ │ ├── test_prebuilt_docs.md
│ │ │ ├── __init__.py
│ │ │ ├── test_pregel_async.py
│ │ │ ├── test_prebuilt.py
│ │ │ ├── test_io.py
│ │ │ ├── test_state.py
│ │ │ ├── test_messages_state.py
│ │ │ ├── test_channels.py
│ │ │ ├── test_runnable.py
│ │ │ ├── test_algo_docs.md
│ │ │ ├── agents_docs.md
│ │ │ ├── test_pregel_async_docs.md
│ │ │ ├── test_tracing_interops.py
│ │ │ ├── memory_assert_docs.md
│ │ │ ├── messages_docs.md
│ │ │ ├── test_messages_state_docs.md
│ │ │ ├── messages.py
│ │ │ ├── test_pregel.py
│ │ │ ├── any_str_docs.md
│ │ │ ├── __init___docs.md
│ │ │ ├── any_str.py
│ │ │ ├── test_algo.py
│ │ │ ├── test_tracing_interops_docs.md
│ │ │ ├── test_large_cases_async_docs.md
│ │ │ ├── fake_tracer.py
│ │ │ ├── test_runnable_docs.md
│ │ │ ├── __snapshots__
│ │ │ │ ├── test_pregel.ambr
│ │ │ │ └── test_pregel_async.ambr
│ │ │ └── test_state_docs.md
│ │ ├── README.md
│ │ ├── poetry.toml
│ │ └── poetry.lock
│ ├── checkpoint
│ │ ├── langgraph
│ │ │ ├── checkpoint
│ │ │ │ ├── memory
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── __init__.py.md
│ │ │ │ │ └── py.typed
│ │ │ │ ├── serde
│ │ │ │ │ ├── jsonplus.py
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── types.py
│ │ │ │ │ ├── py.typed
│ │ │ │ │ └── base.py
│ │ │ │ └── base
│ │ │ │ ├── id.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── __init__.py.md
│ │ │ │ └── py.typed
│ │ │ └── store
│ │ │ ├── memory
│ │ │ │ ├── __init__.py
│ │ │ │ ├── __init__.py.md
│ │ │ │ └── py.typed
│ │ │ └── base
│ │ │ ├── batch.py.md
│ │ │ ├── embed.py
│ │ │ ├── batch.py
│ │ │ ├── __init__.py
│ │ │ ├── __init__.py.md
│ │ │ ├── py.typed
│ │ │ └── embed.py.md
│ │ ├── LICENSE
│ │ ├── Makefile
│ │ ├── pyproject.toml
│ │ ├── tests
│ │ │ ├── __init__.py
│ │ │ ├── test_jsonplus.py
│ │ │ ├── test_store.py
│ │ │ ├── test_memory.py
│ │ │ ├── test_memory.py.md
│ │ │ └── embed_test_utils.py
│ │ ├── README.md
│ │ └── poetry.lock
│ ├── checkpoint-duckdb
│ │ ├── langgraph
│ │ │ ├── checkpoint
│ │ │ │ └── duckdb
│ │ │ │ ├── __init__.py
│ │ │ │ ├── aio.py
│ │ │ │ ├── py.typed
│ │ │ │ └── base.py
│ │ │ └── store
│ │ │ └── duckdb
│ │ │ ├── __init__.py
│ │ │ ├── aio.py
│ │ │ ├── py.typed
│ │ │ └── base.py
│ │ ├── Makefile
│ │ ├── pyproject.toml
│ │ ├── tests
│ │ │ ├── test_sync.py
│ │ │ ├── test_async.py
│ │ │ ├── test_store.py
│ │ │ └── test_async_store.py
│ │ ├── README.md
│ │ └── poetry.lock
│ ├── scheduler-kafka
│ │ ├── langgraph
│ │ │ └── scheduler
│ │ │ └── kafka
│ │ │ ├── default_sync.py
│ │ │ ├── default_async.py
│ │ │ ├── serde.py
│ │ │ ├── __init__.py
│ │ │ ├── types.py
│ │ │ ├── retry.py
│ │ │ ├── orchestrator.py
│ │ │ ├── py.typed
│ │ │ └── executor.py
│ │ ├── LICENSE
│ │ ├── langgraph-distributed.png
│ │ ├── Makefile
│ │ ├── pyproject.toml
│ │ ├── tests
│ │ │ ├── conftest.py
│ │ │ ├── test_fanout_sync.py
│ │ │ ├── __init__.py
│ │ │ ├── test_fanout.py
│ │ │ ├── test_subgraph.py
│ │ │ ├── messages.py
│ │ │ ├── test_push.py
│ │ │ ├── test_subgraph_sync.py
│ │ │ ├── any.py
│ │ │ ├── drain.py
│ │ │ ├── test_push_sync.py
│ │ │ └── compose.yml
│ │ ├── README.md
│ │ └── poetry.lock
│ ├── sdk-js
│ │ ├── LICENSE
│ │ ├── langchain.config.js
│ │ ├── README.md
│ │ ├── yarn.lock
│ │ ├── package.json
│ │ ├── tsconfig.cjs.json
│ │ ├── tsconfig.json
│ │ └── src
│ │ ├── schema.ts
│ │ ├── utils
│ │ │ ├── async_caller.ts
│ │ │ ├── eventsource-parser
│ │ │ │ ├── LICENSE
│ │ │ │ ├── parse.ts
│ │ │ │ ├── stream.ts
│ │ │ │ ├── types.ts
│ │ │ │ └── index.ts
│ │ │ ├── stream.ts
│ │ │ ├── env.ts
│ │ │ └── signals.ts
│ │ ├── types.ts
│ │ ├── client.ts
│ │ └── index.ts
│ ├── sdk-py
│ │ ├── LICENSE
│ │ ├── Makefile
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ ├── langgraph_sdk
│ │ │ ├── client.py
│ │ │ ├── __init__.py
│ │ │ ├── sse.py
│ │ │ ├── py.typed
│ │ │ └── schema.py
│ │ └── poetry.lock
│ ├── cli
│ │ ├── LICENSE
│ │ ├── langgraph_cli
│ │ │ ├── config.py
│ │ │ ├── version.py
│ │ │ ├── util.py
│ │ │ ├── exec.py
│ │ │ ├── constants.py
│ │ │ ├── __init__.py
│ │ │ ├── docker.py
│ │ │ ├── templates.py
│ │ │ ├── cli.py
│ │ │ ├── py.typed
│ │ │ ├── analytics.py
│ │ │ └── progress.py
│ │ ├── Makefile
│ │ ├── pyproject.toml
│ │ ├── js-examples
│ │ │ ├── LICENSE
│ │ │ ├── jest.config.js
│ │ │ ├── tests
│ │ │ │ ├── graph.int.test.ts
│ │ │ │ └── agent.test.ts
│ │ │ ├── README.md
│ │ │ ├── yarn.lock
│ │ │ ├── langgraph.json
│ │ │ ├── package.json
│ │ │ ├── static
│ │ │ │ └── studio.png
│ │ │ ├── tsconfig.json
│ │ │ └── src
│ │ │ └── agent
│ │ │ ├── state.ts
│ │ │ └── graph.ts
│ │ ├── tests
│ │ │ ├── unit_tests
│ │ │ │ ├── test_config.json
│ │ │ │ ├── conftest.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cli
│ │ │ │ │ ├── test_templates.py
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── test_cli.py
│ │ │ │ ├── pipconfig.txt
│ │ │ │ ├── agent.py
│ │ │ │ ├── test_config.py
│ │ │ │ ├── graphs
│ │ │ │ │ └── agent.py
│ │ │ │ ├── helpers.py
│ │ │ │ └── test_docker.py
│ │ │ ├── integration_tests
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_cli.py
│ │ │ └── __init__.py
│ │ ├── README.md
│ │ ├── examples
│ │ │ ├── Makefile
│ │ │ ├── pipconf.txt
│ │ │ ├── pyproject.toml
│ │ │ ├── graphs_reqs_b
│ │ │ │ ├── hello.py
│ │ │ │ ├── requirements.txt
│ │ │ │ ├── graphs_submod
│ │ │ │ │ ├── subprompt.txt
│ │ │ │ │ └── agent.py
│ │ │ │ ├── prompt.txt
│ │ │ │ ├── utils
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── greeter.py
│ │ │ │ └── langgraph.json
│ │ │ ├── langgraph.json
│ │ │ ├── graphs
│ │ │ │ ├── langgraph.json
│ │ │ │ ├── storm.py
│ │ │ │ └── agent.py
│ │ │ ├── poetry.lock
│ │ │ └── graphs_reqs_a
│ │ │ ├── hello.py
│ │ │ ├── requirements.txt
│ │ │ ├── graphs_submod
│ │ │ │ ├── __init__.py
│ │ │ │ ├── subprompt.txt
│ │ │ │ └── agent.py
│ │ │ ├── __init__.py
│ │ │ ├── prompt.txt
│ │ │ └── langgraph.json
│ │ └── poetry.lock
│ └── checkpoint-sqlite
│ ├── langgraph
│ │ └── checkpoint
│ │ └── sqlite
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── aio.py
│ │ └── py.typed
│ ├── Makefile
│ ├── pyproject.toml
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_aiosqlite.py
│ │ └── test_sqlite.py
│ ├── README.md
│ └── poetry.lock
├── README.md
├── CONTRIBUTING.md
├── examples
│ ├── multi_agent
│ │ ├── hierarchical_agent_teams.ipynb
│ │ ├── agent_supervisor.ipynb
│ │ └── multi-agent-collaboration.ipynb
│ ├── introduction.ipynb
│ ├── customer-support
│ │ └── customer-support.ipynb
│ ├── plan-and-execute
│ │ └── plan-and-execute.ipynb
│ ├── run-id-langsmith.ipynb
│ ├── react-agent-from-scratch.ipynb
│ ├── streaming-from-final-node.ipynb
│ ├── recursion-limit.ipynb
│ ├── human_in_the_loop
│ │ ├── breakpoints.ipynb
│ │ ├── time-travel.ipynb
│ │ ├── edit-graph-state.ipynb
│ │ ├── dynamic_breakpoints.ipynb
│ │ ├── review-tool-calls.ipynb
│ │ └── wait-user-input.ipynb
│ ├── memory
│ │ ├── add-summary-conversation-history.ipynb
│ │ ├── manage-conversation-history.ipynb
│ │ └── delete-messages.ipynb
│ ├── streaming-tokens.ipynb
│ ├── create-react-agent.ipynb
│ ├── reflexion
│ │ └── reflexion.ipynb
│ ├── configuration.ipynb
│ ├── code_assistant
│ │ ├── langgraph_code_assistant.ipynb
│ │ └── langgraph_code_assistant_mistral.ipynb
│ ├── lats
│ │ └── lats.ipynb
│ ├── visualization.ipynb
│ ├── rewoo
│ │ └── rewoo.ipynb
│ ├── branching.ipynb
│ ├── create-react-agent-memory.ipynb
│ ├── streaming-events-from-within-tools.ipynb
│ ├── async.ipynb
│ ├── web-navigation
│ │ └── web_voyager.ipynb
│ ├── cloud_examples
│ │ └── langgraph_to_langgraph_cloud.ipynb
│ ├── rag
│ │ ├── langgraph_self_rag_pinecone_movies.ipynb
│ │ ├── langgraph_self_rag_local.ipynb
│ │ ├── langgraph_adaptive_rag_local.ipynb
│ │ ├── langgraph_self_rag.ipynb
│ │ ├── langgraph_crag_local.ipynb
│ │ ├── langgraph_crag.ipynb
│ │ ├── langgraph_adaptive_rag.ipynb
│ │ ├── langgraph_agentic_rag.ipynb
│ │ └── langgraph_adaptive_rag_cohere.ipynb
│ ├── llm-compiler
│ │ └── LLMCompiler.ipynb
│ ├── streaming-content.ipynb
│ ├── README.md
│ ├── subgraphs-manage-state.ipynb
│ ├── persistence.ipynb
│ ├── node-retries.ipynb
│ ├── persistence_mongodb.ipynb
│ ├── create-react-agent-hitl.ipynb
│ ├── state-model.ipynb
│ ├── persistence_postgres.ipynb
│ ├── usaco
│ │ └── usaco.ipynb
│ ├── stream-multiple.ipynb
│ ├── pass-run-time-values-to-tools.ipynb
│ ├── stream-updates.ipynb
│ ├── chatbots
│ │ └── information-gather-prompting.ipynb
│ ├── tool-calling-errors.ipynb
│ ├── create-react-agent-system-prompt.ipynb
│ ├── streaming-subgraphs.ipynb
│ ├── tutorials
│ │ ├── tnt-llm
│ │ │ └── tnt-llm.ipynb
│ │ └── sql-agent.ipynb
│ ├── tool-calling.ipynb
│ ├── persistence_redis.ipynb
│ ├── pass-config-to-tools.ipynb
│ ├── stream-values.ipynb
│ ├── chatbot-simulation-evaluation
│ │ ├── agent-simulation-evaluation.ipynb
│ │ ├── langsmith-agent-simulation-evaluation.ipynb
│ │ └── simulation_utils.py
│ ├── react-agent-structured-output.ipynb
│ ├── subgraph.ipynb
│ ├── extraction
│ │ └── retries.ipynb
│ ├── streaming-tokens-without-langchain.ipynb
│ ├── map-reduce.ipynb
│ ├── storm
│ │ └── storm.ipynb
│ ├── pass_private_state.ipynb
│ ├── self-discover
│ │ └── self-discover.ipynb
│ ├── subgraph-transform-state.ipynb
│ ├── reflection
│ │ └── reflection.ipynb
│ ├── input_output_schema.ipynb
│ └── streaming-events-from-within-tools-without-langchain.ipynb
├── poetry.lock
└── security.md
```
`/Users/malcolm/dev/langchain-ai/langgraph/pyproject.toml`:
```toml
[tool.poetry]
name = "langgraph-monorepo"
version = "0.0.1"
description = "LangGraph monorepo"
authors = []
license = "MIT"
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.10"
aiohappyeyeballs = "2.4.3"
[tool.poetry.group.docs.dependencies]
langgraph = { path = "libs/langgraph/", develop = true }
langgraph-checkpoint = { path = "libs/checkpoint/", develop = true }
langgraph-checkpoint-sqlite = { path = "libs/checkpoint-sqlite", develop = true }
langgraph-checkpoint-postgres = { path = "libs/checkpoint-postgres", develop = true }
langgraph-sdk = {path = "libs/sdk-py", develop = true}
mkdocs = "^1.6.0"
mkdocs-autorefs = ">=1.0.1,<1.1.0"
mkdocstrings = "^0.25.1"
mkdocstrings-python = "^1.10.4"
mkdocs-redirects = "^1.2.1"
mkdocs-minify-plugin = "^0.8.0"
mkdocs-rss-plugin = "^1.13.1"
mkdocs-git-committers-plugin-2 = "^2.3.0"
mkdocs-material = {extras = ["imaging"], version = "^9.5.27"}
markdown-include = "^0.8.1"
markdown-callouts = "^0.4.0"
mkdocs-exclude = "^1.0.2"
vcrpy = "^6.0.1"
click = "^8.1.7"
ruff = "^0.6.8"
jupyter = "^1.1.1"
[tool.poetry.group.test.dependencies]
langchain = "^0.3.8"
langchain-openai = "^0.2.0"
langchain-anthropic = "^0.2.1"
langchain-nomic = "^0.1.3"
langchain-fireworks = "^0.2.0"
langchain-community = "^0.3.0"
langchain-experimental = "^0.3.2"
langsmith = "^0.1.129"
chromadb = "^0.5.5"
gpt4all = "^2.8.2"
scikit-learn = "^1.5.2"
numexpr = "^2.10.1"
numpy = "^1.26.4"
matplotlib = "^3.9.2"
redis = "^5.0.8"
pymongo = "^4.8.0"
motor = "^3.5.1"
grandalf = "^0.8"
pyppeteer = "^2.0.0"
networkx = "^3.3"
autogen = { version = "^0.3.0", python = "<3.13,>=3.8" }
[tool.poetry.group.test]
optional = true
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
extend-include = ["*.ipynb"]
[tool.ruff.lint.per-file-ignores]
"docs/*" = [
"E402", # allow imports to appear anywhere in docs
"F401", # allow "imported but unused" example code
"F811", # allow re-importing the same module, so that cells can stay independent
"F841", # allow assignments to variables that are never read -- it's example code
# The issues below should be cleaned up when there's time
"E722", # allow base imports in notebooks
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/_scripts/generate_api_reference_links.py`:
```py
import importlib
import inspect
import logging
import os
import re
from typing import List, Literal, Optional
from typing_extensions import TypedDict
import nbformat
from nbconvert.preprocessors import Preprocessor
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Base URL for all class documentation
_LANGCHAIN_API_REFERENCE = "https://python.langchain.com/api_reference/"
_LANGGRAPH_API_REFERENCE = "https://langchain-ai.github.io/langgraph/reference/"
# (alias/re-exported modules, source module, class, docs namespace)
MANUAL_API_REFERENCES_LANGGRAPH = [
(
["langgraph.prebuilt"],
"langgraph.prebuilt.chat_agent_executor",
"create_react_agent",
"prebuilt",
),
(["langgraph.prebuilt"], "langgraph.prebuilt.tool_node", "ToolNode", "prebuilt"),
(
["langgraph.prebuilt"],
"langgraph.prebuilt.tool_node",
"tools_condition",
"prebuilt",
),
(
["langgraph.prebuilt"],
"langgraph.prebuilt.tool_node",
"InjectedState",
"prebuilt",
),
# Graph
(["langgraph.graph"], "langgraph.graph.message", "add_messages", "graphs"),
(["langgraph.graph"], "langgraph.graph.state", "StateGraph", "graphs"),
(["langgraph.graph"], "langgraph.graph.state", "CompiledStateGraph", "graphs"),
([], "langgraph.types", "StreamMode", "types"),
(["langgraph.graph"], "langgraph.constants", "START", "constants"),
(["langgraph.graph"], "langgraph.constants", "END", "constants"),
(["langgraph.constants"], "langgraph.types", "Send", "types"),
(["langgraph.constants"], "langgraph.types", "Interrupt", "types"),
([], "langgraph.types", "RetryPolicy", "types"),
([], "langgraph.checkpoint.base", "Checkpoint", "checkpoints"),
([], "langgraph.checkpoint.base", "CheckpointMetadata", "checkpoints"),
([], "langgraph.checkpoint.base", "BaseCheckpointSaver", "checkpoints"),
([], "langgraph.checkpoint.base", "SerializerProtocol", "checkpoints"),
([], "langgraph.checkpoint.serde.jsonplus", "JsonPlusSerializer", "checkpoints"),
([], "langgraph.checkpoint.memory", "MemorySaver", "checkpoints"),
([], "langgraph.checkpoint.sqlite.aio", "AsyncSqliteSaver", "checkpoints"),
([], "langgraph.checkpoint.sqlite", "SqliteSaver", "checkpoints"),
([], "langgraph.checkpoint.postgres.aio", "AsyncPostgresSaver", "checkpoints"),
([], "langgraph.checkpoint.postgres", "PostgresSaver", "checkpoints"),
]
WELL_KNOWN_LANGGRAPH_OBJECTS = {
(module_, class_): (source_module, namespace)
for (modules, source_module, class_, namespace) in MANUAL_API_REFERENCES_LANGGRAPH
for module_ in modules + [source_module]
}
def _make_regular_expression(pkg_prefix: str) -> re.Pattern:
if not pkg_prefix.isidentifier():
raise ValueError(f"Invalid package prefix: {pkg_prefix}")
return re.compile(
r"from\s+(" + pkg_prefix + "(?:_\w+)?(?:\.\w+)*?)\s+import\s+"
r"((?:\w+(?:,\s*)?)*" # Match zero or more words separated by a comma+optional ws
r"(?:\s*\(.*?\))?)", # Match optional parentheses block
re.DOTALL, # Match newlines as well
)
# Regular expression to match langchain import lines
_IMPORT_LANGCHAIN_RE = _make_regular_expression("langchain")
_IMPORT_LANGGRAPH_RE = _make_regular_expression("langgraph")
def _get_full_module_name(module_path, class_name) -> Optional[str]:
"""Get full module name using inspect"""
try:
module = importlib.import_module(module_path)
class_ = getattr(module, class_name)
module = inspect.getmodule(class_)
if module is None:
# For constants, inspect.getmodule() might return None
# In this case, we'll return the original module_path
return module_path
return module.__name__
except AttributeError as e:
logger.warning(f"Could not find module for {class_name}, {e}")
return None
except ImportError as e:
logger.warning(f"Failed to load for class {class_name}, {e}")
return None
def _get_doc_title(data: str, file_name: str) -> str:
try:
return re.findall(r"^#\s*(.*)", data, re.MULTILINE)[0]
except IndexError:
pass
# Parse the rst-style titles
try:
return re.findall(r"^(.*)\n=+\n", data, re.MULTILINE)[0]
except IndexError:
return file_name
class ImportInformation(TypedDict):
imported: str # imported class name
source: str # module path
docs: str # URL to the documentation
title: str # Title of the document
def _get_imports(
code: str, doc_title: str, package_ecosystem: Literal["langchain", "langgraph"]
) -> List[ImportInformation]:
"""Get imports from the given code block.
Args:
code: Python code block from which to extract imports
doc_title: Title of the document
package_ecosystem: "langchain" or "langgraph". The two live in different
repositories and have separate documentation sites.
Returns:
List of import information for the given code block
"""
imports = []
if package_ecosystem == "langchain":
pattern = _IMPORT_LANGCHAIN_RE
elif package_ecosystem == "langgraph":
pattern = _IMPORT_LANGGRAPH_RE
else:
raise ValueError(f"Invalid package ecosystem: {package_ecosystem}")
for import_match in pattern.finditer(code):
module = import_match.group(1)
if "pydantic_v1" in module:
continue
imports_str = (
import_match.group(2).replace("(\n", "").replace("\n)", "")
) # Handle newlines within parentheses
# remove any newline and spaces, then split by comma
imported_classes = [
imp.strip()
for imp in re.split(r",\s*", imports_str.replace("\n", ""))
if imp.strip()
]
for class_name in imported_classes:
module_path = _get_full_module_name(module, class_name)
if not module_path:
continue
if len(module_path.split(".")) < 2:
continue
if package_ecosystem == "langchain":
pkg = module_path.split(".")[0].replace("langchain_", "")
top_level_mod = module_path.split(".")[1]
url = (
_LANGCHAIN_API_REFERENCE
+ pkg
+ "/"
+ top_level_mod
+ "/"
+ module_path
+ "."
+ class_name
+ ".html"
)
elif package_ecosystem == "langgraph":
if (module, class_name) not in WELL_KNOWN_LANGGRAPH_OBJECTS:
# Likely not documented yet
continue
source_module, namespace = WELL_KNOWN_LANGGRAPH_OBJECTS[
(module, class_name)
]
url = (
_LANGGRAPH_API_REFERENCE
+ namespace
+ "/#"
+ source_module
+ "."
+ class_name
)
else:
raise ValueError(f"Invalid package ecosystem: {package_ecosystem}")
# Add the import information to our list
imports.append(
{
"imported": class_name,
"source": module,
"docs": url,
"title": doc_title,
}
)
return imports
class ImportPreprocessor(Preprocessor):
"""A preprocessor to replace imports in each Python code cell with links to their
documentation and append the import info in a comment."""
def preprocess(self, nb, resources):
self.all_imports = []
file_name = os.path.basename(resources.get("metadata", {}).get("name", ""))
_DOC_TITLE = _get_doc_title(nb.cells[0].source, file_name)
cells = []
for cell in nb.cells:
if cell.cell_type == "code":
cells.append(cell)
imports = _get_imports(
cell.source, _DOC_TITLE, "langchain"
) + _get_imports(cell.source, _DOC_TITLE, "langgraph")
if not imports:
continue
cells.append(
nbformat.v4.new_markdown_cell(
source=f"""
<div>
<b>API Reference:</b>
{' | '.join(f'<a href="{imp["docs"]}">{imp["imported"]}</a>' for imp in imports)}
</div>
"""
)
)
else:
cells.append(cell)
nb.cells = cells
return nb, resources
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/_scripts/prepare_notebooks_for_ci.py`:
```py
"""Preprocess notebooks for CI. Currently adds VCR cassettes and optionally removes pip install cells."""
import logging
import os
import json
import click
import nbformat
logger = logging.getLogger(__name__)
NOTEBOOK_DIRS = ("docs/docs/how-tos","docs/docs/tutorials")
DOCS_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CASSETTES_PATH = os.path.join(DOCS_PATH, "cassettes")
BLOCKLIST_COMMANDS = (
# skip if has WebBaseLoader to avoid caching web pages
"WebBaseLoader",
# skip if has draw_mermaid_png to avoid generating mermaid images via API
"draw_mermaid_png",
)
NOTEBOOKS_NO_CASSETTES = (
"docs/docs/how-tos/visualization.ipynb",
"docs/docs/how-tos/many-tools.ipynb"
)
NOTEBOOKS_NO_EXECUTION = [
# this uses a user provided project name for langsmith
"docs/docs/tutorials/tnt-llm/tnt-llm.ipynb",
# this uses langsmith datasets
"docs/docs/tutorials/chatbot-simulation-evaluation/langsmith-agent-simulation-evaluation.ipynb",
# this uses browser APIs
"docs/docs/tutorials/web-navigation/web_voyager.ipynb",
# these RAG guides use an ollama model
"docs/docs/tutorials/rag/langgraph_adaptive_rag_local.ipynb",
"docs/docs/tutorials/rag/langgraph_crag_local.ipynb",
"docs/docs/tutorials/rag/langgraph_self_rag_local.ipynb",
# this loads a massive dataset from gcp
"docs/docs/tutorials/usaco/usaco.ipynb",
# TODO: figure out why autogen notebook is not runnable (they are just hanging. possible due to code execution?)
"docs/docs/how-tos/autogen-integration.ipynb",
# TODO: need to update these notebooks to make sure they are runnable in CI
"docs/docs/tutorials/storm/storm.ipynb", # issues only when running with VCR
"docs/docs/tutorials/lats/lats.ipynb", # issues only when running with VCR
"docs/docs/tutorials/rag/langgraph_crag.ipynb", # flakiness from tavily
"docs/docs/tutorials/rag/langgraph_adaptive_rag.ipynb", # Cannot create a consistent method resolution error from VCR
"docs/docs/how-tos/map-reduce.ipynb" # flakiness from structured output, only when running with VCR
]
def comment_install_cells(notebook: nbformat.NotebookNode) -> nbformat.NotebookNode:
for cell in notebook.cells:
if cell.cell_type != "code":
continue
if "pip install" in cell.source:
# Comment out the lines in cells containing "pip install"
cell.source = "\n".join(
f"# {line}" if line.strip() else line
for line in cell.source.splitlines()
)
return notebook
def is_magic_command(code: str) -> bool:
return code.strip().startswith("%") or code.strip().startswith("!")
def is_comment(code: str) -> bool:
return code.strip().startswith("#")
def has_blocklisted_command(code: str, metadata: dict) -> bool:
if 'hide_from_vcr' in metadata:
return True
code = code.strip()
for blocklisted_pattern in BLOCKLIST_COMMANDS:
if blocklisted_pattern in code:
return True
return False
def add_vcr_to_notebook(
notebook: nbformat.NotebookNode, cassette_prefix: str
) -> nbformat.NotebookNode:
"""Inject `with vcr.cassette` into each code cell of the notebook."""
# Inject VCR context manager into each code cell
for idx, cell in enumerate(notebook.cells):
if cell.cell_type != "code":
continue
lines = cell.source.splitlines()
# skip if empty cell
if not lines:
continue
are_magic_lines = [is_magic_command(line) for line in lines]
# skip if all magic
if all(are_magic_lines):
continue
if any(are_magic_lines):
raise ValueError(
"Cannot process code cells with mixed magic and non-magic code."
)
# skip if just comments
if all(is_comment(line) or not line.strip() for line in lines):
continue
if has_blocklisted_command(cell.source, cell.metadata):
continue
cell_id = cell.get("id", idx)
cassette_name = f"{cassette_prefix}_{cell_id}.msgpack.zlib"
cell.source = f"with custom_vcr.use_cassette('{cassette_name}', filter_headers=['x-api-key', 'authorization'], record_mode='once', serializer='advanced_compressed'):\n" + "\n".join(
f" {line}" for line in lines
)
# Add import statement
vcr_import_lines = [
"import nest_asyncio",
"nest_asyncio.apply()",
"import vcr",
"import msgpack",
"import base64",
"import zlib",
"import os",
"os.environ.pop(\"LANGCHAIN_TRACING_V2\", None)",
"custom_vcr = vcr.VCR()",
"",
"def compress_data(data, compression_level=9):",
" packed = msgpack.packb(data, use_bin_type=True)",
" compressed = zlib.compress(packed, level=compression_level)",
" return base64.b64encode(compressed).decode('utf-8')",
"",
"def decompress_data(compressed_string):",
" decoded = base64.b64decode(compressed_string)",
" decompressed = zlib.decompress(decoded)",
" return msgpack.unpackb(decompressed, raw=False)",
"",
"class AdvancedCompressedSerializer:",
" def serialize(self, cassette_dict):",
" return compress_data(cassette_dict)",
"",
" def deserialize(self, cassette_string):",
" return decompress_data(cassette_string)",
"",
"custom_vcr.register_serializer('advanced_compressed', AdvancedCompressedSerializer())",
"custom_vcr.serializer = 'advanced_compressed'",
]
import_cell = nbformat.v4.new_code_cell(source="\n".join(vcr_import_lines))
import_cell.pop("id", None)
notebook.cells.insert(0, import_cell)
return notebook
def process_notebooks(should_comment_install_cells: bool) -> None:
for directory in NOTEBOOK_DIRS:
for root, _, files in os.walk(directory):
for file in files:
if not file.endswith(".ipynb") or "ipynb_checkpoints" in root:
continue
notebook_path = os.path.join(root, file)
try:
notebook = nbformat.read(notebook_path, as_version=4)
if should_comment_install_cells:
notebook = comment_install_cells(notebook)
base_filename = os.path.splitext(os.path.basename(file))[0]
cassette_prefix = os.path.join(CASSETTES_PATH, base_filename)
if notebook_path not in NOTEBOOKS_NO_CASSETTES:
notebook = add_vcr_to_notebook(
notebook, cassette_prefix=cassette_prefix
)
if notebook_path in NOTEBOOKS_NO_EXECUTION:
# Add a cell at the beginning to indicate that this notebook should not be executed
warning_cell = nbformat.v4.new_markdown_cell(
source="**Warning:** This notebook is not meant to be executed automatically."
)
notebook.cells.insert(0, warning_cell)
# Add a special tag to the first code cell
if notebook.cells and notebook.cells[1].cell_type == "code":
notebook.cells[1].metadata["tags"] = notebook.cells[1].metadata.get("tags", []) + ["no_execution"]
nbformat.write(notebook, notebook_path)
logger.info(f"Processed: {notebook_path}")
except Exception as e:
logger.error(f"Error processing {notebook_path}: {e}")
with open(os.path.join(DOCS_PATH, "notebooks_no_execution.json"), "w") as f:
json.dump(NOTEBOOKS_NO_EXECUTION, f)
@click.command()
@click.option(
"--comment-install-cells",
is_flag=True,
default=False,
help="Whether to comment out install cells",
)
def main(comment_install_cells):
process_notebooks(should_comment_install_cells=comment_install_cells)
logger.info("All notebooks processed successfully.")
if __name__ == "__main__":
main()
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/_scripts/notebook_convert.py`:
```py
import os
import re
from pathlib import Path
import nbformat
from nbconvert.exporters import MarkdownExporter
from nbconvert.preprocessors import Preprocessor
from generate_api_reference_links import ImportPreprocessor
class EscapePreprocessor(Preprocessor):
def preprocess_cell(self, cell, resources, cell_index):
if cell.cell_type == "markdown":
# rewrite markdown links to html links (excluding image links)
cell.source = re.sub(
r"(?<!!)\[([^\]]*)\]\((?![^\)]*//)([^)]*)(?:\.ipynb)?\)",
r'<a href="\2">\1</a>',
cell.source,
)
# Fix image paths in <img> tags
cell.source = re.sub(
r'<img\s+src="\.?/img/([^"]+)"', r'<img src="../img/\1"', cell.source
)
elif cell.cell_type == "code":
# escape ``` in code
cell.source = cell.source.replace("```", r"\`\`\`")
# escape ``` in output
if "outputs" in cell:
filter_out = set()
for i, output in enumerate(cell["outputs"]):
if "text" in output:
if not output["text"].strip():
filter_out.add(i)
continue
value = output["text"].replace("```", r"\`\`\`")
# handle a funky case w/ references in text
value = re.sub(r"\[(\d+)\](?=\[(\d+)\])", r"[\1]\\", value)
output["text"] = value
elif "data" in output:
for key, value in output["data"].items():
if isinstance(value, str):
value = value.replace("```", r"\`\`\`")
# handle a funky case w/ references in text
output["data"][key] = re.sub(
r"\[(\d+)\](?=\[(\d+)\])", r"[\1]\\", value
)
cell["outputs"] = [
output
for i, output in enumerate(cell["outputs"])
if i not in filter_out
]
return cell, resources
class ExtractAttachmentsPreprocessor(Preprocessor):
"""
Extracts all of the outputs from the notebook file. The extracted
outputs are returned in the 'resources' dictionary.
"""
def preprocess_cell(self, cell, resources, cell_index):
"""
Apply a transformation on each cell,
Parameters
----------
cell : NotebookNode cell
Notebook cell being processed
resources : dictionary
Additional resources used in the conversion process. Allows
preprocessors to pass variables into the Jinja engine.
cell_index : int
Index of the cell being processed (see base.py)
"""
# Get files directory if it has been specified
# Make sure outputs key exists
if not isinstance(resources["outputs"], dict):
resources["outputs"] = {}
# Loop through all of the attachments in the cell
for name, attach in cell.get("attachments", {}).items():
for mime, data in attach.items():
if mime not in {
"image/png",
"image/jpeg",
"image/svg+xml",
"application/pdf",
}:
continue
# attachments are pre-rendered. Only replace markdown-formatted
# images with the following logic
attach_str = f"({name})"
if attach_str in cell.source:
data = f"(data:{mime};base64,{data})"
cell.source = cell.source.replace(attach_str, data)
return cell, resources
exporter = MarkdownExporter(
preprocessors=[
EscapePreprocessor,
ExtractAttachmentsPreprocessor,
ImportPreprocessor,
],
template_name="mdoutput",
extra_template_basedirs=[
os.path.join(os.path.dirname(__file__), "notebook_convert_templates")
],
)
def convert_notebook(
notebook_path: Path,
) -> Path:
with open(notebook_path) as f:
nb = nbformat.read(f, as_version=4)
body, _ = exporter.from_notebook_node(nb)
return body
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/_scripts/notebook_hooks.py`:
```py
import logging
from typing import Any, Dict
from mkdocs.structure.pages import Page
from mkdocs.structure.files import Files, File
from notebook_convert import convert_notebook
logger = logging.getLogger(__name__)
logging.basicConfig()
logger.setLevel(logging.INFO)
class NotebookFile(File):
def is_documentation_page(self):
return True
def on_files(files: Files, **kwargs: Dict[str, Any]):
new_files = Files([])
for file in files:
if file.src_path.endswith(".ipynb"):
new_file = NotebookFile(
path=file.src_path,
src_dir=file.src_dir,
dest_dir=file.dest_dir,
use_directory_urls=file.use_directory_urls,
)
new_files.append(new_file)
else:
new_files.append(file)
return new_files
def on_page_markdown(markdown: str, page: Page, **kwargs: Dict[str, Any]):
if page.file.src_path.endswith(".ipynb"):
logger.info("Processing Jupyter notebook: %s", page.file.src_path)
body = convert_notebook(page.file.abs_src_path)
return body
return markdown
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/_scripts/download_tiktoken.py`:
```py
import tiktoken
# This will trigger the download and caching of the necessary files
for encoding in ("gpt2", "gpt-3.5"):
tiktoken.encoding_for_model(encoding)
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/multi_agent/agent_supervisor.py`:
```py
# %% [markdown]
# # Multi-agent supervisor
#
# The [previous example](../multi-agent-collaboration) routed messages automatically based on the output of the initial researcher agent.
#
# We can also choose to use an [LLM to orchestrate](https://langchain-ai.github.io/langgraph/concepts/multi_agent/#supervisor) the different agents.
#
# Below, we will create an agent group, with an agent supervisor to help delegate tasks.
#
# 
#
# To simplify the code in each agent node, we will use LangGraph's prebuilt [create_react_agent](https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent). This and other "advanced agent" notebooks are designed to show how you can implement certain design patterns in LangGraph. If the pattern suits your needs, we recommend combining it with some of the other fundamental patterns described elsewhere in the docs for best performance.
#
# ## Setup
#
# First, let's install required packages and set our API keys
# %%
%%capture --no-stderr
%pip install -U langgraph langchain_community langchain_anthropic langchain_experimental
# %%
import getpass
import os
def _set_if_undefined(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"Please provide your {var}")
_set_if_undefined("ANTHROPIC_API_KEY")
_set_if_undefined("TAVILY_API_KEY")
# %% [markdown]
# <div class="admonition tip">
# <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p>
# <p style="padding-top: 5px;">
# Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>.
# </p>
# </div>
# %% [markdown]
# ## Create tools
#
# For this example, you will make an agent to do web research with a search engine, and one agent to create plots. Define the tools they'll use below:
# %%
from typing import Annotated
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import tool
from langchain_experimental.utilities import PythonREPL
tavily_tool = TavilySearchResults(max_results=5)
# This executes code locally, which can be unsafe
repl = PythonREPL()
@tool
def python_repl_tool(
code: Annotated[str, "The python code to execute to generate your chart."],
):
"""Use this to execute python code and do math. If you want to see the output of a value,
you should print it out with `print(...)`. This is visible to the user."""
try:
result = repl.run(code)
except BaseException as e:
return f"Failed to execute. Error: {repr(e)}"
result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
return result_str
# %% [markdown]
# ### Create Agent Supervisor
#
# It will use LLM with structured output to choose the next worker node OR finish processing.
# %%
from typing import Literal
from typing_extensions import TypedDict
from langchain_anthropic import ChatAnthropic
from langgraph.graph import MessagesState
from langgraph.types import Command
members = ["researcher", "coder"]
# Our team supervisor is an LLM node. It just picks the next agent to process
# and decides when the work is completed
options = members + ["FINISH"]
system_prompt = (
"You are a supervisor tasked with managing a conversation between the"
f" following workers: {members}. Given the following user request,"
" respond with the worker to act next. Each worker will perform a"
" task and respond with their results and status. When finished,"
" respond with FINISH."
)
class Router(TypedDict):
"""Worker to route to next. If no workers needed, route to FINISH."""
next: Literal[*options]
llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
def supervisor_node(state: MessagesState) -> Command[Literal[*members, "__end__"]]:
messages = [
{"role": "system", "content": system_prompt},
] + state["messages"]
response = llm.with_structured_output(Router).invoke(messages)
goto = response["next"]
if goto == "FINISH":
goto = END
return Command(goto=goto)
# %% [markdown]
# ## Construct Graph
#
# We're ready to start building the graph. Below, define the state and worker nodes using the function we just defined.
# %%
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import create_react_agent
research_agent = create_react_agent(
llm, tools=[tavily_tool], state_modifier="You are a researcher. DO NOT do any math."
)
def research_node(state: MessagesState) -> Command[Literal["supervisor"]]:
result = research_agent.invoke(state)
return Command(
update={
"messages": [
HumanMessage(content=result["messages"][-1].content, name="researcher")
]
},
goto="supervisor",
)
# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
code_agent = create_react_agent(llm, tools=[python_repl_tool])
def code_node(state: MessagesState) -> Command[Literal["supervisor"]]:
result = code_agent.invoke(state)
return Command(
update={
"messages": [
HumanMessage(content=result["messages"][-1].content, name="coder")
]
},
goto="supervisor",
)
builder = StateGraph(MessagesState)
builder.add_edge(START, "supervisor")
builder.add_node("supervisor", supervisor_node)
builder.add_node("researcher", research_node)
builder.add_node("coder", code_node)
graph = builder.compile()
# %%
from IPython.display import display, Image
display(Image(graph.get_graph().draw_mermaid_png()))
# %% [markdown]
# ## Invoke the team
#
# With the graph created, we can now invoke it and see how it performs!
# %%
for s in graph.stream(
{"messages": [("user", "What's the square root of 42?")]}, subgraphs=True
):
print(s)
print("----")
# %%
for s in graph.stream(
{
"messages": [
(
"user",
"Find the latest GDP of New York and California, then calculate the average",
)
]
},
subgraphs=True,
):
print(s)
print("----")
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/multi_agent/multi-agent-collaboration.py`:
```py
# %% [markdown]
# # Multi-agent network
#
# A single agent can usually operate effectively using a handful of tools within a single domain, but even using powerful models like `gpt-4`, it can be less effective at using many tools.
#
# One way to approach complicated tasks is through a "divide-and-conquer" approach: create an specialized agent for each task or domain and route tasks to the correct "expert". This is an example of a [multi-agent network](https://langchain-ai.github.io/langgraph/concepts/multi_agent/#network) architecture.
#
# This notebook (inspired by the paper [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://arxiv.org/abs/2308.08155), by Wu, et. al.) shows one way to do this using LangGraph.
#
# The resulting graph will look something like the following diagram:
#
# 
#
# Before we get started, a quick note: this and other multi-agent notebooks are designed to show _how_ you can implement certain design patterns in LangGraph. If the pattern suits your needs, we recommend combining it with some of the other fundamental patterns described elsewhere in the docs for best performance.
#
# ## Setup
#
# First, let's install our required packages and set our API keys:
# %%
# %%capture --no-stderr
# %pip install -U langchain_community langchain_anthropic langchain_experimental matplotlib langgraph
# %%
import getpass
import os
def _set_if_undefined(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"Please provide your {var}")
_set_if_undefined("ANTHROPIC_API_KEY")
_set_if_undefined("TAVILY_API_KEY")
# %% [markdown]
# <div class="admonition tip">
# <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p>
# <p style="padding-top: 5px;">
# Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>.
# </p>
# </div>
# %% [markdown]
# ## Define tools
#
# We will also define some tools that our agents will use in the future
# %%
from typing import Annotated
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import tool
from langchain_experimental.utilities import PythonREPL
tavily_tool = TavilySearchResults(max_results=5)
# Warning: This executes code locally, which can be unsafe when not sandboxed
repl = PythonREPL()
@tool
def python_repl_tool(
code: Annotated[str, "The python code to execute to generate your chart."],
):
"""Use this to execute python code. If you want to see the output of a value,
you should print it out with `print(...)`. This is visible to the user."""
try:
result = repl.run(code)
except BaseException as e:
return f"Failed to execute. Error: {repr(e)}"
result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
return (
result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
)
# %% [markdown]
# ## Create graph
#
# Now that we've defined our tools and made some helper functions, will create the individual agents below and tell them how to talk to each other using LangGraph.
# %% [markdown]
# ### Define Agent Nodes
#
# We now need to define the nodes.
#
# First, we'll create a utility to create a system prompt for each agent.
# %%
def make_system_prompt(suffix: str) -> str:
return (
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK, another assistant with different tools "
" will help where you left off. Execute what you can to make progress."
" If you or any of the other assistants have the final answer or deliverable,"
" prefix your response with FINAL ANSWER so the team knows to stop."
f"\n{suffix}"
)
# %%
from typing import Literal
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_anthropic import ChatAnthropic
from langgraph.prebuilt import create_react_agent
from langgraph.graph import MessagesState, END
from langgraph.types import Command
llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
def get_next_node(last_message: BaseMessage, goto: str):
if "FINAL ANSWER" in last_message.content:
# Any agent decided the work is done
return END
return goto
# Research agent and node
research_agent = create_react_agent(
llm,
tools=[tavily_tool],
state_modifier=make_system_prompt(
"You can only do research. You are working with a chart generator colleague."
),
)
def research_node(
state: MessagesState,
) -> Command[Literal["chart_generator", END]]:
result = research_agent.invoke(state)
goto = get_next_node(result["messages"][-1], "chart_generator")
# wrap in a human message, as not all providers allow
# AI message at the last position of the input messages list
result["messages"][-1] = HumanMessage(
content=result["messages"][-1].content, name="researcher"
)
return Command(
update={
# share internal message history of research agent with other agents
"messages": result["messages"],
},
goto=goto,
)
# Chart generator agent and node
# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
chart_agent = create_react_agent(
llm,
[python_repl_tool],
state_modifier=make_system_prompt(
"You can only generate charts. You are working with a researcher colleague."
),
)
def chart_node(state: MessagesState) -> Command[Literal["researcher", END]]:
result = chart_agent.invoke(state)
goto = get_next_node(result["messages"][-1], "researcher")
# wrap in a human message, as not all providers allow
# AI message at the last position of the input messages list
result["messages"][-1] = HumanMessage(
content=result["messages"][-1].content, name="chart_generator"
)
return Command(
update={
# share internal message history of chart agent with other agents
"messages": result["messages"],
},
goto=goto,
)
# %% [markdown]
# ### Define the Graph
#
# We can now put it all together and define the graph!
# %%
from langgraph.graph import StateGraph, START
workflow = StateGraph(MessagesState)
workflow.add_node("researcher", research_node)
workflow.add_node("chart_generator", chart_node)
workflow.add_edge(START, "researcher")
graph = workflow.compile()
# %%
from IPython.display import Image, display
try:
display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
# This requires some extra dependencies and is optional
pass
# %% [markdown]
# ## Invoke
#
# With the graph created, you can invoke it! Let's have it chart some stats for us.
# %%
events = graph.stream(
{
"messages": [
(
"user",
"First, get the UK's GDP over the past 5 years, then make a line chart of it. "
"Once you make the chart, finish.",
)
],
},
# Maximum number of steps to take in the graph
{"recursion_limit": 150},
)
for s in events:
print(s)
print("----")
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/multi_agent/hierarchical_agent_teams.py`:
```py
# %% [markdown]
# # Hierarchical Agent Teams
#
# In our previous example ([Agent Supervisor](../agent_supervisor)), we introduced the concept of a single [supervisor node](https://langchain-ai.github.io/langgraph/concepts/multi_agent/#supervisor) to route work between different worker nodes.
#
# But what if the job for a single worker becomes too complex? What if the number of workers becomes too large?
#
# For some applications, the system may be more effective if work is distributed _hierarchically_.
#
# You can do this by composing different subgraphs and creating a top-level supervisor, along with mid-level supervisors.
#
# To do this, let's build a simple research assistant! The graph will look something like the following:
#
# 
#
# This notebook is inspired by the paper [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://arxiv.org/abs/2308.08155), by Wu, et. al. In the rest of this notebook, you will:
#
# 1. Define the agents' tools to access the web and write files
# 2. Define some utilities to help create the graph and agents
# 3. Create and define each team (web research + doc writing)
# 4. Compose everything together.
#
# ## Setup
#
# First, let's install our required packages and set our API keys
# %%
%%capture --no-stderr
%pip install -U langgraph langchain_community langchain_anthropic langchain_experimental
# %%
import getpass
import os
def _set_if_undefined(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"Please provide your {var}")
_set_if_undefined("OPENAI_API_KEY")
_set_if_undefined("TAVILY_API_KEY")
# %% [markdown]
# <div class="admonition tip">
# <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p>
# <p style="padding-top: 5px;">
# Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>.
# </p>
# </div>
# %% [markdown]
# ## Create Tools
#
# Each team will be composed of one or more agents each with one or more tools. Below, define all the tools to be used by your different teams.
#
# We'll start with the research team.
#
# **ResearchTeam tools**
#
# The research team can use a search engine and url scraper to find information on the web. Feel free to add additional functionality below to boost the team performance!
# %%
from typing import Annotated, List
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import tool
tavily_tool = TavilySearchResults(max_results=5)
@tool
def scrape_webpages(urls: List[str]) -> str:
"""Use requests and bs4 to scrape the provided web pages for detailed information."""
loader = WebBaseLoader(urls)
docs = loader.load()
return "\n\n".join(
[
f'<Document name="{doc.metadata.get("title", "")}">\n{doc.page_content}\n</Document>'
for doc in docs
]
)
# %% [markdown]
# **Document writing team tools**
#
# Next up, we will give some tools for the doc writing team to use.
# We define some bare-bones file-access tools below.
#
# Note that this gives the agents access to your file-system, which can be unsafe. We also haven't optimized the tool descriptions for performance.
# %%
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, Optional
from langchain_experimental.utilities import PythonREPL
from typing_extensions import TypedDict
_TEMP_DIRECTORY = TemporaryDirectory()
WORKING_DIRECTORY = Path(_TEMP_DIRECTORY.name)
@tool
def create_outline(
points: Annotated[List[str], "List of main points or sections."],
file_name: Annotated[str, "File path to save the outline."],
) -> Annotated[str, "Path of the saved outline file."]:
"""Create and save an outline."""
with (WORKING_DIRECTORY / file_name).open("w") as file:
for i, point in enumerate(points):
file.write(f"{i + 1}. {point}\n")
return f"Outline saved to {file_name}"
@tool
def read_document(
file_name: Annotated[str, "File path to read the document from."],
start: Annotated[Optional[int], "The start line. Default is 0"] = None,
end: Annotated[Optional[int], "The end line. Default is None"] = None,
) -> str:
"""Read the specified document."""
with (WORKING_DIRECTORY / file_name).open("r") as file:
lines = file.readlines()
if start is not None:
start = 0
return "\n".join(lines[start:end])
@tool
def write_document(
content: Annotated[str, "Text content to be written into the document."],
file_name: Annotated[str, "File path to save the document."],
) -> Annotated[str, "Path of the saved document file."]:
"""Create and save a text document."""
with (WORKING_DIRECTORY / file_name).open("w") as file:
file.write(content)
return f"Document saved to {file_name}"
@tool
def edit_document(
file_name: Annotated[str, "Path of the document to be edited."],
inserts: Annotated[
Dict[int, str],
"Dictionary where key is the line number (1-indexed) and value is the text to be inserted at that line.",
],
) -> Annotated[str, "Path of the edited document file."]:
"""Edit a document by inserting text at specific line numbers."""
with (WORKING_DIRECTORY / file_name).open("r") as file:
lines = file.readlines()
sorted_inserts = sorted(inserts.items())
for line_number, text in sorted_inserts:
if 1 <= line_number <= len(lines) + 1:
lines.insert(line_number - 1, text + "\n")
else:
return f"Error: Line number {line_number} is out of range."
with (WORKING_DIRECTORY / file_name).open("w") as file:
file.writelines(lines)
return f"Document edited and saved to {file_name}"
# Warning: This executes code locally, which can be unsafe when not sandboxed
repl = PythonREPL()
@tool
def python_repl_tool(
code: Annotated[str, "The python code to execute to generate your chart."],
):
"""Use this to execute python code. If you want to see the output of a value,
you should print it out with `print(...)`. This is visible to the user."""
try:
result = repl.run(code)
except BaseException as e:
return f"Failed to execute. Error: {repr(e)}"
return f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
# %% [markdown]
# ## Helper Utilities
#
# We are going to create a few utility functions to make it more concise when we want to:
#
# 1. Create a worker agent.
# 2. Create a supervisor for the sub-graph.
#
# These will simplify the graph compositional code at the end for us so it's easier to see what's going on.
# %%
from typing import List, Optional, Literal
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.types import Command
from langchain_core.messages import HumanMessage, trim_messages
def make_supervisor_node(llm: BaseChatModel, members: list[str]) -> str:
options = ["FINISH"] + members
system_prompt = (
"You are a supervisor tasked with managing a conversation between the"
f" following workers: {members}. Given the following user request,"
" respond with the worker to act next. Each worker will perform a"
" task and respond with their results and status. When finished,"
" respond with FINISH."
)
class Router(TypedDict):
"""Worker to route to next. If no workers needed, route to FINISH."""
next: Literal[*options]
def supervisor_node(state: MessagesState) -> Command[Literal[*members, "__end__"]]:
"""An LLM-based router."""
messages = [
{"role": "system", "content": system_prompt},
] + state["messages"]
response = llm.with_structured_output(Router).invoke(messages)
goto = response["next"]
if goto == "FINISH":
goto = END
return Command(goto=goto)
return supervisor_node
# %% [markdown]
# ## Define Agent Teams
#
# Now we can get to define our hierarchical teams. "Choose your player!"
#
# ### Research Team
#
# The research team will have a search agent and a web scraping "research_agent" as the two worker nodes. Let's create those, as well as the team supervisor.
# %%
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
llm = ChatOpenAI(model="gpt-4o")
search_agent = create_react_agent(llm, tools=[tavily_tool])
def search_node(state: MessagesState) -> Command[Literal["supervisor"]]:
result = search_agent.invoke(state)
return Command(
update={
"messages": [
HumanMessage(content=result["messages"][-1].content, name="search")
]
},
# We want our workers to ALWAYS "report back" to the supervisor when done
goto="supervisor",
)
web_scraper_agent = create_react_agent(llm, tools=[scrape_webpages])
def web_scraper_node(state: MessagesState) -> Command[Literal["supervisor"]]:
result = web_scraper_agent.invoke(state)
return Command(
update={
"messages": [
HumanMessage(content=result["messages"][-1].content, name="web_scraper")
]
},
# We want our workers to ALWAYS "report back" to the supervisor when done
goto="supervisor",
)
research_supervisor_node = make_supervisor_node(llm, ["search", "web_scraper"])
# %% [markdown]
# Now that we've created the necessary components, defining their interactions is easy. Add the nodes to the team graph, and define the edges, which determine the transition criteria.
# %%
research_builder = StateGraph(MessagesState)
research_builder.add_node("supervisor", research_supervisor_node)
research_builder.add_node("search", search_node)
research_builder.add_node("web_scraper", web_scraper_node)
research_builder.add_edge(START, "supervisor")
research_graph = research_builder.compile()
# %%
from IPython.display import Image, display
display(Image(research_graph.get_graph().draw_mermaid_png()))
# %% [markdown]
# We can give this team work directly. Try it out below.
# %%
for s in research_graph.stream(
{"messages": [("user", "when is Taylor Swift's next tour?")]},
{"recursion_limit": 100},
):
print(s)
print("---")
# %% [markdown]
# ### Document Writing Team
#
# Create the document writing team below using a similar approach. This time, we will give each agent access to different file-writing tools.
#
# Note that we are giving file-system access to our agent here, which is not safe in all cases.
# %%
llm = ChatOpenAI(model="gpt-4o")
doc_writer_agent = create_react_agent(
llm,
tools=[write_document, edit_document, read_document],
state_modifier=(
"You can read, write and edit documents based on note-taker's outlines. "
"Don't ask follow-up questions."
),
)
def doc_writing_node(state: MessagesState) -> Command[Literal["supervisor"]]:
result = doc_writer_agent.invoke(state)
return Command(
update={
"messages": [
HumanMessage(content=result["messages"][-1].content, name="doc_writer")
]
},
# We want our workers to ALWAYS "report back" to the supervisor when done
goto="supervisor",
)
note_taking_agent = create_react_agent(
llm,
tools=[create_outline, read_document],
state_modifier=(
"You can read documents and create outlines for the document writer. "
"Don't ask follow-up questions."
),
)
def note_taking_node(state: MessagesState) -> Command[Literal["supervisor"]]:
result = note_taking_agent.invoke(state)
return Command(
update={
"messages": [
HumanMessage(content=result["messages"][-1].content, name="note_taker")
]
},
# We want our workers to ALWAYS "report back" to the supervisor when done
goto="supervisor",
)
chart_generating_agent = create_react_agent(
llm, tools=[read_document, python_repl_tool]
)
def chart_generating_node(state: MessagesState) -> Command[Literal["supervisor"]]:
result = chart_generating_agent.invoke(state)
return Command(
update={
"messages": [
HumanMessage(
content=result["messages"][-1].content, name="chart_generator"
)
]
},
# We want our workers to ALWAYS "report back" to the supervisor when done
goto="supervisor",
)
doc_writing_supervisor_node = make_supervisor_node(
llm, ["doc_writer", "note_taker", "chart_generator"]
)
# %% [markdown]
# With the objects themselves created, we can form the graph.
# %%
# Create the graph here
paper_writing_builder = StateGraph(MessagesState)
paper_writing_builder.add_node("supervisor", doc_writing_supervisor_node)
paper_writing_builder.add_node("doc_writer", doc_writing_node)
paper_writing_builder.add_node("note_taker", note_taking_node)
paper_writing_builder.add_node("chart_generator", chart_generating_node)
paper_writing_builder.add_edge(START, "supervisor")
paper_writing_graph = paper_writing_builder.compile()
# %%
from IPython.display import Image, display
display(Image(paper_writing_graph.get_graph().draw_mermaid_png()))
# %%
for s in paper_writing_graph.stream(
{
"messages": [
(
"user",
"Write an outline for poem about cats and then write the poem to disk.",
)
]
},
{"recursion_limit": 100},
):
print(s)
print("---")
# %% [markdown]
# ## Add Layers
#
# In this design, we are enforcing a top-down planning policy. We've created two graphs already, but we have to decide how to route work between the two.
#
# We'll create a _third_ graph to orchestrate the previous two, and add some connectors to define how this top-level state is shared between the different graphs.
# %%
from langchain_core.messages import BaseMessage
llm = ChatOpenAI(model="gpt-4o")
teams_supervisor_node = make_supervisor_node(llm, ["research_team", "writing_team"])
# %%
def call_research_team(state: MessagesState) -> Command[Literal["supervisor"]]:
response = research_graph.invoke({"messages": state["messages"][-1]})
return Command(
update={
"messages": [
HumanMessage(
content=response["messages"][-1].content, name="research_team"
)
]
},
goto="supervisor",
)
def call_paper_writing_team(state: MessagesState) -> Command[Literal["supervisor"]]:
response = paper_writing_graph.invoke({"messages": state["messages"][-1]})
return Command(
update={
"messages": [
HumanMessage(
content=response["messages"][-1].content, name="writing_team"
)
]
},
goto="supervisor",
)
# Define the graph.
super_builder = StateGraph(MessagesState)
super_builder.add_node("supervisor", teams_supervisor_node)
super_builder.add_node("research_team", call_research_team)
super_builder.add_node("writing_team", call_paper_writing_team)
super_builder.add_edge(START, "supervisor")
super_graph = super_builder.compile()
# %%
from IPython.display import Image, display
display(Image(super_graph.get_graph().draw_mermaid_png()))
# %%
for s in super_graph.stream(
{
"messages": [
("user", "Research AI agents and write a brief report about them.")
],
},
{"recursion_limit": 150},
):
print(s)
print("---")
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/plan-and-execute/plan-and-execute.py`:
```py
# %% [markdown]
# # Plan-and-Execute
#
# This notebook shows how to create a "plan-and-execute" style agent. This is heavily inspired by the [Plan-and-Solve](https://arxiv.org/abs/2305.04091) paper as well as the [Baby-AGI](https://github.com/yoheinakajima/babyagi) project.
#
# The core idea is to first come up with a multi-step plan, and then go through that plan one item at a time.
# After accomplishing a particular task, you can then revisit the plan and modify as appropriate.
#
#
# The general computational graph looks like the following:
#
# 
#
#
# This compares to a typical [ReAct](https://arxiv.org/abs/2210.03629) style agent where you think one step at a time.
# The advantages of this "plan-and-execute" style agent are:
#
# 1. Explicit long term planning (which even really strong LLMs can struggle with)
# 2. Ability to use smaller/weaker models for the execution step, only using larger/better models for the planning step
#
#
# The following walkthrough demonstrates how to do so in LangGraph. The resulting agent will leave a trace like the following example: ([link](https://smith.langchain.com/public/d46e24d3-dda6-44d5-9550-b618fca4e0d4/r)).
# %% [markdown]
# ## Setup
#
# First, we need to install the packages required.
# %%
# %%capture --no-stderr
# %pip install --quiet -U langgraph langchain-community langchain-openai tavily-python
# %% [markdown]
# Next, we need to set API keys for OpenAI (the LLM we will use) and Tavily (the search tool we will use)
# %%
import getpass
import os
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")
# %% [markdown]
# <div class="admonition tip">
# <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p>
# <p style="padding-top: 5px;">
# Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>.
# </p>
# </div>
# %% [markdown]
# ## Define Tools
#
# We will first define the tools we want to use. For this simple example, we will use a built-in search tool via Tavily. However, it is really easy to create your own tools - see documentation [here](https://python.langchain.com/docs/how_to/custom_tools) on how to do that.
# %%
from langchain_community.tools.tavily_search import TavilySearchResults
tools = [TavilySearchResults(max_results=3)]
# %% [markdown]
# ## Define our Execution Agent
#
# Now we will create the execution agent we want to use to execute tasks.
# Note that for this example, we will be using the same execution agent for each task, but this doesn't HAVE to be the case.
# %%
from langchain import hub
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
# Get the prompt to use - you can modify this!
prompt = hub.pull("ih/ih-react-agent-executor")
prompt.pretty_print()
# Choose the LLM that will drive the agent
llm = ChatOpenAI(model="gpt-4-turbo-preview")
agent_executor = create_react_agent(llm, tools, state_modifier=prompt)
# %%
agent_executor.invoke({"messages": [("user", "who is the winnner of the us open")]})
# %% [markdown]
# ## Define the State
#
# Let's now start by defining the state the track for this agent.
#
# First, we will need to track the current plan. Let's represent that as a list of strings.
#
# Next, we should track previously executed steps. Let's represent that as a list of tuples (these tuples will contain the step and then the result)
#
# Finally, we need to have some state to represent the final response as well as the original input.
# %%
import operator
from typing import Annotated, List, Tuple
from typing_extensions import TypedDict
class PlanExecute(TypedDict):
input: str
plan: List[str]
past_steps: Annotated[List[Tuple], operator.add]
response: str
# %% [markdown]
# ## Planning Step
#
# Let's now think about creating the planning step. This will use function calling to create a plan.
# %% [markdown]
# <div class="admonition note">
# <p class="admonition-title">Using Pydantic with LangChain</p>
# <p>
# This notebook uses Pydantic v2 <code>BaseModel</code>, which requires <code>langchain-core >= 0.3</code>. Using <code>langchain-core < 0.3</code> will result in errors due to mixing of Pydantic v1 and v2 <code>BaseModels</code>.
# </p>
# </div>
# %%
from pydantic import BaseModel, Field
class Plan(BaseModel):
"""Plan to follow in future"""
steps: List[str] = Field(
description="different steps to follow, should be in sorted order"
)
# %%
from langchain_core.prompts import ChatPromptTemplate
planner_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""For the given objective, come up with a simple step by step plan. \
This plan should involve individual tasks, that if executed correctly will yield the correct answer. Do not add any superfluous steps. \
The result of the final step should be the final answer. Make sure that each step has all the information needed - do not skip steps.""",
),
("placeholder", "{messages}"),
]
)
planner = planner_prompt | ChatOpenAI(
model="gpt-4o", temperature=0
).with_structured_output(Plan)
# %%
planner.invoke(
{
"messages": [
("user", "what is the hometown of the current Australia open winner?")
]
}
)
# %% [markdown]
# ## Re-Plan Step
#
# Now, let's create a step that re-does the plan based on the result of the previous step.
# %%
from typing import Union
class Response(BaseModel):
"""Response to user."""
response: str
class Act(BaseModel):
"""Action to perform."""
action: Union[Response, Plan] = Field(
description="Action to perform. If you want to respond to user, use Response. "
"If you need to further use tools to get the answer, use Plan."
)
replanner_prompt = ChatPromptTemplate.from_template(
"""For the given objective, come up with a simple step by step plan. \
This plan should involve individual tasks, that if executed correctly will yield the correct answer. Do not add any superfluous steps. \
The result of the final step should be the final answer. Make sure that each step has all the information needed - do not skip steps.
Your objective was this:
{input}
Your original plan was this:
{plan}
You have currently done the follow steps:
{past_steps}
Update your plan accordingly. If no more steps are needed and you can return to the user, then respond with that. Otherwise, fill out the plan. Only add steps to the plan that still NEED to be done. Do not return previously done steps as part of the plan."""
)
replanner = replanner_prompt | ChatOpenAI(
model="gpt-4o", temperature=0
).with_structured_output(Act)
# %% [markdown]
# ## Create the Graph
#
# We can now create the graph!
# %%
from typing import Literal
from langgraph.graph import END
async def execute_step(state: PlanExecute):
plan = state["plan"]
plan_str = "\n".join(f"{i+1}. {step}" for i, step in enumerate(plan))
task = plan[0]
task_formatted = f"""For the following plan:
{plan_str}\n\nYou are tasked with executing step {1}, {task}."""
agent_response = await agent_executor.ainvoke(
{"messages": [("user", task_formatted)]}
)
return {
"past_steps": [(task, agent_response["messages"][-1].content)],
}
async def plan_step(state: PlanExecute):
plan = await planner.ainvoke({"messages": [("user", state["input"])]})
return {"plan": plan.steps}
async def replan_step(state: PlanExecute):
output = await replanner.ainvoke(state)
if isinstance(output.action, Response):
return {"response": output.action.response}
else:
return {"plan": output.action.steps}
def should_end(state: PlanExecute):
if "response" in state and state["response"]:
return END
else:
return "agent"
# %%
from langgraph.graph import StateGraph, START
workflow = StateGraph(PlanExecute)
# Add the plan node
workflow.add_node("planner", plan_step)
# Add the execution step
workflow.add_node("agent", execute_step)
# Add a replan node
workflow.add_node("replan", replan_step)
workflow.add_edge(START, "planner")
# From plan we go to agent
workflow.add_edge("planner", "agent")
# From agent, we replan
workflow.add_edge("agent", "replan")
workflow.add_conditional_edges(
"replan",
# Next, we pass in the function that will determine which node is called next.
should_end,
["agent", END],
)
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()
# %%
from IPython.display import Image, display
display(Image(app.get_graph(xray=True).draw_mermaid_png()))
# %%
config = {"recursion_limit": 50}
inputs = {"input": "what is the hometown of the mens 2024 Australia open winner?"}
async for event in app.astream(inputs, config=config):
for k, v in event.items():
if k != "__end__":
print(v)
# %% [markdown]
# ## Conclusion
#
# Congrats on making a plan-and-execute agent! One known limitations of the above design is that each task is still executed in sequence, meaning embarrassingly parallel operations all add to the total execution time. You could improve on this by having each task represented as a DAG (similar to LLMCompiler), rather than a regular list.
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/reflexion/reflexion.py`:
```py
# %% [markdown]
# # Reflexion
#
# [Reflexion](https://arxiv.org/abs/2303.11366) by Shinn, et. al., is an architecture designed to learn through verbal feedback and self-reflection. The agent explicitly critiques its responses for tasks to generate a higher quality final response, at the expense of longer execution time.
#
# 
#
# The paper outlines 3 main components:
#
# 1. Actor (agent) with self-reflection
# 2. External evaluator (task-specific, e.g. code compilation steps)
# 3. Episodic memory that stores the reflections from (1).
#
# In their code, the last two components are very task-specific, so in this notebook, you will build the _actor_ in LangGraph.
#
# To skip to the graph definition, see the [Construct Graph section](#Construct-Graph) below.
# %% [markdown]
# ## Setup
#
# Install `langgraph` (for the framework), `langchain_openai` (for the LLM), and `langchain` + `tavily-python` (for the search engine).
#
# We will use tavily search as a tool. You can get an API key [here](https://app.tavily.com/sign-in) or replace with a different tool of your choosing.
# %%
# %pip install -U --quiet langgraph langchain_anthropic tavily-python
# %%
import getpass
import os
def _set_if_undefined(var: str) -> None:
if os.environ.get(var):
return
os.environ[var] = getpass.getpass(var)
_set_if_undefined("ANTHROPIC_API_KEY")
_set_if_undefined("TAVILY_API_KEY")
# %% [markdown]
# <div class="admonition tip">
# <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p>
# <p style="padding-top: 5px;">
# Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>.
# </p>
# </div>
#
# ### Define our LLM
# %%
from langchain_anthropic import ChatAnthropic
llm = ChatAnthropic(model="claude-3-5-sonnet-20240620")
# You could also use OpenAI or another provider
# from langchain_openai import ChatOpenAI
# llm = ChatOpenAI(model="gpt-4-turbo-preview")
# %% [markdown]
# ## Actor (with reflection)
#
# The main component of Reflexion is the "actor", which is an agent that reflects on its response and re-executes to improve based on self-critique. It's main sub-components include:
# 1. Tools/tool execution
# 2. Initial responder: generate an initial response (and self-reflection)
# 3. Revisor: re-respond (and reflec) based on previous reflections
#
# We'll first define the tool execution context.
#
# #### Construct tools
# %%
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)
# %% [markdown]
# #### Initial responder
# %% [markdown]
# <div class="admonition note">
# <p class="admonition-title">Using Pydantic with LangChain</p>
# <p>
# This notebook uses Pydantic v2 <code>BaseModel</code>, which requires <code>langchain-core >= 0.3</code>. Using <code>langchain-core < 0.3</code> will result in errors due to mixing of Pydantic v1 and v2 <code>BaseModels</code>.
# </p>
# </div>
# %%
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from pydantic import ValidationError
from pydantic import BaseModel, Field
class Reflection(BaseModel):
missing: str = Field(description="Critique of what is missing.")
superfluous: str = Field(description="Critique of what is superfluous")
class AnswerQuestion(BaseModel):
"""Answer the question. Provide an answer, reflection, and then follow up with search queries to improve the answer."""
answer: str = Field(description="~250 word detailed answer to the question.")
reflection: Reflection = Field(description="Your reflection on the initial answer.")
search_queries: list[str] = Field(
description="1-3 search queries for researching improvements to address the critique of your current answer."
)
class ResponderWithRetries:
def __init__(self, runnable, validator):
self.runnable = runnable
self.validator = validator
def respond(self, state: dict):
response = []
for attempt in range(3):
response = self.runnable.invoke(
{"messages": state["messages"]}, {"tags": [f"attempt:{attempt}"]}
)
try:
self.validator.invoke(response)
return {"messages": response}
except ValidationError as e:
state = state + [
response,
ToolMessage(
content=f"{repr(e)}\n\nPay close attention to the function schema.\n\n"
+ self.validator.schema_json()
+ " Respond by fixing all validation errors.",
tool_call_id=response.tool_calls[0]["id"],
),
]
return {"messages": response}
# %%
import datetime
actor_prompt_template = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are expert researcher.
Current time: {time}
1. {first_instruction}
2. Reflect and critique your answer. Be severe to maximize improvement.
3. Recommend search queries to research information and improve your answer.""",
),
MessagesPlaceholder(variable_name="messages"),
(
"user",
"\n\n<system>Reflect on the user's original question and the"
" actions taken thus far. Respond using the {function_name} function.</reminder>",
),
]
).partial(
time=lambda: datetime.datetime.now().isoformat(),
)
initial_answer_chain = actor_prompt_template.partial(
first_instruction="Provide a detailed ~250 word answer.",
function_name=AnswerQuestion.__name__,
) | llm.bind_tools(tools=[AnswerQuestion])
validator = PydanticToolsParser(tools=[AnswerQuestion])
first_responder = ResponderWithRetries(
runnable=initial_answer_chain, validator=validator
)
# %%
example_question = "Why is reflection useful in AI?"
initial = first_responder.respond(
{"messages": [HumanMessage(content=example_question)]}
)
# %% [markdown]
# #### Revision
#
# The second part of the actor is a revision step.
# %%
revise_instructions = """Revise your previous answer using the new information.
- You should use the previous critique to add important information to your answer.
- You MUST include numerical citations in your revised answer to ensure it can be verified.
- Add a "References" section to the bottom of your answer (which does not count towards the word limit). In form of:
- [1] https://example.com
- [2] https://example.com
- You should use the previous critique to remove superfluous information from your answer and make SURE it is not more than 250 words.
"""
# Extend the initial answer schema to include references.
# Forcing citation in the model encourages grounded responses
class ReviseAnswer(AnswerQuestion):
"""Revise your original answer to your question. Provide an answer, reflection,
cite your reflection with references, and finally
add search queries to improve the answer."""
references: list[str] = Field(
description="Citations motivating your updated answer."
)
revision_chain = actor_prompt_template.partial(
first_instruction=revise_instructions,
function_name=ReviseAnswer.__name__,
) | llm.bind_tools(tools=[ReviseAnswer])
revision_validator = PydanticToolsParser(tools=[ReviseAnswer])
revisor = ResponderWithRetries(runnable=revision_chain, validator=revision_validator)
# %%
import json
revised = revisor.respond(
{
"messages": [
HumanMessage(content=example_question),
initial["messages"],
ToolMessage(
tool_call_id=initial["messages"].tool_calls[0]["id"],
content=json.dumps(
tavily_tool.invoke(
{
"query": initial["messages"].tool_calls[0]["args"][
"search_queries"
][0]
}
)
),
),
]
}
)
revised["messages"]
# %% [markdown]
# ## Create Tool Node
#
# Next, create a node to execute the tool calls. While we give the LLMs different schema names (and use those for validation), we want them both to route to the same tool.
# %%
from langchain_core.tools import StructuredTool
from langgraph.prebuilt import ToolNode
def run_queries(search_queries: list[str], **kwargs):
"""Run the generated queries."""
return tavily_tool.batch([{"query": query} for query in search_queries])
tool_node = ToolNode(
[
StructuredTool.from_function(run_queries, name=AnswerQuestion.__name__),
StructuredTool.from_function(run_queries, name=ReviseAnswer.__name__),
]
)
# %% [markdown]
# ## Construct Graph
#
#
# Now we can wire all our components together.
# %%
from typing import Literal
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import add_messages
from typing import Annotated
from typing_extensions import TypedDict
class State(TypedDict):
messages: Annotated[list, add_messages]
MAX_ITERATIONS = 5
builder = StateGraph(State)
builder.add_node("draft", first_responder.respond)
builder.add_node("execute_tools", tool_node)
builder.add_node("revise", revisor.respond)
# draft -> execute_tools
builder.add_edge("draft", "execute_tools")
# execute_tools -> revise
builder.add_edge("execute_tools", "revise")
# Define looping logic:
def _get_num_iterations(state: list):
i = 0
for m in state[::-1]:
if m.type not in {"tool", "ai"}:
break
i += 1
return i
def event_loop(state: list):
# in our case, we'll just stop after N plans
num_iterations = _get_num_iterations(state["messages"])
if num_iterations > MAX_ITERATIONS:
return END
return "execute_tools"
# revise -> execute_tools OR end
builder.add_conditional_edges("revise", event_loop, ["execute_tools", END])
builder.add_edge(START, "draft")
graph = builder.compile()
# %%
from IPython.display import Image, display
try:
display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
# This requires some extra dependencies and is optional
pass
# %%
events = graph.stream(
{"messages": [("user", "How should we handle the climate crisis?")]},
stream_mode="values",
)
for i, step in enumerate(events):
print(f"Step {i}")
step["messages"][-1].pretty_print()
# %% [markdown]
# ## Conclusion
#
# Congrats on building a Reflexion actor! I'll leave you with a few observations to save you some time when choosing which parts of this agent to adapt to your workflow:
# 1. This agent trades off execution time for quality. It explicitly forces the agent to critique and revise the output over several steps, which usually (not always) increases the response quality but takes much longer to return a final answer
# 2. The 'reflections' can be paired with additional external feedback (such as validators), to further guide the actor.
# 3. In the paper, 1 environment (AlfWorld) uses external memory. It does this by storing summaries of the reflections to an external store and using them in subsequent trials/invocations.
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/code_assistant/langgraph_code_assistant.py`:
```py
# %% [markdown]
# # Code generation with RAG and self-correction
#
# AlphaCodium presented an approach for code generation that uses control flow.
#
# Main idea: [construct an answer to a coding question iteratively.](https://x.com/karpathy/status/1748043513156272416?s=20).
#
# [AlphaCodium](https://github.com/Codium-ai/AlphaCodium) iteravely tests and improves an answer on public and AI-generated tests for a particular question.
#
# We will implement some of these ideas from scratch using [LangGraph](https://langchain-ai.github.io/langgraph/):
#
# 1. We start with a set of documentation specified by a user
# 2. We use a long context LLM to ingest it and perform RAG to answer a question based upon it
# 3. We will invoke a tool to produce a structured output
# 4. We will perform two unit tests (check imports and code execution) prior returning the solution to the user
#
# 
# %% [markdown]
# ## Setup
#
# First, let's install our required packages and set the API keys we will need
# %%
# ! pip install -U langchain_community langchain-openai langchain-anthropic langchain langgraph bs4
# %%
import getpass
import os
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("OPENAI_API_KEY")
_set_env("ANTHROPIC_API_KEY")
# %% [markdown]
# <div class="admonition tip">
# <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p>
# <p style="padding-top: 5px;">
# Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>.
# </p>
# </div>
# %% [markdown]
# ## Docs
#
# Load [LangChain Expression Language](https://python.langchain.com/docs/concepts/#langchain-expression-language-lcel) (LCEL) docs as an example.
# %%
from bs4 import BeautifulSoup as Soup
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
# LCEL docs
url = "https://python.langchain.com/docs/concepts/lcel/"
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
# %% [markdown]
# ## LLMs
#
# ### Code solution
#
# First, we will try OpenAI and [Claude3](https://docs.anthropic.com/en/docs/about-claude/models) with function calling.
#
# We will create a `code_gen_chain` w/ either OpenAI or Claude and test them here.
# %% [markdown]
# <div class="admonition note">
# <p class="admonition-title">Using Pydantic with LangChain</p>
# <p>
# This notebook uses Pydantic v2 <code>BaseModel</code>, which requires <code>langchain-core >= 0.3</code>. Using <code>langchain-core < 0.3</code> will result in errors due to mixing of Pydantic v1 and v2 <code>BaseModels</code>.
# </p>
# </div>
# %%
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
### OpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a coding assistant with expertise in LCEL, LangChain expression language. \n
Here is a full set of LCEL documentation: \n ------- \n {context} \n ------- \n Answer the user
question based on the above provided documentation. Ensure any code you provide can be executed \n
with all required imports and variables defined. Structure your answer with a description of the code solution. \n
Then list the imports. And finally list the functioning code block. Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
# Data model
class code(BaseModel):
"""Schema for code solutions to questions about LCEL."""
prefix: str = Field(description="Description of the problem and approach")
imports: str = Field(description="Code block import statements")
code: str = Field(description="Code block not including import statements")
expt_llm = "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain_oai = code_gen_prompt | llm.with_structured_output(code)
question = "How do I build a RAG chain in LCEL?"
solution = code_gen_chain_oai.invoke(
{"context": concatenated_content, "messages": [("user", question)]}
)
solution
# %%
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
### Anthropic
# Prompt to enforce tool use
code_gen_prompt_claude = ChatPromptTemplate.from_messages(
[
(
"system",
"""<instructions> You are a coding assistant with expertise in LCEL, LangChain expression language. \n
Here is the LCEL documentation: \n ------- \n {context} \n ------- \n Answer the user question based on the \n
above provided documentation. Ensure any code you provide can be executed with all required imports and variables \n
defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n
Invoke the code tool to structure the output correctly. </instructions> \n Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
# LLM
expt_llm = "claude-3-opus-20240229"
llm = ChatAnthropic(
model=expt_llm,
default_headers={"anthropic-beta": "tools-2024-04-04"},
)
structured_llm_claude = llm.with_structured_output(code, include_raw=True)
# Optional: Check for errors in case tool use is flaky
def check_claude_output(tool_output):
"""Check for parse error or failure to call the tool"""
# Error with parsing
if tool_output["parsing_error"]:
# Report back output and parsing errors
print("Parsing error!")
raw_output = str(tool_output["raw"].content)
error = tool_output["parsing_error"]
raise ValueError(
f"Error parsing your output! Be sure to invoke the tool. Output: {raw_output}. \n Parse error: {error}"
)
# Tool was not invoked
elif not tool_output["parsed"]:
print("Failed to invoke tool!")
raise ValueError(
"You did not use the provided tool! Be sure to invoke the tool to structure the output."
)
return tool_output
# Chain with output check
code_chain_claude_raw = (
code_gen_prompt_claude | structured_llm_claude | check_claude_output
)
def insert_errors(inputs):
"""Insert errors for tool parsing in the messages"""
# Get errors
error = inputs["error"]
messages = inputs["messages"]
messages += [
(
"assistant",
f"Retry. You are required to fix the parsing errors: {error} \n\n You must invoke the provided tool.",
)
]
return {
"messages": messages,
"context": inputs["context"],
}
# This will be run as a fallback chain
fallback_chain = insert_errors | code_chain_claude_raw
N = 3 # Max re-tries
code_gen_chain_re_try = code_chain_claude_raw.with_fallbacks(
fallbacks=[fallback_chain] * N, exception_key="error"
)
def parse_output(solution):
"""When we add 'include_raw=True' to structured output,
it will return a dict w 'raw', 'parsed', 'parsing_error'."""
return solution["parsed"]
# Optional: With re-try to correct for failure to invoke tool
code_gen_chain = code_gen_chain_re_try | parse_output
# No re-try
code_gen_chain = code_gen_prompt_claude | structured_llm_claude | parse_output
# %%
# Test
question = "How do I build a RAG chain in LCEL?"
solution = code_gen_chain.invoke(
{"context": concatenated_content, "messages": [("user", question)]}
)
solution
# %% [markdown]
# ## State
#
# Our state is a dict that will contain keys (errors, question, code generation) relevant to code generation.
# %%
from typing import List
from typing_extensions import TypedDict
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
error : Binary flag for control flow to indicate whether test error was tripped
messages : With user question, error messages, reasoning
generation : Code solution
iterations : Number of tries
"""
error: str
messages: List
generation: str
iterations: int
# %% [markdown]
# ## Graph
#
# Our graph lays out the logical flow shown in the figure above.
# %%
### Parameter
# Max tries
max_iterations = 3
# Reflect
# flag = 'reflect'
flag = "do not reflect"
### Nodes
def generate(state: GraphState):
"""
Generate a code solution
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
# State
messages = state["messages"]
iterations = state["iterations"]
error = state["error"]
# We have been routed back to generation with an error
if error == "yes":
messages += [
(
"user",
"Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:",
)
]
# Solution
code_solution = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
(
"assistant",
f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
)
]
# Increment
iterations = iterations + 1
return {"generation": code_solution, "messages": messages, "iterations": iterations}
def code_check(state: GraphState):
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state["messages"]
code_solution = state["generation"]
iterations = state["iterations"]
# Get solution components
imports = code_solution.imports
code = code_solution.code
# Check imports
try:
exec(imports)
except Exception as e:
print("---CODE IMPORT CHECK: FAILED---")
error_message = [("user", f"Your solution failed the import test: {e}")]
messages += error_message
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
# Check execution
try:
exec(imports + "\n" + code)
except Exception as e:
print("---CODE BLOCK CHECK: FAILED---")
error_message = [("user", f"Your solution failed the code execution test: {e}")]
messages += error_message
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "yes",
}
# No errors
print("---NO CODE TEST FAILURES---")
return {
"generation": code_solution,
"messages": messages,
"iterations": iterations,
"error": "no",
}
def reflect(state: GraphState):
"""
Reflect on errors
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
# State
messages = state["messages"]
iterations = state["iterations"]
code_solution = state["generation"]
# Prompt reflection
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [("assistant", f"Here are reflections on the error: {reflections}")]
return {"generation": code_solution, "messages": messages, "iterations": iterations}
### Edges
def decide_to_finish(state: GraphState):
"""
Determines whether to finish.
Args:
state (dict): The current graph state
Returns:
str: Next node to call
"""
error = state["error"]
iterations = state["iterations"]
if error == "no" or iterations == max_iterations:
print("---DECISION: FINISH---")
return "end"
else:
print("---DECISION: RE-TRY SOLUTION---")
if flag == "reflect":
return "reflect"
else:
return "generate"
# %%
from langgraph.graph import END, StateGraph, START
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("generate", generate) # generation solution
workflow.add_node("check_code", code_check) # check code
workflow.add_node("reflect", reflect) # reflect
# Build graph
workflow.add_edge(START, "generate")
workflow.add_edge("generate", "check_code")
workflow.add_conditional_edges(
"check_code",
decide_to_finish,
{
"end": END,
"reflect": "reflect",
"generate": "generate",
},
)
workflow.add_edge("reflect", "generate")
app = workflow.compile()
# %%
question = "How can I directly pass a string to a runnable and use it to construct the input needed for my prompt?"
solution = app.invoke({"messages": [("user", question)], "iterations": 0, "error": ""})
# %%
solution["generation"]
# %% [markdown]
# ## Eval
# %% [markdown]
# [Here](https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d) is a public dataset of LCEL questions.
#
# I saved this as `lcel-teacher-eval`.
#
# You can also find the csv [here](https://github.com/langchain-ai/lcel-teacher/blob/main/eval/eval.csv).
# %%
import langsmith
client = langsmith.Client()
# %%
# Clone the dataset to your tenant to use it
try:
public_dataset = (
"https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d"
)
client.clone_public_dataset(public_dataset)
except:
print("Please setup LangSmith")
# %% [markdown]
# Custom evals.
# %%
from langsmith.schemas import Example, Run
def check_import(run: Run, example: Example) -> dict:
imports = run.outputs.get("imports")
try:
exec(imports)
return {"key": "import_check", "score": 1}
except Exception:
return {"key": "import_check", "score": 0}
def check_execution(run: Run, example: Example) -> dict:
imports = run.outputs.get("imports")
code = run.outputs.get("code")
try:
exec(imports + "\n" + code)
return {"key": "code_execution_check", "score": 1}
except Exception:
return {"key": "code_execution_check", "score": 0}
# %% [markdown]
# Compare LangGraph to Context Stuffing.
# %%
def predict_base_case(example: dict):
"""Context stuffing"""
solution = code_gen_chain.invoke(
{"context": concatenated_content, "messages": [("user", example["question"])]}
)
return {"imports": solution.imports, "code": solution.code}
def predict_langgraph(example: dict):
"""LangGraph"""
graph = app.invoke(
{"messages": [("user", example["question"])], "iterations": 0, "error": ""}
)
solution = graph["generation"]
return {"imports": solution.imports, "code": solution.code}
# %%
from langsmith.evaluation import evaluate
# Evaluator
code_evalulator = [check_import, check_execution]
# Dataset
dataset_name = "lcel-teacher-eval"
# %%
# Run base case
try:
experiment_results_ = evaluate(
predict_base_case,
data=dataset_name,
evaluators=code_evalulator,
experiment_prefix=f"test-without-langgraph-{expt_llm}",
max_concurrency=2,
metadata={
"llm": expt_llm,
},
)
except:
print("Please setup LangSmith")
# %%
# Run with langgraph
try:
experiment_results = evaluate(
predict_langgraph,
data=dataset_name,
evaluators=code_evalulator,
experiment_prefix=f"test-with-langgraph-{expt_llm}-{flag}",
max_concurrency=2,
metadata={
"llm": expt_llm,
"feedback": flag,
},
)
except:
print("Please setup LangSmith")
# %% [markdown]
# `Results:`
#
# * `LangGraph outperforms base case`: adding re-try loop improve performance
# * `Reflection did not help`: reflection prior to re-try regression vs just passing errors directly back to the LLM
# * `GPT-4 outperforms Claude3`: Claude3 had 3 and 1 run fail due to tool-use error for Opus and Haiku, respectively
#
# https://smith.langchain.com/public/78a3d858-c811-4e46-91cb-0f10ef56260b/d
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/llm-compiler/output_parser.py`:
```py
import ast
import re
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from typing_extensions import TypedDict
THOUGHT_PATTERN = r"Thought: ([^\n]*)"
ACTION_PATTERN = r"\n*(\d+)\. (\w+)\((.*)\)(\s*#\w+\n)?"
# $1 or ${1} -> 1
ID_PATTERN = r"\$\{?(\d+)\}?"
END_OF_PLAN = "<END_OF_PLAN>"
### Helper functions
def _ast_parse(arg: str) -> Any:
try:
return ast.literal_eval(arg)
except: # noqa
return arg
def _parse_llm_compiler_action_args(args: str, tool: Union[str, BaseTool]) -> list[Any]:
"""Parse arguments from a string."""
if args == "":
return ()
if isinstance(tool, str):
return ()
extracted_args = {}
tool_key = None
prev_idx = None
for key in tool.args.keys():
# Split if present
if f"{key}=" in args:
idx = args.index(f"{key}=")
if prev_idx is not None:
extracted_args[tool_key] = _ast_parse(
args[prev_idx:idx].strip().rstrip(",")
)
args = args.split(f"{key}=", 1)[1]
tool_key = key
prev_idx = 0
if prev_idx is not None:
extracted_args[tool_key] = _ast_parse(
args[prev_idx:].strip().rstrip(",").rstrip(")")
)
return extracted_args
def default_dependency_rule(idx, args: str):
matches = re.findall(ID_PATTERN, args)
numbers = [int(match) for match in matches]
return idx in numbers
def _get_dependencies_from_graph(
idx: int, tool_name: str, args: Dict[str, Any]
) -> dict[str, list[str]]:
"""Get dependencies from a graph."""
if tool_name == "join":
return list(range(1, idx))
return [i for i in range(1, idx) if default_dependency_rule(i, str(args))]
class Task(TypedDict):
idx: int
tool: BaseTool
args: list
dependencies: Dict[str, list]
thought: Optional[str]
def instantiate_task(
tools: Sequence[BaseTool],
idx: int,
tool_name: str,
args: Union[str, Any],
thought: Optional[str] = None,
) -> Task:
if tool_name == "join":
tool = "join"
else:
try:
tool = tools[[tool.name for tool in tools].index(tool_name)]
except ValueError as e:
raise OutputParserException(f"Tool {tool_name} not found.") from e
tool_args = _parse_llm_compiler_action_args(args, tool)
dependencies = _get_dependencies_from_graph(idx, tool_name, tool_args)
return Task(
idx=idx,
tool=tool,
args=tool_args,
dependencies=dependencies,
thought=thought,
)
class LLMCompilerPlanParser(BaseTransformOutputParser[dict], extra="allow"):
"""Planning output parser."""
tools: List[BaseTool]
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Task]:
texts = []
# TODO: Cleanup tuple state tracking here.
thought = None
for chunk in input:
# Assume input is str. TODO: support vision/other formats
text = chunk if isinstance(chunk, str) else str(chunk.content)
for task, thought in self.ingest_token(text, texts, thought):
yield task
# Final possible task
if texts:
task, _ = self._parse_task("".join(texts), thought)
if task:
yield task
def parse(self, text: str) -> List[Task]:
return list(self._transform([text]))
def stream(
self,
input: str | BaseMessage,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Task]:
yield from self.transform([input], config, **kwargs)
def ingest_token(
self, token: str, buffer: List[str], thought: Optional[str]
) -> Iterator[Tuple[Optional[Task], str]]:
buffer.append(token)
if "\n" in token:
buffer_ = "".join(buffer).split("\n")
suffix = buffer_[-1]
for line in buffer_[:-1]:
task, thought = self._parse_task(line, thought)
if task:
yield task, thought
buffer.clear()
buffer.append(suffix)
def _parse_task(self, line: str, thought: Optional[str] = None):
task = None
if match := re.match(THOUGHT_PATTERN, line):
# Optionally, action can be preceded by a thought
thought = match.group(1)
elif match := re.match(ACTION_PATTERN, line):
# if action is parsed, return the task, and clear the buffer
idx, tool_name, args, _ = match.groups()
idx = int(idx)
task = instantiate_task(
tools=self.tools,
idx=idx,
tool_name=tool_name,
args=args,
thought=thought,
)
thought = None
# Else it is just dropped
return task, thought
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/llm-compiler/math_tools.py`:
```py
import math
import re
from typing import List, Optional
import numexpr
from langchain.chains.openai_functions import create_structured_output_runnable
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import StructuredTool
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
_MATH_DESCRIPTION = (
"math(problem: str, context: Optional[list[str]]) -> float:\n"
" - Solves the provided math problem.\n"
' - `problem` can be either a simple math problem (e.g. "1 + 3") or a word problem (e.g. "how many apples are there if there are 3 apples and 2 apples").\n'
" - You cannot calculate multiple expressions in one call. For instance, `math('1 + 3, 2 + 4')` does not work. "
"If you need to calculate multiple expressions, you need to call them separately like `math('1 + 3')` and then `math('2 + 4')`\n"
" - Minimize the number of `math` actions as much as possible. For instance, instead of calling "
'2. math("what is the 10% of $1") and then call 3. math("$1 + $2"), '
'you MUST call 2. math("what is the 110% of $1") instead, which will reduce the number of math actions.\n'
# Context specific rules below
" - You can optionally provide a list of strings as `context` to help the agent solve the problem. "
"If there are multiple contexts you need to answer the question, you can provide them as a list of strings.\n"
" - `math` action will not see the output of the previous actions unless you provide it as `context`. "
"You MUST provide the output of the previous actions as `context` if you need to do math on it.\n"
" - You MUST NEVER provide `search` type action's outputs as a variable in the `problem` argument. "
"This is because `search` returns a text blob that contains the information about the entity, not a number or value. "
"Therefore, when you need to provide an output of `search` action, you MUST provide it as a `context` argument to `math` action. "
'For example, 1. search("Barack Obama") and then 2. math("age of $1") is NEVER allowed. '
'Use 2. math("age of Barack Obama", context=["$1"]) instead.\n'
" - When you ask a question about `context`, specify the units. "
'For instance, "what is xx in height?" or "what is xx in millions?" instead of "what is xx?"\n'
)
_SYSTEM_PROMPT = """Translate a math problem into a expression that can be executed using Python's numexpr library. Use the output of running this code to answer the question.
Question: ${{Question with math problem.}}
```text
${{single line mathematical expression that solves the problem}}
```
...numexpr.evaluate(text)...
```output
${{Output of running the code}}
```
Answer: ${{Answer}}
Begin.
Question: What is 37593 * 67?
ExecuteCode({{code: "37593 * 67"}})
...numexpr.evaluate("37593 * 67")...
```output
2518731
```
Answer: 2518731
Question: 37593^(1/5)
ExecuteCode({{code: "37593**(1/5)"}})
...numexpr.evaluate("37593**(1/5)")...
```output
8.222831614237718
```
Answer: 8.222831614237718
"""
_ADDITIONAL_CONTEXT_PROMPT = """The following additional context is provided from other functions.\
Use it to substitute into any ${{#}} variables or other words in the problem.\
\n\n${context}\n\nNote that context variables are not defined in code yet.\
You must extract the relevant numbers and directly put them in code."""
class ExecuteCode(BaseModel):
"""The input to the numexpr.evaluate() function."""
reasoning: str = Field(
...,
description="The reasoning behind the code expression, including how context is included, if applicable.",
)
code: str = Field(
...,
description="The simple code expression to execute by numexpr.evaluate().",
)
def _evaluate_expression(expression: str) -> str:
try:
local_dict = {"pi": math.pi, "e": math.e}
output = str(
numexpr.evaluate(
expression.strip(),
global_dict={}, # restrict access to globals
local_dict=local_dict, # add common mathematical functions
)
)
except Exception as e:
raise ValueError(
f'Failed to evaluate "{expression}". Raised error: {repr(e)}.'
" Please try again with a valid numerical expression"
)
# Remove any leading and trailing brackets from the output
return re.sub(r"^\[|\]$", "", output)
def get_math_tool(llm: ChatOpenAI):
prompt = ChatPromptTemplate.from_messages(
[
("system", _SYSTEM_PROMPT),
("user", "{problem}"),
MessagesPlaceholder(variable_name="context", optional=True),
]
)
extractor = prompt | llm.with_structured_output(ExecuteCode)
def calculate_expression(
problem: str,
context: Optional[List[str]] = None,
config: Optional[RunnableConfig] = None,
):
chain_input = {"problem": problem}
if context:
context_str = "\n".join(context)
if context_str.strip():
context_str = _ADDITIONAL_CONTEXT_PROMPT.format(
context=context_str.strip()
)
chain_input["context"] = [SystemMessage(content=context_str)]
code_model = extractor.invoke(chain_input, config)
try:
return _evaluate_expression(code_model.code)
except Exception as e:
return repr(e)
return StructuredTool.from_function(
name="math",
func=calculate_expression,
description=_MATH_DESCRIPTION,
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/docs/docs/tutorials/reflection/reflection.py`:
```py
# %% [markdown]
# # Reflection
#
#
# In the context of LLM agent building, reflection refers to the process of prompting an LLM to observe its past steps (along with potential observations from tools/the environment) to assess the quality of the chosen actions.
# This is then used downstream for things like re-planning, search, or evaluation.
#
# 
#
# This notebook demonstrates a very simple form of reflection in LangGraph.
# %% [markdown]
# ## Setup
#
# First, let's install our required packages and set our API keys
# %%
# %pip install -U --quiet langgraph langchain-fireworks
# %pip install -U --quiet tavily-python
# %%
import getpass
import os
def _set_if_undefined(var: str) -> None:
if os.environ.get(var):
return
os.environ[var] = getpass.getpass(var)
_set_if_undefined("TAVILY_API_KEY")
_set_if_undefined("FIREWORKS_API_KEY")
# %% [markdown]
# <div class="admonition tip">
# <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p>
# <p style="padding-top: 5px;">
# Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>.
# </p>
# </div>
# %% [markdown]
# ## Generate
#
# For our example, we will create a "5 paragraph essay" generator. First, create the generator:
#
# %%
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_fireworks import ChatFireworks
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an essay assistant tasked with writing excellent 5-paragraph essays."
" Generate the best essay possible for the user's request."
" If the user provides critique, respond with a revised version of your previous attempts.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
llm = ChatFireworks(
model="accounts/fireworks/models/mixtral-8x7b-instruct", max_tokens=32768
)
generate = prompt | llm
# %%
essay = ""
request = HumanMessage(
content="Write an essay on why the little prince is relevant in modern childhood"
)
for chunk in generate.stream({"messages": [request]}):
print(chunk.content, end="")
essay += chunk.content
# %% [markdown]
# ### Reflect
# %%
reflection_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a teacher grading an essay submission. Generate critique and recommendations for the user's submission."
" Provide detailed recommendations, including requests for length, depth, style, etc.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
reflect = reflection_prompt | llm
# %%
reflection = ""
for chunk in reflect.stream({"messages": [request, HumanMessage(content=essay)]}):
print(chunk.content, end="")
reflection += chunk.content
# %% [markdown]
# ### Repeat
#
# And... that's all there is too it! You can repeat in a loop for a fixed number of steps, or use an LLM (or other check) to decide when the finished product is good enough.
# %%
for chunk in generate.stream(
{"messages": [request, AIMessage(content=essay), HumanMessage(content=reflection)]}
):
print(chunk.content, end="")
# %% [markdown]
# ## Define graph
#
# Now that we've shown each step in isolation, we can wire it up in a graph.
# %%
from typing import Annotated, List, Sequence
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
from typing_extensions import TypedDict
class State(TypedDict):
messages: Annotated[list, add_messages]
async def generation_node(state: State) -> State:
return {"messages": [await generate.ainvoke(state["messages"])]}
async def reflection_node(state: State) -> State:
# Other messages we need to adjust
cls_map = {"ai": HumanMessage, "human": AIMessage}
# First message is the original user request. We hold it the same for all nodes
translated = [state["messages"][0]] + [
cls_map[msg.type](content=msg.content) for msg in state["messages"][1:]
]
res = await reflect.ainvoke(translated)
# We treat the output of this as human feedback for the generator
return {"messages": [HumanMessage(content=res.content)]}
builder = StateGraph(State)
builder.add_node("generate", generation_node)
builder.add_node("reflect", reflection_node)
builder.add_edge(START, "generate")
def should_continue(state: State):
if len(state["messages"]) > 6:
# End after 3 iterations
return END
return "reflect"
builder.add_conditional_edges("generate", should_continue)
builder.add_edge("reflect", "generate")
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)
# %%
config = {"configurable": {"thread_id": "1"}}
# %%
async for event in graph.astream(
{
"messages": [
HumanMessage(
content="Generate an essay on the topicality of The Little Prince and its message in modern life"
)
],
},
config,
):
print(event)
print("---")
# %%
state = graph.get_state(config)
# %%
ChatPromptTemplate.from_messages(state.values["messages"]).pretty_print()
# %% [markdown]
# ## Conclusion
#
# Now that you've applied reflection to an LLM agent, I'll note one thing: self-reflection is inherently cyclic: it is much more effective if the reflection step has additional context or feedback (from tool observations, checks, etc.). If, like in the scenario above, the reflection step simply prompts the LLM to reflect on its output, it can still benefit the output quality (since the LLM then has multiple "shots" at getting a good output), but it's less guaranteed.
#
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/checkpoint/postgres/_internal.py`:
```py
"""Shared utility functions for the Postgres checkpoint & storage classes."""
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Union
from psycopg import Connection
from psycopg.rows import DictRow
from psycopg_pool import ConnectionPool
Conn = Union[Connection[DictRow], ConnectionPool[Connection[DictRow]]]
@contextmanager
def get_connection(conn: Conn) -> Iterator[Connection[DictRow]]:
if isinstance(conn, Connection):
yield conn
elif isinstance(conn, ConnectionPool):
with conn.connection() as conn:
yield conn
else:
raise TypeError(f"Invalid connection type: {type(conn)}")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/checkpoint/postgres/_ainternal.py`:
```py
"""Shared async utility functions for the Postgres checkpoint & storage classes."""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Union
from psycopg import AsyncConnection
from psycopg.rows import DictRow
from psycopg_pool import AsyncConnectionPool
Conn = Union[AsyncConnection[DictRow], AsyncConnectionPool[AsyncConnection[DictRow]]]
@asynccontextmanager
async def get_connection(
conn: Conn,
) -> AsyncIterator[AsyncConnection[DictRow]]:
if isinstance(conn, AsyncConnection):
yield conn
elif isinstance(conn, AsyncConnectionPool):
async with conn.connection() as conn:
yield conn
else:
raise TypeError(f"Invalid connection type: {type(conn)}")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py`:
```py
import threading
from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from typing import Any, Optional
from langchain_core.runnables import RunnableConfig
from psycopg import Capabilities, Connection, Cursor, Pipeline
from psycopg.rows import DictRow, dict_row
from psycopg.types.json import Jsonb
from psycopg_pool import ConnectionPool
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
get_checkpoint_id,
)
from langgraph.checkpoint.postgres import _internal
from langgraph.checkpoint.postgres.base import BasePostgresSaver
from langgraph.checkpoint.serde.base import SerializerProtocol
Conn = _internal.Conn # For backward compatibility
class PostgresSaver(BasePostgresSaver):
lock: threading.Lock
def __init__(
self,
conn: _internal.Conn,
pipe: Optional[Pipeline] = None,
serde: Optional[SerializerProtocol] = None,
) -> None:
super().__init__(serde=serde)
if isinstance(conn, ConnectionPool) and pipe is not None:
raise ValueError(
"Pipeline should be used only with a single Connection, not ConnectionPool."
)
self.conn = conn
self.pipe = pipe
self.lock = threading.Lock()
self.supports_pipeline = Capabilities().has_pipeline()
@classmethod
@contextmanager
def from_conn_string(
cls, conn_string: str, *, pipeline: bool = False
) -> Iterator["PostgresSaver"]:
"""Create a new PostgresSaver instance from a connection string.
Args:
conn_string (str): The Postgres connection info string.
pipeline (bool): whether to use Pipeline
Returns:
PostgresSaver: A new PostgresSaver instance.
"""
with Connection.connect(
conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row
) as conn:
if pipeline:
with conn.pipeline() as pipe:
yield cls(conn, pipe)
else:
yield cls(conn)
def setup(self) -> None:
"""Set up the checkpoint database asynchronously.
This method creates the necessary tables in the Postgres database if they don't
already exist and runs database migrations. It MUST be called directly by the user
the first time checkpointer is used.
"""
with self._cursor() as cur:
cur.execute(self.MIGRATIONS[0])
results = cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
)
row = results.fetchone()
if row is None:
version = -1
else:
version = row["v"]
for v, migration in zip(
range(version + 1, len(self.MIGRATIONS)),
self.MIGRATIONS[version + 1 :],
):
cur.execute(migration)
cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})")
if self.pipe:
self.pipe.sync()
def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints from the database.
This method retrieves a list of checkpoint tuples from the Postgres database based
on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first).
Args:
config (RunnableConfig): The config to use for listing the checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. Defaults to None.
before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None.
limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None.
Yields:
Iterator[CheckpointTuple]: An iterator of checkpoint tuples.
Examples:
>>> from langgraph.checkpoint.postgres import PostgresSaver
>>> DB_URI = "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable"
>>> with PostgresSaver.from_conn_string(DB_URI) as memory:
... # Run a graph, then list the checkpoints
>>> config = {"configurable": {"thread_id": "1"}}
>>> checkpoints = list(memory.list(config, limit=2))
>>> print(checkpoints)
[CheckpointTuple(...), CheckpointTuple(...)]
>>> config = {"configurable": {"thread_id": "1"}}
>>> before = {"configurable": {"checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875"}}
>>> with PostgresSaver.from_conn_string(DB_URI) as memory:
... # Run a graph, then list the checkpoints
>>> checkpoints = list(memory.list(config, before=before))
>>> print(checkpoints)
[CheckpointTuple(...), ...]
"""
where, args = self._search_where(config, filter, before)
query = self.SELECT_SQL + where + " ORDER BY checkpoint_id DESC"
if limit:
query += f" LIMIT {limit}"
# if we change this to use .stream() we need to make sure to close the cursor
with self._cursor() as cur:
cur.execute(query, args, binary=True)
for value in cur:
yield CheckpointTuple(
{
"configurable": {
"thread_id": value["thread_id"],
"checkpoint_ns": value["checkpoint_ns"],
"checkpoint_id": value["checkpoint_id"],
}
},
self._load_checkpoint(
value["checkpoint"],
value["channel_values"],
value["pending_sends"],
),
self._load_metadata(value["metadata"]),
(
{
"configurable": {
"thread_id": value["thread_id"],
"checkpoint_ns": value["checkpoint_ns"],
"checkpoint_id": value["parent_checkpoint_id"],
}
}
if value["parent_checkpoint_id"]
else None
),
self._load_writes(value["pending_writes"]),
)
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database.
This method retrieves a checkpoint tuple from the Postgres database based on the
provided config. If the config contains a "checkpoint_id" key, the checkpoint with
the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
Examples:
Basic:
>>> config = {"configurable": {"thread_id": "1"}}
>>> checkpoint_tuple = memory.get_tuple(config)
>>> print(checkpoint_tuple)
CheckpointTuple(...)
With timestamp:
>>> config = {
... "configurable": {
... "thread_id": "1",
... "checkpoint_ns": "",
... "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875",
... }
... }
>>> checkpoint_tuple = memory.get_tuple(config)
>>> print(checkpoint_tuple)
CheckpointTuple(...)
""" # noqa
thread_id = config["configurable"]["thread_id"]
checkpoint_id = get_checkpoint_id(config)
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id:
args: tuple[Any, ...] = (thread_id, checkpoint_ns, checkpoint_id)
where = "WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s"
else:
args = (thread_id, checkpoint_ns)
where = "WHERE thread_id = %s AND checkpoint_ns = %s ORDER BY checkpoint_id DESC LIMIT 1"
with self._cursor() as cur:
cur.execute(
self.SELECT_SQL + where,
args,
binary=True,
)
for value in cur:
return CheckpointTuple(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": value["checkpoint_id"],
}
},
self._load_checkpoint(
value["checkpoint"],
value["channel_values"],
value["pending_sends"],
),
self._load_metadata(value["metadata"]),
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": value["parent_checkpoint_id"],
}
}
if value["parent_checkpoint_id"]
else None
),
self._load_writes(value["pending_writes"]),
)
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database.
This method saves a checkpoint to the Postgres database. The checkpoint is associated
with the provided config and its parent config (if any).
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Examples:
>>> from langgraph.checkpoint.postgres import PostgresSaver
>>> DB_URI = "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable"
>>> with PostgresSaver.from_conn_string(DB_URI) as memory:
>>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
>>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "channel_values": {"key": "value"}}
>>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {})
>>> print(saved_config)
{'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef4f797-8335-6428-8001-8a1503f9b875'}}
"""
configurable = config["configurable"].copy()
thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
checkpoint_id = configurable.pop(
"checkpoint_id", configurable.pop("thread_ts", None)
)
copy = checkpoint.copy()
next_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
with self._cursor(pipeline=True) as cur:
cur.executemany(
self.UPSERT_CHECKPOINT_BLOBS_SQL,
self._dump_blobs(
thread_id,
checkpoint_ns,
copy.pop("channel_values"), # type: ignore[misc]
new_versions,
),
)
cur.execute(
self.UPSERT_CHECKPOINTS_SQL,
(
thread_id,
checkpoint_ns,
checkpoint["id"],
checkpoint_id,
Jsonb(self._dump_checkpoint(copy)),
self._dump_metadata(metadata),
),
)
return next_config
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint.
This method saves intermediate writes associated with a checkpoint to the Postgres database.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (List[Tuple[str, Any]]): List of writes to store.
task_id (str): Identifier for the task creating the writes.
"""
query = (
self.UPSERT_CHECKPOINT_WRITES_SQL
if all(w[0] in WRITES_IDX_MAP for w in writes)
else self.INSERT_CHECKPOINT_WRITES_SQL
)
with self._cursor(pipeline=True) as cur:
cur.executemany(
query,
self._dump_writes(
config["configurable"]["thread_id"],
config["configurable"]["checkpoint_ns"],
config["configurable"]["checkpoint_id"],
task_id,
writes,
),
)
@contextmanager
def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
"""Create a database cursor as a context manager.
Args:
pipeline (bool): whether to use pipeline for the DB operations inside the context manager.
Will be applied regardless of whether the PostgresSaver instance was initialized with a pipeline.
If pipeline mode is not supported, will fall back to using transaction context manager.
"""
with _internal.get_connection(self.conn) as conn:
if self.pipe:
# a connection in pipeline mode can be used concurrently
# in multiple threads/coroutines, but only one cursor can be
# used at a time
try:
with conn.cursor(binary=True, row_factory=dict_row) as cur:
yield cur
finally:
if pipeline:
self.pipe.sync()
elif pipeline:
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
with (
self.lock,
conn.pipeline(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
# Use connection's transaction context manager when pipeline mode not supported
with (
self.lock,
conn.transaction(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
with self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur:
yield cur
__all__ = ["PostgresSaver", "BasePostgresSaver", "Conn"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py`:
```py
import asyncio
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import asynccontextmanager
from typing import Any, Optional
from langchain_core.runnables import RunnableConfig
from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline, Capabilities
from psycopg.rows import DictRow, dict_row
from psycopg.types.json import Jsonb
from psycopg_pool import AsyncConnectionPool
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
get_checkpoint_id,
)
from langgraph.checkpoint.postgres import _ainternal
from langgraph.checkpoint.postgres.base import BasePostgresSaver
from langgraph.checkpoint.serde.base import SerializerProtocol
Conn = _ainternal.Conn # For backward compatibility
class AsyncPostgresSaver(BasePostgresSaver):
lock: asyncio.Lock
def __init__(
self,
conn: _ainternal.Conn,
pipe: Optional[AsyncPipeline] = None,
serde: Optional[SerializerProtocol] = None,
) -> None:
super().__init__(serde=serde)
if isinstance(conn, AsyncConnectionPool) and pipe is not None:
raise ValueError(
"Pipeline should be used only with a single AsyncConnection, not AsyncConnectionPool."
)
self.conn = conn
self.pipe = pipe
self.lock = asyncio.Lock()
self.loop = asyncio.get_running_loop()
self.supports_pipeline = Capabilities().has_pipeline()
@classmethod
@asynccontextmanager
async def from_conn_string(
cls,
conn_string: str,
*,
pipeline: bool = False,
serde: Optional[SerializerProtocol] = None,
) -> AsyncIterator["AsyncPostgresSaver"]:
"""Create a new AsyncPostgresSaver instance from a connection string.
Args:
conn_string (str): The Postgres connection info string.
pipeline (bool): whether to use AsyncPipeline
Returns:
AsyncPostgresSaver: A new AsyncPostgresSaver instance.
"""
async with await AsyncConnection.connect(
conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row
) as conn:
if pipeline:
async with conn.pipeline() as pipe:
yield cls(conn=conn, pipe=pipe, serde=serde)
else:
yield cls(conn=conn, serde=serde)
async def setup(self) -> None:
"""Set up the checkpoint database asynchronously.
This method creates the necessary tables in the Postgres database if they don't
already exist and runs database migrations. It MUST be called directly by the user
the first time checkpointer is used.
"""
async with self._cursor() as cur:
await cur.execute(self.MIGRATIONS[0])
results = await cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
)
row = await results.fetchone()
if row is None:
version = -1
else:
version = row["v"]
for v, migration in zip(
range(version + 1, len(self.MIGRATIONS)),
self.MIGRATIONS[version + 1 :],
):
await cur.execute(migration)
await cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})")
if self.pipe:
await self.pipe.sync()
async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""List checkpoints from the database asynchronously.
This method retrieves a list of checkpoint tuples from the Postgres database based
on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first).
Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata.
before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None.
limit (Optional[int]): Maximum number of checkpoints to return.
Yields:
AsyncIterator[CheckpointTuple]: An asynchronous iterator of matching checkpoint tuples.
"""
where, args = self._search_where(config, filter, before)
query = self.SELECT_SQL + where + " ORDER BY checkpoint_id DESC"
if limit:
query += f" LIMIT {limit}"
# if we change this to use .stream() we need to make sure to close the cursor
async with self._cursor() as cur:
await cur.execute(query, args, binary=True)
async for value in cur:
yield CheckpointTuple(
{
"configurable": {
"thread_id": value["thread_id"],
"checkpoint_ns": value["checkpoint_ns"],
"checkpoint_id": value["checkpoint_id"],
}
},
await asyncio.to_thread(
self._load_checkpoint,
value["checkpoint"],
value["channel_values"],
value["pending_sends"],
),
self._load_metadata(value["metadata"]),
(
{
"configurable": {
"thread_id": value["thread_id"],
"checkpoint_ns": value["checkpoint_ns"],
"checkpoint_id": value["parent_checkpoint_id"],
}
}
if value["parent_checkpoint_id"]
else None
),
await asyncio.to_thread(self._load_writes, value["pending_writes"]),
)
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database asynchronously.
This method retrieves a checkpoint tuple from the Postgres database based on the
provided config. If the config contains a "checkpoint_id" key, the checkpoint with
the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_id = get_checkpoint_id(config)
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id:
args: tuple[Any, ...] = (thread_id, checkpoint_ns, checkpoint_id)
where = "WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s"
else:
args = (thread_id, checkpoint_ns)
where = "WHERE thread_id = %s AND checkpoint_ns = %s ORDER BY checkpoint_id DESC LIMIT 1"
async with self._cursor() as cur:
await cur.execute(
self.SELECT_SQL + where,
args,
binary=True,
)
async for value in cur:
return CheckpointTuple(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": value["checkpoint_id"],
}
},
await asyncio.to_thread(
self._load_checkpoint,
value["checkpoint"],
value["channel_values"],
value["pending_sends"],
),
self._load_metadata(value["metadata"]),
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": value["parent_checkpoint_id"],
}
}
if value["parent_checkpoint_id"]
else None
),
await asyncio.to_thread(self._load_writes, value["pending_writes"]),
)
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database asynchronously.
This method saves a checkpoint to the Postgres database. The checkpoint is associated
with the provided config and its parent config (if any).
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
"""
configurable = config["configurable"].copy()
thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
checkpoint_id = configurable.pop(
"checkpoint_id", configurable.pop("thread_ts", None)
)
copy = checkpoint.copy()
next_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
async with self._cursor(pipeline=True) as cur:
await cur.executemany(
self.UPSERT_CHECKPOINT_BLOBS_SQL,
await asyncio.to_thread(
self._dump_blobs,
thread_id,
checkpoint_ns,
copy.pop("channel_values"), # type: ignore[misc]
new_versions,
),
)
await cur.execute(
self.UPSERT_CHECKPOINTS_SQL,
(
thread_id,
checkpoint_ns,
checkpoint["id"],
checkpoint_id,
Jsonb(self._dump_checkpoint(copy)),
self._dump_metadata(metadata),
),
)
return next_config
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint asynchronously.
This method saves intermediate writes associated with a checkpoint to the database.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair.
task_id (str): Identifier for the task creating the writes.
"""
query = (
self.UPSERT_CHECKPOINT_WRITES_SQL
if all(w[0] in WRITES_IDX_MAP for w in writes)
else self.INSERT_CHECKPOINT_WRITES_SQL
)
params = await asyncio.to_thread(
self._dump_writes,
config["configurable"]["thread_id"],
config["configurable"]["checkpoint_ns"],
config["configurable"]["checkpoint_id"],
task_id,
writes,
)
async with self._cursor(pipeline=True) as cur:
await cur.executemany(query, params)
@asynccontextmanager
async def _cursor(
self, *, pipeline: bool = False
) -> AsyncIterator[AsyncCursor[DictRow]]:
"""Create a database cursor as a context manager.
Args:
pipeline (bool): whether to use pipeline for the DB operations inside the context manager.
Will be applied regardless of whether the AsyncPostgresSaver instance was initialized with a pipeline.
If pipeline mode is not supported, will fall back to using transaction context manager.
"""
async with _ainternal.get_connection(self.conn) as conn:
if self.pipe:
# a connection in pipeline mode can be used concurrently
# in multiple threads/coroutines, but only one cursor can be
# used at a time
try:
async with conn.cursor(binary=True, row_factory=dict_row) as cur:
yield cur
finally:
if pipeline:
await self.pipe.sync()
elif pipeline:
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
async with (
self.lock,
conn.pipeline(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
# Use connection's transaction context manager when pipeline mode not supported
async with (
self.lock,
conn.transaction(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
async with (
self.lock,
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints from the database.
This method retrieves a list of checkpoint tuples from the Postgres database based
on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first).
Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata.
before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None.
limit (Optional[int]): Maximum number of checkpoints to return.
Yields:
Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples.
"""
aiter_ = self.alist(config, filter=filter, before=before, limit=limit)
while True:
try:
yield asyncio.run_coroutine_threadsafe(
anext(aiter_), # noqa: F821
self.loop,
).result()
except StopAsyncIteration:
break
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database.
This method retrieves a checkpoint tuple from the Postgres database based on the
provided config. If the config contains a "checkpoint_id" key, the checkpoint with
the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
try:
# check if we are in the main thread, only bg threads can block
# we don't check in other methods to avoid the overhead
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AsyncPostgresSaver are only allowed from a "
"different thread. From the main thread, use the async interface."
"For example, use `await checkpointer.aget_tuple(...)` or `await "
"graph.ainvoke(...)`."
)
except RuntimeError:
pass
return asyncio.run_coroutine_threadsafe(
self.aget_tuple(config), self.loop
).result()
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database.
This method saves a checkpoint to the Postgres database. The checkpoint is associated
with the provided config and its parent config (if any).
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
"""
return asyncio.run_coroutine_threadsafe(
self.aput(config, checkpoint, metadata, new_versions), self.loop
).result()
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint.
This method saves intermediate writes associated with a checkpoint to the database.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair.
task_id (str): Identifier for the task creating the writes.
"""
return asyncio.run_coroutine_threadsafe(
self.aput_writes(config, writes, task_id), self.loop
).result()
__all__ = ["AsyncPostgresSaver", "Conn"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py`:
```py
import random
from collections.abc import Sequence
from typing import Any, Optional, cast
from langchain_core.runnables import RunnableConfig
from psycopg.types.json import Jsonb
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
get_checkpoint_id,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol
MetadataInput = Optional[dict[str, Any]]
"""
To add a new migration, add a new string to the MIGRATIONS list.
The position of the migration in the list is the version number.
"""
MIGRATIONS = [
"""CREATE TABLE IF NOT EXISTS checkpoint_migrations (
v INTEGER PRIMARY KEY
);""",
"""CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint JSONB NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}',
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);""",
"""CREATE TABLE IF NOT EXISTS checkpoint_blobs (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
channel TEXT NOT NULL,
version TEXT NOT NULL,
type TEXT NOT NULL,
blob BYTEA,
PRIMARY KEY (thread_id, checkpoint_ns, channel, version)
);""",
"""CREATE TABLE IF NOT EXISTS checkpoint_writes (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx INTEGER NOT NULL,
channel TEXT NOT NULL,
type TEXT,
blob BYTEA NOT NULL,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);""",
"ALTER TABLE checkpoint_blobs ALTER COLUMN blob DROP not null;",
]
SELECT_SQL = f"""
select
thread_id,
checkpoint,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
metadata,
(
select array_agg(array[bl.channel::bytea, bl.type::bytea, bl.blob])
from jsonb_each_text(checkpoint -> 'channel_versions')
inner join checkpoint_blobs bl
on bl.thread_id = checkpoints.thread_id
and bl.checkpoint_ns = checkpoints.checkpoint_ns
and bl.channel = jsonb_each_text.key
and bl.version = jsonb_each_text.value
) as channel_values,
(
select
array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx)
from checkpoint_writes cw
where cw.thread_id = checkpoints.thread_id
and cw.checkpoint_ns = checkpoints.checkpoint_ns
and cw.checkpoint_id = checkpoints.checkpoint_id
) as pending_writes,
(
select array_agg(array[cw.type::bytea, cw.blob] order by cw.task_id, cw.idx)
from checkpoint_writes cw
where cw.thread_id = checkpoints.thread_id
and cw.checkpoint_ns = checkpoints.checkpoint_ns
and cw.checkpoint_id = checkpoints.parent_checkpoint_id
and cw.channel = '{TASKS}'
) as pending_sends
from checkpoints """
UPSERT_CHECKPOINT_BLOBS_SQL = """
INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, version, type, blob)
VALUES (%s, %s, %s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_ns, channel, version) DO NOTHING
"""
UPSERT_CHECKPOINTS_SQL = """
INSERT INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata)
VALUES (%s, %s, %s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id)
DO UPDATE SET
checkpoint = EXCLUDED.checkpoint,
metadata = EXCLUDED.metadata;
"""
UPSERT_CHECKPOINT_WRITES_SQL = """
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET
channel = EXCLUDED.channel,
type = EXCLUDED.type,
blob = EXCLUDED.blob;
"""
INSERT_CHECKPOINT_WRITES_SQL = """
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING
"""
class BasePostgresSaver(BaseCheckpointSaver[str]):
SELECT_SQL = SELECT_SQL
MIGRATIONS = MIGRATIONS
UPSERT_CHECKPOINT_BLOBS_SQL = UPSERT_CHECKPOINT_BLOBS_SQL
UPSERT_CHECKPOINTS_SQL = UPSERT_CHECKPOINTS_SQL
UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL
INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL
jsonplus_serde = JsonPlusSerializer()
supports_pipeline: bool
def _load_checkpoint(
self,
checkpoint: dict[str, Any],
channel_values: list[tuple[bytes, bytes, bytes]],
pending_sends: list[tuple[bytes, bytes]],
) -> Checkpoint:
return {
**checkpoint,
"pending_sends": [
self.serde.loads_typed((c.decode(), b)) for c, b in pending_sends or []
],
"channel_values": self._load_blobs(channel_values),
}
def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]:
return {**checkpoint, "pending_sends": []}
def _load_blobs(
self, blob_values: list[tuple[bytes, bytes, bytes]]
) -> dict[str, Any]:
if not blob_values:
return {}
return {
k.decode(): self.serde.loads_typed((t.decode(), v))
for k, t, v in blob_values
if t.decode() != "empty"
}
def _dump_blobs(
self,
thread_id: str,
checkpoint_ns: str,
values: dict[str, Any],
versions: ChannelVersions,
) -> list[tuple[str, str, str, str, str, Optional[bytes]]]:
if not versions:
return []
return [
(
thread_id,
checkpoint_ns,
k,
cast(str, ver),
*(
self.serde.dumps_typed(values[k])
if k in values
else ("empty", None)
),
)
for k, ver in versions.items()
]
def _load_writes(
self, writes: list[tuple[bytes, bytes, bytes, bytes]]
) -> list[tuple[str, str, Any]]:
return (
[
(
tid.decode(),
channel.decode(),
self.serde.loads_typed((t.decode(), v)),
)
for tid, channel, t, v in writes
]
if writes
else []
)
def _dump_writes(
self,
thread_id: str,
checkpoint_ns: str,
checkpoint_id: str,
task_id: str,
writes: Sequence[tuple[str, Any]],
) -> list[tuple[str, str, str, str, int, str, str, bytes]]:
return [
(
thread_id,
checkpoint_ns,
checkpoint_id,
task_id,
WRITES_IDX_MAP.get(channel, idx),
channel,
*self.serde.dumps_typed(value),
)
for idx, (channel, value) in enumerate(writes)
]
def _load_metadata(self, metadata: dict[str, Any]) -> CheckpointMetadata:
return self.jsonplus_serde.loads(self.jsonplus_serde.dumps(metadata))
def _dump_metadata(self, metadata: CheckpointMetadata) -> str:
serialized_metadata = self.jsonplus_serde.dumps(metadata)
# NOTE: we're using JSON serializer (not msgpack), so we need to remove null characters before writing
return serialized_metadata.decode().replace("\\u0000", "")
def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
if current is None:
current_v = 0
elif isinstance(current, int):
current_v = current
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
next_h = random.random()
return f"{next_v:032}.{next_h:016}"
def _search_where(
self,
config: Optional[RunnableConfig],
filter: MetadataInput,
before: Optional[RunnableConfig] = None,
) -> tuple[str, list[Any]]:
"""Return WHERE clause predicates for alist() given config, filter, before.
This method returns a tuple of a string and a tuple of values. The string
is the parametered WHERE clause predicate (including the WHERE keyword):
"WHERE column1 = $1 AND column2 IS $2". The list of values contains the
values for each of the corresponding parameters.
"""
wheres = []
param_values = []
# construct predicate for config filter
if config:
wheres.append("thread_id = %s ")
param_values.append(config["configurable"]["thread_id"])
checkpoint_ns = config["configurable"].get("checkpoint_ns")
if checkpoint_ns is not None:
wheres.append("checkpoint_ns = %s")
param_values.append(checkpoint_ns)
if checkpoint_id := get_checkpoint_id(config):
wheres.append("checkpoint_id = %s ")
param_values.append(checkpoint_id)
# construct predicate for metadata filter
if filter:
wheres.append("metadata @> %s ")
param_values.append(Jsonb(filter))
# construct predicate for `before`
if before is not None:
wheres.append("checkpoint_id < %s ")
param_values.append(get_checkpoint_id(before))
return (
"WHERE " + " AND ".join(wheres) if wheres else "",
param_values,
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/store/postgres/__init__.py`:
```py
from langgraph.store.postgres.aio import AsyncPostgresStore
from langgraph.store.postgres.base import PostgresStore
__all__ = ["AsyncPostgresStore", "PostgresStore"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/store/postgres/aio.py`:
```py
import asyncio
import logging
from collections.abc import AsyncIterator, Iterable, Sequence
from contextlib import asynccontextmanager
from typing import Any, Callable, Optional, Union, cast
import orjson
from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline, Capabilities
from psycopg.errors import UndefinedTable
from psycopg.rows import DictRow, dict_row
from psycopg_pool import AsyncConnectionPool
from langgraph.checkpoint.postgres import _ainternal
from langgraph.store.base import (
GetOp,
ListNamespacesOp,
Op,
PutOp,
Result,
SearchOp,
)
from langgraph.store.base.batch import AsyncBatchedBaseStore
from langgraph.store.postgres.base import (
_PLACEHOLDER,
BasePostgresStore,
PoolConfig,
PostgresIndexConfig,
Row,
_decode_ns_bytes,
_ensure_index_config,
_group_ops,
_row_to_item,
_row_to_search_item,
)
logger = logging.getLogger(__name__)
class AsyncPostgresStore(AsyncBatchedBaseStore, BasePostgresStore[_ainternal.Conn]):
"""Asynchronous Postgres-backed store with optional vector search using pgvector.
!!! example "Examples"
Basic setup and key-value storage:
```python
from langgraph.store.postgres import AsyncPostgresStore
async with AsyncPostgresStore.from_conn_string(
"postgresql://user:pass@localhost:5432/dbname"
) as store:
await store.setup()
# Store and retrieve data
await store.aput(("users", "123"), "prefs", {"theme": "dark"})
item = await store.aget(("users", "123"), "prefs")
```
Vector search using LangChain embeddings:
```python
from langchain.embeddings import init_embeddings
from langgraph.store.postgres import AsyncPostgresStore
async with AsyncPostgresStore.from_conn_string(
"postgresql://user:pass@localhost:5432/dbname",
index={
"dims": 1536,
"embed": init_embeddings("openai:text-embedding-3-small"),
"fields": ["text"] # specify which fields to embed. Default is the whole serialized value
}
) as store:
await store.setup() # Do this once to run migrations
# Store documents
await store.aput(("docs",), "doc1", {"text": "Python tutorial"})
await store.aput(("docs",), "doc2", {"text": "TypeScript guide"})
# Don't index the following
await store.aput(("docs",), "doc3", {"text": "Other guide"}, index=False)
# Search by similarity
results = await store.asearch(("docs",), query="python programming")
```
Using connection pooling for better performance:
```python
from langgraph.store.postgres import AsyncPostgresStore, PoolConfig
async with AsyncPostgresStore.from_conn_string(
"postgresql://user:pass@localhost:5432/dbname",
pool_config=PoolConfig(
min_size=5,
max_size=20
)
) as store:
await store.setup()
# Use store with connection pooling...
```
Warning:
Make sure to:
1. Call `setup()` before first use to create necessary tables and indexes
2. Have the pgvector extension available to use vector search
3. Use Python 3.10+ for async functionality
Note:
Semantic search is disabled by default. You can enable it by providing an `index` configuration
when creating the store. Without this configuration, all `index` arguments passed to
`put` or `aput`will have no effect.
"""
__slots__ = (
"_deserializer",
"pipe",
"lock",
"supports_pipeline",
"index_config",
"embeddings",
)
def __init__(
self,
conn: _ainternal.Conn,
*,
pipe: Optional[AsyncPipeline] = None,
deserializer: Optional[
Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]
] = None,
index: Optional[PostgresIndexConfig] = None,
) -> None:
if isinstance(conn, AsyncConnectionPool) and pipe is not None:
raise ValueError(
"Pipeline should be used only with a single AsyncConnection, not AsyncConnectionPool."
)
super().__init__()
self._deserializer = deserializer
self.conn = conn
self.pipe = pipe
self.lock = asyncio.Lock()
self.loop = asyncio.get_running_loop()
self.supports_pipeline = Capabilities().has_pipeline()
self.index_config = index
if self.index_config:
self.embeddings, self.index_config = _ensure_index_config(self.index_config)
else:
self.embeddings = None
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
grouped_ops, num_ops = _group_ops(ops)
results: list[Result] = [None] * num_ops
async with _ainternal.get_connection(self.conn) as conn:
if self.pipe:
async with self.pipe:
await self._execute_batch(grouped_ops, results, conn)
else:
await self._execute_batch(grouped_ops, results, conn)
return results
def batch(self, ops: Iterable[Op]) -> list[Result]:
return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result()
@classmethod
@asynccontextmanager
async def from_conn_string(
cls,
conn_string: str,
*,
pipeline: bool = False,
pool_config: Optional[PoolConfig] = None,
index: Optional[PostgresIndexConfig] = None,
) -> AsyncIterator["AsyncPostgresStore"]:
"""Create a new AsyncPostgresStore instance from a connection string.
Args:
conn_string (str): The Postgres connection info string.
pipeline (bool): Whether to use AsyncPipeline (only for single connections)
pool_config (Optional[PoolConfig]): Configuration for the connection pool.
If provided, will create a connection pool and use it instead of a single connection.
This overrides the `pipeline` argument.
index (Optional[PostgresIndexConfig]): The embedding config.
Returns:
AsyncPostgresStore: A new AsyncPostgresStore instance.
"""
if pool_config is not None:
pc = pool_config.copy()
async with cast(
AsyncConnectionPool[AsyncConnection[DictRow]],
AsyncConnectionPool(
conn_string,
min_size=pc.pop("min_size", 1),
max_size=pc.pop("max_size", None),
kwargs={
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
**(pc.pop("kwargs", None) or {}),
},
**cast(dict, pc),
),
) as pool:
yield cls(conn=pool, index=index)
else:
async with await AsyncConnection.connect(
conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row
) as conn:
if pipeline:
async with conn.pipeline() as pipe:
yield cls(conn=conn, pipe=pipe, index=index)
else:
yield cls(conn=conn, index=index)
async def setup(self) -> None:
"""Set up the store database asynchronously.
This method creates the necessary tables in the Postgres database if they don't
already exist and runs database migrations. It MUST be called directly by the user
the first time the store is used.
"""
async def _get_version(cur: AsyncCursor[DictRow], table: str) -> int:
try:
await cur.execute(f"SELECT v FROM {table} ORDER BY v DESC LIMIT 1")
row = await cur.fetchone()
if row is None:
version = -1
else:
version = row["v"]
except UndefinedTable:
version = -1
await cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {table} (
v INTEGER PRIMARY KEY
)
"""
)
return version
async with self._cursor() as cur:
version = await _get_version(cur, table="store_migrations")
for v, sql in enumerate(self.MIGRATIONS[version + 1 :], start=version + 1):
await cur.execute(sql)
await cur.execute("INSERT INTO store_migrations (v) VALUES (%s)", (v,))
if self.index_config:
version = await _get_version(cur, table="vector_migrations")
for v, migration in enumerate(
self.VECTOR_MIGRATIONS[version + 1 :], start=version + 1
):
sql = migration.sql
if migration.params:
params = {
k: v(self) if v is not None and callable(v) else v
for k, v in migration.params.items()
}
sql = sql % params
await cur.execute(sql)
await cur.execute(
"INSERT INTO vector_migrations (v) VALUES (%s)", (v,)
)
async def _execute_batch(
self,
grouped_ops: dict,
results: list[Result],
conn: AsyncConnection[DictRow],
) -> None:
async with self._cursor(pipeline=True) as cur:
if GetOp in grouped_ops:
await self._batch_get_ops(
cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]),
results,
cur,
)
if SearchOp in grouped_ops:
await self._batch_search_ops(
cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]),
results,
cur,
)
if ListNamespacesOp in grouped_ops:
await self._batch_list_namespaces_ops(
cast(
Sequence[tuple[int, ListNamespacesOp]],
grouped_ops[ListNamespacesOp],
),
results,
cur,
)
if PutOp in grouped_ops:
await self._batch_put_ops(
cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]),
cur,
)
async def _batch_get_ops(
self,
get_ops: Sequence[tuple[int, GetOp]],
results: list[Result],
cur: AsyncCursor[DictRow],
) -> None:
for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops):
await cur.execute(query, params)
rows = cast(list[Row], await cur.fetchall())
key_to_row = {row["key"]: row for row in rows}
for idx, key in items:
row = key_to_row.get(key)
if row:
results[idx] = _row_to_item(
namespace, row, loader=self._deserializer
)
else:
results[idx] = None
async def _batch_put_ops(
self,
put_ops: Sequence[tuple[int, PutOp]],
cur: AsyncCursor[DictRow],
) -> None:
queries, embedding_request = self._prepare_batch_PUT_queries(put_ops)
if embedding_request:
if self.embeddings is None:
# Should not get here since the embedding config is required
# to return an embedding_request above
raise ValueError(
"Embedding configuration is required for vector operations "
f"(for semantic search). "
f"Please provide an EmbeddingConfig when initializing the {self.__class__.__name__}."
)
query, txt_params = embedding_request
vectors = await self.embeddings.aembed_documents(
[param[-1] for param in txt_params]
)
queries.append(
(
query,
[
p
for (ns, k, pathname, _), vector in zip(txt_params, vectors)
for p in (ns, k, pathname, vector)
],
)
)
for query, params in queries:
await cur.execute(query, params)
async def _batch_search_ops(
self,
search_ops: Sequence[tuple[int, SearchOp]],
results: list[Result],
cur: AsyncCursor[DictRow],
) -> None:
queries, embedding_requests = self._prepare_batch_search_queries(search_ops)
if embedding_requests and self.embeddings:
vectors = await self.embeddings.aembed_documents(
[query for _, query in embedding_requests]
)
for (idx, _), vector in zip(embedding_requests, vectors):
_paramslist = queries[idx][1]
for i in range(len(_paramslist)):
if _paramslist[i] is _PLACEHOLDER:
_paramslist[i] = vector
for (idx, _), (query, params) in zip(search_ops, queries):
await cur.execute(query, params)
rows = cast(list[Row], await cur.fetchall())
items = [
_row_to_search_item(
_decode_ns_bytes(row["prefix"]), row, loader=self._deserializer
)
for row in rows
]
results[idx] = items
async def _batch_list_namespaces_ops(
self,
list_ops: Sequence[tuple[int, ListNamespacesOp]],
results: list[Result],
cur: AsyncCursor[DictRow],
) -> None:
queries = self._get_batch_list_namespaces_queries(list_ops)
for (query, params), (idx, _) in zip(queries, list_ops):
await cur.execute(query, params)
rows = cast(list[dict], await cur.fetchall())
namespaces = [_decode_ns_bytes(row["truncated_prefix"]) for row in rows]
results[idx] = namespaces
@asynccontextmanager
async def _cursor(
self, *, pipeline: bool = False
) -> AsyncIterator[AsyncCursor[DictRow]]:
"""Create a database cursor as a context manager.
Args:
pipeline: whether to use pipeline for the DB operations inside the context manager.
Will be applied regardless of whether the PostgresStore instance was initialized with a pipeline.
If pipeline mode is not supported, will fall back to using transaction context manager.
"""
async with _ainternal.get_connection(self.conn) as conn:
if self.pipe:
# a connection in pipeline mode can be used concurrently
# in multiple threads/coroutines, but only one cursor can be
# used at a time
try:
async with conn.cursor(binary=True, row_factory=dict_row) as cur:
yield cur
finally:
if pipeline:
await self.pipe.sync()
elif pipeline:
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
async with (
self.lock,
conn.pipeline(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
async with (
self.lock,
conn.transaction(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
async with (
self.lock,
conn.cursor(binary=True) as cur,
):
yield cur
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/langgraph/store/postgres/base.py`:
```py
import asyncio
import json
import logging
import threading
from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Literal,
NamedTuple,
Optional,
TypeVar,
Union,
cast,
)
import orjson
from psycopg import Capabilities, Connection, Cursor, Pipeline
from psycopg.errors import UndefinedTable
from psycopg.rows import DictRow, dict_row
from psycopg.types.json import Jsonb
from psycopg_pool import ConnectionPool
from typing_extensions import TypedDict
from langgraph.checkpoint.postgres import _ainternal as _ainternal
from langgraph.checkpoint.postgres import _internal as _pg_internal
from langgraph.store.base import (
BaseStore,
GetOp,
IndexConfig,
Item,
ListNamespacesOp,
Op,
PutOp,
Result,
SearchItem,
SearchOp,
ensure_embeddings,
get_text_at_path,
tokenize_path,
)
if TYPE_CHECKING:
from langchain_core.embeddings import Embeddings
logger = logging.getLogger(__name__)
class Migration(NamedTuple):
"""A database migration with optional conditions and parameters."""
sql: str
params: Optional[dict[str, Any]] = None
condition: Optional[Callable[["BasePostgresStore"], bool]] = None
MIGRATIONS: Sequence[str] = [
"""
CREATE TABLE IF NOT EXISTS store (
-- 'prefix' represents the doc's 'namespace'
prefix text NOT NULL,
key text NOT NULL,
value jsonb NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (prefix, key)
);
""",
"""
-- For faster lookups by prefix
CREATE INDEX IF NOT EXISTS store_prefix_idx ON store USING btree (prefix text_pattern_ops);
""",
]
VECTOR_MIGRATIONS: Sequence[Migration] = [
Migration(
"""
CREATE EXTENSION IF NOT EXISTS vector;
""",
),
Migration(
"""
CREATE TABLE IF NOT EXISTS store_vectors (
prefix text NOT NULL,
key text NOT NULL,
field_name text NOT NULL,
embedding %(vector_type)s(%(dims)s),
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (prefix, key, field_name),
FOREIGN KEY (prefix, key) REFERENCES store(prefix, key) ON DELETE CASCADE
);
""",
params={
"dims": lambda store: store.index_config["dims"],
"vector_type": lambda store: (
cast(PostgresIndexConfig, store.index_config)
.get("ann_index_config", {})
.get("vector_type", "vector")
),
},
),
Migration(
"""
CREATE INDEX IF NOT EXISTS store_vectors_embedding_idx ON store_vectors
USING %(index_type)s (embedding %(ops)s)%(index_params)s;
""",
condition=lambda store: bool(
store.index_config and _get_index_params(store)[0] != "flat"
),
params={
"index_type": lambda store: _get_index_params(store)[0],
"ops": lambda store: _get_vector_type_ops(store),
"index_params": lambda store: (
" WITH ("
+ ", ".join(f"{k}={v}" for k, v in _get_index_params(store)[1].items())
+ ")"
if _get_index_params(store)[1]
else ""
),
},
),
]
C = TypeVar("C", bound=Union[_pg_internal.Conn, _ainternal.Conn])
class PoolConfig(TypedDict, total=False):
"""Connection pool settings for PostgreSQL connections.
Controls connection lifecycle and resource utilization:
- Small pools (1-5) suit low-concurrency workloads
- Larger pools handle concurrent requests but consume more resources
- Setting max_size prevents resource exhaustion under load
"""
min_size: int
"""Minimum number of connections maintained in the pool. Defaults to 1."""
max_size: Optional[int]
"""Maximum number of connections allowed in the pool. None means unlimited."""
kwargs: dict
"""Additional connection arguments passed to each connection in the pool.
Default kwargs set automatically:
- autocommit: True
- prepare_threshold: 0
- row_factory: dict_row
"""
class ANNIndexConfig(TypedDict, total=False):
"""Configuration for vector index in PostgreSQL store."""
kind: Literal["hnsw", "ivfflat", "flat"]
"""Type of index to use: 'hnsw' for Hierarchical Navigable Small World, or 'ivfflat' for Inverted File Flat."""
vector_type: Literal["vector", "halfvec"]
"""Type of vector storage to use.
Options:
- 'vector': Regular vectors (default)
- 'halfvec': Half-precision vectors for reduced memory usage
"""
class HNSWConfig(ANNIndexConfig, total=False):
"""Configuration for HNSW (Hierarchical Navigable Small World) index."""
kind: Literal["hnsw"] # type: ignore[misc]
m: int
"""Maximum number of connections per layer. Default is 16."""
ef_construction: int
"""Size of dynamic candidate list for index construction. Default is 64."""
class IVFFlatConfig(ANNIndexConfig, total=False):
"""IVFFlat index divides vectors into lists, and then searches a subset of those lists that are closest to the query vector. It has faster build times and uses less memory than HNSW, but has lower query performance (in terms of speed-recall tradeoff).
Three keys to achieving good recall are:
1. Create the index after the table has some data
2. Choose an appropriate number of lists - a good place to start is rows / 1000 for up to 1M rows and sqrt(rows) for over 1M rows
3. When querying, specify an appropriate number of probes (higher is better for recall, lower is better for speed) - a good place to start is sqrt(lists)
"""
kind: Literal["ivfflat"] # type: ignore[misc]
nlist: int
"""Number of inverted lists (clusters) for IVF index.
Determines the number of clusters used in the index structure.
Higher values can improve search speed but increase index size and build time.
Typically set to the square root of the number of vectors in the index.
"""
class PostgresIndexConfig(IndexConfig, total=False):
"""Configuration for vector embeddings in PostgreSQL store with pgvector-specific options.
Extends EmbeddingConfig with additional configuration for pgvector index and vector types.
"""
ann_index_config: ANNIndexConfig
"""Specific configuration for the chosen index type (HNSW or IVF Flat)."""
distance_type: Literal["l2", "inner_product", "cosine"]
"""Distance metric to use for vector similarity search:
- 'l2': Euclidean distance
- 'inner_product': Dot product
- 'cosine': Cosine similarity
"""
class BasePostgresStore(Generic[C]):
MIGRATIONS = MIGRATIONS
VECTOR_MIGRATIONS = VECTOR_MIGRATIONS
conn: C
_deserializer: Optional[Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]]
index_config: Optional[PostgresIndexConfig]
def _get_batch_GET_ops_queries(
self,
get_ops: Sequence[tuple[int, GetOp]],
) -> list[tuple[str, tuple, tuple[str, ...], list]]:
namespace_groups = defaultdict(list)
for idx, op in get_ops:
namespace_groups[op.namespace].append((idx, op.key))
results = []
for namespace, items in namespace_groups.items():
_, keys = zip(*items)
keys_to_query = ",".join(["%s"] * len(keys))
query = f"""
SELECT key, value, created_at, updated_at
FROM store
WHERE prefix = %s AND key IN ({keys_to_query})
"""
params = (_namespace_to_text(namespace), *keys)
results.append((query, params, namespace, items))
return results
def _prepare_batch_PUT_queries(
self,
put_ops: Sequence[tuple[int, PutOp]],
) -> tuple[
list[tuple[str, Sequence]],
Optional[tuple[str, Sequence[tuple[str, str, str, str]]]],
]:
# Last-write wins
dedupped_ops: dict[tuple[tuple[str, ...], str], PutOp] = {}
for _, op in put_ops:
dedupped_ops[(op.namespace, op.key)] = op
inserts: list[PutOp] = []
deletes: list[PutOp] = []
for op in dedupped_ops.values():
if op.value is None:
deletes.append(op)
else:
inserts.append(op)
queries: list[tuple[str, Sequence]] = []
if deletes:
namespace_groups: dict[tuple[str, ...], list[str]] = defaultdict(list)
for op in deletes:
namespace_groups[op.namespace].append(op.key)
for namespace, keys in namespace_groups.items():
placeholders = ",".join(["%s"] * len(keys))
query = (
f"DELETE FROM store WHERE prefix = %s AND key IN ({placeholders})"
)
params = (_namespace_to_text(namespace), *keys)
queries.append((query, params))
embedding_request: Optional[tuple[str, Sequence[tuple[str, str, str, str]]]] = (
None
)
if inserts:
values = []
insertion_params = []
vector_values = []
embedding_request_params = []
# First handle main store insertions
for op in inserts:
values.append("(%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)")
insertion_params.extend(
[
_namespace_to_text(op.namespace),
op.key,
Jsonb(cast(dict, op.value)),
]
)
# Then handle embeddings if configured
if self.index_config:
for op in inserts:
if op.index is False:
continue
value = op.value
ns = _namespace_to_text(op.namespace)
k = op.key
if op.index is None:
paths = self.index_config["__tokenized_fields"]
else:
paths = [(ix, tokenize_path(ix)) for ix in op.index]
for path, tokenized_path in paths:
texts = get_text_at_path(value, tokenized_path)
for i, text in enumerate(texts):
pathname = f"{path}.{i}" if len(texts) > 1 else path
vector_values.append(
"(%s, %s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)"
)
embedding_request_params.append((ns, k, pathname, text))
values_str = ",".join(values)
query = f"""
INSERT INTO store (prefix, key, value, created_at, updated_at)
VALUES {values_str}
ON CONFLICT (prefix, key) DO UPDATE
SET value = EXCLUDED.value,
updated_at = CURRENT_TIMESTAMP
"""
queries.append((query, insertion_params))
if vector_values:
values_str = ",".join(vector_values)
query = f"""
INSERT INTO store_vectors (prefix, key, field_name, embedding, created_at, updated_at)
VALUES {values_str}
ON CONFLICT (prefix, key, field_name) DO UPDATE
SET embedding = EXCLUDED.embedding,
updated_at = CURRENT_TIMESTAMP
"""
embedding_request = (query, embedding_request_params)
return queries, embedding_request
def _prepare_batch_search_queries(
self,
search_ops: Sequence[tuple[int, SearchOp]],
) -> tuple[
list[tuple[str, list[Union[None, str, list[float]]]]], # queries, params
list[tuple[int, str]], # idx, query_text pairs to embed
]:
queries = []
embedding_requests = []
for idx, (_, op) in enumerate(search_ops):
# Build filter conditions first
filter_params = []
filter_conditions = []
if op.filter:
for key, value in op.filter.items():
if isinstance(value, dict):
for op_name, val in value.items():
condition, filter_params_ = self._get_filter_condition(
key, op_name, val
)
filter_conditions.append(condition)
filter_params.extend(filter_params_)
else:
filter_conditions.append("value->%s = %s::jsonb")
filter_params.extend([key, json.dumps(value)])
# Vector search branch
if op.query and self.index_config:
embedding_requests.append((idx, op.query))
score_operator, post_operator = _get_distance_operator(self)
vector_type = (
cast(PostgresIndexConfig, self.index_config)
.get("ann_index_config", {})
.get("vector_type", "vector")
)
if (
vector_type == "bit"
and self.index_config.get("distance_type") == "hamming"
):
score_operator = score_operator % (
"%s",
self.index_config["dims"],
)
else:
score_operator = score_operator % (
"%s",
vector_type,
)
vectors_per_doc_estimate = self.index_config["__estimated_num_vectors"]
expanded_limit = (op.limit * vectors_per_doc_estimate * 2) + 1
# Vector search with CTE for proper score handling
filter_str = (
""
if not filter_conditions
else " AND " + " AND ".join(filter_conditions)
)
if op.namespace_prefix:
prefix_filter_str = f"WHERE s.prefix LIKE %s {filter_str} "
ns_args: Sequence = (f"{_namespace_to_text(op.namespace_prefix)}%",)
else:
ns_args = ()
if filter_str:
prefix_filter_str = f"WHERE {filter_str} "
else:
prefix_filter_str = ""
base_query = f"""
WITH scored AS (
SELECT s.prefix, s.key, s.value, s.created_at, s.updated_at, {score_operator} AS neg_score
FROM store s
JOIN store_vectors sv ON s.prefix = sv.prefix AND s.key = sv.key
{prefix_filter_str}
ORDER BY {score_operator} ASC
LIMIT %s
)
SELECT * FROM (
SELECT DISTINCT ON (prefix, key)
prefix, key, value, created_at, updated_at, {post_operator} as score
FROM scored
ORDER BY prefix, key, score DESC
) AS unique_docs
ORDER BY score DESC
LIMIT %s
OFFSET %s
"""
params = [
_PLACEHOLDER, # Vector placeholder
*ns_args,
*filter_params,
_PLACEHOLDER,
expanded_limit,
op.limit,
op.offset,
]
# Regular search branch
else:
base_query = """
SELECT prefix, key, value, created_at, updated_at
FROM store
WHERE prefix LIKE %s
"""
params = [f"{_namespace_to_text(op.namespace_prefix)}%"]
if filter_conditions:
params.extend(filter_params)
base_query += " AND " + " AND ".join(filter_conditions)
base_query += " ORDER BY updated_at DESC"
base_query += " LIMIT %s OFFSET %s"
params.extend([op.limit, op.offset])
queries.append((base_query, params))
return queries, embedding_requests
def _get_batch_list_namespaces_queries(
self,
list_ops: Sequence[tuple[int, ListNamespacesOp]],
) -> list[tuple[str, Sequence]]:
queries: list[tuple[str, Sequence]] = []
for _, op in list_ops:
query = """
SELECT DISTINCT ON (truncated_prefix) truncated_prefix, prefix
FROM (
SELECT
prefix,
CASE
WHEN %s::integer IS NOT NULL THEN
(SELECT STRING_AGG(part, '.' ORDER BY idx)
FROM (
SELECT part, ROW_NUMBER() OVER () AS idx
FROM UNNEST(REGEXP_SPLIT_TO_ARRAY(prefix, '\.')) AS part
LIMIT %s::integer
) subquery
)
ELSE prefix
END AS truncated_prefix
FROM store
"""
params: list[Any] = [op.max_depth, op.max_depth]
conditions = []
if op.match_conditions:
for condition in op.match_conditions:
if condition.match_type == "prefix":
conditions.append("prefix LIKE %s")
params.append(
f"{_namespace_to_text(condition.path, handle_wildcards=True)}%"
)
elif condition.match_type == "suffix":
conditions.append("prefix LIKE %s")
params.append(
f"%{_namespace_to_text(condition.path, handle_wildcards=True)}"
)
else:
logger.warning(
f"Unknown match_type in list_namespaces: {condition.match_type}"
)
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += ") AS subquery "
query += " ORDER BY truncated_prefix LIMIT %s OFFSET %s"
params.extend([op.limit, op.offset])
queries.append((query, tuple(params)))
return queries
def _get_filter_condition(self, key: str, op: str, value: Any) -> tuple[str, list]:
"""Helper to generate filter conditions."""
if op == "$eq":
return "value->%s = %s::jsonb", [key, json.dumps(value)]
elif op == "$gt":
return "value->>%s > %s", [key, str(value)]
elif op == "$gte":
return "value->>%s >= %s", [key, str(value)]
elif op == "$lt":
return "value->>%s < %s", [key, str(value)]
elif op == "$lte":
return "value->>%s <= %s", [key, str(value)]
elif op == "$ne":
return "value->%s != %s::jsonb", [key, json.dumps(value)]
else:
raise ValueError(f"Unsupported operator: {op}")
class PostgresStore(BaseStore, BasePostgresStore[_pg_internal.Conn]):
"""Postgres-backed store with optional vector search using pgvector.
!!! example "Examples"
Basic setup and key-value storage:
```python
from langgraph.store.postgres import PostgresStore
store = PostgresStore(
connection_string="postgresql://user:pass@localhost:5432/dbname"
)
store.setup()
# Store and retrieve data
store.put(("users", "123"), "prefs", {"theme": "dark"})
item = store.get(("users", "123"), "prefs")
```
Vector search using LangChain embeddings:
```python
from langchain.embeddings import init_embeddings
from langgraph.store.postgres import PostgresStore
store = PostgresStore(
connection_string="postgresql://user:pass@localhost:5432/dbname",
index={
"dims": 1536,
"embed": init_embeddings("openai:text-embedding-3-small"),
"fields": ["text"] # specify which fields to embed. Default is the whole serialized value
}
)
store.setup() # Do this once to run migrations
# Store documents
store.put(("docs",), "doc1", {"text": "Python tutorial"})
store.put(("docs",), "doc2", {"text": "TypeScript guide"})
store.put(("docs",), "doc2", {"text": "Other guide"}, index=False) # don't index
# Search by similarity
results = store.search(("docs",), query="python programming")
Note:
Semantic search is disabled by default. You can enable it by providing an `index` configuration
when creating the store. Without this configuration, all `index` arguments passed to
`put` or `aput`will have no effect.
Warning:
Make sure to call `setup()` before first use to create necessary tables and indexes.
The pgvector extension must be available to use vector search.
"""
__slots__ = (
"_deserializer",
"pipe",
"lock",
"supports_pipeline",
"index_config",
"embeddings",
)
def __init__(
self,
conn: _pg_internal.Conn,
*,
pipe: Optional[Pipeline] = None,
deserializer: Optional[
Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]
] = None,
index: Optional[PostgresIndexConfig] = None,
) -> None:
super().__init__()
self._deserializer = deserializer
self.conn = conn
self.pipe = pipe
self.supports_pipeline = Capabilities().has_pipeline()
self.lock = threading.Lock()
self.index_config = index
if self.index_config:
self.embeddings, self.index_config = _ensure_index_config(self.index_config)
else:
self.embeddings = None
@classmethod
@contextmanager
def from_conn_string(
cls,
conn_string: str,
*,
pipeline: bool = False,
pool_config: Optional[PoolConfig] = None,
index: Optional[PostgresIndexConfig] = None,
) -> Iterator["PostgresStore"]:
"""Create a new PostgresStore instance from a connection string.
Args:
conn_string (str): The Postgres connection info string.
pipeline (bool): whether to use Pipeline
pool_config (Optional[PoolArgs]): Configuration for the connection pool.
If provided, will create a connection pool and use it instead of a single connection.
This overrides the `pipeline` argument.
index (Optional[PostgresIndexConfig]): The index configuration for the store.
Returns:
PostgresStore: A new PostgresStore instance.
"""
if pool_config is not None:
pc = pool_config.copy()
with cast(
ConnectionPool[Connection[DictRow]],
ConnectionPool(
conn_string,
min_size=pc.pop("min_size", 1),
max_size=pc.pop("max_size", None),
kwargs={
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
**(pc.pop("kwargs", None) or {}),
},
**cast(dict, pc),
),
) as pool:
yield cls(conn=pool, index=index)
else:
with Connection.connect(
conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row
) as conn:
if pipeline:
with conn.pipeline() as pipe:
yield cls(conn, pipe=pipe, index=index)
else:
yield cls(conn, index=index)
@contextmanager
def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
"""Create a database cursor as a context manager.
Args:
pipeline (bool): whether to use pipeline for the DB operations inside the context manager.
Will be applied regardless of whether the PostgresStore instance was initialized with a pipeline.
If pipeline mode is not supported, will fall back to using transaction context manager.
"""
with _pg_internal.get_connection(self.conn) as conn:
if self.pipe:
# a connection in pipeline mode can be used concurrently
# in multiple threads/coroutines, but only one cursor can be
# used at a time
try:
with conn.cursor(binary=True, row_factory=dict_row) as cur:
yield cur
finally:
if pipeline:
self.pipe.sync()
elif pipeline:
# a connection not in pipeline mode can only be used by one
# thread/coroutine at a time, so we acquire a lock
if self.supports_pipeline:
with (
self.lock,
conn.pipeline(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
with (
self.lock,
conn.transaction(),
conn.cursor(binary=True, row_factory=dict_row) as cur,
):
yield cur
else:
with conn.cursor(binary=True, row_factory=dict_row) as cur:
yield cur
def batch(self, ops: Iterable[Op]) -> list[Result]:
grouped_ops, num_ops = _group_ops(ops)
results: list[Result] = [None] * num_ops
with self._cursor(pipeline=True) as cur:
if GetOp in grouped_ops:
self._batch_get_ops(
cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), results, cur
)
if SearchOp in grouped_ops:
self._batch_search_ops(
cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]),
results,
cur,
)
if ListNamespacesOp in grouped_ops:
self._batch_list_namespaces_ops(
cast(
Sequence[tuple[int, ListNamespacesOp]],
grouped_ops[ListNamespacesOp],
),
results,
cur,
)
if PutOp in grouped_ops:
self._batch_put_ops(
cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]), cur
)
return results
def _batch_get_ops(
self,
get_ops: Sequence[tuple[int, GetOp]],
results: list[Result],
cur: Cursor[DictRow],
) -> None:
for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops):
cur.execute(query, params)
rows = cast(list[Row], cur.fetchall())
key_to_row = {row["key"]: row for row in rows}
for idx, key in items:
row = key_to_row.get(key)
if row:
results[idx] = _row_to_item(
namespace, row, loader=self._deserializer
)
else:
results[idx] = None
def _batch_put_ops(
self,
put_ops: Sequence[tuple[int, PutOp]],
cur: Cursor[DictRow],
) -> None:
queries, embedding_request = self._prepare_batch_PUT_queries(put_ops)
if embedding_request:
if self.embeddings is None:
# Should not get here since the embedding config is required
# to return an embedding_request above
raise ValueError(
"Embedding configuration is required for vector operations "
f"(for semantic search). "
f"Please provide an Embeddings when initializing the {self.__class__.__name__}."
)
query, txt_params = embedding_request
# Update the params to replace the raw text with the vectors
vectors = self.embeddings.embed_documents(
[param[-1] for param in txt_params]
)
queries.append(
(
query,
[
p
for (ns, k, pathname, _), vector in zip(txt_params, vectors)
for p in (ns, k, pathname, vector)
],
)
)
for query, params in queries:
cur.execute(query, params)
def _batch_search_ops(
self,
search_ops: Sequence[tuple[int, SearchOp]],
results: list[Result],
cur: Cursor[DictRow],
) -> None:
queries, embedding_requests = self._prepare_batch_search_queries(search_ops)
if embedding_requests and self.embeddings:
embeddings = self.embeddings.embed_documents(
[query for _, query in embedding_requests]
)
for (idx, _), embedding in zip(embedding_requests, embeddings):
_paramslist = queries[idx][1]
for i in range(len(_paramslist)):
if _paramslist[i] is _PLACEHOLDER:
_paramslist[i] = embedding
for (idx, _), (query, params) in zip(search_ops, queries):
cur.execute(query, params)
rows = cast(list[Row], cur.fetchall())
results[idx] = [
_row_to_search_item(
_decode_ns_bytes(row["prefix"]), row, loader=self._deserializer
)
for row in rows
]
def _batch_list_namespaces_ops(
self,
list_ops: Sequence[tuple[int, ListNamespacesOp]],
results: list[Result],
cur: Cursor[DictRow],
) -> None:
for (query, params), (idx, _) in zip(
self._get_batch_list_namespaces_queries(list_ops), list_ops
):
cur.execute(query, params)
results[idx] = [_decode_ns_bytes(row["truncated_prefix"]) for row in cur]
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
return await asyncio.get_running_loop().run_in_executor(None, self.batch, ops)
def setup(self) -> None:
"""Set up the store database.
This method creates the necessary tables in the Postgres database if they don't
already exist and runs database migrations. It MUST be called directly by the user
the first time the store is used.
"""
def _get_version(cur: Cursor[dict[str, Any]], table: str) -> int:
try:
cur.execute(f"SELECT v FROM {table} ORDER BY v DESC LIMIT 1")
row = cast(dict, cur.fetchone())
if row is None:
version = -1
else:
version = row["v"]
except UndefinedTable:
version = -1
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {table} (
v INTEGER PRIMARY KEY
)
"""
)
return version
with self._cursor() as cur:
version = _get_version(cur, table="store_migrations")
for v, sql in enumerate(self.MIGRATIONS[version + 1 :], start=version + 1):
cur.execute(sql)
cur.execute("INSERT INTO store_migrations (v) VALUES (%s)", (v,))
if self.index_config:
version = _get_version(cur, table="vector_migrations")
for v, migration in enumerate(
self.VECTOR_MIGRATIONS[version + 1 :], start=version + 1
):
if migration.condition and not migration.condition(self):
continue
sql = migration.sql
if migration.params:
params = {
k: v(self) if v is not None and callable(v) else v
for k, v in migration.params.items()
}
sql = sql % params
cur.execute(sql)
cur.execute("INSERT INTO vector_migrations (v) VALUES (%s)", (v,))
class Row(TypedDict):
key: str
value: Any
prefix: str
created_at: datetime
updated_at: datetime
# Private utilities
_DEFAULT_ANN_CONFIG = ANNIndexConfig(
vector_type="vector",
)
def _get_vector_type_ops(store: BasePostgresStore) -> str:
"""Get the vector type operator class based on config."""
if not store.index_config:
return "vector_cosine_ops"
config = cast(PostgresIndexConfig, store.index_config)
index_config = config.get("ann_index_config", _DEFAULT_ANN_CONFIG).copy()
vector_type = cast(str, index_config.get("vector_type", "vector"))
if vector_type not in ("vector", "halfvec"):
raise ValueError(
f"Vector type must be 'vector' or 'halfvec', got {vector_type}"
)
distance_type = config.get("distance_type", "cosine")
# For regular vectors
type_prefix = {"vector": "vector", "halfvec": "halfvec"}[vector_type]
if distance_type not in ("l2", "inner_product", "cosine"):
raise ValueError(
f"Vector type {vector_type} only supports 'l2', 'inner_product', or 'cosine' distance, got {distance_type}"
)
distance_suffix = {
"l2": "l2_ops",
"inner_product": "ip_ops",
"cosine": "cosine_ops",
}[distance_type]
return f"{type_prefix}_{distance_suffix}"
def _get_index_params(store: Any) -> tuple[str, dict[str, Any]]:
"""Get the index type and configuration based on config."""
if not store.index_config:
return "hnsw", {}
config = cast(PostgresIndexConfig, store.index_config)
index_config = config.get("ann_index_config", _DEFAULT_ANN_CONFIG).copy()
kind = index_config.pop("kind", "hnsw")
index_config.pop("vector_type", None)
return kind, index_config
def _namespace_to_text(
namespace: tuple[str, ...], handle_wildcards: bool = False
) -> str:
"""Convert namespace tuple to text string."""
if handle_wildcards:
namespace = tuple("%" if val == "*" else val for val in namespace)
return ".".join(namespace)
def _row_to_item(
namespace: tuple[str, ...],
row: Row,
*,
loader: Optional[Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]] = None,
) -> Item:
"""Convert a row from the database into an Item.
Args:
namespace: Item namespace
row: Database row
loader: Optional value loader for non-dict values
"""
val = row["value"]
if not isinstance(val, dict):
val = (loader or _json_loads)(val)
kwargs = {
"key": row["key"],
"namespace": namespace,
"value": val,
"created_at": row["created_at"],
"updated_at": row["updated_at"],
}
return Item(**kwargs)
def _row_to_search_item(
namespace: tuple[str, ...],
row: Row,
*,
loader: Optional[Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]] = None,
) -> SearchItem:
"""Convert a row from the database into an Item."""
loader = loader or _json_loads
val = row["value"]
score = row.get("score")
if score is not None:
try:
score = float(score) # type: ignore[arg-type]
except ValueError:
logger.warning("Invalid score: %s", score)
score = None
return SearchItem(
value=val if isinstance(val, dict) else loader(val),
key=row["key"],
namespace=namespace,
created_at=row["created_at"],
updated_at=row["updated_at"],
score=score,
)
def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]:
grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list)
tot = 0
for idx, op in enumerate(ops):
grouped_ops[type(op)].append((idx, op))
tot += 1
return grouped_ops, tot
def _json_loads(content: Union[bytes, orjson.Fragment]) -> Any:
if isinstance(content, orjson.Fragment):
if hasattr(content, "buf"):
content = content.buf
else:
if isinstance(content.contents, bytes):
content = content.contents
else:
content = content.contents.encode()
return orjson.loads(cast(bytes, content))
def _decode_ns_bytes(namespace: Union[str, bytes, list]) -> tuple[str, ...]:
if isinstance(namespace, list):
return tuple(namespace)
if isinstance(namespace, bytes):
namespace = namespace.decode()[1:]
return tuple(namespace.split("."))
def _get_distance_operator(store: Any) -> tuple[str, str]:
"""Get the distance operator and score expression based on config."""
# Note: Today, we are not using ANN indices due to restrictions
# on PGVector's support for mixing vector and non-vector filters
# To use the index, PGVector expects:
# - ORDER BY the operator NOT an expression (even negation blocks it)
# - ASCENDING order
# - Any WHERE clause should be over a partial index.
# If we violate any of these, it will use a sequential scan
# See https://github.com/pgvector/pgvector/issues/216 and the
# pgvector documentation for more details.
if not store.index_config:
raise ValueError(
"Embedding configuration is required for vector operations "
f"(for semantic search). "
f"Please provide an Embeddings when initializing the {store.__class__.__name__}."
)
config = cast(PostgresIndexConfig, store.index_config)
distance_type = config.get("distance_type", "cosine")
# Return the operator and the score expression
# The operator is used in the CTE and will be compatible with an ASCENDING ORDER
# sort clause.
# The score expression is used in the final query and will be compatible with
# a DESCENDING ORDER sort clause and the user's expectations of what the similarity score
# should be.
if distance_type == "l2":
# Final: "-(sv.embedding <-> %s::%s)"
# We return the "l2 similarity" so that the sorting order is the same
return "sv.embedding <-> %s::%s", "-scored.neg_score"
elif distance_type == "inner_product":
# Final: "-(sv.embedding <#> %s::%s)"
return "sv.embedding <#> %s::%s", "-(scored.neg_score)"
else: # cosine similarity
# Final: "1 - (sv.embedding <=> %s::%s)"
return "sv.embedding <=> %s::%s", "1 - scored.neg_score"
def _ensure_index_config(
index_config: PostgresIndexConfig,
) -> tuple[Optional["Embeddings"], PostgresIndexConfig]:
index_config = index_config.copy()
tokenized: list[tuple[str, Union[Literal["$"], list[str]]]] = []
tot = 0
text_fields = index_config.get("text_fields") or ["$"]
if isinstance(text_fields, str):
text_fields = [text_fields]
if not isinstance(text_fields, list):
raise ValueError(f"Text fields must be a list or a string. Got {text_fields}")
for p in text_fields:
if p == "$":
tokenized.append((p, "$"))
tot += 1
else:
toks = tokenize_path(p)
tokenized.append((p, toks))
tot += len(toks)
index_config["__tokenized_fields"] = tokenized
index_config["__estimated_num_vectors"] = tot
embeddings = ensure_embeddings(
index_config.get("embed"),
)
return embeddings, index_config
_PLACEHOLDER = object()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/pyproject.toml`:
```toml
[tool.poetry]
name = "langgraph-checkpoint-postgres"
version = "2.0.7"
description = "Library with a Postgres implementation of LangGraph checkpoint saver."
authors = []
license = "MIT"
readme = "README.md"
repository = "https://www.github.com/langchain-ai/langgraph"
packages = [{ include = "langgraph" }]
[tool.poetry.dependencies]
python = "^3.9.0,<4.0"
langgraph-checkpoint = "^2.0.7"
orjson = ">=3.10.1"
psycopg = "^3.2.0"
psycopg-pool = "^3.2.0"
[tool.poetry.group.dev.dependencies]
ruff = "^0.6.2"
codespell = "^2.2.0"
pytest = "^7.2.1"
anyio = "^4.4.0"
pytest-asyncio = "^0.21.1"
pytest-mock = "^3.11.1"
pytest-watch = "^4.2.0"
mypy = "^1.10.0"
psycopg = {extras = ["binary"], version = ">=3.0.0"}
langgraph-checkpoint = {path = "../checkpoint", develop = true}
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5 -vv"
asyncio_mode = "auto"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
lint.select = [
"E", # pycodestyle
"F", # Pyflakes
"UP", # pyupgrade
"B", # flake8-bugbear
"I", # isort
]
lint.ignore = ["E501", "B008", "UP007", "UP006"]
[tool.mypy]
# https://mypy.readthedocs.io/en/stable/config_file.html
disallow_untyped_defs = "True"
explicit_package_bases = "True"
warn_no_return = "False"
warn_unused_ignores = "True"
warn_redundant_casts = "True"
allow_redefinition = "True"
disable_error_code = "typeddict-item, return-value"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/conftest.py`:
```py
from collections.abc import AsyncIterator
import pytest
from psycopg import AsyncConnection
from psycopg.errors import UndefinedTable
from psycopg.rows import DictRow, dict_row
from tests.embed_test_utils import CharacterEmbeddings
DEFAULT_POSTGRES_URI = "postgres://postgres:postgres@localhost:5441/"
DEFAULT_URI = "postgres://postgres:postgres@localhost:5441/postgres?sslmode=disable"
@pytest.fixture(scope="function")
async def conn() -> AsyncIterator[AsyncConnection[DictRow]]:
async with await AsyncConnection.connect(
DEFAULT_URI, autocommit=True, prepare_threshold=0, row_factory=dict_row
) as conn:
yield conn
@pytest.fixture(scope="function", autouse=True)
async def clear_test_db(conn: AsyncConnection[DictRow]) -> None:
"""Delete all tables before each test."""
try:
await conn.execute("DELETE FROM checkpoints")
await conn.execute("DELETE FROM checkpoint_blobs")
await conn.execute("DELETE FROM checkpoint_writes")
await conn.execute("DELETE FROM checkpoint_migrations")
except UndefinedTable:
pass
try:
await conn.execute("DELETE FROM store_migrations")
await conn.execute("DELETE FROM store")
except UndefinedTable:
pass
@pytest.fixture
def fake_embeddings() -> CharacterEmbeddings:
return CharacterEmbeddings(dims=500)
VECTOR_TYPES = ["vector", "halfvec"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/test_sync.py`:
```py
# type: ignore
from contextlib import contextmanager
from typing import Any
from uuid import uuid4
import pytest
from langchain_core.runnables import RunnableConfig
from psycopg import Connection
from psycopg.rows import dict_row
from psycopg_pool import ConnectionPool
from langgraph.checkpoint.base import (
Checkpoint,
CheckpointMetadata,
create_checkpoint,
empty_checkpoint,
)
from langgraph.checkpoint.postgres import PostgresSaver
from tests.conftest import DEFAULT_POSTGRES_URI
@contextmanager
def _pool_saver():
"""Fixture for pool mode testing."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
with ConnectionPool(
DEFAULT_POSTGRES_URI + database,
max_size=10,
kwargs={"autocommit": True, "row_factory": dict_row},
) as pool:
checkpointer = PostgresSaver(pool)
checkpointer.setup()
yield checkpointer
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
@contextmanager
def _pipe_saver():
"""Fixture for pipeline mode testing."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
with Connection.connect(
DEFAULT_POSTGRES_URI + database,
autocommit=True,
prepare_threshold=0,
row_factory=dict_row,
) as conn:
with conn.pipeline() as pipe:
checkpointer = PostgresSaver(conn, pipe=pipe)
checkpointer.setup()
with conn.pipeline() as pipe:
checkpointer = PostgresSaver(conn, pipe=pipe)
yield checkpointer
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
@contextmanager
def _base_saver():
"""Fixture for regular connection mode testing."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
with Connection.connect(
DEFAULT_POSTGRES_URI + database,
autocommit=True,
prepare_threshold=0,
row_factory=dict_row,
) as conn:
checkpointer = PostgresSaver(conn)
checkpointer.setup()
yield checkpointer
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
@contextmanager
def _saver(name: str):
if name == "base":
with _base_saver() as saver:
yield saver
elif name == "pool":
with _pool_saver() as saver:
yield saver
elif name == "pipe":
with _pipe_saver() as saver:
yield saver
@pytest.fixture
def test_data():
"""Fixture providing test data for checkpoint tests."""
config_1: RunnableConfig = {
"configurable": {
"thread_id": "thread-1",
# for backwards compatibility testing
"thread_ts": "1",
"checkpoint_ns": "",
}
}
config_2: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2",
"checkpoint_ns": "",
}
}
config_3: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2-inner",
"checkpoint_ns": "inner",
}
}
chkpnt_1: Checkpoint = empty_checkpoint()
chkpnt_2: Checkpoint = create_checkpoint(chkpnt_1, {}, 1)
chkpnt_3: Checkpoint = empty_checkpoint()
metadata_1: CheckpointMetadata = {
"source": "input",
"step": 2,
"writes": {},
"score": 1,
}
metadata_2: CheckpointMetadata = {
"source": "loop",
"step": 1,
"writes": {"foo": "bar"},
"score": None,
}
metadata_3: CheckpointMetadata = {}
return {
"configs": [config_1, config_2, config_3],
"checkpoints": [chkpnt_1, chkpnt_2, chkpnt_3],
"metadata": [metadata_1, metadata_2, metadata_3],
}
@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"])
def test_search(saver_name: str, test_data) -> None:
with _saver(saver_name) as saver:
configs = test_data["configs"]
checkpoints = test_data["checkpoints"]
metadata = test_data["metadata"]
saver.put(configs[0], checkpoints[0], metadata[0], {})
saver.put(configs[1], checkpoints[1], metadata[1], {})
saver.put(configs[2], checkpoints[2], metadata[2], {})
# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match
search_results_1 = list(saver.list(None, filter=query_1))
assert len(search_results_1) == 1
assert search_results_1[0].metadata == metadata[0]
search_results_2 = list(saver.list(None, filter=query_2))
assert len(search_results_2) == 1
assert search_results_2[0].metadata == metadata[1]
search_results_3 = list(saver.list(None, filter=query_3))
assert len(search_results_3) == 3
search_results_4 = list(saver.list(None, filter=query_4))
assert len(search_results_4) == 0
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = list(saver.list({"configurable": {"thread_id": "thread-2"}}))
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}
@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"])
def test_null_chars(saver_name: str, test_data) -> None:
with _saver(saver_name) as saver:
config = saver.put(
test_data["configs"][0],
test_data["checkpoints"][0],
{"my_key": "\x00abc"},
{},
)
assert saver.get_tuple(config).metadata["my_key"] == "abc" # type: ignore
assert (
list(saver.list(None, filter={"my_key": "abc"}))[0].metadata["my_key"]
== "abc"
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/test_async.py`:
```py
# type: ignore
from contextlib import asynccontextmanager
from typing import Any
from uuid import uuid4
import pytest
from langchain_core.runnables import RunnableConfig
from psycopg import AsyncConnection
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
from langgraph.checkpoint.base import (
Checkpoint,
CheckpointMetadata,
create_checkpoint,
empty_checkpoint,
)
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from tests.conftest import DEFAULT_POSTGRES_URI
@asynccontextmanager
async def _pool_saver():
"""Fixture for pool mode testing."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
async with AsyncConnectionPool(
DEFAULT_POSTGRES_URI + database,
max_size=10,
kwargs={"autocommit": True, "row_factory": dict_row},
) as pool:
checkpointer = AsyncPostgresSaver(pool)
await checkpointer.setup()
yield checkpointer
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
@asynccontextmanager
async def _pipe_saver():
"""Fixture for pipeline mode testing."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI + database,
autocommit=True,
prepare_threshold=0,
row_factory=dict_row,
) as conn:
async with conn.pipeline() as pipe:
checkpointer = AsyncPostgresSaver(conn, pipe=pipe)
await checkpointer.setup()
async with conn.pipeline() as pipe:
checkpointer = AsyncPostgresSaver(conn, pipe=pipe)
yield checkpointer
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
@asynccontextmanager
async def _base_saver():
"""Fixture for regular connection mode testing."""
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI + database,
autocommit=True,
prepare_threshold=0,
row_factory=dict_row,
) as conn:
checkpointer = AsyncPostgresSaver(conn)
await checkpointer.setup()
yield checkpointer
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
@asynccontextmanager
async def _saver(name: str):
if name == "base":
async with _base_saver() as saver:
yield saver
elif name == "pool":
async with _pool_saver() as saver:
yield saver
elif name == "pipe":
async with _pipe_saver() as saver:
yield saver
@pytest.fixture
def test_data():
"""Fixture providing test data for checkpoint tests."""
config_1: RunnableConfig = {
"configurable": {
"thread_id": "thread-1",
# for backwards compatibility testing
"thread_ts": "1",
"checkpoint_ns": "",
}
}
config_2: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2",
"checkpoint_ns": "",
}
}
config_3: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2-inner",
"checkpoint_ns": "inner",
}
}
chkpnt_1: Checkpoint = empty_checkpoint()
chkpnt_2: Checkpoint = create_checkpoint(chkpnt_1, {}, 1)
chkpnt_3: Checkpoint = empty_checkpoint()
metadata_1: CheckpointMetadata = {
"source": "input",
"step": 2,
"writes": {},
"score": 1,
}
metadata_2: CheckpointMetadata = {
"source": "loop",
"step": 1,
"writes": {"foo": "bar"},
"score": None,
}
metadata_3: CheckpointMetadata = {}
return {
"configs": [config_1, config_2, config_3],
"checkpoints": [chkpnt_1, chkpnt_2, chkpnt_3],
"metadata": [metadata_1, metadata_2, metadata_3],
}
@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"])
async def test_asearch(request, saver_name: str, test_data) -> None:
async with _saver(saver_name) as saver:
configs = test_data["configs"]
checkpoints = test_data["checkpoints"]
metadata = test_data["metadata"]
await saver.aput(configs[0], checkpoints[0], metadata[0], {})
await saver.aput(configs[1], checkpoints[1], metadata[1], {})
await saver.aput(configs[2], checkpoints[2], metadata[2], {})
# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match
search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
assert len(search_results_1) == 1
assert search_results_1[0].metadata == metadata[0]
search_results_2 = [c async for c in saver.alist(None, filter=query_2)]
assert len(search_results_2) == 1
assert search_results_2[0].metadata == metadata[1]
search_results_3 = [c async for c in saver.alist(None, filter=query_3)]
assert len(search_results_3) == 3
search_results_4 = [c async for c in saver.alist(None, filter=query_4)]
assert len(search_results_4) == 0
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = [
c async for c in saver.alist({"configurable": {"thread_id": "thread-2"}})
]
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}
@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"])
async def test_null_chars(request, saver_name: str, test_data) -> None:
async with _saver(saver_name) as saver:
config = await saver.aput(
test_data["configs"][0],
test_data["checkpoints"][0],
{"my_key": "\x00abc"},
{},
)
assert (await saver.aget_tuple(config)).metadata["my_key"] == "abc" # type: ignore
assert [c async for c in saver.alist(None, filter={"my_key": "abc"})][
0
].metadata["my_key"] == "abc"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/test_store.py`:
```py
# type: ignore
from contextlib import contextmanager
from typing import Any, Optional
from uuid import uuid4
import pytest
from langchain_core.embeddings import Embeddings
from psycopg import Connection
from langgraph.store.base import (
GetOp,
Item,
ListNamespacesOp,
MatchCondition,
PutOp,
SearchOp,
)
from langgraph.store.postgres import PostgresStore
from tests.conftest import (
DEFAULT_URI,
VECTOR_TYPES,
CharacterEmbeddings,
)
@pytest.fixture(scope="function", params=["default", "pipe", "pool"])
def store(request) -> PostgresStore:
database = f"test_{uuid4().hex[:16]}"
uri_parts = DEFAULT_URI.split("/")
uri_base = "/".join(uri_parts[:-1])
query_params = ""
if "?" in uri_parts[-1]:
db_name, query_params = uri_parts[-1].split("?", 1)
query_params = "?" + query_params
conn_string = f"{uri_base}/{database}{query_params}"
admin_conn_string = DEFAULT_URI
with Connection.connect(admin_conn_string, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
with PostgresStore.from_conn_string(conn_string) as store:
store.setup()
if request.param == "pipe":
with PostgresStore.from_conn_string(conn_string, pipeline=True) as store:
yield store
elif request.param == "pool":
with PostgresStore.from_conn_string(
conn_string, pool_config={"min_size": 1, "max_size": 10}
) as store:
yield store
else: # default
with PostgresStore.from_conn_string(conn_string) as store:
yield store
finally:
with Connection.connect(admin_conn_string, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
def test_batch_order(store: PostgresStore) -> None:
# Setup test data
store.put(("test", "foo"), "key1", {"data": "value1"})
store.put(("test", "bar"), "key2", {"data": "value2"})
ops = [
GetOp(namespace=("test", "foo"), key="key1"),
PutOp(namespace=("test", "bar"), key="key2", value={"data": "value2"}),
SearchOp(
namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0
),
ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0),
GetOp(namespace=("test",), key="key3"),
]
results = store.batch(ops)
assert len(results) == 5
assert isinstance(results[0], Item)
assert isinstance(results[0].value, dict)
assert results[0].value == {"data": "value1"}
assert results[0].key == "key1"
assert results[1] is None # Put operation returns None
assert isinstance(results[2], list)
assert len(results[2]) == 1
assert isinstance(results[3], list)
assert len(results[3]) > 0 # Should contain at least our test namespaces
assert results[4] is None # Non-existent key returns None
# Test reordered operations
ops_reordered = [
SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0),
GetOp(namespace=("test", "bar"), key="key2"),
ListNamespacesOp(match_conditions=None, max_depth=None, limit=5, offset=0),
PutOp(namespace=("test",), key="key3", value={"data": "value3"}),
GetOp(namespace=("test", "foo"), key="key1"),
]
results_reordered = store.batch(ops_reordered)
assert len(results_reordered) == 5
assert isinstance(results_reordered[0], list)
assert len(results_reordered[0]) >= 2 # Should find at least our two test items
assert isinstance(results_reordered[1], Item)
assert results_reordered[1].value == {"data": "value2"}
assert results_reordered[1].key == "key2"
assert isinstance(results_reordered[2], list)
assert len(results_reordered[2]) > 0
assert results_reordered[3] is None # Put operation returns None
assert isinstance(results_reordered[4], Item)
assert results_reordered[4].value == {"data": "value1"}
assert results_reordered[4].key == "key1"
def test_batch_get_ops(store: PostgresStore) -> None:
# Setup test data
store.put(("test",), "key1", {"data": "value1"})
store.put(("test",), "key2", {"data": "value2"})
ops = [
GetOp(namespace=("test",), key="key1"),
GetOp(namespace=("test",), key="key2"),
GetOp(namespace=("test",), key="key3"), # Non-existent key
]
results = store.batch(ops)
assert len(results) == 3
assert results[0] is not None
assert results[1] is not None
assert results[2] is None
assert results[0].key == "key1"
assert results[1].key == "key2"
def test_batch_put_ops(store: PostgresStore) -> None:
ops = [
PutOp(namespace=("test",), key="key1", value={"data": "value1"}),
PutOp(namespace=("test",), key="key2", value={"data": "value2"}),
PutOp(namespace=("test",), key="key3", value=None), # Delete operation
]
results = store.batch(ops)
assert len(results) == 3
assert all(result is None for result in results)
# Verify the puts worked
item1 = store.get(("test",), "key1")
item2 = store.get(("test",), "key2")
item3 = store.get(("test",), "key3")
assert item1 and item1.value == {"data": "value1"}
assert item2 and item2.value == {"data": "value2"}
assert item3 is None
def test_batch_search_ops(store: PostgresStore) -> None:
# Setup test data
test_data = [
(("test", "foo"), "key1", {"data": "value1", "tag": "a"}),
(("test", "bar"), "key2", {"data": "value2", "tag": "a"}),
(("test", "baz"), "key3", {"data": "value3", "tag": "b"}),
]
for namespace, key, value in test_data:
store.put(namespace, key, value)
ops = [
SearchOp(namespace_prefix=("test",), filter={"tag": "a"}, limit=10, offset=0),
SearchOp(namespace_prefix=("test",), filter=None, limit=2, offset=0),
SearchOp(namespace_prefix=("test", "foo"), filter=None, limit=10, offset=0),
]
results = store.batch(ops)
assert len(results) == 3
# First search should find items with tag "a"
assert len(results[0]) == 2
assert all(item.value["tag"] == "a" for item in results[0])
# Second search should return first 2 items
assert len(results[1]) == 2
# Third search should only find items in test/foo namespace
assert len(results[2]) == 1
assert results[2][0].namespace == ("test", "foo")
def test_batch_list_namespaces_ops(store: PostgresStore) -> None:
# Setup test data with various namespaces
test_data = [
(("test", "documents", "public"), "doc1", {"content": "public doc"}),
(("test", "documents", "private"), "doc2", {"content": "private doc"}),
(("test", "images", "public"), "img1", {"content": "public image"}),
(("prod", "documents", "public"), "doc3", {"content": "prod doc"}),
]
for namespace, key, value in test_data:
store.put(namespace, key, value)
ops = [
ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0),
ListNamespacesOp(match_conditions=None, max_depth=2, limit=10, offset=0),
ListNamespacesOp(
match_conditions=[MatchCondition("suffix", "public")],
max_depth=None,
limit=10,
offset=0,
),
]
results = store.batch(ops)
assert len(results) == 3
# First operation should list all namespaces
assert len(results[0]) == len(test_data)
# Second operation should only return namespaces up to depth 2
assert all(len(ns) <= 2 for ns in results[1])
# Third operation should only return namespaces ending with "public"
assert all(ns[-1] == "public" for ns in results[2])
class TestPostgresStore:
@pytest.fixture(autouse=True)
def setup(self) -> None:
with PostgresStore.from_conn_string(DEFAULT_URI) as store:
store.setup()
def test_basic_store_ops(self) -> None:
with PostgresStore.from_conn_string(DEFAULT_URI) as store:
namespace = ("test", "documents")
item_id = "doc1"
item_value = {"title": "Test Document", "content": "Hello, World!"}
store.put(namespace, item_id, item_value)
item = store.get(namespace, item_id)
assert item
assert item.namespace == namespace
assert item.key == item_id
assert item.value == item_value
# Test update
updated_value = {"title": "Updated Document", "content": "Hello, Updated!"}
store.put(namespace, item_id, updated_value)
updated_item = store.get(namespace, item_id)
assert updated_item.value == updated_value
assert updated_item.updated_at > item.updated_at
# Test get from non-existent namespace
different_namespace = ("test", "other_documents")
item_in_different_namespace = store.get(different_namespace, item_id)
assert item_in_different_namespace is None
# Test delete
store.delete(namespace, item_id)
deleted_item = store.get(namespace, item_id)
assert deleted_item is None
def test_list_namespaces(self) -> None:
with PostgresStore.from_conn_string(DEFAULT_URI) as store:
# Create test data with various namespaces
test_namespaces = [
("test", "documents", "public"),
("test", "documents", "private"),
("test", "images", "public"),
("test", "images", "private"),
("prod", "documents", "public"),
("prod", "documents", "private"),
]
# Insert test data
for namespace in test_namespaces:
store.put(namespace, "dummy", {"content": "dummy"})
# Test listing with various filters
all_namespaces = store.list_namespaces()
assert len(all_namespaces) == len(test_namespaces)
# Test prefix filtering
test_prefix_namespaces = store.list_namespaces(prefix=["test"])
assert len(test_prefix_namespaces) == 4
assert all(ns[0] == "test" for ns in test_prefix_namespaces)
# Test suffix filtering
public_namespaces = store.list_namespaces(suffix=["public"])
assert len(public_namespaces) == 3
assert all(ns[-1] == "public" for ns in public_namespaces)
# Test max depth
depth_2_namespaces = store.list_namespaces(max_depth=2)
assert all(len(ns) <= 2 for ns in depth_2_namespaces)
# Test pagination
paginated_namespaces = store.list_namespaces(limit=3)
assert len(paginated_namespaces) == 3
# Cleanup
for namespace in test_namespaces:
store.delete(namespace, "dummy")
def test_search(self) -> None:
with PostgresStore.from_conn_string(DEFAULT_URI) as store:
# Create test data
test_data = [
(
("test", "docs"),
"doc1",
{"title": "First Doc", "author": "Alice", "tags": ["important"]},
),
(
("test", "docs"),
"doc2",
{"title": "Second Doc", "author": "Bob", "tags": ["draft"]},
),
(
("test", "images"),
"img1",
{"title": "Image 1", "author": "Alice", "tags": ["final"]},
),
]
for namespace, key, value in test_data:
store.put(namespace, key, value)
# Test basic search
all_items = store.search(["test"])
assert len(all_items) == 3
# Test namespace filtering
docs_items = store.search(["test", "docs"])
assert len(docs_items) == 2
assert all(item.namespace == ("test", "docs") for item in docs_items)
# Test value filtering
alice_items = store.search(["test"], filter={"author": "Alice"})
assert len(alice_items) == 2
assert all(item.value["author"] == "Alice" for item in alice_items)
# Test pagination
paginated_items = store.search(["test"], limit=2)
assert len(paginated_items) == 2
offset_items = store.search(["test"], offset=2)
assert len(offset_items) == 1
# Cleanup
for namespace, key, _ in test_data:
store.delete(namespace, key)
@contextmanager
def _create_vector_store(
vector_type: str,
distance_type: str,
fake_embeddings: Embeddings,
text_fields: Optional[list[str]] = None,
) -> PostgresStore:
"""Create a store with vector search enabled."""
database = f"test_{uuid4().hex[:16]}"
uri_parts = DEFAULT_URI.split("/")
uri_base = "/".join(uri_parts[:-1])
query_params = ""
if "?" in uri_parts[-1]:
db_name, query_params = uri_parts[-1].split("?", 1)
query_params = "?" + query_params
conn_string = f"{uri_base}/{database}{query_params}"
admin_conn_string = DEFAULT_URI
index_config = {
"dims": fake_embeddings.dims,
"embed": fake_embeddings,
"ann_index_config": {
"vector_type": vector_type,
},
"distance_type": distance_type,
"text_fields": text_fields,
}
with Connection.connect(admin_conn_string, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
with PostgresStore.from_conn_string(
conn_string,
index=index_config,
) as store:
store.setup()
yield store
finally:
with Connection.connect(admin_conn_string, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
@pytest.fixture(
scope="function",
params=[
(vector_type, distance_type)
for vector_type in VECTOR_TYPES
for distance_type in (
["hamming"] if vector_type == "bit" else ["l2", "inner_product", "cosine"]
)
],
ids=lambda p: f"{p[0]}_{p[1]}",
)
def vector_store(
request,
fake_embeddings: Embeddings,
) -> PostgresStore:
"""Create a store with vector search enabled."""
vector_type, distance_type = request.param
with _create_vector_store(vector_type, distance_type, fake_embeddings) as store:
yield store
def test_vector_store_initialization(
vector_store: PostgresStore, fake_embeddings: CharacterEmbeddings
) -> None:
"""Test store initialization with embedding config."""
# Store should be initialized with embedding config
assert vector_store.index_config is not None
assert vector_store.index_config["dims"] == fake_embeddings.dims
assert vector_store.index_config["embed"] == fake_embeddings
def test_vector_insert_with_auto_embedding(vector_store: PostgresStore) -> None:
"""Test inserting items that get auto-embedded."""
docs = [
("doc1", {"text": "short text"}),
("doc2", {"text": "longer text document"}),
("doc3", {"text": "longest text document here"}),
("doc4", {"description": "text in description field"}),
("doc5", {"content": "text in content field"}),
("doc6", {"body": "text in body field"}),
]
for key, value in docs:
vector_store.put(("test",), key, value)
results = vector_store.search(("test",), query="long text")
assert len(results) > 0
doc_order = [r.key for r in results]
assert "doc2" in doc_order
assert "doc3" in doc_order
def test_vector_update_with_embedding(vector_store: PostgresStore) -> None:
"""Test that updating items properly updates their embeddings."""
vector_store.put(("test",), "doc1", {"text": "zany zebra Xerxes"})
vector_store.put(("test",), "doc2", {"text": "something about dogs"})
vector_store.put(("test",), "doc3", {"text": "text about birds"})
results_initial = vector_store.search(("test",), query="Zany Xerxes")
assert len(results_initial) > 0
assert results_initial[0].key == "doc1"
initial_score = results_initial[0].score
vector_store.put(("test",), "doc1", {"text": "new text about dogs"})
results_after = vector_store.search(("test",), query="Zany Xerxes")
after_score = next((r.score for r in results_after if r.key == "doc1"), 0.0)
assert after_score < initial_score
results_new = vector_store.search(("test",), query="new text about dogs")
for r in results_new:
if r.key == "doc1":
assert r.score > after_score
# Don't index this one
vector_store.put(("test",), "doc4", {"text": "new text about dogs"}, index=False)
results_new = vector_store.search(("test",), query="new text about dogs", limit=3)
assert not any(r.key == "doc4" for r in results_new)
def test_vector_search_with_filters(vector_store: PostgresStore) -> None:
"""Test combining vector search with filters."""
# Insert test documents
docs = [
("doc1", {"text": "red apple", "color": "red", "score": 4.5}),
("doc2", {"text": "red car", "color": "red", "score": 3.0}),
("doc3", {"text": "green apple", "color": "green", "score": 4.0}),
("doc4", {"text": "blue car", "color": "blue", "score": 3.5}),
]
for key, value in docs:
vector_store.put(("test",), key, value)
results = vector_store.search(("test",), query="apple", filter={"color": "red"})
assert len(results) == 2
assert results[0].key == "doc1"
results = vector_store.search(("test",), query="car", filter={"color": "red"})
assert len(results) == 2
assert results[0].key == "doc2"
results = vector_store.search(
("test",), query="bbbbluuu", filter={"score": {"$gt": 3.2}}
)
assert len(results) == 3
assert results[0].key == "doc4"
# Multiple filters
results = vector_store.search(
("test",), query="apple", filter={"score": {"$gte": 4.0}, "color": "green"}
)
assert len(results) == 1
assert results[0].key == "doc3"
def test_vector_search_pagination(vector_store: PostgresStore) -> None:
"""Test pagination with vector search."""
# Insert multiple similar documents
for i in range(5):
vector_store.put(("test",), f"doc{i}", {"text": f"test document number {i}"})
# Test with different page sizes
results_page1 = vector_store.search(("test",), query="test", limit=2)
results_page2 = vector_store.search(("test",), query="test", limit=2, offset=2)
assert len(results_page1) == 2
assert len(results_page2) == 2
assert results_page1[0].key != results_page2[0].key
# Get all results
all_results = vector_store.search(("test",), query="test", limit=10)
assert len(all_results) == 5
def test_vector_search_edge_cases(vector_store: PostgresStore) -> None:
"""Test edge cases in vector search."""
vector_store.put(("test",), "doc1", {"text": "test document"})
results = vector_store.search(("test",), query="")
assert len(results) == 1
results = vector_store.search(("test",), query=None)
assert len(results) == 1
long_query = "test " * 100
results = vector_store.search(("test",), query=long_query)
assert len(results) == 1
special_query = "test!@#$%^&*()"
results = vector_store.search(("test",), query=special_query)
assert len(results) == 1
@pytest.mark.parametrize(
"vector_type,distance_type",
[
("vector", "cosine"),
("vector", "inner_product"),
("halfvec", "cosine"),
("halfvec", "inner_product"),
],
)
def test_embed_with_path_sync(
request: Any,
fake_embeddings: CharacterEmbeddings,
vector_type: str,
distance_type: str,
) -> None:
"""Test vector search with specific text fields in Postgres store."""
with _create_vector_store(
vector_type,
distance_type,
fake_embeddings,
text_fields=["key0", "key1", "key3"],
) as store:
# This will have 2 vectors representing it
doc1 = {
# Omit key0 - check it doesn't raise an error
"key1": "xxx",
"key2": "yyy",
"key3": "zzz",
}
# This will have 3 vectors representing it
doc2 = {
"key0": "uuu",
"key1": "vvv",
"key2": "www",
"key3": "xxx",
}
store.put(("test",), "doc1", doc1)
store.put(("test",), "doc2", doc2)
# doc2.key3 and doc1.key1 both would have the highest score
results = store.search(("test",), query="xxx")
assert len(results) == 2
assert results[0].key != results[1].key
ascore = results[0].score
bscore = results[1].score
assert ascore == pytest.approx(bscore, abs=1e-3)
# ~Only match doc2
results = store.search(("test",), query="uuu")
assert len(results) == 2
assert results[0].key != results[1].key
assert results[0].key == "doc2"
assert results[0].score > results[1].score
assert ascore == pytest.approx(results[0].score, abs=1e-3)
# ~Only match doc1
results = store.search(("test",), query="zzz")
assert len(results) == 2
assert results[0].key != results[1].key
assert results[0].key == "doc1"
assert results[0].score > results[1].score
assert ascore == pytest.approx(results[0].score, abs=1e-3)
# Un-indexed - will have low results for both. Not zero (because we're projecting)
# but less than the above.
results = store.search(("test",), query="www")
assert len(results) == 2
assert results[0].key != results[1].key
assert results[0].score < ascore
assert results[1].score < ascore
@pytest.mark.parametrize(
"vector_type,distance_type",
[
("vector", "cosine"),
("vector", "inner_product"),
("halfvec", "cosine"),
("halfvec", "inner_product"),
],
)
def test_embed_with_path_operation_config(
request: Any,
fake_embeddings: CharacterEmbeddings,
vector_type: str,
distance_type: str,
) -> None:
"""Test operation-level field configuration for vector search."""
with _create_vector_store(
vector_type,
distance_type,
fake_embeddings,
text_fields=["key17"], # Default fields that won't match our test data
) as store:
doc3 = {
"key0": "aaa",
"key1": "bbb",
"key2": "ccc",
"key3": "ddd",
}
doc4 = {
"key0": "eee",
"key1": "bbb", # Same as doc3.key1
"key2": "fff",
"key3": "ggg",
}
store.put(("test",), "doc3", doc3, index=["key0", "key1"])
store.put(("test",), "doc4", doc4, index=["key1", "key3"])
results = store.search(("test",), query="aaa")
assert len(results) == 2
assert results[0].key == "doc3"
assert len(set(r.key for r in results)) == 2
assert results[0].score > results[1].score
results = store.search(("test",), query="ggg")
assert len(results) == 2
assert results[0].key == "doc4"
assert results[0].score > results[1].score
results = store.search(("test",), query="bbb")
assert len(results) == 2
assert results[0].key != results[1].key
assert results[0].score == pytest.approx(results[1].score, abs=1e-3)
results = store.search(("test",), query="ccc")
assert len(results) == 2
assert all(
r.score < 0.9 for r in results
) # Unindexed field should have low scores
# Test index=False behavior
doc5 = {
"key0": "hhh",
"key1": "iii",
}
store.put(("test",), "doc5", doc5, index=False)
results = store.search(("test",))
assert len(results) == 3
assert all(r.score is None for r in results)
assert any(r.key == "doc5" for r in results)
results = store.search(("test",), query="hhh")
# TODO: We don't currently fill in additional results if there are not enough
# returned during vector search.
# assert len(results) == 3
# doc5_result = next(r for r in results if r.key == "doc5")
# assert doc5_result.score is None
def _cosine_similarity(X: list[float], Y: list[list[float]]) -> list[float]:
"""
Compute cosine similarity between a vector X and a matrix Y.
Lazy import numpy for efficiency.
"""
similarities = []
for y in Y:
dot_product = sum(a * b for a, b in zip(X, y))
norm1 = sum(a * a for a in X) ** 0.5
norm2 = sum(a * a for a in y) ** 0.5
similarity = dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0.0
similarities.append(similarity)
return similarities
def _inner_product(X: list[float], Y: list[list[float]]) -> list[float]:
"""
Compute inner product between a vector X and a matrix Y.
Lazy import numpy for efficiency.
"""
similarities = []
for y in Y:
similarity = sum(a * b for a, b in zip(X, y))
similarities.append(similarity)
return similarities
def _neg_l2_distance(X: list[float], Y: list[list[float]]) -> list[float]:
"""
Compute l2 distance between a vector X and a matrix Y.
Lazy import numpy for efficiency.
"""
similarities = []
for y in Y:
similarity = sum((a - b) ** 2 for a, b in zip(X, y)) ** 0.5
similarities.append(-similarity)
return similarities
@pytest.mark.parametrize(
"vector_type,distance_type",
[
("vector", "cosine"),
("vector", "inner_product"),
("halfvec", "l2"),
],
)
@pytest.mark.parametrize("query", ["aaa", "bbb", "ccc", "abcd", "poisson"])
def test_scores(
fake_embeddings: CharacterEmbeddings,
vector_type: str,
distance_type: str,
query: str,
) -> None:
"""Test operation-level field configuration for vector search."""
with _create_vector_store(
vector_type,
distance_type,
fake_embeddings,
text_fields=["key0"],
) as store:
doc = {
"key0": "aaa",
}
store.put(("test",), "doc", doc, index=["key0", "key1"])
results = store.search((), query=query)
vec0 = fake_embeddings.embed_query(doc["key0"])
vec1 = fake_embeddings.embed_query(query)
if distance_type == "cosine":
similarities = _cosine_similarity(vec1, [vec0])
elif distance_type == "inner_product":
similarities = _inner_product(vec1, [vec0])
elif distance_type == "l2":
similarities = _neg_l2_distance(vec1, [vec0])
assert len(results) == 1
assert results[0].score == pytest.approx(similarities[0], abs=1e-3)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/embed_test_utils.py`:
```py
"""Embedding utilities for testing."""
import math
import random
from collections import Counter, defaultdict
from typing import Any
from langchain_core.embeddings import Embeddings
class CharacterEmbeddings(Embeddings):
"""Simple character-frequency based embeddings using random projections."""
def __init__(self, dims: int = 50, seed: int = 42):
"""Initialize with embedding dimensions and random seed."""
self._rng = random.Random(seed)
self.dims = dims
# Create projection vector for each character lazily
self._char_projections: defaultdict[str, list[float]] = defaultdict(
lambda: [
self._rng.gauss(0, 1 / math.sqrt(self.dims)) for _ in range(self.dims)
]
)
def _embed_one(self, text: str) -> list[float]:
"""Embed a single text."""
counts = Counter(text)
total = sum(counts.values())
if total == 0:
return [0.0] * self.dims
embedding = [0.0] * self.dims
for char, count in counts.items():
weight = count / total
char_proj = self._char_projections[char]
for i, proj in enumerate(char_proj):
embedding[i] += weight * proj
norm = math.sqrt(sum(x * x for x in embedding))
if norm > 0:
embedding = [x / norm for x in embedding]
return embedding
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of documents."""
return [self._embed_one(text) for text in texts]
def embed_query(self, text: str) -> list[float]:
"""Embed a query string."""
return self._embed_one(text)
def __eq__(self, other: Any) -> bool:
return isinstance(other, CharacterEmbeddings) and self.dims == other.dims
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-postgres/tests/test_async_store.py`:
```py
# type: ignore
import itertools
import sys
import uuid
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, Optional
import pytest
from langchain_core.embeddings import Embeddings
from psycopg import AsyncConnection
from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp
from langgraph.store.postgres import AsyncPostgresStore
from tests.conftest import (
DEFAULT_URI,
VECTOR_TYPES,
CharacterEmbeddings,
)
@pytest.fixture(scope="function", params=["default", "pipe", "pool"])
async def store(request) -> AsyncIterator[AsyncPostgresStore]:
if sys.version_info < (3, 10):
pytest.skip("Async Postgres tests require Python 3.10+")
database = f"test_{uuid.uuid4().hex[:16]}"
uri_parts = DEFAULT_URI.split("/")
uri_base = "/".join(uri_parts[:-1])
query_params = ""
if "?" in uri_parts[-1]:
db_name, query_params = uri_parts[-1].split("?", 1)
query_params = "?" + query_params
conn_string = f"{uri_base}/{database}{query_params}"
admin_conn_string = DEFAULT_URI
async with await AsyncConnection.connect(
admin_conn_string, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with AsyncPostgresStore.from_conn_string(conn_string) as store:
await store.setup()
if request.param == "pipe":
async with AsyncPostgresStore.from_conn_string(
conn_string, pipeline=True
) as store:
yield store
elif request.param == "pool":
async with AsyncPostgresStore.from_conn_string(
conn_string, pool_config={"min_size": 1, "max_size": 10}
) as store:
yield store
else: # default
async with AsyncPostgresStore.from_conn_string(conn_string) as store:
yield store
finally:
async with await AsyncConnection.connect(
admin_conn_string, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
async def test_abatch_order(store: AsyncPostgresStore) -> None:
# Setup test data
await store.aput(("test", "foo"), "key1", {"data": "value1"})
await store.aput(("test", "bar"), "key2", {"data": "value2"})
ops = [
GetOp(namespace=("test", "foo"), key="key1"),
PutOp(namespace=("test", "bar"), key="key2", value={"data": "value2"}),
SearchOp(
namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0
),
ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0),
GetOp(namespace=("test",), key="key3"),
]
results = await store.abatch(ops)
assert len(results) == 5
assert isinstance(results[0], Item)
assert isinstance(results[0].value, dict)
assert results[0].value == {"data": "value1"}
assert results[0].key == "key1"
assert results[1] is None
assert isinstance(results[2], list)
assert len(results[2]) == 1
assert isinstance(results[3], list)
assert ("test", "foo") in results[3] and ("test", "bar") in results[3]
assert results[4] is None
ops_reordered = [
SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0),
GetOp(namespace=("test", "bar"), key="key2"),
ListNamespacesOp(match_conditions=None, max_depth=None, limit=5, offset=0),
PutOp(namespace=("test",), key="key3", value={"data": "value3"}),
GetOp(namespace=("test", "foo"), key="key1"),
]
results_reordered = await store.abatch(ops_reordered)
assert len(results_reordered) == 5
assert isinstance(results_reordered[0], list)
assert len(results_reordered[0]) == 2
assert isinstance(results_reordered[1], Item)
assert results_reordered[1].value == {"data": "value2"}
assert results_reordered[1].key == "key2"
assert isinstance(results_reordered[2], list)
assert ("test", "foo") in results_reordered[2] and (
"test",
"bar",
) in results_reordered[2]
assert results_reordered[3] is None
assert isinstance(results_reordered[4], Item)
assert results_reordered[4].value == {"data": "value1"}
assert results_reordered[4].key == "key1"
async def test_batch_get_ops(store: AsyncPostgresStore) -> None:
# Setup test data
await store.aput(("test",), "key1", {"data": "value1"})
await store.aput(("test",), "key2", {"data": "value2"})
ops = [
GetOp(namespace=("test",), key="key1"),
GetOp(namespace=("test",), key="key2"),
GetOp(namespace=("test",), key="key3"),
]
results = await store.abatch(ops)
assert len(results) == 3
assert results[0] is not None
assert results[1] is not None
assert results[2] is None
assert results[0].key == "key1"
assert results[1].key == "key2"
async def test_batch_put_ops(store: AsyncPostgresStore) -> None:
ops = [
PutOp(namespace=("test",), key="key1", value={"data": "value1"}),
PutOp(namespace=("test",), key="key2", value={"data": "value2"}),
PutOp(namespace=("test",), key="key3", value=None),
]
results = await store.abatch(ops)
assert len(results) == 3
assert all(result is None for result in results)
# Verify the puts worked
items = await store.asearch(["test"], limit=10)
assert len(items) == 2 # key3 had None value so wasn't stored
async def test_batch_search_ops(store: AsyncPostgresStore) -> None:
# Setup test data
await store.aput(("test", "foo"), "key1", {"data": "value1"})
await store.aput(("test", "bar"), "key2", {"data": "value2"})
ops = [
SearchOp(
namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0
),
SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0),
]
results = await store.abatch(ops)
assert len(results) == 2
assert len(results[0]) == 1 # Filtered results
assert len(results[1]) == 2 # All results
async def test_batch_list_namespaces_ops(store: AsyncPostgresStore) -> None:
# Setup test data
await store.aput(("test", "namespace1"), "key1", {"data": "value1"})
await store.aput(("test", "namespace2"), "key2", {"data": "value2"})
ops = [ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0)]
results = await store.abatch(ops)
assert len(results) == 1
assert len(results[0]) == 2
assert ("test", "namespace1") in results[0]
assert ("test", "namespace2") in results[0]
@asynccontextmanager
async def _create_vector_store(
vector_type: str,
distance_type: str,
fake_embeddings: CharacterEmbeddings,
text_fields: Optional[list[str]] = None,
) -> AsyncIterator[AsyncPostgresStore]:
"""Create a store with vector search enabled."""
if sys.version_info < (3, 10):
pytest.skip("Async Postgres tests require Python 3.10+")
database = f"test_{uuid.uuid4().hex[:16]}"
uri_parts = DEFAULT_URI.split("/")
uri_base = "/".join(uri_parts[:-1])
query_params = ""
if "?" in uri_parts[-1]:
db_name, query_params = uri_parts[-1].split("?", 1)
query_params = "?" + query_params
conn_string = f"{uri_base}/{database}{query_params}"
admin_conn_string = DEFAULT_URI
index_config = {
"dims": fake_embeddings.dims,
"embed": fake_embeddings,
"ann_index_config": {
"vector_type": vector_type,
},
"distance_type": distance_type,
"text_fields": text_fields,
}
async with await AsyncConnection.connect(
admin_conn_string, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with AsyncPostgresStore.from_conn_string(
conn_string,
index=index_config,
) as store:
await store.setup()
yield store
finally:
async with await AsyncConnection.connect(
admin_conn_string, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
@pytest.fixture(
scope="function",
params=[
(vector_type, distance_type)
for vector_type in VECTOR_TYPES
for distance_type in (
["hamming"] if vector_type == "bit" else ["l2", "inner_product", "cosine"]
)
],
ids=lambda p: f"{p[0]}_{p[1]}",
)
async def vector_store(
request,
fake_embeddings: CharacterEmbeddings,
) -> AsyncIterator[AsyncPostgresStore]:
"""Create a store with vector search enabled."""
vector_type, distance_type = request.param
async with _create_vector_store(
vector_type, distance_type, fake_embeddings
) as store:
yield store
async def test_vector_store_initialization(
vector_store: AsyncPostgresStore, fake_embeddings: CharacterEmbeddings
) -> None:
"""Test store initialization with embedding config."""
assert vector_store.index_config is not None
assert vector_store.index_config["dims"] == fake_embeddings.dims
if isinstance(vector_store.index_config["embed"], Embeddings):
assert vector_store.index_config["embed"] == fake_embeddings
async def test_vector_insert_with_auto_embedding(
vector_store: AsyncPostgresStore,
) -> None:
"""Test inserting items that get auto-embedded."""
docs = [
("doc1", {"text": "short text"}),
("doc2", {"text": "longer text document"}),
("doc3", {"text": "longest text document here"}),
("doc4", {"description": "text in description field"}),
("doc5", {"content": "text in content field"}),
("doc6", {"body": "text in body field"}),
]
for key, value in docs:
await vector_store.aput(("test",), key, value)
results = await vector_store.asearch(("test",), query="long text")
assert len(results) > 0
doc_order = [r.key for r in results]
assert "doc2" in doc_order
assert "doc3" in doc_order
async def test_vector_update_with_embedding(vector_store: AsyncPostgresStore) -> None:
"""Test that updating items properly updates their embeddings."""
await vector_store.aput(("test",), "doc1", {"text": "zany zebra Xerxes"})
await vector_store.aput(("test",), "doc2", {"text": "something about dogs"})
await vector_store.aput(("test",), "doc3", {"text": "text about birds"})
results_initial = await vector_store.asearch(("test",), query="Zany Xerxes")
assert len(results_initial) > 0
assert results_initial[0].key == "doc1"
initial_score = results_initial[0].score
await vector_store.aput(("test",), "doc1", {"text": "new text about dogs"})
results_after = await vector_store.asearch(("test",), query="Zany Xerxes")
after_score = next((r.score for r in results_after if r.key == "doc1"), 0.0)
assert after_score < initial_score
results_new = await vector_store.asearch(("test",), query="new text about dogs")
for r in results_new:
if r.key == "doc1":
assert r.score > after_score
# Don't index this one
await vector_store.aput(
("test",), "doc4", {"text": "new text about dogs"}, index=False
)
results_new = await vector_store.asearch(
("test",), query="new text about dogs", limit=3
)
assert not any(r.key == "doc4" for r in results_new)
async def test_vector_search_with_filters(vector_store: AsyncPostgresStore) -> None:
"""Test combining vector search with filters."""
docs = [
("doc1", {"text": "red apple", "color": "red", "score": 4.5}),
("doc2", {"text": "red car", "color": "red", "score": 3.0}),
("doc3", {"text": "green apple", "color": "green", "score": 4.0}),
("doc4", {"text": "blue car", "color": "blue", "score": 3.5}),
]
for key, value in docs:
await vector_store.aput(("test",), key, value)
results = await vector_store.asearch(
("test",), query="apple", filter={"color": "red"}
)
assert len(results) == 2
assert results[0].key == "doc1"
results = await vector_store.asearch(
("test",), query="car", filter={"color": "red"}
)
assert len(results) == 2
assert results[0].key == "doc2"
results = await vector_store.asearch(
("test",), query="bbbbluuu", filter={"score": {"$gt": 3.2}}
)
assert len(results) == 3
assert results[0].key == "doc4"
results = await vector_store.asearch(
("test",), query="apple", filter={"score": {"$gte": 4.0}, "color": "green"}
)
assert len(results) == 1
assert results[0].key == "doc3"
async def test_vector_search_pagination(vector_store: AsyncPostgresStore) -> None:
"""Test pagination with vector search."""
for i in range(5):
await vector_store.aput(
("test",), f"doc{i}", {"text": f"test document number {i}"}
)
results_page1 = await vector_store.asearch(("test",), query="test", limit=2)
results_page2 = await vector_store.asearch(
("test",), query="test", limit=2, offset=2
)
assert len(results_page1) == 2
assert len(results_page2) == 2
assert results_page1[0].key != results_page2[0].key
all_results = await vector_store.asearch(("test",), query="test", limit=10)
assert len(all_results) == 5
async def test_vector_search_edge_cases(vector_store: AsyncPostgresStore) -> None:
"""Test edge cases in vector search."""
await vector_store.aput(("test",), "doc1", {"text": "test document"})
perfect_match = await vector_store.asearch(("test",), query="text test document")
perfect_score = perfect_match[0].score
results = await vector_store.asearch(("test",), query="")
assert len(results) == 1
assert results[0].score is None
results = await vector_store.asearch(("test",), query=None)
assert len(results) == 1
assert results[0].score is None
long_query = "foo " * 100
results = await vector_store.asearch(("test",), query=long_query)
assert len(results) == 1
assert results[0].score < perfect_score
special_query = "test!@#$%^&*()"
results = await vector_store.asearch(("test",), query=special_query)
assert len(results) == 1
assert results[0].score < perfect_score
@pytest.mark.parametrize(
"vector_type,distance_type",
[
*itertools.product(["vector", "halfvec"], ["cosine", "inner_product", "l2"]),
],
)
async def test_embed_with_path(
request: Any,
fake_embeddings: CharacterEmbeddings,
vector_type: str,
distance_type: str,
) -> None:
"""Test vector search with specific text fields in Postgres store."""
async with _create_vector_store(
vector_type,
distance_type,
fake_embeddings,
text_fields=["key0", "key1", "key3"],
) as store:
# This will have 2 vectors representing it
doc1 = {
# Omit key0 - check it doesn't raise an error
"key1": "xxx",
"key2": "yyy",
"key3": "zzz",
}
# This will have 3 vectors representing it
doc2 = {
"key0": "uuu",
"key1": "vvv",
"key2": "www",
"key3": "xxx",
}
await store.aput(("test",), "doc1", doc1)
await store.aput(("test",), "doc2", doc2)
# doc2.key3 and doc1.key1 both would have the highest score
results = await store.asearch(("test",), query="xxx")
assert len(results) == 2
assert results[0].key != results[1].key
ascore = results[0].score
bscore = results[1].score
assert ascore == pytest.approx(bscore, abs=1e-3)
results = await store.asearch(("test",), query="uuu")
assert len(results) == 2
assert results[0].key != results[1].key
assert results[0].key == "doc2"
assert results[0].score > results[1].score
assert ascore == pytest.approx(results[0].score, abs=1e-3)
# Un-indexed - will have low results for both. Not zero (because we're projecting)
# but less than the above.
results = await store.asearch(("test",), query="www")
assert len(results) == 2
assert results[0].score < ascore
assert results[1].score < ascore
@pytest.mark.parametrize(
"vector_type,distance_type",
[
*itertools.product(["vector", "halfvec"], ["cosine", "inner_product", "l2"]),
],
)
async def test_search_sorting(
request: Any,
fake_embeddings: CharacterEmbeddings,
vector_type: str,
distance_type: str,
) -> None:
"""Test operation-level field configuration for vector search."""
async with _create_vector_store(
vector_type,
distance_type,
fake_embeddings,
text_fields=["key1"], # Default fields that won't match our test data
) as store:
amatch = {
"key1": "mmm",
}
await store.aput(("test", "M"), "M", amatch)
N = 100
for i in range(N):
await store.aput(("test", "A"), f"A{i}", {"key1": "no"})
for i in range(N):
await store.aput(("test", "Z"), f"Z{i}", {"key1": "no"})
results = await store.asearch(("test",), query="mmm", limit=10)
assert len(results) == 10
assert len(set(r.key for r in results)) == 10
assert results[0].key == "M"
assert results[0].score > results[1].score
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/runner.py`:
```py
import asyncio
import concurrent.futures
import time
from typing import (
Any,
AsyncIterator,
Callable,
Iterable,
Iterator,
Optional,
Sequence,
Type,
Union,
cast,
)
from langgraph.constants import (
CONF,
CONFIG_KEY_SEND,
ERROR,
INTERRUPT,
NO_WRITES,
PUSH,
RESUME,
TAG_HIDDEN,
)
from langgraph.errors import GraphBubbleUp, GraphInterrupt
from langgraph.pregel.executor import Submit
from langgraph.pregel.retry import arun_with_retry, run_with_retry
from langgraph.types import PregelExecutableTask, RetryPolicy
class PregelRunner:
"""Responsible for executing a set of Pregel tasks concurrently, committing
their writes, yielding control to caller when there is output to emit, and
interrupting other tasks if appropriate."""
def __init__(
self,
*,
submit: Submit,
put_writes: Callable[[str, Sequence[tuple[str, Any]]], None],
schedule_task: Callable[
[PregelExecutableTask, int], Optional[PregelExecutableTask]
],
use_astream: bool = False,
node_finished: Optional[Callable[[str], None]] = None,
) -> None:
self.submit = submit
self.put_writes = put_writes
self.use_astream = use_astream
self.node_finished = node_finished
self.schedule_task = schedule_task
def tick(
self,
tasks: Iterable[PregelExecutableTask],
*,
reraise: bool = True,
timeout: Optional[float] = None,
retry_policy: Optional[RetryPolicy] = None,
get_waiter: Optional[Callable[[], concurrent.futures.Future[None]]] = None,
) -> Iterator[None]:
def writer(
task: PregelExecutableTask, writes: Sequence[tuple[str, Any]]
) -> None:
prev_length = len(task.writes)
# delegate to the underlying writer
task.config[CONF][CONFIG_KEY_SEND](writes)
for idx, w in enumerate(task.writes):
# find the index for the newly inserted writes
if idx < prev_length:
continue
assert writes[idx - prev_length] is w
# bail if not a PUSH write
if w[0] != PUSH:
continue
# schedule the next task, if the callback returns one
if next_task := self.schedule_task(task, idx):
# if the parent task was retried,
# the next task might already be running
if any(
t == next_task.id for t in futures.values() if t is not None
):
continue
# schedule the next task
futures[
self.submit(
run_with_retry,
next_task,
retry_policy,
writer=writer,
__reraise_on_exit__=reraise,
)
] = next_task
tasks = tuple(tasks)
futures: dict[concurrent.futures.Future, Optional[PregelExecutableTask]] = {}
# give control back to the caller
yield
# fast path if single task with no timeout and no waiter
if len(tasks) == 1 and timeout is None and get_waiter is None:
t = tasks[0]
try:
run_with_retry(t, retry_policy, writer=writer)
self.commit(t, None)
except Exception as exc:
self.commit(t, exc)
if reraise:
raise
if not futures: # maybe `t` schuduled another task
return
# add waiter task if requested
if get_waiter is not None:
futures[get_waiter()] = None
# execute tasks, and wait for one to fail or all to finish.
# each task is independent from all other concurrent tasks
# yield updates/debug output as each task finishes
for t in tasks:
if not t.writes:
futures[
self.submit(
run_with_retry,
t,
retry_policy,
writer=writer,
__reraise_on_exit__=reraise,
)
] = t
done_futures: set[concurrent.futures.Future] = set()
end_time = timeout + time.monotonic() if timeout else None
while len(futures) > (1 if get_waiter is not None else 0):
done, inflight = concurrent.futures.wait(
futures,
return_when=concurrent.futures.FIRST_COMPLETED,
timeout=(max(0, end_time - time.monotonic()) if end_time else None),
)
if not done:
break # timed out
for fut in done:
task = futures.pop(fut)
if task is None:
# waiter task finished, schedule another
if inflight and get_waiter is not None:
futures[get_waiter()] = None
else:
# store for panic check
done_futures.add(fut)
# task finished, commit writes
self.commit(task, _exception(fut))
else:
# remove references to loop vars
del fut, task
# maybe stop other tasks
if _should_stop_others(done):
break
# give control back to the caller
yield
# panic on failure or timeout
_panic_or_proceed(
done_futures.union(f for f, t in futures.items() if t is not None),
panic=reraise,
)
async def atick(
self,
tasks: Iterable[PregelExecutableTask],
*,
reraise: bool = True,
timeout: Optional[float] = None,
retry_policy: Optional[RetryPolicy] = None,
get_waiter: Optional[Callable[[], asyncio.Future[None]]] = None,
) -> AsyncIterator[None]:
def writer(
task: PregelExecutableTask, writes: Sequence[tuple[str, Any]]
) -> None:
prev_length = len(task.writes)
# delegate to the underlying writer
task.config[CONF][CONFIG_KEY_SEND](writes)
for idx, w in enumerate(task.writes):
# find the index for the newly inserted writes
if idx < prev_length:
continue
assert writes[idx - prev_length] is w
# bail if not a PUSH write
if w[0] != PUSH:
continue
# schedule the next task, if the callback returns one
if next_task := self.schedule_task(task, idx):
# if the parent task was retried,
# the next task might already be running
if any(
t == next_task.id for t in futures.values() if t is not None
):
continue
# schedule the next task
futures[
cast(
asyncio.Future,
self.submit(
arun_with_retry,
next_task,
retry_policy,
stream=self.use_astream,
writer=writer,
__name__=t.name,
__cancel_on_exit__=True,
__reraise_on_exit__=reraise,
),
)
] = next_task
loop = asyncio.get_event_loop()
tasks = tuple(tasks)
futures: dict[asyncio.Future, Optional[PregelExecutableTask]] = {}
# give control back to the caller
yield
# fast path if single task with no waiter and no timeout
if len(tasks) == 1 and get_waiter is None and timeout is None:
t = tasks[0]
try:
await arun_with_retry(
t, retry_policy, stream=self.use_astream, writer=writer
)
self.commit(t, None)
except Exception as exc:
self.commit(t, exc)
if reraise:
raise
if not futures: # maybe `t` schuduled another task
return
# add waiter task if requested
if get_waiter is not None:
futures[get_waiter()] = None
# execute tasks, and wait for one to fail or all to finish.
# each task is independent from all other concurrent tasks
# yield updates/debug output as each task finishes
for t in tasks:
if not t.writes:
futures[
cast(
asyncio.Future,
self.submit(
arun_with_retry,
t,
retry_policy,
stream=self.use_astream,
writer=writer,
__name__=t.name,
__cancel_on_exit__=True,
__reraise_on_exit__=reraise,
),
)
] = t
done_futures: set[asyncio.Future] = set()
end_time = timeout + loop.time() if timeout else None
while len(futures) > (1 if get_waiter is not None else 0):
done, inflight = await asyncio.wait(
futures,
return_when=asyncio.FIRST_COMPLETED,
timeout=(max(0, end_time - loop.time()) if end_time else None),
)
if not done:
break # timed out
for fut in done:
task = futures.pop(fut)
if task is None:
# waiter task finished, schedule another
if inflight and get_waiter is not None:
futures[get_waiter()] = None
else:
# store for panic check
done_futures.add(fut)
# task finished, commit writes
self.commit(task, _exception(fut))
else:
# remove references to loop vars
del fut, task
# maybe stop other tasks
if _should_stop_others(done):
break
# give control back to the caller
yield
# cancel waiter task
for fut in futures:
fut.cancel()
# panic on failure or timeout
_panic_or_proceed(
done_futures.union(f for f, t in futures.items() if t is not None),
timeout_exc_cls=asyncio.TimeoutError,
panic=reraise,
)
def commit(
self, task: PregelExecutableTask, exception: Optional[BaseException]
) -> None:
if exception:
if isinstance(exception, GraphInterrupt):
# save interrupt to checkpointer
if interrupts := [(INTERRUPT, i) for i in exception.args[0]]:
if resumes := [w for w in task.writes if w[0] == RESUME]:
interrupts.extend(resumes)
self.put_writes(task.id, interrupts)
elif isinstance(exception, GraphBubbleUp):
raise exception
else:
# save error to checkpointer
self.put_writes(task.id, [(ERROR, exception)])
else:
if self.node_finished and (
task.config is None or TAG_HIDDEN not in task.config.get("tags", [])
):
self.node_finished(task.name)
if not task.writes:
# add no writes marker
task.writes.append((NO_WRITES, None))
# save task writes to checkpointer
self.put_writes(task.id, task.writes)
def _should_stop_others(
done: Union[set[concurrent.futures.Future[Any]], set[asyncio.Future[Any]]],
) -> bool:
"""Check if any task failed, if so, cancel all other tasks.
GraphInterrupts are not considered failures."""
for fut in done:
if fut.cancelled():
return True
if exc := fut.exception():
return not isinstance(exc, GraphBubbleUp)
else:
return False
def _exception(
fut: Union[concurrent.futures.Future[Any], asyncio.Future[Any]],
) -> Optional[BaseException]:
"""Return the exception from a future, without raising CancelledError."""
if fut.cancelled():
if isinstance(fut, asyncio.Future):
return asyncio.CancelledError()
else:
return concurrent.futures.CancelledError()
else:
return fut.exception()
def _panic_or_proceed(
futs: Union[set[concurrent.futures.Future], set[asyncio.Future]],
*,
timeout_exc_cls: Type[Exception] = TimeoutError,
panic: bool = True,
) -> None:
"""Cancel remaining tasks if any failed, re-raise exception if panic is True."""
done: set[Union[concurrent.futures.Future[Any], asyncio.Future[Any]]] = set()
inflight: set[Union[concurrent.futures.Future[Any], asyncio.Future[Any]]] = set()
for fut in futs:
if fut.done():
done.add(fut)
else:
inflight.add(fut)
while done:
# if any task failed
if exc := _exception(done.pop()):
# cancel all pending tasks
while inflight:
inflight.pop().cancel()
# raise the exception
if panic:
raise exc
else:
return
if inflight:
# if we got here means we timed out
while inflight:
# cancel all pending tasks
inflight.pop().cancel()
# raise timeout error
raise timeout_exc_cls("Timed out")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/write.py`:
```py
from __future__ import annotations
from typing import (
Any,
Callable,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
cast,
)
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.utils import ConfigurableFieldSpec
from langgraph.constants import CONF, CONFIG_KEY_SEND, FF_SEND_V2, PUSH, TASKS, Send
from langgraph.errors import InvalidUpdateError
from langgraph.utils.runnable import RunnableCallable
TYPE_SEND = Callable[[Sequence[tuple[str, Any]]], None]
R = TypeVar("R", bound=Runnable)
SKIP_WRITE = object()
PASSTHROUGH = object()
class ChannelWriteEntry(NamedTuple):
channel: str
"""Channel name to write to."""
value: Any = PASSTHROUGH
"""Value to write, or PASSTHROUGH to use the input."""
skip_none: bool = False
"""Whether to skip writing if the value is None."""
mapper: Optional[Callable] = None
"""Function to transform the value before writing."""
class ChannelWrite(RunnableCallable):
"""Implements th logic for sending writes to CONFIG_KEY_SEND.
Can be used as a runnable or as a static method to call imperatively."""
writes: list[Union[ChannelWriteEntry, Send]]
"""Sequence of write entries or Send objects to write."""
require_at_least_one_of: Optional[Sequence[str]]
"""If defined, at least one of these channels must be written to."""
def __init__(
self,
writes: Sequence[Union[ChannelWriteEntry, Send]],
*,
tags: Optional[Sequence[str]] = None,
require_at_least_one_of: Optional[Sequence[str]] = None,
):
super().__init__(func=self._write, afunc=self._awrite, name=None, tags=tags)
self.writes = cast(list[Union[ChannelWriteEntry, Send]], writes)
self.require_at_least_one_of = require_at_least_one_of
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
if not name:
name = f"ChannelWrite<{','.join(w.channel if isinstance(w, ChannelWriteEntry) else w.node for w in self.writes)}>"
return super().get_name(suffix, name=name)
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
return [
ConfigurableFieldSpec(
id=CONFIG_KEY_SEND,
name=CONFIG_KEY_SEND,
description=None,
default=None,
annotation=None,
),
]
def _write(self, input: Any, config: RunnableConfig) -> None:
writes = [
ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
else write
for write in self.writes
]
self.do_write(
config,
writes,
self.require_at_least_one_of if input is not None else None,
)
return input
async def _awrite(self, input: Any, config: RunnableConfig) -> None:
writes = [
ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
else write
for write in self.writes
]
self.do_write(
config,
writes,
self.require_at_least_one_of if input is not None else None,
)
return input
@staticmethod
def do_write(
config: RunnableConfig,
writes: Sequence[Union[ChannelWriteEntry, Send]],
require_at_least_one_of: Optional[Sequence[str]] = None,
) -> None:
# validate
for w in writes:
if isinstance(w, ChannelWriteEntry):
if w.channel in (TASKS, PUSH):
raise InvalidUpdateError(
"Cannot write to the reserved channel TASKS"
)
if w.value is PASSTHROUGH:
raise InvalidUpdateError("PASSTHROUGH value must be replaced")
# split packets and entries
sends = [
(PUSH if FF_SEND_V2 else TASKS, packet)
for packet in writes
if isinstance(packet, Send)
]
entries = [write for write in writes if isinstance(write, ChannelWriteEntry)]
# process entries into values
values = [
write.mapper(write.value) if write.mapper is not None else write.value
for write in entries
]
values = [
(write.channel, val)
for val, write in zip(values, entries)
if not write.skip_none or val is not None
]
# filter out SKIP_WRITE values
filtered = [(chan, val) for chan, val in values if val is not SKIP_WRITE]
if require_at_least_one_of is not None:
if not {chan for chan, _ in filtered} & set(require_at_least_one_of):
raise InvalidUpdateError(
f"Must write to at least one of {require_at_least_one_of}"
)
write: TYPE_SEND = config[CONF][CONFIG_KEY_SEND]
write(sends + filtered)
@staticmethod
def is_writer(runnable: Runnable) -> bool:
"""Used by PregelNode to distinguish between writers and other runnables."""
return (
isinstance(runnable, ChannelWrite)
or getattr(runnable, "_is_channel_writer", False) is True
)
@staticmethod
def register_writer(runnable: R) -> R:
"""Used to mark a runnable as a writer, so that it can be detected by is_writer.
Instances of ChannelWrite are automatically marked as writers."""
# using object.__setattr__ to work around objects that override __setattr__
# eg. pydantic models and dataclasses
object.__setattr__(runnable, "_is_channel_writer", True)
return runnable
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/log.py`:
```py
import logging
logger = logging.getLogger("langgraph")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/remote.py`:
```py
from dataclasses import asdict
from typing import (
Any,
AsyncIterator,
Iterator,
Literal,
Optional,
Sequence,
Union,
cast,
)
import orjson
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.graph import (
Edge as DrawableEdge,
)
from langchain_core.runnables.graph import (
Graph as DrawableGraph,
)
from langchain_core.runnables.graph import (
Node as DrawableNode,
)
from langgraph_sdk.client import (
LangGraphClient,
SyncLangGraphClient,
get_client,
get_sync_client,
)
from langgraph_sdk.schema import Checkpoint, ThreadState
from langgraph_sdk.schema import Command as CommandSDK
from langgraph_sdk.schema import StreamMode as StreamModeSDK
from typing_extensions import Self
from langgraph.checkpoint.base import CheckpointMetadata
from langgraph.constants import (
CONF,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_STREAM,
INTERRUPT,
NS_SEP,
)
from langgraph.errors import GraphInterrupt
from langgraph.pregel.protocol import PregelProtocol
from langgraph.pregel.types import All, PregelTask, StateSnapshot, StreamMode
from langgraph.types import Command, Interrupt, StreamProtocol
from langgraph.utils.config import merge_configs
class RemoteException(Exception):
"""Exception raised when an error occurs in the remote graph."""
pass
class RemoteGraph(PregelProtocol):
"""The `RemoteGraph` class is a client implementation for calling remote
APIs that implement the LangGraph Server API specification.
For example, the `RemoteGraph` class can be used to call APIs from deployments
on LangGraph Cloud.
`RemoteGraph` behaves the same way as a `Graph` and can be used directly as
a node in another `Graph`.
"""
name: str
def __init__(
self,
name: str, # graph_id
/,
*,
url: Optional[str] = None,
api_key: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
client: Optional[LangGraphClient] = None,
sync_client: Optional[SyncLangGraphClient] = None,
config: Optional[RunnableConfig] = None,
):
"""Specify `url`, `api_key`, and/or `headers` to create default sync and async clients.
If `client` or `sync_client` are provided, they will be used instead of the default clients.
See `LangGraphClient` and `SyncLangGraphClient` for details on the default clients. At least
one of `url`, `client`, or `sync_client` must be provided.
Args:
name: The name of the graph.
url: The URL of the remote API.
api_key: The API key to use for authentication. If not provided, it will be read from the environment (`LANGGRAPH_API_KEY`, `LANGSMITH_API_KEY`, or `LANGCHAIN_API_KEY`).
headers: Additional headers to include in the requests.
client: A `LangGraphClient` instance to use instead of creating a default client.
sync_client: A `SyncLangGraphClient` instance to use instead of creating a default client.
config: An optional `RunnableConfig` instance with additional configuration.
"""
self.name = name
self.config = config
if client is None and url is not None:
client = get_client(url=url, api_key=api_key, headers=headers)
self.client = client
if sync_client is None and url is not None:
sync_client = get_sync_client(url=url, api_key=api_key, headers=headers)
self.sync_client = sync_client
def _validate_client(self) -> LangGraphClient:
if self.client is None:
raise ValueError(
"Async client is not initialized: please provide `url` or `client` when initializing `RemoteGraph`."
)
return self.client
def _validate_sync_client(self) -> SyncLangGraphClient:
if self.sync_client is None:
raise ValueError(
"Sync client is not initialized: please provide `url` or `sync_client` when initializing `RemoteGraph`."
)
return self.sync_client
def copy(self, update: dict[str, Any]) -> Self:
attrs = {**self.__dict__, **update}
return self.__class__(attrs.pop("name"), **attrs)
def with_config(
self, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Self:
return self.copy(
{"config": merge_configs(self.config, config, cast(RunnableConfig, kwargs))}
)
def _get_drawable_nodes(
self, graph: dict[str, list[dict[str, Any]]]
) -> dict[str, DrawableNode]:
nodes = {}
for node in graph["nodes"]:
node_id = str(node["id"])
node_data = node.get("data", {})
# Get node name from node_data if available. If not, use node_id.
node_name = node.get("name")
if node_name is None:
if isinstance(node_data, dict):
node_name = node_data.get("name", node_id)
else:
node_name = node_id
nodes[node_id] = DrawableNode(
id=node_id,
name=node_name,
data=node_data,
metadata=node.get("metadata"),
)
return nodes
def get_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph:
"""Get graph by graph name.
This method calls `GET /assistants/{assistant_id}/graph`.
Args:
config: This parameter is not used.
xray: Include graph representation of subgraphs. If an integer
value is provided, only subgraphs with a depth less than or
equal to the value will be included.
Returns:
The graph information for the assistant in JSON format.
"""
sync_client = self._validate_sync_client()
graph = sync_client.assistants.get_graph(
assistant_id=self.name,
xray=xray,
)
return DrawableGraph(
nodes=self._get_drawable_nodes(graph),
edges=[DrawableEdge(**edge) for edge in graph["edges"]],
)
async def aget_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph:
"""Get graph by graph name.
This method calls `GET /assistants/{assistant_id}/graph`.
Args:
config: This parameter is not used.
xray: Include graph representation of subgraphs. If an integer
value is provided, only subgraphs with a depth less than or
equal to the value will be included.
Returns:
The graph information for the assistant in JSON format.
"""
client = self._validate_client()
graph = await client.assistants.get_graph(
assistant_id=self.name,
xray=xray,
)
return DrawableGraph(
nodes=self._get_drawable_nodes(graph),
edges=[DrawableEdge(**edge) for edge in graph["edges"]],
)
def _create_state_snapshot(self, state: ThreadState) -> StateSnapshot:
tasks = []
for task in state["tasks"]:
interrupts = []
for interrupt in task["interrupts"]:
interrupts.append(Interrupt(**interrupt))
tasks.append(
PregelTask(
id=task["id"],
name=task["name"],
path=tuple(),
error=Exception(task["error"]) if task["error"] else None,
interrupts=tuple(interrupts),
state=self._create_state_snapshot(task["state"])
if task["state"]
else cast(RunnableConfig, {"configurable": task["checkpoint"]})
if task["checkpoint"]
else None,
result=task.get("result"),
)
)
return StateSnapshot(
values=state["values"],
next=tuple(state["next"]) if state["next"] else tuple(),
config={
"configurable": {
"thread_id": state["checkpoint"]["thread_id"],
"checkpoint_ns": state["checkpoint"]["checkpoint_ns"],
"checkpoint_id": state["checkpoint"]["checkpoint_id"],
"checkpoint_map": state["checkpoint"].get("checkpoint_map", {}),
}
},
metadata=CheckpointMetadata(**state["metadata"]),
created_at=state["created_at"],
parent_config={
"configurable": {
"thread_id": state["parent_checkpoint"]["thread_id"],
"checkpoint_ns": state["parent_checkpoint"]["checkpoint_ns"],
"checkpoint_id": state["parent_checkpoint"]["checkpoint_id"],
"checkpoint_map": state["parent_checkpoint"].get(
"checkpoint_map", {}
),
}
}
if state["parent_checkpoint"]
else None,
tasks=tuple(tasks),
)
def _get_checkpoint(self, config: Optional[RunnableConfig]) -> Optional[Checkpoint]:
if config is None:
return None
checkpoint = {}
if "thread_id" in config["configurable"]:
checkpoint["thread_id"] = config["configurable"]["thread_id"]
if "checkpoint_ns" in config["configurable"]:
checkpoint["checkpoint_ns"] = config["configurable"]["checkpoint_ns"]
if "checkpoint_id" in config["configurable"]:
checkpoint["checkpoint_id"] = config["configurable"]["checkpoint_id"]
if "checkpoint_map" in config["configurable"]:
checkpoint["checkpoint_map"] = config["configurable"]["checkpoint_map"]
return checkpoint if checkpoint else None
def _get_config(self, checkpoint: Checkpoint) -> RunnableConfig:
return {
"configurable": {
"thread_id": checkpoint["thread_id"],
"checkpoint_ns": checkpoint["checkpoint_ns"],
"checkpoint_id": checkpoint["checkpoint_id"],
"checkpoint_map": checkpoint.get("checkpoint_map", {}),
}
}
def _sanitize_config(self, config: RunnableConfig) -> RunnableConfig:
reserved_configurable_keys = frozenset(
[
"callbacks",
"checkpoint_map",
"checkpoint_id",
"checkpoint_ns",
]
)
def _sanitize_obj(obj: Any) -> Any:
"""Remove non-JSON serializable fields from the given object."""
if isinstance(obj, dict):
return {k: _sanitize_obj(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_sanitize_obj(v) for v in obj]
else:
try:
orjson.dumps(obj)
return obj
except orjson.JSONEncodeError:
return None
# Remove non-JSON serializable fields from the config.
config = _sanitize_obj(config)
# Only include configurable keys that are not reserved and
# not starting with "__pregel_" prefix.
new_configurable = {
k: v
for k, v in config["configurable"].items()
if k not in reserved_configurable_keys and not k.startswith("__pregel_")
}
return {
"tags": config.get("tags") or [],
"metadata": config.get("metadata") or {},
"configurable": new_configurable,
}
def get_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
"""Get the state of a thread.
This method calls `POST /threads/{thread_id}/state/checkpoint` if a
checkpoint is specified in the config or `GET /threads/{thread_id}/state`
if no checkpoint is specified.
Args:
config: A `RunnableConfig` that includes `thread_id` in the
`configurable` field.
subgraphs: Include subgraphs in the state.
Returns:
The latest state of the thread.
"""
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)
state = sync_client.threads.get_state(
thread_id=merged_config["configurable"]["thread_id"],
checkpoint=self._get_checkpoint(merged_config),
subgraphs=subgraphs,
)
return self._create_state_snapshot(state)
async def aget_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
"""Get the state of a thread.
This method calls `POST /threads/{thread_id}/state/checkpoint` if a
checkpoint is specified in the config or `GET /threads/{thread_id}/state`
if no checkpoint is specified.
Args:
config: A `RunnableConfig` that includes `thread_id` in the
`configurable` field.
subgraphs: Include subgraphs in the state.
Returns:
The latest state of the thread.
"""
client = self._validate_client()
merged_config = merge_configs(self.config, config)
state = await client.threads.get_state(
thread_id=merged_config["configurable"]["thread_id"],
checkpoint=self._get_checkpoint(merged_config),
subgraphs=subgraphs,
)
return self._create_state_snapshot(state)
def get_state_history(
self,
config: RunnableConfig,
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[StateSnapshot]:
"""Get the state history of a thread.
This method calls `POST /threads/{thread_id}/history`.
Args:
config: A `RunnableConfig` that includes `thread_id` in the
`configurable` field.
filter: Metadata to filter on.
before: A `RunnableConfig` that includes checkpoint metadata.
limit: Max number of states to return.
Returns:
States of the thread.
"""
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)
states = sync_client.threads.get_history(
thread_id=merged_config["configurable"]["thread_id"],
limit=limit if limit else 10,
before=self._get_checkpoint(before),
metadata=filter,
checkpoint=self._get_checkpoint(merged_config),
)
for state in states:
yield self._create_state_snapshot(state)
async def aget_state_history(
self,
config: RunnableConfig,
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[StateSnapshot]:
"""Get the state history of a thread.
This method calls `POST /threads/{thread_id}/history`.
Args:
config: A `RunnableConfig` that includes `thread_id` in the
`configurable` field.
filter: Metadata to filter on.
before: A `RunnableConfig` that includes checkpoint metadata.
limit: Max number of states to return.
Returns:
States of the thread.
"""
client = self._validate_client()
merged_config = merge_configs(self.config, config)
states = await client.threads.get_history(
thread_id=merged_config["configurable"]["thread_id"],
limit=limit if limit else 10,
before=self._get_checkpoint(before),
metadata=filter,
checkpoint=self._get_checkpoint(merged_config),
)
for state in states:
yield self._create_state_snapshot(state)
def update_state(
self,
config: RunnableConfig,
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig:
"""Update the state of a thread.
This method calls `POST /threads/{thread_id}/state`.
Args:
config: A `RunnableConfig` that includes `thread_id` in the
`configurable` field.
values: Values to update to the state.
as_node: Update the state as if this node had just executed.
Returns:
`RunnableConfig` for the updated thread.
"""
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)
response: dict = sync_client.threads.update_state( # type: ignore
thread_id=merged_config["configurable"]["thread_id"],
values=values,
as_node=as_node,
checkpoint=self._get_checkpoint(merged_config),
)
return self._get_config(response["checkpoint"])
async def aupdate_state(
self,
config: RunnableConfig,
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig:
"""Update the state of a thread.
This method calls `POST /threads/{thread_id}/state`.
Args:
config: A `RunnableConfig` that includes `thread_id` in the
`configurable` field.
values: Values to update to the state.
as_node: Update the state as if this node had just executed.
Returns:
`RunnableConfig` for the updated thread.
"""
client = self._validate_client()
merged_config = merge_configs(self.config, config)
response: dict = await client.threads.update_state( # type: ignore
thread_id=merged_config["configurable"]["thread_id"],
values=values,
as_node=as_node,
checkpoint=self._get_checkpoint(merged_config),
)
return self._get_config(response["checkpoint"])
def _get_stream_modes(
self,
stream_mode: Optional[Union[StreamMode, list[StreamMode]]],
config: Optional[RunnableConfig],
default: StreamMode = "updates",
) -> tuple[
list[StreamModeSDK], list[StreamModeSDK], bool, Optional[StreamProtocol]
]:
"""Return a tuple of the final list of stream modes sent to the
remote graph and a boolean flag indicating if stream mode 'updates'
was present in the original list of stream modes.
'updates' mode is added to the list of stream modes so that interrupts
can be detected in the remote graph.
"""
updated_stream_modes: list[StreamModeSDK] = []
req_single = True
# coerce to list, or add default stream mode
if stream_mode:
if isinstance(stream_mode, str):
updated_stream_modes.append(stream_mode)
else:
req_single = False
updated_stream_modes.extend(stream_mode)
else:
updated_stream_modes.append(default)
requested_stream_modes = updated_stream_modes.copy()
# add any from parent graph
stream: Optional[StreamProtocol] = (
(config or {}).get(CONF, {}).get(CONFIG_KEY_STREAM)
)
if stream:
updated_stream_modes.extend(stream.modes)
# map "messages" to "messages-tuple"
if "messages" in updated_stream_modes:
updated_stream_modes.remove("messages")
updated_stream_modes.append("messages-tuple")
# if requested "messages-tuple",
# map to "messages" in requested_stream_modes
if "messages-tuple" in requested_stream_modes:
requested_stream_modes.remove("messages-tuple")
requested_stream_modes.append("messages")
# add 'updates' mode if not present
if "updates" not in updated_stream_modes:
updated_stream_modes.append("updates")
# remove 'events', as it's not supported in Pregel
if "events" in updated_stream_modes:
updated_stream_modes.remove("events")
return (updated_stream_modes, requested_stream_modes, req_single, stream)
def stream(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
subgraphs: bool = False,
**kwargs: Any,
) -> Iterator[Union[dict[str, Any], Any]]:
"""Create a run and stream the results.
This method calls `POST /threads/{thread_id}/runs/stream` if a `thread_id`
is speciffed in the `configurable` field of the config or
`POST /runs/stream` otherwise.
Args:
input: Input to the graph.
config: A `RunnableConfig` for graph invocation.
stream_mode: Stream mode(s) to use.
interrupt_before: Interrupt the graph before these nodes.
interrupt_after: Interrupt the graph after these nodes.
subgraphs: Stream from subgraphs.
**kwargs: Additional params to pass to client.runs.stream.
Yields:
The output of the graph.
"""
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
stream_modes, requested, req_single, stream = self._get_stream_modes(
stream_mode, config
)
if isinstance(input, Command):
command: Optional[CommandSDK] = cast(CommandSDK, asdict(input))
input = None
else:
command = None
for chunk in sync_client.runs.stream(
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
command=command,
config=sanitized_config,
stream_mode=stream_modes,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
stream_subgraphs=subgraphs or stream is not None,
if_not_exists="create",
**kwargs,
):
# split mode and ns
if NS_SEP in chunk.event:
mode, ns_ = chunk.event.split(NS_SEP, 1)
ns = tuple(ns_.split(NS_SEP))
else:
mode, ns = chunk.event, ()
# prepend caller ns (as it is not passed to remote graph)
if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS):
caller_ns = tuple(caller_ns.split(NS_SEP))
ns = caller_ns + ns
# stream to parent stream
if stream is not None and mode in stream.modes:
stream((ns, mode, chunk.data))
# raise interrupt or errors
if chunk.event.startswith("updates"):
if isinstance(chunk.data, dict) and INTERRUPT in chunk.data:
raise GraphInterrupt(chunk.data[INTERRUPT])
elif chunk.event.startswith("error"):
raise RemoteException(chunk.data)
# filter for what was actually requested
if mode not in requested:
continue
# emit chunk
if subgraphs:
if NS_SEP in chunk.event:
mode, ns_ = chunk.event.split(NS_SEP, 1)
ns = tuple(ns_.split(NS_SEP))
else:
mode, ns = chunk.event, ()
if req_single:
yield ns, chunk.data
else:
yield ns, mode, chunk.data
elif req_single:
yield chunk.data
else:
yield chunk
async def astream(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
subgraphs: bool = False,
**kwargs: Any,
) -> AsyncIterator[Union[dict[str, Any], Any]]:
"""Create a run and stream the results.
This method calls `POST /threads/{thread_id}/runs/stream` if a `thread_id`
is speciffed in the `configurable` field of the config or
`POST /runs/stream` otherwise.
Args:
input: Input to the graph.
config: A `RunnableConfig` for graph invocation.
stream_mode: Stream mode(s) to use.
interrupt_before: Interrupt the graph before these nodes.
interrupt_after: Interrupt the graph after these nodes.
subgraphs: Stream from subgraphs.
**kwargs: Additional params to pass to client.runs.stream.
Yields:
The output of the graph.
"""
client = self._validate_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
stream_modes, requested, req_single, stream = self._get_stream_modes(
stream_mode, config
)
if isinstance(input, Command):
command: Optional[CommandSDK] = cast(CommandSDK, asdict(input))
input = None
else:
command = None
async for chunk in client.runs.stream(
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
command=command,
config=sanitized_config,
stream_mode=stream_modes,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
stream_subgraphs=subgraphs or stream is not None,
if_not_exists="create",
**kwargs,
):
# split mode and ns
if NS_SEP in chunk.event:
mode, ns_ = chunk.event.split(NS_SEP, 1)
ns = tuple(ns_.split(NS_SEP))
else:
mode, ns = chunk.event, ()
# prepend caller ns (as it is not passed to remote graph)
if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS):
caller_ns = tuple(caller_ns.split(NS_SEP))
ns = caller_ns + ns
# stream to parent stream
if stream is not None and mode in stream.modes:
stream((ns, mode, chunk.data))
# raise interrupt or errors
if chunk.event.startswith("updates"):
if isinstance(chunk.data, dict) and INTERRUPT in chunk.data:
raise GraphInterrupt(chunk.data[INTERRUPT])
elif chunk.event.startswith("error"):
raise RemoteException(chunk.data)
# filter for what was actually requested
if mode not in requested:
continue
# emit chunk
if subgraphs:
if NS_SEP in chunk.event:
mode, ns_ = chunk.event.split(NS_SEP, 1)
ns = tuple(ns_.split(NS_SEP))
else:
mode, ns = chunk.event, ()
if req_single:
yield ns, chunk.data
else:
yield ns, mode, chunk.data
elif req_single:
yield chunk.data
else:
yield chunk
async def astream_events(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1", "v2"],
include_names: Optional[Sequence[All]] = None,
include_types: Optional[Sequence[All]] = None,
include_tags: Optional[Sequence[All]] = None,
exclude_names: Optional[Sequence[All]] = None,
exclude_types: Optional[Sequence[All]] = None,
exclude_tags: Optional[Sequence[All]] = None,
**kwargs: Any,
) -> AsyncIterator[dict[str, Any]]:
raise NotImplementedError
def invoke(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
**kwargs: Any,
) -> Union[dict[str, Any], Any]:
"""Create a run, wait until it finishes and return the final state.
Args:
input: Input to the graph.
config: A `RunnableConfig` for graph invocation.
interrupt_before: Interrupt the graph before these nodes.
interrupt_after: Interrupt the graph after these nodes.
**kwargs: Additional params to pass to RemoteGraph.stream.
Returns:
The output of the graph.
"""
for chunk in self.stream(
input,
config=config,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
stream_mode="values",
**kwargs,
):
pass
try:
return chunk
except UnboundLocalError:
return None
async def ainvoke(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
**kwargs: Any,
) -> Union[dict[str, Any], Any]:
"""Create a run, wait until it finishes and return the final state.
Args:
input: Input to the graph.
config: A `RunnableConfig` for graph invocation.
interrupt_before: Interrupt the graph before these nodes.
interrupt_after: Interrupt the graph after these nodes.
**kwargs: Additional params to pass to RemoteGraph.astream.
Returns:
The output of the graph.
"""
async for chunk in self.astream(
input,
config=config,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
stream_mode="values",
**kwargs,
):
pass
try:
return chunk
except UnboundLocalError:
return None
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/protocol.py`:
```py
from abc import ABC, abstractmethod
from typing import (
Any,
AsyncIterator,
Iterator,
Optional,
Sequence,
Union,
)
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.graph import Graph as DrawableGraph
from typing_extensions import Self
from langgraph.pregel.types import All, StateSnapshot, StreamMode
class PregelProtocol(
Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]], ABC
):
@abstractmethod
def with_config(
self, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Self: ...
@abstractmethod
def get_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph: ...
@abstractmethod
async def aget_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph: ...
@abstractmethod
def get_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot: ...
@abstractmethod
async def aget_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot: ...
@abstractmethod
def get_state_history(
self,
config: RunnableConfig,
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[StateSnapshot]: ...
@abstractmethod
def aget_state_history(
self,
config: RunnableConfig,
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[StateSnapshot]: ...
@abstractmethod
def update_state(
self,
config: RunnableConfig,
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig: ...
@abstractmethod
async def aupdate_state(
self,
config: RunnableConfig,
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig: ...
@abstractmethod
def stream(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
subgraphs: bool = False,
) -> Iterator[Union[dict[str, Any], Any]]: ...
@abstractmethod
def astream(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
subgraphs: bool = False,
) -> AsyncIterator[Union[dict[str, Any], Any]]: ...
@abstractmethod
def invoke(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
) -> Union[dict[str, Any], Any]: ...
@abstractmethod
async def ainvoke(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
) -> Union[dict[str, Any], Any]: ...
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/io.py`:
```py
from typing import Any, Iterator, Literal, Mapping, Optional, Sequence, TypeVar, Union
from uuid import UUID
from langchain_core.runnables.utils import AddableDict
from langgraph.channels.base import BaseChannel, EmptyChannelError
from langgraph.checkpoint.base import PendingWrite
from langgraph.constants import (
EMPTY_SEQ,
ERROR,
FF_SEND_V2,
INTERRUPT,
NULL_TASK_ID,
PUSH,
RESUME,
TAG_HIDDEN,
TASKS,
)
from langgraph.errors import InvalidUpdateError
from langgraph.pregel.log import logger
from langgraph.types import Command, PregelExecutableTask, Send
def is_task_id(task_id: str) -> bool:
"""Check if a string is a valid task id."""
try:
UUID(task_id)
except ValueError:
return False
return True
def read_channel(
channels: Mapping[str, BaseChannel],
chan: str,
*,
catch: bool = True,
return_exception: bool = False,
) -> Any:
try:
return channels[chan].get()
except EmptyChannelError as exc:
if return_exception:
return exc
elif catch:
return None
else:
raise
def read_channels(
channels: Mapping[str, BaseChannel],
select: Union[Sequence[str], str],
*,
skip_empty: bool = True,
) -> Union[dict[str, Any], Any]:
if isinstance(select, str):
return read_channel(channels, select)
else:
values: dict[str, Any] = {}
for k in select:
try:
values[k] = read_channel(channels, k, catch=not skip_empty)
except EmptyChannelError:
pass
return values
def map_command(
cmd: Command, pending_writes: list[PendingWrite]
) -> Iterator[tuple[str, str, Any]]:
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if cmd.graph == Command.PARENT:
raise InvalidUpdateError("There is not parent graph")
if cmd.goto:
if isinstance(cmd.goto, (tuple, list)):
sends = cmd.goto
else:
sends = [cmd.goto]
for send in sends:
if not isinstance(send, Send):
raise TypeError(
f"In Command.goto, expected Send, got {type(send).__name__}"
)
yield (NULL_TASK_ID, PUSH if FF_SEND_V2 else TASKS, send)
# TODO handle goto str for state graph
if cmd.resume:
if isinstance(cmd.resume, dict) and all(is_task_id(k) for k in cmd.resume):
for tid, resume in cmd.resume.items():
existing: list[Any] = next(
(w[2] for w in pending_writes if w[0] == tid and w[1] == RESUME), []
)
existing.append(resume)
yield (tid, RESUME, existing)
else:
yield (NULL_TASK_ID, RESUME, cmd.resume)
if cmd.update:
if not isinstance(cmd.update, dict):
raise TypeError(
f"Expected cmd.update to be a dict mapping channel names to update values, got {type(cmd.update).__name__}"
)
for k, v in cmd.update.items():
yield (NULL_TASK_ID, k, v)
def map_input(
input_channels: Union[str, Sequence[str]],
chunk: Optional[Union[dict[str, Any], Any]],
) -> Iterator[tuple[str, Any]]:
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if chunk is None:
return
elif isinstance(input_channels, str):
yield (input_channels, chunk)
else:
if not isinstance(chunk, dict):
raise TypeError(f"Expected chunk to be a dict, got {type(chunk).__name__}")
for k in chunk:
if k in input_channels:
yield (k, chunk[k])
else:
logger.warning(f"Input channel {k} not found in {input_channels}")
class AddableValuesDict(AddableDict):
def __add__(self, other: dict[str, Any]) -> "AddableValuesDict":
return self | other
def __radd__(self, other: dict[str, Any]) -> "AddableValuesDict":
return other | self
def map_output_values(
output_channels: Union[str, Sequence[str]],
pending_writes: Union[Literal[True], Sequence[tuple[str, Any]]],
channels: Mapping[str, BaseChannel],
) -> Iterator[Union[dict[str, Any], Any]]:
"""Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
if isinstance(output_channels, str):
if pending_writes is True or any(
chan == output_channels for chan, _ in pending_writes
):
yield read_channel(channels, output_channels)
else:
if pending_writes is True or {
c for c, _ in pending_writes if c in output_channels
}:
yield AddableValuesDict(read_channels(channels, output_channels))
class AddableUpdatesDict(AddableDict):
def __add__(self, other: dict[str, Any]) -> "AddableUpdatesDict":
return [self, other]
def __radd__(self, other: dict[str, Any]) -> "AddableUpdatesDict":
raise TypeError("AddableUpdatesDict does not support right-side addition")
def map_output_updates(
output_channels: Union[str, Sequence[str]],
tasks: list[tuple[PregelExecutableTask, Sequence[tuple[str, Any]]]],
cached: bool = False,
) -> Iterator[dict[str, Union[Any, dict[str, Any]]]]:
"""Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
output_tasks = [
(t, ww)
for t, ww in tasks
if (not t.config or TAG_HIDDEN not in t.config.get("tags", EMPTY_SEQ))
and ww[0][0] != ERROR
and ww[0][0] != INTERRUPT
]
if not output_tasks:
return
if isinstance(output_channels, str):
updated = (
(task.name, value)
for task, writes in output_tasks
for chan, value in writes
if chan == output_channels
)
else:
updated = (
(
task.name,
{chan: value for chan, value in writes if chan in output_channels},
)
for task, writes in output_tasks
if any(chan in output_channels for chan, _ in writes)
)
grouped: dict[str, list[Any]] = {t.name: [] for t, _ in output_tasks}
for node, value in updated:
grouped[node].append(value)
for node, value in grouped.items():
if len(value) == 0:
grouped[node] = None # type: ignore[assignment]
if len(value) == 1:
grouped[node] = value[0]
if cached:
grouped["__metadata__"] = {"cached": cached} # type: ignore[assignment]
yield AddableUpdatesDict(grouped)
T = TypeVar("T")
def single(iter: Iterator[T]) -> Optional[T]:
for item in iter:
return item
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/__init__.py`:
```py
from __future__ import annotations
import asyncio
import concurrent
import concurrent.futures
import queue
from collections import deque
from functools import partial
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
get_type_hints,
overload,
)
from uuid import UUID, uuid5
from langchain_core.globals import get_debug
from langchain_core.runnables import (
RunnableSequence,
)
from langchain_core.runnables.base import Input, Output
from langchain_core.runnables.config import (
RunnableConfig,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
get_unique_config_specs,
)
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from pydantic import BaseModel
from typing_extensions import Self
from langgraph.channels.base import (
BaseChannel,
)
from langgraph.checkpoint.base import (
BaseCheckpointSaver,
CheckpointTuple,
copy_checkpoint,
create_checkpoint,
empty_checkpoint,
)
from langgraph.constants import (
CONF,
CONFIG_KEY_CHECKPOINT_ID,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_NODE_FINISHED,
CONFIG_KEY_READ,
CONFIG_KEY_RESUMING,
CONFIG_KEY_SEND,
CONFIG_KEY_STORE,
CONFIG_KEY_STREAM,
CONFIG_KEY_STREAM_WRITER,
CONFIG_KEY_TASK_ID,
END,
ERROR,
INPUT,
INTERRUPT,
NS_END,
NS_SEP,
NULL_TASK_ID,
PUSH,
SCHEDULED,
)
from langgraph.errors import (
ErrorCode,
GraphRecursionError,
InvalidUpdateError,
create_error_message,
)
from langgraph.managed.base import ManagedValueSpec
from langgraph.pregel.algo import (
PregelTaskWrites,
apply_writes,
local_read,
local_write,
prepare_next_tasks,
)
from langgraph.pregel.debug import tasks_w_writes
from langgraph.pregel.io import read_channels
from langgraph.pregel.loop import AsyncPregelLoop, StreamProtocol, SyncPregelLoop
from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager
from langgraph.pregel.messages import StreamMessagesHandler
from langgraph.pregel.protocol import PregelProtocol
from langgraph.pregel.read import PregelNode
from langgraph.pregel.retry import RetryPolicy
from langgraph.pregel.runner import PregelRunner
from langgraph.pregel.utils import find_subgraph_pregel, get_new_channel_versions
from langgraph.pregel.validate import validate_graph, validate_keys
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import (
All,
Checkpointer,
LoopProtocol,
StateSnapshot,
StreamChunk,
StreamMode,
)
from langgraph.utils.config import (
ensure_config,
merge_configs,
patch_checkpoint_map,
patch_config,
patch_configurable,
)
from langgraph.utils.pydantic import create_model
from langgraph.utils.queue import AsyncQueue, SyncQueue # type: ignore[attr-defined]
WriteValue = Union[Callable[[Input], Output], Any]
class Channel:
@overload
@classmethod
def subscribe_to(
cls,
channels: str,
*,
key: Optional[str] = None,
tags: Optional[list[str]] = None,
) -> PregelNode: ...
@overload
@classmethod
def subscribe_to(
cls,
channels: Sequence[str],
*,
key: None = None,
tags: Optional[list[str]] = None,
) -> PregelNode: ...
@classmethod
def subscribe_to(
cls,
channels: Union[str, Sequence[str]],
*,
key: Optional[str] = None,
tags: Optional[list[str]] = None,
) -> PregelNode:
"""Runs process.invoke() each time channels are updated,
with a dict of the channel values as input."""
if not isinstance(channels, str) and key is not None:
raise ValueError(
"Can't specify a key when subscribing to multiple channels"
)
return PregelNode(
channels=cast(
Union[list[str], Mapping[str, str]],
(
{key: channels}
if isinstance(channels, str) and key is not None
else (
[channels]
if isinstance(channels, str)
else {chan: chan for chan in channels}
)
),
),
triggers=[channels] if isinstance(channels, str) else channels,
tags=tags,
)
@classmethod
def write_to(
cls,
*channels: str,
**kwargs: WriteValue,
) -> ChannelWrite:
"""Writes to channels the result of the lambda, or None to skip writing."""
return ChannelWrite(
[ChannelWriteEntry(c) for c in channels]
+ [
(
ChannelWriteEntry(k, mapper=v)
if callable(v)
else ChannelWriteEntry(k, value=v)
)
for k, v in kwargs.items()
]
)
class Pregel(PregelProtocol):
nodes: dict[str, PregelNode]
channels: dict[str, Union[BaseChannel, ManagedValueSpec]]
stream_mode: StreamMode = "values"
"""Mode to stream output, defaults to 'values'."""
output_channels: Union[str, Sequence[str]]
stream_channels: Optional[Union[str, Sequence[str]]] = None
"""Channels to stream, defaults to all channels not in reserved channels"""
interrupt_after_nodes: Union[All, Sequence[str]]
interrupt_before_nodes: Union[All, Sequence[str]]
input_channels: Union[str, Sequence[str]]
step_timeout: Optional[float] = None
"""Maximum time to wait for a step to complete, in seconds. Defaults to None."""
debug: bool
"""Whether to print debug information during execution. Defaults to False."""
checkpointer: Checkpointer = None
"""Checkpointer used to save and load graph state. Defaults to None."""
store: Optional[BaseStore] = None
"""Memory store to use for SharedValues. Defaults to None."""
retry_policy: Optional[RetryPolicy] = None
"""Retry policy to use when running tasks. Set to None to disable."""
config_type: Optional[Type[Any]] = None
config: Optional[RunnableConfig] = None
name: str = "LangGraph"
def __init__(
self,
*,
nodes: dict[str, PregelNode],
channels: Optional[dict[str, Union[BaseChannel, ManagedValueSpec]]],
auto_validate: bool = True,
stream_mode: StreamMode = "values",
output_channels: Union[str, Sequence[str]],
stream_channels: Optional[Union[str, Sequence[str]]] = None,
interrupt_after_nodes: Union[All, Sequence[str]] = (),
interrupt_before_nodes: Union[All, Sequence[str]] = (),
input_channels: Union[str, Sequence[str]],
step_timeout: Optional[float] = None,
debug: Optional[bool] = None,
checkpointer: Optional[BaseCheckpointSaver] = None,
store: Optional[BaseStore] = None,
retry_policy: Optional[RetryPolicy] = None,
config_type: Optional[Type[Any]] = None,
config: Optional[RunnableConfig] = None,
name: str = "LangGraph",
) -> None:
self.nodes = nodes
self.channels = channels or {}
self.stream_mode = stream_mode
self.output_channels = output_channels
self.stream_channels = stream_channels
self.interrupt_after_nodes = interrupt_after_nodes
self.interrupt_before_nodes = interrupt_before_nodes
self.input_channels = input_channels
self.step_timeout = step_timeout
self.debug = debug if debug is not None else get_debug()
self.checkpointer = checkpointer
self.store = store
self.retry_policy = retry_policy
self.config_type = config_type
self.config = config
self.name = name
if auto_validate:
self.validate()
def get_graph(
self, config: RunnableConfig | None = None, *, xray: int | bool = False
) -> Graph:
raise NotImplementedError
async def aget_graph(
self, config: RunnableConfig | None = None, *, xray: int | bool = False
) -> Graph:
raise NotImplementedError
def copy(self, update: dict[str, Any] | None = None) -> Self:
attrs = {**self.__dict__, **(update or {})}
return self.__class__(**attrs)
def with_config(self, config: RunnableConfig | None = None, **kwargs: Any) -> Self:
return self.copy(
{"config": merge_configs(self.config, config, cast(RunnableConfig, kwargs))}
)
def validate(self) -> Self:
validate_graph(
self.nodes,
{k: v for k, v in self.channels.items() if isinstance(v, BaseChannel)},
self.input_channels,
self.output_channels,
self.stream_channels,
self.interrupt_after_nodes,
self.interrupt_before_nodes,
)
return self
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
return [
spec
for spec in get_unique_config_specs(
[spec for node in self.nodes.values() for spec in node.config_specs]
+ (
self.checkpointer.config_specs
if isinstance(self.checkpointer, BaseCheckpointSaver)
else []
)
+ (
[
ConfigurableFieldSpec(id=name, annotation=typ)
for name, typ in get_type_hints(self.config_type).items()
]
if self.config_type is not None
else []
)
)
# these are provided by the Pregel class
if spec.id
not in [
CONFIG_KEY_READ,
CONFIG_KEY_SEND,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_RESUMING,
]
]
@property
def InputType(self) -> Any:
if isinstance(self.input_channels, str):
channel = self.channels[self.input_channels]
if isinstance(channel, BaseChannel):
return channel.UpdateType
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
config = merge_configs(self.config, config)
if isinstance(self.input_channels, str):
return super().get_input_schema(config)
else:
return create_model(
self.get_name("Input"),
field_definitions={
k: (c.UpdateType, None)
for k in self.input_channels or self.channels.keys()
if (c := self.channels[k]) and isinstance(c, BaseChannel)
},
)
def get_input_jsonschema(
self, config: Optional[RunnableConfig] = None
) -> Dict[All, Any]:
schema = self.get_input_schema(config)
if hasattr(schema, "model_json_schema"):
return schema.model_json_schema()
else:
return schema.schema()
@property
def OutputType(self) -> Any:
if isinstance(self.output_channels, str):
channel = self.channels[self.output_channels]
if isinstance(channel, BaseChannel):
return channel.ValueType
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
config = merge_configs(self.config, config)
if isinstance(self.output_channels, str):
return super().get_output_schema(config)
else:
return create_model(
self.get_name("Output"),
field_definitions={
k: (c.ValueType, None)
for k in self.output_channels
if (c := self.channels[k]) and isinstance(c, BaseChannel)
},
)
def get_output_jsonschema(
self, config: Optional[RunnableConfig] = None
) -> Dict[All, Any]:
schema = self.get_output_schema(config)
if hasattr(schema, "model_json_schema"):
return schema.model_json_schema()
else:
return schema.schema()
@property
def stream_channels_list(self) -> Sequence[str]:
stream_channels = self.stream_channels_asis
return (
[stream_channels] if isinstance(stream_channels, str) else stream_channels
)
@property
def stream_channels_asis(self) -> Union[str, Sequence[str]]:
return self.stream_channels or [
k for k in self.channels if isinstance(self.channels[k], BaseChannel)
]
def get_subgraphs(
self, *, namespace: Optional[str] = None, recurse: bool = False
) -> Iterator[tuple[str, Pregel]]:
for name, node in self.nodes.items():
# filter by prefix
if namespace is not None:
if not namespace.startswith(name):
continue
# find the subgraph, if any
graph = cast(Optional[Pregel], find_subgraph_pregel(node.bound))
# if found, yield recursively
if graph:
if name == namespace:
yield name, graph
return # we found it, stop searching
if namespace is None:
yield name, graph
if recurse:
if namespace is not None:
namespace = namespace[len(name) + 1 :]
yield from (
(f"{name}{NS_SEP}{n}", s)
for n, s in graph.get_subgraphs(
namespace=namespace, recurse=recurse
)
)
async def aget_subgraphs(
self, *, namespace: Optional[str] = None, recurse: bool = False
) -> AsyncIterator[tuple[str, Pregel]]:
for name, node in self.get_subgraphs(namespace=namespace, recurse=recurse):
yield name, node
def _prepare_state_snapshot(
self,
config: RunnableConfig,
saved: Optional[CheckpointTuple],
recurse: Optional[BaseCheckpointSaver] = None,
apply_pending_writes: bool = False,
) -> StateSnapshot:
if not saved:
return StateSnapshot(
values={},
next=(),
config=config,
metadata=None,
created_at=None,
parent_config=None,
tasks=(),
)
with ChannelsManager(
self.channels,
saved.checkpoint,
LoopProtocol(
config=saved.config,
step=saved.metadata.get("step", -1) + 1,
stop=saved.metadata.get("step", -1) + 2,
),
skip_context=True,
) as (channels, managed):
# tasks for this checkpoint
next_tasks = prepare_next_tasks(
saved.checkpoint,
saved.pending_writes or [],
self.nodes,
channels,
managed,
saved.config,
saved.metadata.get("step", -1) + 1,
for_execution=True,
store=self.store,
checkpointer=self.checkpointer or None,
manager=None,
)
# get the subgraphs
subgraphs = dict(self.get_subgraphs())
parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
task_states: dict[str, Union[RunnableConfig, StateSnapshot]] = {}
for task in next_tasks.values():
if task.name not in subgraphs:
continue
# assemble checkpoint_ns for this task
task_ns = f"{task.name}{NS_END}{task.id}"
if parent_ns:
task_ns = f"{parent_ns}{NS_SEP}{task_ns}"
if not recurse:
# set config as signal that subgraph checkpoints exist
config = {
CONF: {
"thread_id": saved.config[CONF]["thread_id"],
CONFIG_KEY_CHECKPOINT_NS: task_ns,
}
}
task_states[task.id] = config
else:
# get the state of the subgraph
config = {
CONF: {
CONFIG_KEY_CHECKPOINTER: recurse,
"thread_id": saved.config[CONF]["thread_id"],
CONFIG_KEY_CHECKPOINT_NS: task_ns,
}
}
task_states[task.id] = subgraphs[task.name].get_state(
config, subgraphs=True
)
# apply pending writes
if null_writes := [
w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
]:
apply_writes(
saved.checkpoint,
channels,
[PregelTaskWrites((), INPUT, null_writes, [])],
None,
)
if apply_pending_writes and saved.pending_writes:
for tid, k, v in saved.pending_writes:
if k in (ERROR, INTERRUPT, SCHEDULED):
continue
if tid not in next_tasks:
continue
next_tasks[tid].writes.append((k, v))
if tasks := [t for t in next_tasks.values() if t.writes]:
apply_writes(saved.checkpoint, channels, tasks, None)
# assemble the state snapshot
return StateSnapshot(
read_channels(channels, self.stream_channels_asis),
tuple(t.name for t in next_tasks.values() if not t.writes),
patch_checkpoint_map(saved.config, saved.metadata),
saved.metadata,
saved.checkpoint["ts"],
patch_checkpoint_map(saved.parent_config, saved.metadata),
tasks_w_writes(
next_tasks.values(),
saved.pending_writes,
task_states,
self.stream_channels_asis,
),
)
async def _aprepare_state_snapshot(
self,
config: RunnableConfig,
saved: Optional[CheckpointTuple],
recurse: Optional[BaseCheckpointSaver] = None,
apply_pending_writes: bool = False,
) -> StateSnapshot:
if not saved:
return StateSnapshot(
values={},
next=(),
config=config,
metadata=None,
created_at=None,
parent_config=None,
tasks=(),
)
async with AsyncChannelsManager(
self.channels,
saved.checkpoint,
LoopProtocol(
config=saved.config,
step=saved.metadata.get("step", -1) + 1,
stop=saved.metadata.get("step", -1) + 2,
),
skip_context=True,
) as (
channels,
managed,
):
# tasks for this checkpoint
next_tasks = prepare_next_tasks(
saved.checkpoint,
saved.pending_writes or [],
self.nodes,
channels,
managed,
saved.config,
saved.metadata.get("step", -1) + 1,
for_execution=True,
store=self.store,
checkpointer=self.checkpointer or None,
manager=None,
)
# get the subgraphs
subgraphs = {n: g async for n, g in self.aget_subgraphs()}
parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
task_states: dict[str, Union[RunnableConfig, StateSnapshot]] = {}
for task in next_tasks.values():
if task.name not in subgraphs:
continue
# assemble checkpoint_ns for this task
task_ns = f"{task.name}{NS_END}{task.id}"
if parent_ns:
task_ns = f"{parent_ns}{NS_SEP}{task_ns}"
if not recurse:
# set config as signal that subgraph checkpoints exist
config = {
CONF: {
"thread_id": saved.config[CONF]["thread_id"],
CONFIG_KEY_CHECKPOINT_NS: task_ns,
}
}
task_states[task.id] = config
else:
# get the state of the subgraph
config = {
CONF: {
CONFIG_KEY_CHECKPOINTER: recurse,
"thread_id": saved.config[CONF]["thread_id"],
CONFIG_KEY_CHECKPOINT_NS: task_ns,
}
}
task_states[task.id] = await subgraphs[task.name].aget_state(
config, subgraphs=True
)
# apply pending writes
if null_writes := [
w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
]:
apply_writes(
saved.checkpoint,
channels,
[PregelTaskWrites((), INPUT, null_writes, [])],
None,
)
if apply_pending_writes and saved.pending_writes:
for tid, k, v in saved.pending_writes:
if k in (ERROR, INTERRUPT, SCHEDULED):
continue
if tid not in next_tasks:
continue
next_tasks[tid].writes.append((k, v))
if tasks := [t for t in next_tasks.values() if t.writes]:
apply_writes(saved.checkpoint, channels, tasks, None)
# assemble the state snapshot
return StateSnapshot(
read_channels(channels, self.stream_channels_asis),
tuple(t.name for t in next_tasks.values() if not t.writes),
patch_checkpoint_map(saved.config, saved.metadata),
saved.metadata,
saved.checkpoint["ts"],
patch_checkpoint_map(saved.parent_config, saved.metadata),
tasks_w_writes(
next_tasks.values(),
saved.pending_writes,
task_states,
self.stream_channels_asis,
),
)
def get_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
"""Get the current state of the graph."""
checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get(
CONFIG_KEY_CHECKPOINTER, self.checkpointer
)
if not checkpointer:
raise ValueError("No checkpointer set")
if (
checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
# remove task_ids from checkpoint_ns
recast_checkpoint_ns = NS_SEP.join(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
for _, pregel in self.get_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
return pregel.get_state(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
subgraphs=subgraphs,
)
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")
config = merge_configs(self.config, config) if self.config else config
saved = checkpointer.get_tuple(config)
return self._prepare_state_snapshot(
config,
saved,
recurse=checkpointer if subgraphs else None,
apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
)
async def aget_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
"""Get the current state of the graph."""
checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get(
CONFIG_KEY_CHECKPOINTER, self.checkpointer
)
if not checkpointer:
raise ValueError("No checkpointer set")
if (
checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
# remove task_ids from checkpoint_ns
recast_checkpoint_ns = NS_SEP.join(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
async for _, pregel in self.aget_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
return await pregel.aget_state(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
subgraphs=subgraphs,
)
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")
config = merge_configs(self.config, config) if self.config else config
saved = await checkpointer.aget_tuple(config)
return await self._aprepare_state_snapshot(
config,
saved,
recurse=checkpointer if subgraphs else None,
apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
)
def get_state_history(
self,
config: RunnableConfig,
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[StateSnapshot]:
config = ensure_config(config)
"""Get the history of the state of the graph."""
checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get(
CONFIG_KEY_CHECKPOINTER, self.checkpointer
)
if not checkpointer:
raise ValueError("No checkpointer set")
if (
checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
# remove task_ids from checkpoint_ns
recast_checkpoint_ns = NS_SEP.join(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
for _, pregel in self.get_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
yield from pregel.get_state_history(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
filter=filter,
before=before,
limit=limit,
)
return
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")
config = merge_configs(
self.config,
config,
{CONF: {CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns}},
)
# eagerly consume list() to avoid holding up the db cursor
for checkpoint_tuple in list(
checkpointer.list(config, before=before, limit=limit, filter=filter)
):
yield self._prepare_state_snapshot(
checkpoint_tuple.config, checkpoint_tuple
)
async def aget_state_history(
self,
config: RunnableConfig,
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[StateSnapshot]:
config = ensure_config(config)
"""Get the history of the state of the graph."""
checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get(
CONFIG_KEY_CHECKPOINTER, self.checkpointer
)
if not checkpointer:
raise ValueError("No checkpointer set")
if (
checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
# remove task_ids from checkpoint_ns
recast_checkpoint_ns = NS_SEP.join(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
async for _, pregel in self.aget_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
async for state in pregel.aget_state_history(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
filter=filter,
before=before,
limit=limit,
):
yield state
return
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")
config = merge_configs(
self.config,
config,
{CONF: {CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns}},
)
# eagerly consume list() to avoid holding up the db cursor
for checkpoint_tuple in [
c
async for c in checkpointer.alist(
config, before=before, limit=limit, filter=filter
)
]:
yield await self._aprepare_state_snapshot(
checkpoint_tuple.config, checkpoint_tuple
)
def update_state(
self,
config: RunnableConfig,
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig:
"""Update the state of the graph with the given values, as if they came from
node `as_node`. If `as_node` is not provided, it will be set to the last node
that updated the state, if not ambiguous.
"""
checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get(
CONFIG_KEY_CHECKPOINTER, self.checkpointer
)
if not checkpointer:
raise ValueError("No checkpointer set")
# delegate to subgraph
if (
checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
# remove task_ids from checkpoint_ns
recast_checkpoint_ns = NS_SEP.join(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
for _, pregel in self.get_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
return pregel.update_state(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
values,
as_node,
)
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")
# get last checkpoint
config = ensure_config(self.config, config)
saved = checkpointer.get_tuple(config)
checkpoint = copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint()
checkpoint_previous_versions = (
saved.checkpoint["channel_versions"].copy() if saved else {}
)
step = saved.metadata.get("step", -1) if saved else -1
# merge configurable fields with previous checkpoint config
checkpoint_config = patch_configurable(
config,
{CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")},
)
checkpoint_metadata = config["metadata"]
if saved:
checkpoint_config = patch_configurable(config, saved.config[CONF])
checkpoint_metadata = {**saved.metadata, **checkpoint_metadata}
with ChannelsManager(
self.channels,
checkpoint,
LoopProtocol(config=config, step=step + 1, stop=step + 2),
) as (channels, managed):
# no values as END, just clear all tasks
if values is None and as_node == END:
if saved is not None:
# tasks for this checkpoint
next_tasks = prepare_next_tasks(
checkpoint,
saved.pending_writes or [],
self.nodes,
channels,
managed,
saved.config,
saved.metadata.get("step", -1) + 1,
for_execution=True,
store=self.store,
checkpointer=self.checkpointer or None,
manager=None,
)
# apply null writes
if null_writes := [
w[1:]
for w in saved.pending_writes or []
if w[0] == NULL_TASK_ID
]:
apply_writes(
saved.checkpoint,
channels,
[PregelTaskWrites((), INPUT, null_writes, [])],
None,
)
# apply writes from tasks that already ran
for tid, k, v in saved.pending_writes or []:
if k in (ERROR, INTERRUPT, SCHEDULED):
continue
if tid not in next_tasks:
continue
next_tasks[tid].writes.append((k, v))
# clear all current tasks
apply_writes(checkpoint, channels, next_tasks.values(), None)
# save checkpoint
next_config = checkpointer.put(
checkpoint_config,
create_checkpoint(checkpoint, None, step),
{
**checkpoint_metadata,
"source": "update",
"step": step + 1,
"writes": {},
"parents": saved.metadata.get("parents", {}) if saved else {},
},
{},
)
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# no values, copy checkpoint
if values is None and as_node is None:
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
next_config = checkpointer.put(
checkpoint_config,
next_checkpoint,
{
**checkpoint_metadata,
"source": "update",
"step": step + 1,
"writes": {},
"parents": saved.metadata.get("parents", {}) if saved else {},
},
{},
)
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
if values is None and as_node == "__copy__":
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
next_config = checkpointer.put(
saved.parent_config or saved.config if saved else checkpoint_config,
next_checkpoint,
{
**checkpoint_metadata,
"source": "fork",
"step": step + 1,
"parents": saved.metadata.get("parents", {}) if saved else {},
},
{},
)
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# apply pending writes, if not on specific checkpoint
if (
CONFIG_KEY_CHECKPOINT_ID not in config[CONF]
and saved is not None
and saved.pending_writes
):
# tasks for this checkpoint
next_tasks = prepare_next_tasks(
checkpoint,
saved.pending_writes,
self.nodes,
channels,
managed,
saved.config,
saved.metadata.get("step", -1) + 1,
for_execution=True,
store=self.store,
checkpointer=self.checkpointer or None,
manager=None,
)
# apply null writes
if null_writes := [
w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
]:
apply_writes(
saved.checkpoint,
channels,
[PregelTaskWrites((), INPUT, null_writes, [])],
None,
)
# apply writes
for tid, k, v in saved.pending_writes:
if k in (ERROR, INTERRUPT, SCHEDULED):
continue
if tid not in next_tasks:
continue
next_tasks[tid].writes.append((k, v))
if tasks := [t for t in next_tasks.values() if t.writes]:
apply_writes(checkpoint, channels, tasks, None)
# find last node that updated the state, if not provided
if as_node is None and not any(
v for vv in checkpoint["versions_seen"].values() for v in vv.values()
):
if (
isinstance(self.input_channels, str)
and self.input_channels in self.nodes
):
as_node = self.input_channels
elif as_node is None:
last_seen_by_node = sorted(
(v, n)
for n, seen in checkpoint["versions_seen"].items()
if n in self.nodes
for v in seen.values()
)
# if two nodes updated the state at the same time, it's ambiguous
if last_seen_by_node:
if len(last_seen_by_node) == 1:
as_node = last_seen_by_node[0][1]
elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
as_node = last_seen_by_node[-1][1]
if as_node is None:
raise InvalidUpdateError("Ambiguous update, specify as_node")
if as_node not in self.nodes:
raise InvalidUpdateError(f"Node {as_node} does not exist")
# create task to run all writers of the chosen node
writers = self.nodes[as_node].flat_writers
if not writers:
raise InvalidUpdateError(f"Node {as_node} has no writers")
writes: deque[tuple[str, Any]] = deque()
task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
# execute task
run.invoke(
values,
patch_config(
config,
run_name=self.name + "UpdateState",
configurable={
# deque.extend is thread-safe
CONFIG_KEY_SEND: partial(
local_write,
writes.extend,
self.nodes.keys(),
),
CONFIG_KEY_READ: partial(
local_read,
step + 1,
checkpoint,
channels,
managed,
task,
config,
),
},
),
)
# save task writes
# channel writes are saved to current checkpoint
# push writes are saved to next checkpoint
channel_writes, push_writes = (
[w for w in task.writes if w[0] != PUSH],
[w for w in task.writes if w[0] == PUSH],
)
if saved and channel_writes:
checkpointer.put_writes(checkpoint_config, channel_writes, task_id)
# apply to checkpoint and save
mv_writes = apply_writes(
checkpoint, channels, [task], checkpointer.get_next_version
)
assert not mv_writes, "Can't write to SharedValues from update_state"
checkpoint = create_checkpoint(checkpoint, channels, step + 1)
next_config = checkpointer.put(
checkpoint_config,
checkpoint,
{
**checkpoint_metadata,
"source": "update",
"step": step + 1,
"writes": {as_node: values},
"parents": saved.metadata.get("parents", {}) if saved else {},
},
get_new_channel_versions(
checkpoint_previous_versions, checkpoint["channel_versions"]
),
)
if push_writes:
checkpointer.put_writes(next_config, push_writes, task_id)
return patch_checkpoint_map(next_config, saved.metadata if saved else None)
async def aupdate_state(
self,
config: RunnableConfig,
values: dict[str, Any] | Any,
as_node: Optional[str] = None,
) -> RunnableConfig:
checkpointer: Optional[BaseCheckpointSaver] = ensure_config(config)[CONF].get(
CONFIG_KEY_CHECKPOINTER, self.checkpointer
)
if not checkpointer:
raise ValueError("No checkpointer set")
# delegate to subgraph
if (
checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
# remove task_ids from checkpoint_ns
recast_checkpoint_ns = NS_SEP.join(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
async for _, pregel in self.aget_subgraphs(
namespace=recast_checkpoint_ns, recurse=True
):
return await pregel.aupdate_state(
patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
values,
as_node,
)
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")
# get last checkpoint
config = ensure_config(self.config, config)
saved = await checkpointer.aget_tuple(config)
checkpoint = copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint()
checkpoint_previous_versions = (
saved.checkpoint["channel_versions"].copy() if saved else {}
)
step = saved.metadata.get("step", -1) if saved else -1
# merge configurable fields with previous checkpoint config
checkpoint_config = patch_configurable(
config,
{CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")},
)
checkpoint_metadata = config["metadata"]
if saved:
checkpoint_config = patch_configurable(config, saved.config[CONF])
checkpoint_metadata = {**saved.metadata, **checkpoint_metadata}
async with AsyncChannelsManager(
self.channels,
checkpoint,
LoopProtocol(config=config, step=step + 1, stop=step + 2),
) as (
channels,
managed,
):
# no values, just clear all tasks
if values is None and as_node == END:
if saved is not None:
# tasks for this checkpoint
next_tasks = prepare_next_tasks(
checkpoint,
saved.pending_writes or [],
self.nodes,
channels,
managed,
saved.config,
saved.metadata.get("step", -1) + 1,
for_execution=True,
store=self.store,
checkpointer=self.checkpointer or None,
manager=None,
)
# apply null writes
if null_writes := [
w[1:]
for w in saved.pending_writes or []
if w[0] == NULL_TASK_ID
]:
apply_writes(
saved.checkpoint,
channels,
[PregelTaskWrites((), INPUT, null_writes, [])],
None,
)
# apply writes from tasks that already ran
for tid, k, v in saved.pending_writes or []:
if k in (ERROR, INTERRUPT, SCHEDULED):
continue
if tid not in next_tasks:
continue
next_tasks[tid].writes.append((k, v))
# clear all current tasks
apply_writes(checkpoint, channels, next_tasks.values(), None)
# save checkpoint
next_config = await checkpointer.aput(
checkpoint_config,
create_checkpoint(checkpoint, None, step),
{
**checkpoint_metadata,
"source": "update",
"step": step + 1,
"writes": {},
"parents": saved.metadata.get("parents", {}) if saved else {},
},
{},
)
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# no values, copy checkpoint
if values is None and as_node is None:
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
next_config = await checkpointer.aput(
checkpoint_config,
next_checkpoint,
{
**checkpoint_metadata,
"source": "update",
"step": step + 1,
"writes": {},
"parents": saved.metadata.get("parents", {}) if saved else {},
},
{},
)
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
if values is None and as_node == "__copy__":
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
next_config = await checkpointer.aput(
saved.parent_config or saved.config if saved else checkpoint_config,
next_checkpoint,
{
**checkpoint_metadata,
"source": "fork",
"step": step + 1,
"parents": saved.metadata.get("parents", {}) if saved else {},
},
{},
)
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# apply pending writes, if not on specific checkpoint
if (
CONFIG_KEY_CHECKPOINT_ID not in config[CONF]
and saved is not None
and saved.pending_writes
):
# tasks for this checkpoint
next_tasks = prepare_next_tasks(
checkpoint,
saved.pending_writes,
self.nodes,
channels,
managed,
saved.config,
saved.metadata.get("step", -1) + 1,
for_execution=True,
store=self.store,
checkpointer=self.checkpointer or None,
manager=None,
)
# apply null writes
if null_writes := [
w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
]:
apply_writes(
saved.checkpoint,
channels,
[PregelTaskWrites((), INPUT, null_writes, [])],
None,
)
for tid, k, v in saved.pending_writes:
if k in (ERROR, INTERRUPT, SCHEDULED):
continue
if tid not in next_tasks:
continue
next_tasks[tid].writes.append((k, v))
if tasks := [t for t in next_tasks.values() if t.writes]:
apply_writes(checkpoint, channels, tasks, None)
# find last node that updated the state, if not provided
if as_node is None and not saved:
if (
isinstance(self.input_channels, str)
and self.input_channels in self.nodes
):
as_node = self.input_channels
elif as_node is None:
last_seen_by_node = sorted(
(v, n)
for n, seen in checkpoint["versions_seen"].items()
if n in self.nodes
for v in seen.values()
)
# if two nodes updated the state at the same time, it's ambiguous
if last_seen_by_node:
if len(last_seen_by_node) == 1:
as_node = last_seen_by_node[0][1]
elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
as_node = last_seen_by_node[-1][1]
if as_node is None:
raise InvalidUpdateError("Ambiguous update, specify as_node")
if as_node not in self.nodes:
raise InvalidUpdateError(f"Node {as_node} does not exist")
# create task to run all writers of the chosen node
writers = self.nodes[as_node].flat_writers
if not writers:
raise InvalidUpdateError(f"Node {as_node} has no writers")
writes: deque[tuple[str, Any]] = deque()
task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
# execute task
await run.ainvoke(
values,
patch_config(
config,
run_name=self.name + "UpdateState",
configurable={
# deque.extend is thread-safe
CONFIG_KEY_SEND: partial(
local_write,
writes.extend,
self.nodes.keys(),
),
CONFIG_KEY_READ: partial(
local_read,
step + 1,
checkpoint,
channels,
managed,
task,
config,
),
},
),
)
# save task writes
# channel writes are saved to current checkpoint
# push writes are saved to next checkpoint
channel_writes, push_writes = (
[w for w in task.writes if w[0] != PUSH],
[w for w in task.writes if w[0] == PUSH],
)
if saved and channel_writes:
await checkpointer.aput_writes(
checkpoint_config, channel_writes, task_id
)
# apply to checkpoint and save
mv_writes = apply_writes(
checkpoint, channels, [task], checkpointer.get_next_version
)
assert not mv_writes, "Can't write to SharedValues from update_state"
checkpoint = create_checkpoint(checkpoint, channels, step + 1)
# save checkpoint, after applying writes
next_config = await checkpointer.aput(
checkpoint_config,
checkpoint,
{
**checkpoint_metadata,
"source": "update",
"step": step + 1,
"writes": {as_node: values},
"parents": saved.metadata.get("parents", {}) if saved else {},
},
get_new_channel_versions(
checkpoint_previous_versions, checkpoint["channel_versions"]
),
)
# save push writes
if push_writes:
await checkpointer.aput_writes(next_config, push_writes, task_id)
return patch_checkpoint_map(next_config, saved.metadata if saved else None)
def _defaults(
self,
config: RunnableConfig,
*,
stream_mode: Optional[Union[StreamMode, list[StreamMode]]],
output_keys: Optional[Union[str, Sequence[str]]],
interrupt_before: Optional[Union[All, Sequence[str]]],
interrupt_after: Optional[Union[All, Sequence[str]]],
debug: Optional[bool],
) -> tuple[
bool,
set[StreamMode],
Union[str, Sequence[str]],
Union[All, Sequence[str]],
Union[All, Sequence[str]],
Optional[BaseCheckpointSaver],
Optional[BaseStore],
]:
if config["recursion_limit"] < 1:
raise ValueError("recursion_limit must be at least 1")
debug = debug if debug is not None else self.debug
if output_keys is None:
output_keys = self.stream_channels_asis
else:
validate_keys(output_keys, self.channels)
interrupt_before = interrupt_before or self.interrupt_before_nodes
interrupt_after = interrupt_after or self.interrupt_after_nodes
stream_mode = stream_mode if stream_mode is not None else self.stream_mode
if not isinstance(stream_mode, list):
stream_mode = [stream_mode]
if CONFIG_KEY_TASK_ID in config.get(CONF, {}):
# if being called as a node in another graph, always use values mode
stream_mode = ["values"]
if self.checkpointer is False:
checkpointer: Optional[BaseCheckpointSaver] = None
elif CONFIG_KEY_CHECKPOINTER in config.get(CONF, {}):
checkpointer = config[CONF][CONFIG_KEY_CHECKPOINTER]
else:
checkpointer = self.checkpointer
if checkpointer and not config.get(CONF):
raise ValueError(
f"Checkpointer requires one or more of the following 'configurable' keys: {[s.id for s in checkpointer.config_specs]}"
)
if CONFIG_KEY_STORE in config.get(CONF, {}):
store: Optional[BaseStore] = config[CONF][CONFIG_KEY_STORE]
else:
store = self.store
return (
debug,
set(stream_mode),
output_keys,
interrupt_before,
interrupt_after,
checkpointer,
store,
)
def stream(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None,
output_keys: Optional[Union[str, Sequence[str]]] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
debug: Optional[bool] = None,
subgraphs: bool = False,
) -> Iterator[Union[dict[str, Any], Any]]:
"""Stream graph steps for a single input.
Args:
input: The input to the graph.
config: The configuration to use for the run.
stream_mode: The mode to stream output, defaults to self.stream_mode.
Options are 'values', 'updates', and 'debug'.
values: Emit the current values of the state for each step.
updates: Emit only the updates to the state for each step.
Output is a dict with the node name as key and the updated values as value.
debug: Emit debug events for each step.
output_keys: The keys to stream, defaults to all non-context channels.
interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph.
interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph.
debug: Whether to print debug information during execution, defaults to False.
subgraphs: Whether to stream subgraphs, defaults to False.
Yields:
The output of each step in the graph. The output shape depends on the stream_mode.
Examples:
Using different stream modes with a graph:
```pycon
>>> import operator
>>> from typing_extensions import Annotated, TypedDict
>>> from langgraph.graph import StateGraph
>>> from langgraph.constants import START
...
>>> class State(TypedDict):
... alist: Annotated[list, operator.add]
... another_list: Annotated[list, operator.add]
...
>>> builder = StateGraph(State)
>>> builder.add_node("a", lambda _state: {"another_list": ["hi"]})
>>> builder.add_node("b", lambda _state: {"alist": ["there"]})
>>> builder.add_edge("a", "b")
>>> builder.add_edge(START, "a")
>>> graph = builder.compile()
```
With stream_mode="values":
```pycon
>>> for event in graph.stream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"):
... print(event)
{'alist': ['Ex for stream_mode="values"'], 'another_list': []}
{'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']}
{'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']}
```
With stream_mode="updates":
```pycon
>>> for event in graph.stream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"):
... print(event)
{'a': {'another_list': ['hi']}}
{'b': {'alist': ['there']}}
```
With stream_mode="debug":
```pycon
>>> for event in graph.stream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"):
... print(event)
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}}
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}}
```
"""
stream = SyncQueue()
def output() -> Iterator:
while True:
try:
ns, mode, payload = stream.get(block=False)
except queue.Empty:
break
if subgraphs and isinstance(stream_mode, list):
yield (ns, mode, payload)
elif isinstance(stream_mode, list):
yield (mode, payload)
elif subgraphs:
yield (ns, payload)
else:
yield payload
config = ensure_config(self.config, config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
None,
input,
name=config.get("run_name", self.get_name()),
run_id=config.get("run_id"),
)
try:
# assign defaults
(
debug,
stream_modes,
output_keys,
interrupt_before_,
interrupt_after_,
checkpointer,
store,
) = self._defaults(
config,
stream_mode=stream_mode,
output_keys=output_keys,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
)
# set up messages stream mode
if "messages" in stream_modes:
run_manager.inheritable_handlers.append(
StreamMessagesHandler(stream.put)
)
# set up custom stream mode
if "custom" in stream_modes:
config[CONF][CONFIG_KEY_STREAM_WRITER] = lambda c: stream.put(
((), "custom", c)
)
with SyncPregelLoop(
input,
stream=StreamProtocol(stream.put, stream_modes),
config=config,
store=store,
checkpointer=checkpointer,
nodes=self.nodes,
specs=self.channels,
output_keys=output_keys,
stream_keys=self.stream_channels_asis,
interrupt_before=interrupt_before_,
interrupt_after=interrupt_after_,
manager=run_manager,
debug=debug,
) as loop:
# create runner
runner = PregelRunner(
submit=loop.submit,
put_writes=loop.put_writes,
schedule_task=loop.accept_push,
node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED),
)
# enable subgraph streaming
if subgraphs:
loop.config[CONF][CONFIG_KEY_STREAM] = loop.stream
# enable concurrent streaming
if subgraphs or "messages" in stream_modes or "custom" in stream_modes:
# we are careful to have a single waiter live at any one time
# because on exit we increment semaphore count by exactly 1
waiter: Optional[concurrent.futures.Future] = None
# because sync futures cannot be cancelled, we instead
# release the stream semaphore on exit, which will cause
# a pending waiter to return immediately
loop.stack.callback(stream._count.release)
def get_waiter() -> concurrent.futures.Future[None]:
nonlocal waiter
if waiter is None or waiter.done():
waiter = loop.submit(stream.wait)
return waiter
else:
return waiter
else:
get_waiter = None # type: ignore[assignment]
# Similarly to Bulk Synchronous Parallel / Pregel model
# computation proceeds in steps, while there are channel updates
# channel updates from step N are only visible in step N+1
# channels are guaranteed to be immutable for the duration of the step,
# with channel updates applied only at the transition between steps
while loop.tick(input_keys=self.input_channels):
for _ in runner.tick(
loop.tasks.values(),
timeout=self.step_timeout,
retry_policy=self.retry_policy,
get_waiter=get_waiter,
):
# emit output
yield from output()
# emit output
yield from output()
# handle exit
if loop.status == "out_of_steps":
msg = create_error_message(
message=(
f"Recursion limit of {config['recursion_limit']} reached "
"without hitting a stop condition. You can increase the "
"limit by setting the `recursion_limit` config key."
),
error_code=ErrorCode.GRAPH_RECURSION_LIMIT,
)
raise GraphRecursionError(msg)
# set final channel values as run output
run_manager.on_chain_end(loop.output)
except BaseException as e:
run_manager.on_chain_error(e)
raise
async def astream(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
stream_mode: Optional[Union[StreamMode, list[StreamMode]]] = None,
output_keys: Optional[Union[str, Sequence[str]]] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
debug: Optional[bool] = None,
subgraphs: bool = False,
) -> AsyncIterator[Union[dict[str, Any], Any]]:
"""Stream graph steps for a single input.
Args:
input: The input to the graph.
config: The configuration to use for the run.
stream_mode: The mode to stream output, defaults to self.stream_mode.
Options are 'values', 'updates', and 'debug'.
values: Emit the current values of the state for each step.
updates: Emit only the updates to the state for each step.
Output is a dict with the node name as key and the updated values as value.
debug: Emit debug events for each step.
output_keys: The keys to stream, defaults to all non-context channels.
interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph.
interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph.
debug: Whether to print debug information during execution, defaults to False.
subgraphs: Whether to stream subgraphs, defaults to False.
Yields:
The output of each step in the graph. The output shape depends on the stream_mode.
Examples:
Using different stream modes with a graph:
```pycon
>>> import operator
>>> from typing_extensions import Annotated, TypedDict
>>> from langgraph.graph import StateGraph
>>> from langgraph.constants import START
...
>>> class State(TypedDict):
... alist: Annotated[list, operator.add]
... another_list: Annotated[list, operator.add]
...
>>> builder = StateGraph(State)
>>> builder.add_node("a", lambda _state: {"another_list": ["hi"]})
>>> builder.add_node("b", lambda _state: {"alist": ["there"]})
>>> builder.add_edge("a", "b")
>>> builder.add_edge(START, "a")
>>> graph = builder.compile()
```
With stream_mode="values":
```pycon
>>> async for event in graph.astream({"alist": ['Ex for stream_mode="values"']}, stream_mode="values"):
... print(event)
{'alist': ['Ex for stream_mode="values"'], 'another_list': []}
{'alist': ['Ex for stream_mode="values"'], 'another_list': ['hi']}
{'alist': ['Ex for stream_mode="values"', 'there'], 'another_list': ['hi']}
```
With stream_mode="updates":
```pycon
>>> async for event in graph.astream({"alist": ['Ex for stream_mode="updates"']}, stream_mode="updates"):
... print(event)
{'a': {'another_list': ['hi']}}
{'b': {'alist': ['there']}}
```
With stream_mode="debug":
```pycon
>>> async for event in graph.astream({"alist": ['Ex for stream_mode="debug"']}, stream_mode="debug"):
... print(event)
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': []}, 'triggers': ['start:a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 1, 'payload': {'id': '...', 'name': 'a', 'result': [('another_list', ['hi'])]}}
{'type': 'task', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'input': {'alist': ['Ex for stream_mode="debug"'], 'another_list': ['hi']}, 'triggers': ['a']}}
{'type': 'task_result', 'timestamp': '2024-06-23T...+00:00', 'step': 2, 'payload': {'id': '...', 'name': 'b', 'result': [('alist', ['there'])]}}
```
"""
stream = AsyncQueue()
aioloop = asyncio.get_running_loop()
stream_put = cast(
Callable[[StreamChunk], None],
partial(aioloop.call_soon_threadsafe, stream.put_nowait),
)
def output() -> Iterator:
while True:
try:
ns, mode, payload = stream.get_nowait()
except asyncio.QueueEmpty:
break
if subgraphs and isinstance(stream_mode, list):
yield (ns, mode, payload)
elif isinstance(stream_mode, list):
yield (mode, payload)
elif subgraphs:
yield (ns, payload)
else:
yield payload
config = ensure_config(self.config, config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
None,
input,
name=config.get("run_name", self.get_name()),
run_id=config.get("run_id"),
)
# if running from astream_log() run each proc with streaming
do_stream = next(
(
cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers
if isinstance(h, _StreamingCallbackHandler)
),
None,
)
try:
# assign defaults
(
debug,
stream_modes,
output_keys,
interrupt_before_,
interrupt_after_,
checkpointer,
store,
) = self._defaults(
config,
stream_mode=stream_mode,
output_keys=output_keys,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
)
# set up messages stream mode
if "messages" in stream_modes:
run_manager.inheritable_handlers.append(
StreamMessagesHandler(stream_put)
)
# set up custom stream mode
if "custom" in stream_modes:
config[CONF][CONFIG_KEY_STREAM_WRITER] = (
lambda c: aioloop.call_soon_threadsafe(
stream.put_nowait, ((), "custom", c)
)
)
async with AsyncPregelLoop(
input,
stream=StreamProtocol(stream.put_nowait, stream_modes),
config=config,
store=store,
checkpointer=checkpointer,
nodes=self.nodes,
specs=self.channels,
output_keys=output_keys,
stream_keys=self.stream_channels_asis,
interrupt_before=interrupt_before_,
interrupt_after=interrupt_after_,
manager=run_manager,
debug=debug,
) as loop:
# create runner
runner = PregelRunner(
submit=loop.submit,
put_writes=loop.put_writes,
schedule_task=loop.accept_push,
use_astream=do_stream is not None,
node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED),
)
# enable subgraph streaming
if subgraphs:
loop.config[CONF][CONFIG_KEY_STREAM] = StreamProtocol(
stream_put, stream_modes
)
# enable concurrent streaming
if subgraphs or "messages" in stream_modes or "custom" in stream_modes:
def get_waiter() -> asyncio.Task[None]:
return aioloop.create_task(stream.wait())
else:
get_waiter = None # type: ignore[assignment]
# Similarly to Bulk Synchronous Parallel / Pregel model
# computation proceeds in steps, while there are channel updates
# channel updates from step N are only visible in step N+1
# channels are guaranteed to be immutable for the duration of the step,
# with channel updates applied only at the transition between steps
while loop.tick(input_keys=self.input_channels):
async for _ in runner.atick(
loop.tasks.values(),
timeout=self.step_timeout,
retry_policy=self.retry_policy,
get_waiter=get_waiter,
):
# emit output
for o in output():
yield o
# emit output
for o in output():
yield o
# handle exit
if loop.status == "out_of_steps":
msg = create_error_message(
message=(
f"Recursion limit of {config['recursion_limit']} reached "
"without hitting a stop condition. You can increase the "
"limit by setting the `recursion_limit` config key."
),
error_code=ErrorCode.GRAPH_RECURSION_LIMIT,
)
raise GraphRecursionError(msg)
# set final channel values as run output
await run_manager.on_chain_end(loop.output)
except BaseException as e:
await asyncio.shield(run_manager.on_chain_error(e))
raise
def invoke(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
stream_mode: StreamMode = "values",
output_keys: Optional[Union[str, Sequence[str]]] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
debug: Optional[bool] = None,
**kwargs: Any,
) -> Union[dict[str, Any], Any]:
"""Run the graph with a single input and config.
Args:
input: The input data for the graph. It can be a dictionary or any other type.
config: Optional. The configuration for the graph run.
stream_mode: Optional[str]. The stream mode for the graph run. Default is "values".
output_keys: Optional. The output keys to retrieve from the graph run.
interrupt_before: Optional. The nodes to interrupt the graph run before.
interrupt_after: Optional. The nodes to interrupt the graph run after.
debug: Optional. Enable debug mode for the graph run.
**kwargs: Additional keyword arguments to pass to the graph run.
Returns:
The output of the graph run. If stream_mode is "values", it returns the latest output.
If stream_mode is not "values", it returns a list of output chunks.
"""
output_keys = output_keys if output_keys is not None else self.output_channels
if stream_mode == "values":
latest: Union[dict[str, Any], Any] = None
else:
chunks = []
for chunk in self.stream(
input,
config,
stream_mode=stream_mode,
output_keys=output_keys,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
**kwargs,
):
if stream_mode == "values":
latest = chunk
else:
chunks.append(chunk)
if stream_mode == "values":
return latest
else:
return chunks
async def ainvoke(
self,
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
*,
stream_mode: StreamMode = "values",
output_keys: Optional[Union[str, Sequence[str]]] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
debug: Optional[bool] = None,
**kwargs: Any,
) -> Union[dict[str, Any], Any]:
"""Asynchronously invoke the graph on a single input.
Args:
input: The input data for the computation. It can be a dictionary or any other type.
config: Optional. The configuration for the computation.
stream_mode: Optional. The stream mode for the computation. Default is "values".
output_keys: Optional. The output keys to include in the result. Default is None.
interrupt_before: Optional. The nodes to interrupt before. Default is None.
interrupt_after: Optional. The nodes to interrupt after. Default is None.
debug: Optional. Whether to enable debug mode. Default is None.
**kwargs: Additional keyword arguments.
Returns:
The result of the computation. If stream_mode is "values", it returns the latest value.
If stream_mode is "chunks", it returns a list of chunks.
"""
output_keys = output_keys if output_keys is not None else self.output_channels
if stream_mode == "values":
latest: Union[dict[str, Any], Any] = None
else:
chunks = []
async for chunk in self.astream(
input,
config,
stream_mode=stream_mode,
output_keys=output_keys,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
**kwargs,
):
if stream_mode == "values":
latest = chunk
else:
chunks.append(chunk)
if stream_mode == "values":
return latest
else:
return chunks
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/types.py`:
```py
"""Re-export types moved to langgraph.types"""
from langgraph.types import (
All,
CachePolicy,
PregelExecutableTask,
PregelTask,
RetryPolicy,
StateSnapshot,
StreamMode,
StreamWriter,
default_retry_on,
)
__all__ = [
"All",
"CachePolicy",
"PregelExecutableTask",
"PregelTask",
"RetryPolicy",
"StateSnapshot",
"StreamMode",
"StreamWriter",
"default_retry_on",
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/retry.py`:
```py
import asyncio
import logging
import random
import sys
import time
from dataclasses import replace
from functools import partial
from typing import Any, Callable, Optional, Sequence
from langgraph.constants import (
CONF,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_RESUMING,
CONFIG_KEY_SEND,
NS_SEP,
)
from langgraph.errors import _SEEN_CHECKPOINT_NS, GraphBubbleUp, ParentCommand
from langgraph.types import Command, PregelExecutableTask, RetryPolicy
from langgraph.utils.config import patch_configurable
logger = logging.getLogger(__name__)
SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11)
def run_with_retry(
task: PregelExecutableTask,
retry_policy: Optional[RetryPolicy],
writer: Optional[
Callable[[PregelExecutableTask, Sequence[tuple[str, Any]]], None]
] = None,
) -> None:
"""Run a task with retries."""
retry_policy = task.retry_policy or retry_policy
interval = retry_policy.initial_interval if retry_policy else 0
attempts = 0
config = task.config
if writer is not None:
config = patch_configurable(config, {CONFIG_KEY_SEND: partial(writer, task)})
while True:
try:
# clear any writes from previous attempts
task.writes.clear()
# run the task
task.proc.invoke(task.input, config)
# if successful, end
break
except ParentCommand as exc:
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
cmd = exc.args[0]
if cmd.graph == ns:
# this command is for the current graph, handle it
for w in task.writers:
w.invoke(cmd, config)
break
elif cmd.graph == Command.PARENT:
# this command is for the parent graph, assign it to the parent
parent_ns = NS_SEP.join(ns.split(NS_SEP)[:-1])
exc.args = (replace(cmd, graph=parent_ns),)
# bubble up
raise
except GraphBubbleUp:
# if interrupted, end
raise
except Exception as exc:
if SUPPORTS_EXC_NOTES:
exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
if retry_policy is None:
raise
# increment attempts
attempts += 1
# check if we should retry
if isinstance(retry_policy.retry_on, Sequence):
if not isinstance(exc, tuple(retry_policy.retry_on)):
raise
elif isinstance(retry_policy.retry_on, type) and issubclass(
retry_policy.retry_on, Exception
):
if not isinstance(exc, retry_policy.retry_on):
raise
elif callable(retry_policy.retry_on):
if not retry_policy.retry_on(exc): # type: ignore[call-arg]
raise
else:
raise TypeError(
"retry_on must be an Exception class, a list or tuple of Exception classes, or a callable"
)
# check if we should give up
if attempts >= retry_policy.max_attempts:
raise
# sleep before retrying
interval = min(
retry_policy.max_interval,
interval * retry_policy.backoff_factor,
)
time.sleep(
interval + random.uniform(0, 1) if retry_policy.jitter else interval
)
# log the retry
logger.info(
f"Retrying task {task.name} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)
finally:
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)
async def arun_with_retry(
task: PregelExecutableTask,
retry_policy: Optional[RetryPolicy],
stream: bool = False,
writer: Optional[
Callable[[PregelExecutableTask, Sequence[tuple[str, Any]]], None]
] = None,
) -> None:
"""Run a task asynchronously with retries."""
retry_policy = task.retry_policy or retry_policy
interval = retry_policy.initial_interval if retry_policy else 0
attempts = 0
config = task.config
if writer is not None:
config = patch_configurable(config, {CONFIG_KEY_SEND: partial(writer, task)})
while True:
try:
# clear any writes from previous attempts
task.writes.clear()
# run the task
if stream:
async for _ in task.proc.astream(task.input, config):
pass
else:
await task.proc.ainvoke(task.input, config)
# if successful, end
break
except ParentCommand as exc:
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
cmd = exc.args[0]
if cmd.graph == ns:
# this command is for the current graph, handle it
for w in task.writers:
w.invoke(cmd, config)
break
elif cmd.graph == Command.PARENT:
# this command is for the parent graph, assign it to the parent
parent_ns = NS_SEP.join(ns.split(NS_SEP)[:-1])
exc.args = (replace(cmd, graph=parent_ns),)
# bubble up
raise
except GraphBubbleUp:
# if interrupted, end
raise
except Exception as exc:
if SUPPORTS_EXC_NOTES:
exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
if retry_policy is None:
raise
# increment attempts
attempts += 1
# check if we should retry
if isinstance(retry_policy.retry_on, Sequence):
if not isinstance(exc, tuple(retry_policy.retry_on)):
raise
elif isinstance(retry_policy.retry_on, type) and issubclass(
retry_policy.retry_on, Exception
):
if not isinstance(exc, retry_policy.retry_on):
raise
elif callable(retry_policy.retry_on):
if not retry_policy.retry_on(exc): # type: ignore[call-arg]
raise
else:
raise TypeError(
"retry_on must be an Exception class, a list or tuple of Exception classes, or a callable"
)
# check if we should give up
if attempts >= retry_policy.max_attempts:
raise
# sleep before retrying
interval = min(
retry_policy.max_interval,
interval * retry_policy.backoff_factor,
)
await asyncio.sleep(
interval + random.uniform(0, 1) if retry_policy.jitter else interval
)
# log the retry
logger.info(
f"Retrying task {task.name} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)
finally:
# clear checkpoint_ns seen (for subgraph detection)
if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS):
_SEEN_CHECKPOINT_NS.discard(checkpoint_ns)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/validate.py`:
```py
from typing import Any, Mapping, Optional, Sequence, Union
from langgraph.channels.base import BaseChannel
from langgraph.constants import RESERVED
from langgraph.pregel.read import PregelNode
from langgraph.types import All
def validate_graph(
nodes: Mapping[str, PregelNode],
channels: dict[str, BaseChannel],
input_channels: Union[str, Sequence[str]],
output_channels: Union[str, Sequence[str]],
stream_channels: Optional[Union[str, Sequence[str]]],
interrupt_after_nodes: Union[All, Sequence[str]],
interrupt_before_nodes: Union[All, Sequence[str]],
) -> None:
for chan in channels:
if chan in RESERVED:
raise ValueError(f"Channel names {chan} are reserved")
subscribed_channels = set[str]()
for name, node in nodes.items():
if name in RESERVED:
raise ValueError(f"Node names {RESERVED} are reserved")
if isinstance(node, PregelNode):
subscribed_channels.update(node.triggers)
else:
raise TypeError(
f"Invalid node type {type(node)}, expected Channel.subscribe_to()"
)
for chan in subscribed_channels:
if chan not in channels:
raise ValueError(f"Subscribed channel '{chan}' not in 'channels'")
if isinstance(input_channels, str):
if input_channels not in channels:
raise ValueError(f"Input channel '{input_channels}' not in 'channels'")
if input_channels not in subscribed_channels:
raise ValueError(
f"Input channel {input_channels} is not subscribed to by any node"
)
else:
for chan in input_channels:
if chan not in channels:
raise ValueError(f"Input channel '{chan}' not in 'channels'")
if all(chan not in subscribed_channels for chan in input_channels):
raise ValueError(
f"None of the input channels {input_channels} are subscribed to by any node"
)
all_output_channels = set[str]()
if isinstance(output_channels, str):
all_output_channels.add(output_channels)
else:
all_output_channels.update(output_channels)
if isinstance(stream_channels, str):
all_output_channels.add(stream_channels)
elif stream_channels is not None:
all_output_channels.update(stream_channels)
for chan in all_output_channels:
if chan not in channels:
raise ValueError(f"Output channel '{chan}' not in 'channels'")
if interrupt_after_nodes != "*":
for n in interrupt_after_nodes:
if n not in nodes:
raise ValueError(f"Node {n} not in nodes")
if interrupt_before_nodes != "*":
for n in interrupt_before_nodes:
if n not in nodes:
raise ValueError(f"Node {n} not in nodes")
def validate_keys(
keys: Optional[Union[str, Sequence[str]]],
channels: Mapping[str, Any],
) -> None:
if isinstance(keys, str):
if keys not in channels:
raise ValueError(f"Key {keys} not in channels")
elif keys is not None:
for chan in keys:
if chan not in channels:
raise ValueError(f"Key {chan} not in channels")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/utils.py`:
```py
from typing import Optional
from langchain_core.runnables import RunnableLambda, RunnableSequence
from langchain_core.runnables.utils import get_function_nonlocals
from langgraph.checkpoint.base import ChannelVersions
from langgraph.pregel.protocol import PregelProtocol
from langgraph.utils.runnable import Runnable, RunnableCallable, RunnableSeq
def get_new_channel_versions(
previous_versions: ChannelVersions, current_versions: ChannelVersions
) -> ChannelVersions:
"""Get subset of current_versions that are newer than previous_versions."""
if previous_versions:
version_type = type(next(iter(current_versions.values()), None))
null_version = version_type() # type: ignore[misc]
new_versions = {
k: v
for k, v in current_versions.items()
if v > previous_versions.get(k, null_version) # type: ignore[operator]
}
else:
new_versions = current_versions
return new_versions
def find_subgraph_pregel(candidate: Runnable) -> Optional[Runnable]:
from langgraph.pregel import Pregel
candidates: list[Runnable] = [candidate]
for c in candidates:
if (
isinstance(c, PregelProtocol)
# subgraphs that disabled checkpointing are not considered
and (not isinstance(c, Pregel) or c.checkpointer is not False)
):
return c
elif isinstance(c, RunnableSequence) or isinstance(c, RunnableSeq):
candidates.extend(c.steps)
elif isinstance(c, RunnableLambda):
candidates.extend(c.deps)
elif isinstance(c, RunnableCallable):
if c.func is not None:
candidates.extend(
nl.__self__ if hasattr(nl, "__self__") else nl
for nl in get_function_nonlocals(c.func)
)
elif c.afunc is not None:
candidates.extend(
nl.__self__ if hasattr(nl, "__self__") else nl
for nl in get_function_nonlocals(c.afunc)
)
return None
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/debug.py`:
```py
from collections import defaultdict
from dataclasses import asdict
from datetime import datetime, timezone
from pprint import pformat
from typing import (
Any,
Iterable,
Iterator,
Literal,
Mapping,
Optional,
Sequence,
TypedDict,
Union,
)
from uuid import UUID
from langchain_core.runnables.config import RunnableConfig
from langchain_core.utils.input import get_bolded_text, get_colored_text
from langgraph.channels.base import BaseChannel
from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata, PendingWrite
from langgraph.constants import (
CONF,
CONFIG_KEY_CHECKPOINT_NS,
ERROR,
INTERRUPT,
NS_END,
NS_SEP,
TAG_HIDDEN,
)
from langgraph.pregel.io import read_channels
from langgraph.pregel.utils import find_subgraph_pregel
from langgraph.types import PregelExecutableTask, PregelTask, StateSnapshot
from langgraph.utils.config import patch_checkpoint_map
class TaskPayload(TypedDict):
id: str
name: str
input: Any
triggers: list[str]
class TaskResultPayload(TypedDict):
id: str
name: str
error: Optional[str]
interrupts: list[dict]
result: list[tuple[str, Any]]
class CheckpointTask(TypedDict):
id: str
name: str
error: Optional[str]
interrupts: list[dict]
state: Optional[RunnableConfig]
class CheckpointPayload(TypedDict):
config: Optional[RunnableConfig]
metadata: CheckpointMetadata
values: dict[str, Any]
next: list[str]
parent_config: Optional[RunnableConfig]
tasks: list[CheckpointTask]
class DebugOutputBase(TypedDict):
timestamp: str
step: int
class DebugOutputTask(DebugOutputBase):
type: Literal["task"]
payload: TaskPayload
class DebugOutputTaskResult(DebugOutputBase):
type: Literal["task_result"]
payload: TaskResultPayload
class DebugOutputCheckpoint(DebugOutputBase):
type: Literal["checkpoint"]
payload: CheckpointPayload
DebugOutput = Union[DebugOutputTask, DebugOutputTaskResult, DebugOutputCheckpoint]
TASK_NAMESPACE = UUID("6ba7b831-9dad-11d1-80b4-00c04fd430c8")
def map_debug_tasks(
step: int, tasks: Iterable[PregelExecutableTask]
) -> Iterator[DebugOutputTask]:
"""Produce "task" events for stream_mode=debug."""
ts = datetime.now(timezone.utc).isoformat()
for task in tasks:
if task.config is not None and TAG_HIDDEN in task.config.get("tags", []):
continue
yield {
"type": "task",
"timestamp": ts,
"step": step,
"payload": {
"id": task.id,
"name": task.name,
"input": task.input,
"triggers": task.triggers,
},
}
def map_debug_task_results(
step: int,
task_tup: tuple[PregelExecutableTask, Sequence[tuple[str, Any]]],
stream_keys: Union[str, Sequence[str]],
) -> Iterator[DebugOutputTaskResult]:
"""Produce "task_result" events for stream_mode=debug."""
stream_channels_list = (
[stream_keys] if isinstance(stream_keys, str) else stream_keys
)
task, writes = task_tup
yield {
"type": "task_result",
"timestamp": datetime.now(timezone.utc).isoformat(),
"step": step,
"payload": {
"id": task.id,
"name": task.name,
"error": next((w[1] for w in writes if w[0] == ERROR), None),
"result": [w for w in writes if w[0] in stream_channels_list],
"interrupts": [asdict(w[1]) for w in writes if w[0] == INTERRUPT],
},
}
def map_debug_checkpoint(
step: int,
config: RunnableConfig,
channels: Mapping[str, BaseChannel],
stream_channels: Union[str, Sequence[str]],
metadata: CheckpointMetadata,
checkpoint: Checkpoint,
tasks: Iterable[PregelExecutableTask],
pending_writes: list[PendingWrite],
parent_config: Optional[RunnableConfig],
output_keys: Union[str, Sequence[str]],
) -> Iterator[DebugOutputCheckpoint]:
"""Produce "checkpoint" events for stream_mode=debug."""
parent_ns = config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
task_states: dict[str, Union[RunnableConfig, StateSnapshot]] = {}
for task in tasks:
if not find_subgraph_pregel(task.proc):
continue
# assemble checkpoint_ns for this task
task_ns = f"{task.name}{NS_END}{task.id}"
if parent_ns:
task_ns = f"{parent_ns}{NS_SEP}{task_ns}"
# set config as signal that subgraph checkpoints exist
task_states[task.id] = {
CONF: {
"thread_id": config[CONF]["thread_id"],
CONFIG_KEY_CHECKPOINT_NS: task_ns,
}
}
yield {
"type": "checkpoint",
"timestamp": checkpoint["ts"],
"step": step,
"payload": {
"config": patch_checkpoint_map(config, metadata),
"parent_config": patch_checkpoint_map(parent_config, metadata),
"values": read_channels(channels, stream_channels),
"metadata": metadata,
"next": [t.name for t in tasks],
"tasks": [
{
"id": t.id,
"name": t.name,
"error": t.error,
"state": t.state,
}
if t.error
else {
"id": t.id,
"name": t.name,
"result": t.result,
"interrupts": tuple(asdict(i) for i in t.interrupts),
"state": t.state,
}
if t.result
else {
"id": t.id,
"name": t.name,
"interrupts": tuple(asdict(i) for i in t.interrupts),
"state": t.state,
}
for t in tasks_w_writes(tasks, pending_writes, task_states, output_keys)
],
},
}
def print_step_tasks(step: int, next_tasks: list[PregelExecutableTask]) -> None:
n_tasks = len(next_tasks)
print(
f"{get_colored_text(f'[{step}:tasks]', color='blue')} "
+ get_bolded_text(
f"Starting {n_tasks} task{'s' if n_tasks != 1 else ''} for step {step}:\n"
)
+ "\n".join(
f"- {get_colored_text(task.name, 'green')} -> {pformat(task.input)}"
for task in next_tasks
)
)
def print_step_writes(
step: int, writes: Sequence[tuple[str, Any]], whitelist: Sequence[str]
) -> None:
by_channel: dict[str, list[Any]] = defaultdict(list)
for channel, value in writes:
if channel in whitelist:
by_channel[channel].append(value)
print(
f"{get_colored_text(f'[{step}:writes]', color='blue')} "
+ get_bolded_text(
f"Finished step {step} with writes to {len(by_channel)} channel{'s' if len(by_channel) != 1 else ''}:\n"
)
+ "\n".join(
f"- {get_colored_text(name, 'yellow')} -> {', '.join(pformat(v) for v in vals)}"
for name, vals in by_channel.items()
)
)
def print_step_checkpoint(
metadata: CheckpointMetadata,
channels: Mapping[str, BaseChannel],
whitelist: Sequence[str],
) -> None:
step = metadata["step"]
print(
f"{get_colored_text(f'[{step}:checkpoint]', color='blue')} "
+ get_bolded_text(f"State at the end of step {step}:\n")
+ pformat(read_channels(channels, whitelist), depth=3)
)
def tasks_w_writes(
tasks: Iterable[Union[PregelTask, PregelExecutableTask]],
pending_writes: Optional[list[PendingWrite]],
states: Optional[dict[str, Union[RunnableConfig, StateSnapshot]]],
output_keys: Union[str, Sequence[str]],
) -> tuple[PregelTask, ...]:
"""Apply writes / subgraph states to tasks to be returned in a StateSnapshot."""
pending_writes = pending_writes or []
return tuple(
PregelTask(
task.id,
task.name,
task.path,
next(
(
exc
for tid, n, exc in pending_writes
if tid == task.id and n == ERROR
),
None,
),
tuple(
v for tid, n, v in pending_writes if tid == task.id and n == INTERRUPT
),
states.get(task.id) if states else None,
(
next(
(
val
for tid, chan, val in pending_writes
if tid == task.id and chan == output_keys
),
None,
)
if isinstance(output_keys, str)
else {
chan: val
for tid, chan, val in pending_writes
if tid == task.id
and (
chan == output_keys
if isinstance(output_keys, str)
else chan in output_keys
)
}
)
if any(
w[0] == task.id and w[1] not in (ERROR, INTERRUPT)
for w in pending_writes
)
else None,
)
for task in tasks
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/messages.py`:
```py
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Union,
cast,
)
from uuid import UUID, uuid4
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, LLMResult
from langchain_core.tracers._streaming import T, _StreamingCallbackHandler
from langgraph.constants import NS_SEP, TAG_HIDDEN, TAG_NOSTREAM
from langgraph.types import StreamChunk
Meta = tuple[tuple[str, ...], dict[str, Any]]
class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler):
"""A callback handler that implements stream_mode=messages.
Collects messages from (1) chat model stream events and (2) node outputs."""
run_inline = True
"""We want this callback to run in the main thread, to avoid order/locking issues."""
def __init__(self, stream: Callable[[StreamChunk], None]):
self.stream = stream
self.metadata: dict[UUID, Meta] = {}
self.seen: set[Union[int, str]] = set()
def _emit(self, meta: Meta, message: BaseMessage, *, dedupe: bool = False) -> None:
if dedupe and message.id in self.seen:
return
else:
if message.id is None:
message.id = str(uuid4())
self.seen.add(message.id)
self.stream((meta[0], "messages", (message, meta[1])))
def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
return output
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
return output
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
if metadata and (not tags or TAG_NOSTREAM not in tags):
self.metadata[run_id] = (
tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)),
metadata,
)
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[ChatGenerationChunk] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if not isinstance(chunk, ChatGenerationChunk):
return
if meta := self.metadata.get(run_id):
self._emit(meta, chunk.message)
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self.metadata.pop(run_id, None)
def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self.metadata.pop(run_id, None)
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
if (
metadata
and kwargs.get("name") == metadata.get("langgraph_node")
and (not tags or TAG_HIDDEN not in tags)
):
self.metadata[run_id] = (
tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)),
metadata,
)
def on_chain_end(
self,
response: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if meta := self.metadata.pop(run_id, None):
if isinstance(response, BaseMessage):
self._emit(meta, response, dedupe=True)
elif isinstance(response, Sequence):
for value in response:
if isinstance(value, BaseMessage):
self._emit(meta, value, dedupe=True)
elif isinstance(response, dict):
for value in response.values():
if isinstance(value, BaseMessage):
self._emit(meta, value, dedupe=True)
elif isinstance(value, Sequence):
for item in value:
if isinstance(item, BaseMessage):
self._emit(meta, item, dedupe=True)
elif hasattr(response, "__dir__") and callable(response.__dir__):
for key in dir(response):
try:
value = getattr(response, key)
if isinstance(value, BaseMessage):
self._emit(meta, value, dedupe=True)
elif isinstance(value, Sequence):
for item in value:
if isinstance(item, BaseMessage):
self._emit(meta, item, dedupe=True)
except AttributeError:
pass
def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self.metadata.pop(run_id, None)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/loop.py`:
```py
import asyncio
import concurrent.futures
from collections import defaultdict, deque
from contextlib import AsyncExitStack, ExitStack
from types import TracebackType
from typing import (
Any,
AsyncContextManager,
Callable,
ContextManager,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
Union,
cast,
)
from langchain_core.callbacks import AsyncParentRunManager, ParentRunManager
from langchain_core.runnables import RunnableConfig
from typing_extensions import ParamSpec, Self
from langgraph.channels.base import BaseChannel
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
PendingWrite,
copy_checkpoint,
create_checkpoint,
empty_checkpoint,
)
from langgraph.constants import (
CONF,
CONFIG_KEY_CHECKPOINT_ID,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_DEDUPE_TASKS,
CONFIG_KEY_DELEGATE,
CONFIG_KEY_ENSURE_LATEST,
CONFIG_KEY_RESUMING,
CONFIG_KEY_STREAM,
CONFIG_KEY_TASK_ID,
EMPTY_SEQ,
ERROR,
INPUT,
INTERRUPT,
NS_SEP,
NULL_TASK_ID,
PUSH,
RESUME,
SCHEDULED,
TAG_HIDDEN,
)
from langgraph.errors import (
_SEEN_CHECKPOINT_NS,
CheckpointNotLatest,
EmptyInputError,
GraphDelegate,
GraphInterrupt,
MultipleSubgraphsError,
)
from langgraph.managed.base import (
ManagedValueMapping,
ManagedValueSpec,
WritableManagedValue,
)
from langgraph.pregel.algo import (
GetNextVersion,
PregelTaskWrites,
apply_writes,
increment,
prepare_next_tasks,
prepare_single_task,
should_interrupt,
)
from langgraph.pregel.debug import (
map_debug_checkpoint,
map_debug_task_results,
map_debug_tasks,
print_step_checkpoint,
print_step_tasks,
print_step_writes,
)
from langgraph.pregel.executor import (
AsyncBackgroundExecutor,
BackgroundExecutor,
Submit,
)
from langgraph.pregel.io import (
map_command,
map_input,
map_output_updates,
map_output_values,
read_channels,
single,
)
from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager
from langgraph.pregel.read import PregelNode
from langgraph.pregel.utils import get_new_channel_versions
from langgraph.store.base import BaseStore
from langgraph.types import (
All,
Command,
LoopProtocol,
PregelExecutableTask,
StreamChunk,
StreamProtocol,
)
from langgraph.utils.config import patch_configurable
V = TypeVar("V")
P = ParamSpec("P")
INPUT_DONE = object()
INPUT_RESUMING = object()
SPECIAL_CHANNELS = (ERROR, INTERRUPT, SCHEDULED)
def DuplexStream(*streams: StreamProtocol) -> StreamProtocol:
def __call__(value: StreamChunk) -> None:
for stream in streams:
if value[1] in stream.modes:
stream(value)
return StreamProtocol(__call__, {mode for s in streams for mode in s.modes})
class PregelLoop(LoopProtocol):
input: Optional[Any]
checkpointer: Optional[BaseCheckpointSaver]
nodes: Mapping[str, PregelNode]
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]]
output_keys: Union[str, Sequence[str]]
stream_keys: Union[str, Sequence[str]]
skip_done_tasks: bool
is_nested: bool
manager: Union[None, AsyncParentRunManager, ParentRunManager]
interrupt_after: Union[All, Sequence[str]]
interrupt_before: Union[All, Sequence[str]]
checkpointer_get_next_version: GetNextVersion
checkpointer_put_writes: Optional[
Callable[[RunnableConfig, Sequence[tuple[str, Any]], str], Any]
]
_checkpointer_put_after_previous: Optional[
Callable[
[
Optional[concurrent.futures.Future],
RunnableConfig,
Sequence[tuple[str, Any]],
str,
ChannelVersions,
],
Any,
]
]
submit: Submit
channels: Mapping[str, BaseChannel]
managed: ManagedValueMapping
checkpoint: Checkpoint
checkpoint_ns: tuple[str, ...]
checkpoint_config: RunnableConfig
checkpoint_metadata: CheckpointMetadata
checkpoint_pending_writes: List[PendingWrite]
checkpoint_previous_versions: dict[str, Union[str, float, int]]
prev_checkpoint_config: Optional[RunnableConfig]
status: Literal[
"pending", "done", "interrupt_before", "interrupt_after", "out_of_steps"
]
tasks: dict[str, PregelExecutableTask]
to_interrupt: list[PregelExecutableTask]
output: Union[None, dict[str, Any], Any] = None
# public
def __init__(
self,
input: Optional[Any],
*,
stream: Optional[StreamProtocol],
config: RunnableConfig,
store: Optional[BaseStore],
checkpointer: Optional[BaseCheckpointSaver],
nodes: Mapping[str, PregelNode],
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]],
output_keys: Union[str, Sequence[str]],
stream_keys: Union[str, Sequence[str]],
interrupt_after: Union[All, Sequence[str]] = EMPTY_SEQ,
interrupt_before: Union[All, Sequence[str]] = EMPTY_SEQ,
manager: Union[None, AsyncParentRunManager, ParentRunManager] = None,
check_subgraphs: bool = True,
debug: bool = False,
) -> None:
super().__init__(
step=0,
stop=0,
config=config,
stream=stream,
store=store,
)
self.input = input
self.checkpointer = checkpointer
self.nodes = nodes
self.specs = specs
self.output_keys = output_keys
self.stream_keys = stream_keys
self.interrupt_after = interrupt_after
self.interrupt_before = interrupt_before
self.manager = manager
self.is_nested = CONFIG_KEY_TASK_ID in self.config.get(CONF, {})
self.skip_done_tasks = (
CONFIG_KEY_CHECKPOINT_ID not in config[CONF]
or CONFIG_KEY_DEDUPE_TASKS in config[CONF]
)
self.debug = debug
if self.stream is not None and CONFIG_KEY_STREAM in config[CONF]:
self.stream = DuplexStream(self.stream, config[CONF][CONFIG_KEY_STREAM])
if not self.is_nested and config[CONF].get(CONFIG_KEY_CHECKPOINT_NS):
self.config = patch_configurable(
self.config,
{CONFIG_KEY_CHECKPOINT_NS: "", CONFIG_KEY_CHECKPOINT_ID: None},
)
if check_subgraphs and self.is_nested and self.checkpointer is not None:
if self.config[CONF][CONFIG_KEY_CHECKPOINT_NS] in _SEEN_CHECKPOINT_NS:
raise MultipleSubgraphsError(
"Multiple subgraphs called inside the same node\n\n"
"Troubleshooting URL: https://python.langchain.com/docs"
"/troubleshooting/errors/MULTIPLE_SUBGRAPHS/"
)
else:
_SEEN_CHECKPOINT_NS.add(self.config[CONF][CONFIG_KEY_CHECKPOINT_NS])
if (
CONFIG_KEY_CHECKPOINT_MAP in self.config[CONF]
and self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS)
in self.config[CONF][CONFIG_KEY_CHECKPOINT_MAP]
):
self.checkpoint_config = patch_configurable(
self.config,
{
CONFIG_KEY_CHECKPOINT_ID: config[CONF][CONFIG_KEY_CHECKPOINT_MAP][
self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]
]
},
)
else:
self.checkpoint_config = config
self.checkpoint_ns = (
tuple(cast(str, self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP))
if self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS)
else ()
)
self.prev_checkpoint_config = None
def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None:
"""Put writes for a task, to be read by the next tick."""
if not writes:
return
# deduplicate writes to special channels, last write wins
if all(w[0] in WRITES_IDX_MAP for w in writes):
writes = list({w[0]: w for w in writes}.values())
# save writes
for c, v in writes:
if (
c in WRITES_IDX_MAP
and (
idx := next(
(
i
for i, w in enumerate(self.checkpoint_pending_writes)
if w[0] == task_id and w[1] == c
),
None,
)
)
is not None
):
self.checkpoint_pending_writes[idx] = (task_id, c, v)
else:
self.checkpoint_pending_writes.append((task_id, c, v))
if self.checkpointer_put_writes is not None:
self.submit(
self.checkpointer_put_writes,
{
**self.checkpoint_config,
CONF: {
**self.checkpoint_config[CONF],
CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
CONFIG_KEY_CHECKPOINT_NS, ""
),
CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
},
},
writes,
task_id,
)
# output writes
if hasattr(self, "tasks"):
self._output_writes(task_id, writes)
def accept_push(
self, task: PregelExecutableTask, write_idx: int
) -> Optional[PregelExecutableTask]:
"""Accept a PUSH from a task, potentially returning a new task to start."""
# don't start if an earlier PUSH has already triggered an interrupt
if self.to_interrupt:
return
# don't start if we should interrupt *after* the original task
if should_interrupt(self.checkpoint, self.interrupt_after, [task]):
self.to_interrupt.append(task)
return
if pushed := cast(
Optional[PregelExecutableTask],
prepare_single_task(
(PUSH, task.path, write_idx, task.id),
None,
checkpoint=self.checkpoint,
pending_writes=[(task.id, *w) for w in task.writes],
processes=self.nodes,
channels=self.channels,
managed=self.managed,
config=self.config,
step=self.step,
for_execution=True,
store=self.store,
checkpointer=self.checkpointer,
manager=self.manager,
),
):
# don't start if we should interrupt *before* the new task
if should_interrupt(self.checkpoint, self.interrupt_before, [pushed]):
self.to_interrupt.append(pushed)
return
# produce debug output
self._emit("debug", map_debug_tasks, self.step, [pushed])
# debug flag
if self.debug:
print_step_tasks(self.step, [pushed])
# save the new task
self.tasks[pushed.id] = pushed
# match any pending writes to the new task
if self.skip_done_tasks:
self._match_writes({pushed.id: pushed})
# return the new task, to be started, if not run before
if not pushed.writes:
return pushed
def tick(
self,
*,
input_keys: Union[str, Sequence[str]],
) -> bool:
"""Execute a single iteration of the Pregel loop.
Returns True if more iterations are needed."""
if self.status != "pending":
raise RuntimeError("Cannot tick when status is no longer 'pending'")
if self.input not in (INPUT_DONE, INPUT_RESUMING):
self._first(input_keys=input_keys)
elif self.to_interrupt:
# if we need to interrupt, do so
self.status = "interrupt_before"
raise GraphInterrupt()
elif all(task.writes for task in self.tasks.values()):
writes = [w for t in self.tasks.values() for w in t.writes]
# debug flag
if self.debug:
print_step_writes(
self.step,
writes,
(
[self.stream_keys]
if isinstance(self.stream_keys, str)
else self.stream_keys
),
)
# all tasks have finished
mv_writes = apply_writes(
self.checkpoint,
self.channels,
self.tasks.values(),
self.checkpointer_get_next_version,
)
# apply writes to managed values
for key, values in mv_writes.items():
self._update_mv(key, values)
# produce values output
self._emit(
"values", map_output_values, self.output_keys, writes, self.channels
)
# clear pending writes
self.checkpoint_pending_writes.clear()
# "not skip_done_tasks" only applies to first tick after resuming
self.skip_done_tasks = True
# save checkpoint
self._put_checkpoint(
{
"source": "loop",
"writes": single(
map_output_updates(
self.output_keys,
[(t, t.writes) for t in self.tasks.values()],
)
),
}
)
# after execution, check if we should interrupt
if should_interrupt(
self.checkpoint, self.interrupt_after, self.tasks.values()
):
self.status = "interrupt_after"
raise GraphInterrupt()
else:
return False
# check if iteration limit is reached
if self.step > self.stop:
self.status = "out_of_steps"
return False
# apply NULL writes
if null_writes := [
w[1:] for w in self.checkpoint_pending_writes if w[0] == NULL_TASK_ID
]:
mv_writes = apply_writes(
self.checkpoint,
self.channels,
[PregelTaskWrites((), INPUT, null_writes, [])],
self.checkpointer_get_next_version,
)
for key, values in mv_writes.items():
self._update_mv(key, values)
# prepare next tasks
self.tasks = prepare_next_tasks(
self.checkpoint,
self.checkpoint_pending_writes,
self.nodes,
self.channels,
self.managed,
self.config,
self.step,
for_execution=True,
manager=self.manager,
store=self.store,
checkpointer=self.checkpointer,
)
self.to_interrupt = []
# produce debug output
if self._checkpointer_put_after_previous is not None:
self._emit(
"debug",
map_debug_checkpoint,
self.step - 1, # printing checkpoint for previous step
self.checkpoint_config,
self.channels,
self.stream_keys,
self.checkpoint_metadata,
self.checkpoint,
self.tasks.values(),
self.checkpoint_pending_writes,
self.prev_checkpoint_config,
self.output_keys,
)
# if no more tasks, we're done
if not self.tasks:
self.status = "done"
return False
# check if we should delegate (used by subgraphs in distributed mode)
if self.config[CONF].get(CONFIG_KEY_DELEGATE):
assert self.input is INPUT_RESUMING
raise GraphDelegate(
{
"config": patch_configurable(
self.config, {CONFIG_KEY_DELEGATE: False}
),
"input": None,
}
)
# if there are pending writes from a previous loop, apply them
if self.skip_done_tasks and self.checkpoint_pending_writes:
self._match_writes(self.tasks)
# if all tasks have finished, re-tick
if all(task.writes for task in self.tasks.values()):
return self.tick(input_keys=input_keys)
# before execution, check if we should interrupt
if should_interrupt(
self.checkpoint, self.interrupt_before, self.tasks.values()
):
self.status = "interrupt_before"
raise GraphInterrupt()
# produce debug output
self._emit("debug", map_debug_tasks, self.step, self.tasks.values())
# debug flag
if self.debug:
print_step_tasks(self.step, list(self.tasks.values()))
# print output for any tasks we applied previous writes to
for task in self.tasks.values():
if task.writes:
self._output_writes(task.id, task.writes, cached=True)
return True
# private
def _match_writes(self, tasks: Mapping[str, PregelExecutableTask]) -> None:
for tid, k, v in self.checkpoint_pending_writes:
if k in (ERROR, INTERRUPT, RESUME):
continue
if task := tasks.get(tid):
if k == SCHEDULED:
if v == max(
self.checkpoint["versions_seen"].get(INTERRUPT, {}).values(),
default=None,
):
self.tasks[tid] = task._replace(scheduled=True)
else:
task.writes.append((k, v))
def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None:
# resuming from previous checkpoint requires
# - finding a previous checkpoint
# - receiving None input (outer graph) or RESUMING flag (subgraph)
configurable = self.config.get(CONF, {})
is_resuming = bool(self.checkpoint["channel_versions"]) and bool(
configurable.get(CONFIG_KEY_RESUMING, self.input is None)
)
# proceed past previous checkpoint
if is_resuming:
self.checkpoint["versions_seen"].setdefault(INTERRUPT, {})
for k in self.channels:
if k in self.checkpoint["channel_versions"]:
version = self.checkpoint["channel_versions"][k]
self.checkpoint["versions_seen"][INTERRUPT][k] = version
# produce values output
self._emit(
"values", map_output_values, self.output_keys, True, self.channels
)
# map command to writes
elif isinstance(self.input, Command):
writes: defaultdict[str, list[tuple[str, Any]]] = defaultdict(list)
# group writes by task ID
for tid, c, v in map_command(self.input, self.checkpoint_pending_writes):
writes[tid].append((c, v))
if not writes:
raise EmptyInputError("Received empty Command input")
# save writes
for tid, ws in writes.items():
self.put_writes(tid, ws)
# map inputs to channel updates
elif input_writes := deque(map_input(input_keys, self.input)):
# TODO shouldn't these writes be passed to put_writes too?
# check if we should delegate (used by subgraphs in distributed mode)
if self.config[CONF].get(CONFIG_KEY_DELEGATE):
raise GraphDelegate(
{
"config": patch_configurable(
self.config, {CONFIG_KEY_DELEGATE: False}
),
"input": self.input,
}
)
# discard any unfinished tasks from previous checkpoint
discard_tasks = prepare_next_tasks(
self.checkpoint,
self.checkpoint_pending_writes,
self.nodes,
self.channels,
self.managed,
self.config,
self.step,
for_execution=True,
store=None,
checkpointer=None,
manager=None,
)
# apply input writes
mv_writes = apply_writes(
self.checkpoint,
self.channels,
[
*discard_tasks.values(),
PregelTaskWrites((), INPUT, input_writes, []),
],
self.checkpointer_get_next_version,
)
assert not mv_writes, "Can't write to SharedValues in graph input"
# save input checkpoint
self._put_checkpoint({"source": "input", "writes": dict(input_writes)})
elif CONFIG_KEY_RESUMING not in configurable:
raise EmptyInputError(f"Received no input for {input_keys}")
# done with input
self.input = INPUT_RESUMING if is_resuming else INPUT_DONE
# update config
if not self.is_nested:
self.config = patch_configurable(
self.config, {CONFIG_KEY_RESUMING: is_resuming}
)
def _put_checkpoint(self, metadata: CheckpointMetadata) -> None:
for k, v in self.config["metadata"].items():
metadata.setdefault(k, v) # type: ignore
# assign step and parents
metadata["step"] = self.step
metadata["parents"] = self.config[CONF].get(CONFIG_KEY_CHECKPOINT_MAP, {})
# debug flag
if self.debug:
print_step_checkpoint(
metadata,
self.channels,
(
[self.stream_keys]
if isinstance(self.stream_keys, str)
else self.stream_keys
),
)
# create new checkpoint
self.checkpoint = create_checkpoint(self.checkpoint, self.channels, self.step)
# bail if no checkpointer
if self._checkpointer_put_after_previous is not None:
self.checkpoint_metadata = metadata
self.prev_checkpoint_config = (
self.checkpoint_config
if CONFIG_KEY_CHECKPOINT_ID in self.checkpoint_config[CONF]
and self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID]
else None
)
self.checkpoint_config = {
**self.checkpoint_config,
CONF: {
**self.checkpoint_config[CONF],
CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
CONFIG_KEY_CHECKPOINT_NS, ""
),
},
}
channel_versions = self.checkpoint["channel_versions"].copy()
new_versions = get_new_channel_versions(
self.checkpoint_previous_versions, channel_versions
)
self.checkpoint_previous_versions = channel_versions
# save it, without blocking
# if there's a previous checkpoint save in progress, wait for it
# ensuring checkpointers receive checkpoints in order
self._put_checkpoint_fut = self.submit(
self._checkpointer_put_after_previous,
getattr(self, "_put_checkpoint_fut", None),
self.checkpoint_config,
copy_checkpoint(self.checkpoint),
self.checkpoint_metadata,
new_versions,
)
self.checkpoint_config = {
**self.checkpoint_config,
CONF: {
**self.checkpoint_config[CONF],
CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
},
}
# increment step
self.step += 1
def _update_mv(self, key: str, values: Sequence[Any]) -> None:
raise NotImplementedError
def _suppress_interrupt(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
suppress = isinstance(exc_value, GraphInterrupt) and not self.is_nested
if suppress or exc_type is None:
# save final output
self.output = read_channels(self.channels, self.output_keys)
if suppress:
# emit one last "values" event, with pending writes applied
if (
hasattr(self, "tasks")
and self.checkpoint_pending_writes
and any(task.writes for task in self.tasks.values())
):
mv_writes = apply_writes(
self.checkpoint,
self.channels,
self.tasks.values(),
self.checkpointer_get_next_version,
)
for key, values in mv_writes.items():
self._update_mv(key, values)
self._emit(
"values",
map_output_values,
self.output_keys,
[w for t in self.tasks.values() for w in t.writes],
self.channels,
)
# emit INTERRUPT event
self._emit(
"updates",
lambda: iter([{INTERRUPT: cast(GraphInterrupt, exc_value).args[0]}]),
)
# suppress interrupt
return True
def _emit(
self,
mode: str,
values: Callable[P, Iterator[Any]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
if self.stream is None:
return
if mode not in self.stream.modes:
return
for v in values(*args, **kwargs):
self.stream((self.checkpoint_ns, mode, v))
def _output_writes(
self, task_id: str, writes: Sequence[tuple[str, Any]], *, cached: bool = False
) -> None:
if task := self.tasks.get(task_id):
if task.config is not None and TAG_HIDDEN in task.config.get(
"tags", EMPTY_SEQ
):
return
if writes[0][0] != ERROR and writes[0][0] != INTERRUPT:
self._emit(
"updates",
map_output_updates,
self.output_keys,
[(task, writes)],
cached,
)
if not cached:
self._emit(
"debug",
map_debug_task_results,
self.step,
(task, writes),
self.stream_keys,
)
class SyncPregelLoop(PregelLoop, ContextManager):
def __init__(
self,
input: Optional[Any],
*,
stream: Optional[StreamProtocol],
config: RunnableConfig,
store: Optional[BaseStore],
checkpointer: Optional[BaseCheckpointSaver],
nodes: Mapping[str, PregelNode],
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]],
manager: Union[None, AsyncParentRunManager, ParentRunManager] = None,
interrupt_after: Union[All, Sequence[str]] = EMPTY_SEQ,
interrupt_before: Union[All, Sequence[str]] = EMPTY_SEQ,
output_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
stream_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
check_subgraphs: bool = True,
debug: bool = False,
) -> None:
super().__init__(
input,
stream=stream,
config=config,
checkpointer=checkpointer,
store=store,
nodes=nodes,
specs=specs,
output_keys=output_keys,
stream_keys=stream_keys,
interrupt_after=interrupt_after,
interrupt_before=interrupt_before,
check_subgraphs=check_subgraphs,
manager=manager,
debug=debug,
)
self.stack = ExitStack()
if checkpointer:
self.checkpointer_get_next_version = checkpointer.get_next_version
self.checkpointer_put_writes = checkpointer.put_writes
else:
self.checkpointer_get_next_version = increment
self._checkpointer_put_after_previous = None # type: ignore[assignment]
self.checkpointer_put_writes = None
def _checkpointer_put_after_previous(
self,
prev: Optional[concurrent.futures.Future],
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
try:
if prev is not None:
prev.result()
finally:
cast(BaseCheckpointSaver, self.checkpointer).put(
config, checkpoint, metadata, new_versions
)
def _update_mv(self, key: str, values: Sequence[Any]) -> None:
return self.submit(cast(WritableManagedValue, self.managed[key]).update, values)
# context manager
def __enter__(self) -> Self:
if self.config.get(CONF, {}).get(
CONFIG_KEY_ENSURE_LATEST
) and self.checkpoint_config[CONF].get(CONFIG_KEY_CHECKPOINT_ID):
if self.checkpointer is None:
raise RuntimeError(
"Cannot ensure latest checkpoint without checkpointer"
)
saved = self.checkpointer.get_tuple(
patch_configurable(
self.checkpoint_config, {CONFIG_KEY_CHECKPOINT_ID: None}
)
)
if (
saved is None
or saved.checkpoint["id"]
!= self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID]
):
raise CheckpointNotLatest
elif self.checkpointer:
saved = self.checkpointer.get_tuple(self.checkpoint_config)
else:
saved = None
if saved is None:
saved = CheckpointTuple(
self.config, empty_checkpoint(), {"step": -2}, None, []
)
self.checkpoint_config = {
**self.config,
**saved.config,
CONF: {
CONFIG_KEY_CHECKPOINT_NS: "",
**self.config.get(CONF, {}),
**saved.config.get(CONF, {}),
},
}
self.prev_checkpoint_config = saved.parent_config
self.checkpoint = saved.checkpoint
self.checkpoint_metadata = saved.metadata
self.checkpoint_pending_writes = (
[(str(tid), k, v) for tid, k, v in saved.pending_writes]
if saved.pending_writes is not None
else []
)
self.submit = self.stack.enter_context(BackgroundExecutor(self.config))
self.channels, self.managed = self.stack.enter_context(
ChannelsManager(self.specs, self.checkpoint, self)
)
self.stack.push(self._suppress_interrupt)
self.status = "pending"
self.step = self.checkpoint_metadata["step"] + 1
self.stop = self.step + self.config["recursion_limit"] + 1
self.checkpoint_previous_versions = self.checkpoint["channel_versions"].copy()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
# unwind stack
return self.stack.__exit__(exc_type, exc_value, traceback)
class AsyncPregelLoop(PregelLoop, AsyncContextManager):
def __init__(
self,
input: Optional[Any],
*,
stream: Optional[StreamProtocol],
config: RunnableConfig,
store: Optional[BaseStore],
checkpointer: Optional[BaseCheckpointSaver],
nodes: Mapping[str, PregelNode],
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]],
interrupt_after: Union[All, Sequence[str]] = EMPTY_SEQ,
interrupt_before: Union[All, Sequence[str]] = EMPTY_SEQ,
manager: Union[None, AsyncParentRunManager, ParentRunManager] = None,
output_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
stream_keys: Union[str, Sequence[str]] = EMPTY_SEQ,
check_subgraphs: bool = True,
debug: bool = False,
) -> None:
super().__init__(
input,
stream=stream,
config=config,
checkpointer=checkpointer,
store=store,
nodes=nodes,
specs=specs,
output_keys=output_keys,
stream_keys=stream_keys,
interrupt_after=interrupt_after,
interrupt_before=interrupt_before,
check_subgraphs=check_subgraphs,
manager=manager,
debug=debug,
)
self.stack = AsyncExitStack()
if checkpointer:
self.checkpointer_get_next_version = checkpointer.get_next_version
self.checkpointer_put_writes = checkpointer.aput_writes
else:
self.checkpointer_get_next_version = increment
self._checkpointer_put_after_previous = None # type: ignore[assignment]
self.checkpointer_put_writes = None
async def _checkpointer_put_after_previous(
self,
prev: Optional[asyncio.Task],
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
try:
if prev is not None:
await prev
finally:
await cast(BaseCheckpointSaver, self.checkpointer).aput(
config, checkpoint, metadata, new_versions
)
def _update_mv(self, key: str, values: Sequence[Any]) -> None:
return self.submit(
cast(WritableManagedValue, self.managed[key]).aupdate, values
)
# context manager
async def __aenter__(self) -> Self:
if self.config.get(CONF, {}).get(
CONFIG_KEY_ENSURE_LATEST
) and self.checkpoint_config[CONF].get(CONFIG_KEY_CHECKPOINT_ID):
if self.checkpointer is None:
raise RuntimeError(
"Cannot ensure latest checkpoint without checkpointer"
)
saved = await self.checkpointer.aget_tuple(
patch_configurable(
self.checkpoint_config, {CONFIG_KEY_CHECKPOINT_ID: None}
)
)
if (
saved is None
or saved.checkpoint["id"]
!= self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID]
):
raise CheckpointNotLatest
elif self.checkpointer:
saved = await self.checkpointer.aget_tuple(self.checkpoint_config)
else:
saved = None
if saved is None:
saved = CheckpointTuple(
self.config, empty_checkpoint(), {"step": -2}, None, []
)
self.checkpoint_config = {
**self.config,
**saved.config,
CONF: {
CONFIG_KEY_CHECKPOINT_NS: "",
**self.config.get(CONF, {}),
**saved.config.get(CONF, {}),
},
}
self.prev_checkpoint_config = saved.parent_config
self.checkpoint = saved.checkpoint
self.checkpoint_metadata = saved.metadata
self.checkpoint_pending_writes = (
[(str(tid), k, v) for tid, k, v in saved.pending_writes]
if saved.pending_writes is not None
else []
)
self.submit = await self.stack.enter_async_context(
AsyncBackgroundExecutor(self.config)
)
self.channels, self.managed = await self.stack.enter_async_context(
AsyncChannelsManager(self.specs, self.checkpoint, self)
)
self.stack.push(self._suppress_interrupt)
self.status = "pending"
self.step = self.checkpoint_metadata["step"] + 1
self.stop = self.step + self.config["recursion_limit"] + 1
self.checkpoint_previous_versions = self.checkpoint["channel_versions"].copy()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
# unwind stack
return await asyncio.shield(
self.stack.__aexit__(exc_type, exc_value, traceback)
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/algo.py`:
```py
import sys
from collections import defaultdict, deque
from functools import partial
from hashlib import sha1
from typing import (
Any,
Callable,
Iterable,
Iterator,
Literal,
Mapping,
NamedTuple,
Optional,
Protocol,
Sequence,
Union,
cast,
overload,
)
from uuid import UUID
from langchain_core.callbacks.manager import AsyncParentRunManager, ParentRunManager
from langchain_core.runnables.config import RunnableConfig
from langgraph.channels.base import BaseChannel
from langgraph.checkpoint.base import (
BaseCheckpointSaver,
Checkpoint,
PendingWrite,
V,
copy_checkpoint,
)
from langgraph.constants import (
CONF,
CONFIG_KEY_CHECKPOINT_ID,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_READ,
CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_SEND,
CONFIG_KEY_STORE,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_WRITES,
EMPTY_SEQ,
INTERRUPT,
NO_WRITES,
NS_END,
NS_SEP,
NULL_TASK_ID,
PULL,
PUSH,
RESERVED,
RESUME,
TAG_HIDDEN,
TASKS,
Send,
)
from langgraph.errors import EmptyChannelError, InvalidUpdateError
from langgraph.managed.base import ManagedValueMapping
from langgraph.pregel.io import read_channel, read_channels
from langgraph.pregel.log import logger
from langgraph.pregel.manager import ChannelsManager
from langgraph.pregel.read import PregelNode
from langgraph.store.base import BaseStore
from langgraph.types import All, LoopProtocol, PregelExecutableTask, PregelTask
from langgraph.utils.config import merge_configs, patch_config
GetNextVersion = Callable[[Optional[V], BaseChannel], V]
SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11)
class WritesProtocol(Protocol):
"""Protocol for objects containing writes to be applied to checkpoint.
Implemented by PregelTaskWrites and PregelExecutableTask."""
@property
def path(self) -> tuple[Union[str, int, tuple], ...]: ...
@property
def name(self) -> str: ...
@property
def writes(self) -> Sequence[tuple[str, Any]]: ...
@property
def triggers(self) -> Sequence[str]: ...
class PregelTaskWrites(NamedTuple):
"""Simplest implementation of WritesProtocol, for usage with writes that
don't originate from a runnable task, eg. graph input, update_state, etc."""
path: tuple[Union[str, int, tuple], ...]
name: str
writes: Sequence[tuple[str, Any]]
triggers: Sequence[str]
def should_interrupt(
checkpoint: Checkpoint,
interrupt_nodes: Union[All, Sequence[str]],
tasks: Iterable[PregelExecutableTask],
) -> list[PregelExecutableTask]:
"""Check if the graph should be interrupted based on current state."""
version_type = type(next(iter(checkpoint["channel_versions"].values()), None))
null_version = version_type() # type: ignore[misc]
seen = checkpoint["versions_seen"].get(INTERRUPT, {})
# interrupt if any channel has been updated since last interrupt
any_updates_since_prev_interrupt = any(
version > seen.get(chan, null_version) # type: ignore[operator]
for chan, version in checkpoint["channel_versions"].items()
)
# and any triggered node is in interrupt_nodes list
return (
[
task
for task in tasks
if (
(
not task.config
or TAG_HIDDEN not in task.config.get("tags", EMPTY_SEQ)
)
if interrupt_nodes == "*"
else task.name in interrupt_nodes
)
]
if any_updates_since_prev_interrupt
else []
)
def local_read(
step: int,
checkpoint: Checkpoint,
channels: Mapping[str, BaseChannel],
managed: ManagedValueMapping,
task: WritesProtocol,
config: RunnableConfig,
select: Union[list[str], str],
fresh: bool = False,
) -> Union[dict[str, Any], Any]:
"""Function injected under CONFIG_KEY_READ in task config, to read current state.
Used by conditional edges to read a copy of the state with reflecting the writes
from that node only."""
if isinstance(select, str):
managed_keys = []
for c, _ in task.writes:
if c == select:
updated = {c}
break
else:
updated = set()
else:
managed_keys = [k for k in select if k in managed]
select = [k for k in select if k not in managed]
updated = set(select).intersection(c for c, _ in task.writes)
if fresh and updated:
with ChannelsManager(
{k: v for k, v in channels.items() if k in updated},
checkpoint,
LoopProtocol(config=config, step=step, stop=step + 1),
skip_context=True,
) as (local_channels, _):
apply_writes(copy_checkpoint(checkpoint), local_channels, [task], None)
values = read_channels({**channels, **local_channels}, select)
else:
values = read_channels(channels, select)
if managed_keys:
values.update({k: managed[k]() for k in managed_keys})
return values
def local_write(
commit: Callable[[Sequence[tuple[str, Any]]], None],
process_keys: Iterable[str],
writes: Sequence[tuple[str, Any]],
) -> None:
"""Function injected under CONFIG_KEY_SEND in task config, to write to channels.
Validates writes and forwards them to `commit` function."""
for chan, value in writes:
if chan in (PUSH, TASKS):
if not isinstance(value, Send):
raise InvalidUpdateError(f"Expected Send, got {value}")
if value.node not in process_keys:
raise InvalidUpdateError(f"Invalid node name {value.node} in packet")
commit(writes)
def increment(current: Optional[int], channel: BaseChannel) -> int:
"""Default channel versioning function, increments the current int version."""
return current + 1 if current is not None else 1
def apply_writes(
checkpoint: Checkpoint,
channels: Mapping[str, BaseChannel],
tasks: Iterable[WritesProtocol],
get_next_version: Optional[GetNextVersion],
) -> dict[str, list[Any]]:
"""Apply writes from a set of tasks (usually the tasks from a Pregel step)
to the checkpoint and channels, and return managed values writes to be applied
externally."""
# sort tasks on path, to ensure deterministic order for update application
# any path parts after the 3rd are ignored for sorting
# (we use them for eg. task ids which aren't good for sorting)
tasks = sorted(tasks, key=lambda t: t.path[:3])
# if no task has triggers this is applying writes from the null task only
# so we don't do anything other than update the channels written to
bump_step = any(t.triggers for t in tasks)
# update seen versions
for task in tasks:
checkpoint["versions_seen"].setdefault(task.name, {}).update(
{
chan: checkpoint["channel_versions"][chan]
for chan in task.triggers
if chan in checkpoint["channel_versions"]
}
)
# Find the highest version of all channels
if checkpoint["channel_versions"]:
max_version = max(checkpoint["channel_versions"].values())
else:
max_version = None
# Consume all channels that were read
for chan in {
chan
for task in tasks
for chan in task.triggers
if chan not in RESERVED and chan in channels
}:
if channels[chan].consume() and get_next_version is not None:
checkpoint["channel_versions"][chan] = get_next_version(
max_version,
channels[chan],
)
# clear pending sends
if checkpoint["pending_sends"] and bump_step:
checkpoint["pending_sends"].clear()
# Group writes by channel
pending_writes_by_channel: dict[str, list[Any]] = defaultdict(list)
pending_writes_by_managed: dict[str, list[Any]] = defaultdict(list)
for task in tasks:
for chan, val in task.writes:
if chan in (NO_WRITES, PUSH, RESUME, INTERRUPT):
pass
elif chan == TASKS: # TODO: remove branch in 1.0
checkpoint["pending_sends"].append(val)
elif chan in channels:
pending_writes_by_channel[chan].append(val)
else:
pending_writes_by_managed[chan].append(val)
# Find the highest version of all channels
if checkpoint["channel_versions"]:
max_version = max(checkpoint["channel_versions"].values())
else:
max_version = None
# Apply writes to channels
updated_channels: set[str] = set()
for chan, vals in pending_writes_by_channel.items():
if chan in channels:
if channels[chan].update(vals) and get_next_version is not None:
checkpoint["channel_versions"][chan] = get_next_version(
max_version,
channels[chan],
)
updated_channels.add(chan)
# Channels that weren't updated in this step are notified of a new step
if bump_step:
for chan in channels:
if chan not in updated_channels:
if channels[chan].update([]) and get_next_version is not None:
checkpoint["channel_versions"][chan] = get_next_version(
max_version,
channels[chan],
)
# Return managed values writes to be applied externally
return pending_writes_by_managed
@overload
def prepare_next_tasks(
checkpoint: Checkpoint,
pending_writes: Sequence[PendingWrite],
processes: Mapping[str, PregelNode],
channels: Mapping[str, BaseChannel],
managed: ManagedValueMapping,
config: RunnableConfig,
step: int,
*,
for_execution: Literal[False],
store: Literal[None] = None,
checkpointer: Literal[None] = None,
manager: Literal[None] = None,
) -> dict[str, PregelTask]: ...
@overload
def prepare_next_tasks(
checkpoint: Checkpoint,
pending_writes: Sequence[PendingWrite],
processes: Mapping[str, PregelNode],
channels: Mapping[str, BaseChannel],
managed: ManagedValueMapping,
config: RunnableConfig,
step: int,
*,
for_execution: Literal[True],
store: Optional[BaseStore],
checkpointer: Optional[BaseCheckpointSaver],
manager: Union[None, ParentRunManager, AsyncParentRunManager],
) -> dict[str, PregelExecutableTask]: ...
def prepare_next_tasks(
checkpoint: Checkpoint,
pending_writes: Sequence[PendingWrite],
processes: Mapping[str, PregelNode],
channels: Mapping[str, BaseChannel],
managed: ManagedValueMapping,
config: RunnableConfig,
step: int,
*,
for_execution: bool,
store: Optional[BaseStore] = None,
checkpointer: Optional[BaseCheckpointSaver] = None,
manager: Union[None, ParentRunManager, AsyncParentRunManager] = None,
) -> Union[dict[str, PregelTask], dict[str, PregelExecutableTask]]:
"""Prepare the set of tasks that will make up the next Pregel step.
This is the union of all PUSH tasks (Sends) and PULL tasks (nodes triggered
by edges)."""
tasks: list[Union[PregelTask, PregelExecutableTask]] = []
# Consume pending_sends from previous step (legacy version of Send)
for idx, _ in enumerate(checkpoint["pending_sends"]): # TODO: remove branch in 1.0
if task := prepare_single_task(
(PUSH, idx),
None,
checkpoint=checkpoint,
pending_writes=pending_writes,
processes=processes,
channels=channels,
managed=managed,
config=config,
step=step,
for_execution=for_execution,
store=store,
checkpointer=checkpointer,
manager=manager,
):
tasks.append(task)
# Check if any processes should be run in next step
# If so, prepare the values to be passed to them
for name in processes:
if task := prepare_single_task(
(PULL, name),
None,
checkpoint=checkpoint,
pending_writes=pending_writes,
processes=processes,
channels=channels,
managed=managed,
config=config,
step=step,
for_execution=for_execution,
store=store,
checkpointer=checkpointer,
manager=manager,
):
tasks.append(task)
# Consume pending Sends from this step (new version of Send)
if any(c == PUSH for _, c, _ in pending_writes):
# group writes by task id
grouped_by_task = defaultdict(list)
for tid, c, _ in pending_writes:
grouped_by_task[tid].append(c)
# prepare send tasks from grouped writes
# 1. start from sends originating from existing tasks
tidx = 0
while tidx < len(tasks):
task = tasks[tidx]
if twrites := grouped_by_task.pop(task.id, None):
for idx, c in enumerate(twrites):
if c != PUSH:
continue
if next_task := prepare_single_task(
(PUSH, task.path, idx, task.id),
None,
checkpoint=checkpoint,
pending_writes=pending_writes,
processes=processes,
channels=channels,
managed=managed,
config=config,
step=step,
for_execution=for_execution,
store=store,
checkpointer=checkpointer,
manager=manager,
):
tasks.append(next_task)
tidx += 1
# key tasks by id
task_map = {t.id: t for t in tasks}
# 2. create new tasks for remaining sends (eg. from update_state)
for tid, writes in grouped_by_task.items():
task = task_map.get(tid)
for idx, c in enumerate(writes):
if c != PUSH:
continue
if next_task := prepare_single_task(
(PUSH, task.path if task else (), idx, tid),
None,
checkpoint=checkpoint,
pending_writes=pending_writes,
processes=processes,
channels=channels,
managed=managed,
config=config,
step=step,
for_execution=for_execution,
store=store,
checkpointer=checkpointer,
manager=manager,
):
task_map[next_task.id] = next_task
else:
task_map = {t.id: t for t in tasks}
return task_map
def prepare_single_task(
task_path: tuple[Union[str, int, tuple], ...],
task_id_checksum: Optional[str],
*,
checkpoint: Checkpoint,
pending_writes: Sequence[PendingWrite],
processes: Mapping[str, PregelNode],
channels: Mapping[str, BaseChannel],
managed: ManagedValueMapping,
config: RunnableConfig,
step: int,
for_execution: bool,
store: Optional[BaseStore] = None,
checkpointer: Optional[BaseCheckpointSaver] = None,
manager: Union[None, ParentRunManager, AsyncParentRunManager] = None,
) -> Union[None, PregelTask, PregelExecutableTask]:
"""Prepares a single task for the next Pregel step, given a task path, which
uniquely identifies a PUSH or PULL task within the graph."""
checkpoint_id = UUID(checkpoint["id"]).bytes
configurable = config.get(CONF, {})
parent_ns = configurable.get(CONFIG_KEY_CHECKPOINT_NS, "")
if task_path[0] == PUSH:
if len(task_path) == 2: # TODO: remove branch in 1.0
# legacy SEND tasks, executed in superstep n+1
# (PUSH, idx of pending send)
idx = cast(int, task_path[1])
if idx >= len(checkpoint["pending_sends"]):
return
packet = checkpoint["pending_sends"][idx]
if not isinstance(packet, Send):
logger.warning(
f"Ignoring invalid packet type {type(packet)} in pending sends"
)
return
if packet.node not in processes:
logger.warning(
f"Ignoring unknown node name {packet.node} in pending sends"
)
return
# create task id
triggers = [PUSH]
checkpoint_ns = (
f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
)
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
PUSH,
str(idx),
)
elif len(task_path) == 4:
# new PUSH tasks, executed in superstep n
# (PUSH, parent task path, idx of PUSH write, id of parent task)
task_path_t = cast(tuple[str, tuple, int, str], task_path)
writes_for_path = [w for w in pending_writes if w[0] == task_path_t[3]]
if task_path_t[2] >= len(writes_for_path):
logger.warning(
f"Ignoring invalid write index {task_path[2]} in pending writes"
)
return
packet = writes_for_path[task_path_t[2]][2]
if not isinstance(packet, Send):
logger.warning(
f"Ignoring invalid packet type {type(packet)} in pending writes"
)
return
if packet.node not in processes:
logger.warning(
f"Ignoring unknown node name {packet.node} in pending writes"
)
return
# create task id
triggers = [PUSH]
checkpoint_ns = (
f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
)
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
PUSH,
_tuple_str(task_path[1]),
str(task_path[2]),
)
else:
logger.warning(f"Ignoring invalid PUSH task path {task_path}")
return
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
metadata = {
"langgraph_step": step,
"langgraph_node": packet.node,
"langgraph_triggers": triggers,
"langgraph_path": task_path,
"langgraph_checkpoint_ns": task_checkpoint_ns,
}
if task_id_checksum is not None:
assert task_id == task_id_checksum, f"{task_id} != {task_id_checksum}"
if for_execution:
proc = processes[packet.node]
if node := proc.node:
if proc.metadata:
metadata.update(proc.metadata)
writes: deque[tuple[str, Any]] = deque()
return PregelExecutableTask(
packet.node,
packet.arg,
node,
writes,
patch_config(
merge_configs(
config, {"metadata": metadata, "tags": proc.tags}
),
run_name=packet.node,
callbacks=(
manager.get_child(f"graph:step:{step}") if manager else None
),
configurable={
CONFIG_KEY_TASK_ID: task_id,
# deque.extend is thread-safe
CONFIG_KEY_SEND: partial(
local_write,
writes.extend,
processes.keys(),
),
CONFIG_KEY_READ: partial(
local_read,
step,
checkpoint,
channels,
managed,
PregelTaskWrites(
task_path, packet.node, writes, triggers
),
config,
),
CONFIG_KEY_STORE: (
store or configurable.get(CONFIG_KEY_STORE)
),
CONFIG_KEY_CHECKPOINTER: (
checkpointer
or configurable.get(CONFIG_KEY_CHECKPOINTER)
),
CONFIG_KEY_CHECKPOINT_MAP: {
**configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}),
parent_ns: checkpoint["id"],
},
CONFIG_KEY_CHECKPOINT_ID: None,
CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
CONFIG_KEY_WRITES: [
w
for w in pending_writes
+ configurable.get(CONFIG_KEY_WRITES, [])
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
},
),
triggers,
proc.retry_policy,
None,
task_id,
task_path,
writers=proc.flat_writers,
)
else:
return PregelTask(task_id, packet.node, task_path)
elif task_path[0] == PULL:
# (PULL, node name)
name = cast(str, task_path[1])
if name not in processes:
return
proc = processes[name]
version_type = type(next(iter(checkpoint["channel_versions"].values()), None))
null_version = version_type() # type: ignore[misc]
if null_version is None:
return
seen = checkpoint["versions_seen"].get(name, {})
# If any of the channels read by this process were updated
if triggers := sorted(
chan
for chan in proc.triggers
if not isinstance(
read_channel(channels, chan, return_exception=True), EmptyChannelError
)
and checkpoint["channel_versions"].get(chan, null_version) # type: ignore[operator]
> seen.get(chan, null_version)
):
try:
val = next(
_proc_input(proc, managed, channels, for_execution=for_execution)
)
except StopIteration:
return
except Exception as exc:
if SUPPORTS_EXC_NOTES:
exc.add_note(
f"Before task with name '{name}' and path '{task_path[:3]}'"
)
raise
# create task id
checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
PULL,
*triggers,
)
task_checkpoint_ns = f"{checkpoint_ns}{NS_END}{task_id}"
metadata = {
"langgraph_step": step,
"langgraph_node": name,
"langgraph_triggers": triggers,
"langgraph_path": task_path,
"langgraph_checkpoint_ns": task_checkpoint_ns,
}
if task_id_checksum is not None:
assert task_id == task_id_checksum
if for_execution:
if node := proc.node:
if proc.metadata:
metadata.update(proc.metadata)
writes = deque()
return PregelExecutableTask(
name,
val,
node,
writes,
patch_config(
merge_configs(
config, {"metadata": metadata, "tags": proc.tags}
),
run_name=name,
callbacks=(
manager.get_child(f"graph:step:{step}")
if manager
else None
),
configurable={
CONFIG_KEY_TASK_ID: task_id,
# deque.extend is thread-safe
CONFIG_KEY_SEND: partial(
local_write,
writes.extend,
processes.keys(),
),
CONFIG_KEY_READ: partial(
local_read,
step,
checkpoint,
channels,
managed,
PregelTaskWrites(task_path, name, writes, triggers),
config,
),
CONFIG_KEY_STORE: (
store or configurable.get(CONFIG_KEY_STORE)
),
CONFIG_KEY_CHECKPOINTER: (
checkpointer
or configurable.get(CONFIG_KEY_CHECKPOINTER)
),
CONFIG_KEY_CHECKPOINT_MAP: {
**configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}),
parent_ns: checkpoint["id"],
},
CONFIG_KEY_CHECKPOINT_ID: None,
CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
CONFIG_KEY_WRITES: [
w
for w in pending_writes
+ configurable.get(CONFIG_KEY_WRITES, [])
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
},
),
triggers,
proc.retry_policy,
None,
task_id,
task_path,
writers=proc.flat_writers,
)
else:
return PregelTask(task_id, name, task_path)
def _proc_input(
proc: PregelNode,
managed: ManagedValueMapping,
channels: Mapping[str, BaseChannel],
*,
for_execution: bool,
) -> Iterator[Any]:
"""Prepare input for a PULL task, based on the process's channels and triggers."""
# If all trigger channels subscribed by this process are not empty
# then invoke the process with the values of all non-empty channels
if isinstance(proc.channels, dict):
try:
val: dict[str, Any] = {}
for k, chan in proc.channels.items():
if chan in proc.triggers:
val[k] = read_channel(channels, chan, catch=False)
elif chan in channels:
try:
val[k] = read_channel(channels, chan, catch=False)
except EmptyChannelError:
continue
else:
val[k] = managed[k]()
except EmptyChannelError:
return
elif isinstance(proc.channels, list):
for chan in proc.channels:
try:
val = read_channel(channels, chan, catch=False)
break
except EmptyChannelError:
pass
else:
return
else:
raise RuntimeError(
"Invalid channels type, expected list or dict, got {proc.channels}"
)
# If the process has a mapper, apply it to the value
if for_execution and proc.mapper is not None:
val = proc.mapper(val)
yield val
def _uuid5_str(namespace: bytes, *parts: str) -> str:
"""Generate a UUID from the SHA-1 hash of a namespace UUID and a name."""
sha = sha1(namespace, usedforsecurity=False)
sha.update(b"".join(p.encode() for p in parts))
hex = sha.hexdigest()
return f"{hex[:8]}-{hex[8:12]}-{hex[12:16]}-{hex[16:20]}-{hex[20:32]}"
def _tuple_str(tup: Union[str, int, tuple]) -> str:
"""Generate a string representation of a tuple."""
return (
f"({', '.join(_tuple_str(x) for x in tup)})"
if isinstance(tup, (tuple, list))
else str(tup)
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/manager.py`:
```py
import asyncio
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
from typing import AsyncIterator, Iterator, Mapping, Union
from langgraph.channels.base import BaseChannel
from langgraph.checkpoint.base import Checkpoint
from langgraph.managed.base import (
ConfiguredManagedValue,
ManagedValueMapping,
ManagedValueSpec,
)
from langgraph.managed.context import Context
from langgraph.types import LoopProtocol
@contextmanager
def ChannelsManager(
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]],
checkpoint: Checkpoint,
loop: LoopProtocol,
*,
skip_context: bool = False,
) -> Iterator[tuple[Mapping[str, BaseChannel], ManagedValueMapping]]:
"""Manage channels for the lifetime of a Pregel invocation (multiple steps)."""
channel_specs: dict[str, BaseChannel] = {}
managed_specs: dict[str, ManagedValueSpec] = {}
for k, v in specs.items():
if isinstance(v, BaseChannel):
channel_specs[k] = v
elif (
skip_context and isinstance(v, ConfiguredManagedValue) and v.cls is Context
):
managed_specs[k] = Context.of(noop_context)
else:
managed_specs[k] = v
with ExitStack() as stack:
yield (
{
k: v.from_checkpoint(checkpoint["channel_values"].get(k))
for k, v in channel_specs.items()
},
ManagedValueMapping(
{
key: stack.enter_context(
value.cls.enter(loop, **value.kwargs)
if isinstance(value, ConfiguredManagedValue)
else value.enter(loop)
)
for key, value in managed_specs.items()
}
),
)
@asynccontextmanager
async def AsyncChannelsManager(
specs: Mapping[str, Union[BaseChannel, ManagedValueSpec]],
checkpoint: Checkpoint,
loop: LoopProtocol,
*,
skip_context: bool = False,
) -> AsyncIterator[tuple[Mapping[str, BaseChannel], ManagedValueMapping]]:
"""Manage channels for the lifetime of a Pregel invocation (multiple steps)."""
channel_specs: dict[str, BaseChannel] = {}
managed_specs: dict[str, ManagedValueSpec] = {}
for k, v in specs.items():
if isinstance(v, BaseChannel):
channel_specs[k] = v
elif (
skip_context and isinstance(v, ConfiguredManagedValue) and v.cls is Context
):
managed_specs[k] = Context.of(noop_context)
else:
managed_specs[k] = v
async with AsyncExitStack() as stack:
# managed: create enter tasks with reference to spec, await them
if tasks := {
asyncio.create_task(
stack.enter_async_context(
value.cls.aenter(loop, **value.kwargs)
if isinstance(value, ConfiguredManagedValue)
else value.aenter(loop)
)
): key
for key, value in managed_specs.items()
}:
done, _ = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
else:
done = set()
yield (
# channels: enter each channel with checkpoint
{
k: v.from_checkpoint(checkpoint["channel_values"].get(k))
for k, v in channel_specs.items()
},
# managed: build mapping from spec to result
ManagedValueMapping({tasks[task]: task.result() for task in done}),
)
@contextmanager
def noop_context() -> Iterator[None]:
yield None
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/executor.py`:
```py
import asyncio
import concurrent.futures
import sys
from contextlib import ExitStack
from contextvars import copy_context
from types import TracebackType
from typing import (
AsyncContextManager,
Awaitable,
Callable,
ContextManager,
Coroutine,
Optional,
Protocol,
TypeVar,
cast,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import get_executor_for_config
from typing_extensions import ParamSpec
from langgraph.errors import GraphBubbleUp
P = ParamSpec("P")
T = TypeVar("T")
class Submit(Protocol[P, T]):
def __call__(
self,
fn: Callable[P, T],
*args: P.args,
__name__: Optional[str] = None,
__cancel_on_exit__: bool = False,
__reraise_on_exit__: bool = True,
**kwargs: P.kwargs,
) -> concurrent.futures.Future[T]: ...
class BackgroundExecutor(ContextManager):
"""A context manager that runs sync tasks in the background.
Uses a thread pool executor to delegate tasks to separate threads.
On exit,
- cancels any (not yet started) tasks with `__cancel_on_exit__=True`
- waits for all tasks to finish
- re-raises the first exception from tasks with `__reraise_on_exit__=True`"""
def __init__(self, config: RunnableConfig) -> None:
self.stack = ExitStack()
self.executor = self.stack.enter_context(get_executor_for_config(config))
self.tasks: dict[concurrent.futures.Future, tuple[bool, bool]] = {}
def submit( # type: ignore[valid-type]
self,
fn: Callable[P, T],
*args: P.args,
__name__: Optional[str] = None, # currently not used in sync version
__cancel_on_exit__: bool = False, # for sync, can cancel only if not started
__reraise_on_exit__: bool = True,
**kwargs: P.kwargs,
) -> concurrent.futures.Future[T]:
task = self.executor.submit(fn, *args, **kwargs)
self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__)
task.add_done_callback(self.done)
return task
def done(self, task: concurrent.futures.Future) -> None:
try:
task.result()
except GraphBubbleUp:
# This exception is an interruption signal, not an error
# so we don't want to re-raise it on exit
self.tasks.pop(task)
except BaseException:
pass
else:
self.tasks.pop(task)
def __enter__(self) -> Submit:
return self.submit
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
# copy the tasks as done() callback may modify the dict
tasks = self.tasks.copy()
# cancel all tasks that should be cancelled
for task, (cancel, _) in tasks.items():
if cancel:
task.cancel()
# wait for all tasks to finish
if pending := {t for t in tasks if not t.done()}:
concurrent.futures.wait(pending)
# shutdown the executor
self.stack.__exit__(exc_type, exc_value, traceback)
# re-raise the first exception that occurred in a task
if exc_type is None:
# if there's already an exception being raised, don't raise another one
for task, (_, reraise) in tasks.items():
if not reraise:
continue
try:
task.result()
except concurrent.futures.CancelledError:
pass
class AsyncBackgroundExecutor(AsyncContextManager):
"""A context manager that runs async tasks in the background.
Uses the current event loop to delegate tasks to asyncio tasks.
On exit,
- cancels any tasks with `__cancel_on_exit__=True`
- waits for all tasks to finish
- re-raises the first exception from tasks with `__reraise_on_exit__=True`
ignoring CancelledError"""
def __init__(self, config: RunnableConfig) -> None:
self.context_not_supported = sys.version_info < (3, 11)
self.tasks: dict[asyncio.Task, tuple[bool, bool]] = {}
self.sentinel = object()
self.loop = asyncio.get_running_loop()
if max_concurrency := config.get("max_concurrency"):
self.semaphore: Optional[asyncio.Semaphore] = asyncio.Semaphore(
max_concurrency
)
else:
self.semaphore = None
def submit( # type: ignore[valid-type]
self,
fn: Callable[P, Awaitable[T]],
*args: P.args,
__name__: Optional[str] = None,
__cancel_on_exit__: bool = False,
__reraise_on_exit__: bool = True,
**kwargs: P.kwargs,
) -> asyncio.Task[T]:
coro = cast(Coroutine[None, None, T], fn(*args, **kwargs))
if self.semaphore:
coro = gated(self.semaphore, coro)
if self.context_not_supported:
task = self.loop.create_task(coro, name=__name__)
else:
task = self.loop.create_task(coro, name=__name__, context=copy_context())
self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__)
task.add_done_callback(self.done)
return task
def done(self, task: asyncio.Task) -> None:
try:
if exc := task.exception():
# This exception is an interruption signal, not an error
# so we don't want to re-raise it on exit
if isinstance(exc, GraphBubbleUp):
self.tasks.pop(task)
else:
self.tasks.pop(task)
except asyncio.CancelledError:
self.tasks.pop(task)
async def __aenter__(self) -> Submit:
return self.submit
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
# copy the tasks as done() callback may modify the dict
tasks = self.tasks.copy()
# cancel all tasks that should be cancelled
for task, (cancel, _) in tasks.items():
if cancel:
task.cancel(self.sentinel)
# wait for all tasks to finish
if tasks:
await asyncio.wait(tasks)
# if there's already an exception being raised, don't raise another one
if exc_type is None:
# re-raise the first exception that occurred in a task
for task, (_, reraise) in tasks.items():
if not reraise:
continue
try:
if exc := task.exception():
raise exc
except asyncio.CancelledError:
pass
async def gated(semaphore: asyncio.Semaphore, coro: Coroutine[None, None, T]) -> T:
"""A coroutine that waits for a semaphore before running another coroutine."""
async with semaphore:
return await coro
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/pregel/read.py`:
```py
from __future__ import annotations
from functools import cached_property
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
Mapping,
Optional,
Sequence,
Union,
)
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnablePassthrough,
RunnableSerializable,
)
from langchain_core.runnables.base import Input, Other, coerce_to_runnable
from langchain_core.runnables.utils import ConfigurableFieldSpec
from langgraph.constants import CONF, CONFIG_KEY_READ
from langgraph.pregel.retry import RetryPolicy
from langgraph.pregel.write import ChannelWrite
from langgraph.utils.config import merge_configs
from langgraph.utils.runnable import RunnableCallable, RunnableSeq
READ_TYPE = Callable[[Union[str, Sequence[str]], bool], Union[Any, dict[str, Any]]]
class ChannelRead(RunnableCallable):
"""Implements the logic for reading state from CONFIG_KEY_READ.
Usable both as a runnable as well as a static method to call imperatively."""
channel: Union[str, list[str]]
fresh: bool = False
mapper: Optional[Callable[[Any], Any]] = None
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
return [
ConfigurableFieldSpec(
id=CONFIG_KEY_READ,
name=CONFIG_KEY_READ,
description=None,
default=None,
annotation=None,
),
]
def __init__(
self,
channel: Union[str, list[str]],
*,
fresh: bool = False,
mapper: Optional[Callable[[Any], Any]] = None,
tags: Optional[list[str]] = None,
) -> None:
super().__init__(func=self._read, afunc=self._aread, tags=tags, name=None)
self.fresh = fresh
self.mapper = mapper
self.channel = channel
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
if name:
pass
elif isinstance(self.channel, str):
name = f"ChannelRead<{self.channel}>"
else:
name = f"ChannelRead<{','.join(self.channel)}>"
return super().get_name(suffix, name=name)
def _read(self, _: Any, config: RunnableConfig) -> Any:
return self.do_read(
config, select=self.channel, fresh=self.fresh, mapper=self.mapper
)
async def _aread(self, _: Any, config: RunnableConfig) -> Any:
return self.do_read(
config, select=self.channel, fresh=self.fresh, mapper=self.mapper
)
@staticmethod
def do_read(
config: RunnableConfig,
*,
select: Union[str, list[str]],
fresh: bool = False,
mapper: Optional[Callable[[Any], Any]] = None,
) -> Any:
try:
read: READ_TYPE = config[CONF][CONFIG_KEY_READ]
except KeyError:
raise RuntimeError(
"Not configured with a read function"
"Make sure to call in the context of a Pregel process"
)
if mapper:
return mapper(read(select, fresh))
else:
return read(select, fresh)
DEFAULT_BOUND: RunnablePassthrough = RunnablePassthrough()
class PregelNode(Runnable):
"""A node in a Pregel graph. This won't be invoked as a runnable by the graph
itself, but instead acts as a container for the components necessary to make
a PregelExecutableTask for a node."""
channels: Union[list[str], Mapping[str, str]]
"""The channels that will be passed as input to `bound`.
If a list, the node will be invoked with the first of that isn't empty.
If a dict, the keys are the names of the channels, and the values are the keys
to use in the input to `bound`."""
triggers: list[str]
"""If any of these channels is written to, this node will be triggered in
the next step."""
mapper: Optional[Callable[[Any], Any]]
"""A function to transform the input before passing it to `bound`."""
writers: list[Runnable]
"""A list of writers that will be executed after `bound`, responsible for
taking the output of `bound` and writing it to the appropriate channels."""
bound: Runnable[Any, Any]
"""The main logic of the node. This will be invoked with the input from
`channels`."""
retry_policy: Optional[RetryPolicy]
"""The retry policy to use when invoking the node."""
tags: Optional[Sequence[str]]
"""Tags to attach to the node for tracing."""
metadata: Optional[Mapping[str, Any]]
"""Metadata to attach to the node for tracing."""
def __init__(
self,
*,
channels: Union[list[str], Mapping[str, str]],
triggers: Sequence[str],
mapper: Optional[Callable[[Any], Any]] = None,
writers: Optional[list[Runnable]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[Mapping[str, Any]] = None,
bound: Optional[Runnable[Any, Any]] = None,
retry_policy: Optional[RetryPolicy] = None,
) -> None:
self.channels = channels
self.triggers = list(triggers)
self.mapper = mapper
self.writers = writers or []
self.bound = bound if bound is not None else DEFAULT_BOUND
self.retry_policy = retry_policy
self.tags = tags
self.metadata = metadata
def copy(self, update: dict[str, Any]) -> PregelNode:
attrs = {**self.__dict__, **update}
return PregelNode(**attrs)
@cached_property
def flat_writers(self) -> list[Runnable]:
"""Get writers with optimizations applied. Dedupes consecutive ChannelWrites."""
writers = self.writers.copy()
while (
len(writers) > 1
and isinstance(writers[-1], ChannelWrite)
and isinstance(writers[-2], ChannelWrite)
):
# we can combine writes if they are consecutive
# careful to not modify the original writers list or ChannelWrite
writers[-2] = ChannelWrite(
writes=writers[-2].writes + writers[-1].writes,
tags=writers[-2].tags,
require_at_least_one_of=writers[-2].require_at_least_one_of,
)
writers.pop()
return writers
@cached_property
def node(self) -> Optional[Runnable[Any, Any]]:
"""Get a runnable that combines `bound` and `writers`."""
writers = self.flat_writers
if self.bound is DEFAULT_BOUND and not writers:
return None
elif self.bound is DEFAULT_BOUND and len(writers) == 1:
return writers[0]
elif self.bound is DEFAULT_BOUND:
return RunnableSeq(*writers)
elif writers:
return RunnableSeq(self.bound, *writers)
else:
return self.bound
def join(self, channels: Sequence[str]) -> PregelNode:
assert isinstance(channels, list) or isinstance(
channels, tuple
), "channels must be a list or tuple"
assert isinstance(
self.channels, dict
), "all channels must be named when using .join()"
return self.copy(
update=dict(
channels={
**self.channels,
**{chan: chan for chan in channels},
}
),
)
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Runnable[Any, Other] | Callable[[Any], Other]],
],
) -> PregelNode:
if isinstance(other, Runnable) and ChannelWrite.is_writer(other):
return self.copy(update=dict(writers=[*self.writers, other]))
elif self.bound is DEFAULT_BOUND:
return self.copy(update=dict(bound=coerce_to_runnable(other)))
else:
return self.copy(update=dict(bound=RunnableSeq(self.bound, other)))
def pipe(
self,
*others: Runnable[Any, Other] | Callable[[Any], Other],
name: Optional[str] = None,
) -> RunnableSerializable[Any, Other]:
for other in others:
self = self | other
return self
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSerializable:
raise NotImplementedError()
def invoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Any:
return self.bound.invoke(
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
)
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Any:
return await self.bound.ainvoke(
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
)
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Any]:
yield from self.bound.stream(
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
)
async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Any]:
async for item in self.bound.astream(
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
):
yield item
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/managed/__init__.py`:
```py
from langgraph.managed.is_last_step import IsLastStep, RemainingSteps
__all__ = ["IsLastStep", "RemainingSteps"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/managed/shared_value.py`:
```py
import collections.abc
from contextlib import asynccontextmanager, contextmanager
from typing import (
Any,
AsyncIterator,
Iterator,
Optional,
Sequence,
Type,
)
from typing_extensions import NotRequired, Required, Self
from langgraph.constants import CONF
from langgraph.errors import InvalidUpdateError
from langgraph.managed.base import (
ChannelKeyPlaceholder,
ChannelTypePlaceholder,
ConfiguredManagedValue,
WritableManagedValue,
)
from langgraph.store.base import PutOp
from langgraph.types import LoopProtocol
V = dict[str, Any]
Value = dict[str, V]
Update = dict[str, Optional[V]]
# Adapted from typing_extensions
def _strip_extras(t): # type: ignore[no-untyped-def]
"""Strips Annotated, Required and NotRequired from a given type."""
if hasattr(t, "__origin__"):
return _strip_extras(t.__origin__)
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
return _strip_extras(t.__args__[0])
return t
class SharedValue(WritableManagedValue[Value, Update]):
@staticmethod
def on(scope: str) -> ConfiguredManagedValue:
return ConfiguredManagedValue(
SharedValue,
{
"scope": scope,
"key": ChannelKeyPlaceholder,
"typ": ChannelTypePlaceholder,
},
)
@classmethod
@contextmanager
def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]:
with super().enter(loop, **kwargs) as value:
if loop.store is not None:
saved = loop.store.search(value.ns)
value.value = {it.key: it.value for it in saved}
yield value
@classmethod
@asynccontextmanager
async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]:
async with super().aenter(loop, **kwargs) as value:
if loop.store is not None:
saved = await loop.store.asearch(value.ns)
value.value = {it.key: it.value for it in saved}
yield value
def __init__(
self, loop: LoopProtocol, *, typ: Type[Any], scope: str, key: str
) -> None:
super().__init__(loop)
if typ := _strip_extras(typ):
if typ not in (
dict,
collections.abc.Mapping,
collections.abc.MutableMapping,
):
raise ValueError("SharedValue must be a dict")
self.scope = scope
self.value: Value = {}
if self.loop.store is None:
pass
elif scope_value := self.loop.config[CONF].get(self.scope):
self.ns = ("scoped", scope, key, scope_value)
else:
raise ValueError(
f"Scope {scope} for shared state key not in config.configurable"
)
def __call__(self) -> Value:
return self.value
def _process_update(self, values: Sequence[Update]) -> list[PutOp]:
writes: list[PutOp] = []
for vv in values:
for k, v in vv.items():
if v is None:
if k in self.value:
del self.value[k]
writes.append(PutOp(self.ns, k, None))
elif not isinstance(v, dict):
raise InvalidUpdateError("Received a non-dict value")
else:
self.value[k] = v
writes.append(PutOp(self.ns, k, v))
return writes
def update(self, values: Sequence[Update]) -> None:
if self.loop.store is None:
self._process_update(values)
else:
return self.loop.store.batch(self._process_update(values))
async def aupdate(self, writes: Sequence[Update]) -> None:
if self.loop.store is None:
self._process_update(writes)
else:
return await self.loop.store.abatch(self._process_update(writes))
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/managed/context.py`:
```py
from contextlib import asynccontextmanager, contextmanager
from inspect import signature
from typing import (
Any,
AsyncContextManager,
AsyncIterator,
Callable,
ContextManager,
Generic,
Iterator,
Optional,
Type,
Union,
)
from typing_extensions import Self
from langgraph.managed.base import ConfiguredManagedValue, ManagedValue, V
from langgraph.types import LoopProtocol
class Context(ManagedValue[V], Generic[V]):
runtime = True
value: V
@staticmethod
def of(
ctx: Union[
None,
Callable[..., ContextManager[V]],
Type[ContextManager[V]],
Callable[..., AsyncContextManager[V]],
Type[AsyncContextManager[V]],
] = None,
actx: Optional[
Union[
Callable[..., AsyncContextManager[V]],
Type[AsyncContextManager[V]],
]
] = None,
) -> ConfiguredManagedValue:
if ctx is None and actx is None:
raise ValueError("Must provide either sync or async context manager.")
return ConfiguredManagedValue(Context, {"ctx": ctx, "actx": actx})
@classmethod
@contextmanager
def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]:
with super().enter(loop, **kwargs) as self:
if self.ctx is None:
raise ValueError(
"Synchronous context manager not found. Please initialize Context value with a sync context manager, or invoke your graph asynchronously."
)
ctx = (
self.ctx(loop.config) # type: ignore[call-arg]
if signature(self.ctx).parameters.get("config")
else self.ctx()
)
with ctx as v: # type: ignore[union-attr]
self.value = v
yield self
@classmethod
@asynccontextmanager
async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]:
async with super().aenter(loop, **kwargs) as self:
if self.actx is not None:
ctx = (
self.actx(loop.config) # type: ignore[call-arg]
if signature(self.actx).parameters.get("config")
else self.actx()
)
elif self.ctx is not None:
ctx = (
self.ctx(loop.config) # type: ignore
if signature(self.ctx).parameters.get("config")
else self.ctx()
)
else:
raise ValueError(
"Asynchronous context manager not found. Please initialize Context value with an async context manager, or invoke your graph synchronously."
)
if hasattr(ctx, "__aenter__"):
async with ctx as v:
self.value = v
yield self
elif hasattr(ctx, "__enter__") and hasattr(ctx, "__exit__"):
with ctx as v:
self.value = v
yield self
else:
raise ValueError(
"Context manager must have either __enter__ or __aenter__ method."
)
def __init__(
self,
loop: LoopProtocol,
*,
ctx: Union[None, Type[ContextManager[V]], Type[AsyncContextManager[V]]] = None,
actx: Optional[Type[AsyncContextManager[V]]] = None,
) -> None:
self.ctx = ctx
self.actx = actx
def __call__(self) -> V:
return self.value
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/managed/is_last_step.py`:
```py
from typing import Annotated
from langgraph.managed.base import ManagedValue
class IsLastStepManager(ManagedValue[bool]):
def __call__(self) -> bool:
return self.loop.step == self.loop.stop - 1
IsLastStep = Annotated[bool, IsLastStepManager]
class RemainingStepsManager(ManagedValue[int]):
def __call__(self) -> int:
return self.loop.stop - self.loop.step
RemainingSteps = Annotated[int, RemainingStepsManager]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/managed/base.py`:
```py
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager, contextmanager
from inspect import isclass
from typing import (
Any,
AsyncIterator,
Generic,
Iterator,
NamedTuple,
Sequence,
Type,
TypeVar,
Union,
)
from typing_extensions import Self, TypeGuard
from langgraph.types import LoopProtocol
V = TypeVar("V")
U = TypeVar("U")
class ManagedValue(ABC, Generic[V]):
def __init__(self, loop: LoopProtocol) -> None:
self.loop = loop
@classmethod
@contextmanager
def enter(cls, loop: LoopProtocol, **kwargs: Any) -> Iterator[Self]:
try:
value = cls(loop, **kwargs)
yield value
finally:
# because managed value and Pregel have reference to each other
# let's make sure to break the reference on exit
try:
del value
except UnboundLocalError:
pass
@classmethod
@asynccontextmanager
async def aenter(cls, loop: LoopProtocol, **kwargs: Any) -> AsyncIterator[Self]:
try:
value = cls(loop, **kwargs)
yield value
finally:
# because managed value and Pregel have reference to each other
# let's make sure to break the reference on exit
try:
del value
except UnboundLocalError:
pass
@abstractmethod
def __call__(self) -> V: ...
class WritableManagedValue(Generic[V, U], ManagedValue[V], ABC):
@abstractmethod
def update(self, writes: Sequence[U]) -> None: ...
@abstractmethod
async def aupdate(self, writes: Sequence[U]) -> None: ...
class ConfiguredManagedValue(NamedTuple):
cls: Type[ManagedValue]
kwargs: dict[str, Any]
ManagedValueSpec = Union[Type[ManagedValue], ConfiguredManagedValue]
def is_managed_value(value: Any) -> TypeGuard[ManagedValueSpec]:
return (isclass(value) and issubclass(value, ManagedValue)) or isinstance(
value, ConfiguredManagedValue
)
def is_readonly_managed_value(value: Any) -> TypeGuard[Type[ManagedValue]]:
return (
isclass(value)
and issubclass(value, ManagedValue)
and not issubclass(value, WritableManagedValue)
) or (
isinstance(value, ConfiguredManagedValue)
and not issubclass(value.cls, WritableManagedValue)
)
def is_writable_managed_value(value: Any) -> TypeGuard[Type[WritableManagedValue]]:
return (isclass(value) and issubclass(value, WritableManagedValue)) or (
isinstance(value, ConfiguredManagedValue)
and issubclass(value.cls, WritableManagedValue)
)
ChannelKeyPlaceholder = object()
ChannelTypePlaceholder = object()
ManagedValueMapping = dict[str, ManagedValue]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/version.py`:
```py
"""Exports package version."""
from importlib import metadata
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""
del metadata # optional, avoids polluting the results of dir(__package__)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/constants.py`:
```py
import sys
from os import getenv
from types import MappingProxyType
from typing import Any, Literal, Mapping, cast
from langgraph.types import Interrupt, Send # noqa: F401
# Interrupt, Send re-exported for backwards compatibility
# --- Empty read-only containers ---
EMPTY_MAP: Mapping[str, Any] = MappingProxyType({})
EMPTY_SEQ: tuple[str, ...] = tuple()
MISSING = object()
# --- Public constants ---
TAG_NOSTREAM = sys.intern("langsmith:nostream")
"""Tag to disable streaming for a chat model."""
TAG_HIDDEN = sys.intern("langsmith:hidden")
"""Tag to hide a node/edge from certain tracing/streaming environments."""
START = sys.intern("__start__")
"""The first (maybe virtual) node in graph-style Pregel."""
END = sys.intern("__end__")
"""The last (maybe virtual) node in graph-style Pregel."""
SELF = sys.intern("__self__")
"""The implicit branch that handles each node's Control values."""
# --- Reserved write keys ---
INPUT = sys.intern("__input__")
# for values passed as input to the graph
INTERRUPT = sys.intern("__interrupt__")
# for dynamic interrupts raised by nodes
RESUME = sys.intern("__resume__")
# for values passed to resume a node after an interrupt
ERROR = sys.intern("__error__")
# for errors raised by nodes
NO_WRITES = sys.intern("__no_writes__")
# marker to signal node didn't write anything
SCHEDULED = sys.intern("__scheduled__")
# marker to signal node was scheduled (in distributed mode)
TASKS = sys.intern("__pregel_tasks")
# for Send objects returned by nodes/edges, corresponds to PUSH below
# --- Reserved config.configurable keys ---
CONFIG_KEY_SEND = sys.intern("__pregel_send")
# holds the `write` function that accepts writes to state/edges/reserved keys
CONFIG_KEY_READ = sys.intern("__pregel_read")
# holds the `read` function that returns a copy of the current state
CONFIG_KEY_CHECKPOINTER = sys.intern("__pregel_checkpointer")
# holds a `BaseCheckpointSaver` passed from parent graph to child graphs
CONFIG_KEY_STREAM = sys.intern("__pregel_stream")
# holds a `StreamProtocol` passed from parent graph to child graphs
CONFIG_KEY_STREAM_WRITER = sys.intern("__pregel_stream_writer")
# holds a `StreamWriter` for stream_mode=custom
CONFIG_KEY_STORE = sys.intern("__pregel_store")
# holds a `BaseStore` made available to managed values
CONFIG_KEY_RESUMING = sys.intern("__pregel_resuming")
# holds a boolean indicating if subgraphs should resume from a previous checkpoint
CONFIG_KEY_TASK_ID = sys.intern("__pregel_task_id")
# holds the task ID for the current task
CONFIG_KEY_DEDUPE_TASKS = sys.intern("__pregel_dedupe_tasks")
# holds a boolean indicating if tasks should be deduplicated (for distributed mode)
CONFIG_KEY_ENSURE_LATEST = sys.intern("__pregel_ensure_latest")
# holds a boolean indicating whether to assert the requested checkpoint is the latest
# (for distributed mode)
CONFIG_KEY_DELEGATE = sys.intern("__pregel_delegate")
# holds a boolean indicating whether to delegate subgraphs (for distributed mode)
CONFIG_KEY_CHECKPOINT_MAP = sys.intern("checkpoint_map")
# holds a mapping of checkpoint_ns -> checkpoint_id for parent graphs
CONFIG_KEY_CHECKPOINT_ID = sys.intern("checkpoint_id")
# holds the current checkpoint_id, if any
CONFIG_KEY_CHECKPOINT_NS = sys.intern("checkpoint_ns")
# holds the current checkpoint_ns, "" for root graph
CONFIG_KEY_NODE_FINISHED = sys.intern("__pregel_node_finished")
# holds the value that "answers" an interrupt() call
CONFIG_KEY_WRITES = sys.intern("__pregel_writes")
# read-only list of existing task writes
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
# holds a mutable dict for temporary storage scoped to the current task
# --- Other constants ---
PUSH = sys.intern("__pregel_push")
# denotes push-style tasks, ie. those created by Send objects
PULL = sys.intern("__pregel_pull")
# denotes pull-style tasks, ie. those triggered by edges
NS_SEP = sys.intern("|")
# for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph)
NS_END = sys.intern(":")
# for checkpoint_ns, for each level, separates the namespace from the task_id
CONF = cast(Literal["configurable"], sys.intern("configurable"))
# key for the configurable dict in RunnableConfig
FF_SEND_V2 = getenv("LANGGRAPH_FF_SEND_V2", "false").lower() == "true"
# temporary flag to enable new Send semantics
NULL_TASK_ID = sys.intern("00000000-0000-0000-0000-000000000000")
# the task_id to use for writes that are not associated with a task
RESERVED = {
TAG_HIDDEN,
# reserved write keys
INPUT,
INTERRUPT,
RESUME,
ERROR,
NO_WRITES,
SCHEDULED,
TASKS,
# reserved config.configurable keys
CONFIG_KEY_SEND,
CONFIG_KEY_READ,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_STREAM,
CONFIG_KEY_STREAM_WRITER,
CONFIG_KEY_STORE,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_RESUMING,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_DEDUPE_TASKS,
CONFIG_KEY_ENSURE_LATEST,
CONFIG_KEY_DELEGATE,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_ID,
CONFIG_KEY_CHECKPOINT_NS,
# other constants
PUSH,
PULL,
NS_SEP,
NS_END,
CONF,
}
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/graph/graph.py`:
```py
import asyncio
import logging
from collections import defaultdict
from typing import (
Any,
Awaitable,
Callable,
Hashable,
Literal,
NamedTuple,
Optional,
Sequence,
Union,
cast,
get_args,
get_origin,
get_type_hints,
overload,
)
from langchain_core.runnables import Runnable
from langchain_core.runnables.base import RunnableLike
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.graph import Graph as DrawableGraph
from langchain_core.runnables.graph import Node as DrawableNode
from typing_extensions import Self
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.constants import (
EMPTY_SEQ,
END,
NS_END,
NS_SEP,
START,
TAG_HIDDEN,
Send,
)
from langgraph.errors import InvalidUpdateError
from langgraph.pregel import Channel, Pregel
from langgraph.pregel.read import PregelNode
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.types import All, Checkpointer
from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable
logger = logging.getLogger(__name__)
class NodeSpec(NamedTuple):
runnable: Runnable
metadata: Optional[dict[str, Any]] = None
ends: Optional[tuple[str, ...]] = EMPTY_SEQ
class Branch(NamedTuple):
path: Runnable[Any, Union[Hashable, list[Hashable]]]
ends: Optional[dict[Hashable, str]]
then: Optional[str] = None
def run(
self,
writer: Callable[
[Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite]
],
reader: Optional[Callable[[RunnableConfig], Any]] = None,
) -> RunnableCallable:
return ChannelWrite.register_writer(
RunnableCallable(
func=self._route,
afunc=self._aroute,
writer=writer,
reader=reader,
name=None,
trace=False,
)
)
def _route(
self,
input: Any,
config: RunnableConfig,
*,
reader: Optional[Callable[[RunnableConfig], Any]],
writer: Callable[
[Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite]
],
) -> Runnable:
if reader:
value = reader(config)
# passthrough additional keys from node to branch
# only doable when using dict states
if isinstance(value, dict) and isinstance(input, dict):
value = {**input, **value}
else:
value = input
result = self.path.invoke(value, config)
return self._finish(writer, input, result, config)
async def _aroute(
self,
input: Any,
config: RunnableConfig,
*,
reader: Optional[Callable[[RunnableConfig], Any]],
writer: Callable[
[Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite]
],
) -> Runnable:
if reader:
value = await asyncio.to_thread(reader, config)
# passthrough additional keys from node to branch
# only doable when using dict states
if isinstance(value, dict) and isinstance(input, dict):
value = {**input, **value}
else:
value = input
result = await self.path.ainvoke(value, config)
return self._finish(writer, input, result, config)
def _finish(
self,
writer: Callable[
[Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite]
],
input: Any,
result: Any,
config: RunnableConfig,
) -> Union[Runnable, Any]:
if not isinstance(result, (list, tuple)):
result = [result]
if self.ends:
destinations: Sequence[Union[Send, str]] = [
r if isinstance(r, Send) else self.ends[r] for r in result
]
else:
destinations = cast(Sequence[Union[Send, str]], result)
if any(dest is None or dest == START for dest in destinations):
raise ValueError("Branch did not return a valid destination")
if any(p.node == END for p in destinations if isinstance(p, Send)):
raise InvalidUpdateError("Cannot send a packet to the END node")
return writer(destinations, config) or input
class Graph:
def __init__(self) -> None:
self.nodes: dict[str, NodeSpec] = {}
self.edges = set[tuple[str, str]]()
self.branches: defaultdict[str, dict[str, Branch]] = defaultdict(dict)
self.support_multiple_edges = False
self.compiled = False
@property
def _all_edges(self) -> set[tuple[str, str]]:
return self.edges
@overload
def add_node(
self,
node: RunnableLike,
*,
metadata: Optional[dict[str, Any]] = None,
) -> Self: ...
@overload
def add_node(
self,
node: str,
action: RunnableLike,
*,
metadata: Optional[dict[str, Any]] = None,
) -> Self: ...
def add_node(
self,
node: Union[str, RunnableLike],
action: Optional[RunnableLike] = None,
*,
metadata: Optional[dict[str, Any]] = None,
) -> Self:
if isinstance(node, str):
for character in (NS_SEP, NS_END):
if character in node:
raise ValueError(
f"'{character}' is a reserved character and is not allowed in the node names."
)
if self.compiled:
logger.warning(
"Adding a node to a graph that has already been compiled. This will "
"not be reflected in the compiled graph."
)
if not isinstance(node, str):
action = node
node = getattr(action, "name", getattr(action, "__name__"))
if node is None:
raise ValueError(
"Node name must be provided if action is not a function"
)
if action is None:
raise RuntimeError(
"Expected a function or Runnable action in add_node. Received None."
)
if node in self.nodes:
raise ValueError(f"Node `{node}` already present.")
if node == END or node == START:
raise ValueError(f"Node `{node}` is reserved.")
self.nodes[cast(str, node)] = NodeSpec(
coerce_to_runnable(action, name=cast(str, node), trace=False), metadata
)
return self
def add_edge(self, start_key: str, end_key: str) -> Self:
if self.compiled:
logger.warning(
"Adding an edge to a graph that has already been compiled. This will "
"not be reflected in the compiled graph."
)
if start_key == END:
raise ValueError("END cannot be a start node")
if end_key == START:
raise ValueError("START cannot be an end node")
# run this validation only for non-StateGraph graphs
if not hasattr(self, "channels") and start_key in set(
start for start, _ in self.edges
):
raise ValueError(
f"Already found path for node '{start_key}'.\n"
"For multiple edges, use StateGraph with an Annotated state key."
)
self.edges.add((start_key, end_key))
return self
def add_conditional_edges(
self,
source: str,
path: Union[
Callable[..., Union[Hashable, list[Hashable]]],
Callable[..., Awaitable[Union[Hashable, list[Hashable]]]],
Runnable[Any, Union[Hashable, list[Hashable]]],
],
path_map: Optional[Union[dict[Hashable, str], list[str]]] = None,
then: Optional[str] = None,
) -> Self:
"""Add a conditional edge from the starting node to any number of destination nodes.
Args:
source (str): The starting node. This conditional edge will run when
exiting this node.
path (Union[Callable, Runnable]): The callable that determines the next
node or nodes. If not specifying `path_map` it should return one or
more nodes. If it returns END, the graph will stop execution.
path_map (Optional[dict[Hashable, str]]): Optional mapping of paths to node
names. If omitted the paths returned by `path` should be node names.
then (Optional[str]): The name of a node to execute after the nodes
selected by `path`.
Returns:
None
Note: Without typehints on the `path` function's return value (e.g., `-> Literal["foo", "__end__"]:`)
or a path_map, the graph visualization assumes the edge could transition to any node in the graph.
""" # noqa: E501
if self.compiled:
logger.warning(
"Adding an edge to a graph that has already been compiled. This will "
"not be reflected in the compiled graph."
)
# coerce path_map to a dictionary
try:
if isinstance(path_map, dict):
path_map_ = path_map.copy()
elif isinstance(path_map, list):
path_map_ = {name: name for name in path_map}
elif isinstance(path, Runnable):
path_map_ = None
elif rtn_type := get_type_hints(path.__call__).get( # type: ignore[operator]
"return"
) or get_type_hints(path).get("return"):
if get_origin(rtn_type) is Literal:
path_map_ = {name: name for name in get_args(rtn_type)}
else:
path_map_ = None
else:
path_map_ = None
except Exception:
path_map_ = None
# find a name for the condition
path = coerce_to_runnable(path, name=None, trace=True)
name = path.name or "condition"
# validate the condition
if name in self.branches[source]:
raise ValueError(
f"Branch with name `{path.name}` already exists for node " f"`{source}`"
)
# save it
self.branches[source][name] = Branch(path, path_map_, then)
return self
def set_entry_point(self, key: str) -> Self:
"""Specifies the first node to be called in the graph.
Equivalent to calling `add_edge(START, key)`.
Parameters:
key (str): The key of the node to set as the entry point.
Returns:
None
"""
return self.add_edge(START, key)
def set_conditional_entry_point(
self,
path: Union[
Callable[..., Union[Hashable, list[Hashable]]],
Callable[..., Awaitable[Union[Hashable, list[Hashable]]]],
Runnable[Any, Union[Hashable, list[Hashable]]],
],
path_map: Optional[Union[dict[Hashable, str], list[str]]] = None,
then: Optional[str] = None,
) -> Self:
"""Sets a conditional entry point in the graph.
Args:
path (Union[Callable, Runnable]): The callable that determines the next
node or nodes. If not specifying `path_map` it should return one or
more nodes. If it returns END, the graph will stop execution.
path_map (Optional[dict[str, str]]): Optional mapping of paths to node
names. If omitted the paths returned by `path` should be node names.
then (Optional[str]): The name of a node to execute after the nodes
selected by `path`.
Returns:
None
"""
return self.add_conditional_edges(START, path, path_map, then)
def set_finish_point(self, key: str) -> Self:
"""Marks a node as a finish point of the graph.
If the graph reaches this node, it will cease execution.
Parameters:
key (str): The key of the node to set as the finish point.
Returns:
None
"""
return self.add_edge(key, END)
def validate(self, interrupt: Optional[Sequence[str]] = None) -> Self:
# assemble sources
all_sources = {src for src, _ in self._all_edges}
for start, branches in self.branches.items():
all_sources.add(start)
for cond, branch in branches.items():
if branch.then is not None:
if branch.ends is not None:
for end in branch.ends.values():
if end != END:
all_sources.add(end)
else:
for node in self.nodes:
if node != start and node != branch.then:
all_sources.add(node)
for name, spec in self.nodes.items():
if spec.ends:
all_sources.add(name)
# validate sources
for source in all_sources:
if source not in self.nodes and source != START:
raise ValueError(f"Found edge starting at unknown node '{source}'")
if START not in all_sources:
raise ValueError(
"Graph must have an entrypoint: add at least one edge from START to another node"
)
# assemble targets
all_targets = {end for _, end in self._all_edges}
for start, branches in self.branches.items():
for cond, branch in branches.items():
if branch.then is not None:
all_targets.add(branch.then)
if branch.ends is not None:
for end in branch.ends.values():
if end not in self.nodes and end != END:
raise ValueError(
f"At '{start}' node, '{cond}' branch found unknown target '{end}'"
)
all_targets.add(end)
else:
all_targets.add(END)
for node in self.nodes:
if node != start and node != branch.then:
all_targets.add(node)
for name, spec in self.nodes.items():
if spec.ends:
all_targets.update(spec.ends)
for target in all_targets:
if target not in self.nodes and target != END:
raise ValueError(f"Found edge ending at unknown node `{target}`")
# validate interrupts
if interrupt:
for node in interrupt:
if node not in self.nodes:
raise ValueError(f"Interrupt node `{node}` not found")
self.compiled = True
return self
def compile(
self,
checkpointer: Checkpointer = None,
interrupt_before: Optional[Union[All, list[str]]] = None,
interrupt_after: Optional[Union[All, list[str]]] = None,
debug: bool = False,
) -> "CompiledGraph":
# assign default values
interrupt_before = interrupt_before or []
interrupt_after = interrupt_after or []
# validate the graph
self.validate(
interrupt=(
(interrupt_before if interrupt_before != "*" else []) + interrupt_after
if interrupt_after != "*"
else []
)
)
# create empty compiled graph
compiled = CompiledGraph(
builder=self,
nodes={},
channels={START: EphemeralValue(Any), END: EphemeralValue(Any)},
input_channels=START,
output_channels=END,
stream_mode="values",
stream_channels=[],
checkpointer=checkpointer,
interrupt_before_nodes=interrupt_before,
interrupt_after_nodes=interrupt_after,
auto_validate=False,
debug=debug,
)
# attach nodes, edges, and branches
for key, node in self.nodes.items():
compiled.attach_node(key, node)
for start, end in self.edges:
compiled.attach_edge(start, end)
for start, branches in self.branches.items():
for name, branch in branches.items():
compiled.attach_branch(start, name, branch)
# validate the compiled graph
return compiled.validate()
class CompiledGraph(Pregel):
builder: Graph
def __init__(self, *, builder: Graph, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.builder = builder
def attach_node(self, key: str, node: NodeSpec) -> None:
self.channels[key] = EphemeralValue(Any)
self.nodes[key] = (
PregelNode(channels=[], triggers=[], metadata=node.metadata)
| node.runnable
| ChannelWrite([ChannelWriteEntry(key)], tags=[TAG_HIDDEN])
)
cast(list[str], self.stream_channels).append(key)
def attach_edge(self, start: str, end: str) -> None:
if end == END:
# publish to end channel
self.nodes[start].writers.append(
ChannelWrite([ChannelWriteEntry(END)], tags=[TAG_HIDDEN])
)
else:
# subscribe to start channel
self.nodes[end].triggers.append(start)
cast(list[str], self.nodes[end].channels).append(start)
def attach_branch(self, start: str, name: str, branch: Branch) -> None:
def branch_writer(
packets: Sequence[Union[str, Send]], config: RunnableConfig
) -> Optional[ChannelWrite]:
writes = [
(
ChannelWriteEntry(f"branch:{start}:{name}:{p}" if p != END else END)
if not isinstance(p, Send)
else p
)
for p in packets
]
return ChannelWrite(
cast(Sequence[Union[ChannelWriteEntry, Send]], writes),
tags=[TAG_HIDDEN],
)
# add hidden start node
if start == START and start not in self.nodes:
self.nodes[start] = Channel.subscribe_to(START, tags=[TAG_HIDDEN])
# attach branch writer
self.nodes[start] |= branch.run(branch_writer)
# attach branch readers
ends = branch.ends.values() if branch.ends else [node for node in self.nodes]
for end in ends:
if end != END:
channel_name = f"branch:{start}:{name}:{end}"
self.channels[channel_name] = EphemeralValue(Any)
self.nodes[end].triggers.append(channel_name)
cast(list[str], self.nodes[end].channels).append(channel_name)
async def aget_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph:
return self.get_graph(config, xray=xray)
def get_graph(
self,
config: Optional[RunnableConfig] = None,
*,
xray: Union[int, bool] = False,
) -> DrawableGraph:
"""Returns a drawable representation of the computation graph."""
graph = DrawableGraph()
start_nodes: dict[str, DrawableNode] = {
START: graph.add_node(self.get_input_schema(config), START)
}
end_nodes: dict[str, DrawableNode] = {}
if xray:
subgraphs = {
k: v for k, v in self.get_subgraphs() if isinstance(v, CompiledGraph)
}
else:
subgraphs = {}
def add_edge(
start: str,
end: str,
label: Optional[Hashable] = None,
conditional: bool = False,
) -> None:
if end == END and END not in end_nodes:
end_nodes[END] = graph.add_node(self.get_output_schema(config), END)
return graph.add_edge(
start_nodes[start],
end_nodes[end],
str(label) if label is not None else None,
conditional,
)
for key, n in self.builder.nodes.items():
node = n.runnable
metadata = n.metadata or {}
if key in self.interrupt_before_nodes and key in self.interrupt_after_nodes:
metadata["__interrupt"] = "before,after"
elif key in self.interrupt_before_nodes:
metadata["__interrupt"] = "before"
elif key in self.interrupt_after_nodes:
metadata["__interrupt"] = "after"
if xray and key in subgraphs:
subgraph = subgraphs[key].get_graph(
config=config,
xray=xray - 1
if isinstance(xray, int) and not isinstance(xray, bool) and xray > 0
else xray,
)
subgraph.trim_first_node()
subgraph.trim_last_node()
if len(subgraph.nodes) > 1:
e, s = graph.extend(subgraph, prefix=key)
if e is None:
raise ValueError(
f"Could not extend subgraph '{key}' due to missing entrypoint"
)
if s is not None:
start_nodes[key] = s
end_nodes[key] = e
else:
nn = graph.add_node(node, key, metadata=metadata or None)
start_nodes[key] = nn
end_nodes[key] = nn
else:
nn = graph.add_node(node, key, metadata=metadata or None)
start_nodes[key] = nn
end_nodes[key] = nn
for start, end in sorted(self.builder._all_edges):
add_edge(start, end)
for start, branches in self.builder.branches.items():
default_ends = {
**{k: k for k in self.builder.nodes if k != start},
END: END,
}
for _, branch in branches.items():
if branch.ends is not None:
ends = branch.ends
elif branch.then is not None:
ends = {k: k for k in default_ends if k not in (END, branch.then)}
else:
ends = cast(dict[Hashable, str], default_ends)
for label, end in ends.items():
add_edge(
start,
end,
label if label != end else None,
conditional=True,
)
if branch.then is not None:
add_edge(end, branch.then)
for key, n in self.builder.nodes.items():
if n.ends:
for end in n.ends:
add_edge(key, end, conditional=True)
return graph
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/graph/__init__.py`:
```py
from langgraph.graph.graph import END, START, Graph
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
from langgraph.graph.state import StateGraph
__all__ = [
"END",
"START",
"Graph",
"StateGraph",
"MessageGraph",
"add_messages",
"MessagesState",
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/graph/message.py`:
```py
import uuid
from typing import Annotated, TypedDict, Union, cast
from langchain_core.messages import (
AnyMessage,
BaseMessageChunk,
MessageLikeRepresentation,
RemoveMessage,
convert_to_messages,
message_chunk_to_message,
)
from langgraph.graph.state import StateGraph
Messages = Union[list[MessageLikeRepresentation], MessageLikeRepresentation]
def add_messages(left: Messages, right: Messages) -> Messages:
"""Merges two lists of messages, updating existing messages by ID.
By default, this ensures the state is "append-only", unless the
new message has the same ID as an existing message.
Args:
left: The base list of messages.
right: The list of messages (or single message) to merge
into the base list.
Returns:
A new list of messages with the messages from `right` merged into `left`.
If a message in `right` has the same ID as a message in `left`, the
message from `right` will replace the message from `left`.
Examples:
```pycon
>>> from langchain_core.messages import AIMessage, HumanMessage
>>> msgs1 = [HumanMessage(content="Hello", id="1")]
>>> msgs2 = [AIMessage(content="Hi there!", id="2")]
>>> add_messages(msgs1, msgs2)
[HumanMessage(content='Hello', id='1'), AIMessage(content='Hi there!', id='2')]
>>> msgs1 = [HumanMessage(content="Hello", id="1")]
>>> msgs2 = [HumanMessage(content="Hello again", id="1")]
>>> add_messages(msgs1, msgs2)
[HumanMessage(content='Hello again', id='1')]
>>> from typing import Annotated
>>> from typing_extensions import TypedDict
>>> from langgraph.graph import StateGraph
>>>
>>> class State(TypedDict):
... messages: Annotated[list, add_messages]
...
>>> builder = StateGraph(State)
>>> builder.add_node("chatbot", lambda state: {"messages": [("assistant", "Hello")]})
>>> builder.set_entry_point("chatbot")
>>> builder.set_finish_point("chatbot")
>>> graph = builder.compile()
>>> graph.invoke({})
{'messages': [AIMessage(content='Hello', id=...)]}
```
"""
# coerce to list
if not isinstance(left, list):
left = [left] # type: ignore[assignment]
if not isinstance(right, list):
right = [right] # type: ignore[assignment]
# coerce to message
left = [
message_chunk_to_message(cast(BaseMessageChunk, m))
for m in convert_to_messages(left)
]
right = [
message_chunk_to_message(cast(BaseMessageChunk, m))
for m in convert_to_messages(right)
]
# assign missing ids
for m in left:
if m.id is None:
m.id = str(uuid.uuid4())
for m in right:
if m.id is None:
m.id = str(uuid.uuid4())
# merge
left_idx_by_id = {m.id: i for i, m in enumerate(left)}
merged = left.copy()
ids_to_remove = set()
for m in right:
if (existing_idx := left_idx_by_id.get(m.id)) is not None:
if isinstance(m, RemoveMessage):
ids_to_remove.add(m.id)
else:
merged[existing_idx] = m
else:
if isinstance(m, RemoveMessage):
raise ValueError(
f"Attempting to delete a message with an ID that doesn't exist ('{m.id}')"
)
merged.append(m)
merged = [m for m in merged if m.id not in ids_to_remove]
return merged
class MessageGraph(StateGraph):
"""A StateGraph where every node receives a list of messages as input and returns one or more messages as output.
MessageGraph is a subclass of StateGraph whose entire state is a single, append-only* list of messages.
Each node in a MessageGraph takes a list of messages as input and returns zero or more
messages as output. The `add_messages` function is used to merge the output messages from each node
into the existing list of messages in the graph's state.
Examples:
```pycon
>>> from langgraph.graph.message import MessageGraph
...
>>> builder = MessageGraph()
>>> builder.add_node("chatbot", lambda state: [("assistant", "Hello!")])
>>> builder.set_entry_point("chatbot")
>>> builder.set_finish_point("chatbot")
>>> builder.compile().invoke([("user", "Hi there.")])
[HumanMessage(content="Hi there.", id='...'), AIMessage(content="Hello!", id='...')]
```
```pycon
>>> from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
>>> from langgraph.graph.message import MessageGraph
...
>>> builder = MessageGraph()
>>> builder.add_node(
... "chatbot",
... lambda state: [
... AIMessage(
... content="Hello!",
... tool_calls=[{"name": "search", "id": "123", "args": {"query": "X"}}],
... )
... ],
... )
>>> builder.add_node(
... "search", lambda state: [ToolMessage(content="Searching...", tool_call_id="123")]
... )
>>> builder.set_entry_point("chatbot")
>>> builder.add_edge("chatbot", "search")
>>> builder.set_finish_point("search")
>>> builder.compile().invoke([HumanMessage(content="Hi there. Can you search for X?")])
{'messages': [HumanMessage(content="Hi there. Can you search for X?", id='b8b7d8f4-7f4d-4f4d-9c1d-f8b8d8f4d9c1'),
AIMessage(content="Hello!", id='f4d9c1d8-8d8f-4d9c-b8b7-d8f4f4d9c1d8'),
ToolMessage(content="Searching...", id='d8f4f4d9-c1d8-4f4d-b8b7-d8f4f4d9c1d8', tool_call_id="123")]}
```
"""
def __init__(self) -> None:
super().__init__(Annotated[list[AnyMessage], add_messages]) # type: ignore[arg-type]
class MessagesState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/graph/state.py`:
```py
import inspect
import logging
import typing
import warnings
from functools import partial
from inspect import isclass, isfunction, ismethod, signature
from types import FunctionType
from typing import (
Any,
Callable,
Literal,
NamedTuple,
Optional,
Sequence,
Type,
Union,
cast,
get_args,
get_origin,
get_type_hints,
overload,
)
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.base import RunnableLike
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self
from langgraph._api.deprecation import LangGraphDeprecationWarning
from langgraph.channels.base import BaseChannel
from langgraph.channels.binop import BinaryOperatorAggregate
from langgraph.channels.dynamic_barrier_value import DynamicBarrierValue, WaitForNames
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.last_value import LastValue
from langgraph.channels.named_barrier_value import NamedBarrierValue
from langgraph.constants import EMPTY_SEQ, NS_END, NS_SEP, SELF, TAG_HIDDEN
from langgraph.errors import (
ErrorCode,
InvalidUpdateError,
ParentCommand,
create_error_message,
)
from langgraph.graph.graph import END, START, Branch, CompiledGraph, Graph, Send
from langgraph.managed.base import (
ChannelKeyPlaceholder,
ChannelTypePlaceholder,
ConfiguredManagedValue,
ManagedValueSpec,
is_managed_value,
is_writable_managed_value,
)
from langgraph.pregel.read import ChannelRead, PregelNode
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import All, Checkpointer, Command, RetryPolicy
from langgraph.utils.fields import get_field_default
from langgraph.utils.pydantic import create_model
from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable
logger = logging.getLogger(__name__)
def _warn_invalid_state_schema(schema: Union[Type[Any], Any]) -> None:
if isinstance(schema, type):
return
if typing.get_args(schema):
return
warnings.warn(
f"Invalid state_schema: {schema}. Expected a type or Annotated[type, reducer]. "
"Please provide a valid schema to ensure correct updates.\n"
" See: https://langchain-ai.github.io/langgraph/reference/graphs/#stategraph"
)
def _get_node_name(node: RunnableLike) -> str:
if isinstance(node, Runnable):
return node.get_name()
elif callable(node):
return getattr(node, "__name__", node.__class__.__name__)
else:
raise TypeError(f"Unsupported node type: {type(node)}")
class StateNodeSpec(NamedTuple):
runnable: Runnable
metadata: Optional[dict[str, Any]]
input: Type[Any]
retry_policy: Optional[RetryPolicy]
ends: Optional[tuple[str, ...]] = EMPTY_SEQ
class StateGraph(Graph):
"""A graph whose nodes communicate by reading and writing to a shared state.
The signature of each node is State -> Partial<State>.
Each state key can optionally be annotated with a reducer function that
will be used to aggregate the values of that key received from multiple nodes.
The signature of a reducer function is (Value, Value) -> Value.
Args:
state_schema (Type[Any]): The schema class that defines the state.
config_schema (Optional[Type[Any]]): The schema class that defines the configuration.
Use this to expose configurable parameters in your API.
Examples:
>>> from langchain_core.runnables import RunnableConfig
>>> from typing_extensions import Annotated, TypedDict
>>> from langgraph.checkpoint.memory import MemorySaver
>>> from langgraph.graph import StateGraph
>>>
>>> def reducer(a: list, b: int | None) -> list:
... if b is not None:
... return a + [b]
... return a
>>>
>>> class State(TypedDict):
... x: Annotated[list, reducer]
>>>
>>> class ConfigSchema(TypedDict):
... r: float
>>>
>>> graph = StateGraph(State, config_schema=ConfigSchema)
>>>
>>> def node(state: State, config: RunnableConfig) -> dict:
... r = config["configurable"].get("r", 1.0)
... x = state["x"][-1]
... next_value = x * r * (1 - x)
... return {"x": next_value}
>>>
>>> graph.add_node("A", node)
>>> graph.set_entry_point("A")
>>> graph.set_finish_point("A")
>>> compiled = graph.compile()
>>>
>>> print(compiled.config_specs)
[ConfigurableFieldSpec(id='r', annotation=<class 'float'>, name=None, description=None, default=None, is_shared=False, dependencies=None)]
>>>
>>> step1 = compiled.invoke({"x": 0.5}, {"configurable": {"r": 3.0}})
>>> print(step1)
{'x': [0.5, 0.75]}"""
nodes: dict[str, StateNodeSpec] # type: ignore[assignment]
channels: dict[str, BaseChannel]
managed: dict[str, ManagedValueSpec]
schemas: dict[Type[Any], dict[str, Union[BaseChannel, ManagedValueSpec]]]
def __init__(
self,
state_schema: Optional[Type[Any]] = None,
config_schema: Optional[Type[Any]] = None,
*,
input: Optional[Type[Any]] = None,
output: Optional[Type[Any]] = None,
) -> None:
super().__init__()
if state_schema is None:
if input is None or output is None:
raise ValueError("Must provide state_schema or input and output")
state_schema = input
warnings.warn(
"Initializing StateGraph without state_schema is deprecated. "
"Please pass in an explicit state_schema instead of just an input and output schema.",
LangGraphDeprecationWarning,
stacklevel=2,
)
else:
if input is None:
input = state_schema
if output is None:
output = state_schema
self.schemas = {}
self.channels = {}
self.managed = {}
self.schema = state_schema
self.input = input
self.output = output
self._add_schema(state_schema)
self._add_schema(input, allow_managed=False)
self._add_schema(output, allow_managed=False)
self.config_schema = config_schema
self.waiting_edges: set[tuple[tuple[str, ...], str]] = set()
@property
def _all_edges(self) -> set[tuple[str, str]]:
return self.edges | {
(start, end) for starts, end in self.waiting_edges for start in starts
}
def _add_schema(self, schema: Type[Any], /, allow_managed: bool = True) -> None:
if schema not in self.schemas:
_warn_invalid_state_schema(schema)
channels, managed = _get_channels(schema)
if managed and not allow_managed:
names = ", ".join(managed)
schema_name = getattr(schema, "__name__", "")
raise ValueError(
f"Invalid managed channels detected in {schema_name}: {names}."
" Managed channels are not permitted in Input/Output schema."
)
self.schemas[schema] = {**channels, **managed}
for key, channel in channels.items():
if key in self.channels:
if self.channels[key] != channel:
if isinstance(channel, LastValue):
pass
else:
raise ValueError(
f"Channel '{key}' already exists with a different type"
)
else:
self.channels[key] = channel
for key, managed in managed.items():
if key in self.managed:
if self.managed[key] != managed:
raise ValueError(
f"Managed value '{key}' already exists with a different type"
)
else:
self.managed[key] = managed
@overload
def add_node(
self,
node: RunnableLike,
*,
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
) -> Self:
"""Adds a new node to the state graph.
Will take the name of the function/runnable as the node name.
Args:
node (RunnableLike): The function or runnable this node will run.
Raises:
ValueError: If the key is already being used as a state key.
Returns:
StateGraph
"""
...
@overload
def add_node(
self,
node: str,
action: RunnableLike,
*,
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
) -> Self:
"""Adds a new node to the state graph.
Args:
node (str): The key of the node.
action (RunnableLike): The action associated with the node.
Raises:
ValueError: If the key is already being used as a state key.
Returns:
StateGraph
"""
...
def add_node(
self,
node: Union[str, RunnableLike],
action: Optional[RunnableLike] = None,
*,
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
) -> Self:
"""Adds a new node to the state graph.
Will take the name of the function/runnable as the node name.
Args:
node (Union[str, RunnableLike)]: The function or runnable this node will run.
action (Optional[RunnableLike]): The action associated with the node. (default: None)
metadata (Optional[dict[str, Any]]): The metadata associated with the node. (default: None)
input (Optional[Type[Any]]): The input schema for the node. (default: the graph's input schema)
retry (Optional[RetryPolicy]): The policy for retrying the node. (default: None)
Raises:
ValueError: If the key is already being used as a state key.
Examples:
```pycon
>>> from langgraph.graph import START, StateGraph
...
>>> def my_node(state, config):
... return {"x": state["x"] + 1}
...
>>> builder = StateGraph(dict)
>>> builder.add_node(my_node) # node name will be 'my_node'
>>> builder.add_edge(START, "my_node")
>>> graph = builder.compile()
>>> graph.invoke({"x": 1})
{'x': 2}
```
Customize the name:
```pycon
>>> builder = StateGraph(dict)
>>> builder.add_node("my_fair_node", my_node)
>>> builder.add_edge(START, "my_fair_node")
>>> graph = builder.compile()
>>> graph.invoke({"x": 1})
{'x': 2}
```
Returns:
StateGraph
"""
if not isinstance(node, str):
action = node
if isinstance(action, Runnable):
node = action.get_name()
else:
node = getattr(action, "__name__", action.__class__.__name__)
if node is None:
raise ValueError(
"Node name must be provided if action is not a function"
)
if node in self.channels:
raise ValueError(f"'{node}' is already being used as a state key")
if self.compiled:
logger.warning(
"Adding a node to a graph that has already been compiled. This will "
"not be reflected in the compiled graph."
)
if not isinstance(node, str):
action = node
node = cast(str, getattr(action, "name", getattr(action, "__name__", None)))
if node is None:
raise ValueError(
"Node name must be provided if action is not a function"
)
if action is None:
raise RuntimeError
if node in self.nodes:
raise ValueError(f"Node `{node}` already present.")
if node == END or node == START:
raise ValueError(f"Node `{node}` is reserved.")
for character in (NS_SEP, NS_END):
if character in cast(str, node):
raise ValueError(
f"'{character}' is a reserved character and is not allowed in the node names."
)
ends = EMPTY_SEQ
try:
if (isfunction(action) or ismethod(getattr(action, "__call__", None))) and (
hints := get_type_hints(getattr(action, "__call__"))
or get_type_hints(action)
):
if input is None:
first_parameter_name = next(
iter(
inspect.signature(
cast(FunctionType, action)
).parameters.keys()
)
)
if input_hint := hints.get(first_parameter_name):
if isinstance(input_hint, type) and get_type_hints(input_hint):
input = input_hint
if (
(rtn := hints.get("return"))
and get_origin(rtn) is Command
and (rargs := get_args(rtn))
and get_origin(rargs[0]) is Literal
and (vals := get_args(rargs[0]))
):
ends = vals
except (TypeError, StopIteration):
pass
if input is not None:
self._add_schema(input)
self.nodes[cast(str, node)] = StateNodeSpec(
coerce_to_runnable(action, name=cast(str, node), trace=False),
metadata,
input=input or self.schema,
retry_policy=retry,
ends=ends,
)
return self
def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> Self:
"""Adds a directed edge from the start node to the end node.
If the graph transitions to the start_key node, it will always transition to the end_key node next.
Args:
start_key (Union[str, list[str]]): The key(s) of the start node(s) of the edge.
end_key (str): The key of the end node of the edge.
Raises:
ValueError: If the start key is 'END' or if the start key or end key is not present in the graph.
Returns:
StateGraph
"""
if isinstance(start_key, str):
return super().add_edge(start_key, end_key)
if self.compiled:
logger.warning(
"Adding an edge to a graph that has already been compiled. This will "
"not be reflected in the compiled graph."
)
for start in start_key:
if start == END:
raise ValueError("END cannot be a start node")
if start not in self.nodes:
raise ValueError(f"Need to add_node `{start}` first")
if end_key == START:
raise ValueError("START cannot be an end node")
if end_key != END and end_key not in self.nodes:
raise ValueError(f"Need to add_node `{end_key}` first")
self.waiting_edges.add((tuple(start_key), end_key))
return self
def add_sequence(
self,
nodes: Sequence[Union[RunnableLike, tuple[str, RunnableLike]]],
) -> Self:
"""Add a sequence of nodes that will be executed in the provided order.
Args:
nodes: A sequence of RunnableLike objects (e.g. a LangChain Runnable or a callable) or (name, RunnableLike) tuples.
If no names are provided, the name will be inferred from the node object (e.g. a runnable or a callable name).
Each node will be executed in the order provided.
Raises:
ValueError: if the sequence is empty.
ValueError: if the sequence contains duplicate node names.
Returns:
StateGraph
"""
if len(nodes) < 1:
raise ValueError("Sequence requires at least one node.")
previous_name: Optional[str] = None
for node in nodes:
if isinstance(node, tuple) and len(node) == 2:
name, node = node
else:
name = _get_node_name(node)
if name in self.nodes:
raise ValueError(
f"Node names must be unique: node with the name '{name}' already exists. "
"If you need to use two different runnables/callables with the same name (for example, using `lambda`), please provide them as tuples (name, runnable/callable)."
)
self.add_node(name, node)
if previous_name is not None:
self.add_edge(previous_name, name)
previous_name = name
return self
def compile(
self,
checkpointer: Checkpointer = None,
*,
store: Optional[BaseStore] = None,
interrupt_before: Optional[Union[All, list[str]]] = None,
interrupt_after: Optional[Union[All, list[str]]] = None,
debug: bool = False,
) -> "CompiledStateGraph":
"""Compiles the state graph into a `CompiledGraph` object.
The compiled graph implements the `Runnable` interface and can be invoked,
streamed, batched, and run asynchronously.
Args:
checkpointer (Optional[Union[Checkpointer, Literal[False]]]): A checkpoint saver object or flag.
If provided, this Checkpointer serves as a fully versioned "short-term memory" for the graph,
allowing it to be paused, resumed, and replayed from any point.
If None, it may inherit the parent graph's checkpointer when used as a subgraph.
If False, it will not use or inherit any checkpointer.
interrupt_before (Optional[Sequence[str]]): An optional list of node names to interrupt before.
interrupt_after (Optional[Sequence[str]]): An optional list of node names to interrupt after.
debug (bool): A flag indicating whether to enable debug mode.
Returns:
CompiledStateGraph: The compiled state graph.
"""
# assign default values
interrupt_before = interrupt_before or []
interrupt_after = interrupt_after or []
# validate the graph
self.validate(
interrupt=(
(interrupt_before if interrupt_before != "*" else []) + interrupt_after
if interrupt_after != "*"
else []
)
)
# prepare output channels
output_channels = (
"__root__"
if len(self.schemas[self.output]) == 1
and "__root__" in self.schemas[self.output]
else [
key
for key, val in self.schemas[self.output].items()
if not is_managed_value(val)
]
)
stream_channels = (
"__root__"
if len(self.channels) == 1 and "__root__" in self.channels
else [
key for key, val in self.channels.items() if not is_managed_value(val)
]
)
compiled = CompiledStateGraph(
builder=self,
config_type=self.config_schema,
nodes={},
channels={
**self.channels,
**self.managed,
START: EphemeralValue(self.input),
},
input_channels=START,
stream_mode="updates",
output_channels=output_channels,
stream_channels=stream_channels,
checkpointer=checkpointer,
interrupt_before_nodes=interrupt_before,
interrupt_after_nodes=interrupt_after,
auto_validate=False,
debug=debug,
store=store,
)
compiled.attach_node(START, None)
for key, node in self.nodes.items():
compiled.attach_node(key, node)
for key, node in self.nodes.items():
compiled.attach_branch(key, SELF, CONTROL_BRANCH, with_reader=False)
for start, end in self.edges:
compiled.attach_edge(start, end)
for starts, end in self.waiting_edges:
compiled.attach_edge(starts, end)
for start, branches in self.branches.items():
for name, branch in branches.items():
compiled.attach_branch(start, name, branch)
return compiled.validate()
class CompiledStateGraph(CompiledGraph):
builder: StateGraph
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
return _get_schema(
typ=self.builder.input,
schemas=self.builder.schemas,
channels=self.builder.channels,
name=self.get_name("Input"),
)
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
return _get_schema(
typ=self.builder.output,
schemas=self.builder.schemas,
channels=self.builder.channels,
name=self.get_name("Output"),
)
def attach_node(self, key: str, node: Optional[StateNodeSpec]) -> None:
if key == START:
output_keys = [
k
for k, v in self.builder.schemas[self.builder.input].items()
if not is_managed_value(v)
]
else:
output_keys = list(self.builder.channels) + [
k
for k, v in self.builder.managed.items()
if is_writable_managed_value(v)
]
def _get_root(input: Any) -> Any:
if isinstance(input, Command):
if input.graph == Command.PARENT:
return SKIP_WRITE
return input.update
else:
return input
# to avoid name collision below
node_key = key
def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any:
if input is None:
return SKIP_WRITE
elif isinstance(input, dict):
if all(k not in output_keys for k in input):
raise InvalidUpdateError(
f"Expected node {node_key} to update at least one of {output_keys}, got {input}"
)
return input.get(key, SKIP_WRITE)
elif isinstance(input, Command):
if input.graph == Command.PARENT:
return SKIP_WRITE
return _get_state_key(input.update, key=key)
elif get_type_hints(type(input)):
value = getattr(input, key, SKIP_WRITE)
return value if value is not None else SKIP_WRITE
else:
msg = create_error_message(
message=f"Expected dict, got {input}",
error_code=ErrorCode.INVALID_GRAPH_NODE_RETURN_VALUE,
)
raise InvalidUpdateError(msg)
# state updaters
write_entries = (
[ChannelWriteEntry("__root__", skip_none=True, mapper=_get_root)]
if output_keys == ["__root__"]
else [
ChannelWriteEntry(key, mapper=partial(_get_state_key, key=key))
for key in output_keys
]
)
# add node and output channel
if key == START:
self.nodes[key] = PregelNode(
tags=[TAG_HIDDEN],
triggers=[START],
channels=[START],
writers=[
ChannelWrite(
write_entries,
tags=[TAG_HIDDEN],
require_at_least_one_of=output_keys,
),
],
)
elif node is not None:
input_schema = node.input if node else self.builder.schema
input_values = {k: k for k in self.builder.schemas[input_schema]}
is_single_input = len(input_values) == 1 and "__root__" in input_values
self.channels[key] = EphemeralValue(Any, guard=False)
self.nodes[key] = PregelNode(
triggers=[],
# read state keys and managed values
channels=(list(input_values) if is_single_input else input_values),
# coerce state dict to schema class (eg. pydantic model)
mapper=(
None
if is_single_input or issubclass(input_schema, dict)
else partial(_coerce_state, input_schema)
),
writers=[
# publish to this channel and state keys
ChannelWrite(
[ChannelWriteEntry(key, key)] + write_entries,
tags=[TAG_HIDDEN],
),
],
metadata=node.metadata,
retry_policy=node.retry_policy,
bound=node.runnable,
)
else:
raise RuntimeError
def attach_edge(self, starts: Union[str, Sequence[str]], end: str) -> None:
if isinstance(starts, str):
if starts == START:
channel_name = f"start:{end}"
# register channel
self.channels[channel_name] = EphemeralValue(Any)
# subscribe to channel
self.nodes[end].triggers.append(channel_name)
# publish to channel
self.nodes[START] |= ChannelWrite(
[ChannelWriteEntry(channel_name, START)], tags=[TAG_HIDDEN]
)
elif end != END:
# subscribe to start channel
self.nodes[end].triggers.append(starts)
elif end != END:
channel_name = f"join:{'+'.join(starts)}:{end}"
# register channel
self.channels[channel_name] = NamedBarrierValue(str, set(starts))
# subscribe to channel
self.nodes[end].triggers.append(channel_name)
# publish to channel
for start in starts:
self.nodes[start] |= ChannelWrite(
[ChannelWriteEntry(channel_name, start)], tags=[TAG_HIDDEN]
)
def attach_branch(
self, start: str, name: str, branch: Branch, *, with_reader: bool = True
) -> None:
def branch_writer(
packets: Sequence[Union[str, Send]], config: RunnableConfig
) -> None:
if filtered := [p for p in packets if p != END]:
writes = [
(
ChannelWriteEntry(f"branch:{start}:{name}:{p}", start)
if not isinstance(p, Send)
else p
)
for p in filtered
]
if branch.then and branch.then != END:
writes.append(
ChannelWriteEntry(
f"branch:{start}:{name}::then",
WaitForNames(
{p.node if isinstance(p, Send) else p for p in filtered}
),
)
)
ChannelWrite.do_write(
config, cast(Sequence[Union[Send, ChannelWriteEntry]], writes)
)
# attach branch publisher
schema = (
self.builder.nodes[start].input
if start in self.builder.nodes
else self.builder.schema
)
self.nodes[start] |= branch.run(
branch_writer,
_get_state_reader(self.builder, schema) if with_reader else None,
)
# attach branch subscribers
ends = (
branch.ends.values()
if branch.ends
else [node for node in self.builder.nodes if node != branch.then]
)
for end in ends:
if end != END:
channel_name = f"branch:{start}:{name}:{end}"
self.channels[channel_name] = EphemeralValue(Any, guard=False)
self.nodes[end].triggers.append(channel_name)
# attach then subscriber
if branch.then and branch.then != END:
channel_name = f"branch:{start}:{name}::then"
self.channels[channel_name] = DynamicBarrierValue(str)
self.nodes[branch.then].triggers.append(channel_name)
for end in ends:
if end != END:
self.nodes[end] |= ChannelWrite(
[ChannelWriteEntry(channel_name, end)], tags=[TAG_HIDDEN]
)
def _get_state_reader(
builder: StateGraph, schema: Type[Any]
) -> Callable[[RunnableConfig], Any]:
state_keys = list(builder.channels)
select = list(builder.schemas[schema])
return partial(
ChannelRead.do_read,
select=select[0] if select == ["__root__"] else select,
fresh=True,
# coerce state dict to schema class (eg. pydantic model)
mapper=(
None
if state_keys == ["__root__"] or issubclass(schema, dict)
else partial(_coerce_state, schema)
),
)
def _coerce_state(schema: Type[Any], input: dict[str, Any]) -> dict[str, Any]:
return schema(**input)
def _control_branch(value: Any) -> Sequence[Union[str, Send]]:
if isinstance(value, Send):
return [value]
if not isinstance(value, Command):
return EMPTY_SEQ
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value.goto, Send):
rtn.append(value.goto)
elif isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
return rtn
async def _acontrol_branch(value: Any) -> Sequence[Union[str, Send]]:
if isinstance(value, Send):
return [value]
if not isinstance(value, Command):
return EMPTY_SEQ
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value.goto, Send):
rtn.append(value.goto)
elif isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
return rtn
CONTROL_BRANCH_PATH = RunnableCallable(
_control_branch, _acontrol_branch, tags=[TAG_HIDDEN], trace=False, recurse=False
)
CONTROL_BRANCH = Branch(CONTROL_BRANCH_PATH, None)
def _get_channels(
schema: Type[dict],
) -> tuple[dict[str, BaseChannel], dict[str, ManagedValueSpec]]:
if not hasattr(schema, "__annotations__"):
return {"__root__": _get_channel("__root__", schema, allow_managed=False)}, {}
all_keys = {
name: _get_channel(name, typ)
for name, typ in get_type_hints(schema, include_extras=True).items()
if name != "__slots__"
}
return (
{k: v for k, v in all_keys.items() if isinstance(v, BaseChannel)},
{k: v for k, v in all_keys.items() if is_managed_value(v)},
)
@overload
def _get_channel(
name: str, annotation: Any, *, allow_managed: Literal[False]
) -> BaseChannel: ...
@overload
def _get_channel(
name: str, annotation: Any, *, allow_managed: Literal[True] = True
) -> Union[BaseChannel, ManagedValueSpec]: ...
def _get_channel(
name: str, annotation: Any, *, allow_managed: bool = True
) -> Union[BaseChannel, ManagedValueSpec]:
if manager := _is_field_managed_value(name, annotation):
if allow_managed:
return manager
else:
raise ValueError(f"This {annotation} not allowed in this position")
elif channel := _is_field_channel(annotation):
channel.key = name
return channel
elif channel := _is_field_binop(annotation):
channel.key = name
return channel
fallback: LastValue = LastValue(annotation)
fallback.key = name
return fallback
def _is_field_channel(typ: Type[Any]) -> Optional[BaseChannel]:
if hasattr(typ, "__metadata__"):
meta = typ.__metadata__
if len(meta) >= 1 and isinstance(meta[-1], BaseChannel):
return meta[-1]
elif len(meta) >= 1 and isclass(meta[-1]) and issubclass(meta[-1], BaseChannel):
return meta[-1](typ.__origin__ if hasattr(typ, "__origin__") else typ)
return None
def _is_field_binop(typ: Type[Any]) -> Optional[BinaryOperatorAggregate]:
if hasattr(typ, "__metadata__"):
meta = typ.__metadata__
if len(meta) >= 1 and callable(meta[-1]):
sig = signature(meta[-1])
params = list(sig.parameters.values())
if len(params) == 2 and all(
p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) for p in params
):
return BinaryOperatorAggregate(typ, meta[-1])
else:
raise ValueError(
f"Invalid reducer signature. Expected (a, b) -> c. Got {sig}"
)
return None
def _is_field_managed_value(name: str, typ: Type[Any]) -> Optional[ManagedValueSpec]:
if hasattr(typ, "__metadata__"):
meta = typ.__metadata__
if len(meta) >= 1:
decoration = get_origin(meta[-1]) or meta[-1]
if is_managed_value(decoration):
if isinstance(decoration, ConfiguredManagedValue):
for k, v in decoration.kwargs.items():
if v is ChannelKeyPlaceholder:
decoration.kwargs[k] = name
if v is ChannelTypePlaceholder:
decoration.kwargs[k] = typ.__origin__
return decoration
return None
def _get_schema(
typ: Type,
schemas: dict,
channels: dict,
name: str,
) -> type[BaseModel]:
if isclass(typ) and issubclass(typ, (BaseModel, BaseModelV1)):
return typ
else:
keys = list(schemas[typ].keys())
if len(keys) == 1 and keys[0] == "__root__":
return create_model(
name,
root=(channels[keys[0]].UpdateType, None),
)
else:
return create_model(
name,
field_definitions={
k: (
channels[k].UpdateType,
(
get_field_default(
k,
channels[k].UpdateType,
typ,
)
),
)
for k in schemas[typ]
if k in channels and isinstance(channels[k], BaseChannel)
},
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/_api/deprecation.py`:
```py
import functools
import warnings
from typing import Any, Callable, Type, TypeVar, Union, cast
class LangGraphDeprecationWarning(DeprecationWarning):
pass
F = TypeVar("F", bound=Callable[..., Any])
C = TypeVar("C", bound=Type[Any])
def deprecated(
since: str, alternative: str, *, removal: str = "", example: str = ""
) -> Callable[[F], F]:
def decorator(obj: Union[F, C]) -> Union[F, C]:
removal_str = removal if removal else "a future version"
message = (
f"{obj.__name__} is deprecated as of version {since} and will be"
f" removed in {removal_str}. Use {alternative} instead.{example}"
)
if isinstance(obj, type):
original_init = obj.__init__ # type: ignore[misc]
@functools.wraps(original_init)
def new_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore[no-untyped-def]
warnings.warn(message, LangGraphDeprecationWarning, stacklevel=2)
original_init(self, *args, **kwargs)
obj.__init__ = new_init # type: ignore[misc]
docstring = (
f"**Deprecated**: This class is deprecated as of version {since}. "
f"Use `{alternative}` instead."
)
if obj.__doc__:
docstring = docstring + f"\n\n{obj.__doc__}"
obj.__doc__ = docstring
return cast(C, obj)
elif callable(obj):
@functools.wraps(obj)
def wrapper(*args: Any, **kwargs: Any) -> Any:
warnings.warn(message, LangGraphDeprecationWarning, stacklevel=2)
return obj(*args, **kwargs)
docstring = (
f"**Deprecated**: This function is deprecated as of version {since}. "
f"Use `{alternative}` instead."
)
if obj.__doc__:
docstring = docstring + f"\n\n{obj.__doc__}"
wrapper.__doc__ = docstring
return cast(F, wrapper)
else:
raise TypeError(
f"Can only add deprecation decorator to classes or callables, got '{type(obj)}' instead."
)
return decorator
def deprecated_parameter(
arg_name: str, since: str, alternative: str, *, removal: str
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@functools.wraps(func)
def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
if arg_name in kwargs:
warnings.warn(
f"Parameter '{arg_name}' in function '{func.__name__}' is "
f"deprecated as of version {since} and will be removed in version {removal}. "
f"Use '{alternative}' parameter instead.",
category=LangGraphDeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return cast(F, wrapper)
return decorator
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/utils/queue.py`:
```py
# type: ignore
import asyncio
import queue
import sys
import threading
import types
from collections import deque
from time import monotonic
from typing import Optional
PY_310 = sys.version_info >= (3, 10)
class AsyncQueue(asyncio.Queue):
"""Async unbounded FIFO queue with a wait() method.
Subclassed from asyncio.Queue, adding a wait() method."""
async def wait(self) -> None:
"""If queue is empty, wait until an item is available.
Copied from Queue.get(), removing the call to .get_nowait(),
ie. this doesn't consume the item, just waits for it.
"""
while self.empty():
if PY_310:
getter = self._get_loop().create_future()
else:
getter = self._loop.create_future()
self._getters.append(getter)
try:
await getter
except:
getter.cancel() # Just in case getter is not done yet.
try:
# Clean self._getters from canceled getters.
self._getters.remove(getter)
except ValueError:
# The getter could be removed from self._getters by a
# previous put_nowait call.
pass
if not self.empty() and not getter.cancelled():
# We were woken up by put_nowait(), but can't take
# the call. Wake up the next in line.
self._wakeup_next(self._getters)
raise
class Semaphore(threading.Semaphore):
"""Semaphore subclass with a wait() method."""
def wait(self, blocking: bool = True, timeout: Optional[float] = None):
"""Block until the semaphore can be acquired, but don't acquire it."""
if not blocking and timeout is not None:
raise ValueError("can't specify timeout for non-blocking acquire")
rc = False
endtime = None
with self._cond:
while self._value == 0:
if not blocking:
break
if timeout is not None:
if endtime is None:
endtime = monotonic() + timeout
else:
timeout = endtime - monotonic()
if timeout <= 0:
break
self._cond.wait(timeout)
else:
rc = True
return rc
class SyncQueue:
"""Unbounded FIFO queue with a wait() method.
Adapted from pure Python implementation of queue.SimpleQueue.
"""
def __init__(self):
self._queue = deque()
self._count = Semaphore(0)
def put(self, item, block=True, timeout=None):
"""Put the item on the queue.
The optional 'block' and 'timeout' arguments are ignored, as this method
never blocks. They are provided for compatibility with the Queue class.
"""
self._queue.append(item)
self._count.release()
def get(self, block=True, timeout=None):
"""Remove and return an item from the queue.
If optional args 'block' is true and 'timeout' is None (the default),
block if necessary until an item is available. If 'timeout' is
a non-negative number, it blocks at most 'timeout' seconds and raises
the Empty exception if no item was available within that time.
Otherwise ('block' is false), return an item if one is immediately
available, else raise the Empty exception ('timeout' is ignored
in that case).
"""
if timeout is not None and timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
if not self._count.acquire(block, timeout):
raise queue.Empty
try:
return self._queue.popleft()
except IndexError:
raise queue.Empty
def wait(self, block=True, timeout=None):
"""If queue is empty, wait until an item maybe is available,
but don't consume it.
"""
if timeout is not None and timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
self._count.wait(block, timeout)
def empty(self):
"""Return True if the queue is empty, False otherwise (not reliable!)."""
return len(self._queue) == 0
def qsize(self):
"""Return the approximate size of the queue (not reliable!)."""
return len(self._queue)
__class_getitem__ = classmethod(types.GenericAlias)
__all__ = ["AsyncQueue", "SyncQueue"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/utils/config.py`:
```py
import asyncio
import sys
from collections import ChainMap
from typing import Any, Optional, Sequence, cast
from langchain_core.callbacks import (
AsyncCallbackManager,
BaseCallbackManager,
CallbackManager,
Callbacks,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import (
CONFIG_KEYS,
COPIABLE_KEYS,
DEFAULT_RECURSION_LIMIT,
var_child_runnable_config,
)
from langgraph.checkpoint.base import CheckpointMetadata
from langgraph.constants import (
CONF,
CONFIG_KEY_CHECKPOINT_ID,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_NS,
)
def patch_configurable(
config: Optional[RunnableConfig], patch: dict[str, Any]
) -> RunnableConfig:
if config is None:
return {CONF: patch}
elif CONF not in config:
return {**config, CONF: patch}
else:
return {**config, CONF: {**config[CONF], **patch}}
def patch_checkpoint_map(
config: Optional[RunnableConfig], metadata: Optional[CheckpointMetadata]
) -> RunnableConfig:
if config is None:
return config
elif parents := (metadata.get("parents") if metadata else None):
conf = config[CONF]
return patch_configurable(
config,
{
CONFIG_KEY_CHECKPOINT_MAP: {
**parents,
conf[CONFIG_KEY_CHECKPOINT_NS]: conf[CONFIG_KEY_CHECKPOINT_ID],
},
},
)
else:
return config
def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
"""Merge multiple configs into one.
Args:
*configs (Optional[RunnableConfig]): The configs to merge.
Returns:
RunnableConfig: The merged config.
"""
base: RunnableConfig = {}
# Even though the keys aren't literals, this is correct
# because both dicts are the same type
for config in configs:
if config is None:
continue
for key, value in config.items():
if not value:
continue
if key == "metadata":
if base_value := base.get(key):
base[key] = {**base_value, **value} # type: ignore
else:
base[key] = value # type: ignore[literal-required]
elif key == "tags":
if base_value := base.get(key):
base[key] = [*base_value, *value] # type: ignore
else:
base[key] = value # type: ignore[literal-required]
elif key == CONF:
if base_value := base.get(key):
base[key] = {**base_value, **value} # type: ignore[dict-item]
else:
base[key] = value
elif key == "callbacks":
base_callbacks = base.get("callbacks")
# callbacks can be either None, list[handler] or manager
# so merging two callbacks values has 6 cases
if isinstance(value, list):
if base_callbacks is None:
base["callbacks"] = value.copy()
elif isinstance(base_callbacks, list):
base["callbacks"] = base_callbacks + value
else:
# base_callbacks is a manager
mngr = base_callbacks.copy()
for callback in value:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
elif isinstance(value, BaseCallbackManager):
# value is a manager
if base_callbacks is None:
base["callbacks"] = value.copy()
elif isinstance(base_callbacks, list):
mngr = value.copy()
for callback in base_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
else:
# base_callbacks is also a manager
base["callbacks"] = base_callbacks.merge(value)
else:
raise NotImplementedError
elif key == "recursion_limit":
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
base["recursion_limit"] = config["recursion_limit"]
else:
base[key] = config[key] # type: ignore[literal-required]
if CONF not in base:
base[CONF] = {}
return base
def patch_config(
config: Optional[RunnableConfig],
*,
callbacks: Optional[Callbacks] = None,
recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None,
run_name: Optional[str] = None,
configurable: Optional[dict[str, Any]] = None,
) -> RunnableConfig:
"""Patch a config with new values.
Args:
config (Optional[RunnableConfig]): The config to patch.
callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
Defaults to None.
recursion_limit (Optional[int], optional): The recursion limit to set.
Defaults to None.
max_concurrency (Optional[int], optional): The max concurrency to set.
Defaults to None.
run_name (Optional[str], optional): The run name to set. Defaults to None.
configurable (Optional[Dict[str, Any]], optional): The configurable to set.
Defaults to None.
Returns:
RunnableConfig: The patched config.
"""
config = config.copy() if config is not None else {}
if callbacks is not None:
# If we're replacing callbacks, we need to unset run_name
# As that should apply only to the same run as the original callbacks
config["callbacks"] = callbacks
if "run_name" in config:
del config["run_name"]
if "run_id" in config:
del config["run_id"]
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
if max_concurrency is not None:
config["max_concurrency"] = max_concurrency
if run_name is not None:
config["run_name"] = run_name
if configurable is not None:
config[CONF] = {**config.get(CONF, {}), **configurable}
return config
def get_callback_manager_for_config(
config: RunnableConfig, tags: Optional[Sequence[str]] = None
) -> CallbackManager:
"""Get a callback manager for a config.
Args:
config (RunnableConfig): The config.
Returns:
CallbackManager: The callback manager.
"""
from langchain_core.callbacks.manager import CallbackManager
# merge tags
all_tags = config.get("tags")
if all_tags is not None and tags is not None:
all_tags = [*all_tags, *tags]
elif tags is not None:
all_tags = list(tags)
# use existing callbacks if they exist
if (callbacks := config.get("callbacks")) and isinstance(
callbacks, CallbackManager
):
if all_tags:
callbacks.add_tags(all_tags)
if metadata := config.get("metadata"):
callbacks.add_metadata(metadata)
return callbacks
else:
# otherwise create a new manager
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=all_tags,
inheritable_metadata=config.get("metadata"),
)
def get_async_callback_manager_for_config(
config: RunnableConfig,
tags: Optional[Sequence[str]] = None,
) -> AsyncCallbackManager:
"""Get an async callback manager for a config.
Args:
config (RunnableConfig): The config.
Returns:
AsyncCallbackManager: The async callback manager.
"""
from langchain_core.callbacks.manager import AsyncCallbackManager
# merge tags
all_tags = config.get("tags")
if all_tags is not None and tags is not None:
all_tags = [*all_tags, *tags]
elif tags is not None:
all_tags = list(tags)
# use existing callbacks if they exist
if (callbacks := config.get("callbacks")) and isinstance(
callbacks, AsyncCallbackManager
):
if all_tags:
callbacks.add_tags(all_tags)
if metadata := config.get("metadata"):
callbacks.add_metadata(metadata)
return callbacks
else:
# otherwise create a new manager
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
def ensure_config(*configs: Optional[RunnableConfig]) -> RunnableConfig:
"""Ensure that a config is a dict with all keys present.
Args:
config (Optional[RunnableConfig], optional): The config to ensure.
Defaults to None.
Returns:
RunnableConfig: The ensured config.
"""
empty = RunnableConfig(
tags=[],
metadata=ChainMap(),
callbacks=None,
recursion_limit=DEFAULT_RECURSION_LIMIT,
configurable={},
)
if var_config := var_child_runnable_config.get():
empty.update(
{
k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined]
for k, v in var_config.items()
if v is not None
},
)
for config in configs:
if config is None:
continue
for k, v in config.items():
if v is not None and k in CONFIG_KEYS:
if k == CONF:
empty[k] = cast(dict, v).copy()
else:
empty[k] = v # type: ignore[literal-required]
for k, v in config.items():
if v is not None and k not in CONFIG_KEYS:
empty[CONF][k] = v
for key, value in empty[CONF].items():
if (
not key.startswith("__")
and isinstance(value, (str, int, float, bool))
and key not in empty["metadata"]
):
empty["metadata"][key] = value
return empty
def get_configurable() -> dict[str, Any]:
if sys.version_info < (3, 11):
try:
if asyncio.current_task():
raise RuntimeError(
"Python 3.11 or later required to use this in an async context"
)
except RuntimeError:
pass
if var_config := var_child_runnable_config.get():
return var_config[CONF]
else:
raise RuntimeError("Called get_configurable outside of a runnable context")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/utils/fields.py`:
```py
import dataclasses
from typing import Any, Optional, Type, Union
from typing_extensions import Annotated, NotRequired, ReadOnly, Required, get_origin
def _is_optional_type(type_: Any) -> bool:
"""Check if a type is Optional."""
if hasattr(type_, "__origin__") and hasattr(type_, "__args__"):
origin = get_origin(type_)
if origin is Optional:
return True
if origin is Union:
return any(
arg is type(None) or _is_optional_type(arg) for arg in type_.__args__
)
if origin is Annotated:
return _is_optional_type(type_.__args__[0])
return origin is None
if hasattr(type_, "__bound__") and type_.__bound__ is not None:
return _is_optional_type(type_.__bound__)
return type_ is None
def _is_required_type(type_: Any) -> Optional[bool]:
"""Check if an annotation is marked as Required/NotRequired.
Returns:
- True if required
- False if not required
- None if not annotated with either
"""
origin = get_origin(type_)
if origin is Required:
return True
if origin is NotRequired:
return False
if origin is Annotated or getattr(origin, "__args__", None):
# See https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-annotated
return _is_required_type(type_.__args__[0])
return None
def _is_readonly_type(type_: Any) -> bool:
"""Check if an annotation is marked as ReadOnly.
Returns:
- True if is read only
- False if not read only
"""
# See: https://typing.readthedocs.io/en/latest/spec/typeddict.html#typing-readonly-type-qualifier
origin = get_origin(type_)
if origin is Annotated:
return _is_readonly_type(type_.__args__[0])
if origin is ReadOnly:
return True
return False
_DEFAULT_KEYS: frozenset[str] = frozenset()
def get_field_default(name: str, type_: Any, schema: Type[Any]) -> Any:
"""Determine the default value for a field in a state schema.
This is based on:
If TypedDict:
- Required/NotRequired
- total=False -> everything optional
- Type annotation (Optional/Union[None])
"""
optional_keys = getattr(schema, "__optional_keys__", _DEFAULT_KEYS)
irq = _is_required_type(type_)
if name in optional_keys:
# Either total=False or explicit NotRequired.
# No type annotation trumps this.
if irq:
# Unless it's earlier versions of python & explicit Required
return ...
return None
if irq is not None:
if irq:
# Handle Required[<type>]
# (we already handled NotRequired and total=False)
return ...
# Handle NotRequired[<type>] for earlier versions of python
return None
if dataclasses.is_dataclass(schema):
field_info = next(
(f for f in dataclasses.fields(schema) if f.name == name), None
)
if field_info:
if (
field_info.default is not dataclasses.MISSING
and field_info.default is not ...
):
return field_info.default
elif field_info.default_factory is not dataclasses.MISSING:
return field_info.default_factory()
# Note, we ignore ReadOnly attributes,
# as they don't make much sense. (we don't care if you mutate the state in your node)
# and mutating state in your node has no effect on our graph state.
# Base case is the annotation
if _is_optional_type(type_):
return None
return ...
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/utils/pydantic.py`:
```py
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
def create_model(
model_name: str,
*,
field_definitions: Optional[Dict[str, Any]] = None,
root: Optional[Any] = None,
) -> Union[BaseModel, BaseModelV1]:
"""Create a pydantic model with the given field definitions.
Args:
model_name: The name of the model.
field_definitions: The field definitions for the model.
root: Type for a root model (RootModel)
"""
try:
# for langchain-core >= 0.3.0
from langchain_core.utils.pydantic import create_model_v2
return create_model_v2(
model_name,
field_definitions=field_definitions,
root=root,
)
except ImportError:
# for langchain-core < 0.3.0
from langchain_core.runnables.utils import create_model
v1_kwargs = {}
if root is not None:
v1_kwargs["__root__"] = root
return create_model(model_name, **v1_kwargs, **(field_definitions or {}))
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/utils/runnable.py`:
```py
import asyncio
import enum
import inspect
import sys
from contextlib import AsyncExitStack
from contextvars import copy_context
from functools import partial, wraps
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Iterator,
Optional,
Sequence,
Union,
cast,
)
from langchain_core.runnables.base import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnableLike,
RunnableParallel,
RunnableSequence,
)
from langchain_core.runnables.config import (
run_in_executor,
var_child_runnable_config,
)
from langchain_core.runnables.utils import Input
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from typing_extensions import TypeGuard
from langgraph.constants import CONF, CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER
from langgraph.store.base import BaseStore
from langgraph.types import StreamWriter
from langgraph.utils.config import (
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
patch_config,
)
try:
from langchain_core.runnables.config import _set_config_context
except ImportError:
# For forwards compatibility
def _set_config_context(context: RunnableConfig) -> None: # type: ignore
"""Set the context for the current thread."""
var_child_runnable_config.set(context)
# Before Python 3.11 native StrEnum is not available
class StrEnum(str, enum.Enum):
"""A string enum."""
ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)
KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
(
sys.intern("writer"),
(StreamWriter, "StreamWriter", inspect.Parameter.empty),
CONFIG_KEY_STREAM_WRITER,
lambda _: None,
),
(
sys.intern("store"),
(BaseStore, "BaseStore", inspect.Parameter.empty),
CONFIG_KEY_STORE,
inspect.Parameter.empty,
),
)
"""List of kwargs that can be passed to functions, and their corresponding
config keys, default values and type annotations."""
VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
class RunnableCallable(Runnable):
"""A much simpler version of RunnableLambda that requires sync and async functions."""
def __init__(
self,
func: Optional[Callable[..., Union[Any, Runnable]]],
afunc: Optional[Callable[..., Awaitable[Union[Any, Runnable]]]] = None,
*,
name: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
trace: bool = True,
recurse: bool = True,
**kwargs: Any,
) -> None:
self.name = name
if self.name is None:
if func:
try:
if func.__name__ != "<lambda>":
self.name = func.__name__
except AttributeError:
pass
elif afunc:
try:
self.name = afunc.__name__
except AttributeError:
pass
self.func = func
self.afunc = afunc
self.tags = tags
self.kwargs = kwargs
self.trace = trace
self.recurse = recurse
# check signature
if func is None and afunc is None:
raise ValueError("At least one of func or afunc must be provided.")
params = inspect.signature(cast(Callable, func or afunc)).parameters
self.func_accepts_config = "config" in params
self.func_accepts: dict[str, bool] = {}
for kw, typ, _, _ in KWARGS_CONFIG_KEYS:
p = params.get(kw)
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)
def __repr__(self) -> str:
repr_args = {
k: v
for k, v in self.__dict__.items()
if k not in {"name", "func", "afunc", "config", "kwargs", "trace"}
}
return f"{self.get_name()}({', '.join(f'{k}={v!r}' for k, v in repr_args.items())})"
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
if self.func is None:
raise TypeError(
f'No synchronous function provided to "{self.name}".'
"\nEither initialize with a synchronous function or invoke"
" via the async API (ainvoke, astream, etc.)"
)
if config is None:
config = ensure_config()
kwargs = {**self.kwargs, **kwargs}
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue
if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
context = copy_context()
if self.trace:
callback_manager = get_callback_manager_for_config(config, self.tags)
run_manager = callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
ret = context.run(self.func, input, **kwargs)
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(ret)
else:
context.run(_set_config_context, config)
ret = context.run(self.func, input, **kwargs)
if isinstance(ret, Runnable) and self.recurse:
return ret.invoke(input, config)
return ret
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
if not self.afunc:
return self.invoke(input, config)
if config is None:
config = ensure_config()
kwargs = {**self.kwargs, **kwargs}
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue
if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
context = copy_context()
if self.trace:
callback_manager = get_async_callback_manager_for_config(config, self.tags)
run_manager = await callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.name,
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context.run(_set_config_context, child_config)
coro = cast(Coroutine[None, None, Any], self.afunc(input, **kwargs))
if ASYNCIO_ACCEPTS_CONTEXT:
ret = await asyncio.create_task(coro, context=context)
else:
ret = await coro
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(ret)
else:
context.run(_set_config_context, config)
if ASYNCIO_ACCEPTS_CONTEXT:
coro = cast(Coroutine[None, None, Any], self.afunc(input, **kwargs))
ret = await asyncio.create_task(coro, context=context)
else:
ret = await self.afunc(input, **kwargs)
if isinstance(ret, Runnable) and self.recurse:
return await ret.ainvoke(input, config)
return ret
def is_async_callable(
func: Any,
) -> TypeGuard[Callable[..., Awaitable]]:
"""Check if a function is async."""
return (
asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__")
and asyncio.iscoroutinefunction(func.__call__)
)
def is_async_generator(
func: Any,
) -> TypeGuard[Callable[..., AsyncIterator]]:
"""Check if a function is an async generator."""
return (
inspect.isasyncgenfunction(func)
or hasattr(func, "__call__")
and inspect.isasyncgenfunction(func.__call__)
)
def coerce_to_runnable(
thing: RunnableLike, *, name: Optional[str], trace: bool
) -> Runnable:
"""Coerce a runnable-like object into a Runnable.
Args:
thing: A runnable-like object.
Returns:
A Runnable.
"""
if isinstance(thing, Runnable):
return thing
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
return RunnableLambda(thing, name=name)
elif callable(thing):
if is_async_callable(thing):
return RunnableCallable(None, thing, name=name, trace=trace)
else:
return RunnableCallable(
thing,
wraps(thing)(partial(run_in_executor, None, thing)), # type: ignore[arg-type]
name=name,
trace=trace,
)
elif isinstance(thing, dict):
return RunnableParallel(thing)
else:
raise TypeError(
f"Expected a Runnable, callable or dict."
f"Instead got an unsupported type: {type(thing)}"
)
class RunnableSeq(Runnable):
"""A simpler version of RunnableSequence."""
def __init__(
self,
*steps: RunnableLike,
name: Optional[str] = None,
) -> None:
"""Create a new RunnableSequence.
Args:
steps: The steps to include in the sequence.
name: The name of the Runnable. Defaults to None.
first: The first Runnable in the sequence. Defaults to None.
middle: The middle Runnables in the sequence. Defaults to None.
last: The last Runnable in the sequence. Defaults to None.
Raises:
ValueError: If the sequence has less than 2 steps.
"""
steps_flat: list[Runnable] = []
for step in steps:
if isinstance(step, RunnableSequence):
steps_flat.extend(step.steps)
elif isinstance(step, RunnableSeq):
steps_flat.extend(step.steps)
else:
steps_flat.append(coerce_to_runnable(step, name=None, trace=True))
if len(steps_flat) < 2:
raise ValueError(
f"RunnableSeq must have at least 2 steps, got {len(steps_flat)}"
)
self.steps = steps_flat
self.name = name
def __or__(
self,
other: Any,
) -> Runnable:
if isinstance(other, RunnableSequence):
return RunnableSeq(
*self.steps,
other.first,
*other.middle,
other.last,
name=self.name or other.name,
)
elif isinstance(other, RunnableSeq):
return RunnableSeq(
*self.steps,
*other.steps,
name=self.name or other.name,
)
else:
return RunnableSeq(
*self.steps,
coerce_to_runnable(other, name=None, trace=True),
name=self.name,
)
def __ror__(
self,
other: Any,
) -> Runnable:
if isinstance(other, RunnableSequence):
return RunnableSequence(
other.first,
*other.middle,
other.last,
*self.steps,
name=other.name or self.name,
)
elif isinstance(other, RunnableSeq):
return RunnableSeq(
*other.steps,
*self.steps,
name=other.name or self.name,
)
else:
return RunnableSequence(
coerce_to_runnable(other, name=None, trace=True),
*self.steps,
name=self.name,
)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
if config is None:
config = ensure_config()
# setup callbacks and context
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
)
context = copy_context()
context.run(_set_config_context, config)
if i == 0:
input = context.run(step.invoke, input, config, **kwargs)
else:
input = context.run(step.invoke, input, config)
# finish the root run
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(input)
return input
async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Any:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
)
context = copy_context()
context.run(_set_config_context, config)
if i == 0:
coro = step.ainvoke(input, config, **kwargs)
else:
coro = step.ainvoke(input, config)
if ASYNCIO_ACCEPTS_CONTEXT:
input = await asyncio.create_task(coro, context=context)
else:
input = await asyncio.create_task(coro)
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(input)
return input
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Any]:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
# stream the last steps
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
for idx, step in enumerate(self.steps):
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{idx+1}"),
)
if idx == 0:
iterator = step.stream(input, config, **kwargs)
else:
iterator = step.transform(iterator, config)
if stream_handler := next(
(
cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers
if isinstance(h, _StreamingCallbackHandler)
),
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator)
output: Any = None
add_supported = False
for chunk in iterator:
yield chunk
# collect final output
if output is None:
output = chunk
elif add_supported:
try:
output = output + chunk
except TypeError:
output = chunk
add_supported = False
else:
output = chunk
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(output)
async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Any]:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
async with AsyncExitStack() as stack:
# stream the last steps
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
for idx, step in enumerate(self.steps):
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{idx+1}"),
)
if idx == 0:
aiterator = step.astream(input, config, **kwargs)
else:
aiterator = step.atransform(aiterator, config)
if hasattr(aiterator, "aclose"):
stack.push_async_callback(aiterator.aclose)
if stream_handler := next(
(
cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers
if isinstance(h, _StreamingCallbackHandler)
),
None,
):
# populates streamed_output in astream_log() output if needed
aiterator = stream_handler.tap_output_aiter(
run_manager.run_id, aiterator
)
output: Any = None
add_supported = False
async for chunk in aiterator:
yield chunk
# collect final output
if add_supported:
try:
output = output + chunk
except TypeError:
output = chunk
add_supported = False
else:
output = chunk
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(output)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/types.py`:
```py
import dataclasses
import sys
from collections import deque
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Hashable,
Literal,
NamedTuple,
Optional,
Sequence,
Type,
TypedDict,
TypeVar,
Union,
cast,
)
from langchain_core.runnables import Runnable, RunnableConfig
from typing_extensions import Self
from langgraph.checkpoint.base import (
BaseCheckpointSaver,
CheckpointMetadata,
PendingWrite,
)
if TYPE_CHECKING:
from langgraph.store.base import BaseStore
All = Literal["*"]
"""Special value to indicate that graph should interrupt on all nodes."""
Checkpointer = Union[None, Literal[False], BaseCheckpointSaver]
"""Type of the checkpointer to use for a subgraph. False disables checkpointing,
even if the parent graph has a checkpointer. None inherits checkpointer."""
StreamMode = Literal["values", "updates", "debug", "messages", "custom"]
"""How the stream method should emit outputs.
- 'values': Emit all values of the state for each step.
- 'updates': Emit only the node name(s) and updates
that were returned by the node(s) **after** each step.
- 'debug': Emit debug events for each step.
- 'messages': Emit LLM messages token-by-token.
- 'custom': Emit custom output `write: StreamWriter` kwarg of each node.
"""
StreamWriter = Callable[[Any], None]
"""Callable that accepts a single argument and writes it to the output stream.
Always injected into nodes if requested as a keyword argument, but it's a no-op
when not using stream_mode="custom"."""
if sys.version_info >= (3, 10):
_DC_KWARGS = {"kw_only": True, "slots": True, "frozen": True}
else:
_DC_KWARGS = {"frozen": True}
def default_retry_on(exc: Exception) -> bool:
import httpx
import requests
if isinstance(exc, ConnectionError):
return True
if isinstance(
exc,
(
ValueError,
TypeError,
ArithmeticError,
ImportError,
LookupError,
NameError,
SyntaxError,
RuntimeError,
ReferenceError,
StopIteration,
StopAsyncIteration,
OSError,
),
):
return False
if isinstance(exc, httpx.HTTPStatusError):
return 500 <= exc.response.status_code < 600
if isinstance(exc, requests.HTTPError):
return 500 <= exc.response.status_code < 600 if exc.response else True
return True
class RetryPolicy(NamedTuple):
"""Configuration for retrying nodes."""
initial_interval: float = 0.5
"""Amount of time that must elapse before the first retry occurs. In seconds."""
backoff_factor: float = 2.0
"""Multiplier by which the interval increases after each retry."""
max_interval: float = 128.0
"""Maximum amount of time that may elapse between retries. In seconds."""
max_attempts: int = 3
"""Maximum number of attempts to make before giving up, including the first."""
jitter: bool = True
"""Whether to add random jitter to the interval between retries."""
retry_on: Union[
Type[Exception], Sequence[Type[Exception]], Callable[[Exception], bool]
] = default_retry_on
"""List of exception classes that should trigger a retry, or a callable that returns True for exceptions that should trigger a retry."""
class CachePolicy(NamedTuple):
"""Configuration for caching nodes."""
pass
@dataclasses.dataclass(**_DC_KWARGS)
class Interrupt:
value: Any
resumable: bool = False
ns: Optional[Sequence[str]] = None
when: Literal["during"] = "during"
class PregelTask(NamedTuple):
id: str
name: str
path: tuple[Union[str, int, tuple], ...]
error: Optional[Exception] = None
interrupts: tuple[Interrupt, ...] = ()
state: Union[None, RunnableConfig, "StateSnapshot"] = None
result: Optional[dict[str, Any]] = None
class PregelExecutableTask(NamedTuple):
name: str
input: Any
proc: Runnable
writes: deque[tuple[str, Any]]
config: RunnableConfig
triggers: list[str]
retry_policy: Optional[RetryPolicy]
cache_policy: Optional[CachePolicy]
id: str
path: tuple[Union[str, int, tuple], ...]
scheduled: bool = False
writers: Sequence[Runnable] = ()
class StateSnapshot(NamedTuple):
"""Snapshot of the state of the graph at the beginning of a step."""
values: Union[dict[str, Any], Any]
"""Current values of channels"""
next: tuple[str, ...]
"""The name of the node to execute in each task for this step."""
config: RunnableConfig
"""Config used to fetch this snapshot"""
metadata: Optional[CheckpointMetadata]
"""Metadata associated with this snapshot"""
created_at: Optional[str]
"""Timestamp of snapshot creation"""
parent_config: Optional[RunnableConfig]
"""Config used to fetch the parent snapshot, if any"""
tasks: tuple[PregelTask, ...]
"""Tasks to execute in this step. If already attempted, may contain an error."""
class Send:
"""A message or packet to send to a specific node in the graph.
The `Send` class is used within a `StateGraph`'s conditional edges to
dynamically invoke a node with a custom state at the next step.
Importantly, the sent state can differ from the core graph's state,
allowing for flexible and dynamic workflow management.
One such example is a "map-reduce" workflow where your graph invokes
the same node multiple times in parallel with different states,
before aggregating the results back into the main graph's state.
Attributes:
node (str): The name of the target node to send the message to.
arg (Any): The state or message to send to the target node.
Examples:
>>> from typing import Annotated
>>> import operator
>>> class OverallState(TypedDict):
... subjects: list[str]
... jokes: Annotated[list[str], operator.add]
...
>>> from langgraph.types import Send
>>> from langgraph.graph import END, START
>>> def continue_to_jokes(state: OverallState):
... return [Send("generate_joke", {"subject": s}) for s in state['subjects']]
...
>>> from langgraph.graph import StateGraph
>>> builder = StateGraph(OverallState)
>>> builder.add_node("generate_joke", lambda state: {"jokes": [f"Joke about {state['subject']}"]})
>>> builder.add_conditional_edges(START, continue_to_jokes)
>>> builder.add_edge("generate_joke", END)
>>> graph = builder.compile()
>>>
>>> # Invoking with two subjects results in a generated joke for each
>>> graph.invoke({"subjects": ["cats", "dogs"]})
{'subjects': ['cats', 'dogs'], 'jokes': ['Joke about cats', 'Joke about dogs']}
"""
__slots__ = ("node", "arg")
node: str
arg: Any
def __init__(self, /, node: str, arg: Any) -> None:
"""
Initialize a new instance of the Send class.
Args:
node (str): The name of the target node to send the message to.
arg (Any): The state or message to send to the target node.
"""
self.node = node
self.arg = arg
def __hash__(self) -> int:
return hash((self.node, self.arg))
def __repr__(self) -> str:
return f"Send(node={self.node!r}, arg={self.arg!r})"
def __eq__(self, value: object) -> bool:
return (
isinstance(value, Send)
and self.node == value.node
and self.arg == value.arg
)
N = TypeVar("N", bound=Hashable)
@dataclasses.dataclass(**_DC_KWARGS)
class Command(Generic[N]):
"""One or more commands to update the graph's state and send messages to nodes.
Args:
graph: graph to send the command to. Supported values are:
- None: the current graph (default)
- GraphCommand.PARENT: closest parent graph
update: update to apply to the graph's state.
resume: value to resume execution with. To be used together with [`interrupt()`][langgraph.types.interrupt].
goto: can be one of the following:
- name of the node to navigate to next (any node that belongs to the specified `graph`)
- sequence of node names to navigate to next
- `Send` object (to execute a node with the input provided)
- sequence of `Send` objects
"""
graph: Optional[str] = None
update: Optional[dict[str, Any]] = None
resume: Optional[Union[Any, dict[str, Any]]] = None
goto: Union[Send, Sequence[Union[Send, str]], str] = ()
def __repr__(self) -> str:
# get all non-None values
contents = ", ".join(
f"{key}={value!r}"
for key, value in dataclasses.asdict(self).items()
if value
)
return f"Command({contents})"
PARENT: ClassVar[Literal["__parent__"]] = "__parent__"
StreamChunk = tuple[tuple[str, ...], str, Any]
class StreamProtocol:
__slots__ = ("modes", "__call__")
modes: set[StreamMode]
__call__: Callable[[Self, StreamChunk], None]
def __init__(
self,
__call__: Callable[[StreamChunk], None],
modes: set[StreamMode],
) -> None:
self.__call__ = cast(Callable[[Self, StreamChunk], None], __call__)
self.modes = modes
class LoopProtocol:
config: RunnableConfig
store: Optional["BaseStore"]
stream: Optional[StreamProtocol]
step: int
stop: int
def __init__(
self,
*,
step: int,
stop: int,
config: RunnableConfig,
store: Optional["BaseStore"] = None,
stream: Optional[StreamProtocol] = None,
) -> None:
self.stream = stream
self.config = config
self.store = store
self.step = step
self.stop = stop
class PregelScratchpad(TypedDict, total=False):
interrupt_counter: int
used_null_resume: bool
resume: list[Any]
def interrupt(value: Any) -> Any:
from langgraph.constants import (
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_SEND,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_WRITES,
NS_SEP,
NULL_TASK_ID,
RESUME,
)
from langgraph.errors import GraphInterrupt
from langgraph.utils.config import get_configurable
conf = get_configurable()
# track interrupt index
scratchpad: PregelScratchpad = conf[CONFIG_KEY_SCRATCHPAD]
if "interrupt_counter" not in scratchpad:
scratchpad["interrupt_counter"] = 0
else:
scratchpad["interrupt_counter"] += 1
idx = scratchpad["interrupt_counter"]
# find previous resume values
task_id = conf[CONFIG_KEY_TASK_ID]
writes: list[PendingWrite] = conf[CONFIG_KEY_WRITES]
scratchpad.setdefault(
"resume", next((w[2] for w in writes if w[0] == task_id and w[1] == RESUME), [])
)
if scratchpad["resume"]:
if idx < len(scratchpad["resume"]):
return scratchpad["resume"][idx]
# find current resume value
if not scratchpad.get("used_null_resume"):
scratchpad["used_null_resume"] = True
for tid, c, v in sorted(writes, key=lambda x: x[0], reverse=True):
if tid == NULL_TASK_ID and c == RESUME:
assert len(scratchpad["resume"]) == idx, (scratchpad["resume"], idx)
scratchpad["resume"].append(v)
conf[CONFIG_KEY_SEND]([(RESUME, scratchpad["resume"])])
return v
# no resume value found
raise GraphInterrupt(
(
Interrupt(
value=value,
resumable=True,
ns=cast(str, conf[CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP),
),
)
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/prebuilt/__init__.py`:
```py
"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools."""
from langgraph.prebuilt.chat_agent_executor import create_react_agent
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
from langgraph.prebuilt.tool_node import (
InjectedState,
InjectedStore,
ToolNode,
tools_condition,
)
from langgraph.prebuilt.tool_validator import ValidationNode
__all__ = [
"create_react_agent",
"ToolExecutor",
"ToolInvocation",
"ToolNode",
"tools_condition",
"ValidationNode",
"InjectedState",
"InjectedStore",
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/prebuilt/tool_validator.py`:
```py
"""This module provides a ValidationNode class that can be used to validate tool calls
in a langchain graph. It applies a pydantic schema to tool_calls in the models' outputs,
and returns a ToolMessage with the validated content. If the schema is not valid, it
returns a ToolMessage with the error message. The ValidationNode can be used in a
StateGraph with a "messages" key or in a MessageGraph. If multiple tool calls are
requested, they will be run in parallel.
"""
from typing import (
Any,
Callable,
Dict,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from langchain_core.messages import (
AIMessage,
AnyMessage,
ToolCall,
ToolMessage,
)
from langchain_core.runnables import (
RunnableConfig,
)
from langchain_core.runnables.config import get_executor_for_config
from langchain_core.tools import BaseTool, create_schema_from_function
from pydantic import BaseModel, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from langgraph.utils.runnable import RunnableCallable
def _default_format_error(
error: BaseException,
call: ToolCall,
schema: Union[Type[BaseModel], Type[BaseModelV1]],
) -> str:
"""Default error formatting function."""
return f"{repr(error)}\n\nRespond after fixing all validation errors."
class ValidationNode(RunnableCallable):
"""A node that validates all tools requests from the last AIMessage.
It can be used either in StateGraph with a "messages" key or in MessageGraph.
!!! note
This node does not actually **run** the tools, it only validates the tool calls,
which is useful for extraction and other use cases where you need to generate
structured output that conforms to a complex schema without losing the original
messages and tool IDs (for use in multi-turn conversations).
Args:
schemas: A list of schemas to validate the tool calls with. These can be
any of the following:
- A pydantic BaseModel class
- A BaseTool instance (the args_schema will be used)
- A function (a schema will be created from the function signature)
format_error: A function that takes an exception, a ToolCall, and a schema
and returns a formatted error string. By default, it returns the
exception repr and a message to respond after fixing validation errors.
name: The name of the node.
tags: A list of tags to add to the node.
Returns:
(Union[Dict[str, List[ToolMessage]], Sequence[ToolMessage]]): A list of ToolMessages with the validated content or error messages.
Examples:
Example usage for re-prompting the model to generate a valid response:
>>> from typing import Literal, Annotated, TypedDict
...
>>> from langchain_anthropic import ChatAnthropic
>>> from pydantic import BaseModel, validator
...
>>> from langgraph.graph import END, START, StateGraph
>>> from langgraph.prebuilt import ValidationNode
>>> from langgraph.graph.message import add_messages
...
...
>>> class SelectNumber(BaseModel):
... a: int
...
... @validator("a")
... def a_must_be_meaningful(cls, v):
... if v != 37:
... raise ValueError("Only 37 is allowed")
... return v
...
...
>>> class State(TypedDict):
... messages: Annotated[list, add_messages]
...
>>> builder = StateGraph(State)
>>> llm = ChatAnthropic(model="claude-3-haiku-20240307").bind_tools([SelectNumber])
>>> builder.add_node("model", llm)
>>> builder.add_node("validation", ValidationNode([SelectNumber]))
>>> builder.add_edge(START, "model")
...
...
>>> def should_validate(state: list) -> Literal["validation", "__end__"]:
... if state[-1].tool_calls:
... return "validation"
... return END
...
...
>>> builder.add_conditional_edges("model", should_validate)
...
...
>>> def should_reprompt(state: list) -> Literal["model", "__end__"]:
... for msg in state[::-1]:
... # None of the tool calls were errors
... if msg.type == "ai":
... return END
... if msg.additional_kwargs.get("is_error"):
... return "model"
... return END
...
...
>>> builder.add_conditional_edges("validation", should_reprompt)
...
...
>>> graph = builder.compile()
>>> res = graph.invoke(("user", "Select a number, any number"))
>>> # Show the retry logic
>>> for msg in res:
... msg.pretty_print()
================================ Human Message =================================
Select a number, any number
================================== Ai Message ==================================
[{'id': 'toolu_01JSjT9Pq8hGmTgmMPc6KnvM', 'input': {'a': 42}, 'name': 'SelectNumber', 'type': 'tool_use'}]
Tool Calls:
SelectNumber (toolu_01JSjT9Pq8hGmTgmMPc6KnvM)
Call ID: toolu_01JSjT9Pq8hGmTgmMPc6KnvM
Args:
a: 42
================================= Tool Message =================================
Name: SelectNumber
ValidationError(model='SelectNumber', errors=[{'loc': ('a',), 'msg': 'Only 37 is allowed', 'type': 'value_error'}])
Respond after fixing all validation errors.
================================== Ai Message ==================================
[{'id': 'toolu_01PkxSVxNxc5wqwCPW1FiSmV', 'input': {'a': 37}, 'name': 'SelectNumber', 'type': 'tool_use'}]
Tool Calls:
SelectNumber (toolu_01PkxSVxNxc5wqwCPW1FiSmV)
Call ID: toolu_01PkxSVxNxc5wqwCPW1FiSmV
Args:
a: 37
================================= Tool Message =================================
Name: SelectNumber
{"a": 37}
"""
def __init__(
self,
schemas: Sequence[Union[BaseTool, Type[BaseModel], Callable]],
*,
format_error: Optional[
Callable[[BaseException, ToolCall, Type[BaseModel]], str]
] = None,
name: str = "validation",
tags: Optional[list[str]] = None,
) -> None:
super().__init__(self._func, None, name=name, tags=tags, trace=False)
self._format_error = format_error or _default_format_error
self.schemas_by_name: Dict[str, Type[BaseModel]] = {}
for schema in schemas:
if isinstance(schema, BaseTool):
if schema.args_schema is None:
raise ValueError(
f"Tool {schema.name} does not have an args_schema defined."
)
self.schemas_by_name[schema.name] = schema.args_schema
elif isinstance(schema, type) and issubclass(
schema, (BaseModel, BaseModelV1)
):
self.schemas_by_name[schema.__name__] = cast(Type[BaseModel], schema)
elif callable(schema):
base_model = create_schema_from_function("Validation", schema)
self.schemas_by_name[schema.__name__] = base_model
else:
raise ValueError(
f"Unsupported input to ValidationNode. Expected BaseModel, tool or function. Got: {type(schema)}."
)
def _get_message(
self, input: Union[list[AnyMessage], dict[str, Any]]
) -> Tuple[str, AIMessage]:
"""Extract the last AIMessage from the input."""
if isinstance(input, list):
output_type = "list"
messages: list = input
elif messages := input.get("messages", []):
output_type = "dict"
else:
raise ValueError("No message found in input")
message: AnyMessage = messages[-1]
if not isinstance(message, AIMessage):
raise ValueError("Last message is not an AIMessage")
return output_type, message
def _func(
self, input: Union[list[AnyMessage], dict[str, Any]], config: RunnableConfig
) -> Any:
"""Validate and run tool calls synchronously."""
output_type, message = self._get_message(input)
def run_one(call: ToolCall) -> ToolMessage:
schema = self.schemas_by_name[call["name"]]
try:
if issubclass(schema, BaseModel):
output = schema.model_validate(call["args"])
content = output.model_dump_json()
elif issubclass(schema, BaseModelV1):
output = schema.validate(call["args"])
content = output.json()
else:
raise ValueError(
f"Unsupported schema type: {type(schema)}. Expected BaseModel or BaseModelV1."
)
return ToolMessage(
content=content,
name=call["name"],
tool_call_id=cast(str, call["id"]),
)
except (ValidationError, ValidationErrorV1) as e:
return ToolMessage(
content=self._format_error(e, call, schema),
name=call["name"],
tool_call_id=cast(str, call["id"]),
additional_kwargs={"is_error": True},
)
with get_executor_for_config(config) as executor:
outputs = [*executor.map(run_one, message.tool_calls)]
if output_type == "list":
return outputs
else:
return {"messages": outputs}
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py`:
```py
from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union, cast
from langchain_core.language_models import BaseChatModel, LanguageModelLike
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage
from langchain_core.runnables import (
Runnable,
RunnableBinding,
RunnableConfig,
)
from langchain_core.tools import BaseTool
from typing_extensions import Annotated, TypedDict
from langgraph._api.deprecation import deprecated_parameter
from langgraph.errors import ErrorCode, create_error_message
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
from langgraph.managed import IsLastStep, RemainingSteps
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.prebuilt.tool_node import ToolNode
from langgraph.store.base import BaseStore
from langgraph.types import Checkpointer
from langgraph.utils.runnable import RunnableCallable
# We create the AgentState that we will pass around
# This simply involves a list of messages
# We want steps to return messages to append to the list
# So we annotate the messages attribute with operator.add
class AgentState(TypedDict):
"""The state of the agent."""
messages: Annotated[Sequence[BaseMessage], add_messages]
is_last_step: IsLastStep
remaining_steps: RemainingSteps
StateSchema = TypeVar("StateSchema", bound=AgentState)
StateSchemaType = Type[StateSchema]
STATE_MODIFIER_RUNNABLE_NAME = "StateModifier"
MessagesModifier = Union[
SystemMessage,
str,
Callable[[Sequence[BaseMessage]], Sequence[BaseMessage]],
Runnable[Sequence[BaseMessage], Sequence[BaseMessage]],
]
StateModifier = Union[
SystemMessage,
str,
Callable[[StateSchema], Sequence[BaseMessage]],
Runnable[StateSchema, Sequence[BaseMessage]],
]
def _get_state_modifier_runnable(
state_modifier: Optional[StateModifier], store: Optional[BaseStore] = None
) -> Runnable:
state_modifier_runnable: Runnable
if state_modifier is None:
state_modifier_runnable = RunnableCallable(
lambda state: state["messages"], name=STATE_MODIFIER_RUNNABLE_NAME
)
elif isinstance(state_modifier, str):
_system_message: BaseMessage = SystemMessage(content=state_modifier)
state_modifier_runnable = RunnableCallable(
lambda state: [_system_message] + state["messages"],
name=STATE_MODIFIER_RUNNABLE_NAME,
)
elif isinstance(state_modifier, SystemMessage):
state_modifier_runnable = RunnableCallable(
lambda state: [state_modifier] + state["messages"],
name=STATE_MODIFIER_RUNNABLE_NAME,
)
elif callable(state_modifier):
state_modifier_runnable = RunnableCallable(
state_modifier,
name=STATE_MODIFIER_RUNNABLE_NAME,
)
elif isinstance(state_modifier, Runnable):
state_modifier_runnable = state_modifier
else:
raise ValueError(
f"Got unexpected type for `state_modifier`: {type(state_modifier)}"
)
return state_modifier_runnable
def _convert_messages_modifier_to_state_modifier(
messages_modifier: MessagesModifier,
) -> StateModifier:
state_modifier: StateModifier
if isinstance(messages_modifier, (str, SystemMessage)):
return messages_modifier
elif callable(messages_modifier):
def state_modifier(state: AgentState) -> Sequence[BaseMessage]:
return messages_modifier(state["messages"])
return state_modifier
elif isinstance(messages_modifier, Runnable):
state_modifier = (lambda state: state["messages"]) | messages_modifier
return state_modifier
raise ValueError(
f"Got unexpected type for `messages_modifier`: {type(messages_modifier)}"
)
def _get_model_preprocessing_runnable(
state_modifier: Optional[StateModifier],
messages_modifier: Optional[MessagesModifier],
store: Optional[BaseStore],
) -> Runnable:
# Add the state or message modifier, if exists
if state_modifier is not None and messages_modifier is not None:
raise ValueError(
"Expected value for either state_modifier or messages_modifier, got values for both"
)
if state_modifier is None and messages_modifier is not None:
state_modifier = _convert_messages_modifier_to_state_modifier(messages_modifier)
return _get_state_modifier_runnable(state_modifier, store)
def _should_bind_tools(model: LanguageModelLike, tools: Sequence[BaseTool]) -> bool:
if not isinstance(model, RunnableBinding):
return True
if "tools" not in model.kwargs:
return True
bound_tools = model.kwargs["tools"]
if len(tools) != len(bound_tools):
raise ValueError(
"Number of tools in the model.bind_tools() and tools passed to create_react_agent must match"
)
tool_names = set(tool.name for tool in tools)
bound_tool_names = set()
for bound_tool in bound_tools:
# OpenAI-style tool
if bound_tool.get("type") == "function":
bound_tool_name = bound_tool["function"]["name"]
# Anthropic-style tool
elif bound_tool.get("name"):
bound_tool_name = bound_tool["name"]
else:
# unknown tool type so we'll ignore it
continue
bound_tool_names.add(bound_tool_name)
if missing_tools := tool_names - bound_tool_names:
raise ValueError(f"Missing tools '{missing_tools}' in the model.bind_tools()")
return False
def _validate_chat_history(
messages: Sequence[BaseMessage],
) -> None:
"""Validate that all tool calls in AIMessages have a corresponding ToolMessage."""
all_tool_calls = [
tool_call
for message in messages
if isinstance(message, AIMessage)
for tool_call in message.tool_calls
]
tool_call_ids_with_results = {
message.tool_call_id for message in messages if isinstance(message, ToolMessage)
}
tool_calls_without_results = [
tool_call
for tool_call in all_tool_calls
if tool_call["id"] not in tool_call_ids_with_results
]
if not tool_calls_without_results:
return
error_message = create_error_message(
message="Found AIMessages with tool_calls that do not have a corresponding ToolMessage. "
f"Here are the first few of those tool calls: {tool_calls_without_results[:3]}.\n\n"
"Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage "
"(result of a tool invocation to return to the LLM) - this is required by most LLM providers.",
error_code=ErrorCode.INVALID_CHAT_HISTORY,
)
raise ValueError(error_message)
@deprecated_parameter("messages_modifier", "0.1.9", "state_modifier", removal="0.3.0")
def create_react_agent(
model: LanguageModelLike,
tools: Union[ToolExecutor, Sequence[BaseTool], ToolNode],
*,
state_schema: Optional[StateSchemaType] = None,
messages_modifier: Optional[MessagesModifier] = None,
state_modifier: Optional[StateModifier] = None,
checkpointer: Optional[Checkpointer] = None,
store: Optional[BaseStore] = None,
interrupt_before: Optional[list[str]] = None,
interrupt_after: Optional[list[str]] = None,
debug: bool = False,
) -> CompiledGraph:
"""Creates a graph that works with a chat model that utilizes tool calling.
Args:
model: The `LangChain` chat model that supports tool calling.
tools: A list of tools, a ToolExecutor, or a ToolNode instance.
If an empty list is provided, the agent will consist of a single LLM node without tool calling.
state_schema: An optional state schema that defines graph state.
Must have `messages` and `is_last_step` keys.
Defaults to `AgentState` that defines those two keys.
messages_modifier: An optional
messages modifier. This applies to messages BEFORE they are passed into the LLM.
Can take a few different forms:
- SystemMessage: this is added to the beginning of the list of messages.
- str: This is converted to a SystemMessage and added to the beginning of the list of messages.
- Callable: This function should take in a list of messages and the output is then passed to the language model.
- Runnable: This runnable should take in a list of messages and the output is then passed to the language model.
!!! Warning
`messages_modifier` parameter is deprecated as of version 0.1.9 and will be removed in 0.2.0
state_modifier: An optional
state modifier. This takes full graph state BEFORE the LLM is called and prepares the input to LLM.
Can take a few different forms:
- SystemMessage: this is added to the beginning of the list of messages in state["messages"].
- str: This is converted to a SystemMessage and added to the beginning of the list of messages in state["messages"].
- Callable: This function should take in full graph state and the output is then passed to the language model.
- Runnable: This runnable should take in full graph state and the output is then passed to the language model.
checkpointer: An optional checkpoint saver object. This is used for persisting
the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation).
store: An optional store object. This is used for persisting data
across multiple threads (e.g., multiple conversations / users).
interrupt_before: An optional list of node names to interrupt before.
Should be one of the following: "agent", "tools".
This is useful if you want to add a user confirmation or other interrupt before taking an action.
interrupt_after: An optional list of node names to interrupt after.
Should be one of the following: "agent", "tools".
This is useful if you want to return directly or run additional processing on an output.
debug: A flag indicating whether to enable debug mode.
Returns:
A compiled LangChain runnable that can be used for chat interactions.
The resulting graph looks like this:
``` mermaid
stateDiagram-v2
[*] --> Start
Start --> Agent
Agent --> Tools : continue
Tools --> Agent
Agent --> End : end
End --> [*]
classDef startClass fill:#ffdfba;
classDef endClass fill:#baffc9;
classDef otherClass fill:#fad7de;
class Start startClass
class End endClass
class Agent,Tools otherClass
```
The "agent" node calls the language model with the messages list (after applying the messages modifier).
If the resulting AIMessage contains `tool_calls`, the graph will then call the ["tools"][langgraph.prebuilt.tool_node.ToolNode].
The "tools" node executes the tools (1 tool per `tool_call`) and adds the responses to the messages list
as `ToolMessage` objects. The agent node then calls the language model again.
The process repeats until no more `tool_calls` are present in the response.
The agent then returns the full list of messages as a dictionary containing the key "messages".
``` mermaid
sequenceDiagram
participant U as User
participant A as Agent (LLM)
participant T as Tools
U->>A: Initial input
Note over A: Messages modifier + LLM
loop while tool_calls present
A->>T: Execute tools
T-->>A: ToolMessage for each tool_calls
end
A->>U: Return final state
```
Examples:
Use with a simple tool:
```pycon
>>> from datetime import datetime
>>> from langchain_openai import ChatOpenAI
>>> from langgraph.prebuilt import create_react_agent
... def check_weather(location: str, at_time: datetime | None = None) -> str:
... '''Return the weather forecast for the specified location.'''
... return f"It's always sunny in {location}"
>>>
>>> tools = [check_weather]
>>> model = ChatOpenAI(model="gpt-4o")
>>> graph = create_react_agent(model, tools=tools)
>>> inputs = {"messages": [("user", "what is the weather in sf")]}
>>> for s in graph.stream(inputs, stream_mode="values"):
... message = s["messages"][-1]
... if isinstance(message, tuple):
... print(message)
... else:
... message.pretty_print()
('user', 'what is the weather in sf')
================================== Ai Message ==================================
Tool Calls:
check_weather (call_LUzFvKJRuaWQPeXvBOzwhQOu)
Call ID: call_LUzFvKJRuaWQPeXvBOzwhQOu
Args:
location: San Francisco
================================= Tool Message =================================
Name: check_weather
It's always sunny in San Francisco
================================== Ai Message ==================================
The weather in San Francisco is sunny.
```
Add a system prompt for the LLM:
```pycon
>>> system_prompt = "You are a helpful bot named Fred."
>>> graph = create_react_agent(model, tools, state_modifier=system_prompt)
>>> inputs = {"messages": [("user", "What's your name? And what's the weather in SF?")]}
>>> for s in graph.stream(inputs, stream_mode="values"):
... message = s["messages"][-1]
... if isinstance(message, tuple):
... print(message)
... else:
... message.pretty_print()
('user', "What's your name? And what's the weather in SF?")
================================== Ai Message ==================================
Hi, my name is Fred. Let me check the weather in San Francisco for you.
Tool Calls:
check_weather (call_lqhj4O0hXYkW9eknB4S41EXk)
Call ID: call_lqhj4O0hXYkW9eknB4S41EXk
Args:
location: San Francisco
================================= Tool Message =================================
Name: check_weather
It's always sunny in San Francisco
================================== Ai Message ==================================
The weather in San Francisco is currently sunny. If you need any more details or have other questions, feel free to ask!
```
Add a more complex prompt for the LLM:
```pycon
>>> from langchain_core.prompts import ChatPromptTemplate
>>> prompt = ChatPromptTemplate.from_messages([
... ("system", "You are a helpful bot named Fred."),
... ("placeholder", "{messages}"),
... ("user", "Remember, always be polite!"),
... ])
>>> def format_for_model(state: AgentState):
... # You can do more complex modifications here
... return prompt.invoke({"messages": state["messages"]})
>>>
>>> graph = create_react_agent(model, tools, state_modifier=format_for_model)
>>> inputs = {"messages": [("user", "What's your name? And what's the weather in SF?")]}
>>> for s in graph.stream(inputs, stream_mode="values"):
... message = s["messages"][-1]
... if isinstance(message, tuple):
... print(message)
... else:
... message.pretty_print()
```
Add complex prompt with custom graph state:
```pycon
>>> from typing import TypedDict
>>> prompt = ChatPromptTemplate.from_messages(
... [
... ("system", "Today is {today}"),
... ("placeholder", "{messages}"),
... ]
... )
>>>
>>> class CustomState(TypedDict):
... today: str
... messages: Annotated[list[BaseMessage], add_messages]
... is_last_step: str
>>>
>>> graph = create_react_agent(model, tools, state_schema=CustomState, state_modifier=prompt)
>>> inputs = {"messages": [("user", "What's today's date? And what's the weather in SF?")], "today": "July 16, 2004"}
>>> for s in graph.stream(inputs, stream_mode="values"):
... message = s["messages"][-1]
... if isinstance(message, tuple):
... print(message)
... else:
... message.pretty_print()
```
Add thread-level "chat memory" to the graph:
```pycon
>>> from langgraph.checkpoint.memory import MemorySaver
>>> graph = create_react_agent(model, tools, checkpointer=MemorySaver())
>>> config = {"configurable": {"thread_id": "thread-1"}}
>>> def print_stream(graph, inputs, config):
... for s in graph.stream(inputs, config, stream_mode="values"):
... message = s["messages"][-1]
... if isinstance(message, tuple):
... print(message)
... else:
... message.pretty_print()
>>> inputs = {"messages": [("user", "What's the weather in SF?")]}
>>> print_stream(graph, inputs, config)
>>> inputs2 = {"messages": [("user", "Cool, so then should i go biking today?")]}
>>> print_stream(graph, inputs2, config)
('user', "What's the weather in SF?")
================================== Ai Message ==================================
Tool Calls:
check_weather (call_ChndaktJxpr6EMPEB5JfOFYc)
Call ID: call_ChndaktJxpr6EMPEB5JfOFYc
Args:
location: San Francisco
================================= Tool Message =================================
Name: check_weather
It's always sunny in San Francisco
================================== Ai Message ==================================
The weather in San Francisco is sunny. Enjoy your day!
================================ Human Message =================================
Cool, so then should i go biking today?
================================== Ai Message ==================================
Since the weather in San Francisco is sunny, it sounds like a great day for biking! Enjoy your ride!
```
Add an interrupt to let the user confirm before taking an action:
```pycon
>>> graph = create_react_agent(
... model, tools, interrupt_before=["tools"], checkpointer=MemorySaver()
>>> )
>>> config = {"configurable": {"thread_id": "thread-1"}}
>>> inputs = {"messages": [("user", "What's the weather in SF?")]}
>>> print_stream(graph, inputs, config)
>>> snapshot = graph.get_state(config)
>>> print("Next step: ", snapshot.next)
>>> print_stream(graph, None, config)
```
Add cross-thread memory to the graph:
```pycon
>>> from langgraph.prebuilt import InjectedStore
>>> from langgraph.store.base import BaseStore
>>> def save_memory(memory: str, *, config: RunnableConfig, store: Annotated[BaseStore, InjectedStore()]) -> str:
... '''Save the given memory for the current user.'''
... # This is a **tool** the model can use to save memories to storage
... user_id = config.get("configurable", {}).get("user_id")
... namespace = ("memories", user_id)
... store.put(namespace, f"memory_{len(store.search(namespace))}", {"data": memory})
... return f"Saved memory: {memory}"
>>> def prepare_model_inputs(state: AgentState, config: RunnableConfig, store: BaseStore):
... # Retrieve user memories and add them to the system message
... # This function is called **every time** the model is prompted. It converts the state to a prompt
... user_id = config.get("configurable", {}).get("user_id")
... namespace = ("memories", user_id)
... memories = [m.value["data"] for m in store.search(namespace)]
... system_msg = f"User memories: {', '.join(memories)}"
... return [{"role": "system", "content": system_msg)] + state["messages"]
>>> from langgraph.checkpoint.memory import MemorySaver
>>> from langgraph.store.memory import InMemoryStore
>>> store = InMemoryStore()
>>> graph = create_react_agent(model, [save_memory], state_modifier=prepare_model_inputs, store=store, checkpointer=MemorySaver())
>>> config = {"configurable": {"thread_id": "thread-1", "user_id": "1"}}
>>> inputs = {"messages": [("user", "Hey I'm Will, how's it going?")]}
>>> print_stream(graph, inputs, config)
('user', "Hey I'm Will, how's it going?")
================================== Ai Message ==================================
Hello Will! It's nice to meet you. I'm doing well, thank you for asking. How are you doing today?
>>> inputs2 = {"messages": [("user", "I like to bike")]}
>>> print_stream(graph, inputs2, config)
================================ Human Message =================================
I like to bike
================================== Ai Message ==================================
That's great to hear, Will! Biking is an excellent hobby and form of exercise. It's a fun way to stay active and explore your surroundings. Do you have any favorite biking routes or trails you enjoy? Or perhaps you're into a specific type of biking, like mountain biking or road cycling?
>>> config = {"configurable": {"thread_id": "thread-2", "user_id": "1"}}
>>> inputs3 = {"messages": [("user", "Hi there! Remember me?")]}
>>> print_stream(graph, inputs3, config)
================================ Human Message =================================
Hi there! Remember me?
================================== Ai Message ==================================
User memories:
Hello! Of course, I remember you, Will! You mentioned earlier that you like to bike. It's great to hear from you again. How have you been? Have you been on any interesting bike rides lately?
```
Add a timeout for a given step:
```pycon
>>> import time
... def check_weather(location: str, at_time: datetime | None = None) -> float:
... '''Return the weather forecast for the specified location.'''
... time.sleep(2)
... return f"It's always sunny in {location}"
>>>
>>> tools = [check_weather]
>>> graph = create_react_agent(model, tools)
>>> graph.step_timeout = 1 # Seconds
>>> for s in graph.stream({"messages": [("user", "what is the weather in sf")]}):
... print(s)
TimeoutError: Timed out at step 2
```
"""
if state_schema is not None:
if missing_keys := {"messages", "is_last_step"} - set(
state_schema.__annotations__
):
raise ValueError(f"Missing required key(s) {missing_keys} in state_schema")
if isinstance(tools, ToolExecutor):
tool_classes: Sequence[BaseTool] = tools.tools
tool_node = ToolNode(tool_classes)
elif isinstance(tools, ToolNode):
tool_classes = list(tools.tools_by_name.values())
tool_node = tools
else:
tool_node = ToolNode(tools)
# get the tool functions wrapped in a tool class from the ToolNode
tool_classes = list(tool_node.tools_by_name.values())
tool_calling_enabled = len(tool_classes) > 0
if _should_bind_tools(model, tool_classes) and tool_calling_enabled:
model = cast(BaseChatModel, model).bind_tools(tool_classes)
# we're passing store here for validation
preprocessor = _get_model_preprocessing_runnable(
state_modifier, messages_modifier, store
)
model_runnable = preprocessor | model
# Define the function that calls the model
def call_model(state: AgentState, config: RunnableConfig) -> AgentState:
_validate_chat_history(state["messages"])
response = model_runnable.invoke(state, config)
has_tool_calls = isinstance(response, AIMessage) and response.tool_calls
all_tools_return_direct = (
all(call["name"] in should_return_direct for call in response.tool_calls)
if isinstance(response, AIMessage)
else False
)
if (
(
"remaining_steps" not in state
and state["is_last_step"]
and has_tool_calls
)
or (
"remaining_steps" in state
and state["remaining_steps"] < 1
and all_tools_return_direct
)
or (
"remaining_steps" in state
and state["remaining_steps"] < 2
and has_tool_calls
)
):
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, need more steps to process this request.",
)
]
}
# We return a list, because this will get added to the existing list
return {"messages": [response]}
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
_validate_chat_history(state["messages"])
response = await model_runnable.ainvoke(state, config)
has_tool_calls = isinstance(response, AIMessage) and response.tool_calls
all_tools_return_direct = (
all(call["name"] in should_return_direct for call in response.tool_calls)
if isinstance(response, AIMessage)
else False
)
if (
(
"remaining_steps" not in state
and state["is_last_step"]
and has_tool_calls
)
or (
"remaining_steps" in state
and state["remaining_steps"] < 1
and all_tools_return_direct
)
or (
"remaining_steps" in state
and state["remaining_steps"] < 2
and has_tool_calls
)
):
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, need more steps to process this request.",
)
]
}
# We return a list, because this will get added to the existing list
return {"messages": [response]}
if not tool_calling_enabled:
# Define a new graph
workflow = StateGraph(state_schema or AgentState)
workflow.add_node("agent", RunnableCallable(call_model, acall_model))
workflow.set_entry_point("agent")
return workflow.compile(
checkpointer=checkpointer,
store=store,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
)
# Define the function that determines whether to continue or not
def should_continue(state: AgentState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return "__end__"
# Otherwise if there is, we continue
else:
return "tools"
# Define a new graph
workflow = StateGraph(state_schema or AgentState)
# Define the two nodes we will cycle between
workflow.add_node("agent", RunnableCallable(call_model, acall_model))
workflow.add_node("tools", tool_node)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
)
# If any of the tools are configured to return_directly after running,
# our graph needs to check if these were called
should_return_direct = {t.name for t in tool_classes if t.return_direct}
def route_tool_responses(state: AgentState) -> Literal["agent", "__end__"]:
for m in reversed(state["messages"]):
if not isinstance(m, ToolMessage):
break
if m.name in should_return_direct:
return "__end__"
return "agent"
if should_return_direct:
workflow.add_conditional_edges("tools", route_tool_responses)
else:
workflow.add_edge("tools", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
return workflow.compile(
checkpointer=checkpointer,
store=store,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
)
# Keep for backwards compatibility
create_tool_calling_executor = create_react_agent
__all__ = [
"create_react_agent",
"create_tool_calling_executor",
"AgentState",
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/prebuilt/tool_node.py`:
```py
from __future__ import annotations
import asyncio
import inspect
import json
from copy import copy
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
get_type_hints,
)
from langchain_core.messages import (
AIMessage,
AnyMessage,
ToolCall,
ToolMessage,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import (
get_config_list,
get_executor_for_config,
)
from langchain_core.runnables.utils import Input
from langchain_core.tools import BaseTool, InjectedToolArg
from langchain_core.tools import tool as create_tool
from langchain_core.tools.base import get_all_basemodel_annotations
from typing_extensions import Annotated, get_args, get_origin
from langgraph.errors import GraphBubbleUp
from langgraph.store.base import BaseStore
from langgraph.utils.runnable import RunnableCallable
if TYPE_CHECKING:
from pydantic import BaseModel
INVALID_TOOL_NAME_ERROR_TEMPLATE = (
"Error: {requested_tool} is not a valid tool, try one of [{available_tools}]."
)
TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
def msg_content_output(output: Any) -> str | List[dict]:
recognized_content_block_types = ("image", "image_url", "text", "json")
if isinstance(output, str):
return output
elif all(
[
isinstance(x, dict) and x.get("type") in recognized_content_block_types
for x in output
]
):
return output
# Technically a list of strings is also valid message content but it's not currently
# well tested that all chat models support this. And for backwards compatibility
# we want to make sure we don't break any existing ToolNode usage.
else:
try:
return json.dumps(output, ensure_ascii=False)
except Exception:
return str(output)
def _handle_tool_error(
e: Exception,
*,
flag: Union[
bool,
str,
Callable[..., str],
tuple[type[Exception], ...],
],
) -> str:
if isinstance(flag, (bool, tuple)):
content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
elif isinstance(flag, str):
content = flag
elif callable(flag):
content = flag(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {flag}"
)
return content
def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
if params:
# If it's a method, the first argument is typically 'self' or 'cls'
if params[0].name in ["self", "cls"] and len(params) == 2:
first_param = params[1]
else:
first_param = params[0]
type_hints = get_type_hints(handler)
if first_param.name in type_hints:
origin = get_origin(first_param.annotation)
if origin is Union:
args = get_args(first_param.annotation)
if all(issubclass(arg, Exception) for arg in args):
return tuple(args)
else:
raise ValueError(
"All types in the error handler error annotation must be Exception types. "
"For example, `def custom_handler(e: Union[ValueError, TypeError])`. "
f"Got '{first_param.annotation}' instead."
)
exception_type = type_hints[first_param.name]
if Exception in exception_type.__mro__:
return (exception_type,)
else:
raise ValueError(
f"Arbitrary types are not supported in the error handler signature. "
"Please annotate the error with either a specific Exception type or a union of Exception types. "
"For example, `def custom_handler(e: ValueError)` or `def custom_handler(e: Union[ValueError, TypeError])`. "
f"Got '{exception_type}' instead."
)
# If no type information is available, return (Exception,) for backwards compatibility.
return (Exception,)
class ToolNode(RunnableCallable):
"""A node that runs the tools called in the last AIMessage.
It can be used either in StateGraph with a "messages" state key (or a custom key passed via ToolNode's 'messages_key').
If multiple tool calls are requested, they will be run in parallel. The output will be
a list of ToolMessages, one for each tool call.
Args:
tools: A sequence of tools that can be invoked by the ToolNode.
name: The name of the ToolNode in the graph. Defaults to "tools".
tags: Optional tags to associate with the node. Defaults to None.
handle_tool_errors: How to handle tool errors raised by tools inside the node. Defaults to True.
Must be one of the following:
- True: all errors will be caught and
a ToolMessage with a default error message (TOOL_CALL_ERROR_TEMPLATE) will be returned.
- str: all errors will be caught and
a ToolMessage with the string value of 'handle_tool_errors' will be returned.
- tuple[type[Exception], ...]: exceptions in the tuple will be caught and
a ToolMessage with a default error message (TOOL_CALL_ERROR_TEMPLATE) will be returned.
- Callable[..., str]: exceptions from the signature of the callable will be caught and
a ToolMessage with the string value of the result of the 'handle_tool_errors' callable will be returned.
- False: none of the errors raised by the tools will be caught
messages_key: The state key in the input that contains the list of messages.
The same key will be used for the output from the ToolNode.
Defaults to "messages".
The `ToolNode` is roughly analogous to:
```python
tools_by_name = {tool.name: tool for tool in tools}
def tool_node(state: dict):
result = []
for tool_call in state["messages"][-1].tool_calls:
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": result}
```
Important:
- The state MUST contain a list of messages.
- The last message MUST be an `AIMessage`.
- The `AIMessage` MUST have `tool_calls` populated.
"""
name: str = "ToolNode"
def __init__(
self,
tools: Sequence[Union[BaseTool, Callable]],
*,
name: str = "tools",
tags: Optional[list[str]] = None,
handle_tool_errors: Union[
bool, str, Callable[..., str], tuple[type[Exception], ...]
] = True,
messages_key: str = "messages",
) -> None:
super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
self.tools_by_name: Dict[str, BaseTool] = {}
self.tool_to_state_args: Dict[str, Dict[str, Optional[str]]] = {}
self.tool_to_store_arg: Dict[str, Optional[str]] = {}
self.handle_tool_errors = handle_tool_errors
self.messages_key = messages_key
for tool_ in tools:
if not isinstance(tool_, BaseTool):
tool_ = create_tool(tool_)
self.tools_by_name[tool_.name] = tool_
self.tool_to_state_args[tool_.name] = _get_state_args(tool_)
self.tool_to_store_arg[tool_.name] = _get_store_arg(tool_)
def _func(
self,
input: Union[
list[AnyMessage],
dict[str, Any],
BaseModel,
],
config: RunnableConfig,
*,
store: BaseStore,
) -> Any:
tool_calls, output_type = self._parse_input(input, store)
config_list = get_config_list(config, len(tool_calls))
with get_executor_for_config(config) as executor:
outputs = [*executor.map(self._run_one, tool_calls, config_list)]
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if output_type == "list" else {self.messages_key: outputs}
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
if "store" not in kwargs:
kwargs["store"] = None
return super().invoke(input, config, **kwargs)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
if "store" not in kwargs:
kwargs["store"] = None
return await super().ainvoke(input, config, **kwargs)
async def _afunc(
self,
input: Union[
list[AnyMessage],
dict[str, Any],
BaseModel,
],
config: RunnableConfig,
*,
store: BaseStore,
) -> Any:
tool_calls, output_type = self._parse_input(input, store)
outputs = await asyncio.gather(
*(self._arun_one(call, config) for call in tool_calls)
)
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if output_type == "list" else {self.messages_key: outputs}
def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message
try:
input = {**call, **{"type": "tool_call"}}
tool_message: ToolMessage = self.tools_by_name[call["name"]].invoke(
input, config
)
tool_message.content = cast(
Union[str, list], msg_content_output(tool_message.content)
)
return tool_message
# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
# (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
except GraphBubbleUp as e:
raise e
except Exception as e:
if isinstance(self.handle_tool_errors, tuple):
handled_types: tuple = self.handle_tool_errors
elif callable(self.handle_tool_errors):
handled_types = _infer_handled_types(self.handle_tool_errors)
else:
# default behavior is catching all exceptions
handled_types = (Exception,)
# Unhandled
if not self.handle_tool_errors or not isinstance(e, handled_types):
raise e
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)
return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)
async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message
try:
input = {**call, **{"type": "tool_call"}}
tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke(
input, config
)
tool_message.content = cast(
Union[str, list], msg_content_output(tool_message.content)
)
return tool_message
# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
# (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
except GraphBubbleUp as e:
raise e
except Exception as e:
if isinstance(self.handle_tool_errors, tuple):
handled_types: tuple = self.handle_tool_errors
elif callable(self.handle_tool_errors):
handled_types = _infer_handled_types(self.handle_tool_errors)
else:
# default behavior is catching all exceptions
handled_types = (Exception,)
# Unhandled
if not self.handle_tool_errors or not isinstance(e, handled_types):
raise e
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)
return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)
def _parse_input(
self,
input: Union[
list[AnyMessage],
dict[str, Any],
BaseModel,
],
store: BaseStore,
) -> Tuple[List[ToolCall], Literal["list", "dict"]]:
if isinstance(input, list):
output_type = "list"
message: AnyMessage = input[-1]
elif isinstance(input, dict) and (messages := input.get(self.messages_key, [])):
output_type = "dict"
message = messages[-1]
elif messages := getattr(input, self.messages_key, None):
# Assume dataclass-like state that can coerce from dict
output_type = "dict"
message = messages[-1]
else:
raise ValueError("No message found in input")
if not isinstance(message, AIMessage):
raise ValueError("Last message is not an AIMessage")
tool_calls = [
self._inject_tool_args(call, input, store) for call in message.tool_calls
]
return tool_calls, output_type
def _validate_tool_call(self, call: ToolCall) -> Optional[ToolMessage]:
if (requested_tool := call["name"]) not in self.tools_by_name:
content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format(
requested_tool=requested_tool,
available_tools=", ".join(self.tools_by_name.keys()),
)
return ToolMessage(
content, name=requested_tool, tool_call_id=call["id"], status="error"
)
else:
return None
def _inject_state(
self,
tool_call: ToolCall,
input: Union[
list[AnyMessage],
dict[str, Any],
BaseModel,
],
) -> ToolCall:
state_args = self.tool_to_state_args[tool_call["name"]]
if state_args and isinstance(input, list):
required_fields = list(state_args.values())
if (
len(required_fields) == 1
and required_fields[0] == self.messages_key
or required_fields[0] is None
):
input = {self.messages_key: input}
else:
err_msg = (
f"Invalid input to ToolNode. Tool {tool_call['name']} requires "
f"graph state dict as input."
)
if any(state_field for state_field in state_args.values()):
required_fields_str = ", ".join(f for f in required_fields if f)
err_msg += f" State should contain fields {required_fields_str}."
raise ValueError(err_msg)
if isinstance(input, dict):
tool_state_args = {
tool_arg: input[state_field] if state_field else input
for tool_arg, state_field in state_args.items()
}
else:
tool_state_args = {
tool_arg: getattr(input, state_field) if state_field else input
for tool_arg, state_field in state_args.items()
}
tool_call["args"] = {
**tool_call["args"],
**tool_state_args,
}
return tool_call
def _inject_store(self, tool_call: ToolCall, store: BaseStore) -> ToolCall:
store_arg = self.tool_to_store_arg[tool_call["name"]]
if not store_arg:
return tool_call
if store is None:
raise ValueError(
"Cannot inject store into tools with InjectedStore annotations - "
"please compile your graph with a store."
)
tool_call["args"] = {
**tool_call["args"],
store_arg: store,
}
return tool_call
def _inject_tool_args(
self,
tool_call: ToolCall,
input: Union[
list[AnyMessage],
dict[str, Any],
BaseModel,
],
store: BaseStore,
) -> ToolCall:
if tool_call["name"] not in self.tools_by_name:
return tool_call
tool_call_copy: ToolCall = copy(tool_call)
tool_call_with_state = self._inject_state(tool_call_copy, input)
tool_call_with_store = self._inject_store(tool_call_with_state, store)
return tool_call_with_store
def tools_condition(
state: Union[list[AnyMessage], dict[str, Any], BaseModel],
messages_key: str = "messages",
) -> Literal["tools", "__end__"]:
"""Use in the conditional_edge to route to the ToolNode if the last message
has tool calls. Otherwise, route to the end.
Args:
state (Union[list[AnyMessage], dict[str, Any], BaseModel]): The state to check for
tool calls. Must have a list of messages (MessageGraph) or have the
"messages" key (StateGraph).
Returns:
The next node to route to.
Examples:
Create a custom ReAct-style agent with tools.
```pycon
>>> from langchain_anthropic import ChatAnthropic
>>> from langchain_core.tools import tool
...
>>> from langgraph.graph import StateGraph
>>> from langgraph.prebuilt import ToolNode, tools_condition
>>> from langgraph.graph.message import add_messages
...
>>> from typing import TypedDict, Annotated
...
>>> @tool
>>> def divide(a: float, b: float) -> int:
... \"\"\"Return a / b.\"\"\"
... return a / b
...
>>> llm = ChatAnthropic(model="claude-3-haiku-20240307")
>>> tools = [divide]
...
>>> class State(TypedDict):
... messages: Annotated[list, add_messages]
>>>
>>> graph_builder = StateGraph(State)
>>> graph_builder.add_node("tools", ToolNode(tools))
>>> graph_builder.add_node("chatbot", lambda state: {"messages":llm.bind_tools(tools).invoke(state['messages'])})
>>> graph_builder.add_edge("tools", "chatbot")
>>> graph_builder.add_conditional_edges(
... "chatbot", tools_condition
... )
>>> graph_builder.set_entry_point("chatbot")
>>> graph = graph_builder.compile()
>>> graph.invoke({"messages": {"role": "user", "content": "What's 329993 divided by 13662?"}})
```
"""
if isinstance(state, list):
ai_message = state[-1]
elif isinstance(state, dict) and (messages := state.get(messages_key, [])):
ai_message = messages[-1]
elif messages := getattr(state, messages_key, []):
ai_message = messages[-1]
else:
raise ValueError(f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return "__end__"
class InjectedState(InjectedToolArg):
"""Annotation for a Tool arg that is meant to be populated with the graph state.
Any Tool argument annotated with InjectedState will be hidden from a tool-calling
model, so that the model doesn't attempt to generate the argument. If using
ToolNode, the appropriate graph state field will be automatically injected into
the model-generated tool args.
Args:
field: The key from state to insert. If None, the entire state is expected to
be passed in.
Example:
```python
from typing import List
from typing_extensions import Annotated, TypedDict
from langchain_core.messages import BaseMessage, AIMessage
from langchain_core.tools import tool
from langgraph.prebuilt import InjectedState, ToolNode
class AgentState(TypedDict):
messages: List[BaseMessage]
foo: str
@tool
def state_tool(x: int, state: Annotated[dict, InjectedState]) -> str:
'''Do something with state.'''
if len(state["messages"]) > 2:
return state["foo"] + str(x)
else:
return "not enough messages"
@tool
def foo_tool(x: int, foo: Annotated[str, InjectedState("foo")]) -> str:
'''Do something else with state.'''
return foo + str(x + 1)
node = ToolNode([state_tool, foo_tool])
tool_call1 = {"name": "state_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"}
tool_call2 = {"name": "foo_tool", "args": {"x": 1}, "id": "2", "type": "tool_call"}
state = {
"messages": [AIMessage("", tool_calls=[tool_call1, tool_call2])],
"foo": "bar",
}
node.invoke(state)
```
```pycon
[
ToolMessage(content='not enough messages', name='state_tool', tool_call_id='1'),
ToolMessage(content='bar2', name='foo_tool', tool_call_id='2')
]
```
""" # noqa: E501
def __init__(self, field: Optional[str] = None) -> None:
self.field = field
class InjectedStore(InjectedToolArg):
"""Annotation for a Tool arg that is meant to be populated with LangGraph store.
Any Tool argument annotated with InjectedStore will be hidden from a tool-calling
model, so that the model doesn't attempt to generate the argument. If using
ToolNode, the appropriate store field will be automatically injected into
the model-generated tool args. Note: if a graph is compiled with a store object,
the store will be automatically propagated to the tools with InjectedStore args
when using ToolNode.
!!! Warning
`InjectedStore` annotation requires `langchain-core >= 0.3.8`
Example:
```python
from typing import Any
from typing_extensions import Annotated
from langchain_core.messages import AIMessage
from langchain_core.tools import tool
from langgraph.store.memory import InMemoryStore
from langgraph.prebuilt import InjectedStore, ToolNode
store = InMemoryStore()
store.put(("values",), "foo", {"bar": 2})
@tool
def store_tool(x: int, my_store: Annotated[Any, InjectedStore()]) -> str:
'''Do something with store.'''
stored_value = my_store.get(("values",), "foo").value["bar"]
return stored_value + x
node = ToolNode([store_tool])
tool_call = {"name": "store_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"}
state = {
"messages": [AIMessage("", tool_calls=[tool_call])],
}
node.invoke(state, store=store)
```
```pycon
{
"messages": [
ToolMessage(content='3', name='store_tool', tool_call_id='1'),
]
}
```
""" # noqa: E501
def _is_injection(
type_arg: Any, injection_type: Union[Type[InjectedState], Type[InjectedStore]]
) -> bool:
if isinstance(type_arg, injection_type) or (
isinstance(type_arg, type) and issubclass(type_arg, injection_type)
):
return True
origin_ = get_origin(type_arg)
if origin_ is Union or origin_ is Annotated:
return any(_is_injection(ta, injection_type) for ta in get_args(type_arg))
return False
def _get_state_args(tool: BaseTool) -> Dict[str, Optional[str]]:
full_schema = tool.get_input_schema()
tool_args_to_state_fields: Dict = {}
for name, type_ in get_all_basemodel_annotations(full_schema).items():
injections = [
type_arg
for type_arg in get_args(type_)
if _is_injection(type_arg, InjectedState)
]
if len(injections) > 1:
raise ValueError(
"A tool argument should not be annotated with InjectedState more than "
f"once. Received arg {name} with annotations {injections}."
)
elif len(injections) == 1:
injection = injections[0]
if isinstance(injection, InjectedState) and injection.field:
tool_args_to_state_fields[name] = injection.field
else:
tool_args_to_state_fields[name] = None
else:
pass
return tool_args_to_state_fields
def _get_store_arg(tool: BaseTool) -> Optional[str]:
full_schema = tool.get_input_schema()
for name, type_ in get_all_basemodel_annotations(full_schema).items():
injections = [
type_arg
for type_arg in get_args(type_)
if _is_injection(type_arg, InjectedStore)
]
if len(injections) > 1:
ValueError(
"A tool argument should not be annotated with InjectedStore more than "
f"once. Received arg {name} with annotations {injections}."
)
elif len(injections) == 1:
return name
else:
pass
return None
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/prebuilt/tool_executor.py`:
```py
from typing import Any, Callable, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langchain_core.tools import tool as create_tool
from langgraph._api.deprecation import deprecated
from langgraph.utils.runnable import RunnableCallable
INVALID_TOOL_MSG_TEMPLATE = (
"{requested_tool_name} is not a valid tool, "
"try one of [{available_tool_names_str}]."
)
@deprecated("0.2.0", "langgraph.prebuilt.ToolNode", removal="0.3.0")
class ToolInvocationInterface:
"""Interface for invoking a tool.
Attributes:
tool (str): The name of the tool to invoke.
tool_input (Union[str, dict]): The input to pass to the tool.
"""
tool: str
tool_input: Union[str, dict]
@deprecated("0.2.0", "langgraph.prebuilt.ToolNode", removal="0.3.0")
class ToolInvocation(Serializable):
"""Information about how to invoke a tool.
Attributes:
tool (str): The name of the Tool to execute.
tool_input (Union[str, dict]): The input to pass in to the Tool.
Examples:
Basic usage:
```pycon
>>> invocation = ToolInvocation(
... tool="search",
... tool_input="What is the capital of France?"
... )
```
"""
tool: str
tool_input: Union[str, dict]
@deprecated("0.2.0", "langgraph.prebuilt.ToolNode", removal="0.3.0")
class ToolExecutor(RunnableCallable):
"""Executes a tool invocation.
Args:
tools (Sequence[BaseTool]): A sequence of tools that can be invoked.
invalid_tool_msg_template (str, optional): The template for the error message
when an invalid tool is requested. Defaults to INVALID_TOOL_MSG_TEMPLATE.
Examples:
Basic usage:
```pycon
>>> from langchain_core.tools import tool
>>> from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
...
...
>>> @tool
... def search(query: str) -> str:
... \"\"\"Search engine.\"\"\"
... return f"Searching for: {query}"
...
...
>>> tools = [search]
>>> executor = ToolExecutor(tools)
...
>>> invocation = ToolInvocation(tool="search", tool_input="What is the capital of France?")
>>> result = executor.invoke(invocation)
>>> print(result)
"Searching for: What is the capital of France?"
```
Handling invalid tool:
```pycon
>>> invocation = ToolInvocation(
... tool="nonexistent", tool_input="What is the capital of France?"
... )
>>> result = executor.invoke(invocation)
>>> print(result)
"nonexistent is not a valid tool, try one of [search]."
```
"""
def __init__(
self,
tools: Sequence[Union[BaseTool, Callable]],
*,
invalid_tool_msg_template: str = INVALID_TOOL_MSG_TEMPLATE,
) -> None:
super().__init__(self._execute, afunc=self._aexecute, trace=False)
tools_ = [
tool if isinstance(tool, BaseTool) else create_tool(tool) for tool in tools
]
self.tools = tools_
self.tool_map = {t.name: t for t in tools_}
self.invalid_tool_msg_template = invalid_tool_msg_template
def _execute(
self, tool_invocation: ToolInvocationInterface, config: RunnableConfig
) -> Any:
if tool_invocation.tool not in self.tool_map:
return self.invalid_tool_msg_template.format(
requested_tool_name=tool_invocation.tool,
available_tool_names_str=", ".join([t.name for t in self.tools]),
)
else:
tool = self.tool_map[tool_invocation.tool]
output = tool.invoke(tool_invocation.tool_input, config)
return output
async def _aexecute(
self, tool_invocation: ToolInvocationInterface, config: RunnableConfig
) -> Any:
if tool_invocation.tool not in self.tool_map:
return self.invalid_tool_msg_template.format(
requested_tool_name=tool_invocation.tool,
available_tool_names_str=", ".join([t.name for t in self.tools]),
)
else:
tool = self.tool_map[tool_invocation.tool]
output = await tool.ainvoke(tool_invocation.tool_input, config)
return output
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/errors.py`:
```py
from enum import Enum
from typing import Any, Sequence
from langgraph.checkpoint.base import EmptyChannelError # noqa: F401
from langgraph.types import Command, Interrupt
# EmptyChannelError re-exported for backwards compatibility
class ErrorCode(Enum):
GRAPH_RECURSION_LIMIT = "GRAPH_RECURSION_LIMIT"
INVALID_CONCURRENT_GRAPH_UPDATE = "INVALID_CONCURRENT_GRAPH_UPDATE"
INVALID_GRAPH_NODE_RETURN_VALUE = "INVALID_GRAPH_NODE_RETURN_VALUE"
MULTIPLE_SUBGRAPHS = "MULTIPLE_SUBGRAPHS"
INVALID_CHAT_HISTORY = "INVALID_CHAT_HISTORY"
def create_error_message(*, message: str, error_code: ErrorCode) -> str:
return (
f"{message}\n"
"For troubleshooting, visit: https://python.langchain.com/docs/"
f"troubleshooting/errors/{error_code.value}"
)
class GraphRecursionError(RecursionError):
"""Raised when the graph has exhausted the maximum number of steps.
This prevents infinite loops. To increase the maximum number of steps,
run your graph with a config specifying a higher `recursion_limit`.
Troubleshooting Guides:
- [GRAPH_RECURSION_LIMIT](https://python.langchain.com/docs/troubleshooting/errors/GRAPH_RECURSION_LIMIT)
Examples:
graph = builder.compile()
graph.invoke(
{"messages": [("user", "Hello, world!")]},
# The config is the second positional argument
{"recursion_limit": 1000},
)
"""
pass
class InvalidUpdateError(Exception):
"""Raised when attempting to update a channel with an invalid set of updates.
Troubleshooting Guides:
- [INVALID_CONCURRENT_GRAPH_UPDATE](https://python.langchain.com/docs/troubleshooting/errors/INVALID_CONCURRENT_GRAPH_UPDATE)
- [INVALID_GRAPH_NODE_RETURN_VALUE](https://python.langchain.com/docs/troubleshooting/errors/INVALID_GRAPH_NODE_RETURN_VALUE)
"""
pass
class GraphBubbleUp(Exception):
pass
class GraphInterrupt(GraphBubbleUp):
"""Raised when a subgraph is interrupted, suppressed by the root graph.
Never raised directly, or surfaced to the user."""
def __init__(self, interrupts: Sequence[Interrupt] = ()) -> None:
super().__init__(interrupts)
class NodeInterrupt(GraphInterrupt):
"""Raised by a node to interrupt execution."""
def __init__(self, value: Any) -> None:
super().__init__([Interrupt(value=value)])
class GraphDelegate(GraphBubbleUp):
"""Raised when a graph is delegated (for distributed mode)."""
def __init__(self, *args: dict[str, Any]) -> None:
super().__init__(*args)
class ParentCommand(GraphBubbleUp):
args: tuple[Command]
def __init__(self, command: Command) -> None:
super().__init__(command)
class EmptyInputError(Exception):
"""Raised when graph receives an empty input."""
pass
class TaskNotFound(Exception):
"""Raised when the executor is unable to find a task (for distributed mode)."""
pass
class CheckpointNotLatest(Exception):
"""Raised when the checkpoint is not the latest version (for distributed mode)."""
pass
class MultipleSubgraphsError(Exception):
"""Raised when multiple subgraphs are called inside the same node.
Troubleshooting guides:
- [MULTIPLE_SUBGRAPHS](https://python.langchain.com/docs/troubleshooting/errors/MULTIPLE_SUBGRAPHS)
"""
pass
_SEEN_CHECKPOINT_NS: set[str] = set()
"""Used for subgraph detection."""
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/last_value.py`:
```py
from typing import Generic, Optional, Sequence, Type
from typing_extensions import Self
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import (
EmptyChannelError,
ErrorCode,
InvalidUpdateError,
create_error_message,
)
class LastValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, can receive at most one value per step."""
__slots__ = ("value",)
def __eq__(self, value: object) -> bool:
return isinstance(value, LastValue)
@property
def ValueType(self) -> Type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ)
empty.key = self.key
if checkpoint is not None:
empty.value = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
return False
if len(values) != 1:
msg = create_error_message(
message=f"At key '{self.key}': Can receive only one value per step. Use an Annotated key to handle multiple values.",
error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
)
raise InvalidUpdateError(msg)
self.value = values[-1]
return True
def get(self) -> Value:
try:
return self.value
except AttributeError:
raise EmptyChannelError()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/__init__.py`:
```py
from langgraph.channels.any_value import AnyValue
from langgraph.channels.binop import BinaryOperatorAggregate
from langgraph.channels.context import Context
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.last_value import LastValue
from langgraph.channels.topic import Topic
from langgraph.channels.untracked_value import UntrackedValue
__all__ = [
"LastValue",
"Topic",
"Context",
"BinaryOperatorAggregate",
"UntrackedValue",
"EphemeralValue",
"AnyValue",
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/untracked_value.py`:
```py
from typing import Generic, Optional, Sequence, Type
from typing_extensions import Self
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError, InvalidUpdateError
class UntrackedValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, never checkpointed."""
__slots__ = ("value", "guard")
def __init__(self, typ: Type[Value], guard: bool = True) -> None:
super().__init__(typ)
self.guard = guard
def __eq__(self, value: object) -> bool:
return isinstance(value, UntrackedValue) and value.guard == self.guard
@property
def ValueType(self) -> Type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ
def checkpoint(self) -> Value:
raise EmptyChannelError()
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ, self.guard)
empty.key = self.key
return empty
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
return False
if len(values) != 1 and self.guard:
raise InvalidUpdateError(
f"At key '{self.key}': UntrackedValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values."
)
self.value = values[-1]
return True
def get(self) -> Value:
try:
return self.value
except AttributeError:
raise EmptyChannelError()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/any_value.py`:
```py
from typing import Generic, Optional, Sequence, Type
from typing_extensions import Self
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError
class AnyValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, assumes that if multiple values are
received, they are all equal."""
__slots__ = ("typ", "value")
def __eq__(self, value: object) -> bool:
return isinstance(value, AnyValue)
@property
def ValueType(self) -> Type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ)
empty.key = self.key
if checkpoint is not None:
empty.value = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
try:
del self.value
return True
except AttributeError:
return False
self.value = values[-1]
return True
def get(self) -> Value:
try:
return self.value
except AttributeError:
raise EmptyChannelError()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/named_barrier_value.py`:
```py
from typing import Generic, Optional, Sequence, Type
from typing_extensions import Self
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError, InvalidUpdateError
class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]):
"""A channel that waits until all named values are received before making the value available."""
__slots__ = ("names", "seen")
names: set[Value]
seen: set[Value]
def __init__(self, typ: Type[Value], names: set[Value]) -> None:
super().__init__(typ)
self.names = names
self.seen: set[str] = set()
def __eq__(self, value: object) -> bool:
return isinstance(value, NamedBarrierValue) and value.names == self.names
@property
def ValueType(self) -> Type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ
def checkpoint(self) -> set[Value]:
return self.seen
def from_checkpoint(self, checkpoint: Optional[set[Value]]) -> Self:
empty = self.__class__(self.typ, self.names)
empty.key = self.key
if checkpoint is not None:
empty.seen = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
updated = False
for value in values:
if value in self.names:
if value not in self.seen:
self.seen.add(value)
updated = True
else:
raise InvalidUpdateError(
f"At key '{self.key}': Value {value} not in {self.names}"
)
return updated
def get(self) -> Value:
if self.seen != self.names:
raise EmptyChannelError()
return None
def consume(self) -> bool:
if self.seen == self.names:
self.seen = set()
return True
return False
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/binop.py`:
```py
import collections.abc
from typing import (
Callable,
Generic,
Optional,
Sequence,
Type,
)
from typing_extensions import NotRequired, Required, Self
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError
# Adapted from typing_extensions
def _strip_extras(t): # type: ignore[no-untyped-def]
"""Strips Annotated, Required and NotRequired from a given type."""
if hasattr(t, "__origin__"):
return _strip_extras(t.__origin__)
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
return _strip_extras(t.__args__[0])
return t
class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the result of applying a binary operator to the current value and each new value.
```python
import operator
total = Channels.BinaryOperatorAggregate(int, operator.add)
```
"""
__slots__ = ("value", "operator")
def __init__(self, typ: Type[Value], operator: Callable[[Value, Value], Value]):
super().__init__(typ)
self.operator = operator
# special forms from typing or collections.abc are not instantiable
# so we need to replace them with their concrete counterparts
typ = _strip_extras(typ)
if typ in (collections.abc.Sequence, collections.abc.MutableSequence):
typ = list
if typ in (collections.abc.Set, collections.abc.MutableSet):
typ = set
if typ in (collections.abc.Mapping, collections.abc.MutableMapping):
typ = dict
try:
self.value = typ()
except Exception:
pass
def __eq__(self, value: object) -> bool:
return isinstance(value, BinaryOperatorAggregate) and (
value.operator is self.operator
if value.operator.__name__ != "<lambda>"
and self.operator.__name__ != "<lambda>"
else True
)
@property
def ValueType(self) -> Type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ, self.operator)
empty.key = self.key
if checkpoint is not None:
empty.value = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
if not values:
return False
if not hasattr(self, "value"):
self.value = values[0]
values = values[1:]
for value in values:
self.value = self.operator(self.value, value)
return True
def get(self) -> Value:
try:
return self.value
except AttributeError:
raise EmptyChannelError()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/ephemeral_value.py`:
```py
from typing import Any, Generic, Optional, Sequence, Type
from typing_extensions import Self
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError, InvalidUpdateError
class EphemeralValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the value received in the step immediately preceding, clears after."""
__slots__ = ("value", "guard")
def __init__(self, typ: Any, guard: bool = True) -> None:
super().__init__(typ)
self.guard = guard
def __eq__(self, value: object) -> bool:
return isinstance(value, EphemeralValue) and value.guard == self.guard
@property
def ValueType(self) -> Type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ, self.guard)
empty.key = self.key
if checkpoint is not None:
empty.value = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
try:
del self.value
return True
except AttributeError:
return False
if len(values) != 1 and self.guard:
raise InvalidUpdateError(
f"At key '{self.key}': EphemeralValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values."
)
self.value = values[-1]
return True
def get(self) -> Value:
try:
return self.value
except AttributeError:
raise EmptyChannelError()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/dynamic_barrier_value.py`:
```py
from typing import Any, Generic, NamedTuple, Optional, Sequence, Type, Union
from typing_extensions import Self
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError, InvalidUpdateError
class WaitForNames(NamedTuple):
names: set[Any]
class DynamicBarrierValue(
Generic[Value], BaseChannel[Value, Union[Value, WaitForNames], set[Value]]
):
"""A channel that switches between two states
- in the "priming" state it can't be read from.
- if it receives a WaitForNames update, it switches to the "waiting" state.
- in the "waiting" state it collects named values until all are received.
- once all named values are received, it can be read once, and it switches
back to the "priming" state.
"""
__slots__ = ("names", "seen")
names: Optional[set[Value]]
seen: set[Value]
def __init__(self, typ: Type[Value]) -> None:
super().__init__(typ)
self.names = None
self.seen = set()
def __eq__(self, value: object) -> bool:
return isinstance(value, DynamicBarrierValue) and value.names == self.names
@property
def ValueType(self) -> Type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ
def checkpoint(self) -> tuple[Optional[set[Value]], set[Value]]:
return (self.names, self.seen)
def from_checkpoint(
self,
checkpoint: Optional[tuple[Optional[set[Value]], set[Value]]],
) -> Self:
empty = self.__class__(self.typ)
empty.key = self.key
if checkpoint is not None:
names, seen = checkpoint
empty.names = names if names is not None else None
empty.seen = seen
return empty
def update(self, values: Sequence[Union[Value, WaitForNames]]) -> bool:
if wait_for_names := [v for v in values if isinstance(v, WaitForNames)]:
if len(wait_for_names) > 1:
raise InvalidUpdateError(
f"At key '{self.key}': Received multiple WaitForNames updates in the same step."
)
self.names = wait_for_names[0].names
return True
elif self.names is not None:
updated = False
for value in values:
assert not isinstance(value, WaitForNames)
if value in self.names:
if value not in self.seen:
self.seen.add(value)
updated = True
else:
raise InvalidUpdateError(f"Value {value} not in {self.names}")
return updated
def get(self) -> Value:
if self.seen != self.names:
raise EmptyChannelError()
return None
def consume(self) -> bool:
if self.seen == self.names:
self.seen = set()
self.names = None
return True
return False
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/context.py`:
```py
from langgraph.managed.context import Context as ContextManagedValue
Context = ContextManagedValue.of
__all__ = ["Context"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/topic.py`:
```py
from typing import Any, Generic, Iterator, Optional, Sequence, Type, Union
from typing_extensions import Self
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError
def flatten(values: Sequence[Union[Value, list[Value]]]) -> Iterator[Value]:
for value in values:
if isinstance(value, list):
yield from value
else:
yield value
class Topic(
Generic[Value],
BaseChannel[
Sequence[Value], Union[Value, list[Value]], tuple[set[Value], list[Value]]
],
):
"""A configurable PubSub Topic.
Args:
typ: The type of the value stored in the channel.
accumulate: Whether to accumulate values across steps. If False, the channel will be emptied after each step.
"""
__slots__ = ("values", "accumulate")
def __init__(self, typ: Type[Value], accumulate: bool = False) -> None:
super().__init__(typ)
# attrs
self.accumulate = accumulate
# state
self.values = list[Value]()
def __eq__(self, value: object) -> bool:
return isinstance(value, Topic) and value.accumulate == self.accumulate
@property
def ValueType(self) -> Any:
"""The type of the value stored in the channel."""
return Sequence[self.typ] # type: ignore[name-defined]
@property
def UpdateType(self) -> Any:
"""The type of the update received by the channel."""
return Union[self.typ, list[self.typ]] # type: ignore[name-defined]
def checkpoint(self) -> tuple[set[Value], list[Value]]:
return self.values
def from_checkpoint(self, checkpoint: Optional[list[Value]]) -> Self:
empty = self.__class__(self.typ, self.accumulate)
empty.key = self.key
if checkpoint is not None:
if isinstance(checkpoint, tuple):
empty.values = checkpoint[1]
else:
empty.values = checkpoint
return empty
def update(self, values: Sequence[Union[Value, list[Value]]]) -> None:
current = list(self.values)
if not self.accumulate:
self.values = list[Value]()
if flat_values := flatten(values):
self.values.extend(flat_values)
return self.values != current
def get(self) -> Sequence[Value]:
if self.values:
return list(self.values)
else:
raise EmptyChannelError
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/langgraph/channels/base.py`:
```py
from abc import ABC, abstractmethod
from typing import Any, Generic, Optional, Sequence, Type, TypeVar
from typing_extensions import Self
from langgraph.errors import EmptyChannelError, InvalidUpdateError
Value = TypeVar("Value")
Update = TypeVar("Update")
C = TypeVar("C")
class BaseChannel(Generic[Value, Update, C], ABC):
__slots__ = ("key", "typ")
def __init__(self, typ: Type[Any], key: str = "") -> None:
self.typ = typ
self.key = key
@property
@abstractmethod
def ValueType(self) -> Any:
"""The type of the value stored in the channel."""
@property
@abstractmethod
def UpdateType(self) -> Any:
"""The type of the update received by the channel."""
# serialize/deserialize methods
def checkpoint(self) -> Optional[C]:
"""Return a serializable representation of the channel's current state.
Raises EmptyChannelError if the channel is empty (never updated yet),
or doesn't support checkpoints."""
return self.get()
@abstractmethod
def from_checkpoint(self, checkpoint: Optional[C]) -> Self:
"""Return a new identical channel, optionally initialized from a checkpoint.
If the checkpoint contains complex data structures, they should be copied."""
# state methods
@abstractmethod
def update(self, values: Sequence[Update]) -> bool:
"""Update the channel's value with the given sequence of updates.
The order of the updates in the sequence is arbitrary.
This method is called by Pregel for all channels at the end of each step.
If there are no updates, it is called with an empty sequence.
Raises InvalidUpdateError if the sequence of updates is invalid.
Returns True if the channel was updated, False otherwise."""
@abstractmethod
def get(self) -> Value:
"""Return the current value of the channel.
Raises EmptyChannelError if the channel is empty (never updated yet)."""
def consume(self) -> bool:
"""Mark the current value of the channel as consumed. By default, no-op.
This is called by Pregel before the start of the next step, for all
channels that triggered a node. If the channel was updated, return True.
"""
return False
__all__ = [
"BaseChannel",
"EmptyChannelError",
"InvalidUpdateError",
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/bench/fanout_to_subgraph.py`:
```py
import operator
from typing import Annotated, TypedDict
from langgraph.constants import END, START, Send
from langgraph.graph.state import StateGraph
def fanout_to_subgraph() -> StateGraph:
class OverallState(TypedDict):
subjects: list[str]
jokes: Annotated[list[str], operator.add]
async def continue_to_jokes(state: OverallState):
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
class JokeInput(TypedDict):
subject: str
class JokeOutput(TypedDict):
jokes: list[str]
async def bump(state: JokeOutput):
return {"jokes": [state["jokes"][0] + " a"]}
async def generate(state: JokeInput):
return {"jokes": [f"Joke about {state['subject']}"]}
async def edit(state: JokeInput):
subject = state["subject"]
return {"subject": f"{subject} - hohoho"}
async def bump_loop(state: JokeOutput):
return END if state["jokes"][0].endswith(" a" * 10) else "bump"
# subgraph
subgraph = StateGraph(input=JokeInput, output=JokeOutput)
subgraph.add_node("edit", edit)
subgraph.add_node("generate", generate)
subgraph.add_node("bump", bump)
subgraph.set_entry_point("edit")
subgraph.add_edge("edit", "generate")
subgraph.add_edge("generate", "bump")
subgraph.add_conditional_edges("bump", bump_loop)
subgraph.set_finish_point("generate")
subgraphc = subgraph.compile()
# parent graph
builder = StateGraph(OverallState)
builder.add_node("generate_joke", subgraphc)
builder.add_conditional_edges(START, continue_to_jokes)
builder.add_edge("generate_joke", END)
return builder
def fanout_to_subgraph_sync() -> StateGraph:
class OverallState(TypedDict):
subjects: list[str]
jokes: Annotated[list[str], operator.add]
def continue_to_jokes(state: OverallState):
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
class JokeInput(TypedDict):
subject: str
class JokeOutput(TypedDict):
jokes: list[str]
def bump(state: JokeOutput):
return {"jokes": [state["jokes"][0] + " a"]}
def generate(state: JokeInput):
return {"jokes": [f"Joke about {state['subject']}"]}
def edit(state: JokeInput):
subject = state["subject"]
return {"subject": f"{subject} - hohoho"}
def bump_loop(state: JokeOutput):
return END if state["jokes"][0].endswith(" a" * 10) else "bump"
# subgraph
subgraph = StateGraph(input=JokeInput, output=JokeOutput)
subgraph.add_node("edit", edit)
subgraph.add_node("generate", generate)
subgraph.add_node("bump", bump)
subgraph.set_entry_point("edit")
subgraph.add_edge("edit", "generate")
subgraph.add_edge("generate", "bump")
subgraph.add_conditional_edges("bump", bump_loop)
subgraph.set_finish_point("generate")
subgraphc = subgraph.compile()
# parent graph
builder = StateGraph(OverallState)
builder.add_node("generate_joke", subgraphc)
builder.add_conditional_edges(START, continue_to_jokes)
builder.add_edge("generate_joke", END)
return builder
if __name__ == "__main__":
import asyncio
import random
import uvloop
from langgraph.checkpoint.memory import MemorySaver
graph = fanout_to_subgraph().compile(checkpointer=MemorySaver())
input = {
"subjects": [
random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(1000)
]
}
config = {"configurable": {"thread_id": "1"}}
async def run():
len([c async for c in graph.astream(input, config=config)])
uvloop.install()
asyncio.run(run())
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/bench/react_agent.py`:
```py
from typing import Any, Optional
from uuid import uuid4
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import StructuredTool
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.prebuilt.chat_agent_executor import create_react_agent
from langgraph.pregel import Pregel
def react_agent(n_tools: int, checkpointer: Optional[BaseCheckpointSaver]) -> Pregel:
class FakeFuntionChatModel(FakeMessagesListChatModel):
def bind_tools(self, functions: list):
return self
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = self.responses[self.i].copy()
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
generation = ChatGeneration(message=response)
return ChatResult(generations=[generation])
tool = StructuredTool.from_function(
lambda query: f"result for query: {query}" * 10,
name=str(uuid4()),
description="",
)
model = FakeFuntionChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"id": str(uuid4()),
"name": tool.name,
"args": {"query": str(uuid4()) * 100},
}
],
id=str(uuid4()),
)
for _ in range(n_tools)
]
+ [
AIMessage(content="answer" * 100, id=str(uuid4())),
]
)
return create_react_agent(model, [tool], checkpointer=checkpointer)
if __name__ == "__main__":
import asyncio
import uvloop
from langgraph.checkpoint.memory import MemorySaver
graph = react_agent(100, checkpointer=MemorySaver())
input = {"messages": [HumanMessage("hi?")]}
config = {"configurable": {"thread_id": "1"}, "recursion_limit": 20000000000}
async def run():
len([c async for c in graph.astream(input, config=config)])
uvloop.install()
asyncio.run(run())
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/bench/wide_state.py`:
```py
import operator
from dataclasses import dataclass, field
from functools import partial
from typing import Annotated, Optional, Sequence
from langgraph.constants import END, START
from langgraph.graph.state import StateGraph
def wide_state(n: int) -> StateGraph:
@dataclass(kw_only=True)
class State:
messages: Annotated[list, operator.add] = field(default_factory=list)
trigger_events: Annotated[list, operator.add] = field(default_factory=list)
"""The external events that are converted by the graph."""
primary_issue_medium: Annotated[str, lambda x, y: y or x] = field(
default="email"
)
autoresponse: Annotated[Optional[dict], lambda _, y: y] = field(
default=None
) # Always overwrite
issue: Annotated[dict | None, lambda x, y: y if y else x] = field(default=None)
relevant_rules: Optional[list[dict]] = field(default=None)
"""SOPs fetched from the rulebook that are relevant to the current conversation."""
memory_docs: Optional[list[dict]] = field(default=None)
"""Memory docs fetched from the memory service that are relevant to the current conversation."""
categorizations: Annotated[list[dict], operator.add] = field(
default_factory=list
)
"""The issue categorizations auto-generated by the AI."""
responses: Annotated[list[dict], operator.add] = field(default_factory=list)
"""The draft responses recommended by the AI."""
user_info: Annotated[Optional[dict], lambda x, y: y if y is not None else x] = (
field(default=None)
)
"""The current user state (by email)."""
crm_info: Annotated[Optional[dict], lambda x, y: y if y is not None else x] = (
field(default=None)
)
"""The CRM information for organization the current user is from."""
email_thread_id: Annotated[
Optional[str], lambda x, y: y if y is not None else x
] = field(default=None)
"""The current email thread ID."""
slack_participants: Annotated[dict, operator.or_] = field(default_factory=dict)
"""The growing list of current slack participants."""
bot_id: Optional[str] = field(default=None)
"""The ID of the bot user in the slack channel."""
notified_assignees: Annotated[dict, operator.or_] = field(default_factory=dict)
def read_write(read: str, write: Sequence[str], input: State) -> dict:
val = getattr(input, read)
val_single = val[-1] if isinstance(val, list) else val
val_list = val if isinstance(val, list) else [val]
return {
k: val_list if isinstance(getattr(input, k), list) else val_single
for k in write
}
builder = StateGraph(State)
builder.add_edge(START, "one")
builder.add_node(
"one",
partial(read_write, "messages", ["trigger_events", "primary_issue_medium"]),
)
builder.add_edge("one", "two")
builder.add_node(
"two",
partial(read_write, "trigger_events", ["autoresponse", "issue"]),
)
builder.add_edge("two", "three")
builder.add_edge("two", "four")
builder.add_node(
"three",
partial(read_write, "autoresponse", ["relevant_rules"]),
)
builder.add_node(
"four",
partial(
read_write,
"trigger_events",
["categorizations", "responses", "memory_docs"],
),
)
builder.add_node(
"five",
partial(
read_write,
"categorizations",
[
"user_info",
"crm_info",
"email_thread_id",
"slack_participants",
"bot_id",
"notified_assignees",
],
),
)
builder.add_edge(["three", "four"], "five")
builder.add_edge("five", "six")
builder.add_node(
"six",
partial(read_write, "responses", ["messages"]),
)
builder.add_conditional_edges(
"six", lambda state: END if len(state.messages) > n else "one"
)
return builder
if __name__ == "__main__":
import asyncio
import uvloop
from langgraph.checkpoint.memory import MemorySaver
graph = wide_state(1000).compile(checkpointer=MemorySaver())
input = {
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(5)
}
for i in range(5)
}
]
}
config = {"configurable": {"thread_id": "1"}, "recursion_limit": 20000000000}
async def run():
async for c in graph.astream(input, config=config):
print(c.keys())
uvloop.install()
asyncio.run(run())
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/bench/__main__.py`:
```py
import random
from uuid import uuid4
from langchain_core.messages import HumanMessage
from pyperf._runner import Runner
from uvloop import new_event_loop
from bench.fanout_to_subgraph import fanout_to_subgraph, fanout_to_subgraph_sync
from bench.react_agent import react_agent
from bench.wide_state import wide_state
from langgraph.checkpoint.memory import MemorySaver
from langgraph.pregel import Pregel
async def arun(graph: Pregel, input: dict):
len(
[
c
async for c in graph.astream(
input,
{
"configurable": {"thread_id": str(uuid4())},
"recursion_limit": 1000000000,
},
)
]
)
def run(graph: Pregel, input: dict):
len(
[
c
for c in graph.stream(
input,
{
"configurable": {"thread_id": str(uuid4())},
"recursion_limit": 1000000000,
},
)
]
)
benchmarks = (
(
"fanout_to_subgraph_10x",
fanout_to_subgraph().compile(checkpointer=None),
fanout_to_subgraph_sync().compile(checkpointer=None),
{
"subjects": [
random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(10)
]
},
),
(
"fanout_to_subgraph_10x_checkpoint",
fanout_to_subgraph().compile(checkpointer=MemorySaver()),
fanout_to_subgraph_sync().compile(checkpointer=MemorySaver()),
{
"subjects": [
random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(10)
]
},
),
(
"fanout_to_subgraph_100x",
fanout_to_subgraph().compile(checkpointer=None),
fanout_to_subgraph_sync().compile(checkpointer=None),
{
"subjects": [
random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(100)
]
},
),
(
"fanout_to_subgraph_100x_checkpoint",
fanout_to_subgraph().compile(checkpointer=MemorySaver()),
fanout_to_subgraph_sync().compile(checkpointer=MemorySaver()),
{
"subjects": [
random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(100)
]
},
),
(
"react_agent_10x",
react_agent(10, checkpointer=None),
react_agent(10, checkpointer=None),
{"messages": [HumanMessage("hi?")]},
),
(
"react_agent_10x_checkpoint",
react_agent(10, checkpointer=MemorySaver()),
react_agent(10, checkpointer=MemorySaver()),
{"messages": [HumanMessage("hi?")]},
),
(
"react_agent_100x",
react_agent(100, checkpointer=None),
react_agent(100, checkpointer=None),
{"messages": [HumanMessage("hi?")]},
),
(
"react_agent_100x_checkpoint",
react_agent(100, checkpointer=MemorySaver()),
react_agent(100, checkpointer=MemorySaver()),
{"messages": [HumanMessage("hi?")]},
),
(
"wide_state_25x300",
wide_state(300).compile(checkpointer=None),
wide_state(300).compile(checkpointer=None),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(5)
}
for i in range(5)
}
]
},
),
(
"wide_state_25x300_checkpoint",
wide_state(300).compile(checkpointer=MemorySaver()),
wide_state(300).compile(checkpointer=MemorySaver()),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(5)
}
for i in range(5)
}
]
},
),
(
"wide_state_15x600",
wide_state(600).compile(checkpointer=None),
wide_state(600).compile(checkpointer=None),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(5)
}
for i in range(3)
}
]
},
),
(
"wide_state_15x600_checkpoint",
wide_state(600).compile(checkpointer=MemorySaver()),
wide_state(600).compile(checkpointer=MemorySaver()),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(5)
}
for i in range(3)
}
]
},
),
(
"wide_state_9x1200",
wide_state(1200).compile(checkpointer=None),
wide_state(1200).compile(checkpointer=None),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(3)
}
for i in range(3)
}
]
},
),
(
"wide_state_9x1200_checkpoint",
wide_state(1200).compile(checkpointer=MemorySaver()),
wide_state(1200).compile(checkpointer=MemorySaver()),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(3)
}
for i in range(3)
}
]
},
),
)
r = Runner()
for name, agraph, graph, input in benchmarks:
r.bench_async_func(name, arun, agraph, input, loop_factory=new_event_loop)
if graph is not None:
r.bench_func(name + "_sync", run, graph, input)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/pyproject.toml`:
```toml
[tool.poetry]
name = "langgraph"
version = "0.2.56"
description = "Building stateful, multi-actor applications with LLMs"
authors = []
license = "MIT"
readme = "README.md"
repository = "https://www.github.com/langchain-ai/langgraph"
[tool.poetry.dependencies]
python = ">=3.9.0,<4.0"
langchain-core = ">=0.2.43,<0.4.0,!=0.3.0,!=0.3.1,!=0.3.2,!=0.3.3,!=0.3.4,!=0.3.5,!=0.3.6,!=0.3.7,!=0.3.8,!=0.3.9,!=0.3.10,!=0.3.11,!=0.3.12,!=0.3.13,!=0.3.14"
langgraph-checkpoint = "^2.0.4"
langgraph-sdk = "^0.1.42"
[tool.poetry.group.dev.dependencies]
pytest = "^8.3.2"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
httpx = "^0.26.0"
pytest-watcher = "^0.4.1"
mypy = "^1.6.0"
ruff = "^0.6.2"
jupyter = "^1.0.0"
pytest-xdist = {extras = ["psutil"], version = "^3.6.1"}
pytest-repeat = "^0.9.3"
langgraph-checkpoint = {path = "../checkpoint", develop = true}
langgraph-checkpoint-duckdb = {path = "../checkpoint-duckdb", develop = true}
langgraph-checkpoint-sqlite = {path = "../checkpoint-sqlite", develop = true}
langgraph-checkpoint-postgres = {path = "../checkpoint-postgres", develop = true}
langgraph-sdk = {path = "../sdk-py", develop = true}
psycopg = {extras = ["binary"], version = ">=3.0.0", python = ">=3.10"}
uvloop = "0.21.0beta1"
pyperf = "^2.7.0"
py-spy = "^0.3.14"
types-requests = "^2.32.0.20240914"
[tool.ruff]
lint.select = [ "E", "F", "I" ]
lint.ignore = [ "E501" ]
line-length = 88
indent-width = 4
extend-include = ["*.ipynb"]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
docstring-code-line-length = "dynamic"
[tool.mypy]
# https://mypy.readthedocs.io/en/stable/config_file.html
disallow_untyped_defs = "True"
explicit_package_bases = "True"
warn_no_return = "False"
warn_unused_ignores = "True"
warn_redundant_casts = "True"
allow_redefinition = "True"
disable_error_code = "typeddict-item, return-value, override, has-type"
[tool.coverage.run]
omit = ["tests/*"]
[tool.pytest-watcher]
now = true
delay = 0.1
patterns = ["*.py"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--full-trace --strict-markers --strict-config --durations=5 --snapshot-warn-unused"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_utils.py`:
```py
import functools
import sys
import uuid
from typing import (
Any,
Callable,
Dict,
ForwardRef,
List,
Literal,
Optional,
TypedDict,
TypeVar,
Union,
)
from unittest.mock import patch
import langsmith
import pytest
from typing_extensions import Annotated, NotRequired, Required
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.utils.fields import _is_optional_type, get_field_default
from langgraph.utils.runnable import is_async_callable, is_async_generator
pytestmark = pytest.mark.anyio
def test_is_async() -> None:
async def func() -> None:
pass
assert is_async_callable(func)
wrapped_func = functools.wraps(func)(func)
assert is_async_callable(wrapped_func)
def sync_func() -> None:
pass
assert not is_async_callable(sync_func)
wrapped_sync_func = functools.wraps(sync_func)(sync_func)
assert not is_async_callable(wrapped_sync_func)
class AsyncFuncCallable:
async def __call__(self) -> None:
pass
runnable = AsyncFuncCallable()
assert is_async_callable(runnable)
wrapped_runnable = functools.wraps(runnable)(runnable)
assert is_async_callable(wrapped_runnable)
class SyncFuncCallable:
def __call__(self) -> None:
pass
sync_runnable = SyncFuncCallable()
assert not is_async_callable(sync_runnable)
wrapped_sync_runnable = functools.wraps(sync_runnable)(sync_runnable)
assert not is_async_callable(wrapped_sync_runnable)
def test_is_generator() -> None:
async def gen():
yield
assert is_async_generator(gen)
wrapped_gen = functools.wraps(gen)(gen)
assert is_async_generator(wrapped_gen)
def sync_gen():
yield
assert not is_async_generator(sync_gen)
wrapped_sync_gen = functools.wraps(sync_gen)(sync_gen)
assert not is_async_generator(wrapped_sync_gen)
class AsyncGenCallable:
async def __call__(self):
yield
runnable = AsyncGenCallable()
assert is_async_generator(runnable)
wrapped_runnable = functools.wraps(runnable)(runnable)
assert is_async_generator(wrapped_runnable)
class SyncGenCallable:
def __call__(self):
yield
sync_runnable = SyncGenCallable()
assert not is_async_generator(sync_runnable)
wrapped_sync_runnable = functools.wraps(sync_runnable)(sync_runnable)
assert not is_async_generator(wrapped_sync_runnable)
@pytest.fixture
def rt_graph() -> CompiledGraph:
class State(TypedDict):
foo: int
node_run_id: int
def node(_: State):
from langsmith import get_current_run_tree # type: ignore
return {"node_run_id": get_current_run_tree().id} # type: ignore
graph = StateGraph(State)
graph.add_node(node)
graph.set_entry_point("node")
graph.add_edge("node", END)
return graph.compile()
def test_runnable_callable_tracing_nested(rt_graph: CompiledGraph) -> None:
with patch("langsmith.client.Client", spec=langsmith.Client) as mock_client:
with patch("langchain_core.tracers.langchain.get_client") as mock_get_client:
mock_get_client.return_value = mock_client
with langsmith.tracing_context(enabled=True):
res = rt_graph.invoke({"foo": 1})
assert isinstance(res["node_run_id"], uuid.UUID)
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
async def test_runnable_callable_tracing_nested_async(rt_graph: CompiledGraph) -> None:
with patch("langsmith.client.Client", spec=langsmith.Client) as mock_client:
with patch("langchain_core.tracers.langchain.get_client") as mock_get_client:
mock_get_client.return_value = mock_client
with langsmith.tracing_context(enabled=True):
res = await rt_graph.ainvoke({"foo": 1})
assert isinstance(res["node_run_id"], uuid.UUID)
def test_is_optional_type():
assert _is_optional_type(None)
assert not _is_optional_type(type(None))
assert _is_optional_type(Optional[list])
assert not _is_optional_type(int)
assert _is_optional_type(Optional[Literal[1, 2, 3]])
assert not _is_optional_type(Literal[1, 2, 3])
assert _is_optional_type(Optional[List[int]])
assert _is_optional_type(Optional[Dict[str, int]])
assert not _is_optional_type(List[Optional[int]])
assert _is_optional_type(Union[Optional[str], Optional[int]])
assert _is_optional_type(
Union[
Union[Optional[str], Optional[int]], Union[Optional[float], Optional[dict]]
]
)
assert not _is_optional_type(Union[Union[str, int], Union[float, dict]])
assert _is_optional_type(Union[int, None])
assert _is_optional_type(Union[str, None, int])
assert _is_optional_type(Union[None, str, int])
assert not _is_optional_type(Union[int, str])
assert not _is_optional_type(Any) # Do we actually want this?
assert _is_optional_type(Optional[Any])
class MyClass:
pass
assert _is_optional_type(Optional[MyClass])
assert not _is_optional_type(MyClass)
assert _is_optional_type(Optional[ForwardRef("MyClass")])
assert not _is_optional_type(ForwardRef("MyClass"))
assert _is_optional_type(Optional[Union[List[int], Dict[str, Optional[int]]]])
assert not _is_optional_type(Union[List[int], Dict[str, Optional[int]]])
assert _is_optional_type(Optional[Callable[[int], str]])
assert not _is_optional_type(Callable[[int], Optional[str]])
T = TypeVar("T")
assert _is_optional_type(Optional[T])
assert not _is_optional_type(T)
U = TypeVar("U", bound=Optional[T]) # type: ignore
assert _is_optional_type(U)
def test_is_required():
class MyBaseTypedDict(TypedDict):
val_1: Required[Optional[str]]
val_2: Required[str]
val_3: NotRequired[str]
val_4: NotRequired[Optional[str]]
val_5: Annotated[NotRequired[int], "foo"]
val_6: NotRequired[Annotated[int, "foo"]]
val_7: Annotated[Required[int], "foo"]
val_8: Required[Annotated[int, "foo"]]
val_9: Optional[str]
val_10: str
annos = MyBaseTypedDict.__annotations__
assert get_field_default("val_1", annos["val_1"], MyBaseTypedDict) == ...
assert get_field_default("val_2", annos["val_2"], MyBaseTypedDict) == ...
assert get_field_default("val_3", annos["val_3"], MyBaseTypedDict) is None
assert get_field_default("val_4", annos["val_4"], MyBaseTypedDict) is None
# See https://peps.python.org/pep-0655/#interaction-with-annotated
assert get_field_default("val_5", annos["val_5"], MyBaseTypedDict) is None
assert get_field_default("val_6", annos["val_6"], MyBaseTypedDict) is None
assert get_field_default("val_7", annos["val_7"], MyBaseTypedDict) == ...
assert get_field_default("val_8", annos["val_8"], MyBaseTypedDict) == ...
assert get_field_default("val_9", annos["val_9"], MyBaseTypedDict) is None
assert get_field_default("val_10", annos["val_10"], MyBaseTypedDict) == ...
class MyChildDict(MyBaseTypedDict):
val_11: int
val_11b: Optional[int]
val_11c: Union[int, None, str]
class MyGrandChildDict(MyChildDict, total=False):
val_12: int
val_13: Required[str]
cannos = MyChildDict.__annotations__
gcannos = MyGrandChildDict.__annotations__
assert get_field_default("val_11", cannos["val_11"], MyChildDict) == ...
assert get_field_default("val_11b", cannos["val_11b"], MyChildDict) is None
assert get_field_default("val_11c", cannos["val_11c"], MyChildDict) is None
assert get_field_default("val_12", gcannos["val_12"], MyGrandChildDict) is None
assert get_field_default("val_9", gcannos["val_9"], MyGrandChildDict) is None
assert get_field_default("val_13", gcannos["val_13"], MyGrandChildDict) == ...
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/conftest.py`:
```py
import sys
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional
from uuid import UUID, uuid4
import pytest
from langchain_core import __version__ as core_version
from packaging import version
from psycopg import AsyncConnection, Connection
from psycopg_pool import AsyncConnectionPool, ConnectionPool
from pytest_mock import MockerFixture
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.duckdb import DuckDBSaver
from langgraph.checkpoint.duckdb.aio import AsyncDuckDBSaver
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.store.base import BaseStore
from langgraph.store.duckdb import AsyncDuckDBStore, DuckDBStore
from langgraph.store.memory import InMemoryStore
from langgraph.store.postgres import AsyncPostgresStore, PostgresStore
pytest.register_assert_rewrite("tests.memory_assert")
DEFAULT_POSTGRES_URI = "postgres://postgres:postgres@localhost:5442/"
# TODO: fix this once core is released
IS_LANGCHAIN_CORE_030_OR_GREATER = version.parse(core_version) >= version.parse(
"0.3.0.dev0"
)
SHOULD_CHECK_SNAPSHOTS = IS_LANGCHAIN_CORE_030_OR_GREATER
@pytest.fixture
def anyio_backend():
return "asyncio"
@pytest.fixture()
def deterministic_uuids(mocker: MockerFixture) -> MockerFixture:
side_effect = (
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
)
return mocker.patch("uuid.uuid4", side_effect=side_effect)
# checkpointer fixtures
@pytest.fixture(scope="function")
def checkpointer_memory():
from tests.memory_assert import MemorySaverAssertImmutable
yield MemorySaverAssertImmutable()
@pytest.fixture(scope="function")
def checkpointer_sqlite():
with SqliteSaver.from_conn_string(":memory:") as checkpointer:
yield checkpointer
@asynccontextmanager
async def _checkpointer_sqlite_aio():
async with AsyncSqliteSaver.from_conn_string(":memory:") as checkpointer:
yield checkpointer
@pytest.fixture(scope="function")
def checkpointer_duckdb():
with DuckDBSaver.from_conn_string(":memory:") as checkpointer:
checkpointer.setup()
yield checkpointer
@asynccontextmanager
async def _checkpointer_duckdb_aio():
async with AsyncDuckDBSaver.from_conn_string(":memory:") as checkpointer:
await checkpointer.setup()
yield checkpointer
@pytest.fixture(scope="function")
def checkpointer_postgres():
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
with PostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database
) as checkpointer:
checkpointer.setup()
yield checkpointer
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
@pytest.fixture(scope="function")
def checkpointer_postgres_pipe():
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
with PostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database
) as checkpointer:
checkpointer.setup()
# setup can't run inside pipeline because of implicit transaction
with checkpointer.conn.pipeline() as pipe:
checkpointer.pipe = pipe
yield checkpointer
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
@pytest.fixture(scope="function")
def checkpointer_postgres_pool():
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
with ConnectionPool(
DEFAULT_POSTGRES_URI + database, max_size=10, kwargs={"autocommit": True}
) as pool:
checkpointer = PostgresSaver(pool)
checkpointer.setup()
yield checkpointer
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
@asynccontextmanager
async def _checkpointer_postgres_aio():
if sys.version_info < (3, 10):
pytest.skip("Async Postgres tests require Python 3.10+")
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
async with AsyncPostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database
) as checkpointer:
await checkpointer.setup()
yield checkpointer
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
@asynccontextmanager
async def _checkpointer_postgres_aio_pipe():
if sys.version_info < (3, 10):
pytest.skip("Async Postgres tests require Python 3.10+")
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
async with AsyncPostgresSaver.from_conn_string(
DEFAULT_POSTGRES_URI + database
) as checkpointer:
await checkpointer.setup()
# setup can't run inside pipeline because of implicit transaction
async with checkpointer.conn.pipeline() as pipe:
checkpointer.pipe = pipe
yield checkpointer
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
@asynccontextmanager
async def _checkpointer_postgres_aio_pool():
if sys.version_info < (3, 10):
pytest.skip("Async Postgres tests require Python 3.10+")
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
async with AsyncConnectionPool(
DEFAULT_POSTGRES_URI + database, max_size=10, kwargs={"autocommit": True}
) as pool:
checkpointer = AsyncPostgresSaver(pool)
await checkpointer.setup()
yield checkpointer
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
@asynccontextmanager
async def awith_checkpointer(
checkpointer_name: Optional[str],
) -> AsyncIterator[BaseCheckpointSaver]:
if checkpointer_name is None:
yield None
elif checkpointer_name == "memory":
from tests.memory_assert import MemorySaverAssertImmutable
yield MemorySaverAssertImmutable()
elif checkpointer_name == "sqlite_aio":
async with _checkpointer_sqlite_aio() as checkpointer:
yield checkpointer
elif checkpointer_name == "duckdb_aio":
async with _checkpointer_duckdb_aio() as checkpointer:
yield checkpointer
elif checkpointer_name == "postgres_aio":
async with _checkpointer_postgres_aio() as checkpointer:
yield checkpointer
elif checkpointer_name == "postgres_aio_pipe":
async with _checkpointer_postgres_aio_pipe() as checkpointer:
yield checkpointer
elif checkpointer_name == "postgres_aio_pool":
async with _checkpointer_postgres_aio_pool() as checkpointer:
yield checkpointer
else:
raise NotImplementedError(f"Unknown checkpointer: {checkpointer_name}")
@asynccontextmanager
async def _store_postgres_aio():
if sys.version_info < (3, 10):
pytest.skip("Async Postgres tests require Python 3.10+")
database = f"test_{uuid4().hex[:16]}"
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with AsyncPostgresStore.from_conn_string(
DEFAULT_POSTGRES_URI + database
) as store:
await store.setup()
yield store
finally:
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
@asynccontextmanager
async def _store_postgres_aio_pipe():
if sys.version_info < (3, 10):
pytest.skip("Async Postgres tests require Python 3.10+")
database = f"test_{uuid4().hex[:16]}"
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with AsyncPostgresStore.from_conn_string(
DEFAULT_POSTGRES_URI + database
) as store:
await store.setup() # Run in its own transaction
async with AsyncPostgresStore.from_conn_string(
DEFAULT_POSTGRES_URI + database, pipeline=True
) as store:
yield store
finally:
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
@asynccontextmanager
async def _store_postgres_aio_pool():
if sys.version_info < (3, 10):
pytest.skip("Async Postgres tests require Python 3.10+")
database = f"test_{uuid4().hex[:16]}"
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
async with AsyncPostgresStore.from_conn_string(
DEFAULT_POSTGRES_URI + database,
pool_config={"max_size": 10},
) as store:
await store.setup()
yield store
finally:
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
@asynccontextmanager
async def _store_duckdb_aio():
async with AsyncDuckDBStore.from_conn_string(":memory:") as store:
await store.setup()
yield store
@pytest.fixture(scope="function")
def store_postgres():
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
# yield store
with PostgresStore.from_conn_string(DEFAULT_POSTGRES_URI + database) as store:
store.setup()
yield store
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
@pytest.fixture(scope="function")
def store_postgres_pipe():
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
# yield store
with PostgresStore.from_conn_string(DEFAULT_POSTGRES_URI + database) as store:
store.setup() # Run in its own transaction
with PostgresStore.from_conn_string(
DEFAULT_POSTGRES_URI + database, pipeline=True
) as store:
yield store
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
@pytest.fixture(scope="function")
def store_postgres_pool():
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
# yield store
with PostgresStore.from_conn_string(
DEFAULT_POSTGRES_URI + database, pool_config={"max_size": 10}
) as store:
store.setup()
yield store
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
@pytest.fixture(scope="function")
def store_duckdb():
with DuckDBStore.from_conn_string(":memory:") as store:
store.setup()
yield store
@pytest.fixture(scope="function")
def store_in_memory():
yield InMemoryStore()
@asynccontextmanager
async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]:
if store_name is None:
yield None
elif store_name == "in_memory":
yield InMemoryStore()
elif store_name == "postgres_aio":
async with _store_postgres_aio() as store:
yield store
elif store_name == "postgres_aio_pipe":
async with _store_postgres_aio_pipe() as store:
yield store
elif store_name == "postgres_aio_pool":
async with _store_postgres_aio_pool() as store:
yield store
elif store_name == "duckdb_aio":
async with _store_duckdb_aio() as store:
yield store
else:
raise NotImplementedError(f"Unknown store {store_name}")
ALL_CHECKPOINTERS_SYNC = [
"memory",
"sqlite",
"postgres",
"postgres_pipe",
"postgres_pool",
]
ALL_CHECKPOINTERS_ASYNC = [
"memory",
"sqlite_aio",
"postgres_aio",
"postgres_aio_pipe",
"postgres_aio_pool",
]
ALL_CHECKPOINTERS_ASYNC_PLUS_NONE = [
*ALL_CHECKPOINTERS_ASYNC,
None,
]
ALL_STORES_SYNC = [
"in_memory",
"postgres",
"postgres_pipe",
"postgres_pool",
"duckdb",
]
ALL_STORES_ASYNC = [
"in_memory",
"postgres_aio",
"postgres_aio_pipe",
"postgres_aio_pool",
"duckdb_aio",
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_remote_graph.py`:
```py
from unittest.mock import AsyncMock, MagicMock
import pytest
from langchain_core.runnables.graph import (
Edge as DrawableEdge,
)
from langchain_core.runnables.graph import (
Node as DrawableNode,
)
from langgraph_sdk.schema import StreamPart
from langgraph.errors import GraphInterrupt
from langgraph.pregel.remote import RemoteGraph
from langgraph.pregel.types import StateSnapshot
def test_with_config():
# set up test
remote_pregel = RemoteGraph(
"test_graph_id",
config={
"configurable": {
"foo": "bar",
"thread_id": "thread_id_1",
}
},
)
# call method / assertions
config = {"configurable": {"hello": "world"}}
remote_pregel_copy = remote_pregel.with_config(config)
# assert that a copy was returned
assert remote_pregel_copy != remote_pregel
# assert that configs were merged
assert remote_pregel_copy.config == {
"configurable": {
"foo": "bar",
"thread_id": "thread_id_1",
"hello": "world",
}
}
def test_get_graph():
# set up test
mock_sync_client = MagicMock()
mock_sync_client.assistants.get_graph.return_value = {
"nodes": [
{"id": "__start__", "type": "schema", "data": "__start__"},
{"id": "__end__", "type": "schema", "data": "__end__"},
{
"id": "agent",
"type": "runnable",
"data": {
"id": ["langgraph", "utils", "RunnableCallable"],
"name": "agent_1",
},
},
],
"edges": [
{"source": "__start__", "target": "agent"},
{"source": "agent", "target": "__end__"},
],
}
remote_pregel = RemoteGraph("test_graph_id", sync_client=mock_sync_client)
# call method / assertions
drawable_graph = remote_pregel.get_graph()
assert drawable_graph.nodes == {
"__start__": DrawableNode(
id="__start__", name="__start__", data="__start__", metadata=None
),
"__end__": DrawableNode(
id="__end__", name="__end__", data="__end__", metadata=None
),
"agent": DrawableNode(
id="agent",
name="agent_1",
data={"id": ["langgraph", "utils", "RunnableCallable"], "name": "agent_1"},
metadata=None,
),
}
assert drawable_graph.edges == [
DrawableEdge(source="__start__", target="agent"),
DrawableEdge(source="agent", target="__end__"),
]
@pytest.mark.anyio
async def test_aget_graph():
# set up test
mock_async_client = AsyncMock()
mock_async_client.assistants.get_graph.return_value = {
"nodes": [
{"id": "__start__", "type": "schema", "data": "__start__"},
{"id": "__end__", "type": "schema", "data": "__end__"},
{
"id": "agent",
"type": "runnable",
"data": {
"id": ["langgraph", "utils", "RunnableCallable"],
"name": "agent_1",
},
},
],
"edges": [
{"source": "__start__", "target": "agent"},
{"source": "agent", "target": "__end__"},
],
}
remote_pregel = RemoteGraph("test_graph_id", client=mock_async_client)
# call method / assertions
drawable_graph = await remote_pregel.aget_graph()
assert drawable_graph.nodes == {
"__start__": DrawableNode(
id="__start__", name="__start__", data="__start__", metadata=None
),
"__end__": DrawableNode(
id="__end__", name="__end__", data="__end__", metadata=None
),
"agent": DrawableNode(
id="agent",
name="agent_1",
data={"id": ["langgraph", "utils", "RunnableCallable"], "name": "agent_1"},
metadata=None,
),
}
assert drawable_graph.edges == [
DrawableEdge(source="__start__", target="agent"),
DrawableEdge(source="agent", target="__end__"),
]
def test_get_state():
# set up test
mock_sync_client = MagicMock()
mock_sync_client.threads.get_state.return_value = {
"values": {"messages": [{"type": "human", "content": "hello"}]},
"next": None,
"checkpoint": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
},
"metadata": {},
"created_at": "timestamp",
"parent_checkpoint": None,
"tasks": [],
}
# call method / assertions
remote_pregel = RemoteGraph(
"test_graph_id",
sync_client=mock_sync_client,
)
config = {"configurable": {"thread_id": "thread1"}}
state_snapshot = remote_pregel.get_state(config)
assert state_snapshot == StateSnapshot(
values={"messages": [{"type": "human", "content": "hello"}]},
next=(),
config={
"configurable": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
}
},
metadata={},
created_at="timestamp",
parent_config=None,
tasks=(),
)
@pytest.mark.anyio
async def test_aget_state():
mock_async_client = AsyncMock()
mock_async_client.threads.get_state.return_value = {
"values": {"messages": [{"type": "human", "content": "hello"}]},
"next": None,
"checkpoint": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_2",
"checkpoint_map": {},
},
"metadata": {},
"created_at": "timestamp",
"parent_checkpoint": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
},
"tasks": [],
}
# call method / assertions
remote_pregel = RemoteGraph(
"test_graph_id",
client=mock_async_client,
)
config = {"configurable": {"thread_id": "thread1"}}
state_snapshot = await remote_pregel.aget_state(config)
assert state_snapshot == StateSnapshot(
values={"messages": [{"type": "human", "content": "hello"}]},
next=(),
config={
"configurable": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_2",
"checkpoint_map": {},
}
},
metadata={},
created_at="timestamp",
parent_config={
"configurable": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
}
},
tasks=(),
)
def test_get_state_history():
# set up test
mock_sync_client = MagicMock()
mock_sync_client.threads.get_history.return_value = [
{
"values": {"messages": [{"type": "human", "content": "hello"}]},
"next": None,
"checkpoint": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
},
"metadata": {},
"created_at": "timestamp",
"parent_checkpoint": None,
"tasks": [],
}
]
# call method / assertions
remote_pregel = RemoteGraph(
"test_graph_id",
sync_client=mock_sync_client,
)
config = {"configurable": {"thread_id": "thread1"}}
state_history_snapshot = list(
remote_pregel.get_state_history(config, filter=None, before=None, limit=None)
)
assert len(state_history_snapshot) == 1
assert state_history_snapshot[0] == StateSnapshot(
values={"messages": [{"type": "human", "content": "hello"}]},
next=(),
config={
"configurable": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
}
},
metadata={},
created_at="timestamp",
parent_config=None,
tasks=(),
)
@pytest.mark.anyio
async def test_aget_state_history():
# set up test
mock_async_client = AsyncMock()
mock_async_client.threads.get_history.return_value = [
{
"values": {"messages": [{"type": "human", "content": "hello"}]},
"next": None,
"checkpoint": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
},
"metadata": {},
"created_at": "timestamp",
"parent_checkpoint": None,
"tasks": [],
}
]
# call method / assertions
remote_pregel = RemoteGraph(
"test_graph_id",
client=mock_async_client,
)
config = {"configurable": {"thread_id": "thread1"}}
state_history_snapshot = []
async for state_snapshot in remote_pregel.aget_state_history(
config, filter=None, before=None, limit=None
):
state_history_snapshot.append(state_snapshot)
assert len(state_history_snapshot) == 1
assert state_history_snapshot[0] == StateSnapshot(
values={"messages": [{"type": "human", "content": "hello"}]},
next=(),
config={
"configurable": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
}
},
metadata={},
created_at="timestamp",
parent_config=None,
tasks=(),
)
def test_update_state():
# set up test
mock_sync_client = MagicMock()
mock_sync_client.threads.update_state.return_value = {
"checkpoint": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
}
}
# call method / assertions
remote_pregel = RemoteGraph(
"test_graph_id",
sync_client=mock_sync_client,
)
config = {"configurable": {"thread_id": "thread1"}}
response = remote_pregel.update_state(config, {"key": "value"})
assert response == {
"configurable": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
}
}
@pytest.mark.anyio
async def test_aupdate_state():
# set up test
mock_async_client = AsyncMock()
mock_async_client.threads.update_state.return_value = {
"checkpoint": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
}
}
# call method / assertions
remote_pregel = RemoteGraph(
"test_graph_id",
client=mock_async_client,
)
config = {"configurable": {"thread_id": "thread1"}}
response = await remote_pregel.aupdate_state(config, {"key": "value"})
assert response == {
"configurable": {
"thread_id": "thread_1",
"checkpoint_ns": "ns",
"checkpoint_id": "checkpoint_1",
"checkpoint_map": {},
}
}
def test_stream():
# set up test
mock_sync_client = MagicMock()
mock_sync_client.runs.stream.return_value = [
StreamPart(event="values", data={"chunk": "data1"}),
StreamPart(event="values", data={"chunk": "data2"}),
StreamPart(event="values", data={"chunk": "data3"}),
StreamPart(event="updates", data={"chunk": "data4"}),
StreamPart(event="updates", data={"__interrupt__": ()}),
]
# call method / assertions
remote_pregel = RemoteGraph(
"test_graph_id",
sync_client=mock_sync_client,
)
# stream modes doesn't include 'updates'
stream_parts = []
with pytest.raises(GraphInterrupt):
for stream_part in remote_pregel.stream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode="values",
):
stream_parts.append(stream_part)
assert stream_parts == [
{"chunk": "data1"},
{"chunk": "data2"},
{"chunk": "data3"},
]
mock_sync_client.runs.stream.return_value = [
StreamPart(event="updates", data={"chunk": "data3"}),
StreamPart(event="updates", data={"chunk": "data4"}),
StreamPart(event="updates", data={"__interrupt__": ()}),
]
# default stream_mode is updates
stream_parts = []
with pytest.raises(GraphInterrupt):
for stream_part in remote_pregel.stream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
):
stream_parts.append(stream_part)
assert stream_parts == [
{"chunk": "data3"},
{"chunk": "data4"},
]
# list stream_mode includes mode names
stream_parts = []
with pytest.raises(GraphInterrupt):
for stream_part in remote_pregel.stream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode=["updates"],
):
stream_parts.append(stream_part)
assert stream_parts == [
("updates", {"chunk": "data3"}),
("updates", {"chunk": "data4"}),
]
# subgraphs + list modes
stream_parts = []
with pytest.raises(GraphInterrupt):
for stream_part in remote_pregel.stream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode=["updates"],
subgraphs=True,
):
stream_parts.append(stream_part)
assert stream_parts == [
((), "updates", {"chunk": "data3"}),
((), "updates", {"chunk": "data4"}),
]
# subgraphs + single mode
stream_parts = []
with pytest.raises(GraphInterrupt):
for stream_part in remote_pregel.stream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
subgraphs=True,
):
stream_parts.append(stream_part)
assert stream_parts == [
((), {"chunk": "data3"}),
((), {"chunk": "data4"}),
]
@pytest.mark.anyio
async def test_astream():
# set up test
mock_async_client = MagicMock()
async_iter = MagicMock()
async_iter.__aiter__.return_value = [
StreamPart(event="values", data={"chunk": "data1"}),
StreamPart(event="values", data={"chunk": "data2"}),
StreamPart(event="values", data={"chunk": "data3"}),
StreamPart(event="updates", data={"chunk": "data4"}),
StreamPart(event="updates", data={"__interrupt__": ()}),
]
mock_async_client.runs.stream.return_value = async_iter
# call method / assertions
remote_pregel = RemoteGraph(
"test_graph_id",
client=mock_async_client,
)
# stream modes doesn't include 'updates'
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode="values",
):
stream_parts.append(stream_part)
assert stream_parts == [
{"chunk": "data1"},
{"chunk": "data2"},
{"chunk": "data3"},
]
async_iter = MagicMock()
async_iter.__aiter__.return_value = [
StreamPart(event="updates", data={"chunk": "data3"}),
StreamPart(event="updates", data={"chunk": "data4"}),
StreamPart(event="updates", data={"__interrupt__": ()}),
]
mock_async_client.runs.stream.return_value = async_iter
# default stream_mode is updates
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
):
stream_parts.append(stream_part)
assert stream_parts == [
{"chunk": "data3"},
{"chunk": "data4"},
]
# list stream_mode includes mode names
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode=["updates"],
):
stream_parts.append(stream_part)
assert stream_parts == [
("updates", {"chunk": "data3"}),
("updates", {"chunk": "data4"}),
]
# subgraphs + list modes
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode=["updates"],
subgraphs=True,
):
stream_parts.append(stream_part)
assert stream_parts == [
((), "updates", {"chunk": "data3"}),
((), "updates", {"chunk": "data4"}),
]
# subgraphs + single mode
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
subgraphs=True,
):
stream_parts.append(stream_part)
assert stream_parts == [
((), {"chunk": "data3"}),
((), {"chunk": "data4"}),
]
async_iter = MagicMock()
async_iter.__aiter__.return_value = [
StreamPart(event="updates|my|subgraph", data={"chunk": "data3"}),
StreamPart(event="updates|hello|subgraph", data={"chunk": "data4"}),
StreamPart(event="updates|bye|subgraph", data={"__interrupt__": ()}),
]
mock_async_client.runs.stream.return_value = async_iter
# subgraphs + list modes
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
stream_mode=["updates"],
subgraphs=True,
):
stream_parts.append(stream_part)
assert stream_parts == [
(("my", "subgraph"), "updates", {"chunk": "data3"}),
(("hello", "subgraph"), "updates", {"chunk": "data4"}),
]
# subgraphs + single mode
stream_parts = []
with pytest.raises(GraphInterrupt):
async for stream_part in remote_pregel.astream(
{"input": "data"},
config={"configurable": {"thread_id": "thread_1"}},
subgraphs=True,
):
stream_parts.append(stream_part)
assert stream_parts == [
(("my", "subgraph"), {"chunk": "data3"}),
(("hello", "subgraph"), {"chunk": "data4"}),
]
def test_invoke():
# set up test
mock_sync_client = MagicMock()
mock_sync_client.runs.stream.return_value = [
StreamPart(event="values", data={"chunk": "data1"}),
StreamPart(event="values", data={"chunk": "data2"}),
StreamPart(
event="values", data={"messages": [{"type": "human", "content": "world"}]}
),
]
# call method / assertions
remote_pregel = RemoteGraph(
"test_graph_id",
sync_client=mock_sync_client,
)
config = {"configurable": {"thread_id": "thread_1"}}
result = remote_pregel.invoke(
{"input": {"messages": [{"type": "human", "content": "hello"}]}}, config
)
assert result == {"messages": [{"type": "human", "content": "world"}]}
@pytest.mark.anyio
async def test_ainvoke():
# set up test
mock_async_client = MagicMock()
async_iter = MagicMock()
async_iter.__aiter__.return_value = [
StreamPart(event="values", data={"chunk": "data1"}),
StreamPart(event="values", data={"chunk": "data2"}),
StreamPart(
event="values", data={"messages": [{"type": "human", "content": "world"}]}
),
]
mock_async_client.runs.stream.return_value = async_iter
# call method / assertions
remote_pregel = RemoteGraph(
"test_graph_id",
client=mock_async_client,
)
config = {"configurable": {"thread_id": "thread_1"}}
result = await remote_pregel.ainvoke(
{"input": {"messages": [{"type": "human", "content": "hello"}]}}, config
)
assert result == {"messages": [{"type": "human", "content": "world"}]}
@pytest.mark.skip("Unskip this test to manually test the LangGraph Cloud integration")
@pytest.mark.anyio
async def test_langgraph_cloud_integration():
from langgraph_sdk.client import get_client, get_sync_client
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, MessagesState, StateGraph
# create RemotePregel instance
client = get_client()
sync_client = get_sync_client()
remote_pregel = RemoteGraph(
"agent",
client=client,
sync_client=sync_client,
)
# define graph
workflow = StateGraph(MessagesState)
workflow.add_node("agent", remote_pregel)
workflow.add_edge(START, "agent")
workflow.add_edge("agent", END)
app = workflow.compile(checkpointer=MemorySaver())
# test invocation
input = {
"messages": [
{
"role": "human",
"content": "What's the weather in SF?",
}
]
}
# test invoke
response = app.invoke(
input,
config={"configurable": {"thread_id": "39a6104a-34e7-4f83-929c-d9eb163003c9"}},
interrupt_before=["agent"],
)
print("response:", response["messages"][-1].content)
# test stream
async for chunk in app.astream(
input,
config={"configurable": {"thread_id": "2dc3e3e7-39ac-4597-aa57-4404b944e82a"}},
subgraphs=True,
stream_mode=["debug", "messages"],
):
print("chunk:", chunk)
# test stream events
async for chunk in remote_pregel.astream_events(
input,
config={"configurable": {"thread_id": "2dc3e3e7-39ac-4597-aa57-4404b944e82a"}},
version="v2",
subgraphs=True,
stream_mode=[],
):
print("chunk:", chunk)
# test get state
state_snapshot = await remote_pregel.aget_state(
config={"configurable": {"thread_id": "2dc3e3e7-39ac-4597-aa57-4404b944e82a"}},
subgraphs=True,
)
print("state snapshot:", state_snapshot)
# test update state
response = await remote_pregel.aupdate_state(
config={"configurable": {"thread_id": "6645e002-ed50-4022-92a3-d0d186fdf812"}},
values={
"messages": [
{
"role": "ai",
"content": "Hello world again!",
}
]
},
)
print("response:", response)
# test get history
async for state in remote_pregel.aget_state_history(
config={"configurable": {"thread_id": "2dc3e3e7-39ac-4597-aa57-4404b944e82a"}},
):
print("state snapshot:", state)
# test get graph
remote_pregel.graph_id = "fe096781-5601-53d2-b2f6-0d3403f7e9ca" # must be UUID
graph = await remote_pregel.aget_graph(xray=True)
print("graph:", graph)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/fake_chat.py`:
```py
import re
from typing import Any, AsyncIterator, Iterator, List, Optional, cast
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
class FakeChatModel(GenericFakeChatModel):
messages: list[BaseMessage]
i: int = 0
def bind_tools(self, functions: list):
return self
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
if self.i >= len(self.messages):
self.i = 0
message = self.messages[self.i]
self.i += 1
if isinstance(message, str):
message_ = AIMessage(content=message)
else:
if hasattr(message, "model_copy"):
message_ = message.model_copy()
else:
message_ = message.copy()
generation = ChatGeneration(message=message_)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model."""
chat_result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
if not isinstance(chat_result, ChatResult):
raise ValueError(
f"Expected generate to return a ChatResult, "
f"but got {type(chat_result)} instead."
)
message = chat_result.generations[0].message
if not isinstance(message, AIMessage):
raise ValueError(
f"Expected invoke to return an AIMessage, "
f"but got {type(message)} instead."
)
content = message.content
if content:
# Use a regular expression to split on whitespace with a capture group
# so that we can preserve the whitespace in the output.
assert isinstance(content, str)
content_chunks = cast(list[str], re.split(r"(\s)", content))
for token in content_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, id=message.id)
)
if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
else:
args = message.__dict__
args.pop("type")
chunk = ChatGenerationChunk(message=AIMessageChunk(**args))
if run_manager:
run_manager.on_llm_new_token("", chunk=chunk)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Stream the output of the model."""
chat_result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
if not isinstance(chat_result, ChatResult):
raise ValueError(
f"Expected generate to return a ChatResult, "
f"but got {type(chat_result)} instead."
)
message = chat_result.generations[0].message
if not isinstance(message, AIMessage):
raise ValueError(
f"Expected invoke to return an AIMessage, "
f"but got {type(message)} instead."
)
content = message.content
if content:
# Use a regular expression to split on whitespace with a capture group
# so that we can preserve the whitespace in the output.
assert isinstance(content, str)
content_chunks = cast(list[str], re.split(r"(\s)", content))
for token in content_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, id=message.id)
)
if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
else:
args = message.__dict__
args.pop("type")
chunk = ChatGenerationChunk(message=AIMessageChunk(**args))
if run_manager:
await run_manager.on_llm_new_token("", chunk=chunk)
yield chunk
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/any_int.py`:
```py
class AnyInt(int):
def __init__(self) -> None:
super().__init__()
def __eq__(self, other: object) -> bool:
return isinstance(other, int)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_interruption.py`:
```py
from typing import TypedDict
import pytest
from pytest_mock import MockerFixture
from langgraph.graph import END, START, StateGraph
from tests.conftest import (
ALL_CHECKPOINTERS_ASYNC,
ALL_CHECKPOINTERS_SYNC,
awith_checkpointer,
)
pytestmark = pytest.mark.anyio
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_interruption_without_state_updates(
request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture
) -> None:
"""Test interruption without state updates. This test confirms that
interrupting doesn't require a state key having been updated in the prev step"""
class State(TypedDict):
input: str
def noop(_state):
pass
builder = StateGraph(State)
builder.add_node("step_1", noop)
builder.add_node("step_2", noop)
builder.add_node("step_3", noop)
builder.add_edge(START, "step_1")
builder.add_edge("step_1", "step_2")
builder.add_edge("step_2", "step_3")
builder.add_edge("step_3", END)
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
graph = builder.compile(checkpointer=checkpointer, interrupt_after="*")
initial_input = {"input": "hello world"}
thread = {"configurable": {"thread_id": "1"}}
graph.invoke(initial_input, thread, debug=True)
assert graph.get_state(thread).next == ("step_2",)
graph.invoke(None, thread, debug=True)
assert graph.get_state(thread).next == ("step_3",)
graph.invoke(None, thread, debug=True)
assert graph.get_state(thread).next == ()
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_interruption_without_state_updates_async(
checkpointer_name: str, mocker: MockerFixture
):
"""Test interruption without state updates. This test confirms that
interrupting doesn't require a state key having been updated in the prev step"""
class State(TypedDict):
input: str
async def noop(_state):
pass
builder = StateGraph(State)
builder.add_node("step_1", noop)
builder.add_node("step_2", noop)
builder.add_node("step_3", noop)
builder.add_edge(START, "step_1")
builder.add_edge("step_1", "step_2")
builder.add_edge("step_2", "step_3")
builder.add_edge("step_3", END)
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer, interrupt_after="*")
initial_input = {"input": "hello world"}
thread = {"configurable": {"thread_id": "1"}}
await graph.ainvoke(initial_input, thread, debug=True)
assert (await graph.aget_state(thread)).next == ("step_2",)
await graph.ainvoke(None, thread, debug=True)
assert (await graph.aget_state(thread)).next == ("step_3",)
await graph.ainvoke(None, thread, debug=True)
assert (await graph.aget_state(thread)).next == ()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/memory_assert.py`:
```py
import asyncio
import os
import tempfile
from collections import defaultdict
from functools import partial
from typing import Any, Optional
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
SerializerProtocol,
copy_checkpoint,
)
from langgraph.checkpoint.memory import MemorySaver, PersistentDict
class NoopSerializer(SerializerProtocol):
def loads_typed(self, data: tuple[str, bytes]) -> Any:
return data[1]
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
return "type", obj
class MemorySaverAssertImmutable(MemorySaver):
storage_for_copies: defaultdict[str, dict[str, dict[str, Checkpoint]]]
def __init__(
self,
*,
serde: Optional[SerializerProtocol] = None,
put_sleep: Optional[float] = None,
) -> None:
_, filename = tempfile.mkstemp()
super().__init__(
serde=serde, factory=partial(PersistentDict, filename=filename)
)
self.storage_for_copies = defaultdict(lambda: defaultdict(dict))
self.put_sleep = put_sleep
self.stack.callback(os.remove, filename)
def put(
self,
config: dict,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> None:
if self.put_sleep:
import time
time.sleep(self.put_sleep)
# assert checkpoint hasn't been modified since last written
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"]["checkpoint_ns"]
if saved := super().get(config):
assert (
self.serde.loads_typed(
self.storage_for_copies[thread_id][checkpoint_ns][saved["id"]]
)
== saved
)
self.storage_for_copies[thread_id][checkpoint_ns][checkpoint["id"]] = (
self.serde.dumps_typed(copy_checkpoint(checkpoint))
)
# call super to write checkpoint
return super().put(config, checkpoint, metadata, new_versions)
class MemorySaverAssertCheckpointMetadata(MemorySaver):
"""This custom checkpointer is for verifying that a run's configurable
fields are merged with the previous checkpoint config for each step in
the run. This is the desired behavior. Because the checkpointer's (a)put()
method is called for each step, the implementation of this checkpointer
should produce a side effect that can be asserted.
"""
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> None:
"""The implementation of put() merges config["configurable"] (a run's
configurable fields) with the metadata field. The state of the
checkpoint metadata can be asserted to confirm that the run's
configurable fields were merged with the previous checkpoint config.
"""
configurable = config["configurable"].copy()
# remove checkpoint_id to make testing simpler
checkpoint_id = configurable.pop("checkpoint_id", None)
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"]["checkpoint_ns"]
self.storage[thread_id][checkpoint_ns].update(
{
checkpoint["id"]: (
self.serde.dumps_typed(checkpoint),
# merge configurable fields and metadata
self.serde.dumps_typed({**configurable, **metadata}),
checkpoint_id,
)
}
)
return {
"configurable": {
"thread_id": config["configurable"]["thread_id"],
"checkpoint_id": checkpoint["id"],
}
}
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
return await asyncio.get_running_loop().run_in_executor(
None, self.put, config, checkpoint, metadata, new_versions
)
class MemorySaverNoPending(MemorySaver):
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
result = super().get_tuple(config)
if result:
return CheckpointTuple(result.config, result.checkpoint, result.metadata)
return result
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_pregel_async.py`:
```py
import asyncio
import logging
import operator
import random
import re
import sys
import uuid
from collections import Counter
from contextlib import asynccontextmanager, contextmanager
from dataclasses import replace
from time import perf_counter
from typing import (
Annotated,
Any,
AsyncGenerator,
AsyncIterator,
Dict,
Generator,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
cast,
)
from uuid import UUID
import httpx
import pytest
from langchain_core.messages import ToolCall
from langchain_core.runnables import (
RunnableConfig,
RunnableLambda,
RunnablePassthrough,
RunnablePick,
)
from langchain_core.utils.aiter import aclosing
from pydantic import BaseModel
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from langgraph.channels.base import BaseChannel
from langgraph.channels.binop import BinaryOperatorAggregate
from langgraph.channels.context import Context
from langgraph.channels.last_value import LastValue
from langgraph.channels.topic import Topic
from langgraph.channels.untracked_value import UntrackedValue
from langgraph.checkpoint.base import (
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
)
from langgraph.checkpoint.memory import MemorySaver
from langgraph.constants import (
CONFIG_KEY_NODE_FINISHED,
ERROR,
FF_SEND_V2,
PULL,
PUSH,
START,
)
from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt
from langgraph.graph import END, Graph, StateGraph
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
from langgraph.managed.shared_value import SharedValue
from langgraph.prebuilt.chat_agent_executor import create_tool_calling_executor
from langgraph.prebuilt.tool_node import ToolNode
from langgraph.pregel import Channel, GraphRecursionError, Pregel, StateSnapshot
from langgraph.pregel.retry import RetryPolicy
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
from langgraph.types import (
Command,
Interrupt,
PregelTask,
Send,
StreamWriter,
interrupt,
)
from tests.any_str import AnyDict, AnyStr, AnyVersion, FloatBetween, UnsortedSequence
from tests.conftest import (
ALL_CHECKPOINTERS_ASYNC,
ALL_CHECKPOINTERS_ASYNC_PLUS_NONE,
ALL_STORES_ASYNC,
SHOULD_CHECK_SNAPSHOTS,
awith_checkpointer,
awith_store,
)
from tests.fake_chat import FakeChatModel
from tests.fake_tracer import FakeTracer
from tests.memory_assert import (
MemorySaverAssertCheckpointMetadata,
MemorySaverNoPending,
)
from tests.messages import (
_AnyIdAIMessage,
_AnyIdAIMessageChunk,
_AnyIdHumanMessage,
_AnyIdToolMessage,
)
logger = logging.getLogger(__name__)
pytestmark = pytest.mark.anyio
async def test_checkpoint_errors() -> None:
class FaultyGetCheckpointer(MemorySaver):
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
raise ValueError("Faulty get_tuple")
class FaultyPutCheckpointer(MemorySaver):
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
raise ValueError("Faulty put")
class FaultyPutWritesCheckpointer(MemorySaver):
async def aput_writes(
self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str
) -> RunnableConfig:
raise ValueError("Faulty put_writes")
class FaultyVersionCheckpointer(MemorySaver):
def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int:
raise ValueError("Faulty get_next_version")
def logic(inp: str) -> str:
return ""
builder = StateGraph(Annotated[str, operator.add])
builder.add_node("agent", logic)
builder.add_edge(START, "agent")
graph = builder.compile(checkpointer=FaultyGetCheckpointer())
with pytest.raises(ValueError, match="Faulty get_tuple"):
await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}})
with pytest.raises(ValueError, match="Faulty get_tuple"):
async for _ in graph.astream("", {"configurable": {"thread_id": "thread-2"}}):
pass
with pytest.raises(ValueError, match="Faulty get_tuple"):
async for _ in graph.astream_events(
"", {"configurable": {"thread_id": "thread-3"}}, version="v2"
):
pass
graph = builder.compile(checkpointer=FaultyPutCheckpointer())
with pytest.raises(ValueError, match="Faulty put"):
await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}})
with pytest.raises(ValueError, match="Faulty put"):
async for _ in graph.astream("", {"configurable": {"thread_id": "thread-2"}}):
pass
with pytest.raises(ValueError, match="Faulty put"):
async for _ in graph.astream_events(
"", {"configurable": {"thread_id": "thread-3"}}, version="v2"
):
pass
graph = builder.compile(checkpointer=FaultyVersionCheckpointer())
with pytest.raises(ValueError, match="Faulty get_next_version"):
await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}})
with pytest.raises(ValueError, match="Faulty get_next_version"):
async for _ in graph.astream("", {"configurable": {"thread_id": "thread-2"}}):
pass
with pytest.raises(ValueError, match="Faulty get_next_version"):
async for _ in graph.astream_events(
"", {"configurable": {"thread_id": "thread-3"}}, version="v2"
):
pass
# add a parallel node
builder.add_node("parallel", logic)
builder.add_edge(START, "parallel")
graph = builder.compile(checkpointer=FaultyPutWritesCheckpointer())
with pytest.raises(ValueError, match="Faulty put_writes"):
await graph.ainvoke("", {"configurable": {"thread_id": "thread-1"}})
with pytest.raises(ValueError, match="Faulty put_writes"):
async for _ in graph.astream("", {"configurable": {"thread_id": "thread-2"}}):
pass
with pytest.raises(ValueError, match="Faulty put_writes"):
async for _ in graph.astream_events(
"", {"configurable": {"thread_id": "thread-3"}}, version="v2"
):
pass
async def test_node_cancellation_on_external_cancel() -> None:
inner_task_cancelled = False
async def awhile(input: Any) -> None:
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
nonlocal inner_task_cancelled
inner_task_cancelled = True
raise
builder = Graph()
builder.add_node("agent", awhile)
builder.set_entry_point("agent")
builder.set_finish_point("agent")
graph = builder.compile()
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(graph.ainvoke(1), 0.5)
assert inner_task_cancelled
async def test_node_cancellation_on_other_node_exception() -> None:
inner_task_cancelled = False
async def awhile(input: Any) -> None:
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
nonlocal inner_task_cancelled
inner_task_cancelled = True
raise
async def iambad(input: Any) -> None:
raise ValueError("I am bad")
builder = Graph()
builder.add_node("agent", awhile)
builder.add_node("bad", iambad)
builder.set_conditional_entry_point(lambda _: ["agent", "bad"], then=END)
graph = builder.compile()
with pytest.raises(ValueError, match="I am bad"):
# This will raise ValueError, not TimeoutError
await asyncio.wait_for(graph.ainvoke(1), 0.5)
assert inner_task_cancelled
async def test_node_cancellation_on_other_node_exception_two() -> None:
async def awhile(input: Any) -> None:
await asyncio.sleep(1)
async def iambad(input: Any) -> None:
raise ValueError("I am bad")
builder = Graph()
builder.add_node("agent", awhile)
builder.add_node("bad", iambad)
builder.set_conditional_entry_point(lambda _: ["agent", "bad"], then=END)
graph = builder.compile()
with pytest.raises(ValueError, match="I am bad"):
# This will raise ValueError, not CancelledError
await graph.ainvoke(1)
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_dynamic_interrupt(checkpointer_name: str) -> None:
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str
tool_two_node_count = 0
async def tool_two_node(s: State) -> State:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
answer = interrupt("Just because...")
else:
answer = " all good"
return {"my_key": answer}
tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy())
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()
tracer = FakeTracer()
assert await tool_two.ainvoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value"}
assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
}
async with awith_checkpointer(checkpointer_name) as checkpointer:
tool_two = tool_two_graph.compile(checkpointer=checkpointer)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
await tool_two.ainvoke({"my_key": "value", "market": "DE"})
# flow: interrupt -> resume with answer
thread2 = {"configurable": {"thread_id": "2"}}
# stop when about to enter node
assert [
c
async for c in tool_two.astream(
{"my_key": "value ⛰️", "market": "DE"}, thread2
)
] == [
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
)
},
]
# resume with answer
assert [
c async for c in tool_two.astream(Command(resume=" my answer"), thread2)
] == [
{"tool_two": {"my_key": " my answer"}},
]
# flow: interrupt -> clear
thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert [
c
async for c in tool_two.astream(
{"my_key": "value ⛰️", "market": "DE"}, thread1
)
] == [
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
)
},
]
assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}},
"thread_id": "1",
},
]
tup = await tool_two.checkpointer.aget_tuple(thread1)
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
),
),
),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)
# clear the interrupt and next tasks
await tool_two.aupdate_state(thread1, None, as_node=END)
# interrupt is cleared, as well as the next tasks
tup = await tool_two.checkpointer.aget_tuple(thread1)
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=(),
tasks=(),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {},
"thread_id": "1",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_dynamic_interrupt_subgraph(checkpointer_name: str) -> None:
class SubgraphState(TypedDict):
my_key: str
market: str
tool_two_node_count = 0
def tool_two_node(s: SubgraphState) -> SubgraphState:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
answer = interrupt("Just because...")
else:
answer = " all good"
return {"my_key": answer}
subgraph = StateGraph(SubgraphState)
subgraph.add_node("do", tool_two_node, retry=RetryPolicy())
subgraph.add_edge(START, "do")
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str
tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", subgraph.compile())
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()
tracer = FakeTracer()
assert await tool_two.ainvoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value"}
assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
}
async with awith_checkpointer(checkpointer_name) as checkpointer:
tool_two = tool_two_graph.compile(checkpointer=checkpointer)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
await tool_two.ainvoke({"my_key": "value", "market": "DE"})
# flow: interrupt -> resume with answer
thread2 = {"configurable": {"thread_id": "2"}}
# stop when about to enter node
assert [
c
async for c in tool_two.astream(
{"my_key": "value ⛰️", "market": "DE"}, thread2
)
] == [
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:"), AnyStr("do:")],
),
)
},
]
# resume with answer
assert [
c async for c in tool_two.astream(Command(resume=" my answer"), thread2)
] == [
{"tool_two": {"my_key": " my answer", "market": "DE"}},
]
# flow: interrupt -> clear
thread1 = {"configurable": {"thread_id": "1"}}
thread1root = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
# stop when about to enter node
assert [
c
async for c in tool_two.astream(
{"my_key": "value ⛰️", "market": "DE"}, thread1
)
] == [
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:"), AnyStr("do:")],
),
)
},
]
assert [c.metadata async for c in tool_two.checkpointer.alist(thread1root)] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}},
"thread_id": "1",
},
]
tup = await tool_two.checkpointer.aget_tuple(thread1)
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:"), AnyStr("do:")],
),
),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("tool_two:"),
}
},
),
),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1root, limit=2)
][-1].config,
)
# clear the interrupt and next tasks
await tool_two.aupdate_state(thread1, None, as_node=END)
# interrupt is cleared, as well as the next tasks
tup = await tool_two.checkpointer.aget_tuple(thread1)
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=(),
tasks=(),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {},
"thread_id": "1",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1root, limit=2)
][-1].config,
)
@pytest.mark.skipif(not FF_SEND_V2, reason="send v2 is not enabled")
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_copy_checkpoint(checkpointer_name: str) -> None:
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str
def tool_one(s: State) -> State:
return {"my_key": " one"}
tool_two_node_count = 0
def tool_two_node(s: State) -> State:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
answer = interrupt("Just because...")
else:
answer = " all good"
return {"my_key": answer}
def start(state: State) -> list[Union[Send, str]]:
return ["tool_two", Send("tool_one", state)]
tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy())
tool_two_graph.add_node("tool_one", tool_one)
tool_two_graph.set_conditional_entry_point(start)
tool_two = tool_two_graph.compile()
tracer = FakeTracer()
assert await tool_two.ainvoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value one",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value one"}
assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == {
"my_key": "value one all good",
"market": "US",
}
async with awith_checkpointer(checkpointer_name) as checkpointer:
tool_two = tool_two_graph.compile(checkpointer=checkpointer)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
await tool_two.ainvoke({"my_key": "value", "market": "DE"})
# flow: interrupt -> resume with answer
thread2 = {"configurable": {"thread_id": "2"}}
# stop when about to enter node
assert [
c
async for c in tool_two.astream(
{"my_key": "value ⛰️", "market": "DE"}, thread2
)
] == [
{
"tool_one": {"my_key": " one"},
},
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
)
},
]
# resume with answer
assert [
c async for c in tool_two.astream(Command(resume=" my answer"), thread2)
] == [
{"tool_two": {"my_key": " my answer"}},
]
# flow: interrupt -> clear tasks
thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert await tool_two.ainvoke(
{"my_key": "value ⛰️", "market": "DE"}, thread1
) == {
"my_key": "value ⛰️ one",
"market": "DE",
}
assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": {"tool_one": {"my_key": " one"}},
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}},
"thread_id": "1",
},
]
tup = await tool_two.checkpointer.aget_tuple(thread1)
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️ one", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
),
),
),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {"tool_one": {"my_key": " one"}},
"thread_id": "1",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)
# clear the interrupt and next tasks
await tool_two.aupdate_state(thread1, None)
# interrupt is cleared, next task is kept
tup = await tool_two.checkpointer.aget_tuple(thread1)
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️ one", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(),
),
),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {},
"thread_id": "1",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_node_not_cancelled_on_other_node_interrupted(
checkpointer_name: str,
) -> None:
class State(TypedDict):
hello: Annotated[str, operator.add]
awhiles = 0
inner_task_cancelled = False
async def awhile(input: State) -> None:
nonlocal awhiles
awhiles += 1
try:
await asyncio.sleep(1)
return {"hello": " again"}
except asyncio.CancelledError:
nonlocal inner_task_cancelled
inner_task_cancelled = True
raise
async def iambad(input: State) -> None:
return {"hello": interrupt("I am bad")}
builder = StateGraph(State)
builder.add_node("agent", awhile)
builder.add_node("bad", iambad)
builder.set_conditional_entry_point(lambda _: ["agent", "bad"], then=END)
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
thread = {"configurable": {"thread_id": "1"}}
# writes from "awhile" are applied to last chunk
assert await graph.ainvoke({"hello": "world"}, thread) == {
"hello": "world again"
}
assert not inner_task_cancelled
assert awhiles == 1
assert await graph.ainvoke(None, thread, debug=True) == {"hello": "world again"}
assert not inner_task_cancelled
assert awhiles == 1
# resume with answer
assert await graph.ainvoke(Command(resume=" okay"), thread) == {
"hello": "world again okay"
}
assert not inner_task_cancelled
assert awhiles == 1
async def test_step_timeout_on_stream_hang() -> None:
inner_task_cancelled = False
async def awhile(input: Any) -> None:
try:
await asyncio.sleep(1.5)
except asyncio.CancelledError:
nonlocal inner_task_cancelled
inner_task_cancelled = True
raise
async def alittlewhile(input: Any) -> None:
await asyncio.sleep(0.6)
return "1"
builder = Graph()
builder.add_node(awhile)
builder.add_node(alittlewhile)
builder.set_conditional_entry_point(lambda _: ["awhile", "alittlewhile"], then=END)
graph = builder.compile()
graph.step_timeout = 1
with pytest.raises(asyncio.TimeoutError):
async for chunk in graph.astream(1, stream_mode="updates"):
assert chunk == {"alittlewhile": {"alittlewhile": "1"}}
await asyncio.sleep(0.6)
assert inner_task_cancelled
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC_PLUS_NONE)
async def test_cancel_graph_astream(checkpointer_name: str) -> None:
class State(TypedDict):
value: Annotated[int, operator.add]
class AwhileMaker:
def __init__(self) -> None:
self.reset()
async def __call__(self, input: State) -> Any:
self.started = True
try:
await asyncio.sleep(1.5)
except asyncio.CancelledError:
self.cancelled = True
raise
def reset(self):
self.started = False
self.cancelled = False
async def alittlewhile(input: State) -> None:
await asyncio.sleep(0.6)
return {"value": 2}
awhile = AwhileMaker()
aparallelwhile = AwhileMaker()
builder = StateGraph(State)
builder.add_node("awhile", awhile)
builder.add_node("aparallelwhile", aparallelwhile)
builder.add_node(alittlewhile)
builder.add_edge(START, "alittlewhile")
builder.add_edge(START, "aparallelwhile")
builder.add_edge("alittlewhile", "awhile")
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
# test interrupting astream
got_event = False
thread1: RunnableConfig = {"configurable": {"thread_id": "1"}}
async with aclosing(graph.astream({"value": 1}, thread1)) as stream:
async for chunk in stream:
assert chunk == {"alittlewhile": {"value": 2}}
got_event = True
break
assert got_event
# node aparallelwhile should start, but be cancelled
assert aparallelwhile.started is True
assert aparallelwhile.cancelled is True
# node "awhile" should never start
assert awhile.started is False
# checkpoint with output of "alittlewhile" should not be saved
# but we should have applied pending writes
if checkpointer is not None:
state = await graph.aget_state(thread1)
assert state is not None
assert state.values == {"value": 3} # 1 + 2
assert state.next == ("aparallelwhile",)
assert state.metadata == {
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
}
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC_PLUS_NONE)
async def test_cancel_graph_astream_events_v2(checkpointer_name: Optional[str]) -> None:
class State(TypedDict):
value: int
class AwhileMaker:
def __init__(self) -> None:
self.reset()
async def __call__(self, input: State) -> Any:
self.started = True
try:
await asyncio.sleep(1.5)
except asyncio.CancelledError:
self.cancelled = True
raise
def reset(self):
self.started = False
self.cancelled = False
async def alittlewhile(input: State) -> None:
await asyncio.sleep(0.6)
return {"value": 2}
awhile = AwhileMaker()
anotherwhile = AwhileMaker()
builder = StateGraph(State)
builder.add_node(alittlewhile)
builder.add_node("awhile", awhile)
builder.add_node("anotherwhile", anotherwhile)
builder.add_edge(START, "alittlewhile")
builder.add_edge("alittlewhile", "awhile")
builder.add_edge("awhile", "anotherwhile")
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
# test interrupting astream_events v2
got_event = False
thread2: RunnableConfig = {"configurable": {"thread_id": "2"}}
async with aclosing(
graph.astream_events({"value": 1}, thread2, version="v2")
) as stream:
async for chunk in stream:
if chunk["event"] == "on_chain_stream" and not chunk["parent_ids"]:
got_event = True
assert chunk["data"]["chunk"] == {"alittlewhile": {"value": 2}}
await asyncio.sleep(0.1)
break
# did break
assert got_event
# node "awhile" maybe starts (impl detail of astream_events)
# if it does start, it must be cancelled
if awhile.started:
assert awhile.cancelled is True
# node "anotherwhile" should never start
assert anotherwhile.started is False
# checkpoint with output of "alittlewhile" should not be saved
if checkpointer is not None:
state = await graph.aget_state(thread2)
assert state is not None
assert state.values == {"value": 2}
assert state.next == ("awhile",)
assert state.metadata == {
"parents": {},
"source": "loop",
"step": 1,
"writes": {"alittlewhile": {"value": 2}},
"thread_id": "2",
}
async def test_node_schemas_custom_output() -> None:
class State(TypedDict):
hello: str
bye: str
messages: Annotated[list[str], add_messages]
class Output(TypedDict):
messages: list[str]
class StateForA(TypedDict):
hello: str
messages: Annotated[list[str], add_messages]
async def node_a(state: StateForA):
assert state == {
"hello": "there",
"messages": [_AnyIdHumanMessage(content="hello")],
}
class StateForB(TypedDict):
bye: str
now: int
async def node_b(state: StateForB):
assert state == {
"bye": "world",
}
return {
"now": 123,
"hello": "again",
}
class StateForC(TypedDict):
hello: str
now: int
async def node_c(state: StateForC):
assert state == {
"hello": "again",
"now": 123,
}
builder = StateGraph(State, output=Output)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_node("c", node_c)
builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("b", "c")
graph = builder.compile()
assert await graph.ainvoke(
{"hello": "there", "bye": "world", "messages": "hello"}
) == {
"messages": [_AnyIdHumanMessage(content="hello")],
}
builder = StateGraph(State, output=Output)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_node("c", node_c)
builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("b", "c")
graph = builder.compile()
assert await graph.ainvoke(
{
"hello": "there",
"bye": "world",
"messages": "hello",
"now": 345, # ignored because not in input schema
}
) == {
"messages": [_AnyIdHumanMessage(content="hello")],
}
assert [
c
async for c in graph.astream(
{
"hello": "there",
"bye": "world",
"messages": "hello",
"now": 345, # ignored because not in input schema
}
)
] == [
{"a": None},
{"b": {"hello": "again", "now": 123}},
{"c": None},
]
async def test_invoke_single_process_in_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
app = Pregel(
nodes={
"one": chain,
},
channels={
"input": LastValue(int),
"output": LastValue(int),
},
input_channels="input",
output_channels="output",
)
graph = Graph()
graph.add_node("add_one", add_one)
graph.set_entry_point("add_one")
graph.set_finish_point("add_one")
gapp = graph.compile()
if SHOULD_CHECK_SNAPSHOTS:
assert app.input_schema.model_json_schema() == {
"title": "LangGraphInput",
"type": "integer",
}
assert app.output_schema.model_json_schema() == {
"title": "LangGraphOutput",
"type": "integer",
}
assert await app.ainvoke(2) == 3
assert await app.ainvoke(2, output_keys=["output"]) == {"output": 3}
assert await gapp.ainvoke(2) == 3
@pytest.mark.parametrize(
"falsy_value",
[None, False, 0, "", [], {}, set(), frozenset(), 0.0, 0j],
)
async def test_invoke_single_process_in_out_falsy_values(falsy_value: Any) -> None:
graph = Graph()
graph.add_node("return_falsy_const", lambda *args, **kwargs: falsy_value)
graph.set_entry_point("return_falsy_const")
graph.set_finish_point("return_falsy_const")
gapp = graph.compile()
assert falsy_value == await gapp.ainvoke(1)
async def test_invoke_single_process_in_write_kwargs(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
chain = (
Channel.subscribe_to("input")
| add_one
| Channel.write_to("output", fixed=5, output_plus_one=lambda x: x + 1)
)
app = Pregel(
nodes={"one": chain},
channels={
"input": LastValue(int),
"output": LastValue(int),
"fixed": LastValue(int),
"output_plus_one": LastValue(int),
},
output_channels=["output", "fixed", "output_plus_one"],
input_channels="input",
)
if SHOULD_CHECK_SNAPSHOTS:
assert app.input_schema.model_json_schema() == {
"title": "LangGraphInput",
"type": "integer",
}
assert app.output_schema.model_json_schema() == {
"title": "LangGraphOutput",
"type": "object",
"properties": {
"output": {"title": "Output", "type": "integer", "default": None},
"fixed": {"title": "Fixed", "type": "integer", "default": None},
"output_plus_one": {
"title": "Output Plus One",
"type": "integer",
"default": None,
},
},
}
assert await app.ainvoke(2) == {"output": 3, "fixed": 5, "output_plus_one": 4}
async def test_invoke_single_process_in_out_dict(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": chain},
channels={"input": LastValue(int), "output": LastValue(int)},
input_channels="input",
output_channels=["output"],
)
if SHOULD_CHECK_SNAPSHOTS:
assert app.input_schema.model_json_schema() == {
"title": "LangGraphInput",
"type": "integer",
}
assert app.output_schema.model_json_schema() == {
"title": "LangGraphOutput",
"type": "object",
"properties": {
"output": {"title": "Output", "type": "integer", "default": None}
},
}
assert await app.ainvoke(2) == {"output": 3}
async def test_invoke_single_process_in_dict_out_dict(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": chain},
channels={"input": LastValue(int), "output": LastValue(int)},
input_channels=["input"],
output_channels=["output"],
)
if SHOULD_CHECK_SNAPSHOTS:
assert app.input_schema.model_json_schema() == {
"title": "LangGraphInput",
"type": "object",
"properties": {
"input": {"title": "Input", "type": "integer", "default": None}
},
}
assert app.output_schema.model_json_schema() == {
"title": "LangGraphOutput",
"type": "object",
"properties": {
"output": {"title": "Output", "type": "integer", "default": None}
},
}
assert await app.ainvoke({"input": 2}) == {"output": 3}
async def test_invoke_two_processes_in_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
two = Channel.subscribe_to("inbox") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": one, "two": two},
channels={
"inbox": LastValue(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
stream_channels=["inbox", "output"],
)
assert await app.ainvoke(2) == 4
with pytest.raises(GraphRecursionError):
await app.ainvoke(2, {"recursion_limit": 1})
step = 0
async for values in app.astream(2):
step += 1
if step == 1:
assert values == {
"inbox": 3,
}
elif step == 2:
assert values == {
"inbox": 3,
"output": 4,
}
assert step == 2
graph = Graph()
graph.add_node("add_one", add_one)
graph.add_node("add_one_more", add_one)
graph.set_entry_point("add_one")
graph.set_finish_point("add_one_more")
graph.add_edge("add_one", "add_one_more")
gapp = graph.compile()
assert await gapp.ainvoke(2) == 4
step = 0
async for values in gapp.astream(2):
step += 1
if step == 1:
assert values == {
"add_one": 3,
}
elif step == 2:
assert values == {
"add_one_more": 4,
}
assert step == 2
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_invoke_two_processes_in_out_interrupt(
checkpointer_name: str, mocker: MockerFixture
) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
two = Channel.subscribe_to("inbox") | add_one | Channel.write_to("output")
async with awith_checkpointer(checkpointer_name) as checkpointer:
app = Pregel(
nodes={"one": one, "two": two},
channels={
"inbox": LastValue(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
checkpointer=checkpointer,
interrupt_after_nodes=["one"],
)
thread1 = {"configurable": {"thread_id": "1"}}
thread2 = {"configurable": {"thread_id": "2"}}
# start execution, stop at inbox
assert await app.ainvoke(2, thread1) is None
# inbox == 3
checkpoint = await checkpointer.aget(thread1)
assert checkpoint is not None
assert checkpoint["channel_values"]["inbox"] == 3
# resume execution, finish
assert await app.ainvoke(None, thread1) == 4
# start execution again, stop at inbox
assert await app.ainvoke(20, thread1) is None
# inbox == 21
checkpoint = await checkpointer.aget(thread1)
assert checkpoint is not None
assert checkpoint["channel_values"]["inbox"] == 21
# send a new value in, interrupting the previous execution
assert await app.ainvoke(3, thread1) is None
assert await app.ainvoke(None, thread1) == 5
# start execution again, stopping at inbox
assert await app.ainvoke(20, thread2) is None
# inbox == 21
snapshot = await app.aget_state(thread2)
assert snapshot.values["inbox"] == 21
assert snapshot.next == ("two",)
# update the state, resume
await app.aupdate_state(thread2, 25, as_node="one")
assert await app.ainvoke(None, thread2) == 26
# no pending tasks
snapshot = await app.aget_state(thread2)
assert snapshot.next == ()
# list history
history = [c async for c in app.aget_state_history(thread1)]
assert history == [
StateSnapshot(
values={"inbox": 4, "output": 5, "input": 3},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 6,
"writes": {"two": 5},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[1].config,
),
StateSnapshot(
values={"inbox": 4, "output": 4, "input": 3},
tasks=(
PregelTask(AnyStr(), "two", (PULL, "two"), result={"output": 5}),
),
next=("two",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 5,
"writes": {"one": None},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[2].config,
),
StateSnapshot(
values={"inbox": 21, "output": 4, "input": 3},
tasks=(
PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 4}),
),
next=("one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"step": 4,
"writes": {"input": 3},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[3].config,
),
StateSnapshot(
values={"inbox": 21, "output": 4, "input": 20},
tasks=(PregelTask(AnyStr(), "two", (PULL, "two")),),
next=("two",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"one": None},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[4].config,
),
StateSnapshot(
values={"inbox": 3, "output": 4, "input": 20},
tasks=(
PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 21}),
),
next=("one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"step": 2,
"writes": {"input": 20},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[5].config,
),
StateSnapshot(
values={"inbox": 3, "output": 4, "input": 2},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"two": 4},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[6].config,
),
StateSnapshot(
values={"inbox": 3, "input": 2},
tasks=(
PregelTask(AnyStr(), "two", (PULL, "two"), result={"output": 4}),
),
next=("two",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {"one": None},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[7].config,
),
StateSnapshot(
values={"input": 2},
tasks=(
PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 3}),
),
next=("one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"step": -1,
"writes": {"input": 2},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
),
]
# forking from any previous checkpoint should re-run nodes
assert [
c async for c in app.astream(None, history[0].config, stream_mode="updates")
] == []
assert [
c async for c in app.astream(None, history[1].config, stream_mode="updates")
] == [
{"two": {"output": 5}},
]
assert [
c async for c in app.astream(None, history[2].config, stream_mode="updates")
] == [
{"one": {"inbox": 4}},
{"__interrupt__": ()},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_fork_always_re_runs_nodes(
checkpointer_name: str, mocker: MockerFixture
) -> None:
add_one = mocker.Mock(side_effect=lambda _: 1)
builder = StateGraph(Annotated[int, operator.add])
builder.add_node("add_one", add_one)
builder.add_edge(START, "add_one")
builder.add_conditional_edges("add_one", lambda cnt: "add_one" if cnt < 6 else END)
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
thread1 = {"configurable": {"thread_id": "1"}}
# start execution, stop at inbox
assert [
c
async for c in graph.astream(1, thread1, stream_mode=["values", "updates"])
] == [
("values", 1),
("updates", {"add_one": 1}),
("values", 2),
("updates", {"add_one": 1}),
("values", 3),
("updates", {"add_one": 1}),
("values", 4),
("updates", {"add_one": 1}),
("values", 5),
("updates", {"add_one": 1}),
("values", 6),
]
# list history
history = [c async for c in graph.aget_state_history(thread1)]
assert history == [
StateSnapshot(
values=6,
next=(),
tasks=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 5,
"writes": {"add_one": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[1].config,
),
StateSnapshot(
values=5,
tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),),
next=("add_one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 4,
"writes": {"add_one": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[2].config,
),
StateSnapshot(
values=4,
tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),),
next=("add_one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"add_one": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[3].config,
),
StateSnapshot(
values=3,
tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),),
next=("add_one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 2,
"writes": {"add_one": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[4].config,
),
StateSnapshot(
values=2,
tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),),
next=("add_one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"add_one": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[5].config,
),
StateSnapshot(
values=1,
tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),),
next=("add_one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[6].config,
),
StateSnapshot(
values=0,
tasks=(
PregelTask(AnyStr(), "__start__", (PULL, "__start__"), result=1),
),
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
),
]
# forking from any previous checkpoint should re-run nodes
assert [
c
async for c in graph.astream(None, history[0].config, stream_mode="updates")
] == []
assert [
c
async for c in graph.astream(None, history[1].config, stream_mode="updates")
] == [
{"add_one": 1},
]
assert [
c
async for c in graph.astream(None, history[2].config, stream_mode="updates")
] == [
{"add_one": 1},
{"add_one": 1},
]
async def test_invoke_two_processes_in_dict_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
two = (
Channel.subscribe_to("inbox")
| RunnableLambda(add_one).abatch
| Channel.write_to("output").abatch
)
app = Pregel(
nodes={"one": one, "two": two},
channels={
"inbox": Topic(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels=["input", "inbox"],
stream_channels=["output", "inbox"],
output_channels=["output"],
)
# [12 + 1, 2 + 1 + 1]
assert [
c
async for c in app.astream(
{"input": 2, "inbox": 12}, output_keys="output", stream_mode="updates"
)
] == [
{"one": None},
{"two": 13},
{"two": 4},
]
assert [
c async for c in app.astream({"input": 2, "inbox": 12}, output_keys="output")
] == [13, 4]
assert [
c async for c in app.astream({"input": 2, "inbox": 12}, stream_mode="updates")
] == [
{"one": {"inbox": 3}},
{"two": {"output": 13}},
{"two": {"output": 4}},
]
assert [c async for c in app.astream({"input": 2, "inbox": 12})] == [
{"inbox": [3], "output": 13},
{"output": 4},
]
assert [
c async for c in app.astream({"input": 2, "inbox": 12}, stream_mode="debug")
] == [
{
"type": "task",
"timestamp": AnyStr(),
"step": 0,
"payload": {
"id": AnyStr(),
"name": "one",
"input": 2,
"triggers": ["input"],
},
},
{
"type": "task",
"timestamp": AnyStr(),
"step": 0,
"payload": {
"id": AnyStr(),
"name": "two",
"input": [12],
"triggers": ["inbox"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 0,
"payload": {
"id": AnyStr(),
"name": "one",
"result": [("inbox", 3)],
"error": None,
"interrupts": [],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 0,
"payload": {
"id": AnyStr(),
"name": "two",
"result": [("output", 13)],
"error": None,
"interrupts": [],
},
},
{
"type": "task",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "two",
"input": [3],
"triggers": ["inbox"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "two",
"result": [("output", 4)],
"error": None,
"interrupts": [],
},
},
]
async def test_batch_two_processes_in_out() -> None:
async def add_one_with_delay(inp: int) -> int:
await asyncio.sleep(inp / 10)
return inp + 1
one = Channel.subscribe_to("input") | add_one_with_delay | Channel.write_to("one")
two = Channel.subscribe_to("one") | add_one_with_delay | Channel.write_to("output")
app = Pregel(
nodes={"one": one, "two": two},
channels={
"one": LastValue(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
)
assert await app.abatch([3, 2, 1, 3, 5]) == [5, 4, 3, 5, 7]
assert await app.abatch([3, 2, 1, 3, 5], output_keys=["output"]) == [
{"output": 5},
{"output": 4},
{"output": 3},
{"output": 5},
{"output": 7},
]
graph = Graph()
graph.add_node("add_one", add_one_with_delay)
graph.add_node("add_one_more", add_one_with_delay)
graph.set_entry_point("add_one")
graph.set_finish_point("add_one_more")
graph.add_edge("add_one", "add_one_more")
gapp = graph.compile()
assert await gapp.abatch([3, 2, 1, 3, 5]) == [5, 4, 3, 5, 7]
async def test_invoke_many_processes_in_out(mocker: MockerFixture) -> None:
test_size = 100
add_one = mocker.Mock(side_effect=lambda x: x + 1)
nodes = {"-1": Channel.subscribe_to("input") | add_one | Channel.write_to("-1")}
for i in range(test_size - 2):
nodes[str(i)] = (
Channel.subscribe_to(str(i - 1)) | add_one | Channel.write_to(str(i))
)
nodes["last"] = Channel.subscribe_to(str(i)) | add_one | Channel.write_to("output")
app = Pregel(
nodes=nodes,
channels={str(i): LastValue(int) for i in range(-1, test_size - 2)}
| {"input": LastValue(int), "output": LastValue(int)},
input_channels="input",
output_channels="output",
)
# No state is left over from previous invocations
for _ in range(10):
assert await app.ainvoke(2, {"recursion_limit": test_size}) == 2 + test_size
# Concurrent invocations do not interfere with each other
assert await asyncio.gather(
*(app.ainvoke(2, {"recursion_limit": test_size}) for _ in range(10))
) == [2 + test_size for _ in range(10)]
async def test_batch_many_processes_in_out(mocker: MockerFixture) -> None:
test_size = 100
add_one = mocker.Mock(side_effect=lambda x: x + 1)
nodes = {"-1": Channel.subscribe_to("input") | add_one | Channel.write_to("-1")}
for i in range(test_size - 2):
nodes[str(i)] = (
Channel.subscribe_to(str(i - 1)) | add_one | Channel.write_to(str(i))
)
nodes["last"] = Channel.subscribe_to(str(i)) | add_one | Channel.write_to("output")
app = Pregel(
nodes=nodes,
channels={str(i): LastValue(int) for i in range(-1, test_size - 2)}
| {"input": LastValue(int), "output": LastValue(int)},
input_channels="input",
output_channels="output",
)
# No state is left over from previous invocations
for _ in range(3):
# Then invoke pubsub
assert await app.abatch([2, 1, 3, 4, 5], {"recursion_limit": test_size}) == [
2 + test_size,
1 + test_size,
3 + test_size,
4 + test_size,
5 + test_size,
]
# Concurrent invocations do not interfere with each other
assert await asyncio.gather(
*(app.abatch([2, 1, 3, 4, 5], {"recursion_limit": test_size}) for _ in range(3))
) == [
[2 + test_size, 1 + test_size, 3 + test_size, 4 + test_size, 5 + test_size]
for _ in range(3)
]
async def test_invoke_two_processes_two_in_two_out_invalid(
mocker: MockerFixture,
) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
two = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": one, "two": two},
channels={"output": LastValue(int), "input": LastValue(int)},
input_channels="input",
output_channels="output",
)
with pytest.raises(InvalidUpdateError):
# LastValue channels can only be updated once per iteration
await app.ainvoke(2)
async def test_invoke_two_processes_two_in_two_out_valid(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
two = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": one, "two": two},
channels={
"input": LastValue(int),
"output": Topic(int),
},
input_channels="input",
output_channels="output",
)
# An Topic channel accumulates updates into a sequence
assert await app.ainvoke(2) == [3, 3]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_invoke_checkpoint(mocker: MockerFixture, checkpointer_name: str) -> None:
add_one = mocker.Mock(side_effect=lambda x: x["total"] + x["input"])
errored_once = False
def raise_if_above_10(input: int) -> int:
nonlocal errored_once
if input > 4:
if errored_once:
pass
else:
errored_once = True
raise ConnectionError("I will be retried")
if input > 10:
raise ValueError("Input is too large")
return input
one = (
Channel.subscribe_to(["input"]).join(["total"])
| add_one
| Channel.write_to("output", "total")
| raise_if_above_10
)
async with awith_checkpointer(checkpointer_name) as checkpointer:
app = Pregel(
nodes={"one": one},
channels={
"total": BinaryOperatorAggregate(int, operator.add),
"input": LastValue(int),
"output": LastValue(int),
},
input_channels="input",
output_channels="output",
checkpointer=checkpointer,
retry_policy=RetryPolicy(),
)
# total starts out as 0, so output is 0+2=2
assert await app.ainvoke(2, {"configurable": {"thread_id": "1"}}) == 2
checkpoint = await checkpointer.aget({"configurable": {"thread_id": "1"}})
assert checkpoint is not None
assert checkpoint["channel_values"].get("total") == 2
# total is now 2, so output is 2+3=5
assert await app.ainvoke(3, {"configurable": {"thread_id": "1"}}) == 5
assert errored_once, "errored and retried"
checkpoint = await checkpointer.aget({"configurable": {"thread_id": "1"}})
assert checkpoint is not None
assert checkpoint["channel_values"].get("total") == 7
# total is now 2+5=7, so output would be 7+4=11, but raises ValueError
with pytest.raises(ValueError):
await app.ainvoke(4, {"configurable": {"thread_id": "1"}})
# checkpoint is not updated
checkpoint = await checkpointer.aget({"configurable": {"thread_id": "1"}})
assert checkpoint is not None
assert checkpoint["channel_values"].get("total") == 7
# on a new thread, total starts out as 0, so output is 0+5=5
assert await app.ainvoke(5, {"configurable": {"thread_id": "2"}}) == 5
checkpoint = await checkpointer.aget({"configurable": {"thread_id": "1"}})
assert checkpoint is not None
assert checkpoint["channel_values"].get("total") == 7
checkpoint = await checkpointer.aget({"configurable": {"thread_id": "2"}})
assert checkpoint is not None
assert checkpoint["channel_values"].get("total") == 5
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_pending_writes_resume(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
class State(TypedDict):
value: Annotated[int, operator.add]
class AwhileMaker:
def __init__(self, sleep: float, rtn: Union[Dict, Exception]) -> None:
self.sleep = sleep
self.rtn = rtn
self.reset()
async def __call__(self, input: State) -> Any:
self.calls += 1
await asyncio.sleep(self.sleep)
if isinstance(self.rtn, Exception):
raise self.rtn
else:
return self.rtn
def reset(self):
self.calls = 0
one = AwhileMaker(0.1, {"value": 2})
two = AwhileMaker(0.3, ConnectionError("I'm not good"))
builder = StateGraph(State)
builder.add_node("one", one)
builder.add_node("two", two, retry=RetryPolicy(max_attempts=2))
builder.add_edge(START, "one")
builder.add_edge(START, "two")
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
thread1: RunnableConfig = {"configurable": {"thread_id": "1"}}
with pytest.raises(ConnectionError, match="I'm not good"):
await graph.ainvoke({"value": 1}, thread1)
# both nodes should have been called once
assert one.calls == 1
assert two.calls == 2
# latest checkpoint should be before nodes "one", "two"
# but we should have applied pending writes from "one"
state = await graph.aget_state(thread1)
assert state is not None
assert state.values == {"value": 3}
assert state.next == ("two",)
assert state.tasks == (
PregelTask(AnyStr(), "one", (PULL, "one"), result={"value": 2}),
PregelTask(
AnyStr(),
"two",
(PULL, "two"),
'ConnectionError("I\'m not good")',
),
)
assert state.metadata == {
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
}
# get_state with checkpoint_id should not apply any pending writes
state = await graph.aget_state(state.config)
assert state is not None
assert state.values == {"value": 1}
assert state.next == ("one", "two")
# should contain pending write of "one"
checkpoint = await checkpointer.aget_tuple(thread1)
assert checkpoint is not None
# should contain error from "two"
expected_writes = [
(AnyStr(), "one", "one"),
(AnyStr(), "value", 2),
(AnyStr(), ERROR, 'ConnectionError("I\'m not good")'),
]
assert len(checkpoint.pending_writes) == 3
assert all(w in expected_writes for w in checkpoint.pending_writes)
# both non-error pending writes come from same task
non_error_writes = [w for w in checkpoint.pending_writes if w[1] != ERROR]
assert non_error_writes[0][0] == non_error_writes[1][0]
# error write is from the other task
error_write = next(w for w in checkpoint.pending_writes if w[1] == ERROR)
assert error_write[0] != non_error_writes[0][0]
# resume execution
with pytest.raises(ConnectionError, match="I'm not good"):
await graph.ainvoke(None, thread1)
# node "one" succeeded previously, so shouldn't be called again
assert one.calls == 1
# node "two" should have been called once again
assert two.calls == 4
# confirm no new checkpoints saved
state_two = await graph.aget_state(thread1)
assert state_two.metadata == state.metadata
# resume execution, without exception
two.rtn = {"value": 3}
# both the pending write and the new write were applied, 1 + 2 + 3 = 6
assert await graph.ainvoke(None, thread1) == {"value": 6}
# check all final checkpoints
checkpoints = [c async for c in checkpointer.alist(thread1)]
# we should have 3
assert len(checkpoints) == 3
# the last one not too interesting for this test
assert checkpoints[0] == CheckpointTuple(
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
checkpoint={
"v": 1,
"id": AnyStr(),
"ts": AnyStr(),
"pending_sends": [],
"versions_seen": {
"one": {
"start:one": AnyVersion(),
},
"two": {
"start:two": AnyVersion(),
},
"__input__": {},
"__start__": {
"__start__": AnyVersion(),
},
"__interrupt__": {
"value": AnyVersion(),
"__start__": AnyVersion(),
"start:one": AnyVersion(),
"start:two": AnyVersion(),
},
},
"channel_versions": {
"one": AnyVersion(),
"two": AnyVersion(),
"value": AnyVersion(),
"__start__": AnyVersion(),
"start:one": AnyVersion(),
"start:two": AnyVersion(),
},
"channel_values": {"one": "one", "two": "two", "value": 6},
},
metadata={
"parents": {},
"step": 1,
"source": "loop",
"writes": {"one": {"value": 2}, "two": {"value": 3}},
"thread_id": "1",
},
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": checkpoints[1].config["configurable"][
"checkpoint_id"
],
}
},
pending_writes=[],
)
# the previous one we assert that pending writes contains both
# - original error
# - successful writes from resuming after preventing error
assert checkpoints[1] == CheckpointTuple(
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
checkpoint={
"v": 1,
"id": AnyStr(),
"ts": AnyStr(),
"pending_sends": [],
"versions_seen": {
"__input__": {},
"__start__": {
"__start__": AnyVersion(),
},
},
"channel_versions": {
"value": AnyVersion(),
"__start__": AnyVersion(),
"start:one": AnyVersion(),
"start:two": AnyVersion(),
},
"channel_values": {
"value": 1,
"start:one": "__start__",
"start:two": "__start__",
},
},
metadata={
"parents": {},
"step": 0,
"source": "loop",
"writes": None,
"thread_id": "1",
},
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": checkpoints[2].config["configurable"][
"checkpoint_id"
],
}
},
pending_writes=UnsortedSequence(
(AnyStr(), "one", "one"),
(AnyStr(), "value", 2),
(AnyStr(), "__error__", 'ConnectionError("I\'m not good")'),
(AnyStr(), "two", "two"),
(AnyStr(), "value", 3),
),
)
assert checkpoints[2] == CheckpointTuple(
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
checkpoint={
"v": 1,
"id": AnyStr(),
"ts": AnyStr(),
"pending_sends": [],
"versions_seen": {"__input__": {}},
"channel_versions": {
"__start__": AnyVersion(),
},
"channel_values": {"__start__": {"value": 1}},
},
metadata={
"parents": {},
"step": -1,
"source": "input",
"writes": {"__start__": {"value": 1}},
"thread_id": "1",
},
parent_config=None,
pending_writes=UnsortedSequence(
(AnyStr(), "value", 1),
(AnyStr(), "start:one", "__start__"),
(AnyStr(), "start:two", "__start__"),
),
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_run_from_checkpoint_id_retains_previous_writes(
request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture
) -> None:
class MyState(TypedDict):
myval: Annotated[int, operator.add]
otherval: bool
class Anode:
def __init__(self):
self.switch = False
async def __call__(self, state: MyState):
self.switch = not self.switch
return {"myval": 2 if self.switch else 1, "otherval": self.switch}
builder = StateGraph(MyState)
thenode = Anode() # Fun.
builder.add_node("node_one", thenode)
builder.add_node("node_two", thenode)
builder.add_edge(START, "node_one")
def _getedge(src: str):
swap = "node_one" if src == "node_two" else "node_two"
def _edge(st: MyState) -> Literal["__end__", "node_one", "node_two"]:
if st["myval"] > 3:
return END
if st["otherval"]:
return swap
return src
return _edge
builder.add_conditional_edges("node_one", _getedge("node_one"))
builder.add_conditional_edges("node_two", _getedge("node_two"))
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
thread_id = uuid.uuid4()
thread1 = {"configurable": {"thread_id": str(thread_id)}}
result = await graph.ainvoke({"myval": 1}, thread1)
assert result["myval"] == 4
history = [c async for c in graph.aget_state_history(thread1)]
assert len(history) == 4
assert history[-1].values == {"myval": 0}
assert history[0].values == {"myval": 4, "otherval": False}
second_run_config = {
**thread1,
"configurable": {
**thread1["configurable"],
"checkpoint_id": history[1].config["configurable"]["checkpoint_id"],
},
}
second_result = await graph.ainvoke(None, second_run_config)
assert second_result == {"myval": 5, "otherval": True}
new_history = [
c
async for c in graph.aget_state_history(
{"configurable": {"thread_id": str(thread_id), "checkpoint_ns": ""}}
)
]
assert len(new_history) == len(history) + 1
for original, new in zip(history, new_history[1:]):
assert original.values == new.values
assert original.next == new.next
assert original.metadata["step"] == new.metadata["step"]
def _get_tasks(hist: list, start: int):
return [h.tasks for h in hist[start:]]
assert _get_tasks(new_history, 1) == _get_tasks(history, 0)
async def test_cond_edge_after_send() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)
async def __call__(self, state):
return [self.name]
async def send_for_fun(state):
return [Send("2", state), Send("2", state)]
async def route_to_three(state) -> Literal["3"]:
return "3"
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert await graph.ainvoke(["0"]) == ["0", "1", "2", "2", "3"]
async def test_concurrent_emit_sends() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)
async def __call__(self, state):
return (
[self.name]
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
async def send_for_fun(state):
return [Send("2", 1), Send("2", 2), "3.1"]
async def send_for_profit(state):
return [Send("2", 3), Send("2", 4)]
async def route_to_three(state) -> Literal["3"]:
return "3"
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("1.1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_edge(START, "1")
builder.add_edge(START, "1.1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("1.1", send_for_profit)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert await graph.ainvoke(["0"]) == (
[
"0",
"1",
"1.1",
"2|1",
"2|2",
"2|3",
"2|4",
"3",
"3.1",
]
if FF_SEND_V2
else [
"0",
"1",
"1.1",
"3.1",
"2|1",
"2|2",
"2|3",
"2|4",
"3",
]
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_send_sequences(checkpointer_name: str) -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)
async def __call__(self, state):
update = (
[self.name]
if isinstance(state, list) # or isinstance(state, Control)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Command):
return replace(state, update=update)
else:
return update
async def send_for_fun(state):
return [
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("2", 4))),
"3.1",
]
async def route_to_three(state) -> Literal["3"]:
return "3"
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert (
await graph.ainvoke(["0"])
== [
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"2|3",
"2|4",
"3",
"3.1",
]
if FF_SEND_V2
else [
"0",
"1",
"3.1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"3",
"2|3",
"2|4",
"3",
]
)
if not FF_SEND_V2:
return
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["3.1"])
thread1 = {"configurable": {"thread_id": "1"}}
assert await graph.ainvoke(["0"], thread1) == [
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"2|3",
"2|4",
]
assert await graph.ainvoke(None, thread1) == [
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"2|3",
"2|4",
"3",
"3.1",
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_send_dedupe_on_resume(checkpointer_name: str) -> None:
if not FF_SEND_V2:
pytest.skip("Send deduplication is only available in Send V2")
class InterruptOnce:
ticks: int = 0
def __call__(self, state):
self.ticks += 1
if self.ticks == 1:
raise NodeInterrupt("Bahh")
return ["|".join(("flaky", str(state)))]
class Node:
def __init__(self, name: str):
self.name = name
self.ticks = 0
setattr(self, "__name__", name)
def __call__(self, state):
self.ticks += 1
update = (
[self.name]
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Command):
return replace(state, update=update)
else:
return update
def send_for_fun(state):
return [
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("flaky", 4))),
"3.1",
]
def route_to_three(state) -> Literal["3"]:
return "3"
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_node("flaky", InterruptOnce())
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
thread1 = {"configurable": {"thread_id": "1"}}
assert await graph.ainvoke(["0"], thread1, debug=1) == [
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
]
assert builder.nodes["2"].runnable.func.ticks == 3
assert builder.nodes["flaky"].runnable.func.ticks == 1
# resume execution
assert await graph.ainvoke(None, thread1, debug=1) == [
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
"3.1",
]
# node "2" doesn't get called again, as we recover writes saved before
assert builder.nodes["2"].runnable.func.ticks == 3
# node "flaky" gets called again, as it was interrupted
assert builder.nodes["flaky"].runnable.func.ticks == 2
# check history
history = [c async for c in graph.aget_state_history(thread1)]
assert history == [
StateSnapshot(
values=[
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
"3.1",
],
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {"3": ["3"], "3.1": ["3.1"]},
"thread_id": "1",
"step": 2,
"parents": {},
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(),
),
StateSnapshot(
values=[
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
],
next=("3", "3.1"),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {
"1": ["1"],
"2": [
["2|Command(goto=Send(node='2', arg=3))"],
["2|Command(goto=Send(node='flaky', arg=4))"],
["2|3"],
],
"flaky": ["flaky|4"],
},
"thread_id": "1",
"step": 1,
"parents": {},
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="3",
path=("__pregel_pull", "3"),
error=None,
interrupts=(),
state=None,
result=["3"],
),
PregelTask(
id=AnyStr(),
name="3.1",
path=("__pregel_pull", "3.1"),
error=None,
interrupts=(),
state=None,
result=["3.1"],
),
),
),
StateSnapshot(
values=["0"],
next=("1", "2", "2", "2", "flaky"),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": None,
"thread_id": "1",
"step": 0,
"parents": {},
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="1",
path=("__pregel_pull", "1"),
error=None,
interrupts=(),
state=None,
result=["1"],
),
PregelTask(
id=AnyStr(),
name="2",
path=(
"__pregel_push",
("__pregel_pull", "1"),
2,
AnyStr(),
),
error=None,
interrupts=(),
state=None,
result=["2|Command(goto=Send(node='2', arg=3))"],
),
PregelTask(
id=AnyStr(),
name="2",
path=(
"__pregel_push",
("__pregel_pull", "1"),
3,
AnyStr(),
),
error=None,
interrupts=(),
state=None,
result=["2|Command(goto=Send(node='flaky', arg=4))"],
),
PregelTask(
id=AnyStr(),
name="2",
path=(
"__pregel_push",
(
"__pregel_push",
("__pregel_pull", "1"),
2,
AnyStr(),
),
2,
AnyStr(),
),
error=None,
interrupts=(),
state=None,
result=["2|3"],
),
PregelTask(
id=AnyStr(),
name="flaky",
path=(
"__pregel_push",
(
"__pregel_push",
("__pregel_pull", "1"),
3,
AnyStr(),
),
2,
AnyStr(),
),
error=None,
interrupts=(Interrupt(value="Bahh", when="during"),),
state=None,
result=["flaky|4"],
),
),
),
StateSnapshot(
values=[],
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "input",
"writes": {"__start__": ["0"]},
"thread_id": "1",
"step": -1,
"parents": {},
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
id=AnyStr(),
name="__start__",
path=("__pregel_pull", "__start__"),
error=None,
interrupts=(),
state=None,
result=["0"],
),
),
),
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_send_react_interrupt(checkpointer_name: str) -> None:
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
ai_message = AIMessage(
"",
id="ai1",
tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())],
)
async def agent(state):
return {"messages": ai_message}
def route(state):
if isinstance(state["messages"][-1], AIMessage):
return [
Send(call["name"], call) for call in state["messages"][-1].tool_calls
]
foo_called = 0
async def foo(call: ToolCall):
nonlocal foo_called
foo_called += 1
return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])}
builder = StateGraph(MessagesState)
builder.add_node(agent)
builder.add_node(foo)
builder.add_edge(START, "agent")
builder.add_conditional_edges("agent", route)
graph = builder.compile()
assert await graph.ainvoke({"messages": [HumanMessage("hello")]}) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
_AnyIdToolMessage(
content="{'hi': [1, 2, 3]}",
tool_call_id=AnyStr(),
),
]
}
assert foo_called == 1
async with awith_checkpointer(checkpointer_name) as checkpointer:
# simple interrupt-resume flow
foo_called = 0
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"])
thread1 = {"configurable": {"thread_id": "1"}}
assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
}
assert foo_called == 0
assert await graph.ainvoke(None, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
_AnyIdToolMessage(
content="{'hi': [1, 2, 3]}",
tool_call_id=AnyStr(),
),
]
}
assert foo_called == 1
# interrupt-update-resume flow
foo_called = 0
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"])
thread1 = {"configurable": {"thread_id": "2"}}
assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
}
assert foo_called == 0
if not FF_SEND_V2:
return
# get state should show the pending task
state = await graph.aget_state(thread1)
assert state == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
},
next=("foo",),
config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 0,
"source": "loop",
"writes": None,
"parents": {},
"thread_id": "2",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
content="",
additional_kwargs={},
response_metadata={},
id="ai1",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
)
},
),
PregelTask(
id=AnyStr(),
name="foo",
path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()),
error=None,
interrupts=(),
state=None,
result=None,
),
),
)
# remove the tool call, clearing the pending task
await graph.aupdate_state(
thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])}
)
# tool call no longer in pending tasks
assert await graph.aget_state(thread1) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="Bye now",
tool_calls=[],
),
]
},
next=(),
config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 1,
"source": "update",
"writes": {
"agent": {
"messages": _AnyIdAIMessage(
content="Bye now",
tool_calls=[],
)
}
},
"parents": {},
"thread_id": "2",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(),
)
# tool call not executed
assert await graph.ainvoke(None, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(content="Bye now"),
]
}
assert foo_called == 0
# interrupt-update-resume flow, creating new Send in update call
foo_called = 0
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"])
thread1 = {"configurable": {"thread_id": "3"}}
assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
}
assert foo_called == 0
# get state should show the pending task
state = await graph.aget_state(thread1)
assert state == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
},
next=("foo",),
config={
"configurable": {
"thread_id": "3",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 0,
"source": "loop",
"writes": None,
"parents": {},
"thread_id": "3",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "3",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
"",
id="ai1",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
)
},
),
PregelTask(
id=AnyStr(),
name="foo",
path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()),
error=None,
interrupts=(),
state=None,
result=None,
),
),
)
# replace the tool call, should clear previous send, create new one
await graph.aupdate_state(
thread1,
{
"messages": AIMessage(
"",
id=ai_message.id,
tool_calls=[
{
"name": "foo",
"args": {"hi": [4, 5, 6]},
"id": "tool1",
"type": "tool_call",
}
],
)
},
)
# prev tool call no longer in pending tasks, new tool call is
assert await graph.aget_state(thread1) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [4, 5, 6]},
"id": "tool1",
"type": "tool_call",
}
],
),
]
},
next=("foo",),
config={
"configurable": {
"thread_id": "3",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 1,
"source": "update",
"writes": {
"agent": {
"messages": _AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [4, 5, 6]},
"id": "tool1",
"type": "tool_call",
}
],
)
}
},
"parents": {},
"thread_id": "3",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "3",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="foo",
path=("__pregel_push", (), 0, AnyStr()),
error=None,
interrupts=(),
state=None,
result=None,
),
),
)
# prev tool call not executed, new tool call is
assert await graph.ainvoke(None, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
AIMessage(
"",
id="ai1",
tool_calls=[
{
"name": "foo",
"args": {"hi": [4, 5, 6]},
"id": "tool1",
"type": "tool_call",
}
],
),
_AnyIdToolMessage(content="{'hi': [4, 5, 6]}", tool_call_id="tool1"),
]
}
assert foo_called == 1
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_send_react_interrupt_control(
checkpointer_name: str, snapshot: SnapshotAssertion
) -> None:
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
ai_message = AIMessage(
"",
id="ai1",
tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())],
)
async def agent(state) -> Command[Literal["foo"]]:
return Command(
update={"messages": ai_message},
goto=[Send(call["name"], call) for call in ai_message.tool_calls],
)
foo_called = 0
async def foo(call: ToolCall):
nonlocal foo_called
foo_called += 1
return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])}
builder = StateGraph(MessagesState)
builder.add_node(agent)
builder.add_node(foo)
builder.add_edge(START, "agent")
graph = builder.compile()
assert graph.get_graph().draw_mermaid() == snapshot
assert await graph.ainvoke({"messages": [HumanMessage("hello")]}) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
_AnyIdToolMessage(
content="{'hi': [1, 2, 3]}",
tool_call_id=AnyStr(),
),
]
}
assert foo_called == 1
async with awith_checkpointer(checkpointer_name) as checkpointer:
# simple interrupt-resume flow
foo_called = 0
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"])
thread1 = {"configurable": {"thread_id": "1"}}
assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
}
assert foo_called == 0
assert await graph.ainvoke(None, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
_AnyIdToolMessage(
content="{'hi': [1, 2, 3]}",
tool_call_id=AnyStr(),
),
]
}
assert foo_called == 1
# interrupt-update-resume flow
foo_called = 0
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"])
thread1 = {"configurable": {"thread_id": "2"}}
assert await graph.ainvoke({"messages": [HumanMessage("hello")]}, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
}
assert foo_called == 0
if not FF_SEND_V2:
return
# get state should show the pending task
state = await graph.aget_state(thread1)
assert state == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
},
next=("foo",),
config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 0,
"source": "loop",
"writes": None,
"parents": {},
"thread_id": "2",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
content="",
additional_kwargs={},
response_metadata={},
id="ai1",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
)
},
),
PregelTask(
id=AnyStr(),
name="foo",
path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()),
error=None,
interrupts=(),
state=None,
result=None,
),
),
)
# remove the tool call, clearing the pending task
await graph.aupdate_state(
thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])}
)
# tool call no longer in pending tasks
assert await graph.aget_state(thread1) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="Bye now",
tool_calls=[],
),
]
},
next=(),
config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 1,
"source": "update",
"writes": {
"agent": {
"messages": _AnyIdAIMessage(
content="Bye now",
tool_calls=[],
)
}
},
"parents": {},
"thread_id": "2",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(),
)
# tool call not executed
assert await graph.ainvoke(None, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(content="Bye now"),
]
}
assert foo_called == 0
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_max_concurrency(checkpointer_name: str) -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)
self.currently = 0
self.max_currently = 0
async def __call__(self, state):
self.currently += 1
if self.currently > self.max_currently:
self.max_currently = self.currently
await asyncio.sleep(random.random() / 10)
self.currently -= 1
return [state]
def one(state):
return ["1"]
def three(state):
return ["3"]
async def send_to_many(state):
return [Send("2", idx) for idx in range(100)]
async def route_to_three(state) -> Literal["3"]:
return "3"
node2 = Node("2")
builder = StateGraph(Annotated[list, operator.add])
builder.add_node("1", one)
builder.add_node(node2)
builder.add_node("3", three)
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_to_many)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert await graph.ainvoke(["0"]) == ["0", "1", *range(100), "3"]
assert node2.max_currently == 100
assert node2.currently == 0
node2.max_currently = 0
assert await graph.ainvoke(["0"], {"max_concurrency": 10}) == [
"0",
"1",
*range(100),
"3",
]
assert node2.max_currently == 10
assert node2.currently == 0
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["2"])
thread1 = {"max_concurrency": 10, "configurable": {"thread_id": "1"}}
assert await graph.ainvoke(["0"], thread1, debug=True) == ["0", "1"]
state = await graph.aget_state(thread1)
assert state.values == ["0", "1"]
assert await graph.ainvoke(None, thread1) == ["0", "1", *range(100), "3"]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_max_concurrency_control(checkpointer_name: str) -> None:
async def node1(state) -> Command[Literal["2"]]:
return Command(update=["1"], goto=[Send("2", idx) for idx in range(100)])
node2_currently = 0
node2_max_currently = 0
async def node2(state) -> Command[Literal["3"]]:
nonlocal node2_currently, node2_max_currently
node2_currently += 1
if node2_currently > node2_max_currently:
node2_max_currently = node2_currently
await asyncio.sleep(0.1)
node2_currently -= 1
return Command(update=[state], goto="3")
async def node3(state) -> Literal["3"]:
return ["3"]
builder = StateGraph(Annotated[list, operator.add])
builder.add_node("1", node1)
builder.add_node("2", node2)
builder.add_node("3", node3)
builder.add_edge(START, "1")
graph = builder.compile()
assert (
graph.get_graph().draw_mermaid()
== """%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
__start__([<p>__start__</p>]):::first
1(1)
2(2)
3([3]):::last
__start__ --> 1;
1 -.-> 2;
2 -.-> 3;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
"""
)
assert await graph.ainvoke(["0"], debug=True) == ["0", "1", *range(100), "3"]
assert node2_max_currently == 100
assert node2_currently == 0
node2_max_currently = 0
assert await graph.ainvoke(["0"], {"max_concurrency": 10}) == [
"0",
"1",
*range(100),
"3",
]
assert node2_max_currently == 10
assert node2_currently == 0
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["2"])
thread1 = {"max_concurrency": 10, "configurable": {"thread_id": "1"}}
assert await graph.ainvoke(["0"], thread1) == ["0", "1"]
assert await graph.ainvoke(None, thread1) == ["0", "1", *range(100), "3"]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_invoke_checkpoint_three(
mocker: MockerFixture, checkpointer_name: str
) -> None:
add_one = mocker.Mock(side_effect=lambda x: x["total"] + x["input"])
def raise_if_above_10(input: int) -> int:
if input > 10:
raise ValueError("Input is too large")
return input
one = (
Channel.subscribe_to(["input"]).join(["total"])
| add_one
| Channel.write_to("output", "total")
| raise_if_above_10
)
async with awith_checkpointer(checkpointer_name) as checkpointer:
app = Pregel(
nodes={"one": one},
channels={
"total": BinaryOperatorAggregate(int, operator.add),
"input": LastValue(int),
"output": LastValue(int),
},
input_channels="input",
output_channels="output",
checkpointer=checkpointer,
debug=True,
)
thread_1 = {"configurable": {"thread_id": "1"}}
# total starts out as 0, so output is 0+2=2
assert await app.ainvoke(2, thread_1) == 2
state = await app.aget_state(thread_1)
assert state is not None
assert state.values.get("total") == 2
assert (
state.config["configurable"]["checkpoint_id"]
== (await checkpointer.aget(thread_1))["id"]
)
# total is now 2, so output is 2+3=5
assert await app.ainvoke(3, thread_1) == 5
state = await app.aget_state(thread_1)
assert state is not None
assert state.values.get("total") == 7
assert (
state.config["configurable"]["checkpoint_id"]
== (await checkpointer.aget(thread_1))["id"]
)
# total is now 2+5=7, so output would be 7+4=11, but raises ValueError
with pytest.raises(ValueError):
await app.ainvoke(4, thread_1)
# checkpoint is not updated
state = await app.aget_state(thread_1)
assert state is not None
assert state.values.get("total") == 7
assert state.next == ("one",)
"""we checkpoint inputs and it failed on "one", so the next node is one"""
# we can recover from error by sending new inputs
assert await app.ainvoke(2, thread_1) == 9
state = await app.aget_state(thread_1)
assert state is not None
assert state.values.get("total") == 16, "total is now 7+9=16"
assert state.next == ()
thread_2 = {"configurable": {"thread_id": "2"}}
# on a new thread, total starts out as 0, so output is 0+5=5
assert await app.ainvoke(5, thread_2) == 5
state = await app.aget_state({"configurable": {"thread_id": "1"}})
assert state is not None
assert state.values.get("total") == 16
assert state.next == ()
state = await app.aget_state(thread_2)
assert state is not None
assert state.values.get("total") == 5
assert state.next == ()
assert len([c async for c in app.aget_state_history(thread_1, limit=1)]) == 1
# list all checkpoints for thread 1
thread_1_history = [c async for c in app.aget_state_history(thread_1)]
# there are 7 checkpoints
assert len(thread_1_history) == 7
assert Counter(c.metadata["source"] for c in thread_1_history) == {
"input": 4,
"loop": 3,
}
# sorted descending
assert (
thread_1_history[0].config["configurable"]["checkpoint_id"]
> thread_1_history[1].config["configurable"]["checkpoint_id"]
)
# cursor pagination
cursored = [
c
async for c in app.aget_state_history(
thread_1, limit=1, before=thread_1_history[0].config
)
]
assert len(cursored) == 1
assert cursored[0].config == thread_1_history[1].config
# the last checkpoint
assert thread_1_history[0].values["total"] == 16
# the first "loop" checkpoint
assert thread_1_history[-2].values["total"] == 2
# can get each checkpoint using aget with config
assert (await checkpointer.aget(thread_1_history[0].config))[
"id"
] == thread_1_history[0].config["configurable"]["checkpoint_id"]
assert (await checkpointer.aget(thread_1_history[1].config))[
"id"
] == thread_1_history[1].config["configurable"]["checkpoint_id"]
thread_1_next_config = await app.aupdate_state(thread_1_history[1].config, 10)
# update creates a new checkpoint
assert (
thread_1_next_config["configurable"]["checkpoint_id"]
> thread_1_history[0].config["configurable"]["checkpoint_id"]
)
# 1 more checkpoint in history
assert len([c async for c in app.aget_state_history(thread_1)]) == 8
assert Counter(
[c.metadata["source"] async for c in app.aget_state_history(thread_1)]
) == {
"update": 1,
"input": 4,
"loop": 3,
}
# the latest checkpoint is the updated one
assert await app.aget_state(thread_1) == await app.aget_state(
thread_1_next_config
)
async def test_invoke_two_processes_two_in_join_two_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
add_10_each = mocker.Mock(side_effect=lambda x: sorted(y + 10 for y in x))
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
chain_three = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
chain_four = (
Channel.subscribe_to("inbox") | add_10_each | Channel.write_to("output")
)
app = Pregel(
nodes={
"one": one,
"chain_three": chain_three,
"chain_four": chain_four,
},
channels={
"inbox": Topic(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
)
# Then invoke app
# We get a single array result as chain_four waits for all publishers to finish
# before operating on all elements published to topic_two as an array
for _ in range(100):
assert await app.ainvoke(2) == [13, 13]
assert await asyncio.gather(*(app.ainvoke(2) for _ in range(100))) == [
[13, 13] for _ in range(100)
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_invoke_join_then_call_other_pregel(
mocker: MockerFixture, checkpointer_name: str
) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
add_10_each = mocker.Mock(side_effect=lambda x: [y + 10 for y in x])
inner_app = Pregel(
nodes={
"one": Channel.subscribe_to("input") | add_one | Channel.write_to("output")
},
channels={
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
)
one = (
Channel.subscribe_to("input")
| add_10_each
| Channel.write_to("inbox_one").map()
)
two = (
Channel.subscribe_to("inbox_one")
| inner_app.map()
| sorted
| Channel.write_to("outbox_one")
)
chain_three = Channel.subscribe_to("outbox_one") | sum | Channel.write_to("output")
app = Pregel(
nodes={
"one": one,
"two": two,
"chain_three": chain_three,
},
channels={
"inbox_one": Topic(int),
"outbox_one": LastValue(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
)
# Then invoke pubsub
for _ in range(10):
assert await app.ainvoke([2, 3]) == 27
assert await asyncio.gather(*(app.ainvoke([2, 3]) for _ in range(10))) == [
27 for _ in range(10)
]
async with awith_checkpointer(checkpointer_name) as checkpointer:
# add checkpointer
app.checkpointer = checkpointer
# subgraph is called twice in the same node, through .map(), so raises
with pytest.raises(MultipleSubgraphsError):
await app.ainvoke([2, 3], {"configurable": {"thread_id": "1"}})
# set inner graph checkpointer NeverCheckpoint
inner_app.checkpointer = False
# subgraph still called twice, but checkpointing for inner graph is disabled
assert await app.ainvoke([2, 3], {"configurable": {"thread_id": "1"}}) == 27
async def test_invoke_two_processes_one_in_two_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = (
Channel.subscribe_to("input") | add_one | Channel.write_to("output", "between")
)
two = Channel.subscribe_to("between") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": one, "two": two},
channels={
"input": LastValue(int),
"between": LastValue(int),
"output": LastValue(int),
},
stream_channels=["output", "between"],
input_channels="input",
output_channels="output",
)
# Then invoke pubsub
assert [c async for c in app.astream(2)] == [
{"between": 3, "output": 3},
{"between": 3, "output": 4},
]
async def test_invoke_two_processes_no_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("between")
two = Channel.subscribe_to("between") | add_one
app = Pregel(
nodes={"one": one, "two": two},
channels={
"input": LastValue(int),
"between": LastValue(int),
"output": LastValue(int),
},
input_channels="input",
output_channels="output",
)
# It finishes executing (once no more messages being published)
# but returns nothing, as nothing was published to "output" topic
assert await app.ainvoke(2) is None
async def test_channel_enter_exit_timing(mocker: MockerFixture) -> None:
setup_sync = mocker.Mock()
cleanup_sync = mocker.Mock()
setup_async = mocker.Mock()
cleanup_async = mocker.Mock()
@contextmanager
def an_int() -> Generator[int, None, None]:
setup_sync()
try:
yield 5
finally:
cleanup_sync()
@asynccontextmanager
async def an_int_async() -> AsyncGenerator[int, None]:
setup_async()
try:
yield 5
finally:
cleanup_async()
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
two = (
Channel.subscribe_to("inbox")
| RunnableLambda(add_one).abatch
| Channel.write_to("output").abatch
)
app = Pregel(
nodes={"one": one, "two": two},
channels={
"input": LastValue(int),
"output": LastValue(int),
"inbox": Topic(int),
"ctx": Context(an_int, an_int_async),
},
input_channels="input",
output_channels=["inbox", "output"],
stream_channels=["inbox", "output"],
)
async def aenumerate(aiter: AsyncIterator[Any]) -> AsyncIterator[tuple[int, Any]]:
i = 0
async for chunk in aiter:
yield i, chunk
i += 1
assert setup_sync.call_count == 0
assert cleanup_sync.call_count == 0
assert setup_async.call_count == 0
assert cleanup_async.call_count == 0
async for i, chunk in aenumerate(app.astream(2)):
assert setup_sync.call_count == 0, "Sync context manager should not be used"
assert cleanup_sync.call_count == 0, "Sync context manager should not be used"
assert setup_async.call_count == 1, "Expected setup to be called once"
if i == 0:
assert chunk == {"inbox": [3]}
elif i == 1:
assert chunk == {"output": 4}
else:
assert False, "Expected only two chunks"
assert setup_sync.call_count == 0
assert cleanup_sync.call_count == 0
assert setup_async.call_count == 1, "Expected setup to be called once"
assert cleanup_async.call_count == 1, "Expected cleanup to be called once"
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_conditional_graph(checkpointer_name: str) -> None:
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models.fake import FakeStreamingListLLM
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.tools import tool
# Assemble the tools
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
# Construct the agent
prompt = PromptTemplate.from_template("Hello!")
llm = FakeStreamingListLLM(
responses=[
"tool:search_api:query",
"tool:search_api:another",
"finish:answer",
]
)
async def agent_parser(input: str) -> Union[AgentAction, AgentFinish]:
if input.startswith("finish"):
_, answer = input.split(":")
return AgentFinish(return_values={"answer": answer}, log=input)
else:
_, tool_name, tool_input = input.split(":")
return AgentAction(tool=tool_name, tool_input=tool_input, log=input)
agent = RunnablePassthrough.assign(agent_outcome=prompt | llm | agent_parser)
# Define tool execution logic
async def execute_tools(data: dict) -> dict:
data = data.copy()
agent_action: AgentAction = data.pop("agent_outcome")
observation = await {t.name: t for t in tools}[agent_action.tool].ainvoke(
agent_action.tool_input
)
if data.get("intermediate_steps") is None:
data["intermediate_steps"] = []
else:
data["intermediate_steps"] = data["intermediate_steps"].copy()
data["intermediate_steps"].append([agent_action, observation])
return data
# Define decision-making logic
async def should_continue(data: dict, config: RunnableConfig) -> str:
# Logic to decide whether to continue in the loop or exit
if isinstance(data["agent_outcome"], AgentFinish):
return "exit"
else:
return "continue"
# Define a new graph
workflow = Graph()
workflow.add_node("agent", agent)
workflow.add_node("tools", execute_tools)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent", should_continue, {"continue": "tools", "exit": END}
)
workflow.add_edge("tools", "agent")
app = workflow.compile()
assert await app.ainvoke({"input": "what is weather in sf"}) == {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
}
assert [c async for c in app.astream({"input": "what is weather in sf"})] == [
{
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
}
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
}
},
]
patches = [c async for c in app.astream_log({"input": "what is weather in sf"})]
patch_paths = {op["path"] for log in patches for op in log.ops}
# Check that agent (one of the nodes) has its output streamed to the logs
assert "/logs/agent/streamed_output/-" in patch_paths
assert "/logs/agent:2/streamed_output/-" in patch_paths
assert "/logs/agent:3/streamed_output/-" in patch_paths
# Check that agent (one of the nodes) has its final output set in the logs
assert "/logs/agent/final_output" in patch_paths
assert "/logs/agent:2/final_output" in patch_paths
assert "/logs/agent:3/final_output" in patch_paths
assert [
p["value"]
for log in patches
for p in log.ops
if p["path"] == "/logs/agent/final_output"
or p["path"] == "/logs/agent:2/final_output"
or p["path"] == "/logs/agent:3/final_output"
] == [
{
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
},
{
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
},
{
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
},
]
async with awith_checkpointer(checkpointer_name) as checkpointer:
# test state get/update methods with interrupt_after
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["agent"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c
async for c in app_w_interrupt.astream(
{"input": "what is weather in sf"}, config
)
] == [
{
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
}
]
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
},
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {
"agent": {
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
}
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
await app_w_interrupt.aupdate_state(
config,
{
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
}
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
]
await app_w_interrupt.aupdate_state(
config,
{
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
},
)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
},
},
tasks=(),
next=(),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 4,
"writes": {
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
}
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
# test state get/update methods with interrupt_before
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["tools"],
)
config = {"configurable": {"thread_id": "2"}}
llm.i = 0
assert [
c
async for c in app_w_interrupt.astream(
{"input": "what is weather in sf"}, config
)
] == [
{
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
}
]
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
},
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {
"agent": {
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
}
},
"thread_id": "2",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
await app_w_interrupt.aupdate_state(
config,
{
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
}
},
"thread_id": "2",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
]
await app_w_interrupt.aupdate_state(
config,
{
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
},
)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
},
},
tasks=(),
next=(),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 4,
"writes": {
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
}
},
"thread_id": "2",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
# test re-invoke to continue with interrupt_before
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["tools"],
)
config = {"configurable": {"thread_id": "3"}}
llm.i = 0 # reset the llm
assert [
c
async for c in app_w_interrupt.astream(
{"input": "what is weather in sf"}, config
)
] == [
{
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
}
]
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
},
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {
"agent": {
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
}
},
"thread_id": "3",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
},
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
]
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
}
},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_conditional_graph_state(
mocker: MockerFixture, checkpointer_name: str
) -> None:
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models.fake import FakeStreamingListLLM
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import tool
setup = mocker.Mock()
teardown = mocker.Mock()
@asynccontextmanager
async def assert_ctx_once() -> AsyncIterator[None]:
assert setup.call_count == 0
assert teardown.call_count == 0
try:
yield
finally:
assert setup.call_count == 1
assert teardown.call_count == 1
setup.reset_mock()
teardown.reset_mock()
class MyPydanticContextModel(BaseModel, arbitrary_types_allowed=True):
session: httpx.AsyncClient
something_else: str
@asynccontextmanager
async def make_context(
config: RunnableConfig,
) -> AsyncIterator[MyPydanticContextModel]:
assert isinstance(config, dict)
setup()
session = httpx.AsyncClient()
try:
yield MyPydanticContextModel(session=session, something_else="hello")
finally:
await session.aclose()
teardown()
class AgentState(TypedDict):
input: Annotated[str, UntrackedValue]
agent_outcome: Optional[Union[AgentAction, AgentFinish]]
intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add]
context: Annotated[MyPydanticContextModel, Context(make_context)]
# Assemble the tools
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
# Construct the agent
prompt = PromptTemplate.from_template("Hello!")
llm = FakeStreamingListLLM(
responses=[
"tool:search_api:query",
"tool:search_api:another",
"finish:answer",
]
)
def agent_parser(input: str) -> dict[str, Union[AgentAction, AgentFinish]]:
if input.startswith("finish"):
_, answer = input.split(":")
return {
"agent_outcome": AgentFinish(
return_values={"answer": answer}, log=input
)
}
else:
_, tool_name, tool_input = input.split(":")
return {
"agent_outcome": AgentAction(
tool=tool_name, tool_input=tool_input, log=input
)
}
agent = prompt | llm | agent_parser
# Define tool execution logic
def execute_tools(data: AgentState) -> dict:
# check we have httpx session in AgentState
assert isinstance(data["context"], MyPydanticContextModel)
# execute the tool
agent_action: AgentAction = data.pop("agent_outcome")
observation = {t.name: t for t in tools}[agent_action.tool].invoke(
agent_action.tool_input
)
return {"intermediate_steps": [[agent_action, observation]]}
# Define decision-making logic
def should_continue(data: AgentState) -> str:
# check we have httpx session in AgentState
assert isinstance(data["context"], MyPydanticContextModel)
# Logic to decide whether to continue in the loop or exit
if isinstance(data["agent_outcome"], AgentFinish):
return "exit"
else:
return "continue"
# Define a new graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent)
workflow.add_node("tools", execute_tools)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent", should_continue, {"continue": "tools", "exit": END}
)
workflow.add_edge("tools", "agent")
app = workflow.compile()
async with assert_ctx_once():
assert await app.ainvoke({"input": "what is weather in sf"}) == {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
}
async with assert_ctx_once():
assert [c async for c in app.astream({"input": "what is weather in sf"})] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
{
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
}
},
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
}
},
{
"agent": {
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
}
},
]
async with assert_ctx_once():
patches = [c async for c in app.astream_log({"input": "what is weather in sf"})]
patch_paths = {op["path"] for log in patches for op in log.ops}
# Check that agent (one of the nodes) has its output streamed to the logs
assert "/logs/agent/streamed_output/-" in patch_paths
# Check that agent (one of the nodes) has its final output set in the logs
assert "/logs/agent/final_output" in patch_paths
assert [
p["value"]
for log in patches
for p in log.ops
if p["path"] == "/logs/agent/final_output"
or p["path"] == "/logs/agent:2/final_output"
or p["path"] == "/logs/agent:3/final_output"
] == [
{
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
)
},
{
"agent_outcome": AgentAction(
tool="search_api", tool_input="another", log="tool:search_api:another"
)
},
{
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
},
]
async with awith_checkpointer(checkpointer_name) as checkpointer:
# test state get/update methods with interrupt_after
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["agent"],
)
config = {"configurable": {"thread_id": "1"}}
async with assert_ctx_once():
assert [
c
async for c in app_w_interrupt.astream(
{"input": "what is weather in sf"}, config
)
] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
{"__interrupt__": ()},
]
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"intermediate_steps": [],
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
async with assert_ctx_once():
await app_w_interrupt.aupdate_state(
config,
{
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
)
},
)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"intermediate_steps": [],
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 2,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
)
}
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
async with assert_ctx_once():
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
}
},
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{"__interrupt__": ()},
]
async with assert_ctx_once():
await app_w_interrupt.aupdate_state(
config,
{
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
)
},
)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
},
tasks=(),
next=(),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 5,
"writes": {
"agent": {
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
)
}
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
# test state get/update methods with interrupt_before
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["tools"],
)
config = {"configurable": {"thread_id": "2"}}
llm.i = 0 # reset the llm
assert [
c
async for c in app_w_interrupt.astream(
{"input": "what is weather in sf"}, config
)
] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
{"__interrupt__": ()},
]
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
"intermediate_steps": [],
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
"thread_id": "2",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
await app_w_interrupt.aupdate_state(
config,
{
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
)
},
)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"intermediate_steps": [],
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 2,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
)
}
},
"thread_id": "2",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
}
},
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{"__interrupt__": ()},
]
await app_w_interrupt.aupdate_state(
config,
{
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
)
},
)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
},
tasks=(),
next=(),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 5,
"writes": {
"agent": {
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
)
}
},
"thread_id": "2",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
async def test_conditional_entrypoint_graph() -> None:
async def left(data: str) -> str:
return data + "->left"
async def right(data: str) -> str:
return data + "->right"
def should_start(data: str) -> str:
# Logic to decide where to start
if len(data) > 10:
return "go-right"
else:
return "go-left"
# Define a new graph
workflow = Graph()
workflow.add_node("left", left)
workflow.add_node("right", right)
workflow.set_conditional_entry_point(
should_start, {"go-left": "left", "go-right": "right"}
)
workflow.add_conditional_edges("left", lambda data: END)
workflow.add_edge("right", END)
app = workflow.compile()
assert await app.ainvoke("what is weather in sf") == "what is weather in sf->right"
assert [c async for c in app.astream("what is weather in sf")] == [
{"right": "what is weather in sf->right"},
]
async def test_conditional_entrypoint_graph_state() -> None:
class AgentState(TypedDict, total=False):
input: str
output: str
steps: Annotated[list[str], operator.add]
async def left(data: AgentState) -> AgentState:
return {"output": data["input"] + "->left"}
async def right(data: AgentState) -> AgentState:
return {"output": data["input"] + "->right"}
def should_start(data: AgentState) -> str:
assert data["steps"] == [], "Expected input to be read from the state"
# Logic to decide where to start
if len(data["input"]) > 10:
return "go-right"
else:
return "go-left"
# Define a new graph
workflow = StateGraph(AgentState)
workflow.add_node("left", left)
workflow.add_node("right", right)
workflow.set_conditional_entry_point(
should_start, {"go-left": "left", "go-right": "right"}
)
workflow.add_conditional_edges("left", lambda data: END)
workflow.add_edge("right", END)
app = workflow.compile()
assert await app.ainvoke({"input": "what is weather in sf"}) == {
"input": "what is weather in sf",
"output": "what is weather in sf->right",
"steps": [],
}
assert [c async for c in app.astream({"input": "what is weather in sf"})] == [
{"right": {"output": "what is weather in sf->right"}},
]
async def test_prebuilt_tool_chat() -> None:
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import tool
model = FakeChatModel(
messages=[
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another"},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one"},
},
],
),
AIMessage(content="answer"),
]
)
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
app = create_tool_calling_executor(model, tools)
assert await app.ainvoke(
{"messages": [HumanMessage(content="what is weather in sf")]}
) == {
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another"},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one"},
},
],
),
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call234",
),
_AnyIdToolMessage(
content="result for a third one",
name="search_api",
tool_call_id="tool_call567",
id=AnyStr(),
),
_AnyIdAIMessage(content="answer"),
]
}
assert [
c
async for c in app.astream(
{"messages": [HumanMessage(content="what is weather in sf")]},
stream_mode="messages",
)
] == [
(
_AnyIdAIMessageChunk(
content="",
tool_calls=[
{
"name": "search_api",
"args": {"query": "query"},
"id": "tool_call123",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "search_api",
"args": '{"query": "query"}',
"id": "tool_call123",
"index": None,
"type": "tool_call_chunk",
}
],
),
{
"langgraph_step": 1,
"langgraph_node": "agent",
"langgraph_triggers": ["start:agent"],
"langgraph_path": ("__pregel_pull", "agent"),
"langgraph_checkpoint_ns": AnyStr("agent:"),
"checkpoint_ns": AnyStr("agent:"),
"ls_provider": "fakechatmodel",
"ls_model_type": "chat",
},
),
(
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
),
{
"langgraph_step": 2,
"langgraph_node": "tools",
"langgraph_triggers": ["branch:agent:should_continue:tools"],
"langgraph_path": ("__pregel_pull", "tools"),
"langgraph_checkpoint_ns": AnyStr("tools:"),
},
),
(
_AnyIdAIMessageChunk(
content="",
tool_calls=[
{
"name": "search_api",
"args": {"query": "another"},
"id": "tool_call234",
"type": "tool_call",
},
{
"name": "search_api",
"args": {"query": "a third one"},
"id": "tool_call567",
"type": "tool_call",
},
],
tool_call_chunks=[
{
"name": "search_api",
"args": '{"query": "another"}',
"id": "tool_call234",
"index": None,
"type": "tool_call_chunk",
},
{
"name": "search_api",
"args": '{"query": "a third one"}',
"id": "tool_call567",
"index": None,
"type": "tool_call_chunk",
},
],
),
{
"langgraph_step": 3,
"langgraph_node": "agent",
"langgraph_triggers": ["tools"],
"langgraph_path": ("__pregel_pull", "agent"),
"langgraph_checkpoint_ns": AnyStr("agent:"),
"checkpoint_ns": AnyStr("agent:"),
"ls_provider": "fakechatmodel",
"ls_model_type": "chat",
},
),
(
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call234",
),
{
"langgraph_step": 4,
"langgraph_node": "tools",
"langgraph_triggers": ["branch:agent:should_continue:tools"],
"langgraph_path": ("__pregel_pull", "tools"),
"langgraph_checkpoint_ns": AnyStr("tools:"),
},
),
(
_AnyIdToolMessage(
content="result for a third one",
name="search_api",
tool_call_id="tool_call567",
),
{
"langgraph_step": 4,
"langgraph_node": "tools",
"langgraph_triggers": ["branch:agent:should_continue:tools"],
"langgraph_path": ("__pregel_pull", "tools"),
"langgraph_checkpoint_ns": AnyStr("tools:"),
},
),
(
_AnyIdAIMessageChunk(
content="answer",
),
{
"langgraph_step": 5,
"langgraph_node": "agent",
"langgraph_triggers": ["tools"],
"langgraph_path": ("__pregel_pull", "agent"),
"langgraph_checkpoint_ns": AnyStr("agent:"),
"checkpoint_ns": AnyStr("agent:"),
"ls_provider": "fakechatmodel",
"ls_model_type": "chat",
},
),
]
assert [
c
async for c in app.astream(
{"messages": [HumanMessage(content="what is weather in sf")]}
)
] == [
{
"agent": {
"messages": [
_AnyIdAIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
)
]
}
},
{
"tools": {
"messages": [
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
)
]
}
},
{
"agent": {
"messages": [
_AnyIdAIMessage(
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another"},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one"},
},
],
)
]
}
},
{
"tools": {
"messages": [
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call234",
),
_AnyIdToolMessage(
content="result for a third one",
name="search_api",
tool_call_id="tool_call567",
),
]
}
},
{"agent": {"messages": [_AnyIdAIMessage(content="answer")]}},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_state_graph_packets(checkpointer_name: str) -> None:
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
ToolMessage,
)
from langchain_core.tools import tool
class AgentState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
session: Annotated[httpx.AsyncClient, Context(httpx.AsyncClient)]
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
tools_by_name = {t.name: t for t in tools}
model = FakeMessagesListChatModel(
responses=[
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
),
AIMessage(id="ai3", content="answer"),
]
)
# Define decision-making logic
def should_continue(data: AgentState) -> str:
assert isinstance(data["session"], httpx.AsyncClient)
# Logic to decide whether to continue in the loop or exit
if tool_calls := data["messages"][-1].tool_calls:
return [Send("tools", tool_call) for tool_call in tool_calls]
else:
return END
async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState:
await asyncio.sleep(input["args"].get("idx", 0) / 10)
output = await tools_by_name[input["name"]].ainvoke(input["args"], config)
return {
"messages": ToolMessage(
content=output, name=input["name"], tool_call_id=input["id"]
)
}
# Define a new graph
workflow = StateGraph(AgentState)
# Define the two nodes we will cycle between
workflow.add_node("agent", {"messages": RunnablePick("messages") | model})
workflow.add_node("tools", tools_node)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges("agent", should_continue)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("tools", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()
assert await app.ainvoke(
{"messages": HumanMessage(content="what is weather in sf")}
) == {
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
),
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call234",
),
_AnyIdToolMessage(
content="result for a third one",
name="search_api",
tool_call_id="tool_call567",
),
AIMessage(content="answer", id="ai3"),
]
}
assert [
c
async for c in app.astream(
{"messages": [HumanMessage(content="what is weather in sf")]}
)
] == [
{
"agent": {
"messages": AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
)
},
},
{
"tools": {
"messages": _AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
)
}
},
{
"agent": {
"messages": AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
)
}
},
{
"tools": {
"messages": _AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call234",
)
},
},
{
"tools": {
"messages": _AnyIdToolMessage(
content="result for a third one",
name="search_api",
tool_call_id="tool_call567",
),
},
},
{"agent": {"messages": AIMessage(content="answer", id="ai3")}},
]
async with awith_checkpointer(checkpointer_name) as checkpointer:
# interrupt after agent
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["agent"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c
async for c in app_w_interrupt.astream(
{"messages": HumanMessage(content="what is weather in sf")}, config
)
] == [
{
"agent": {
"messages": AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
)
}
},
{"__interrupt__": ()},
]
if not FF_SEND_V2:
return
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
]
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
"",
id="ai1",
tool_calls=[
{
"name": "search_api",
"args": {"query": "query"},
"id": "tool_call123",
"type": "tool_call",
}
],
)
},
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr())
),
),
next=("tools",),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
# modify ai message
last_message = (await app_w_interrupt.aget_state(config)).values["messages"][-1]
last_message.tool_calls[0]["args"]["query"] = "a different query"
await app_w_interrupt.aupdate_state(config, {"messages": last_message})
# message was replaced instead of appended
tup = await app_w_interrupt.checkpointer.aget_tuple(config)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
]
},
tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0, AnyStr())),),
next=("tools",),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {
"agent": {
"messages": AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
)
}
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{
"tools": {
"messages": _AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
)
}
},
{
"agent": {
"messages": AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
)
},
},
{"__interrupt__": ()},
]
tup = await app_w_interrupt.checkpointer.aget_tuple(config)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
),
]
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
"",
id="ai2",
tool_calls=[
{
"name": "search_api",
"args": {"query": "another", "idx": 0},
"id": "tool_call234",
"type": "tool_call",
},
{
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
"id": "tool_call567",
"type": "tool_call",
},
],
)
},
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr())
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3, AnyStr())
),
),
next=("tools", "tools"),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 2,
"writes": {
"tools": {
"messages": _AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
},
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
await app_w_interrupt.aupdate_state(
config,
{"messages": AIMessage(content="answer", id="ai2")},
)
# replaces message even if object identity is different, as long as id is the same
tup = await app_w_interrupt.checkpointer.aget_tuple(config)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(content="answer", id="ai2"),
]
},
tasks=(),
next=(),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 3,
"writes": {
"agent": {
"messages": AIMessage(content="answer", id="ai2"),
}
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
# interrupt before tools
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["tools"],
)
config = {"configurable": {"thread_id": "2"}}
model.i = 0
assert [
c
async for c in app_w_interrupt.astream(
{"messages": HumanMessage(content="what is weather in sf")}, config
)
] == [
{
"agent": {
"messages": AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
)
}
},
{"__interrupt__": ()},
]
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
]
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
content="",
additional_kwargs={},
response_metadata={},
id="ai1",
tool_calls=[
{
"name": "search_api",
"args": {"query": "query"},
"id": "tool_call123",
"type": "tool_call",
}
],
)
},
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr())
),
),
next=("tools",),
config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config,
created_at=(
await app_w_interrupt.checkpointer.aget_tuple(config)
).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "2",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
# modify ai message
last_message = (await app_w_interrupt.aget_state(config)).values["messages"][-1]
last_message.tool_calls[0]["args"]["query"] = "a different query"
await app_w_interrupt.aupdate_state(config, {"messages": last_message})
# message was replaced instead of appended
tup = await app_w_interrupt.checkpointer.aget_tuple(config)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
]
},
tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0, AnyStr())),),
next=("tools",),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {
"agent": {
"messages": AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
)
}
},
"thread_id": "2",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{
"tools": {
"messages": _AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
)
}
},
{
"agent": {
"messages": AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
)
},
},
{"__interrupt__": ()},
]
tup = await app_w_interrupt.checkpointer.aget_tuple(config)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
),
]
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
content="",
additional_kwargs={},
response_metadata={},
id="ai2",
tool_calls=[
{
"name": "search_api",
"args": {"query": "another", "idx": 0},
"id": "tool_call234",
"type": "tool_call",
},
{
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
"id": "tool_call567",
"type": "tool_call",
},
],
)
},
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr())
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3, AnyStr())
),
),
next=("tools", "tools"),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 2,
"writes": {
"tools": {
"messages": _AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
},
},
"thread_id": "2",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
await app_w_interrupt.aupdate_state(
config,
{"messages": AIMessage(content="answer", id="ai2")},
)
# replaces message even if object identity is different, as long as id is the same
tup = await app_w_interrupt.checkpointer.aget_tuple(config)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(content="answer", id="ai2"),
]
},
tasks=(),
next=(),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 3,
"writes": {
"agent": {
"messages": AIMessage(content="answer", id="ai2"),
}
},
"thread_id": "2",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_message_graph(checkpointer_name: str) -> None:
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import tool
class FakeFuntionChatModel(FakeMessagesListChatModel):
def bind_functions(self, functions: list):
return self
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
model = FakeFuntionChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
),
AIMessage(content="answer", id="ai3"),
]
)
# Define the function that determines whether to continue or not
def should_continue(messages):
last_message = messages[-1]
# If there is no function call, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
# Define a new graph
workflow = MessageGraph()
# Define the two nodes we will cycle between
workflow.add_node("agent", model)
workflow.add_node("tools", ToolNode(tools))
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "tools",
# Otherwise we finish.
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("tools", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()
assert await app.ainvoke(HumanMessage(content="what is weather in sf")) == [
_AnyIdHumanMessage(
content="what is weather in sf",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1", # respects ids passed in
),
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
),
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call456",
),
AIMessage(content="answer", id="ai3"),
]
assert [
c async for c in app.astream([HumanMessage(content="what is weather in sf")])
] == [
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
{
"tools": [
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
)
]
},
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
{
"tools": [
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call456",
)
]
},
{"agent": AIMessage(content="answer", id="ai3")},
]
async with awith_checkpointer(checkpointer_name) as checkpointer:
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["agent"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c
async for c in app_w_interrupt.astream(
HumanMessage(content="what is weather in sf"), config
)
] == [
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
{"__interrupt__": ()},
]
tup = await app_w_interrupt.checkpointer.aget_tuple(config)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
# modify ai message
last_message = (await app_w_interrupt.aget_state(config)).values[-1]
last_message.tool_calls[0]["args"] = {"query": "a different query"}
await app_w_interrupt.aupdate_state(config, last_message)
# message was replaced instead of appended
tup = await app_w_interrupt.checkpointer.aget_tuple(config)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 2,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
id="ai1",
)
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{
"tools": [
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
)
]
},
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
{"__interrupt__": ()},
]
tup = await app_w_interrupt.checkpointer.aget_tuple(config)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 4,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
await app_w_interrupt.aupdate_state(
config,
AIMessage(content="answer", id="ai2"),
)
# replaces message even if object identity is different, as long as id is the same
tup = await app_w_interrupt.checkpointer.aget_tuple(config)
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(content="answer", id="ai2"),
],
tasks=(),
next=(),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 5,
"writes": {"agent": AIMessage(content="answer", id="ai2")},
"thread_id": "1",
},
parent_config=[
c async for c in app_w_interrupt.checkpointer.alist(config, limit=2)
][-1].config,
)
async def test_in_one_fan_out_out_one_graph_state() -> None:
def sorted_add(x: list[str], y: list[str]) -> list[str]:
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], operator.add]
async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
async def retriever_one(data: State) -> State:
await asyncio.sleep(0.1)
return {"docs": ["doc1", "doc2"]}
async def retriever_two(data: State) -> State:
return {"docs": ["doc3", "doc4"]}
async def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "retriever_one")
workflow.add_edge("rewrite_query", "retriever_two")
workflow.add_edge("retriever_one", "qa")
workflow.add_edge("retriever_two", "qa")
workflow.set_finish_point("qa")
app = workflow.compile()
assert await app.ainvoke({"query": "what is weather in sf"}) == {
"query": "query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
}
assert [c async for c in app.astream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
assert [
c
async for c in app.astream(
{"query": "what is weather in sf"}, stream_mode="values"
)
] == [
{"query": "what is weather in sf", "docs": []},
{"query": "query: what is weather in sf", "docs": []},
{
"query": "query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
},
{
"query": "query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
},
]
assert [
c
async for c in app.astream(
{"query": "what is weather in sf"},
stream_mode=["values", "updates", "debug"],
)
] == [
("values", {"query": "what is weather in sf", "docs": []}),
(
"debug",
{
"type": "task",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "rewrite_query",
"input": {"query": "what is weather in sf", "docs": []},
"triggers": ["start:rewrite_query"],
},
},
),
("updates", {"rewrite_query": {"query": "query: what is weather in sf"}}),
(
"debug",
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "rewrite_query",
"result": [("query", "query: what is weather in sf")],
"error": None,
"interrupts": [],
},
},
),
("values", {"query": "query: what is weather in sf", "docs": []}),
(
"debug",
{
"type": "task",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "retriever_one",
"input": {"query": "query: what is weather in sf", "docs": []},
"triggers": ["rewrite_query"],
},
},
),
(
"debug",
{
"type": "task",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "retriever_two",
"input": {"query": "query: what is weather in sf", "docs": []},
"triggers": ["rewrite_query"],
},
},
),
(
"updates",
{"retriever_two": {"docs": ["doc3", "doc4"]}},
),
(
"debug",
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "retriever_two",
"result": [("docs", ["doc3", "doc4"])],
"error": None,
"interrupts": [],
},
},
),
(
"updates",
{"retriever_one": {"docs": ["doc1", "doc2"]}},
),
(
"debug",
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "retriever_one",
"result": [("docs", ["doc1", "doc2"])],
"error": None,
"interrupts": [],
},
},
),
(
"values",
{
"query": "query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
},
),
(
"debug",
{
"type": "task",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": AnyStr(),
"name": "qa",
"input": {
"query": "query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
},
"triggers": ["retriever_one", "retriever_two"],
},
},
),
("updates", {"qa": {"answer": "doc1,doc2,doc3,doc4"}}),
(
"debug",
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": AnyStr(),
"name": "qa",
"result": [("answer", "doc1,doc2,doc3,doc4")],
"error": None,
"interrupts": [],
},
},
),
(
"values",
{
"query": "query: what is weather in sf",
"answer": "doc1,doc2,doc3,doc4",
"docs": ["doc1", "doc2", "doc3", "doc4"],
},
),
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_start_branch_then(checkpointer_name: str) -> None:
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str
shared: Annotated[dict[str, dict[str, Any]], SharedValue.on("assistant_id")]
other: Annotated[dict[str, dict[str, Any]], SharedValue.on("assistant_id")]
def assert_shared_value(data: State, config: RunnableConfig) -> State:
assert "shared" in data
if thread_id := config["configurable"].get("thread_id"):
if thread_id == "1":
# this is the first thread, so should not see a value
assert data["shared"] == {}
return {"shared": {"1": {"hello": "world"}}, "other": {"2": {1: 2}}}
elif thread_id == "2":
# this should get value saved by thread 1
assert data["shared"] == {"1": {"hello": "world"}}
elif thread_id == "3":
# this is a different assistant, so should not see previous value
assert data["shared"] == {}
return {}
def tool_two_slow(data: State, config: RunnableConfig) -> State:
return {"my_key": " slow", **assert_shared_value(data, config)}
def tool_two_fast(data: State, config: RunnableConfig) -> State:
return {"my_key": " fast", **assert_shared_value(data, config)}
tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two_slow", tool_two_slow)
tool_two_graph.add_node("tool_two_fast", tool_two_fast)
tool_two_graph.set_conditional_entry_point(
lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast", then=END
)
tool_two = tool_two_graph.compile()
assert await tool_two.ainvoke({"my_key": "value", "market": "DE"}) == {
"my_key": "value slow",
"market": "DE",
}
assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == {
"my_key": "value fast",
"market": "US",
}
async with awith_checkpointer(checkpointer_name) as checkpointer:
tool_two = tool_two_graph.compile(
store=InMemoryStore(),
checkpointer=checkpointer,
interrupt_before=["tool_two_fast", "tool_two_slow"],
)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
await tool_two.ainvoke({"my_key": "value", "market": "DE"})
thread1 = {"configurable": {"thread_id": "1", "assistant_id": "a"}}
# stop when about to enter node
assert await tool_two.ainvoke({"my_key": "value", "market": "DE"}, thread1) == {
"my_key": "value",
"market": "DE",
}
assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"assistant_id": "a",
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value", "market": "DE"}},
"assistant_id": "a",
"thread_id": "1",
},
]
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value", "market": "DE"},
tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),),
next=("tool_two_slow",),
config=(await tool_two.checkpointer.aget_tuple(thread1)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"assistant_id": "a",
"thread_id": "1",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)
# resume, for same result as above
assert await tool_two.ainvoke(None, thread1, debug=1) == {
"my_key": "value slow",
"market": "DE",
}
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value slow", "market": "DE"},
tasks=(),
next=(),
config=(await tool_two.checkpointer.aget_tuple(thread1)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"tool_two_slow": {"my_key": " slow"}},
"assistant_id": "a",
"thread_id": "1",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)
thread2 = {"configurable": {"thread_id": "2", "assistant_id": "a"}}
# stop when about to enter node
assert await tool_two.ainvoke({"my_key": "value", "market": "US"}, thread2) == {
"my_key": "value",
"market": "US",
}
assert await tool_two.aget_state(thread2) == StateSnapshot(
values={"my_key": "value", "market": "US"},
tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),),
next=("tool_two_fast",),
config=(await tool_two.checkpointer.aget_tuple(thread2)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"assistant_id": "a",
"thread_id": "2",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread2, limit=2)
][-1].config,
)
# resume, for same result as above
assert await tool_two.ainvoke(None, thread2, debug=1) == {
"my_key": "value fast",
"market": "US",
}
assert await tool_two.aget_state(thread2) == StateSnapshot(
values={"my_key": "value fast", "market": "US"},
tasks=(),
next=(),
config=(await tool_two.checkpointer.aget_tuple(thread2)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"tool_two_fast": {"my_key": " fast"}},
"assistant_id": "a",
"thread_id": "2",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread2, limit=2)
][-1].config,
)
thread3 = {"configurable": {"thread_id": "3", "assistant_id": "b"}}
# stop when about to enter node
assert await tool_two.ainvoke({"my_key": "value", "market": "US"}, thread3) == {
"my_key": "value",
"market": "US",
}
assert await tool_two.aget_state(thread3) == StateSnapshot(
values={"my_key": "value", "market": "US"},
tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),),
next=("tool_two_fast",),
config=(await tool_two.checkpointer.aget_tuple(thread3)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"assistant_id": "b",
"thread_id": "3",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread3, limit=2)
][-1].config,
)
# update state
await tool_two.aupdate_state(thread3, {"my_key": "key"}) # appends to my_key
assert await tool_two.aget_state(thread3) == StateSnapshot(
values={"my_key": "valuekey", "market": "US"},
tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),),
next=("tool_two_fast",),
config=(await tool_two.checkpointer.aget_tuple(thread3)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {START: {"my_key": "key"}},
"assistant_id": "b",
"thread_id": "3",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread3, limit=2)
][-1].config,
)
# resume, for same result as above
assert await tool_two.ainvoke(None, thread3, debug=1) == {
"my_key": "valuekey fast",
"market": "US",
}
assert await tool_two.aget_state(thread3) == StateSnapshot(
values={"my_key": "valuekey fast", "market": "US"},
tasks=(),
next=(),
config=(await tool_two.checkpointer.aget_tuple(thread3)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 2,
"writes": {"tool_two_fast": {"my_key": " fast"}},
"assistant_id": "b",
"thread_id": "3",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread3, limit=2)
][-1].config,
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_branch_then(checkpointer_name: str) -> None:
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str
tool_two_graph = StateGraph(State)
tool_two_graph.set_entry_point("prepare")
tool_two_graph.set_finish_point("finish")
tool_two_graph.add_conditional_edges(
source="prepare",
path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
then="finish",
)
tool_two_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
tool_two_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
tool_two_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
tool_two_graph.add_node("finish", lambda s: {"my_key": " finished"})
tool_two = tool_two_graph.compile()
assert await tool_two.ainvoke({"my_key": "value", "market": "DE"}, debug=1) == {
"my_key": "value prepared slow finished",
"market": "DE",
}
assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == {
"my_key": "value prepared fast finished",
"market": "US",
}
async with awith_checkpointer(checkpointer_name) as checkpointer:
# test stream_mode=debug
tool_two = tool_two_graph.compile(checkpointer=checkpointer)
thread10 = {"configurable": {"thread_id": "10"}}
assert [
c
async for c in tool_two.astream(
{"my_key": "value", "market": "DE"}, thread10, stream_mode="debug"
)
] == [
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": -1,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {"my_key": ""},
"metadata": {
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value", "market": "DE"}},
"thread_id": "10",
},
"parent_config": None,
"next": ["__start__"],
"tasks": [
{
"id": AnyStr(),
"name": "__start__",
"interrupts": (),
"state": None,
}
],
},
},
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": 0,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {
"my_key": "value",
"market": "DE",
},
"metadata": {
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "10",
},
"parent_config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"next": ["prepare"],
"tasks": [
{
"id": AnyStr(),
"name": "prepare",
"interrupts": (),
"state": None,
}
],
},
},
{
"type": "task",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "prepare",
"input": {"my_key": "value", "market": "DE"},
"triggers": ["start:prepare"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "prepare",
"result": [("my_key", " prepared")],
"error": None,
"interrupts": [],
},
},
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {
"my_key": "value prepared",
"market": "DE",
},
"metadata": {
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "10",
},
"parent_config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"next": ["tool_two_slow"],
"tasks": [
{
"id": AnyStr(),
"name": "tool_two_slow",
"interrupts": (),
"state": None,
}
],
},
},
{
"type": "task",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "tool_two_slow",
"input": {"my_key": "value prepared", "market": "DE"},
"triggers": ["branch:prepare:condition:tool_two_slow"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "tool_two_slow",
"result": [("my_key", " slow")],
"error": None,
"interrupts": [],
},
},
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {
"my_key": "value prepared slow",
"market": "DE",
},
"metadata": {
"parents": {},
"source": "loop",
"step": 2,
"writes": {"tool_two_slow": {"my_key": " slow"}},
"thread_id": "10",
},
"parent_config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"next": ["finish"],
"tasks": [
{
"id": AnyStr(),
"name": "finish",
"interrupts": (),
"state": None,
}
],
},
},
{
"type": "task",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": AnyStr(),
"name": "finish",
"input": {"my_key": "value prepared slow", "market": "DE"},
"triggers": ["branch:prepare:condition::then"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": AnyStr(),
"name": "finish",
"result": [("my_key", " finished")],
"error": None,
"interrupts": [],
},
},
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {
"my_key": "value prepared slow finished",
"market": "DE",
},
"metadata": {
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "10",
},
"parent_config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"next": [],
"tasks": [],
},
},
]
tool_two = tool_two_graph.compile(
checkpointer=checkpointer,
interrupt_before=["tool_two_fast", "tool_two_slow"],
)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
await tool_two.ainvoke({"my_key": "value", "market": "DE"})
thread1 = {"configurable": {"thread_id": "11"}}
# stop when about to enter node
assert [
c
async for c in tool_two.astream(
{"my_key": "value", "market": "DE"}, thread1, stream_mode="debug"
)
] == [
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": -1,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "11"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "11",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {"my_key": ""},
"metadata": {
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value", "market": "DE"}},
"thread_id": "11",
},
"parent_config": None,
"next": ["__start__"],
"tasks": [
{
"id": AnyStr(),
"name": "__start__",
"interrupts": (),
"state": None,
}
],
},
},
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": 0,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "11"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "11",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {
"my_key": "value",
"market": "DE",
},
"metadata": {
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "11",
},
"parent_config": {
"tags": [],
"metadata": {"thread_id": "11"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "11",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"next": ["prepare"],
"tasks": [
{
"id": AnyStr(),
"name": "prepare",
"interrupts": (),
"state": None,
}
],
},
},
{
"type": "task",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "prepare",
"input": {"my_key": "value", "market": "DE"},
"triggers": ["start:prepare"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "prepare",
"result": [("my_key", " prepared")],
"error": None,
"interrupts": [],
},
},
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "11"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "11",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {
"my_key": "value prepared",
"market": "DE",
},
"metadata": {
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "11",
},
"parent_config": {
"tags": [],
"metadata": {"thread_id": "11"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "11",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"next": ["tool_two_slow"],
"tasks": [
{
"id": AnyStr(),
"name": "tool_two_slow",
"interrupts": (),
"state": None,
}
],
},
},
]
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value prepared", "market": "DE"},
tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),),
next=("tool_two_slow",),
config=(await tool_two.checkpointer.aget_tuple(thread1)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "11",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)
# resume, for same result as above
assert await tool_two.ainvoke(None, thread1, debug=1) == {
"my_key": "value prepared slow finished",
"market": "DE",
}
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value prepared slow finished", "market": "DE"},
tasks=(),
next=(),
config=(await tool_two.checkpointer.aget_tuple(thread1)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "11",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)
thread2 = {"configurable": {"thread_id": "12"}}
# stop when about to enter node
assert await tool_two.ainvoke({"my_key": "value", "market": "US"}, thread2) == {
"my_key": "value prepared",
"market": "US",
}
assert await tool_two.aget_state(thread2) == StateSnapshot(
values={"my_key": "value prepared", "market": "US"},
tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),),
next=("tool_two_fast",),
config=(await tool_two.checkpointer.aget_tuple(thread2)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "12",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread2, limit=2)
][-1].config,
)
# resume, for same result as above
assert await tool_two.ainvoke(None, thread2, debug=1) == {
"my_key": "value prepared fast finished",
"market": "US",
}
assert await tool_two.aget_state(thread2) == StateSnapshot(
values={"my_key": "value prepared fast finished", "market": "US"},
tasks=(),
next=(),
config=(await tool_two.checkpointer.aget_tuple(thread2)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "12",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread2, limit=2)
][-1].config,
)
tool_two = tool_two_graph.compile(
checkpointer=checkpointer, interrupt_after=["prepare"]
)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
await tool_two.ainvoke({"my_key": "value", "market": "DE"})
thread1 = {"configurable": {"thread_id": "21"}}
# stop when about to enter node
assert await tool_two.ainvoke({"my_key": "value", "market": "DE"}, thread1) == {
"my_key": "value prepared",
"market": "DE",
}
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value prepared", "market": "DE"},
tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),),
next=("tool_two_slow",),
config=(await tool_two.checkpointer.aget_tuple(thread1)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "21",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)
# resume, for same result as above
assert await tool_two.ainvoke(None, thread1, debug=1) == {
"my_key": "value prepared slow finished",
"market": "DE",
}
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value prepared slow finished", "market": "DE"},
tasks=(),
next=(),
config=(await tool_two.checkpointer.aget_tuple(thread1)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "21",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)
thread2 = {"configurable": {"thread_id": "22"}}
# stop when about to enter node
assert await tool_two.ainvoke({"my_key": "value", "market": "US"}, thread2) == {
"my_key": "value prepared",
"market": "US",
}
assert await tool_two.aget_state(thread2) == StateSnapshot(
values={"my_key": "value prepared", "market": "US"},
tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),),
next=("tool_two_fast",),
config=(await tool_two.checkpointer.aget_tuple(thread2)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "22",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread2, limit=2)
][-1].config,
)
# resume, for same result as above
assert await tool_two.ainvoke(None, thread2, debug=1) == {
"my_key": "value prepared fast finished",
"market": "US",
}
assert await tool_two.aget_state(thread2) == StateSnapshot(
values={"my_key": "value prepared fast finished", "market": "US"},
tasks=(),
next=(),
config=(await tool_two.checkpointer.aget_tuple(thread2)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "22",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread2, limit=2)
][-1].config,
)
thread3 = {"configurable": {"thread_id": "23"}}
# update an empty thread before first run
uconfig = await tool_two.aupdate_state(
thread3, {"my_key": "key", "market": "DE"}
)
# check current state
assert await tool_two.aget_state(thread3) == StateSnapshot(
values={"my_key": "key", "market": "DE"},
tasks=(PregelTask(AnyStr(), "prepare", (PULL, "prepare")),),
next=("prepare",),
config=uconfig,
created_at=AnyStr(),
metadata={
"parents": {},
"source": "update",
"step": 0,
"writes": {START: {"my_key": "key", "market": "DE"}},
"thread_id": "23",
},
parent_config=None,
)
# run from this point
assert await tool_two.ainvoke(None, thread3) == {
"my_key": "key prepared",
"market": "DE",
}
# get state after first node
assert await tool_two.aget_state(thread3) == StateSnapshot(
values={"my_key": "key prepared", "market": "DE"},
tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),),
next=("tool_two_slow",),
config=(await tool_two.checkpointer.aget_tuple(thread3)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "23",
},
parent_config=uconfig,
)
# resume, for same result as above
assert await tool_two.ainvoke(None, thread3, debug=1) == {
"my_key": "key prepared slow finished",
"market": "DE",
}
assert await tool_two.aget_state(thread3) == StateSnapshot(
values={"my_key": "key prepared slow finished", "market": "DE"},
tasks=(),
next=(),
config=(await tool_two.checkpointer.aget_tuple(thread3)).config,
created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[
"ts"
],
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "23",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread3, limit=2)
][-1].config,
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_in_one_fan_out_state_graph_waiting_edge(checkpointer_name: str) -> None:
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
async def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
async def retriever_two(data: State) -> State:
await asyncio.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
async def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_edge("rewrite_query", "retriever_two")
workflow.add_edge(["retriever_one", "retriever_two"], "qa")
workflow.set_finish_point("qa")
app = workflow.compile()
assert await app.ainvoke({"query": "what is weather in sf"}) == {
"query": "analyzed: query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
}
assert [c async for c in app.astream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
async with awith_checkpointer(checkpointer_name) as checkpointer:
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["retriever_one"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c
async for c in app_w_interrupt.astream(
{"query": "what is weather in sf"}, config
)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_in_one_fan_out_state_graph_waiting_edge_via_branch(
snapshot: SnapshotAssertion, checkpointer_name: str
) -> None:
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
async def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
async def retriever_two(data: State) -> State:
await asyncio.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
async def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_conditional_edges(
"rewrite_query", lambda _: "retriever_two", {"retriever_two": "retriever_two"}
)
workflow.add_edge(["retriever_one", "retriever_two"], "qa")
workflow.set_finish_point("qa")
app = workflow.compile()
assert await app.ainvoke({"query": "what is weather in sf"}, debug=True) == {
"query": "analyzed: query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
}
assert [c async for c in app.astream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
async with awith_checkpointer(checkpointer_name) as checkpointer:
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["retriever_one"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c
async for c in app_w_interrupt.astream(
{"query": "what is weather in sf"}, config
)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_in_one_fan_out_state_graph_waiting_edge_custom_state_class(
snapshot: SnapshotAssertion, mocker: MockerFixture, checkpointer_name: str
) -> None:
from pydantic.v1 import BaseModel, ValidationError
setup = mocker.Mock()
teardown = mocker.Mock()
@asynccontextmanager
async def assert_ctx_once() -> AsyncIterator[None]:
assert setup.call_count == 0
assert teardown.call_count == 0
try:
yield
finally:
assert setup.call_count == 1
assert teardown.call_count == 1
setup.reset_mock()
teardown.reset_mock()
@asynccontextmanager
async def make_httpx_client() -> AsyncIterator[httpx.AsyncClient]:
setup()
async with httpx.AsyncClient() as client:
try:
yield client
finally:
teardown()
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(BaseModel):
class Config:
arbitrary_types_allowed = True
query: str
answer: Optional[str] = None
docs: Annotated[list[str], sorted_add]
client: Annotated[httpx.AsyncClient, Context(make_httpx_client)]
class Input(BaseModel):
query: str
class Output(BaseModel):
answer: str
docs: list[str]
class StateUpdate(BaseModel):
query: Optional[str] = None
answer: Optional[str] = None
docs: Optional[list[str]] = None
async def rewrite_query(data: State) -> State:
return {"query": f"query: {data.query}"}
async def analyzer_one(data: State) -> State:
return StateUpdate(query=f"analyzed: {data.query}")
async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
async def retriever_two(data: State) -> State:
await asyncio.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
async def qa(data: State) -> State:
return {"answer": ",".join(data.docs)}
async def decider(data: State) -> str:
assert isinstance(data, State)
return "retriever_two"
workflow = StateGraph(State, input=Input, output=Output)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_conditional_edges(
"rewrite_query", decider, {"retriever_two": "retriever_two"}
)
workflow.add_edge(["retriever_one", "retriever_two"], "qa")
workflow.set_finish_point("qa")
app = workflow.compile()
async with assert_ctx_once():
with pytest.raises(ValidationError):
await app.ainvoke({"query": {}})
async with assert_ctx_once():
assert await app.ainvoke({"query": "what is weather in sf"}) == {
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
}
async with assert_ctx_once():
assert [c async for c in app.astream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
async with awith_checkpointer(checkpointer_name) as checkpointer:
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["retriever_one"],
)
config = {"configurable": {"thread_id": "1"}}
async with assert_ctx_once():
assert [
c
async for c in app_w_interrupt.astream(
{"query": "what is weather in sf"}, config
)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]
async with assert_ctx_once():
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
assert await app_w_interrupt.aget_state(config) == StateSnapshot(
values={
"query": "analyzed: query: what is weather in sf",
"answer": "doc1,doc2,doc3,doc4",
"docs": ["doc1", "doc2", "doc3", "doc4"],
},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"qa": {"answer": "doc1,doc2,doc3,doc4"}},
"step": 4,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
async with assert_ctx_once():
assert await app_w_interrupt.aupdate_state(
config, {"docs": ["doc5"]}, as_node="rewrite_query"
) == {
"configurable": {
"thread_id": "1",
"checkpoint_id": AnyStr(),
"checkpoint_ns": "",
}
}
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2(
snapshot: SnapshotAssertion, checkpointer_name: str
) -> None:
from pydantic import BaseModel, ValidationError
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class InnerObject(BaseModel):
yo: int
class State(BaseModel):
query: str
inner: InnerObject
answer: Optional[str] = None
docs: Annotated[list[str], sorted_add]
class StateUpdate(BaseModel):
query: Optional[str] = None
answer: Optional[str] = None
docs: Optional[list[str]] = None
async def rewrite_query(data: State) -> State:
return {"query": f"query: {data.query}"}
async def analyzer_one(data: State) -> State:
return StateUpdate(query=f"analyzed: {data.query}")
async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
async def retriever_two(data: State) -> State:
await asyncio.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
async def qa(data: State) -> State:
return {"answer": ",".join(data.docs)}
async def decider(data: State) -> str:
assert isinstance(data, State)
return "retriever_two"
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_conditional_edges(
"rewrite_query", decider, {"retriever_two": "retriever_two"}
)
workflow.add_edge(["retriever_one", "retriever_two"], "qa")
workflow.set_finish_point("qa")
app = workflow.compile()
if SHOULD_CHECK_SNAPSHOTS:
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.get_input_schema().model_json_schema() == snapshot
assert app.get_output_schema().model_json_schema() == snapshot
with pytest.raises(ValidationError):
await app.ainvoke({"query": {}})
assert await app.ainvoke(
{"query": "what is weather in sf", "inner": {"yo": 1}}
) == {
"query": "analyzed: query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
"inner": {"yo": 1},
}
assert [
c
async for c in app.astream(
{"query": "what is weather in sf", "inner": {"yo": 1}}
)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
async with awith_checkpointer(checkpointer_name) as checkpointer:
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["retriever_one"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c
async for c in app_w_interrupt.astream(
{"query": "what is weather in sf", "inner": {"yo": 1}}, config
)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
assert await app_w_interrupt.aupdate_state(
config, {"docs": ["doc5"]}, as_node="rewrite_query"
) == {
"configurable": {
"thread_id": "1",
"checkpoint_id": AnyStr(),
"checkpoint_ns": "",
}
}
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_in_one_fan_out_state_graph_waiting_edge_plus_regular(
checkpointer_name: str,
) -> None:
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
async def analyzer_one(data: State) -> State:
await asyncio.sleep(0.1)
return {"query": f'analyzed: {data["query"]}'}
async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
async def retriever_two(data: State) -> State:
await asyncio.sleep(0.2)
return {"docs": ["doc3", "doc4"]}
async def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_edge("rewrite_query", "retriever_two")
workflow.add_edge(["retriever_one", "retriever_two"], "qa")
workflow.set_finish_point("qa")
# silly edge, to make sure having been triggered before doesn't break
# semantics of named barrier (== waiting edges)
workflow.add_edge("rewrite_query", "qa")
app = workflow.compile()
assert await app.ainvoke({"query": "what is weather in sf"}) == {
"query": "analyzed: query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
}
assert [c async for c in app.astream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"qa": {"answer": ""}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
async with awith_checkpointer(checkpointer_name) as checkpointer:
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["retriever_one"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c
async for c in app_w_interrupt.astream(
{"query": "what is weather in sf"}, config
)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"qa": {"answer": ""}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]
assert [c async for c in app_w_interrupt.astream(None, config)] == [
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
async def test_in_one_fan_out_state_graph_waiting_edge_multiple() -> None:
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
async def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
async def retriever_two(data: State) -> State:
await asyncio.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
async def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
async def decider(data: State) -> None:
return None
def decider_cond(data: State) -> str:
if data["query"].count("analyzed") > 1:
return "qa"
else:
return "rewrite_query"
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("decider", decider)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_edge("rewrite_query", "retriever_two")
workflow.add_edge(["retriever_one", "retriever_two"], "decider")
workflow.add_conditional_edges("decider", decider_cond)
workflow.set_finish_point("qa")
app = workflow.compile()
assert await app.ainvoke({"query": "what is weather in sf"}) == {
"query": "analyzed: query: analyzed: query: what is weather in sf",
"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4",
"docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"],
}
assert [c async for c in app.astream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"decider": None},
{"rewrite_query": {"query": "query: analyzed: query: what is weather in sf"}},
{
"analyzer_one": {
"query": "analyzed: query: analyzed: query: what is weather in sf"
}
},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"decider": None},
{"qa": {"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4"}},
]
async def test_in_one_fan_out_state_graph_waiting_edge_multiple_cond_edge() -> None:
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
async def retriever_picker(data: State) -> list[str]:
return ["analyzer_one", "retriever_two"]
async def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
async def retriever_two(data: State) -> State:
await asyncio.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
async def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
async def decider(data: State) -> None:
return None
def decider_cond(data: State) -> str:
if data["query"].count("analyzed") > 1:
return "qa"
else:
return "rewrite_query"
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("decider", decider)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_conditional_edges("rewrite_query", retriever_picker)
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_edge(["retriever_one", "retriever_two"], "decider")
workflow.add_conditional_edges("decider", decider_cond)
workflow.set_finish_point("qa")
app = workflow.compile()
assert await app.ainvoke({"query": "what is weather in sf"}) == {
"query": "analyzed: query: analyzed: query: what is weather in sf",
"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4",
"docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"],
}
assert [c async for c in app.astream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"decider": None},
{"rewrite_query": {"query": "query: analyzed: query: what is weather in sf"}},
{
"analyzer_one": {
"query": "analyzed: query: analyzed: query: what is weather in sf"
}
},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"decider": None},
{"qa": {"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4"}},
]
async def test_nested_graph(snapshot: SnapshotAssertion) -> None:
def never_called_fn(state: Any):
assert 0, "This function should never be called"
never_called = RunnableLambda(never_called_fn)
class InnerState(TypedDict):
my_key: str
my_other_key: str
def up(state: InnerState):
return {"my_key": state["my_key"] + " there", "my_other_key": state["my_key"]}
inner = StateGraph(InnerState)
inner.add_node("up", up)
inner.set_entry_point("up")
inner.set_finish_point("up")
class State(TypedDict):
my_key: str
never_called: Any
async def side(state: State):
return {"my_key": state["my_key"] + " and back again"}
graph = StateGraph(State)
graph.add_node("inner", inner.compile())
graph.add_node("side", side)
graph.set_entry_point("inner")
graph.add_edge("inner", "side")
graph.set_finish_point("side")
app = graph.compile()
assert await app.ainvoke({"my_key": "my value", "never_called": never_called}) == {
"my_key": "my value there and back again",
"never_called": never_called,
}
assert [
chunk
async for chunk in app.astream(
{"my_key": "my value", "never_called": never_called}
)
] == [
{"inner": {"my_key": "my value there"}},
{"side": {"my_key": "my value there and back again"}},
]
assert [
chunk
async for chunk in app.astream(
{"my_key": "my value", "never_called": never_called}, stream_mode="values"
)
] == [
{"my_key": "my value", "never_called": never_called},
{"my_key": "my value there", "never_called": never_called},
{"my_key": "my value there and back again", "never_called": never_called},
]
times_called = 0
async for event in app.astream_events(
{"my_key": "my value", "never_called": never_called},
version="v2",
config={"run_id": UUID(int=0)},
stream_mode="values",
):
if event["event"] == "on_chain_end" and event["run_id"] == str(UUID(int=0)):
times_called += 1
assert event["data"] == {
"output": {
"my_key": "my value there and back again",
"never_called": never_called,
}
}
assert times_called == 1
times_called = 0
async for event in app.astream_events(
{"my_key": "my value", "never_called": never_called},
version="v2",
config={"run_id": UUID(int=0)},
):
if event["event"] == "on_chain_end" and event["run_id"] == str(UUID(int=0)):
times_called += 1
assert event["data"] == {
"output": {
"my_key": "my value there and back again",
"never_called": never_called,
}
}
assert times_called == 1
chain = app | RunnablePassthrough()
assert await chain.ainvoke(
{"my_key": "my value", "never_called": never_called}
) == {
"my_key": "my value there and back again",
"never_called": never_called,
}
assert [
chunk
async for chunk in chain.astream(
{"my_key": "my value", "never_called": never_called}
)
] == [
{"inner": {"my_key": "my value there"}},
{"side": {"my_key": "my value there and back again"}},
]
times_called = 0
async for event in chain.astream_events(
{"my_key": "my value", "never_called": never_called},
version="v2",
config={"run_id": UUID(int=0)},
):
if event["event"] == "on_chain_end" and event["run_id"] == str(UUID(int=0)):
times_called += 1
assert event["data"] == {
"output": [
{"inner": {"my_key": "my value there"}},
{"side": {"my_key": "my value there and back again"}},
]
}
assert times_called == 1
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_stream_subgraphs_during_execution(checkpointer_name: str) -> None:
class InnerState(TypedDict):
my_key: Annotated[str, operator.add]
my_other_key: str
async def inner_1(state: InnerState):
return {"my_key": "got here", "my_other_key": state["my_key"]}
async def inner_2(state: InnerState):
await asyncio.sleep(0.5)
return {
"my_key": " and there",
"my_other_key": state["my_key"],
}
inner = StateGraph(InnerState)
inner.add_node("inner_1", inner_1)
inner.add_node("inner_2", inner_2)
inner.add_edge("inner_1", "inner_2")
inner.set_entry_point("inner_1")
inner.set_finish_point("inner_2")
class State(TypedDict):
my_key: Annotated[str, operator.add]
async def outer_1(state: State):
await asyncio.sleep(0.2)
return {"my_key": " and parallel"}
async def outer_2(state: State):
return {"my_key": " and back again"}
graph = StateGraph(State)
graph.add_node("inner", inner.compile())
graph.add_node("outer_1", outer_1)
graph.add_node("outer_2", outer_2)
graph.add_edge(START, "inner")
graph.add_edge(START, "outer_1")
graph.add_edge(["inner", "outer_1"], "outer_2")
graph.add_edge("outer_2", END)
async with awith_checkpointer(checkpointer_name) as checkpointer:
app = graph.compile(checkpointer=checkpointer)
start = perf_counter()
chunks: list[tuple[float, Any]] = []
config = {"configurable": {"thread_id": "2"}}
async for c in app.astream({"my_key": ""}, config, subgraphs=True):
chunks.append((round(perf_counter() - start, 1), c))
for idx in range(len(chunks)):
elapsed, c = chunks[idx]
chunks[idx] = (round(elapsed - chunks[0][0], 1), c)
assert chunks == [
# arrives before "inner" finishes
(
FloatBetween(0.0, 0.1),
(
(AnyStr("inner:"),),
{"inner_1": {"my_key": "got here", "my_other_key": ""}},
),
),
(FloatBetween(0.2, 0.4), ((), {"outer_1": {"my_key": " and parallel"}})),
(
FloatBetween(0.5, 0.7),
(
(AnyStr("inner:"),),
{"inner_2": {"my_key": " and there", "my_other_key": "got here"}},
),
),
(FloatBetween(0.5, 0.7), ((), {"inner": {"my_key": "got here and there"}})),
(FloatBetween(0.5, 0.7), ((), {"outer_2": {"my_key": " and back again"}})),
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_stream_buffering_single_node(checkpointer_name: str) -> None:
class State(TypedDict):
my_key: Annotated[str, operator.add]
async def node(state: State, writer: StreamWriter):
writer("Before sleep")
await asyncio.sleep(0.2)
writer("After sleep")
return {"my_key": "got here"}
builder = StateGraph(State)
builder.add_node("node", node)
builder.add_edge(START, "node")
builder.add_edge("node", END)
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
start = perf_counter()
chunks: list[tuple[float, Any]] = []
config = {"configurable": {"thread_id": "2"}}
async for c in graph.astream({"my_key": ""}, config, stream_mode="custom"):
chunks.append((round(perf_counter() - start, 1), c))
assert chunks == [
(FloatBetween(0.0, 0.1), "Before sleep"),
(FloatBetween(0.2, 0.3), "After sleep"),
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_nested_graph_interrupts_parallel(checkpointer_name: str) -> None:
class InnerState(TypedDict):
my_key: Annotated[str, operator.add]
my_other_key: str
async def inner_1(state: InnerState):
await asyncio.sleep(0.1)
return {"my_key": "got here", "my_other_key": state["my_key"]}
async def inner_2(state: InnerState):
return {
"my_key": " and there",
"my_other_key": state["my_key"],
}
inner = StateGraph(InnerState)
inner.add_node("inner_1", inner_1)
inner.add_node("inner_2", inner_2)
inner.add_edge("inner_1", "inner_2")
inner.set_entry_point("inner_1")
inner.set_finish_point("inner_2")
class State(TypedDict):
my_key: Annotated[str, operator.add]
async def outer_1(state: State):
return {"my_key": " and parallel"}
async def outer_2(state: State):
return {"my_key": " and back again"}
graph = StateGraph(State)
graph.add_node(
"inner",
inner.compile(interrupt_before=["inner_2"]),
)
graph.add_node("outer_1", outer_1)
graph.add_node("outer_2", outer_2)
graph.add_edge(START, "inner")
graph.add_edge(START, "outer_1")
graph.add_edge(["inner", "outer_1"], "outer_2")
graph.set_finish_point("outer_2")
async with awith_checkpointer(checkpointer_name) as checkpointer:
app = graph.compile(checkpointer=checkpointer)
# test invoke w/ nested interrupt
config = {"configurable": {"thread_id": "1"}}
assert await app.ainvoke({"my_key": ""}, config, debug=True) == {
"my_key": " and parallel",
}
assert await app.ainvoke(None, config, debug=True) == {
"my_key": "got here and there and parallel and back again",
}
# below combo of assertions is asserting two things
# - outer_1 finishes before inner interrupts (because we see its output in stream, which only happens after node finishes)
# - the writes of outer are persisted in 1st call and used in 2nd call, ie outer isn't called again (because we dont see outer_1 output again in 2nd stream)
# test stream updates w/ nested interrupt
config = {"configurable": {"thread_id": "2"}}
assert [
c async for c in app.astream({"my_key": ""}, config, subgraphs=True)
] == [
# we got to parallel node first
((), {"outer_1": {"my_key": " and parallel"}}),
(
(AnyStr("inner:"),),
{"inner_1": {"my_key": "got here", "my_other_key": ""}},
),
((), {"__interrupt__": ()}),
]
assert [c async for c in app.astream(None, config)] == [
{"outer_1": {"my_key": " and parallel"}, "__metadata__": {"cached": True}},
{"inner": {"my_key": "got here and there"}},
{"outer_2": {"my_key": " and back again"}},
]
# test stream values w/ nested interrupt
config = {"configurable": {"thread_id": "3"}}
assert [
c async for c in app.astream({"my_key": ""}, config, stream_mode="values")
] == [
{"my_key": ""},
{"my_key": " and parallel"},
]
assert [c async for c in app.astream(None, config, stream_mode="values")] == [
{"my_key": ""},
{"my_key": "got here and there and parallel"},
{"my_key": "got here and there and parallel and back again"},
]
# # test interrupts BEFORE the parallel node
app = graph.compile(checkpointer=checkpointer, interrupt_before=["outer_1"])
config = {"configurable": {"thread_id": "4"}}
assert [
c async for c in app.astream({"my_key": ""}, config, stream_mode="values")
] == [
{"my_key": ""},
]
# while we're waiting for the node w/ interrupt inside to finish
assert [c async for c in app.astream(None, config, stream_mode="values")] == [
{"my_key": ""},
{"my_key": " and parallel"},
]
assert [c async for c in app.astream(None, config, stream_mode="values")] == [
{"my_key": ""},
{"my_key": "got here and there and parallel"},
{"my_key": "got here and there and parallel and back again"},
]
# test interrupts AFTER the parallel node
app = graph.compile(checkpointer=checkpointer, interrupt_after=["outer_1"])
config = {"configurable": {"thread_id": "5"}}
assert [
c async for c in app.astream({"my_key": ""}, config, stream_mode="values")
] == [
{"my_key": ""},
{"my_key": " and parallel"},
]
assert [c async for c in app.astream(None, config, stream_mode="values")] == [
{"my_key": ""},
{"my_key": "got here and there and parallel"},
]
assert [c async for c in app.astream(None, config, stream_mode="values")] == [
{"my_key": "got here and there and parallel"},
{"my_key": "got here and there and parallel and back again"},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_doubly_nested_graph_interrupts(checkpointer_name: str) -> None:
class State(TypedDict):
my_key: str
class ChildState(TypedDict):
my_key: str
class GrandChildState(TypedDict):
my_key: str
async def grandchild_1(state: ChildState):
return {"my_key": state["my_key"] + " here"}
async def grandchild_2(state: ChildState):
return {
"my_key": state["my_key"] + " and there",
}
grandchild = StateGraph(GrandChildState)
grandchild.add_node("grandchild_1", grandchild_1)
grandchild.add_node("grandchild_2", grandchild_2)
grandchild.add_edge("grandchild_1", "grandchild_2")
grandchild.set_entry_point("grandchild_1")
grandchild.set_finish_point("grandchild_2")
child = StateGraph(ChildState)
child.add_node(
"child_1",
grandchild.compile(interrupt_before=["grandchild_2"]),
)
child.set_entry_point("child_1")
child.set_finish_point("child_1")
async def parent_1(state: State):
return {"my_key": "hi " + state["my_key"]}
async def parent_2(state: State):
return {"my_key": state["my_key"] + " and back again"}
graph = StateGraph(State)
graph.add_node("parent_1", parent_1)
graph.add_node("child", child.compile())
graph.add_node("parent_2", parent_2)
graph.set_entry_point("parent_1")
graph.add_edge("parent_1", "child")
graph.add_edge("child", "parent_2")
graph.set_finish_point("parent_2")
async with awith_checkpointer(checkpointer_name) as checkpointer:
app = graph.compile(checkpointer=checkpointer)
# test invoke w/ nested interrupt
config = {"configurable": {"thread_id": "1"}}
assert await app.ainvoke({"my_key": "my value"}, config, debug=True) == {
"my_key": "hi my value",
}
assert await app.ainvoke(None, config, debug=True) == {
"my_key": "hi my value here and there and back again",
}
# test stream updates w/ nested interrupt
nodes: list[str] = []
config = {
"configurable": {"thread_id": "2", CONFIG_KEY_NODE_FINISHED: nodes.append}
}
assert [c async for c in app.astream({"my_key": "my value"}, config)] == [
{"parent_1": {"my_key": "hi my value"}},
{"__interrupt__": ()},
]
assert nodes == ["parent_1", "grandchild_1"]
assert [c async for c in app.astream(None, config)] == [
{"child": {"my_key": "hi my value here and there"}},
{"parent_2": {"my_key": "hi my value here and there and back again"}},
]
assert nodes == [
"parent_1",
"grandchild_1",
"grandchild_2",
"child_1",
"child",
"parent_2",
]
# test stream values w/ nested interrupt
config = {"configurable": {"thread_id": "3"}}
assert [
c
async for c in app.astream(
{"my_key": "my value"}, config, stream_mode="values"
)
] == [
{"my_key": "my value"},
{"my_key": "hi my value"},
]
assert [c async for c in app.astream(None, config, stream_mode="values")] == [
{"my_key": "hi my value"},
{"my_key": "hi my value here and there"},
{"my_key": "hi my value here and there and back again"},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_nested_graph_state(checkpointer_name: str) -> None:
class InnerState(TypedDict):
my_key: str
my_other_key: str
def inner_1(state: InnerState):
return {
"my_key": state["my_key"] + " here",
"my_other_key": state["my_key"],
}
def inner_2(state: InnerState):
return {
"my_key": state["my_key"] + " and there",
"my_other_key": state["my_key"],
}
inner = StateGraph(InnerState)
inner.add_node("inner_1", inner_1)
inner.add_node("inner_2", inner_2)
inner.add_edge("inner_1", "inner_2")
inner.set_entry_point("inner_1")
inner.set_finish_point("inner_2")
class State(TypedDict):
my_key: str
other_parent_key: str
def outer_1(state: State):
return {"my_key": "hi " + state["my_key"]}
def outer_2(state: State):
return {"my_key": state["my_key"] + " and back again"}
graph = StateGraph(State)
graph.add_node("outer_1", outer_1)
graph.add_node(
"inner",
inner.compile(interrupt_before=["inner_2"]),
)
graph.add_node("outer_2", outer_2)
graph.set_entry_point("outer_1")
graph.add_edge("outer_1", "inner")
graph.add_edge("inner", "outer_2")
graph.set_finish_point("outer_2")
async with awith_checkpointer(checkpointer_name) as checkpointer:
app = graph.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}
await app.ainvoke({"my_key": "my value"}, config, debug=True)
# test state w/ nested subgraph state (right after interrupt)
# first get_state without subgraph state
assert await app.aget_state(config) == StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"inner",
(PULL, "inner"),
state={
"configurable": {"thread_id": "1", "checkpoint_ns": AnyStr()}
},
),
),
next=("inner",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"outer_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
# now, get_state with subgraphs state
assert await app.aget_state(config, subgraphs=True) == StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"inner",
(PULL, "inner"),
state=StateSnapshot(
values={
"my_key": "hi my value here",
"my_other_key": "hi my value",
},
tasks=(
PregelTask(
AnyStr(),
"inner_2",
(PULL, "inner_2"),
),
),
next=("inner_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"parents": {
"": AnyStr(),
},
"source": "loop",
"writes": {
"inner_1": {
"my_key": "hi my value here",
"my_other_key": "hi my value",
}
},
"step": 1,
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"langgraph_node": "inner",
"langgraph_path": [PULL, "inner"],
"langgraph_step": 2,
"langgraph_triggers": ["outer_1"],
"langgraph_checkpoint_ns": AnyStr("inner:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
),
),
),
next=("inner",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"outer_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
# get_state_history returns outer graph checkpoints
history = [c async for c in app.aget_state_history(config)]
assert history == [
StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"inner",
(PULL, "inner"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
}
},
),
),
next=("inner",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"outer_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "my value"},
tasks=(
PregelTask(
AnyStr(),
"outer_1",
(PULL, "outer_1"),
result={"my_key": "hi my value"},
),
),
next=("outer_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": None,
"step": 0,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={},
tasks=(
PregelTask(
AnyStr(),
"__start__",
(PULL, "__start__"),
result={"my_key": "my value"},
),
),
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"writes": {"__start__": {"my_key": "my value"}},
"step": -1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
),
]
# get_state_history for a subgraph returns its checkpoints
child_history = [
c async for c in app.aget_state_history(history[0].tasks[0].state)
]
assert child_history == [
StateSnapshot(
values={"my_key": "hi my value here", "my_other_key": "hi my value"},
next=("inner_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "loop",
"writes": {
"inner_1": {
"my_key": "hi my value here",
"my_other_key": "hi my value",
}
},
"step": 1,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"langgraph_node": "inner",
"langgraph_path": [PULL, "inner"],
"langgraph_step": 2,
"langgraph_triggers": ["outer_1"],
"langgraph_checkpoint_ns": AnyStr("inner:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
tasks=(PregelTask(AnyStr(), "inner_2", (PULL, "inner_2")),),
),
StateSnapshot(
values={"my_key": "hi my value"},
next=("inner_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "loop",
"writes": None,
"step": 0,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"langgraph_node": "inner",
"langgraph_path": [PULL, "inner"],
"langgraph_step": 2,
"langgraph_triggers": ["outer_1"],
"langgraph_checkpoint_ns": AnyStr("inner:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
tasks=(
PregelTask(
AnyStr(),
"inner_1",
(PULL, "inner_1"),
result={
"my_key": "hi my value here",
"my_other_key": "hi my value",
},
),
),
),
StateSnapshot(
values={},
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "input",
"writes": {"__start__": {"my_key": "hi my value"}},
"step": -1,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"langgraph_node": "inner",
"langgraph_path": [PULL, "inner"],
"langgraph_step": 2,
"langgraph_triggers": ["outer_1"],
"langgraph_checkpoint_ns": AnyStr("inner:"),
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
AnyStr(),
"__start__",
(PULL, "__start__"),
result={"my_key": "hi my value"},
),
),
),
]
# resume
await app.ainvoke(None, config, debug=True)
# test state w/ nested subgraph state (after resuming from interrupt)
assert await app.aget_state(config) == StateSnapshot(
values={"my_key": "hi my value here and there and back again"},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"outer_2": {"my_key": "hi my value here and there and back again"}
},
"step": 3,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
# test full history at the end
actual_history = [c async for c in app.aget_state_history(config)]
expected_history = [
StateSnapshot(
values={"my_key": "hi my value here and there and back again"},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"outer_2": {
"my_key": "hi my value here and there and back again"
}
},
"step": 3,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "hi my value here and there"},
tasks=(
PregelTask(
AnyStr(),
"outer_2",
(PULL, "outer_2"),
result={"my_key": "hi my value here and there and back again"},
),
),
next=("outer_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"inner": {"my_key": "hi my value here and there"}},
"step": 2,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"inner",
(PULL, "inner"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
}
},
result={"my_key": "hi my value here and there"},
),
),
next=("inner",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"outer_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "my value"},
tasks=(
PregelTask(
AnyStr(),
"outer_1",
(PULL, "outer_1"),
result={"my_key": "hi my value"},
),
),
next=("outer_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": None,
"step": 0,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={},
tasks=(
PregelTask(
AnyStr(),
"__start__",
(PULL, "__start__"),
result={"my_key": "my value"},
),
),
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"writes": {"__start__": {"my_key": "my value"}},
"step": -1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
),
]
assert actual_history == expected_history
# test looking up parent state by checkpoint ID
for actual_snapshot, expected_snapshot in zip(actual_history, expected_history):
assert await app.aget_state(actual_snapshot.config) == expected_snapshot
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_doubly_nested_graph_state(checkpointer_name: str) -> None:
class State(TypedDict):
my_key: str
class ChildState(TypedDict):
my_key: str
class GrandChildState(TypedDict):
my_key: str
def grandchild_1(state: ChildState):
return {"my_key": state["my_key"] + " here"}
def grandchild_2(state: ChildState):
return {
"my_key": state["my_key"] + " and there",
}
grandchild = StateGraph(GrandChildState)
grandchild.add_node("grandchild_1", grandchild_1)
grandchild.add_node("grandchild_2", grandchild_2)
grandchild.add_edge("grandchild_1", "grandchild_2")
grandchild.set_entry_point("grandchild_1")
grandchild.set_finish_point("grandchild_2")
child = StateGraph(ChildState)
child.add_node(
"child_1",
grandchild.compile(interrupt_before=["grandchild_2"]),
)
child.set_entry_point("child_1")
child.set_finish_point("child_1")
def parent_1(state: State):
return {"my_key": "hi " + state["my_key"]}
def parent_2(state: State):
return {"my_key": state["my_key"] + " and back again"}
graph = StateGraph(State)
graph.add_node("parent_1", parent_1)
graph.add_node("child", child.compile())
graph.add_node("parent_2", parent_2)
graph.set_entry_point("parent_1")
graph.add_edge("parent_1", "child")
graph.add_edge("child", "parent_2")
graph.set_finish_point("parent_2")
async with awith_checkpointer(checkpointer_name) as checkpointer:
app = graph.compile(checkpointer=checkpointer)
# test invoke w/ nested interrupt
config = {"configurable": {"thread_id": "1"}}
assert [
c async for c in app.astream({"my_key": "my value"}, config, subgraphs=True)
] == [
((), {"parent_1": {"my_key": "hi my value"}}),
(
(AnyStr("child:"), AnyStr("child_1:")),
{"grandchild_1": {"my_key": "hi my value here"}},
),
((), {"__interrupt__": ()}),
]
# get state without subgraphs
outer_state = await app.aget_state(config)
assert outer_state == StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child",
(PULL, "child"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child"),
}
},
),
),
next=("child",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"parent_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
child_state = await app.aget_state(outer_state.tasks[0].state)
assert (
child_state.tasks[0]
== StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child_1",
(PULL, "child_1"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
}
},
),
),
next=("child_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {"": AnyStr()},
"source": "loop",
"writes": None,
"step": 0,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
}
},
).tasks[0]
)
grandchild_state = await app.aget_state(child_state.tasks[0].state)
assert grandchild_state == StateSnapshot(
values={"my_key": "hi my value here"},
tasks=(
PregelTask(
AnyStr(),
"grandchild_2",
(PULL, "grandchild_2"),
),
),
next=("grandchild_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"source": "loop",
"writes": {"grandchild_1": {"my_key": "hi my value here"}},
"step": 1,
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [PULL, AnyStr("child_1")],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
)
# get state with subgraphs
assert await app.aget_state(config, subgraphs=True) == StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child",
(PULL, "child"),
state=StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child_1",
(PULL, "child_1"),
state=StateSnapshot(
values={"my_key": "hi my value here"},
tasks=(
PregelTask(
AnyStr(),
"grandchild_2",
(PULL, "grandchild_2"),
),
),
next=("grandchild_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(
re.compile(r"child:.+|child1:")
): AnyStr(),
}
),
}
},
metadata={
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"source": "loop",
"writes": {
"grandchild_1": {
"my_key": "hi my value here"
}
},
"step": 1,
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [
PULL,
AnyStr("child_1"),
],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(
re.compile(r"child:.+|child1:")
): AnyStr(),
}
),
}
},
),
),
),
next=("child_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"parents": {"": AnyStr()},
"source": "loop",
"writes": None,
"step": 0,
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child",
"langgraph_path": [PULL, AnyStr("child")],
"langgraph_step": 2,
"langgraph_triggers": [AnyStr("parent_1")],
"langgraph_checkpoint_ns": AnyStr("child:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
),
),
),
next=("child",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"parent_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
# resume
assert [c async for c in app.astream(None, config, subgraphs=True)] == [
(
(AnyStr("child:"), AnyStr("child_1:")),
{"grandchild_2": {"my_key": "hi my value here and there"}},
),
(
(AnyStr("child:"),),
{"child_1": {"my_key": "hi my value here and there"}},
),
((), {"child": {"my_key": "hi my value here and there"}}),
((), {"parent_2": {"my_key": "hi my value here and there and back again"}}),
]
# get state with and without subgraphs
assert (
await app.aget_state(config)
== await app.aget_state(config, subgraphs=True)
== StateSnapshot(
values={"my_key": "hi my value here and there and back again"},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"parent_2": {
"my_key": "hi my value here and there and back again"
}
},
"step": 3,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
)
# get outer graph history
outer_history = [c async for c in app.aget_state_history(config)]
assert (
outer_history[0]
== [
StateSnapshot(
values={"my_key": "hi my value here and there and back again"},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"parent_2": {
"my_key": "hi my value here and there and back again"
}
},
"step": 3,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "hi my value here and there"},
next=("parent_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"child": {"my_key": "hi my value here and there"}},
"step": 2,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(), name="parent_2", path=(PULL, "parent_2")
),
),
),
StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child",
(PULL, "child"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child"),
}
},
),
),
next=("child",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"parent_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "my value"},
next=("parent_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": None,
"step": 0,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(), name="parent_1", path=(PULL, "parent_1")
),
),
),
StateSnapshot(
values={},
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"writes": {"my_key": "my value"},
"step": -1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
id=AnyStr(), name="__start__", path=(PULL, "__start__")
),
),
),
][0]
)
# get child graph history
child_history = [
c async for c in app.aget_state_history(outer_history[2].tasks[0].state)
]
assert child_history == [
StateSnapshot(
values={"my_key": "hi my value here and there"},
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "loop",
"writes": {"child_1": {"my_key": "hi my value here and there"}},
"step": 1,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child",
"langgraph_path": [PULL, AnyStr("child")],
"langgraph_step": 2,
"langgraph_triggers": [AnyStr("parent_1")],
"langgraph_checkpoint_ns": AnyStr("child:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
tasks=(),
),
StateSnapshot(
values={"my_key": "hi my value"},
next=("child_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "loop",
"writes": None,
"step": 0,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child",
"langgraph_path": [PULL, AnyStr("child")],
"langgraph_step": 2,
"langgraph_triggers": [AnyStr("parent_1")],
"langgraph_checkpoint_ns": AnyStr("child:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="child_1",
path=(PULL, "child_1"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
}
},
result={"my_key": "hi my value here and there"},
),
),
),
StateSnapshot(
values={},
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "input",
"writes": {"__start__": {"my_key": "hi my value"}},
"step": -1,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child",
"langgraph_path": [PULL, AnyStr("child")],
"langgraph_step": 2,
"langgraph_triggers": [AnyStr("parent_1")],
"langgraph_checkpoint_ns": AnyStr("child:"),
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
id=AnyStr(),
name="__start__",
path=(PULL, "__start__"),
result={"my_key": "hi my value"},
),
),
),
]
# get grandchild graph history
grandchild_history = [
c async for c in app.aget_state_history(child_history[1].tasks[0].state)
]
assert grandchild_history == [
StateSnapshot(
values={"my_key": "hi my value here and there"},
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"source": "loop",
"writes": {
"grandchild_2": {"my_key": "hi my value here and there"}
},
"step": 2,
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [
PULL,
AnyStr("child_1"),
],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
tasks=(),
),
StateSnapshot(
values={"my_key": "hi my value here"},
next=("grandchild_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"source": "loop",
"writes": {"grandchild_1": {"my_key": "hi my value here"}},
"step": 1,
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [
PULL,
AnyStr("child_1"),
],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="grandchild_2",
path=(PULL, "grandchild_2"),
result={"my_key": "hi my value here and there"},
),
),
),
StateSnapshot(
values={"my_key": "hi my value"},
next=("grandchild_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"source": "loop",
"writes": None,
"step": 0,
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [
PULL,
AnyStr("child_1"),
],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="grandchild_1",
path=(PULL, "grandchild_1"),
result={"my_key": "hi my value here"},
),
),
),
StateSnapshot(
values={},
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"source": "input",
"writes": {"__start__": {"my_key": "hi my value"}},
"step": -1,
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [
PULL,
AnyStr("child_1"),
],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
id=AnyStr(),
name="__start__",
path=(PULL, "__start__"),
result={"my_key": "hi my value"},
),
),
),
]
# replay grandchild checkpoint
assert [
c
async for c in app.astream(
None, grandchild_history[2].config, subgraphs=True
)
] == [
(
(AnyStr("child:"), AnyStr("child_1:")),
{"grandchild_1": {"my_key": "hi my value here"}},
),
((), {"__interrupt__": ()}),
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_send_to_nested_graphs(checkpointer_name: str) -> None:
class OverallState(TypedDict):
subjects: list[str]
jokes: Annotated[list[str], operator.add]
async def continue_to_jokes(state: OverallState):
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
class JokeState(TypedDict):
subject: str
async def edit(state: JokeState):
subject = state["subject"]
return {"subject": f"{subject} - hohoho"}
# subgraph
subgraph = StateGraph(JokeState, output=OverallState)
subgraph.add_node("edit", edit)
subgraph.add_node(
"generate", lambda state: {"jokes": [f"Joke about {state['subject']}"]}
)
subgraph.set_entry_point("edit")
subgraph.add_edge("edit", "generate")
subgraph.set_finish_point("generate")
# parent graph
builder = StateGraph(OverallState)
builder.add_node(
"generate_joke",
subgraph.compile(interrupt_before=["generate"]),
)
builder.add_conditional_edges(START, continue_to_jokes)
builder.add_edge("generate_joke", END)
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}
tracer = FakeTracer()
# invoke and pause at nested interrupt
assert await graph.ainvoke(
{"subjects": ["cats", "dogs"]},
config={**config, "callbacks": [tracer]},
) == {
"subjects": ["cats", "dogs"],
"jokes": [],
}
assert len(tracer.runs) == 1, "Should produce exactly 1 root run"
# check state
outer_state = await graph.aget_state(config)
if not FF_SEND_V2:
# update state of dogs joke graph
await graph.aupdate_state(
outer_state.tasks[1].state, {"subject": "turtles - hohoho"}
)
# continue past interrupt
assert await graph.ainvoke(None, config=config) == {
"subjects": ["cats", "dogs"],
"jokes": ["Joke about cats - hohoho", "Joke about turtles - hohoho"],
}
return
assert outer_state == StateSnapshot(
values={"subjects": ["cats", "dogs"], "jokes": []},
tasks=(
PregelTask(
id=AnyStr(),
name="__start__",
path=("__pregel_pull", "__start__"),
error=None,
interrupts=(),
state=None,
result={"subjects": ["cats", "dogs"]},
),
PregelTask(
AnyStr(),
"generate_joke",
(PUSH, ("__pregel_pull", "__start__"), 1, AnyStr()),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
}
},
),
PregelTask(
AnyStr(),
"generate_joke",
(PUSH, ("__pregel_pull", "__start__"), 2, AnyStr()),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
}
},
),
),
next=("generate_joke", "generate_joke"),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"writes": {
"__start__": {
"subjects": [
"cats",
"dogs",
],
}
},
"step": -1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
)
# update state of dogs joke graph
await graph.aupdate_state(
outer_state.tasks[2].state, {"subject": "turtles - hohoho"}
)
# continue past interrupt
assert await graph.ainvoke(None, config=config) == {
"subjects": ["cats", "dogs"],
"jokes": ["Joke about cats - hohoho", "Joke about turtles - hohoho"],
}
actual_snapshot = await graph.aget_state(config)
expected_snapshot = StateSnapshot(
values={
"subjects": ["cats", "dogs"],
"jokes": ["Joke about cats - hohoho", "Joke about turtles - hohoho"],
},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"generate_joke": [
{"jokes": ["Joke about cats - hohoho"]},
{"jokes": ["Joke about turtles - hohoho"]},
]
},
"step": 0,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
assert actual_snapshot == expected_snapshot
# test full history
actual_history = [c async for c in graph.aget_state_history(config)]
expected_history = [
StateSnapshot(
values={
"subjects": ["cats", "dogs"],
"jokes": [
"Joke about cats - hohoho",
"Joke about turtles - hohoho",
],
},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"generate_joke": [
{"jokes": ["Joke about cats - hohoho"]},
{"jokes": ["Joke about turtles - hohoho"]},
]
},
"step": 0,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"jokes": []},
next=("__start__", "generate_joke", "generate_joke"),
tasks=(
PregelTask(
id=AnyStr(),
name="__start__",
path=("__pregel_pull", "__start__"),
error=None,
interrupts=(),
state=None,
result={"subjects": ["cats", "dogs"]},
),
PregelTask(
AnyStr(),
"generate_joke",
(PUSH, ("__pregel_pull", "__start__"), 1, AnyStr()),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
}
},
result={"jokes": ["Joke about cats - hohoho"]},
),
PregelTask(
AnyStr(),
"generate_joke",
(PUSH, ("__pregel_pull", "__start__"), 2, AnyStr()),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
}
},
result={"jokes": ["Joke about turtles - hohoho"]},
),
),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"writes": {"__start__": {"subjects": ["cats", "dogs"]}},
"step": -1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
),
]
assert actual_history == expected_history
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_weather_subgraph(
checkpointer_name: str, snapshot: SnapshotAssertion
) -> None:
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, ToolCall
from langchain_core.tools import tool
from langgraph.graph import MessagesState
# setup subgraph
@tool
def get_weather(city: str):
"""Get the weather for a specific city"""
return f"I'ts sunny in {city}!"
weather_model = FakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
ToolCall(
id="tool_call123",
name="get_weather",
args={"city": "San Francisco"},
)
],
)
]
)
class SubGraphState(MessagesState):
city: str
def model_node(state: SubGraphState, writer: StreamWriter):
writer(" very")
result = weather_model.invoke(state["messages"])
return {"city": cast(AIMessage, result).tool_calls[0]["args"]["city"]}
def weather_node(state: SubGraphState, writer: StreamWriter):
writer(" good")
result = get_weather.invoke({"city": state["city"]})
return {"messages": [{"role": "assistant", "content": result}]}
subgraph = StateGraph(SubGraphState)
subgraph.add_node(model_node)
subgraph.add_node(weather_node)
subgraph.add_edge(START, "model_node")
subgraph.add_edge("model_node", "weather_node")
subgraph.add_edge("weather_node", END)
subgraph = subgraph.compile(interrupt_before=["weather_node"])
# setup main graph
class RouterState(MessagesState):
route: Literal["weather", "other"]
class Router(TypedDict):
route: Literal["weather", "other"]
router_model = FakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
ToolCall(
id="tool_call123",
name="router",
args={"dest": "weather"},
)
],
)
]
)
def router_node(state: RouterState, writer: StreamWriter):
writer("I'm")
system_message = "Classify the incoming query as either about weather or not."
messages = [{"role": "system", "content": system_message}] + state["messages"]
route = router_model.invoke(messages)
return {"route": cast(AIMessage, route).tool_calls[0]["args"]["dest"]}
def normal_llm_node(state: RouterState):
return {"messages": [AIMessage("Hello!")]}
def route_after_prediction(state: RouterState):
if state["route"] == "weather":
return "weather_graph"
else:
return "normal_llm_node"
def weather_graph(state: RouterState):
# this tests that all async checkpointers tested also implement sync methods
# as the subgraph called with sync invoke will use sync checkpointer methods
return subgraph.invoke(state)
graph = StateGraph(RouterState)
graph.add_node(router_node)
graph.add_node(normal_llm_node)
graph.add_node("weather_graph", weather_graph)
graph.add_edge(START, "router_node")
graph.add_conditional_edges("router_node", route_after_prediction)
graph.add_edge("normal_llm_node", END)
graph.add_edge("weather_graph", END)
def get_first_in_list():
return [*graph.get_state_history(config, limit=1)][0]
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = graph.compile(checkpointer=checkpointer)
assert graph.get_graph(xray=1).draw_mermaid() == snapshot
config = {"configurable": {"thread_id": "1"}}
thread2 = {"configurable": {"thread_id": "2"}}
inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}
# run with custom output
assert [
c async for c in graph.astream(inputs, thread2, stream_mode="custom")
] == [
"I'm",
" very",
]
assert [
c async for c in graph.astream(None, thread2, stream_mode="custom")
] == [
" good",
]
# run until interrupt
assert [
c
async for c in graph.astream(
inputs, config=config, stream_mode="updates", subgraphs=True
)
] == [
((), {"router_node": {"route": "weather"}}),
((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}),
((), {"__interrupt__": ()}),
]
# check current state
state = await graph.aget_state(config)
assert state == StateSnapshot(
values={
"messages": [_AnyIdHumanMessage(content="what's the weather in sf")],
"route": "weather",
},
next=("weather_graph",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {"router_node": {"route": "weather"}},
"step": 1,
"parents": {},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="weather_graph",
path=(PULL, "weather_graph"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("weather_graph:"),
}
},
),
),
)
# confirm that list() delegates to alist() correctly
assert await asyncio.to_thread(get_first_in_list) == state
# update
await graph.aupdate_state(state.tasks[0].state, {"city": "la"})
# run after update
assert [
c
async for c in graph.astream(
None, config=config, stream_mode="updates", subgraphs=True
)
] == [
(
(AnyStr("weather_graph:"),),
{
"weather_node": {
"messages": [
{"role": "assistant", "content": "I'ts sunny in la!"}
]
}
},
),
(
(),
{
"weather_graph": {
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf"),
_AnyIdAIMessage(content="I'ts sunny in la!"),
]
}
},
),
]
# try updating acting as weather node
config = {"configurable": {"thread_id": "14"}}
inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}
assert [
c
async for c in graph.astream(
inputs, config=config, stream_mode="updates", subgraphs=True
)
] == [
((), {"router_node": {"route": "weather"}}),
((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}),
((), {"__interrupt__": ()}),
]
state = await graph.aget_state(config, subgraphs=True)
assert state == StateSnapshot(
values={
"messages": [_AnyIdHumanMessage(content="what's the weather in sf")],
"route": "weather",
},
next=("weather_graph",),
config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {"router_node": {"route": "weather"}},
"step": 1,
"parents": {},
"thread_id": "14",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="weather_graph",
path=(PULL, "weather_graph"),
state=StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf")
],
"city": "San Francisco",
},
next=("weather_node",),
config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": AnyStr("weather_graph:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("weather_graph:"): AnyStr(),
}
),
}
},
metadata={
"source": "loop",
"writes": {"model_node": {"city": "San Francisco"}},
"step": 1,
"parents": {"": AnyStr()},
"thread_id": "14",
"checkpoint_ns": AnyStr("weather_graph:"),
"langgraph_node": "weather_graph",
"langgraph_path": [PULL, "weather_graph"],
"langgraph_step": 2,
"langgraph_triggers": [
"branch:router_node:route_after_prediction:weather_graph"
],
"langgraph_checkpoint_ns": AnyStr("weather_graph:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": AnyStr("weather_graph:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("weather_graph:"): AnyStr(),
}
),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="weather_node",
path=(PULL, "weather_node"),
),
),
),
),
),
)
await graph.aupdate_state(
state.tasks[0].state.config,
{"messages": [{"role": "assistant", "content": "rainy"}]},
as_node="weather_node",
)
state = await graph.aget_state(config, subgraphs=True)
assert state == StateSnapshot(
values={
"messages": [_AnyIdHumanMessage(content="what's the weather in sf")],
"route": "weather",
},
next=("weather_graph",),
config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {"router_node": {"route": "weather"}},
"step": 1,
"parents": {},
"thread_id": "14",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="weather_graph",
path=(PULL, "weather_graph"),
state=StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf"),
_AnyIdAIMessage(content="rainy"),
],
"city": "San Francisco",
},
next=(),
config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": AnyStr("weather_graph:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("weather_graph:"): AnyStr(),
}
),
}
},
metadata={
"step": 2,
"source": "update",
"writes": {
"weather_node": {
"messages": [
{"role": "assistant", "content": "rainy"}
]
}
},
"parents": {"": AnyStr()},
"thread_id": "14",
"checkpoint_id": AnyStr(),
"checkpoint_ns": AnyStr("weather_graph:"),
"langgraph_node": "weather_graph",
"langgraph_path": [PULL, "weather_graph"],
"langgraph_step": 2,
"langgraph_triggers": [
"branch:router_node:route_after_prediction:weather_graph"
],
"langgraph_checkpoint_ns": AnyStr("weather_graph:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": AnyStr("weather_graph:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("weather_graph:"): AnyStr(),
}
),
}
},
tasks=(),
),
),
),
)
assert [
c
async for c in graph.astream(
None, config=config, stream_mode="updates", subgraphs=True
)
] == [
(
(),
{
"weather_graph": {
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf"),
_AnyIdAIMessage(content="rainy"),
]
}
},
),
]
async def test_checkpoint_metadata() -> None:
"""This test verifies that a run's configurable fields are merged with the
previous checkpoint config for each step in the run.
"""
# set up test
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
# graph state
class BaseState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
# initialize graph nodes
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a nice assistant."),
("placeholder", "{messages}"),
]
)
model = FakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
AIMessage(content="answer"),
]
)
def agent(state: BaseState, config: RunnableConfig) -> BaseState:
formatted = prompt.invoke(state)
response = model.invoke(formatted)
return {"messages": response}
def should_continue(data: BaseState) -> str:
# Logic to decide whether to continue in the loop or exit
if not data["messages"][-1].tool_calls:
return "exit"
else:
return "continue"
# define graphs w/ and w/o interrupt
workflow = StateGraph(BaseState)
workflow.add_node("agent", agent)
workflow.add_node("tools", ToolNode(tools))
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent", should_continue, {"continue": "tools", "exit": END}
)
workflow.add_edge("tools", "agent")
# graph w/o interrupt
checkpointer_1 = MemorySaverAssertCheckpointMetadata()
app = workflow.compile(checkpointer=checkpointer_1)
# graph w/ interrupt
checkpointer_2 = MemorySaverAssertCheckpointMetadata()
app_w_interrupt = workflow.compile(
checkpointer=checkpointer_2, interrupt_before=["tools"]
)
# assertions
# invoke graph w/o interrupt
await app.ainvoke(
{"messages": ["what is weather in sf"]},
{
"configurable": {
"thread_id": "1",
"test_config_1": "foo",
"test_config_2": "bar",
},
},
)
config = {"configurable": {"thread_id": "1"}}
# assert that checkpoint metadata contains the run's configurable fields
chkpnt_metadata_1 = (await checkpointer_1.aget_tuple(config)).metadata
assert chkpnt_metadata_1["thread_id"] == "1"
assert chkpnt_metadata_1["test_config_1"] == "foo"
assert chkpnt_metadata_1["test_config_2"] == "bar"
# Verify that all checkpoint metadata have the expected keys. This check
# is needed because a run may have an arbitrary number of steps depending
# on how the graph is constructed.
chkpnt_tuples_1 = checkpointer_1.alist(config)
async for chkpnt_tuple in chkpnt_tuples_1:
assert chkpnt_tuple.metadata["thread_id"] == "1"
assert chkpnt_tuple.metadata["test_config_1"] == "foo"
assert chkpnt_tuple.metadata["test_config_2"] == "bar"
# invoke graph, but interrupt before tool call
await app_w_interrupt.ainvoke(
{"messages": ["what is weather in sf"]},
{
"configurable": {
"thread_id": "2",
"test_config_3": "foo",
"test_config_4": "bar",
},
},
)
config = {"configurable": {"thread_id": "2"}}
# assert that checkpoint metadata contains the run's configurable fields
chkpnt_metadata_2 = (await checkpointer_2.aget_tuple(config)).metadata
assert chkpnt_metadata_2["thread_id"] == "2"
assert chkpnt_metadata_2["test_config_3"] == "foo"
assert chkpnt_metadata_2["test_config_4"] == "bar"
# resume graph execution
await app_w_interrupt.ainvoke(
input=None,
config={
"configurable": {
"thread_id": "2",
"test_config_3": "foo",
"test_config_4": "bar",
}
},
)
# assert that checkpoint metadata contains the run's configurable fields
chkpnt_metadata_3 = (await checkpointer_2.aget_tuple(config)).metadata
assert chkpnt_metadata_3["thread_id"] == "2"
assert chkpnt_metadata_3["test_config_3"] == "foo"
assert chkpnt_metadata_3["test_config_4"] == "bar"
# Verify that all checkpoint metadata have the expected keys. This check
# is needed because a run may have an arbitrary number of steps depending
# on how the graph is constructed.
chkpnt_tuples_2 = checkpointer_2.alist(config)
async for chkpnt_tuple in chkpnt_tuples_2:
assert chkpnt_tuple.metadata["thread_id"] == "2"
assert chkpnt_tuple.metadata["test_config_3"] == "foo"
assert chkpnt_tuple.metadata["test_config_4"] == "bar"
async def test_checkpointer_null_pending_writes() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)
def __call__(self, state):
return [self.name]
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_edge(START, "1")
graph = builder.compile(checkpointer=MemorySaverNoPending())
assert graph.invoke([], {"configurable": {"thread_id": "foo"}}) == ["1"]
assert graph.invoke([], {"configurable": {"thread_id": "foo"}}) == ["1"] * 2
assert (await graph.ainvoke([], {"configurable": {"thread_id": "foo"}})) == [
"1"
] * 3
assert (await graph.ainvoke([], {"configurable": {"thread_id": "foo"}})) == [
"1"
] * 4
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
@pytest.mark.parametrize("store_name", ALL_STORES_ASYNC)
async def test_store_injected_async(checkpointer_name: str, store_name: str) -> None:
class State(TypedDict):
count: Annotated[int, operator.add]
doc_id = str(uuid.uuid4())
doc = {"some-key": "this-is-a-val"}
uid = uuid.uuid4().hex
namespace = (f"foo-{uid}", "bar")
thread_1 = str(uuid.uuid4())
thread_2 = str(uuid.uuid4())
class Node:
def __init__(self, i: Optional[int] = None):
self.i = i
async def __call__(
self, inputs: State, config: RunnableConfig, store: BaseStore
):
assert isinstance(store, BaseStore)
await store.aput(
namespace
if self.i is not None
and config["configurable"]["thread_id"] in (thread_1, thread_2)
else (f"foo_{self.i}", "bar"),
doc_id,
{
**doc,
"from_thread": config["configurable"]["thread_id"],
"some_val": inputs["count"],
},
)
return {"count": 1}
builder = StateGraph(State)
builder.add_node("node", Node())
builder.add_edge("__start__", "node")
N = 500
M = 1
if "duckdb" in store_name:
logger.warning(
"DuckDB store implementation has a known issue that does not"
" support concurrent writes, so we're reducing the test scope"
)
N = M = 1
for i in range(N):
builder.add_node(f"node_{i}", Node(i))
builder.add_edge("__start__", f"node_{i}")
async with awith_checkpointer(checkpointer_name) as checkpointer, awith_store(
store_name
) as the_store:
graph = builder.compile(store=the_store, checkpointer=checkpointer)
# Test batch operations with multiple threads
results = await graph.abatch(
[{"count": 0}] * M,
([{"configurable": {"thread_id": str(uuid.uuid4())}}] * (M - 1))
+ [{"configurable": {"thread_id": thread_1}}],
)
result = results[-1]
assert result == {"count": N + 1}
returned_doc = (await the_store.aget(namespace, doc_id)).value
assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 0}
assert len((await the_store.asearch(namespace))) == 1
# Check results after another turn of the same thread
result = await graph.ainvoke(
{"count": 0}, {"configurable": {"thread_id": thread_1}}
)
assert result == {"count": (N + 1) * 2}
returned_doc = (await the_store.aget(namespace, doc_id)).value
assert returned_doc == {**doc, "from_thread": thread_1, "some_val": N + 1}
assert len((await the_store.asearch(namespace))) == 1
# Test with a different thread
result = await graph.ainvoke(
{"count": 0}, {"configurable": {"thread_id": thread_2}}
)
assert result == {"count": N + 1}
returned_doc = (await the_store.aget(namespace, doc_id)).value
assert returned_doc == {
**doc,
"from_thread": thread_2,
"some_val": 0,
} # Overwrites the whole doc
assert (
len((await the_store.asearch(namespace))) == 1
) # still overwriting the same one
async def test_debug_retry():
class State(TypedDict):
messages: Annotated[list[str], operator.add]
def node(name):
async def _node(state: State):
return {"messages": [f"entered {name} node"]}
return _node
builder = StateGraph(State)
builder.add_node("one", node("one"))
builder.add_node("two", node("two"))
builder.add_edge(START, "one")
builder.add_edge("one", "two")
builder.add_edge("two", END)
saver = MemorySaver()
graph = builder.compile(checkpointer=saver)
config = {"configurable": {"thread_id": "1"}}
await graph.ainvoke({"messages": []}, config=config)
# re-run step: 1
async for c in saver.alist(config):
if c.metadata["step"] == 1:
target_config = c.parent_config
break
assert target_config is not None
update_config = await graph.aupdate_state(target_config, values=None)
events = [
c async for c in graph.astream(None, config=update_config, stream_mode="debug")
]
checkpoint_events = list(
reversed([e["payload"] for e in events if e["type"] == "checkpoint"])
)
checkpoint_history = {
c.config["configurable"]["checkpoint_id"]: c
async for c in graph.aget_state_history(config)
}
def lax_normalize_config(config: Optional[dict]) -> Optional[dict]:
if config is None:
return None
return config["configurable"]
for stream in checkpoint_events:
stream_conf = lax_normalize_config(stream["config"])
stream_parent_conf = lax_normalize_config(stream["parent_config"])
assert stream_conf != stream_parent_conf
# ensure the streamed checkpoint == checkpoint from checkpointer.list()
history = checkpoint_history[stream["config"]["configurable"]["checkpoint_id"]]
history_conf = lax_normalize_config(history.config)
assert stream_conf == history_conf
history_parent_conf = lax_normalize_config(history.parent_config)
assert stream_parent_conf == history_parent_conf
async def test_debug_subgraphs():
class State(TypedDict):
messages: Annotated[list[str], operator.add]
def node(name):
async def _node(state: State):
return {"messages": [f"entered {name} node"]}
return _node
parent = StateGraph(State)
child = StateGraph(State)
child.add_node("c_one", node("c_one"))
child.add_node("c_two", node("c_two"))
child.add_edge(START, "c_one")
child.add_edge("c_one", "c_two")
child.add_edge("c_two", END)
parent.add_node("p_one", node("p_one"))
parent.add_node("p_two", child.compile())
parent.add_edge(START, "p_one")
parent.add_edge("p_one", "p_two")
parent.add_edge("p_two", END)
graph = parent.compile(checkpointer=MemorySaver())
config = {"configurable": {"thread_id": "1"}}
events = [
c
async for c in graph.astream(
{"messages": []},
config=config,
stream_mode="debug",
)
]
checkpoint_events = list(
reversed([e["payload"] for e in events if e["type"] == "checkpoint"])
)
checkpoint_history = [c async for c in graph.aget_state_history(config)]
assert len(checkpoint_events) == len(checkpoint_history)
def normalize_config(config: Optional[dict]) -> Optional[dict]:
if config is None:
return None
return config["configurable"]
for stream, history in zip(checkpoint_events, checkpoint_history):
assert stream["values"] == history.values
assert stream["next"] == list(history.next)
assert normalize_config(stream["config"]) == normalize_config(history.config)
assert normalize_config(stream["parent_config"]) == normalize_config(
history.parent_config
)
assert len(stream["tasks"]) == len(history.tasks)
for stream_task, history_task in zip(stream["tasks"], history.tasks):
assert stream_task["id"] == history_task.id
assert stream_task["name"] == history_task.name
assert stream_task["interrupts"] == history_task.interrupts
assert stream_task.get("error") == history_task.error
assert stream_task.get("state") == history_task.state
async def test_debug_nested_subgraphs():
from collections import defaultdict
class State(TypedDict):
messages: Annotated[list[str], operator.add]
def node(name):
async def _node(state: State):
return {"messages": [f"entered {name} node"]}
return _node
grand_parent = StateGraph(State)
parent = StateGraph(State)
child = StateGraph(State)
child.add_node("c_one", node("c_one"))
child.add_node("c_two", node("c_two"))
child.add_edge(START, "c_one")
child.add_edge("c_one", "c_two")
child.add_edge("c_two", END)
parent.add_node("p_one", node("p_one"))
parent.add_node("p_two", child.compile())
parent.add_edge(START, "p_one")
parent.add_edge("p_one", "p_two")
parent.add_edge("p_two", END)
grand_parent.add_node("gp_one", node("gp_one"))
grand_parent.add_node("gp_two", parent.compile())
grand_parent.add_edge(START, "gp_one")
grand_parent.add_edge("gp_one", "gp_two")
grand_parent.add_edge("gp_two", END)
graph = grand_parent.compile(checkpointer=MemorySaver())
config = {"configurable": {"thread_id": "1"}}
events = [
c
async for c in graph.astream(
{"messages": []},
config=config,
stream_mode="debug",
subgraphs=True,
)
]
stream_ns: dict[tuple, dict] = defaultdict(list)
for ns, e in events:
if e["type"] == "checkpoint":
stream_ns[ns].append(e["payload"])
assert list(stream_ns.keys()) == [
(),
(AnyStr("gp_two:"),),
(AnyStr("gp_two:"), AnyStr("p_two:")),
]
history_ns = {}
for ns in stream_ns.keys():
async def get_history():
history = [
c
async for c in graph.aget_state_history(
{"configurable": {"thread_id": "1", "checkpoint_ns": "|".join(ns)}}
)
]
return history[::-1]
history_ns[ns] = await get_history()
def normalize_config(config: Optional[dict]) -> Optional[dict]:
if config is None:
return None
clean_config = {}
clean_config["thread_id"] = config["configurable"]["thread_id"]
clean_config["checkpoint_id"] = config["configurable"]["checkpoint_id"]
clean_config["checkpoint_ns"] = config["configurable"]["checkpoint_ns"]
if "checkpoint_map" in config["configurable"]:
clean_config["checkpoint_map"] = config["configurable"]["checkpoint_map"]
return clean_config
for checkpoint_events, checkpoint_history in zip(
stream_ns.values(), history_ns.values()
):
for stream, history in zip(checkpoint_events, checkpoint_history):
assert stream["values"] == history.values
assert stream["next"] == list(history.next)
assert normalize_config(stream["config"]) == normalize_config(
history.config
)
assert normalize_config(stream["parent_config"]) == normalize_config(
history.parent_config
)
assert len(stream["tasks"]) == len(history.tasks)
for stream_task, history_task in zip(stream["tasks"], history.tasks):
assert stream_task["id"] == history_task.id
assert stream_task["name"] == history_task.name
assert stream_task["interrupts"] == history_task.interrupts
assert stream_task.get("error") == history_task.error
assert stream_task.get("state") == history_task.state
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_parent_command(checkpointer_name: str) -> None:
from langchain_core.messages import BaseMessage
from langchain_core.tools import tool
@tool(return_direct=True)
def get_user_name() -> Command:
"""Retrieve user name"""
return Command(update={"user_name": "Meow"}, graph=Command.PARENT)
subgraph_builder = StateGraph(MessagesState)
subgraph_builder.add_node("tool", get_user_name)
subgraph_builder.add_edge(START, "tool")
subgraph = subgraph_builder.compile()
class CustomParentState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
# this key is not available to the child graph
user_name: str
builder = StateGraph(CustomParentState)
builder.add_node("alice", subgraph)
builder.add_edge(START, "alice")
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}
assert await graph.ainvoke(
{"messages": [("user", "get user name")]}, config
) == {
"messages": [
_AnyIdHumanMessage(
content="get user name", additional_kwargs={}, response_metadata={}
),
],
"user_name": "Meow",
}
assert await graph.aget_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(
content="get user name",
additional_kwargs={},
response_metadata={},
),
],
"user_name": "Meow",
},
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {
"alice": {
"user_name": "Meow",
}
},
"thread_id": "1",
"step": 1,
"parents": {},
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(),
)
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_interrupt_subgraph(checkpointer_name: str):
class State(TypedDict):
baz: str
def foo(state):
return {"baz": "foo"}
def bar(state):
value = interrupt("Please provide baz value:")
return {"baz": value}
child_builder = StateGraph(State)
child_builder.add_node(bar)
child_builder.add_edge(START, "bar")
builder = StateGraph(State)
builder.add_node(foo)
builder.add_node("bar", child_builder.compile())
builder.add_edge(START, "foo")
builder.add_edge("foo", "bar")
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
thread1 = {"configurable": {"thread_id": "1"}}
# First run, interrupted at bar
assert await graph.ainvoke({"baz": ""}, thread1)
# Resume with answer
assert await graph.ainvoke(Command(resume="bar"), thread1)
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_interrupt_multiple(checkpointer_name: str):
class State(TypedDict):
my_key: Annotated[str, operator.add]
async def node(s: State) -> State:
answer = interrupt({"value": 1})
answer2 = interrupt({"value": 2})
return {"my_key": answer + " " + answer2}
builder = StateGraph(State)
builder.add_node("node", node)
builder.add_edge(START, "node")
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
thread1 = {"configurable": {"thread_id": "1"}}
assert [
e async for e in graph.astream({"my_key": "DE", "market": "DE"}, thread1)
] == [
{
"__interrupt__": (
Interrupt(
value={"value": 1},
resumable=True,
ns=[AnyStr("node:")],
when="during",
),
)
}
]
assert [
event
async for event in graph.astream(
Command(resume="answer 1", update={"my_key": "foofoo"}),
thread1,
stream_mode="updates",
)
] == [
{
"__interrupt__": (
Interrupt(
value={"value": 2},
resumable=True,
ns=[AnyStr("node:")],
when="during",
),
)
}
]
assert [
event
async for event in graph.astream(
Command(resume="answer 2"), thread1, stream_mode="updates"
)
] == [
{"node": {"my_key": "answer 1 answer 2"}},
]
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_interrupt_loop(checkpointer_name: str):
class State(TypedDict):
age: int
other: str
async def ask_age(s: State):
"""Ask an expert for help."""
question = "How old are you?"
value = None
for _ in range(10):
value: str = interrupt(question)
if not value.isdigit() or int(value) < 18:
question = "invalid response"
value = None
else:
break
return {"age": int(value)}
builder = StateGraph(State)
builder.add_node("node", ask_age)
builder.add_edge(START, "node")
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)
thread1 = {"configurable": {"thread_id": "1"}}
assert [e async for e in graph.astream({"other": ""}, thread1)] == [
{
"__interrupt__": (
Interrupt(
value="How old are you?",
resumable=True,
ns=[AnyStr("node:")],
when="during",
),
)
}
]
assert [
event
async for event in graph.astream(
Command(resume="13"),
thread1,
)
] == [
{
"__interrupt__": (
Interrupt(
value="invalid response",
resumable=True,
ns=[AnyStr("node:")],
when="during",
),
)
}
]
assert [
event
async for event in graph.astream(
Command(resume="15"),
thread1,
)
] == [
{
"__interrupt__": (
Interrupt(
value="invalid response",
resumable=True,
ns=[AnyStr("node:")],
when="during",
),
)
}
]
assert [
event async for event in graph.astream(Command(resume="19"), thread1)
] == [
{"node": {"age": 19}},
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_prebuilt.py`:
```py
import dataclasses
import json
from functools import partial
from typing import (
Annotated,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
TypeVar,
Union,
)
import pytest
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.tools import BaseTool, ToolException
from langchain_core.tools import tool as dec_tool
from pydantic import BaseModel, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from typing_extensions import TypedDict
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import MemorySaver
from langgraph.errors import NodeInterrupt
from langgraph.graph import START, MessagesState, StateGraph, add_messages
from langgraph.prebuilt import (
ToolNode,
ValidationNode,
create_react_agent,
tools_condition,
)
from langgraph.prebuilt.chat_agent_executor import _validate_chat_history
from langgraph.prebuilt.tool_node import (
TOOL_CALL_ERROR_TEMPLATE,
InjectedState,
InjectedStore,
_get_state_args,
_infer_handled_types,
)
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
from langgraph.types import Interrupt
from tests.conftest import (
ALL_CHECKPOINTERS_ASYNC,
ALL_CHECKPOINTERS_SYNC,
IS_LANGCHAIN_CORE_030_OR_GREATER,
awith_checkpointer,
)
from tests.messages import _AnyIdHumanMessage, _AnyIdToolMessage
pytestmark = pytest.mark.anyio
class FakeToolCallingModel(BaseChatModel):
tool_calls: Optional[list[list[ToolCall]]] = None
index: int = 0
tool_style: Literal["openai", "anthropic"] = "openai"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
messages_string = "-".join([m.content for m in messages])
tool_calls = (
self.tool_calls[self.index % len(self.tool_calls)]
if self.tool_calls
else []
)
message = AIMessage(
content=messages_string, id=str(self.index), tool_calls=tool_calls.copy()
)
self.index += 1
return ChatResult(generations=[ChatGeneration(message=message)])
@property
def _llm_type(self) -> str:
return "fake-tool-call-model"
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
if len(tools) == 0:
raise ValueError("Must provide at least one tool")
tool_dicts = []
for tool in tools:
if not isinstance(tool, BaseTool):
raise TypeError(
"Only BaseTool is supported by FakeToolCallingModel.bind_tools"
)
# NOTE: this is a simplified tool spec for testing purposes only
if self.tool_style == "openai":
tool_dicts.append(
{
"type": "function",
"function": {
"name": tool.name,
},
}
)
elif self.tool_style == "anthropic":
tool_dicts.append(
{
"name": tool.name,
}
)
return self.bind(tools=tool_dicts)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_no_modifier(request: pytest.FixtureRequest, checkpointer_name: str) -> None:
checkpointer: BaseCheckpointSaver = request.getfixturevalue(
"checkpointer_" + checkpointer_name
)
model = FakeToolCallingModel()
agent = create_react_agent(model, [], checkpointer=checkpointer)
inputs = [HumanMessage("hi?")]
thread = {"configurable": {"thread_id": "123"}}
response = agent.invoke({"messages": inputs}, thread, debug=True)
expected_response = {"messages": inputs + [AIMessage(content="hi?", id="0")]}
assert response == expected_response
if checkpointer:
saved = checkpointer.get_tuple(thread)
assert saved is not None
assert saved.checkpoint["channel_values"] == {
"messages": [
_AnyIdHumanMessage(content="hi?"),
AIMessage(content="hi?", id="0"),
],
"agent": "agent",
}
assert saved.metadata == {
"parents": {},
"source": "loop",
"writes": {"agent": {"messages": [AIMessage(content="hi?", id="0")]}},
"step": 1,
"thread_id": "123",
}
assert saved.pending_writes == []
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_no_modifier_async(checkpointer_name: str) -> None:
async with awith_checkpointer(checkpointer_name) as checkpointer:
model = FakeToolCallingModel()
agent = create_react_agent(model, [], checkpointer=checkpointer)
inputs = [HumanMessage("hi?")]
thread = {"configurable": {"thread_id": "123"}}
response = await agent.ainvoke({"messages": inputs}, thread, debug=True)
expected_response = {"messages": inputs + [AIMessage(content="hi?", id="0")]}
assert response == expected_response
if checkpointer:
saved = await checkpointer.aget_tuple(thread)
assert saved is not None
assert saved.checkpoint["channel_values"] == {
"messages": [
_AnyIdHumanMessage(content="hi?"),
AIMessage(content="hi?", id="0"),
],
"agent": "agent",
}
assert saved.metadata == {
"parents": {},
"source": "loop",
"writes": {"agent": {"messages": [AIMessage(content="hi?", id="0")]}},
"step": 1,
"thread_id": "123",
}
assert saved.pending_writes == []
def test_passing_two_modifiers():
model = FakeToolCallingModel()
with pytest.raises(ValueError):
create_react_agent(model, [], messages_modifier="Foo", state_modifier="Bar")
def test_system_message_modifier():
messages_modifier = SystemMessage(content="Foo")
agent_1 = create_react_agent(
FakeToolCallingModel(), [], messages_modifier=messages_modifier
)
agent_2 = create_react_agent(
FakeToolCallingModel(), [], state_modifier=messages_modifier
)
for agent in [agent_1, agent_2]:
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {
"messages": inputs + [AIMessage(content="Foo-hi?", id="0", tool_calls=[])]
}
assert response == expected_response
def test_system_message_string_modifier():
messages_modifier = "Foo"
agent_1 = create_react_agent(
FakeToolCallingModel(), [], messages_modifier=messages_modifier
)
agent_2 = create_react_agent(
FakeToolCallingModel(), [], state_modifier=messages_modifier
)
for agent in [agent_1, agent_2]:
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {
"messages": inputs + [AIMessage(content="Foo-hi?", id="0", tool_calls=[])]
}
assert response == expected_response
def test_callable_messages_modifier():
model = FakeToolCallingModel()
def messages_modifier(messages):
modified_message = f"Bar {messages[-1].content}"
return [HumanMessage(content=modified_message)]
agent = create_react_agent(model, [], messages_modifier=messages_modifier)
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {"messages": inputs + [AIMessage(content="Bar hi?", id="0")]}
assert response == expected_response
def test_callable_state_modifier():
model = FakeToolCallingModel()
def state_modifier(state):
modified_message = f"Bar {state['messages'][-1].content}"
return [HumanMessage(content=modified_message)]
agent = create_react_agent(model, [], state_modifier=state_modifier)
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {"messages": inputs + [AIMessage(content="Bar hi?", id="0")]}
assert response == expected_response
def test_runnable_messages_modifier():
model = FakeToolCallingModel()
messages_modifier = RunnableLambda(
lambda messages: [HumanMessage(content=f"Baz {messages[-1].content}")]
)
agent = create_react_agent(model, [], messages_modifier=messages_modifier)
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {"messages": inputs + [AIMessage(content="Baz hi?", id="0")]}
assert response == expected_response
def test_runnable_state_modifier():
model = FakeToolCallingModel()
state_modifier = RunnableLambda(
lambda state: [HumanMessage(content=f"Baz {state['messages'][-1].content}")]
)
agent = create_react_agent(model, [], state_modifier=state_modifier)
inputs = [HumanMessage("hi?")]
response = agent.invoke({"messages": inputs})
expected_response = {"messages": inputs + [AIMessage(content="Baz hi?", id="0")]}
assert response == expected_response
def test_state_modifier_with_store():
def add(a: int, b: int):
"""Adds a and b"""
return a + b
in_memory_store = InMemoryStore()
in_memory_store.put(("memories", "1"), "user_name", {"data": "User name is Alice"})
in_memory_store.put(("memories", "2"), "user_name", {"data": "User name is Bob"})
def modify(state, config, *, store):
user_id = config["configurable"]["user_id"]
system_str = store.get(("memories", user_id), "user_name").value["data"]
return [SystemMessage(system_str)] + state["messages"]
def modify_no_store(state, config):
return SystemMessage("foo") + state["messages"]
model = FakeToolCallingModel()
# test state modifier that uses store works
agent = create_react_agent(
model, [add], state_modifier=modify, store=in_memory_store
)
response = agent.invoke(
{"messages": [("user", "hi")]}, {"configurable": {"user_id": "1"}}
)
assert response["messages"][-1].content == "User name is Alice-hi"
# test state modifier that doesn't use store works
agent = create_react_agent(
model, [add], state_modifier=modify_no_store, store=in_memory_store
)
response = agent.invoke(
{"messages": [("user", "hi")]}, {"configurable": {"user_id": "2"}}
)
assert response["messages"][-1].content == "foo-hi"
@pytest.mark.parametrize("tool_style", ["openai", "anthropic"])
def test_model_with_tools(tool_style: str):
model = FakeToolCallingModel(tool_style=tool_style)
@dec_tool
def tool1(some_val: int) -> str:
"""Tool 1 docstring."""
return f"Tool 1: {some_val}"
@dec_tool
def tool2(some_val: int) -> str:
"""Tool 2 docstring."""
return f"Tool 2: {some_val}"
# check valid agent constructor
agent = create_react_agent(model.bind_tools([tool1, tool2]), [tool1, tool2])
result = agent.nodes["tools"].invoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool1",
"args": {"some_val": 2},
"id": "some 1",
},
{
"name": "tool2",
"args": {"some_val": 2},
"id": "some 2",
},
],
)
]
}
)
tool_messages: ToolMessage = result["messages"][-2:]
for tool_message in tool_messages:
assert tool_message.type == "tool"
assert tool_message.content in {"Tool 1: 2", "Tool 2: 2"}
assert tool_message.tool_call_id in {"some 1", "some 2"}
# test mismatching tool lengths
with pytest.raises(ValueError):
create_react_agent(model.bind_tools([tool1]), [tool1, tool2])
# test missing bound tools
with pytest.raises(ValueError):
create_react_agent(model.bind_tools([tool1]), [tool2])
def test__validate_messages():
# empty input
_validate_chat_history([])
# single human message
_validate_chat_history(
[
HumanMessage(content="What's the weather?"),
]
)
# human + AI
_validate_chat_history(
[
HumanMessage(content="What's the weather?"),
AIMessage(content="The weather is sunny and 75°F."),
]
)
# Answered tool calls
_validate_chat_history(
[
HumanMessage(content="What's the weather?"),
AIMessage(
content="Let me check that for you.",
tool_calls=[{"id": "call1", "name": "get_weather", "args": {}}],
),
ToolMessage(content="Sunny, 75°F", tool_call_id="call1"),
AIMessage(content="The weather is sunny and 75°F."),
]
)
# Unanswered tool calls
with pytest.raises(ValueError):
_validate_chat_history(
[
AIMessage(
content="I'll check that for you.",
tool_calls=[
{"id": "call1", "name": "get_weather", "args": {}},
{"id": "call2", "name": "get_time", "args": {}},
],
)
]
)
with pytest.raises(ValueError):
_validate_chat_history(
[
HumanMessage(content="What's the weather and time?"),
AIMessage(
content="I'll check that for you.",
tool_calls=[
{"id": "call1", "name": "get_weather", "args": {}},
{"id": "call2", "name": "get_time", "args": {}},
],
),
ToolMessage(content="Sunny, 75°F", tool_call_id="call1"),
AIMessage(
content="The weather is sunny and 75°F. Let me check the time."
),
]
)
def test__infer_handled_types() -> None:
def handle(e): # type: ignore
return ""
def handle2(e: Exception) -> str:
return ""
def handle3(e: Union[ValueError, ToolException]) -> str:
return ""
class Handler:
def handle(self, e: ValueError) -> str:
return ""
handle4 = Handler().handle
def handle5(e: Union[Union[TypeError, ValueError], ToolException]):
return ""
expected: tuple = (Exception,)
actual = _infer_handled_types(handle)
assert expected == actual
expected = (Exception,)
actual = _infer_handled_types(handle2)
assert expected == actual
expected = (ValueError, ToolException)
actual = _infer_handled_types(handle3)
assert expected == actual
expected = (ValueError,)
actual = _infer_handled_types(handle4)
assert expected == actual
expected = (TypeError, ValueError, ToolException)
actual = _infer_handled_types(handle5)
assert expected == actual
with pytest.raises(ValueError):
def handler(e: str):
return ""
_infer_handled_types(handler)
with pytest.raises(ValueError):
def handler(e: list[Exception]):
return ""
_infer_handled_types(handler)
with pytest.raises(ValueError):
def handler(e: Union[str, int]):
return ""
_infer_handled_types(handler)
# tools for testing Too
def tool1(some_val: int, some_other_val: str) -> str:
"""Tool 1 docstring."""
if some_val == 0:
raise ValueError("Test error")
return f"{some_val} - {some_other_val}"
async def tool2(some_val: int, some_other_val: str) -> str:
"""Tool 2 docstring."""
if some_val == 0:
raise ToolException("Test error")
return f"tool2: {some_val} - {some_other_val}"
async def tool3(some_val: int, some_other_val: str) -> str:
"""Tool 3 docstring."""
return [
{"key_1": some_val, "key_2": "foo"},
{"key_1": some_other_val, "key_2": "baz"},
]
async def tool4(some_val: int, some_other_val: str) -> str:
"""Tool 4 docstring."""
return [
{"type": "image_url", "image_url": {"url": "abdc"}},
]
@dec_tool
def tool5(some_val: int):
"""Tool 5 docstring."""
raise ToolException("Test error")
tool5.handle_tool_error = "foo"
async def test_tool_node():
result = ToolNode([tool1]).invoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool1",
"args": {"some_val": 1, "some_other_val": "foo"},
"id": "some 0",
}
],
)
]
}
)
tool_message: ToolMessage = result["messages"][-1]
assert tool_message.type == "tool"
assert tool_message.content == "1 - foo"
assert tool_message.tool_call_id == "some 0"
result2 = await ToolNode([tool2]).ainvoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool2",
"args": {"some_val": 2, "some_other_val": "bar"},
"id": "some 1",
}
],
)
]
}
)
tool_message: ToolMessage = result2["messages"][-1]
assert tool_message.type == "tool"
assert tool_message.content == "tool2: 2 - bar"
# list of dicts tool content
result3 = await ToolNode([tool3]).ainvoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool3",
"args": {"some_val": 2, "some_other_val": "bar"},
"id": "some 2",
}
],
)
]
}
)
tool_message: ToolMessage = result3["messages"][-1]
assert tool_message.type == "tool"
assert (
tool_message.content
== '[{"key_1": 2, "key_2": "foo"}, {"key_1": "bar", "key_2": "baz"}]'
)
assert tool_message.tool_call_id == "some 2"
# list of content blocks tool content
result4 = await ToolNode([tool4]).ainvoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool4",
"args": {"some_val": 2, "some_other_val": "bar"},
"id": "some 3",
}
],
)
]
}
)
tool_message: ToolMessage = result4["messages"][-1]
assert tool_message.type == "tool"
assert tool_message.content == [{"type": "image_url", "image_url": {"url": "abdc"}}]
assert tool_message.tool_call_id == "some 3"
async def test_tool_node_error_handling():
def handle_all(e: Union[ValueError, ToolException, ValidationError]):
return TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
# test catching all exceptions, via:
# - handle_tool_errors = True
# - passing a tuple of all exceptions
# - passing a callable with all exceptions in the signature
for handle_tool_errors in (
True,
(ValueError, ToolException, ValidationError),
handle_all,
):
result_error = await ToolNode(
[tool1, tool2, tool3], handle_tool_errors=handle_tool_errors
).ainvoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool1",
"args": {"some_val": 0, "some_other_val": "foo"},
"id": "some id",
},
{
"name": "tool2",
"args": {"some_val": 0, "some_other_val": "bar"},
"id": "some other id",
},
{
"name": "tool3",
"args": {"some_val": 0},
"id": "another id",
},
],
)
]
}
)
assert all(m.type == "tool" for m in result_error["messages"])
assert all(m.status == "error" for m in result_error["messages"])
assert (
result_error["messages"][0].content
== f"Error: {repr(ValueError('Test error'))}\n Please fix your mistakes."
)
assert (
result_error["messages"][1].content
== f"Error: {repr(ToolException('Test error'))}\n Please fix your mistakes."
)
assert (
"ValidationError" in result_error["messages"][2].content
or "validation error" in result_error["messages"][2].content
)
assert result_error["messages"][0].tool_call_id == "some id"
assert result_error["messages"][1].tool_call_id == "some other id"
assert result_error["messages"][2].tool_call_id == "another id"
async def test_tool_node_error_handling_callable():
def handle_value_error(e: ValueError):
return "Value error"
def handle_tool_exception(e: ToolException):
return "Tool exception"
for handle_tool_errors in ("Value error", handle_value_error):
result_error = await ToolNode(
[tool1], handle_tool_errors=handle_tool_errors
).ainvoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool1",
"args": {"some_val": 0, "some_other_val": "foo"},
"id": "some id",
},
],
)
]
}
)
tool_message: ToolMessage = result_error["messages"][-1]
assert tool_message.type == "tool"
assert tool_message.status == "error"
assert tool_message.content == "Value error"
# test raising for an unhandled exception, via:
# - passing a tuple of all exceptions
# - passing a callable with all exceptions in the signature
for handle_tool_errors in ((ValueError,), handle_value_error):
with pytest.raises(ToolException) as exc_info:
await ToolNode(
[tool1, tool2], handle_tool_errors=handle_tool_errors
).ainvoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool1",
"args": {"some_val": 0, "some_other_val": "foo"},
"id": "some id",
},
{
"name": "tool2",
"args": {"some_val": 0, "some_other_val": "bar"},
"id": "some other id",
},
],
)
]
}
)
assert str(exc_info.value) == "Test error"
for handle_tool_errors in ((ToolException,), handle_tool_exception):
with pytest.raises(ValueError) as exc_info:
await ToolNode(
[tool1, tool2], handle_tool_errors=handle_tool_errors
).ainvoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool1",
"args": {"some_val": 0, "some_other_val": "foo"},
"id": "some id",
},
{
"name": "tool2",
"args": {"some_val": 0, "some_other_val": "bar"},
"id": "some other id",
},
],
)
]
}
)
assert str(exc_info.value) == "Test error"
async def test_tool_node_handle_tool_errors_false():
with pytest.raises(ValueError) as exc_info:
ToolNode([tool1], handle_tool_errors=False).invoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool1",
"args": {"some_val": 0, "some_other_val": "foo"},
"id": "some id",
}
],
)
]
}
)
assert str(exc_info.value) == "Test error"
with pytest.raises(ToolException):
await ToolNode([tool2], handle_tool_errors=False).ainvoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool2",
"args": {"some_val": 0, "some_other_val": "bar"},
"id": "some id",
}
],
)
]
}
)
assert str(exc_info.value) == "Test error"
# test validation errors get raised if handle_tool_errors is False
with pytest.raises((ValidationError, ValidationErrorV1)):
ToolNode([tool1], handle_tool_errors=False).invoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool1",
"args": {"some_val": 0},
"id": "some id",
}
],
)
]
}
)
def test_tool_node_individual_tool_error_handling():
# test error handling on individual tools (and that it overrides overall error handling!)
result_individual_tool_error_handler = ToolNode(
[tool5], handle_tool_errors="bar"
).invoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool5",
"args": {"some_val": 0},
"id": "some 0",
}
],
)
]
}
)
tool_message: ToolMessage = result_individual_tool_error_handler["messages"][-1]
assert tool_message.type == "tool"
assert tool_message.status == "error"
assert tool_message.content == "foo"
assert tool_message.tool_call_id == "some 0"
def test_tool_node_incorrect_tool_name():
result_incorrect_name = ToolNode([tool1, tool2]).invoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool3",
"args": {"some_val": 1, "some_other_val": "foo"},
"id": "some 0",
}
],
)
]
}
)
tool_message: ToolMessage = result_incorrect_name["messages"][-1]
assert tool_message.type == "tool"
assert tool_message.status == "error"
assert (
tool_message.content
== "Error: tool3 is not a valid tool, try one of [tool1, tool2]."
)
assert tool_message.tool_call_id == "some 0"
def test_tool_node_node_interrupt():
def tool_normal(some_val: int) -> str:
"""Tool docstring."""
return "normal"
def tool_interrupt(some_val: int) -> str:
"""Tool docstring."""
raise NodeInterrupt("foo")
def handle(e: NodeInterrupt):
return "handled"
for handle_tool_errors in (True, (NodeInterrupt,), "handled", handle, False):
node = ToolNode([tool_interrupt], handle_tool_errors=handle_tool_errors)
with pytest.raises(NodeInterrupt) as exc_info:
node.invoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool_interrupt",
"args": {"some_val": 0},
"id": "some 0",
}
],
)
]
}
)
assert exc_info.value == "foo"
# test inside react agent
model = FakeToolCallingModel(
tool_calls=[
[
ToolCall(name="tool_interrupt", args={"some_val": 0}, id="1"),
ToolCall(name="tool_normal", args={"some_val": 1}, id="2"),
],
[],
]
)
checkpointer = MemorySaver()
config = {"configurable": {"thread_id": "1"}}
agent = create_react_agent(
model, [tool_interrupt, tool_normal], checkpointer=checkpointer
)
result = agent.invoke({"messages": [HumanMessage("hi?")]}, config)
assert result["messages"] == [
_AnyIdHumanMessage(
content="hi?",
),
AIMessage(
content="hi?",
id="0",
tool_calls=[
{
"name": "tool_interrupt",
"args": {"some_val": 0},
"id": "1",
"type": "tool_call",
},
{
"name": "tool_normal",
"args": {"some_val": 1},
"id": "2",
"type": "tool_call",
},
],
),
]
state = agent.get_state(config)
assert state.next == ("tools",)
task = state.tasks[0]
assert task.name == "tools"
assert task.interrupts == (Interrupt(value="foo", when="during"),)
def my_function(some_val: int, some_other_val: str) -> str:
return f"{some_val} - {some_other_val}"
class MyModel(BaseModel):
some_val: int
some_other_val: str
class MyModelV1(BaseModelV1):
some_val: int
some_other_val: str
@dec_tool
def my_tool(some_val: int, some_other_val: str) -> str:
"""Cool."""
return f"{some_val} - {some_other_val}"
@pytest.mark.parametrize(
"tool_schema",
[
my_function,
MyModel,
MyModelV1,
my_tool,
],
)
@pytest.mark.parametrize("use_message_key", [True, False])
async def test_validation_node(tool_schema: Any, use_message_key: bool):
validation_node = ValidationNode([tool_schema])
tool_name = getattr(tool_schema, "name", getattr(tool_schema, "__name__", None))
inputs = [
AIMessage(
"hi?",
tool_calls=[
{
"name": tool_name,
"args": {"some_val": 1, "some_other_val": "foo"},
"id": "some 0",
},
{
"name": tool_name,
# Wrong type for some_val
"args": {"some_val": "bar", "some_other_val": "foo"},
"id": "some 1",
},
],
),
]
if use_message_key:
inputs = {"messages": inputs}
result = await validation_node.ainvoke(inputs)
if use_message_key:
result = result["messages"]
def check_results(messages: list):
assert len(messages) == 2
assert all(m.type == "tool" for m in messages)
assert not messages[0].additional_kwargs.get("is_error")
assert messages[1].additional_kwargs.get("is_error")
check_results(result)
result_sync = validation_node.invoke(inputs)
if use_message_key:
result_sync = result_sync["messages"]
check_results(result_sync)
class _InjectStateSchema(TypedDict):
messages: list
foo: str
class _InjectedStatePydanticSchema(BaseModelV1):
messages: list
foo: str
class _InjectedStatePydanticV2Schema(BaseModel):
messages: list
foo: str
@dataclasses.dataclass
class _InjectedStateDataclassSchema:
messages: list
foo: str
T = TypeVar("T")
@pytest.mark.parametrize(
"schema_",
[
_InjectStateSchema,
_InjectedStatePydanticSchema,
_InjectedStatePydanticV2Schema,
_InjectedStateDataclassSchema,
],
)
def test_tool_node_inject_state(schema_: Type[T]) -> None:
def tool1(some_val: int, state: Annotated[T, InjectedState]) -> str:
"""Tool 1 docstring."""
if isinstance(state, dict):
return state["foo"]
else:
return getattr(state, "foo")
def tool2(some_val: int, state: Annotated[T, InjectedState()]) -> str:
"""Tool 2 docstring."""
if isinstance(state, dict):
return state["foo"]
else:
return getattr(state, "foo")
def tool3(
some_val: int,
foo: Annotated[str, InjectedState("foo")],
msgs: Annotated[List[AnyMessage], InjectedState("messages")],
) -> str:
"""Tool 1 docstring."""
return foo
def tool4(
some_val: int, msgs: Annotated[List[AnyMessage], InjectedState("messages")]
) -> str:
"""Tool 1 docstring."""
return msgs[0].content
node = ToolNode([tool1, tool2, tool3, tool4])
for tool_name in ("tool1", "tool2", "tool3"):
tool_call = {
"name": tool_name,
"args": {"some_val": 1},
"id": "some 0",
"type": "tool_call",
}
msg = AIMessage("hi?", tool_calls=[tool_call])
result = node.invoke(schema_(**{"messages": [msg], "foo": "bar"}))
tool_message = result["messages"][-1]
assert tool_message.content == "bar", f"Failed for tool={tool_name}"
if tool_name == "tool3":
failure_input = None
try:
failure_input = schema_(**{"messages": [msg], "notfoo": "bar"})
except Exception:
pass
if failure_input is not None:
with pytest.raises(KeyError):
node.invoke(failure_input)
with pytest.raises(ValueError):
node.invoke([msg])
else:
failure_input = None
try:
failure_input = schema_(**{"messages": [msg], "notfoo": "bar"})
except Exception:
# We'd get a validation error from pydantic state and wouldn't make it to the node
# anyway
pass
if failure_input is not None:
messages_ = node.invoke(failure_input)
tool_message = messages_["messages"][-1]
assert "KeyError" in tool_message.content
tool_message = node.invoke([msg])[-1]
assert "KeyError" in tool_message.content
tool_call = {
"name": "tool4",
"args": {"some_val": 1},
"id": "some 0",
"type": "tool_call",
}
msg = AIMessage("hi?", tool_calls=[tool_call])
result = node.invoke(schema_(**{"messages": [msg], "foo": ""}))
tool_message = result["messages"][-1]
assert tool_message.content == "hi?"
result = node.invoke([msg])
tool_message = result[-1]
assert tool_message.content == "hi?"
@pytest.mark.skipif(
not IS_LANGCHAIN_CORE_030_OR_GREATER,
reason="Langchain core 0.3.0 or greater is required",
)
def test_tool_node_inject_store() -> None:
store = InMemoryStore()
namespace = ("test",)
def tool1(some_val: int, store: Annotated[BaseStore, InjectedStore()]) -> str:
"""Tool 1 docstring."""
store_val = store.get(namespace, "test_key").value["foo"]
return f"Some val: {some_val}, store val: {store_val}"
def tool2(some_val: int, store: Annotated[BaseStore, InjectedStore()]) -> str:
"""Tool 2 docstring."""
store_val = store.get(namespace, "test_key").value["foo"]
return f"Some val: {some_val}, store val: {store_val}"
def tool3(
some_val: int,
bar: Annotated[str, InjectedState("bar")],
store: Annotated[BaseStore, InjectedStore()],
) -> str:
"""Tool 3 docstring."""
store_val = store.get(namespace, "test_key").value["foo"]
return f"Some val: {some_val}, store val: {store_val}, state val: {bar}"
node = ToolNode([tool1, tool2, tool3], handle_tool_errors=True)
store.put(namespace, "test_key", {"foo": "bar"})
class State(MessagesState):
bar: str
builder = StateGraph(State)
builder.add_node("tools", node)
builder.add_edge(START, "tools")
graph = builder.compile(store=store)
for tool_name in ("tool1", "tool2"):
tool_call = {
"name": tool_name,
"args": {"some_val": 1},
"id": "some 0",
"type": "tool_call",
}
msg = AIMessage("hi?", tool_calls=[tool_call])
node_result = node.invoke({"messages": [msg]}, store=store)
graph_result = graph.invoke({"messages": [msg]})
for result in (node_result, graph_result):
result["messages"][-1]
tool_message = result["messages"][-1]
assert (
tool_message.content == "Some val: 1, store val: bar"
), f"Failed for tool={tool_name}"
tool_call = {
"name": "tool3",
"args": {"some_val": 1},
"id": "some 0",
"type": "tool_call",
}
msg = AIMessage("hi?", tool_calls=[tool_call])
node_result = node.invoke({"messages": [msg], "bar": "baz"}, store=store)
graph_result = graph.invoke({"messages": [msg], "bar": "baz"})
for result in (node_result, graph_result):
result["messages"][-1]
tool_message = result["messages"][-1]
assert (
tool_message.content == "Some val: 1, store val: bar, state val: baz"
), f"Failed for tool={tool_name}"
# test injected store without passing store to compiled graph
failing_graph = builder.compile()
with pytest.raises(ValueError):
failing_graph.invoke({"messages": [msg], "bar": "baz"})
def test_tool_node_ensure_utf8() -> None:
@dec_tool
def get_day_list(days: list[str]) -> list[str]:
"""choose days"""
return days
data = ["星期一", "水曜日", "목요일", "Friday"]
tools = [get_day_list]
tool_calls = [ToolCall(name=get_day_list.name, args={"days": data}, id="test_id")]
outputs: list[ToolMessage] = ToolNode(tools).invoke(
[AIMessage(content="", tool_calls=tool_calls)]
)
assert outputs[0].content == json.dumps(data, ensure_ascii=False)
def test_tool_node_messages_key() -> None:
@dec_tool
def add(a: int, b: int):
"""Adds a and b."""
return a + b
model = FakeToolCallingModel(
tool_calls=[[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")]]
)
class State(TypedDict):
subgraph_messages: Annotated[list[AnyMessage], add_messages]
def call_model(state: State):
response = model.invoke(state["subgraph_messages"])
model.tool_calls = []
return {"subgraph_messages": response}
builder = StateGraph(State)
builder.add_node("agent", call_model)
builder.add_node("tools", ToolNode([add], messages_key="subgraph_messages"))
builder.add_conditional_edges(
"agent", partial(tools_condition, messages_key="subgraph_messages")
)
builder.add_edge(START, "agent")
builder.add_edge("tools", "agent")
graph = builder.compile()
result = graph.invoke({"subgraph_messages": [HumanMessage(content="hi")]})
assert result["subgraph_messages"] == [
_AnyIdHumanMessage(content="hi"),
AIMessage(
content="hi",
id="0",
tool_calls=[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")],
),
_AnyIdToolMessage(content="3", name=add.name, tool_call_id="test_id"),
AIMessage(content="hi-hi-3", id="1"),
]
async def test_return_direct() -> None:
@dec_tool(return_direct=True)
def tool_return_direct(input: str) -> str:
"""A tool that returns directly."""
return f"Direct result: {input}"
@dec_tool
def tool_normal(input: str) -> str:
"""A normal tool."""
return f"Normal result: {input}"
first_tool_call = [
ToolCall(
name="tool_return_direct",
args={"input": "Test direct"},
id="1",
),
]
expected_ai = AIMessage(
content="Test direct",
id="0",
tool_calls=first_tool_call,
)
model = FakeToolCallingModel(tool_calls=[first_tool_call, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])
# Test direct return for tool_return_direct
result = agent.invoke(
{"messages": [HumanMessage(content="Test direct", id="hum0")]}
)
assert result["messages"] == [
HumanMessage(content="Test direct", id="hum0"),
expected_ai,
ToolMessage(
content="Direct result: Test direct",
name="tool_return_direct",
tool_call_id="1",
id=result["messages"][2].id,
),
]
second_tool_call = [
ToolCall(
name="tool_normal",
args={"input": "Test normal"},
id="2",
),
]
model = FakeToolCallingModel(tool_calls=[second_tool_call, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])
result = agent.invoke(
{"messages": [HumanMessage(content="Test normal", id="hum1")]}
)
assert result["messages"] == [
HumanMessage(content="Test normal", id="hum1"),
AIMessage(content="Test normal", id="0", tool_calls=second_tool_call),
ToolMessage(
content="Normal result: Test normal",
name="tool_normal",
tool_call_id="2",
id=result["messages"][2].id,
),
AIMessage(content="Test normal-Test normal-Normal result: Test normal", id="1"),
]
both_tool_calls = [
ToolCall(
name="tool_return_direct",
args={"input": "Test both direct"},
id="3",
),
ToolCall(
name="tool_normal",
args={"input": "Test both normal"},
id="4",
),
]
model = FakeToolCallingModel(tool_calls=[both_tool_calls, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])
result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]})
assert result["messages"] == [
HumanMessage(content="Test both", id="hum2"),
AIMessage(content="Test both", id="0", tool_calls=both_tool_calls),
ToolMessage(
content="Direct result: Test both direct",
name="tool_return_direct",
tool_call_id="3",
id=result["messages"][2].id,
),
ToolMessage(
content="Normal result: Test both normal",
name="tool_normal",
tool_call_id="4",
id=result["messages"][3].id,
),
]
def test__get_state_args() -> None:
class Schema1(BaseModel):
a: Annotated[str, InjectedState]
class Schema2(Schema1):
b: Annotated[int, InjectedState("bar")]
@dec_tool(args_schema=Schema2)
def foo(a: str, b: int) -> float:
"""return"""
return 0.0
assert _get_state_args(foo) == {"a": None, "b": "bar"}
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_io.py`:
```py
from typing import Iterator
from langgraph.pregel.io import single
def test_single() -> None:
closed = False
def myiter() -> Iterator[int]:
try:
yield 1
yield 2
finally:
nonlocal closed
closed = True
assert single(myiter()) == 1
assert closed
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_state.py`:
```py
import inspect
import warnings
from dataclasses import dataclass, field
from typing import Annotated as Annotated2
from typing import Any, Optional
import pytest
from langchain_core.runnables import RunnableConfig, RunnableLambda
from pydantic.v1 import BaseModel
from typing_extensions import Annotated, NotRequired, Required, TypedDict
from langgraph.graph.state import StateGraph, _get_node_name, _warn_invalid_state_schema
from langgraph.managed.shared_value import SharedValue
class State(BaseModel):
foo: str
bar: int
class State2(TypedDict):
foo: str
bar: int
@pytest.mark.parametrize(
"schema",
[
{"foo": "bar"},
["hi", lambda x, y: x + y],
State(foo="bar", bar=1),
State2(foo="bar", bar=1),
],
)
def test_warns_invalid_schema(schema: Any):
with pytest.warns(UserWarning):
_warn_invalid_state_schema(schema)
@pytest.mark.parametrize(
"schema",
[
Annotated[dict, lambda x, y: y],
Annotated2[list, lambda x, y: y],
dict,
State,
State2,
],
)
def test_doesnt_warn_valid_schema(schema: Any):
# Assert the function does not raise a warning
with warnings.catch_warnings():
warnings.simplefilter("error")
_warn_invalid_state_schema(schema)
def test_state_schema_with_type_hint():
class InputState(TypedDict):
question: str
class OutputState(TypedDict):
input_state: InputState
class FooState(InputState):
foo: str
def complete_hint(state: InputState) -> OutputState:
return {"input_state": state}
def miss_first_hint(state, config: RunnableConfig) -> OutputState:
return {"input_state": state}
def only_return_hint(state, config) -> OutputState:
return {"input_state": state}
def miss_all_hint(state, config):
return {"input_state": state}
def pre_foo(_) -> FooState:
return {"foo": "bar"}
class Foo:
def __call__(self, state: FooState) -> OutputState:
assert state.pop("foo") == "bar"
return {"input_state": state}
graph = StateGraph(InputState, output=OutputState)
actions = [
complete_hint,
miss_first_hint,
only_return_hint,
miss_all_hint,
pre_foo,
Foo(),
]
for action in actions:
graph.add_node(action)
def get_name(action) -> str:
return getattr(action, "__name__", action.__class__.__name__)
graph.set_entry_point(get_name(actions[0]))
for i in range(len(actions) - 1):
graph.add_edge(get_name(actions[i]), get_name(actions[i + 1]))
graph.set_finish_point(get_name(actions[-1]))
graph = graph.compile()
input_state = InputState(question="Hello World!")
output_state = OutputState(input_state=input_state)
foo_state = FooState(foo="bar")
for i, c in enumerate(graph.stream(input_state, stream_mode="updates")):
node_name = get_name(actions[i])
if node_name == get_name(pre_foo):
assert c[node_name] == foo_state
else:
assert c[node_name] == output_state
@pytest.mark.parametrize("total_", [True, False])
def test_state_schema_optional_values(total_: bool):
class SomeParentState(TypedDict):
val0a: str
val0b: Optional[str]
class InputState(SomeParentState, total=total_): # type: ignore
val1: str
val2: Optional[str]
val3: Required[str]
val4: NotRequired[dict]
val5: Annotated[Required[str], "foo"]
val6: Annotated[NotRequired[str], "bar"]
class OutputState(SomeParentState, total=total_): # type: ignore
out_val1: str
out_val2: Optional[str]
out_val3: Required[str]
out_val4: NotRequired[dict]
out_val5: Annotated[Required[str], "foo"]
out_val6: Annotated[NotRequired[str], "bar"]
class State(InputState): # this would be ignored
val4: dict
some_shared_channel: Annotated[str, SharedValue.on("assistant_id")] = field(
default="foo"
)
builder = StateGraph(State, input=InputState, output=OutputState)
builder.add_node("n", lambda x: x)
builder.add_edge("__start__", "n")
graph = builder.compile()
json_schema = graph.get_input_jsonschema()
if total_ is False:
expected_required = set()
expected_optional = {"val2", "val1"}
else:
expected_required = {"val1"}
expected_optional = {"val2"}
# The others should always have precedence based on the required annotation
expected_required |= {"val0a", "val3", "val5"}
expected_optional |= {"val0b", "val4", "val6"}
assert set(json_schema.get("required", set())) == expected_required
assert (
set(json_schema["properties"].keys()) == expected_required | expected_optional
)
# Check output schema. Should be the same process
output_schema = graph.get_output_jsonschema()
if total_ is False:
expected_required = set()
expected_optional = {"out_val2", "out_val1"}
else:
expected_required = {"out_val1"}
expected_optional = {"out_val2"}
expected_required |= {"val0a", "out_val3", "out_val5"}
expected_optional |= {"val0b", "out_val4", "out_val6"}
assert set(output_schema.get("required", set())) == expected_required
assert (
set(output_schema["properties"].keys()) == expected_required | expected_optional
)
@pytest.mark.parametrize("kw_only_", [False, True])
def test_state_schema_default_values(kw_only_: bool):
kwargs = {}
if "kw_only" in inspect.signature(dataclass).parameters:
kwargs = {"kw_only": kw_only_}
@dataclass(**kwargs)
class InputState:
val1: str
val2: Optional[int]
val3: Annotated[Optional[float], "optional annotated"]
val4: Optional[str] = None
val5: list[int] = field(default_factory=lambda: [1, 2, 3])
val6: dict[str, int] = field(default_factory=lambda: {"a": 1})
val7: str = field(default=...)
val8: Annotated[int, "some metadata"] = 42
val9: Annotated[str, "more metadata"] = field(default="some foo")
val10: str = "default"
val11: Annotated[list[str], "annotated list"] = field(
default_factory=lambda: ["a", "b"]
)
some_shared_channel: Annotated[str, SharedValue.on("assistant_id")] = field(
default="foo"
)
builder = StateGraph(InputState)
builder.add_node("n", lambda x: x)
builder.add_edge("__start__", "n")
graph = builder.compile()
for json_schema in [graph.get_input_jsonschema(), graph.get_output_jsonschema()]:
expected_required = {"val1", "val7"}
expected_optional = {
"val2",
"val3",
"val4",
"val5",
"val6",
"val8",
"val9",
"val10",
"val11",
}
assert set(json_schema.get("required", set())) == expected_required
assert (
set(json_schema["properties"].keys()) == expected_required | expected_optional
)
def test_raises_invalid_managed():
class BadInputState(TypedDict):
some_thing: str
some_input_channel: Annotated[str, SharedValue.on("assistant_id")]
class InputState(TypedDict):
some_thing: str
some_input_channel: str
class BadOutputState(TypedDict):
some_thing: str
some_output_channel: Annotated[str, SharedValue.on("assistant_id")]
class OutputState(TypedDict):
some_thing: str
some_output_channel: str
class State(TypedDict):
some_thing: str
some_channel: Annotated[str, SharedValue.on("assistant_id")]
# All OK
StateGraph(State, input=InputState, output=OutputState)
StateGraph(State)
StateGraph(State, input=State, output=State)
StateGraph(State, input=InputState)
StateGraph(State, input=InputState)
bad_input_examples = [
(State, BadInputState, OutputState),
(State, BadInputState, BadOutputState),
(State, BadInputState, State),
(State, BadInputState, None),
]
for _state, _inp, _outp in bad_input_examples:
with pytest.raises(
ValueError,
match="Invalid managed channels detected in BadInputState: some_input_channel. Managed channels are not permitted in Input/Output schema.",
):
StateGraph(_state, input=_inp, output=_outp)
bad_output_examples = [
(State, InputState, BadOutputState),
(State, None, BadOutputState),
]
for _state, _inp, _outp in bad_output_examples:
with pytest.raises(
ValueError,
match="Invalid managed channels detected in BadOutputState: some_output_channel. Managed channels are not permitted in Input/Output schema.",
):
StateGraph(_state, input=_inp, output=_outp)
def test__get_node_name() -> None:
# default runnable name
assert _get_node_name(RunnableLambda(func=lambda x: x)) == "RunnableLambda"
# custom runnable name
assert (
_get_node_name(RunnableLambda(name="my_runnable", func=lambda x: x))
== "my_runnable"
)
# lambda
assert _get_node_name(lambda x: x) == "<lambda>"
# regular function
def func(state):
return
assert _get_node_name(func) == "func"
class MyClass:
def __call__(self, state):
return
def class_method(self, state):
return
# callable class
assert _get_node_name(MyClass()) == "MyClass"
# class method
assert _get_node_name(MyClass().class_method) == "class_method"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_messages_state.py`:
```py
from typing import Annotated
from uuid import UUID
import pytest
from langchain_core.messages import (
AIMessage,
AnyMessage,
HumanMessage,
RemoveMessage,
SystemMessage,
)
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
from langgraph.graph import add_messages
from langgraph.graph.message import MessagesState
from langgraph.graph.state import END, START, StateGraph
from tests.conftest import IS_LANGCHAIN_CORE_030_OR_GREATER
from tests.messages import _AnyIdHumanMessage
def test_add_single_message():
left = [HumanMessage(content="Hello", id="1")]
right = AIMessage(content="Hi there!", id="2")
result = add_messages(left, right)
expected_result = [
HumanMessage(content="Hello", id="1"),
AIMessage(content="Hi there!", id="2"),
]
assert result == expected_result
def test_add_multiple_messages():
left = [HumanMessage(content="Hello", id="1")]
right = [
AIMessage(content="Hi there!", id="2"),
SystemMessage(content="System message", id="3"),
]
result = add_messages(left, right)
expected_result = [
HumanMessage(content="Hello", id="1"),
AIMessage(content="Hi there!", id="2"),
SystemMessage(content="System message", id="3"),
]
assert result == expected_result
def test_update_existing_message():
left = [HumanMessage(content="Hello", id="1")]
right = HumanMessage(content="Hello again", id="1")
result = add_messages(left, right)
expected_result = [HumanMessage(content="Hello again", id="1")]
assert result == expected_result
def test_missing_ids():
left = [HumanMessage(content="Hello")]
right = [AIMessage(content="Hi there!")]
result = add_messages(left, right)
assert len(result) == 2
assert all(isinstance(m.id, str) and UUID(m.id, version=4) for m in result)
def test_remove_message():
left = [
HumanMessage(content="Hello", id="1"),
AIMessage(content="Hi there!", id="2"),
]
right = RemoveMessage(id="2")
result = add_messages(left, right)
expected_result = [HumanMessage(content="Hello", id="1")]
assert result == expected_result
def test_duplicate_remove_message():
left = [
HumanMessage(content="Hello", id="1"),
AIMessage(content="Hi there!", id="2"),
]
right = [RemoveMessage(id="2"), RemoveMessage(id="2")]
result = add_messages(left, right)
expected_result = [HumanMessage(content="Hello", id="1")]
assert result == expected_result
def test_remove_nonexistent_message():
left = [HumanMessage(content="Hello", id="1")]
right = RemoveMessage(id="2")
with pytest.raises(
ValueError, match="Attempting to delete a message with an ID that doesn't exist"
):
add_messages(left, right)
def test_mixed_operations():
left = [
HumanMessage(content="Hello", id="1"),
AIMessage(content="Hi there!", id="2"),
]
right = [
HumanMessage(content="Updated hello", id="1"),
RemoveMessage(id="2"),
SystemMessage(content="New message", id="3"),
]
result = add_messages(left, right)
expected_result = [
HumanMessage(content="Updated hello", id="1"),
SystemMessage(content="New message", id="3"),
]
assert result == expected_result
def test_empty_inputs():
assert add_messages([], []) == []
assert add_messages([], [HumanMessage(content="Hello", id="1")]) == [
HumanMessage(content="Hello", id="1")
]
assert add_messages([HumanMessage(content="Hello", id="1")], []) == [
HumanMessage(content="Hello", id="1")
]
def test_non_list_inputs():
left = HumanMessage(content="Hello", id="1")
right = AIMessage(content="Hi there!", id="2")
result = add_messages(left, right)
expected_result = [
HumanMessage(content="Hello", id="1"),
AIMessage(content="Hi there!", id="2"),
]
assert result == expected_result
def test_delete_all():
left = [
HumanMessage(content="Hello", id="1"),
AIMessage(content="Hi there!", id="2"),
]
right = [
RemoveMessage(id="1"),
RemoveMessage(id="2"),
]
result = add_messages(left, right)
expected_result = []
assert result == expected_result
MESSAGES_STATE_SCHEMAS = [MessagesState]
if IS_LANGCHAIN_CORE_030_OR_GREATER:
class MessagesStatePydantic(BaseModel):
messages: Annotated[list[AnyMessage], add_messages]
MESSAGES_STATE_SCHEMAS.append(MessagesStatePydantic)
else:
class MessagesStatePydanticV1(BaseModelV1):
messages: Annotated[list[AnyMessage], add_messages]
MESSAGES_STATE_SCHEMAS.append(MessagesStatePydanticV1)
@pytest.mark.parametrize("state_schema", MESSAGES_STATE_SCHEMAS)
def test_messages_state(state_schema):
def foo(state):
return {"messages": [HumanMessage("foo")]}
graph = StateGraph(state_schema)
graph.add_edge(START, "foo")
graph.add_edge("foo", END)
graph.add_node(foo)
app = graph.compile()
assert app.invoke({"messages": [("user", "meow")]}) == {
"messages": [
_AnyIdHumanMessage(content="meow"),
_AnyIdHumanMessage(content="foo"),
]
}
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_channels.py`:
```py
import operator
from typing import Sequence, Union
import pytest
from langgraph.channels.binop import BinaryOperatorAggregate
from langgraph.channels.last_value import LastValue
from langgraph.channels.topic import Topic
from langgraph.errors import EmptyChannelError, InvalidUpdateError
pytestmark = pytest.mark.anyio
def test_last_value() -> None:
channel = LastValue(int).from_checkpoint(None)
assert channel.ValueType is int
assert channel.UpdateType is int
with pytest.raises(EmptyChannelError):
channel.get()
with pytest.raises(InvalidUpdateError):
channel.update([5, 6])
channel.update([3])
assert channel.get() == 3
channel.update([4])
assert channel.get() == 4
checkpoint = channel.checkpoint()
channel = LastValue(int).from_checkpoint(checkpoint)
assert channel.get() == 4
def test_topic() -> None:
channel = Topic(str).from_checkpoint(None)
assert channel.ValueType is Sequence[str]
assert channel.UpdateType is Union[str, list[str]]
assert channel.update(["a", "b"])
assert channel.get() == ["a", "b"]
assert channel.update([["c", "d"], "d"])
assert channel.get() == ["c", "d", "d"]
assert channel.update([])
with pytest.raises(EmptyChannelError):
channel.get()
assert not channel.update([]), "channel already empty"
assert channel.update(["e"])
assert channel.get() == ["e"]
checkpoint = channel.checkpoint()
channel = Topic(str).from_checkpoint(checkpoint)
assert channel.get() == ["e"]
channel_copy = Topic(str).from_checkpoint(checkpoint)
channel_copy.update(["f"])
assert channel_copy.get() == ["f"]
assert channel.get() == ["e"]
def test_topic_accumulate() -> None:
channel = Topic(str, accumulate=True).from_checkpoint(None)
assert channel.ValueType is Sequence[str]
assert channel.UpdateType is Union[str, list[str]]
assert channel.update(["a", "b"])
assert channel.get() == ["a", "b"]
assert channel.update(["b", ["c", "d"], "d"])
assert channel.get() == ["a", "b", "b", "c", "d", "d"]
assert not channel.update([])
assert channel.get() == ["a", "b", "b", "c", "d", "d"]
checkpoint = channel.checkpoint()
channel = Topic(str, accumulate=True).from_checkpoint(checkpoint)
assert channel.get() == ["a", "b", "b", "c", "d", "d"]
assert channel.update(["e"])
assert channel.get() == ["a", "b", "b", "c", "d", "d", "e"]
def test_binop() -> None:
channel = BinaryOperatorAggregate(int, operator.add).from_checkpoint(None)
assert channel.ValueType is int
assert channel.UpdateType is int
assert channel.get() == 0
channel.update([1, 2, 3])
assert channel.get() == 6
channel.update([4])
assert channel.get() == 10
checkpoint = channel.checkpoint()
channel = BinaryOperatorAggregate(int, operator.add).from_checkpoint(checkpoint)
assert channel.get() == 10
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_runnable.py`:
```py
from __future__ import annotations
from typing import Any
import pytest
from langgraph.store.base import BaseStore
from langgraph.types import StreamWriter
from langgraph.utils.runnable import RunnableCallable
pytestmark = pytest.mark.anyio
def test_runnable_callable_func_accepts():
def sync_func(x: Any) -> str:
return f"{x}"
async def async_func(x: Any) -> str:
return f"{x}"
def func_with_store(x: Any, store: BaseStore) -> str:
return f"{x}"
def func_with_writer(x: Any, writer: StreamWriter) -> str:
return f"{x}"
async def afunc_with_store(x: Any, store: BaseStore) -> str:
return f"{x}"
async def afunc_with_writer(x: Any, writer: StreamWriter) -> str:
return f"{x}"
runnables = {
"sync": RunnableCallable(sync_func),
"async": RunnableCallable(func=None, afunc=async_func),
"with_store": RunnableCallable(func_with_store),
"with_writer": RunnableCallable(func_with_writer),
"awith_store": RunnableCallable(afunc_with_store),
"awith_writer": RunnableCallable(afunc_with_writer),
}
expected_store = {"with_store": True, "awith_store": True}
expected_writer = {"with_writer": True, "awith_writer": True}
for name, runnable in runnables.items():
assert runnable.func_accepts["writer"] == expected_writer.get(name, False)
assert runnable.func_accepts["store"] == expected_store.get(name, False)
async def test_runnable_callable_basic():
def sync_func(x: Any) -> str:
return f"{x}"
async def async_func(x: Any) -> str:
return f"{x}"
runnable_sync = RunnableCallable(sync_func)
runnable_async = RunnableCallable(func=None, afunc=async_func)
result_sync = runnable_sync.invoke("test")
assert result_sync == "test"
# Test asynchronous ainvoke
result_async = await runnable_async.ainvoke("test")
assert result_async == "test"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_tracing_interops.py`:
```py
import json
import sys
import time
from typing import Any, Callable, Tuple, TypedDict, TypeVar
from unittest.mock import MagicMock
import langsmith as ls
import pytest
from langchain_core.runnables import RunnableConfig
from langchain_core.tracers import LangChainTracer
from langgraph.graph import StateGraph
pytestmark = pytest.mark.anyio
def _get_mock_client(**kwargs: Any) -> ls.Client:
mock_session = MagicMock()
return ls.Client(session=mock_session, api_key="test", **kwargs)
def _get_calls(
mock_client: Any,
verbs: set[str] = {"POST"},
) -> list:
return [
c
for c in mock_client.session.request.mock_calls
if c.args and c.args[0] in verbs
]
T = TypeVar("T")
def wait_for(
condition: Callable[[], Tuple[T, bool]],
max_sleep_time: int = 10,
sleep_time: int = 3,
) -> T:
"""Wait for a condition to be true."""
start_time = time.time()
last_e = None
while time.time() - start_time < max_sleep_time:
try:
res, cond = condition()
if cond:
return res
except Exception as e:
last_e = e
time.sleep(sleep_time)
total_time = time.time() - start_time
if last_e is not None:
raise last_e
raise ValueError(f"Callable did not return within {total_time}")
@pytest.mark.skip("This test times out in CI")
async def test_nested_tracing():
lt_py_311 = sys.version_info < (3, 11)
mock_client = _get_mock_client()
class State(TypedDict):
value: str
@ls.traceable
async def some_traceable(content: State):
return await child_graph.ainvoke(content)
async def parent_node(state: State, config: RunnableConfig) -> State:
if lt_py_311:
result = await some_traceable(state, langsmith_extra={"config": config})
else:
result = await some_traceable(state)
return {"value": f"parent_{result['value']}"}
async def child_node(state: State) -> State:
return {"value": f"child_{state['value']}"}
child_builder = StateGraph(State)
child_builder.add_node(child_node)
child_builder.add_edge("__start__", "child_node")
child_graph = child_builder.compile().with_config(run_name="child_graph")
parent_builder = StateGraph(State)
parent_builder.add_node(parent_node)
parent_builder.add_edge("__start__", "parent_node")
parent_graph = parent_builder.compile()
tracer = LangChainTracer(client=mock_client)
result = await parent_graph.ainvoke({"value": "input"}, {"callbacks": [tracer]})
assert result == {"value": "parent_child_input"}
def get_posts():
post_calls = _get_calls(mock_client, verbs={"POST"})
posts = [p for c in post_calls for p in json.loads(c.kwargs["data"])["post"]]
names = [p.get("name") for p in posts]
if "child_node" in names:
return posts, True
return None, False
posts = wait_for(get_posts)
# If the callbacks weren't propagated correctly, we'd
# end up with broken dotted_orders
parent_run = next(data for data in posts if data["name"] == "parent_node")
child_run = next(data for data in posts if data["name"] == "child_graph")
traceable_run = next(data for data in posts if data["name"] == "some_traceable")
assert child_run["dotted_order"].startswith(traceable_run["dotted_order"])
assert traceable_run["dotted_order"].startswith(parent_run["dotted_order"])
assert child_run["parent_run_id"] == traceable_run["id"]
assert traceable_run["parent_run_id"] == parent_run["id"]
assert parent_run["trace_id"] == child_run["trace_id"] == traceable_run["trace_id"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/messages.py`:
```py
"""Redefined messages as a work-around for pydantic issue with AnyStr.
The code below creates version of pydantic models
that will work in unit tests with AnyStr as id field
Please note that the `id` field is assigned AFTER the model is created
to workaround an issue with pydantic ignoring the __eq__ method on
subclassed strings.
"""
from typing import Any
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
from tests.any_str import AnyStr
def _AnyIdDocument(**kwargs: Any) -> Document:
"""Create a document with an id field."""
message = Document(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
"""Create ai message with an any id field."""
message = AIMessage(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
"""Create ai message with an any id field."""
message = AIMessageChunk(**kwargs)
message.id = AnyStr()
return message
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
"""Create a human message with an any id field."""
message = HumanMessage(**kwargs)
message.id = AnyStr()
return message
def _AnyIdToolMessage(**kwargs: Any) -> ToolMessage:
"""Create a tool message with an any id field."""
message = ToolMessage(**kwargs)
message.id = AnyStr()
return message
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_pregel.py`:
```py
import enum
import json
import logging
import operator
import re
import time
import uuid
import warnings
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import replace
from random import randrange
from typing import (
Annotated,
Any,
Dict,
Generator,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
TypedDict,
Union,
cast,
get_type_hints,
)
import httpx
import pytest
from langchain_core.runnables import (
RunnableConfig,
RunnableLambda,
RunnableMap,
RunnablePassthrough,
RunnablePick,
)
from langsmith import traceable
from pydantic import BaseModel
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from langgraph.channels.base import BaseChannel
from langgraph.channels.binop import BinaryOperatorAggregate
from langgraph.channels.context import Context
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.last_value import LastValue
from langgraph.channels.topic import Topic
from langgraph.channels.untracked_value import UntrackedValue
from langgraph.checkpoint.base import (
BaseCheckpointSaver,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
)
from langgraph.checkpoint.memory import MemorySaver
from langgraph.constants import (
CONFIG_KEY_NODE_FINISHED,
ERROR,
FF_SEND_V2,
PULL,
PUSH,
START,
)
from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt
from langgraph.graph import END, Graph, StateGraph
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
from langgraph.managed.shared_value import SharedValue
from langgraph.prebuilt.chat_agent_executor import create_tool_calling_executor
from langgraph.prebuilt.tool_node import ToolNode
from langgraph.pregel import Channel, GraphRecursionError, Pregel, StateSnapshot
from langgraph.pregel.retry import RetryPolicy
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
from langgraph.types import (
Command,
Interrupt,
PregelTask,
Send,
StreamWriter,
interrupt,
)
from tests.any_str import AnyDict, AnyStr, AnyVersion, FloatBetween, UnsortedSequence
from tests.conftest import (
ALL_CHECKPOINTERS_SYNC,
ALL_STORES_SYNC,
SHOULD_CHECK_SNAPSHOTS,
)
from tests.fake_chat import FakeChatModel
from tests.fake_tracer import FakeTracer
from tests.memory_assert import MemorySaverAssertCheckpointMetadata
from tests.messages import (
_AnyIdAIMessage,
_AnyIdAIMessageChunk,
_AnyIdHumanMessage,
_AnyIdToolMessage,
)
logger = logging.getLogger(__name__)
# define these objects to avoid importing langchain_core.agents
# and therefore avoid relying on core Pydantic version
class AgentAction(BaseModel):
tool: str
tool_input: Union[str, dict]
log: str
type: Literal["AgentAction"] = "AgentAction"
model_config = {
"json_schema_extra": {
"description": (
"""Represents a request to execute an action by an agent.
The action consists of the name of the tool to execute and the input to pass
to the tool. The log is used to pass along extra information about the action."""
)
}
}
class AgentFinish(BaseModel):
"""Final return value of an ActionAgent.
Agents return an AgentFinish when they have reached a stopping condition.
"""
return_values: dict
log: str
type: Literal["AgentFinish"] = "AgentFinish"
model_config = {
"json_schema_extra": {
"description": (
"""Final return value of an ActionAgent.
Agents return an AgentFinish when they have reached a stopping condition."""
)
}
}
def test_graph_validation() -> None:
def logic(inp: str) -> str:
return ""
workflow = Graph()
workflow.add_node("agent", logic)
workflow.set_entry_point("agent")
workflow.set_finish_point("agent")
assert workflow.compile(), "valid graph"
# Accept a dead-end
workflow = Graph()
workflow.add_node("agent", logic)
workflow.set_entry_point("agent")
workflow.compile()
workflow = Graph()
workflow.add_node("agent", logic)
workflow.set_finish_point("agent")
with pytest.raises(ValueError, match="must have an entrypoint"):
workflow.compile()
workflow = Graph()
workflow.add_node("agent", logic)
workflow.add_node("tools", logic)
workflow.set_entry_point("agent")
workflow.add_conditional_edges("agent", logic, {"continue": "tools", "exit": END})
workflow.add_edge("tools", "agent")
assert workflow.compile(), "valid graph"
workflow = Graph()
workflow.add_node("agent", logic)
workflow.add_node("tools", logic)
workflow.set_entry_point("tools")
workflow.add_conditional_edges("agent", logic, {"continue": "tools", "exit": END})
workflow.add_edge("tools", "agent")
assert workflow.compile(), "valid graph"
workflow = Graph()
workflow.set_entry_point("tools")
workflow.add_conditional_edges("agent", logic, {"continue": "tools", "exit": END})
workflow.add_edge("tools", "agent")
workflow.add_node("agent", logic)
workflow.add_node("tools", logic)
assert workflow.compile(), "valid graph"
workflow = Graph()
workflow.set_entry_point("tools")
workflow.add_conditional_edges(
"agent", logic, {"continue": "tools", "exit": END, "hmm": "extra"}
)
workflow.add_edge("tools", "agent")
workflow.add_node("agent", logic)
workflow.add_node("tools", logic)
with pytest.raises(ValueError, match="unknown"): # extra is not defined
workflow.compile()
workflow = Graph()
workflow.set_entry_point("agent")
workflow.add_conditional_edges("agent", logic, {"continue": "tools", "exit": END})
workflow.add_edge("tools", "extra")
workflow.add_node("agent", logic)
workflow.add_node("tools", logic)
with pytest.raises(ValueError, match="unknown"): # extra is not defined
workflow.compile()
workflow = Graph()
workflow.add_node("agent", logic)
workflow.add_node("tools", logic)
workflow.add_node("extra", logic)
workflow.set_entry_point("agent")
workflow.add_conditional_edges("agent", logic)
workflow.add_edge("tools", "agent")
# Accept, even though extra is dead-end
workflow.compile()
class State(TypedDict):
hello: str
def node_a(state: State) -> State:
# typo
return {"hell": "world"}
builder = StateGraph(State)
builder.add_node("a", node_a)
builder.set_entry_point("a")
builder.set_finish_point("a")
graph = builder.compile()
with pytest.raises(InvalidUpdateError):
graph.invoke({"hello": "there"})
graph = StateGraph(State)
graph.add_node("start", lambda x: x)
graph.add_edge("__start__", "start")
graph.add_edge("unknown", "start")
graph.add_edge("start", "__end__")
with pytest.raises(ValueError, match="Found edge starting at unknown node "):
graph.compile()
def bad_reducer(a): ...
class BadReducerState(TypedDict):
hello: Annotated[str, bad_reducer]
with pytest.raises(ValueError, match="Invalid reducer"):
StateGraph(BadReducerState)
def node_b(state: State) -> State:
return {"hello": "world"}
builder = StateGraph(State)
builder.add_node("a", node_b)
builder.add_node("b", node_b)
builder.add_node("c", node_b)
builder.set_entry_point("a")
builder.add_edge("a", "b")
builder.add_edge("a", "c")
graph = builder.compile()
with pytest.raises(InvalidUpdateError, match="At key 'hello'"):
graph.invoke({"hello": "there"})
def test_graph_validation_with_command() -> None:
class State(TypedDict):
foo: str
bar: str
def node_a(state: State):
return Command(goto="b", update={"foo": "bar"})
def node_b(state: State):
return Command(goto=END, update={"bar": "baz"})
builder = StateGraph(State)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_edge(START, "a")
graph = builder.compile()
assert graph.invoke({"foo": ""}) == {"foo": "bar", "bar": "baz"}
def test_checkpoint_errors() -> None:
class FaultyGetCheckpointer(MemorySaver):
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
raise ValueError("Faulty get_tuple")
class FaultyPutCheckpointer(MemorySaver):
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: Optional[dict[str, Union[str, int, float]]] = None,
) -> RunnableConfig:
raise ValueError("Faulty put")
class FaultyPutWritesCheckpointer(MemorySaver):
def put_writes(
self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str
) -> RunnableConfig:
raise ValueError("Faulty put_writes")
class FaultyVersionCheckpointer(MemorySaver):
def get_next_version(self, current: Optional[int], channel: BaseChannel) -> int:
raise ValueError("Faulty get_next_version")
def logic(inp: str) -> str:
return ""
builder = StateGraph(Annotated[str, operator.add])
builder.add_node("agent", logic)
builder.add_edge(START, "agent")
graph = builder.compile(checkpointer=FaultyGetCheckpointer())
with pytest.raises(ValueError, match="Faulty get_tuple"):
graph.invoke("", {"configurable": {"thread_id": "thread-1"}})
graph = builder.compile(checkpointer=FaultyPutCheckpointer())
with pytest.raises(ValueError, match="Faulty put"):
graph.invoke("", {"configurable": {"thread_id": "thread-1"}})
graph = builder.compile(checkpointer=FaultyVersionCheckpointer())
with pytest.raises(ValueError, match="Faulty get_next_version"):
graph.invoke("", {"configurable": {"thread_id": "thread-1"}})
# add parallel node
builder.add_node("parallel", logic)
builder.add_edge(START, "parallel")
graph = builder.compile(checkpointer=FaultyPutWritesCheckpointer())
with pytest.raises(ValueError, match="Faulty put_writes"):
graph.invoke("", {"configurable": {"thread_id": "thread-1"}})
def test_node_schemas_custom_output() -> None:
class State(TypedDict):
hello: str
bye: str
messages: Annotated[list[str], add_messages]
class Output(TypedDict):
messages: list[str]
class StateForA(TypedDict):
hello: str
messages: Annotated[list[str], add_messages]
def node_a(state: StateForA) -> State:
assert state == {
"hello": "there",
"messages": [_AnyIdHumanMessage(content="hello")],
}
class StateForB(TypedDict):
bye: str
now: int
def node_b(state: StateForB):
assert state == {
"bye": "world",
}
return {
"now": 123,
"hello": "again",
}
class StateForC(TypedDict):
hello: str
now: int
def node_c(state: StateForC) -> StateForC:
assert state == {
"hello": "again",
"now": 123,
}
builder = StateGraph(State, output=Output)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_node("c", node_c)
builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("b", "c")
graph = builder.compile()
assert graph.invoke({"hello": "there", "bye": "world", "messages": "hello"}) == {
"messages": [_AnyIdHumanMessage(content="hello")],
}
builder = StateGraph(State, output=Output)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_node("c", node_c)
builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("b", "c")
graph = builder.compile()
assert graph.invoke(
{
"hello": "there",
"bye": "world",
"messages": "hello",
"now": 345, # ignored because not in input schema
}
) == {
"messages": [_AnyIdHumanMessage(content="hello")],
}
assert [
c
for c in graph.stream(
{
"hello": "there",
"bye": "world",
"messages": "hello",
"now": 345, # ignored because not in input schema
}
)
] == [
{"a": None},
{"b": {"hello": "again", "now": 123}},
{"c": None},
]
def test_reducer_before_first_node() -> None:
class State(TypedDict):
hello: str
messages: Annotated[list[str], add_messages]
def node_a(state: State) -> State:
assert state == {
"hello": "there",
"messages": [_AnyIdHumanMessage(content="hello")],
}
builder = StateGraph(State)
builder.add_node("a", node_a)
builder.set_entry_point("a")
builder.set_finish_point("a")
graph = builder.compile()
assert graph.invoke({"hello": "there", "messages": "hello"}) == {
"hello": "there",
"messages": [_AnyIdHumanMessage(content="hello")],
}
class State(TypedDict):
hello: str
messages: Annotated[List[str], add_messages]
def node_a(state: State) -> State:
assert state == {
"hello": "there",
"messages": [_AnyIdHumanMessage(content="hello")],
}
builder = StateGraph(State)
builder.add_node("a", node_a)
builder.set_entry_point("a")
builder.set_finish_point("a")
graph = builder.compile()
assert graph.invoke({"hello": "there", "messages": "hello"}) == {
"hello": "there",
"messages": [_AnyIdHumanMessage(content="hello")],
}
class State(TypedDict):
hello: str
messages: Annotated[Sequence[str], add_messages]
def node_a(state: State) -> State:
assert state == {
"hello": "there",
"messages": [_AnyIdHumanMessage(content="hello")],
}
builder = StateGraph(State)
builder.add_node("a", node_a)
builder.set_entry_point("a")
builder.set_finish_point("a")
graph = builder.compile()
assert graph.invoke({"hello": "there", "messages": "hello"}) == {
"hello": "there",
"messages": [_AnyIdHumanMessage(content="hello")],
}
def test_invoke_single_process_in_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
app = Pregel(
nodes={
"one": chain,
},
channels={
"input": LastValue(int),
"output": LastValue(int),
},
input_channels="input",
output_channels="output",
)
graph = Graph()
graph.add_node("add_one", add_one)
graph.set_entry_point("add_one")
graph.set_finish_point("add_one")
gapp = graph.compile()
if SHOULD_CHECK_SNAPSHOTS:
assert app.input_schema.model_json_schema() == {
"title": "LangGraphInput",
"type": "integer",
}
assert app.output_schema.model_json_schema() == {
"title": "LangGraphOutput",
"type": "integer",
}
with warnings.catch_warnings():
warnings.simplefilter("error") # raise warnings as errors
assert app.config_schema().model_json_schema() == {
"properties": {},
"title": "LangGraphConfig",
"type": "object",
}
assert app.invoke(2) == 3
assert app.invoke(2, output_keys=["output"]) == {"output": 3}
assert repr(app), "does not raise recursion error"
assert gapp.invoke(2, debug=True) == 3
@pytest.mark.parametrize(
"falsy_value",
[None, False, 0, "", [], {}, set(), frozenset(), 0.0, 0j],
)
def test_invoke_single_process_in_out_falsy_values(falsy_value: Any) -> None:
graph = Graph()
graph.add_node("return_falsy_const", lambda *args, **kwargs: falsy_value)
graph.set_entry_point("return_falsy_const")
graph.set_finish_point("return_falsy_const")
gapp = graph.compile()
assert gapp.invoke(1) == falsy_value
def test_invoke_single_process_in_write_kwargs(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
chain = (
Channel.subscribe_to("input")
| add_one
| Channel.write_to("output", fixed=5, output_plus_one=lambda x: x + 1)
)
app = Pregel(
nodes={"one": chain},
channels={
"input": LastValue(int),
"output": LastValue(int),
"fixed": LastValue(int),
"output_plus_one": LastValue(int),
},
output_channels=["output", "fixed", "output_plus_one"],
input_channels="input",
)
if SHOULD_CHECK_SNAPSHOTS:
assert app.input_schema.model_json_schema() == {
"title": "LangGraphInput",
"type": "integer",
}
assert app.output_schema.model_json_schema() == {
"title": "LangGraphOutput",
"type": "object",
"properties": {
"output": {"title": "Output", "type": "integer", "default": None},
"fixed": {"title": "Fixed", "type": "integer", "default": None},
"output_plus_one": {
"title": "Output Plus One",
"type": "integer",
"default": None,
},
},
}
assert app.invoke(2) == {"output": 3, "fixed": 5, "output_plus_one": 4}
def test_invoke_single_process_in_out_dict(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": chain},
channels={"input": LastValue(int), "output": LastValue(int)},
input_channels="input",
output_channels=["output"],
)
if SHOULD_CHECK_SNAPSHOTS:
assert app.input_schema.model_json_schema() == {
"title": "LangGraphInput",
"type": "integer",
}
assert app.output_schema.model_json_schema() == {
"title": "LangGraphOutput",
"type": "object",
"properties": {
"output": {"title": "Output", "type": "integer", "default": None}
},
}
assert app.invoke(2) == {"output": 3}
def test_invoke_single_process_in_dict_out_dict(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
chain = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": chain},
channels={"input": LastValue(int), "output": LastValue(int)},
input_channels=["input"],
output_channels=["output"],
)
if SHOULD_CHECK_SNAPSHOTS:
assert app.input_schema.model_json_schema() == {
"title": "LangGraphInput",
"type": "object",
"properties": {
"input": {"title": "Input", "type": "integer", "default": None}
},
}
assert app.output_schema.model_json_schema() == {
"title": "LangGraphOutput",
"type": "object",
"properties": {
"output": {"title": "Output", "type": "integer", "default": None}
},
}
assert app.invoke({"input": 2}) == {"output": 3}
def test_invoke_two_processes_in_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
two = Channel.subscribe_to("inbox") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": one, "two": two},
channels={
"inbox": LastValue(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
)
assert app.invoke(2) == 4
with pytest.raises(GraphRecursionError):
app.invoke(2, {"recursion_limit": 1}, debug=1)
graph = Graph()
graph.add_node("add_one", add_one)
graph.add_node("add_one_more", add_one)
graph.set_entry_point("add_one")
graph.set_finish_point("add_one_more")
graph.add_edge("add_one", "add_one_more")
gapp = graph.compile()
assert gapp.invoke(2) == 4
for step, values in enumerate(gapp.stream(2, debug=1), start=1):
if step == 1:
assert values == {
"add_one": 3,
}
elif step == 2:
assert values == {
"add_one_more": 4,
}
else:
assert 0, f"{step}:{values}"
assert step == 2
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_invoke_two_processes_in_out_interrupt(
request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
two = Channel.subscribe_to("inbox") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": one, "two": two},
channels={
"inbox": LastValue(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
checkpointer=checkpointer,
interrupt_after_nodes=["one"],
)
thread1 = {"configurable": {"thread_id": "1"}}
thread2 = {"configurable": {"thread_id": "2"}}
# start execution, stop at inbox
assert app.invoke(2, thread1) is None
# inbox == 3
checkpoint = checkpointer.get(thread1)
assert checkpoint is not None
assert checkpoint["channel_values"]["inbox"] == 3
# resume execution, finish
assert app.invoke(None, thread1) == 4
# start execution again, stop at inbox
assert app.invoke(20, thread1) is None
# inbox == 21
checkpoint = checkpointer.get(thread1)
assert checkpoint is not None
assert checkpoint["channel_values"]["inbox"] == 21
# send a new value in, interrupting the previous execution
assert app.invoke(3, thread1) is None
assert app.invoke(None, thread1) == 5
# start execution again, stopping at inbox
assert app.invoke(20, thread2) is None
# inbox == 21
snapshot = app.get_state(thread2)
assert snapshot.values["inbox"] == 21
assert snapshot.next == ("two",)
# update the state, resume
app.update_state(thread2, 25, as_node="one")
assert app.invoke(None, thread2) == 26
# no pending tasks
snapshot = app.get_state(thread2)
assert snapshot.next == ()
# list history
history = [c for c in app.get_state_history(thread1)]
assert history == [
StateSnapshot(
values={"inbox": 4, "output": 5, "input": 3},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 6,
"writes": {"two": 5},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[1].config,
),
StateSnapshot(
values={"inbox": 4, "output": 4, "input": 3},
tasks=(PregelTask(AnyStr(), "two", (PULL, "two"), result={"output": 5}),),
next=("two",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 5,
"writes": {"one": None},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[2].config,
),
StateSnapshot(
values={"inbox": 21, "output": 4, "input": 3},
tasks=(PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 4}),),
next=("one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"step": 4,
"writes": {"input": 3},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[3].config,
),
StateSnapshot(
values={"inbox": 21, "output": 4, "input": 20},
tasks=(PregelTask(AnyStr(), "two", (PULL, "two")),),
next=("two",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"one": None},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[4].config,
),
StateSnapshot(
values={"inbox": 3, "output": 4, "input": 20},
tasks=(PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 21}),),
next=("one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"step": 2,
"writes": {"input": 20},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[5].config,
),
StateSnapshot(
values={"inbox": 3, "output": 4, "input": 2},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"two": 4},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[6].config,
),
StateSnapshot(
values={"inbox": 3, "input": 2},
tasks=(PregelTask(AnyStr(), "two", (PULL, "two"), result={"output": 4}),),
next=("two",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {"one": None},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[7].config,
),
StateSnapshot(
values={"input": 2},
tasks=(PregelTask(AnyStr(), "one", (PULL, "one"), result={"inbox": 3}),),
next=("one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"step": -1,
"writes": {"input": 2},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
),
]
# re-running from any previous checkpoint should re-run nodes
assert [c for c in app.stream(None, history[0].config, stream_mode="updates")] == []
assert [c for c in app.stream(None, history[1].config, stream_mode="updates")] == [
{"two": {"output": 5}},
]
assert [c for c in app.stream(None, history[2].config, stream_mode="updates")] == [
{"one": {"inbox": 4}},
{"__interrupt__": ()},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_fork_always_re_runs_nodes(
request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
add_one = mocker.Mock(side_effect=lambda _: 1)
builder = StateGraph(Annotated[int, operator.add])
builder.add_node("add_one", add_one)
builder.add_edge(START, "add_one")
builder.add_conditional_edges("add_one", lambda cnt: "add_one" if cnt < 6 else END)
graph = builder.compile(checkpointer=checkpointer)
thread1 = {"configurable": {"thread_id": "1"}}
# start execution, stop at inbox
assert [*graph.stream(1, thread1, stream_mode=["values", "updates"])] == [
("values", 1),
("updates", {"add_one": 1}),
("values", 2),
("updates", {"add_one": 1}),
("values", 3),
("updates", {"add_one": 1}),
("values", 4),
("updates", {"add_one": 1}),
("values", 5),
("updates", {"add_one": 1}),
("values", 6),
]
# list history
history = [c for c in graph.get_state_history(thread1)]
assert history == [
StateSnapshot(
values=6,
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 5,
"writes": {"add_one": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[1].config,
),
StateSnapshot(
values=5,
tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),),
next=("add_one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 4,
"writes": {"add_one": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[2].config,
),
StateSnapshot(
values=4,
tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),),
next=("add_one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"add_one": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[3].config,
),
StateSnapshot(
values=3,
tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),),
next=("add_one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 2,
"writes": {"add_one": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[4].config,
),
StateSnapshot(
values=2,
tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),),
next=("add_one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"add_one": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[5].config,
),
StateSnapshot(
values=1,
tasks=(PregelTask(AnyStr(), "add_one", (PULL, "add_one"), result=1),),
next=("add_one",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=history[6].config,
),
StateSnapshot(
values=0,
tasks=(PregelTask(AnyStr(), "__start__", (PULL, "__start__"), result=1),),
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": 1},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
),
]
# forking from any previous checkpoint should re-run nodes
assert [
c for c in graph.stream(None, history[0].config, stream_mode="updates")
] == []
assert [
c for c in graph.stream(None, history[1].config, stream_mode="updates")
] == [
{"add_one": 1},
]
assert [
c for c in graph.stream(None, history[2].config, stream_mode="updates")
] == [
{"add_one": 1},
{"add_one": 1},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_run_from_checkpoint_id_retains_previous_writes(
request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
class MyState(TypedDict):
myval: Annotated[int, operator.add]
otherval: bool
class Anode:
def __init__(self):
self.switch = False
def __call__(self, state: MyState):
self.switch = not self.switch
return {"myval": 2 if self.switch else 1, "otherval": self.switch}
builder = StateGraph(MyState)
thenode = Anode() # Fun.
builder.add_node("node_one", thenode)
builder.add_node("node_two", thenode)
builder.add_edge(START, "node_one")
def _getedge(src: str):
swap = "node_one" if src == "node_two" else "node_two"
def _edge(st: MyState) -> Literal["__end__", "node_one", "node_two"]:
if st["myval"] > 3:
return END
if st["otherval"]:
return swap
return src
return _edge
builder.add_conditional_edges("node_one", _getedge("node_one"))
builder.add_conditional_edges("node_two", _getedge("node_two"))
graph = builder.compile(checkpointer=checkpointer)
thread_id = uuid.uuid4()
thread1 = {"configurable": {"thread_id": str(thread_id)}}
result = graph.invoke({"myval": 1}, thread1)
assert result["myval"] == 4
history = [c for c in graph.get_state_history(thread1)]
assert len(history) == 4
assert history[-1].values == {"myval": 0}
assert history[0].values == {"myval": 4, "otherval": False}
second_run_config = {
**thread1,
"configurable": {
**thread1["configurable"],
"checkpoint_id": history[1].config["configurable"]["checkpoint_id"],
},
}
second_result = graph.invoke(None, second_run_config)
assert second_result == {"myval": 5, "otherval": True}
new_history = [
c
for c in graph.get_state_history(
{"configurable": {"thread_id": str(thread_id), "checkpoint_ns": ""}}
)
]
assert len(new_history) == len(history) + 1
for original, new in zip(history, new_history[1:]):
assert original.values == new.values
assert original.next == new.next
assert original.metadata["step"] == new.metadata["step"]
def _get_tasks(hist: list, start: int):
return [h.tasks for h in hist[start:]]
assert _get_tasks(new_history, 1) == _get_tasks(history, 0)
def test_invoke_two_processes_in_dict_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
two = (
Channel.subscribe_to("inbox")
| RunnableLambda(add_one).batch
| RunnablePassthrough(lambda _: time.sleep(0.1))
| Channel.write_to("output").batch
)
app = Pregel(
nodes={"one": one, "two": two},
channels={
"inbox": Topic(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels=["input", "inbox"],
stream_channels=["output", "inbox"],
output_channels=["output"],
)
# [12 + 1, 2 + 1 + 1]
assert [
*app.stream(
{"input": 2, "inbox": 12}, output_keys="output", stream_mode="updates"
)
] == [
{"one": None},
{"two": 13},
{"two": 4},
]
assert [*app.stream({"input": 2, "inbox": 12}, output_keys="output")] == [
13,
4,
]
assert [*app.stream({"input": 2, "inbox": 12}, stream_mode="updates")] == [
{"one": {"inbox": 3}},
{"two": {"output": 13}},
{"two": {"output": 4}},
]
assert [*app.stream({"input": 2, "inbox": 12})] == [
{"inbox": [3], "output": 13},
{"output": 4},
]
assert [*app.stream({"input": 2, "inbox": 12}, stream_mode="debug")] == [
{
"type": "task",
"timestamp": AnyStr(),
"step": 0,
"payload": {
"id": AnyStr(),
"name": "one",
"input": 2,
"triggers": ["input"],
},
},
{
"type": "task",
"timestamp": AnyStr(),
"step": 0,
"payload": {
"id": AnyStr(),
"name": "two",
"input": [12],
"triggers": ["inbox"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 0,
"payload": {
"id": AnyStr(),
"name": "one",
"result": [("inbox", 3)],
"error": None,
"interrupts": [],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 0,
"payload": {
"id": AnyStr(),
"name": "two",
"result": [("output", 13)],
"error": None,
"interrupts": [],
},
},
{
"type": "task",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "two",
"input": [3],
"triggers": ["inbox"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "two",
"result": [("output", 4)],
"error": None,
"interrupts": [],
},
},
]
def test_batch_two_processes_in_out() -> None:
def add_one_with_delay(inp: int) -> int:
time.sleep(inp / 10)
return inp + 1
one = Channel.subscribe_to("input") | add_one_with_delay | Channel.write_to("one")
two = Channel.subscribe_to("one") | add_one_with_delay | Channel.write_to("output")
app = Pregel(
nodes={"one": one, "two": two},
channels={
"one": LastValue(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
)
assert app.batch([3, 2, 1, 3, 5]) == [5, 4, 3, 5, 7]
assert app.batch([3, 2, 1, 3, 5], output_keys=["output"]) == [
{"output": 5},
{"output": 4},
{"output": 3},
{"output": 5},
{"output": 7},
]
graph = Graph()
graph.add_node("add_one", add_one_with_delay)
graph.add_node("add_one_more", add_one_with_delay)
graph.set_entry_point("add_one")
graph.set_finish_point("add_one_more")
graph.add_edge("add_one", "add_one_more")
gapp = graph.compile()
assert gapp.batch([3, 2, 1, 3, 5]) == [5, 4, 3, 5, 7]
def test_invoke_many_processes_in_out(mocker: MockerFixture) -> None:
test_size = 100
add_one = mocker.Mock(side_effect=lambda x: x + 1)
nodes = {"-1": Channel.subscribe_to("input") | add_one | Channel.write_to("-1")}
for i in range(test_size - 2):
nodes[str(i)] = (
Channel.subscribe_to(str(i - 1)) | add_one | Channel.write_to(str(i))
)
nodes["last"] = Channel.subscribe_to(str(i)) | add_one | Channel.write_to("output")
app = Pregel(
nodes=nodes,
channels={str(i): LastValue(int) for i in range(-1, test_size - 2)}
| {"input": LastValue(int), "output": LastValue(int)},
input_channels="input",
output_channels="output",
)
for _ in range(10):
assert app.invoke(2, {"recursion_limit": test_size}) == 2 + test_size
with ThreadPoolExecutor() as executor:
assert [
*executor.map(app.invoke, [2] * 10, [{"recursion_limit": test_size}] * 10)
] == [2 + test_size] * 10
def test_batch_many_processes_in_out(mocker: MockerFixture) -> None:
test_size = 100
add_one = mocker.Mock(side_effect=lambda x: x + 1)
nodes = {"-1": Channel.subscribe_to("input") | add_one | Channel.write_to("-1")}
for i in range(test_size - 2):
nodes[str(i)] = (
Channel.subscribe_to(str(i - 1)) | add_one | Channel.write_to(str(i))
)
nodes["last"] = Channel.subscribe_to(str(i)) | add_one | Channel.write_to("output")
app = Pregel(
nodes=nodes,
channels={str(i): LastValue(int) for i in range(-1, test_size - 2)}
| {"input": LastValue(int), "output": LastValue(int)},
input_channels="input",
output_channels="output",
)
for _ in range(3):
assert app.batch([2, 1, 3, 4, 5], {"recursion_limit": test_size}) == [
2 + test_size,
1 + test_size,
3 + test_size,
4 + test_size,
5 + test_size,
]
with ThreadPoolExecutor() as executor:
assert [
*executor.map(
app.batch, [[2, 1, 3, 4, 5]] * 3, [{"recursion_limit": test_size}] * 3
)
] == [
[2 + test_size, 1 + test_size, 3 + test_size, 4 + test_size, 5 + test_size]
] * 3
def test_invoke_two_processes_two_in_two_out_invalid(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
two = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": one, "two": two},
channels={"output": LastValue(int), "input": LastValue(int)},
input_channels="input",
output_channels="output",
)
with pytest.raises(InvalidUpdateError):
# LastValue channels can only be updated once per iteration
app.invoke(2)
class State(TypedDict):
hello: str
def my_node(input: State) -> State:
return {"hello": "world"}
builder = StateGraph(State)
builder.add_node("one", my_node)
builder.add_node("two", my_node)
builder.set_conditional_entry_point(lambda _: ["one", "two"])
graph = builder.compile()
with pytest.raises(InvalidUpdateError, match="At key 'hello'"):
graph.invoke({"hello": "there"}, debug=True)
def test_invoke_two_processes_two_in_two_out_valid(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
two = Channel.subscribe_to("input") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": one, "two": two},
channels={
"input": LastValue(int),
"output": Topic(int),
},
input_channels="input",
output_channels="output",
)
# An Inbox channel accumulates updates into a sequence
assert app.invoke(2) == [3, 3]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_invoke_checkpoint_two(
mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer: BaseCheckpointSaver = request.getfixturevalue(
f"checkpointer_{checkpointer_name}"
)
add_one = mocker.Mock(side_effect=lambda x: x["total"] + x["input"])
errored_once = False
def raise_if_above_10(input: int) -> int:
nonlocal errored_once
if input > 4:
if errored_once:
pass
else:
errored_once = True
raise ConnectionError("I will be retried")
if input > 10:
raise ValueError("Input is too large")
return input
one = (
Channel.subscribe_to(["input"]).join(["total"])
| add_one
| Channel.write_to("output", "total")
| raise_if_above_10
)
app = Pregel(
nodes={"one": one},
channels={
"total": BinaryOperatorAggregate(int, operator.add),
"input": LastValue(int),
"output": LastValue(int),
},
input_channels="input",
output_channels="output",
checkpointer=checkpointer,
retry_policy=RetryPolicy(),
)
# total starts out as 0, so output is 0+2=2
assert app.invoke(2, {"configurable": {"thread_id": "1"}}) == 2
checkpoint = checkpointer.get({"configurable": {"thread_id": "1"}})
assert checkpoint is not None
assert checkpoint["channel_values"].get("total") == 2
# total is now 2, so output is 2+3=5
assert app.invoke(3, {"configurable": {"thread_id": "1"}}) == 5
assert errored_once, "errored and retried"
checkpoint_tup = checkpointer.get_tuple({"configurable": {"thread_id": "1"}})
assert checkpoint_tup is not None
assert checkpoint_tup.checkpoint["channel_values"].get("total") == 7
# total is now 2+5=7, so output would be 7+4=11, but raises ValueError
with pytest.raises(ValueError):
app.invoke(4, {"configurable": {"thread_id": "1"}})
# checkpoint is not updated, error is recorded
checkpoint_tup = checkpointer.get_tuple({"configurable": {"thread_id": "1"}})
assert checkpoint_tup is not None
assert checkpoint_tup.checkpoint["channel_values"].get("total") == 7
assert checkpoint_tup.pending_writes == [
(AnyStr(), ERROR, "ValueError('Input is too large')")
]
# on a new thread, total starts out as 0, so output is 0+5=5
assert app.invoke(5, {"configurable": {"thread_id": "2"}}) == 5
checkpoint = checkpointer.get({"configurable": {"thread_id": "1"}})
assert checkpoint is not None
assert checkpoint["channel_values"].get("total") == 7
checkpoint = checkpointer.get({"configurable": {"thread_id": "2"}})
assert checkpoint is not None
assert checkpoint["channel_values"].get("total") == 5
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_pending_writes_resume(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer: BaseCheckpointSaver = request.getfixturevalue(
f"checkpointer_{checkpointer_name}"
)
class State(TypedDict):
value: Annotated[int, operator.add]
class AwhileMaker:
def __init__(self, sleep: float, rtn: Union[Dict, Exception]) -> None:
self.sleep = sleep
self.rtn = rtn
self.reset()
def __call__(self, input: State) -> Any:
self.calls += 1
time.sleep(self.sleep)
if isinstance(self.rtn, Exception):
raise self.rtn
else:
return self.rtn
def reset(self):
self.calls = 0
one = AwhileMaker(0.1, {"value": 2})
two = AwhileMaker(0.3, ConnectionError("I'm not good"))
builder = StateGraph(State)
builder.add_node("one", one)
builder.add_node("two", two, retry=RetryPolicy(max_attempts=2))
builder.add_edge(START, "one")
builder.add_edge(START, "two")
graph = builder.compile(checkpointer=checkpointer)
thread1: RunnableConfig = {"configurable": {"thread_id": "1"}}
with pytest.raises(ConnectionError, match="I'm not good"):
graph.invoke({"value": 1}, thread1)
# both nodes should have been called once
assert one.calls == 1
assert two.calls == 2 # two attempts
# latest checkpoint should be before nodes "one", "two"
# but we should have applied the write from "one"
state = graph.get_state(thread1)
assert state is not None
assert state.values == {"value": 3}
assert state.next == ("two",)
assert state.tasks == (
PregelTask(AnyStr(), "one", (PULL, "one"), result={"value": 2}),
PregelTask(AnyStr(), "two", (PULL, "two"), 'ConnectionError("I\'m not good")'),
)
assert state.metadata == {
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
}
# get_state with checkpoint_id should not apply any pending writes
state = graph.get_state(state.config)
assert state is not None
assert state.values == {"value": 1}
assert state.next == ("one", "two")
# should contain pending write of "one"
checkpoint = checkpointer.get_tuple(thread1)
assert checkpoint is not None
# should contain error from "two"
expected_writes = [
(AnyStr(), "one", "one"),
(AnyStr(), "value", 2),
(AnyStr(), ERROR, 'ConnectionError("I\'m not good")'),
]
assert len(checkpoint.pending_writes) == 3
assert all(w in expected_writes for w in checkpoint.pending_writes)
# both non-error pending writes come from same task
non_error_writes = [w for w in checkpoint.pending_writes if w[1] != ERROR]
assert non_error_writes[0][0] == non_error_writes[1][0]
# error write is from the other task
error_write = next(w for w in checkpoint.pending_writes if w[1] == ERROR)
assert error_write[0] != non_error_writes[0][0]
# resume execution
with pytest.raises(ConnectionError, match="I'm not good"):
graph.invoke(None, thread1)
# node "one" succeeded previously, so shouldn't be called again
assert one.calls == 1
# node "two" should have been called once again
assert two.calls == 4 # two attempts before + two attempts now
# confirm no new checkpoints saved
state_two = graph.get_state(thread1)
assert state_two.metadata == state.metadata
# resume execution, without exception
two.rtn = {"value": 3}
# both the pending write and the new write were applied, 1 + 2 + 3 = 6
assert graph.invoke(None, thread1) == {"value": 6}
# check all final checkpoints
checkpoints = [c for c in checkpointer.list(thread1)]
# we should have 3
assert len(checkpoints) == 3
# the last one not too interesting for this test
assert checkpoints[0] == CheckpointTuple(
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
checkpoint={
"v": 1,
"id": AnyStr(),
"ts": AnyStr(),
"pending_sends": [],
"versions_seen": {
"one": {
"start:one": AnyVersion(),
},
"two": {
"start:two": AnyVersion(),
},
"__input__": {},
"__start__": {
"__start__": AnyVersion(),
},
"__interrupt__": {
"value": AnyVersion(),
"__start__": AnyVersion(),
"start:one": AnyVersion(),
"start:two": AnyVersion(),
},
},
"channel_versions": {
"one": AnyVersion(),
"two": AnyVersion(),
"value": AnyVersion(),
"__start__": AnyVersion(),
"start:one": AnyVersion(),
"start:two": AnyVersion(),
},
"channel_values": {"one": "one", "two": "two", "value": 6},
},
metadata={
"parents": {},
"step": 1,
"source": "loop",
"writes": {"one": {"value": 2}, "two": {"value": 3}},
"thread_id": "1",
},
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": checkpoints[1].config["configurable"]["checkpoint_id"],
}
},
pending_writes=[],
)
# the previous one we assert that pending writes contains both
# - original error
# - successful writes from resuming after preventing error
assert checkpoints[1] == CheckpointTuple(
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
checkpoint={
"v": 1,
"id": AnyStr(),
"ts": AnyStr(),
"pending_sends": [],
"versions_seen": {
"__input__": {},
"__start__": {
"__start__": AnyVersion(),
},
},
"channel_versions": {
"value": AnyVersion(),
"__start__": AnyVersion(),
"start:one": AnyVersion(),
"start:two": AnyVersion(),
},
"channel_values": {
"value": 1,
"start:one": "__start__",
"start:two": "__start__",
},
},
metadata={
"parents": {},
"step": 0,
"source": "loop",
"writes": None,
"thread_id": "1",
},
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": checkpoints[2].config["configurable"]["checkpoint_id"],
}
},
pending_writes=UnsortedSequence(
(AnyStr(), "one", "one"),
(AnyStr(), "value", 2),
(AnyStr(), "__error__", 'ConnectionError("I\'m not good")'),
(AnyStr(), "two", "two"),
(AnyStr(), "value", 3),
),
)
assert checkpoints[2] == CheckpointTuple(
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
checkpoint={
"v": 1,
"id": AnyStr(),
"ts": AnyStr(),
"pending_sends": [],
"versions_seen": {"__input__": {}},
"channel_versions": {
"__start__": AnyVersion(),
},
"channel_values": {"__start__": {"value": 1}},
},
metadata={
"parents": {},
"step": -1,
"source": "input",
"writes": {"__start__": {"value": 1}},
"thread_id": "1",
},
parent_config=None,
pending_writes=UnsortedSequence(
(AnyStr(), "value", 1),
(AnyStr(), "start:one", "__start__"),
(AnyStr(), "start:two", "__start__"),
),
)
def test_cond_edge_after_send() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)
def __call__(self, state):
return [self.name]
def send_for_fun(state):
return [Send("2", state), Send("2", state)]
def route_to_three(state) -> Literal["3"]:
return "3"
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert graph.invoke(["0"]) == ["0", "1", "2", "2", "3"]
def test_concurrent_emit_sends() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)
def __call__(self, state):
return (
[self.name]
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
def send_for_fun(state):
return [Send("2", 1), Send("2", 2), "3.1"]
def send_for_profit(state):
return [Send("2", 3), Send("2", 4)]
def route_to_three(state) -> Literal["3"]:
return "3"
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("1.1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_edge(START, "1")
builder.add_edge(START, "1.1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("1.1", send_for_profit)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert graph.invoke(["0"]) == (
[
"0",
"1",
"1.1",
"2|1",
"2|2",
"2|3",
"2|4",
"3",
"3.1",
]
if FF_SEND_V2
else [
"0",
"1",
"1.1",
"3.1",
"2|1",
"2|2",
"2|3",
"2|4",
"3",
]
)
def test_send_sequences() -> None:
class Node:
def __init__(self, name: str):
self.name = name
setattr(self, "__name__", name)
def __call__(self, state):
update = (
[self.name]
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Command):
return replace(state, update=update)
else:
return update
def send_for_fun(state):
return [
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("2", 4))),
"3.1",
]
def route_to_three(state) -> Literal["3"]:
return "3"
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile()
assert (
graph.invoke(["0"])
== [
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"2|3",
"2|4",
"3",
"3.1",
]
if FF_SEND_V2
else [
"0",
"1",
"3.1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"3",
"2|3",
"2|4",
"3",
]
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_send_dedupe_on_resume(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
if not FF_SEND_V2:
pytest.skip("Send deduplication is only available in Send V2")
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
class InterruptOnce:
ticks: int = 0
def __call__(self, state):
self.ticks += 1
if self.ticks == 1:
raise NodeInterrupt("Bahh")
return ["|".join(("flaky", str(state)))]
class Node:
def __init__(self, name: str):
self.name = name
self.ticks = 0
setattr(self, "__name__", name)
def __call__(self, state):
self.ticks += 1
update = (
[self.name]
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Command):
return replace(state, update=update)
else:
return update
def send_for_fun(state):
return [
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("flaky", 4))),
"3.1",
]
def route_to_three(state) -> Literal["3"]:
return "3"
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_node("flaky", InterruptOnce())
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
graph = builder.compile(checkpointer=checkpointer)
thread1 = {"configurable": {"thread_id": "1"}}
assert graph.invoke(["0"], thread1, debug=1) == [
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
]
assert builder.nodes["2"].runnable.func.ticks == 3
assert builder.nodes["flaky"].runnable.func.ticks == 1
# check state
state = graph.get_state(thread1)
assert state.next == ("flaky",)
# check history
history = [c for c in graph.get_state_history(thread1)]
assert len(history) == 2
# resume execution
assert graph.invoke(None, thread1, debug=1) == [
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
"3.1",
]
# node "2" doesn't get called again, as we recover writes saved before
assert builder.nodes["2"].runnable.func.ticks == 3
# node "flaky" gets called again, as it was interrupted
assert builder.nodes["flaky"].runnable.func.ticks == 2
# check state
state = graph.get_state(thread1)
assert state.next == ()
# check history
history = [c for c in graph.get_state_history(thread1)]
assert (
history[1]
== [
StateSnapshot(
values=[
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
"3.1",
],
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {"3": ["3"], "3.1": ["3.1"]},
"thread_id": "1",
"step": 2,
"parents": {},
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(),
),
StateSnapshot(
values=[
"0",
"1",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
],
next=("3", "3.1"),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {
"1": ["1"],
"2": [
["2|Command(goto=Send(node='2', arg=3))"],
["2|Command(goto=Send(node='flaky', arg=4))"],
["2|3"],
],
"flaky": ["flaky|4"],
},
"thread_id": "1",
"step": 1,
"parents": {},
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="3",
path=("__pregel_pull", "3"),
error=None,
interrupts=(),
state=None,
result=["3"],
),
PregelTask(
id=AnyStr(),
name="3.1",
path=("__pregel_pull", "3.1"),
error=None,
interrupts=(),
state=None,
result=["3.1"],
),
),
),
StateSnapshot(
values=["0"],
next=("1", "2", "2", "2", "flaky"),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": None,
"thread_id": "1",
"step": 0,
"parents": {},
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="1",
path=("__pregel_pull", "1"),
error=None,
interrupts=(),
state=None,
result=["1"],
),
PregelTask(
id=AnyStr(),
name="2",
path=(
"__pregel_push",
("__pregel_pull", "1"),
2,
AnyStr(),
),
error=None,
interrupts=(),
state=None,
result=["2|Command(goto=Send(node='2', arg=3))"],
),
PregelTask(
id=AnyStr(),
name="2",
path=(
"__pregel_push",
("__pregel_pull", "1"),
3,
AnyStr(),
),
error=None,
interrupts=(),
state=None,
result=["2|Command(goto=Send(node='flaky', arg=4))"],
),
PregelTask(
id=AnyStr(),
name="2",
path=(
"__pregel_push",
(
"__pregel_push",
("__pregel_pull", "1"),
2,
AnyStr(),
),
2,
AnyStr(),
),
error=None,
interrupts=(),
state=None,
result=["2|3"],
),
PregelTask(
id=AnyStr(),
name="flaky",
path=(
"__pregel_push",
(
"__pregel_push",
("__pregel_pull", "1"),
3,
AnyStr(),
),
2,
AnyStr(),
),
error=None,
interrupts=(Interrupt(value="Bahh", when="during"),),
state=None,
result=["flaky|4"],
),
),
),
StateSnapshot(
values=[],
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "input",
"writes": {"__start__": ["0"]},
"thread_id": "1",
"step": -1,
"parents": {},
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
id=AnyStr(),
name="__start__",
path=("__pregel_pull", "__start__"),
error=None,
interrupts=(),
state=None,
result=["0"],
),
),
),
][1]
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_send_react_interrupt(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
ai_message = AIMessage(
"",
id="ai1",
tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())],
)
def agent(state):
return {"messages": ai_message}
def route(state):
if isinstance(state["messages"][-1], AIMessage):
return [
Send(call["name"], call) for call in state["messages"][-1].tool_calls
]
foo_called = 0
def foo(call: ToolCall):
nonlocal foo_called
foo_called += 1
return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])}
builder = StateGraph(MessagesState)
builder.add_node(agent)
builder.add_node(foo)
builder.add_edge(START, "agent")
builder.add_conditional_edges("agent", route)
graph = builder.compile()
assert graph.invoke({"messages": [HumanMessage("hello")]}) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
_AnyIdToolMessage(
content="{'hi': [1, 2, 3]}",
tool_call_id=AnyStr(),
),
]
}
assert foo_called == 1
# simple interrupt-resume flow
foo_called = 0
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"])
thread1 = {"configurable": {"thread_id": "1"}}
assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
}
assert foo_called == 0
assert graph.invoke(None, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
_AnyIdToolMessage(
content="{'hi': [1, 2, 3]}",
tool_call_id=AnyStr(),
),
]
}
assert foo_called == 1
# interrupt-update-resume flow
foo_called = 0
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"])
thread1 = {"configurable": {"thread_id": "2"}}
assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
}
assert foo_called == 0
if not FF_SEND_V2:
return
# get state should show the pending task
state = graph.get_state(thread1)
assert state == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
},
next=("foo",),
config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 0,
"source": "loop",
"writes": None,
"parents": {},
"thread_id": "2",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
content="",
additional_kwargs={},
response_metadata={},
id="ai1",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
)
},
),
PregelTask(
id=AnyStr(),
name="foo",
path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()),
error=None,
interrupts=(),
state=None,
result=None,
),
),
)
# remove the tool call, clearing the pending task
graph.update_state(
thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])}
)
# tool call no longer in pending tasks
assert graph.get_state(thread1) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="Bye now",
tool_calls=[],
),
]
},
next=(),
config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 1,
"source": "update",
"writes": {
"agent": {
"messages": _AnyIdAIMessage(
content="Bye now",
tool_calls=[],
)
}
},
"parents": {},
"thread_id": "2",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(),
)
# tool call not executed
assert graph.invoke(None, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(content="Bye now"),
]
}
assert foo_called == 0
# interrupt-update-resume flow, creating new Send in update call
foo_called = 0
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"])
thread1 = {"configurable": {"thread_id": "3"}}
assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
}
assert foo_called == 0
# get state should show the pending task
state = graph.get_state(thread1)
assert state == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
},
next=("foo",),
config={
"configurable": {
"thread_id": "3",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 0,
"source": "loop",
"writes": None,
"parents": {},
"thread_id": "3",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "3",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
"",
id="ai1",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
)
},
),
PregelTask(
id=AnyStr(),
name="foo",
path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()),
error=None,
interrupts=(),
state=None,
result=None,
),
),
)
# replace the tool call, should clear previous send, create new one
graph.update_state(
thread1,
{
"messages": AIMessage(
"",
id=ai_message.id,
tool_calls=[
{
"name": "foo",
"args": {"hi": [4, 5, 6]},
"id": "tool1",
"type": "tool_call",
}
],
)
},
)
# prev tool call no longer in pending tasks, new tool call is
assert graph.get_state(thread1) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [4, 5, 6]},
"id": "tool1",
"type": "tool_call",
}
],
),
]
},
next=("foo",),
config={
"configurable": {
"thread_id": "3",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 1,
"source": "update",
"writes": {
"agent": {
"messages": _AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [4, 5, 6]},
"id": "tool1",
"type": "tool_call",
}
],
)
}
},
"parents": {},
"thread_id": "3",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "3",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="foo",
path=("__pregel_push", (), 0, AnyStr()),
error=None,
interrupts=(),
state=None,
result=None,
),
),
)
# prev tool call not executed, new tool call is
assert graph.invoke(None, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
AIMessage(
"",
id="ai1",
tool_calls=[
{
"name": "foo",
"args": {"hi": [4, 5, 6]},
"id": "tool1",
"type": "tool_call",
}
],
),
_AnyIdToolMessage(content="{'hi': [4, 5, 6]}", tool_call_id="tool1"),
]
}
assert foo_called == 1
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_send_react_interrupt_control(
request: pytest.FixtureRequest, checkpointer_name: str, snapshot: SnapshotAssertion
) -> None:
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
ai_message = AIMessage(
"",
id="ai1",
tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())],
)
def agent(state) -> Command[Literal["foo"]]:
return Command(
update={"messages": ai_message},
goto=[Send(call["name"], call) for call in ai_message.tool_calls],
)
foo_called = 0
def foo(call: ToolCall):
nonlocal foo_called
foo_called += 1
return {"messages": ToolMessage(str(call["args"]), tool_call_id=call["id"])}
builder = StateGraph(MessagesState)
builder.add_node(agent)
builder.add_node(foo)
builder.add_edge(START, "agent")
graph = builder.compile()
assert graph.get_graph().draw_mermaid() == snapshot
assert graph.invoke({"messages": [HumanMessage("hello")]}) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
_AnyIdToolMessage(
content="{'hi': [1, 2, 3]}",
tool_call_id=AnyStr(),
),
]
}
assert foo_called == 1
# simple interrupt-resume flow
foo_called = 0
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"])
thread1 = {"configurable": {"thread_id": "1"}}
assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
}
assert foo_called == 0
assert graph.invoke(None, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
_AnyIdToolMessage(
content="{'hi': [1, 2, 3]}",
tool_call_id=AnyStr(),
),
]
}
assert foo_called == 1
if not FF_SEND_V2:
return
# interrupt-update-resume flow
foo_called = 0
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["foo"])
thread1 = {"configurable": {"thread_id": "2"}}
assert graph.invoke({"messages": [HumanMessage("hello")]}, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
}
assert foo_called == 0
# get state should show the pending task
state = graph.get_state(thread1)
assert state == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
),
]
},
next=("foo",),
config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 0,
"source": "loop",
"writes": None,
"parents": {},
"thread_id": "2",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
content="",
additional_kwargs={},
response_metadata={},
id="ai1",
tool_calls=[
{
"name": "foo",
"args": {"hi": [1, 2, 3]},
"id": "",
"type": "tool_call",
}
],
)
},
),
PregelTask(
id=AnyStr(),
name="foo",
path=("__pregel_push", ("__pregel_pull", "agent"), 2, AnyStr()),
error=None,
interrupts=(),
state=None,
result=None,
),
),
)
# remove the tool call, clearing the pending task
graph.update_state(
thread1, {"messages": AIMessage("Bye now", id=ai_message.id, tool_calls=[])}
)
# tool call no longer in pending tasks
assert graph.get_state(thread1) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(
content="Bye now",
tool_calls=[],
),
]
},
next=(),
config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"step": 1,
"source": "update",
"writes": {
"agent": {
"messages": _AnyIdAIMessage(
content="Bye now",
tool_calls=[],
)
}
},
"parents": {},
"thread_id": "2",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "2",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(),
)
# tool call not executed
assert graph.invoke(None, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
_AnyIdAIMessage(content="Bye now"),
]
}
assert foo_called == 0
# interrupt-update-resume flow, creating new Send in update call
# TODO add here test with invoke(Command())
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_invoke_checkpoint_three(
mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
adder = mocker.Mock(side_effect=lambda x: x["total"] + x["input"])
def raise_if_above_10(input: int) -> int:
if input > 10:
raise ValueError("Input is too large")
return input
one = (
Channel.subscribe_to(["input"]).join(["total"])
| adder
| Channel.write_to("output", "total")
| raise_if_above_10
)
app = Pregel(
nodes={"one": one},
channels={
"total": BinaryOperatorAggregate(int, operator.add),
"input": LastValue(int),
"output": LastValue(int),
},
input_channels="input",
output_channels="output",
checkpointer=checkpointer,
)
thread_1 = {"configurable": {"thread_id": "1"}}
# total starts out as 0, so output is 0+2=2
assert app.invoke(2, thread_1, debug=1) == 2
state = app.get_state(thread_1)
assert state is not None
assert state.values.get("total") == 2
assert state.next == ()
assert (
state.config["configurable"]["checkpoint_id"]
== checkpointer.get(thread_1)["id"]
)
# total is now 2, so output is 2+3=5
assert app.invoke(3, thread_1) == 5
state = app.get_state(thread_1)
assert state is not None
assert state.values.get("total") == 7
assert (
state.config["configurable"]["checkpoint_id"]
== checkpointer.get(thread_1)["id"]
)
# total is now 2+5=7, so output would be 7+4=11, but raises ValueError
with pytest.raises(ValueError):
app.invoke(4, thread_1)
# checkpoint is updated with new input
state = app.get_state(thread_1)
assert state is not None
assert state.values.get("total") == 7
assert state.next == ("one",)
"""we checkpoint inputs and it failed on "one", so the next node is one"""
# we can recover from error by sending new inputs
assert app.invoke(2, thread_1) == 9
state = app.get_state(thread_1)
assert state is not None
assert state.values.get("total") == 16, "total is now 7+9=16"
assert state.next == ()
thread_2 = {"configurable": {"thread_id": "2"}}
# on a new thread, total starts out as 0, so output is 0+5=5
assert app.invoke(5, thread_2, debug=True) == 5
state = app.get_state({"configurable": {"thread_id": "1"}})
assert state is not None
assert state.values.get("total") == 16
assert state.next == (), "checkpoint of other thread not touched"
state = app.get_state(thread_2)
assert state is not None
assert state.values.get("total") == 5
assert state.next == ()
assert len(list(app.get_state_history(thread_1, limit=1))) == 1
# list all checkpoints for thread 1
thread_1_history = [c for c in app.get_state_history(thread_1)]
# there are 7 checkpoints
assert len(thread_1_history) == 7
assert Counter(c.metadata["source"] for c in thread_1_history) == {
"input": 4,
"loop": 3,
}
# sorted descending
assert (
thread_1_history[0].config["configurable"]["checkpoint_id"]
> thread_1_history[1].config["configurable"]["checkpoint_id"]
)
# cursor pagination
cursored = list(
app.get_state_history(thread_1, limit=1, before=thread_1_history[0].config)
)
assert len(cursored) == 1
assert cursored[0].config == thread_1_history[1].config
# the last checkpoint
assert thread_1_history[0].values["total"] == 16
# the first "loop" checkpoint
assert thread_1_history[-2].values["total"] == 2
# can get each checkpoint using aget with config
assert (
checkpointer.get(thread_1_history[0].config)["id"]
== thread_1_history[0].config["configurable"]["checkpoint_id"]
)
assert (
checkpointer.get(thread_1_history[1].config)["id"]
== thread_1_history[1].config["configurable"]["checkpoint_id"]
)
thread_1_next_config = app.update_state(thread_1_history[1].config, 10)
# update creates a new checkpoint
assert (
thread_1_next_config["configurable"]["checkpoint_id"]
> thread_1_history[0].config["configurable"]["checkpoint_id"]
)
# update makes new checkpoint child of the previous one
assert (
app.get_state(thread_1_next_config).parent_config == thread_1_history[1].config
)
# 1 more checkpoint in history
assert len(list(app.get_state_history(thread_1))) == 8
assert Counter(c.metadata["source"] for c in app.get_state_history(thread_1)) == {
"update": 1,
"input": 4,
"loop": 3,
}
# the latest checkpoint is the updated one
assert app.get_state(thread_1) == app.get_state(thread_1_next_config)
def test_invoke_two_processes_two_in_join_two_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
add_10_each = mocker.Mock(side_effect=lambda x: sorted(y + 10 for y in x))
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
chain_three = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
chain_four = (
Channel.subscribe_to("inbox") | add_10_each | Channel.write_to("output")
)
app = Pregel(
nodes={
"one": one,
"chain_three": chain_three,
"chain_four": chain_four,
},
channels={
"inbox": Topic(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
)
# Then invoke app
# We get a single array result as chain_four waits for all publishers to finish
# before operating on all elements published to topic_two as an array
for _ in range(100):
assert app.invoke(2) == [13, 13]
with ThreadPoolExecutor() as executor:
assert [*executor.map(app.invoke, [2] * 100)] == [[13, 13]] * 100
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_invoke_join_then_call_other_pregel(
mocker: MockerFixture, request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
add_one = mocker.Mock(side_effect=lambda x: x + 1)
add_10_each = mocker.Mock(side_effect=lambda x: [y + 10 for y in x])
inner_app = Pregel(
nodes={
"one": Channel.subscribe_to("input") | add_one | Channel.write_to("output")
},
channels={
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
)
one = (
Channel.subscribe_to("input")
| add_10_each
| Channel.write_to("inbox_one").map()
)
two = (
Channel.subscribe_to("inbox_one")
| inner_app.map()
| sorted
| Channel.write_to("outbox_one")
)
chain_three = Channel.subscribe_to("outbox_one") | sum | Channel.write_to("output")
app = Pregel(
nodes={
"one": one,
"two": two,
"chain_three": chain_three,
},
channels={
"inbox_one": Topic(int),
"outbox_one": LastValue(int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels="output",
)
for _ in range(10):
assert app.invoke([2, 3]) == 27
with ThreadPoolExecutor() as executor:
assert [*executor.map(app.invoke, [[2, 3]] * 10)] == [27] * 10
# add checkpointer
app.checkpointer = checkpointer
# subgraph is called twice in the same node, through .map(), so raises
with pytest.raises(MultipleSubgraphsError):
app.invoke([2, 3], {"configurable": {"thread_id": "1"}})
# set inner graph checkpointer NeverCheckpoint
inner_app.checkpointer = False
# subgraph still called twice, but checkpointing for inner graph is disabled
assert app.invoke([2, 3], {"configurable": {"thread_id": "1"}}) == 27
def test_invoke_two_processes_one_in_two_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = (
Channel.subscribe_to("input") | add_one | Channel.write_to("output", "between")
)
two = Channel.subscribe_to("between") | add_one | Channel.write_to("output")
app = Pregel(
nodes={"one": one, "two": two},
channels={
"input": LastValue(int),
"between": LastValue(int),
"output": LastValue(int),
},
stream_channels=["output", "between"],
input_channels="input",
output_channels="output",
)
assert [c for c in app.stream(2, stream_mode="updates")] == [
{"one": {"between": 3, "output": 3}},
{"two": {"output": 4}},
]
assert [c for c in app.stream(2)] == [
{"between": 3, "output": 3},
{"between": 3, "output": 4},
]
def test_invoke_two_processes_no_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("between")
two = Channel.subscribe_to("between") | add_one
app = Pregel(
nodes={"one": one, "two": two},
channels={
"input": LastValue(int),
"between": LastValue(int),
"output": LastValue(int),
},
input_channels="input",
output_channels="output",
)
# It finishes executing (once no more messages being published)
# but returns nothing, as nothing was published to OUT topic
assert app.invoke(2) is None
def test_invoke_two_processes_no_in(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("between") | add_one | Channel.write_to("output")
two = Channel.subscribe_to("between") | add_one
with pytest.raises(TypeError):
Pregel(nodes={"one": one, "two": two})
def test_channel_enter_exit_timing(mocker: MockerFixture) -> None:
setup = mocker.Mock()
cleanup = mocker.Mock()
@contextmanager
def an_int() -> Generator[int, None, None]:
setup()
try:
yield 5
finally:
cleanup()
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
two = (
Channel.subscribe_to("inbox")
| RunnableLambda(add_one).batch
| Channel.write_to("output").batch
)
app = Pregel(
nodes={"one": one, "two": two},
channels={
"inbox": Topic(int),
"ctx": Context(an_int),
"output": LastValue(int),
"input": LastValue(int),
},
input_channels="input",
output_channels=["inbox", "output"],
stream_channels=["inbox", "output"],
)
assert setup.call_count == 0
assert cleanup.call_count == 0
for i, chunk in enumerate(app.stream(2)):
assert setup.call_count == 1, "Expected setup to be called once"
if i == 0:
assert chunk == {"inbox": [3]}
elif i == 1:
assert chunk == {"output": 4}
else:
assert False, "Expected only two chunks"
assert cleanup.call_count == 1, "Expected cleanup to be called once"
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_conditional_graph(
snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
from langchain_core.language_models.fake import FakeStreamingListLLM
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.tools import tool
checkpointer: BaseCheckpointSaver = request.getfixturevalue(
f"checkpointer_{checkpointer_name}"
)
# Assemble the tools
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
# Construct the agent
prompt = PromptTemplate.from_template("Hello!")
llm = FakeStreamingListLLM(
responses=[
"tool:search_api:query",
"tool:search_api:another",
"finish:answer",
]
)
def agent_parser(input: str) -> Union[AgentAction, AgentFinish]:
if input.startswith("finish"):
_, answer = input.split(":")
return AgentFinish(return_values={"answer": answer}, log=input)
else:
_, tool_name, tool_input = input.split(":")
return AgentAction(tool=tool_name, tool_input=tool_input, log=input)
agent = RunnablePassthrough.assign(agent_outcome=prompt | llm | agent_parser)
# Define tool execution logic
def execute_tools(data: dict) -> dict:
data = data.copy()
agent_action: AgentAction = data.pop("agent_outcome")
observation = {t.name: t for t in tools}[agent_action.tool].invoke(
agent_action.tool_input
)
if data.get("intermediate_steps") is None:
data["intermediate_steps"] = []
else:
data["intermediate_steps"] = data["intermediate_steps"].copy()
data["intermediate_steps"].append([agent_action, observation])
return data
# Define decision-making logic
def should_continue(data: dict) -> str:
# Logic to decide whether to continue in the loop or exit
if isinstance(data["agent_outcome"], AgentFinish):
return "exit"
else:
return "continue"
# Define a new graph
workflow = Graph()
workflow.add_node("agent", agent)
workflow.add_node(
"tools",
execute_tools,
metadata={"parents": {}, "version": 2, "variant": "b"},
)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent", should_continue, {"continue": "tools", "exit": END}
)
workflow.add_edge("tools", "agent")
app = workflow.compile()
if SHOULD_CHECK_SNAPSHOTS:
assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.get_graph().draw_mermaid() == snapshot
assert json.dumps(app.get_graph(xray=True).to_json(), indent=2) == snapshot
assert app.get_graph(xray=True).draw_mermaid(with_styles=False) == snapshot
assert app.invoke({"input": "what is weather in sf"}) == {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
}
assert [c for c in app.stream({"input": "what is weather in sf"})] == [
{
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
}
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
}
},
]
# test state get/update methods with interrupt_after
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["agent"],
)
config = {"configurable": {"thread_id": "1"}}
if SHOULD_CHECK_SNAPSHOTS:
assert app_w_interrupt.get_graph().to_json() == snapshot
assert app_w_interrupt.get_graph().draw_mermaid() == snapshot
assert [
c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config)
] == [
{
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
}
}
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
},
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
config=app_w_interrupt.checkpointer.get_tuple(config).config,
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {
"agent": {
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert (
app_w_interrupt.checkpointer.get_tuple(config).config["configurable"][
"checkpoint_id"
]
is not None
)
app_w_interrupt.update_state(
config,
{
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
)
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
}
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
]
app_w_interrupt.update_state(
config,
{
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
},
)
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
},
},
tasks=(),
next=(),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 4,
"writes": {
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
}
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# test state get/update methods with interrupt_before
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["tools"],
)
config = {"configurable": {"thread_id": "2"}}
llm.i = 0 # reset the llm
assert [
c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config)
] == [
{
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
}
}
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
},
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {
"agent": {
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
}
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
app_w_interrupt.update_state(
config,
{
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
)
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
}
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"input": "what is weather in sf",
},
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
]
app_w_interrupt.update_state(
config,
{
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
},
)
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
},
},
tasks=(),
next=(),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 4,
"writes": {
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
}
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# test re-invoke to continue with interrupt_before
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["tools"],
)
config = {"configurable": {"thread_id": "3"}}
llm.i = 0 # reset the llm
assert [
c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config)
] == [
{
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
}
}
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
},
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {
"agent": {
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
}
},
"thread_id": "3",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"agent": {
"input": "what is weather in sf",
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
},
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
]
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{
"tools": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
}
},
{
"agent": {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
}
},
]
def test_conditional_entrypoint_graph(snapshot: SnapshotAssertion) -> None:
def left(data: str) -> str:
return data + "->left"
def right(data: str) -> str:
return data + "->right"
def should_start(data: str) -> str:
# Logic to decide where to start
if len(data) > 10:
return "go-right"
else:
return "go-left"
# Define a new graph
workflow = Graph()
workflow.add_node("left", left)
workflow.add_node("right", right)
workflow.set_conditional_entry_point(
should_start, {"go-left": "left", "go-right": "right"}
)
workflow.add_conditional_edges("left", lambda data: END, {END: END})
workflow.add_edge("right", END)
app = workflow.compile()
if SHOULD_CHECK_SNAPSHOTS:
assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert (
app.invoke("what is weather in sf", debug=True)
== "what is weather in sf->right"
)
assert [*app.stream("what is weather in sf")] == [
{"right": "what is weather in sf->right"},
]
def test_conditional_entrypoint_to_multiple_state_graph(
snapshot: SnapshotAssertion,
) -> None:
class OverallState(TypedDict):
locations: list[str]
results: Annotated[list[str], operator.add]
def get_weather(state: OverallState) -> OverallState:
location = state["location"]
weather = "sunny" if len(location) > 2 else "cloudy"
return {"results": [f"It's {weather} in {location}"]}
def continue_to_weather(state: OverallState) -> list[Send]:
return [
Send("get_weather", {"location": location})
for location in state["locations"]
]
workflow = StateGraph(OverallState)
workflow.add_node("get_weather", get_weather)
workflow.add_edge("get_weather", END)
workflow.set_conditional_entry_point(continue_to_weather)
app = workflow.compile()
if SHOULD_CHECK_SNAPSHOTS:
assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.invoke({"locations": ["sf", "nyc"]}, debug=True) == {
"locations": ["sf", "nyc"],
"results": ["It's cloudy in sf", "It's sunny in nyc"],
}
assert [*app.stream({"locations": ["sf", "nyc"]}, stream_mode="values")][-1] == {
"locations": ["sf", "nyc"],
"results": ["It's cloudy in sf", "It's sunny in nyc"],
}
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_conditional_state_graph(
snapshot: SnapshotAssertion,
mocker: MockerFixture,
request: pytest.FixtureRequest,
checkpointer_name: str,
) -> None:
from langchain_core.language_models.fake import FakeStreamingListLLM
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import tool
checkpointer: BaseCheckpointSaver = request.getfixturevalue(
f"checkpointer_{checkpointer_name}"
)
setup = mocker.Mock()
teardown = mocker.Mock()
@contextmanager
def assert_ctx_once() -> Iterator[None]:
assert setup.call_count == 0
assert teardown.call_count == 0
try:
yield
finally:
assert setup.call_count == 1
assert teardown.call_count == 1
setup.reset_mock()
teardown.reset_mock()
@contextmanager
def make_httpx_client() -> Iterator[httpx.Client]:
setup()
with httpx.Client() as client:
try:
yield client
finally:
teardown()
class AgentState(TypedDict, total=False):
input: Annotated[str, UntrackedValue]
agent_outcome: Optional[Union[AgentAction, AgentFinish]]
intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add]
session: Annotated[httpx.Client, Context(make_httpx_client)]
class ToolState(TypedDict, total=False):
agent_outcome: Union[AgentAction, AgentFinish]
session: Annotated[httpx.Client, Context(make_httpx_client)]
# Assemble the tools
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
# Construct the agent
prompt = PromptTemplate.from_template("Hello!")
llm = FakeStreamingListLLM(
responses=[
"tool:search_api:query",
"tool:search_api:another",
"finish:answer",
]
)
def agent_parser(input: str) -> dict[str, Union[AgentAction, AgentFinish]]:
if input.startswith("finish"):
_, answer = input.split(":")
return {
"agent_outcome": AgentFinish(
return_values={"answer": answer}, log=input
)
}
else:
_, tool_name, tool_input = input.split(":")
return {
"agent_outcome": AgentAction(
tool=tool_name, tool_input=tool_input, log=input
)
}
agent = prompt | llm | agent_parser
# Define tool execution logic
def execute_tools(data: ToolState) -> dict:
# check session in data
assert isinstance(data["session"], httpx.Client)
assert "input" not in data
assert "intermediate_steps" not in data
# execute the tool
agent_action: AgentAction = data.pop("agent_outcome")
observation = {t.name: t for t in tools}[agent_action.tool].invoke(
agent_action.tool_input
)
return {"intermediate_steps": [[agent_action, observation]]}
# Define decision-making logic
def should_continue(data: AgentState) -> str:
# check session in data
assert isinstance(data["session"], httpx.Client)
# Logic to decide whether to continue in the loop or exit
if isinstance(data["agent_outcome"], AgentFinish):
return "exit"
else:
return "continue"
# Define a new graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent)
workflow.add_node("tools", execute_tools, input=ToolState)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent", should_continue, {"continue": "tools", "exit": END}
)
workflow.add_edge("tools", "agent")
app = workflow.compile()
if SHOULD_CHECK_SNAPSHOTS:
assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
with assert_ctx_once():
assert app.invoke({"input": "what is weather in sf"}) == {
"input": "what is weather in sf",
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
],
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
}
with assert_ctx_once():
assert [*app.stream({"input": "what is weather in sf"})] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
{
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
}
},
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
],
],
}
},
{
"agent": {
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
}
},
]
# test state get/update methods with interrupt_after
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["agent"],
)
config = {"configurable": {"thread_id": "1"}}
with assert_ctx_once():
assert [
c
for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config)
] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
"intermediate_steps": [],
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
with assert_ctx_once():
app_w_interrupt.update_state(
config,
{
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
)
},
)
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"intermediate_steps": [],
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 2,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
)
},
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
with assert_ctx_once():
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
}
},
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{"__interrupt__": ()},
]
with assert_ctx_once():
app_w_interrupt.update_state(
config,
{
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
)
},
)
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
},
tasks=(),
next=(),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 5,
"writes": {
"agent": {
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
)
}
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# test state get/update methods with interrupt_before
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["tools"],
debug=True,
)
config = {"configurable": {"thread_id": "2"}}
llm.i = 0 # reset the llm
assert [
c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config)
] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
}
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
"intermediate_steps": [],
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
app_w_interrupt.update_state(
config,
{
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
)
},
)
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"intermediate_steps": [],
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 2,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
)
}
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
}
},
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{"__interrupt__": ()},
]
app_w_interrupt.update_state(
config,
{
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
)
},
)
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
),
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:a different query",
),
"result for query",
]
],
},
tasks=(),
next=(),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 5,
"writes": {
"agent": {
"agent_outcome": AgentFinish(
return_values={"answer": "a really nice answer"},
log="finish:a really nice answer",
)
}
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# test w interrupt before all
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before="*",
debug=True,
)
config = {"configurable": {"thread_id": "3"}}
llm.i = 0 # reset the llm
assert [
c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config)
] == [
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"intermediate_steps": [],
},
tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),),
next=("agent",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "3",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
}
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
"intermediate_steps": [],
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
"thread_id": "3",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
}
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
},
tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),),
next=("agent",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 2,
"writes": {
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
}
},
"thread_id": "3",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{"__interrupt__": ()},
]
# test w interrupt after all
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after="*",
)
config = {"configurable": {"thread_id": "4"}}
llm.i = 0 # reset the llm
assert [
c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config)
] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
}
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
"intermediate_steps": [],
},
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
}
},
"thread_id": "4",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
}
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"agent_outcome": AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
},
tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),),
next=("agent",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 2,
"writes": {
"tools": {
"intermediate_steps": [
[
AgentAction(
tool="search_api",
tool_input="query",
log="tool:search_api:query",
),
"result for query",
]
],
}
},
"thread_id": "4",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"agent": {
"agent_outcome": AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
}
},
{"__interrupt__": ()},
]
def test_conditional_state_graph_with_list_edge_inputs(snapshot: SnapshotAssertion):
class State(TypedDict):
foo: Annotated[list[str], operator.add]
graph_builder = StateGraph(State)
graph_builder.add_node("A", lambda x: {"foo": ["A"]})
graph_builder.add_node("B", lambda x: {"foo": ["B"]})
graph_builder.add_edge(START, "A")
graph_builder.add_edge(START, "B")
graph_builder.add_edge(["A", "B"], END)
app = graph_builder.compile()
assert app.invoke({"foo": []}) == {"foo": ["A", "B"]}
assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
def test_state_graph_w_config_inherited_state_keys(snapshot: SnapshotAssertion) -> None:
from langchain_core.language_models.fake import FakeStreamingListLLM
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import tool
class BaseState(TypedDict):
input: str
agent_outcome: Optional[Union[AgentAction, AgentFinish]]
class AgentState(BaseState, total=False):
intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add]
assert get_type_hints(AgentState).keys() == {
"input",
"agent_outcome",
"intermediate_steps",
}
class Config(TypedDict, total=False):
tools: list[str]
# Assemble the tools
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
# Construct the agent
prompt = PromptTemplate.from_template("Hello!")
llm = FakeStreamingListLLM(
responses=[
"tool:search_api:query",
"tool:search_api:another",
"finish:answer",
]
)
def agent_parser(input: str) -> dict[str, Union[AgentAction, AgentFinish]]:
if input.startswith("finish"):
_, answer = input.split(":")
return {
"agent_outcome": AgentFinish(
return_values={"answer": answer}, log=input
)
}
else:
_, tool_name, tool_input = input.split(":")
return {
"agent_outcome": AgentAction(
tool=tool_name, tool_input=tool_input, log=input
)
}
agent = prompt | llm | agent_parser
# Define tool execution logic
def execute_tools(data: AgentState) -> dict:
agent_action: AgentAction = data.pop("agent_outcome")
observation = {t.name: t for t in tools}[agent_action.tool].invoke(
agent_action.tool_input
)
return {"intermediate_steps": [(agent_action, observation)]}
# Define decision-making logic
def should_continue(data: AgentState) -> str:
# Logic to decide whether to continue in the loop or exit
if isinstance(data["agent_outcome"], AgentFinish):
return "exit"
else:
return "continue"
# Define a new graph
builder = StateGraph(AgentState, Config)
builder.add_node("agent", agent)
builder.add_node("tools", execute_tools)
builder.set_entry_point("agent")
builder.add_conditional_edges(
"agent", should_continue, {"continue": "tools", "exit": END}
)
builder.add_edge("tools", "agent")
app = builder.compile()
if SHOULD_CHECK_SNAPSHOTS:
assert json.dumps(app.config_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot
assert builder.channels.keys() == {"input", "agent_outcome", "intermediate_steps"}
assert app.invoke({"input": "what is weather in sf"}) == {
"agent_outcome": AgentFinish(
return_values={"answer": "answer"}, log="finish:answer"
),
"input": "what is weather in sf",
"intermediate_steps": [
(
AgentAction(
tool="search_api", tool_input="query", log="tool:search_api:query"
),
"result for query",
),
(
AgentAction(
tool="search_api",
tool_input="another",
log="tool:search_api:another",
),
"result for another",
),
],
}
def test_conditional_entrypoint_graph_state(snapshot: SnapshotAssertion) -> None:
class AgentState(TypedDict, total=False):
input: str
output: str
steps: Annotated[list[str], operator.add]
def left(data: AgentState) -> AgentState:
return {"output": data["input"] + "->left"}
def right(data: AgentState) -> AgentState:
return {"output": data["input"] + "->right"}
def should_start(data: AgentState) -> str:
assert data["steps"] == [], "Expected input to be read from the state"
# Logic to decide where to start
if len(data["input"]) > 10:
return "go-right"
else:
return "go-left"
# Define a new graph
workflow = StateGraph(AgentState)
workflow.add_node("left", left)
workflow.add_node("right", right)
workflow.set_conditional_entry_point(
should_start, {"go-left": "left", "go-right": "right"}
)
workflow.add_conditional_edges("left", lambda data: END, {END: END})
workflow.add_edge("right", END)
app = workflow.compile()
if SHOULD_CHECK_SNAPSHOTS:
assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.invoke({"input": "what is weather in sf"}) == {
"input": "what is weather in sf",
"output": "what is weather in sf->right",
"steps": [],
}
assert [*app.stream({"input": "what is weather in sf"})] == [
{"right": {"output": "what is weather in sf->right"}},
]
def test_prebuilt_tool_chat(snapshot: SnapshotAssertion) -> None:
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import tool
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
model = FakeChatModel(
messages=[
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another"},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one"},
},
],
),
AIMessage(content="answer"),
]
)
app = create_tool_calling_executor(model, tools)
if SHOULD_CHECK_SNAPSHOTS:
assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.invoke(
{"messages": [HumanMessage(content="what is weather in sf")]}
) == {
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another"},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one"},
},
],
),
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call234",
),
_AnyIdToolMessage(
content="result for a third one",
name="search_api",
tool_call_id="tool_call567",
id=AnyStr(),
),
_AnyIdAIMessage(content="answer"),
]
}
assert [
c
for c in app.stream(
{"messages": [HumanMessage(content="what is weather in sf")]},
stream_mode="messages",
)
] == [
(
_AnyIdAIMessageChunk(
content="",
tool_calls=[
{
"name": "search_api",
"args": {"query": "query"},
"id": "tool_call123",
"type": "tool_call",
}
],
tool_call_chunks=[
{
"name": "search_api",
"args": '{"query": "query"}',
"id": "tool_call123",
"index": None,
"type": "tool_call_chunk",
}
],
),
{
"langgraph_step": 1,
"langgraph_node": "agent",
"langgraph_triggers": ["start:agent"],
"langgraph_path": (PULL, "agent"),
"langgraph_checkpoint_ns": AnyStr("agent:"),
"checkpoint_ns": AnyStr("agent:"),
"ls_provider": "fakechatmodel",
"ls_model_type": "chat",
},
),
(
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
),
{
"langgraph_step": 2,
"langgraph_node": "tools",
"langgraph_triggers": ["branch:agent:should_continue:tools"],
"langgraph_path": (PULL, "tools"),
"langgraph_checkpoint_ns": AnyStr("tools:"),
},
),
(
_AnyIdAIMessageChunk(
content="",
tool_calls=[
{
"name": "search_api",
"args": {"query": "another"},
"id": "tool_call234",
"type": "tool_call",
},
{
"name": "search_api",
"args": {"query": "a third one"},
"id": "tool_call567",
"type": "tool_call",
},
],
tool_call_chunks=[
{
"name": "search_api",
"args": '{"query": "another"}',
"id": "tool_call234",
"index": None,
"type": "tool_call_chunk",
},
{
"name": "search_api",
"args": '{"query": "a third one"}',
"id": "tool_call567",
"index": None,
"type": "tool_call_chunk",
},
],
),
{
"langgraph_step": 3,
"langgraph_node": "agent",
"langgraph_triggers": ["tools"],
"langgraph_path": (PULL, "agent"),
"langgraph_checkpoint_ns": AnyStr("agent:"),
"checkpoint_ns": AnyStr("agent:"),
"ls_provider": "fakechatmodel",
"ls_model_type": "chat",
},
),
(
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call234",
),
{
"langgraph_step": 4,
"langgraph_node": "tools",
"langgraph_triggers": ["branch:agent:should_continue:tools"],
"langgraph_path": (PULL, "tools"),
"langgraph_checkpoint_ns": AnyStr("tools:"),
},
),
(
_AnyIdToolMessage(
content="result for a third one",
name="search_api",
tool_call_id="tool_call567",
),
{
"langgraph_step": 4,
"langgraph_node": "tools",
"langgraph_triggers": ["branch:agent:should_continue:tools"],
"langgraph_path": (PULL, "tools"),
"langgraph_checkpoint_ns": AnyStr("tools:"),
},
),
(
_AnyIdAIMessageChunk(
content="answer",
),
{
"langgraph_step": 5,
"langgraph_node": "agent",
"langgraph_triggers": ["tools"],
"langgraph_path": (PULL, "agent"),
"langgraph_checkpoint_ns": AnyStr("agent:"),
"checkpoint_ns": AnyStr("agent:"),
"ls_provider": "fakechatmodel",
"ls_model_type": "chat",
},
),
]
assert app.invoke(
{"messages": [HumanMessage(content="what is weather in sf")]},
{"recursion_limit": 2},
debug=True,
) == {
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
_AnyIdAIMessage(content="Sorry, need more steps to process this request."),
]
}
model.i = 0 # reset the model
assert (
app.invoke(
{"messages": [HumanMessage(content="what is weather in sf")]},
stream_mode="updates",
)[0]["agent"]["messages"]
== [
{
"agent": {
"messages": [
_AnyIdAIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
)
]
}
},
{
"tools": {
"messages": [
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
)
]
}
},
{
"agent": {
"messages": [
_AnyIdAIMessage(
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another"},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one"},
},
],
)
]
}
},
{
"tools": {
"messages": [
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call234",
),
_AnyIdToolMessage(
content="result for a third one",
name="search_api",
tool_call_id="tool_call567",
),
]
}
},
{"agent": {"messages": [_AnyIdAIMessage(content="answer")]}},
][0]["agent"]["messages"]
)
assert [
*app.stream({"messages": [HumanMessage(content="what is weather in sf")]})
] == [
{
"agent": {
"messages": [
_AnyIdAIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
)
]
}
},
{
"tools": {
"messages": [
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
)
]
}
},
{
"agent": {
"messages": [
_AnyIdAIMessage(
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another"},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one"},
},
],
)
]
}
},
{
"tools": {
"messages": [
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call234",
),
_AnyIdToolMessage(
content="result for a third one",
name="search_api",
tool_call_id="tool_call567",
),
]
}
},
{"agent": {"messages": [_AnyIdAIMessage(content="answer")]}},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_state_graph_packets(
request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture
) -> None:
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
ToolCall,
ToolMessage,
)
from langchain_core.tools import tool
checkpointer: BaseCheckpointSaver = request.getfixturevalue(
f"checkpointer_{checkpointer_name}"
)
class AgentState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
session: Annotated[httpx.Client, Context(httpx.Client)]
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
tools_by_name = {t.name: t for t in tools}
model = FakeMessagesListChatModel(
responses=[
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
),
AIMessage(id="ai3", content="answer"),
]
)
def agent(data: AgentState) -> AgentState:
assert isinstance(data["session"], httpx.Client)
return {
"messages": model.invoke(data["messages"]),
"something_extra": "hi there",
}
# Define decision-making logic
def should_continue(data: AgentState) -> str:
assert isinstance(data["session"], httpx.Client)
assert (
data["something_extra"] == "hi there"
), "nodes can pass extra data to their cond edges, which isn't saved in state"
# Logic to decide whether to continue in the loop or exit
if tool_calls := data["messages"][-1].tool_calls:
return [Send("tools", tool_call) for tool_call in tool_calls]
else:
return END
def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState:
time.sleep(input["args"].get("idx", 0) / 10)
output = tools_by_name[input["name"]].invoke(input["args"], config)
return {
"messages": ToolMessage(
content=output, name=input["name"], tool_call_id=input["id"]
)
}
# Define a new graph
workflow = StateGraph(AgentState)
# Define the two nodes we will cycle between
workflow.add_node("agent", agent)
workflow.add_node("tools", tools_node)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges("agent", should_continue)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("tools", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()
assert app.invoke({"messages": HumanMessage(content="what is weather in sf")}) == {
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
),
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call234",
),
_AnyIdToolMessage(
content="result for a third one",
name="search_api",
tool_call_id="tool_call567",
),
AIMessage(content="answer", id="ai3"),
]
}
assert [
c
for c in app.stream(
{"messages": [HumanMessage(content="what is weather in sf")]}
)
] == [
{
"agent": {
"messages": AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
)
},
},
{
"tools": {
"messages": _AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
)
}
},
{
"agent": {
"messages": AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
)
}
},
{
"tools": {
"messages": _AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call234",
)
},
},
{
"tools": {
"messages": _AnyIdToolMessage(
content="result for a third one",
name="search_api",
tool_call_id="tool_call567",
),
},
},
{"agent": {"messages": AIMessage(content="answer", id="ai3")}},
]
# interrupt after agent
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["agent"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c
for c in app_w_interrupt.stream(
{"messages": HumanMessage(content="what is weather in sf")}, config
)
] == [
{
"agent": {
"messages": AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
)
}
},
{"__interrupt__": ()},
]
if not FF_SEND_V2:
return
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
]
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
content="",
additional_kwargs={},
response_metadata={},
id="ai1",
tool_calls=[
{
"name": "search_api",
"args": {"query": "query"},
"id": "tool_call123",
"type": "tool_call",
}
],
)
},
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr())
),
),
next=("tools",),
config=(app_w_interrupt.checkpointer.get_tuple(config)).config,
created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# modify ai message
last_message = (app_w_interrupt.get_state(config)).values["messages"][-1]
last_message.tool_calls[0]["args"]["query"] = "a different query"
app_w_interrupt.update_state(
config, {"messages": last_message, "something_extra": "hi there"}
)
# message was replaced instead of appended
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
]
},
tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0, AnyStr())),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {
"agent": {
"messages": AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
"something_extra": "hi there",
}
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"tools": {
"messages": _AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
)
}
},
{
"agent": {
"messages": AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
)
},
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
),
]
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
"",
id="ai2",
tool_calls=[
{
"name": "search_api",
"args": {"query": "another", "idx": 0},
"id": "tool_call234",
"type": "tool_call",
},
{
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
"id": "tool_call567",
"type": "tool_call",
},
],
)
},
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr())
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3, AnyStr())
),
),
next=("tools", "tools"),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 2,
"writes": {
"tools": {
"messages": _AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
},
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
app_w_interrupt.update_state(
config,
{
"messages": AIMessage(content="answer", id="ai2"),
"something_extra": "hi there",
},
)
# replaces message even if object identity is different, as long as id is the same
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(content="answer", id="ai2"),
]
},
tasks=(),
next=(),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 3,
"writes": {
"agent": {
"messages": AIMessage(content="answer", id="ai2"),
"something_extra": "hi there",
}
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# interrupt before tools
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["tools"],
)
config = {"configurable": {"thread_id": "2"}}
model.i = 0
assert [
c
for c in app_w_interrupt.stream(
{"messages": HumanMessage(content="what is weather in sf")}, config
)
] == [
{
"agent": {
"messages": AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
)
}
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
]
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
"",
id="ai1",
tool_calls=[
{
"name": "search_api",
"args": {"query": "query"},
"id": "tool_call123",
"type": "tool_call",
}
],
)
},
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr())
),
),
next=("tools",),
config=(app_w_interrupt.checkpointer.get_tuple(config)).config,
created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# modify ai message
last_message = (app_w_interrupt.get_state(config)).values["messages"][-1]
last_message.tool_calls[0]["args"]["query"] = "a different query"
app_w_interrupt.update_state(
config, {"messages": last_message, "something_extra": "hi there"}
)
# message was replaced instead of appended
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
]
},
tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0, AnyStr())),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {
"agent": {
"messages": AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
"something_extra": "hi there",
}
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"tools": {
"messages": _AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
)
}
},
{
"agent": {
"messages": AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
)
},
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
id="ai2",
content="",
tool_calls=[
{
"id": "tool_call234",
"name": "search_api",
"args": {"query": "another", "idx": 0},
},
{
"id": "tool_call567",
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
},
],
),
]
},
tasks=(
PregelTask(
id=AnyStr(),
name="agent",
path=("__pregel_pull", "agent"),
error=None,
interrupts=(),
state=None,
result={
"messages": AIMessage(
"",
id="ai2",
tool_calls=[
{
"name": "search_api",
"args": {"query": "another", "idx": 0},
"id": "tool_call234",
"type": "tool_call",
},
{
"name": "search_api",
"args": {"query": "a third one", "idx": 1},
"id": "tool_call567",
"type": "tool_call",
},
],
)
},
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2, AnyStr())
),
PregelTask(
AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3, AnyStr())
),
),
next=("tools", "tools"),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 2,
"writes": {
"tools": {
"messages": _AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
},
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
app_w_interrupt.update_state(
config,
{
"messages": AIMessage(content="answer", id="ai2"),
"something_extra": "hi there",
},
)
# replaces message even if object identity is different, as long as id is the same
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
id="ai1",
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
},
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(content="answer", id="ai2"),
]
},
tasks=(),
next=(),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 3,
"writes": {
"agent": {
"messages": AIMessage(content="answer", id="ai2"),
"something_extra": "hi there",
}
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_message_graph(
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
request: pytest.FixtureRequest,
checkpointer_name: str,
) -> None:
from copy import deepcopy
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import tool
checkpointer: BaseCheckpointSaver = request.getfixturevalue(
f"checkpointer_{checkpointer_name}"
)
class FakeFuntionChatModel(FakeMessagesListChatModel):
def bind_functions(self, functions: list):
return self
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = deepcopy(self.responses[self.i])
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
generation = ChatGeneration(message=response)
return ChatResult(generations=[generation])
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
model = FakeFuntionChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
),
AIMessage(content="answer", id="ai3"),
]
)
# Define the function that determines whether to continue or not
def should_continue(messages):
last_message = messages[-1]
# If there is no function call, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
# Define a new graph
workflow = MessageGraph()
# Define the two nodes we will cycle between
workflow.add_node("agent", model)
workflow.add_node("tools", ToolNode(tools))
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "tools",
# Otherwise we finish.
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("tools", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()
if SHOULD_CHECK_SNAPSHOTS:
assert json.dumps(app.get_input_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_output_schema().model_json_schema()) == snapshot
assert json.dumps(app.get_graph().to_json(), indent=2) == snapshot
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.invoke(HumanMessage(content="what is weather in sf")) == [
_AnyIdHumanMessage(
content="what is weather in sf",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1", # respects ids passed in
),
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
),
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call456",
),
AIMessage(content="answer", id="ai3"),
]
assert [*app.stream([HumanMessage(content="what is weather in sf")])] == [
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
{
"tools": [
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
)
]
},
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
{
"tools": [
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call456",
)
]
},
{"agent": AIMessage(content="answer", id="ai3")},
]
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["agent"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c for c in app_w_interrupt.stream(("human", "what is weather in sf"), config)
] == [
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# modify ai message
last_message = app_w_interrupt.get_state(config).values[-1]
last_message.tool_calls[0]["args"] = {"query": "a different query"}
next_config = app_w_interrupt.update_state(config, last_message)
# message was replaced instead of appended
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=next_config,
created_at=AnyStr(),
metadata={
"parents": {},
"source": "update",
"step": 2,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
id="ai1",
)
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"tools": [
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
)
]
},
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 4,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
app_w_interrupt.update_state(
config,
AIMessage(content="answer", id="ai2"), # replace existing message
)
# replaces message even if object identity is different, as long as id is the same
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(content="answer", id="ai2"),
],
tasks=(),
next=(),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 5,
"writes": {"agent": AIMessage(content="answer", id="ai2")},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["tools"],
)
config = {"configurable": {"thread_id": "2"}}
model.i = 0 # reset the llm
assert [c for c in app_w_interrupt.stream("what is weather in sf", config)] == [
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# modify ai message
last_message = app_w_interrupt.get_state(config).values[-1]
last_message.tool_calls[0]["args"] = {"query": "a different query"}
app_w_interrupt.update_state(config, last_message)
# message was replaced instead of appended
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 2,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
id="ai1",
)
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"tools": [
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
)
]
},
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 4,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
app_w_interrupt.update_state(
config,
AIMessage(content="answer", id="ai2"),
)
# replaces message even if object identity is different, as long as id is the same
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
id=AnyStr(),
),
AIMessage(content="answer", id="ai2"),
],
tasks=(),
next=(),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 5,
"writes": {"agent": AIMessage(content="answer", id="ai2")},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# add an extra message as if it came from "tools" node
app_w_interrupt.update_state(config, ("ai", "an extra message"), as_node="tools")
# extra message is coerced BaseMessge and appended
# now the next node is "agent" per the graph edges
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
id=AnyStr(),
),
AIMessage(content="answer", id="ai2"),
_AnyIdAIMessage(content="an extra message"),
],
tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),),
next=("agent",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 6,
"writes": {"tools": UnsortedSequence("ai", "an extra message")},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_root_graph(
deterministic_uuids: MockerFixture,
request: pytest.FixtureRequest,
checkpointer_name: str,
) -> None:
from copy import deepcopy
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import tool
checkpointer: BaseCheckpointSaver = request.getfixturevalue(
f"checkpointer_{checkpointer_name}"
)
class FakeFuntionChatModel(FakeMessagesListChatModel):
def bind_functions(self, functions: list):
return self
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = deepcopy(self.responses[self.i])
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
generation = ChatGeneration(message=response)
return ChatResult(generations=[generation])
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
model = FakeFuntionChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
),
AIMessage(content="answer", id="ai3"),
]
)
# Define the function that determines whether to continue or not
def should_continue(messages):
last_message = messages[-1]
# If there is no function call, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
class State(TypedDict):
__root__: Annotated[list[BaseMessage], add_messages]
# Define a new graph
workflow = StateGraph(State)
# Define the two nodes we will cycle between
workflow.add_node("agent", model)
workflow.add_node("tools", ToolNode(tools))
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "tools",
# Otherwise we finish.
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("tools", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()
assert app.invoke(HumanMessage(content="what is weather in sf")) == [
_AnyIdHumanMessage(
content="what is weather in sf",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1", # respects ids passed in
),
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
),
_AnyIdToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call456",
),
AIMessage(content="answer", id="ai3"),
]
assert [*app.stream([HumanMessage(content="what is weather in sf")])] == [
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
{
"tools": [
ToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
id="00000000-0000-4000-8000-000000000033",
)
]
},
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
{
"tools": [
ToolMessage(
content="result for another",
name="search_api",
tool_call_id="tool_call456",
id="00000000-0000-4000-8000-000000000041",
)
]
},
{"agent": AIMessage(content="answer", id="ai3")},
]
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["agent"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c for c in app_w_interrupt.stream(("human", "what is weather in sf"), config)
] == [
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# modify ai message
last_message = app_w_interrupt.get_state(config).values[-1]
last_message.tool_calls[0]["args"] = {"query": "a different query"}
next_config = app_w_interrupt.update_state(config, last_message)
# message was replaced instead of appended
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=next_config,
created_at=AnyStr(),
metadata={
"parents": {},
"source": "update",
"step": 2,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
id="ai1",
)
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"tools": [
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
)
]
},
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
id=AnyStr(),
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 4,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
app_w_interrupt.update_state(
config,
AIMessage(content="answer", id="ai2"), # replace existing message
)
# replaces message even if object identity is different, as long as id is the same
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
id=AnyStr(),
),
AIMessage(content="answer", id="ai2"),
],
tasks=(),
next=(),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 5,
"writes": {"agent": AIMessage(content="answer", id="ai2")},
"thread_id": "1",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["tools"],
)
config = {"configurable": {"thread_id": "2"}}
model.i = 0 # reset the llm
assert [c for c in app_w_interrupt.stream("what is weather in sf", config)] == [
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
}
],
id="ai1",
)
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# modify ai message
last_message = app_w_interrupt.get_state(config).values[-1]
last_message.tool_calls[0]["args"] = {"query": "a different query"}
app_w_interrupt.update_state(config, last_message)
# message was replaced instead of appended
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 2,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
id="ai1",
)
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config)] == [
{
"tools": [
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
)
]
},
{
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
{"__interrupt__": ()},
]
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
id=AnyStr(),
),
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
),
],
tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),),
next=("tools",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 4,
"writes": {
"agent": AIMessage(
content="",
tool_calls=[
{
"id": "tool_call456",
"name": "search_api",
"args": {"query": "another"},
}
],
id="ai2",
)
},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
app_w_interrupt.update_state(
config,
AIMessage(content="answer", id="ai2"),
)
# replaces message even if object identity is different, as long as id is the same
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(content="answer", id="ai2"),
],
tasks=(),
next=(),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 5,
"writes": {"agent": AIMessage(content="answer", id="ai2")},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# add an extra message as if it came from "tools" node
app_w_interrupt.update_state(config, ("ai", "an extra message"), as_node="tools")
# extra message is coerced BaseMessge and appended
# now the next node is "agent" per the graph edges
assert app_w_interrupt.get_state(config) == StateSnapshot(
values=[
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
id=AnyStr(),
),
AIMessage(content="answer", id="ai2"),
_AnyIdAIMessage(content="an extra message"),
],
tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),),
next=("agent",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 6,
"writes": {"tools": UnsortedSequence("ai", "an extra message")},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# create new graph with one more state key, reuse previous thread history
def simple_add(left, right):
if not isinstance(right, list):
right = [right]
return left + right
class MoreState(TypedDict):
__root__: Annotated[list[BaseMessage], simple_add]
something_else: str
# Define a new graph
new_workflow = StateGraph(MoreState)
new_workflow.add_node(
"agent", RunnableMap(__root__=RunnablePick("__root__") | model)
)
new_workflow.add_node(
"tools", RunnableMap(__root__=RunnablePick("__root__") | ToolNode(tools))
)
new_workflow.set_entry_point("agent")
new_workflow.add_conditional_edges(
"agent",
RunnablePick("__root__") | should_continue,
{
# If `tools`, then we call the tool node.
"continue": "tools",
# Otherwise we finish.
"end": END,
},
)
new_workflow.add_edge("tools", "agent")
new_app = new_workflow.compile(checkpointer=checkpointer)
model.i = 0 # reset the llm
# previous state is converted to new schema
assert new_app.get_state(config) == StateSnapshot(
values={
"__root__": [
_AnyIdHumanMessage(content="what is weather in sf"),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "a different query"},
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(content="answer", id="ai2"),
_AnyIdAIMessage(content="an extra message"),
]
},
tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),),
next=("agent",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 6,
"writes": {"tools": UnsortedSequence("ai", "an extra message")},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
# new input is merged to old state
assert new_app.invoke(
{
"__root__": [HumanMessage(content="what is weather in la")],
"something_else": "value",
},
config,
interrupt_before=["agent"],
) == {
"__root__": [
HumanMessage(
content="what is weather in sf",
id="00000000-0000-4000-8000-000000000070",
),
AIMessage(
content="",
id="ai1",
tool_calls=[
{
"name": "search_api",
"args": {"query": "a different query"},
"id": "tool_call123",
}
],
),
_AnyIdToolMessage(
content="result for a different query",
name="search_api",
tool_call_id="tool_call123",
),
AIMessage(content="answer", id="ai2"),
AIMessage(
content="an extra message", id="00000000-0000-4000-8000-000000000092"
),
HumanMessage(content="what is weather in la"),
],
"something_else": "value",
}
def test_in_one_fan_out_out_one_graph_state() -> None:
def sorted_add(x: list[str], y: list[str]) -> list[str]:
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
def retriever_one(data: State) -> State:
# timer ensures stream output order is stable
# also, it confirms that the update order is not dependent on finishing order
# instead being defined by the order of the nodes/edges in the graph definition
# ie. stable between invocations
time.sleep(0.1)
return {"docs": ["doc1", "doc2"]}
def retriever_two(data: State) -> State:
return {"docs": ["doc3", "doc4"]}
def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "retriever_one")
workflow.add_edge("rewrite_query", "retriever_two")
workflow.add_edge("retriever_one", "qa")
workflow.add_edge("retriever_two", "qa")
workflow.set_finish_point("qa")
app = workflow.compile()
assert app.invoke({"query": "what is weather in sf"}) == {
"query": "query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
}
assert [*app.stream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
assert [*app.stream({"query": "what is weather in sf"}, stream_mode="values")] == [
{"query": "what is weather in sf", "docs": []},
{"query": "query: what is weather in sf", "docs": []},
{
"query": "query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
},
{
"query": "query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
},
]
assert [
*app.stream(
{"query": "what is weather in sf"},
stream_mode=["values", "updates", "debug"],
)
] == [
("values", {"query": "what is weather in sf", "docs": []}),
(
"debug",
{
"type": "task",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "rewrite_query",
"input": {"query": "what is weather in sf", "docs": []},
"triggers": ["start:rewrite_query"],
},
},
),
("updates", {"rewrite_query": {"query": "query: what is weather in sf"}}),
(
"debug",
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "rewrite_query",
"result": [("query", "query: what is weather in sf")],
"error": None,
"interrupts": [],
},
},
),
("values", {"query": "query: what is weather in sf", "docs": []}),
(
"debug",
{
"type": "task",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "retriever_one",
"input": {"query": "query: what is weather in sf", "docs": []},
"triggers": ["rewrite_query"],
},
},
),
(
"debug",
{
"type": "task",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "retriever_two",
"input": {"query": "query: what is weather in sf", "docs": []},
"triggers": ["rewrite_query"],
},
},
),
(
"updates",
{"retriever_two": {"docs": ["doc3", "doc4"]}},
),
(
"debug",
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "retriever_two",
"result": [("docs", ["doc3", "doc4"])],
"error": None,
"interrupts": [],
},
},
),
(
"updates",
{"retriever_one": {"docs": ["doc1", "doc2"]}},
),
(
"debug",
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "retriever_one",
"result": [("docs", ["doc1", "doc2"])],
"error": None,
"interrupts": [],
},
},
),
(
"values",
{
"query": "query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
},
),
(
"debug",
{
"type": "task",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": AnyStr(),
"name": "qa",
"input": {
"query": "query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
},
"triggers": ["retriever_one", "retriever_two"],
},
},
),
("updates", {"qa": {"answer": "doc1,doc2,doc3,doc4"}}),
(
"debug",
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": AnyStr(),
"name": "qa",
"result": [("answer", "doc1,doc2,doc3,doc4")],
"error": None,
"interrupts": [],
},
},
),
(
"values",
{
"query": "query: what is weather in sf",
"answer": "doc1,doc2,doc3,doc4",
"docs": ["doc1", "doc2", "doc3", "doc4"],
},
),
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_dynamic_interrupt(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str
tool_two_node_count = 0
def tool_two_node(s: State) -> State:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
answer = interrupt("Just because...")
else:
answer = " all good"
return {"my_key": answer}
tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy())
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()
tracer = FakeTracer()
assert tool_two.invoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value"}
assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
}
tool_two = tool_two_graph.compile(checkpointer=checkpointer)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
tool_two.invoke({"my_key": "value", "market": "DE"})
# flow: interrupt -> resume with answer
thread2 = {"configurable": {"thread_id": "2"}}
# stop when about to enter node
assert [
c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2)
] == [
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
)
},
]
# resume with answer
assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [
{"tool_two": {"my_key": " my answer"}},
]
# flow: interrupt -> clear tasks
thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == {
"my_key": "value ⛰️",
"market": "DE",
}
assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}},
"thread_id": "1",
},
]
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
),
),
),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
# clear the interrupt and next tasks
tool_two.update_state(thread1, None, as_node=END)
# interrupt and next tasks are cleared
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=(),
tasks=(),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {},
"thread_id": "1",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
@pytest.mark.skipif(not FF_SEND_V2, reason="send v2 is not enabled")
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_copy_checkpoint(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str
def tool_one(s: State) -> State:
return {"my_key": " one"}
tool_two_node_count = 0
def tool_two_node(s: State) -> State:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
answer = interrupt("Just because...")
else:
answer = " all good"
return {"my_key": answer}
def start(state: State) -> list[Union[Send, str]]:
return ["tool_two", Send("tool_one", state)]
tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy())
tool_two_graph.add_node("tool_one", tool_one)
tool_two_graph.set_conditional_entry_point(start)
tool_two = tool_two_graph.compile()
tracer = FakeTracer()
assert tool_two.invoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value one",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value one"}
assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value one all good",
"market": "US",
}
tool_two = tool_two_graph.compile(checkpointer=checkpointer)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
tool_two.invoke({"my_key": "value", "market": "DE"})
# flow: interrupt -> resume with answer
thread2 = {"configurable": {"thread_id": "2"}}
# stop when about to enter node
assert [
c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2)
] == [
{
"tool_one": {"my_key": " one"},
},
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
)
},
]
# resume with answer
assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [
{"tool_two": {"my_key": " my answer"}},
]
# flow: interrupt -> clear tasks
thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == {
"my_key": "value ⛰️ one",
"market": "DE",
}
assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": {"tool_one": {"my_key": " one"}},
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}},
"thread_id": "1",
},
]
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️ one", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
),
),
),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {"tool_one": {"my_key": " one"}},
"thread_id": "1",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
# clear the interrupt and next tasks
tool_two.update_state(thread1, None)
# interrupt is cleared, next task is kept
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️ one", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(),
),
),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {},
"thread_id": "1",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_dynamic_interrupt_subgraph(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
class SubgraphState(TypedDict):
my_key: str
market: str
tool_two_node_count = 0
def tool_two_node(s: SubgraphState) -> SubgraphState:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
answer = interrupt("Just because...")
else:
answer = " all good"
return {"my_key": answer}
subgraph = StateGraph(SubgraphState)
subgraph.add_node("do", tool_two_node, retry=RetryPolicy())
subgraph.add_edge(START, "do")
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str
tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", subgraph.compile())
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()
tracer = FakeTracer()
assert tool_two.invoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value"}
assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
}
tool_two = tool_two_graph.compile(checkpointer=checkpointer)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
tool_two.invoke({"my_key": "value", "market": "DE"})
# flow: interrupt -> resume with answer
thread2 = {"configurable": {"thread_id": "2"}}
# stop when about to enter node
assert [
c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2)
] == [
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:"), AnyStr("do:")],
),
)
},
]
# resume with answer
assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [
{"tool_two": {"my_key": " my answer", "market": "DE"}},
]
# flow: interrupt -> clear tasks
thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == {
"my_key": "value ⛰️",
"market": "DE",
}
assert [
c.metadata
for c in tool_two.checkpointer.list(
{"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
)
] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}},
"thread_id": "1",
},
]
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:"), AnyStr("do:")],
),
),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("tool_two:"),
}
},
),
),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
parent_config=[
*tool_two.checkpointer.list(
{"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2
)
][-1].config,
)
# clear the interrupt and next tasks
tool_two.update_state(thread1, None, as_node=END)
# interrupt and next tasks are cleared
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=(),
tasks=(),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {},
"thread_id": "1",
},
parent_config=[
*tool_two.checkpointer.list(
{"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2
)
][-1].config,
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_start_branch_then(
snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str
shared: Annotated[dict[str, dict[str, Any]], SharedValue.on("assistant_id")]
def assert_shared_value(data: State, config: RunnableConfig) -> State:
assert "shared" in data
if thread_id := config["configurable"].get("thread_id"):
if thread_id == "1":
# this is the first thread, so should not see a value
assert data["shared"] == {}
return {"shared": {"1": {"hello": "world"}}}
elif thread_id == "2":
# this should get value saved by thread 1
assert data["shared"] == {"1": {"hello": "world"}}
elif thread_id == "3":
# this is a different assistant, so should not see previous value
assert data["shared"] == {}
return {}
def tool_two_slow(data: State, config: RunnableConfig) -> State:
return {"my_key": " slow", **assert_shared_value(data, config)}
def tool_two_fast(data: State, config: RunnableConfig) -> State:
return {"my_key": " fast", **assert_shared_value(data, config)}
tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two_slow", tool_two_slow)
tool_two_graph.add_node("tool_two_fast", tool_two_fast)
tool_two_graph.set_conditional_entry_point(
lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast", then=END
)
tool_two = tool_two_graph.compile()
assert tool_two.get_graph().draw_mermaid() == snapshot
assert tool_two.invoke({"my_key": "value", "market": "DE"}) == {
"my_key": "value slow",
"market": "DE",
}
assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value fast",
"market": "US",
}
tool_two = tool_two_graph.compile(
store=InMemoryStore(),
checkpointer=checkpointer,
interrupt_before=["tool_two_fast", "tool_two_slow"],
)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
tool_two.invoke({"my_key": "value", "market": "DE"})
thread1 = {"configurable": {"thread_id": "1", "assistant_id": "a"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == {
"my_key": "value ⛰️",
"market": "DE",
}
assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"assistant_id": "a",
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}},
"assistant_id": "a",
"thread_id": "1",
},
]
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),),
next=("tool_two_slow",),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"assistant_id": "a",
"thread_id": "1",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
# resume, for same result as above
assert tool_two.invoke(None, thread1, debug=1) == {
"my_key": "value ⛰️ slow",
"market": "DE",
}
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️ slow", "market": "DE"},
tasks=(),
next=(),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"tool_two_slow": {"my_key": " slow"}},
"assistant_id": "a",
"thread_id": "1",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
thread2 = {"configurable": {"thread_id": "2", "assistant_id": "a"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value", "market": "US"}, thread2) == {
"my_key": "value",
"market": "US",
}
assert tool_two.get_state(thread2) == StateSnapshot(
values={"my_key": "value", "market": "US"},
tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),),
next=("tool_two_fast",),
config=tool_two.checkpointer.get_tuple(thread2).config,
created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"assistant_id": "a",
"thread_id": "2",
},
parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config,
)
# resume, for same result as above
assert tool_two.invoke(None, thread2, debug=1) == {
"my_key": "value fast",
"market": "US",
}
assert tool_two.get_state(thread2) == StateSnapshot(
values={"my_key": "value fast", "market": "US"},
tasks=(),
next=(),
config=tool_two.checkpointer.get_tuple(thread2).config,
created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"tool_two_fast": {"my_key": " fast"}},
"assistant_id": "a",
"thread_id": "2",
},
parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config,
)
thread3 = {"configurable": {"thread_id": "3", "assistant_id": "b"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value", "market": "US"}, thread3) == {
"my_key": "value",
"market": "US",
}
assert tool_two.get_state(thread3) == StateSnapshot(
values={"my_key": "value", "market": "US"},
tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),),
next=("tool_two_fast",),
config=tool_two.checkpointer.get_tuple(thread3).config,
created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"assistant_id": "b",
"thread_id": "3",
},
parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config,
)
# update state
tool_two.update_state(thread3, {"my_key": "key"}) # appends to my_key
assert tool_two.get_state(thread3) == StateSnapshot(
values={"my_key": "valuekey", "market": "US"},
tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),),
next=("tool_two_fast",),
config=tool_two.checkpointer.get_tuple(thread3).config,
created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {START: {"my_key": "key"}},
"assistant_id": "b",
"thread_id": "3",
},
parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config,
)
# resume, for same result as above
assert tool_two.invoke(None, thread3, debug=1) == {
"my_key": "valuekey fast",
"market": "US",
}
assert tool_two.get_state(thread3) == StateSnapshot(
values={"my_key": "valuekey fast", "market": "US"},
tasks=(),
next=(),
config=tool_two.checkpointer.get_tuple(thread3).config,
created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 2,
"writes": {"tool_two_fast": {"my_key": " fast"}},
"assistant_id": "b",
"thread_id": "3",
},
parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config,
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_branch_then(
snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str
tool_two_graph = StateGraph(State)
tool_two_graph.set_entry_point("prepare")
tool_two_graph.set_finish_point("finish")
tool_two_graph.add_conditional_edges(
source="prepare",
path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
then="finish",
)
tool_two_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
tool_two_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
tool_two_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
tool_two_graph.add_node("finish", lambda s: {"my_key": " finished"})
tool_two = tool_two_graph.compile()
assert tool_two.get_graph().draw_mermaid(with_styles=False) == snapshot
assert tool_two.get_graph().draw_mermaid() == snapshot
assert tool_two.invoke({"my_key": "value", "market": "DE"}, debug=1) == {
"my_key": "value prepared slow finished",
"market": "DE",
}
assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value prepared fast finished",
"market": "US",
}
# test stream_mode=debug
tool_two = tool_two_graph.compile(checkpointer=checkpointer)
thread10 = {"configurable": {"thread_id": "10"}}
res = [
*tool_two.stream(
{"my_key": "value", "market": "DE"}, thread10, stream_mode="debug"
)
]
assert res == [
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": -1,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {"my_key": ""},
"metadata": {
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value", "market": "DE"}},
"thread_id": "10",
},
"parent_config": None,
"next": ["__start__"],
"tasks": [
{
"id": AnyStr(),
"name": "__start__",
"interrupts": (),
"state": None,
}
],
},
},
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": 0,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {
"my_key": "value",
"market": "DE",
},
"metadata": {
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "10",
},
"parent_config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"next": ["prepare"],
"tasks": [
{"id": AnyStr(), "name": "prepare", "interrupts": (), "state": None}
],
},
},
{
"type": "task",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "prepare",
"input": {"my_key": "value", "market": "DE"},
"triggers": ["start:prepare"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"id": AnyStr(),
"name": "prepare",
"result": [("my_key", " prepared")],
"error": None,
"interrupts": [],
},
},
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": 1,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {
"my_key": "value prepared",
"market": "DE",
},
"metadata": {
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "10",
},
"parent_config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"next": ["tool_two_slow"],
"tasks": [
{
"id": AnyStr(),
"name": "tool_two_slow",
"interrupts": (),
"state": None,
}
],
},
},
{
"type": "task",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "tool_two_slow",
"input": {"my_key": "value prepared", "market": "DE"},
"triggers": ["branch:prepare:condition:tool_two_slow"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"id": AnyStr(),
"name": "tool_two_slow",
"result": [("my_key", " slow")],
"error": None,
"interrupts": [],
},
},
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": 2,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {
"my_key": "value prepared slow",
"market": "DE",
},
"metadata": {
"parents": {},
"source": "loop",
"step": 2,
"writes": {"tool_two_slow": {"my_key": " slow"}},
"thread_id": "10",
},
"parent_config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"next": ["finish"],
"tasks": [
{"id": AnyStr(), "name": "finish", "interrupts": (), "state": None}
],
},
},
{
"type": "task",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": AnyStr(),
"name": "finish",
"input": {"my_key": "value prepared slow", "market": "DE"},
"triggers": ["branch:prepare:condition::then"],
},
},
{
"type": "task_result",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"id": AnyStr(),
"name": "finish",
"result": [("my_key", " finished")],
"error": None,
"interrupts": [],
},
},
{
"type": "checkpoint",
"timestamp": AnyStr(),
"step": 3,
"payload": {
"config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"values": {
"my_key": "value prepared slow finished",
"market": "DE",
},
"metadata": {
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "10",
},
"parent_config": {
"tags": [],
"metadata": {"thread_id": "10"},
"callbacks": None,
"recursion_limit": 25,
"configurable": {
"thread_id": "10",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
},
},
"next": [],
"tasks": [],
},
},
]
tool_two = tool_two_graph.compile(
checkpointer=checkpointer, interrupt_before=["tool_two_fast", "tool_two_slow"]
)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
tool_two.invoke({"my_key": "value", "market": "DE"})
thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value", "market": "DE"}, thread1) == {
"my_key": "value prepared",
"market": "DE",
}
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value prepared", "market": "DE"},
tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),),
next=("tool_two_slow",),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "1",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
# resume, for same result as above
assert tool_two.invoke(None, thread1, debug=1) == {
"my_key": "value prepared slow finished",
"market": "DE",
}
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value prepared slow finished", "market": "DE"},
tasks=(),
next=(),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "1",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
thread2 = {"configurable": {"thread_id": "2"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value", "market": "US"}, thread2) == {
"my_key": "value prepared",
"market": "US",
}
assert tool_two.get_state(thread2) == StateSnapshot(
values={"my_key": "value prepared", "market": "US"},
tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),),
next=("tool_two_fast",),
config=tool_two.checkpointer.get_tuple(thread2).config,
created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "2",
},
parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config,
)
# resume, for same result as above
assert tool_two.invoke(None, thread2, debug=1) == {
"my_key": "value prepared fast finished",
"market": "US",
}
assert tool_two.get_state(thread2) == StateSnapshot(
values={"my_key": "value prepared fast finished", "market": "US"},
tasks=(),
next=(),
config=tool_two.checkpointer.get_tuple(thread2).config,
created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "2",
},
parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config,
)
tool_two = tool_two_graph.compile(
checkpointer=checkpointer, interrupt_before=["finish"]
)
thread1 = {"configurable": {"thread_id": "11"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value", "market": "DE"}, thread1) == {
"my_key": "value prepared slow",
"market": "DE",
}
assert tool_two.get_state(thread1) == StateSnapshot(
values={
"my_key": "value prepared slow",
"market": "DE",
},
tasks=(PregelTask(AnyStr(), "finish", (PULL, "finish")),),
next=("finish",),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 2,
"writes": {"tool_two_slow": {"my_key": " slow"}},
"thread_id": "11",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
# update state
tool_two.update_state(thread1, {"my_key": "er"})
assert tool_two.get_state(thread1) == StateSnapshot(
values={
"my_key": "value prepared slower",
"market": "DE",
},
tasks=(PregelTask(AnyStr(), "finish", (PULL, "finish")),),
next=("finish",),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 3,
"writes": {"tool_two_slow": {"my_key": "er"}},
"thread_id": "11",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
tool_two = tool_two_graph.compile(
checkpointer=checkpointer, interrupt_after=["prepare"]
)
# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
tool_two.invoke({"my_key": "value", "market": "DE"})
thread1 = {"configurable": {"thread_id": "21"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value", "market": "DE"}, thread1) == {
"my_key": "value prepared",
"market": "DE",
}
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value prepared", "market": "DE"},
tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),),
next=("tool_two_slow",),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "21",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
# resume, for same result as above
assert tool_two.invoke(None, thread1, debug=1) == {
"my_key": "value prepared slow finished",
"market": "DE",
}
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value prepared slow finished", "market": "DE"},
tasks=(),
next=(),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "21",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
thread2 = {"configurable": {"thread_id": "22"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value", "market": "US"}, thread2) == {
"my_key": "value prepared",
"market": "US",
}
assert tool_two.get_state(thread2) == StateSnapshot(
values={"my_key": "value prepared", "market": "US"},
tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),),
next=("tool_two_fast",),
config=tool_two.checkpointer.get_tuple(thread2).config,
created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "22",
},
parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config,
)
# resume, for same result as above
assert tool_two.invoke(None, thread2, debug=1) == {
"my_key": "value prepared fast finished",
"market": "US",
}
assert tool_two.get_state(thread2) == StateSnapshot(
values={"my_key": "value prepared fast finished", "market": "US"},
tasks=(),
next=(),
config=tool_two.checkpointer.get_tuple(thread2).config,
created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "22",
},
parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config,
)
thread3 = {"configurable": {"thread_id": "23"}}
# update an empty thread before first run
uconfig = tool_two.update_state(thread3, {"my_key": "key", "market": "DE"})
# check current state
assert tool_two.get_state(thread3) == StateSnapshot(
values={"my_key": "key", "market": "DE"},
tasks=(PregelTask(AnyStr(), "prepare", (PULL, "prepare")),),
next=("prepare",),
config=uconfig,
created_at=AnyStr(),
metadata={
"parents": {},
"source": "update",
"step": 0,
"writes": {START: {"my_key": "key", "market": "DE"}},
"thread_id": "23",
},
parent_config=None,
)
# run from this point
assert tool_two.invoke(None, thread3) == {
"my_key": "key prepared",
"market": "DE",
}
# get state after first node
assert tool_two.get_state(thread3) == StateSnapshot(
values={"my_key": "key prepared", "market": "DE"},
tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),),
next=("tool_two_slow",),
config=tool_two.checkpointer.get_tuple(thread3).config,
created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 1,
"writes": {"prepare": {"my_key": " prepared"}},
"thread_id": "23",
},
parent_config=uconfig,
)
# resume, for same result as above
assert tool_two.invoke(None, thread3, debug=1) == {
"my_key": "key prepared slow finished",
"market": "DE",
}
assert tool_two.get_state(thread3) == StateSnapshot(
values={"my_key": "key prepared slow finished", "market": "DE"},
tasks=(),
next=(),
config=tool_two.checkpointer.get_tuple(thread3).config,
created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 3,
"writes": {"finish": {"my_key": " finished"}},
"thread_id": "23",
},
parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config,
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_in_one_fan_out_state_graph_waiting_edge(
snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer: BaseCheckpointSaver = request.getfixturevalue(
f"checkpointer_{checkpointer_name}"
)
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
workflow = StateGraph(State)
@workflow.add_node
def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
def retriever_two(data: State) -> State:
time.sleep(0.1) # to ensure stream order
return {"docs": ["doc3", "doc4"]}
def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
workflow.add_node(analyzer_one)
workflow.add_node(retriever_one)
workflow.add_node(retriever_two)
workflow.add_node(qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_edge("rewrite_query", "retriever_two")
workflow.add_edge(["retriever_one", "retriever_two"], "qa")
workflow.set_finish_point("qa")
app = workflow.compile()
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.invoke({"query": "what is weather in sf"}) == {
"query": "analyzed: query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
}
assert [*app.stream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["retriever_one"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c for c in app_w_interrupt.stream({"query": "what is weather in sf"}, config)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]
assert [c for c in app_w_interrupt.stream(None, config)] == [
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_before=["qa"],
)
config = {"configurable": {"thread_id": "2"}}
assert [
c for c in app_w_interrupt.stream({"query": "what is weather in sf"}, config)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]
app_w_interrupt.update_state(config, {"docs": ["doc5"]})
assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
"query": "analyzed: query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4", "doc5"],
},
tasks=(PregelTask(AnyStr(), "qa", (PULL, "qa")),),
next=("qa",),
config=app_w_interrupt.checkpointer.get_tuple(config).config,
created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 4,
"writes": {"retriever_one": {"docs": ["doc5"]}},
"thread_id": "2",
},
parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config,
)
assert [c for c in app_w_interrupt.stream(None, config, debug=1)] == [
{"qa": {"answer": "doc1,doc2,doc3,doc4,doc5"}},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_in_one_fan_out_state_graph_waiting_edge_via_branch(
snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer: BaseCheckpointSaver = request.getfixturevalue(
f"checkpointer_{checkpointer_name}"
)
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
def retriever_two(data: State) -> State:
time.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
def rewrite_query_then(data: State) -> Literal["retriever_two"]:
return "retriever_two"
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_conditional_edges("rewrite_query", rewrite_query_then)
workflow.add_edge(["retriever_one", "retriever_two"], "qa")
workflow.set_finish_point("qa")
app = workflow.compile()
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.invoke({"query": "what is weather in sf"}, debug=True) == {
"query": "analyzed: query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
}
assert [*app.stream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["retriever_one"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c for c in app_w_interrupt.stream({"query": "what is weather in sf"}, config)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]
assert [c for c in app_w_interrupt.stream(None, config)] == [
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1(
snapshot: SnapshotAssertion,
mocker: MockerFixture,
request: pytest.FixtureRequest,
checkpointer_name: str,
) -> None:
from pydantic.v1 import BaseModel, ValidationError
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
setup = mocker.Mock()
teardown = mocker.Mock()
@contextmanager
def assert_ctx_once() -> Iterator[None]:
assert setup.call_count == 0
assert teardown.call_count == 0
try:
yield
finally:
assert setup.call_count == 1
assert teardown.call_count == 1
setup.reset_mock()
teardown.reset_mock()
@contextmanager
def make_httpx_client() -> Iterator[httpx.Client]:
setup()
with httpx.Client() as client:
try:
yield client
finally:
teardown()
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class InnerObject(BaseModel):
yo: int
class State(BaseModel):
class Config:
arbitrary_types_allowed = True
query: str
inner: InnerObject
answer: Optional[str] = None
docs: Annotated[list[str], sorted_add]
client: Annotated[httpx.Client, Context(make_httpx_client)]
class Input(BaseModel):
query: str
inner: InnerObject
class Output(BaseModel):
answer: str
docs: list[str]
class StateUpdate(BaseModel):
query: Optional[str] = None
answer: Optional[str] = None
docs: Optional[list[str]] = None
def rewrite_query(data: State) -> State:
return {"query": f"query: {data.query}"}
def analyzer_one(data: State) -> State:
return StateUpdate(query=f"analyzed: {data.query}")
def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
def retriever_two(data: State) -> State:
time.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
def qa(data: State) -> State:
return {"answer": ",".join(data.docs)}
def decider(data: State) -> str:
assert isinstance(data, State)
return "retriever_two"
workflow = StateGraph(State, input=Input, output=Output)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_conditional_edges(
"rewrite_query", decider, {"retriever_two": "retriever_two"}
)
workflow.add_edge(["retriever_one", "retriever_two"], "qa")
workflow.set_finish_point("qa")
app = workflow.compile()
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.get_input_jsonschema() == snapshot
assert app.get_output_jsonschema() == snapshot
with pytest.raises(ValidationError), assert_ctx_once():
app.invoke({"query": {}})
with assert_ctx_once():
assert app.invoke({"query": "what is weather in sf", "inner": {"yo": 1}}) == {
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
}
with assert_ctx_once():
assert [
*app.stream({"query": "what is weather in sf", "inner": {"yo": 1}})
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["retriever_one"],
)
config = {"configurable": {"thread_id": "1"}}
with assert_ctx_once():
assert [
c
for c in app_w_interrupt.stream(
{"query": "what is weather in sf", "inner": {"yo": 1}}, config
)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]
with assert_ctx_once():
assert [c for c in app_w_interrupt.stream(None, config)] == [
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
with assert_ctx_once():
assert app_w_interrupt.update_state(
config, {"docs": ["doc5"]}, as_node="rewrite_query"
) == {
"configurable": {
"thread_id": "1",
"checkpoint_id": AnyStr(),
"checkpoint_ns": "",
}
}
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2(
snapshot: SnapshotAssertion,
mocker: MockerFixture,
request: pytest.FixtureRequest,
checkpointer_name: str,
) -> None:
from pydantic import BaseModel, ConfigDict, ValidationError
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
setup = mocker.Mock()
teardown = mocker.Mock()
@contextmanager
def assert_ctx_once() -> Iterator[None]:
assert setup.call_count == 0
assert teardown.call_count == 0
try:
yield
finally:
assert setup.call_count == 1
assert teardown.call_count == 1
setup.reset_mock()
teardown.reset_mock()
@contextmanager
def make_httpx_client() -> Iterator[httpx.Client]:
setup()
with httpx.Client() as client:
try:
yield client
finally:
teardown()
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class InnerObject(BaseModel):
yo: int
class State(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
query: str
inner: InnerObject
answer: Optional[str] = None
docs: Annotated[list[str], sorted_add]
client: Annotated[httpx.Client, Context(make_httpx_client)]
class StateUpdate(BaseModel):
query: Optional[str] = None
answer: Optional[str] = None
docs: Optional[list[str]] = None
class Input(BaseModel):
query: str
inner: InnerObject
class Output(BaseModel):
answer: str
docs: list[str]
def rewrite_query(data: State) -> State:
return {"query": f"query: {data.query}"}
def analyzer_one(data: State) -> State:
return StateUpdate(query=f"analyzed: {data.query}")
def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
def retriever_two(data: State) -> State:
time.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
def qa(data: State) -> State:
return {"answer": ",".join(data.docs)}
def decider(data: State) -> str:
assert isinstance(data, State)
return "retriever_two"
workflow = StateGraph(State, input=Input, output=Output)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_conditional_edges(
"rewrite_query", decider, {"retriever_two": "retriever_two"}
)
workflow.add_edge(["retriever_one", "retriever_two"], "qa")
workflow.set_finish_point("qa")
app = workflow.compile()
if SHOULD_CHECK_SNAPSHOTS:
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.get_input_schema().model_json_schema() == snapshot
assert app.get_output_schema().model_json_schema() == snapshot
with pytest.raises(ValidationError), assert_ctx_once():
app.invoke({"query": {}})
with assert_ctx_once():
assert app.invoke({"query": "what is weather in sf", "inner": {"yo": 1}}) == {
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
}
with assert_ctx_once():
assert [
*app.stream({"query": "what is weather in sf", "inner": {"yo": 1}})
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["retriever_one"],
)
config = {"configurable": {"thread_id": "1"}}
with assert_ctx_once():
assert [
c
for c in app_w_interrupt.stream(
{"query": "what is weather in sf", "inner": {"yo": 1}}, config
)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]
with assert_ctx_once():
assert [c for c in app_w_interrupt.stream(None, config)] == [
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
with assert_ctx_once():
assert app_w_interrupt.update_state(
config, {"docs": ["doc5"]}, as_node="rewrite_query"
) == {
"configurable": {
"thread_id": "1",
"checkpoint_id": AnyStr(),
"checkpoint_ns": "",
}
}
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_in_one_fan_out_state_graph_waiting_edge_plus_regular(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer: BaseCheckpointSaver = request.getfixturevalue(
f"checkpointer_{checkpointer_name}"
)
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
def analyzer_one(data: State) -> State:
time.sleep(0.1)
return {"query": f'analyzed: {data["query"]}'}
def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
def retriever_two(data: State) -> State:
time.sleep(0.2)
return {"docs": ["doc3", "doc4"]}
def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_edge("rewrite_query", "retriever_two")
workflow.add_edge(["retriever_one", "retriever_two"], "qa")
workflow.set_finish_point("qa")
# silly edge, to make sure having been triggered before doesn't break
# semantics of named barrier (== waiting edges)
workflow.add_edge("rewrite_query", "qa")
app = workflow.compile()
assert app.invoke({"query": "what is weather in sf"}) == {
"query": "analyzed: query: what is weather in sf",
"docs": ["doc1", "doc2", "doc3", "doc4"],
"answer": "doc1,doc2,doc3,doc4",
}
assert [*app.stream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"qa": {"answer": ""}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
app_w_interrupt = workflow.compile(
checkpointer=checkpointer,
interrupt_after=["retriever_one"],
)
config = {"configurable": {"thread_id": "1"}}
assert [
c for c in app_w_interrupt.stream({"query": "what is weather in sf"}, config)
] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"qa": {"answer": ""}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]
assert [c for c in app_w_interrupt.stream(None, config)] == [
{"qa": {"answer": "doc1,doc2,doc3,doc4"}},
]
def test_in_one_fan_out_state_graph_waiting_edge_multiple() -> None:
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
def retriever_two(data: State) -> State:
time.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
def decider(data: State) -> None:
return None
def decider_cond(data: State) -> str:
if data["query"].count("analyzed") > 1:
return "qa"
else:
return "rewrite_query"
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("decider", decider)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_edge("rewrite_query", "analyzer_one")
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_edge("rewrite_query", "retriever_two")
workflow.add_edge(["retriever_one", "retriever_two"], "decider")
workflow.add_conditional_edges("decider", decider_cond)
workflow.set_finish_point("qa")
app = workflow.compile()
assert app.invoke({"query": "what is weather in sf"}) == {
"query": "analyzed: query: analyzed: query: what is weather in sf",
"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4",
"docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"],
}
assert [*app.stream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"decider": None},
{"rewrite_query": {"query": "query: analyzed: query: what is weather in sf"}},
{
"analyzer_one": {
"query": "analyzed: query: analyzed: query: what is weather in sf"
}
},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"decider": None},
{"qa": {"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4"}},
]
def test_callable_in_conditional_edges_with_no_path_map() -> None:
class State(TypedDict, total=False):
query: str
def rewrite(data: State) -> State:
return {"query": f'query: {data["query"]}'}
def analyze(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
class ChooseAnalyzer:
def __call__(self, data: State) -> str:
return "analyzer"
workflow = StateGraph(State)
workflow.add_node("rewriter", rewrite)
workflow.add_node("analyzer", analyze)
workflow.add_conditional_edges("rewriter", ChooseAnalyzer())
workflow.set_entry_point("rewriter")
app = workflow.compile()
assert app.invoke({"query": "what is weather in sf"}) == {
"query": "analyzed: query: what is weather in sf",
}
def test_function_in_conditional_edges_with_no_path_map() -> None:
class State(TypedDict, total=False):
query: str
def rewrite(data: State) -> State:
return {"query": f'query: {data["query"]}'}
def analyze(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
def choose_analyzer(data: State) -> str:
return "analyzer"
workflow = StateGraph(State)
workflow.add_node("rewriter", rewrite)
workflow.add_node("analyzer", analyze)
workflow.add_conditional_edges("rewriter", choose_analyzer)
workflow.set_entry_point("rewriter")
app = workflow.compile()
assert app.invoke({"query": "what is weather in sf"}) == {
"query": "analyzed: query: what is weather in sf",
}
def test_in_one_fan_out_state_graph_waiting_edge_multiple_cond_edge() -> None:
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
def retriever_picker(data: State) -> list[str]:
return ["analyzer_one", "retriever_two"]
def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
def retriever_two(data: State) -> State:
time.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
def decider(data: State) -> None:
return None
def decider_cond(data: State) -> str:
if data["query"].count("analyzed") > 1:
return "qa"
else:
return "rewrite_query"
workflow = StateGraph(State)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("analyzer_one", analyzer_one)
workflow.add_node("retriever_one", retriever_one)
workflow.add_node("retriever_two", retriever_two)
workflow.add_node("decider", decider)
workflow.add_node("qa", qa)
workflow.set_entry_point("rewrite_query")
workflow.add_conditional_edges("rewrite_query", retriever_picker)
workflow.add_edge("analyzer_one", "retriever_one")
workflow.add_edge(["retriever_one", "retriever_two"], "decider")
workflow.add_conditional_edges("decider", decider_cond)
workflow.set_finish_point("qa")
app = workflow.compile()
assert app.invoke({"query": "what is weather in sf"}) == {
"query": "analyzed: query: analyzed: query: what is weather in sf",
"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4",
"docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"],
}
assert [*app.stream({"query": "what is weather in sf"})] == [
{"rewrite_query": {"query": "query: what is weather in sf"}},
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"decider": None},
{"rewrite_query": {"query": "query: analyzed: query: what is weather in sf"}},
{
"analyzer_one": {
"query": "analyzed: query: analyzed: query: what is weather in sf"
}
},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"decider": None},
{"qa": {"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4"}},
]
def test_simple_multi_edge(snapshot: SnapshotAssertion) -> None:
class State(TypedDict):
my_key: Annotated[str, operator.add]
def up(state: State):
pass
def side(state: State):
pass
def other(state: State):
return {"my_key": "_more"}
def down(state: State):
pass
graph = StateGraph(State)
graph.add_node("up", up)
graph.add_node("side", side)
graph.add_node("other", other)
graph.add_node("down", down)
graph.set_entry_point("up")
graph.add_edge("up", "side")
graph.add_edge("up", "other")
graph.add_edge(["up", "side"], "down")
graph.set_finish_point("down")
app = graph.compile()
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.invoke({"my_key": "my_value"}) == {"my_key": "my_value_more"}
assert [*app.stream({"my_key": "my_value"})] in (
[
{"up": None},
{"side": None},
{"other": {"my_key": "_more"}},
{"down": None},
],
[
{"up": None},
{"other": {"my_key": "_more"}},
{"side": None},
{"down": None},
],
)
def test_nested_graph_xray(snapshot: SnapshotAssertion) -> None:
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str
def logic(state: State):
pass
tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two_slow", logic)
tool_two_graph.add_node("tool_two_fast", logic)
tool_two_graph.set_conditional_entry_point(
lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
then=END,
)
tool_two = tool_two_graph.compile()
graph = StateGraph(State)
graph.add_node("tool_one", logic)
graph.add_node("tool_two", tool_two)
graph.add_node("tool_three", logic)
graph.set_conditional_entry_point(lambda s: "tool_one", then=END)
app = graph.compile()
assert app.get_graph(xray=True).to_json() == snapshot
assert app.get_graph(xray=True).draw_mermaid() == snapshot
def test_nested_graph(snapshot: SnapshotAssertion) -> None:
def never_called_fn(state: Any):
assert 0, "This function should never be called"
never_called = RunnableLambda(never_called_fn)
class InnerState(TypedDict):
my_key: str
my_other_key: str
def up(state: InnerState):
return {"my_key": state["my_key"] + " there", "my_other_key": state["my_key"]}
inner = StateGraph(InnerState)
inner.add_node("up", up)
inner.set_entry_point("up")
inner.set_finish_point("up")
class State(TypedDict):
my_key: str
never_called: Any
def side(state: State):
return {"my_key": state["my_key"] + " and back again"}
graph = StateGraph(State)
graph.add_node("inner", inner.compile())
graph.add_node("side", side)
graph.set_entry_point("inner")
graph.add_edge("inner", "side")
graph.set_finish_point("side")
app = graph.compile()
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.get_graph(xray=True).draw_mermaid() == snapshot
assert app.invoke(
{"my_key": "my value", "never_called": never_called}, debug=True
) == {
"my_key": "my value there and back again",
"never_called": never_called,
}
assert [*app.stream({"my_key": "my value", "never_called": never_called})] == [
{"inner": {"my_key": "my value there"}},
{"side": {"my_key": "my value there and back again"}},
]
assert [
*app.stream(
{"my_key": "my value", "never_called": never_called}, stream_mode="values"
)
] == [
{
"my_key": "my value",
"never_called": never_called,
},
{
"my_key": "my value there",
"never_called": never_called,
},
{
"my_key": "my value there and back again",
"never_called": never_called,
},
]
chain = app | RunnablePassthrough()
assert chain.invoke({"my_key": "my value", "never_called": never_called}) == {
"my_key": "my value there and back again",
"never_called": never_called,
}
assert [*chain.stream({"my_key": "my value", "never_called": never_called})] == [
{"inner": {"my_key": "my value there"}},
{"side": {"my_key": "my value there and back again"}},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_stream_subgraphs_during_execution(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name)
class InnerState(TypedDict):
my_key: Annotated[str, operator.add]
my_other_key: str
def inner_1(state: InnerState):
return {"my_key": "got here", "my_other_key": state["my_key"]}
def inner_2(state: InnerState):
time.sleep(0.5)
return {
"my_key": " and there",
"my_other_key": state["my_key"],
}
inner = StateGraph(InnerState)
inner.add_node("inner_1", inner_1)
inner.add_node("inner_2", inner_2)
inner.add_edge("inner_1", "inner_2")
inner.set_entry_point("inner_1")
inner.set_finish_point("inner_2")
class State(TypedDict):
my_key: Annotated[str, operator.add]
def outer_1(state: State):
time.sleep(0.2)
return {"my_key": " and parallel"}
def outer_2(state: State):
return {"my_key": " and back again"}
graph = StateGraph(State)
graph.add_node("inner", inner.compile())
graph.add_node("outer_1", outer_1)
graph.add_node("outer_2", outer_2)
graph.add_edge(START, "inner")
graph.add_edge(START, "outer_1")
graph.add_edge(["inner", "outer_1"], "outer_2")
graph.add_edge("outer_2", END)
app = graph.compile(checkpointer=checkpointer)
start = time.perf_counter()
chunks: list[tuple[float, Any]] = []
config = {"configurable": {"thread_id": "2"}}
for c in app.stream({"my_key": ""}, config, subgraphs=True):
chunks.append((round(time.perf_counter() - start, 1), c))
for idx in range(len(chunks)):
elapsed, c = chunks[idx]
chunks[idx] = (round(elapsed - chunks[0][0], 1), c)
assert chunks == [
# arrives before "inner" finishes
(
FloatBetween(0.0, 0.1),
(
(AnyStr("inner:"),),
{"inner_1": {"my_key": "got here", "my_other_key": ""}},
),
),
(FloatBetween(0.2, 0.3), ((), {"outer_1": {"my_key": " and parallel"}})),
(
FloatBetween(0.5, 0.8),
(
(AnyStr("inner:"),),
{"inner_2": {"my_key": " and there", "my_other_key": "got here"}},
),
),
(FloatBetween(0.5, 0.8), ((), {"inner": {"my_key": "got here and there"}})),
(FloatBetween(0.5, 0.8), ((), {"outer_2": {"my_key": " and back again"}})),
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_stream_buffering_single_node(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name)
class State(TypedDict):
my_key: Annotated[str, operator.add]
def node(state: State, writer: StreamWriter):
writer("Before sleep")
time.sleep(0.2)
writer("After sleep")
return {"my_key": "got here"}
builder = StateGraph(State)
builder.add_node("node", node)
builder.add_edge(START, "node")
builder.add_edge("node", END)
graph = builder.compile(checkpointer=checkpointer)
start = time.perf_counter()
chunks: list[tuple[float, Any]] = []
config = {"configurable": {"thread_id": "2"}}
for c in graph.stream({"my_key": ""}, config, stream_mode="custom"):
chunks.append((round(time.perf_counter() - start, 1), c))
assert chunks == [
(FloatBetween(0.0, 0.1), "Before sleep"),
(FloatBetween(0.2, 0.3), "After sleep"),
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_nested_graph_interrupts_parallel(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name)
class InnerState(TypedDict):
my_key: Annotated[str, operator.add]
my_other_key: str
def inner_1(state: InnerState):
time.sleep(0.1)
return {"my_key": "got here", "my_other_key": state["my_key"]}
def inner_2(state: InnerState):
return {
"my_key": " and there",
"my_other_key": state["my_key"],
}
inner = StateGraph(InnerState)
inner.add_node("inner_1", inner_1)
inner.add_node("inner_2", inner_2)
inner.add_edge("inner_1", "inner_2")
inner.set_entry_point("inner_1")
inner.set_finish_point("inner_2")
class State(TypedDict):
my_key: Annotated[str, operator.add]
def outer_1(state: State):
return {"my_key": " and parallel"}
def outer_2(state: State):
return {"my_key": " and back again"}
graph = StateGraph(State)
graph.add_node("inner", inner.compile(interrupt_before=["inner_2"]))
graph.add_node("outer_1", outer_1)
graph.add_node("outer_2", outer_2)
graph.add_edge(START, "inner")
graph.add_edge(START, "outer_1")
graph.add_edge(["inner", "outer_1"], "outer_2")
graph.set_finish_point("outer_2")
app = graph.compile(checkpointer=checkpointer)
# test invoke w/ nested interrupt
config = {"configurable": {"thread_id": "1"}}
assert app.invoke({"my_key": ""}, config, debug=True) == {
"my_key": " and parallel",
}
assert app.invoke(None, config, debug=True) == {
"my_key": "got here and there and parallel and back again",
}
# below combo of assertions is asserting two things
# - outer_1 finishes before inner interrupts (because we see its output in stream, which only happens after node finishes)
# - the writes of outer are persisted in 1st call and used in 2nd call, ie outer isn't called again (because we dont see outer_1 output again in 2nd stream)
# test stream updates w/ nested interrupt
config = {"configurable": {"thread_id": "2"}}
assert [*app.stream({"my_key": ""}, config, subgraphs=True)] == [
# we got to parallel node first
((), {"outer_1": {"my_key": " and parallel"}}),
((AnyStr("inner:"),), {"inner_1": {"my_key": "got here", "my_other_key": ""}}),
((), {"__interrupt__": ()}),
]
assert [*app.stream(None, config)] == [
{"outer_1": {"my_key": " and parallel"}, "__metadata__": {"cached": True}},
{"inner": {"my_key": "got here and there"}},
{"outer_2": {"my_key": " and back again"}},
]
# test stream values w/ nested interrupt
config = {"configurable": {"thread_id": "3"}}
assert [*app.stream({"my_key": ""}, config, stream_mode="values")] == [
{"my_key": ""},
{"my_key": " and parallel"},
]
assert [*app.stream(None, config, stream_mode="values")] == [
{"my_key": ""},
{"my_key": "got here and there and parallel"},
{"my_key": "got here and there and parallel and back again"},
]
# test interrupts BEFORE the parallel node
app = graph.compile(checkpointer=checkpointer, interrupt_before=["outer_1"])
config = {"configurable": {"thread_id": "4"}}
assert [*app.stream({"my_key": ""}, config, stream_mode="values")] == [
{"my_key": ""}
]
# while we're waiting for the node w/ interrupt inside to finish
assert [*app.stream(None, config, stream_mode="values")] == [
{"my_key": ""},
{"my_key": " and parallel"},
]
assert [*app.stream(None, config, stream_mode="values")] == [
{"my_key": ""},
{"my_key": "got here and there and parallel"},
{"my_key": "got here and there and parallel and back again"},
]
# test interrupts AFTER the parallel node
app = graph.compile(checkpointer=checkpointer, interrupt_after=["outer_1"])
config = {"configurable": {"thread_id": "5"}}
assert [*app.stream({"my_key": ""}, config, stream_mode="values")] == [
{"my_key": ""},
{"my_key": " and parallel"},
]
assert [*app.stream(None, config, stream_mode="values")] == [
{"my_key": ""},
{"my_key": "got here and there and parallel"},
]
assert [*app.stream(None, config, stream_mode="values")] == [
{"my_key": "got here and there and parallel"},
{"my_key": "got here and there and parallel and back again"},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_doubly_nested_graph_interrupts(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name)
class State(TypedDict):
my_key: str
class ChildState(TypedDict):
my_key: str
class GrandChildState(TypedDict):
my_key: str
def grandchild_1(state: ChildState):
return {"my_key": state["my_key"] + " here"}
def grandchild_2(state: ChildState):
return {
"my_key": state["my_key"] + " and there",
}
grandchild = StateGraph(GrandChildState)
grandchild.add_node("grandchild_1", grandchild_1)
grandchild.add_node("grandchild_2", grandchild_2)
grandchild.add_edge("grandchild_1", "grandchild_2")
grandchild.set_entry_point("grandchild_1")
grandchild.set_finish_point("grandchild_2")
child = StateGraph(ChildState)
child.add_node(
"child_1",
grandchild.compile(interrupt_before=["grandchild_2"]),
)
child.set_entry_point("child_1")
child.set_finish_point("child_1")
def parent_1(state: State):
return {"my_key": "hi " + state["my_key"]}
def parent_2(state: State):
return {"my_key": state["my_key"] + " and back again"}
graph = StateGraph(State)
graph.add_node("parent_1", parent_1)
graph.add_node("child", child.compile())
graph.add_node("parent_2", parent_2)
graph.set_entry_point("parent_1")
graph.add_edge("parent_1", "child")
graph.add_edge("child", "parent_2")
graph.set_finish_point("parent_2")
app = graph.compile(checkpointer=checkpointer)
# test invoke w/ nested interrupt
config = {"configurable": {"thread_id": "1"}}
assert app.invoke({"my_key": "my value"}, config, debug=True) == {
"my_key": "hi my value",
}
assert app.invoke(None, config, debug=True) == {
"my_key": "hi my value here and there and back again",
}
# test stream updates w/ nested interrupt
nodes: list[str] = []
config = {
"configurable": {"thread_id": "2", CONFIG_KEY_NODE_FINISHED: nodes.append}
}
assert [*app.stream({"my_key": "my value"}, config)] == [
{"parent_1": {"my_key": "hi my value"}},
{"__interrupt__": ()},
]
assert nodes == ["parent_1", "grandchild_1"]
assert [*app.stream(None, config)] == [
{"child": {"my_key": "hi my value here and there"}},
{"parent_2": {"my_key": "hi my value here and there and back again"}},
]
assert nodes == [
"parent_1",
"grandchild_1",
"grandchild_2",
"child_1",
"child",
"parent_2",
]
# test stream values w/ nested interrupt
config = {"configurable": {"thread_id": "3"}}
assert [*app.stream({"my_key": "my value"}, config, stream_mode="values")] == [
{"my_key": "my value"},
{"my_key": "hi my value"},
]
assert [*app.stream(None, config, stream_mode="values")] == [
{"my_key": "hi my value"},
{"my_key": "hi my value here and there"},
{"my_key": "hi my value here and there and back again"},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_nested_graph_state(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name)
class InnerState(TypedDict):
my_key: str
my_other_key: str
def inner_1(state: InnerState):
return {
"my_key": state["my_key"] + " here",
"my_other_key": state["my_key"],
}
def inner_2(state: InnerState):
return {
"my_key": state["my_key"] + " and there",
"my_other_key": state["my_key"],
}
inner = StateGraph(InnerState)
inner.add_node("inner_1", inner_1)
inner.add_node("inner_2", inner_2)
inner.add_edge("inner_1", "inner_2")
inner.set_entry_point("inner_1")
inner.set_finish_point("inner_2")
class State(TypedDict):
my_key: str
other_parent_key: str
def outer_1(state: State):
return {"my_key": "hi " + state["my_key"]}
def outer_2(state: State):
return {"my_key": state["my_key"] + " and back again"}
graph = StateGraph(State)
graph.add_node("outer_1", outer_1)
graph.add_node(
"inner",
inner.compile(interrupt_before=["inner_2"]),
)
graph.add_node("outer_2", outer_2)
graph.set_entry_point("outer_1")
graph.add_edge("outer_1", "inner")
graph.add_edge("inner", "outer_2")
graph.set_finish_point("outer_2")
app = graph.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}
app.invoke({"my_key": "my value"}, config, debug=True)
# test state w/ nested subgraph state (right after interrupt)
# first get_state without subgraph state
assert app.get_state(config) == StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"inner",
(PULL, "inner"),
state={"configurable": {"thread_id": "1", "checkpoint_ns": AnyStr()}},
),
),
next=("inner",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"outer_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
# now, get_state with subgraphs state
assert app.get_state(config, subgraphs=True) == StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"inner",
(PULL, "inner"),
state=StateSnapshot(
values={
"my_key": "hi my value here",
"my_other_key": "hi my value",
},
tasks=(
PregelTask(
AnyStr(),
"inner_2",
(PULL, "inner_2"),
),
),
next=("inner_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"parents": {
"": AnyStr(),
},
"source": "loop",
"writes": {
"inner_1": {
"my_key": "hi my value here",
"my_other_key": "hi my value",
}
},
"step": 1,
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"langgraph_node": "inner",
"langgraph_path": [PULL, "inner"],
"langgraph_step": 2,
"langgraph_triggers": ["outer_1"],
"langgraph_checkpoint_ns": AnyStr("inner:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
),
),
),
next=("inner",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"outer_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
# get_state_history returns outer graph checkpoints
history = list(app.get_state_history(config))
assert history == [
StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"inner",
(PULL, "inner"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
}
},
),
),
next=("inner",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"outer_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "my value"},
tasks=(
PregelTask(
AnyStr(),
"outer_1",
(PULL, "outer_1"),
result={"my_key": "hi my value"},
),
),
next=("outer_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": None,
"step": 0,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={},
tasks=(
PregelTask(
AnyStr(),
"__start__",
(PULL, "__start__"),
result={"my_key": "my value"},
),
),
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"writes": {"__start__": {"my_key": "my value"}},
"step": -1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
),
]
# get_state_history for a subgraph returns its checkpoints
child_history = [*app.get_state_history(history[0].tasks[0].state)]
assert child_history == [
StateSnapshot(
values={"my_key": "hi my value here", "my_other_key": "hi my value"},
next=("inner_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "loop",
"writes": {
"inner_1": {
"my_key": "hi my value here",
"my_other_key": "hi my value",
}
},
"step": 1,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"langgraph_node": "inner",
"langgraph_path": [PULL, "inner"],
"langgraph_step": 2,
"langgraph_triggers": ["outer_1"],
"langgraph_checkpoint_ns": AnyStr("inner:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
tasks=(PregelTask(AnyStr(), "inner_2", (PULL, "inner_2")),),
),
StateSnapshot(
values={"my_key": "hi my value"},
next=("inner_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "loop",
"writes": None,
"step": 0,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"langgraph_node": "inner",
"langgraph_path": [PULL, "inner"],
"langgraph_step": 2,
"langgraph_triggers": ["outer_1"],
"langgraph_checkpoint_ns": AnyStr("inner:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
tasks=(
PregelTask(
AnyStr(),
"inner_1",
(PULL, "inner_1"),
result={
"my_key": "hi my value here",
"my_other_key": "hi my value",
},
),
),
),
StateSnapshot(
values={},
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "input",
"writes": {"__start__": {"my_key": "hi my value"}},
"step": -1,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("inner:"),
"langgraph_node": "inner",
"langgraph_path": [PULL, "inner"],
"langgraph_step": 2,
"langgraph_triggers": ["outer_1"],
"langgraph_checkpoint_ns": AnyStr("inner:"),
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
AnyStr(),
"__start__",
(PULL, "__start__"),
result={"my_key": "hi my value"},
),
),
),
]
# resume
app.invoke(None, config, debug=True)
# test state w/ nested subgraph state (after resuming from interrupt)
assert app.get_state(config) == StateSnapshot(
values={"my_key": "hi my value here and there and back again"},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"outer_2": {"my_key": "hi my value here and there and back again"}
},
"step": 3,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
# test full history at the end
actual_history = list(app.get_state_history(config))
expected_history = [
StateSnapshot(
values={"my_key": "hi my value here and there and back again"},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"outer_2": {"my_key": "hi my value here and there and back again"}
},
"step": 3,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "hi my value here and there"},
tasks=(
PregelTask(
AnyStr(),
"outer_2",
(PULL, "outer_2"),
result={"my_key": "hi my value here and there and back again"},
),
),
next=("outer_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"inner": {"my_key": "hi my value here and there"}},
"step": 2,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"inner",
(PULL, "inner"),
state={
"configurable": {"thread_id": "1", "checkpoint_ns": AnyStr()}
},
result={"my_key": "hi my value here and there"},
),
),
next=("inner",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"outer_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "my value"},
tasks=(
PregelTask(
AnyStr(),
"outer_1",
(PULL, "outer_1"),
result={"my_key": "hi my value"},
),
),
next=("outer_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": None,
"step": 0,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={},
tasks=(
PregelTask(
AnyStr(),
"__start__",
(PULL, "__start__"),
result={"my_key": "my value"},
),
),
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"writes": {"__start__": {"my_key": "my value"}},
"step": -1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
),
]
assert actual_history == expected_history
# test looking up parent state by checkpoint ID
for actual_snapshot, expected_snapshot in zip(actual_history, expected_history):
assert app.get_state(actual_snapshot.config) == expected_snapshot
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_doubly_nested_graph_state(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name)
class State(TypedDict):
my_key: str
class ChildState(TypedDict):
my_key: str
class GrandChildState(TypedDict):
my_key: str
def grandchild_1(state: ChildState):
return {"my_key": state["my_key"] + " here"}
def grandchild_2(state: ChildState):
return {
"my_key": state["my_key"] + " and there",
}
grandchild = StateGraph(GrandChildState)
grandchild.add_node("grandchild_1", grandchild_1)
grandchild.add_node("grandchild_2", grandchild_2)
grandchild.add_edge("grandchild_1", "grandchild_2")
grandchild.set_entry_point("grandchild_1")
grandchild.set_finish_point("grandchild_2")
child = StateGraph(ChildState)
child.add_node(
"child_1",
grandchild.compile(interrupt_before=["grandchild_2"]),
)
child.set_entry_point("child_1")
child.set_finish_point("child_1")
def parent_1(state: State):
return {"my_key": "hi " + state["my_key"]}
def parent_2(state: State):
return {"my_key": state["my_key"] + " and back again"}
graph = StateGraph(State)
graph.add_node("parent_1", parent_1)
graph.add_node("child", child.compile())
graph.add_node("parent_2", parent_2)
graph.set_entry_point("parent_1")
graph.add_edge("parent_1", "child")
graph.add_edge("child", "parent_2")
graph.set_finish_point("parent_2")
app = graph.compile(checkpointer=checkpointer)
# test invoke w/ nested interrupt
config = {"configurable": {"thread_id": "1"}}
assert [c for c in app.stream({"my_key": "my value"}, config, subgraphs=True)] == [
((), {"parent_1": {"my_key": "hi my value"}}),
(
(AnyStr("child:"), AnyStr("child_1:")),
{"grandchild_1": {"my_key": "hi my value here"}},
),
((), {"__interrupt__": ()}),
]
# get state without subgraphs
outer_state = app.get_state(config)
assert outer_state == StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child",
(PULL, "child"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child"),
}
},
),
),
next=("child",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"parent_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
child_state = app.get_state(outer_state.tasks[0].state)
assert (
child_state.tasks[0]
== StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child_1",
(PULL, "child_1"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
}
},
),
),
next=("child_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {"": AnyStr()},
"source": "loop",
"writes": None,
"step": 0,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
}
},
).tasks[0]
)
grandchild_state = app.get_state(child_state.tasks[0].state)
assert grandchild_state == StateSnapshot(
values={"my_key": "hi my value here"},
tasks=(
PregelTask(
AnyStr(),
"grandchild_2",
(PULL, "grandchild_2"),
),
),
next=("grandchild_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"source": "loop",
"writes": {"grandchild_1": {"my_key": "hi my value here"}},
"step": 1,
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [PULL, AnyStr("child_1")],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
)
# get state with subgraphs
assert app.get_state(config, subgraphs=True) == StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child",
(PULL, "child"),
state=StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child_1",
(PULL, "child_1"),
state=StateSnapshot(
values={"my_key": "hi my value here"},
tasks=(
PregelTask(
AnyStr(),
"grandchild_2",
(PULL, "grandchild_2"),
),
),
next=("grandchild_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(
re.compile(r"child:.+|child1:")
): AnyStr(),
}
),
}
},
metadata={
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"source": "loop",
"writes": {
"grandchild_1": {"my_key": "hi my value here"}
},
"step": 1,
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [
PULL,
AnyStr("child_1"),
],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(
re.compile(r"child:.+|child1:")
): AnyStr(),
}
),
}
},
),
),
),
next=("child_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"parents": {"": AnyStr()},
"source": "loop",
"writes": None,
"step": 0,
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child",
"langgraph_path": [PULL, AnyStr("child")],
"langgraph_step": 2,
"langgraph_triggers": [AnyStr("parent_1")],
"langgraph_checkpoint_ns": AnyStr("child:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
),
),
),
next=("child",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"parent_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
# # resume
assert [c for c in app.stream(None, config, subgraphs=True)] == [
(
(AnyStr("child:"), AnyStr("child_1:")),
{"grandchild_2": {"my_key": "hi my value here and there"}},
),
((AnyStr("child:"),), {"child_1": {"my_key": "hi my value here and there"}}),
((), {"child": {"my_key": "hi my value here and there"}}),
((), {"parent_2": {"my_key": "hi my value here and there and back again"}}),
]
# get state with and without subgraphs
assert (
app.get_state(config)
== app.get_state(config, subgraphs=True)
== StateSnapshot(
values={"my_key": "hi my value here and there and back again"},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"parent_2": {"my_key": "hi my value here and there and back again"}
},
"step": 3,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
)
# get outer graph history
outer_history = list(app.get_state_history(config))
assert outer_history == [
StateSnapshot(
values={"my_key": "hi my value here and there and back again"},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"parent_2": {"my_key": "hi my value here and there and back again"}
},
"step": 3,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "hi my value here and there"},
next=("parent_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {"child": {"my_key": "hi my value here and there"}},
"step": 2,
"parents": {},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="parent_2",
path=(PULL, "parent_2"),
result={"my_key": "hi my value here and there and back again"},
),
),
),
StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child",
(PULL, "child"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child"),
}
},
result={"my_key": "hi my value here and there"},
),
),
next=("child",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {"parent_1": {"my_key": "hi my value"}},
"step": 1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"my_key": "my value"},
next=("parent_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": None,
"step": 0,
"parents": {},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="parent_1",
path=(PULL, "parent_1"),
result={"my_key": "hi my value"},
),
),
),
StateSnapshot(
values={},
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "input",
"writes": {"__start__": {"my_key": "my value"}},
"step": -1,
"parents": {},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
id=AnyStr(),
name="__start__",
path=(PULL, "__start__"),
result={"my_key": "my value"},
),
),
),
]
# get child graph history
child_history = list(app.get_state_history(outer_history[2].tasks[0].state))
assert child_history == [
StateSnapshot(
values={"my_key": "hi my value here and there"},
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "loop",
"writes": {"child_1": {"my_key": "hi my value here and there"}},
"step": 1,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child",
"langgraph_path": [PULL, AnyStr("child")],
"langgraph_step": 2,
"langgraph_triggers": [AnyStr("parent_1")],
"langgraph_checkpoint_ns": AnyStr("child:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
tasks=(),
),
StateSnapshot(
values={"my_key": "hi my value"},
next=("child_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "loop",
"writes": None,
"step": 0,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child",
"langgraph_path": [PULL, AnyStr("child")],
"langgraph_step": 2,
"langgraph_triggers": [AnyStr("parent_1")],
"langgraph_checkpoint_ns": AnyStr("child:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="child_1",
path=(PULL, "child_1"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
}
},
result={"my_key": "hi my value here and there"},
),
),
),
StateSnapshot(
values={},
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "input",
"writes": {"__start__": {"my_key": "hi my value"}},
"step": -1,
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child",
"langgraph_path": [PULL, AnyStr("child")],
"langgraph_step": 2,
"langgraph_triggers": [AnyStr("parent_1")],
"langgraph_checkpoint_ns": AnyStr("child:"),
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
id=AnyStr(),
name="__start__",
path=(PULL, "__start__"),
result={"my_key": "hi my value"},
),
),
),
]
# get grandchild graph history
grandchild_history = list(app.get_state_history(child_history[1].tasks[0].state))
assert grandchild_history == [
StateSnapshot(
values={"my_key": "hi my value here and there"},
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"source": "loop",
"writes": {"grandchild_2": {"my_key": "hi my value here and there"}},
"step": 2,
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [
PULL,
AnyStr("child_1"),
],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
tasks=(),
),
StateSnapshot(
values={"my_key": "hi my value here"},
next=("grandchild_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"source": "loop",
"writes": {"grandchild_1": {"my_key": "hi my value here"}},
"step": 1,
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [
PULL,
AnyStr("child_1"),
],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="grandchild_2",
path=(PULL, "grandchild_2"),
result={"my_key": "hi my value here and there"},
),
),
),
StateSnapshot(
values={"my_key": "hi my value"},
next=("grandchild_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"source": "loop",
"writes": None,
"step": 0,
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [
PULL,
AnyStr("child_1"),
],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="grandchild_1",
path=(PULL, "grandchild_1"),
result={"my_key": "hi my value here"},
),
),
),
StateSnapshot(
values={},
next=("__start__",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"source": "input",
"writes": {"__start__": {"my_key": "hi my value"}},
"step": -1,
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"langgraph_checkpoint_ns": AnyStr("child:"),
"langgraph_node": "child_1",
"langgraph_path": [
PULL,
AnyStr("child_1"),
],
"langgraph_step": 1,
"langgraph_triggers": [AnyStr("start:child_1")],
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
id=AnyStr(),
name="__start__",
path=(PULL, "__start__"),
result={"my_key": "hi my value"},
),
),
),
]
# replay grandchild checkpoint
assert [
c for c in app.stream(None, grandchild_history[2].config, subgraphs=True)
] == [
(
(AnyStr("child:"), AnyStr("child_1:")),
{"grandchild_1": {"my_key": "hi my value here"}},
),
((), {"__interrupt__": ()}),
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_send_to_nested_graphs(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name)
class OverallState(TypedDict):
subjects: list[str]
jokes: Annotated[list[str], operator.add]
def continue_to_jokes(state: OverallState):
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
class JokeState(TypedDict):
subject: str
def edit(state: JokeState):
subject = state["subject"]
return {"subject": f"{subject} - hohoho"}
# subgraph
subgraph = StateGraph(JokeState, output=OverallState)
subgraph.add_node("edit", edit)
subgraph.add_node(
"generate", lambda state: {"jokes": [f"Joke about {state['subject']}"]}
)
subgraph.set_entry_point("edit")
subgraph.add_edge("edit", "generate")
subgraph.set_finish_point("generate")
# parent graph
builder = StateGraph(OverallState)
builder.add_node(
"generate_joke",
subgraph.compile(interrupt_before=["generate"]),
)
builder.add_conditional_edges(START, continue_to_jokes)
builder.add_edge("generate_joke", END)
graph = builder.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}
tracer = FakeTracer()
# invoke and pause at nested interrupt
assert graph.invoke(
{"subjects": ["cats", "dogs"]}, config={**config, "callbacks": [tracer]}
) == {
"subjects": ["cats", "dogs"],
"jokes": [],
}
assert len(tracer.runs) == 1, "Should produce exactly 1 root run"
# check state
outer_state = graph.get_state(config)
if not FF_SEND_V2:
# update state of dogs joke graph
graph.update_state(outer_state.tasks[1].state, {"subject": "turtles - hohoho"})
# continue past interrupt
assert sorted(
graph.stream(None, config=config),
key=lambda d: d["generate_joke"]["jokes"][0],
) == [
{"generate_joke": {"jokes": ["Joke about cats - hohoho"]}},
{"generate_joke": {"jokes": ["Joke about turtles - hohoho"]}},
]
return
assert outer_state == StateSnapshot(
values={"subjects": ["cats", "dogs"], "jokes": []},
tasks=(
PregelTask(
id=AnyStr(),
name="__start__",
path=("__pregel_pull", "__start__"),
error=None,
interrupts=(),
state=None,
result={"subjects": ["cats", "dogs"]},
),
PregelTask(
AnyStr(),
"generate_joke",
(PUSH, ("__pregel_pull", "__start__"), 1, AnyStr()),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
}
},
),
PregelTask(
AnyStr(),
"generate_joke",
(PUSH, ("__pregel_pull", "__start__"), 2, AnyStr()),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
}
},
),
),
next=("generate_joke", "generate_joke"),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"writes": {"__start__": {"subjects": ["cats", "dogs"]}},
"step": -1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
)
# check state of each of the inner tasks
assert graph.get_state(outer_state.tasks[1].state) == StateSnapshot(
values={"subject": "cats - hohoho", "jokes": []},
next=("generate",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("generate_joke:"): AnyStr(),
}
),
}
},
metadata={
"step": 1,
"source": "loop",
"writes": {"edit": None},
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
"langgraph_checkpoint_ns": AnyStr("generate_joke:"),
"langgraph_node": "generate_joke",
"langgraph_path": [PUSH, ["__pregel_pull", "__start__"], 1, AnyStr()],
"langgraph_step": 0,
"langgraph_triggers": [PUSH],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("generate_joke:"): AnyStr(),
}
),
}
},
tasks=(PregelTask(id=AnyStr(""), name="generate", path=(PULL, "generate")),),
)
assert graph.get_state(outer_state.tasks[2].state) == StateSnapshot(
values={"subject": "dogs - hohoho", "jokes": []},
next=("generate",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("generate_joke:"): AnyStr(),
}
),
}
},
metadata={
"step": 1,
"source": "loop",
"writes": {"edit": None},
"parents": {"": AnyStr()},
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
"langgraph_checkpoint_ns": AnyStr("generate_joke:"),
"langgraph_node": "generate_joke",
"langgraph_path": [PUSH, ["__pregel_pull", "__start__"], 2, AnyStr()],
"langgraph_step": 0,
"langgraph_triggers": [PUSH],
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("generate_joke:"): AnyStr(),
}
),
}
},
tasks=(PregelTask(id=AnyStr(""), name="generate", path=(PULL, "generate")),),
)
# update state of dogs joke graph
graph.update_state(
outer_state.tasks[2 if FF_SEND_V2 else 1].state, {"subject": "turtles - hohoho"}
)
# continue past interrupt
assert sorted(
graph.stream(None, config=config), key=lambda d: d["generate_joke"]["jokes"][0]
) == [
{"generate_joke": {"jokes": ["Joke about cats - hohoho"]}},
{"generate_joke": {"jokes": ["Joke about turtles - hohoho"]}},
]
actual_snapshot = graph.get_state(config)
expected_snapshot = StateSnapshot(
values={
"subjects": ["cats", "dogs"],
"jokes": ["Joke about cats - hohoho", "Joke about turtles - hohoho"],
},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"generate_joke": [
{"jokes": ["Joke about cats - hohoho"]},
{"jokes": ["Joke about turtles - hohoho"]},
]
},
"step": 0,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
)
assert actual_snapshot == expected_snapshot
# test full history
actual_history = list(graph.get_state_history(config))
# get subgraph node state for expected history
expected_history = [
StateSnapshot(
values={
"subjects": ["cats", "dogs"],
"jokes": ["Joke about cats - hohoho", "Joke about turtles - hohoho"],
},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"writes": {
"generate_joke": [
{"jokes": ["Joke about cats - hohoho"]},
{"jokes": ["Joke about turtles - hohoho"]},
]
},
"step": 0,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
),
StateSnapshot(
values={"jokes": []},
tasks=(
PregelTask(
id=AnyStr(),
name="__start__",
path=("__pregel_pull", "__start__"),
error=None,
interrupts=(),
state=None,
result={"subjects": ["cats", "dogs"]},
),
PregelTask(
AnyStr(),
"generate_joke",
(PUSH, ("__pregel_pull", "__start__"), 1, AnyStr()),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
}
},
result={"jokes": ["Joke about cats - hohoho"]},
),
PregelTask(
AnyStr(),
"generate_joke",
(PUSH, ("__pregel_pull", "__start__"), 2, AnyStr()),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("generate_joke:"),
}
},
result={"jokes": ["Joke about turtles - hohoho"]},
),
),
next=("__start__", "generate_joke", "generate_joke"),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "input",
"writes": {"__start__": {"subjects": ["cats", "dogs"]}},
"step": -1,
"thread_id": "1",
},
created_at=AnyStr(),
parent_config=None,
),
]
assert actual_history == expected_history
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_weather_subgraph(
request: pytest.FixtureRequest, checkpointer_name: str, snapshot: SnapshotAssertion
) -> None:
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, ToolCall
from langchain_core.tools import tool
from langgraph.graph import MessagesState
checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name)
# setup subgraph
@tool
def get_weather(city: str):
"""Get the weather for a specific city"""
return f"I'ts sunny in {city}!"
weather_model = FakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
ToolCall(
id="tool_call123",
name="get_weather",
args={"city": "San Francisco"},
)
],
)
]
)
class SubGraphState(MessagesState):
city: str
def model_node(state: SubGraphState, writer: StreamWriter):
writer(" very")
result = weather_model.invoke(state["messages"])
return {"city": cast(AIMessage, result).tool_calls[0]["args"]["city"]}
def weather_node(state: SubGraphState, writer: StreamWriter):
writer(" good")
result = get_weather.invoke({"city": state["city"]})
return {"messages": [{"role": "assistant", "content": result}]}
subgraph = StateGraph(SubGraphState)
subgraph.add_node(model_node)
subgraph.add_node(weather_node)
subgraph.add_edge(START, "model_node")
subgraph.add_edge("model_node", "weather_node")
subgraph.add_edge("weather_node", END)
subgraph = subgraph.compile(interrupt_before=["weather_node"])
# setup main graph
class RouterState(MessagesState):
route: Literal["weather", "other"]
router_model = FakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
ToolCall(
id="tool_call123",
name="router",
args={"dest": "weather"},
)
],
)
]
)
def router_node(state: RouterState, writer: StreamWriter):
writer("I'm")
system_message = "Classify the incoming query as either about weather or not."
messages = [{"role": "system", "content": system_message}] + state["messages"]
route = router_model.invoke(messages)
return {"route": cast(AIMessage, route).tool_calls[0]["args"]["dest"]}
def normal_llm_node(state: RouterState):
return {"messages": [AIMessage("Hello!")]}
def route_after_prediction(state: RouterState):
if state["route"] == "weather":
return "weather_graph"
else:
return "normal_llm_node"
def weather_graph(state: RouterState):
return subgraph.invoke(state)
graph = StateGraph(RouterState)
graph.add_node(router_node)
graph.add_node(normal_llm_node)
graph.add_node("weather_graph", weather_graph)
graph.add_edge(START, "router_node")
graph.add_conditional_edges("router_node", route_after_prediction)
graph.add_edge("normal_llm_node", END)
graph.add_edge("weather_graph", END)
graph = graph.compile(checkpointer=checkpointer)
assert graph.get_graph(xray=1).draw_mermaid() == snapshot
config = {"configurable": {"thread_id": "1"}}
thread2 = {"configurable": {"thread_id": "2"}}
inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}
# run with custom output
assert [c for c in graph.stream(inputs, thread2, stream_mode="custom")] == [
"I'm",
" very",
]
assert [c for c in graph.stream(None, thread2, stream_mode="custom")] == [
" good",
]
# run until interrupt
assert [
c
for c in graph.stream(
inputs, config=config, stream_mode="updates", subgraphs=True
)
] == [
((), {"router_node": {"route": "weather"}}),
((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}),
((), {"__interrupt__": ()}),
]
# check current state
state = graph.get_state(config)
assert state == StateSnapshot(
values={
"messages": [_AnyIdHumanMessage(content="what's the weather in sf")],
"route": "weather",
},
next=("weather_graph",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {"router_node": {"route": "weather"}},
"step": 1,
"parents": {},
"thread_id": "1",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="weather_graph",
path=(PULL, "weather_graph"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("weather_graph:"),
}
},
),
),
)
# update
graph.update_state(state.tasks[0].state, {"city": "la"})
# run after update
assert [
c
for c in graph.stream(
None, config=config, stream_mode="updates", subgraphs=True
)
] == [
(
(AnyStr("weather_graph:"),),
{
"weather_node": {
"messages": [{"role": "assistant", "content": "I'ts sunny in la!"}]
}
},
),
(
(),
{
"weather_graph": {
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf"),
_AnyIdAIMessage(content="I'ts sunny in la!"),
]
}
},
),
]
# try updating acting as weather node
config = {"configurable": {"thread_id": "14"}}
inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}
assert [
c
for c in graph.stream(
inputs, config=config, stream_mode="updates", subgraphs=True
)
] == [
((), {"router_node": {"route": "weather"}}),
((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}),
((), {"__interrupt__": ()}),
]
state = graph.get_state(config, subgraphs=True)
assert state == StateSnapshot(
values={
"messages": [_AnyIdHumanMessage(content="what's the weather in sf")],
"route": "weather",
},
next=("weather_graph",),
config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {"router_node": {"route": "weather"}},
"step": 1,
"parents": {},
"thread_id": "14",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="weather_graph",
path=(PULL, "weather_graph"),
state=StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf")
],
"city": "San Francisco",
},
next=("weather_node",),
config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": AnyStr("weather_graph:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("weather_graph:"): AnyStr(),
}
),
}
},
metadata={
"source": "loop",
"writes": {"model_node": {"city": "San Francisco"}},
"step": 1,
"parents": {"": AnyStr()},
"thread_id": "14",
"checkpoint_ns": AnyStr("weather_graph:"),
"langgraph_node": "weather_graph",
"langgraph_path": [PULL, "weather_graph"],
"langgraph_step": 2,
"langgraph_triggers": [
"branch:router_node:route_after_prediction:weather_graph"
],
"langgraph_checkpoint_ns": AnyStr("weather_graph:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": AnyStr("weather_graph:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("weather_graph:"): AnyStr(),
}
),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="weather_node",
path=(PULL, "weather_node"),
),
),
),
),
),
)
graph.update_state(
state.tasks[0].state.config,
{"messages": [{"role": "assistant", "content": "rainy"}]},
as_node="weather_node",
)
state = graph.get_state(config, subgraphs=True)
assert state == StateSnapshot(
values={
"messages": [_AnyIdHumanMessage(content="what's the weather in sf")],
"route": "weather",
},
next=("weather_graph",),
config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {"router_node": {"route": "weather"}},
"step": 1,
"parents": {},
"thread_id": "14",
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(
PregelTask(
id=AnyStr(),
name="weather_graph",
path=(PULL, "weather_graph"),
state=StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf"),
_AnyIdAIMessage(content="rainy"),
],
"city": "San Francisco",
},
next=(),
config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": AnyStr("weather_graph:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("weather_graph:"): AnyStr(),
}
),
}
},
metadata={
"step": 2,
"source": "update",
"writes": {
"weather_node": {
"messages": [{"role": "assistant", "content": "rainy"}]
}
},
"parents": {"": AnyStr()},
"thread_id": "14",
"checkpoint_id": AnyStr(),
"checkpoint_ns": AnyStr("weather_graph:"),
"langgraph_node": "weather_graph",
"langgraph_path": [PULL, "weather_graph"],
"langgraph_step": 2,
"langgraph_triggers": [
"branch:router_node:route_after_prediction:weather_graph"
],
"langgraph_checkpoint_ns": AnyStr("weather_graph:"),
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "14",
"checkpoint_ns": AnyStr("weather_graph:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("weather_graph:"): AnyStr(),
}
),
}
},
tasks=(),
),
),
),
)
assert [
c
for c in graph.stream(
None, config=config, stream_mode="updates", subgraphs=True
)
] == [
(
(),
{
"weather_graph": {
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf"),
_AnyIdAIMessage(content="rainy"),
]
}
},
),
]
def test_repeat_condition(snapshot: SnapshotAssertion) -> None:
class AgentState(TypedDict):
hello: str
def router(state: AgentState) -> str:
return "hmm"
workflow = StateGraph(AgentState)
workflow.add_node("Researcher", lambda x: x)
workflow.add_node("Chart Generator", lambda x: x)
workflow.add_node("Call Tool", lambda x: x)
workflow.add_conditional_edges(
"Researcher",
router,
{
"redo": "Researcher",
"continue": "Chart Generator",
"call_tool": "Call Tool",
"end": END,
},
)
workflow.add_conditional_edges(
"Chart Generator",
router,
{"continue": "Researcher", "call_tool": "Call Tool", "end": END},
)
workflow.add_conditional_edges(
"Call Tool",
# Each agent node updates the 'sender' field
# the tool calling node does not, meaning
# this edge will route back to the original agent
# who invoked the tool
lambda x: x["sender"],
{
"Researcher": "Researcher",
"Chart Generator": "Chart Generator",
},
)
workflow.set_entry_point("Researcher")
app = workflow.compile()
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
def test_checkpoint_metadata() -> None:
"""This test verifies that a run's configurable fields are merged with the
previous checkpoint config for each step in the run.
"""
# set up test
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
# graph state
class BaseState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
# initialize graph nodes
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
return f"result for {query}"
tools = [search_api]
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a nice assistant."),
("placeholder", "{messages}"),
]
)
model = FakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"id": "tool_call123",
"name": "search_api",
"args": {"query": "query"},
},
],
),
AIMessage(content="answer"),
]
)
@traceable(run_type="llm")
def agent(state: BaseState) -> BaseState:
formatted = prompt.invoke(state)
response = model.invoke(formatted)
return {"messages": response, "usage_metadata": {"total_tokens": 123}}
def should_continue(data: BaseState) -> str:
# Logic to decide whether to continue in the loop or exit
if not data["messages"][-1].tool_calls:
return "exit"
else:
return "continue"
# define graphs w/ and w/o interrupt
workflow = StateGraph(BaseState)
workflow.add_node("agent", agent)
workflow.add_node("tools", ToolNode(tools))
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent", should_continue, {"continue": "tools", "exit": END}
)
workflow.add_edge("tools", "agent")
# graph w/o interrupt
checkpointer_1 = MemorySaverAssertCheckpointMetadata()
app = workflow.compile(checkpointer=checkpointer_1)
# graph w/ interrupt
checkpointer_2 = MemorySaverAssertCheckpointMetadata()
app_w_interrupt = workflow.compile(
checkpointer=checkpointer_2, interrupt_before=["tools"]
)
# assertions
# invoke graph w/o interrupt
assert app.invoke(
{"messages": ["what is weather in sf"]},
{
"configurable": {
"thread_id": "1",
"test_config_1": "foo",
"test_config_2": "bar",
},
},
) == {
"messages": [
_AnyIdHumanMessage(content="what is weather in sf"),
_AnyIdAIMessage(
content="",
tool_calls=[
{
"name": "search_api",
"args": {"query": "query"},
"id": "tool_call123",
"type": "tool_call",
}
],
),
_AnyIdToolMessage(
content="result for query",
name="search_api",
tool_call_id="tool_call123",
),
_AnyIdAIMessage(content="answer"),
]
}
config = {"configurable": {"thread_id": "1"}}
# assert that checkpoint metadata contains the run's configurable fields
chkpnt_metadata_1 = checkpointer_1.get_tuple(config).metadata
assert chkpnt_metadata_1["thread_id"] == "1"
assert chkpnt_metadata_1["test_config_1"] == "foo"
assert chkpnt_metadata_1["test_config_2"] == "bar"
# Verify that all checkpoint metadata have the expected keys. This check
# is needed because a run may have an arbitrary number of steps depending
# on how the graph is constructed.
chkpnt_tuples_1 = checkpointer_1.list(config)
for chkpnt_tuple in chkpnt_tuples_1:
assert chkpnt_tuple.metadata["thread_id"] == "1"
assert chkpnt_tuple.metadata["test_config_1"] == "foo"
assert chkpnt_tuple.metadata["test_config_2"] == "bar"
# invoke graph, but interrupt before tool call
app_w_interrupt.invoke(
{"messages": ["what is weather in sf"]},
{
"configurable": {
"thread_id": "2",
"test_config_3": "foo",
"test_config_4": "bar",
},
},
)
config = {"configurable": {"thread_id": "2"}}
# assert that checkpoint metadata contains the run's configurable fields
chkpnt_metadata_2 = checkpointer_2.get_tuple(config).metadata
assert chkpnt_metadata_2["thread_id"] == "2"
assert chkpnt_metadata_2["test_config_3"] == "foo"
assert chkpnt_metadata_2["test_config_4"] == "bar"
# resume graph execution
app_w_interrupt.invoke(
input=None,
config={
"configurable": {
"thread_id": "2",
"test_config_3": "foo",
"test_config_4": "bar",
}
},
)
# assert that checkpoint metadata contains the run's configurable fields
chkpnt_metadata_3 = checkpointer_2.get_tuple(config).metadata
assert chkpnt_metadata_3["thread_id"] == "2"
assert chkpnt_metadata_3["test_config_3"] == "foo"
assert chkpnt_metadata_3["test_config_4"] == "bar"
# Verify that all checkpoint metadata have the expected keys. This check
# is needed because a run may have an arbitrary number of steps depending
# on how the graph is constructed.
chkpnt_tuples_2 = checkpointer_2.list(config)
for chkpnt_tuple in chkpnt_tuples_2:
assert chkpnt_tuple.metadata["thread_id"] == "2"
assert chkpnt_tuple.metadata["test_config_3"] == "foo"
assert chkpnt_tuple.metadata["test_config_4"] == "bar"
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_remove_message_via_state_update(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
workflow = MessageGraph()
workflow.add_node(
"chatbot",
lambda state: [
AIMessage(
content="Hello! How can I help you",
)
],
)
workflow.set_entry_point("chatbot")
workflow.add_edge("chatbot", END)
checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name)
app = workflow.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}
output = app.invoke([HumanMessage(content="Hi")], config=config)
app.update_state(config, values=[RemoveMessage(id=output[-1].id)])
updated_state = app.get_state(config)
assert len(updated_state.values) == 1
assert updated_state.values[-1].content == "Hi"
def test_remove_message_from_node():
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
workflow = MessageGraph()
workflow.add_node(
"chatbot",
lambda state: [
AIMessage(
content="Hello!",
),
AIMessage(
content="How can I help you?",
),
],
)
workflow.add_node("delete_messages", lambda state: [RemoveMessage(id=state[-2].id)])
workflow.set_entry_point("chatbot")
workflow.add_edge("chatbot", "delete_messages")
workflow.add_edge("delete_messages", END)
app = workflow.compile()
output = app.invoke([HumanMessage(content="Hi")])
assert len(output) == 2
assert output[-1].content == "How can I help you?"
def test_xray_lance(snapshot: SnapshotAssertion):
from langchain_core.messages import AnyMessage, HumanMessage
from pydantic import BaseModel, Field
class Analyst(BaseModel):
affiliation: str = Field(
description="Primary affiliation of the investment analyst.",
)
name: str = Field(
description="Name of the investment analyst.",
pattern=r"^[a-zA-Z0-9_-]{1,64}$",
)
role: str = Field(
description="Role of the investment analyst in the context of the topic.",
)
description: str = Field(
description="Description of the investment analyst focus, concerns, and motives.",
)
@property
def persona(self) -> str:
return f"Name: {self.name}\nRole: {self.role}\nAffiliation: {self.affiliation}\nDescription: {self.description}\n"
class Perspectives(BaseModel):
analysts: List[Analyst] = Field(
description="Comprehensive list of investment analysts with their roles and affiliations.",
)
class Section(BaseModel):
section_title: str = Field(..., title="Title of the section")
context: str = Field(
..., title="Provide a clear summary of the focus area that you researched."
)
findings: str = Field(
...,
title="Give a clear and detailed overview of your findings based upon the expert interview.",
)
thesis: str = Field(
...,
title="Give a clear and specific investment thesis based upon these findings.",
)
class InterviewState(TypedDict):
messages: Annotated[List[AnyMessage], add_messages]
analyst: Analyst
section: Section
class ResearchGraphState(TypedDict):
analysts: List[Analyst]
topic: str
max_analysts: int
sections: List[Section]
interviews: Annotated[list, operator.add]
# Conditional edge
def route_messages(state):
return "ask_question"
def generate_question(state):
return ...
def generate_answer(state):
return ...
# Add nodes and edges
interview_builder = StateGraph(InterviewState)
interview_builder.add_node("ask_question", generate_question)
interview_builder.add_node("answer_question", generate_answer)
# Flow
interview_builder.add_edge(START, "ask_question")
interview_builder.add_edge("ask_question", "answer_question")
interview_builder.add_conditional_edges("answer_question", route_messages)
# Set up memory
memory = MemorySaver()
# Interview
interview_graph = interview_builder.compile(checkpointer=memory).with_config(
run_name="Conduct Interviews"
)
# View
assert interview_graph.get_graph().to_json() == snapshot
def run_all_interviews(state: ResearchGraphState):
"""Edge to run the interview sub-graph using Send"""
return [
Send(
"conduct_interview",
{
"analyst": Analyst(),
"messages": [
HumanMessage(
content="So you said you were writing an article on ...?"
)
],
},
)
for s in state["analysts"]
]
def generate_sections(state: ResearchGraphState):
return ...
def generate_analysts(state: ResearchGraphState):
return ...
builder = StateGraph(ResearchGraphState)
builder.add_node("generate_analysts", generate_analysts)
builder.add_node("conduct_interview", interview_builder.compile())
builder.add_node("generate_sections", generate_sections)
builder.add_edge(START, "generate_analysts")
builder.add_conditional_edges(
"generate_analysts", run_all_interviews, ["conduct_interview"]
)
builder.add_edge("conduct_interview", "generate_sections")
builder.add_edge("generate_sections", END)
graph = builder.compile()
# View
assert graph.get_graph().to_json() == snapshot
assert graph.get_graph(xray=1).to_json() == snapshot
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_channel_values(request: pytest.FixtureRequest, checkpointer_name: str) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
config = {"configurable": {"thread_id": "1"}}
chain = Channel.subscribe_to("input") | Channel.write_to("output")
app = Pregel(
nodes={
"one": chain,
},
channels={
"ephemeral": EphemeralValue(Any),
"input": LastValue(int),
"output": LastValue(int),
},
input_channels=["input", "ephemeral"],
output_channels="output",
checkpointer=checkpointer,
)
app.invoke({"input": 1, "ephemeral": "meow"}, config)
assert checkpointer.get(config)["channel_values"] == {"input": 1, "output": 1}
def test_xray_issue(snapshot: SnapshotAssertion) -> None:
class State(TypedDict):
messages: Annotated[list, add_messages]
def node(name):
def _node(state: State):
return {"messages": [("human", f"entered {name} node")]}
return _node
parent = StateGraph(State)
child = StateGraph(State)
child.add_node("c_one", node("c_one"))
child.add_node("c_two", node("c_two"))
child.add_edge("__start__", "c_one")
child.add_edge("c_two", "c_one")
child.add_conditional_edges(
"c_one", lambda x: str(randrange(0, 2)), {"0": "c_two", "1": "__end__"}
)
parent.add_node("p_one", node("p_one"))
parent.add_node("p_two", child.compile())
parent.add_edge("__start__", "p_one")
parent.add_edge("p_two", "p_one")
parent.add_conditional_edges(
"p_one", lambda x: str(randrange(0, 2)), {"0": "p_two", "1": "__end__"}
)
app = parent.compile()
assert app.get_graph(xray=True).draw_mermaid() == snapshot
def test_xray_bool(snapshot: SnapshotAssertion) -> None:
class State(TypedDict):
messages: Annotated[list, add_messages]
def node(name):
def _node(state: State):
return {"messages": [("human", f"entered {name} node")]}
return _node
grand_parent = StateGraph(State)
child = StateGraph(State)
child.add_node("c_one", node("c_one"))
child.add_node("c_two", node("c_two"))
child.add_edge("__start__", "c_one")
child.add_edge("c_two", "c_one")
child.add_conditional_edges(
"c_one", lambda x: str(randrange(0, 2)), {"0": "c_two", "1": "__end__"}
)
parent = StateGraph(State)
parent.add_node("p_one", node("p_one"))
parent.add_node("p_two", child.compile())
parent.add_edge("__start__", "p_one")
parent.add_edge("p_two", "p_one")
parent.add_conditional_edges(
"p_one", lambda x: str(randrange(0, 2)), {"0": "p_two", "1": "__end__"}
)
grand_parent.add_node("gp_one", node("gp_one"))
grand_parent.add_node("gp_two", parent.compile())
grand_parent.add_edge("__start__", "gp_one")
grand_parent.add_edge("gp_two", "gp_one")
grand_parent.add_conditional_edges(
"gp_one", lambda x: str(randrange(0, 2)), {"0": "gp_two", "1": "__end__"}
)
app = grand_parent.compile()
assert app.get_graph(xray=True).draw_mermaid() == snapshot
def test_multiple_sinks_subgraphs(snapshot: SnapshotAssertion) -> None:
class State(TypedDict):
messages: Annotated[list, add_messages]
subgraph_builder = StateGraph(State)
subgraph_builder.add_node("one", lambda x: x)
subgraph_builder.add_node("two", lambda x: x)
subgraph_builder.add_node("three", lambda x: x)
subgraph_builder.add_edge("__start__", "one")
subgraph_builder.add_conditional_edges("one", lambda x: "two", ["two", "three"])
subgraph = subgraph_builder.compile()
builder = StateGraph(State)
builder.add_node("uno", lambda x: x)
builder.add_node("dos", lambda x: x)
builder.add_node("subgraph", subgraph)
builder.add_edge("__start__", "uno")
builder.add_conditional_edges("uno", lambda x: "dos", ["dos", "subgraph"])
app = builder.compile()
assert app.get_graph(xray=True).draw_mermaid() == snapshot
def test_subgraph_retries():
class State(TypedDict):
count: int
class ChildState(State):
some_list: Annotated[list, operator.add]
called_times = 0
class RandomError(ValueError):
"""This will be retried on."""
def parent_node(state: State):
return {"count": state["count"] + 1}
def child_node_a(state: ChildState):
nonlocal called_times
# We want it to retry only on node_b
# NOT re-compute the whole graph.
assert not called_times
called_times += 1
return {"some_list": ["val"]}
def child_node_b(state: ChildState):
raise RandomError("First attempt fails")
child = StateGraph(ChildState)
child.add_node(child_node_a)
child.add_node(child_node_b)
child.add_edge("__start__", "child_node_a")
child.add_edge("child_node_a", "child_node_b")
parent = StateGraph(State)
parent.add_node("parent_node", parent_node)
parent.add_node(
"child_graph",
child.compile(),
retry=RetryPolicy(
max_attempts=3,
retry_on=(RandomError,),
backoff_factor=0.0001,
initial_interval=0.0001,
),
)
parent.add_edge("parent_node", "child_graph")
parent.set_entry_point("parent_node")
checkpointer = MemorySaver()
app = parent.compile(checkpointer=checkpointer)
with pytest.raises(RandomError):
app.invoke({"count": 0}, {"configurable": {"thread_id": "foo"}})
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
@pytest.mark.parametrize("store_name", ALL_STORES_SYNC)
def test_store_injected(
request: pytest.FixtureRequest, checkpointer_name: str, store_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
the_store = request.getfixturevalue(f"store_{store_name}")
class State(TypedDict):
count: Annotated[int, operator.add]
doc_id = str(uuid.uuid4())
doc = {"some-key": "this-is-a-val"}
uid = uuid.uuid4().hex
namespace = (f"foo-{uid}", "bar")
thread_1 = str(uuid.uuid4())
thread_2 = str(uuid.uuid4())
class Node:
def __init__(self, i: Optional[int] = None):
self.i = i
def __call__(self, inputs: State, config: RunnableConfig, store: BaseStore):
assert isinstance(store, BaseStore)
store.put(
namespace
if self.i is not None
and config["configurable"]["thread_id"] in (thread_1, thread_2)
else (f"foo_{self.i}", "bar"),
doc_id,
{
**doc,
"from_thread": config["configurable"]["thread_id"],
"some_val": inputs["count"],
},
)
return {"count": 1}
builder = StateGraph(State)
builder.add_node("node", Node())
builder.add_edge("__start__", "node")
N = 500
M = 1
if "duckdb" in store_name:
logger.warning(
"DuckDB store implementation has a known issue that does not"
" support concurrent writes, so we're reducing the test scope"
)
N = M = 1
for i in range(N):
builder.add_node(f"node_{i}", Node(i))
builder.add_edge("__start__", f"node_{i}")
graph = builder.compile(store=the_store, checkpointer=checkpointer)
results = graph.batch(
[{"count": 0}] * M,
([{"configurable": {"thread_id": str(uuid.uuid4())}}] * (M - 1))
+ [{"configurable": {"thread_id": thread_1}}],
)
result = results[-1]
assert result == {"count": N + 1}
returned_doc = the_store.get(namespace, doc_id).value
assert returned_doc == {**doc, "from_thread": thread_1, "some_val": 0}
assert len(the_store.search(namespace)) == 1
# Check results after another turn of the same thread
result = graph.invoke({"count": 0}, {"configurable": {"thread_id": thread_1}})
assert result == {"count": (N + 1) * 2}
returned_doc = the_store.get(namespace, doc_id).value
assert returned_doc == {**doc, "from_thread": thread_1, "some_val": N + 1}
assert len(the_store.search(namespace)) == 1
result = graph.invoke({"count": 0}, {"configurable": {"thread_id": thread_2}})
assert result == {"count": N + 1}
returned_doc = the_store.get(namespace, doc_id).value
assert returned_doc == {
**doc,
"from_thread": thread_2,
"some_val": 0,
} # Overwrites the whole doc
assert len(the_store.search(namespace)) == 1 # still overwriting the same one
def test_enum_node_names():
class NodeName(str, enum.Enum):
BAZ = "baz"
class State(TypedDict):
foo: str
bar: str
def baz(state: State):
return {"bar": state["foo"] + "!"}
graph = StateGraph(State)
graph.add_node(NodeName.BAZ, baz)
graph.add_edge(START, NodeName.BAZ)
graph.add_edge(NodeName.BAZ, END)
graph = graph.compile()
assert graph.invoke({"foo": "hello"}) == {"foo": "hello", "bar": "hello!"}
def test_debug_retry():
class State(TypedDict):
messages: Annotated[list[str], operator.add]
def node(name):
def _node(state: State):
return {"messages": [f"entered {name} node"]}
return _node
builder = StateGraph(State)
builder.add_node("one", node("one"))
builder.add_node("two", node("two"))
builder.add_edge(START, "one")
builder.add_edge("one", "two")
builder.add_edge("two", END)
saver = MemorySaver()
graph = builder.compile(checkpointer=saver)
config = {"configurable": {"thread_id": "1"}}
graph.invoke({"messages": []}, config=config)
# re-run step: 1
target_config = next(
c.parent_config for c in saver.list(config) if c.metadata["step"] == 1
)
update_config = graph.update_state(target_config, values=None)
events = [*graph.stream(None, config=update_config, stream_mode="debug")]
checkpoint_events = list(
reversed([e["payload"] for e in events if e["type"] == "checkpoint"])
)
checkpoint_history = {
c.config["configurable"]["checkpoint_id"]: c
for c in graph.get_state_history(config)
}
def lax_normalize_config(config: Optional[dict]) -> Optional[dict]:
if config is None:
return None
return config["configurable"]
for stream in checkpoint_events:
stream_conf = lax_normalize_config(stream["config"])
stream_parent_conf = lax_normalize_config(stream["parent_config"])
assert stream_conf != stream_parent_conf
# ensure the streamed checkpoint == checkpoint from checkpointer.list()
history = checkpoint_history[stream["config"]["configurable"]["checkpoint_id"]]
history_conf = lax_normalize_config(history.config)
assert stream_conf == history_conf
history_parent_conf = lax_normalize_config(history.parent_config)
assert stream_parent_conf == history_parent_conf
def test_debug_subgraphs():
class State(TypedDict):
messages: Annotated[list[str], operator.add]
def node(name):
def _node(state: State):
return {"messages": [f"entered {name} node"]}
return _node
parent = StateGraph(State)
child = StateGraph(State)
child.add_node("c_one", node("c_one"))
child.add_node("c_two", node("c_two"))
child.add_edge(START, "c_one")
child.add_edge("c_one", "c_two")
child.add_edge("c_two", END)
parent.add_node("p_one", node("p_one"))
parent.add_node("p_two", child.compile())
parent.add_edge(START, "p_one")
parent.add_edge("p_one", "p_two")
parent.add_edge("p_two", END)
graph = parent.compile(checkpointer=MemorySaver())
config = {"configurable": {"thread_id": "1"}}
events = [
*graph.stream(
{"messages": []},
config=config,
stream_mode="debug",
)
]
checkpoint_events = list(
reversed([e["payload"] for e in events if e["type"] == "checkpoint"])
)
checkpoint_history = list(graph.get_state_history(config))
assert len(checkpoint_events) == len(checkpoint_history)
def lax_normalize_config(config: Optional[dict]) -> Optional[dict]:
if config is None:
return None
return config["configurable"]
for stream, history in zip(checkpoint_events, checkpoint_history):
assert stream["values"] == history.values
assert stream["next"] == list(history.next)
assert lax_normalize_config(stream["config"]) == lax_normalize_config(
history.config
)
assert lax_normalize_config(stream["parent_config"]) == lax_normalize_config(
history.parent_config
)
assert len(stream["tasks"]) == len(history.tasks)
for stream_task, history_task in zip(stream["tasks"], history.tasks):
assert stream_task["id"] == history_task.id
assert stream_task["name"] == history_task.name
assert stream_task["interrupts"] == history_task.interrupts
assert stream_task.get("error") == history_task.error
assert stream_task.get("state") == history_task.state
def test_debug_nested_subgraphs():
from collections import defaultdict
class State(TypedDict):
messages: Annotated[list[str], operator.add]
def node(name):
def _node(state: State):
return {"messages": [f"entered {name} node"]}
return _node
grand_parent = StateGraph(State)
parent = StateGraph(State)
child = StateGraph(State)
child.add_node("c_one", node("c_one"))
child.add_node("c_two", node("c_two"))
child.add_edge(START, "c_one")
child.add_edge("c_one", "c_two")
child.add_edge("c_two", END)
parent.add_node("p_one", node("p_one"))
parent.add_node("p_two", child.compile())
parent.add_edge(START, "p_one")
parent.add_edge("p_one", "p_two")
parent.add_edge("p_two", END)
grand_parent.add_node("gp_one", node("gp_one"))
grand_parent.add_node("gp_two", parent.compile())
grand_parent.add_edge(START, "gp_one")
grand_parent.add_edge("gp_one", "gp_two")
grand_parent.add_edge("gp_two", END)
graph = grand_parent.compile(checkpointer=MemorySaver())
config = {"configurable": {"thread_id": "1"}}
events = [
*graph.stream(
{"messages": []},
config=config,
stream_mode="debug",
subgraphs=True,
)
]
stream_ns: dict[tuple, dict] = defaultdict(list)
for ns, e in events:
if e["type"] == "checkpoint":
stream_ns[ns].append(e["payload"])
assert list(stream_ns.keys()) == [
(),
(AnyStr("gp_two:"),),
(AnyStr("gp_two:"), AnyStr("p_two:")),
]
history_ns = {
ns: list(
graph.get_state_history(
{"configurable": {"thread_id": "1", "checkpoint_ns": "|".join(ns)}}
)
)[::-1]
for ns in stream_ns.keys()
}
def normalize_config(config: Optional[dict]) -> Optional[dict]:
if config is None:
return None
clean_config = {}
clean_config["thread_id"] = config["configurable"]["thread_id"]
clean_config["checkpoint_id"] = config["configurable"]["checkpoint_id"]
clean_config["checkpoint_ns"] = config["configurable"]["checkpoint_ns"]
if "checkpoint_map" in config["configurable"]:
clean_config["checkpoint_map"] = config["configurable"]["checkpoint_map"]
return clean_config
for checkpoint_events, checkpoint_history in zip(
stream_ns.values(), history_ns.values()
):
for stream, history in zip(checkpoint_events, checkpoint_history):
assert stream["values"] == history.values
assert stream["next"] == list(history.next)
assert normalize_config(stream["config"]) == normalize_config(
history.config
)
assert normalize_config(stream["parent_config"]) == normalize_config(
history.parent_config
)
assert len(stream["tasks"]) == len(history.tasks)
for stream_task, history_task in zip(stream["tasks"], history.tasks):
assert stream_task["id"] == history_task.id
assert stream_task["name"] == history_task.name
assert stream_task["interrupts"] == history_task.interrupts
assert stream_task.get("error") == history_task.error
assert stream_task.get("state") == history_task.state
def test_add_sequence():
class State(TypedDict):
foo: Annotated[list[str], operator.add]
bar: str
def step1(state: State):
return {"foo": ["step1"], "bar": "baz"}
def step2(state: State):
return {"foo": ["step2"]}
# test raising if less than 1 steps
with pytest.raises(ValueError):
StateGraph(State).add_sequence([])
# test raising if duplicate step names
with pytest.raises(ValueError):
StateGraph(State).add_sequence([step1, step1])
with pytest.raises(ValueError):
StateGraph(State).add_sequence([("foo", step1), ("foo", step1)])
# test unnamed steps
builder = StateGraph(State)
builder.add_sequence([step1, step2])
builder.add_edge(START, "step1")
graph = builder.compile()
result = graph.invoke({"foo": []})
assert result == {"foo": ["step1", "step2"], "bar": "baz"}
stream_chunks = list(graph.stream({"foo": []}))
assert stream_chunks == [
{"step1": {"foo": ["step1"], "bar": "baz"}},
{"step2": {"foo": ["step2"]}},
]
# test named steps
builder_named_steps = StateGraph(State)
builder_named_steps.add_sequence([("meow1", step1), ("meow2", step2)])
builder_named_steps.add_edge(START, "meow1")
graph_named_steps = builder_named_steps.compile()
result = graph_named_steps.invoke({"foo": []})
stream_chunks = list(graph_named_steps.stream({"foo": []}))
assert result == {"foo": ["step1", "step2"], "bar": "baz"}
assert stream_chunks == [
{"meow1": {"foo": ["step1"], "bar": "baz"}},
{"meow2": {"foo": ["step2"]}},
]
builder_named_steps = StateGraph(State)
builder_named_steps.add_sequence(
[
("meow1", lambda state: {"foo": ["foo"]}),
("meow2", lambda state: {"bar": state["foo"][0] + "bar"}),
],
)
builder_named_steps.add_edge(START, "meow1")
graph_named_steps = builder_named_steps.compile()
result = graph_named_steps.invoke({"foo": []})
stream_chunks = list(graph_named_steps.stream({"foo": []}))
# filtered by output schema
assert result == {"bar": "foobar", "foo": ["foo"]}
assert stream_chunks == [
{"meow1": {"foo": ["foo"]}},
{"meow2": {"bar": "foobar"}},
]
# test two sequences
def a(state: State):
return {"foo": ["a"]}
def b(state: State):
return {"foo": ["b"]}
builder_two_sequences = StateGraph(State)
builder_two_sequences.add_sequence([a])
builder_two_sequences.add_sequence([b])
builder_two_sequences.add_edge(START, "a")
builder_two_sequences.add_edge("a", "b")
graph_two_sequences = builder_two_sequences.compile()
result = graph_two_sequences.invoke({"foo": []})
assert result == {"foo": ["a", "b"]}
stream_chunks = list(graph_two_sequences.stream({"foo": []}))
assert stream_chunks == [
{"a": {"foo": ["a"]}},
{"b": {"foo": ["b"]}},
]
# test mixed nodes and sequences
def c(state: State):
return {"foo": ["c"]}
def d(state: State):
return {"foo": ["d"]}
def e(state: State):
return {"foo": ["e"]}
def foo(state: State):
if state["foo"][0] == "a":
return "d"
else:
return "c"
builder_complex = StateGraph(State)
builder_complex.add_sequence([a, b])
builder_complex.add_conditional_edges("b", foo)
builder_complex.add_node(c)
builder_complex.add_sequence([d, e])
builder_complex.add_edge(START, "a")
graph_complex = builder_complex.compile()
result = graph_complex.invoke({"foo": []})
assert result == {"foo": ["a", "b", "d", "e"]}
result = graph_complex.invoke({"foo": ["start"]})
assert result == {"foo": ["start", "a", "b", "c"]}
stream_chunks = list(graph_complex.stream({"foo": []}))
assert stream_chunks == [
{"a": {"foo": ["a"]}},
{"b": {"foo": ["b"]}},
{"d": {"foo": ["d"]}},
{"e": {"foo": ["e"]}},
]
def test_runnable_passthrough_node_graph() -> None:
class State(TypedDict):
changeme: str
async def dummy(state):
return state
agent = dummy | RunnablePassthrough.assign(prediction=RunnableLambda(lambda x: x))
graph_builder = StateGraph(State)
graph_builder.add_node("agent", agent)
graph_builder.add_edge(START, "agent")
graph = graph_builder.compile()
assert graph.get_graph(xray=True).to_json() == graph.get_graph(xray=False).to_json()
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_parent_command(request: pytest.FixtureRequest, checkpointer_name: str) -> None:
from langchain_core.messages import BaseMessage
from langchain_core.tools import tool
@tool(return_direct=True)
def get_user_name() -> Command:
"""Retrieve user name"""
return Command(update={"user_name": "Meow"}, graph=Command.PARENT)
subgraph_builder = StateGraph(MessagesState)
subgraph_builder.add_node("tool", get_user_name)
subgraph_builder.add_edge(START, "tool")
subgraph = subgraph_builder.compile()
class CustomParentState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
# this key is not available to the child graph
user_name: str
builder = StateGraph(CustomParentState)
builder.add_node("alice", subgraph)
builder.add_edge(START, "alice")
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
graph = builder.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "1"}}
assert graph.invoke({"messages": [("user", "get user name")]}, config) == {
"messages": [
_AnyIdHumanMessage(
content="get user name", additional_kwargs={}, response_metadata={}
),
],
"user_name": "Meow",
}
assert graph.get_state(config) == StateSnapshot(
values={
"messages": [
_AnyIdHumanMessage(
content="get user name", additional_kwargs={}, response_metadata={}
),
],
"user_name": "Meow",
},
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"source": "loop",
"writes": {
"alice": {
"user_name": "Meow",
}
},
"thread_id": "1",
"step": 1,
"parents": {},
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
tasks=(),
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_interrupt_subgraph(request: pytest.FixtureRequest, checkpointer_name: str):
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
class State(TypedDict):
baz: str
def foo(state):
return {"baz": "foo"}
def bar(state):
value = interrupt("Please provide baz value:")
return {"baz": value}
child_builder = StateGraph(State)
child_builder.add_node(bar)
child_builder.add_edge(START, "bar")
builder = StateGraph(State)
builder.add_node(foo)
builder.add_node("bar", child_builder.compile())
builder.add_edge(START, "foo")
builder.add_edge("foo", "bar")
graph = builder.compile(checkpointer=checkpointer)
thread1 = {"configurable": {"thread_id": "1"}}
# First run, interrupted at bar
assert graph.invoke({"baz": ""}, thread1)
# Resume with answer
assert graph.invoke(Command(resume="bar"), thread1)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_interrupt_multiple(request: pytest.FixtureRequest, checkpointer_name: str):
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
class State(TypedDict):
my_key: Annotated[str, operator.add]
def node(s: State) -> State:
answer = interrupt({"value": 1})
answer2 = interrupt({"value": 2})
return {"my_key": answer + " " + answer2}
builder = StateGraph(State)
builder.add_node("node", node)
builder.add_edge(START, "node")
graph = builder.compile(checkpointer=checkpointer)
thread1 = {"configurable": {"thread_id": "1"}}
assert [e for e in graph.stream({"my_key": "DE", "market": "DE"}, thread1)] == [
{
"__interrupt__": (
Interrupt(
value={"value": 1},
resumable=True,
ns=[AnyStr("node:")],
when="during",
),
)
}
]
assert [
event
for event in graph.stream(
Command(resume="answer 1", update={"my_key": "foofoo"}), thread1
)
] == [
{
"__interrupt__": (
Interrupt(
value={"value": 2},
resumable=True,
ns=[AnyStr("node:")],
when="during",
),
)
}
]
assert [event for event in graph.stream(Command(resume="answer 2"), thread1)] == [
{"node": {"my_key": "answer 1 answer 2"}},
]
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_interrupt_loop(request: pytest.FixtureRequest, checkpointer_name: str):
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
class State(TypedDict):
age: int
other: str
def ask_age(s: State):
"""Ask an expert for help."""
question = "How old are you?"
value = None
for _ in range(10):
value: str = interrupt(question)
if not value.isdigit() or int(value) < 18:
question = "invalid response"
value = None
else:
break
return {"age": int(value)}
builder = StateGraph(State)
builder.add_node("node", ask_age)
builder.add_edge(START, "node")
graph = builder.compile(checkpointer=checkpointer)
thread1 = {"configurable": {"thread_id": "1"}}
assert [e for e in graph.stream({"other": ""}, thread1)] == [
{
"__interrupt__": (
Interrupt(
value="How old are you?",
resumable=True,
ns=[AnyStr("node:")],
when="during",
),
)
}
]
assert [
event
for event in graph.stream(
Command(resume="13"),
thread1,
)
] == [
{
"__interrupt__": (
Interrupt(
value="invalid response",
resumable=True,
ns=[AnyStr("node:")],
when="during",
),
)
}
]
assert [
event
for event in graph.stream(
Command(resume="15"),
thread1,
)
] == [
{
"__interrupt__": (
Interrupt(
value="invalid response",
resumable=True,
ns=[AnyStr("node:")],
when="during",
),
)
}
]
assert [event for event in graph.stream(Command(resume="19"), thread1)] == [
{"node": {"age": 19}},
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/any_str.py`:
```py
import re
from typing import Any, Sequence, Union
from typing_extensions import Self
class FloatBetween(float):
def __new__(cls, min_value: float, max_value: float) -> Self:
return super().__new__(cls, min_value)
def __init__(self, min_value: float, max_value: float) -> None:
super().__init__()
self.min_value = min_value
self.max_value = max_value
def __eq__(self, other: object) -> bool:
return (
isinstance(other, float)
and other >= self.min_value
and other <= self.max_value
)
def __hash__(self) -> int:
return hash((float(self), self.min_value, self.max_value))
class AnyStr(str):
def __init__(self, prefix: Union[str, re.Pattern] = "") -> None:
super().__init__()
self.prefix = prefix
def __eq__(self, other: object) -> bool:
return isinstance(other, str) and (
other.startswith(self.prefix)
if isinstance(self.prefix, str)
else self.prefix.match(other)
)
def __hash__(self) -> int:
return hash((str(self), self.prefix))
class AnyDict(dict):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __eq__(self, other: object) -> bool:
if not isinstance(other, dict) or len(self) != len(other):
return False
for k, v in self.items():
if kk := next((kk for kk in other if kk == k), None):
if v == other[kk]:
continue
else:
return False
else:
return True
class AnyVersion:
def __init__(self) -> None:
super().__init__()
def __eq__(self, other: object) -> bool:
return isinstance(other, (str, int, float))
def __hash__(self) -> int:
return hash(str(self))
class UnsortedSequence:
def __init__(self, *values: Any) -> None:
self.seq = values
def __eq__(self, value: object) -> bool:
return (
isinstance(value, Sequence)
and len(self.seq) == len(value)
and all(a in value for a in self.seq)
)
def __hash__(self) -> int:
return hash(frozenset(self.seq))
def __repr__(self) -> str:
return repr(self.seq)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/test_algo.py`:
```py
from langgraph.checkpoint.base import empty_checkpoint
from langgraph.pregel.algo import prepare_next_tasks
from langgraph.pregel.manager import ChannelsManager
def test_prepare_next_tasks() -> None:
config = {}
processes = {}
checkpoint = empty_checkpoint()
with ChannelsManager({}, checkpoint, config) as (channels, managed):
assert (
prepare_next_tasks(
checkpoint,
{},
processes,
channels,
managed,
config,
0,
for_execution=False,
)
== {}
)
assert (
prepare_next_tasks(
checkpoint,
{},
processes,
channels,
managed,
config,
0,
for_execution=True,
checkpointer=None,
store=None,
manager=None,
)
== {}
)
# TODO: add more tests
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/tests/fake_tracer.py`:
```py
from typing import Any, Optional
from uuid import UUID
from langchain_core.messages.base import BaseMessage
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers import BaseTracer, Run
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution.
It replaces run ids with deterministic UUIDs for snapshotting."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: list[Run] = []
self.uuids_map: dict[UUID, UUID] = {}
self.uuids_generator = (
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
)
def _replace_uuid(self, uuid: UUID) -> UUID:
if uuid not in self.uuids_map:
self.uuids_map[uuid] = next(self.uuids_generator)
return self.uuids_map[uuid]
def _replace_message_id(self, maybe_message: Any) -> Any:
if isinstance(maybe_message, BaseMessage):
maybe_message.id = str(next(self.uuids_generator))
if isinstance(maybe_message, ChatGeneration):
maybe_message.message.id = str(next(self.uuids_generator))
if isinstance(maybe_message, LLMResult):
for i, gen_list in enumerate(maybe_message.generations):
for j, gen in enumerate(gen_list):
maybe_message.generations[i][j] = self._replace_message_id(gen)
if isinstance(maybe_message, dict):
for k, v in maybe_message.items():
maybe_message[k] = self._replace_message_id(v)
if isinstance(maybe_message, list):
for i, v in enumerate(maybe_message):
maybe_message[i] = self._replace_message_id(v)
return maybe_message
def _copy_run(self, run: Run) -> Run:
if run.dotted_order:
levels = run.dotted_order.split(".")
processed_levels = []
for level in levels:
timestamp, run_id = level.split("Z")
new_run_id = self._replace_uuid(UUID(run_id))
processed_level = f"{timestamp}Z{new_run_id}"
processed_levels.append(processed_level)
new_dotted_order = ".".join(processed_levels)
else:
new_dotted_order = None
return run.copy(
update={
"id": self._replace_uuid(run.id),
"parent_run_id": (
self.uuids_map[run.parent_run_id] if run.parent_run_id else None
),
"child_runs": [self._copy_run(child) for child in run.child_runs],
"trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,
"dotted_order": new_dotted_order,
"inputs": self._replace_message_id(run.inputs),
"outputs": self._replace_message_id(run.outputs),
}
)
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
self.runs.append(self._copy_run(run))
def flattened_runs(self) -> list[Run]:
q = [] + self.runs
result = []
while q:
parent = q.pop()
result.append(parent)
if parent.child_runs:
q.extend(parent.child_runs)
return result
@property
def run_ids(self) -> list[Optional[UUID]]:
runs = self.flattened_runs()
uuids_map = {v: k for k, v in self.uuids_map.items()}
return [uuids_map.get(r.id) for r in runs]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/langgraph/poetry.toml`:
```toml
[virtualenvs]
in-project = true
[installer]
modern-installation = false
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/memory/__init__.py`:
```py
import asyncio
import logging
import os
import pickle
import random
import shutil
from collections import defaultdict
from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack
from functools import partial
from types import TracebackType
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple, Type
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
SerializerProtocol,
get_checkpoint_id,
)
from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol
logger = logging.getLogger(__name__)
class MemorySaver(
BaseCheckpointSaver[str], AbstractContextManager, AbstractAsyncContextManager
):
"""An in-memory checkpoint saver.
This checkpoint saver stores checkpoints in memory using a defaultdict.
Note:
Only use `MemorySaver` for debugging or testing purposes.
For production use cases we recommend installing [langgraph-checkpoint-postgres](https://pypi.org/project/langgraph-checkpoint-postgres/) and using `PostgresSaver` / `AsyncPostgresSaver`.
Args:
serde (Optional[SerializerProtocol]): The serializer to use for serializing and deserializing checkpoints. Defaults to None.
Examples:
import asyncio
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph
builder = StateGraph(int)
builder.add_node("add_one", lambda x: x + 1)
builder.set_entry_point("add_one")
builder.set_finish_point("add_one")
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)
coro = graph.ainvoke(1, {"configurable": {"thread_id": "thread-1"}})
asyncio.run(coro) # Output: 2
"""
# thread ID -> checkpoint NS -> checkpoint ID -> checkpoint mapping
storage: defaultdict[
str,
dict[
str, dict[str, tuple[tuple[str, bytes], tuple[str, bytes], Optional[str]]]
],
]
writes: defaultdict[
tuple[str, str, str], dict[tuple[str, int], tuple[str, str, tuple[str, bytes]]]
]
def __init__(
self,
*,
serde: Optional[SerializerProtocol] = None,
factory: Type[defaultdict] = defaultdict,
) -> None:
super().__init__(serde=serde)
self.storage = factory(lambda: defaultdict(dict))
self.writes = factory(dict)
self.stack = ExitStack()
if factory is not defaultdict:
self.stack.enter_context(self.storage) # type: ignore[arg-type]
self.stack.enter_context(self.writes) # type: ignore[arg-type]
def __enter__(self) -> "MemorySaver":
return self.stack.__enter__()
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
return self.stack.__exit__(exc_type, exc_value, traceback)
async def __aenter__(self) -> "MemorySaver":
return self.stack.__enter__()
async def __aexit__(
self,
__exc_type: Optional[type[BaseException]],
__exc_value: Optional[BaseException],
__traceback: Optional[TracebackType],
) -> Optional[bool]:
return self.stack.__exit__(__exc_type, __exc_value, __traceback)
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the in-memory storage.
This method retrieves a checkpoint tuple from the in-memory storage based on the
provided config. If the config contains a "checkpoint_id" key, the checkpoint with
the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id := get_checkpoint_id(config):
if saved := self.storage[thread_id][checkpoint_ns].get(checkpoint_id):
checkpoint, metadata, parent_checkpoint_id = saved
writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values()
if parent_checkpoint_id:
sends = [
w[2]
for w in self.writes[
(thread_id, checkpoint_ns, parent_checkpoint_id)
].values()
if w[1] == TASKS
]
else:
sends = []
return CheckpointTuple(
config=config,
checkpoint={
**self.serde.loads_typed(checkpoint),
"pending_sends": [self.serde.loads_typed(s) for s in sends],
},
metadata=self.serde.loads_typed(metadata),
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v in writes
],
parent_config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None,
)
else:
if checkpoints := self.storage[thread_id][checkpoint_ns]:
checkpoint_id = max(checkpoints.keys())
checkpoint, metadata, parent_checkpoint_id = checkpoints[checkpoint_id]
writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values()
if parent_checkpoint_id:
sends = [
w[2]
for w in self.writes[
(thread_id, checkpoint_ns, parent_checkpoint_id)
].values()
if w[1] == TASKS
]
else:
sends = []
return CheckpointTuple(
config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
checkpoint={
**self.serde.loads_typed(checkpoint),
"pending_sends": [self.serde.loads_typed(s) for s in sends],
},
metadata=self.serde.loads_typed(metadata),
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v in writes
],
parent_config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None,
)
def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints from the in-memory storage.
This method retrieves a list of checkpoint tuples from the in-memory storage based
on the provided criteria.
Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata.
before (Optional[RunnableConfig]): List checkpoints created before this configuration.
limit (Optional[int]): Maximum number of checkpoints to return.
Yields:
Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples.
"""
thread_ids = (config["configurable"]["thread_id"],) if config else self.storage
config_checkpoint_ns = (
config["configurable"].get("checkpoint_ns") if config else None
)
config_checkpoint_id = get_checkpoint_id(config) if config else None
for thread_id in thread_ids:
for checkpoint_ns in self.storage[thread_id].keys():
if (
config_checkpoint_ns is not None
and checkpoint_ns != config_checkpoint_ns
):
continue
for checkpoint_id, (
checkpoint,
metadata_b,
parent_checkpoint_id,
) in sorted(
self.storage[thread_id][checkpoint_ns].items(),
key=lambda x: x[0],
reverse=True,
):
# filter by checkpoint ID from config
if config_checkpoint_id and checkpoint_id != config_checkpoint_id:
continue
# filter by checkpoint ID from `before` config
if (
before
and (before_checkpoint_id := get_checkpoint_id(before))
and checkpoint_id >= before_checkpoint_id
):
continue
# filter by metadata
metadata = self.serde.loads_typed(metadata_b)
if filter and not all(
query_value == metadata.get(query_key)
for query_key, query_value in filter.items()
):
continue
# limit search results
if limit is not None and limit <= 0:
break
elif limit is not None:
limit -= 1
writes = self.writes[
(thread_id, checkpoint_ns, checkpoint_id)
].values()
if parent_checkpoint_id:
sends = [
w[2]
for w in self.writes[
(thread_id, checkpoint_ns, parent_checkpoint_id)
].values()
if w[1] == TASKS
]
else:
sends = []
yield CheckpointTuple(
config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
checkpoint={
**self.serde.loads_typed(checkpoint),
"pending_sends": [self.serde.loads_typed(s) for s in sends],
},
metadata=metadata,
parent_config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None,
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v in writes
],
)
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the in-memory storage.
This method saves a checkpoint to the in-memory storage. The checkpoint is associated
with the provided config.
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (dict): New versions as of this write
Returns:
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
"""
c = checkpoint.copy()
c.pop("pending_sends") # type: ignore[misc]
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"]["checkpoint_ns"]
self.storage[thread_id][checkpoint_ns].update(
{
checkpoint["id"]: (
self.serde.dumps_typed(c),
self.serde.dumps_typed(metadata),
config["configurable"].get("checkpoint_id"), # parent
)
}
)
return {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Save a list of writes to the in-memory storage.
This method saves a list of writes to the in-memory storage. The writes are associated
with the provided config.
Args:
config (RunnableConfig): The config to associate with the writes.
writes (list[tuple[str, Any]]): The writes to save.
task_id (str): Identifier for the task creating the writes.
Returns:
RunnableConfig: The updated config containing the saved writes' timestamp.
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
checkpoint_id = config["configurable"]["checkpoint_id"]
outer_key = (thread_id, checkpoint_ns, checkpoint_id)
outer_writes_ = self.writes.get(outer_key)
for idx, (c, v) in enumerate(writes):
inner_key = (task_id, WRITES_IDX_MAP.get(c, idx))
if inner_key[1] >= 0 and outer_writes_ and inner_key in outer_writes_:
continue
self.writes[outer_key][inner_key] = (task_id, c, self.serde.dumps_typed(v))
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Asynchronous version of get_tuple.
This method is an asynchronous wrapper around get_tuple that runs the synchronous
method in a separate thread using asyncio.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.get_tuple, config
)
async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""Asynchronous version of list.
This method is an asynchronous wrapper around list that runs the synchronous
method in a separate thread using asyncio.
Args:
config (RunnableConfig): The config to use for listing the checkpoints.
Yields:
AsyncIterator[CheckpointTuple]: An asynchronous iterator of checkpoint tuples.
"""
loop = asyncio.get_running_loop()
iter = await loop.run_in_executor(
None,
partial(
self.list,
before=before,
limit=limit,
filter=filter,
),
config,
)
while True:
# handling StopIteration exception inside coroutine won't work
# as expected, so using next() with default value to break the loop
if item := await loop.run_in_executor(None, next, iter, None):
yield item
else:
break
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Asynchronous version of put.
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (dict): New versions as of this write
Returns:
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.put, config, checkpoint, metadata, new_versions
)
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Asynchronous version of put_writes.
This method is an asynchronous wrapper around put_writes that runs the synchronous
method in a separate thread using asyncio.
Args:
config (RunnableConfig): The config to associate with the writes.
writes (List[Tuple[str, Any]]): The writes to save, each as a (channel, value) pair.
task_id (str): Identifier for the task creating the writes.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.put_writes, config, writes, task_id
)
def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
if current is None:
current_v = 0
elif isinstance(current, int):
current_v = current
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
next_h = random.random()
return f"{next_v:032}.{next_h:016}"
class PersistentDict(defaultdict):
"""Persistent dictionary with an API compatible with shelve and anydbm.
The dict is kept in memory, so the dictionary operations run as fast as
a regular dictionary.
Write to disk is delayed until close or sync (similar to gdbm's fast mode).
Input file format is automatically discovered.
Output file format is selectable between pickle, json, and csv.
All three serialization formats are backed by fast C implementations.
Adapted from https://code.activestate.com/recipes/576642-persistent-dict-with-multiple-standard-file-format/
"""
def __init__(self, *args: Any, filename: str, **kwds: Any) -> None:
self.flag = "c" # r=readonly, c=create, or n=new
self.mode = None # None or an octal triple like 0644
self.format = "pickle" # 'csv', 'json', or 'pickle'
self.filename = filename
super().__init__(*args, **kwds)
def sync(self) -> None:
"Write dict to disk"
if self.flag == "r":
return
tempname = self.filename + ".tmp"
fileobj = open(tempname, "wb" if self.format == "pickle" else "w")
try:
self.dump(fileobj)
except Exception:
os.remove(tempname)
raise
finally:
fileobj.close()
shutil.move(tempname, self.filename) # atomic commit
if self.mode is not None:
os.chmod(self.filename, self.mode)
def close(self) -> None:
self.sync()
self.clear()
def __enter__(self) -> "PersistentDict":
return self
def __exit__(self, *exc_info: Any) -> None:
self.close()
def dump(self, fileobj: Any) -> None:
if self.format == "pickle":
pickle.dump(dict(self), fileobj, 2)
else:
raise NotImplementedError("Unknown format: " + repr(self.format))
def load(self) -> None:
# try formats from most restrictive to least restrictive
if self.flag == "n":
return
with open(self.filename, "rb" if self.format == "pickle" else "r") as fileobj:
for loader in (pickle.load,):
fileobj.seek(0)
try:
return self.update(loader(fileobj))
except EOFError:
return
except Exception:
logging.error(f"Failed to load file: {fileobj.name}")
raise
raise ValueError("File not in a supported f ormat")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py`:
```py
import dataclasses
import decimal
import importlib
import json
import pathlib
import re
from collections import deque
from datetime import date, datetime, time, timedelta, timezone
from enum import Enum
from inspect import isclass
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from typing import Any, Callable, Optional, Sequence, Union, cast
from uuid import UUID
import msgpack # type: ignore[import-untyped]
from langchain_core.load.load import Reviver
from langchain_core.load.serializable import Serializable
from zoneinfo import ZoneInfo
from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.types import SendProtocol
from langgraph.store.base import Item
LC_REVIVER = Reviver()
class JsonPlusSerializer(SerializerProtocol):
def _encode_constructor_args(
self,
constructor: Union[Callable, type[Any]],
*,
method: Union[None, str, Sequence[Union[None, str]]] = None,
args: Optional[Sequence[Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
out = {
"lc": 2,
"type": "constructor",
"id": (*constructor.__module__.split("."), constructor.__name__),
}
if method is not None:
out["method"] = method
if args is not None:
out["args"] = args
if kwargs is not None:
out["kwargs"] = kwargs
return out
def _default(self, obj: Any) -> Union[str, dict[str, Any]]:
if isinstance(obj, Serializable):
return cast(dict[str, Any], obj.to_json())
elif hasattr(obj, "model_dump") and callable(obj.model_dump):
return self._encode_constructor_args(
obj.__class__, method=(None, "model_construct"), kwargs=obj.model_dump()
)
elif hasattr(obj, "dict") and callable(obj.dict):
return self._encode_constructor_args(
obj.__class__, method=(None, "construct"), kwargs=obj.dict()
)
elif hasattr(obj, "_asdict") and callable(obj._asdict):
return self._encode_constructor_args(obj.__class__, kwargs=obj._asdict())
elif isinstance(obj, pathlib.Path):
return self._encode_constructor_args(pathlib.Path, args=obj.parts)
elif isinstance(obj, re.Pattern):
return self._encode_constructor_args(
re.compile, args=(obj.pattern, obj.flags)
)
elif isinstance(obj, UUID):
return self._encode_constructor_args(UUID, args=(obj.hex,))
elif isinstance(obj, decimal.Decimal):
return self._encode_constructor_args(decimal.Decimal, args=(str(obj),))
elif isinstance(obj, (set, frozenset, deque)):
return self._encode_constructor_args(type(obj), args=(tuple(obj),))
elif isinstance(obj, (IPv4Address, IPv4Interface, IPv4Network)):
return self._encode_constructor_args(obj.__class__, args=(str(obj),))
elif isinstance(obj, (IPv6Address, IPv6Interface, IPv6Network)):
return self._encode_constructor_args(obj.__class__, args=(str(obj),))
elif isinstance(obj, datetime):
return self._encode_constructor_args(
datetime, method="fromisoformat", args=(obj.isoformat(),)
)
elif isinstance(obj, timezone):
return self._encode_constructor_args(
timezone,
args=obj.__getinitargs__(), # type: ignore[attr-defined]
)
elif isinstance(obj, ZoneInfo):
return self._encode_constructor_args(ZoneInfo, args=(obj.key,))
elif isinstance(obj, timedelta):
return self._encode_constructor_args(
timedelta, args=(obj.days, obj.seconds, obj.microseconds)
)
elif isinstance(obj, date):
return self._encode_constructor_args(
date, args=(obj.year, obj.month, obj.day)
)
elif isinstance(obj, time):
return self._encode_constructor_args(
time,
args=(obj.hour, obj.minute, obj.second, obj.microsecond, obj.tzinfo),
kwargs={"fold": obj.fold},
)
elif dataclasses.is_dataclass(obj):
return self._encode_constructor_args(
obj.__class__,
kwargs={
field.name: getattr(obj, field.name)
for field in dataclasses.fields(obj)
},
)
elif isinstance(obj, Enum):
return self._encode_constructor_args(obj.__class__, args=(obj.value,))
elif isinstance(obj, SendProtocol):
return self._encode_constructor_args(
obj.__class__, kwargs={"node": obj.node, "arg": obj.arg}
)
elif isinstance(obj, (bytes, bytearray)):
return self._encode_constructor_args(
obj.__class__, method="fromhex", args=(obj.hex(),)
)
elif isinstance(obj, BaseException):
return repr(obj)
else:
raise TypeError(
f"Object of type {obj.__class__.__name__} is not JSON serializable"
)
def _reviver(self, value: dict[str, Any]) -> Any:
if (
value.get("lc", None) == 2
and value.get("type", None) == "constructor"
and value.get("id", None) is not None
):
try:
# Get module and class name
[*module, name] = value["id"]
# Import module
mod = importlib.import_module(".".join(module))
# Import class
cls = getattr(mod, name)
# Instantiate class
method = value.get("method")
if isinstance(method, str):
methods = [getattr(cls, method)]
elif isinstance(method, list):
methods = [
cls if method is None else getattr(cls, method)
for method in method
]
else:
methods = [cls]
args = value.get("args")
kwargs = value.get("kwargs")
for method in methods:
try:
if isclass(method) and issubclass(method, BaseException):
return None
if args and kwargs:
return method(*args, **kwargs)
elif args:
return method(*args)
elif kwargs:
return method(**kwargs)
else:
return method()
except Exception:
continue
except Exception:
return None
return LC_REVIVER(value)
def dumps(self, obj: Any) -> bytes:
return json.dumps(obj, default=self._default, ensure_ascii=False).encode(
"utf-8", "ignore"
)
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
if isinstance(obj, bytes):
return "bytes", obj
elif isinstance(obj, bytearray):
return "bytearray", obj
else:
try:
return "msgpack", _msgpack_enc(obj)
except UnicodeEncodeError:
return "json", self.dumps(obj)
def loads(self, data: bytes) -> Any:
return json.loads(data, object_hook=self._reviver)
def loads_typed(self, data: tuple[str, bytes]) -> Any:
type_, data_ = data
if type_ == "bytes":
return data_
elif type_ == "bytearray":
return bytearray(data_)
elif type_ == "json":
return self.loads(data_)
elif type_ == "msgpack":
return msgpack.unpackb(
data_, ext_hook=_msgpack_ext_hook, strict_map_key=False
)
else:
raise NotImplementedError(f"Unknown serialization type: {type_}")
# --- msgpack ---
EXT_CONSTRUCTOR_SINGLE_ARG = 0
EXT_CONSTRUCTOR_POS_ARGS = 1
EXT_CONSTRUCTOR_KW_ARGS = 2
EXT_METHOD_SINGLE_ARG = 3
EXT_PYDANTIC_V1 = 4
EXT_PYDANTIC_V2 = 5
def _msgpack_default(obj: Any) -> Union[str, msgpack.ExtType]:
if hasattr(obj, "model_dump") and callable(obj.model_dump): # pydantic v2
return msgpack.ExtType(
EXT_PYDANTIC_V2,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.model_dump(),
"model_validate_json",
),
),
)
elif hasattr(obj, "get_secret_value") and callable(obj.get_secret_value):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.get_secret_value(),
),
),
)
elif hasattr(obj, "dict") and callable(obj.dict): # pydantic v1
return msgpack.ExtType(
EXT_PYDANTIC_V1,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.dict(),
),
),
)
elif hasattr(obj, "_asdict") and callable(obj._asdict): # namedtuple
return msgpack.ExtType(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj._asdict(),
),
),
)
elif isinstance(obj, pathlib.Path):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.parts),
),
)
elif isinstance(obj, re.Pattern):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
("re", "compile", (obj.pattern, obj.flags)),
),
)
elif isinstance(obj, UUID):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.hex),
),
)
elif isinstance(obj, decimal.Decimal):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, (set, frozenset, deque)):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, tuple(obj)),
),
)
elif isinstance(obj, (IPv4Address, IPv4Interface, IPv4Network)):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, (IPv6Address, IPv6Interface, IPv6Network)):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, datetime):
return msgpack.ExtType(
EXT_METHOD_SINGLE_ARG,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.isoformat(),
"fromisoformat",
),
),
)
elif isinstance(obj, timedelta):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
(obj.days, obj.seconds, obj.microseconds),
),
),
)
elif isinstance(obj, date):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
(obj.year, obj.month, obj.day),
),
),
)
elif isinstance(obj, time):
return msgpack.ExtType(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{
"hour": obj.hour,
"minute": obj.minute,
"second": obj.second,
"microsecond": obj.microsecond,
"tzinfo": obj.tzinfo,
"fold": obj.fold,
},
),
),
)
elif isinstance(obj, timezone):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.__getinitargs__(), # type: ignore[attr-defined]
),
),
)
elif isinstance(obj, ZoneInfo):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.key),
),
)
elif isinstance(obj, Enum):
return msgpack.ExtType(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.value),
),
)
elif isinstance(obj, SendProtocol):
return msgpack.ExtType(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, (obj.node, obj.arg)),
),
)
elif dataclasses.is_dataclass(obj):
# doesn't use dataclasses.asdict to avoid deepcopy and recursion
return msgpack.ExtType(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{
field.name: getattr(obj, field.name)
for field in dataclasses.fields(obj)
},
),
),
)
elif isinstance(obj, Item):
return msgpack.ExtType(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{k: getattr(obj, k) for k in obj.__slots__},
),
),
)
elif isinstance(obj, BaseException):
return repr(obj)
else:
raise TypeError(f"Object of type {obj.__class__.__name__} is not serializable")
def _msgpack_ext_hook(code: int, data: bytes) -> Any:
if code == EXT_CONSTRUCTOR_SINGLE_ARG:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, arg
return getattr(importlib.import_module(tup[0]), tup[1])(tup[2])
except Exception:
return
elif code == EXT_CONSTRUCTOR_POS_ARGS:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, args
return getattr(importlib.import_module(tup[0]), tup[1])(*tup[2])
except Exception:
return
elif code == EXT_CONSTRUCTOR_KW_ARGS:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, args
return getattr(importlib.import_module(tup[0]), tup[1])(**tup[2])
except Exception:
return
elif code == EXT_METHOD_SINGLE_ARG:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, arg, method
return getattr(getattr(importlib.import_module(tup[0]), tup[1]), tup[3])(
tup[2]
)
except Exception:
return
elif code == EXT_PYDANTIC_V1:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, kwargs
cls = getattr(importlib.import_module(tup[0]), tup[1])
try:
return cls(**tup[2])
except Exception:
return cls.construct(**tup[2])
except Exception:
return
elif code == EXT_PYDANTIC_V2:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
# module, name, kwargs, method
cls = getattr(importlib.import_module(tup[0]), tup[1])
try:
return cls(**tup[2])
except Exception:
return cls.model_construct(**tup[2])
except Exception:
return
ENC_POOL: deque[msgpack.Packer] = deque(maxlen=32)
def _msgpack_enc(data: Any) -> bytes:
try:
enc = ENC_POOL.popleft()
except IndexError:
enc = msgpack.Packer(default=_msgpack_default)
try:
return enc.pack(data)
finally:
ENC_POOL.append(enc)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/serde/types.py`:
```py
from typing import (
Any,
Optional,
Protocol,
Sequence,
TypeVar,
runtime_checkable,
)
from typing_extensions import Self
ERROR = "__error__"
SCHEDULED = "__scheduled__"
INTERRUPT = "__interrupt__"
RESUME = "__resume__"
TASKS = "__pregel_tasks"
Value = TypeVar("Value", covariant=True)
Update = TypeVar("Update", contravariant=True)
C = TypeVar("C")
class ChannelProtocol(Protocol[Value, Update, C]):
# Mirrors langgraph.channels.base.BaseChannel
@property
def ValueType(self) -> Any: ...
@property
def UpdateType(self) -> Any: ...
def checkpoint(self) -> Optional[C]: ...
def from_checkpoint(self, checkpoint: Optional[C]) -> Self: ...
def update(self, values: Sequence[Update]) -> bool: ...
def get(self) -> Value: ...
def consume(self) -> bool: ...
@runtime_checkable
class SendProtocol(Protocol):
# Mirrors langgraph.constants.Send
node: str
arg: Any
def __hash__(self) -> int: ...
def __repr__(self) -> str: ...
def __eq__(self, value: object) -> bool: ...
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/serde/base.py`:
```py
from typing import Any, Protocol
class SerializerProtocol(Protocol):
"""Protocol for serialization and deserialization of objects.
- `dumps`: Serialize an object to bytes.
- `dumps_typed`: Serialize an object to a tuple (type, bytes).
- `loads`: Deserialize an object from bytes.
- `loads_typed`: Deserialize an object from a tuple (type, bytes).
Valid implementations include the `pickle`, `json` and `orjson` modules.
"""
def dumps(self, obj: Any) -> bytes: ...
def dumps_typed(self, obj: Any) -> tuple[str, bytes]: ...
def loads(self, data: bytes) -> Any: ...
def loads_typed(self, data: tuple[str, bytes]) -> Any: ...
class SerializerCompat(SerializerProtocol):
def __init__(self, serde: SerializerProtocol) -> None:
self.serde = serde
def dumps(self, obj: Any) -> bytes:
return self.serde.dumps(obj)
def loads(self, data: bytes) -> Any:
return self.serde.loads(data)
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
return type(obj).__name__, self.serde.dumps(obj)
def loads_typed(self, data: tuple[str, bytes]) -> Any:
return self.serde.loads(data[1])
def maybe_add_typed_methods(serde: SerializerProtocol) -> SerializerProtocol:
"""Wrap serde old serde implementations in a class with loads_typed and dumps_typed for backwards compatibility."""
if not hasattr(serde, "loads_typed") or not hasattr(serde, "dumps_typed"):
return SerializerCompat(serde)
return serde
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/base/id.py`:
```py
"""Adapted from
https://github.com/oittaa/uuid6-python/blob/main/src/uuid6/__init__.py#L95
Bundled in to avoid install issues with uuid6 package
"""
import random
import time
import uuid
from typing import Optional, Tuple
_last_v6_timestamp = None
class UUID(uuid.UUID):
r"""UUID draft version objects"""
__slots__ = ()
def __init__(
self,
hex: Optional[str] = None,
bytes: Optional[bytes] = None,
bytes_le: Optional[bytes] = None,
fields: Optional[Tuple[int, int, int, int, int, int]] = None,
int: Optional[int] = None,
version: Optional[int] = None,
*,
is_safe: uuid.SafeUUID = uuid.SafeUUID.unknown,
) -> None:
r"""Create a UUID."""
if int is None or [hex, bytes, bytes_le, fields].count(None) != 4:
return super().__init__(
hex=hex,
bytes=bytes,
bytes_le=bytes_le,
fields=fields,
int=int,
version=version,
is_safe=is_safe,
)
if not 0 <= int < 1 << 128:
raise ValueError("int is out of range (need a 128-bit value)")
if version is not None:
if not 6 <= version <= 8:
raise ValueError("illegal version number")
# Set the variant to RFC 4122.
int &= ~(0xC000 << 48)
int |= 0x8000 << 48
# Set the version number.
int &= ~(0xF000 << 64)
int |= version << 76
super().__init__(int=int, is_safe=is_safe)
@property
def subsec(self) -> int:
return ((self.int >> 64) & 0x0FFF) << 8 | ((self.int >> 54) & 0xFF)
@property
def time(self) -> int:
if self.version == 6:
return (
(self.time_low << 28)
| (self.time_mid << 12)
| (self.time_hi_version & 0x0FFF)
)
if self.version == 7:
return self.int >> 80
if self.version == 8:
return (self.int >> 80) * 10**6 + _subsec_decode(self.subsec)
return super().time
def _subsec_decode(value: int) -> int:
return -(-value * 10**6 // 2**20)
def uuid6(node: Optional[int] = None, clock_seq: Optional[int] = None) -> UUID:
r"""UUID version 6 is a field-compatible version of UUIDv1, reordered for
improved DB locality. It is expected that UUIDv6 will primarily be
used in contexts where there are existing v1 UUIDs. Systems that do
not involve legacy UUIDv1 SHOULD consider using UUIDv7 instead.
If 'node' is not given, a random 48-bit number is chosen.
If 'clock_seq' is given, it is used as the sequence number;
otherwise a random 14-bit sequence number is chosen."""
global _last_v6_timestamp
nanoseconds = time.time_ns()
# 0x01b21dd213814000 is the number of 100-ns intervals between the
# UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00.
timestamp = nanoseconds // 100 + 0x01B21DD213814000
if _last_v6_timestamp is not None and timestamp <= _last_v6_timestamp:
timestamp = _last_v6_timestamp + 1
_last_v6_timestamp = timestamp
if clock_seq is None:
clock_seq = random.getrandbits(14) # instead of stable storage
if node is None:
node = random.getrandbits(48)
time_high_and_time_mid = (timestamp >> 12) & 0xFFFFFFFFFFFF
time_low_and_version = timestamp & 0x0FFF
uuid_int = time_high_and_time_mid << 80
uuid_int |= time_low_and_version << 64
uuid_int |= (clock_seq & 0x3FFF) << 48
uuid_int |= node & 0xFFFFFFFFFFFF
return UUID(int=uuid_int, version=6)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/checkpoint/base/__init__.py`:
```py
from datetime import datetime, timezone
from typing import (
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
List,
Literal,
Mapping,
NamedTuple,
Optional,
Sequence,
Tuple,
TypedDict,
TypeVar,
Union,
)
from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
from langgraph.checkpoint.base.id import uuid6
from langgraph.checkpoint.serde.base import SerializerProtocol, maybe_add_typed_methods
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import (
ERROR,
INTERRUPT,
RESUME,
SCHEDULED,
ChannelProtocol,
SendProtocol,
)
V = TypeVar("V", int, float, str)
PendingWrite = Tuple[str, str, Any]
# Marked as total=False to allow for future expansion.
class CheckpointMetadata(TypedDict, total=False):
"""Metadata associated with a checkpoint."""
source: Literal["input", "loop", "update", "fork"]
"""The source of the checkpoint.
- "input": The checkpoint was created from an input to invoke/stream/batch.
- "loop": The checkpoint was created from inside the pregel loop.
- "update": The checkpoint was created from a manual state update.
- "fork": The checkpoint was created as a copy of another checkpoint.
"""
step: int
"""The step number of the checkpoint.
-1 for the first "input" checkpoint.
0 for the first "loop" checkpoint.
... for the nth checkpoint afterwards.
"""
writes: dict[str, Any]
"""The writes that were made between the previous checkpoint and this one.
Mapping from node name to writes emitted by that node.
"""
parents: dict[str, str]
"""The IDs of the parent checkpoints.
Mapping from checkpoint namespace to checkpoint ID.
"""
class TaskInfo(TypedDict):
status: Literal["scheduled", "success", "error"]
ChannelVersions = dict[str, Union[str, int, float]]
class Checkpoint(TypedDict):
"""State snapshot at a given point in time."""
v: int
"""The version of the checkpoint format. Currently 1."""
id: str
"""The ID of the checkpoint. This is both unique and monotonically
increasing, so can be used for sorting checkpoints from first to last."""
ts: str
"""The timestamp of the checkpoint in ISO 8601 format."""
channel_values: dict[str, Any]
"""The values of the channels at the time of the checkpoint.
Mapping from channel name to deserialized channel snapshot value.
"""
channel_versions: ChannelVersions
"""The versions of the channels at the time of the checkpoint.
The keys are channel names and the values are monotonically increasing
version strings for each channel.
"""
versions_seen: dict[str, ChannelVersions]
"""Map from node ID to map from channel name to version seen.
This keeps track of the versions of the channels that each node has seen.
Used to determine which nodes to execute next.
"""
pending_sends: List[SendProtocol]
"""List of inputs pushed to nodes but not yet processed.
Cleared by the next checkpoint."""
def empty_checkpoint() -> Checkpoint:
return Checkpoint(
v=1,
id=str(uuid6(clock_seq=-2)),
ts=datetime.now(timezone.utc).isoformat(),
channel_values={},
channel_versions={},
versions_seen={},
pending_sends=[],
)
def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint:
return Checkpoint(
v=checkpoint["v"],
ts=checkpoint["ts"],
id=checkpoint["id"],
channel_values=checkpoint["channel_values"].copy(),
channel_versions=checkpoint["channel_versions"].copy(),
versions_seen={k: v.copy() for k, v in checkpoint["versions_seen"].items()},
pending_sends=checkpoint.get("pending_sends", []).copy(),
)
def create_checkpoint(
checkpoint: Checkpoint,
channels: Optional[Mapping[str, ChannelProtocol]],
step: int,
*,
id: Optional[str] = None,
) -> Checkpoint:
"""Create a checkpoint for the given channels."""
ts = datetime.now(timezone.utc).isoformat()
if channels is None:
values = checkpoint["channel_values"]
else:
values = {}
for k, v in channels.items():
if k not in checkpoint["channel_versions"]:
continue
try:
values[k] = v.checkpoint()
except EmptyChannelError:
pass
return Checkpoint(
v=1,
ts=ts,
id=id or str(uuid6(clock_seq=step)),
channel_values=values,
channel_versions=checkpoint["channel_versions"],
versions_seen=checkpoint["versions_seen"],
pending_sends=checkpoint.get("pending_sends", []),
)
class CheckpointTuple(NamedTuple):
"""A tuple containing a checkpoint and its associated data."""
config: RunnableConfig
checkpoint: Checkpoint
metadata: CheckpointMetadata
parent_config: Optional[RunnableConfig] = None
pending_writes: Optional[List[PendingWrite]] = None
CheckpointThreadId = ConfigurableFieldSpec(
id="thread_id",
annotation=str,
name="Thread ID",
description=None,
default="",
is_shared=True,
)
CheckpointNS = ConfigurableFieldSpec(
id="checkpoint_ns",
annotation=str,
name="Checkpoint NS",
description='Checkpoint namespace. Denotes the path to the subgraph node the checkpoint originates from, separated by `|` character, e.g. `"child|grandchild"`. Defaults to "" (root graph).',
default="",
is_shared=True,
)
CheckpointId = ConfigurableFieldSpec(
id="checkpoint_id",
annotation=Optional[str],
name="Checkpoint ID",
description="Pass to fetch a past checkpoint. If None, fetches the latest checkpoint.",
default=None,
is_shared=True,
)
class BaseCheckpointSaver(Generic[V]):
"""Base class for creating a graph checkpointer.
Checkpointers allow LangGraph agents to persist their state
within and across multiple interactions.
Attributes:
serde (SerializerProtocol): Serializer for encoding/decoding checkpoints.
Note:
When creating a custom checkpoint saver, consider implementing async
versions to avoid blocking the main thread.
"""
serde: SerializerProtocol = JsonPlusSerializer()
def __init__(
self,
*,
serde: Optional[SerializerProtocol] = None,
) -> None:
self.serde = maybe_add_typed_methods(serde or self.serde)
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
"""Define the configuration options for the checkpoint saver.
Returns:
list[ConfigurableFieldSpec]: List of configuration field specs.
"""
return [CheckpointThreadId, CheckpointNS, CheckpointId]
def get(self, config: RunnableConfig) -> Optional[Checkpoint]:
"""Fetch a checkpoint using the given configuration.
Args:
config (RunnableConfig): Configuration specifying which checkpoint to retrieve.
Returns:
Optional[Checkpoint]: The requested checkpoint, or None if not found.
"""
if value := self.get_tuple(config):
return value.checkpoint
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Fetch a checkpoint tuple using the given configuration.
Args:
config (RunnableConfig): Configuration specifying which checkpoint to retrieve.
Returns:
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints that match the given criteria.
Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria.
before (Optional[RunnableConfig]): List checkpoints created before this configuration.
limit (Optional[int]): Maximum number of checkpoints to return.
Returns:
Iterator[CheckpointTuple]: Iterator of matching checkpoint tuples.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Store a checkpoint with its configuration and metadata.
Args:
config (RunnableConfig): Configuration for the checkpoint.
checkpoint (Checkpoint): The checkpoint to store.
metadata (CheckpointMetadata): Additional metadata for the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (List[Tuple[str, Any]]): List of writes to store.
task_id (str): Identifier for the task creating the writes.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
async def aget(self, config: RunnableConfig) -> Optional[Checkpoint]:
"""Asynchronously fetch a checkpoint using the given configuration.
Args:
config (RunnableConfig): Configuration specifying which checkpoint to retrieve.
Returns:
Optional[Checkpoint]: The requested checkpoint, or None if not found.
"""
if value := await self.aget_tuple(config):
return value.checkpoint
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Asynchronously fetch a checkpoint tuple using the given configuration.
Args:
config (RunnableConfig): Configuration specifying which checkpoint to retrieve.
Returns:
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""Asynchronously list checkpoints that match the given criteria.
Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata.
before (Optional[RunnableConfig]): List checkpoints created before this configuration.
limit (Optional[int]): Maximum number of checkpoints to return.
Returns:
AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
yield
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Asynchronously store a checkpoint with its configuration and metadata.
Args:
config (RunnableConfig): Configuration for the checkpoint.
checkpoint (Checkpoint): The checkpoint to store.
metadata (CheckpointMetadata): Additional metadata for the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Asynchronously store intermediate writes linked to a checkpoint.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (List[Tuple[str, Any]]): List of writes to store.
task_id (str): Identifier for the task creating the writes.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V:
"""Generate the next version ID for a channel.
Default is to use integer versions, incrementing by 1. If you override, you can use str/int/float versions,
as long as they are monotonically increasing.
Args:
current (Optional[V]): The current version identifier (int, float, or str).
channel (BaseChannel): The channel being versioned.
Returns:
V: The next version identifier, which must be increasing.
"""
if isinstance(current, str):
raise NotImplementedError
elif current is None:
return 1
else:
return current + 1
class EmptyChannelError(Exception):
"""Raised when attempting to get the value of a channel that hasn't been updated
for the first time yet."""
pass
def get_checkpoint_id(config: RunnableConfig) -> Optional[str]:
"""Get checkpoint ID in a backwards-compatible manner (fallback on thread_ts)."""
return config["configurable"].get(
"checkpoint_id", config["configurable"].get("thread_ts")
)
"""
Mapping from error type to error index.
Regular writes just map to their index in the list of writes being saved.
Special writes (e.g. errors) map to negative indices, to avoid those writes from
conflicting with regular writes.
Each Checkpointer implementation should use this mapping in put_writes.
"""
WRITES_IDX_MAP = {ERROR: -1, SCHEDULED: -2, INTERRUPT: -3, RESUME: -4}
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/store/memory/__init__.py`:
```py
"""In-memory dictionary-backed store with optional vector search.
!!! example "Examples"
Basic key-value storage:
```python
from langgraph.store.memory import InMemoryStore
store = InMemoryStore()
store.put(("users", "123"), "prefs", {"theme": "dark"})
item = store.get(("users", "123"), "prefs")
```
Vector search using LangChain embeddings:
```python
from langchain.embeddings import init_embeddings
from langgraph.store.memory import InMemoryStore
store = InMemoryStore(
index={
"dims": 1536,
"embed": init_embeddings("openai:text-embedding-3-small")
}
)
# Store documents
store.put(("docs",), "doc1", {"text": "Python tutorial"})
store.put(("docs",), "doc2", {"text": "TypeScript guide"})
# Search by similarity
results = store.search(("docs",), query="python programming")
```
Vector search using OpenAI SDK directly:
```python
from openai import OpenAI
from langgraph.store.memory import InMemoryStore
client = OpenAI()
def embed_texts(texts: list[str]) -> list[list[float]]:
response = client.embeddings.create(
model="text-embedding-3-small",
input=texts
)
return [e.embedding for e in response.data]
store = InMemoryStore(
index={
"dims": 1536,
"embed": embed_texts
}
)
# Store documents
store.put(("docs",), "doc1", {"text": "Python tutorial"})
store.put(("docs",), "doc2", {"text": "TypeScript guide"})
# Search by similarity
results = store.search(("docs",), query="python programming")
```
Async vector search using OpenAI SDK:
```python
from openai import AsyncOpenAI
from langgraph.store.memory import InMemoryStore
client = AsyncOpenAI()
async def aembed_texts(texts: list[str]) -> list[list[float]]:
response = await client.embeddings.create(
model="text-embedding-3-small",
input=texts
)
return [e.embedding for e in response.data]
store = InMemoryStore(
index={
"dims": 1536,
"embed": aembed_texts
}
)
# Store documents
await store.aput(("docs",), "doc1", {"text": "Python tutorial"})
await store.aput(("docs",), "doc2", {"text": "TypeScript guide"})
# Search by similarity
results = await store.asearch(("docs",), query="python programming")
```
Warning:
This store keeps all data in memory. Data is lost when the process exits.
For persistence, use a database-backed store like PostgresStore.
Tip:
For vector search, install numpy for better performance:
```bash
pip install numpy
```
"""
import asyncio
import concurrent.futures as cf
import functools
import logging
from collections import defaultdict
from datetime import datetime, timezone
from importlib import util
from typing import Any, Iterable, Optional
from langchain_core.embeddings import Embeddings
from langgraph.store.base import (
BaseStore,
GetOp,
IndexConfig,
Item,
ListNamespacesOp,
MatchCondition,
Op,
PutOp,
Result,
SearchItem,
SearchOp,
ensure_embeddings,
get_text_at_path,
tokenize_path,
)
logger = logging.getLogger(__name__)
class InMemoryStore(BaseStore):
"""In-memory dictionary-backed store with optional vector search.
!!! example "Examples"
Basic key-value storage:
store = InMemoryStore()
store.put(("users", "123"), "prefs", {"theme": "dark"})
item = store.get(("users", "123"), "prefs")
Vector search with embeddings:
from langchain.embeddings import init_embeddings
store = InMemoryStore(index={
"dims": 1536,
"embed": init_embeddings("openai:text-embedding-3-small"),
"fields": ["text"],
})
# Store documents
store.put(("docs",), "doc1", {"text": "Python tutorial"})
store.put(("docs",), "doc2", {"text": "TypeScript guide"})
# Search by similarity
results = store.search(("docs",), query="python programming")
Note:
Semantic search is disabled by default. You can enable it by providing an `index` configuration
when creating the store. Without this configuration, all `index` arguments passed to
`put` or `aput`will have no effect.
Warning:
This store keeps all data in memory. Data is lost when the process exits.
For persistence, use a database-backed store like PostgresStore.
Tip:
For vector search, install numpy for better performance:
```bash
pip install numpy
```
"""
__slots__ = (
"_data",
"_vectors",
"index_config",
"embeddings",
)
def __init__(self, *, index: Optional[IndexConfig] = None) -> None:
# Both _data and _vectors are wrapped in the In-memory API
# Do not change their names
self._data: dict[tuple[str, ...], dict[str, Item]] = defaultdict(dict)
# [ns][key][path]
self._vectors: dict[tuple[str, ...], dict[str, dict[str, list[float]]]] = (
defaultdict(lambda: defaultdict(dict))
)
self.index_config = index
if self.index_config:
self.index_config = self.index_config.copy()
self.embeddings: Optional[Embeddings] = ensure_embeddings(
self.index_config.get("embed"),
)
self.index_config["__tokenized_fields"] = [
(p, tokenize_path(p)) if p != "$" else (p, p)
for p in (self.index_config.get("fields") or ["$"])
]
else:
self.index_config = None
self.embeddings = None
def batch(self, ops: Iterable[Op]) -> list[Result]:
# The batch/abatch methods are treated as internal.
# Users should access via put/search/get/list_namespaces/etc.
results, put_ops, search_ops = self._prepare_ops(ops)
if search_ops:
queryinmem_store = self._embed_search_queries(search_ops)
self._batch_search(search_ops, queryinmem_store, results)
to_embed = self._extract_texts(put_ops)
if to_embed and self.index_config and self.embeddings:
embeddings = self.embeddings.embed_documents(list(to_embed))
self._insertinmem_store(to_embed, embeddings)
self._apply_put_ops(put_ops)
return results
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
# The batch/abatch methods are treated as internal.
# Users should access via put/search/get/list_namespaces/etc.
results, put_ops, search_ops = self._prepare_ops(ops)
if search_ops:
queryinmem_store = await self._aembed_search_queries(search_ops)
self._batch_search(search_ops, queryinmem_store, results)
to_embed = self._extract_texts(put_ops)
if to_embed and self.index_config and self.embeddings:
embeddings = await self.embeddings.aembed_documents(list(to_embed))
self._insertinmem_store(to_embed, embeddings)
self._apply_put_ops(put_ops)
return results
# Helpers
def _filter_items(self, op: SearchOp) -> list[tuple[Item, list[list[float]]]]:
"""Filter items by namespace and filter function, return items with their embeddings."""
namespace_prefix = op.namespace_prefix
def filter_func(item: Item) -> bool:
if not op.filter:
return True
return all(
_compare_values(item.value.get(key), filter_value)
for key, filter_value in op.filter.items()
)
filtered = []
for namespace in self._data:
if not (
namespace[: len(namespace_prefix)] == namespace_prefix
if len(namespace) >= len(namespace_prefix)
else False
):
continue
for key, item in self._data[namespace].items():
if filter_func(item):
if op.query and (embeddings := self._vectors[namespace].get(key)):
filtered.append((item, list(embeddings.values())))
else:
filtered.append((item, []))
return filtered
def _embed_search_queries(
self,
search_ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
) -> dict[str, list[float]]:
queryinmem_store = {}
if self.index_config and self.embeddings and search_ops:
queries = {op.query for (op, _) in search_ops.values() if op.query}
if queries:
with cf.ThreadPoolExecutor() as executor:
futures = {
q: executor.submit(self.embeddings.embed_query, q)
for q in list(queries)
}
for query, future in futures.items():
queryinmem_store[query] = future.result()
return queryinmem_store
async def _aembed_search_queries(
self,
search_ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
) -> dict[str, list[float]]:
queryinmem_store = {}
if self.index_config and self.embeddings and search_ops:
queries = {op.query for (op, _) in search_ops.values() if op.query}
if queries:
coros = [self.embeddings.aembed_query(q) for q in list(queries)]
results = await asyncio.gather(*coros)
queryinmem_store = dict(zip(queries, results))
return queryinmem_store
def _batch_search(
self,
ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
queryinmem_store: dict[str, list[float]],
results: list[Result],
) -> None:
"""Perform batch similarity search for multiple queries."""
for i, (op, candidates) in ops.items():
if not candidates:
results[i] = []
continue
if op.query and queryinmem_store:
query_embedding = queryinmem_store[op.query]
flat_items, flat_vectors = [], []
scoreless = []
for item, vectors in candidates:
for vector in vectors:
flat_items.append(item)
flat_vectors.append(vector)
if not vectors:
scoreless.append(item)
scores = _cosine_similarity(query_embedding, flat_vectors)
sorted_results = sorted(
zip(scores, flat_items), key=lambda x: x[0], reverse=True
)
# max pooling
seen: set[tuple[tuple[str, ...], str]] = set()
kept: list[tuple[Optional[float], Item]] = []
for score, item in sorted_results:
key = (item.namespace, item.key)
if key in seen:
continue
ix = len(seen)
seen.add(key)
if ix >= op.offset + op.limit:
break
if ix < op.offset:
continue
kept.append((score, item))
if scoreless and len(kept) < op.limit:
# Corner case: if we request more items than what we have embedded,
# fill the rest with non-scored items
kept.extend(
(None, item) for item in scoreless[: op.limit - len(kept)]
)
results[i] = [
SearchItem(
namespace=item.namespace,
key=item.key,
value=item.value,
created_at=item.created_at,
updated_at=item.updated_at,
score=float(score) if score is not None else None,
)
for score, item in kept
]
else:
results[i] = [
SearchItem(
namespace=item.namespace,
key=item.key,
value=item.value,
created_at=item.created_at,
updated_at=item.updated_at,
)
for (item, _) in candidates[op.offset : op.offset + op.limit]
]
def _prepare_ops(
self, ops: Iterable[Op]
) -> tuple[
list[Result],
dict[tuple[tuple[str, ...], str], PutOp],
dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
]:
results: list[Result] = []
put_ops: dict[tuple[tuple[str, ...], str], PutOp] = {}
search_ops: dict[
int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]
] = {}
for i, op in enumerate(ops):
if isinstance(op, GetOp):
item = self._data[op.namespace].get(op.key)
results.append(item)
elif isinstance(op, SearchOp):
search_ops[i] = (op, self._filter_items(op))
results.append(None)
elif isinstance(op, ListNamespacesOp):
results.append(self._handle_list_namespaces(op))
elif isinstance(op, PutOp):
put_ops[(op.namespace, op.key)] = op
results.append(None)
else:
raise ValueError(f"Unknown operation type: {type(op)}")
return results, put_ops, search_ops
def _apply_put_ops(self, put_ops: dict[tuple[tuple[str, ...], str], PutOp]) -> None:
for (namespace, key), op in put_ops.items():
if op.value is None:
self._data[namespace].pop(key, None)
self._vectors[namespace].pop(key, None)
else:
self._data[namespace][key] = Item(
value=op.value,
key=key,
namespace=namespace,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
def _extract_texts(
self, put_ops: dict[tuple[tuple[str, ...], str], PutOp]
) -> dict[str, list[tuple[tuple[str, ...], str, str]]]:
if put_ops and self.index_config and self.embeddings:
to_embed = defaultdict(list)
for op in put_ops.values():
if op.value is not None and op.index is not False:
if op.index is None:
paths = self.index_config["__tokenized_fields"]
else:
paths = [(ix, tokenize_path(ix)) for ix in op.index]
for path, field in paths:
texts = get_text_at_path(op.value, field)
if texts:
if len(texts) > 1:
for i, text in enumerate(texts):
to_embed[text].append(
(op.namespace, op.key, f"{path}.{i}")
)
else:
to_embed[texts[0]].append((op.namespace, op.key, path))
return to_embed
return {}
def _insertinmem_store(
self,
to_embed: dict[str, list[tuple[tuple[str, ...], str, str]]],
embeddings: list[list[float]],
) -> None:
indices = [index for indices in to_embed.values() for index in indices]
if len(indices) != len(embeddings):
raise ValueError(
f"Number of embeddings ({len(embeddings)}) does not"
f" match number of indices ({len(indices)})"
)
for embedding, (ns, key, path) in zip(embeddings, indices):
self._vectors[ns][key][path] = embedding
def _handle_list_namespaces(self, op: ListNamespacesOp) -> list[tuple[str, ...]]:
all_namespaces = list(
self._data.keys()
) # Avoid collection size changing while iterating
namespaces = all_namespaces
if op.match_conditions:
namespaces = [
ns
for ns in namespaces
if all(_does_match(condition, ns) for condition in op.match_conditions)
]
if op.max_depth is not None:
namespaces = sorted({ns[: op.max_depth] for ns in namespaces})
else:
namespaces = sorted(namespaces)
return namespaces[op.offset : op.offset + op.limit]
@functools.lru_cache(maxsize=1)
def _check_numpy() -> bool:
if bool(util.find_spec("numpy")):
return True
logger.warning(
"NumPy not found in the current Python environment. "
"The InMemoryStore will use a pure Python implementation for vector operations, "
"which may significantly impact performance, especially for large datasets or frequent searches. "
"For optimal speed and efficiency, consider installing NumPy: "
"pip install numpy"
)
return False
def _cosine_similarity(X: list[float], Y: list[list[float]]) -> list[float]:
"""
Compute cosine similarity between a vector X and a matrix Y.
Lazy import numpy for efficiency.
"""
if not Y:
return []
if _check_numpy():
import numpy as np # type: ignore
X_arr = np.array(X) if not isinstance(X, np.ndarray) else X
Y_arr = np.array(Y) if not isinstance(Y, np.ndarray) else Y
X_norm = np.linalg.norm(X_arr)
Y_norm = np.linalg.norm(Y_arr, axis=1)
# Avoid division by zero
mask = Y_norm != 0
similarities = np.zeros_like(Y_norm)
similarities[mask] = np.dot(Y_arr[mask], X_arr) / (Y_norm[mask] * X_norm)
return similarities.tolist()
similarities = []
for y in Y:
dot_product = sum(a * b for a, b in zip(X, y))
norm1 = sum(a * a for a in X) ** 0.5
norm2 = sum(a * a for a in y) ** 0.5
similarity = dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0.0
similarities.append(similarity)
return similarities
def _does_match(match_condition: MatchCondition, key: tuple[str, ...]) -> bool:
"""Whether a namespace key matches a match condition."""
match_type = match_condition.match_type
path = match_condition.path
if len(key) < len(path):
return False
if match_type == "prefix":
for k_elem, p_elem in zip(key, path):
if p_elem == "*":
continue # Wildcard matches any element
if k_elem != p_elem:
return False
return True
elif match_type == "suffix":
for k_elem, p_elem in zip(reversed(key), reversed(path)):
if p_elem == "*":
continue # Wildcard matches any element
if k_elem != p_elem:
return False
return True
else:
raise ValueError(f"Unsupported match type: {match_type}")
def _compare_values(item_value: Any, filter_value: Any) -> bool:
"""Compare values in a JSONB-like way, handling nested objects."""
if isinstance(filter_value, dict):
if any(k.startswith("$") for k in filter_value):
return all(
_apply_operator(item_value, op_key, op_value)
for op_key, op_value in filter_value.items()
)
if not isinstance(item_value, dict):
return False
return all(
_compare_values(item_value.get(k), v) for k, v in filter_value.items()
)
elif isinstance(filter_value, (list, tuple)):
return (
isinstance(item_value, (list, tuple))
and len(item_value) == len(filter_value)
and all(_compare_values(iv, fv) for iv, fv in zip(item_value, filter_value))
)
else:
return item_value == filter_value
def _apply_operator(value: Any, operator: str, op_value: Any) -> bool:
"""Apply a comparison operator, matching PostgreSQL's JSONB behavior."""
if operator == "$eq":
return value == op_value
elif operator == "$gt":
return float(value) > float(op_value)
elif operator == "$gte":
return float(value) >= float(op_value)
elif operator == "$lt":
return float(value) < float(op_value)
elif operator == "$lte":
return float(value) <= float(op_value)
elif operator == "$ne":
return value != op_value
else:
raise ValueError(f"Unsupported operator: {operator}")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/store/base/embed.py`:
```py
"""Utilities for working with embedding functions and LangChain's Embeddings interface.
This module provides tools to wrap arbitrary embedding functions (both sync and async)
into LangChain's Embeddings interface. This enables using custom embedding functions
with LangChain-compatible tools while maintaining support for both synchronous and
asynchronous operations.
"""
import asyncio
import json
from typing import Any, Awaitable, Callable, Optional, Sequence, Union
from langchain_core.embeddings import Embeddings
EmbeddingsFunc = Callable[[Sequence[str]], list[list[float]]]
"""Type for synchronous embedding functions.
The function should take a sequence of strings and return a list of embeddings,
where each embedding is a list of floats. The dimensionality of the embeddings
should be consistent for all inputs.
"""
AEmbeddingsFunc = Callable[[Sequence[str]], Awaitable[list[list[float]]]]
"""Type for asynchronous embedding functions.
Similar to EmbeddingsFunc, but returns an awaitable that resolves to the embeddings.
"""
def ensure_embeddings(
embed: Union[Embeddings, EmbeddingsFunc, AEmbeddingsFunc, None],
) -> Embeddings:
"""Ensure that an embedding function conforms to LangChain's Embeddings interface.
This function wraps arbitrary embedding functions to make them compatible with
LangChain's Embeddings interface. It handles both synchronous and asynchronous
functions.
Args:
embed: Either an existing Embeddings instance, or a function that converts
text to embeddings. If the function is async, it will be used for both
sync and async operations.
Returns:
An Embeddings instance that wraps the provided function(s).
??? example "Examples"
Wrap a synchronous embedding function:
```python
def my_embed_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = ensure_embeddings(my_embed_fn)
result = embeddings.embed_query("hello") # Returns [0.1, 0.2]
```
Wrap an asynchronous embedding function:
```python
async def my_async_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = ensure_embeddings(my_async_fn)
result = await embeddings.aembed_query("hello") # Returns [0.1, 0.2]
```
"""
if embed is None:
raise ValueError("embed must be provided")
if isinstance(embed, Embeddings):
return embed
return EmbeddingsLambda(embed)
class EmbeddingsLambda(Embeddings):
"""Wrapper to convert embedding functions into LangChain's Embeddings interface.
This class allows arbitrary embedding functions to be used with LangChain-compatible
tools. It supports both synchronous and asynchronous operations, and can handle:
1. A synchronous function for sync operations (async operations will use sync function)
2. An async function for both sync/async operations (sync operations will raise an error)
The embedding functions should convert text into fixed-dimensional vectors that
capture the semantic meaning of the text.
Args:
func: Function that converts text to embeddings. Can be sync or async.
If async, it will be used for async operations, but sync operations
will raise an error. If sync, it will be used for both sync and async operations.
??? example "Examples"
With a sync function:
```python
def my_embed_fn(texts):
# Return 2D embeddings for each text
return [[0.1, 0.2] for _ in texts]
embeddings = EmbeddingsLambda(my_embed_fn)
result = embeddings.embed_query("hello") # Returns [0.1, 0.2]
await embeddings.aembed_query("hello") # Also returns [0.1, 0.2]
```
With an async function:
```python
async def my_async_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = EmbeddingsLambda(my_async_fn)
await embeddings.aembed_query("hello") # Returns [0.1, 0.2]
# Note: embed_query() would raise an error
```
"""
def __init__(
self,
func: Union[EmbeddingsFunc, AEmbeddingsFunc],
) -> None:
if func is None:
raise ValueError("func must be provided")
if _is_async_callable(func):
self.afunc = func
else:
self.func = func
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of texts into vectors.
Args:
texts: list of texts to convert to embeddings.
Returns:
list of embeddings, one per input text. Each embedding is a list of floats.
Raises:
ValueError: If the instance was initialized with only an async function.
"""
func = getattr(self, "func", None)
if func is None:
raise ValueError(
"EmbeddingsLambda was initialized with an async function but no sync function. "
"Use aembed_documents for async operation or provide a sync function."
)
return func(texts)
def embed_query(self, text: str) -> list[float]:
"""Embed a single piece of text.
Args:
text: Text to convert to an embedding.
Returns:
Embedding vector as a list of floats.
Note:
This is equivalent to calling embed_documents with a single text
and taking the first result.
"""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronously embed a list of texts into vectors.
Args:
texts: list of texts to convert to embeddings.
Returns:
list of embeddings, one per input text. Each embedding is a list of floats.
Note:
If no async function was provided, this falls back to the sync implementation.
"""
afunc = getattr(self, "afunc", None)
if afunc is None:
return await super().aembed_documents(texts)
return await afunc(texts)
async def aembed_query(self, text: str) -> list[float]:
"""Asynchronously embed a single piece of text.
Args:
text: Text to convert to an embedding.
Returns:
Embedding vector as a list of floats.
Note:
This is equivalent to calling aembed_documents with a single text
and taking the first result.
"""
afunc = getattr(self, "afunc", None)
if afunc is None:
return await super().aembed_query(text)
return (await afunc([text]))[0]
def get_text_at_path(obj: Any, path: Union[str, list[str]]) -> list[str]:
"""Extract text from an object using a path expression or pre-tokenized path.
Args:
obj: The object to extract text from
path: Either a path string or pre-tokenized path list.
!!! info "Path types handled"
- Simple paths: "field1.field2"
- Array indexing: "[0]", "[*]", "[-1]"
- Wildcards: "*"
- Multi-field selection: "{field1,field2}"
- Nested paths in multi-field: "{field1,nested.field2}"
"""
if not path or path == "$":
return [json.dumps(obj, sort_keys=True)]
tokens = tokenize_path(path) if isinstance(path, str) else path
def _extract_from_obj(obj: Any, tokens: list[str], pos: int) -> list[str]:
if pos >= len(tokens):
if isinstance(obj, (str, int, float, bool)):
return [str(obj)]
elif obj is None:
return []
elif isinstance(obj, (list, dict)):
return [json.dumps(obj, sort_keys=True)]
return []
token = tokens[pos]
results = []
if token.startswith("[") and token.endswith("]"):
if not isinstance(obj, list):
return []
index = token[1:-1]
if index == "*":
for item in obj:
results.extend(_extract_from_obj(item, tokens, pos + 1))
else:
try:
idx = int(index)
if idx < 0:
idx = len(obj) + idx
if 0 <= idx < len(obj):
results.extend(_extract_from_obj(obj[idx], tokens, pos + 1))
except (ValueError, IndexError):
return []
elif token.startswith("{") and token.endswith("}"):
if not isinstance(obj, dict):
return []
fields = [f.strip() for f in token[1:-1].split(",")]
for field in fields:
nested_tokens = tokenize_path(field)
if nested_tokens:
current_obj: Optional[dict] = obj
for nested_token in nested_tokens:
if (
isinstance(current_obj, dict)
and nested_token in current_obj
):
current_obj = current_obj[nested_token]
else:
current_obj = None
break
if current_obj is not None:
if isinstance(current_obj, (str, int, float, bool)):
results.append(str(current_obj))
elif isinstance(current_obj, (list, dict)):
results.append(json.dumps(current_obj, sort_keys=True))
# Handle wildcard
elif token == "*":
if isinstance(obj, dict):
for value in obj.values():
results.extend(_extract_from_obj(value, tokens, pos + 1))
elif isinstance(obj, list):
for item in obj:
results.extend(_extract_from_obj(item, tokens, pos + 1))
# Handle regular field
else:
if isinstance(obj, dict) and token in obj:
results.extend(_extract_from_obj(obj[token], tokens, pos + 1))
return results
return _extract_from_obj(obj, tokens, 0)
# Private utility functions
def tokenize_path(path: str) -> list[str]:
"""Tokenize a path into components.
!!! info "Types handled"
- Simple paths: "field1.field2"
- Array indexing: "[0]", "[*]", "[-1]"
- Wildcards: "*"
- Multi-field selection: "{field1,field2}"
"""
if not path:
return []
tokens = []
current: list[str] = []
i = 0
while i < len(path):
char = path[i]
if char == "[": # Handle array index
if current:
tokens.append("".join(current))
current = []
bracket_count = 1
index_chars = ["["]
i += 1
while i < len(path) and bracket_count > 0:
if path[i] == "[":
bracket_count += 1
elif path[i] == "]":
bracket_count -= 1
index_chars.append(path[i])
i += 1
tokens.append("".join(index_chars))
continue
elif char == "{": # Handle multi-field selection
if current:
tokens.append("".join(current))
current = []
brace_count = 1
field_chars = ["{"]
i += 1
while i < len(path) and brace_count > 0:
if path[i] == "{":
brace_count += 1
elif path[i] == "}":
brace_count -= 1
field_chars.append(path[i])
i += 1
tokens.append("".join(field_chars))
continue
elif char == ".": # Handle regular field
if current:
tokens.append("".join(current))
current = []
else:
current.append(char)
i += 1
if current:
tokens.append("".join(current))
return tokens
def _is_async_callable(
func: Any,
) -> bool:
"""Check if a function is async.
This includes both async def functions and classes with async __call__ methods.
Args:
func: Function or callable object to check.
Returns:
True if the function is async, False otherwise.
"""
return (
asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__") # noqa: B004
and asyncio.iscoroutinefunction(func.__call__)
)
__all__ = [
"ensure_embeddings",
"EmbeddingsFunc",
"AEmbeddingsFunc",
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/store/base/batch.py`:
```py
import asyncio
import weakref
from typing import Any, Literal, Optional, Union
from langgraph.store.base import (
BaseStore,
GetOp,
Item,
ListNamespacesOp,
MatchCondition,
NamespacePath,
Op,
PutOp,
SearchItem,
SearchOp,
_validate_namespace,
)
class AsyncBatchedBaseStore(BaseStore):
"""Efficiently batch operations in a background task."""
__slots__ = ("_loop", "_aqueue", "_task")
def __init__(self) -> None:
self._loop = asyncio.get_running_loop()
self._aqueue: dict[asyncio.Future, Op] = {}
self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self)))
def __del__(self) -> None:
self._task.cancel()
async def aget(
self,
namespace: tuple[str, ...],
key: str,
) -> Optional[Item]:
fut = self._loop.create_future()
self._aqueue[fut] = GetOp(namespace, key)
return await fut
async def asearch(
self,
namespace_prefix: tuple[str, ...],
/,
*,
query: Optional[str] = None,
filter: Optional[dict[str, Any]] = None,
limit: int = 10,
offset: int = 0,
) -> list[SearchItem]:
fut = self._loop.create_future()
self._aqueue[fut] = SearchOp(namespace_prefix, filter, limit, offset, query)
return await fut
async def aput(
self,
namespace: tuple[str, ...],
key: str,
value: dict[str, Any],
index: Optional[Union[Literal[False], list[str]]] = None,
) -> None:
_validate_namespace(namespace)
fut = self._loop.create_future()
self._aqueue[fut] = PutOp(namespace, key, value, index)
return await fut
async def adelete(
self,
namespace: tuple[str, ...],
key: str,
) -> None:
fut = self._loop.create_future()
self._aqueue[fut] = PutOp(namespace, key, None)
return await fut
async def alist_namespaces(
self,
*,
prefix: Optional[NamespacePath] = None,
suffix: Optional[NamespacePath] = None,
max_depth: Optional[int] = None,
limit: int = 100,
offset: int = 0,
) -> list[tuple[str, ...]]:
fut = self._loop.create_future()
match_conditions = []
if prefix:
match_conditions.append(MatchCondition(match_type="prefix", path=prefix))
if suffix:
match_conditions.append(MatchCondition(match_type="suffix", path=suffix))
op = ListNamespacesOp(
match_conditions=tuple(match_conditions),
max_depth=max_depth,
limit=limit,
offset=offset,
)
self._aqueue[fut] = op
return await fut
def _dedupe_ops(values: list[Op]) -> tuple[Optional[list[int]], list[Op]]:
"""Dedupe operations while preserving order for results.
Args:
values: List of operations to dedupe
Returns:
Tuple of (listen indices, deduped operations)
where listen indices map deduped operation results back to original positions
"""
if len(values) <= 1:
return None, list(values)
dedupped: list[Op] = []
listen: list[int] = []
puts: dict[tuple[tuple[str, ...], str], int] = {}
for op in values:
if isinstance(op, (GetOp, SearchOp, ListNamespacesOp)):
try:
listen.append(dedupped.index(op))
except ValueError:
listen.append(len(dedupped))
dedupped.append(op)
elif isinstance(op, PutOp):
putkey = (op.namespace, op.key)
if putkey in puts:
# Overwrite previous put
ix = puts[putkey]
dedupped[ix] = op
listen.append(ix)
else:
puts[putkey] = len(dedupped)
listen.append(len(dedupped))
dedupped.append(op)
else: # Any new ops will be treated regularly
listen.append(len(dedupped))
dedupped.append(op)
return listen, dedupped
async def _run(
aqueue: dict[asyncio.Future, Op], store: weakref.ReferenceType[BaseStore]
) -> None:
while True:
await asyncio.sleep(0)
if not aqueue:
continue
if s := store():
# get the operations to run
taken = aqueue.copy()
# action each operation
try:
values = list(taken.values())
listen, dedupped = _dedupe_ops(values)
results = await s.abatch(dedupped)
if listen is not None:
results = [results[ix] for ix in listen]
# set the results of each operation
for fut, result in zip(taken, results):
fut.set_result(result)
except Exception as e:
for fut in taken:
fut.set_exception(e)
# remove the operations from the queue
for fut in taken:
del aqueue[fut]
else:
break
# remove strong ref to store
del s
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/langgraph/store/base/__init__.py`:
```py
"""Base classes and types for persistent key-value stores.
Stores provide long-term memory that persists across threads and conversations.
Supports hierarchical namespaces, key-value storage, and optional vector search.
Core types:
- BaseStore: Store interface with sync/async operations
- Item: Stored key-value pairs with metadata
- Op: Get/Put/Search/List operations
"""
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Iterable, Literal, NamedTuple, Optional, TypedDict, Union, cast
from langchain_core.embeddings import Embeddings
from langgraph.store.base.embed import (
AEmbeddingsFunc,
EmbeddingsFunc,
ensure_embeddings,
get_text_at_path,
tokenize_path,
)
class Item:
"""Represents a stored item with metadata.
Args:
value (dict[str, Any]): The stored data as a dictionary. Keys are filterable.
key (str): Unique identifier within the namespace.
namespace (tuple[str, ...]): Hierarchical path defining the collection in which this document resides.
Represented as a tuple of strings, allowing for nested categorization.
For example: ("documents", 'user123')
created_at (datetime): Timestamp of item creation.
updated_at (datetime): Timestamp of last update.
"""
__slots__ = ("value", "key", "namespace", "created_at", "updated_at")
def __init__(
self,
*,
value: dict[str, Any],
key: str,
namespace: tuple[str, ...],
created_at: datetime,
updated_at: datetime,
):
self.value = value
self.key = key
# The casting from json-like types is for if this object is
# deserialized.
self.namespace = tuple(namespace)
self.created_at = (
datetime.fromisoformat(cast(str, created_at))
if isinstance(created_at, str)
else created_at
)
self.updated_at = (
datetime.fromisoformat(cast(str, created_at))
if isinstance(updated_at, str)
else updated_at
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Item):
return False
return (
self.value == other.value
and self.key == other.key
and self.namespace == other.namespace
and self.created_at == other.created_at
and self.updated_at == other.updated_at
)
def __hash__(self) -> int:
return hash((self.namespace, self.key))
def dict(self) -> dict:
return {
"value": self.value,
"key": self.key,
"namespace": list(self.namespace),
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
class SearchItem(Item):
"""Represents an item returned from a search operation with additional metadata."""
__slots__ = ("score",)
def __init__(
self,
namespace: tuple[str, ...],
key: str,
value: dict[str, Any],
created_at: datetime,
updated_at: datetime,
score: Optional[float] = None,
) -> None:
"""Initialize a result item.
Args:
namespace: Hierarchical path to the item.
key: Unique identifier within the namespace.
value: The stored value.
created_at: When the item was first created.
updated_at: When the item was last updated.
score: Relevance/similarity score if from a ranked operation.
"""
super().__init__(
value=value,
key=key,
namespace=namespace,
created_at=created_at,
updated_at=updated_at,
)
self.score = score
def dict(self) -> dict:
result = super().dict()
result["score"] = self.score
return result
class GetOp(NamedTuple):
"""Operation to retrieve a specific item by its namespace and key.
This operation allows precise retrieval of stored items using their full path
(namespace) and unique identifier (key) combination.
???+ example "Examples"
Basic item retrieval:
```python
GetOp(namespace=("users", "profiles"), key="user123")
GetOp(namespace=("cache", "embeddings"), key="doc456")
```
"""
namespace: tuple[str, ...]
"""Hierarchical path that uniquely identifies the item's location.
???+ example "Examples"
```python
("users",) # Root level users namespace
("users", "profiles") # Profiles within users namespace
```
"""
key: str
"""Unique identifier for the item within its specific namespace.
???+ example "Examples"
```python
"user123" # For a user profile
"doc456" # For a document
```
"""
class SearchOp(NamedTuple):
"""Operation to search for items within a specified namespace hierarchy.
This operation supports both structured filtering and natural language search
within a given namespace prefix. It provides pagination through limit and offset
parameters.
Note:
Natural language search support depends on your store implementation.
???+ example "Examples"
Search with filters and pagination:
```python
SearchOp(
namespace_prefix=("documents",),
filter={"type": "report", "status": "active"},
limit=5,
offset=10
)
```
Natural language search:
```python
SearchOp(
namespace_prefix=("users", "content"),
query="technical documentation about APIs",
limit=20
)
```
"""
namespace_prefix: tuple[str, ...]
"""Hierarchical path prefix defining the search scope.
???+ example "Examples"
```python
() # Search entire store
("documents",) # Search all documents
("users", "content") # Search within user content
```
"""
filter: Optional[dict[str, Any]] = None
"""Key-value pairs for filtering results based on exact matches or comparison operators.
The filter supports both exact matches and operator-based comparisons.
Supported Operators:
- $eq: Equal to (same as direct value comparison)
- $ne: Not equal to
- $gt: Greater than
- $gte: Greater than or equal to
- $lt: Less than
- $lte: Less than or equal to
???+ example "Examples"
Simple exact match:
```python
{"status": "active"}
```
Comparison operators:
```python
{"score": {"$gt": 4.99}} # Score greater than 4.99
```
Multiple conditions:
```python
{
"score": {"$gte": 3.0},
"color": "red"
}
```
"""
limit: int = 10
"""Maximum number of items to return in the search results."""
offset: int = 0
"""Number of matching items to skip for pagination."""
query: Optional[str] = None
"""Natural language search query for semantic search capabilities.
???+ example "Examples"
- "technical documentation about REST APIs"
- "machine learning papers from 2023"
"""
# Type representing a namespace path that can include wildcards
NamespacePath = tuple[Union[str, Literal["*"]], ...]
"""A tuple representing a namespace path that can include wildcards.
???+ example "Examples"
```python
("users",) # Exact users namespace
("documents", "*") # Any sub-namespace under documents
("cache", "*", "v1") # Any cache category with v1 version
```
"""
# Type for specifying how to match namespaces
NamespaceMatchType = Literal["prefix", "suffix"]
"""Specifies how to match namespace paths.
Values:
"prefix": Match from the start of the namespace
"suffix": Match from the end of the namespace
"""
class MatchCondition(NamedTuple):
"""Represents a pattern for matching namespaces in the store.
This class combines a match type (prefix or suffix) with a namespace path
pattern that can include wildcards to flexibly match different namespace
hierarchies.
???+ example "Examples"
Prefix matching:
```python
MatchCondition(match_type="prefix", path=("users", "profiles"))
```
Suffix matching with wildcard:
```python
MatchCondition(match_type="suffix", path=("cache", "*"))
```
Simple suffix matching:
```python
MatchCondition(match_type="suffix", path=("v1",))
```
"""
match_type: NamespaceMatchType
"""Type of namespace matching to perform."""
path: NamespacePath
"""Namespace path pattern that can include wildcards."""
class ListNamespacesOp(NamedTuple):
"""Operation to list and filter namespaces in the store.
This operation allows exploring the organization of data, finding specific
collections, and navigating the namespace hierarchy.
???+ example "Examples"
List all namespaces under the "documents" path:
```python
ListNamespacesOp(
match_conditions=(MatchCondition(match_type="prefix", path=("documents",)),),
max_depth=2
)
```
List all namespaces that end with "v1":
```python
ListNamespacesOp(
match_conditions=(MatchCondition(match_type="suffix", path=("v1",)),),
limit=50
)
```
"""
match_conditions: Optional[tuple[MatchCondition, ...]] = None
"""Optional conditions for filtering namespaces.
???+ example "Examples"
All user namespaces:
```python
(MatchCondition(match_type="prefix", path=("users",)),)
```
All namespaces that start with "docs" and end with "draft":
```python
(
MatchCondition(match_type="prefix", path=("docs",)),
MatchCondition(match_type="suffix", path=("draft",))
)
```
"""
max_depth: Optional[int] = None
"""Maximum depth of namespace hierarchy to return.
Note:
Namespaces deeper than this level will be truncated.
"""
limit: int = 100
"""Maximum number of namespaces to return."""
offset: int = 0
"""Number of namespaces to skip for pagination."""
class PutOp(NamedTuple):
"""Operation to store, update, or delete an item in the store.
This class represents a single operation to modify the store's contents,
whether adding new items, updating existing ones, or removing them.
"""
namespace: tuple[str, ...]
"""Hierarchical path that identifies the location of the item.
The namespace acts as a folder-like structure to organize items.
Each element in the tuple represents one level in the hierarchy.
???+ example "Examples"
Root level documents
```python
("documents",)
```
User-specific documents
```python
("documents", "user123")
```
Nested cache structure
```python
("cache", "embeddings", "v1")
```
"""
key: str
"""Unique identifier for the item within its namespace.
The key must be unique within the specific namespace to avoid conflicts.
Together with the namespace, it forms a complete path to the item.
Example:
If namespace is ("documents", "user123") and key is "report1",
the full path would effectively be "documents/user123/report1"
"""
value: Optional[dict[str, Any]]
"""The data to store, or None to mark the item for deletion.
The value must be a dictionary with string keys and JSON-serializable values.
Setting this to None signals that the item should be deleted.
Example:
{
"field1": "string value",
"field2": 123,
"nested": {"can": "contain", "any": "serializable data"}
}
"""
index: Optional[Union[Literal[False], list[str]]] = None # type: ignore[assignment]
"""Controls how the item's fields are indexed for search operations.
Indexing configuration determines how the item can be found through search:
- None (default): Uses the store's default indexing configuration (if provided)
- False: Disables indexing for this item
- list[str]: Specifies which json path fields to index for search
The item remains accessible through direct get() operations regardless of indexing.
When indexed, fields can be searched using natural language queries through
vector similarity search (if supported by the store implementation).
Path Syntax:
- Simple field access: "field"
- Nested fields: "parent.child.grandchild"
- Array indexing:
- Specific index: "array[0]"
- Last element: "array[-1]"
- All elements (each individually): "array[*]"
???+ example "Examples"
- None - Use store defaults (whole item)
- list[str] - List of fields to index
```python
[
"metadata.title", # Nested field access
"context[*].content", # Index content from all context as separate vectors
"authors[0].name", # First author's name
"revisions[-1].changes", # Most recent revision's changes
"sections[*].paragraphs[*].text", # All text from all paragraphs in all sections
"metadata.tags[*]", # All tags in metadata
]
```
"""
Op = Union[GetOp, SearchOp, PutOp, ListNamespacesOp]
Result = Union[Item, list[Item], list[SearchItem], list[tuple[str, ...]], None]
class InvalidNamespaceError(ValueError):
"""Provided namespace is invalid."""
class IndexConfig(TypedDict, total=False):
"""Configuration for indexing documents for semantic search in the store.
If not provided to the store, the store will not support vector search.
In that case, all `index` arguments to put() and `aput()` operations will be ignored.
"""
dims: int
"""Number of dimensions in the embedding vectors.
Common embedding models have the following dimensions:
- openai:text-embedding-3-large: 3072
- openai:text-embedding-3-small: 1536
- openai:text-embedding-ada-002: 1536
- cohere:embed-english-v3.0: 1024
- cohere:embed-english-light-v3.0: 384
- cohere:embed-multilingual-v3.0: 1024
- cohere:embed-multilingual-light-v3.0: 384
"""
embed: Union[Embeddings, EmbeddingsFunc, AEmbeddingsFunc]
"""Optional function to generate embeddings from text.
Can be specified in three ways:
1. A LangChain Embeddings instance
2. A synchronous embedding function (EmbeddingsFunc)
3. An asynchronous embedding function (AEmbeddingsFunc)
???+ example "Examples"
Using LangChain's initialization with InMemoryStore:
```python
from langchain.embeddings import init_embeddings
from langgraph.store.memory import InMemoryStore
store = InMemoryStore(
index={
"dims": 1536,
"embed": init_embeddings("openai:text-embedding-3-small")
}
)
```
Using a custom embedding function with InMemoryStore:
```python
from openai import OpenAI
from langgraph.store.memory import InMemoryStore
client = OpenAI()
def embed_texts(texts: list[str]) -> list[list[float]]:
response = client.embeddings.create(
model="text-embedding-3-small",
input=texts
)
return [e.embedding for e in response.data]
store = InMemoryStore(
index={
"dims": 1536,
"embed": embed_texts
}
)
```
Using an asynchronous embedding function with InMemoryStore:
```python
from openai import AsyncOpenAI
from langgraph.store.memory import InMemoryStore
client = AsyncOpenAI()
async def aembed_texts(texts: list[str]) -> list[list[float]]:
response = await client.embeddings.create(
model="text-embedding-3-small",
input=texts
)
return [e.embedding for e in response.data]
store = InMemoryStore(
index={
"dims": 1536,
"embed": aembed_texts
}
)
```
"""
fields: Optional[list[str]]
"""Fields to extract text from for embedding generation.
Controls which parts of stored items are embedded for semantic search. Follows JSON path syntax:
- ["$"]: Embeds the entire JSON object as one vector (default)
- ["field1", "field2"]: Embeds specific top-level fields
- ["parent.child"]: Embeds nested fields using dot notation
- ["array[*].field"]: Embeds field from each array element separately
Note:
You can always override this behavior when storing an item using the
`index` parameter in the `put` or `aput` operations.
???+ example "Examples"
```python
# Embed entire document (default)
fields=["$"]
# Embed specific fields
fields=["text", "summary"]
# Embed nested fields
fields=["metadata.title", "content.body"]
# Embed from arrays
fields=["messages[*].content"] # Each message content separately
fields=["context[0].text"] # First context item's text
```
Note:
- Fields missing from a document are skipped
- Array notation creates separate embeddings for each element
- Complex nested paths are supported (e.g., "a.b[*].c.d")
"""
class BaseStore(ABC):
"""Abstract base class for persistent key-value stores.
Stores enable persistence and memory that can be shared across threads,
scoped to user IDs, assistant IDs, or other arbitrary namespaces.
Some implementations may support semantic search capabilities through
an optional `index` configuration.
Note:
Semantic search capabilities vary by implementation and are typically
disabled by default. Stores that support this feature can be configured
by providing an `index` configuration at creation time. Without this
configuration, semantic search is disabled and any `index` arguments
to storage operations will have no effect.
"""
__slots__ = ("__weakref__",)
@abstractmethod
def batch(self, ops: Iterable[Op]) -> list[Result]:
"""Execute multiple operations synchronously in a single batch.
Args:
ops: An iterable of operations to execute.
Returns:
A list of results, where each result corresponds to an operation in the input.
The order of results matches the order of input operations.
"""
@abstractmethod
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
"""Execute multiple operations asynchronously in a single batch.
Args:
ops: An iterable of operations to execute.
Returns:
A list of results, where each result corresponds to an operation in the input.
The order of results matches the order of input operations.
"""
def get(self, namespace: tuple[str, ...], key: str) -> Optional[Item]:
"""Retrieve a single item.
Args:
namespace: Hierarchical path for the item.
key: Unique identifier within the namespace.
Returns:
The retrieved item or None if not found.
"""
return self.batch([GetOp(namespace, key)])[0]
def search(
self,
namespace_prefix: tuple[str, ...],
/,
*,
query: Optional[str] = None,
filter: Optional[dict[str, Any]] = None,
limit: int = 10,
offset: int = 0,
) -> list[SearchItem]:
"""Search for items within a namespace prefix.
Args:
namespace_prefix: Hierarchical path prefix to search within.
query: Optional query for natural language search.
filter: Key-value pairs to filter results.
limit: Maximum number of items to return.
offset: Number of items to skip before returning results.
Returns:
List of items matching the search criteria.
???+ example "Examples"
Basic filtering:
```python
# Search for documents with specific metadata
results = store.search(
("docs",),
filter={"type": "article", "status": "published"}
)
```
Natural language search (requires vector store implementation):
```python
# Initialize store with embedding configuration
store = YourStore( # e.g., InMemoryStore, AsyncPostgresStore
index={
"dims": 1536, # embedding dimensions
"embed": your_embedding_function, # function to create embeddings
"fields": ["text"] # fields to embed. Defaults to ["$"]
}
)
# Search for semantically similar documents
results = store.search(
("docs",),
query="machine learning applications in healthcare",
filter={"type": "research_paper"},
limit=5
)
```
Note: Natural language search support depends on your store implementation
and requires proper embedding configuration.
"""
return self.batch([SearchOp(namespace_prefix, filter, limit, offset, query)])[0]
def put(
self,
namespace: tuple[str, ...],
key: str,
value: dict[str, Any],
index: Optional[Union[Literal[False], list[str]]] = None,
) -> None:
"""Store or update an item in the store.
Args:
namespace: Hierarchical path for the item, represented as a tuple of strings.
Example: ("documents", "user123")
key: Unique identifier within the namespace. Together with namespace forms
the complete path to the item.
value: Dictionary containing the item's data. Must contain string keys
and JSON-serializable values.
index: Controls how the item's fields are indexed for search:
- None (default): Use `fields` you configured when creating the store (if any)
If you do not initialize the store with indexing capabilities,
the `index` parameter will be ignored
- False: Disable indexing for this item
- list[str]: List of field paths to index, supporting:
- Nested fields: "metadata.title"
- Array access: "chapters[*].content" (each indexed separately)
- Specific indices: "authors[0].name"
Note:
Indexing support depends on your store implementation.
If you do not initialize the store with indexing capabilities,
the `index` parameter will be ignored.
???+ example "Examples"
Store item. Indexing depends on how you configure the store.
```python
store.put(("docs",), "report", {"memory": "Will likes ai"})
```
Do not index item for semantic search. Still accessible through get()
and search() operations but won't have a vector representation.
```python
store.put(("docs",), "report", {"memory": "Will likes ai"}, index=False)
```
Index specific fields for search.
```python
store.put(("docs",), "report", {"memory": "Will likes ai"}, index=["memory"])
```
"""
_validate_namespace(namespace)
self.batch([PutOp(namespace, key, value, index=index)])
def delete(self, namespace: tuple[str, ...], key: str) -> None:
"""Delete an item.
Args:
namespace: Hierarchical path for the item.
key: Unique identifier within the namespace.
"""
self.batch([PutOp(namespace, key, None)])
def list_namespaces(
self,
*,
prefix: Optional[NamespacePath] = None,
suffix: Optional[NamespacePath] = None,
max_depth: Optional[int] = None,
limit: int = 100,
offset: int = 0,
) -> list[tuple[str, ...]]:
"""List and filter namespaces in the store.
Used to explore the organization of data,
find specific collections, or navigate the namespace hierarchy.
Args:
prefix (Optional[Tuple[str, ...]]): Filter namespaces that start with this path.
suffix (Optional[Tuple[str, ...]]): Filter namespaces that end with this path.
max_depth (Optional[int]): Return namespaces up to this depth in the hierarchy.
Namespaces deeper than this level will be truncated.
limit (int): Maximum number of namespaces to return (default 100).
offset (int): Number of namespaces to skip for pagination (default 0).
Returns:
List[Tuple[str, ...]]: A list of namespace tuples that match the criteria.
Each tuple represents a full namespace path up to `max_depth`.
???+ example "Examples":
Setting max_depth=3. Given the namespaces:
```python
# Example if you have the following namespaces:
# ("a", "b", "c")
# ("a", "b", "d", "e")
# ("a", "b", "d", "i")
# ("a", "b", "f")
# ("a", "c", "f")
store.list_namespaces(prefix=("a", "b"), max_depth=3)
# [("a", "b", "c"), ("a", "b", "d"), ("a", "b", "f")]
```
"""
match_conditions = []
if prefix:
match_conditions.append(MatchCondition(match_type="prefix", path=prefix))
if suffix:
match_conditions.append(MatchCondition(match_type="suffix", path=suffix))
op = ListNamespacesOp(
match_conditions=tuple(match_conditions),
max_depth=max_depth,
limit=limit,
offset=offset,
)
return self.batch([op])[0]
async def aget(self, namespace: tuple[str, ...], key: str) -> Optional[Item]:
"""Asynchronously retrieve a single item.
Args:
namespace: Hierarchical path for the item.
key: Unique identifier within the namespace.
Returns:
The retrieved item or None if not found.
"""
return (await self.abatch([GetOp(namespace, key)]))[0]
async def asearch(
self,
namespace_prefix: tuple[str, ...],
/,
*,
query: Optional[str] = None,
filter: Optional[dict[str, Any]] = None,
limit: int = 10,
offset: int = 0,
) -> list[SearchItem]:
"""Asynchronously search for items within a namespace prefix.
Args:
namespace_prefix: Hierarchical path prefix to search within.
query: Optional query for natural language search.
filter: Key-value pairs to filter results.
limit: Maximum number of items to return.
offset: Number of items to skip before returning results.
Returns:
List of items matching the search criteria.
???+ example "Examples"
Basic filtering:
```python
# Search for documents with specific metadata
results = await store.asearch(
("docs",),
filter={"type": "article", "status": "published"}
)
```
Natural language search (requires vector store implementation):
```python
# Initialize store with embedding configuration
store = YourStore( # e.g., InMemoryStore, AsyncPostgresStore
index={
"dims": 1536, # embedding dimensions
"embed": your_embedding_function, # function to create embeddings
"fields": ["text"] # fields to embed
}
)
# Search for semantically similar documents
results = await store.asearch(
("docs",),
query="machine learning applications in healthcare",
filter={"type": "research_paper"},
limit=5
)
```
Note: Natural language search support depends on your store implementation
and requires proper embedding configuration.
"""
return (
await self.abatch(
[SearchOp(namespace_prefix, filter, limit, offset, query)]
)
)[0]
async def aput(
self,
namespace: tuple[str, ...],
key: str,
value: dict[str, Any],
index: Optional[Union[Literal[False], list[str]]] = None,
) -> None:
"""Asynchronously store or update an item in the store.
Args:
namespace: Hierarchical path for the item, represented as a tuple of strings.
Example: ("documents", "user123")
key: Unique identifier within the namespace. Together with namespace forms
the complete path to the item.
value: Dictionary containing the item's data. Must contain string keys
and JSON-serializable values.
index: Controls how the item's fields are indexed for search:
- None (default): Use `fields` you configured when creating the store (if any)
If you do not initialize the store with indexing capabilities,
the `index` parameter will be ignored
- False: Disable indexing for this item
- list[str]: List of field paths to index, supporting:
- Nested fields: "metadata.title"
- Array access: "chapters[*].content" (each indexed separately)
- Specific indices: "authors[0].name"
Note:
Indexing support depends on your store implementation.
If you do not initialize the store with indexing capabilities,
the `index` parameter will be ignored.
???+ example "Examples"
Store item. Indexing depends on how you configure the store.
```python
await store.aput(("docs",), "report", {"memory": "Will likes ai"})
```
Do not index item for semantic search. Still accessible through get()
and search() operations but won't have a vector representation.
```python
await store.aput(("docs",), "report", {"memory": "Will likes ai"}, index=False)
```
Index specific fields for search (if store configured to index items):
```python
await store.aput(
("docs",),
"report",
{
"memory": "Will likes ai",
"context": [{"content": "..."}, {"content": "..."}]
},
index=["memory", "context[*].content"]
)
```
"""
_validate_namespace(namespace)
await self.abatch([PutOp(namespace, key, value, index=index)])
async def adelete(self, namespace: tuple[str, ...], key: str) -> None:
"""Asynchronously delete an item.
Args:
namespace: Hierarchical path for the item.
key: Unique identifier within the namespace.
"""
await self.abatch([PutOp(namespace, key, None)])
async def alist_namespaces(
self,
*,
prefix: Optional[NamespacePath] = None,
suffix: Optional[NamespacePath] = None,
max_depth: Optional[int] = None,
limit: int = 100,
offset: int = 0,
) -> list[tuple[str, ...]]:
"""List and filter namespaces in the store asynchronously.
Used to explore the organization of data,
find specific collections, or navigate the namespace hierarchy.
Args:
prefix (Optional[Tuple[str, ...]]): Filter namespaces that start with this path.
suffix (Optional[Tuple[str, ...]]): Filter namespaces that end with this path.
max_depth (Optional[int]): Return namespaces up to this depth in the hierarchy.
Namespaces deeper than this level will be truncated to this depth.
limit (int): Maximum number of namespaces to return (default 100).
offset (int): Number of namespaces to skip for pagination (default 0).
Returns:
List[Tuple[str, ...]]: A list of namespace tuples that match the criteria.
Each tuple represents a full namespace path up to `max_depth`.
???+ example "Examples"
Setting max_depth=3 with existing namespaces:
```python
# Given the following namespaces:
# ("a", "b", "c")
# ("a", "b", "d", "e")
# ("a", "b", "d", "i")
# ("a", "b", "f")
# ("a", "c", "f")
await store.alist_namespaces(prefix=("a", "b"), max_depth=3)
# Returns: [("a", "b", "c"), ("a", "b", "d"), ("a", "b", "f")]
```
"""
match_conditions = []
if prefix:
match_conditions.append(MatchCondition(match_type="prefix", path=prefix))
if suffix:
match_conditions.append(MatchCondition(match_type="suffix", path=suffix))
op = ListNamespacesOp(
match_conditions=tuple(match_conditions),
max_depth=max_depth,
limit=limit,
offset=offset,
)
return (await self.abatch([op]))[0]
def _validate_namespace(namespace: tuple[str, ...]) -> None:
if not namespace:
raise InvalidNamespaceError("Namespace cannot be empty.")
for label in namespace:
if not isinstance(label, str):
raise InvalidNamespaceError(
f"Invalid namespace label '{label}' found in {namespace}. Namespace labels"
f" must be strings, but got {type(label).__name__}."
)
if "." in label:
raise InvalidNamespaceError(
f"Invalid namespace label '{label}' found in {namespace}. Namespace labels cannot contain periods ('.')."
)
elif not label:
raise InvalidNamespaceError(
f"Namespace labels cannot be empty strings. Got {label} in {namespace}"
)
if namespace[0] == "langgraph":
raise InvalidNamespaceError(
f'Root label for namespace cannot be "langgraph". Got: {namespace}'
)
__all__ = [
"BaseStore",
"Item",
"Op",
"PutOp",
"GetOp",
"SearchOp",
"ListNamespacesOp",
"MatchCondition",
"NamespacePath",
"NamespaceMatchType",
"Embeddings",
"ensure_embeddings",
"tokenize_path",
"get_text_at_path",
]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/pyproject.toml`:
```toml
[tool.poetry]
name = "langgraph-checkpoint"
version = "2.0.8"
description = "Library with base interfaces for LangGraph checkpoint savers."
authors = []
license = "MIT"
readme = "README.md"
repository = "https://www.github.com/langchain-ai/langgraph"
packages = [{ include = "langgraph" }]
[tool.poetry.dependencies]
python = "^3.9.0,<4.0"
langchain-core = ">=0.2.38,<0.4"
msgpack = "^1.1.0"
[tool.poetry.group.dev.dependencies]
ruff = "^0.6.2"
codespell = "^2.2.0"
pytest = "^7.2.1"
pytest-asyncio = "^0.21.1"
pytest-mock = "^3.11.1"
pytest-watcher = "^0.4.1"
mypy = "^1.10.0"
dataclasses-json = "^0.6.7"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5 -vv"
asyncio_mode = "auto"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
lint.select = [
"E", # pycodestyle
"F", # Pyflakes
"UP", # pyupgrade
"B", # flake8-bugbear
"I", # isort
]
lint.ignore = ["E501", "B008", "UP007", "UP006"]
[tool.pytest-watcher]
now = true
delay = 0.1
runner_args = ["--ff", "-v", "--tb", "short"]
patterns = ["*.py"]
[tool.mypy]
# https://mypy.readthedocs.io/en/stable/config_file.html
disallow_untyped_defs = "True"
explicit_package_bases = "True"
warn_no_return = "False"
warn_unused_ignores = "True"
warn_redundant_casts = "True"
allow_redefinition = "True"
disable_error_code = "typeddict-item, return-value"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/tests/test_jsonplus.py`:
```py
import dataclasses
import pathlib
import re
import sys
import uuid
from collections import deque
from datetime import date, datetime, time, timezone
from decimal import Decimal
from enum import Enum
from ipaddress import IPv4Address
import dataclasses_json
from pydantic import BaseModel, SecretStr
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import SecretStr as SecretStrV1
from zoneinfo import ZoneInfo
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.store.base import Item
class InnerPydantic(BaseModel):
hello: str
class MyPydantic(BaseModel):
foo: str
bar: int
inner: InnerPydantic
class InnerPydanticV1(BaseModelV1):
hello: str
class MyPydanticV1(BaseModelV1):
foo: str
bar: int
inner: InnerPydanticV1
@dataclasses.dataclass
class InnerDataclass:
hello: str
@dataclasses.dataclass
class MyDataclass:
foo: str
bar: int
inner: InnerDataclass
def something(self) -> None:
pass
if sys.version_info < (3, 10):
class MyDataclassWSlots(MyDataclass):
pass
else:
@dataclasses.dataclass(slots=True)
class MyDataclassWSlots:
foo: str
bar: int
inner: InnerDataclass
def something(self) -> None:
pass
class MyEnum(Enum):
FOO = "foo"
BAR = "bar"
@dataclasses_json.dataclass_json
@dataclasses.dataclass
class Person:
name: str
def test_serde_jsonplus() -> None:
uid = uuid.UUID(int=1)
deque_instance = deque([1, 2, 3])
tzn = ZoneInfo("America/New_York")
ip4 = IPv4Address("192.168.0.1")
current_date = date(2024, 4, 19)
current_time = time(23, 4, 57, 51022, timezone.max)
current_timestamp = datetime(2024, 4, 19, 23, 4, 57, 51022, timezone.max)
to_serialize = {
"path": pathlib.Path("foo", "bar"),
"re": re.compile(r"foo", re.DOTALL),
"decimal": Decimal("1.10101"),
"set": {1, 2, frozenset({1, 2})},
"frozen_set": frozenset({1, 2, 3}),
"ip4": ip4,
"deque": deque_instance,
"tzn": tzn,
"date": current_date,
"time": current_time,
"uid": uid,
"timestamp": current_timestamp,
"my_slotted_class": MyDataclassWSlots("bar", 2, InnerDataclass("hello")),
"my_dataclass": MyDataclass("foo", 1, InnerDataclass("hello")),
"my_enum": MyEnum.FOO,
"my_pydantic": MyPydantic(foo="foo", bar=1, inner=InnerPydantic(hello="hello")),
"my_pydantic_v1": MyPydanticV1(
foo="foo", bar=1, inner=InnerPydanticV1(hello="hello")
),
"my_secret_str": SecretStr("meow"),
"my_secret_str_v1": SecretStrV1("meow"),
"person": Person(name="foo"),
"a_bool": True,
"a_none": None,
"a_str": "foo",
"a_str_nuc": "foo\u0000",
"a_str_uc": "foo ⛰️",
"a_str_ucuc": "foo \u26f0\ufe0f\u0000",
"a_str_ucucuc": "foo \\u26f0\\ufe0f",
"an_int": 1,
"a_float": 1.1,
"a_bytes": b"my bytes",
"a_bytearray": bytearray([42]),
"my_item": Item(
value={},
key="my-key",
namespace=("a", "name", " "),
created_at=datetime(2024, 9, 24, 17, 29, 10, 128397),
updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397),
),
}
serde = JsonPlusSerializer()
dumped = serde.dumps_typed(to_serialize)
assert dumped[0] == "msgpack"
assert serde.loads_typed(dumped) == to_serialize
for value in to_serialize.values():
assert serde.loads_typed(serde.dumps_typed(value)) == value
surrogates = [
"Hello\ud83d\ude00",
"Python\ud83d\udc0d",
"Surrogate\ud834\udd1e",
"Example\ud83c\udf89",
"String\ud83c\udfa7",
"With\ud83c\udf08",
"Surrogates\ud83d\ude0e",
"Embedded\ud83d\udcbb",
"In\ud83c\udf0e",
"The\ud83d\udcd6",
"Text\ud83d\udcac",
"收花🙄·到",
]
assert serde.loads_typed(serde.dumps_typed(surrogates)) == [
v.encode("utf-8", "ignore").decode() for v in surrogates
]
def test_serde_jsonplus_bytes() -> None:
serde = JsonPlusSerializer()
some_bytes = b"my bytes"
dumped = serde.dumps_typed(some_bytes)
assert dumped == ("bytes", some_bytes)
assert serde.loads_typed(dumped) == some_bytes
def test_serde_jsonplus_bytearray() -> None:
serde = JsonPlusSerializer()
some_bytearray = bytearray([42])
dumped = serde.dumps_typed(some_bytearray)
assert dumped == ("bytearray", some_bytearray)
assert serde.loads_typed(dumped) == some_bytearray
def test_loads_cannot_find() -> None:
serde = JsonPlusSerializer()
dumped = (
"json",
b'{"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyPydanticccc"], "method": null, "args": [], "kwargs": {"foo": "foo", "bar": 1}}',
)
assert serde.loads_typed(dumped) is None, "Should return None if cannot find class"
dumped = (
"json",
b'{"lc": 2, "type": "constructor", "id": ["tests", "test_jsonpluss", "MyPydantic"], "method": null, "args": [], "kwargs": {"foo": "foo", "bar": 1}}',
)
assert serde.loads_typed(dumped) is None, "Should return None if cannot find module"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/tests/test_store.py`:
```py
# mypy: disable-error-code="operator"
import asyncio
import json
from datetime import datetime
from typing import Any, Iterable
import pytest
from pytest_mock import MockerFixture
from langgraph.store.base import (
GetOp,
InvalidNamespaceError,
Item,
Op,
PutOp,
Result,
get_text_at_path,
)
from langgraph.store.base.batch import AsyncBatchedBaseStore
from langgraph.store.memory import InMemoryStore
from tests.embed_test_utils import CharacterEmbeddings
class MockAsyncBatchedStore(AsyncBatchedBaseStore):
def __init__(self, **kwargs: Any) -> None:
super().__init__()
self._store = InMemoryStore(**kwargs)
def batch(self, ops: Iterable[Op]) -> list[Result]:
return self._store.batch(ops)
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
return self._store.batch(ops)
def test_get_text_at_path() -> None:
nested_data = {
"name": "test",
"info": {
"age": 25,
"tags": ["a", "b", "c"],
"metadata": {"created": "2024-01-01", "updated": "2024-01-02"},
},
"items": [
{"id": 1, "value": "first", "tags": ["x", "y"]},
{"id": 2, "value": "second", "tags": ["y", "z"]},
{"id": 3, "value": "third", "tags": ["z", "w"]},
],
"empty": None,
"zeros": [0, 0.0, "0"],
"empty_list": [],
"empty_dict": {},
}
assert get_text_at_path(nested_data, "$") == [
json.dumps(nested_data, sort_keys=True)
]
assert get_text_at_path(nested_data, "name") == ["test"]
assert get_text_at_path(nested_data, "info.age") == ["25"]
assert get_text_at_path(nested_data, "info.metadata.created") == ["2024-01-01"]
assert get_text_at_path(nested_data, "items[0].value") == ["first"]
assert get_text_at_path(nested_data, "items[-1].value") == ["third"]
assert get_text_at_path(nested_data, "items[1].tags[0]") == ["y"]
values = get_text_at_path(nested_data, "items[*].value")
assert set(values) == {"first", "second", "third"}
metadata_dates = get_text_at_path(nested_data, "info.metadata.*")
assert set(metadata_dates) == {"2024-01-01", "2024-01-02"}
name_and_age = get_text_at_path(nested_data, "{name,info.age}")
assert set(name_and_age) == {"test", "25"}
item_fields = get_text_at_path(nested_data, "items[*].{id,value}")
assert set(item_fields) == {"1", "2", "3", "first", "second", "third"}
all_tags = get_text_at_path(nested_data, "items[*].tags[*]")
assert set(all_tags) == {"x", "y", "z", "w"}
assert get_text_at_path(None, "any.path") == []
assert get_text_at_path({}, "any.path") == []
assert get_text_at_path(nested_data, "") == [
json.dumps(nested_data, sort_keys=True)
]
assert get_text_at_path(nested_data, "nonexistent") == []
assert get_text_at_path(nested_data, "items[99].value") == []
assert get_text_at_path(nested_data, "items[*].nonexistent") == []
assert get_text_at_path(nested_data, "empty") == []
assert get_text_at_path(nested_data, "empty_list") == ["[]"]
assert get_text_at_path(nested_data, "empty_dict") == ["{}"]
zeros = get_text_at_path(nested_data, "zeros[*]")
assert set(zeros) == {"0", "0.0"}
assert get_text_at_path(nested_data, "items[].value") == []
assert get_text_at_path(nested_data, "items[abc].value") == []
assert get_text_at_path(nested_data, "{unclosed") == []
assert get_text_at_path(nested_data, "nested[{invalid}]") == []
async def test_async_batch_store(mocker: MockerFixture) -> None:
abatch = mocker.stub()
class MockStore(AsyncBatchedBaseStore):
def batch(self, ops: Iterable[Op]) -> list[Result]:
raise NotImplementedError
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
assert all(isinstance(op, GetOp) for op in ops)
abatch(ops)
return [
Item(
value={},
key=getattr(op, "key", ""),
namespace=getattr(op, "namespace", ()),
created_at=datetime(2024, 9, 24, 17, 29, 10, 128397),
updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397),
)
for op in ops
]
store = MockStore()
# concurrent calls are batched
results = await asyncio.gather(
store.aget(namespace=("a",), key="b"),
store.aget(namespace=("c",), key="d"),
)
assert results == [
Item(
value={},
key="b",
namespace=("a",),
created_at=datetime(2024, 9, 24, 17, 29, 10, 128397),
updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397),
),
Item(
value={},
key="d",
namespace=("c",),
created_at=datetime(2024, 9, 24, 17, 29, 10, 128397),
updated_at=datetime(2024, 9, 24, 17, 29, 10, 128397),
),
]
assert abatch.call_count == 1
assert [tuple(c.args[0]) for c in abatch.call_args_list] == [
(
GetOp(("a",), "b"),
GetOp(("c",), "d"),
),
]
def test_list_namespaces_basic() -> None:
store = InMemoryStore()
namespaces = [
("a", "b", "c"),
("a", "b", "d", "e"),
("a", "b", "d", "i"),
("a", "b", "f"),
("a", "c", "f"),
("b", "a", "f"),
("users", "123"),
("users", "456", "settings"),
("admin", "users", "789"),
]
for i, ns in enumerate(namespaces):
store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"})
result = store.list_namespaces(prefix=("a", "b"))
expected = [
("a", "b", "c"),
("a", "b", "d", "e"),
("a", "b", "d", "i"),
("a", "b", "f"),
]
assert sorted(result) == sorted(expected)
result = store.list_namespaces(suffix=("f",))
expected = [
("a", "b", "f"),
("a", "c", "f"),
("b", "a", "f"),
]
assert sorted(result) == sorted(expected)
result = store.list_namespaces(prefix=("a",), suffix=("f",))
expected = [
("a", "b", "f"),
("a", "c", "f"),
]
assert sorted(result) == sorted(expected)
# Test max_depth
result = store.list_namespaces(prefix=("a", "b"), max_depth=3)
expected = [
("a", "b", "c"),
("a", "b", "d"),
("a", "b", "f"),
]
assert sorted(result) == sorted(expected)
# Test limit and offset
result = store.list_namespaces(prefix=("a", "b"), limit=2)
expected = [
("a", "b", "c"),
("a", "b", "d", "e"),
]
assert result == expected
result = store.list_namespaces(prefix=("a", "b"), offset=2)
expected = [
("a", "b", "d", "i"),
("a", "b", "f"),
]
assert result == expected
result = store.list_namespaces(prefix=("a", "*", "f"))
expected = [
("a", "b", "f"),
("a", "c", "f"),
]
assert sorted(result) == sorted(expected)
result = store.list_namespaces(suffix=("*", "f"))
expected = [
("a", "b", "f"),
("a", "c", "f"),
("b", "a", "f"),
]
assert sorted(result) == sorted(expected)
result = store.list_namespaces(prefix=("nonexistent",))
assert result == []
result = store.list_namespaces(prefix=("users", "123"))
expected = [("users", "123")]
assert result == expected
def test_list_namespaces_with_wildcards() -> None:
store = InMemoryStore()
namespaces = [
("users", "123"),
("users", "456"),
("users", "789", "settings"),
("admin", "users", "789"),
("guests", "123"),
("guests", "456", "preferences"),
]
for i, ns in enumerate(namespaces):
store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"})
result = store.list_namespaces(prefix=("users", "*"))
expected = [
("users", "123"),
("users", "456"),
("users", "789", "settings"),
]
assert sorted(result) == sorted(expected)
result = store.list_namespaces(suffix=("*", "preferences"))
expected = [
("guests", "456", "preferences"),
]
assert result == expected
result = store.list_namespaces(prefix=("*", "users"), suffix=("*", "settings"))
assert result == []
store.put(
namespace=("admin", "users", "settings", "789"),
key="foo",
value={"data": "some_val"},
)
expected = [
("admin", "users", "settings", "789"),
]
def test_list_namespaces_pagination() -> None:
store = InMemoryStore()
for i in range(20):
ns = ("namespace", f"sub_{i:02d}")
store.put(namespace=ns, key=f"id_{i:02d}", value={"data": f"value_{i:02d}"})
result = store.list_namespaces(prefix=("namespace",), limit=5, offset=0)
expected = [("namespace", f"sub_{i:02d}") for i in range(5)]
assert result == expected
result = store.list_namespaces(prefix=("namespace",), limit=5, offset=5)
expected = [("namespace", f"sub_{i:02d}") for i in range(5, 10)]
assert result == expected
result = store.list_namespaces(prefix=("namespace",), limit=5, offset=15)
expected = [("namespace", f"sub_{i:02d}") for i in range(15, 20)]
assert result == expected
def test_list_namespaces_max_depth() -> None:
store = InMemoryStore()
namespaces = [
("a", "b", "c", "d"),
("a", "b", "c", "e"),
("a", "b", "f"),
("a", "g"),
("h", "i", "j", "k"),
]
for i, ns in enumerate(namespaces):
store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"})
result = store.list_namespaces(max_depth=2)
expected = [
("a", "b"),
("a", "g"),
("h", "i"),
]
assert sorted(result) == sorted(expected)
def test_list_namespaces_no_conditions() -> None:
store = InMemoryStore()
namespaces = [
("a", "b"),
("c", "d"),
("e", "f", "g"),
]
for i, ns in enumerate(namespaces):
store.put(namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"})
result = store.list_namespaces()
expected = namespaces
assert sorted(result) == sorted(expected)
def test_list_namespaces_empty_store() -> None:
store = InMemoryStore()
result = store.list_namespaces()
assert result == []
async def test_cannot_put_empty_namespace() -> None:
store = InMemoryStore()
doc = {"foo": "bar"}
with pytest.raises(InvalidNamespaceError):
store.put((), "foo", doc)
with pytest.raises(InvalidNamespaceError):
await store.aput((), "foo", doc)
with pytest.raises(InvalidNamespaceError):
store.put(("the", "thing.about"), "foo", doc)
with pytest.raises(InvalidNamespaceError):
await store.aput(("the", "thing.about"), "foo", doc)
with pytest.raises(InvalidNamespaceError):
store.put(("some", "fun", ""), "foo", doc)
with pytest.raises(InvalidNamespaceError):
await store.aput(("some", "fun", ""), "foo", doc)
with pytest.raises(InvalidNamespaceError):
await store.aput(("langgraph", "foo"), "bar", doc)
with pytest.raises(InvalidNamespaceError):
store.put(("langgraph", "foo"), "bar", doc)
await store.aput(("foo", "langgraph", "foo"), "bar", doc)
assert (await store.aget(("foo", "langgraph", "foo"), "bar")).value == doc # type: ignore[union-attr]
assert (await store.asearch(("foo", "langgraph", "foo"), query="bar"))[
0
].value == doc
await store.adelete(("foo", "langgraph", "foo"), "bar")
assert (await store.aget(("foo", "langgraph", "foo"), "bar")) is None
store.put(("foo", "langgraph", "foo"), "bar", doc)
assert store.get(("foo", "langgraph", "foo"), "bar").value == doc # type: ignore[union-attr]
assert store.search(("foo", "langgraph", "foo"), query="bar")[0].value == doc
store.delete(("foo", "langgraph", "foo"), "bar")
assert store.get(("foo", "langgraph", "foo"), "bar") is None
# Do the same but go past the public put api
await store.abatch([PutOp(("langgraph", "foo"), "bar", doc)])
assert (await store.aget(("langgraph", "foo"), "bar")).value == doc # type: ignore[union-attr]
assert (await store.asearch(("langgraph", "foo")))[0].value == doc
await store.adelete(("langgraph", "foo"), "bar")
assert (await store.aget(("langgraph", "foo"), "bar")) is None
store.batch([PutOp(("langgraph", "foo"), "bar", doc)])
assert store.get(("langgraph", "foo"), "bar").value == doc # type: ignore[union-attr]
assert store.search(("langgraph", "foo"))[0].value == doc
store.delete(("langgraph", "foo"), "bar")
assert store.get(("langgraph", "foo"), "bar") is None
async_store = MockAsyncBatchedStore()
doc = {"foo": "bar"}
with pytest.raises(InvalidNamespaceError):
await async_store.aput((), "foo", doc)
with pytest.raises(InvalidNamespaceError):
await async_store.aput(("the", "thing.about"), "foo", doc)
with pytest.raises(InvalidNamespaceError):
await async_store.aput(("some", "fun", ""), "foo", doc)
with pytest.raises(InvalidNamespaceError):
await async_store.aput(("langgraph", "foo"), "bar", doc)
await async_store.aput(("foo", "langgraph", "foo"), "bar", doc)
val = await async_store.aget(("foo", "langgraph", "foo"), "bar")
assert val is not None
assert val.value == doc
assert (await async_store.asearch(("foo", "langgraph", "foo")))[0].value == doc
assert (await async_store.asearch(("foo", "langgraph", "foo"), query="bar"))[
0
].value == doc
await async_store.adelete(("foo", "langgraph", "foo"), "bar")
assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")) is None
await async_store.abatch([PutOp(("valid", "namespace"), "key", doc)])
val = await async_store.aget(("valid", "namespace"), "key")
assert val is not None
assert val.value == doc
assert (await async_store.asearch(("valid", "namespace")))[0].value == doc
await async_store.adelete(("valid", "namespace"), "key")
assert (await async_store.aget(("valid", "namespace"), "key")) is None
async def test_async_batch_store_deduplication(mocker: MockerFixture) -> None:
abatch = mocker.spy(InMemoryStore, "batch")
store = MockAsyncBatchedStore()
same_doc = {"value": "same"}
diff_doc = {"value": "different"}
await asyncio.gather(
store.aput(namespace=("test",), key="same", value=same_doc),
store.aput(namespace=("test",), key="different", value=diff_doc),
)
abatch.reset_mock()
results = await asyncio.gather(
store.aget(namespace=("test",), key="same"),
store.aget(namespace=("test",), key="same"),
store.aget(namespace=("test",), key="different"),
)
assert len(results) == 3
assert results[0] == results[1]
assert results[0] != results[2]
assert results[0].value == same_doc # type: ignore
assert results[2].value == diff_doc # type: ignore
assert len(abatch.call_args_list) == 1
ops = list(abatch.call_args_list[0].args[1])
assert len(ops) == 2
assert GetOp(("test",), "same") in ops
assert GetOp(("test",), "different") in ops
abatch.reset_mock()
doc1 = {"value": 1}
doc2 = {"value": 2}
results = await asyncio.gather(
store.aput(namespace=("test",), key="key", value=doc1),
store.aput(namespace=("test",), key="key", value=doc2),
)
assert len(abatch.call_args_list) == 1
ops = list(abatch.call_args_list[0].args[1])
assert len(ops) == 1
assert ops[0] == PutOp(("test",), "key", doc2)
assert len(results) == 2
assert all(result is None for result in results)
result = await store.aget(namespace=("test",), key="key")
assert result is not None
assert result.value == doc2
abatch.reset_mock()
results = await asyncio.gather(
store.asearch(("test",), filter={"value": 2}),
store.asearch(("test",), filter={"value": 2}),
)
assert len(abatch.call_args_list) == 1
ops = list(abatch.call_args_list[0].args[1])
assert len(ops) == 1
assert len(results) == 2
assert results[0] == results[1]
assert len(results[0]) == 1
assert results[0][0].value == doc2
abatch.reset_mock()
@pytest.fixture
def fake_embeddings() -> CharacterEmbeddings:
return CharacterEmbeddings(dims=500)
def test_vector_store_initialization(fake_embeddings: CharacterEmbeddings) -> None:
"""Test store initialization with embedding config."""
store = InMemoryStore(
index={"dims": fake_embeddings.dims, "embed": fake_embeddings}
)
assert store.index_config is not None
assert store.index_config["dims"] == fake_embeddings.dims
assert store.index_config["embed"] == fake_embeddings
def test_vector_insert_with_auto_embedding(
fake_embeddings: CharacterEmbeddings,
) -> None:
"""Test inserting items that get auto-embedded."""
store = InMemoryStore(
index={"dims": fake_embeddings.dims, "embed": fake_embeddings}
)
docs = [
("doc1", {"text": "short text"}),
("doc2", {"text": "longer text document"}),
("doc3", {"text": "longest text document here"}),
("doc4", {"description": "text in description field"}),
("doc5", {"content": "text in content field"}),
("doc6", {"body": "text in body field"}),
]
for key, value in docs:
store.put(("test",), key, value)
results = store.search(("test",), query="long text")
assert len(results) > 0
doc_order = [r.key for r in results]
assert "doc2" in doc_order
assert "doc3" in doc_order
async def test_async_vector_insert_with_auto_embedding(
fake_embeddings: CharacterEmbeddings,
) -> None:
"""Test inserting items that get auto-embedded using async methods."""
store = InMemoryStore(
index={"dims": fake_embeddings.dims, "embed": fake_embeddings}
)
docs = [
("doc1", {"text": "short text"}),
("doc2", {"text": "longer text document"}),
("doc3", {"text": "longest text document here"}),
("doc4", {"description": "text in description field"}),
("doc5", {"content": "text in content field"}),
("doc6", {"body": "text in body field"}),
]
for key, value in docs:
await store.aput(("test",), key, value)
results = await store.asearch(("test",), query="long text")
assert len(results) > 0
doc_order = [r.key for r in results]
assert "doc2" in doc_order
assert "doc3" in doc_order
def test_vector_update_with_embedding(fake_embeddings: CharacterEmbeddings) -> None:
"""Test that updating items properly updates their embeddings."""
store = InMemoryStore(
index={"dims": fake_embeddings.dims, "embed": fake_embeddings}
)
store.put(("test",), "doc1", {"text": "zany zebra Xerxes"})
store.put(("test",), "doc2", {"text": "something about dogs"})
store.put(("test",), "doc3", {"text": "text about birds"})
results_initial = store.search(("test",), query="Zany Xerxes")
assert len(results_initial) > 0
assert results_initial[0].key == "doc1"
initial_score = results_initial[0].score
assert initial_score is not None
store.put(("test",), "doc1", {"text": "new text about dogs"})
results_after = store.search(("test",), query="Zany Xerxes")
after_score = next((r.score for r in results_after if r.key == "doc1"), 0.0)
assert after_score is not None
assert after_score < initial_score
results_new = store.search(("test",), query="new text about dogs")
for r in results_new:
if r.key == "doc1":
assert r.score > after_score
# Don't index this one
store.put(("test",), "doc4", {"text": "new text about dogs"}, index=False)
results_new = store.search(("test",), query="new text about dogs", limit=3)
assert not any(r.key == "doc4" for r in results_new)
async def test_async_vector_update_with_embedding(
fake_embeddings: CharacterEmbeddings,
) -> None:
"""Test that updating items properly updates their embeddings using async methods."""
store = InMemoryStore(
index={"dims": fake_embeddings.dims, "embed": fake_embeddings}
)
await store.aput(("test",), "doc1", {"text": "zany zebra Xerxes"})
await store.aput(("test",), "doc2", {"text": "something about dogs"})
await store.aput(("test",), "doc3", {"text": "text about birds"})
results_initial = await store.asearch(("test",), query="Zany Xerxes")
assert len(results_initial) > 0
assert results_initial[0].key == "doc1"
initial_score = results_initial[0].score
await store.aput(("test",), "doc1", {"text": "new text about dogs"})
results_after = await store.asearch(("test",), query="Zany Xerxes")
after_score = next((r.score for r in results_after if r.key == "doc1"), 0.0)
assert after_score is not None
assert after_score < initial_score
results_new = await store.asearch(("test",), query="new text about dogs")
for r in results_new:
if r.key == "doc1":
assert r.score is not None
assert r.score > after_score
# Don't index this one
await store.aput(("test",), "doc4", {"text": "new text about dogs"}, index=False)
results_new = await store.asearch(("test",), query="new text about dogs", limit=3)
assert not any(r.key == "doc4" for r in results_new)
def test_vector_search_with_filters(fake_embeddings: CharacterEmbeddings) -> None:
"""Test combining vector search with filters."""
inmem_store = InMemoryStore(
index={"dims": fake_embeddings.dims, "embed": fake_embeddings}
)
# Insert test documents
docs = [
("doc1", {"text": "red apple", "color": "red", "score": 4.5}),
("doc2", {"text": "red car", "color": "red", "score": 3.0}),
("doc3", {"text": "green apple", "color": "green", "score": 4.0}),
("doc4", {"text": "blue car", "color": "blue", "score": 3.5}),
]
for key, value in docs:
inmem_store.put(("test",), key, value)
results = inmem_store.search(("test",), query="apple", filter={"color": "red"})
assert len(results) == 2
assert results[0].key == "doc1"
results = inmem_store.search(("test",), query="car", filter={"color": "red"})
assert len(results) == 2
assert results[0].key == "doc2"
results = inmem_store.search(
("test",), query="bbbbluuu", filter={"score": {"$gt": 3.2}}
)
assert len(results) == 3
assert results[0].key == "doc4"
# Multiple filters
results = inmem_store.search(
("test",), query="apple", filter={"score": {"$gte": 4.0}, "color": "green"}
)
assert len(results) == 1
assert results[0].key == "doc3"
async def test_async_vector_search_with_filters(
fake_embeddings: CharacterEmbeddings,
) -> None:
"""Test combining vector search with filters using async methods."""
store = InMemoryStore(
index={"dims": fake_embeddings.dims, "embed": fake_embeddings}
)
# Insert test documents
docs = [
("doc1", {"text": "red apple", "color": "red", "score": 4.5}),
("doc2", {"text": "red car", "color": "red", "score": 3.0}),
("doc3", {"text": "green apple", "color": "green", "score": 4.0}),
("doc4", {"text": "blue car", "color": "blue", "score": 3.5}),
]
for key, value in docs:
await store.aput(("test",), key, value)
results = await store.asearch(("test",), query="apple", filter={"color": "red"})
assert len(results) == 2
assert results[0].key == "doc1"
results = await store.asearch(("test",), query="car", filter={"color": "red"})
assert len(results) == 2
assert results[0].key == "doc2"
results = await store.asearch(
("test",), query="bbbbluuu", filter={"score": {"$gt": 3.2}}
)
assert len(results) == 3
assert results[0].key == "doc4"
# Multiple filters
results = await store.asearch(
("test",), query="apple", filter={"score": {"$gte": 4.0}, "color": "green"}
)
assert len(results) == 1
assert results[0].key == "doc3"
async def test_async_batched_vector_search_concurrent(
fake_embeddings: CharacterEmbeddings,
) -> None:
"""Test concurrent vector search operations using async batched store."""
store = MockAsyncBatchedStore(
index={"dims": fake_embeddings.dims, "embed": fake_embeddings}
)
colors = ["red", "blue", "green", "yellow", "purple"]
items = ["apple", "car", "house", "book", "phone"]
scores = [3.0, 3.5, 4.0, 4.5, 5.0]
docs = []
for i in range(50):
color = colors[i % len(colors)]
item = items[i % len(items)]
score = scores[i % len(scores)]
docs.append(
(
f"doc{i}",
{"text": f"{color} {item}", "color": color, "score": score, "index": i},
)
)
coros = [
*[store.aput(("test",), key, value) for key, value in docs],
*[store.adelete(("test",), key) for key, value in docs],
*[store.aput(("test",), key, value) for key, value in docs],
]
await asyncio.gather(*coros)
# Prepare multiple search queries with different filters
search_queries: list[tuple[str, dict[str, Any]]] = [
("apple", {"color": "red"}),
("car", {"color": "blue"}),
("house", {"color": "green"}),
("phone", {"score": {"$gt": 4.99}}),
("book", {"score": {"$lte": 3.5}}),
("apple", {"score": {"$gte": 3.0}, "color": "red"}),
("car", {"score": {"$lt": 5.1}, "color": "blue"}),
("house", {"index": {"$gt": 25}}),
("phone", {"index": {"$lte": 10}}),
]
all_results = await asyncio.gather(
*[
store.asearch(("test",), query=query, filter=filter_)
for query, filter_ in search_queries
]
)
for results, (query, filter_) in zip(all_results, search_queries):
assert len(results) > 0, f"No results for query '{query}' with filter {filter_}"
for result in results:
if "color" in filter_:
assert result.value["color"] == filter_["color"]
if "score" in filter_:
score = result.value["score"]
for op, value in filter_["score"].items():
if op == "$gt":
assert score > value
elif op == "$gte":
assert score >= value
elif op == "$lt":
assert score < value
elif op == "$lte":
assert score <= value
if "index" in filter_:
index = result.value["index"]
for op, value in filter_["index"].items():
if op == "$gt":
assert index > value
elif op == "$gte":
assert index >= value
elif op == "$lt":
assert index < value
elif op == "$lte":
assert index <= value
def test_vector_search_pagination(fake_embeddings: CharacterEmbeddings) -> None:
"""Test pagination with vector search."""
store = InMemoryStore(
index={"dims": fake_embeddings.dims, "embed": fake_embeddings}
)
for i in range(5):
store.put(("test",), f"doc{i}", {"text": f"test document number {i}"})
results_page1 = store.search(("test",), query="test", limit=2)
results_page2 = store.search(("test",), query="test", limit=2, offset=2)
assert len(results_page1) == 2
assert len(results_page2) == 2
assert results_page1[0].key != results_page2[0].key
all_results = store.search(("test",), query="test", limit=10)
assert len(all_results) == 5
async def test_async_vector_search_pagination(
fake_embeddings: CharacterEmbeddings,
) -> None:
"""Test pagination with vector search using async methods."""
store = InMemoryStore(
index={"dims": fake_embeddings.dims, "embed": fake_embeddings}
)
for i in range(5):
await store.aput(("test",), f"doc{i}", {"text": f"test document number {i}"})
results_page1 = await store.asearch(("test",), query="test", limit=2)
results_page2 = await store.asearch(("test",), query="test", limit=2, offset=2)
assert len(results_page1) == 2
assert len(results_page2) == 2
assert results_page1[0].key != results_page2[0].key
all_results = await store.asearch(("test",), query="test", limit=10)
assert len(all_results) == 5
async def test_embed_with_path(fake_embeddings: CharacterEmbeddings) -> None:
# Test store-level field configuration
store = InMemoryStore(
index={
"dims": fake_embeddings.dims,
"embed": fake_embeddings,
# Key 2 isn't included. Don't index it.
"fields": ["key0", "key1", "key3"],
}
)
# This will have 2 vectors representing it
doc1 = {
# Omit key0 - check it doesn't raise an error
"key1": "xxx",
"key2": "yyy",
"key3": "zzz",
}
# This will have 3 vectors representing it
doc2 = {
"key0": "uuu",
"key1": "vvv",
"key2": "www",
"key3": "xxx",
}
await store.aput(("test",), "doc1", doc1)
await store.aput(("test",), "doc2", doc2)
# doc2.key3 and doc1.key1 both would have the highest score
results = await store.asearch(("test",), query="xxx")
assert len(results) == 2
assert results[0].key != results[1].key
ascore = results[0].score
bscore = results[1].score
assert ascore == bscore
assert ascore is not None and bscore is not None
results = await store.asearch(("test",), query="uuu")
assert len(results) == 2
assert results[0].key != results[1].key
assert results[0].key == "doc2"
assert results[0].score is not None and results[0].score > results[1].score
assert ascore == pytest.approx(results[0].score, abs=1e-5)
# Un-indexed - will have low results for both. Not zero (because we're projecting)
# but less than the above.
results = await store.asearch(("test",), query="www")
assert len(results) == 2
assert results[0].score < ascore
assert results[1].score < ascore
# Test operation-level field configuration
store_no_defaults = InMemoryStore(
index={
"dims": fake_embeddings.dims,
"embed": fake_embeddings,
"fields": ["key17"],
}
)
doc3 = {
"key0": "aaa",
"key1": "bbb",
"key2": "ccc",
"key3": "ddd",
}
doc4 = {
"key0": "eee",
"key1": "bbb", # Same as doc3.key1
"key2": "fff",
"key3": "ggg",
}
await store_no_defaults.aput(("test",), "doc3", doc3, index=["key0", "key1"])
await store_no_defaults.aput(("test",), "doc4", doc4, index=["key1", "key3"])
results = await store_no_defaults.asearch(("test",), query="aaa")
assert len(results) == 2
assert results[0].key == "doc3"
assert results[0].score is not None and results[0].score > results[1].score
results = await store_no_defaults.asearch(("test",), query="ggg")
assert len(results) == 2
assert results[0].key == "doc4"
assert results[0].score is not None and results[0].score > results[1].score
results = await store_no_defaults.asearch(("test",), query="bbb")
assert len(results) == 2
assert results[0].key != results[1].key
assert results[0].score == results[1].score
results = await store_no_defaults.asearch(("test",), query="ccc")
assert len(results) == 2
assert all(r.score < ascore for r in results)
doc5 = {
"key0": "hhh",
"key1": "iii",
}
await store_no_defaults.aput(("test",), "doc5", doc5, index=False)
results = await store_no_defaults.asearch(("test",), query="hhh")
assert len(results) == 3
doc5_result = next(r for r in results if r.key == "doc5")
assert doc5_result.score is None
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/tests/test_memory.py`:
```py
from typing import Any
import pytest
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Checkpoint,
CheckpointMetadata,
create_checkpoint,
empty_checkpoint,
)
from langgraph.checkpoint.memory import MemorySaver
class TestMemorySaver:
@pytest.fixture(autouse=True)
def setup(self) -> None:
self.memory_saver = MemorySaver()
# objects for test setup
self.config_1: RunnableConfig = {
"configurable": {
"thread_id": "thread-1",
"checkpoint_ns": "",
# for backwards compatibility testing
"thread_ts": "1",
}
}
self.config_2: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_ns": "",
"checkpoint_id": "2",
}
}
self.config_3: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2-inner",
"checkpoint_ns": "inner",
}
}
self.chkpnt_1: Checkpoint = empty_checkpoint()
self.chkpnt_2: Checkpoint = create_checkpoint(self.chkpnt_1, {}, 1)
self.chkpnt_3: Checkpoint = empty_checkpoint()
self.metadata_1: CheckpointMetadata = {
"source": "input",
"step": 2,
"writes": {},
"score": 1,
}
self.metadata_2: CheckpointMetadata = {
"source": "loop",
"step": 1,
"writes": {"foo": "bar"},
"score": None,
}
self.metadata_3: CheckpointMetadata = {}
async def test_search(self) -> None:
# set up test
# save checkpoints
self.memory_saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {})
self.memory_saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {})
self.memory_saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {})
# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match
search_results_1 = list(self.memory_saver.list(None, filter=query_1))
assert len(search_results_1) == 1
assert search_results_1[0].metadata == self.metadata_1
search_results_2 = list(self.memory_saver.list(None, filter=query_2))
assert len(search_results_2) == 1
assert search_results_2[0].metadata == self.metadata_2
search_results_3 = list(self.memory_saver.list(None, filter=query_3))
assert len(search_results_3) == 3
search_results_4 = list(self.memory_saver.list(None, filter=query_4))
assert len(search_results_4) == 0
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = list(
self.memory_saver.list({"configurable": {"thread_id": "thread-2"}})
)
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}
# TODO: test before and limit params
async def test_asearch(self) -> None:
# set up test
# save checkpoints
self.memory_saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {})
self.memory_saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {})
self.memory_saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {})
# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match
search_results_1 = [
c async for c in self.memory_saver.alist(None, filter=query_1)
]
assert len(search_results_1) == 1
assert search_results_1[0].metadata == self.metadata_1
search_results_2 = [
c async for c in self.memory_saver.alist(None, filter=query_2)
]
assert len(search_results_2) == 1
assert search_results_2[0].metadata == self.metadata_2
search_results_3 = [
c async for c in self.memory_saver.alist(None, filter=query_3)
]
assert len(search_results_3) == 3
search_results_4 = [
c async for c in self.memory_saver.alist(None, filter=query_4)
]
assert len(search_results_4) == 0
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint/tests/embed_test_utils.py`:
```py
"""Embedding utilities for testing."""
import math
import random
from collections import Counter, defaultdict
from typing import Any
from langchain_core.embeddings import Embeddings
class CharacterEmbeddings(Embeddings):
"""Simple character-frequency based embeddings using random projections."""
def __init__(self, dims: int = 50, seed: int = 42):
"""Initialize with embedding dimensions and random seed."""
self._rng = random.Random(seed)
self.dims = dims
# Create projection vector for each character lazily
self._char_projections: defaultdict[str, list[float]] = defaultdict(
lambda: [
self._rng.gauss(0, 1 / math.sqrt(self.dims)) for _ in range(self.dims)
]
)
def _embed_one(self, text: str) -> list[float]:
"""Embed a single text."""
counts = Counter(text)
total = sum(counts.values())
if total == 0:
return [0.0] * self.dims
embedding = [0.0] * self.dims
for char, count in counts.items():
weight = count / total
char_proj = self._char_projections[char]
for i, proj in enumerate(char_proj):
embedding[i] += weight * proj
norm = math.sqrt(sum(x * x for x in embedding))
if norm > 0:
embedding = [x / norm for x in embedding]
return embedding
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of documents."""
return [self._embed_one(text) for text in texts]
def embed_query(self, text: str) -> list[float]:
"""Embed a query string."""
return self._embed_one(text)
def __eq__(self, other: Any) -> bool:
return isinstance(other, CharacterEmbeddings) and self.dims == other.dims
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/checkpoint/duckdb/__init__.py`:
```py
import threading
from contextlib import contextmanager
from typing import Any, Iterator, Optional, Sequence
from langchain_core.runnables import RunnableConfig
import duckdb
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
get_checkpoint_id,
)
from langgraph.checkpoint.duckdb.base import BaseDuckDBSaver
from langgraph.checkpoint.serde.base import SerializerProtocol
class DuckDBSaver(BaseDuckDBSaver):
lock: threading.Lock
def __init__(
self,
conn: duckdb.DuckDBPyConnection,
serde: Optional[SerializerProtocol] = None,
) -> None:
super().__init__(serde=serde)
self.conn = conn
self.lock = threading.Lock()
@classmethod
@contextmanager
def from_conn_string(cls, conn_string: str) -> Iterator["DuckDBSaver"]:
"""Create a new DuckDBSaver instance from a connection string.
Args:
conn_string (str): The DuckDB connection info string.
Returns:
DuckDBSaver: A new DuckDBSaver instance.
"""
with duckdb.connect(conn_string) as conn:
yield cls(conn)
def setup(self) -> None:
"""Set up the checkpoint database asynchronously.
This method creates the necessary tables in the DuckDB database if they don't
already exist and runs database migrations. It MUST be called directly by the user
the first time checkpointer is used.
"""
with self.lock, self.conn.cursor() as cur:
try:
row = cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
).fetchone()
if row is None:
version = -1
else:
version = row[0]
except duckdb.CatalogException:
version = -1
for v, migration in zip(
range(version + 1, len(self.MIGRATIONS)),
self.MIGRATIONS[version + 1 :],
):
cur.execute(migration)
cur.execute("INSERT INTO checkpoint_migrations (v) VALUES (?)", [v])
def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints from the database.
This method retrieves a list of checkpoint tuples from the DuckDB database based
on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first).
Args:
config (RunnableConfig): The config to use for listing the checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. Defaults to None.
before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None.
limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None.
Yields:
Iterator[CheckpointTuple]: An iterator of checkpoint tuples.
Examples:
>>> from langgraph.checkpoint.duckdb import DuckDBSaver
>>> with DuckDBSaver.from_conn_string(":memory:") as memory:
... # Run a graph, then list the checkpoints
>>> config = {"configurable": {"thread_id": "1"}}
>>> checkpoints = list(memory.list(config, limit=2))
>>> print(checkpoints)
[CheckpointTuple(...), CheckpointTuple(...)]
>>> config = {"configurable": {"thread_id": "1"}}
>>> before = {"configurable": {"checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875"}}
>>> with DuckDBSaver.from_conn_string(":memory:") as memory:
... # Run a graph, then list the checkpoints
>>> checkpoints = list(memory.list(config, before=before))
>>> print(checkpoints)
[CheckpointTuple(...), ...]
"""
where, args = self._search_where(config, filter, before)
query = self.SELECT_SQL + where + " ORDER BY checkpoint_id DESC"
if limit:
query += f" LIMIT {limit}"
# if we change this to use .stream() we need to make sure to close the cursor
with self._cursor() as cur:
cur.execute(query, args)
for value in cur.fetchall():
(
thread_id,
checkpoint,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
metadata,
channel_values,
pending_writes,
pending_sends,
) = value
yield CheckpointTuple(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
self._load_checkpoint(
checkpoint,
channel_values,
pending_sends,
),
self._load_metadata(metadata),
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
self._load_writes(pending_writes),
)
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database.
This method retrieves a checkpoint tuple from the DuckDB database based on the
provided config. If the config contains a "checkpoint_id" key, the checkpoint with
the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
Examples:
Basic:
>>> config = {"configurable": {"thread_id": "1"}}
>>> checkpoint_tuple = memory.get_tuple(config)
>>> print(checkpoint_tuple)
CheckpointTuple(...)
With timestamp:
>>> config = {
... "configurable": {
... "thread_id": "1",
... "checkpoint_ns": "",
... "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875",
... }
... }
>>> checkpoint_tuple = memory.get_tuple(config)
>>> print(checkpoint_tuple)
CheckpointTuple(...)
""" # noqa
thread_id = config["configurable"]["thread_id"]
checkpoint_id = get_checkpoint_id(config)
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id:
args: tuple[Any, ...] = (thread_id, checkpoint_ns, checkpoint_id)
where = "WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?"
else:
args = (thread_id, checkpoint_ns)
where = "WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1"
with self._cursor() as cur:
cur.execute(
self.SELECT_SQL + where,
args,
)
value = cur.fetchone()
if value:
(
thread_id,
checkpoint,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
metadata,
channel_values,
pending_writes,
pending_sends,
) = value
return CheckpointTuple(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
self._load_checkpoint(
checkpoint,
channel_values,
pending_sends,
),
self._load_metadata(metadata),
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
self._load_writes(pending_writes),
)
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database.
This method saves a checkpoint to the DuckDB database. The checkpoint is associated
with the provided config and its parent config (if any).
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Examples:
>>> from langgraph.checkpoint.duckdb import DuckDBSaver
>>> with DuckDBSaver.from_conn_string(":memory:") as memory:
>>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
>>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "channel_values": {"key": "value"}}
>>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {})
>>> print(saved_config)
{'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef4f797-8335-6428-8001-8a1503f9b875'}}
"""
configurable = config["configurable"].copy()
thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
checkpoint_id = configurable.pop(
"checkpoint_id", configurable.pop("thread_ts", None)
)
copy = checkpoint.copy()
next_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
checkpoint_blobs = self._dump_blobs(
thread_id,
checkpoint_ns,
copy.pop("channel_values"), # type: ignore[misc]
new_versions,
)
with self._cursor() as cur:
if checkpoint_blobs:
cur.executemany(self.UPSERT_CHECKPOINT_BLOBS_SQL, checkpoint_blobs)
cur.execute(
self.UPSERT_CHECKPOINTS_SQL,
(
thread_id,
checkpoint_ns,
checkpoint["id"],
checkpoint_id,
self._dump_checkpoint(copy),
self._dump_metadata(metadata),
),
)
return next_config
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint.
This method saves intermediate writes associated with a checkpoint to the DuckDB database.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (List[Tuple[str, Any]]): List of writes to store.
task_id (str): Identifier for the task creating the writes.
"""
query = (
self.UPSERT_CHECKPOINT_WRITES_SQL
if all(w[0] in WRITES_IDX_MAP for w in writes)
else self.INSERT_CHECKPOINT_WRITES_SQL
)
with self._cursor() as cur:
cur.executemany(
query,
self._dump_writes(
config["configurable"]["thread_id"],
config["configurable"]["checkpoint_ns"],
config["configurable"]["checkpoint_id"],
task_id,
writes,
),
)
@contextmanager
def _cursor(self) -> Iterator[duckdb.DuckDBPyConnection]:
with self.lock, self.conn.cursor() as cur:
yield cur
__all__ = ["DuckDBSaver", "Conn"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/checkpoint/duckdb/aio.py`:
```py
import asyncio
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Iterator, Optional, Sequence
from langchain_core.runnables import RunnableConfig
import duckdb
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
get_checkpoint_id,
)
from langgraph.checkpoint.duckdb.base import BaseDuckDBSaver
from langgraph.checkpoint.serde.base import SerializerProtocol
class AsyncDuckDBSaver(BaseDuckDBSaver):
lock: asyncio.Lock
def __init__(
self,
conn: duckdb.DuckDBPyConnection,
serde: Optional[SerializerProtocol] = None,
) -> None:
super().__init__(serde=serde)
self.conn = conn
self.lock = asyncio.Lock()
self.loop = asyncio.get_running_loop()
@classmethod
@asynccontextmanager
async def from_conn_string(
cls,
conn_string: str,
) -> AsyncIterator["AsyncDuckDBSaver"]:
"""Create a new AsyncDuckDBSaver instance from a connection string.
Args:
conn_string (str): The DuckDB connection info string.
Returns:
AsyncDuckDBSaver: A new AsyncDuckDBSaver instance.
"""
with duckdb.connect(conn_string) as conn:
yield cls(conn)
async def setup(self) -> None:
"""Set up the checkpoint database asynchronously.
This method creates the necessary tables in the DuckDB database if they don't
already exist and runs database migrations. It MUST be called directly by the user
the first time checkpointer is used.
"""
async with self.lock:
with self.conn.cursor() as cur:
try:
await asyncio.to_thread(
cur.execute,
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1",
)
row = await asyncio.to_thread(cur.fetchone)
if row is None:
version = -1
else:
version = row[0]
except duckdb.CatalogException:
version = -1
for v, migration in zip(
range(version + 1, len(self.MIGRATIONS)),
self.MIGRATIONS[version + 1 :],
):
await asyncio.to_thread(cur.execute, migration)
await asyncio.to_thread(
cur.execute,
"INSERT INTO checkpoint_migrations (v) VALUES (?)",
[v],
)
async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""List checkpoints from the database asynchronously.
This method retrieves a list of checkpoint tuples from the DuckDB database based
on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first).
Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata.
before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None.
limit (Optional[int]): Maximum number of checkpoints to return.
Yields:
AsyncIterator[CheckpointTuple]: An asynchronous iterator of matching checkpoint tuples.
"""
where, args = self._search_where(config, filter, before)
query = self.SELECT_SQL + where + " ORDER BY checkpoint_id DESC"
if limit:
query += f" LIMIT {limit}"
# if we change this to use .stream() we need to make sure to close the cursor
async with self._cursor() as cur:
await asyncio.to_thread(cur.execute, query, args)
results = await asyncio.to_thread(cur.fetchall)
for value in results:
(
thread_id,
checkpoint,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
metadata,
channel_values,
pending_writes,
pending_sends,
) = value
yield CheckpointTuple(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
await asyncio.to_thread(
self._load_checkpoint,
checkpoint,
channel_values,
pending_sends,
),
self._load_metadata(metadata),
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
await asyncio.to_thread(self._load_writes, pending_writes),
)
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database asynchronously.
This method retrieves a checkpoint tuple from the DuckDBdatabase based on the
provided config. If the config contains a "checkpoint_id" key, the checkpoint with
the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_id = get_checkpoint_id(config)
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id:
args: tuple[Any, ...] = (thread_id, checkpoint_ns, checkpoint_id)
where = "WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?"
else:
args = (thread_id, checkpoint_ns)
where = "WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1"
async with self._cursor() as cur:
await asyncio.to_thread(
cur.execute,
self.SELECT_SQL + where,
args,
)
value = await asyncio.to_thread(cur.fetchone)
if value:
(
thread_id,
checkpoint,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
metadata,
channel_values,
pending_writes,
pending_sends,
) = value
return CheckpointTuple(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
await asyncio.to_thread(
self._load_checkpoint,
checkpoint,
channel_values,
pending_sends,
),
self._load_metadata(metadata),
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
await asyncio.to_thread(self._load_writes, pending_writes),
)
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database asynchronously.
This method saves a checkpoint to the DuckDB database. The checkpoint is associated
with the provided config and its parent config (if any).
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
"""
configurable = config["configurable"].copy()
thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
checkpoint_id = configurable.pop(
"checkpoint_id", configurable.pop("thread_ts", None)
)
copy = checkpoint.copy()
next_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
checkpoint_blobs = await asyncio.to_thread(
self._dump_blobs,
thread_id,
checkpoint_ns,
copy.pop("channel_values"), # type: ignore[misc]
new_versions,
)
async with self._cursor() as cur:
if checkpoint_blobs:
await asyncio.to_thread(
cur.executemany, self.UPSERT_CHECKPOINT_BLOBS_SQL, checkpoint_blobs
)
await asyncio.to_thread(
cur.execute,
self.UPSERT_CHECKPOINTS_SQL,
(
thread_id,
checkpoint_ns,
checkpoint["id"],
checkpoint_id,
self._dump_checkpoint(copy),
self._dump_metadata(metadata),
),
)
return next_config
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint asynchronously.
This method saves intermediate writes associated with a checkpoint to the database.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair.
task_id (str): Identifier for the task creating the writes.
"""
query = (
self.UPSERT_CHECKPOINT_WRITES_SQL
if all(w[0] in WRITES_IDX_MAP for w in writes)
else self.INSERT_CHECKPOINT_WRITES_SQL
)
params = await asyncio.to_thread(
self._dump_writes,
config["configurable"]["thread_id"],
config["configurable"]["checkpoint_ns"],
config["configurable"]["checkpoint_id"],
task_id,
writes,
)
async with self._cursor() as cur:
await asyncio.to_thread(cur.executemany, query, params)
@asynccontextmanager
async def _cursor(self) -> AsyncIterator[duckdb.DuckDBPyConnection]:
async with self.lock:
with self.conn.cursor() as cur:
yield cur
def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints from the database.
This method retrieves a list of checkpoint tuples from the DuckDB database based
on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first).
Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata.
before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None.
limit (Optional[int]): Maximum number of checkpoints to return.
Yields:
Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples.
"""
aiter_ = self.alist(config, filter=filter, before=before, limit=limit)
while True:
try:
yield asyncio.run_coroutine_threadsafe(
anext(aiter_),
self.loop,
).result()
except StopAsyncIteration:
break
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database.
This method retrieves a checkpoint tuple from the DuckDB database based on the
provided config. If the config contains a "checkpoint_id" key, the checkpoint with
the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
try:
# check if we are in the main thread, only bg threads can block
# we don't check in other methods to avoid the overhead
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AsyncDuckDBSaver are only allowed from a "
"different thread. From the main thread, use the async interface."
"For example, use `await checkpointer.aget_tuple(...)` or `await "
"graph.ainvoke(...)`."
)
except RuntimeError:
pass
return asyncio.run_coroutine_threadsafe(
self.aget_tuple(config), self.loop
).result()
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database.
This method saves a checkpoint to the DuckDB database. The checkpoint is associated
with the provided config and its parent config (if any).
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
"""
return asyncio.run_coroutine_threadsafe(
self.aput(config, checkpoint, metadata, new_versions), self.loop
).result()
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint.
This method saves intermediate writes associated with a checkpoint to the database.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair.
task_id (str): Identifier for the task creating the writes.
"""
return asyncio.run_coroutine_threadsafe(
self.aput_writes(config, writes, task_id), self.loop
).result()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/checkpoint/duckdb/base.py`:
```py
import json
import random
from typing import Any, List, Optional, Sequence, Tuple, cast
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
get_checkpoint_id,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol
MetadataInput = Optional[dict[str, Any]]
"""
To add a new migration, add a new string to the MIGRATIONS list.
The position of the migration in the list is the version number.
"""
MIGRATIONS = [
"""CREATE TABLE IF NOT EXISTS checkpoint_migrations (
v INTEGER PRIMARY KEY
);""",
"""CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint JSON NOT NULL,
metadata JSON NOT NULL DEFAULT '{}',
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);""",
"""CREATE TABLE IF NOT EXISTS checkpoint_blobs (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
channel TEXT NOT NULL,
version TEXT NOT NULL,
type TEXT NOT NULL,
blob BLOB,
PRIMARY KEY (thread_id, checkpoint_ns, channel, version)
);""",
"""CREATE TABLE IF NOT EXISTS checkpoint_writes (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx INTEGER NOT NULL,
channel TEXT NOT NULL,
type TEXT,
blob BLOB NOT NULL,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);""",
]
SELECT_SQL = f"""
select
thread_id,
checkpoint,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
metadata,
(
select array_agg(array[bl.channel::bytea, bl.type::bytea, bl.blob])
from (
SELECT unnest(json_keys(json_extract(checkpoint, '$.channel_versions'))) as key
) cv
inner join checkpoint_blobs bl
on bl.thread_id = checkpoints.thread_id
and bl.checkpoint_ns = checkpoints.checkpoint_ns
and bl.channel = cv.key
and bl.version = json_extract_string(checkpoint, '$.channel_versions.' || cv.key)
) as channel_values,
(
select
array_agg(array[cw.task_id::blob, cw.channel::blob, cw.type::blob, cw.blob])
from checkpoint_writes cw
where cw.thread_id = checkpoints.thread_id
and cw.checkpoint_ns = checkpoints.checkpoint_ns
and cw.checkpoint_id = checkpoints.checkpoint_id
) as pending_writes,
(
select array_agg(array[cw.type::blob, cw.blob])
from checkpoint_writes cw
where cw.thread_id = checkpoints.thread_id
and cw.checkpoint_ns = checkpoints.checkpoint_ns
and cw.checkpoint_id = checkpoints.parent_checkpoint_id
and cw.channel = '{TASKS}'
) as pending_sends
from checkpoints """
UPSERT_CHECKPOINT_BLOBS_SQL = """
INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, version, type, blob)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (thread_id, checkpoint_ns, channel, version) DO NOTHING
"""
UPSERT_CHECKPOINTS_SQL = """
INSERT INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id)
DO UPDATE SET
checkpoint = EXCLUDED.checkpoint,
metadata = EXCLUDED.metadata;
"""
UPSERT_CHECKPOINT_WRITES_SQL = """
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET
channel = EXCLUDED.channel,
type = EXCLUDED.type,
blob = EXCLUDED.blob;
"""
INSERT_CHECKPOINT_WRITES_SQL = """
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING
"""
class BaseDuckDBSaver(BaseCheckpointSaver[str]):
SELECT_SQL = SELECT_SQL
MIGRATIONS = MIGRATIONS
UPSERT_CHECKPOINT_BLOBS_SQL = UPSERT_CHECKPOINT_BLOBS_SQL
UPSERT_CHECKPOINTS_SQL = UPSERT_CHECKPOINTS_SQL
UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL
INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL
jsonplus_serde = JsonPlusSerializer()
def _load_checkpoint(
self,
checkpoint_json_str: str,
channel_values: list[tuple[bytes, bytes, bytes]],
pending_sends: list[tuple[bytes, bytes]],
) -> Checkpoint:
checkpoint = json.loads(checkpoint_json_str)
return {
**checkpoint,
"pending_sends": [
self.serde.loads_typed((c.decode(), b)) for c, b in pending_sends or []
],
"channel_values": self._load_blobs(channel_values),
}
def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]:
return {**checkpoint, "pending_sends": []}
def _load_blobs(
self, blob_values: list[tuple[bytes, bytes, bytes]]
) -> dict[str, Any]:
if not blob_values:
return {}
return {
k.decode(): self.serde.loads_typed((t.decode(), v))
for k, t, v in blob_values
if t.decode() != "empty"
}
def _dump_blobs(
self,
thread_id: str,
checkpoint_ns: str,
values: dict[str, Any],
versions: ChannelVersions,
) -> list[tuple[str, str, str, str, str, Optional[bytes]]]:
if not versions:
return []
return [
(
thread_id,
checkpoint_ns,
k,
cast(str, ver),
*(
self.serde.dumps_typed(values[k])
if k in values
else ("empty", None)
),
)
for k, ver in versions.items()
]
def _load_writes(
self, writes: list[tuple[bytes, bytes, bytes, bytes]]
) -> list[tuple[str, str, Any]]:
return (
[
(
tid.decode(),
channel.decode(),
self.serde.loads_typed((t.decode(), v)),
)
for tid, channel, t, v in writes
]
if writes
else []
)
def _dump_writes(
self,
thread_id: str,
checkpoint_ns: str,
checkpoint_id: str,
task_id: str,
writes: Sequence[tuple[str, Any]],
) -> list[tuple[str, str, str, str, int, str, str, bytes]]:
return [
(
thread_id,
checkpoint_ns,
checkpoint_id,
task_id,
WRITES_IDX_MAP.get(channel, idx),
channel,
*self.serde.dumps_typed(value),
)
for idx, (channel, value) in enumerate(writes)
]
def _load_metadata(self, metadata_json_str: str) -> CheckpointMetadata:
return self.jsonplus_serde.loads(metadata_json_str.encode())
def _dump_metadata(self, metadata: CheckpointMetadata) -> str:
serialized_metadata = self.jsonplus_serde.dumps(metadata)
# NOTE: we're using JSON serializer (not msgpack), so we need to remove null characters before writing
return serialized_metadata.decode().replace("\\u0000", "")
def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
if current is None:
current_v = 0
elif isinstance(current, int):
current_v = current
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
next_h = random.random()
return f"{next_v:032}.{next_h:016}"
def _search_where(
self,
config: Optional[RunnableConfig],
filter: MetadataInput,
before: Optional[RunnableConfig] = None,
) -> Tuple[str, List[Any]]:
"""Return WHERE clause predicates for alist() given config, filter, before.
This method returns a tuple of a string and a tuple of values. The string
is the parametered WHERE clause predicate (including the WHERE keyword):
"WHERE column1 = $1 AND column2 IS $2". The list of values contains the
values for each of the corresponding parameters.
"""
wheres = []
param_values = []
# construct predicate for config filter
if config:
wheres.append("thread_id = ?")
param_values.append(config["configurable"]["thread_id"])
checkpoint_ns = config["configurable"].get("checkpoint_ns")
if checkpoint_ns is not None:
wheres.append("checkpoint_ns = ?")
param_values.append(checkpoint_ns)
if checkpoint_id := get_checkpoint_id(config):
wheres.append("checkpoint_id = ?")
param_values.append(checkpoint_id)
# construct predicate for metadata filter
if filter:
wheres.append("json_contains(metadata, ?)")
param_values.append(json.dumps(filter))
# construct predicate for `before`
if before is not None:
wheres.append("checkpoint_id < ?")
param_values.append(get_checkpoint_id(before))
return (
"WHERE " + " AND ".join(wheres) if wheres else "",
param_values,
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/store/duckdb/__init__.py`:
```py
from langgraph.store.duckdb.aio import AsyncDuckDBStore
from langgraph.store.duckdb.base import DuckDBStore
__all__ = ["AsyncDuckDBStore", "DuckDBStore"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/store/duckdb/aio.py`:
```py
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import (
AsyncIterator,
Iterable,
Sequence,
cast,
)
import duckdb
from langgraph.store.base import GetOp, ListNamespacesOp, Op, PutOp, Result, SearchOp
from langgraph.store.base.batch import AsyncBatchedBaseStore
from langgraph.store.duckdb.base import (
BaseDuckDBStore,
_convert_ns,
_group_ops,
_row_to_item,
)
logger = logging.getLogger(__name__)
class AsyncDuckDBStore(AsyncBatchedBaseStore, BaseDuckDBStore):
def __init__(
self,
conn: duckdb.DuckDBPyConnection,
) -> None:
super().__init__()
self.conn = conn
self.loop = asyncio.get_running_loop()
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
grouped_ops, num_ops = _group_ops(ops)
results: list[Result] = [None] * num_ops
tasks = []
if GetOp in grouped_ops:
tasks.append(
self._batch_get_ops(
cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), results
)
)
if PutOp in grouped_ops:
tasks.append(
self._batch_put_ops(
cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp])
)
)
if SearchOp in grouped_ops:
tasks.append(
self._batch_search_ops(
cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]),
results,
)
)
if ListNamespacesOp in grouped_ops:
tasks.append(
self._batch_list_namespaces_ops(
cast(
Sequence[tuple[int, ListNamespacesOp]],
grouped_ops[ListNamespacesOp],
),
results,
)
)
await asyncio.gather(*tasks)
return results
def batch(self, ops: Iterable[Op]) -> list[Result]:
return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result()
async def _batch_get_ops(
self,
get_ops: Sequence[tuple[int, GetOp]],
results: list[Result],
) -> None:
cursors = []
for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops):
cur = self.conn.cursor()
await asyncio.to_thread(cur.execute, query, params)
cursors.append((cur, namespace, items))
for cur, namespace, items in cursors:
rows = await asyncio.to_thread(cur.fetchall)
key_to_row = {row[1]: row for row in rows}
for idx, key in items:
row = key_to_row.get(key)
if row:
results[idx] = _row_to_item(namespace, row)
else:
results[idx] = None
async def _batch_put_ops(
self,
put_ops: Sequence[tuple[int, PutOp]],
) -> None:
queries = self._get_batch_PUT_queries(put_ops)
for query, params in queries:
cur = self.conn.cursor()
await asyncio.to_thread(cur.execute, query, params)
async def _batch_search_ops(
self,
search_ops: Sequence[tuple[int, SearchOp]],
results: list[Result],
) -> None:
queries = self._get_batch_search_queries(search_ops)
cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = []
for (query, params), (idx, _) in zip(queries, search_ops):
cur = self.conn.cursor()
await asyncio.to_thread(cur.execute, query, params)
cursors.append((cur, idx))
for cur, idx in cursors:
rows = await asyncio.to_thread(cur.fetchall)
items = [_row_to_item(_convert_ns(row[0]), row) for row in rows]
results[idx] = items
async def _batch_list_namespaces_ops(
self,
list_ops: Sequence[tuple[int, ListNamespacesOp]],
results: list[Result],
) -> None:
queries = self._get_batch_list_namespaces_queries(list_ops)
cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = []
for (query, params), (idx, _) in zip(queries, list_ops):
cur = self.conn.cursor()
await asyncio.to_thread(cur.execute, query, params)
cursors.append((cur, idx))
for cur, idx in cursors:
rows = cast(list[tuple], await asyncio.to_thread(cur.fetchall))
namespaces = [_convert_ns(row[0]) for row in rows]
results[idx] = namespaces
@classmethod
@asynccontextmanager
async def from_conn_string(
cls,
conn_string: str,
) -> AsyncIterator["AsyncDuckDBStore"]:
"""Create a new AsyncDuckDBStore instance from a connection string.
Args:
conn_string (str): The DuckDB connection info string.
Returns:
AsyncDuckDBStore: A new AsyncDuckDBStore instance.
"""
with duckdb.connect(conn_string) as conn:
yield cls(conn)
async def setup(self) -> None:
"""Set up the store database asynchronously.
This method creates the necessary tables in the DuckDB database if they don't
already exist and runs database migrations. It is called automatically when needed and should not be called
directly by the user.
"""
cur = self.conn.cursor()
try:
await asyncio.to_thread(
cur.execute, "SELECT v FROM store_migrations ORDER BY v DESC LIMIT 1"
)
row = await asyncio.to_thread(cur.fetchone)
if row is None:
version = -1
else:
version = row[0]
except duckdb.CatalogException:
version = -1
# Create store_migrations table if it doesn't exist
await asyncio.to_thread(
cur.execute,
"""
CREATE TABLE IF NOT EXISTS store_migrations (
v INTEGER PRIMARY KEY
)
""",
)
for v, migration in enumerate(
self.MIGRATIONS[version + 1 :], start=version + 1
):
await asyncio.to_thread(cur.execute, migration)
await asyncio.to_thread(
cur.execute, "INSERT INTO store_migrations (v) VALUES (?)", (v,)
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/langgraph/store/duckdb/base.py`:
```py
import asyncio
import json
import logging
from collections import defaultdict
from contextlib import contextmanager
from typing import (
Any,
Generic,
Iterable,
Iterator,
Sequence,
TypeVar,
Union,
cast,
)
import duckdb
from langgraph.store.base import (
BaseStore,
GetOp,
Item,
ListNamespacesOp,
Op,
PutOp,
Result,
SearchItem,
SearchOp,
)
logger = logging.getLogger(__name__)
MIGRATIONS = [
"""
CREATE TABLE IF NOT EXISTS store (
prefix TEXT NOT NULL,
key TEXT NOT NULL,
value JSON NOT NULL,
created_at TIMESTAMP DEFAULT now(),
updated_at TIMESTAMP DEFAULT now(),
PRIMARY KEY (prefix, key)
);
""",
"""
CREATE INDEX IF NOT EXISTS store_prefix_idx ON store (prefix);
""",
]
C = TypeVar("C", bound=duckdb.DuckDBPyConnection)
class BaseDuckDBStore(Generic[C]):
MIGRATIONS = MIGRATIONS
conn: C
def _get_batch_GET_ops_queries(
self,
get_ops: Sequence[tuple[int, GetOp]],
) -> list[tuple[str, tuple, tuple[str, ...], list]]:
namespace_groups = defaultdict(list)
for idx, op in get_ops:
namespace_groups[op.namespace].append((idx, op.key))
results = []
for namespace, items in namespace_groups.items():
_, keys = zip(*items)
keys_to_query = ",".join(["?"] * len(keys))
query = f"""
SELECT prefix, key, value, created_at, updated_at
FROM store
WHERE prefix = ? AND key IN ({keys_to_query})
"""
params = (_namespace_to_text(namespace), *keys)
results.append((query, params, namespace, items))
return results
def _get_batch_PUT_queries(
self,
put_ops: Sequence[tuple[int, PutOp]],
) -> list[tuple[str, Sequence]]:
inserts: list[PutOp] = []
deletes: list[PutOp] = []
for _, op in put_ops:
if op.value is None:
deletes.append(op)
else:
inserts.append(op)
queries: list[tuple[str, Sequence]] = []
if deletes:
namespace_groups: dict[tuple[str, ...], list[str]] = defaultdict(list)
for op in deletes:
namespace_groups[op.namespace].append(op.key)
for namespace, keys in namespace_groups.items():
placeholders = ",".join(["?"] * len(keys))
query = (
f"DELETE FROM store WHERE prefix = ? AND key IN ({placeholders})"
)
params = (_namespace_to_text(namespace), *keys)
queries.append((query, params))
if inserts:
values = []
insertion_params = []
for op in inserts:
values.append("(?, ?, ?, now(), now())")
insertion_params.extend(
[
_namespace_to_text(op.namespace),
op.key,
json.dumps(op.value),
]
)
values_str = ",".join(values)
query = f"""
INSERT INTO store (prefix, key, value, created_at, updated_at)
VALUES {values_str}
ON CONFLICT (prefix, key) DO UPDATE
SET value = EXCLUDED.value, updated_at = now()
"""
queries.append((query, insertion_params))
return queries
def _get_batch_search_queries(
self,
search_ops: Sequence[tuple[int, SearchOp]],
) -> list[tuple[str, Sequence]]:
queries: list[tuple[str, Sequence]] = []
for _, op in search_ops:
query = """
SELECT prefix, key, value, created_at, updated_at
FROM store
WHERE prefix LIKE ?
"""
params: list = [f"{_namespace_to_text(op.namespace_prefix)}%"]
if op.filter:
filter_conditions = []
for key, value in op.filter.items():
filter_conditions.append(f"json_extract(value, '$.{key}') = ?")
params.append(json.dumps(value))
query += " AND " + " AND ".join(filter_conditions)
query += " ORDER BY updated_at DESC LIMIT ? OFFSET ?"
params.extend([op.limit, op.offset])
queries.append((query, params))
return queries
def _get_batch_list_namespaces_queries(
self,
list_ops: Sequence[tuple[int, ListNamespacesOp]],
) -> list[tuple[str, Sequence]]:
queries: list[tuple[str, Sequence]] = []
for _, op in list_ops:
query = """
WITH split_prefix AS (
SELECT
prefix,
string_split(prefix, '.') AS parts
FROM store
)
SELECT DISTINCT ON (truncated_prefix)
CASE
WHEN ? IS NOT NULL THEN
array_to_string(array_slice(parts, 1, ?), '.')
ELSE prefix
END AS truncated_prefix,
prefix
FROM split_prefix
"""
params: list[Any] = [op.max_depth, op.max_depth]
conditions = []
if op.match_conditions:
for condition in op.match_conditions:
if condition.match_type == "prefix":
conditions.append("prefix LIKE ?")
params.append(
f"{_namespace_to_text(condition.path, handle_wildcards=True)}%"
)
elif condition.match_type == "suffix":
conditions.append("prefix LIKE ?")
params.append(
f"%{_namespace_to_text(condition.path, handle_wildcards=True)}"
)
else:
logger.warning(
f"Unknown match_type in list_namespaces: {condition.match_type}"
)
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY prefix LIMIT ? OFFSET ?"
params.extend([op.limit, op.offset])
queries.append((query, params))
return queries
class DuckDBStore(BaseStore, BaseDuckDBStore[duckdb.DuckDBPyConnection]):
def __init__(
self,
conn: duckdb.DuckDBPyConnection,
) -> None:
super().__init__()
self.conn = conn
def batch(self, ops: Iterable[Op]) -> list[Result]:
grouped_ops, num_ops = _group_ops(ops)
results: list[Result] = [None] * num_ops
if GetOp in grouped_ops:
self._batch_get_ops(
cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), results
)
if PutOp in grouped_ops:
self._batch_put_ops(cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]))
if SearchOp in grouped_ops:
self._batch_search_ops(
cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]),
results,
)
if ListNamespacesOp in grouped_ops:
self._batch_list_namespaces_ops(
cast(
Sequence[tuple[int, ListNamespacesOp]],
grouped_ops[ListNamespacesOp],
),
results,
)
return results
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
return await asyncio.get_running_loop().run_in_executor(None, self.batch, ops)
def _batch_get_ops(
self,
get_ops: Sequence[tuple[int, GetOp]],
results: list[Result],
) -> None:
cursors = []
for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops):
cur = self.conn.cursor()
cur.execute(query, params)
cursors.append((cur, namespace, items))
for cur, namespace, items in cursors:
rows = cur.fetchall()
key_to_row = {row[1]: row for row in rows}
for idx, key in items:
row = key_to_row.get(key)
if row:
results[idx] = _row_to_item(namespace, row)
else:
results[idx] = None
def _batch_put_ops(
self,
put_ops: Sequence[tuple[int, PutOp]],
) -> None:
queries = self._get_batch_PUT_queries(put_ops)
for query, params in queries:
cur = self.conn.cursor()
cur.execute(query, params)
def _batch_search_ops(
self,
search_ops: Sequence[tuple[int, SearchOp]],
results: list[Result],
) -> None:
queries = self._get_batch_search_queries(search_ops)
cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = []
for (query, params), (idx, _) in zip(queries, search_ops):
cur = self.conn.cursor()
cur.execute(query, params)
cursors.append((cur, idx))
for cur, idx in cursors:
rows = cur.fetchall()
items = [_row_to_search_item(_convert_ns(row[0]), row) for row in rows]
results[idx] = items
def _batch_list_namespaces_ops(
self,
list_ops: Sequence[tuple[int, ListNamespacesOp]],
results: list[Result],
) -> None:
queries = self._get_batch_list_namespaces_queries(list_ops)
cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = []
for (query, params), (idx, _) in zip(queries, list_ops):
cur = self.conn.cursor()
cur.execute(query, params)
cursors.append((cur, idx))
for cur, idx in cursors:
rows = cast(list[dict], cur.fetchall())
namespaces = [_convert_ns(row[0]) for row in rows]
results[idx] = namespaces
@classmethod
@contextmanager
def from_conn_string(
cls,
conn_string: str,
) -> Iterator["DuckDBStore"]:
"""Create a new BaseDuckDBStore instance from a connection string.
Args:
conn_string (str): The DuckDB connection info string.
Returns:
DuckDBStore: A new DuckDBStore instance.
"""
with duckdb.connect(conn_string) as conn:
yield cls(conn=conn)
def setup(self) -> None:
"""Set up the store database.
This method creates the necessary tables in the DuckDB database if they don't
already exist and runs database migrations. It is called automatically when needed and should not be called
directly by the user.
"""
with self.conn.cursor() as cur:
try:
cur.execute("SELECT v FROM store_migrations ORDER BY v DESC LIMIT 1")
row = cast(dict, cur.fetchone())
if row is None:
version = -1
else:
version = row["v"]
except duckdb.CatalogException:
version = -1
# Create store_migrations table if it doesn't exist
cur.execute(
"""
CREATE TABLE IF NOT EXISTS store_migrations (
v INTEGER PRIMARY KEY
)
"""
)
for v, migration in enumerate(
self.MIGRATIONS[version + 1 :], start=version + 1
):
cur.execute(migration)
cur.execute("INSERT INTO store_migrations (v) VALUES (?)", (v,))
def _namespace_to_text(
namespace: tuple[str, ...], handle_wildcards: bool = False
) -> str:
"""Convert namespace tuple to text string."""
if handle_wildcards:
namespace = tuple("%" if val == "*" else val for val in namespace)
return ".".join(namespace)
def _row_to_item(
namespace: tuple[str, ...],
row: tuple,
) -> Item:
"""Convert a row from the database into an Item."""
_, key, val, created_at, updated_at = row
return Item(
value=val if isinstance(val, dict) else json.loads(val),
key=key,
namespace=namespace,
created_at=created_at,
updated_at=updated_at,
)
def _row_to_search_item(
namespace: tuple[str, ...],
row: tuple,
) -> SearchItem:
"""Convert a row from the database into an SearchItem."""
# TODO: Add support for search
_, key, val, created_at, updated_at = row
return SearchItem(
value=val if isinstance(val, dict) else json.loads(val),
key=key,
namespace=namespace,
created_at=created_at,
updated_at=updated_at,
)
def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]:
grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list)
tot = 0
for idx, op in enumerate(ops):
grouped_ops[type(op)].append((idx, op))
tot += 1
return grouped_ops, tot
def _convert_ns(namespace: Union[str, list]) -> tuple[str, ...]:
if isinstance(namespace, list):
return tuple(namespace)
return tuple(namespace.split("."))
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/pyproject.toml`:
```toml
[tool.poetry]
name = "langgraph-checkpoint-duckdb"
version = "2.0.1"
description = "Library with a DuckDB implementation of LangGraph checkpoint saver."
authors = []
license = "MIT"
readme = "README.md"
repository = "https://www.github.com/langchain-ai/langgraph"
packages = [{ include = "langgraph" }]
[tool.poetry.dependencies]
python = "^3.9.0,<4.0"
langgraph-checkpoint = "^2.0.2"
duckdb = ">=1.1.2"
[tool.poetry.group.dev.dependencies]
ruff = "^0.6.2"
codespell = "^2.2.0"
pytest = "^7.2.1"
anyio = "^4.4.0"
pytest-asyncio = "^0.21.1"
pytest-mock = "^3.11.1"
pytest-watch = "^4.2.0"
mypy = "^1.10.0"
langgraph-checkpoint = {path = "../checkpoint", develop = true}
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5 -vv"
asyncio_mode = "auto"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
lint.select = [
"E", # pycodestyle
"F", # Pyflakes
"UP", # pyupgrade
"B", # flake8-bugbear
"I", # isort
]
lint.ignore = ["E501", "B008", "UP007", "UP006"]
[tool.mypy]
# https://mypy.readthedocs.io/en/stable/config_file.html
disallow_untyped_defs = "True"
explicit_package_bases = "True"
warn_no_return = "False"
warn_unused_ignores = "True"
warn_redundant_casts = "True"
allow_redefinition = "True"
disable_error_code = "typeddict-item, return-value"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/tests/test_sync.py`:
```py
from typing import Any
import pytest
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Checkpoint,
CheckpointMetadata,
create_checkpoint,
empty_checkpoint,
)
from langgraph.checkpoint.duckdb import DuckDBSaver
class TestDuckDBSaver:
@pytest.fixture(autouse=True)
def setup(self) -> None:
# objects for test setup
self.config_1: RunnableConfig = {
"configurable": {
"thread_id": "thread-1",
# for backwards compatibility testing
"thread_ts": "1",
"checkpoint_ns": "",
}
}
self.config_2: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2",
"checkpoint_ns": "",
}
}
self.config_3: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2-inner",
"checkpoint_ns": "inner",
}
}
self.chkpnt_1: Checkpoint = empty_checkpoint()
self.chkpnt_2: Checkpoint = create_checkpoint(self.chkpnt_1, {}, 1)
self.chkpnt_3: Checkpoint = empty_checkpoint()
self.metadata_1: CheckpointMetadata = {
"source": "input",
"step": 2,
"writes": {},
"score": 1,
}
self.metadata_2: CheckpointMetadata = {
"source": "loop",
"step": 1,
"writes": {"foo": "bar"},
"score": None,
}
self.metadata_3: CheckpointMetadata = {}
def test_search(self) -> None:
with DuckDBSaver.from_conn_string(":memory:") as saver:
saver.setup()
# save checkpoints
saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {})
saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {})
saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {})
# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match
search_results_1 = list(saver.list(None, filter=query_1))
assert len(search_results_1) == 1
assert search_results_1[0].metadata == self.metadata_1
search_results_2 = list(saver.list(None, filter=query_2))
assert len(search_results_2) == 1
assert search_results_2[0].metadata == self.metadata_2
search_results_3 = list(saver.list(None, filter=query_3))
assert len(search_results_3) == 3
search_results_4 = list(saver.list(None, filter=query_4))
assert len(search_results_4) == 0
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = list(
saver.list({"configurable": {"thread_id": "thread-2"}})
)
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}
# TODO: test before and limit params
def test_null_chars(self) -> None:
with DuckDBSaver.from_conn_string(":memory:") as saver:
saver.setup()
config = saver.put(self.config_1, self.chkpnt_1, {"my_key": "\x00abc"}, {})
assert saver.get_tuple(config).metadata["my_key"] == "abc" # type: ignore
assert (
list(saver.list(None, filter={"my_key": "abc"}))[0].metadata["my_key"] # type: ignore
== "abc"
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/tests/test_async.py`:
```py
from typing import Any
import pytest
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Checkpoint,
CheckpointMetadata,
create_checkpoint,
empty_checkpoint,
)
from langgraph.checkpoint.duckdb.aio import AsyncDuckDBSaver
class TestAsyncDuckDBSaver:
@pytest.fixture(autouse=True)
async def setup(self) -> None:
# objects for test setup
self.config_1: RunnableConfig = {
"configurable": {
"thread_id": "thread-1",
# for backwards compatibility testing
"thread_ts": "1",
"checkpoint_ns": "",
}
}
self.config_2: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2",
"checkpoint_ns": "",
}
}
self.config_3: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2-inner",
"checkpoint_ns": "inner",
}
}
self.chkpnt_1: Checkpoint = empty_checkpoint()
self.chkpnt_2: Checkpoint = create_checkpoint(self.chkpnt_1, {}, 1)
self.chkpnt_3: Checkpoint = empty_checkpoint()
self.metadata_1: CheckpointMetadata = {
"source": "input",
"step": 2,
"writes": {},
"score": 1,
}
self.metadata_2: CheckpointMetadata = {
"source": "loop",
"step": 1,
"writes": {"foo": "bar"},
"score": None,
}
self.metadata_3: CheckpointMetadata = {}
async def test_asearch(self) -> None:
async with AsyncDuckDBSaver.from_conn_string(":memory:") as saver:
await saver.setup()
await saver.aput(self.config_1, self.chkpnt_1, self.metadata_1, {})
await saver.aput(self.config_2, self.chkpnt_2, self.metadata_2, {})
await saver.aput(self.config_3, self.chkpnt_3, self.metadata_3, {})
# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match
search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
assert len(search_results_1) == 1
assert search_results_1[0].metadata == self.metadata_1
search_results_2 = [c async for c in saver.alist(None, filter=query_2)]
assert len(search_results_2) == 1
assert search_results_2[0].metadata == self.metadata_2
search_results_3 = [c async for c in saver.alist(None, filter=query_3)]
assert len(search_results_3) == 3
search_results_4 = [c async for c in saver.alist(None, filter=query_4)]
assert len(search_results_4) == 0
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = [
c
async for c in saver.alist({"configurable": {"thread_id": "thread-2"}})
]
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}
# TODO: test before and limit params
async def test_null_chars(self) -> None:
async with AsyncDuckDBSaver.from_conn_string(":memory:") as saver:
await saver.setup()
config = await saver.aput(
self.config_1, self.chkpnt_1, {"my_key": "\x00abc"}, {}
)
assert (await saver.aget_tuple(config)).metadata["my_key"] == "abc" # type: ignore
assert [c async for c in saver.alist(None, filter={"my_key": "abc"})][
0
].metadata["my_key"] == "abc"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/tests/test_store.py`:
```py
# type: ignore
import uuid
from datetime import datetime
from typing import Any
from unittest.mock import MagicMock
import pytest
from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp
from langgraph.store.duckdb import DuckDBStore
class MockCursor:
def __init__(self, fetch_result: Any) -> None:
self.fetch_result = fetch_result
self.execute = MagicMock()
self.fetchall = MagicMock(return_value=self.fetch_result)
class MockConnection:
def __init__(self) -> None:
self.cursor = MagicMock()
@pytest.fixture
def mock_connection() -> MockConnection:
return MockConnection()
@pytest.fixture
def store(mock_connection: MockConnection) -> DuckDBStore:
duck_db_store = DuckDBStore(mock_connection)
duck_db_store.setup()
return duck_db_store
def test_batch_order(store: DuckDBStore) -> None:
mock_connection = store.conn
mock_get_cursor = MockCursor(
[
(
"test.foo",
"key1",
'{"data": "value1"}',
datetime.now(),
datetime.now(),
),
(
"test.bar",
"key2",
'{"data": "value2"}',
datetime.now(),
datetime.now(),
),
]
)
mock_search_cursor = MockCursor(
[
(
"test.foo",
"key1",
'{"data": "value1"}',
datetime.now(),
datetime.now(),
),
]
)
mock_list_namespaces_cursor = MockCursor(
[
("test",),
]
)
failures = []
def cursor_side_effect() -> Any:
cursor = MagicMock()
def execute_side_effect(query: str, *params: Any) -> None:
# My super sophisticated database.
if "WHERE prefix = ? AND key" in query:
cursor.fetchall = mock_get_cursor.fetchall
elif "SELECT prefix, key, value" in query:
cursor.fetchall = mock_search_cursor.fetchall
elif "SELECT DISTINCT ON (truncated_prefix)" in query:
cursor.fetchall = mock_list_namespaces_cursor.fetchall
elif "INSERT INTO " in query:
pass
else:
e = ValueError(f"Unmatched query: {query}")
failures.append(e)
raise e
cursor.execute = MagicMock(side_effect=execute_side_effect)
return cursor
mock_connection.cursor.side_effect = cursor_side_effect
ops = [
GetOp(namespace=("test",), key="key1"),
PutOp(namespace=("test",), key="key2", value={"data": "value2"}),
SearchOp(
namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0
),
ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0),
GetOp(namespace=("test",), key="key3"),
]
results = store.batch(ops)
assert not failures
assert len(results) == 5
assert isinstance(results[0], Item)
assert isinstance(results[0].value, dict)
assert results[0].value == {"data": "value1"}
assert results[0].key == "key1"
assert results[1] is None
assert isinstance(results[2], list)
assert len(results[2]) == 1
assert isinstance(results[3], list)
assert results[3] == [("test",)]
assert results[4] is None
ops_reordered = [
SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0),
GetOp(namespace=("test",), key="key2"),
ListNamespacesOp(match_conditions=None, max_depth=None, limit=5, offset=0),
PutOp(namespace=("test",), key="key3", value={"data": "value3"}),
GetOp(namespace=("test",), key="key1"),
]
results_reordered = store.batch(ops_reordered)
assert not failures
assert len(results_reordered) == 5
assert isinstance(results_reordered[0], list)
assert len(results_reordered[0]) == 1
assert isinstance(results_reordered[1], Item)
assert results_reordered[1].value == {"data": "value2"}
assert results_reordered[1].key == "key2"
assert isinstance(results_reordered[2], list)
assert results_reordered[2] == [("test",)]
assert results_reordered[3] is None
assert isinstance(results_reordered[4], Item)
assert results_reordered[4].value == {"data": "value1"}
assert results_reordered[4].key == "key1"
def test_batch_get_ops(store: DuckDBStore) -> None:
mock_connection = store.conn
mock_cursor = MockCursor(
[
(
"test.foo",
"key1",
'{"data": "value1"}',
datetime.now(),
datetime.now(),
),
(
"test.bar",
"key2",
'{"data": "value2"}',
datetime.now(),
datetime.now(),
),
]
)
mock_connection.cursor.return_value = mock_cursor
ops = [
GetOp(namespace=("test",), key="key1"),
GetOp(namespace=("test",), key="key2"),
GetOp(namespace=("test",), key="key3"),
]
results = store.batch(ops)
assert len(results) == 3
assert results[0] is not None
assert results[1] is not None
assert results[2] is None
assert results[0].key == "key1"
assert results[1].key == "key2"
def test_batch_put_ops(store: DuckDBStore) -> None:
mock_connection = store.conn
mock_cursor = MockCursor([])
mock_connection.cursor.return_value = mock_cursor
ops = [
PutOp(namespace=("test",), key="key1", value={"data": "value1"}),
PutOp(namespace=("test",), key="key2", value={"data": "value2"}),
PutOp(namespace=("test",), key="key3", value=None),
]
results = store.batch(ops)
assert len(results) == 3
assert all(result is None for result in results)
assert mock_cursor.execute.call_count == 2
def test_batch_search_ops(store: DuckDBStore) -> None:
mock_connection = store.conn
mock_cursor = MockCursor(
[
(
"test.foo",
"key1",
'{"data": "value1"}',
datetime.now(),
datetime.now(),
),
(
"test.bar",
"key2",
'{"data": "value2"}',
datetime.now(),
datetime.now(),
),
]
)
mock_connection.cursor.return_value = mock_cursor
ops = [
SearchOp(
namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0
),
SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0),
]
results = store.batch(ops)
assert len(results) == 2
assert len(results[0]) == 2
assert len(results[1]) == 2
def test_batch_list_namespaces_ops(store: DuckDBStore) -> None:
mock_connection = store.conn
mock_cursor = MockCursor([("test.namespace1",), ("test.namespace2",)])
mock_connection.cursor.return_value = mock_cursor
ops = [ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0)]
results = store.batch(ops)
assert len(results) == 1
assert results[0] == [("test", "namespace1"), ("test", "namespace2")]
def test_basic_store_ops() -> None:
with DuckDBStore.from_conn_string(":memory:") as store:
store.setup()
namespace = ("test", "documents")
item_id = "doc1"
item_value = {"title": "Test Document", "content": "Hello, World!"}
store.put(namespace, item_id, item_value)
item = store.get(namespace, item_id)
assert item
assert item.namespace == namespace
assert item.key == item_id
assert item.value == item_value
updated_value = {
"title": "Updated Test Document",
"content": "Hello, LangGraph!",
}
store.put(namespace, item_id, updated_value)
updated_item = store.get(namespace, item_id)
assert updated_item.value == updated_value
assert updated_item.updated_at > item.updated_at
different_namespace = ("test", "other_documents")
item_in_different_namespace = store.get(different_namespace, item_id)
assert item_in_different_namespace is None
new_item_id = "doc2"
new_item_value = {"title": "Another Document", "content": "Greetings!"}
store.put(namespace, new_item_id, new_item_value)
search_results = store.search(["test"], limit=10)
items = search_results
assert len(items) == 2
assert any(item.key == item_id for item in items)
assert any(item.key == new_item_id for item in items)
namespaces = store.list_namespaces(prefix=["test"])
assert ("test", "documents") in namespaces
store.delete(namespace, item_id)
store.delete(namespace, new_item_id)
deleted_item = store.get(namespace, item_id)
assert deleted_item is None
deleted_item = store.get(namespace, new_item_id)
assert deleted_item is None
empty_search_results = store.search(["test"], limit=10)
assert len(empty_search_results) == 0
def test_list_namespaces() -> None:
with DuckDBStore.from_conn_string(":memory:") as store:
store.setup()
test_pref = str(uuid.uuid4())
test_namespaces = [
(test_pref, "test", "documents", "public", test_pref),
(test_pref, "test", "documents", "private", test_pref),
(test_pref, "test", "images", "public", test_pref),
(test_pref, "test", "images", "private", test_pref),
(test_pref, "prod", "documents", "public", test_pref),
(
test_pref,
"prod",
"documents",
"some",
"nesting",
"public",
test_pref,
),
(test_pref, "prod", "documents", "private", test_pref),
]
for namespace in test_namespaces:
store.put(namespace, "dummy", {"content": "dummy"})
prefix_result = store.list_namespaces(prefix=[test_pref, "test"])
assert len(prefix_result) == 4
assert all([ns[1] == "test" for ns in prefix_result])
specific_prefix_result = store.list_namespaces(
prefix=[test_pref, "test", "documents"]
)
assert len(specific_prefix_result) == 2
assert all([ns[1:3] == ("test", "documents") for ns in specific_prefix_result])
suffix_result = store.list_namespaces(suffix=["public", test_pref])
assert len(suffix_result) == 4
assert all(ns[-2] == "public" for ns in suffix_result)
prefix_suffix_result = store.list_namespaces(
prefix=[test_pref, "test"], suffix=["public", test_pref]
)
assert len(prefix_suffix_result) == 2
assert all(
ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result
)
wildcard_prefix_result = store.list_namespaces(
prefix=[test_pref, "*", "documents"]
)
assert len(wildcard_prefix_result) == 5
assert all(ns[2] == "documents" for ns in wildcard_prefix_result)
wildcard_suffix_result = store.list_namespaces(
suffix=["*", "public", test_pref]
)
assert len(wildcard_suffix_result) == 4
assert all(ns[-2] == "public" for ns in wildcard_suffix_result)
wildcard_single = store.list_namespaces(
suffix=["some", "*", "public", test_pref]
)
assert len(wildcard_single) == 1
assert wildcard_single[0] == (
test_pref,
"prod",
"documents",
"some",
"nesting",
"public",
test_pref,
)
max_depth_result = store.list_namespaces(max_depth=3)
assert all([len(ns) <= 3 for ns in max_depth_result])
max_depth_result = store.list_namespaces(
max_depth=4, prefix=[test_pref, "*", "documents"]
)
assert (
len(set(tuple(res) for res in max_depth_result))
== len(max_depth_result)
== 5
)
limit_result = store.list_namespaces(prefix=[test_pref], limit=3)
assert len(limit_result) == 3
offset_result = store.list_namespaces(prefix=[test_pref], offset=3)
assert len(offset_result) == len(test_namespaces) - 3
empty_prefix_result = store.list_namespaces(prefix=[test_pref])
assert len(empty_prefix_result) == len(test_namespaces)
assert set(tuple(ns) for ns in empty_prefix_result) == set(
tuple(ns) for ns in test_namespaces
)
for namespace in test_namespaces:
store.delete(namespace, "dummy")
def test_search():
with DuckDBStore.from_conn_string(":memory:") as store:
store.setup()
test_namespaces = [
("test_search", "documents", "user1"),
("test_search", "documents", "user2"),
("test_search", "reports", "department1"),
("test_search", "reports", "department2"),
]
test_items = [
{"title": "Doc 1", "author": "John Doe", "tags": ["important"]},
{"title": "Doc 2", "author": "Jane Smith", "tags": ["draft"]},
{"title": "Report A", "author": "John Doe", "tags": ["final"]},
{"title": "Report B", "author": "Alice Johnson", "tags": ["draft"]},
]
for namespace, item in zip(test_namespaces, test_items):
store.put(namespace, f"item_{namespace[-1]}", item)
docs_result = store.search(["test_search", "documents"])
assert len(docs_result) == 2
assert all(
[item.namespace[1] == "documents" for item in docs_result]
), docs_result
reports_result = store.search(["test_search", "reports"])
assert len(reports_result) == 2
assert all(item.namespace[1] == "reports" for item in reports_result)
limited_result = store.search(["test_search"], limit=2)
assert len(limited_result) == 2
offset_result = store.search(["test_search"])
assert len(offset_result) == 4
offset_result = store.search(["test_search"], offset=2)
assert len(offset_result) == 2
assert all(item not in limited_result for item in offset_result)
john_doe_result = store.search(["test_search"], filter={"author": "John Doe"})
assert len(john_doe_result) == 2
assert all(item.value["author"] == "John Doe" for item in john_doe_result)
draft_result = store.search(["test_search"], filter={"tags": ["draft"]})
assert len(draft_result) == 2
assert all("draft" in item.value["tags"] for item in draft_result)
page1 = store.search(["test_search"], limit=2, offset=0)
page2 = store.search(["test_search"], limit=2, offset=2)
all_items = page1 + page2
assert len(all_items) == 4
assert len(set(item.key for item in all_items)) == 4
for namespace in test_namespaces:
store.delete(namespace, f"item_{namespace[-1]}")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-duckdb/tests/test_async_store.py`:
```py
# type: ignore
import uuid
from datetime import datetime
from typing import Any
from unittest.mock import MagicMock
import pytest
from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp
from langgraph.store.duckdb import AsyncDuckDBStore
class MockCursor:
def __init__(self, fetch_result: Any) -> None:
self.fetch_result = fetch_result
self.execute = MagicMock()
self.fetchall = MagicMock(return_value=self.fetch_result)
class MockConnection:
def __init__(self) -> None:
self.cursor = MagicMock()
@pytest.fixture
def mock_connection() -> MockConnection:
return MockConnection()
@pytest.fixture
async def store(mock_connection: MockConnection) -> AsyncDuckDBStore:
duck_db_store = AsyncDuckDBStore(mock_connection)
await duck_db_store.setup()
return duck_db_store
async def test_abatch_order(store: AsyncDuckDBStore) -> None:
mock_connection = store.conn
mock_get_cursor = MockCursor(
[
(
"test.foo",
"key1",
'{"data": "value1"}',
datetime.now(),
datetime.now(),
),
(
"test.bar",
"key2",
'{"data": "value2"}',
datetime.now(),
datetime.now(),
),
]
)
mock_search_cursor = MockCursor(
[
(
"test.foo",
"key1",
'{"data": "value1"}',
datetime.now(),
datetime.now(),
),
]
)
mock_list_namespaces_cursor = MockCursor(
[
("test",),
]
)
failures = []
def cursor_side_effect() -> Any:
cursor = MagicMock()
def execute_side_effect(query: str, *params: Any) -> None:
# My super sophisticated database.
if "WHERE prefix = ? AND key" in query:
cursor.fetchall = mock_get_cursor.fetchall
elif "SELECT prefix, key, value" in query:
cursor.fetchall = mock_search_cursor.fetchall
elif "SELECT DISTINCT ON (truncated_prefix)" in query:
cursor.fetchall = mock_list_namespaces_cursor.fetchall
elif "INSERT INTO " in query:
pass
else:
e = ValueError(f"Unmatched query: {query}")
failures.append(e)
raise e
cursor.execute = MagicMock(side_effect=execute_side_effect)
return cursor
mock_connection.cursor.side_effect = cursor_side_effect # type: ignore
ops = [
GetOp(namespace=("test",), key="key1"),
PutOp(namespace=("test",), key="key2", value={"data": "value2"}),
SearchOp(
namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0
),
ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0),
GetOp(namespace=("test",), key="key3"),
]
results = await store.abatch(ops)
assert not failures
assert len(results) == 5
assert isinstance(results[0], Item)
assert isinstance(results[0].value, dict)
assert results[0].value == {"data": "value1"}
assert results[0].key == "key1"
assert results[1] is None
assert isinstance(results[2], list)
assert len(results[2]) == 1
assert isinstance(results[3], list)
assert results[3] == [("test",)]
assert results[4] is None
ops_reordered = [
SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0),
GetOp(namespace=("test",), key="key2"),
ListNamespacesOp(match_conditions=None, max_depth=None, limit=5, offset=0),
PutOp(namespace=("test",), key="key3", value={"data": "value3"}),
GetOp(namespace=("test",), key="key1"),
]
results_reordered = await store.abatch(ops_reordered)
assert not failures
assert len(results_reordered) == 5
assert isinstance(results_reordered[0], list)
assert len(results_reordered[0]) == 1
assert isinstance(results_reordered[1], Item)
assert results_reordered[1].value == {"data": "value2"}
assert results_reordered[1].key == "key2"
assert isinstance(results_reordered[2], list)
assert results_reordered[2] == [("test",)]
assert results_reordered[3] is None
assert isinstance(results_reordered[4], Item)
assert results_reordered[4].value == {"data": "value1"}
assert results_reordered[4].key == "key1"
async def test_batch_get_ops(store: AsyncDuckDBStore) -> None:
mock_connection = store.conn
mock_cursor = MockCursor(
[
(
"test.foo",
"key1",
'{"data": "value1"}',
datetime.now(),
datetime.now(),
),
(
"test.bar",
"key2",
'{"data": "value2"}',
datetime.now(),
datetime.now(),
),
]
)
mock_connection.cursor.return_value = mock_cursor
ops = [
GetOp(namespace=("test",), key="key1"),
GetOp(namespace=("test",), key="key2"),
GetOp(namespace=("test",), key="key3"),
]
results = await store.abatch(ops)
assert len(results) == 3
assert results[0] is not None
assert results[1] is not None
assert results[2] is None
assert results[0].key == "key1"
assert results[1].key == "key2"
async def test_batch_put_ops(store: AsyncDuckDBStore) -> None:
mock_connection = store.conn
mock_cursor = MockCursor([])
mock_connection.cursor.return_value = mock_cursor
ops = [
PutOp(namespace=("test",), key="key1", value={"data": "value1"}),
PutOp(namespace=("test",), key="key2", value={"data": "value2"}),
PutOp(namespace=("test",), key="key3", value=None),
]
results = await store.abatch(ops)
assert len(results) == 3
assert all(result is None for result in results)
assert mock_cursor.execute.call_count == 2
async def test_batch_search_ops(store: AsyncDuckDBStore) -> None:
mock_connection = store.conn
mock_cursor = MockCursor(
[
(
"test.foo",
"key1",
'{"data": "value1"}',
datetime.now(),
datetime.now(),
),
(
"test.bar",
"key2",
'{"data": "value2"}',
datetime.now(),
datetime.now(),
),
]
)
mock_connection.cursor.return_value = mock_cursor
ops = [
SearchOp(
namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0
),
SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0),
]
results = await store.abatch(ops)
assert len(results) == 2
assert len(results[0]) == 2
assert len(results[1]) == 2
async def test_batch_list_namespaces_ops(store: AsyncDuckDBStore) -> None:
mock_connection = store.conn
mock_cursor = MockCursor([("test.namespace1",), ("test.namespace2",)])
mock_connection.cursor.return_value = mock_cursor
ops = [ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0)]
results = await store.abatch(ops)
assert len(results) == 1
assert results[0] == [("test", "namespace1"), ("test", "namespace2")]
# The following use the actual DB connection
async def test_basic_store_ops() -> None:
async with AsyncDuckDBStore.from_conn_string(":memory:") as store:
await store.setup()
namespace = ("test", "documents")
item_id = "doc1"
item_value = {"title": "Test Document", "content": "Hello, World!"}
await store.aput(namespace, item_id, item_value)
item = await store.aget(namespace, item_id)
assert item
assert item.namespace == namespace
assert item.key == item_id
assert item.value == item_value
updated_value = {
"title": "Updated Test Document",
"content": "Hello, LangGraph!",
}
await store.aput(namespace, item_id, updated_value)
updated_item = await store.aget(namespace, item_id)
assert updated_item.value == updated_value
assert updated_item.updated_at > item.updated_at
different_namespace = ("test", "other_documents")
item_in_different_namespace = await store.aget(different_namespace, item_id)
assert item_in_different_namespace is None
new_item_id = "doc2"
new_item_value = {"title": "Another Document", "content": "Greetings!"}
await store.aput(namespace, new_item_id, new_item_value)
search_results = await store.asearch(["test"], limit=10)
items = search_results
assert len(items) == 2
assert any(item.key == item_id for item in items)
assert any(item.key == new_item_id for item in items)
namespaces = await store.alist_namespaces(prefix=["test"])
assert ("test", "documents") in namespaces
await store.adelete(namespace, item_id)
await store.adelete(namespace, new_item_id)
deleted_item = await store.aget(namespace, item_id)
assert deleted_item is None
deleted_item = await store.aget(namespace, new_item_id)
assert deleted_item is None
empty_search_results = await store.asearch(["test"], limit=10)
assert len(empty_search_results) == 0
async def test_list_namespaces() -> None:
async with AsyncDuckDBStore.from_conn_string(":memory:") as store:
await store.setup()
test_pref = str(uuid.uuid4())
test_namespaces = [
(test_pref, "test", "documents", "public", test_pref),
(test_pref, "test", "documents", "private", test_pref),
(test_pref, "test", "images", "public", test_pref),
(test_pref, "test", "images", "private", test_pref),
(test_pref, "prod", "documents", "public", test_pref),
(
test_pref,
"prod",
"documents",
"some",
"nesting",
"public",
test_pref,
),
(test_pref, "prod", "documents", "private", test_pref),
]
for namespace in test_namespaces:
await store.aput(namespace, "dummy", {"content": "dummy"})
prefix_result = await store.alist_namespaces(prefix=[test_pref, "test"])
assert len(prefix_result) == 4
assert all([ns[1] == "test" for ns in prefix_result])
specific_prefix_result = await store.alist_namespaces(
prefix=[test_pref, "test", "documents"]
)
assert len(specific_prefix_result) == 2
assert all([ns[1:3] == ("test", "documents") for ns in specific_prefix_result])
suffix_result = await store.alist_namespaces(suffix=["public", test_pref])
assert len(suffix_result) == 4
assert all(ns[-2] == "public" for ns in suffix_result)
prefix_suffix_result = await store.alist_namespaces(
prefix=[test_pref, "test"], suffix=["public", test_pref]
)
assert len(prefix_suffix_result) == 2
assert all(
ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result
)
wildcard_prefix_result = await store.alist_namespaces(
prefix=[test_pref, "*", "documents"]
)
assert len(wildcard_prefix_result) == 5
assert all(ns[2] == "documents" for ns in wildcard_prefix_result)
wildcard_suffix_result = await store.alist_namespaces(
suffix=["*", "public", test_pref]
)
assert len(wildcard_suffix_result) == 4
assert all(ns[-2] == "public" for ns in wildcard_suffix_result)
wildcard_single = await store.alist_namespaces(
suffix=["some", "*", "public", test_pref]
)
assert len(wildcard_single) == 1
assert wildcard_single[0] == (
test_pref,
"prod",
"documents",
"some",
"nesting",
"public",
test_pref,
)
max_depth_result = await store.alist_namespaces(max_depth=3)
assert all([len(ns) <= 3 for ns in max_depth_result])
max_depth_result = await store.alist_namespaces(
max_depth=4, prefix=[test_pref, "*", "documents"]
)
assert (
len(set(tuple(res) for res in max_depth_result))
== len(max_depth_result)
== 5
)
limit_result = await store.alist_namespaces(prefix=[test_pref], limit=3)
assert len(limit_result) == 3
offset_result = await store.alist_namespaces(prefix=[test_pref], offset=3)
assert len(offset_result) == len(test_namespaces) - 3
empty_prefix_result = await store.alist_namespaces(prefix=[test_pref])
assert len(empty_prefix_result) == len(test_namespaces)
assert set(tuple(ns) for ns in empty_prefix_result) == set(
tuple(ns) for ns in test_namespaces
)
for namespace in test_namespaces:
await store.adelete(namespace, "dummy")
async def test_search():
async with AsyncDuckDBStore.from_conn_string(":memory:") as store:
await store.setup()
test_namespaces = [
("test_search", "documents", "user1"),
("test_search", "documents", "user2"),
("test_search", "reports", "department1"),
("test_search", "reports", "department2"),
]
test_items = [
{"title": "Doc 1", "author": "John Doe", "tags": ["important"]},
{"title": "Doc 2", "author": "Jane Smith", "tags": ["draft"]},
{"title": "Report A", "author": "John Doe", "tags": ["final"]},
{"title": "Report B", "author": "Alice Johnson", "tags": ["draft"]},
]
empty = await store.asearch(
(
"scoped",
"assistant_id",
"shared",
"6c5356f6-63ab-4158-868d-cd9fd14c736e",
),
limit=10,
offset=0,
)
assert len(empty) == 0
for namespace, item in zip(test_namespaces, test_items):
await store.aput(namespace, f"item_{namespace[-1]}", item)
docs_result = await store.asearch(["test_search", "documents"])
assert len(docs_result) == 2
assert all([item.namespace[1] == "documents" for item in docs_result]), [
item.namespace for item in docs_result
]
reports_result = await store.asearch(["test_search", "reports"])
assert len(reports_result) == 2
assert all(item.namespace[1] == "reports" for item in reports_result)
limited_result = await store.asearch(["test_search"], limit=2)
assert len(limited_result) == 2
offset_result = await store.asearch(["test_search"])
assert len(offset_result) == 4
offset_result = await store.asearch(["test_search"], offset=2)
assert len(offset_result) == 2
assert all(item not in limited_result for item in offset_result)
john_doe_result = await store.asearch(
["test_search"], filter={"author": "John Doe"}
)
assert len(john_doe_result) == 2
assert all(item.value["author"] == "John Doe" for item in john_doe_result)
draft_result = await store.asearch(["test_search"], filter={"tags": ["draft"]})
assert len(draft_result) == 2
assert all("draft" in item.value["tags"] for item in draft_result)
page1 = await store.asearch(["test_search"], limit=2, offset=0)
page2 = await store.asearch(["test_search"], limit=2, offset=2)
all_items = page1 + page2
assert len(all_items) == 4
assert len(set(item.key for item in all_items)) == 4
empty = await store.asearch(
(
"scoped",
"assistant_id",
"shared",
"again",
"maybe",
"some-long",
"6be5cb0e-2eb4-42e6-bb6b-fba3c269db25",
),
limit=10,
offset=0,
)
assert len(empty) == 0
# Test with a namespace beginning with a number (like a UUID)
uuid_namespace = (str(uuid.uuid4()), "documents")
uuid_item_id = "uuid_doc"
uuid_item_value = {
"title": "UUID Document",
"content": "This document has a UUID namespace.",
}
# Insert the item with the UUID namespace
await store.aput(uuid_namespace, uuid_item_id, uuid_item_value)
# Retrieve the item to verify it was stored correctly
retrieved_item = await store.aget(uuid_namespace, uuid_item_id)
assert retrieved_item is not None
assert retrieved_item.namespace == uuid_namespace
assert retrieved_item.key == uuid_item_id
assert retrieved_item.value == uuid_item_value
# Search for the item using the UUID namespace
search_result = await store.asearch([uuid_namespace[0]])
assert len(search_result) == 1
assert search_result[0].key == uuid_item_id
assert search_result[0].value == uuid_item_value
# Clean up: delete the item with the UUID namespace
await store.adelete(uuid_namespace, uuid_item_id)
# Verify the item was deleted
deleted_item = await store.aget(uuid_namespace, uuid_item_id)
assert deleted_item is None
for namespace in test_namespaces:
await store.adelete(namespace, f"item_{namespace[-1]}")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/default_sync.py`:
```py
import concurrent.futures
from typing import Optional, Sequence
from kafka import KafkaConsumer, KafkaProducer
from langgraph.scheduler.kafka.types import ConsumerRecord, TopicPartition
class DefaultConsumer(KafkaConsumer):
def getmany(
self, timeout_ms: int, max_records: int
) -> dict[TopicPartition, Sequence[ConsumerRecord]]:
return self.poll(timeout_ms=timeout_ms, max_records=max_records)
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
class DefaultProducer(KafkaProducer):
def send(
self,
topic: str,
*,
key: Optional[bytes] = None,
value: Optional[bytes] = None,
) -> concurrent.futures.Future:
fut = concurrent.futures.Future()
kfut = super().send(topic, key=key, value=value)
kfut.add_callback(fut.set_result)
kfut.add_errback(fut.set_exception)
return fut
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py`:
```py
import aiokafka
class DefaultAsyncConsumer(aiokafka.AIOKafkaConsumer):
pass
class DefaultAsyncProducer(aiokafka.AIOKafkaProducer):
pass
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/serde.py`:
```py
from typing import Any
import orjson
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
SERIALIZER = JsonPlusSerializer()
def loads(v: bytes) -> Any:
return SERIALIZER.loads(v)
def dumps(v: Any) -> bytes:
return orjson.dumps(v, default=_default)
def _default(v: Any) -> Any:
# things we don't know how to serialize (eg. functions) ignore
return None
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/types.py`:
```py
import asyncio
import concurrent.futures
from typing import Any, NamedTuple, Optional, Protocol, Sequence, TypedDict, Union
from langchain_core.runnables import RunnableConfig
class Topics(NamedTuple):
orchestrator: str
executor: str
error: str
class Sendable(TypedDict):
topic: str
value: Optional[Any]
key: Optional[Any]
class MessageToOrchestrator(TypedDict):
input: Optional[dict[str, Any]]
config: RunnableConfig
finally_send: Optional[Sequence[Sendable]]
class ExecutorTask(TypedDict):
id: Optional[str]
path: tuple[Union[str, int], ...]
class MessageToExecutor(TypedDict):
config: RunnableConfig
task: ExecutorTask
finally_send: Optional[Sequence[Sendable]]
class ErrorMessage(TypedDict):
topic: str
error: str
msg: Union[MessageToExecutor, MessageToOrchestrator]
class TopicPartition(Protocol):
topic: str
partition: int
class ConsumerRecord(Protocol):
topic: str
"The topic this record is received from"
partition: int
"The partition from which this record is received"
offset: int
"The position of this record in the corresponding Kafka partition."
timestamp: int
"The timestamp of this record"
timestamp_type: int
"The timestamp type of this record"
key: Optional[bytes]
"The key (or `None` if no key is specified)"
value: Optional[bytes]
"The value"
class Consumer(Protocol):
def getmany(
self, timeout_ms: int, max_records: int
) -> dict[TopicPartition, Sequence[ConsumerRecord]]: ...
def commit(self) -> None: ...
class AsyncConsumer(Protocol):
async def getmany(
self, timeout_ms: int, max_records: int
) -> dict[TopicPartition, Sequence[ConsumerRecord]]: ...
async def commit(self) -> None: ...
class Producer(Protocol):
def send(
self,
topic: str,
*,
key: Optional[bytes] = None,
value: Optional[bytes] = None,
) -> concurrent.futures.Future: ...
class AsyncProducer(Protocol):
async def send(
self,
topic: str,
*,
key: Optional[bytes] = None,
value: Optional[bytes] = None,
) -> asyncio.Future: ...
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/retry.py`:
```py
import asyncio
import logging
import random
import time
from typing import Awaitable, Callable, Optional
from typing_extensions import ParamSpec
from langgraph.types import RetryPolicy
logger = logging.getLogger(__name__)
P = ParamSpec("P")
def retry(
retry_policy: Optional[RetryPolicy],
func: Callable[P, None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Run a task asynchronously with retries."""
interval = retry_policy.initial_interval if retry_policy else 0
attempts = 0
while True:
try:
func(*args, **kwargs)
# if successful, end
break
except Exception as exc:
if retry_policy is None:
raise
# increment attempts
attempts += 1
# check if we should retry
if callable(retry_policy.retry_on):
if not retry_policy.retry_on(exc):
raise
elif not isinstance(exc, retry_policy.retry_on):
raise
# check if we should give up
if attempts >= retry_policy.max_attempts:
raise
# sleep before retrying
interval = min(
retry_policy.max_interval,
interval * retry_policy.backoff_factor,
)
time.sleep(
interval + random.uniform(0, 1) if retry_policy.jitter else interval
)
# log the retry
logger.info(
f"Retrying function {func} with {args} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
exc_info=exc,
)
async def aretry(
retry_policy: Optional[RetryPolicy],
func: Callable[P, Awaitable[None]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Run a task asynchronously with retries."""
interval = retry_policy.initial_interval if retry_policy else 0
attempts = 0
while True:
try:
await func(*args, **kwargs)
# if successful, end
break
except Exception as exc:
if retry_policy is None:
raise
# increment attempts
attempts += 1
# check if we should retry
if callable(retry_policy.retry_on):
if not retry_policy.retry_on(exc):
raise
elif not isinstance(exc, retry_policy.retry_on):
raise
# check if we should give up
if attempts >= retry_policy.max_attempts:
raise
# sleep before retrying
interval = min(
retry_policy.max_interval,
interval * retry_policy.backoff_factor,
)
await asyncio.sleep(
interval + random.uniform(0, 1) if retry_policy.jitter else interval
)
# log the retry
logger.info(
f"Retrying function {func} with {args} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
exc_info=exc,
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py`:
```py
import asyncio
import concurrent.futures
from contextlib import (
AbstractAsyncContextManager,
AbstractContextManager,
AsyncExitStack,
ExitStack,
)
from typing import Any, Optional
from langchain_core.runnables import ensure_config
from typing_extensions import Self
import langgraph.scheduler.kafka.serde as serde
from langgraph.constants import (
CONFIG_KEY_DEDUPE_TASKS,
CONFIG_KEY_ENSURE_LATEST,
INTERRUPT,
NS_END,
NS_SEP,
SCHEDULED,
)
from langgraph.errors import CheckpointNotLatest, GraphInterrupt
from langgraph.pregel import Pregel
from langgraph.pregel.executor import BackgroundExecutor, Submit
from langgraph.pregel.loop import AsyncPregelLoop, SyncPregelLoop
from langgraph.scheduler.kafka.retry import aretry, retry
from langgraph.scheduler.kafka.types import (
AsyncConsumer,
AsyncProducer,
Consumer,
ErrorMessage,
ExecutorTask,
MessageToExecutor,
MessageToOrchestrator,
Producer,
Topics,
)
from langgraph.types import RetryPolicy
from langgraph.utils.config import patch_configurable
class AsyncKafkaOrchestrator(AbstractAsyncContextManager):
consumer: AsyncConsumer
producer: AsyncProducer
def __init__(
self,
graph: Pregel,
topics: Topics,
batch_max_n: int = 10,
batch_max_ms: int = 1000,
retry_policy: Optional[RetryPolicy] = None,
consumer: Optional[AsyncConsumer] = None,
producer: Optional[AsyncProducer] = None,
**kwargs: Any,
) -> None:
self.graph = graph
self.topics = topics
self.stack = AsyncExitStack()
self.kwargs = kwargs
self.consumer = consumer
self.producer = producer
self.batch_max_n = batch_max_n
self.batch_max_ms = batch_max_ms
self.retry_policy = retry_policy
async def __aenter__(self) -> Self:
loop = asyncio.get_running_loop()
self.subgraphs = {
k: v async for k, v in self.graph.aget_subgraphs(recurse=True)
}
if self.consumer is None:
from langgraph.scheduler.kafka.default_async import DefaultAsyncConsumer
self.consumer = await self.stack.enter_async_context(
DefaultAsyncConsumer(
self.topics.orchestrator,
auto_offset_reset="earliest",
group_id="orchestrator",
enable_auto_commit=False,
loop=loop,
**self.kwargs,
)
)
if self.producer is None:
from langgraph.scheduler.kafka.default_async import DefaultAsyncProducer
self.producer = await self.stack.enter_async_context(
DefaultAsyncProducer(
loop=loop,
**self.kwargs,
)
)
return self
async def __aexit__(self, *args: Any) -> None:
return await self.stack.__aexit__(*args)
def __aiter__(self) -> Self:
return self
async def __anext__(self) -> list[MessageToOrchestrator]:
# wait for next batch
recs = await self.consumer.getmany(
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
# dedupe messages, eg. if multiple nodes finish around same time
uniq = set(msg.value for msgs in recs.values() for msg in msgs)
msgs: list[MessageToOrchestrator] = [serde.loads(msg) for msg in uniq]
# process batch
await asyncio.gather(*(self.each(msg) for msg in msgs))
# commit offsets
await self.consumer.commit()
# return message
return msgs
async def each(self, msg: MessageToOrchestrator) -> None:
try:
await aretry(self.retry_policy, self.attempt, msg)
except CheckpointNotLatest:
pass
except GraphInterrupt:
pass
except Exception as exc:
fut = await self.producer.send(
self.topics.error,
value=serde.dumps(
ErrorMessage(
topic=self.topics.orchestrator,
msg=msg,
error=repr(exc),
)
),
)
await fut
async def attempt(self, msg: MessageToOrchestrator) -> None:
# find graph
if checkpoint_ns := msg["config"]["configurable"].get("checkpoint_ns"):
# remove task_ids from checkpoint_ns
recast_checkpoint_ns = NS_SEP.join(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
if recast_checkpoint_ns in self.subgraphs:
graph = self.subgraphs[recast_checkpoint_ns]
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")
else:
graph = self.graph
# process message
async with AsyncPregelLoop(
msg["input"],
config=ensure_config(msg["config"]),
stream=None,
store=self.graph.store,
checkpointer=self.graph.checkpointer,
nodes=graph.nodes,
specs=graph.channels,
output_keys=graph.output_channels,
stream_keys=graph.stream_channels,
interrupt_after=graph.interrupt_after_nodes,
interrupt_before=graph.interrupt_before_nodes,
check_subgraphs=False,
) as loop:
if loop.tick(input_keys=graph.input_channels):
# wait for checkpoint to be saved
if hasattr(loop, "_put_checkpoint_fut"):
await loop._put_checkpoint_fut
# schedule any new tasks
if new_tasks := [
t for t in loop.tasks.values() if not t.scheduled and not t.writes
]:
# send messages to executor
futures = await asyncio.gather(
*(
self.producer.send(
self.topics.executor,
value=serde.dumps(
MessageToExecutor(
config=patch_configurable(
loop.config,
{
**loop.checkpoint_config[
"configurable"
],
CONFIG_KEY_DEDUPE_TASKS: True,
CONFIG_KEY_ENSURE_LATEST: True,
},
),
task=ExecutorTask(id=task.id, path=task.path),
finally_send=msg.get("finally_send"),
)
),
)
for task in new_tasks
)
)
# wait for messages to be sent
await asyncio.gather(*futures)
# mark as scheduled
for task in new_tasks:
loop.put_writes(
task.id,
[
(
SCHEDULED,
max(
loop.checkpoint["versions_seen"]
.get(INTERRUPT, {})
.values(),
default=None,
),
)
],
)
elif loop.status == "done" and msg.get("finally_send"):
# send any finally_send messages
futs = await asyncio.gather(
*(
self.producer.send(
m["topic"],
value=serde.dumps(m["value"]) if m.get("value") else None,
key=serde.dumps(m["key"]) if m.get("key") else None,
)
for m in msg["finally_send"]
)
)
# wait for messages to be sent
await asyncio.gather(*futs)
class KafkaOrchestrator(AbstractContextManager):
consumer: Consumer
producer: Producer
submit: Submit
def __init__(
self,
graph: Pregel,
topics: Topics,
batch_max_n: int = 10,
batch_max_ms: int = 1000,
retry_policy: Optional[RetryPolicy] = None,
consumer: Optional[Consumer] = None,
producer: Optional[Producer] = None,
**kwargs: Any,
) -> None:
self.graph = graph
self.topics = topics
self.stack = ExitStack()
self.kwargs = kwargs
self.consumer = consumer
self.producer = producer
self.batch_max_n = batch_max_n
self.batch_max_ms = batch_max_ms
self.retry_policy = retry_policy
def __enter__(self) -> Self:
self.subgraphs = dict(self.graph.get_subgraphs(recurse=True))
self.submit = self.stack.enter_context(BackgroundExecutor({}))
if self.consumer is None:
from langgraph.scheduler.kafka.default_sync import DefaultConsumer
self.consumer = self.stack.enter_context(
DefaultConsumer(
self.topics.orchestrator,
auto_offset_reset="earliest",
group_id="orchestrator",
enable_auto_commit=False,
**self.kwargs,
)
)
if self.producer is None:
from langgraph.scheduler.kafka.default_sync import DefaultProducer
self.producer = self.stack.enter_context(
DefaultProducer(
**self.kwargs,
)
)
return self
def __exit__(self, *args: Any) -> None:
return self.stack.__exit__(*args)
def __iter__(self) -> Self:
return self
def __next__(self) -> list[MessageToOrchestrator]:
# wait for next batch
recs = self.consumer.getmany(
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
# dedupe messages, eg. if multiple nodes finish around same time
uniq = set(msg.value for msgs in recs.values() for msg in msgs)
msgs: list[MessageToOrchestrator] = [serde.loads(msg) for msg in uniq]
# process batch
concurrent.futures.wait(self.submit(self.each, msg) for msg in msgs)
# commit offsets
self.consumer.commit()
# return message
return msgs
def each(self, msg: MessageToOrchestrator) -> None:
try:
retry(self.retry_policy, self.attempt, msg)
except CheckpointNotLatest:
pass
except GraphInterrupt:
pass
except Exception as exc:
fut = self.producer.send(
self.topics.error,
value=serde.dumps(
ErrorMessage(
topic=self.topics.orchestrator,
msg=msg,
error=repr(exc),
)
),
)
fut.result()
def attempt(self, msg: MessageToOrchestrator) -> None:
# find graph
if checkpoint_ns := msg["config"]["configurable"].get("checkpoint_ns"):
# remove task_ids from checkpoint_ns
recast_checkpoint_ns = NS_SEP.join(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
if recast_checkpoint_ns in self.subgraphs:
graph = self.subgraphs[recast_checkpoint_ns]
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")
else:
graph = self.graph
# process message
with SyncPregelLoop(
msg["input"],
config=ensure_config(msg["config"]),
stream=None,
store=self.graph.store,
checkpointer=self.graph.checkpointer,
nodes=graph.nodes,
specs=graph.channels,
output_keys=graph.output_channels,
stream_keys=graph.stream_channels,
interrupt_after=graph.interrupt_after_nodes,
interrupt_before=graph.interrupt_before_nodes,
check_subgraphs=False,
) as loop:
if loop.tick(input_keys=graph.input_channels):
# wait for checkpoint to be saved
if hasattr(loop, "_put_checkpoint_fut"):
loop._put_checkpoint_fut.result()
# schedule any new tasks
if new_tasks := [
t for t in loop.tasks.values() if not t.scheduled and not t.writes
]:
# send messages to executor
futures = [
self.producer.send(
self.topics.executor,
value=serde.dumps(
MessageToExecutor(
config=patch_configurable(
loop.config,
{
**loop.checkpoint_config["configurable"],
CONFIG_KEY_DEDUPE_TASKS: True,
CONFIG_KEY_ENSURE_LATEST: True,
},
),
task=ExecutorTask(id=task.id, path=task.path),
finally_send=msg.get("finally_send"),
)
),
)
for task in new_tasks
]
# wait for messages to be sent
concurrent.futures.wait(futures)
# mark as scheduled
for task in new_tasks:
loop.put_writes(
task.id,
[
(
SCHEDULED,
max(
loop.checkpoint["versions_seen"]
.get(INTERRUPT, {})
.values(),
default=None,
),
)
],
)
elif loop.status == "done" and msg.get("finally_send"):
# schedule any finally_send msgs
futs = [
self.producer.send(
m["topic"],
value=serde.dumps(m["value"]) if m.get("value") else None,
key=serde.dumps(m["key"]) if m.get("key") else None,
)
for m in msg["finally_send"]
]
# wait for messages to be sent
concurrent.futures.wait(futs)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py`:
```py
import asyncio
import concurrent.futures
from contextlib import (
AbstractAsyncContextManager,
AbstractContextManager,
AsyncExitStack,
ExitStack,
)
from functools import partial
from typing import Any, Optional, Sequence
from uuid import UUID
import orjson
from langchain_core.runnables import RunnableConfig
from typing_extensions import Self
import langgraph.scheduler.kafka.serde as serde
from langgraph.constants import CONFIG_KEY_DELEGATE, ERROR, NS_END, NS_SEP
from langgraph.errors import CheckpointNotLatest, GraphDelegate, TaskNotFound
from langgraph.pregel import Pregel
from langgraph.pregel.algo import prepare_single_task
from langgraph.pregel.executor import (
AsyncBackgroundExecutor,
BackgroundExecutor,
Submit,
)
from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager
from langgraph.pregel.runner import PregelRunner
from langgraph.scheduler.kafka.retry import aretry, retry
from langgraph.scheduler.kafka.types import (
AsyncConsumer,
AsyncProducer,
Consumer,
ErrorMessage,
MessageToExecutor,
MessageToOrchestrator,
Producer,
Sendable,
Topics,
)
from langgraph.types import LoopProtocol, PregelExecutableTask, RetryPolicy
from langgraph.utils.config import patch_configurable
class AsyncKafkaExecutor(AbstractAsyncContextManager):
consumer: AsyncConsumer
producer: AsyncProducer
def __init__(
self,
graph: Pregel,
topics: Topics,
*,
batch_max_n: int = 10,
batch_max_ms: int = 1000,
retry_policy: Optional[RetryPolicy] = None,
consumer: Optional[AsyncConsumer] = None,
producer: Optional[AsyncProducer] = None,
**kwargs: Any,
) -> None:
self.graph = graph
self.topics = topics
self.stack = AsyncExitStack()
self.kwargs = kwargs
self.consumer = consumer
self.producer = producer
self.batch_max_n = batch_max_n
self.batch_max_ms = batch_max_ms
self.retry_policy = retry_policy
async def __aenter__(self) -> Self:
loop = asyncio.get_running_loop()
self.subgraphs = {
k: v async for k, v in self.graph.aget_subgraphs(recurse=True)
}
if self.consumer is None:
from langgraph.scheduler.kafka.default_async import DefaultAsyncConsumer
self.consumer = await self.stack.enter_async_context(
DefaultAsyncConsumer(
self.topics.executor,
auto_offset_reset="earliest",
group_id="executor",
enable_auto_commit=False,
loop=loop,
**self.kwargs,
)
)
if self.producer is None:
from langgraph.scheduler.kafka.default_async import DefaultAsyncProducer
self.producer = await self.stack.enter_async_context(
DefaultAsyncProducer(
loop=loop,
**self.kwargs,
)
)
return self
async def __aexit__(self, *args: Any) -> None:
return await self.stack.__aexit__(*args)
def __aiter__(self) -> Self:
return self
async def __anext__(self) -> Sequence[MessageToExecutor]:
# wait for next batch
recs = await self.consumer.getmany(
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
msgs: list[MessageToExecutor] = [
serde.loads(msg.value) for msgs in recs.values() for msg in msgs
]
# process batch
await asyncio.gather(*(self.each(msg) for msg in msgs))
# commit offsets
await self.consumer.commit()
# return message
return msgs
async def each(self, msg: MessageToExecutor) -> None:
try:
await aretry(self.retry_policy, self.attempt, msg)
except CheckpointNotLatest:
pass
except GraphDelegate as exc:
for arg in exc.args:
fut = await self.producer.send(
self.topics.orchestrator,
value=serde.dumps(
MessageToOrchestrator(
config=arg["config"],
input=orjson.Fragment(
self.graph.checkpointer.serde.dumps(arg["input"])
),
finally_send=[
Sendable(topic=self.topics.executor, value=msg)
],
)
),
# use thread_id, checkpoint_ns as partition key
key=serde.dumps(
(
arg["config"]["configurable"]["thread_id"],
arg["config"]["configurable"].get("checkpoint_ns"),
)
),
)
await fut
except Exception as exc:
fut = await self.producer.send(
self.topics.error,
value=serde.dumps(
ErrorMessage(
topic=self.topics.executor,
msg=msg,
error=repr(exc),
)
),
)
await fut
async def attempt(self, msg: MessageToExecutor) -> None:
# find graph
if checkpoint_ns := msg["config"]["configurable"].get("checkpoint_ns"):
# remove task_ids from checkpoint_ns
recast_checkpoint_ns = NS_SEP.join(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
if recast_checkpoint_ns in self.subgraphs:
graph = self.subgraphs[recast_checkpoint_ns]
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")
else:
graph = self.graph
# process message
saved = await self.graph.checkpointer.aget_tuple(
patch_configurable(msg["config"], {"checkpoint_id": None})
)
if saved is None:
raise RuntimeError("Checkpoint not found")
if saved.checkpoint["id"] != msg["config"]["configurable"]["checkpoint_id"]:
raise CheckpointNotLatest()
async with AsyncChannelsManager(
graph.channels,
saved.checkpoint,
LoopProtocol(
config=msg["config"],
store=self.graph.store,
step=saved.metadata["step"] + 1,
stop=saved.metadata["step"] + 2,
),
) as (channels, managed), AsyncBackgroundExecutor(msg["config"]) as submit:
if task := await asyncio.to_thread(
prepare_single_task,
msg["task"]["path"],
msg["task"]["id"],
checkpoint=saved.checkpoint,
pending_writes=saved.pending_writes or [],
processes=graph.nodes,
channels=channels,
managed=managed,
config=patch_configurable(msg["config"], {CONFIG_KEY_DELEGATE: True}),
step=saved.metadata["step"] + 1,
for_execution=True,
checkpointer=self.graph.checkpointer,
store=self.graph.store,
):
# execute task, saving writes
runner = PregelRunner(
submit=submit,
put_writes=partial(self._put_writes, submit, msg["config"]),
schedule_task=self._schedule_task,
)
async for _ in runner.atick([task], reraise=False):
pass
else:
# task was not found
await self.graph.checkpointer.aput_writes(
msg["config"], [(ERROR, TaskNotFound())], str(UUID(int=0))
)
# notify orchestrator
fut = await self.producer.send(
self.topics.orchestrator,
value=serde.dumps(
MessageToOrchestrator(
input=None,
config=msg["config"],
finally_send=msg.get("finally_send"),
)
),
# use thread_id, checkpoint_ns as partition key
key=serde.dumps(
(
msg["config"]["configurable"]["thread_id"],
msg["config"]["configurable"].get("checkpoint_ns"),
)
),
)
await fut
def _schedule_task(
self,
task: PregelExecutableTask,
idx: int,
) -> None:
# will be scheduled by orchestrator when executor finishes
pass
def _put_writes(
self,
submit: Submit,
config: RunnableConfig,
task_id: str,
writes: list[tuple[str, Any]],
) -> None:
return submit(self.graph.checkpointer.aput_writes, config, writes, task_id)
class KafkaExecutor(AbstractContextManager):
consumer: Consumer
producer: Producer
def __init__(
self,
graph: Pregel,
topics: Topics,
*,
batch_max_n: int = 10,
batch_max_ms: int = 1000,
retry_policy: Optional[RetryPolicy] = None,
consumer: Optional[Consumer] = None,
producer: Optional[Producer] = None,
**kwargs: Any,
) -> None:
self.graph = graph
self.topics = topics
self.stack = ExitStack()
self.kwargs = kwargs
self.consumer = consumer
self.producer = producer
self.batch_max_n = batch_max_n
self.batch_max_ms = batch_max_ms
self.retry_policy = retry_policy
def __enter__(self) -> Self:
self.subgraphs = dict(self.graph.get_subgraphs(recurse=True))
self.submit = self.stack.enter_context(BackgroundExecutor({}))
if self.consumer is None:
from langgraph.scheduler.kafka.default_sync import DefaultConsumer
self.consumer = self.stack.enter_context(
DefaultConsumer(
self.topics.executor,
auto_offset_reset="earliest",
group_id="executor",
enable_auto_commit=False,
**self.kwargs,
)
)
if self.producer is None:
from langgraph.scheduler.kafka.default_sync import DefaultProducer
self.producer = self.stack.enter_context(
DefaultProducer(
**self.kwargs,
)
)
return self
def __exit__(self, *args: Any) -> None:
return self.stack.__exit__(*args)
def __iter__(self) -> Self:
return self
def __next__(self) -> Sequence[MessageToExecutor]:
# wait for next batch
recs = self.consumer.getmany(
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
msgs: list[MessageToExecutor] = [
serde.loads(msg.value) for msgs in recs.values() for msg in msgs
]
# process batch
concurrent.futures.wait(self.submit(self.each, msg) for msg in msgs)
# commit offsets
self.consumer.commit()
# return message
return msgs
def each(self, msg: MessageToExecutor) -> None:
try:
retry(self.retry_policy, self.attempt, msg)
except CheckpointNotLatest:
pass
except GraphDelegate as exc:
for arg in exc.args:
fut = self.producer.send(
self.topics.orchestrator,
value=serde.dumps(
MessageToOrchestrator(
config=arg["config"],
input=orjson.Fragment(
self.graph.checkpointer.serde.dumps(arg["input"])
),
finally_send=[
Sendable(topic=self.topics.executor, value=msg)
],
)
),
# use thread_id, checkpoint_ns as partition key
key=serde.dumps(
(
arg["config"]["configurable"]["thread_id"],
arg["config"]["configurable"].get("checkpoint_ns"),
)
),
)
fut.result()
except Exception as exc:
fut = self.producer.send(
self.topics.error,
value=serde.dumps(
ErrorMessage(
topic=self.topics.executor,
msg=msg,
error=repr(exc),
)
),
)
fut.result()
def attempt(self, msg: MessageToExecutor) -> None:
# find graph
if checkpoint_ns := msg["config"]["configurable"].get("checkpoint_ns"):
# remove task_ids from checkpoint_ns
recast_checkpoint_ns = NS_SEP.join(
part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP)
)
# find the subgraph with the matching name
if recast_checkpoint_ns in self.subgraphs:
graph = self.subgraphs[recast_checkpoint_ns]
else:
raise ValueError(f"Subgraph {recast_checkpoint_ns} not found")
else:
graph = self.graph
# process message
saved = self.graph.checkpointer.get_tuple(
patch_configurable(msg["config"], {"checkpoint_id": None})
)
if saved is None:
raise RuntimeError("Checkpoint not found")
if saved.checkpoint["id"] != msg["config"]["configurable"]["checkpoint_id"]:
raise CheckpointNotLatest()
with ChannelsManager(
graph.channels,
saved.checkpoint,
LoopProtocol(
config=msg["config"],
store=self.graph.store,
step=saved.metadata["step"] + 1,
stop=saved.metadata["step"] + 2,
),
) as (channels, managed), BackgroundExecutor({}) as submit:
if task := prepare_single_task(
msg["task"]["path"],
msg["task"]["id"],
checkpoint=saved.checkpoint,
pending_writes=saved.pending_writes or [],
processes=graph.nodes,
channels=channels,
managed=managed,
config=patch_configurable(msg["config"], {CONFIG_KEY_DELEGATE: True}),
step=saved.metadata["step"] + 1,
for_execution=True,
checkpointer=self.graph.checkpointer,
):
# execute task, saving writes
runner = PregelRunner(
submit=submit,
put_writes=partial(self._put_writes, submit, msg["config"]),
schedule_task=self._schedule_task,
)
for _ in runner.tick([task], reraise=False):
pass
else:
# task was not found
self.graph.checkpointer.put_writes(
msg["config"], [(ERROR, TaskNotFound())], str(UUID(int=0))
)
# notify orchestrator
fut = self.producer.send(
self.topics.orchestrator,
value=serde.dumps(
MessageToOrchestrator(
input=None,
config=msg["config"],
finally_send=msg.get("finally_send"),
)
),
# use thread_id, checkpoint_ns as partition key
key=serde.dumps(
(
msg["config"]["configurable"]["thread_id"],
msg["config"]["configurable"].get("checkpoint_ns"),
)
),
)
fut.result()
def _schedule_task(
self,
task: PregelExecutableTask,
idx: int,
) -> None:
# will be scheduled by orchestrator when executor finishes
pass
def _put_writes(
self,
submit: Submit,
config: RunnableConfig,
task_id: str,
writes: list[tuple[str, Any]],
) -> None:
return submit(self.graph.checkpointer.put_writes, config, writes, task_id)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/pyproject.toml`:
```toml
[tool.poetry]
name = "langgraph-scheduler-kafka"
version = "1.0.0"
description = "Library with Kafka-based work scheduler."
authors = []
license = "MIT"
readme = "README.md"
repository = "https://www.github.com/langchain-ai/langgraph"
packages = [{ include = "langgraph" }]
[tool.poetry.dependencies]
python = "^3.9.0,<4.0"
orjson = "^3.10.7"
crc32c = "^2.7.post1"
aiokafka = "^0.11.0"
langgraph = "^0.2.19"
[tool.poetry.group.dev.dependencies]
ruff = "^0.6.2"
codespell = "^2.2.0"
pytest = "^7.2.1"
pytest-mock = "^3.11.1"
pytest-watcher = "^0.4.1"
mypy = "^1.10.0"
langgraph = {path = "../langgraph", develop = true}
langgraph-checkpoint-postgres = {path = "../checkpoint-postgres", develop = true}
langgraph-checkpoint = {path = "../checkpoint", develop = true}
kafka-python-ng = "^2.2.2"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5 -vv"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
lint.select = [
"E", # pycodestyle
"F", # Pyflakes
"UP", # pyupgrade
"B", # flake8-bugbear
"I", # isort
]
lint.ignore = ["E501", "B008", "UP007", "UP006"]
[tool.pytest-watcher]
now = true
delay = 0.1
runner_args = ["--ff", "-v", "--tb", "short", "-s"]
patterns = ["*.py"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/conftest.py`:
```py
from typing import AsyncIterator, Iterator
from uuid import uuid4
import kafka.admin
import pytest
from psycopg import AsyncConnection, Connection
from psycopg_pool import AsyncConnectionPool, ConnectionPool
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.scheduler.kafka.types import Topics
DEFAULT_POSTGRES_URI = "postgres://postgres:postgres@localhost:5443/"
@pytest.fixture
def anyio_backend():
return "asyncio"
@pytest.fixture
def topics() -> Iterator[Topics]:
o = f"test_o_{uuid4().hex[:16]}"
e = f"test_e_{uuid4().hex[:16]}"
z = f"test_z_{uuid4().hex[:16]}"
admin = kafka.admin.KafkaAdminClient()
# create topics
admin.create_topics(
[
kafka.admin.NewTopic(name=o, num_partitions=1, replication_factor=1),
kafka.admin.NewTopic(name=e, num_partitions=1, replication_factor=1),
kafka.admin.NewTopic(name=z, num_partitions=1, replication_factor=1),
]
)
# yield topics
yield Topics(orchestrator=o, executor=e, error=z)
# delete topics
admin.delete_topics([o, e, z])
admin.close()
@pytest.fixture
async def acheckpointer() -> AsyncIterator[AsyncPostgresSaver]:
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
async with AsyncConnectionPool(
DEFAULT_POSTGRES_URI + database, max_size=10, kwargs={"autocommit": True}
) as pool:
checkpointer = AsyncPostgresSaver(pool)
await checkpointer.setup()
yield checkpointer
finally:
# drop unique db
async with await AsyncConnection.connect(
DEFAULT_POSTGRES_URI, autocommit=True
) as conn:
await conn.execute(f"DROP DATABASE {database}")
@pytest.fixture
def checkpointer() -> Iterator[PostgresSaver]:
database = f"test_{uuid4().hex[:16]}"
# create unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
with ConnectionPool(
DEFAULT_POSTGRES_URI + database, max_size=10, kwargs={"autocommit": True}
) as pool:
checkpointer = PostgresSaver(pool)
checkpointer.setup()
yield checkpointer
finally:
# drop unique db
with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn:
conn.execute(f"DROP DATABASE {database}")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_fanout_sync.py`:
```py
import operator
import time
from typing import (
Annotated,
Sequence,
TypedDict,
Union,
)
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph.state import StateGraph
from langgraph.pregel import Pregel
from langgraph.scheduler.kafka import serde
from langgraph.scheduler.kafka.default_sync import DefaultProducer
from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics
from tests.any import AnyDict
from tests.drain import drain_topics
def mk_fanout_graph(
checkpointer: BaseCheckpointSaver, interrupt_before: Sequence[str] = ()
) -> Pregel:
# copied from test_in_one_fan_out_state_graph_waiting_edge_multiple_cond_edge
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
def retriever_picker(data: State) -> list[str]:
return ["analyzer_one", "retriever_two"]
def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
def retriever_two(data: State) -> State:
time.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
def decider(data: State) -> None:
return None
def decider_cond(data: State) -> str:
if data["query"].count("analyzed") > 1:
return "qa"
else:
return "rewrite_query"
builder = StateGraph(State)
builder.add_node("rewrite_query", rewrite_query)
builder.add_node("analyzer_one", analyzer_one)
builder.add_node("retriever_one", retriever_one)
builder.add_node("retriever_two", retriever_two)
builder.add_node("decider", decider)
builder.add_node("qa", qa)
builder.set_entry_point("rewrite_query")
builder.add_conditional_edges("rewrite_query", retriever_picker)
builder.add_edge("analyzer_one", "retriever_one")
builder.add_edge(["retriever_one", "retriever_two"], "decider")
builder.add_conditional_edges("decider", decider_cond)
builder.set_finish_point("qa")
return builder.compile(checkpointer, interrupt_before=interrupt_before)
def test_fanout_graph(topics: Topics, checkpointer: BaseCheckpointSaver) -> None:
input = {"query": "what is weather in sf"}
config = {"configurable": {"thread_id": "1"}}
graph = mk_fanout_graph(checkpointer)
# start a new run
with DefaultProducer() as producer:
producer.send(
topics.orchestrator,
value=serde.dumps(MessageToOrchestrator(input=input, config=config)),
)
producer.flush()
# drain topics
orch_msgs, exec_msgs = drain_topics(topics, graph, debug=1)
# check state
state = graph.get_state(config)
assert state.next == ()
assert (
state.values
== graph.invoke(input, {"configurable": {"thread_id": "2"}})
== {
"docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"],
"query": "analyzed: query: analyzed: query: what is weather in sf",
"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4",
}
)
# check history
history = [c for c in graph.get_state_history(config)]
assert len(history) == 11
# check messages
assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": None,
}
for c in reversed(history)
for _ in c.tasks
]
assert exec_msgs == [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": None,
}
for c in reversed(history)
for t in c.tasks
]
def test_fanout_graph_w_interrupt(
topics: Topics, checkpointer: BaseCheckpointSaver
) -> None:
input = {"query": "what is weather in sf"}
config = {"configurable": {"thread_id": "1"}}
graph = mk_fanout_graph(checkpointer, interrupt_before=["qa"])
# start a new run
with DefaultProducer() as producer:
producer.send(
topics.orchestrator,
value=serde.dumps(MessageToOrchestrator(input=input, config=config)),
)
producer.flush()
orch_msgs, exec_msgs = drain_topics(topics, graph, debug=1)
# check interrupted state
state = graph.get_state(config)
assert state.next == ("qa",)
assert (
state.values
== graph.invoke(input, {"configurable": {"thread_id": "2"}})
== {
"docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"],
"query": "analyzed: query: analyzed: query: what is weather in sf",
}
)
# check history
history = [c for c in graph.get_state_history(config)]
assert len(history) == 10
# check messages
assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
# orchestrator messages appear only after tasks for that checkpoint
# finish executing, ie. after executor sends message to resume checkpoint
for _ in c.tasks
]
assert exec_msgs == [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
for t in c.tasks
]
# resume the thread
with DefaultProducer() as producer:
producer.send(
topics.orchestrator,
value=serde.dumps(MessageToOrchestrator(input=None, config=config)),
)
producer.flush()
orch_msgs, exec_msgs = drain_topics(topics, graph)
# check final state
state = graph.get_state(config)
assert state.next == ()
assert (
state.values
== graph.invoke(None, {"configurable": {"thread_id": "2"}})
== {
"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4",
"docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"],
"query": "analyzed: query: analyzed: query: what is weather in sf",
}
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_fanout.py`:
```py
import asyncio
import operator
from typing import (
Annotated,
Sequence,
TypedDict,
Union,
)
import pytest
from aiokafka import AIOKafkaProducer
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph.state import StateGraph
from langgraph.pregel import Pregel
from langgraph.scheduler.kafka import serde
from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics
from tests.any import AnyDict
from tests.drain import drain_topics_async
pytestmark = pytest.mark.anyio
def mk_fanout_graph(
checkpointer: BaseCheckpointSaver, interrupt_before: Sequence[str] = ()
) -> Pregel:
# copied from test_in_one_fan_out_state_graph_waiting_edge_multiple_cond_edge
def sorted_add(
x: list[str], y: Union[list[str], list[tuple[str, str]]]
) -> list[str]:
if isinstance(y[0], tuple):
for rem, _ in y:
x.remove(rem)
y = [t[1] for t in y]
return sorted(operator.add(x, y))
class State(TypedDict, total=False):
query: str
answer: str
docs: Annotated[list[str], sorted_add]
async def rewrite_query(data: State) -> State:
return {"query": f'query: {data["query"]}'}
async def retriever_picker(data: State) -> list[str]:
return ["analyzer_one", "retriever_two"]
async def analyzer_one(data: State) -> State:
return {"query": f'analyzed: {data["query"]}'}
async def retriever_one(data: State) -> State:
return {"docs": ["doc1", "doc2"]}
async def retriever_two(data: State) -> State:
await asyncio.sleep(0.1)
return {"docs": ["doc3", "doc4"]}
async def qa(data: State) -> State:
return {"answer": ",".join(data["docs"])}
async def decider(data: State) -> None:
return None
def decider_cond(data: State) -> str:
if data["query"].count("analyzed") > 1:
return "qa"
else:
return "rewrite_query"
builder = StateGraph(State)
builder.add_node("rewrite_query", rewrite_query)
builder.add_node("analyzer_one", analyzer_one)
builder.add_node("retriever_one", retriever_one)
builder.add_node("retriever_two", retriever_two)
builder.add_node("decider", decider)
builder.add_node("qa", qa)
builder.set_entry_point("rewrite_query")
builder.add_conditional_edges("rewrite_query", retriever_picker)
builder.add_edge("analyzer_one", "retriever_one")
builder.add_edge(["retriever_one", "retriever_two"], "decider")
builder.add_conditional_edges("decider", decider_cond)
builder.set_finish_point("qa")
return builder.compile(checkpointer, interrupt_before=interrupt_before)
async def test_fanout_graph(topics: Topics, acheckpointer: BaseCheckpointSaver) -> None:
input = {"query": "what is weather in sf"}
config = {"configurable": {"thread_id": "1"}}
graph = mk_fanout_graph(acheckpointer)
# start a new run
async with AIOKafkaProducer(value_serializer=serde.dumps) as producer:
await producer.send_and_wait(
topics.orchestrator,
MessageToOrchestrator(input=input, config=config),
)
# drain topics
orch_msgs, exec_msgs = await drain_topics_async(topics, graph)
# check state
state = await graph.aget_state(config)
assert state.next == ()
assert (
state.values
== await graph.ainvoke(input, {"configurable": {"thread_id": "2"}})
== {
"docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"],
"query": "analyzed: query: analyzed: query: what is weather in sf",
"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4",
}
)
# check history
history = [c async for c in graph.aget_state_history(config)]
assert len(history) == 11
# check messages
assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": None,
}
for c in reversed(history)
for _ in c.tasks
]
assert exec_msgs == [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": None,
}
for c in reversed(history)
for t in c.tasks
]
async def test_fanout_graph_w_interrupt(
topics: Topics, acheckpointer: BaseCheckpointSaver
) -> None:
input = {"query": "what is weather in sf"}
config = {"configurable": {"thread_id": "1"}}
graph = mk_fanout_graph(acheckpointer, interrupt_before=["qa"])
# start a new run
async with AIOKafkaProducer(value_serializer=serde.dumps) as producer:
await producer.send_and_wait(
topics.orchestrator,
MessageToOrchestrator(input=input, config=config),
)
orch_msgs, exec_msgs = await drain_topics_async(topics, graph)
# check interrupted state
state = await graph.aget_state(config)
assert state.next == ("qa",)
assert (
state.values
== await graph.ainvoke(input, {"configurable": {"thread_id": "2"}})
== {
"docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"],
"query": "analyzed: query: analyzed: query: what is weather in sf",
}
)
# check history
history = [c async for c in graph.aget_state_history(config)]
assert len(history) == 10
# check messages
assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
# orchestrator messages appear only after tasks for that checkpoint
# finish executing, ie. after executor sends message to resume checkpoint
for _ in c.tasks
]
assert exec_msgs == [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
for t in c.tasks
]
# resume the thread
async with AIOKafkaProducer(value_serializer=serde.dumps) as producer:
await producer.send_and_wait(
topics.orchestrator,
MessageToOrchestrator(input=None, config=config),
)
orch_msgs, exec_msgs = await drain_topics_async(topics, graph)
# check final state
state = await graph.aget_state(config)
assert state.next == ()
assert (
state.values
== await graph.ainvoke(None, {"configurable": {"thread_id": "2"}})
== {
"answer": "doc1,doc1,doc2,doc2,doc3,doc3,doc4,doc4",
"docs": ["doc1", "doc1", "doc2", "doc2", "doc3", "doc3", "doc4", "doc4"],
"query": "analyzed: query: analyzed: query: what is weather in sf",
}
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_subgraph.py`:
```py
from typing import Literal, cast
import pytest
from aiokafka import AIOKafkaProducer
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, ToolCall
from langchain_core.tools import tool
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import END, START
from langgraph.graph import MessagesState
from langgraph.graph.state import StateGraph
from langgraph.pregel import Pregel
from langgraph.scheduler.kafka import serde
from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics
from tests.any import AnyDict, AnyList
from tests.drain import drain_topics_async
from tests.messages import _AnyIdAIMessage, _AnyIdHumanMessage
pytestmark = pytest.mark.anyio
def mk_weather_graph(checkpointer: BaseCheckpointSaver) -> Pregel:
# copied from test_weather_subgraph
# setup subgraph
@tool
def get_weather(city: str):
"""Get the weather for a specific city"""
return f"I'ts sunny in {city}!"
weather_model = FakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
ToolCall(
id="tool_call123",
name="get_weather",
args={"city": "San Francisco"},
)
],
)
]
)
class SubGraphState(MessagesState):
city: str
def model_node(state: SubGraphState):
result = weather_model.invoke(state["messages"])
return {"city": cast(AIMessage, result).tool_calls[0]["args"]["city"]}
def weather_node(state: SubGraphState):
result = get_weather.invoke({"city": state["city"]})
return {"messages": [{"role": "assistant", "content": result}]}
subgraph = StateGraph(SubGraphState)
subgraph.add_node(model_node)
subgraph.add_node(weather_node)
subgraph.add_edge(START, "model_node")
subgraph.add_edge("model_node", "weather_node")
subgraph.add_edge("weather_node", END)
subgraph = subgraph.compile(interrupt_before=["weather_node"])
# setup main graph
class RouterState(MessagesState):
route: Literal["weather", "other"]
router_model = FakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
ToolCall(
id="tool_call123",
name="router",
args={"dest": "weather"},
)
],
)
]
)
def router_node(state: RouterState):
system_message = "Classify the incoming query as either about weather or not."
messages = [{"role": "system", "content": system_message}] + state["messages"]
route = router_model.invoke(messages)
return {"route": cast(AIMessage, route).tool_calls[0]["args"]["dest"]}
def normal_llm_node(state: RouterState):
return {"messages": [AIMessage("Hello!")]}
def route_after_prediction(state: RouterState):
if state["route"] == "weather":
return "weather_graph"
else:
return "normal_llm_node"
async def weather_graph(state: RouterState):
return await subgraph.ainvoke(state)
graph = StateGraph(RouterState)
graph.add_node(router_node)
graph.add_node(normal_llm_node)
graph.add_node("weather_graph", weather_graph)
graph.add_edge(START, "router_node")
graph.add_conditional_edges("router_node", route_after_prediction)
graph.add_edge("normal_llm_node", END)
graph.add_edge("weather_graph", END)
return graph.compile(checkpointer=checkpointer)
async def test_subgraph_w_interrupt(
topics: Topics, acheckpointer: BaseCheckpointSaver
) -> None:
input = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}
config = {"configurable": {"thread_id": "1"}}
graph = mk_weather_graph(acheckpointer)
# start a new run
async with AIOKafkaProducer(value_serializer=serde.dumps) as producer:
await producer.send_and_wait(
topics.orchestrator,
MessageToOrchestrator(input=input, config=config),
)
orch_msgs, exec_msgs = await drain_topics_async(topics, graph)
# check interrupted state
state = await graph.aget_state(config)
assert state.next == ("weather_graph",)
assert state.values == {
"messages": [_AnyIdHumanMessage(content="what's the weather in sf")],
"route": "weather",
}
# check outer history
history = [c async for c in graph.aget_state_history(config)]
assert len(history) == 3
# check child history
child_history = [
c async for c in graph.aget_state_history(history[0].tasks[0].state)
]
assert len(child_history) == 3
# check messages
assert (
orch_msgs
== (
# initial message to outer graph
[MessageToOrchestrator(input=input, config=config)]
# outer graph messages, until interrupted
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
# orchestrator messages appear only after tasks for that checkpoint
# finish executing, ie. after executor sends message to resume checkpoint
for _ in c.tasks
]
# initial message to child graph
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"__pregel_store": None,
"__pregel_task_id": history[0].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": None,
"checkpoint_map": {
"": history[0].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[0]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": {
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf")
],
"route": "weather",
},
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": False,
"checkpoint_id": history[0].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[0].tasks[0].id,
"path": list(history[0].tasks[0].path),
},
},
}
],
}
]
# child graph messages, until interrupted
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"__pregel_store": None,
"__pregel_task_id": history[0].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_map": {
"": history[0].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[0]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": False,
"checkpoint_id": history[0].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[0].tasks[0].id,
"path": list(history[0].tasks[0].path),
},
},
}
],
}
for c in reversed(child_history[1:]) # the last one wasn't executed
# orchestrator messages appear only after tasks for that checkpoint
# finish executing, ie. after executor sends message to resume checkpoint
for _ in c.tasks
]
)
)
assert (
exec_msgs
== (
# outer graph tasks
[
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": None,
}
for c in reversed(history)
for t in c.tasks
]
# child graph tasks
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"__pregel_store": None,
"__pregel_task_id": history[0].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_map": {
"": history[0].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[0]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": False,
"checkpoint_id": history[0].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[0].tasks[0].id,
"path": list(history[0].tasks[0].path),
},
},
}
],
}
for c in reversed(child_history[1:]) # the last one wasn't executed
for t in c.tasks
]
)
)
# resume the thread
async with AIOKafkaProducer(value_serializer=serde.dumps) as producer:
await producer.send_and_wait(
topics.orchestrator,
MessageToOrchestrator(input=None, config=config),
)
orch_msgs, exec_msgs = await drain_topics_async(topics, graph)
# check final state
state = await graph.aget_state(config)
assert state.next == ()
assert state.values == {
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf"),
_AnyIdAIMessage(content="I'ts sunny in San Francisco!"),
],
"route": "weather",
}
# check outer history
history = [c async for c in graph.aget_state_history(config)]
assert len(history) == 4
# check child history
# accessing second to last checkpoint, since that's the one w/ subgraph task
child_history = [
c async for c in graph.aget_state_history(history[1].tasks[0].state)
]
assert len(child_history) == 4
# check messages
assert (
orch_msgs
== (
# initial message to outer graph
[MessageToOrchestrator(input=None, config=config)]
# initial message to child graph
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": True,
"__pregel_store": None,
"__pregel_task_id": history[1].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": None,
"checkpoint_map": {
"": history[1].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[1]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": True,
"checkpoint_id": history[1].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[1].tasks[0].id,
"path": list(history[1].tasks[0].path),
},
},
}
],
}
]
# child graph messages, from previous last checkpoint onwards
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": True,
"__pregel_store": None,
"__pregel_task_id": history[1].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_map": {
"": history[1].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[1]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": True,
"checkpoint_id": history[1].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[1].tasks[0].id,
"path": list(history[1].tasks[0].path),
},
},
}
],
}
for c in reversed(child_history[:2])
for _ in c.tasks
]
# outer graph messages, from previous last checkpoint onwards
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": True,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": None,
}
for c in reversed(history[:2])
for _ in c.tasks
]
)
)
assert (
exec_msgs
== (
# outer graph tasks
[
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": True,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": None,
}
for c in reversed(history[:2])
for t in c.tasks
]
# child graph tasks
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": True,
"__pregel_store": None,
"__pregel_task_id": history[1].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_map": {
"": history[1].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[1]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": True,
"checkpoint_id": history[1].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[1].tasks[0].id,
"path": list(history[1].tasks[0].path),
},
},
}
],
}
for c in reversed(child_history[:2])
for t in c.tasks
]
# "finally" tasks
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": True,
"checkpoint_id": history[1].config["configurable"][
"checkpoint_id"
],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[1].tasks[0].id,
"path": list(history[1].tasks[0].path),
},
}
]
)
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/messages.py`:
```py
"""Redefined messages as a work-around for pydantic issue with AnyStr.
The code below creates version of pydantic models
that will work in unit tests with AnyStr as id field
Please note that the `id` field is assigned AFTER the model is created
to workaround an issue with pydantic ignoring the __eq__ method on
subclassed strings.
"""
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage
from tests.any import AnyStr
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
"""Create ai message with an any id field."""
message = AIMessage(**kwargs)
message.id = AnyStr()
return message
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
"""Create a human message with an any id field."""
message = HumanMessage(**kwargs)
message.id = AnyStr()
return message
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_push.py`:
```py
import operator
from typing import (
Annotated,
Literal,
Union,
)
import pytest
from aiokafka import AIOKafkaProducer
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import FF_SEND_V2, START
from langgraph.errors import NodeInterrupt
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.scheduler.kafka import serde
from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics
from langgraph.types import Command, Send
from tests.any import AnyDict
from tests.drain import drain_topics_async
pytestmark = pytest.mark.anyio
def mk_push_graph(
checkpointer: BaseCheckpointSaver,
) -> CompiledStateGraph:
# copied from test_send_dedupe_on_resume
class InterruptOnce:
ticks: int = 0
def __call__(self, state):
self.ticks += 1
if self.ticks == 1:
raise NodeInterrupt("Bahh")
return ["|".join(("flaky", str(state)))]
class Node:
def __init__(self, name: str):
self.name = name
self.ticks = 0
self.__name__ = name
def __call__(self, state):
self.ticks += 1
update = (
[self.name]
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Command):
return state.copy(update=update)
else:
return update
def send_for_fun(state):
return [
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("flaky", 4))),
"3.1",
]
def route_to_three(state) -> Literal["3"]:
return "3"
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_node("flaky", InterruptOnce())
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
return builder.compile(checkpointer=checkpointer)
async def test_push_graph(topics: Topics, acheckpointer: BaseCheckpointSaver) -> None:
if not FF_SEND_V2:
pytest.skip("Test requires FF_SEND_V2")
input = ["0"]
config = {"configurable": {"thread_id": "1"}}
graph = mk_push_graph(acheckpointer)
graph_compare = mk_push_graph(acheckpointer)
# start a new run
async with AIOKafkaProducer(value_serializer=serde.dumps) as producer:
await producer.send_and_wait(
topics.orchestrator,
MessageToOrchestrator(input=input, config=config),
)
# drain topics
orch_msgs, exec_msgs = await drain_topics_async(topics, graph)
# check state
state = await graph.aget_state(config)
assert all(not t.error for t in state.tasks)
assert state.next == ("flaky",)
assert (
state.values
== await graph_compare.ainvoke(input, {"configurable": {"thread_id": "2"}})
== [
"0",
"1",
"2|Control(goto=Send(node='2', arg=3))",
"2|Control(goto=Send(node='flaky', arg=4))",
"2|3",
]
)
# check history
history = [c async for c in graph.aget_state_history(config)]
assert len(history) == 2
# check messages
assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": None,
}
for c in reversed(history)
for _ in c.tasks
]
assert exec_msgs == [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": _convert_path(t.path),
},
"finally_send": None,
}
for c in reversed(history)
for t in c.tasks
]
# resume the thread
async with AIOKafkaProducer(value_serializer=serde.dumps) as producer:
await producer.send_and_wait(
topics.orchestrator,
MessageToOrchestrator(input=None, config=config),
)
orch_msgs, exec_msgs = await drain_topics_async(topics, graph)
# check final state
state = await graph.aget_state(config)
assert state.next == ()
assert (
state.values
== await graph_compare.ainvoke(None, {"configurable": {"thread_id": "2"}})
== [
"0",
"1",
"2|Control(goto=Send(node='2', arg=3))",
"2|Control(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
"3.1",
]
)
# check history
history = [c async for c in graph.aget_state_history(config)]
assert len(history) == 4
# check executions
# node "2" doesn't get called again, as we recover writes saved before
assert graph.builder.nodes["2"].runnable.func.ticks == 3
# node "flaky" gets called again, as it was interrupted
assert graph.builder.nodes["flaky"].runnable.func.ticks == 2
def _convert_path(
path: tuple[Union[str, int, tuple], ...],
) -> list[Union[str, int, list]]:
return list(_convert_path(p) if isinstance(p, tuple) else p for p in path)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_subgraph_sync.py`:
```py
from typing import Literal, cast
import pytest
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, ToolCall
from langchain_core.tools import tool
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import END, START
from langgraph.graph import MessagesState
from langgraph.graph.state import StateGraph
from langgraph.pregel import Pregel
from langgraph.scheduler.kafka import serde
from langgraph.scheduler.kafka.default_sync import DefaultProducer
from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics
from tests.any import AnyDict, AnyList
from tests.drain import drain_topics
from tests.messages import _AnyIdAIMessage, _AnyIdHumanMessage
pytestmark = pytest.mark.anyio
def mk_weather_graph(checkpointer: BaseCheckpointSaver) -> Pregel:
# copied from test_weather_subgraph
# setup subgraph
@tool
def get_weather(city: str):
"""Get the weather for a specific city"""
return f"I'ts sunny in {city}!"
weather_model = FakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
ToolCall(
id="tool_call123",
name="get_weather",
args={"city": "San Francisco"},
)
],
)
]
)
class SubGraphState(MessagesState):
city: str
def model_node(state: SubGraphState):
result = weather_model.invoke(state["messages"])
return {"city": cast(AIMessage, result).tool_calls[0]["args"]["city"]}
def weather_node(state: SubGraphState):
result = get_weather.invoke({"city": state["city"]})
return {"messages": [{"role": "assistant", "content": result}]}
subgraph = StateGraph(SubGraphState)
subgraph.add_node(model_node)
subgraph.add_node(weather_node)
subgraph.add_edge(START, "model_node")
subgraph.add_edge("model_node", "weather_node")
subgraph.add_edge("weather_node", END)
subgraph = subgraph.compile(interrupt_before=["weather_node"])
# setup main graph
class RouterState(MessagesState):
route: Literal["weather", "other"]
router_model = FakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
ToolCall(
id="tool_call123",
name="router",
args={"dest": "weather"},
)
],
)
]
)
def router_node(state: RouterState):
system_message = "Classify the incoming query as either about weather or not."
messages = [{"role": "system", "content": system_message}] + state["messages"]
route = router_model.invoke(messages)
return {"route": cast(AIMessage, route).tool_calls[0]["args"]["dest"]}
def normal_llm_node(state: RouterState):
return {"messages": [AIMessage("Hello!")]}
def route_after_prediction(state: RouterState):
if state["route"] == "weather":
return "weather_graph"
else:
return "normal_llm_node"
def weather_graph(state: RouterState):
return subgraph.invoke(state)
graph = StateGraph(RouterState)
graph.add_node(router_node)
graph.add_node(normal_llm_node)
graph.add_node("weather_graph", weather_graph)
graph.add_edge(START, "router_node")
graph.add_conditional_edges("router_node", route_after_prediction)
graph.add_edge("normal_llm_node", END)
graph.add_edge("weather_graph", END)
return graph.compile(checkpointer=checkpointer)
def test_subgraph_w_interrupt(
topics: Topics, checkpointer: BaseCheckpointSaver
) -> None:
input = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}
config = {"configurable": {"thread_id": "1"}}
graph = mk_weather_graph(checkpointer)
# start a new run
with DefaultProducer() as producer:
producer.send(
topics.orchestrator,
value=serde.dumps(MessageToOrchestrator(input=input, config=config)),
)
producer.flush()
orch_msgs, exec_msgs = drain_topics(topics, graph)
# check interrupted state
state = graph.get_state(config)
assert state.next == ("weather_graph",)
assert state.values == {
"messages": [_AnyIdHumanMessage(content="what's the weather in sf")],
"route": "weather",
}
# check outer history
history = [c for c in graph.get_state_history(config)]
assert len(history) == 3
# check child history
child_history = [c for c in graph.get_state_history(history[0].tasks[0].state)]
assert len(child_history) == 3
# check messages
assert (
orch_msgs
== (
# initial message to outer graph
[MessageToOrchestrator(input=input, config=config)]
# outer graph messages, until interrupted
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
# orchestrator messages appear only after tasks for that checkpoint
# finish executing, ie. after executor sends message to resume checkpoint
for _ in c.tasks
]
# initial message to child graph
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"__pregel_store": None,
"__pregel_task_id": history[0].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": None,
"checkpoint_map": {
"": history[0].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[0]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": {
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf")
],
"route": "weather",
},
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": False,
"checkpoint_id": history[0].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[0].tasks[0].id,
"path": list(history[0].tasks[0].path),
},
},
}
],
}
]
# child graph messages, until interrupted
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_store": None,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"__pregel_task_id": history[0].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_map": {
"": history[0].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[0]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": False,
"checkpoint_id": history[0].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[0].tasks[0].id,
"path": list(history[0].tasks[0].path),
},
},
}
],
}
for c in reversed(child_history[1:]) # the last one wasn't executed
# orchestrator messages appear only after tasks for that checkpoint
# finish executing, ie. after executor sends message to resume checkpoint
for _ in c.tasks
]
)
)
assert (
exec_msgs
== (
# outer graph tasks
[
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": None,
}
for c in reversed(history)
for t in c.tasks
]
# child graph tasks
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_store": None,
"__pregel_resuming": False,
"__pregel_task_id": history[0].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_map": {
"": history[0].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[0]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": False,
"checkpoint_id": history[0].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[0].tasks[0].id,
"path": list(history[0].tasks[0].path),
},
},
}
],
}
for c in reversed(child_history[1:]) # the last one wasn't executed
for t in c.tasks
]
)
)
# resume the thread
with DefaultProducer() as producer:
producer.send(
topics.orchestrator,
value=serde.dumps(MessageToOrchestrator(input=None, config=config)),
)
producer.flush()
orch_msgs, exec_msgs = drain_topics(topics, graph)
# check final state
state = graph.get_state(config)
assert state.next == ()
assert state.values == {
"messages": [
_AnyIdHumanMessage(content="what's the weather in sf"),
_AnyIdAIMessage(content="I'ts sunny in San Francisco!"),
],
"route": "weather",
}
# check outer history
history = [c for c in graph.get_state_history(config)]
assert len(history) == 4
# check child history
# accessing second to last checkpoint, since that's the one w/ subgraph task
child_history = [c for c in graph.get_state_history(history[1].tasks[0].state)]
assert len(child_history) == 4
# check messages
assert (
orch_msgs
== (
# initial message to outer graph
[MessageToOrchestrator(input=None, config=config)]
# initial message to child graph
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_store": None,
"__pregel_resuming": True,
"__pregel_task_id": history[1].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": None,
"checkpoint_map": {
"": history[1].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[1]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": True,
"checkpoint_id": history[1].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[1].tasks[0].id,
"path": list(history[1].tasks[0].path),
},
},
}
],
}
]
# child graph messages, from previous last checkpoint onwards
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_store": None,
"__pregel_resuming": True,
"__pregel_task_id": history[1].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_map": {
"": history[1].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[1]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": True,
"checkpoint_id": history[1].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[1].tasks[0].id,
"path": list(history[1].tasks[0].path),
},
},
}
],
}
for c in reversed(child_history[:2])
for _ in c.tasks
]
# outer graph messages, from previous last checkpoint onwards
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": True,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": None,
}
for c in reversed(history[:2])
for _ in c.tasks
]
)
)
assert (
exec_msgs
== (
# outer graph tasks
[
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": True,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": None,
}
for c in reversed(history[:2])
for t in c.tasks
]
# child graph tasks
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_checkpointer": None,
"__pregel_delegate": False,
"__pregel_read": None,
"__pregel_send": None,
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": True,
"__pregel_store": None,
"__pregel_task_id": history[1].tasks[0].id,
"__pregel_scratchpad": {},
"__pregel_writes": AnyList(),
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_map": {
"": history[1].config["configurable"]["checkpoint_id"]
},
"checkpoint_ns": history[1]
.tasks[0]
.state["configurable"]["checkpoint_ns"],
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": list(t.path),
},
"finally_send": [
{
"topic": topics.executor,
"value": {
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": True,
"checkpoint_id": history[1].config[
"configurable"
]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[1].tasks[0].id,
"path": list(history[1].tasks[0].path),
},
},
}
],
}
for c in reversed(child_history[:2])
for t in c.tasks
]
# "finally" tasks
+ [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_dedupe_tasks": True,
"__pregel_ensure_latest": True,
"__pregel_resuming": True,
"checkpoint_id": history[1].config["configurable"][
"checkpoint_id"
],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"finally_send": None,
"task": {
"id": history[1].tasks[0].id,
"path": list(history[1].tasks[0].path),
},
}
]
)
)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/any.py`:
```py
import re
from typing import Union
class AnyStr(str):
def __init__(self, prefix: Union[str, re.Pattern] = "") -> None:
super().__init__()
self.prefix = prefix
def __eq__(self, other: object) -> bool:
return isinstance(other, str) and (
other.startswith(self.prefix)
if isinstance(self.prefix, str)
else self.prefix.match(other)
)
def __hash__(self) -> int:
return hash((str(self), self.prefix))
class AnyDict(dict):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __eq__(self, other: object) -> bool:
if not self and isinstance(other, dict):
return True
if not isinstance(other, dict) or len(self) != len(other):
return False
for k, v in self.items():
if kk := next((kk for kk in other if kk == k), None):
if v == other[kk]:
continue
else:
return False
else:
return True
class AnyList(list):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __eq__(self, other: object) -> bool:
if not self and isinstance(other, list):
return True
if not isinstance(other, list) or len(self) != len(other):
return False
for i, v in enumerate(self):
if v == other[i]:
continue
else:
return False
else:
return True
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/drain.py`:
```py
import asyncio
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, TypeVar
import anyio
from aiokafka import AIOKafkaConsumer
from typing_extensions import ParamSpec
from langgraph.pregel import Pregel
from langgraph.scheduler.kafka.default_sync import DefaultConsumer
from langgraph.scheduler.kafka.executor import AsyncKafkaExecutor, KafkaExecutor
from langgraph.scheduler.kafka.orchestrator import (
AsyncKafkaOrchestrator,
KafkaOrchestrator,
)
from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics
C = ParamSpec("C")
R = TypeVar("R")
async def drain_topics_async(
topics: Topics, graph: Pregel, *, debug: bool = False
) -> tuple[list[MessageToOrchestrator], list[MessageToOrchestrator]]:
scope: Optional[anyio.CancelScope] = None
orch_msgs = []
exec_msgs = []
errors = []
def done() -> bool:
return (
len(orch_msgs) > 0
and any(orch_msgs)
and len(exec_msgs) > 0
and any(exec_msgs)
and not orch_msgs[-1]
and not exec_msgs[-1]
)
async def orchestrator() -> None:
async with AsyncKafkaOrchestrator(graph, topics) as orch:
async for msgs in orch:
orch_msgs.append(msgs)
if debug:
print("\n---\norch", len(msgs), msgs)
if done():
scope.cancel()
async def executor() -> None:
async with AsyncKafkaExecutor(graph, topics) as exec:
async for msgs in exec:
exec_msgs.append(msgs)
if debug:
print("\n---\nexec", len(msgs), msgs)
if done():
scope.cancel()
async def error_consumer() -> None:
async with AIOKafkaConsumer(topics.error) as consumer:
async for msg in consumer:
errors.append(msg)
if scope:
scope.cancel()
# start error consumer
error_task = asyncio.create_task(error_consumer(), name="error_consumer")
# run the orchestrator and executor until break_when
async with anyio.create_task_group() as tg:
tg.cancel_scope.deadline = anyio.current_time() + 20
scope = tg.cancel_scope
tg.start_soon(orchestrator, name="orchestrator")
tg.start_soon(executor, name="executor")
# cancel error consumer
error_task.cancel()
try:
await error_task
except asyncio.CancelledError:
pass
# check no errors
assert not errors, errors
return [m for mm in orch_msgs for m in mm], [m for mm in exec_msgs for m in mm]
def drain_topics(
topics: Topics, graph: Pregel, *, debug: bool = False
) -> tuple[list[MessageToOrchestrator], list[MessageToOrchestrator]]:
orch_msgs = []
exec_msgs = []
errors = []
event = threading.Event()
def done() -> bool:
return (
len(orch_msgs) > 0
and any(orch_msgs)
and len(exec_msgs) > 0
and any(exec_msgs)
and not orch_msgs[-1]
and not exec_msgs[-1]
)
def orchestrator() -> None:
try:
with KafkaOrchestrator(graph, topics) as orch:
for msgs in orch:
orch_msgs.append(msgs)
if debug:
print("\n---\norch", len(msgs), msgs)
if done():
event.set()
if event.is_set():
break
except Exception as e:
errors.append(e)
event.set()
def executor() -> None:
try:
with KafkaExecutor(graph, topics) as exec:
for msgs in exec:
exec_msgs.append(msgs)
if debug:
print("\n---\nexec", len(msgs), msgs)
if done():
event.set()
if event.is_set():
break
except Exception as e:
errors.append(e)
event.set()
def error_consumer() -> None:
try:
with DefaultConsumer(topics.error) as consumer:
while not event.is_set():
if msg := consumer.poll(timeout_ms=100):
errors.append(msg)
event.set()
except Exception as e:
errors.append(e)
event.set()
with ThreadPoolExecutor() as pool:
# start error consumer
pool.submit(error_consumer)
# run the orchestrator and executor until break_when
pool.submit(orchestrator)
pool.submit(executor)
# timeout
start = time.time()
while not event.is_set():
time.sleep(0.1)
if time.time() - start > 20:
event.set()
# check no errors
assert not errors, errors
return [m for mm in orch_msgs for m in mm], [m for mm in exec_msgs for m in mm]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/scheduler-kafka/tests/test_push_sync.py`:
```py
import operator
from typing import (
Annotated,
Literal,
Union,
)
import pytest
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import FF_SEND_V2, START
from langgraph.errors import NodeInterrupt
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.scheduler.kafka import serde
from langgraph.scheduler.kafka.default_sync import DefaultProducer
from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics
from langgraph.types import Command, Send
from tests.any import AnyDict
from tests.drain import drain_topics
pytestmark = pytest.mark.anyio
def mk_push_graph(
checkpointer: BaseCheckpointSaver,
) -> CompiledStateGraph:
# copied from test_send_dedupe_on_resume
class InterruptOnce:
ticks: int = 0
def __call__(self, state):
self.ticks += 1
if self.ticks == 1:
raise NodeInterrupt("Bahh")
return ["|".join(("flaky", str(state)))]
class Node:
def __init__(self, name: str):
self.name = name
self.ticks = 0
self.__name__ = name
def __call__(self, state):
self.ticks += 1
update = (
[self.name]
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Command):
return state.copy(update=update)
else:
return update
def send_for_fun(state):
return [
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("flaky", 4))),
"3.1",
]
def route_to_three(state) -> Literal["3"]:
return "3"
builder = StateGraph(Annotated[list, operator.add])
builder.add_node(Node("1"))
builder.add_node(Node("2"))
builder.add_node(Node("3"))
builder.add_node(Node("3.1"))
builder.add_node("flaky", InterruptOnce())
builder.add_edge(START, "1")
builder.add_conditional_edges("1", send_for_fun)
builder.add_conditional_edges("2", route_to_three)
return builder.compile(checkpointer=checkpointer)
def test_push_graph(topics: Topics, acheckpointer: BaseCheckpointSaver) -> None:
if not FF_SEND_V2:
pytest.skip("Test requires FF_SEND_V2")
input = ["0"]
config = {"configurable": {"thread_id": "1"}}
graph = mk_push_graph(acheckpointer)
graph_compare = mk_push_graph(acheckpointer)
# start a new run
with DefaultProducer() as producer:
producer.send(
topics.orchestrator,
value=serde.dumps(MessageToOrchestrator(input=input, config=config)),
)
producer.flush()
# drain topics
orch_msgs, exec_msgs = drain_topics(topics, graph)
# check state
state = graph.get_state(config)
assert all(not t.error for t in state.tasks)
assert state.next == ("flaky",)
assert (
state.values
== graph_compare.invoke(input, {"configurable": {"thread_id": "2"}})
== [
"0",
"1",
"2|Control(goto=Send(node='2', arg=3))",
"2|Control(goto=Send(node='flaky', arg=4))",
"2|3",
]
)
# check history
history = [c for c in graph.get_state_history(config)]
assert len(history) == 2
# check messages
assert orch_msgs == [MessageToOrchestrator(input=input, config=config)] + [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"input": None,
"finally_send": None,
}
for c in reversed(history)
for _ in c.tasks
]
assert exec_msgs == [
{
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
"checkpoint_ns": "",
"thread_id": "1",
},
"metadata": AnyDict(),
"recursion_limit": 25,
"tags": [],
},
"task": {
"id": t.id,
"path": _convert_path(t.path),
},
"finally_send": None,
}
for c in reversed(history)
for t in c.tasks
]
# resume the thread
with DefaultProducer() as producer:
producer.send(
topics.orchestrator,
value=serde.dumps(MessageToOrchestrator(input=None, config=config)),
)
producer.flush()
orch_msgs, exec_msgs = drain_topics(topics, graph)
# check final state
state = graph.get_state(config)
assert state.next == ()
assert (
state.values
== graph_compare.invoke(None, {"configurable": {"thread_id": "2"}})
== [
"0",
"1",
"2|Control(goto=Send(node='2', arg=3))",
"2|Control(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
"3.1",
]
)
# check history
history = [c for c in graph.get_state_history(config)]
assert len(history) == 4
# check executions
# node "2" doesn't get called again, as we recover writes saved before
assert graph.builder.nodes["2"].runnable.func.ticks == 3
# node "flaky" gets called again, as it was interrupted
assert graph.builder.nodes["flaky"].runnable.func.ticks == 2
def _convert_path(
path: tuple[Union[str, int, tuple], ...],
) -> list[Union[str, int, list]]:
return list(_convert_path(p) if isinstance(p, tuple) else p for p in path)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/sdk-py/pyproject.toml`:
```toml
[tool.poetry]
name = "langgraph-sdk"
version = "0.1.42"
description = "SDK for interacting with LangGraph API"
authors = []
license = "MIT"
readme = "README.md"
repository = "https://www.github.com/langchain-ai/langgraph"
packages = [{ include = "langgraph_sdk" }]
[tool.poetry.dependencies]
python = "^3.9.0,<4.0"
httpx = ">=0.25.2"
orjson = ">=3.10.1"
[tool.poetry.group.dev.dependencies]
ruff = "^0.6.2"
codespell = "^2.2.0"
pytest = "^7.2.1"
pytest-asyncio = "^0.21.1"
pytest-mock = "^3.11.1"
pytest-watch = "^4.2.0"
mypy = "^1.10.0"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5 -vv"
asyncio_mode = "auto"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
lint.select = [
"E", # pycodestyle
"F", # Pyflakes
"UP", # pyupgrade
"B", # flake8-bugbear
"I", # isort
]
lint.ignore = ["E501", "B008", "UP007", "UP006"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/sdk-py/langgraph_sdk/client.py`:
```py
"""The LangGraph client implementations connect to the LangGraph API.
This module provides both asynchronous (LangGraphClient) and synchronous (SyncLanggraphClient)
clients to interacting with the LangGraph API's core resources such as
Assistants, Threads, Runs, and Cron jobs, as well as its persistent
document Store.
"""
from __future__ import annotations
import asyncio
import logging
import os
import sys
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
overload,
)
import httpx
import orjson
from httpx._types import QueryParamTypes
import langgraph_sdk
from langgraph_sdk.schema import (
All,
Assistant,
AssistantVersion,
CancelAction,
Checkpoint,
Command,
Config,
Cron,
DisconnectMode,
GraphSchema,
IfNotExists,
Item,
Json,
ListNamespaceResponse,
MultitaskStrategy,
OnCompletionBehavior,
OnConflictBehavior,
Run,
RunCreate,
RunStatus,
SearchItemsResponse,
StreamMode,
StreamPart,
Subgraphs,
Thread,
ThreadState,
ThreadStatus,
ThreadUpdateStateResponse,
)
from langgraph_sdk.sse import SSEDecoder, aiter_lines_raw, iter_lines_raw
logger = logging.getLogger(__name__)
RESERVED_HEADERS = ("x-api-key",)
def _get_api_key(api_key: Optional[str] = None) -> Optional[str]:
"""Get the API key from the environment.
Precedence:
1. explicit argument
2. LANGGRAPH_API_KEY
3. LANGSMITH_API_KEY
4. LANGCHAIN_API_KEY
"""
if api_key:
return api_key
for prefix in ["LANGGRAPH", "LANGSMITH", "LANGCHAIN"]:
if env := os.getenv(f"{prefix}_API_KEY"):
return env.strip().strip('"').strip("'")
return None # type: ignore
def get_headers(
api_key: Optional[str], custom_headers: Optional[dict[str, str]]
) -> dict[str, str]:
"""Combine api_key and custom user-provided headers."""
custom_headers = custom_headers or {}
for header in RESERVED_HEADERS:
if header in custom_headers:
raise ValueError(f"Cannot set reserved header '{header}'")
headers = {
"User-Agent": f"langgraph-sdk-py/{langgraph_sdk.__version__}",
**custom_headers,
}
api_key = _get_api_key(api_key)
if api_key:
headers["x-api-key"] = api_key
return headers
def orjson_default(obj: Any) -> Any:
if hasattr(obj, "model_dump") and callable(obj.model_dump):
return obj.model_dump()
elif hasattr(obj, "dict") and callable(obj.dict):
return obj.dict()
elif isinstance(obj, (set, frozenset)):
return list(obj)
else:
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
def get_client(
*,
url: Optional[str] = None,
api_key: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
) -> LangGraphClient:
"""Get a LangGraphClient instance.
Args:
url: The URL of the LangGraph API.
api_key: The API key. If not provided, it will be read from the environment.
Precedence:
1. explicit argument
2. LANGGRAPH_API_KEY
3. LANGSMITH_API_KEY
4. LANGCHAIN_API_KEY
headers: Optional custom headers
Returns:
LangGraphClient: The top-level client for accessing AssistantsClient,
ThreadsClient, RunsClient, and CronClient.
Example:
from langgraph_sdk import get_client
# get top-level LangGraphClient
client = get_client(url="http://localhost:8123")
# example usage: client.<model>.<method_name>()
assistants = await client.assistants.get(assistant_id="some_uuid")
"""
transport: Optional[httpx.AsyncBaseTransport] = None
if url is None:
try:
from langgraph_api.server import app # type: ignore
url = "http://api"
transport = httpx.ASGITransport(app, root_path="/noauth")
except Exception:
url = "http://localhost:8123"
if transport is None:
transport = httpx.AsyncHTTPTransport(retries=5)
client = httpx.AsyncClient(
base_url=url,
transport=transport,
timeout=httpx.Timeout(connect=5, read=300, write=300, pool=5),
headers=get_headers(api_key, headers),
)
return LangGraphClient(client)
class LangGraphClient:
"""Top-level client for LangGraph API.
Attributes:
assistants: Manages versioned configuration for your graphs.
threads: Handles (potentially) multi-turn interactions, such as conversational threads.
runs: Controls individual invocations of the graph.
crons: Manages scheduled operations.
store: Interfaces with persistent, shared data storage.
"""
def __init__(self, client: httpx.AsyncClient) -> None:
self.http = HttpClient(client)
self.assistants = AssistantsClient(self.http)
self.threads = ThreadsClient(self.http)
self.runs = RunsClient(self.http)
self.crons = CronClient(self.http)
self.store = StoreClient(self.http)
class HttpClient:
"""Handle async requests to the LangGraph API.
Adds additional error messaging & content handling above the
provided httpx client.
Attributes:
client (httpx.AsyncClient): Underlying HTTPX async client.
"""
def __init__(self, client: httpx.AsyncClient) -> None:
self.client = client
async def get(self, path: str, *, params: Optional[QueryParamTypes] = None) -> Any:
"""Send a GET request."""
r = await self.client.get(path, params=params)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
body = (await r.aread()).decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
return await adecode_json(r)
async def post(self, path: str, *, json: Optional[dict]) -> Any:
"""Send a POST request."""
if json is not None:
headers, content = await aencode_json(json)
else:
headers, content = {}, b""
r = await self.client.post(path, headers=headers, content=content)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
body = (await r.aread()).decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
return await adecode_json(r)
async def put(self, path: str, *, json: dict) -> Any:
"""Send a PUT request."""
headers, content = await aencode_json(json)
r = await self.client.put(path, headers=headers, content=content)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
body = (await r.aread()).decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
return await adecode_json(r)
async def patch(self, path: str, *, json: dict) -> Any:
"""Send a PATCH request."""
headers, content = await aencode_json(json)
r = await self.client.patch(path, headers=headers, content=content)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
body = (await r.aread()).decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
return await adecode_json(r)
async def delete(self, path: str, *, json: Optional[Any] = None) -> None:
"""Send a DELETE request."""
r = await self.client.request("DELETE", path, json=json)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
body = (await r.aread()).decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
async def stream(
self, path: str, method: str, *, json: Optional[dict] = None
) -> AsyncIterator[StreamPart]:
"""Stream results using SSE."""
headers, content = await aencode_json(json)
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"
async with self.client.stream(
method, path, headers=headers, content=content
) as res:
# check status
try:
res.raise_for_status()
except httpx.HTTPStatusError as e:
body = (await res.aread()).decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
# check content type
content_type = res.headers.get("content-type", "").partition(";")[0]
if "text/event-stream" not in content_type:
raise httpx.TransportError(
"Expected response header Content-Type to contain 'text/event-stream', "
f"got {content_type!r}"
)
# parse SSE
decoder = SSEDecoder()
async for line in aiter_lines_raw(res):
sse = decoder.decode(line=line.rstrip(b"\n"))
if sse is not None:
yield sse
async def aencode_json(json: Any) -> tuple[dict[str, str], bytes]:
body = await asyncio.get_running_loop().run_in_executor(
None,
orjson.dumps,
json,
orjson_default,
orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS,
)
content_length = str(len(body))
content_type = "application/json"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, body
async def adecode_json(r: httpx.Response) -> Any:
body = await r.aread()
return (
await asyncio.get_running_loop().run_in_executor(None, orjson.loads, body)
if body
else None
)
class AssistantsClient:
"""Client for managing assistants in LangGraph.
This class provides methods to interact with assistants,
which are versioned configurations of your graph.
Example:
client = get_client()
assistant = await client.assistants.get("assistant_id_123")
"""
def __init__(self, http: HttpClient) -> None:
self.http = http
async def get(self, assistant_id: str) -> Assistant:
"""Get an assistant by ID.
Args:
assistant_id: The ID of the assistant to get.
Returns:
Assistant: Assistant Object.
Example Usage:
assistant = await client.assistants.get(
assistant_id="my_assistant_id"
)
print(assistant)
----------------------------------------------------
{
'assistant_id': 'my_assistant_id',
'graph_id': 'agent',
'created_at': '2024-06-25T17:10:33.109781+00:00',
'updated_at': '2024-06-25T17:10:33.109781+00:00',
'config': {},
'metadata': {'created_by': 'system'}
}
""" # noqa: E501
return await self.http.get(f"/assistants/{assistant_id}")
async def get_graph(
self, assistant_id: str, *, xray: Union[int, bool] = False
) -> dict[str, list[dict[str, Any]]]:
"""Get the graph of an assistant by ID.
Args:
assistant_id: The ID of the assistant to get the graph of.
xray: Include graph representation of subgraphs. If an integer value is provided, only subgraphs with a depth less than or equal to the value will be included.
Returns:
Graph: The graph information for the assistant in JSON format.
Example Usage:
graph_info = await client.assistants.get_graph(
assistant_id="my_assistant_id"
)
print(graph_info)
--------------------------------------------------------------------------------------------------------------------------
{
'nodes':
[
{'id': '__start__', 'type': 'schema', 'data': '__start__'},
{'id': '__end__', 'type': 'schema', 'data': '__end__'},
{'id': 'agent','type': 'runnable','data': {'id': ['langgraph', 'utils', 'RunnableCallable'],'name': 'agent'}},
],
'edges':
[
{'source': '__start__', 'target': 'agent'},
{'source': 'agent','target': '__end__'}
]
}
""" # noqa: E501
return await self.http.get(
f"/assistants/{assistant_id}/graph", params={"xray": xray}
)
async def get_schemas(self, assistant_id: str) -> GraphSchema:
"""Get the schemas of an assistant by ID.
Args:
assistant_id: The ID of the assistant to get the schema of.
Returns:
GraphSchema: The graph schema for the assistant.
Example Usage:
schema = await client.assistants.get_schemas(
assistant_id="my_assistant_id"
)
print(schema)
----------------------------------------------------------------------------------------------------------------------------
{
'graph_id': 'agent',
'state_schema':
{
'title': 'LangGraphInput',
'$ref': '#/definitions/AgentState',
'definitions':
{
'BaseMessage':
{
'title': 'BaseMessage',
'description': 'Base abstract Message class. Messages are the inputs and outputs of ChatModels.',
'type': 'object',
'properties':
{
'content':
{
'title': 'Content',
'anyOf': [
{'type': 'string'},
{'type': 'array','items': {'anyOf': [{'type': 'string'}, {'type': 'object'}]}}
]
},
'additional_kwargs':
{
'title': 'Additional Kwargs',
'type': 'object'
},
'response_metadata':
{
'title': 'Response Metadata',
'type': 'object'
},
'type':
{
'title': 'Type',
'type': 'string'
},
'name':
{
'title': 'Name',
'type': 'string'
},
'id':
{
'title': 'Id',
'type': 'string'
}
},
'required': ['content', 'type']
},
'AgentState':
{
'title': 'AgentState',
'type': 'object',
'properties':
{
'messages':
{
'title': 'Messages',
'type': 'array',
'items': {'$ref': '#/definitions/BaseMessage'}
}
},
'required': ['messages']
}
}
},
'config_schema':
{
'title': 'Configurable',
'type': 'object',
'properties':
{
'model_name':
{
'title': 'Model Name',
'enum': ['anthropic', 'openai'],
'type': 'string'
}
}
}
}
""" # noqa: E501
return await self.http.get(f"/assistants/{assistant_id}/schemas")
async def get_subgraphs(
self, assistant_id: str, namespace: Optional[str] = None, recurse: bool = False
) -> Subgraphs:
"""Get the schemas of an assistant by ID.
Args:
assistant_id: The ID of the assistant to get the schema of.
Returns:
Subgraphs: The graph schema for the assistant.
""" # noqa: E501
if namespace is not None:
return await self.http.get(
f"/assistants/{assistant_id}/subgraphs/{namespace}",
params={"recurse": recurse},
)
else:
return await self.http.get(
f"/assistants/{assistant_id}/subgraphs",
params={"recurse": recurse},
)
async def create(
self,
graph_id: Optional[str],
config: Optional[Config] = None,
*,
metadata: Json = None,
assistant_id: Optional[str] = None,
if_exists: Optional[OnConflictBehavior] = None,
name: Optional[str] = None,
) -> Assistant:
"""Create a new assistant.
Useful when graph is configurable and you want to create different assistants based on different configurations.
Args:
graph_id: The ID of the graph the assistant should use. The graph ID is normally set in your langgraph.json configuration.
config: Configuration to use for the graph.
metadata: Metadata to add to assistant.
assistant_id: Assistant ID to use, will default to a random UUID if not provided.
if_exists: How to handle duplicate creation. Defaults to 'raise' under the hood.
Must be either 'raise' (raise error if duplicate), or 'do_nothing' (return existing assistant).
name: The name of the assistant. Defaults to 'Untitled' under the hood.
Returns:
Assistant: The created assistant.
Example Usage:
assistant = await client.assistants.create(
graph_id="agent",
config={"configurable": {"model_name": "openai"}},
metadata={"number":1},
assistant_id="my-assistant-id",
if_exists="do_nothing",
name="my_name"
)
""" # noqa: E501
payload: Dict[str, Any] = {
"graph_id": graph_id,
}
if config:
payload["config"] = config
if metadata:
payload["metadata"] = metadata
if assistant_id:
payload["assistant_id"] = assistant_id
if if_exists:
payload["if_exists"] = if_exists
if name:
payload["name"] = name
return await self.http.post("/assistants", json=payload)
async def update(
self,
assistant_id: str,
*,
graph_id: Optional[str] = None,
config: Optional[Config] = None,
metadata: Json = None,
name: Optional[str] = None,
) -> Assistant:
"""Update an assistant.
Use this to point to a different graph, update the configuration, or change the metadata of an assistant.
Args:
assistant_id: Assistant to update.
graph_id: The ID of the graph the assistant should use.
The graph ID is normally set in your langgraph.json configuration. If None, assistant will keep pointing to same graph.
config: Configuration to use for the graph.
metadata: Metadata to merge with existing assistant metadata.
name: The new name for the assistant.
Returns:
Assistant: The updated assistant.
Example Usage:
assistant = await client.assistants.update(
assistant_id='e280dad7-8618-443f-87f1-8e41841c180f',
graph_id="other-graph",
config={"configurable": {"model_name": "anthropic"}},
metadata={"number":2}
)
""" # noqa: E501
payload: Dict[str, Any] = {}
if graph_id:
payload["graph_id"] = graph_id
if config:
payload["config"] = config
if metadata:
payload["metadata"] = metadata
if name:
payload["name"] = name
return await self.http.patch(
f"/assistants/{assistant_id}",
json=payload,
)
async def delete(
self,
assistant_id: str,
) -> None:
"""Delete an assistant.
Args:
assistant_id: The assistant ID to delete.
Returns:
None
Example Usage:
await client.assistants.delete(
assistant_id="my_assistant_id"
)
""" # noqa: E501
await self.http.delete(f"/assistants/{assistant_id}")
async def search(
self,
*,
metadata: Json = None,
graph_id: Optional[str] = None,
limit: int = 10,
offset: int = 0,
) -> list[Assistant]:
"""Search for assistants.
Args:
metadata: Metadata to filter by. Exact match filter for each KV pair.
graph_id: The ID of the graph to filter by.
The graph ID is normally set in your langgraph.json configuration.
limit: The maximum number of results to return.
offset: The number of results to skip.
Returns:
list[Assistant]: A list of assistants.
Example Usage:
assistants = await client.assistants.search(
metadata = {"name":"my_name"},
graph_id="my_graph_id",
limit=5,
offset=5
)
"""
payload: Dict[str, Any] = {
"limit": limit,
"offset": offset,
}
if metadata:
payload["metadata"] = metadata
if graph_id:
payload["graph_id"] = graph_id
return await self.http.post(
"/assistants/search",
json=payload,
)
async def get_versions(
self,
assistant_id: str,
metadata: Json = None,
limit: int = 10,
offset: int = 0,
) -> list[AssistantVersion]:
"""List all versions of an assistant.
Args:
assistant_id: The assistant ID to get versions for.
metadata: Metadata to filter versions by. Exact match filter for each KV pair.
limit: The maximum number of versions to return.
offset: The number of versions to skip.
Returns:
list[Assistant]: A list of assistants.
Example Usage:
assistant_versions = await client.assistants.get_versions(
assistant_id="my_assistant_id"
)
""" # noqa: E501
payload: Dict[str, Any] = {
"limit": limit,
"offset": offset,
}
if metadata:
payload["metadata"] = metadata
return await self.http.post(
f"/assistants/{assistant_id}/versions", json=payload
)
async def set_latest(self, assistant_id: str, version: int) -> Assistant:
"""Change the version of an assistant.
Args:
assistant_id: The assistant ID to delete.
version: The version to change to.
Returns:
Assistant: Assistant Object.
Example Usage:
new_version_assistant = await client.assistants.set_latest(
assistant_id="my_assistant_id",
version=3
)
""" # noqa: E501
payload: Dict[str, Any] = {"version": version}
return await self.http.post(f"/assistants/{assistant_id}/latest", json=payload)
class ThreadsClient:
"""Client for managing threads in LangGraph.
A thread maintains the state of a graph across multiple interactions/invocations (aka runs).
It accumulates and persists the graph's state, allowing for continuity between separate
invocations of the graph.
Example:
client = get_client()
new_thread = await client.threads.create(metadata={"user_id": "123"})
"""
def __init__(self, http: HttpClient) -> None:
self.http = http
async def get(self, thread_id: str) -> Thread:
"""Get a thread by ID.
Args:
thread_id: The ID of the thread to get.
Returns:
Thread: Thread object.
Example Usage:
thread = await client.threads.get(
thread_id="my_thread_id"
)
print(thread)
-----------------------------------------------------
{
'thread_id': 'my_thread_id',
'created_at': '2024-07-18T18:35:15.540834+00:00',
'updated_at': '2024-07-18T18:35:15.540834+00:00',
'metadata': {'graph_id': 'agent'}
}
""" # noqa: E501
return await self.http.get(f"/threads/{thread_id}")
async def create(
self,
*,
metadata: Json = None,
thread_id: Optional[str] = None,
if_exists: Optional[OnConflictBehavior] = None,
) -> Thread:
"""Create a new thread.
Args:
metadata: Metadata to add to thread.
thread_id: ID of thread.
If None, ID will be a randomly generated UUID.
if_exists: How to handle duplicate creation. Defaults to 'raise' under the hood.
Must be either 'raise' (raise error if duplicate), or 'do_nothing' (return existing thread).
Returns:
Thread: The created thread.
Example Usage:
thread = await client.threads.create(
metadata={"number":1},
thread_id="my-thread-id",
if_exists="raise"
)
""" # noqa: E501
payload: Dict[str, Any] = {}
if thread_id:
payload["thread_id"] = thread_id
if metadata:
payload["metadata"] = metadata
if if_exists:
payload["if_exists"] = if_exists
return await self.http.post("/threads", json=payload)
async def update(self, thread_id: str, *, metadata: dict[str, Any]) -> Thread:
"""Update a thread.
Args:
thread_id: ID of thread to update.
metadata: Metadata to merge with existing thread metadata.
Returns:
Thread: The created thread.
Example Usage:
thread = await client.threads.update(
thread_id="my-thread-id",
metadata={"number":1},
)
""" # noqa: E501
return await self.http.patch(
f"/threads/{thread_id}", json={"metadata": metadata}
)
async def delete(self, thread_id: str) -> None:
"""Delete a thread.
Args:
thread_id: The ID of the thread to delete.
Returns:
None
Example Usage:
await client.threads.delete(
thread_id="my_thread_id"
)
""" # noqa: E501
await self.http.delete(f"/threads/{thread_id}")
async def search(
self,
*,
metadata: Json = None,
values: Json = None,
status: Optional[ThreadStatus] = None,
limit: int = 10,
offset: int = 0,
) -> list[Thread]:
"""Search for threads.
Args:
metadata: Thread metadata to filter on.
values: State values to filter on.
status: Thread status to filter on.
Must be one of 'idle', 'busy', 'interrupted' or 'error'.
limit: Limit on number of threads to return.
offset: Offset in threads table to start search from.
Returns:
list[Thread]: List of the threads matching the search parameters.
Example Usage:
threads = await client.threads.search(
metadata={"number":1},
status="interrupted",
limit=15,
offset=5
)
""" # noqa: E501
payload: Dict[str, Any] = {
"limit": limit,
"offset": offset,
}
if metadata:
payload["metadata"] = metadata
if values:
payload["values"] = values
if status:
payload["status"] = status
return await self.http.post(
"/threads/search",
json=payload,
)
async def copy(self, thread_id: str) -> None:
"""Copy a thread.
Args:
thread_id: The ID of the thread to copy.
Returns:
None
Example Usage:
await client.threads.copy(
thread_id="my_thread_id"
)
""" # noqa: E501
return await self.http.post(f"/threads/{thread_id}/copy", json=None)
async def get_state(
self,
thread_id: str,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None, # deprecated
*,
subgraphs: bool = False,
) -> ThreadState:
"""Get the state of a thread.
Args:
thread_id: The ID of the thread to get the state of.
checkpoint: The checkpoint to get the state of.
subgraphs: Include subgraphs states.
Returns:
ThreadState: the thread of the state.
Example Usage:
thread_state = await client.threads.get_state(
thread_id="my_thread_id",
checkpoint_id="my_checkpoint_id"
)
print(thread_state)
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
{
'values': {
'messages': [
{
'content': 'how are you?',
'additional_kwargs': {},
'response_metadata': {},
'type': 'human',
'name': None,
'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10',
'example': False
},
{
'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.",
'additional_kwargs': {},
'response_metadata': {},
'type': 'ai',
'name': None,
'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b',
'example': False,
'tool_calls': [],
'invalid_tool_calls': [],
'usage_metadata': None
}
]
},
'next': [],
'checkpoint':
{
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'checkpoint_ns': '',
'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1'
}
'metadata':
{
'step': 1,
'run_id': '1ef4a9b8-d7da-679a-a45a-872054341df2',
'source': 'loop',
'writes':
{
'agent':
{
'messages': [
{
'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b',
'name': None,
'type': 'ai',
'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.",
'example': False,
'tool_calls': [],
'usage_metadata': None,
'additional_kwargs': {},
'response_metadata': {},
'invalid_tool_calls': []
}
]
}
},
'user_id': None,
'graph_id': 'agent',
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'created_by': 'system',
'assistant_id': 'fe096781-5601-53d2-b2f6-0d3403f7e9ca'},
'created_at': '2024-07-25T15:35:44.184703+00:00',
'parent_config':
{
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'checkpoint_ns': '',
'checkpoint_id': '1ef4a9b8-d80d-6fa7-8000-9300467fad0f'
}
}
""" # noqa: E501
if checkpoint:
return await self.http.post(
f"/threads/{thread_id}/state/checkpoint",
json={"checkpoint": checkpoint, "subgraphs": subgraphs},
)
elif checkpoint_id:
return await self.http.get(
f"/threads/{thread_id}/state/{checkpoint_id}",
params={"subgraphs": subgraphs},
)
else:
return await self.http.get(
f"/threads/{thread_id}/state",
params={"subgraphs": subgraphs},
)
async def update_state(
self,
thread_id: str,
values: Optional[Union[dict, Sequence[dict]]],
*,
as_node: Optional[str] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None, # deprecated
) -> ThreadUpdateStateResponse:
"""Update the state of a thread.
Args:
thread_id: The ID of the thread to update.
values: The values to update the state with.
as_node: Update the state as if this node had just executed.
checkpoint: The checkpoint to update the state of.
Returns:
ThreadUpdateStateResponse: Response after updating a thread's state.
Example Usage:
response = await client.threads.update_state(
thread_id="my_thread_id",
values={"messages":[{"role": "user", "content": "hello!"}]},
as_node="my_node",
)
print(response)
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
{
'checkpoint': {
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'checkpoint_ns': '',
'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1',
'checkpoint_map': {}
}
}
""" # noqa: E501
payload: Dict[str, Any] = {
"values": values,
}
if checkpoint_id:
payload["checkpoint_id"] = checkpoint_id
if checkpoint:
payload["checkpoint"] = checkpoint
if as_node:
payload["as_node"] = as_node
return await self.http.post(f"/threads/{thread_id}/state", json=payload)
async def get_history(
self,
thread_id: str,
*,
limit: int = 10,
before: Optional[str | Checkpoint] = None,
metadata: Optional[dict] = None,
checkpoint: Optional[Checkpoint] = None,
) -> list[ThreadState]:
"""Get the state history of a thread.
Args:
thread_id: The ID of the thread to get the state history for.
checkpoint: Return states for this subgraph. If empty defaults to root.
limit: The maximum number of states to return.
before: Return states before this checkpoint.
metadata: Filter states by metadata key-value pairs.
Returns:
list[ThreadState]: the state history of the thread.
Example Usage:
thread_state = await client.threads.get_history(
thread_id="my_thread_id",
limit=5,
)
""" # noqa: E501
payload: Dict[str, Any] = {
"limit": limit,
}
if before:
payload["before"] = before
if metadata:
payload["metadata"] = metadata
if checkpoint:
payload["checkpoint"] = checkpoint
return await self.http.post(f"/threads/{thread_id}/history", json=payload)
class RunsClient:
"""Client for managing runs in LangGraph.
A run is a single assistant invocation with optional input, config, and metadata.
This client manages runs, which can be stateful (on threads) or stateless.
Example:
client = get_client()
run = await client.runs.create(assistant_id="asst_123", thread_id="thread_456", input={"query": "Hello"})
"""
def __init__(self, http: HttpClient) -> None:
self.http = http
@overload
def stream(
self,
thread_id: str,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
feedback_keys: Optional[Sequence[str]] = None,
on_disconnect: Optional[DisconnectMode] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> AsyncIterator[StreamPart]: ...
@overload
def stream(
self,
thread_id: None,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
feedback_keys: Optional[Sequence[str]] = None,
on_disconnect: Optional[DisconnectMode] = None,
on_completion: Optional[OnCompletionBehavior] = None,
if_not_exists: Optional[IfNotExists] = None,
webhook: Optional[str] = None,
after_seconds: Optional[int] = None,
) -> AsyncIterator[StreamPart]: ...
def stream(
self,
thread_id: Optional[str],
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
feedback_keys: Optional[Sequence[str]] = None,
on_disconnect: Optional[DisconnectMode] = None,
on_completion: Optional[OnCompletionBehavior] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> AsyncIterator[StreamPart]:
"""Create a run and stream the results.
Args:
thread_id: the thread ID to assign to the thread.
If None will create a stateless run.
assistant_id: The assistant ID or graph name to stream from.
If using graph name, will default to first assistant created from that graph.
input: The input to the graph.
command: A command to execute. Cannot be combined with input.
stream_mode: The stream mode(s) to use.
stream_subgraphs: Whether to stream output from subgraphs.
metadata: Metadata to assign to the run.
config: The configuration for the assistant.
checkpoint: The checkpoint to resume from.
interrupt_before: Nodes to interrupt immediately before they get executed.
interrupt_after: Nodes to Nodes to interrupt immediately after they get executed.
feedback_keys: Feedback keys to assign to run.
on_disconnect: The disconnect mode to use.
Must be one of 'cancel' or 'continue'.
on_completion: Whether to delete or keep the thread created for a stateless run.
Must be one of 'delete' or 'keep'.
webhook: Webhook to call after LangGraph API call is done.
multitask_strategy: Multitask strategy to use.
Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'.
if_not_exists: How to handle missing thread. Defaults to 'reject'.
Must be either 'reject' (raise error if missing), or 'create' (create new thread).
after_seconds: The number of seconds to wait before starting the run.
Use to schedule future runs.
Returns:
AsyncIterator[StreamPart]: Asynchronous iterator of stream results.
Example Usage:
async for chunk in client.runs.stream(
thread_id=None,
assistant_id="agent",
input={"messages": [{"role": "user", "content": "how are you?"}]},
stream_mode=["values","debug"],
metadata={"name":"my_run"},
config={"configurable": {"model_name": "anthropic"}},
interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"],
interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"],
feedback_keys=["my_feedback_key_1","my_feedback_key_2"],
webhook="https://my.fake.webhook.com",
multitask_strategy="interrupt"
):
print(chunk)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
StreamPart(event='metadata', data={'run_id': '1ef4a9b8-d7da-679a-a45a-872054341df2'})
StreamPart(event='values', data={'messages': [{'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10', 'example': False}]})
StreamPart(event='values', data={'messages': [{'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10', 'example': False}, {'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.", 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'ai', 'name': None, 'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b', 'example': False, 'tool_calls': [], 'invalid_tool_calls': [], 'usage_metadata': None}]})
StreamPart(event='end', data=None)
""" # noqa: E501
payload = {
"input": input,
"command": command,
"config": config,
"metadata": metadata,
"stream_mode": stream_mode,
"stream_subgraphs": stream_subgraphs,
"assistant_id": assistant_id,
"interrupt_before": interrupt_before,
"interrupt_after": interrupt_after,
"feedback_keys": feedback_keys,
"webhook": webhook,
"checkpoint": checkpoint,
"checkpoint_id": checkpoint_id,
"multitask_strategy": multitask_strategy,
"if_not_exists": if_not_exists,
"on_disconnect": on_disconnect,
"on_completion": on_completion,
"after_seconds": after_seconds,
}
endpoint = (
f"/threads/{thread_id}/runs/stream"
if thread_id is not None
else "/runs/stream"
)
return self.http.stream(
endpoint, "POST", json={k: v for k, v in payload.items() if v is not None}
)
@overload
async def create(
self,
thread_id: None,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
on_completion: Optional[OnCompletionBehavior] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> Run: ...
@overload
async def create(
self,
thread_id: str,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> Run: ...
async def create(
self,
thread_id: Optional[str],
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
if_not_exists: Optional[IfNotExists] = None,
on_completion: Optional[OnCompletionBehavior] = None,
after_seconds: Optional[int] = None,
) -> Run:
"""Create a background run.
Args:
thread_id: the thread ID to assign to the thread.
If None will create a stateless run.
assistant_id: The assistant ID or graph name to stream from.
If using graph name, will default to first assistant created from that graph.
input: The input to the graph.
command: A command to execute. Cannot be combined with input.
stream_mode: The stream mode(s) to use.
stream_subgraphs: Whether to stream output from subgraphs.
metadata: Metadata to assign to the run.
config: The configuration for the assistant.
checkpoint: The checkpoint to resume from.
interrupt_before: Nodes to interrupt immediately before they get executed.
interrupt_after: Nodes to Nodes to interrupt immediately after they get executed.
webhook: Webhook to call after LangGraph API call is done.
multitask_strategy: Multitask strategy to use.
Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'.
on_completion: Whether to delete or keep the thread created for a stateless run.
Must be one of 'delete' or 'keep'.
if_not_exists: How to handle missing thread. Defaults to 'reject'.
Must be either 'reject' (raise error if missing), or 'create' (create new thread).
after_seconds: The number of seconds to wait before starting the run.
Use to schedule future runs.
Returns:
Run: The created background run.
Example Usage:
background_run = await client.runs.create(
thread_id="my_thread_id",
assistant_id="my_assistant_id",
input={"messages": [{"role": "user", "content": "hello!"}]},
metadata={"name":"my_run"},
config={"configurable": {"model_name": "openai"}},
interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"],
interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"],
webhook="https://my.fake.webhook.com",
multitask_strategy="interrupt"
)
print(background_run)
--------------------------------------------------------------------------------
{
'run_id': 'my_run_id',
'thread_id': 'my_thread_id',
'assistant_id': 'my_assistant_id',
'created_at': '2024-07-25T15:35:42.598503+00:00',
'updated_at': '2024-07-25T15:35:42.598503+00:00',
'metadata': {},
'status': 'pending',
'kwargs':
{
'input':
{
'messages': [
{
'role': 'user',
'content': 'how are you?'
}
]
},
'config':
{
'metadata':
{
'created_by': 'system'
},
'configurable':
{
'run_id': 'my_run_id',
'user_id': None,
'graph_id': 'agent',
'thread_id': 'my_thread_id',
'checkpoint_id': None,
'model_name': "openai",
'assistant_id': 'my_assistant_id'
}
},
'webhook': "https://my.fake.webhook.com",
'temporary': False,
'stream_mode': ['values'],
'feedback_keys': None,
'interrupt_after': ["node_to_stop_after_1","node_to_stop_after_2"],
'interrupt_before': ["node_to_stop_before_1","node_to_stop_before_2"]
},
'multitask_strategy': 'interrupt'
}
""" # noqa: E501
payload = {
"input": input,
"command": command,
"stream_mode": stream_mode,
"stream_subgraphs": stream_subgraphs,
"config": config,
"metadata": metadata,
"assistant_id": assistant_id,
"interrupt_before": interrupt_before,
"interrupt_after": interrupt_after,
"webhook": webhook,
"checkpoint": checkpoint,
"checkpoint_id": checkpoint_id,
"multitask_strategy": multitask_strategy,
"if_not_exists": if_not_exists,
"on_completion": on_completion,
"after_seconds": after_seconds,
}
payload = {k: v for k, v in payload.items() if v is not None}
if thread_id:
return await self.http.post(f"/threads/{thread_id}/runs", json=payload)
else:
return await self.http.post("/runs", json=payload)
async def create_batch(self, payloads: list[RunCreate]) -> list[Run]:
"""Create a batch of stateless background runs."""
def filter_payload(payload: RunCreate):
return {k: v for k, v in payload.items() if v is not None}
payloads = [filter_payload(payload) for payload in payloads]
return await self.http.post("/runs/batch", json=payloads)
@overload
async def wait(
self,
thread_id: str,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
on_disconnect: Optional[DisconnectMode] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
raise_error: bool = True,
) -> Union[list[dict], dict[str, Any]]: ...
@overload
async def wait(
self,
thread_id: None,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
on_disconnect: Optional[DisconnectMode] = None,
on_completion: Optional[OnCompletionBehavior] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
raise_error: bool = True,
) -> Union[list[dict], dict[str, Any]]: ...
async def wait(
self,
thread_id: Optional[str],
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
on_disconnect: Optional[DisconnectMode] = None,
on_completion: Optional[OnCompletionBehavior] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
raise_error: bool = True,
) -> Union[list[dict], dict[str, Any]]:
"""Create a run, wait until it finishes and return the final state.
Args:
thread_id: the thread ID to create the run on.
If None will create a stateless run.
assistant_id: The assistant ID or graph name to run.
If using graph name, will default to first assistant created from that graph.
input: The input to the graph.
command: A command to execute. Cannot be combined with input.
metadata: Metadata to assign to the run.
config: The configuration for the assistant.
checkpoint: The checkpoint to resume from.
interrupt_before: Nodes to interrupt immediately before they get executed.
interrupt_after: Nodes to Nodes to interrupt immediately after they get executed.
webhook: Webhook to call after LangGraph API call is done.
on_disconnect: The disconnect mode to use.
Must be one of 'cancel' or 'continue'.
on_completion: Whether to delete or keep the thread created for a stateless run.
Must be one of 'delete' or 'keep'.
multitask_strategy: Multitask strategy to use.
Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'.
if_not_exists: How to handle missing thread. Defaults to 'reject'.
Must be either 'reject' (raise error if missing), or 'create' (create new thread).
after_seconds: The number of seconds to wait before starting the run.
Use to schedule future runs.
Returns:
Union[list[dict], dict[str, Any]]: The output of the run.
Example Usage:
final_state_of_run = await client.runs.wait(
thread_id=None,
assistant_id="agent",
input={"messages": [{"role": "user", "content": "how are you?"}]},
metadata={"name":"my_run"},
config={"configurable": {"model_name": "anthropic"}},
interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"],
interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"],
webhook="https://my.fake.webhook.com",
multitask_strategy="interrupt"
)
print(final_state_of_run)
-------------------------------------------------------------------------------------------------------------------------------------------
{
'messages': [
{
'content': 'how are you?',
'additional_kwargs': {},
'response_metadata': {},
'type': 'human',
'name': None,
'id': 'f51a862c-62fe-4866-863b-b0863e8ad78a',
'example': False
},
{
'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.",
'additional_kwargs': {},
'response_metadata': {},
'type': 'ai',
'name': None,
'id': 'run-bf1cd3c6-768f-4c16-b62d-ba6f17ad8b36',
'example': False,
'tool_calls': [],
'invalid_tool_calls': [],
'usage_metadata': None
}
]
}
""" # noqa: E501
payload = {
"input": input,
"command": command,
"config": config,
"metadata": metadata,
"assistant_id": assistant_id,
"interrupt_before": interrupt_before,
"interrupt_after": interrupt_after,
"webhook": webhook,
"checkpoint": checkpoint,
"checkpoint_id": checkpoint_id,
"multitask_strategy": multitask_strategy,
"if_not_exists": if_not_exists,
"on_disconnect": on_disconnect,
"on_completion": on_completion,
"after_seconds": after_seconds,
}
endpoint = (
f"/threads/{thread_id}/runs/wait" if thread_id is not None else "/runs/wait"
)
response = await self.http.post(
endpoint, json={k: v for k, v in payload.items() if v is not None}
)
if (
raise_error
and isinstance(response, dict)
and "__error__" in response
and isinstance(response["__error__"], dict)
):
raise Exception(
f"{response['__error__'].get('error')}: {response['__error__'].get('message')}"
)
return response
async def list(
self,
thread_id: str,
*,
limit: int = 10,
offset: int = 0,
status: Optional[RunStatus] = None,
) -> List[Run]:
"""List runs.
Args:
thread_id: The thread ID to list runs for.
limit: The maximum number of results to return.
offset: The number of results to skip.
status: The status of the run to filter by.
Returns:
List[Run]: The runs for the thread.
Example Usage:
await client.runs.delete(
thread_id="thread_id_to_delete",
limit=5,
offset=5,
)
""" # noqa: E501
params = {
"limit": limit,
"offset": offset,
}
if status is not None:
params["status"] = status
return await self.http.get(f"/threads/{thread_id}/runs", params=params)
async def get(self, thread_id: str, run_id: str) -> Run:
"""Get a run.
Args:
thread_id: The thread ID to get.
run_id: The run ID to get.
Returns:
Run: Run object.
Example Usage:
run = await client.runs.get(
thread_id="thread_id_to_delete",
run_id="run_id_to_delete",
)
""" # noqa: E501
return await self.http.get(f"/threads/{thread_id}/runs/{run_id}")
async def cancel(
self,
thread_id: str,
run_id: str,
*,
wait: bool = False,
action: CancelAction = "interrupt",
) -> None:
"""Get a run.
Args:
thread_id: The thread ID to cancel.
run_id: The run ID to cancek.
wait: Whether to wait until run has completed.
action: Action to take when cancelling the run. Possible values
are `interrupt` or `rollback`. Default is `interrupt`.
Returns:
None
Example Usage:
await client.runs.cancel(
thread_id="thread_id_to_cancel",
run_id="run_id_to_cancel",
wait=True,
action="interrupt"
)
""" # noqa: E501
return await self.http.post(
f"/threads/{thread_id}/runs/{run_id}/cancel?wait={1 if wait else 0}&action={action}",
json=None,
)
async def join(self, thread_id: str, run_id: str) -> dict:
"""Block until a run is done. Returns the final state of the thread.
Args:
thread_id: The thread ID to join.
run_id: The run ID to join.
Returns:
None
Example Usage:
result =await client.runs.join(
thread_id="thread_id_to_join",
run_id="run_id_to_join"
)
""" # noqa: E501
return await self.http.get(f"/threads/{thread_id}/runs/{run_id}/join")
def join_stream(
self, thread_id: str, run_id: str, *, cancel_on_disconnect: bool = False
) -> AsyncIterator[StreamPart]:
"""Stream output from a run in real-time, until the run is done.
Output is not buffered, so any output produced before this call will
not be received here.
Args:
thread_id: The thread ID to join.
run_id: The run ID to join.
cancel_on_disconnect: Whether to cancel the run when the stream is disconnected.
Returns:
None
Example Usage:
await client.runs.join_stream(
thread_id="thread_id_to_join",
run_id="run_id_to_join"
)
""" # noqa: E501
return self.http.stream(
f"/threads/{thread_id}/runs/{run_id}/stream",
"GET",
params={"cancel_on_disconnect": cancel_on_disconnect},
)
async def delete(self, thread_id: str, run_id: str) -> None:
"""Delete a run.
Args:
thread_id: The thread ID to delete.
run_id: The run ID to delete.
Returns:
None
Example Usage:
await client.runs.delete(
thread_id="thread_id_to_delete",
run_id="run_id_to_delete"
)
""" # noqa: E501
await self.http.delete(f"/threads/{thread_id}/runs/{run_id}")
class CronClient:
"""Client for managing recurrent runs (cron jobs) in LangGraph.
A run is a single invocation of an assistant with optional input and config.
This client allows scheduling recurring runs to occur automatically.
Example:
client = get_client()
cron_job = await client.crons.create_for_thread(
thread_id="thread_123",
assistant_id="asst_456",
schedule="0 9 * * *",
input={"message": "Daily update"}
)
"""
def __init__(self, http_client: HttpClient) -> None:
self.http = http_client
async def create_for_thread(
self,
thread_id: str,
assistant_id: str,
*,
schedule: str,
input: Optional[dict] = None,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
interrupt_before: Optional[Union[All, list[str]]] = None,
interrupt_after: Optional[Union[All, list[str]]] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[str] = None,
) -> Run:
"""Create a cron job for a thread.
Args:
thread_id: the thread ID to run the cron job on.
assistant_id: The assistant ID or graph name to use for the cron job.
If using graph name, will default to first assistant created from that graph.
schedule: The cron schedule to execute this job on.
input: The input to the graph.
metadata: Metadata to assign to the cron job runs.
config: The configuration for the assistant.
interrupt_before: Nodes to interrupt immediately before they get executed.
interrupt_after: Nodes to Nodes to interrupt immediately after they get executed.
webhook: Webhook to call after LangGraph API call is done.
multitask_strategy: Multitask strategy to use.
Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'.
Returns:
Run: The cron run.
Example Usage:
cron_run = await client.crons.create_for_thread(
thread_id="my-thread-id",
assistant_id="agent",
schedule="27 15 * * *",
input={"messages": [{"role": "user", "content": "hello!"}]},
metadata={"name":"my_run"},
config={"configurable": {"model_name": "openai"}},
interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"],
interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"],
webhook="https://my.fake.webhook.com",
multitask_strategy="interrupt"
)
""" # noqa: E501
payload = {
"schedule": schedule,
"input": input,
"config": config,
"metadata": metadata,
"assistant_id": assistant_id,
"interrupt_before": interrupt_before,
"interrupt_after": interrupt_after,
"webhook": webhook,
}
if multitask_strategy:
payload["multitask_strategy"] = multitask_strategy
payload = {k: v for k, v in payload.items() if v is not None}
return await self.http.post(f"/threads/{thread_id}/runs/crons", json=payload)
async def create(
self,
assistant_id: str,
*,
schedule: str,
input: Optional[dict] = None,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
interrupt_before: Optional[Union[All, list[str]]] = None,
interrupt_after: Optional[Union[All, list[str]]] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[str] = None,
) -> Run:
"""Create a cron run.
Args:
assistant_id: The assistant ID or graph name to use for the cron job.
If using graph name, will default to first assistant created from that graph.
schedule: The cron schedule to execute this job on.
input: The input to the graph.
metadata: Metadata to assign to the cron job runs.
config: The configuration for the assistant.
interrupt_before: Nodes to interrupt immediately before they get executed.
interrupt_after: Nodes to Nodes to interrupt immediately after they get executed.
webhook: Webhook to call after LangGraph API call is done.
multitask_strategy: Multitask strategy to use.
Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'.
Returns:
Run: The cron run.
Example Usage:
cron_run = client.crons.create(
assistant_id="agent",
schedule="27 15 * * *",
input={"messages": [{"role": "user", "content": "hello!"}]},
metadata={"name":"my_run"},
config={"configurable": {"model_name": "openai"}},
interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"],
interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"],
webhook="https://my.fake.webhook.com",
multitask_strategy="interrupt"
)
""" # noqa: E501
payload = {
"schedule": schedule,
"input": input,
"config": config,
"metadata": metadata,
"assistant_id": assistant_id,
"interrupt_before": interrupt_before,
"interrupt_after": interrupt_after,
"webhook": webhook,
}
if multitask_strategy:
payload["multitask_strategy"] = multitask_strategy
payload = {k: v for k, v in payload.items() if v is not None}
return await self.http.post("/runs/crons", json=payload)
async def delete(self, cron_id: str) -> None:
"""Delete a cron.
Args:
cron_id: The cron ID to delete.
Returns:
None
Example Usage:
await client.crons.delete(
cron_id="cron_to_delete"
)
""" # noqa: E501
await self.http.delete(f"/runs/crons/{cron_id}")
async def search(
self,
*,
assistant_id: Optional[str] = None,
thread_id: Optional[str] = None,
limit: int = 10,
offset: int = 0,
) -> list[Cron]:
"""Get a list of cron jobs.
Args:
assistant_id: The assistant ID or graph name to search for.
thread_id: the thread ID to search for.
limit: The maximum number of results to return.
offset: The number of results to skip.
Returns:
list[Cron]: The list of cron jobs returned by the search,
Example Usage:
cron_jobs = await client.crons.search(
assistant_id="my_assistant_id",
thread_id="my_thread_id",
limit=5,
offset=5,
)
print(cron_jobs)
----------------------------------------------------------
[
{
'cron_id': '1ef3cefa-4c09-6926-96d0-3dc97fd5e39b',
'assistant_id': 'my_assistant_id',
'thread_id': 'my_thread_id',
'user_id': None,
'payload':
{
'input': {'start_time': ''},
'schedule': '4 * * * *',
'assistant_id': 'my_assistant_id'
},
'schedule': '4 * * * *',
'next_run_date': '2024-07-25T17:04:00+00:00',
'end_time': None,
'created_at': '2024-07-08T06:02:23.073257+00:00',
'updated_at': '2024-07-08T06:02:23.073257+00:00'
}
]
""" # noqa: E501
payload = {
"assistant_id": assistant_id,
"thread_id": thread_id,
"limit": limit,
"offset": offset,
}
payload = {k: v for k, v in payload.items() if v is not None}
return await self.http.post("/runs/crons/search", json=payload)
class StoreClient:
"""Client for interacting with the graph's shared storage.
The Store provides a key-value storage system for persisting data across graph executions,
allowing for stateful operations and data sharing across threads.
Example:
client = get_client()
await client.store.put_item(["users", "user123"], "mem-123451342", {"name": "Alice", "score": 100})
"""
def __init__(self, http: HttpClient) -> None:
self.http = http
async def put_item(
self,
namespace: Sequence[str],
/,
key: str,
value: dict[str, Any],
index: Optional[Union[Literal[False], list[str]]] = None,
) -> None:
"""Store or update an item.
Args:
namespace: A list of strings representing the namespace path.
key: The unique identifier for the item within the namespace.
value: A dictionary containing the item's data.
index: Controls search indexing - None (use defaults), False (disable), or list of field paths to index.
Returns:
None
Example Usage:
await client.store.put_item(
["documents", "user123"],
key="item456",
value={"title": "My Document", "content": "Hello World"}
)
"""
for label in namespace:
if "." in label:
raise ValueError(
f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')."
)
payload = {"namespace": namespace, "key": key, "value": value, "index": index}
await self.http.put("/store/items", json=payload)
async def get_item(self, namespace: Sequence[str], /, key: str) -> Item:
"""Retrieve a single item.
Args:
key: The unique identifier for the item.
namespace: Optional list of strings representing the namespace path.
Returns:
Item: The retrieved item.
Example Usage:
item = await client.store.get_item(
["documents", "user123"],
key="item456",
)
print(item)
----------------------------------------------------------------
{
'namespace': ['documents', 'user123'],
'key': 'item456',
'value': {'title': 'My Document', 'content': 'Hello World'},
'created_at': '2024-07-30T12:00:00Z',
'updated_at': '2024-07-30T12:00:00Z'
}
"""
for label in namespace:
if "." in label:
raise ValueError(
f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')."
)
return await self.http.get(
"/store/items", params={"namespace": ".".join(namespace), "key": key}
)
async def delete_item(self, namespace: Sequence[str], /, key: str) -> None:
"""Delete an item.
Args:
key: The unique identifier for the item.
namespace: Optional list of strings representing the namespace path.
Returns:
None
Example Usage:
await client.store.delete_item(
["documents", "user123"],
key="item456",
)
"""
await self.http.delete(
"/store/items", json={"namespace": namespace, "key": key}
)
async def search_items(
self,
namespace_prefix: Sequence[str],
/,
filter: Optional[dict[str, Any]] = None,
limit: int = 10,
offset: int = 0,
query: Optional[str] = None,
) -> SearchItemsResponse:
"""Search for items within a namespace prefix.
Args:
namespace_prefix: List of strings representing the namespace prefix.
filter: Optional dictionary of key-value pairs to filter results.
limit: Maximum number of items to return (default is 10).
offset: Number of items to skip before returning results (default is 0).
query: Optional query for natural language search.
Returns:
List[Item]: A list of items matching the search criteria.
Example Usage:
items = await client.store.search_items(
["documents"],
filter={"author": "John Doe"},
limit=5,
offset=0
)
print(items)
----------------------------------------------------------------
{
"items": [
{
"namespace": ["documents", "user123"],
"key": "item789",
"value": {
"title": "Another Document",
"author": "John Doe"
},
"created_at": "2024-07-30T12:00:00Z",
"updated_at": "2024-07-30T12:00:00Z"
},
# ... additional items ...
]
}
"""
payload = {
"namespace_prefix": namespace_prefix,
"filter": filter,
"limit": limit,
"offset": offset,
"query": query,
}
return await self.http.post("/store/items/search", json=_provided_vals(payload))
async def list_namespaces(
self,
prefix: Optional[List[str]] = None,
suffix: Optional[List[str]] = None,
max_depth: Optional[int] = None,
limit: int = 100,
offset: int = 0,
) -> ListNamespaceResponse:
"""List namespaces with optional match conditions.
Args:
prefix: Optional list of strings representing the prefix to filter namespaces.
suffix: Optional list of strings representing the suffix to filter namespaces.
max_depth: Optional integer specifying the maximum depth of namespaces to return.
limit: Maximum number of namespaces to return (default is 100).
offset: Number of namespaces to skip before returning results (default is 0).
Returns:
List[List[str]]: A list of namespaces matching the criteria.
Example Usage:
namespaces = await client.store.list_namespaces(
prefix=["documents"],
max_depth=3,
limit=10,
offset=0
)
print(namespaces)
----------------------------------------------------------------
[
["documents", "user123", "reports"],
["documents", "user456", "invoices"],
...
]
"""
payload = {
"prefix": prefix,
"suffix": suffix,
"max_depth": max_depth,
"limit": limit,
"offset": offset,
}
return await self.http.post("/store/namespaces", json=_provided_vals(payload))
def get_sync_client(
*,
url: Optional[str] = None,
api_key: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
) -> SyncLangGraphClient:
"""Get a synchronous LangGraphClient instance.
Args:
url: The URL of the LangGraph API.
api_key: The API key. If not provided, it will be read from the environment.
Precedence:
1. explicit argument
2. LANGGRAPH_API_KEY
3. LANGSMITH_API_KEY
4. LANGCHAIN_API_KEY
headers: Optional custom headers
Returns:
SyncLangGraphClient: The top-level synchronous client for accessing AssistantsClient,
ThreadsClient, RunsClient, and CronClient.
Example:
from langgraph_sdk import get_sync_client
# get top-level synchronous LangGraphClient
client = get_sync_client(url="http://localhost:8123")
# example usage: client.<model>.<method_name>()
assistant = client.assistants.get(assistant_id="some_uuid")
"""
if url is None:
url = "http://localhost:8123"
transport = httpx.HTTPTransport(retries=5)
client = httpx.Client(
base_url=url,
transport=transport,
timeout=httpx.Timeout(connect=5, read=300, write=300, pool=5),
headers=get_headers(api_key, headers),
)
return SyncLangGraphClient(client)
class SyncLangGraphClient:
"""Synchronous client for interacting with the LangGraph API.
This class provides synchronous access to LangGraph API endpoints for managing
assistants, threads, runs, cron jobs, and data storage.
Example:
client = get_sync_client()
assistant = client.assistants.get("asst_123")
"""
def __init__(self, client: httpx.Client) -> None:
self.http = SyncHttpClient(client)
self.assistants = SyncAssistantsClient(self.http)
self.threads = SyncThreadsClient(self.http)
self.runs = SyncRunsClient(self.http)
self.crons = SyncCronClient(self.http)
self.store = SyncStoreClient(self.http)
class SyncHttpClient:
def __init__(self, client: httpx.Client) -> None:
self.client = client
def get(self, path: str, *, params: Optional[QueryParamTypes] = None) -> Any:
"""Send a GET request."""
r = self.client.get(path, params=params)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
body = r.read().decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
return decode_json(r)
def post(self, path: str, *, json: Optional[dict]) -> Any:
"""Send a POST request."""
if json is not None:
headers, content = encode_json(json)
else:
headers, content = {}, b""
r = self.client.post(path, headers=headers, content=content)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
body = r.read().decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
return decode_json(r)
def put(self, path: str, *, json: dict) -> Any:
"""Send a PUT request."""
headers, content = encode_json(json)
r = self.client.put(path, headers=headers, content=content)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
body = r.read().decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
return decode_json(r)
def patch(self, path: str, *, json: dict) -> Any:
"""Send a PATCH request."""
headers, content = encode_json(json)
r = self.client.patch(path, headers=headers, content=content)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
body = r.read().decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
return decode_json(r)
def delete(self, path: str, *, json: Optional[Any] = None) -> None:
"""Send a DELETE request."""
r = self.client.request("DELETE", path, json=json)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
body = r.read().decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
def stream(
self, path: str, method: str, *, json: Optional[dict] = None
) -> Iterator[StreamPart]:
"""Stream the results of a request using SSE."""
headers, content = encode_json(json)
with self.client.stream(method, path, headers=headers, content=content) as res:
# check status
try:
res.raise_for_status()
except httpx.HTTPStatusError as e:
body = (res.read()).decode()
if sys.version_info >= (3, 11):
e.add_note(body)
else:
logger.error(f"Error from langgraph-api: {body}", exc_info=e)
raise e
# check content type
content_type = res.headers.get("content-type", "").partition(";")[0]
if "text/event-stream" not in content_type:
raise httpx.TransportError(
"Expected response header Content-Type to contain 'text/event-stream', "
f"got {content_type!r}"
)
# parse SSE
decoder = SSEDecoder()
for line in iter_lines_raw(res):
sse = decoder.decode(line.rstrip(b"\n"))
if sse is not None:
yield sse
def encode_json(json: Any) -> tuple[dict[str, str], bytes]:
body = orjson.dumps(
json,
orjson_default,
orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS,
)
content_length = str(len(body))
content_type = "application/json"
headers = {"Content-Length": content_length, "Content-Type": content_type}
return headers, body
def decode_json(r: httpx.Response) -> Any:
body = r.read()
return orjson.loads(body if body else None)
class SyncAssistantsClient:
"""Client for managing assistants in LangGraph synchronously.
This class provides methods to interact with assistants, which are versioned configurations of your graph.
Example:
client = get_client()
assistant = client.assistants.get("assistant_id_123")
"""
def __init__(self, http: SyncHttpClient) -> None:
self.http = http
def get(self, assistant_id: str) -> Assistant:
"""Get an assistant by ID.
Args:
assistant_id: The ID of the assistant to get.
Returns:
Assistant: Assistant Object.
Example Usage:
assistant = client.assistants.get(
assistant_id="my_assistant_id"
)
print(assistant)
----------------------------------------------------
{
'assistant_id': 'my_assistant_id',
'graph_id': 'agent',
'created_at': '2024-06-25T17:10:33.109781+00:00',
'updated_at': '2024-06-25T17:10:33.109781+00:00',
'config': {},
'metadata': {'created_by': 'system'}
}
""" # noqa: E501
return self.http.get(f"/assistants/{assistant_id}")
def get_graph(
self, assistant_id: str, *, xray: Union[int, bool] = False
) -> dict[str, list[dict[str, Any]]]:
"""Get the graph of an assistant by ID.
Args:
assistant_id: The ID of the assistant to get the graph of.
xray: Include graph representation of subgraphs. If an integer value is provided, only subgraphs with a depth less than or equal to the value will be included.
Returns:
Graph: The graph information for the assistant in JSON format.
Example Usage:
graph_info = client.assistants.get_graph(
assistant_id="my_assistant_id"
)
print(graph_info)
--------------------------------------------------------------------------------------------------------------------------
{
'nodes':
[
{'id': '__start__', 'type': 'schema', 'data': '__start__'},
{'id': '__end__', 'type': 'schema', 'data': '__end__'},
{'id': 'agent','type': 'runnable','data': {'id': ['langgraph', 'utils', 'RunnableCallable'],'name': 'agent'}},
],
'edges':
[
{'source': '__start__', 'target': 'agent'},
{'source': 'agent','target': '__end__'}
]
}
""" # noqa: E501
return self.http.get(f"/assistants/{assistant_id}/graph", params={"xray": xray})
def get_schemas(self, assistant_id: str) -> GraphSchema:
"""Get the schemas of an assistant by ID.
Args:
assistant_id: The ID of the assistant to get the schema of.
Returns:
GraphSchema: The graph schema for the assistant.
Example Usage:
schema = client.assistants.get_schemas(
assistant_id="my_assistant_id"
)
print(schema)
----------------------------------------------------------------------------------------------------------------------------
{
'graph_id': 'agent',
'state_schema':
{
'title': 'LangGraphInput',
'$ref': '#/definitions/AgentState',
'definitions':
{
'BaseMessage':
{
'title': 'BaseMessage',
'description': 'Base abstract Message class. Messages are the inputs and outputs of ChatModels.',
'type': 'object',
'properties':
{
'content':
{
'title': 'Content',
'anyOf': [
{'type': 'string'},
{'type': 'array','items': {'anyOf': [{'type': 'string'}, {'type': 'object'}]}}
]
},
'additional_kwargs':
{
'title': 'Additional Kwargs',
'type': 'object'
},
'response_metadata':
{
'title': 'Response Metadata',
'type': 'object'
},
'type':
{
'title': 'Type',
'type': 'string'
},
'name':
{
'title': 'Name',
'type': 'string'
},
'id':
{
'title': 'Id',
'type': 'string'
}
},
'required': ['content', 'type']
},
'AgentState':
{
'title': 'AgentState',
'type': 'object',
'properties':
{
'messages':
{
'title': 'Messages',
'type': 'array',
'items': {'$ref': '#/definitions/BaseMessage'}
}
},
'required': ['messages']
}
}
},
'config_schema':
{
'title': 'Configurable',
'type': 'object',
'properties':
{
'model_name':
{
'title': 'Model Name',
'enum': ['anthropic', 'openai'],
'type': 'string'
}
}
}
}
""" # noqa: E501
return self.http.get(f"/assistants/{assistant_id}/schemas")
def get_subgraphs(
self, assistant_id: str, namespace: Optional[str] = None, recurse: bool = False
) -> Subgraphs:
"""Get the schemas of an assistant by ID.
Args:
assistant_id: The ID of the assistant to get the schema of.
Returns:
Subgraphs: The graph schema for the assistant.
""" # noqa: E501
if namespace is not None:
return self.http.get(
f"/assistants/{assistant_id}/subgraphs/{namespace}",
params={"recurse": recurse},
)
else:
return self.http.get(
f"/assistants/{assistant_id}/subgraphs",
params={"recurse": recurse},
)
def create(
self,
graph_id: Optional[str],
config: Optional[Config] = None,
*,
metadata: Json = None,
assistant_id: Optional[str] = None,
if_exists: Optional[OnConflictBehavior] = None,
name: Optional[str] = None,
) -> Assistant:
"""Create a new assistant.
Useful when graph is configurable and you want to create different assistants based on different configurations.
Args:
graph_id: The ID of the graph the assistant should use. The graph ID is normally set in your langgraph.json configuration.
config: Configuration to use for the graph.
metadata: Metadata to add to assistant.
assistant_id: Assistant ID to use, will default to a random UUID if not provided.
if_exists: How to handle duplicate creation. Defaults to 'raise' under the hood.
Must be either 'raise' (raise error if duplicate), or 'do_nothing' (return existing assistant).
name: The name of the assistant. Defaults to 'Untitled' under the hood.
Returns:
Assistant: The created assistant.
Example Usage:
assistant = client.assistants.create(
graph_id="agent",
config={"configurable": {"model_name": "openai"}},
metadata={"number":1},
assistant_id="my-assistant-id",
if_exists="do_nothing",
name="my_name"
)
""" # noqa: E501
payload: Dict[str, Any] = {
"graph_id": graph_id,
}
if config:
payload["config"] = config
if metadata:
payload["metadata"] = metadata
if assistant_id:
payload["assistant_id"] = assistant_id
if if_exists:
payload["if_exists"] = if_exists
if name:
payload["name"] = name
return self.http.post("/assistants", json=payload)
def update(
self,
assistant_id: str,
*,
graph_id: Optional[str] = None,
config: Optional[Config] = None,
metadata: Json = None,
name: Optional[str] = None,
) -> Assistant:
"""Update an assistant.
Use this to point to a different graph, update the configuration, or change the metadata of an assistant.
Args:
assistant_id: Assistant to update.
graph_id: The ID of the graph the assistant should use.
The graph ID is normally set in your langgraph.json configuration. If None, assistant will keep pointing to same graph.
config: Configuration to use for the graph.
metadata: Metadata to merge with existing assistant metadata.
name: The new name for the assistant.
Returns:
Assistant: The updated assistant.
Example Usage:
assistant = client.assistants.update(
assistant_id='e280dad7-8618-443f-87f1-8e41841c180f',
graph_id="other-graph",
config={"configurable": {"model_name": "anthropic"}},
metadata={"number":2}
)
""" # noqa: E501
payload: Dict[str, Any] = {}
if graph_id:
payload["graph_id"] = graph_id
if config:
payload["config"] = config
if metadata:
payload["metadata"] = metadata
if name:
payload["name"] = name
return self.http.patch(
f"/assistants/{assistant_id}",
json=payload,
)
def delete(
self,
assistant_id: str,
) -> None:
"""Delete an assistant.
Args:
assistant_id: The assistant ID to delete.
Returns:
None
Example Usage:
client.assistants.delete(
assistant_id="my_assistant_id"
)
""" # noqa: E501
self.http.delete(f"/assistants/{assistant_id}")
def search(
self,
*,
metadata: Json = None,
graph_id: Optional[str] = None,
limit: int = 10,
offset: int = 0,
) -> list[Assistant]:
"""Search for assistants.
Args:
metadata: Metadata to filter by. Exact match filter for each KV pair.
graph_id: The ID of the graph to filter by.
The graph ID is normally set in your langgraph.json configuration.
limit: The maximum number of results to return.
offset: The number of results to skip.
Returns:
list[Assistant]: A list of assistants.
Example Usage:
assistants = client.assistants.search(
metadata = {"name":"my_name"},
graph_id="my_graph_id",
limit=5,
offset=5
)
"""
payload: Dict[str, Any] = {
"limit": limit,
"offset": offset,
}
if metadata:
payload["metadata"] = metadata
if graph_id:
payload["graph_id"] = graph_id
return self.http.post(
"/assistants/search",
json=payload,
)
def get_versions(
self,
assistant_id: str,
metadata: Json = None,
limit: int = 10,
offset: int = 0,
) -> list[AssistantVersion]:
"""List all versions of an assistant.
Args:
assistant_id: The assistant ID to get versions for.
metadata: Metadata to filter versions by. Exact match filter for each KV pair.
limit: The maximum number of versions to return.
offset: The number of versions to skip.
Returns:
list[Assistant]: A list of assistants.
Example Usage:
assistant_versions = await client.assistants.get_versions(
assistant_id="my_assistant_id"
)
""" # noqa: E501
payload: Dict[str, Any] = {
"limit": limit,
"offset": offset,
}
if metadata:
payload["metadata"] = metadata
return self.http.post(f"/assistants/{assistant_id}/versions", json=payload)
def set_latest(self, assistant_id: str, version: int) -> Assistant:
"""Change the version of an assistant.
Args:
assistant_id: The assistant ID to delete.
version: The version to change to.
Returns:
Assistant: Assistant Object.
Example Usage:
new_version_assistant = await client.assistants.set_latest(
assistant_id="my_assistant_id",
version=3
)
""" # noqa: E501
payload: Dict[str, Any] = {"version": version}
return self.http.post(f"/assistants/{assistant_id}/latest", json=payload)
class SyncThreadsClient:
"""Synchronous client for managing threads in LangGraph.
This class provides methods to create, retrieve, and manage threads,
which represent conversations or stateful interactions.
Example:
client = get_sync_client()
thread = client.threads.create(metadata={"user_id": "123"})
"""
def __init__(self, http: SyncHttpClient) -> None:
self.http = http
def get(self, thread_id: str) -> Thread:
"""Get a thread by ID.
Args:
thread_id: The ID of the thread to get.
Returns:
Thread: Thread object.
Example Usage:
thread = client.threads.get(
thread_id="my_thread_id"
)
print(thread)
-----------------------------------------------------
{
'thread_id': 'my_thread_id',
'created_at': '2024-07-18T18:35:15.540834+00:00',
'updated_at': '2024-07-18T18:35:15.540834+00:00',
'metadata': {'graph_id': 'agent'}
}
""" # noqa: E501
return self.http.get(f"/threads/{thread_id}")
def create(
self,
*,
metadata: Json = None,
thread_id: Optional[str] = None,
if_exists: Optional[OnConflictBehavior] = None,
) -> Thread:
"""Create a new thread.
Args:
metadata: Metadata to add to thread.
thread_id: ID of thread.
If None, ID will be a randomly generated UUID.
if_exists: How to handle duplicate creation. Defaults to 'raise' under the hood.
Must be either 'raise' (raise error if duplicate), or 'do_nothing' (return existing thread).
Returns:
Thread: The created thread.
Example Usage:
thread = client.threads.create(
metadata={"number":1},
thread_id="my-thread-id",
if_exists="raise"
)
""" # noqa: E501
payload: Dict[str, Any] = {}
if thread_id:
payload["thread_id"] = thread_id
if metadata:
payload["metadata"] = metadata
if if_exists:
payload["if_exists"] = if_exists
return self.http.post("/threads", json=payload)
def update(self, thread_id: str, *, metadata: dict[str, Any]) -> Thread:
"""Update a thread.
Args:
thread_id: ID of thread to update.
metadata: Metadata to merge with existing thread metadata.
Returns:
Thread: The created thread.
Example Usage:
thread = client.threads.update(
thread_id="my-thread-id",
metadata={"number":1},
)
""" # noqa: E501
return self.http.patch(f"/threads/{thread_id}", json={"metadata": metadata})
def delete(self, thread_id: str) -> None:
"""Delete a thread.
Args:
thread_id: The ID of the thread to delete.
Returns:
None
Example Usage:
client.threads.delete(
thread_id="my_thread_id"
)
""" # noqa: E501
self.http.delete(f"/threads/{thread_id}")
def search(
self,
*,
metadata: Json = None,
values: Json = None,
status: Optional[ThreadStatus] = None,
limit: int = 10,
offset: int = 0,
) -> list[Thread]:
"""Search for threads.
Args:
metadata: Thread metadata to filter on.
values: State values to filter on.
status: Thread status to filter on.
Must be one of 'idle', 'busy', 'interrupted' or 'error'.
limit: Limit on number of threads to return.
offset: Offset in threads table to start search from.
Returns:
list[Thread]: List of the threads matching the search parameters.
Example Usage:
threads = client.threads.search(
metadata={"number":1},
status="interrupted",
limit=15,
offset=5
)
""" # noqa: E501
payload: Dict[str, Any] = {
"limit": limit,
"offset": offset,
}
if metadata:
payload["metadata"] = metadata
if values:
payload["values"] = values
if status:
payload["status"] = status
return self.http.post(
"/threads/search",
json=payload,
)
def copy(self, thread_id: str) -> None:
"""Copy a thread.
Args:
thread_id: The ID of the thread to copy.
Returns:
None
Example Usage:
client.threads.copy(
thread_id="my_thread_id"
)
""" # noqa: E501
return self.http.post(f"/threads/{thread_id}/copy", json=None)
def get_state(
self,
thread_id: str,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None, # deprecated
*,
subgraphs: bool = False,
) -> ThreadState:
"""Get the state of a thread.
Args:
thread_id: The ID of the thread to get the state of.
checkpoint: The checkpoint to get the state of.
subgraphs: Include subgraphs states.
Returns:
ThreadState: the thread of the state.
Example Usage:
thread_state = client.threads.get_state(
thread_id="my_thread_id",
checkpoint_id="my_checkpoint_id"
)
print(thread_state)
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
{
'values': {
'messages': [
{
'content': 'how are you?',
'additional_kwargs': {},
'response_metadata': {},
'type': 'human',
'name': None,
'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10',
'example': False
},
{
'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.",
'additional_kwargs': {},
'response_metadata': {},
'type': 'ai',
'name': None,
'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b',
'example': False,
'tool_calls': [],
'invalid_tool_calls': [],
'usage_metadata': None
}
]
},
'next': [],
'checkpoint':
{
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'checkpoint_ns': '',
'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1'
}
'metadata':
{
'step': 1,
'run_id': '1ef4a9b8-d7da-679a-a45a-872054341df2',
'source': 'loop',
'writes':
{
'agent':
{
'messages': [
{
'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b',
'name': None,
'type': 'ai',
'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.",
'example': False,
'tool_calls': [],
'usage_metadata': None,
'additional_kwargs': {},
'response_metadata': {},
'invalid_tool_calls': []
}
]
}
},
'user_id': None,
'graph_id': 'agent',
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'created_by': 'system',
'assistant_id': 'fe096781-5601-53d2-b2f6-0d3403f7e9ca'},
'created_at': '2024-07-25T15:35:44.184703+00:00',
'parent_config':
{
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'checkpoint_ns': '',
'checkpoint_id': '1ef4a9b8-d80d-6fa7-8000-9300467fad0f'
}
}
""" # noqa: E501
if checkpoint:
return self.http.post(
f"/threads/{thread_id}/state/checkpoint",
json={"checkpoint": checkpoint, "subgraphs": subgraphs},
)
elif checkpoint_id:
return self.http.get(
f"/threads/{thread_id}/state/{checkpoint_id}",
params={"subgraphs": subgraphs},
)
else:
return self.http.get(
f"/threads/{thread_id}/state",
params={"subgraphs": subgraphs},
)
def update_state(
self,
thread_id: str,
values: Optional[Union[dict, Sequence[dict]]],
*,
as_node: Optional[str] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None, # deprecated
) -> ThreadUpdateStateResponse:
"""Update the state of a thread.
Args:
thread_id: The ID of the thread to update.
values: The values to update the state with.
as_node: Update the state as if this node had just executed.
checkpoint: The checkpoint to update the state of.
Returns:
ThreadUpdateStateResponse: Response after updating a thread's state.
Example Usage:
response = client.threads.update_state(
thread_id="my_thread_id",
values={"messages":[{"role": "user", "content": "hello!"}]},
as_node="my_node",
)
print(response)
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
{
'checkpoint': {
'thread_id': 'e2496803-ecd5-4e0c-a779-3226296181c2',
'checkpoint_ns': '',
'checkpoint_id': '1ef4a9b8-e6fb-67b1-8001-abd5184439d1',
'checkpoint_map': {}
}
}
""" # noqa: E501
payload: Dict[str, Any] = {
"values": values,
}
if checkpoint_id:
payload["checkpoint_id"] = checkpoint_id
if checkpoint:
payload["checkpoint"] = checkpoint
if as_node:
payload["as_node"] = as_node
return self.http.post(f"/threads/{thread_id}/state", json=payload)
def get_history(
self,
thread_id: str,
*,
limit: int = 10,
before: Optional[str | Checkpoint] = None,
metadata: Optional[dict] = None,
checkpoint: Optional[Checkpoint] = None,
) -> list[ThreadState]:
"""Get the state history of a thread.
Args:
thread_id: The ID of the thread to get the state history for.
checkpoint: Return states for this subgraph. If empty defaults to root.
limit: The maximum number of states to return.
before: Return states before this checkpoint.
metadata: Filter states by metadata key-value pairs.
Returns:
list[ThreadState]: the state history of the thread.
Example Usage:
thread_state = client.threads.get_history(
thread_id="my_thread_id",
limit=5,
before="my_timestamp",
metadata={"name":"my_name"}
)
""" # noqa: E501
payload: Dict[str, Any] = {
"limit": limit,
}
if before:
payload["before"] = before
if metadata:
payload["metadata"] = metadata
if checkpoint:
payload["checkpoint"] = checkpoint
return self.http.post(f"/threads/{thread_id}/history", json=payload)
class SyncRunsClient:
"""Synchronous client for managing runs in LangGraph.
This class provides methods to create, retrieve, and manage runs, which represent
individual executions of graphs.
Example:
client = get_sync_client()
run = client.runs.create(thread_id="thread_123", assistant_id="asst_456")
"""
def __init__(self, http: SyncHttpClient) -> None:
self.http = http
@overload
def stream(
self,
thread_id: str,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
feedback_keys: Optional[Sequence[str]] = None,
on_disconnect: Optional[DisconnectMode] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> Iterator[StreamPart]: ...
@overload
def stream(
self,
thread_id: None,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
feedback_keys: Optional[Sequence[str]] = None,
on_disconnect: Optional[DisconnectMode] = None,
on_completion: Optional[OnCompletionBehavior] = None,
if_not_exists: Optional[IfNotExists] = None,
webhook: Optional[str] = None,
after_seconds: Optional[int] = None,
) -> Iterator[StreamPart]: ...
def stream(
self,
thread_id: Optional[str],
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
feedback_keys: Optional[Sequence[str]] = None,
on_disconnect: Optional[DisconnectMode] = None,
on_completion: Optional[OnCompletionBehavior] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> Iterator[StreamPart]:
"""Create a run and stream the results.
Args:
thread_id: the thread ID to assign to the thread.
If None will create a stateless run.
assistant_id: The assistant ID or graph name to stream from.
If using graph name, will default to first assistant created from that graph.
input: The input to the graph.
command: The command to execute.
stream_mode: The stream mode(s) to use.
stream_subgraphs: Whether to stream output from subgraphs.
metadata: Metadata to assign to the run.
config: The configuration for the assistant.
checkpoint: The checkpoint to resume from.
interrupt_before: Nodes to interrupt immediately before they get executed.
interrupt_after: Nodes to Nodes to interrupt immediately after they get executed.
feedback_keys: Feedback keys to assign to run.
on_disconnect: The disconnect mode to use.
Must be one of 'cancel' or 'continue'.
on_completion: Whether to delete or keep the thread created for a stateless run.
Must be one of 'delete' or 'keep'.
webhook: Webhook to call after LangGraph API call is done.
multitask_strategy: Multitask strategy to use.
Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'.
if_not_exists: How to handle missing thread. Defaults to 'reject'.
Must be either 'reject' (raise error if missing), or 'create' (create new thread).
after_seconds: The number of seconds to wait before starting the run.
Use to schedule future runs.
Returns:
Iterator[StreamPart]: Iterator of stream results.
Example Usage:
async for chunk in client.runs.stream(
thread_id=None,
assistant_id="agent",
input={"messages": [{"role": "user", "content": "how are you?"}]},
stream_mode=["values","debug"],
metadata={"name":"my_run"},
config={"configurable": {"model_name": "anthropic"}},
interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"],
interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"],
feedback_keys=["my_feedback_key_1","my_feedback_key_2"],
webhook="https://my.fake.webhook.com",
multitask_strategy="interrupt"
):
print(chunk)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
StreamPart(event='metadata', data={'run_id': '1ef4a9b8-d7da-679a-a45a-872054341df2'})
StreamPart(event='values', data={'messages': [{'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10', 'example': False}]})
StreamPart(event='values', data={'messages': [{'content': 'how are you?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': 'fe0a5778-cfe9-42ee-b807-0adaa1873c10', 'example': False}, {'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.", 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'ai', 'name': None, 'id': 'run-159b782c-b679-4830-83c6-cef87798fe8b', 'example': False, 'tool_calls': [], 'invalid_tool_calls': [], 'usage_metadata': None}]})
StreamPart(event='end', data=None)
""" # noqa: E501
payload = {
"input": input,
"command": command,
"config": config,
"metadata": metadata,
"stream_mode": stream_mode,
"stream_subgraphs": stream_subgraphs,
"assistant_id": assistant_id,
"interrupt_before": interrupt_before,
"interrupt_after": interrupt_after,
"feedback_keys": feedback_keys,
"webhook": webhook,
"checkpoint": checkpoint,
"checkpoint_id": checkpoint_id,
"multitask_strategy": multitask_strategy,
"if_not_exists": if_not_exists,
"on_disconnect": on_disconnect,
"on_completion": on_completion,
"after_seconds": after_seconds,
}
endpoint = (
f"/threads/{thread_id}/runs/stream"
if thread_id is not None
else "/runs/stream"
)
return self.http.stream(
endpoint, "POST", json={k: v for k, v in payload.items() if v is not None}
)
@overload
def create(
self,
thread_id: None,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
on_completion: Optional[OnCompletionBehavior] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> Run: ...
@overload
def create(
self,
thread_id: str,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> Run: ...
def create(
self,
thread_id: Optional[str],
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values",
stream_subgraphs: bool = False,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
on_completion: Optional[OnCompletionBehavior] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> Run:
"""Create a background run.
Args:
thread_id: the thread ID to assign to the thread.
If None will create a stateless run.
assistant_id: The assistant ID or graph name to stream from.
If using graph name, will default to first assistant created from that graph.
input: The input to the graph.
command: The command to execute.
stream_mode: The stream mode(s) to use.
stream_subgraphs: Whether to stream output from subgraphs.
metadata: Metadata to assign to the run.
config: The configuration for the assistant.
checkpoint: The checkpoint to resume from.
interrupt_before: Nodes to interrupt immediately before they get executed.
interrupt_after: Nodes to Nodes to interrupt immediately after they get executed.
webhook: Webhook to call after LangGraph API call is done.
multitask_strategy: Multitask strategy to use.
Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'.
on_completion: Whether to delete or keep the thread created for a stateless run.
Must be one of 'delete' or 'keep'.
if_not_exists: How to handle missing thread. Defaults to 'reject'.
Must be either 'reject' (raise error if missing), or 'create' (create new thread).
after_seconds: The number of seconds to wait before starting the run.
Use to schedule future runs.
Returns:
Run: The created background run.
Example Usage:
background_run = client.runs.create(
thread_id="my_thread_id",
assistant_id="my_assistant_id",
input={"messages": [{"role": "user", "content": "hello!"}]},
metadata={"name":"my_run"},
config={"configurable": {"model_name": "openai"}},
interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"],
interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"],
webhook="https://my.fake.webhook.com",
multitask_strategy="interrupt"
)
print(background_run)
--------------------------------------------------------------------------------
{
'run_id': 'my_run_id',
'thread_id': 'my_thread_id',
'assistant_id': 'my_assistant_id',
'created_at': '2024-07-25T15:35:42.598503+00:00',
'updated_at': '2024-07-25T15:35:42.598503+00:00',
'metadata': {},
'status': 'pending',
'kwargs':
{
'input':
{
'messages': [
{
'role': 'user',
'content': 'how are you?'
}
]
},
'config':
{
'metadata':
{
'created_by': 'system'
},
'configurable':
{
'run_id': 'my_run_id',
'user_id': None,
'graph_id': 'agent',
'thread_id': 'my_thread_id',
'checkpoint_id': None,
'model_name': "openai",
'assistant_id': 'my_assistant_id'
}
},
'webhook': "https://my.fake.webhook.com",
'temporary': False,
'stream_mode': ['values'],
'feedback_keys': None,
'interrupt_after': ["node_to_stop_after_1","node_to_stop_after_2"],
'interrupt_before': ["node_to_stop_before_1","node_to_stop_before_2"]
},
'multitask_strategy': 'interrupt'
}
""" # noqa: E501
payload = {
"input": input,
"command": command,
"stream_mode": stream_mode,
"stream_subgraphs": stream_subgraphs,
"config": config,
"metadata": metadata,
"assistant_id": assistant_id,
"interrupt_before": interrupt_before,
"interrupt_after": interrupt_after,
"webhook": webhook,
"checkpoint": checkpoint,
"checkpoint_id": checkpoint_id,
"multitask_strategy": multitask_strategy,
"if_not_exists": if_not_exists,
"on_completion": on_completion,
"after_seconds": after_seconds,
}
payload = {k: v for k, v in payload.items() if v is not None}
if thread_id:
return self.http.post(f"/threads/{thread_id}/runs", json=payload)
else:
return self.http.post("/runs", json=payload)
def create_batch(self, payloads: list[RunCreate]) -> list[Run]:
"""Create a batch of stateless background runs."""
def filter_payload(payload: RunCreate):
return {k: v for k, v in payload.items() if v is not None}
payloads = [filter_payload(payload) for payload in payloads]
return self.http.post("/runs/batch", json=payloads)
@overload
def wait(
self,
thread_id: str,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
on_disconnect: Optional[DisconnectMode] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> Union[list[dict], dict[str, Any]]: ...
@overload
def wait(
self,
thread_id: None,
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
on_disconnect: Optional[DisconnectMode] = None,
on_completion: Optional[OnCompletionBehavior] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> Union[list[dict], dict[str, Any]]: ...
def wait(
self,
thread_id: Optional[str],
assistant_id: str,
*,
input: Optional[dict] = None,
command: Optional[Command] = None,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
checkpoint: Optional[Checkpoint] = None,
checkpoint_id: Optional[str] = None,
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
webhook: Optional[str] = None,
on_disconnect: Optional[DisconnectMode] = None,
on_completion: Optional[OnCompletionBehavior] = None,
multitask_strategy: Optional[MultitaskStrategy] = None,
if_not_exists: Optional[IfNotExists] = None,
after_seconds: Optional[int] = None,
) -> Union[list[dict], dict[str, Any]]:
"""Create a run, wait until it finishes and return the final state.
Args:
thread_id: the thread ID to create the run on.
If None will create a stateless run.
assistant_id: The assistant ID or graph name to run.
If using graph name, will default to first assistant created from that graph.
input: The input to the graph.
command: The command to execute.
metadata: Metadata to assign to the run.
config: The configuration for the assistant.
checkpoint: The checkpoint to resume from.
interrupt_before: Nodes to interrupt immediately before they get executed.
interrupt_after: Nodes to Nodes to interrupt immediately after they get executed.
webhook: Webhook to call after LangGraph API call is done.
on_disconnect: The disconnect mode to use.
Must be one of 'cancel' or 'continue'.
on_completion: Whether to delete or keep the thread created for a stateless run.
Must be one of 'delete' or 'keep'.
multitask_strategy: Multitask strategy to use.
Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'.
if_not_exists: How to handle missing thread. Defaults to 'reject'.
Must be either 'reject' (raise error if missing), or 'create' (create new thread).
after_seconds: The number of seconds to wait before starting the run.
Use to schedule future runs.
Returns:
Union[list[dict], dict[str, Any]]: The output of the run.
Example Usage:
final_state_of_run = client.runs.wait(
thread_id=None,
assistant_id="agent",
input={"messages": [{"role": "user", "content": "how are you?"}]},
metadata={"name":"my_run"},
config={"configurable": {"model_name": "anthropic"}},
interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"],
interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"],
webhook="https://my.fake.webhook.com",
multitask_strategy="interrupt"
)
print(final_state_of_run)
-------------------------------------------------------------------------------------------------------------------------------------------
{
'messages': [
{
'content': 'how are you?',
'additional_kwargs': {},
'response_metadata': {},
'type': 'human',
'name': None,
'id': 'f51a862c-62fe-4866-863b-b0863e8ad78a',
'example': False
},
{
'content': "I'm doing well, thanks for asking! I'm an AI assistant created by Anthropic to be helpful, honest, and harmless.",
'additional_kwargs': {},
'response_metadata': {},
'type': 'ai',
'name': None,
'id': 'run-bf1cd3c6-768f-4c16-b62d-ba6f17ad8b36',
'example': False,
'tool_calls': [],
'invalid_tool_calls': [],
'usage_metadata': None
}
]
}
""" # noqa: E501
payload = {
"input": input,
"command": command,
"config": config,
"metadata": metadata,
"assistant_id": assistant_id,
"interrupt_before": interrupt_before,
"interrupt_after": interrupt_after,
"webhook": webhook,
"checkpoint": checkpoint,
"checkpoint_id": checkpoint_id,
"multitask_strategy": multitask_strategy,
"if_not_exists": if_not_exists,
"on_disconnect": on_disconnect,
"on_completion": on_completion,
"after_seconds": after_seconds,
}
endpoint = (
f"/threads/{thread_id}/runs/wait" if thread_id is not None else "/runs/wait"
)
return self.http.post(
endpoint, json={k: v for k, v in payload.items() if v is not None}
)
def list(self, thread_id: str, *, limit: int = 10, offset: int = 0) -> List[Run]:
"""List runs.
Args:
thread_id: The thread ID to list runs for.
limit: The maximum number of results to return.
offset: The number of results to skip.
Returns:
List[Run]: The runs for the thread.
Example Usage:
client.runs.delete(
thread_id="thread_id_to_delete",
limit=5,
offset=5,
)
""" # noqa: E501
return self.http.get(f"/threads/{thread_id}/runs?limit={limit}&offset={offset}")
def get(self, thread_id: str, run_id: str) -> Run:
"""Get a run.
Args:
thread_id: The thread ID to get.
run_id: The run ID to get.
Returns:
Run: Run object.
Example Usage:
run = client.runs.get(
thread_id="thread_id_to_delete",
run_id="run_id_to_delete",
)
""" # noqa: E501
return self.http.get(f"/threads/{thread_id}/runs/{run_id}")
def cancel(
self,
thread_id: str,
run_id: str,
*,
wait: bool = False,
action: CancelAction = "interrupt",
) -> None:
"""Get a run.
Args:
thread_id: The thread ID to cancel.
run_id: The run ID to cancek.
wait: Whether to wait until run has completed.
action: Action to take when cancelling the run. Possible values
are `interrupt` or `rollback`. Default is `interrupt`.
Returns:
None
Example Usage:
client.runs.cancel(
thread_id="thread_id_to_cancel",
run_id="run_id_to_cancel",
wait=True,
action="interrupt"
)
""" # noqa: E501
return self.http.post(
f"/threads/{thread_id}/runs/{run_id}/cancel?wait={1 if wait else 0}&action={action}",
json=None,
)
def join(self, thread_id: str, run_id: str) -> dict:
"""Block until a run is done. Returns the final state of the thread.
Args:
thread_id: The thread ID to join.
run_id: The run ID to join.
Returns:
None
Example Usage:
client.runs.join(
thread_id="thread_id_to_join",
run_id="run_id_to_join"
)
""" # noqa: E501
return self.http.get(f"/threads/{thread_id}/runs/{run_id}/join")
def join_stream(self, thread_id: str, run_id: str) -> Iterator[StreamPart]:
"""Stream output from a run in real-time, until the run is done.
Output is not buffered, so any output produced before this call will
not be received here.
Args:
thread_id: The thread ID to join.
run_id: The run ID to join.
Returns:
None
Example Usage:
client.runs.join_stream(
thread_id="thread_id_to_join",
run_id="run_id_to_join"
)
""" # noqa: E501
return self.http.stream(f"/threads/{thread_id}/runs/{run_id}/stream", "GET")
def delete(self, thread_id: str, run_id: str) -> None:
"""Delete a run.
Args:
thread_id: The thread ID to delete.
run_id: The run ID to delete.
Returns:
None
Example Usage:
client.runs.delete(
thread_id="thread_id_to_delete",
run_id="run_id_to_delete"
)
""" # noqa: E501
self.http.delete(f"/threads/{thread_id}/runs/{run_id}")
class SyncCronClient:
"""Synchronous client for managing cron jobs in LangGraph.
This class provides methods to create and manage scheduled tasks (cron jobs) for automated graph executions.
Example:
client = get_sync_client()
cron_job = client.crons.create_for_thread(thread_id="thread_123", assistant_id="asst_456", schedule="0 * * * *")
"""
def __init__(self, http_client: SyncHttpClient) -> None:
self.http = http_client
def create_for_thread(
self,
thread_id: str,
assistant_id: str,
*,
schedule: str,
input: Optional[dict] = None,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
interrupt_before: Optional[Union[All, list[str]]] = None,
interrupt_after: Optional[Union[All, list[str]]] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[str] = None,
) -> Run:
"""Create a cron job for a thread.
Args:
thread_id: the thread ID to run the cron job on.
assistant_id: The assistant ID or graph name to use for the cron job.
If using graph name, will default to first assistant created from that graph.
schedule: The cron schedule to execute this job on.
input: The input to the graph.
metadata: Metadata to assign to the cron job runs.
config: The configuration for the assistant.
interrupt_before: Nodes to interrupt immediately before they get executed.
interrupt_after: Nodes to Nodes to interrupt immediately after they get executed.
webhook: Webhook to call after LangGraph API call is done.
multitask_strategy: Multitask strategy to use.
Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'.
Returns:
Run: The cron run.
Example Usage:
cron_run = client.crons.create_for_thread(
thread_id="my-thread-id",
assistant_id="agent",
schedule="27 15 * * *",
input={"messages": [{"role": "user", "content": "hello!"}]},
metadata={"name":"my_run"},
config={"configurable": {"model_name": "openai"}},
interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"],
interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"],
webhook="https://my.fake.webhook.com",
multitask_strategy="interrupt"
)
""" # noqa: E501
payload = {
"schedule": schedule,
"input": input,
"config": config,
"metadata": metadata,
"assistant_id": assistant_id,
"interrupt_before": interrupt_before,
"interrupt_after": interrupt_after,
"webhook": webhook,
}
if multitask_strategy:
payload["multitask_strategy"] = multitask_strategy
payload = {k: v for k, v in payload.items() if v is not None}
return self.http.post(f"/threads/{thread_id}/runs/crons", json=payload)
def create(
self,
assistant_id: str,
*,
schedule: str,
input: Optional[dict] = None,
metadata: Optional[dict] = None,
config: Optional[Config] = None,
interrupt_before: Optional[Union[All, list[str]]] = None,
interrupt_after: Optional[Union[All, list[str]]] = None,
webhook: Optional[str] = None,
multitask_strategy: Optional[str] = None,
) -> Run:
"""Create a cron run.
Args:
assistant_id: The assistant ID or graph name to use for the cron job.
If using graph name, will default to first assistant created from that graph.
schedule: The cron schedule to execute this job on.
input: The input to the graph.
metadata: Metadata to assign to the cron job runs.
config: The configuration for the assistant.
interrupt_before: Nodes to interrupt immediately before they get executed.
interrupt_after: Nodes to Nodes to interrupt immediately after they get executed.
webhook: Webhook to call after LangGraph API call is done.
multitask_strategy: Multitask strategy to use.
Must be one of 'reject', 'interrupt', 'rollback', or 'enqueue'.
Returns:
Run: The cron run.
Example Usage:
cron_run = client.crons.create(
assistant_id="agent",
schedule="27 15 * * *",
input={"messages": [{"role": "user", "content": "hello!"}]},
metadata={"name":"my_run"},
config={"configurable": {"model_name": "openai"}},
interrupt_before=["node_to_stop_before_1","node_to_stop_before_2"],
interrupt_after=["node_to_stop_after_1","node_to_stop_after_2"],
webhook="https://my.fake.webhook.com",
multitask_strategy="interrupt"
)
""" # noqa: E501
payload = {
"schedule": schedule,
"input": input,
"config": config,
"metadata": metadata,
"assistant_id": assistant_id,
"interrupt_before": interrupt_before,
"interrupt_after": interrupt_after,
"webhook": webhook,
}
if multitask_strategy:
payload["multitask_strategy"] = multitask_strategy
payload = {k: v for k, v in payload.items() if v is not None}
return self.http.post("/runs/crons", json=payload)
def delete(self, cron_id: str) -> None:
"""Delete a cron.
Args:
cron_id: The cron ID to delete.
Returns:
None
Example Usage:
client.crons.delete(
cron_id="cron_to_delete"
)
""" # noqa: E501
self.http.delete(f"/runs/crons/{cron_id}")
def search(
self,
*,
assistant_id: Optional[str] = None,
thread_id: Optional[str] = None,
limit: int = 10,
offset: int = 0,
) -> list[Cron]:
"""Get a list of cron jobs.
Args:
assistant_id: The assistant ID or graph name to search for.
thread_id: the thread ID to search for.
limit: The maximum number of results to return.
offset: The number of results to skip.
Returns:
list[Cron]: The list of cron jobs returned by the search,
Example Usage:
cron_jobs = client.crons.search(
assistant_id="my_assistant_id",
thread_id="my_thread_id",
limit=5,
offset=5,
)
print(cron_jobs)
----------------------------------------------------------
[
{
'cron_id': '1ef3cefa-4c09-6926-96d0-3dc97fd5e39b',
'assistant_id': 'my_assistant_id',
'thread_id': 'my_thread_id',
'user_id': None,
'payload':
{
'input': {'start_time': ''},
'schedule': '4 * * * *',
'assistant_id': 'my_assistant_id'
},
'schedule': '4 * * * *',
'next_run_date': '2024-07-25T17:04:00+00:00',
'end_time': None,
'created_at': '2024-07-08T06:02:23.073257+00:00',
'updated_at': '2024-07-08T06:02:23.073257+00:00'
}
]
""" # noqa: E501
payload = {
"assistant_id": assistant_id,
"thread_id": thread_id,
"limit": limit,
"offset": offset,
}
payload = {k: v for k, v in payload.items() if v is not None}
return self.http.post("/runs/crons/search", json=payload)
class SyncStoreClient:
"""A client for synchronous operations on a key-value store.
Provides methods to interact with a remote key-value store, allowing
storage and retrieval of items within namespaced hierarchies.
Example:
client = get_sync_client()
client.store.put_item(["users", "profiles"], "user123", {"name": "Alice", "age": 30})
"""
def __init__(self, http: SyncHttpClient) -> None:
self.http = http
def put_item(
self,
namespace: Sequence[str],
/,
key: str,
value: dict[str, Any],
index: Optional[Union[Literal[False], list[str]]] = None,
) -> None:
"""Store or update an item.
Args:
namespace: A list of strings representing the namespace path.
key: The unique identifier for the item within the namespace.
value: A dictionary containing the item's data.
index: Controls search indexing - None (use defaults), False (disable), or list of field paths to index.
Returns:
None
Example Usage:
client.store.put_item(
["documents", "user123"],
key="item456",
value={"title": "My Document", "content": "Hello World"}
)
"""
for label in namespace:
if "." in label:
raise ValueError(
f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')."
)
payload = {
"namespace": namespace,
"key": key,
"value": value,
"index": index,
}
self.http.put("/store/items", json=payload)
def get_item(self, namespace: Sequence[str], /, key: str) -> Item:
"""Retrieve a single item.
Args:
key: The unique identifier for the item.
namespace: Optional list of strings representing the namespace path.
Returns:
Item: The retrieved item.
Example Usage:
item = client.store.get_item(
["documents", "user123"],
key="item456",
)
print(item)
----------------------------------------------------------------
{
'namespace': ['documents', 'user123'],
'key': 'item456',
'value': {'title': 'My Document', 'content': 'Hello World'},
'created_at': '2024-07-30T12:00:00Z',
'updated_at': '2024-07-30T12:00:00Z'
}
"""
for label in namespace:
if "." in label:
raise ValueError(
f"Invalid namespace label '{label}'. Namespace labels cannot contain periods ('.')."
)
return self.http.get(
"/store/items", params={"key": key, "namespace": ".".join(namespace)}
)
def delete_item(self, namespace: Sequence[str], /, key: str) -> None:
"""Delete an item.
Args:
key: The unique identifier for the item.
namespace: Optional list of strings representing the namespace path.
Returns:
None
Example Usage:
client.store.delete_item(
["documents", "user123"],
key="item456",
)
"""
self.http.delete("/store/items", json={"key": key, "namespace": namespace})
def search_items(
self,
namespace_prefix: Sequence[str],
/,
filter: Optional[dict[str, Any]] = None,
limit: int = 10,
offset: int = 0,
query: Optional[str] = None,
) -> SearchItemsResponse:
"""Search for items within a namespace prefix.
Args:
namespace_prefix: List of strings representing the namespace prefix.
filter: Optional dictionary of key-value pairs to filter results.
limit: Maximum number of items to return (default is 10).
offset: Number of items to skip before returning results (default is 0).
query: Optional query for natural language search.
Returns:
List[Item]: A list of items matching the search criteria.
Example Usage:
items = client.store.search_items(
["documents"],
filter={"author": "John Doe"},
limit=5,
offset=0
)
print(items)
----------------------------------------------------------------
{
"items": [
{
"namespace": ["documents", "user123"],
"key": "item789",
"value": {
"title": "Another Document",
"author": "John Doe"
},
"created_at": "2024-07-30T12:00:00Z",
"updated_at": "2024-07-30T12:00:00Z"
},
# ... additional items ...
]
}
"""
payload = {
"namespace_prefix": namespace_prefix,
"filter": filter,
"limit": limit,
"offset": offset,
"query": query,
}
return self.http.post("/store/items/search", json=_provided_vals(payload))
def list_namespaces(
self,
prefix: Optional[List[str]] = None,
suffix: Optional[List[str]] = None,
max_depth: Optional[int] = None,
limit: int = 100,
offset: int = 0,
) -> ListNamespaceResponse:
"""List namespaces with optional match conditions.
Args:
prefix: Optional list of strings representing the prefix to filter namespaces.
suffix: Optional list of strings representing the suffix to filter namespaces.
max_depth: Optional integer specifying the maximum depth of namespaces to return.
limit: Maximum number of namespaces to return (default is 100).
offset: Number of namespaces to skip before returning results (default is 0).
Returns:
List[List[str]]: A list of namespaces matching the criteria.
Example Usage:
namespaces = client.store.list_namespaces(
prefix=["documents"],
max_depth=3,
limit=10,
offset=0
)
print(namespaces)
----------------------------------------------------------------
[
["documents", "user123", "reports"],
["documents", "user456", "invoices"],
...
]
"""
payload = {
"prefix": prefix,
"suffix": suffix,
"max_depth": max_depth,
"limit": limit,
"offset": offset,
}
return self.http.post("/store/namespaces", json=_provided_vals(payload))
def _provided_vals(d: dict):
return {k: v for k, v in d.items() if v is not None}
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/sdk-py/langgraph_sdk/__init__.py`:
```py
from langgraph_sdk.client import get_client, get_sync_client
try:
from importlib import metadata
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
__version__ = "unknown"
__all__ = ["get_client", "get_sync_client"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/sdk-py/langgraph_sdk/sse.py`:
```py
"""Adapted from httpx_sse to split lines on \n, \r, \r\n per the SSE spec."""
from typing import AsyncIterator, Iterator, Optional, Union
import httpx
import orjson
from langgraph_sdk.schema import StreamPart
BytesLike = Union[bytes, bytearray, memoryview]
class BytesLineDecoder:
"""
Handles incrementally reading lines from text.
Has the same behaviour as the stdllib bytes splitlines,
but handling the input iteratively.
"""
def __init__(self) -> None:
self.buffer = bytearray()
self.trailing_cr: bool = False
def decode(self, text: bytes) -> list[BytesLike]:
# See https://docs.python.org/3/glossary.html#term-universal-newlines
NEWLINE_CHARS = b"\n\r"
# We always push a trailing `\r` into the next decode iteration.
if self.trailing_cr:
text = b"\r" + text
self.trailing_cr = False
if text.endswith(b"\r"):
self.trailing_cr = True
text = text[:-1]
if not text:
# NOTE: the edge case input of empty text doesn't occur in practice,
# because other httpx internals filter out this value
return [] # pragma: no cover
trailing_newline = text[-1] in NEWLINE_CHARS
lines = text.splitlines()
if len(lines) == 1 and not trailing_newline:
# No new lines, buffer the input and continue.
self.buffer.extend(lines[0])
return []
if self.buffer:
# Include any existing buffer in the first portion of the
# splitlines result.
self.buffer.extend(lines[0])
lines = [self.buffer] + lines[1:]
self.buffer = bytearray()
if not trailing_newline:
# If the last segment of splitlines is not newline terminated,
# then drop it from our output and start a new buffer.
self.buffer.extend(lines.pop())
return lines
def flush(self) -> list[BytesLike]:
if not self.buffer and not self.trailing_cr:
return []
lines = [self.buffer]
self.buffer = bytearray()
self.trailing_cr = False
return lines
class SSEDecoder:
def __init__(self) -> None:
self._event = ""
self._data = bytearray()
self._last_event_id = ""
self._retry: Optional[int] = None
def decode(self, line: bytes) -> Optional[StreamPart]:
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
if not line:
if (
not self._event
and not self._data
and not self._last_event_id
and self._retry is None
):
return None
sse = StreamPart(
event=self._event,
data=orjson.loads(self._data) if self._data else None,
)
# NOTE: as per the SSE spec, do not reset last_event_id.
self._event = ""
self._data = bytearray()
self._retry = None
return sse
if line.startswith(b":"):
return None
fieldname, _, value = line.partition(b":")
if value.startswith(b" "):
value = value[1:]
if fieldname == b"event":
self._event = value.decode()
elif fieldname == b"data":
self._data.extend(value)
elif fieldname == b"id":
if b"\0" in value:
pass
else:
self._last_event_id = value.decode()
elif fieldname == b"retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
else:
pass # Field is ignored.
return None
async def aiter_lines_raw(response: httpx.Response) -> AsyncIterator[BytesLike]:
decoder = BytesLineDecoder()
async for chunk in response.aiter_bytes():
for line in decoder.decode(chunk):
yield line
for line in decoder.flush():
yield line
def iter_lines_raw(response: httpx.Response) -> Iterator[BytesLike]:
decoder = BytesLineDecoder()
for chunk in response.iter_bytes():
for line in decoder.decode(chunk):
yield line
for line in decoder.flush():
yield line
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/sdk-py/langgraph_sdk/schema.py`:
```py
"""Data models for interacting with the LangGraph API."""
from datetime import datetime
from typing import Any, Dict, Literal, NamedTuple, Optional, Sequence, TypedDict, Union
Json = Optional[dict[str, Any]]
"""Represents a JSON-like structure, which can be None or a dictionary with string keys and any values."""
RunStatus = Literal["pending", "error", "success", "timeout", "interrupted"]
"""
Represents the status of a run:
- "pending": The run is waiting to start.
- "error": The run encountered an error and stopped.
- "success": The run completed successfully.
- "timeout": The run exceeded its time limit.
- "interrupted": The run was manually stopped or interrupted.
"""
ThreadStatus = Literal["idle", "busy", "interrupted", "error"]
"""
Represents the status of a thread:
- "idle": The thread is not currently processing any task.
- "busy": The thread is actively processing a task.
- "interrupted": The thread's execution was interrupted.
- "error": An exception occurred during task processing.
"""
StreamMode = Literal[
"values", "messages", "updates", "events", "debug", "custom", "messages-tuple"
]
"""
Defines the mode of streaming:
- "values": Stream only the values.
- "messages": Stream complete messages.
- "updates": Stream updates to the state.
- "events": Stream events occurring during execution.
- "debug": Stream detailed debug information.
- "custom": Stream custom events.
"""
DisconnectMode = Literal["cancel", "continue"]
"""
Specifies behavior on disconnection:
- "cancel": Cancel the operation on disconnection.
- "continue": Continue the operation even if disconnected.
"""
MultitaskStrategy = Literal["reject", "interrupt", "rollback", "enqueue"]
"""
Defines how to handle multiple tasks:
- "reject": Reject new tasks when busy.
- "interrupt": Interrupt current task for new ones.
- "rollback": Roll back current task and start new one.
- "enqueue": Queue new tasks for later execution.
"""
OnConflictBehavior = Literal["raise", "do_nothing"]
"""
Specifies behavior on conflict:
- "raise": Raise an exception when a conflict occurs.
- "do_nothing": Ignore conflicts and proceed.
"""
OnCompletionBehavior = Literal["delete", "keep"]
"""
Defines action after completion:
- "delete": Delete resources after completion.
- "keep": Retain resources after completion.
"""
All = Literal["*"]
"""Represents a wildcard or 'all' selector."""
IfNotExists = Literal["create", "reject"]
"""
Specifies behavior if the thread doesn't exist:
- "create": Create a new thread if it doesn't exist.
- "reject": Reject the operation if the thread doesn't exist.
"""
CancelAction = Literal["interrupt", "rollback"]
"""
Action to take when cancelling the run.
- "interrupt": Simply cancel the run.
- "rollback": Cancel the run. Then delete the run and associated checkpoints.
"""
class Config(TypedDict, total=False):
"""Configuration options for a call."""
tags: list[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls.
"""
recursion_limit: int
"""
Maximum number of times a call can recurse. If not provided, defaults to 25.
"""
configurable: dict[str, Any]
"""
Runtime values for attributes previously made configurable on this Runnable,
or sub-Runnables, through .configurable_fields() or .configurable_alternatives().
Check .output_schema() for a description of the attributes that have been made
configurable.
"""
class Checkpoint(TypedDict):
"""Represents a checkpoint in the execution process."""
thread_id: str
"""Unique identifier for the thread associated with this checkpoint."""
checkpoint_ns: str
"""Namespace for the checkpoint, used for organization and retrieval."""
checkpoint_id: Optional[str]
"""Optional unique identifier for the checkpoint itself."""
checkpoint_map: Optional[dict[str, Any]]
"""Optional dictionary containing checkpoint-specific data."""
class GraphSchema(TypedDict):
"""Defines the structure and properties of a graph."""
graph_id: str
"""The ID of the graph."""
input_schema: Optional[dict]
"""The schema for the graph input.
Missing if unable to generate JSON schema from graph."""
output_schema: Optional[dict]
"""The schema for the graph output.
Missing if unable to generate JSON schema from graph."""
state_schema: Optional[dict]
"""The schema for the graph state.
Missing if unable to generate JSON schema from graph."""
config_schema: Optional[dict]
"""The schema for the graph config.
Missing if unable to generate JSON schema from graph."""
Subgraphs = dict[str, GraphSchema]
class AssistantBase(TypedDict):
"""Base model for an assistant."""
assistant_id: str
"""The ID of the assistant."""
graph_id: str
"""The ID of the graph."""
config: Config
"""The assistant config."""
created_at: datetime
"""The time the assistant was created."""
metadata: Json
"""The assistant metadata."""
version: int
"""The version of the assistant"""
class AssistantVersion(AssistantBase):
"""Represents a specific version of an assistant."""
pass
class Assistant(AssistantBase):
"""Represents an assistant with additional properties."""
updated_at: datetime
"""The last time the assistant was updated."""
name: str
"""The name of the assistant"""
class Interrupt(TypedDict, total=False):
"""Represents an interruption in the execution flow."""
value: Any
"""The value associated with the interrupt."""
when: Literal["during"]
"""When the interrupt occurred."""
resumable: bool
"""Whether the interrupt can be resumed."""
ns: Optional[list[str]]
"""Optional namespace for the interrupt."""
class Thread(TypedDict):
"""Represents a conversation thread."""
thread_id: str
"""The ID of the thread."""
created_at: datetime
"""The time the thread was created."""
updated_at: datetime
"""The last time the thread was updated."""
metadata: Json
"""The thread metadata."""
status: ThreadStatus
"""The status of the thread, one of 'idle', 'busy', 'interrupted'."""
values: Json
"""The current state of the thread."""
interrupts: Dict[str, list[Interrupt]]
"""Interrupts which were thrown in this thread"""
class ThreadTask(TypedDict):
"""Represents a task within a thread."""
id: str
name: str
error: Optional[str]
interrupts: list[Interrupt]
checkpoint: Optional[Checkpoint]
state: Optional["ThreadState"]
result: Optional[dict[str, Any]]
class ThreadState(TypedDict):
"""Represents the state of a thread."""
values: Union[list[dict], dict[str, Any]]
"""The state values."""
next: Sequence[str]
"""The next nodes to execute. If empty, the thread is done until new input is
received."""
checkpoint: Checkpoint
"""The ID of the checkpoint."""
metadata: Json
"""Metadata for this state"""
created_at: Optional[str]
"""Timestamp of state creation"""
parent_checkpoint: Optional[Checkpoint]
"""The ID of the parent checkpoint. If missing, this is the root checkpoint."""
tasks: Sequence[ThreadTask]
"""Tasks to execute in this step. If already attempted, may contain an error."""
class ThreadUpdateStateResponse(TypedDict):
"""Represents the response from updating a thread's state."""
checkpoint: Checkpoint
"""Checkpoint of the latest state."""
class Run(TypedDict):
"""Represents a single execution run."""
run_id: str
"""The ID of the run."""
thread_id: str
"""The ID of the thread."""
assistant_id: str
"""The assistant that was used for this run."""
created_at: datetime
"""The time the run was created."""
updated_at: datetime
"""The last time the run was updated."""
status: RunStatus
"""The status of the run. One of 'pending', 'running', "error", 'success', "timeout", "interrupted"."""
metadata: Json
"""The run metadata."""
multitask_strategy: MultitaskStrategy
"""Strategy to handle concurrent runs on the same thread."""
class Cron(TypedDict):
"""Represents a scheduled task."""
cron_id: str
"""The ID of the cron."""
thread_id: Optional[str]
"""The ID of the thread."""
end_time: Optional[datetime]
"""The end date to stop running the cron."""
schedule: str
"""The schedule to run, cron format."""
created_at: datetime
"""The time the cron was created."""
updated_at: datetime
"""The last time the cron was updated."""
payload: dict
"""The run payload to use for creating new run."""
class RunCreate(TypedDict):
"""Defines the parameters for initiating a background run."""
thread_id: Optional[str]
"""The identifier of the thread to run. If not provided, the run is stateless."""
assistant_id: str
"""The identifier of the assistant to use for this run."""
input: Optional[dict]
"""Initial input data for the run."""
metadata: Optional[dict]
"""Additional metadata to associate with the run."""
config: Optional[Config]
"""Configuration options for the run."""
checkpoint_id: Optional[str]
"""The identifier of a checkpoint to resume from."""
interrupt_before: Optional[list[str]]
"""List of node names to interrupt execution before."""
interrupt_after: Optional[list[str]]
"""List of node names to interrupt execution after."""
webhook: Optional[str]
"""URL to send webhook notifications about the run's progress."""
multitask_strategy: Optional[MultitaskStrategy]
"""Strategy for handling concurrent runs on the same thread."""
class Item(TypedDict):
"""Represents a single document or data entry in the graph's Store.
Items are used to store cross-thread memories.
"""
namespace: list[str]
"""The namespace of the item. A namespace is analogous to a document's directory."""
key: str
"""The unique identifier of the item within its namespace.
In general, keys needn't be globally unique.
"""
value: dict[str, Any]
"""The value stored in the item. This is the document itself."""
created_at: datetime
"""The timestamp when the item was created."""
updated_at: datetime
"""The timestamp when the item was last updated."""
class ListNamespaceResponse(TypedDict):
"""Response structure for listing namespaces."""
namespaces: list[list[str]]
"""A list of namespace paths, where each path is a list of strings."""
class SearchItem(Item, total=False):
"""Item with an optional relevance score from search operations.
Attributes:
score (Optional[float]): Relevance/similarity score. Included when
searching a compatible store with a natural language query.
"""
score: Optional[float]
class SearchItemsResponse(TypedDict):
"""Response structure for searching items."""
items: list[SearchItem]
"""A list of items matching the search criteria."""
class StreamPart(NamedTuple):
"""Represents a part of a stream response."""
event: str
"""The type of event for this stream part."""
data: dict
"""The data payload associated with the event."""
class Send(TypedDict):
node: str
input: Optional[dict[str, Any]]
class Command(TypedDict, total=False):
goto: Union[Send, str, Sequence[Union[Send, str]]]
update: dict[str, Any]
resume: Any
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/config.py`:
```py
import json
import os
import pathlib
import textwrap
from typing import NamedTuple, Optional, TypedDict, Union
import click
MIN_NODE_VERSION = "20"
MIN_PYTHON_VERSION = "3.11"
class IndexConfig(TypedDict, total=False):
"""Configuration for indexing documents for semantic search in the store."""
dims: int
"""Number of dimensions in the embedding vectors.
Common embedding models have the following dimensions:
- openai:text-embedding-3-large: 3072
- openai:text-embedding-3-small: 1536
- openai:text-embedding-ada-002: 1536
- cohere:embed-english-v3.0: 1024
- cohere:embed-english-light-v3.0: 384
- cohere:embed-multilingual-v3.0: 1024
- cohere:embed-multilingual-light-v3.0: 384
"""
embed: str
"""Optional model (string) to generate embeddings from text or path to model or function.
Examples:
- "openai:text-embedding-3-large"
- "cohere:embed-multilingual-v3.0"
- "src/app.py:embeddings
"""
fields: Optional[list[str]]
"""Fields to extract text from for embedding generation.
Defaults to the root ["$"], which embeds the json object as a whole.
"""
class StoreConfig(TypedDict, total=False):
embed: Optional[IndexConfig]
"""Configuration for vector embeddings in store."""
class Config(TypedDict, total=False):
python_version: str
node_version: Optional[str]
pip_config_file: Optional[str]
dockerfile_lines: list[str]
dependencies: list[str]
graphs: dict[str, str]
env: Union[dict[str, str], str]
store: Optional[StoreConfig]
def _parse_version(version_str: str) -> tuple[int, int]:
"""Parse a version string into a tuple of (major, minor)."""
try:
major, minor = map(int, version_str.split("."))
return (major, minor)
except ValueError:
raise click.UsageError(f"Invalid version format: {version_str}") from None
def _parse_node_version(version_str: str) -> int:
"""Parse a Node.js version string into a major version number."""
try:
if "." in version_str:
raise ValueError("Node.js version must be major version only")
return int(version_str)
except ValueError:
raise click.UsageError(
f"Invalid Node.js version format: {version_str}. "
"Use major version only (e.g., '20')."
) from None
def validate_config(config: Config) -> Config:
config = (
{
"node_version": config.get("node_version"),
"dockerfile_lines": config.get("dockerfile_lines", []),
"graphs": config.get("graphs", {}),
"env": config.get("env", {}),
"store": config.get("store"),
}
if config.get("node_version")
else {
"python_version": config.get("python_version", "3.11"),
"pip_config_file": config.get("pip_config_file"),
"dockerfile_lines": config.get("dockerfile_lines", []),
"dependencies": config.get("dependencies", []),
"graphs": config.get("graphs", {}),
"env": config.get("env", {}),
"store": config.get("store"),
}
)
if config.get("node_version"):
node_version = config["node_version"]
try:
major = _parse_node_version(node_version)
min_major = _parse_node_version(MIN_NODE_VERSION)
if major < min_major:
raise click.UsageError(
f"Node.js version {node_version} is not supported. "
f"Minimum required version is {MIN_NODE_VERSION}."
)
except ValueError as e:
raise click.UsageError(str(e)) from None
if config.get("python_version"):
pyversion = config["python_version"]
if not pyversion.count(".") == 1 or not all(
part.isdigit() for part in pyversion.split(".")
):
raise click.UsageError(
f"Invalid Python version format: {pyversion}. "
"Use 'major.minor' format (e.g., '3.11'). "
"Patch version cannot be specified."
)
if _parse_version(pyversion) < _parse_version(MIN_PYTHON_VERSION):
raise click.UsageError(
f"Python version {pyversion} is not supported. "
f"Minimum required version is {MIN_PYTHON_VERSION}."
)
if not config["dependencies"]:
raise click.UsageError(
"No dependencies found in config. "
"Add at least one dependency to 'dependencies' list."
)
if not config["graphs"]:
raise click.UsageError(
"No graphs found in config. "
"Add at least one graph to 'graphs' dictionary."
)
return config
def validate_config_file(config_path: pathlib.Path) -> Config:
with open(config_path) as f:
config = json.load(f)
validated = validate_config(config)
# Enforce the package.json doesn't enforce an
# incompatible Node.js version
if validated.get("node_version"):
package_json_path = config_path.parent / "package.json"
if package_json_path.is_file():
try:
with open(package_json_path) as f:
package_json = json.load(f)
if "engines" in package_json:
engines = package_json["engines"]
if any(engine != "node" for engine in engines.keys()):
raise click.UsageError(
"Only 'node' engine is supported in package.json engines."
f" Got engines: {list(engines.keys())}"
)
if engines:
node_version = engines["node"]
try:
major = _parse_node_version(node_version)
min_major = _parse_node_version(MIN_NODE_VERSION)
if major < min_major:
raise click.UsageError(
f"Node.js version in package.json engines must be >= {MIN_NODE_VERSION} "
f"(major version only), got '{node_version}'. Minor/patch versions "
"(like '20.x.y') are not supported to prevent deployment issues "
"when new Node.js versions are released."
)
except ValueError as e:
raise click.UsageError(str(e)) from None
except json.JSONDecodeError:
raise click.UsageError(
"Invalid package.json found in langgraph "
f"config directory {package_json_path}: file is not valid JSON"
) from None
return validated
class LocalDeps(NamedTuple):
pip_reqs: list[tuple[pathlib.Path, str]]
real_pkgs: dict[pathlib.Path, str]
faux_pkgs: dict[pathlib.Path, tuple[str, str]]
# if . is in dependencies, use it as working_dir
working_dir: Optional[str] = None
def _assemble_local_deps(config_path: pathlib.Path, config: Config) -> LocalDeps:
# ensure reserved package names are not used
reserved = {
"src",
"langgraph-api",
"langgraph_api",
"langgraph",
"langchain-core",
"langchain_core",
"pydantic",
"orjson",
"fastapi",
"uvicorn",
"psycopg",
"httpx",
"langsmith",
}
def check_reserved(name: str, ref: str):
if name in reserved:
raise ValueError(
f"Package name '{name}' used in local dep '{ref}' is reserved. "
"Rename the directory."
)
reserved.add(name)
pip_reqs = []
real_pkgs = {}
faux_pkgs = {}
working_dir = None
for local_dep in config["dependencies"]:
if not local_dep.startswith("."):
continue
resolved = config_path.parent / local_dep
# validate local dependency
if not resolved.exists():
raise FileNotFoundError(f"Could not find local dependency: {resolved}")
elif not resolved.is_dir():
raise NotADirectoryError(
f"Local dependency must be a directory: {resolved}"
)
elif not resolved.is_relative_to(config_path.parent):
raise ValueError(
f"Local dependency '{resolved}' must be a subdirectory of '{config_path.parent}'"
)
# if it's installable, add it to local_pkgs
# otherwise, add it to faux_pkgs, and create a pyproject.toml
files = os.listdir(resolved)
if "pyproject.toml" in files:
real_pkgs[resolved] = local_dep
if local_dep == ".":
working_dir = f"/deps/{resolved.name}"
elif "setup.py" in files:
real_pkgs[resolved] = local_dep
if local_dep == ".":
working_dir = f"/deps/{resolved.name}"
else:
if any(file == "__init__.py" for file in files):
# flat layout
if "-" in resolved.name:
raise ValueError(
f"Package name '{resolved.name}' contains a hyphen. "
"Rename the directory to use it as flat-layout package."
)
check_reserved(resolved.name, local_dep)
container_path = f"/deps/__outer_{resolved.name}/{resolved.name}"
else:
# src layout
container_path = f"/deps/__outer_{resolved.name}/src"
for file in files:
rfile = resolved / file
if (
rfile.is_dir()
and file != "__pycache__"
and not file.startswith(".")
):
try:
for subfile in os.listdir(rfile):
if subfile.endswith(".py"):
check_reserved(file, local_dep)
break
except PermissionError:
pass
faux_pkgs[resolved] = (local_dep, container_path)
if local_dep == ".":
working_dir = container_path
if "requirements.txt" in files:
rfile = resolved / "requirements.txt"
pip_reqs.append(
(
rfile.relative_to(config_path.parent),
f"{container_path}/requirements.txt",
)
)
return LocalDeps(pip_reqs, real_pkgs, faux_pkgs, working_dir)
def _update_graph_paths(
config_path: pathlib.Path, config: Config, local_deps: LocalDeps
) -> None:
for graph_id, import_str in config["graphs"].items():
module_str, _, attr_str = import_str.partition(":")
if not module_str or not attr_str:
message = (
'Import string "{import_str}" must be in format "<module>:<attribute>".'
)
raise ValueError(message.format(import_str=import_str))
if "/" in module_str:
resolved = config_path.parent / module_str
if not resolved.exists():
raise FileNotFoundError(f"Could not find local module: {resolved}")
elif not resolved.is_file():
raise IsADirectoryError(f"Local module must be a file: {resolved}")
else:
for path in local_deps.real_pkgs:
if resolved.is_relative_to(path):
module_str = f"/deps/{path.name}/{resolved.relative_to(path)}"
break
else:
for faux_pkg, (_, destpath) in local_deps.faux_pkgs.items():
if resolved.is_relative_to(faux_pkg):
module_str = f"{destpath}/{resolved.relative_to(faux_pkg)}"
break
else:
raise ValueError(
f"Module '{import_str}' not found in 'dependencies' list. "
"Add its containing package to 'dependencies' list."
)
# update the config
config["graphs"][graph_id] = f"{module_str}:{attr_str}"
def python_config_to_docker(config_path: pathlib.Path, config: Config, base_image: str):
# configure pip
pip_install = (
"PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt"
)
if config.get("pip_config_file"):
pip_install = f"PIP_CONFIG_FILE=/pipconfig.txt {pip_install}"
pip_config_file_str = (
f"ADD {config['pip_config_file']} /pipconfig.txt"
if config.get("pip_config_file")
else ""
)
# collect dependencies
pypi_deps = [dep for dep in config["dependencies"] if not dep.startswith(".")]
local_deps = _assemble_local_deps(config_path, config)
# rewrite graph paths
_update_graph_paths(config_path, config, local_deps)
pip_pkgs_str = f"RUN {pip_install} {' '.join(pypi_deps)}" if pypi_deps else ""
if local_deps.pip_reqs:
pip_reqs_str = os.linesep.join(
f"ADD {reqpath} {destpath}" for reqpath, destpath in local_deps.pip_reqs
)
pip_reqs_str += f'{os.linesep}RUN {pip_install} {" ".join("-r " + r for _,r in local_deps.pip_reqs)}'
else:
pip_reqs_str = ""
# https://setuptools.pypa.io/en/latest/userguide/datafiles.html#package-data
# https://til.simonwillison.net/python/pyproject
faux_pkgs_str = f"{os.linesep}{os.linesep}".join(
f"""ADD {relpath} {destpath}
RUN set -ex && \\
for line in '[project]' \\
'name = "{fullpath.name}"' \\
'version = "0.1"' \\
'[tool.setuptools.package-data]' \\
'"*" = ["**/*"]'; do \\
echo "$line" >> /deps/__outer_{fullpath.name}/pyproject.toml; \\
done"""
for fullpath, (relpath, destpath) in local_deps.faux_pkgs.items()
)
local_pkgs_str = os.linesep.join(
f"ADD {relpath} /deps/{fullpath.name}"
for fullpath, relpath in local_deps.real_pkgs.items()
)
installs = f"{os.linesep}{os.linesep}".join(
filter(
None,
[
pip_config_file_str,
pip_pkgs_str,
pip_reqs_str,
local_pkgs_str,
faux_pkgs_str,
],
)
)
store_config = config.get("store")
env_additional_config = (
""
if not store_config
else f"""
ENV LANGGRAPH_STORE='{json.dumps(store_config)}'
"""
)
return f"""FROM {base_image}:{config['python_version']}
{os.linesep.join(config["dockerfile_lines"])}
{installs}
RUN {pip_install} -e /deps/*
{env_additional_config}
ENV LANGSERVE_GRAPHS='{json.dumps(config["graphs"])}'
{f"WORKDIR {local_deps.working_dir}" if local_deps.working_dir else ""}"""
def node_config_to_docker(config_path: pathlib.Path, config: Config, base_image: str):
faux_path = f"/deps/{config_path.parent.name}"
def test_file(file_name):
full_path = config_path.parent / file_name
try:
return full_path.is_file()
except OSError:
return False
npm, yarn, pnpm = [
test_file("package-lock.json"),
test_file("yarn.lock"),
test_file("pnpm-lock.yaml"),
]
if yarn:
install_cmd = "yarn install --frozen-lockfile"
elif pnpm:
install_cmd = "pnpm i --frozen-lockfile"
elif npm:
install_cmd = "npm ci"
else:
install_cmd = "npm i"
store_config = config.get("store")
env_additional_config = (
""
if not store_config
else f"""
ENV LANGGRAPH_STORE='{json.dumps(store_config)}'
"""
)
return f"""FROM {base_image}:{config['node_version']}
{os.linesep.join(config["dockerfile_lines"])}
ADD . {faux_path}
RUN cd {faux_path} && {install_cmd}
{env_additional_config}
ENV LANGSERVE_GRAPHS='{json.dumps(config["graphs"])}'
WORKDIR {faux_path}
RUN (test ! -f /api/langgraph_api/js/build.mts && echo "Prebuild script not found, skipping") || tsx /api/langgraph_api/js/build.mts"""
def config_to_docker(config_path: pathlib.Path, config: Config, base_image: str):
if config.get("node_version"):
return node_config_to_docker(config_path, config, base_image)
return python_config_to_docker(config_path, config, base_image)
def config_to_compose(
config_path: pathlib.Path,
config: Config,
base_image: str,
watch: bool = False,
):
env_vars = config["env"].items() if isinstance(config["env"], dict) else {}
env_vars_str = "\n".join(f' {k}: "{v}"' for k, v in env_vars)
env_file_str = (
f"env_file: {config['env']}" if isinstance(config["env"], str) else ""
)
if watch:
watch_paths = [config_path.name] + [
dep for dep in config["dependencies"] if dep.startswith(".")
]
watch_actions = "\n".join(
f"""- path: {path}
action: rebuild"""
for path in watch_paths
)
watch_str = f"""
develop:
watch:
{textwrap.indent(watch_actions, " ")}
"""
else:
watch_str = ""
return f"""
{textwrap.indent(env_vars_str, " ")}
{env_file_str}
pull_policy: build
build:
context: .
dockerfile_inline: |
{textwrap.indent(config_to_docker(config_path, config, base_image), " ")}
{watch_str}
"""
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/version.py`:
```py
"""Main entrypoint into package."""
from importlib import metadata
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""
del metadata # optional, avoids polluting the results of dir(__package__)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/util.py`:
```py
def clean_empty_lines(input_str: str):
return "\n".join(filter(None, input_str.splitlines()))
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/exec.py`:
```py
import asyncio
import signal
import sys
from contextlib import contextmanager
from typing import Callable, Optional, cast
import click.exceptions
@contextmanager
def Runner():
if hasattr(asyncio, "Runner"):
with asyncio.Runner() as runner:
yield runner
else:
class _Runner:
def __enter__(self):
return self
def __exit__(self, *args):
pass
def run(self, coro):
return asyncio.run(coro)
yield _Runner()
async def subp_exec(
cmd: str,
*args: str,
input: Optional[str] = None,
wait: Optional[float] = None,
verbose: bool = False,
collect: bool = False,
on_stdout: Optional[Callable[[str], Optional[bool]]] = None,
) -> tuple[Optional[str], Optional[str]]:
if verbose:
cmd_str = f"+ {cmd} {' '.join(map(str, args))}"
if input:
print(cmd_str, " <\n", "\n".join(filter(None, input.splitlines())), sep="")
else:
print(cmd_str)
if wait:
await asyncio.sleep(wait)
try:
proc = await asyncio.create_subprocess_exec(
cmd,
*args,
stdin=asyncio.subprocess.PIPE if input else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
def signal_handler():
# make sure process exists, then terminate it
if proc.returncode is None:
proc.terminate()
original_sigint_handler = signal.getsignal(signal.SIGINT)
if sys.platform == "win32":
def handle_windows_signal(signum, frame):
signal_handler()
original_sigint_handler(signum, frame)
signal.signal(signal.SIGINT, handle_windows_signal)
# NOTE: we're not adding a handler for SIGTERM since it's ignored on Windows
else:
loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
empty_fut: asyncio.Future = asyncio.Future()
empty_fut.set_result(None)
stdout, stderr, _ = await asyncio.gather(
monitor_stream(
cast(asyncio.StreamReader, proc.stdout),
collect=True,
display=verbose,
on_line=on_stdout,
),
monitor_stream(
cast(asyncio.StreamReader, proc.stderr),
collect=True,
display=verbose,
),
proc._feed_stdin(input.encode()) if input else empty_fut, # type: ignore[attr-defined]
)
returncode = await proc.wait()
if (
returncode is not None
and returncode != 0 # success
and returncode != 130 # user interrupt
):
sys.stdout.write(stdout.decode() if stdout else "")
sys.stderr.write(stderr.decode() if stderr else "")
raise click.exceptions.Exit(returncode)
if collect:
return (
stdout.decode() if stdout else None,
stderr.decode() if stderr else None,
)
else:
return None, None
finally:
try:
if proc.returncode is None:
try:
proc.terminate()
except (ProcessLookupError, KeyboardInterrupt):
pass
if sys.platform == "win32":
signal.signal(signal.SIGINT, original_sigint_handler)
else:
loop.remove_signal_handler(signal.SIGINT)
loop.remove_signal_handler(signal.SIGTERM)
except UnboundLocalError:
pass
async def monitor_stream(
stream: asyncio.StreamReader,
collect: bool = False,
display: bool = False,
on_line: Optional[Callable[[str], Optional[bool]]] = None,
) -> Optional[bytearray]:
if collect:
ba = bytearray()
def handle(line: bytes, overrun: bool):
nonlocal on_line
nonlocal display
if display:
sys.stdout.buffer.write(line)
if overrun:
return
if collect:
ba.extend(line)
if on_line:
if on_line(line.decode()):
on_line = None
display = True
"""Adapted from asyncio.StreamReader.readline() to handle LimitOverrunError."""
sep = b"\n"
seplen = len(sep)
while True:
try:
line = await stream.readuntil(sep)
overrun = False
except asyncio.IncompleteReadError as e:
line = e.partial
overrun = False
except asyncio.LimitOverrunError as e:
if stream._buffer.startswith(sep, e.consumed):
line = stream._buffer[: e.consumed + seplen]
else:
line = stream._buffer.clear()
overrun = True
stream._maybe_resume_transport()
await asyncio.to_thread(handle, line, overrun)
if line == b"":
break
if collect:
return ba
else:
return None
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/constants.py`:
```py
DEFAULT_CONFIG = "langgraph.json"
DEFAULT_PORT = 8123
# analytics
SUPABASE_PUBLIC_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Imt6cmxwcG9qaW5wY3l5YWlweG5iIiwicm9sZSI6ImFub24iLCJpYXQiOjE3MTkyNTc1NzksImV4cCI6MjAzNDgzMzU3OX0.kkVOlLz3BxemA5nP-vat3K4qRtrDuO4SwZSR_htcX9c"
SUPABASE_URL = "https://kzrlppojinpcyyaipxnb.supabase.co"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/docker.py`:
```py
import json
import pathlib
import shutil
from typing import Literal, NamedTuple, Optional
import click.exceptions
from langgraph_cli.exec import subp_exec
ROOT = pathlib.Path(__file__).parent.resolve()
DEFAULT_POSTGRES_URI = (
"postgres://postgres:postgres@langgraph-postgres:5432/postgres?sslmode=disable"
)
class Version(NamedTuple):
major: int
minor: int
patch: int
DockerComposeType = Literal["plugin", "standalone"]
class DockerCapabilities(NamedTuple):
version_docker: Version
version_compose: Version
healthcheck_start_interval: bool
compose_type: DockerComposeType = "plugin"
def _parse_version(version: str) -> Version:
parts = version.split(".", 2)
if len(parts) == 1:
major = parts[0]
minor = "0"
patch = "0"
elif len(parts) == 2:
major, minor = parts
patch = "0"
else:
major, minor, patch = parts
return Version(int(major.lstrip("v")), int(minor), int(patch.split("-")[0]))
def check_capabilities(runner) -> DockerCapabilities:
# check docker available
if shutil.which("docker") is None:
raise click.UsageError("Docker not installed") from None
try:
stdout, _ = runner.run(subp_exec("docker", "info", "-f", "json", collect=True))
info = json.loads(stdout)
except (click.exceptions.Exit, json.JSONDecodeError):
raise click.UsageError("Docker not installed or not running") from None
if not info["ServerVersion"]:
raise click.UsageError("Docker not running") from None
compose_type: DockerComposeType
try:
compose = next(
p for p in info["ClientInfo"]["Plugins"] if p["Name"] == "compose"
)
compose_version_str = compose["Version"]
compose_type = "plugin"
except (KeyError, StopIteration):
if shutil.which("docker-compose") is None:
raise click.UsageError("Docker Compose not installed") from None
compose_version_str, _ = runner.run(
subp_exec("docker-compose", "--version", "--short", collect=True)
)
compose_type = "standalone"
# parse versions
docker_version = _parse_version(info["ServerVersion"])
compose_version = _parse_version(compose_version_str)
# check capabilities
return DockerCapabilities(
version_docker=docker_version,
version_compose=compose_version,
healthcheck_start_interval=docker_version >= Version(25, 0, 0),
compose_type=compose_type,
)
def debugger_compose(
*, port: Optional[int] = None, base_url: Optional[str] = None
) -> dict:
if port is None:
return ""
config = {
"langgraph-debugger": {
"image": "langchain/langgraph-debugger",
"restart": "on-failure",
"depends_on": {
"langgraph-postgres": {"condition": "service_healthy"},
},
"ports": [f'"{port}:3968"'],
}
}
if base_url:
config["langgraph-debugger"]["environment"] = {
"VITE_STUDIO_LOCAL_GRAPH_URL": base_url
}
return config
# Function to convert dictionary to YAML
def dict_to_yaml(d: dict, *, indent: int = 0) -> str:
"""Convert a dictionary to a YAML string."""
yaml_str = ""
for idx, (key, value) in enumerate(d.items()):
# Format things in a visually appealing way
# Use an extra newline for top-level keys only
if idx >= 1 and indent < 2:
yaml_str += "\n"
space = " " * indent
if isinstance(value, dict):
yaml_str += f"{space}{key}:\n" + dict_to_yaml(value, indent=indent + 1)
elif isinstance(value, list):
yaml_str += f"{space}{key}:\n"
for item in value:
yaml_str += f"{space} - {item}\n"
else:
yaml_str += f"{space}{key}: {value}\n"
return yaml_str
def compose_as_dict(
capabilities: DockerCapabilities,
*,
port: int,
debugger_port: Optional[int] = None,
debugger_base_url: Optional[str] = None,
# postgres://user:password@host:port/database?option=value
postgres_uri: Optional[str] = None,
) -> dict:
"""Create a docker compose file as a dictionary in YML style."""
if postgres_uri is None:
include_db = True
postgres_uri = DEFAULT_POSTGRES_URI
else:
include_db = False
# The services below are defined in a non-intuitive order to match
# the existing unit tests for this function.
# It's fine to re-order just requires updating the unit tests, so it should
# be done with caution.
# Define the Redis service first as per the test order
services = {
"langgraph-redis": {
"image": "redis:6",
"healthcheck": {
"test": "redis-cli ping",
"interval": "5s",
"timeout": "1s",
"retries": 5,
},
}
}
# Add Postgres service before langgraph-api if it is needed
if include_db:
services["langgraph-postgres"] = {
"image": "postgres:16",
"ports": ['"5433:5432"'],
"environment": {
"POSTGRES_DB": "postgres",
"POSTGRES_USER": "postgres",
"POSTGRES_PASSWORD": "postgres",
},
"volumes": ["langgraph-data:/var/lib/postgresql/data"],
"healthcheck": {
"test": "pg_isready -U postgres",
"start_period": "10s",
"timeout": "1s",
"retries": 5,
},
}
if capabilities.healthcheck_start_interval:
services["langgraph-postgres"]["healthcheck"]["interval"] = "60s"
services["langgraph-postgres"]["healthcheck"]["start_interval"] = "1s"
else:
services["langgraph-postgres"]["healthcheck"]["interval"] = "5s"
# Add optional debugger service if debugger_port is specified
if debugger_port:
services["langgraph-debugger"] = debugger_compose(
port=debugger_port, base_url=debugger_base_url
)["langgraph-debugger"]
# Add langgraph-api service
services["langgraph-api"] = {
"ports": [f'"{port}:8000"'],
"depends_on": {
"langgraph-redis": {"condition": "service_healthy"},
},
"environment": {
"REDIS_URI": "redis://langgraph-redis:6379",
"POSTGRES_URI": postgres_uri,
},
}
# If Postgres is included, add it to the dependencies of langgraph-api
if include_db:
services["langgraph-api"]["depends_on"]["langgraph-postgres"] = {
"condition": "service_healthy"
}
# Additional healthcheck for langgraph-api if required
if capabilities.healthcheck_start_interval:
services["langgraph-api"]["healthcheck"] = {
"test": "python /api/healthcheck.py",
"interval": "60s",
"start_interval": "1s",
"start_period": "10s",
}
# Final compose dictionary with volumes included if needed
compose_dict = {}
if include_db:
compose_dict["volumes"] = {"langgraph-data": {"driver": "local"}}
compose_dict["services"] = services
return compose_dict
def compose(
capabilities: DockerCapabilities,
*,
port: int,
debugger_port: Optional[int] = None,
debugger_base_url: Optional[str] = None,
# postgres://user:password@host:port/database?option=value
postgres_uri: Optional[str] = None,
) -> str:
"""Create a docker compose file as a string."""
compose_content = compose_as_dict(
capabilities,
port=port,
debugger_port=debugger_port,
debugger_base_url=debugger_base_url,
postgres_uri=postgres_uri,
)
compose_str = dict_to_yaml(compose_content)
return compose_str
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/templates.py`:
```py
import os
import shutil
import sys
from io import BytesIO
from typing import Dict, Optional
from urllib import error, request
from zipfile import ZipFile
import click
TEMPLATES: Dict[str, Dict[str, str]] = {
"New LangGraph Project": {
"description": "A simple, minimal chatbot with memory.",
"python": "https://github.com/langchain-ai/new-langgraph-project/archive/refs/heads/main.zip",
"js": "https://github.com/langchain-ai/new-langgraphjs-project/archive/refs/heads/main.zip",
},
"ReAct Agent": {
"description": "A simple agent that can be flexibly extended to many tools.",
"python": "https://github.com/langchain-ai/react-agent/archive/refs/heads/main.zip",
"js": "https://github.com/langchain-ai/react-agent-js/archive/refs/heads/main.zip",
},
"Memory Agent": {
"description": "A ReAct-style agent with an additional tool to store memories for use across conversational threads.",
"python": "https://github.com/langchain-ai/memory-agent/archive/refs/heads/main.zip",
"js": "https://github.com/langchain-ai/memory-agent-js/archive/refs/heads/main.zip",
},
"Retrieval Agent": {
"description": "An agent that includes a retrieval-based question-answering system.",
"python": "https://github.com/langchain-ai/retrieval-agent-template/archive/refs/heads/main.zip",
"js": "https://github.com/langchain-ai/retrieval-agent-template-js/archive/refs/heads/main.zip",
},
"Data-enrichment Agent": {
"description": "An agent that performs web searches and organizes its findings into a structured format.",
"python": "https://github.com/langchain-ai/data-enrichment/archive/refs/heads/main.zip",
"js": "https://github.com/langchain-ai/data-enrichment-js/archive/refs/heads/main.zip",
},
}
# Generate TEMPLATE_IDS programmatically
TEMPLATE_ID_TO_CONFIG = {
f"{name.lower().replace(' ', '-')}-{lang}": (name, lang, url)
for name, versions in TEMPLATES.items()
for lang, url in versions.items()
if lang in {"python", "js"}
}
TEMPLATE_IDS = list(TEMPLATE_ID_TO_CONFIG.keys())
TEMPLATE_HELP_STRING = (
"The name of the template to use. Available options:\n"
+ "\n".join(f"{id_}" for id_ in TEMPLATE_ID_TO_CONFIG)
)
def _choose_template() -> str:
"""Presents a list of templates to the user and prompts them to select one.
Returns:
str: The URL of the selected template.
"""
click.secho("🌟 Please select a template:", bold=True, fg="yellow")
for idx, (template_name, template_info) in enumerate(TEMPLATES.items(), 1):
click.secho(f"{idx}. ", nl=False, fg="cyan")
click.secho(template_name, fg="cyan", nl=False)
click.secho(f" - {template_info['description']}", fg="white")
# Get the template choice from the user, defaulting to the first template if blank
template_choice: Optional[int] = click.prompt(
"Enter the number of your template choice (default is 1)",
type=int,
default=1,
show_default=False,
)
template_keys = list(TEMPLATES.keys())
if 1 <= template_choice <= len(template_keys):
selected_template: str = template_keys[template_choice - 1]
else:
click.secho("❌ Invalid choice. Please try again.", fg="red")
return _choose_template()
# Prompt the user to choose between Python or JS/TS version
click.secho(
f"\nYou selected: {selected_template} - {TEMPLATES[selected_template]['description']}",
fg="green",
)
version_choice: int = click.prompt(
"Choose language (1 for Python 🐍, 2 for JS/TS 🌐)", type=int
)
if version_choice == 1:
return TEMPLATES[selected_template]["python"]
elif version_choice == 2:
return TEMPLATES[selected_template]["js"]
else:
click.secho("❌ Invalid choice. Please try again.", fg="red")
return _choose_template()
def _download_repo_with_requests(repo_url: str, path: str) -> None:
"""Download a ZIP archive from the given URL and extracts it to the specified path.
Args:
repo_url (str): The URL of the repository to download.
path (str): The path where the repository should be extracted.
"""
click.secho("📥 Attempting to download repository as a ZIP archive...", fg="yellow")
click.secho(f"URL: {repo_url}", fg="yellow")
try:
with request.urlopen(repo_url) as response:
if response.status == 200:
with ZipFile(BytesIO(response.read())) as zip_file:
zip_file.extractall(path)
# Move extracted contents to path
for item in os.listdir(path):
if item.endswith("-main"):
extracted_dir = os.path.join(path, item)
for filename in os.listdir(extracted_dir):
shutil.move(os.path.join(extracted_dir, filename), path)
shutil.rmtree(extracted_dir)
click.secho(
f"✅ Downloaded and extracted repository to {path}", fg="green"
)
except error.HTTPError as e:
click.secho(
f"❌ Error: Failed to download repository.\n" f"Details: {e}\n",
fg="red",
bold=True,
err=True,
)
sys.exit(1)
def _get_template_url(template_name: str) -> Optional[str]:
"""
Retrieves the template URL based on the provided template name.
Args:
template_name (str): The name of the template.
Returns:
Optional[str]: The URL of the template if found, else None.
"""
if template_name in TEMPLATES:
click.secho(f"Template selected: {template_name}", fg="green")
version_choice: int = click.prompt(
"Choose version (1 for Python 🐍, 2 for JS/TS 🌐)", type=int
)
if version_choice == 1:
return TEMPLATES[template_name]["python"]
elif version_choice == 2:
return TEMPLATES[template_name]["js"]
else:
click.secho("❌ Invalid choice. Please try again.", fg="red")
return None
else:
click.secho(
f"Template '{template_name}' not found. Please select from the available options.",
fg="red",
)
return None
def create_new(path: Optional[str], template: Optional[str]) -> None:
"""Create a new LangGraph project at the specified PATH using the chosen TEMPLATE.
Args:
path (Optional[str]): The path where the new project will be created.
template (Optional[str]): The name of the template to use.
"""
# Prompt for path if not provided
if not path:
path = click.prompt(
"📂 Please specify the path to create the application", default="."
)
path = os.path.abspath(path) # Ensure path is absolute
# Check if path exists and is not empty
if os.path.exists(path) and os.listdir(path):
click.secho(
"❌ The specified directory already exists and is not empty. "
"Aborting to prevent overwriting files.",
fg="red",
bold=True,
)
sys.exit(1)
# Get template URL either from command-line argument or
# through interactive selection
if template:
if template not in TEMPLATE_ID_TO_CONFIG:
# Format available options in a readable way with descriptions
template_options = ""
for id_ in TEMPLATE_IDS:
name, lang, _ = TEMPLATE_ID_TO_CONFIG[id_]
description = TEMPLATES[name]["description"]
# Add each template option with color formatting
template_options += (
click.style("- ", fg="yellow", bold=True)
+ click.style(f"{id_}", fg="cyan")
+ click.style(f": {description}", fg="white")
+ "\n"
)
# Display error message with colors and formatting
click.secho("❌ Error:", fg="red", bold=True, nl=False)
click.secho(f" Template '{template}' not found.", fg="red")
click.secho(
"Please select from the available options:\n", fg="yellow", bold=True
)
click.secho(template_options, fg="cyan")
sys.exit(1)
_, _, template_url = TEMPLATE_ID_TO_CONFIG[template]
else:
template_url = _choose_template()
# Download and extract the template
_download_repo_with_requests(template_url, path)
click.secho(f"🎉 New project created at {path}", fg="green", bold=True)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/cli.py`:
```py
import os
import pathlib
import shutil
import sys
from typing import Callable, Optional, Sequence
import click
import click.exceptions
from click import secho
import langgraph_cli.config
import langgraph_cli.docker
from langgraph_cli.analytics import log_command
from langgraph_cli.config import Config
from langgraph_cli.constants import DEFAULT_CONFIG, DEFAULT_PORT
from langgraph_cli.docker import DockerCapabilities
from langgraph_cli.exec import Runner, subp_exec
from langgraph_cli.progress import Progress
from langgraph_cli.templates import TEMPLATE_HELP_STRING, create_new
from langgraph_cli.version import __version__
OPT_DOCKER_COMPOSE = click.option(
"--docker-compose",
"-d",
help="Advanced: Path to docker-compose.yml file with additional services to launch.",
type=click.Path(
exists=True,
file_okay=True,
dir_okay=False,
resolve_path=True,
path_type=pathlib.Path,
),
)
OPT_CONFIG = click.option(
"--config",
"-c",
help="""Path to configuration file declaring dependencies, graphs and environment variables.
\b
Config file must be a JSON file that has the following keys:
- "dependencies": array of dependencies for langgraph API server. Dependencies can be one of the following:
- ".", which would look for local python packages, as well as pyproject.toml, setup.py or requirements.txt in the app directory
- "./local_package"
- "<package_name>
- "graphs": mapping from graph ID to path where the compiled graph is defined, i.e. ./your_package/your_file.py:variable, where
"variable" is an instance of langgraph.graph.graph.CompiledGraph
- "env": (optional) path to .env file or a mapping from environment variable to its value
- "python_version": (optional) 3.11, 3.12, or 3.13. Defaults to 3.11
- "pip_config_file": (optional) path to pip config file
- "dockerfile_lines": (optional) array of additional lines to add to Dockerfile following the import from parent image
\b
Example:
langgraph up -c langgraph.json
\b
Example:
{
"dependencies": [
"langchain_openai",
"./your_package"
],
"graphs": {
"my_graph_id": "./your_package/your_file.py:variable"
},
"env": "./.env"
}
\b
Example:
{
"python_version": "3.11",
"dependencies": [
"langchain_openai",
"."
],
"graphs": {
"my_graph_id": "./your_package/your_file.py:variable"
},
"env": {
"OPENAI_API_KEY": "secret-key"
}
}
Defaults to looking for langgraph.json in the current directory.""",
default=DEFAULT_CONFIG,
type=click.Path(
exists=True,
file_okay=True,
dir_okay=False,
resolve_path=True,
path_type=pathlib.Path,
),
)
OPT_PORT = click.option(
"--port",
"-p",
type=int,
default=DEFAULT_PORT,
show_default=True,
help="""
Port to expose.
\b
Example:
langgraph up --port 8000
\b
""",
)
OPT_RECREATE = click.option(
"--recreate/--no-recreate",
default=False,
show_default=True,
help="Recreate containers even if their configuration and image haven't changed",
)
OPT_PULL = click.option(
"--pull/--no-pull",
default=True,
show_default=True,
help="""
Pull latest images. Use --no-pull for running the server with locally-built images.
\b
Example:
langgraph up --no-pull
\b
""",
)
OPT_VERBOSE = click.option(
"--verbose",
is_flag=True,
default=False,
help="Show more output from the server logs",
)
OPT_WATCH = click.option("--watch", is_flag=True, help="Restart on file changes")
OPT_DEBUGGER_PORT = click.option(
"--debugger-port",
type=int,
help="Pull the debugger image locally and serve the UI on specified port",
)
OPT_DEBUGGER_BASE_URL = click.option(
"--debugger-base-url",
type=str,
help="URL used by the debugger to access LangGraph API. Defaults to http://127.0.0.1:[PORT]",
)
OPT_POSTGRES_URI = click.option(
"--postgres-uri",
help="Postgres URI to use for the database. Defaults to launching a local database",
)
@click.group()
@click.version_option(version=__version__, prog_name="LangGraph CLI")
def cli():
pass
@OPT_RECREATE
@OPT_PULL
@OPT_PORT
@OPT_DOCKER_COMPOSE
@OPT_CONFIG
@OPT_VERBOSE
@OPT_DEBUGGER_PORT
@OPT_DEBUGGER_BASE_URL
@OPT_WATCH
@OPT_POSTGRES_URI
@click.option(
"--wait",
is_flag=True,
help="Wait for services to start before returning. Implies --detach",
)
@cli.command(help="🚀 Launch LangGraph API server.")
@log_command
def up(
config: pathlib.Path,
docker_compose: Optional[pathlib.Path],
port: int,
recreate: bool,
pull: bool,
watch: bool,
wait: bool,
verbose: bool,
debugger_port: Optional[int],
debugger_base_url: Optional[str],
postgres_uri: Optional[str],
):
click.secho("Starting LangGraph API server...", fg="green")
click.secho(
"""For local dev, requires env var LANGSMITH_API_KEY with access to LangGraph Cloud closed beta.
For production use, requires a license key in env var LANGGRAPH_CLOUD_LICENSE_KEY.""",
)
with Runner() as runner, Progress(message="Pulling...") as set:
capabilities = langgraph_cli.docker.check_capabilities(runner)
args, stdin = prepare(
runner,
capabilities=capabilities,
config_path=config,
docker_compose=docker_compose,
port=port,
pull=pull,
watch=watch,
verbose=verbose,
debugger_port=debugger_port,
debugger_base_url=debugger_base_url,
postgres_uri=postgres_uri,
)
# add up + options
args.extend(["up", "--remove-orphans"])
if recreate:
args.extend(["--force-recreate", "--renew-anon-volumes"])
try:
runner.run(subp_exec("docker", "volume", "rm", "langgraph-data"))
except click.exceptions.Exit:
pass
if watch:
args.append("--watch")
if wait:
args.append("--wait")
else:
args.append("--abort-on-container-exit")
# run docker compose
set("Building...")
def on_stdout(line: str):
if "unpacking to docker.io" in line:
set("Starting...")
elif "Application startup complete" in line:
debugger_origin = (
f"http://localhost:{debugger_port}"
if debugger_port
else "https://smith.langchain.com"
)
debugger_base_url_query = (
debugger_base_url or f"http://127.0.0.1:{port}"
)
set("")
sys.stdout.write(
f"""Ready!
- API: http://localhost:{port}
- Docs: http://localhost:{port}/docs
- LangGraph Studio: {debugger_origin}/studio/?baseUrl={debugger_base_url_query}
"""
)
sys.stdout.flush()
return True
if capabilities.compose_type == "plugin":
compose_cmd = ["docker", "compose"]
elif capabilities.compose_type == "standalone":
compose_cmd = ["docker-compose"]
runner.run(
subp_exec(
*compose_cmd,
*args,
input=stdin,
verbose=verbose,
on_stdout=on_stdout,
)
)
def _build(
runner,
set: Callable[[str], None],
config: pathlib.Path,
config_json: dict,
base_image: Optional[str],
pull: bool,
tag: str,
passthrough: Sequence[str] = (),
):
base_image = base_image or (
"langchain/langgraphjs-api"
if config_json.get("node_version")
else "langchain/langgraph-api"
)
# pull latest images
if pull:
runner.run(
subp_exec(
"docker",
"pull",
(
f"{base_image}:{config_json['node_version']}"
if config_json.get("node_version")
else f"{base_image}:{config_json['python_version']}"
),
verbose=True,
)
)
set("Building...")
# apply options
args = [
"-f",
"-", # stdin
"-t",
tag,
]
# apply config
stdin = langgraph_cli.config.config_to_docker(config, config_json, base_image)
# run docker build
runner.run(
subp_exec(
"docker",
"build",
*args,
*passthrough,
str(config.parent),
input=stdin,
verbose=True,
)
)
@OPT_CONFIG
@OPT_PULL
@click.option(
"--tag",
"-t",
help="""Tag for the docker image.
\b
Example:
langgraph build -t my-image
\b
""",
required=True,
)
@click.option(
"--base-image",
hidden=True,
)
@click.argument("docker_build_args", nargs=-1, type=click.UNPROCESSED)
@cli.command(
help="📦 Build LangGraph API server Docker image.",
context_settings=dict(
ignore_unknown_options=True,
),
)
@log_command
def build(
config: pathlib.Path,
docker_build_args: Sequence[str],
base_image: Optional[str],
pull: bool,
tag: str,
):
with Runner() as runner, Progress(message="Pulling...") as set:
if shutil.which("docker") is None:
raise click.UsageError("Docker not installed") from None
config_json = langgraph_cli.config.validate_config_file(config)
_build(
runner, set, config, config_json, base_image, pull, tag, docker_build_args
)
def _get_docker_ignore_content() -> str:
"""Return the content of a .dockerignore file.
This file is used to exclude files and directories from the Docker build context.
It may be overly broad, but it's better to be safe than sorry.
The main goal is to exclude .env files by default.
"""
return """\
# Ignore node_modules and other dependency directories
node_modules
bower_components
vendor
# Ignore logs and temporary files
*.log
*.tmp
*.swp
# Ignore .env files and other environment files
.env
.env.*
*.local
# Ignore git-related files
.git
.gitignore
# Ignore Docker-related files and configs
.dockerignore
docker-compose.yml
# Ignore build and cache directories
dist
build
.cache
__pycache__
# Ignore IDE and editor configurations
.vscode
.idea
*.sublime-project
*.sublime-workspace
.DS_Store # macOS-specific
# Ignore test and coverage files
coverage
*.coverage
*.test.js
*.spec.js
tests
"""
@OPT_CONFIG
@click.argument("save_path", type=click.Path(resolve_path=True))
@cli.command(
help="🐳 Generate a Dockerfile for the LangGraph API server, with Docker Compose options."
)
@click.option(
# Add a flag for adding a docker-compose.yml file as part of the output
"--add-docker-compose",
help=(
"Add additional files for running the LangGraph API server with "
"docker-compose. These files include a docker-compose.yml, .env file, "
"and a .dockerignore file."
),
is_flag=True,
)
@log_command
def dockerfile(save_path: str, config: pathlib.Path, add_docker_compose: bool) -> None:
save_path = pathlib.Path(save_path).absolute()
secho(f"🔍 Validating configuration at path: {config}", fg="yellow")
config_json = langgraph_cli.config.validate_config_file(config)
secho("✅ Configuration validated!", fg="green")
secho(f"📝 Generating Dockerfile at {save_path}", fg="yellow")
with open(str(save_path), "w", encoding="utf-8") as f:
f.write(
langgraph_cli.config.config_to_docker(
config,
config_json,
(
"langchain/langgraphjs-api"
if config_json.get("node_version")
else "langchain/langgraph-api"
),
)
)
secho("✅ Created: Dockerfile", fg="green")
if add_docker_compose:
# Add docker compose and related files
# Add .dockerignore file in the same directory as the Dockerfile
with open(str(save_path.parent / ".dockerignore"), "w", encoding="utf-8") as f:
f.write(_get_docker_ignore_content())
secho("✅ Created: .dockerignore", fg="green")
# Generate a docker-compose.yml file
path = str(save_path.parent / "docker-compose.yml")
with open(path, "w", encoding="utf-8") as f:
with Runner() as runner:
capabilities = langgraph_cli.docker.check_capabilities(runner)
compose_dict = langgraph_cli.docker.compose_as_dict(
capabilities,
port=8123,
)
# Add .env file to the docker-compose.yml for the langgraph-api service
compose_dict["services"]["langgraph-api"]["env_file"] = [".env"]
# Add the Dockerfile to the build context
compose_dict["services"]["langgraph-api"]["build"] = {
"context": ".",
"dockerfile": save_path.name,
}
f.write(langgraph_cli.docker.dict_to_yaml(compose_dict))
secho("✅ Created: docker-compose.yml", fg="green")
# Check if the .env file exists in the same directory as the Dockerfile
if not (save_path.parent / ".env").exists():
# Also add an empty .env file
with open(str(save_path.parent / ".env"), "w", encoding="utf-8") as f:
f.writelines(
[
"# Uncomment the following line to add your LangSmith API key",
"\n",
"# LANGSMITH_API_KEY=your-api-key",
"\n",
"# Or if you have a LangGraph Cloud license key, "
"then uncomment the following line: ",
"\n",
"# LANGGRAPH_CLOUD_LICENSE_KEY=your-license-key",
"\n",
"# Add any other environment variables go below...",
]
)
secho("✅ Created: .env", fg="green")
else:
# Do nothing since the .env file already exists. Not a great
# idea to overwrite in case the user has added custom env vars set
# in the .env file already.
secho("➖ Skipped: .env. It already exists!", fg="yellow")
secho(
f"🎉 Files generated successfully at path {save_path.parent}!",
fg="cyan",
bold=True,
)
@click.option(
"--host",
default="127.0.0.1",
help="Network interface to bind the development server to. Default 127.0.0.1 is recommended for security. Only use 0.0.0.0 in trusted networks",
)
@click.option(
"--port",
default=2024,
type=int,
help="Port number to bind the development server to. Example: langgraph dev --port 8000",
)
@click.option(
"--no-reload",
is_flag=True,
help="Disable automatic reloading when code changes are detected",
)
@click.option(
"--config",
type=click.Path(exists=True),
default="langgraph.json",
help="Path to configuration file declaring dependencies, graphs and environment variables",
)
@click.option(
"--n-jobs-per-worker",
default=None,
type=int,
help="Maximum number of concurrent jobs each worker process can handle. Default: 10",
)
@click.option(
"--no-browser",
is_flag=True,
help="Skip automatically opening the browser when the server starts",
)
@click.option(
"--debug-port",
default=None,
type=int,
help="Enable remote debugging by listening on specified port. Requires debugpy to be installed",
)
@click.option(
"--wait-for-client",
is_flag=True,
help="Wait for a debugger client to connect to the debug port before starting the server",
default=False,
)
@cli.command(
"dev",
help="🏃♀️➡️ Run LangGraph API server in development mode with hot reloading and debugging support",
)
@log_command
def dev(
host: str,
port: int,
no_reload: bool,
config: pathlib.Path,
n_jobs_per_worker: Optional[int],
no_browser: bool,
debug_port: Optional[int],
wait_for_client: bool,
):
"""CLI entrypoint for running the LangGraph API server."""
try:
from langgraph_api.cli import run_server
except ImportError:
try:
import pkg_resources
pkg_resources.require("langgraph-api-inmem")
except (ImportError, pkg_resources.DistributionNotFound):
raise click.UsageError(
"Required package 'langgraph-api-inmem' is not installed.\n"
"Please install it with:\n\n"
' pip install -U "langgraph-cli[inmem]"\n\n'
"If you're developing the langgraph-cli package locally, you can install in development mode:\n"
" pip install -e ."
) from None
raise click.UsageError(
"Could not import run_server. This likely means your installation is incomplete.\n"
"Please ensure langgraph-cli is installed with the 'inmem' extra: pip install -U \"langgraph-cli[inmem]\""
) from None
config_json = langgraph_cli.config.validate_config_file(config)
cwd = os.getcwd()
sys.path.append(cwd)
dependencies = config_json.get("dependencies", [])
for dep in dependencies:
dep_path = pathlib.Path(cwd) / dep
if dep_path.is_dir() and dep_path.exists():
sys.path.append(str(dep_path))
graphs = config_json.get("graphs", {})
run_server(
host,
port,
not no_reload,
graphs,
n_jobs_per_worker=n_jobs_per_worker,
open_browser=not no_browser,
debug_port=debug_port,
env=config_json.get("env"),
store=config_json.get("store"),
wait_for_client=wait_for_client,
)
@click.argument("path", required=False)
@click.option(
"--template",
type=str,
help=TEMPLATE_HELP_STRING,
)
@cli.command("new", help="🌱 Create a new LangGraph project from a template.")
@log_command
def new(path: Optional[str], template: Optional[str]) -> None:
"""Create a new LangGraph project from a template."""
return create_new(path, template)
def prepare_args_and_stdin(
*,
capabilities: DockerCapabilities,
config_path: pathlib.Path,
config: Config,
docker_compose: Optional[pathlib.Path],
port: int,
watch: bool,
debugger_port: Optional[int] = None,
debugger_base_url: Optional[str] = None,
postgres_uri: Optional[str] = None,
):
# prepare args
stdin = langgraph_cli.docker.compose(
capabilities,
port=port,
debugger_port=debugger_port,
debugger_base_url=debugger_base_url,
postgres_uri=postgres_uri,
)
args = [
"--project-directory",
str(config_path.parent),
]
# apply options
if docker_compose:
args.extend(["-f", str(docker_compose)])
args.extend(["-f", "-"]) # stdin
# apply config
stdin += langgraph_cli.config.config_to_compose(
config_path,
config,
watch=watch,
base_image=(
"langchain/langgraphjs-api"
if config.get("node_version")
else "langchain/langgraph-api"
),
)
return args, stdin
def prepare(
runner,
*,
capabilities: DockerCapabilities,
config_path: pathlib.Path,
docker_compose: Optional[pathlib.Path],
port: int,
pull: bool,
watch: bool,
verbose: bool,
debugger_port: Optional[int] = None,
debugger_base_url: Optional[str] = None,
postgres_uri: Optional[str] = None,
):
config_json = langgraph_cli.config.validate_config_file(config_path)
# pull latest images
if pull:
runner.run(
subp_exec(
"docker",
"pull",
(
f"langchain/langgraphjs-api:{config_json['node_version']}"
if config_json.get("node_version")
else f"langchain/langgraph-api:{config_json['python_version']}"
),
verbose=verbose,
)
)
args, stdin = prepare_args_and_stdin(
capabilities=capabilities,
config_path=config_path,
config=config_json,
docker_compose=docker_compose,
port=port,
watch=watch,
debugger_port=debugger_port,
debugger_base_url=debugger_base_url or f"http://127.0.0.1:{port}",
postgres_uri=postgres_uri,
)
return args, stdin
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/analytics.py`:
```py
import functools
import json
import os
import pathlib
import platform
import threading
import urllib.error
import urllib.request
from typing import Any, TypedDict
from langgraph_cli.constants import (
DEFAULT_CONFIG,
DEFAULT_PORT,
SUPABASE_PUBLIC_API_KEY,
SUPABASE_URL,
)
from langgraph_cli.version import __version__
class LogData(TypedDict):
os: str
os_version: str
python_version: str
cli_version: str
cli_command: str
params: dict[str, Any]
def get_anonymized_params(kwargs: dict[str, Any]) -> dict[str, bool]:
params = {}
# anonymize params with values
if config := kwargs.get("config"):
if config != pathlib.Path(DEFAULT_CONFIG).resolve():
params["config"] = True
if port := kwargs.get("port"):
if port != DEFAULT_PORT:
params["port"] = True
if kwargs.get("docker_compose"):
params["docker_compose"] = True
if kwargs.get("debugger_port"):
params["debugger_port"] = True
if kwargs.get("postgres_uri"):
params["postgres_uri"] = True
# pick up exact values for boolean flags
for boolean_param in ["recreate", "pull", "watch", "wait", "verbose"]:
if kwargs.get(boolean_param):
params[boolean_param] = kwargs[boolean_param]
return params
def log_data(data: LogData) -> None:
headers = {
"Content-Type": "application/json",
"apikey": SUPABASE_PUBLIC_API_KEY,
"User-Agent": "Mozilla/5.0",
}
supabase_url = SUPABASE_URL
req = urllib.request.Request(
f"{supabase_url}/rest/v1/logs",
data=json.dumps(data).encode("utf-8"),
headers=headers,
method="POST",
)
try:
urllib.request.urlopen(req)
except urllib.error.URLError:
pass
def log_command(func):
@functools.wraps(func)
def decorator(*args, **kwargs):
if os.getenv("LANGGRAPH_CLI_NO_ANALYTICS") == "1":
return func(*args, **kwargs)
data = {
"os": platform.system(),
"os_version": platform.version(),
"python_version": platform.python_version(),
"cli_version": __version__,
"cli_command": func.__name__,
"params": get_anonymized_params(kwargs),
}
background_thread = threading.Thread(target=log_data, args=(data,))
background_thread.start()
return func(*args, **kwargs)
return decorator
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/langgraph_cli/progress.py`:
```py
import sys
import threading
import time
from typing import Callable
class Progress:
delay: float = 0.1
@staticmethod
def spinning_cursor():
while True:
yield from "|/-\\"
def __init__(self, *, message=""):
self.message = message
self.spinner_generator = self.spinning_cursor()
def spinner_iteration(self):
message = self.message
sys.stdout.write(next(self.spinner_generator) + " " + message)
sys.stdout.flush()
time.sleep(self.delay)
# clear the spinner and message
sys.stdout.write(
"\b" * (len(message) + 2)
+ " " * (len(message) + 2)
+ "\b" * (len(message) + 2)
)
sys.stdout.flush()
def spinner_task(self):
while self.message:
message = self.message
sys.stdout.write(next(self.spinner_generator) + " " + message)
sys.stdout.flush()
time.sleep(self.delay)
# clear the spinner and message
sys.stdout.write(
"\b" * (len(message) + 2)
+ " " * (len(message) + 2)
+ "\b" * (len(message) + 2)
)
sys.stdout.flush()
def __enter__(self) -> Callable[[str], None]:
self.thread = threading.Thread(target=self.spinner_task)
self.thread.start()
def set_message(message):
self.message = message
if not message:
self.thread.join()
return set_message
def __exit__(self, exception, value, tb):
self.message = ""
try:
self.thread.join()
finally:
del self.thread
if exception is not None:
return False
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/pyproject.toml`:
```toml
[tool.poetry]
name = "langgraph-cli"
version = "0.1.61"
description = "CLI for interacting with LangGraph API"
authors = []
license = "MIT"
readme = "README.md"
repository = "https://www.github.com/langchain-ai/langgraph"
packages = [{ include = "langgraph_cli" }]
[tool.poetry.scripts]
langgraph = "langgraph_cli.cli:cli"
[tool.poetry.dependencies]
python = "^3.9.0,<4.0"
click = "^8.1.7"
langgraph-api = { version = ">=0.0.6,<0.1.0", optional = true, python = ">=3.11,<4.0" }
python-dotenv = { version = ">=0.8.0", optional = true }
[tool.poetry.group.dev.dependencies]
ruff = "^0.6.2"
codespell = "^2.2.0"
pytest = "^7.2.1"
pytest-asyncio = "^0.21.1"
pytest-mock = "^3.11.1"
pytest-watch = "^4.2.0"
mypy = "^1.10.0"
[tool.poetry.extras]
inmem = ["langgraph-api", "python-dotenv"]
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5 -vv"
asyncio_mode = "auto"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
lint.select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# isort
"I",
]
lint.ignore = ["E501", "B008"]
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/conftest.py`:
```py
import os
from unittest.mock import patch
import pytest
@pytest.fixture(autouse=True)
def disable_analytics_env() -> None:
"""Disable analytics for unit tests LANGGRAPH_CLI_NO_ANALYTICS."""
# First check if the environment variable is already set, if so, log a warning prior
# to overriding it.
if "LANGGRAPH_CLI_NO_ANALYTICS" in os.environ:
print("⚠️ LANGGRAPH_CLI_NO_ANALYTICS is set. Overriding it for the test.")
with patch.dict(os.environ, {"LANGGRAPH_CLI_NO_ANALYTICS": "0"}):
yield
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/cli/test_templates.py`:
```py
"""Unit tests for the 'new' CLI command.
This command creates a new LangGraph project using a specified template.
"""
import os
from io import BytesIO
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock, patch
from urllib import request
from zipfile import ZipFile
from click.testing import CliRunner
from langgraph_cli.cli import cli
from langgraph_cli.templates import TEMPLATE_ID_TO_CONFIG
@patch.object(request, "urlopen")
def test_create_new_with_mocked_download(mock_urlopen: MagicMock) -> None:
"""Test the 'new' CLI command with a mocked download response using urllib."""
# Mock the response content to simulate a ZIP file
mock_zip_content = BytesIO()
with ZipFile(mock_zip_content, "w") as mock_zip:
mock_zip.writestr("test-file.txt", "Test content.")
# Create a mock response that behaves like a context manager
mock_response = MagicMock()
mock_response.read.return_value = mock_zip_content.getvalue()
mock_response.__enter__.return_value = mock_response # Setup enter context
mock_response.status = 200
mock_urlopen.return_value = mock_response
with TemporaryDirectory() as temp_dir:
runner = CliRunner()
template = next(
iter(TEMPLATE_ID_TO_CONFIG)
) # Select the first template for the test
result = runner.invoke(cli, ["new", temp_dir, "--template", template])
# Verify CLI command execution and success
assert result.exit_code == 0, result.output
assert (
"New project created" in result.output
), "Expected success message in output."
# Verify that the directory is not empty
assert os.listdir(temp_dir), "Expected files to be created in temp directory."
# Check for a known file in the extracted content
extracted_files = [f.name for f in Path(temp_dir).glob("*")]
assert (
"test-file.txt" in extracted_files
), "Expected 'test-file.txt' in the extracted content."
def test_invalid_template_id() -> None:
"""Test that an invalid template ID passed via CLI results in a graceful error."""
runner = CliRunner()
result = runner.invoke(
cli, ["new", "dummy_path", "--template", "invalid-template-id"]
)
# Verify the command failed and proper message is displayed
assert result.exit_code != 0, "Expected non-zero exit code for invalid template."
assert (
"Template 'invalid-template-id' not found" in result.output
), "Expected error message in output."
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/cli/test_cli.py`:
```py
import json
import pathlib
import shutil
import tempfile
from contextlib import contextmanager
from pathlib import Path
from click.testing import CliRunner
from langgraph_cli.cli import cli, prepare_args_and_stdin
from langgraph_cli.config import Config, validate_config
from langgraph_cli.docker import DEFAULT_POSTGRES_URI, DockerCapabilities, Version
from langgraph_cli.util import clean_empty_lines
DEFAULT_DOCKER_CAPABILITIES = DockerCapabilities(
version_docker=Version(26, 1, 1),
version_compose=Version(2, 27, 0),
healthcheck_start_interval=True,
)
@contextmanager
def temporary_config_folder(config_content: dict):
# Create a temporary directory
temp_dir = tempfile.mkdtemp()
try:
# Define the path for the config.json file
config_path = Path(temp_dir) / "config.json"
# Write the provided dictionary content to config.json
with open(config_path, "w", encoding="utf-8") as config_file:
json.dump(config_content, config_file)
# Yield the temporary directory path for use within the context
yield config_path.parent
finally:
# Cleanup the temporary directory and its contents
shutil.rmtree(temp_dir)
def test_prepare_args_and_stdin() -> None:
# this basically serves as an end-to-end test for using config and docker helpers
config_path = pathlib.Path("./langgraph.json")
config = validate_config(
Config(dependencies=["."], graphs={"agent": "agent.py:graph"})
)
port = 8000
debugger_port = 8001
debugger_graph_url = f"http://127.0.0.1:{port}"
actual_args, actual_stdin = prepare_args_and_stdin(
capabilities=DEFAULT_DOCKER_CAPABILITIES,
config_path=config_path,
config=config,
docker_compose=pathlib.Path("custom-docker-compose.yml"),
port=port,
debugger_port=debugger_port,
debugger_base_url=debugger_graph_url,
watch=True,
)
expected_args = [
"--project-directory",
".",
"-f",
"custom-docker-compose.yml",
"-f",
"-",
]
expected_stdin = f"""volumes:
langgraph-data:
driver: local
services:
langgraph-redis:
image: redis:6
healthcheck:
test: redis-cli ping
interval: 5s
timeout: 1s
retries: 5
langgraph-postgres:
image: postgres:16
ports:
- "5433:5432"
environment:
POSTGRES_DB: postgres
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
volumes:
- langgraph-data:/var/lib/postgresql/data
healthcheck:
test: pg_isready -U postgres
start_period: 10s
timeout: 1s
retries: 5
interval: 60s
start_interval: 1s
langgraph-debugger:
image: langchain/langgraph-debugger
restart: on-failure
depends_on:
langgraph-postgres:
condition: service_healthy
ports:
- "{debugger_port}:3968"
environment:
VITE_STUDIO_LOCAL_GRAPH_URL: {debugger_graph_url}
langgraph-api:
ports:
- "8000:8000"
depends_on:
langgraph-redis:
condition: service_healthy
langgraph-postgres:
condition: service_healthy
environment:
REDIS_URI: redis://langgraph-redis:6379
POSTGRES_URI: {DEFAULT_POSTGRES_URI}
healthcheck:
test: python /api/healthcheck.py
interval: 60s
start_interval: 1s
start_period: 10s
pull_policy: build
build:
context: .
dockerfile_inline: |
FROM langchain/langgraph-api:3.11
ADD . /deps/
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/*
ENV LANGSERVE_GRAPHS='{{"agent": "agent.py:graph"}}'
WORKDIR /deps/
develop:
watch:
- path: langgraph.json
action: rebuild
- path: .
action: rebuild\
"""
assert actual_args == expected_args
assert clean_empty_lines(actual_stdin) == expected_stdin
def test_version_option() -> None:
"""Test the --version option of the CLI."""
runner = CliRunner()
result = runner.invoke(cli, ["--version"])
# Verify that the command executed successfully
assert result.exit_code == 0, "Expected exit code 0 for --version option"
# Check that the output contains the correct version information
assert (
"LangGraph CLI, version" in result.output
), "Expected version information in output"
def test_dockerfile_command_basic() -> None:
"""Test the 'dockerfile' command with basic configuration."""
runner = CliRunner()
config_content = {
"node_version": "20", # Add any other necessary configuration fields
"graphs": {"agent": "agent.py:graph"},
}
with temporary_config_folder(config_content) as temp_dir:
save_path = temp_dir / "Dockerfile"
result = runner.invoke(
cli,
["dockerfile", str(save_path), "--config", str(temp_dir / "config.json")],
)
# Assert command was successful
assert result.exit_code == 0, result.output
assert "✅ Created: Dockerfile" in result.output
# Check if Dockerfile was created
assert save_path.exists()
def test_dockerfile_command_with_docker_compose() -> None:
"""Test the 'dockerfile' command with Docker Compose configuration."""
runner = CliRunner()
config_content = {
"dependencies": ["./my_agent"],
"graphs": {"agent": "./my_agent/agent.py:graph"},
"env": ".env",
}
with temporary_config_folder(config_content) as temp_dir:
save_path = temp_dir / "Dockerfile"
# Add agent.py file
agent_path = temp_dir / "my_agent" / "agent.py"
agent_path.parent.mkdir(parents=True, exist_ok=True)
agent_path.touch()
result = runner.invoke(
cli,
[
"dockerfile",
str(save_path),
"--config",
str(temp_dir / "config.json"),
"--add-docker-compose",
],
)
# Assert command was successful
assert result.exit_code == 0
assert "✅ Created: Dockerfile" in result.output
assert "✅ Created: .dockerignore" in result.output
assert "✅ Created: docker-compose.yml" in result.output
assert (
"✅ Created: .env" in result.output or "➖ Skipped: .env" in result.output
)
assert "🎉 Files generated successfully" in result.output
# Check if Dockerfile, .dockerignore, docker-compose.yml, and .env were created
assert save_path.exists()
assert (temp_dir / ".dockerignore").exists()
assert (temp_dir / "docker-compose.yml").exists()
assert (temp_dir / ".env").exists() or "➖ Skipped: .env" in result.output
def test_dockerfile_command_with_bad_config() -> None:
"""Test the 'dockerfile' command with basic configuration."""
runner = CliRunner()
config_content = {
"node_version": "20" # Add any other necessary configuration fields
}
with temporary_config_folder(config_content) as temp_dir:
save_path = temp_dir / "Dockerfile"
result = runner.invoke(
cli,
["dockerfile", str(save_path), "--config", str(temp_dir / "conf.json")],
)
# Assert command was successful
assert result.exit_code == 2
assert "conf.json' does not exist" in result.output
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/agent.py`:
```py
import asyncio
import os
from typing import Annotated, Sequence, TypedDict
from langchain_core.language_models.fake_chat_models import FakeListChatModel
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage
from langgraph.graph import END, StateGraph, add_messages
# check that env var is present
os.environ["SOME_ENV_VAR"]
class AgentState(TypedDict):
some_bytes: bytes
some_byte_array: bytearray
dict_with_bytes: dict[str, bytes]
messages: Annotated[Sequence[BaseMessage], add_messages]
sleep: int
async def call_model(state, config):
if sleep := state.get("sleep"):
await asyncio.sleep(sleep)
messages = state["messages"]
if len(messages) > 1:
assert state["some_bytes"] == b"some_bytes"
assert state["some_byte_array"] == bytearray(b"some_byte_array")
assert state["dict_with_bytes"] == {"more_bytes": b"more_bytes"}
# hacky way to reset model to the "first" response
if isinstance(messages[-1], HumanMessage):
model.i = 0
response = await model.ainvoke(messages)
return {
"messages": [response],
"some_bytes": b"some_bytes",
"some_byte_array": bytearray(b"some_byte_array"),
"dict_with_bytes": {"more_bytes": b"more_bytes"},
}
def call_tool(state):
last_message_content = state["messages"][-1].content
return {
"messages": [
ToolMessage(
f"tool_call__{last_message_content}", tool_call_id="tool_call_id"
)
]
}
def should_continue(state):
messages = state["messages"]
last_message = messages[-1]
if last_message.content == "end":
return END
else:
return "tool"
# NOTE: the model cycles through responses infinitely here
model = FakeListChatModel(responses=["begin", "end"])
workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model)
workflow.add_node("tool", call_tool)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
should_continue,
)
workflow.add_edge("tool", "agent")
graph = workflow.compile()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/test_config.py`:
```py
import json
import os
import pathlib
import tempfile
import click
import pytest
from langgraph_cli.config import (
config_to_compose,
config_to_docker,
validate_config,
validate_config_file,
)
from langgraph_cli.util import clean_empty_lines
PATH_TO_CONFIG = pathlib.Path(__file__).parent / "test_config.json"
def test_validate_config():
# minimal config
expected_config = {
"dependencies": ["."],
"graphs": {
"agent": "./agent.py:graph",
},
}
expected_config = {
"python_version": "3.11",
"pip_config_file": None,
"dockerfile_lines": [],
"env": {},
"store": None,
**expected_config,
}
actual_config = validate_config(expected_config)
assert actual_config == expected_config
# full config
env = ".env"
expected_config = {
"python_version": "3.12",
"pip_config_file": "pipconfig.txt",
"dockerfile_lines": ["ARG meow"],
"dependencies": [".", "langchain"],
"graphs": {
"agent": "./agent.py:graph",
},
"env": env,
"store": None,
}
actual_config = validate_config(expected_config)
assert actual_config == expected_config
expected_config["python_version"] = "3.13"
actual_config = validate_config(expected_config)
assert actual_config == expected_config
# check wrong python version raises
with pytest.raises(click.UsageError):
validate_config(
{
"python_version": "3.9",
}
)
# check missing dependencies key raises
with pytest.raises(click.UsageError):
validate_config(
{"python_version": "3.9", "graphs": {"agent": "./agent.py:graph"}},
)
# check missing graphs key raises
with pytest.raises(click.UsageError):
validate_config({"python_version": "3.9", "dependencies": ["."]})
with pytest.raises(click.UsageError) as exc_info:
validate_config({"python_version": "3.11.0"})
assert "Invalid Python version format" in str(exc_info.value)
with pytest.raises(click.UsageError) as exc_info:
validate_config({"python_version": "3"})
assert "Invalid Python version format" in str(exc_info.value)
with pytest.raises(click.UsageError) as exc_info:
validate_config({"python_version": "abc.def"})
assert "Invalid Python version format" in str(exc_info.value)
with pytest.raises(click.UsageError) as exc_info:
validate_config({"python_version": "3.10"})
assert "Minimum required version" in str(exc_info.value)
def test_validate_config_file():
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = pathlib.Path(tmpdir)
config_path = tmpdir_path / "langgraph.json"
node_config = {"node_version": "20", "graphs": {"agent": "./agent.js:graph"}}
with open(config_path, "w") as f:
json.dump(node_config, f)
validate_config_file(config_path)
package_json = {"name": "test", "engines": {"node": "20"}}
with open(tmpdir_path / "package.json", "w") as f:
json.dump(package_json, f)
validate_config_file(config_path)
package_json["engines"]["node"] = "20.18"
with open(tmpdir_path / "package.json", "w") as f:
json.dump(package_json, f)
with pytest.raises(click.UsageError, match="Use major version only"):
validate_config_file(config_path)
package_json["engines"] = {"node": "18"}
with open(tmpdir_path / "package.json", "w") as f:
json.dump(package_json, f)
with pytest.raises(click.UsageError, match="must be >= 20"):
validate_config_file(config_path)
package_json["engines"] = {"node": "20", "deno": "1.0"}
with open(tmpdir_path / "package.json", "w") as f:
json.dump(package_json, f)
with pytest.raises(click.UsageError, match="Only 'node' engine is supported"):
validate_config_file(config_path)
with open(tmpdir_path / "package.json", "w") as f:
f.write("{invalid json")
with pytest.raises(click.UsageError, match="Invalid package.json"):
validate_config_file(config_path)
python_config = {
"python_version": "3.11",
"dependencies": ["."],
"graphs": {"agent": "./agent.py:graph"},
}
with open(config_path, "w") as f:
json.dump(python_config, f)
validate_config_file(config_path)
for package_content in [
{"name": "test"},
{"engines": {"node": "18"}},
{"engines": {"node": "20", "deno": "1.0"}},
"{invalid json",
]:
with open(tmpdir_path / "package.json", "w") as f:
if isinstance(package_content, dict):
json.dump(package_content, f)
else:
f.write(package_content)
validate_config_file(config_path)
# config_to_docker
def test_config_to_docker_simple():
graphs = {"agent": "./agent.py:graph"}
actual_docker_stdin = config_to_docker(
PATH_TO_CONFIG,
validate_config({"dependencies": ["."], "graphs": graphs}),
"langchain/langgraph-api",
)
expected_docker_stdin = """\
FROM langchain/langgraph-api:3.11
ADD . /deps/__outer_unit_tests/unit_tests
RUN set -ex && \\
for line in '[project]' \\
'name = "unit_tests"' \\
'version = "0.1"' \\
'[tool.setuptools.package-data]' \\
'"*" = ["**/*"]'; do \\
echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\
done
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/*
ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}'
WORKDIR /deps/__outer_unit_tests/unit_tests\
"""
assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin
def test_config_to_docker_pipconfig():
graphs = {"agent": "./agent.py:graph"}
actual_docker_stdin = config_to_docker(
PATH_TO_CONFIG,
validate_config(
{
"dependencies": ["."],
"graphs": graphs,
"pip_config_file": "pipconfig.txt",
}
),
"langchain/langgraph-api",
)
expected_docker_stdin = """\
FROM langchain/langgraph-api:3.11
ADD pipconfig.txt /pipconfig.txt
ADD . /deps/__outer_unit_tests/unit_tests
RUN set -ex && \\
for line in '[project]' \\
'name = "unit_tests"' \\
'version = "0.1"' \\
'[tool.setuptools.package-data]' \\
'"*" = ["**/*"]'; do \\
echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\
done
RUN PIP_CONFIG_FILE=/pipconfig.txt PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/*
ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}'
WORKDIR /deps/__outer_unit_tests/unit_tests\
"""
assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin
def test_config_to_docker_invalid_inputs():
# test missing local dependencies
with pytest.raises(FileNotFoundError):
graphs = {"agent": "tests/unit_tests/agent.py:graph"}
config_to_docker(
PATH_TO_CONFIG,
validate_config({"dependencies": ["./missing"], "graphs": graphs}),
"langchain/langgraph-api",
)
# test missing local module
with pytest.raises(FileNotFoundError):
graphs = {"agent": "./missing_agent.py:graph"}
config_to_docker(
PATH_TO_CONFIG,
validate_config({"dependencies": ["."], "graphs": graphs}),
"langchain/langgraph-api",
)
def test_config_to_docker_local_deps():
graphs = {"agent": "./graphs/agent.py:graph"}
actual_docker_stdin = config_to_docker(
PATH_TO_CONFIG,
validate_config(
{
"dependencies": ["./graphs"],
"graphs": graphs,
}
),
"langchain/langgraph-api-custom",
)
expected_docker_stdin = """\
FROM langchain/langgraph-api-custom:3.11
ADD ./graphs /deps/__outer_graphs/src
RUN set -ex && \\
for line in '[project]' \\
'name = "graphs"' \\
'version = "0.1"' \\
'[tool.setuptools.package-data]' \\
'"*" = ["**/*"]'; do \\
echo "$line" >> /deps/__outer_graphs/pyproject.toml; \\
done
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/*
ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_graphs/src/agent.py:graph"}'\
"""
assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin
def test_config_to_docker_pyproject():
pyproject_str = """[project]
name = "custom"
version = "0.1"
dependencies = ["langchain"]"""
pyproject_path = "tests/unit_tests/pyproject.toml"
with open(pyproject_path, "w") as f:
f.write(pyproject_str)
graphs = {"agent": "./graphs/agent.py:graph"}
actual_docker_stdin = config_to_docker(
PATH_TO_CONFIG,
validate_config(
{
"dependencies": ["."],
"graphs": graphs,
}
),
"langchain/langgraph-api",
)
os.remove(pyproject_path)
expected_docker_stdin = """FROM langchain/langgraph-api:3.11
ADD . /deps/unit_tests
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/*
ENV LANGSERVE_GRAPHS='{"agent": "/deps/unit_tests/graphs/agent.py:graph"}'
WORKDIR /deps/unit_tests"""
assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin
def test_config_to_docker_end_to_end():
graphs = {"agent": "./graphs/agent.py:graph"}
actual_docker_stdin = config_to_docker(
PATH_TO_CONFIG,
validate_config(
{
"python_version": "3.12",
"dependencies": ["./graphs/", "langchain", "langchain_openai"],
"graphs": graphs,
"pip_config_file": "pipconfig.txt",
"dockerfile_lines": ["ARG meow", "ARG foo"],
}
),
"langchain/langgraph-api",
)
expected_docker_stdin = """FROM langchain/langgraph-api:3.12
ARG meow
ARG foo
ADD pipconfig.txt /pipconfig.txt
RUN PIP_CONFIG_FILE=/pipconfig.txt PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt langchain langchain_openai
ADD ./graphs/ /deps/__outer_graphs/src
RUN set -ex && \\
for line in '[project]' \\
'name = "graphs"' \\
'version = "0.1"' \\
'[tool.setuptools.package-data]' \\
'"*" = ["**/*"]'; do \\
echo "$line" >> /deps/__outer_graphs/pyproject.toml; \\
done
RUN PIP_CONFIG_FILE=/pipconfig.txt PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/*
ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_graphs/src/agent.py:graph"}'"""
assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin
# node.js build used for LangGraph Cloud
def test_config_to_docker_nodejs():
graphs = {"agent": "./graphs/agent.js:graph"}
actual_docker_stdin = config_to_docker(
PATH_TO_CONFIG,
validate_config(
{
"node_version": "20",
"graphs": graphs,
"dockerfile_lines": ["ARG meow", "ARG foo"],
}
),
"langchain/langgraphjs-api",
)
expected_docker_stdin = """FROM langchain/langgraphjs-api:20
ARG meow
ARG foo
ADD . /deps/unit_tests
RUN cd /deps/unit_tests && npm i
ENV LANGSERVE_GRAPHS='{"agent": "./graphs/agent.js:graph"}'
WORKDIR /deps/unit_tests
RUN (test ! -f /api/langgraph_api/js/build.mts && echo "Prebuild script not found, skipping") || tsx /api/langgraph_api/js/build.mts"""
assert clean_empty_lines(actual_docker_stdin) == expected_docker_stdin
# config_to_compose
def test_config_to_compose_simple_config():
graphs = {"agent": "./agent.py:graph"}
expected_compose_stdin = """\
pull_policy: build
build:
context: .
dockerfile_inline: |
FROM langchain/langgraph-api:3.11
ADD . /deps/__outer_unit_tests/unit_tests
RUN set -ex && \\
for line in '[project]' \\
'name = "unit_tests"' \\
'version = "0.1"' \\
'[tool.setuptools.package-data]' \\
'"*" = ["**/*"]'; do \\
echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\
done
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/*
ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}'
WORKDIR /deps/__outer_unit_tests/unit_tests
"""
actual_compose_stdin = config_to_compose(
PATH_TO_CONFIG,
validate_config({"dependencies": ["."], "graphs": graphs}),
"langchain/langgraph-api",
)
assert clean_empty_lines(actual_compose_stdin) == expected_compose_stdin
def test_config_to_compose_env_vars():
graphs = {"agent": "./agent.py:graph"}
expected_compose_stdin = """ OPENAI_API_KEY: "key"
pull_policy: build
build:
context: .
dockerfile_inline: |
FROM langchain/langgraph-api-custom:3.11
ADD . /deps/__outer_unit_tests/unit_tests
RUN set -ex && \\
for line in '[project]' \\
'name = "unit_tests"' \\
'version = "0.1"' \\
'[tool.setuptools.package-data]' \\
'"*" = ["**/*"]'; do \\
echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\
done
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/*
ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}'
WORKDIR /deps/__outer_unit_tests/unit_tests
"""
openai_api_key = "key"
actual_compose_stdin = config_to_compose(
PATH_TO_CONFIG,
validate_config(
{
"dependencies": ["."],
"graphs": graphs,
"env": {"OPENAI_API_KEY": openai_api_key},
}
),
"langchain/langgraph-api-custom",
)
assert clean_empty_lines(actual_compose_stdin) == expected_compose_stdin
def test_config_to_compose_env_file():
graphs = {"agent": "./agent.py:graph"}
expected_compose_stdin = """\
env_file: .env
pull_policy: build
build:
context: .
dockerfile_inline: |
FROM langchain/langgraph-api:3.11
ADD . /deps/__outer_unit_tests/unit_tests
RUN set -ex && \\
for line in '[project]' \\
'name = "unit_tests"' \\
'version = "0.1"' \\
'[tool.setuptools.package-data]' \\
'"*" = ["**/*"]'; do \\
echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\
done
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/*
ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}'
WORKDIR /deps/__outer_unit_tests/unit_tests
"""
actual_compose_stdin = config_to_compose(
PATH_TO_CONFIG,
validate_config({"dependencies": ["."], "graphs": graphs, "env": ".env"}),
"langchain/langgraph-api",
)
assert clean_empty_lines(actual_compose_stdin) == expected_compose_stdin
def test_config_to_compose_watch():
graphs = {"agent": "./agent.py:graph"}
expected_compose_stdin = """\
pull_policy: build
build:
context: .
dockerfile_inline: |
FROM langchain/langgraph-api:3.11
ADD . /deps/__outer_unit_tests/unit_tests
RUN set -ex && \\
for line in '[project]' \\
'name = "unit_tests"' \\
'version = "0.1"' \\
'[tool.setuptools.package-data]' \\
'"*" = ["**/*"]'; do \\
echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\
done
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/*
ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}'
WORKDIR /deps/__outer_unit_tests/unit_tests
develop:
watch:
- path: test_config.json
action: rebuild
- path: .
action: rebuild\
"""
actual_compose_stdin = config_to_compose(
PATH_TO_CONFIG,
validate_config({"dependencies": ["."], "graphs": graphs}),
"langchain/langgraph-api",
watch=True,
)
assert clean_empty_lines(actual_compose_stdin) == expected_compose_stdin
def test_config_to_compose_end_to_end():
# test all of the above + langgraph API path
graphs = {"agent": "./agent.py:graph"}
expected_compose_stdin = """\
env_file: .env
pull_policy: build
build:
context: .
dockerfile_inline: |
FROM langchain/langgraph-api:3.11
ADD . /deps/__outer_unit_tests/unit_tests
RUN set -ex && \\
for line in '[project]' \\
'name = "unit_tests"' \\
'version = "0.1"' \\
'[tool.setuptools.package-data]' \\
'"*" = ["**/*"]'; do \\
echo "$line" >> /deps/__outer_unit_tests/pyproject.toml; \\
done
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/*
ENV LANGSERVE_GRAPHS='{"agent": "/deps/__outer_unit_tests/unit_tests/agent.py:graph"}'
WORKDIR /deps/__outer_unit_tests/unit_tests
develop:
watch:
- path: test_config.json
action: rebuild
- path: .
action: rebuild\
"""
actual_compose_stdin = config_to_compose(
PATH_TO_CONFIG,
validate_config({"dependencies": ["."], "graphs": graphs, "env": ".env"}),
"langchain/langgraph-api",
watch=True,
)
assert clean_empty_lines(actual_compose_stdin) == expected_compose_stdin
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/helpers.py`:
```py
def clean_empty_lines(input_str: str):
return "\n".join(filter(None, input_str.splitlines()))
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/unit_tests/test_docker.py`:
```py
from langgraph_cli.docker import (
DEFAULT_POSTGRES_URI,
DockerCapabilities,
Version,
compose,
)
from langgraph_cli.util import clean_empty_lines
DEFAULT_DOCKER_CAPABILITIES = DockerCapabilities(
version_docker=Version(26, 1, 1),
version_compose=Version(2, 27, 0),
healthcheck_start_interval=False,
)
def test_compose_with_no_debugger_and_custom_db():
port = 8123
custom_postgres_uri = "custom_postgres_uri"
actual_compose_str = compose(
DEFAULT_DOCKER_CAPABILITIES, port=port, postgres_uri=custom_postgres_uri
)
expected_compose_str = f"""services:
langgraph-redis:
image: redis:6
healthcheck:
test: redis-cli ping
interval: 5s
timeout: 1s
retries: 5
langgraph-api:
ports:
- "{port}:8000"
depends_on:
langgraph-redis:
condition: service_healthy
environment:
REDIS_URI: redis://langgraph-redis:6379
POSTGRES_URI: {custom_postgres_uri}"""
assert clean_empty_lines(actual_compose_str) == expected_compose_str
def test_compose_with_no_debugger_and_custom_db_with_healthcheck():
port = 8123
custom_postgres_uri = "custom_postgres_uri"
actual_compose_str = compose(
DEFAULT_DOCKER_CAPABILITIES._replace(healthcheck_start_interval=True),
port=port,
postgres_uri=custom_postgres_uri,
)
expected_compose_str = f"""services:
langgraph-redis:
image: redis:6
healthcheck:
test: redis-cli ping
interval: 5s
timeout: 1s
retries: 5
langgraph-api:
ports:
- "{port}:8000"
depends_on:
langgraph-redis:
condition: service_healthy
environment:
REDIS_URI: redis://langgraph-redis:6379
POSTGRES_URI: {custom_postgres_uri}
healthcheck:
test: python /api/healthcheck.py
interval: 60s
start_interval: 1s
start_period: 10s"""
assert clean_empty_lines(actual_compose_str) == expected_compose_str
def test_compose_with_debugger_and_custom_db():
port = 8123
custom_postgres_uri = "custom_postgres_uri"
actual_compose_str = compose(
DEFAULT_DOCKER_CAPABILITIES,
port=port,
postgres_uri=custom_postgres_uri,
)
expected_compose_str = f"""services:
langgraph-redis:
image: redis:6
healthcheck:
test: redis-cli ping
interval: 5s
timeout: 1s
retries: 5
langgraph-api:
ports:
- "{port}:8000"
depends_on:
langgraph-redis:
condition: service_healthy
environment:
REDIS_URI: redis://langgraph-redis:6379
POSTGRES_URI: {custom_postgres_uri}"""
assert clean_empty_lines(actual_compose_str) == expected_compose_str
def test_compose_with_debugger_and_default_db():
port = 8123
actual_compose_str = compose(DEFAULT_DOCKER_CAPABILITIES, port=port)
expected_compose_str = f"""volumes:
langgraph-data:
driver: local
services:
langgraph-redis:
image: redis:6
healthcheck:
test: redis-cli ping
interval: 5s
timeout: 1s
retries: 5
langgraph-postgres:
image: postgres:16
ports:
- "5433:5432"
environment:
POSTGRES_DB: postgres
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
volumes:
- langgraph-data:/var/lib/postgresql/data
healthcheck:
test: pg_isready -U postgres
start_period: 10s
timeout: 1s
retries: 5
interval: 5s
langgraph-api:
ports:
- "{port}:8000"
depends_on:
langgraph-redis:
condition: service_healthy
langgraph-postgres:
condition: service_healthy
environment:
REDIS_URI: redis://langgraph-redis:6379
POSTGRES_URI: {DEFAULT_POSTGRES_URI}"""
assert clean_empty_lines(actual_compose_str) == expected_compose_str
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/tests/integration_tests/test_cli.py`:
```py
import pytest
import requests
from langgraph_cli.templates import TEMPLATE_ID_TO_CONFIG
@pytest.mark.parametrize("template_key", TEMPLATE_ID_TO_CONFIG.keys())
def test_template_urls_work(template_key: str) -> None:
"""Integration test to verify that all template URLs are reachable."""
_, _, template_url = TEMPLATE_ID_TO_CONFIG[template_key]
response = requests.head(template_url)
# Returns 302 on a successful HEAD request
assert response.status_code == 302, f"URL {template_url} is not reachable."
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/pyproject.toml`:
```toml
[tool.poetry]
name = "langgraph-examples"
version = "0.1.0"
description = ""
authors = []
readme = "README.md"
packages = []
package-mode = false
[tool.poetry.dependencies]
python = "^3.9.0,<4.0"
langgraph-cli = {path = "../../cli", develop = true}
langgraph-sdk = {path = "../../sdk-py", develop = true}
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs_reqs_b/hello.py`:
```py
from graphs_submod.agent import graph # noqa
from utils.greeter import greet
greet()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs_reqs_b/graphs_submod/agent.py`:
```py
from pathlib import Path
from typing import Annotated, Sequence, TypedDict
from langchain_anthropic import ChatAnthropic
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import BaseMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, add_messages
from langgraph.prebuilt import ToolNode
tools = [TavilySearchResults(max_results=1)]
model_anth = ChatAnthropic(temperature=0, model_name="claude-3-sonnet-20240229")
model_oai = ChatOpenAI(temperature=0)
model_anth = model_anth.bind_tools(tools)
model_oai = model_oai.bind_tools(tools)
prompt = open("prompt.txt").read()
subprompt = open(Path(__file__).parent / "subprompt.txt").read()
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], add_messages]
# Define the function that determines whether to continue or not
def should_continue(state):
messages = state["messages"]
last_message = messages[-1]
# If there are no tool calls, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
# Define the function that calls the model
def call_model(state, config):
if config["configurable"].get("model", "anthropic") == "anthropic":
model = model_anth
else:
model = model_oai
messages = state["messages"]
response = model.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
# Define the function to execute tools
tool_node = ToolNode(tools)
# Define a new graph
workflow = StateGraph(AgentState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "action",
# Otherwise we finish.
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
graph = workflow.compile()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs_reqs_b/utils/greeter.py`:
```py
def greet():
print("Hello, world!")
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs/storm.py`:
```py
import asyncio
import json
from typing import Annotated, List, Optional
from langchain_community.retrievers import WikipediaRetriever
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.vectorstores import SKLearnVectorStore
from langchain_core.documents import Document
from langchain_core.messages import (
AIMessage,
AnyMessage,
HumanMessage,
ToolMessage,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.runnables import chain as as_runnable
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langgraph.graph import END, StateGraph
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
fast_llm = ChatOpenAI(model="gpt-3.5-turbo")
# Uncomment for a Fireworks model
# fast_llm = ChatFireworks(model="accounts/fireworks/models/firefunction-v1", max_tokens=32_000)
long_context_llm = ChatOpenAI(model="gpt-4-turbo-preview")
direct_gen_outline_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a Wikipedia writer. Write an outline for a Wikipedia page about a user-provided topic. Be comprehensive and specific.",
),
("user", "{topic}"),
]
)
class Subsection(BaseModel):
subsection_title: str = Field(..., title="Title of the subsection")
description: str = Field(..., title="Content of the subsection")
@property
def as_str(self) -> str:
return f"### {self.subsection_title}\n\n{self.description}".strip()
class Section(BaseModel):
section_title: str = Field(..., title="Title of the section")
description: str = Field(..., title="Content of the section")
subsections: Optional[List[Subsection]] = Field(
default=None,
title="Titles and descriptions for each subsection of the Wikipedia page.",
)
@property
def as_str(self) -> str:
subsections = "\n\n".join(
f"### {subsection.subsection_title}\n\n{subsection.description}"
for subsection in self.subsections or []
)
return f"## {self.section_title}\n\n{self.description}\n\n{subsections}".strip()
class Outline(BaseModel):
page_title: str = Field(..., title="Title of the Wikipedia page")
sections: List[Section] = Field(
default_factory=list,
title="Titles and descriptions for each section of the Wikipedia page.",
)
@property
def as_str(self) -> str:
sections = "\n\n".join(section.as_str for section in self.sections)
return f"# {self.page_title}\n\n{sections}".strip()
generate_outline_direct = direct_gen_outline_prompt | fast_llm.with_structured_output(
Outline
)
gen_related_topics_prompt = ChatPromptTemplate.from_template(
"""I'm writing a Wikipedia page for a topic mentioned below. Please identify and recommend some Wikipedia pages on closely related subjects. I'm looking for examples that provide insights into interesting aspects commonly associated with this topic, or examples that help me understand the typical content and structure included in Wikipedia pages for similar topics.
Please list the as many subjects and urls as you can.
Topic of interest: {topic}
"""
)
class RelatedSubjects(BaseModel):
topics: List[str] = Field(
description="Comprehensive list of related subjects as background research.",
)
expand_chain = gen_related_topics_prompt | fast_llm.with_structured_output(
RelatedSubjects
)
class Editor(BaseModel):
affiliation: str = Field(
description="Primary affiliation of the editor.",
)
name: str = Field(
description="Name of the editor.",
)
role: str = Field(
description="Role of the editor in the context of the topic.",
)
description: str = Field(
description="Description of the editor's focus, concerns, and motives.",
)
@property
def persona(self) -> str:
return f"Name: {self.name}\nRole: {self.role}\nAffiliation: {self.affiliation}\nDescription: {self.description}\n"
class Perspectives(BaseModel):
editors: List[Editor] = Field(
description="Comprehensive list of editors with their roles and affiliations.",
# Add a pydantic validation/restriction to be at most M editors
)
gen_perspectives_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You need to select a diverse (and distinct) group of Wikipedia editors who will work together to create a comprehensive article on the topic. Each of them represents a different perspective, role, or affiliation related to this topic.\
You can use other Wikipedia pages of related topics for inspiration. For each editor, add a description of what they will focus on.
Wiki page outlines of related topics for inspiration:
{examples}""",
),
("user", "Topic of interest: {topic}"),
]
)
gen_perspectives_chain = gen_perspectives_prompt | ChatOpenAI(
model="gpt-3.5-turbo"
).with_structured_output(Perspectives)
wikipedia_retriever = WikipediaRetriever(load_all_available_meta=True, top_k_results=1)
def format_doc(doc, max_length=1000):
related = "- ".join(doc.metadata["categories"])
return f"### {doc.metadata['title']}\n\nSummary: {doc.page_content}\n\nRelated\n{related}"[
:max_length
]
def format_docs(docs):
return "\n\n".join(format_doc(doc) for doc in docs)
@as_runnable
async def survey_subjects(topic: str):
related_subjects = await expand_chain.ainvoke({"topic": topic})
retrieved_docs = await wikipedia_retriever.abatch(
related_subjects.topics, return_exceptions=True
)
all_docs = []
for docs in retrieved_docs:
if isinstance(docs, BaseException):
continue
all_docs.extend(docs)
formatted = format_docs(all_docs)
return await gen_perspectives_chain.ainvoke({"examples": formatted, "topic": topic})
def add_messages(left, right):
if not isinstance(left, list):
left = [left]
if not isinstance(right, list):
right = [right]
return left + right
def update_references(references, new_references):
if not references:
references = {}
references.update(new_references)
return references
def update_editor(editor, new_editor):
# Can only set at the outset
if not editor:
return new_editor
return editor
class InterviewState(TypedDict):
messages: Annotated[List[AnyMessage], add_messages]
references: Annotated[Optional[dict], update_references]
editor: Annotated[Optional[Editor], update_editor]
gen_qn_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are an experienced Wikipedia writer and want to edit a specific page. \
Besides your identity as a Wikipedia writer, you have a specific focus when researching the topic. \
Now, you are chatting with an expert to get information. Ask good questions to get more useful information.
When you have no more questions to ask, say "Thank you so much for your help!" to end the conversation.\
Please only ask one question at a time and don't ask what you have asked before.\
Your questions should be related to the topic you want to write.
Be comprehensive and curious, gaining as much unique insight from the expert as possible.\
Stay true to your specific perspective:
{persona}""",
),
MessagesPlaceholder(variable_name="messages", optional=True),
]
)
def tag_with_name(ai_message: AIMessage, name: str):
ai_message.name = name.replace(" ", "_").replace(".", "_")
return ai_message
def swap_roles(state: InterviewState, name: str):
converted = []
for message in state["messages"]:
if isinstance(message, AIMessage) and message.name != name:
message = HumanMessage(**message.dict(exclude={"type"}))
converted.append(message)
return {"messages": converted}
@as_runnable
async def generate_question(state: InterviewState):
editor = state["editor"]
gn_chain = (
RunnableLambda(swap_roles).bind(name=editor.name)
| gen_qn_prompt.partial(persona=editor.persona)
| fast_llm
| RunnableLambda(tag_with_name).bind(name=editor.name)
)
result = await gn_chain.ainvoke(state)
return {"messages": [result]}
class Queries(BaseModel):
queries: List[str] = Field(
description="Comprehensive list of search engine queries to answer the user's questions.",
)
gen_queries_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful research assistant. Query the search engine to answer the user's questions.",
),
MessagesPlaceholder(variable_name="messages", optional=True),
]
)
gen_queries_chain = gen_queries_prompt | ChatOpenAI(
model="gpt-3.5-turbo"
).with_structured_output(Queries, include_raw=True)
class AnswerWithCitations(BaseModel):
answer: str = Field(
description="Comprehensive answer to the user's question with citations.",
)
cited_urls: List[str] = Field(
description="List of urls cited in the answer.",
)
@property
def as_str(self) -> str:
return f"{self.answer}\n\nCitations:\n\n" + "\n".join(
f"[{i+1}]: {url}" for i, url in enumerate(self.cited_urls)
)
gen_answer_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are an expert who can use information effectively. You are chatting with a Wikipedia writer who wants\
to write a Wikipedia page on the topic you know. You have gathered the related information and will now use the information to form a response.
Make your response as informative as possible and make sure every sentence is supported by the gathered information.
Each response must be backed up by a citation from a reliable source, formatted as a footnote, reproducing the URLS after your response.""",
),
MessagesPlaceholder(variable_name="messages", optional=True),
]
)
gen_answer_chain = gen_answer_prompt | fast_llm.with_structured_output(
AnswerWithCitations, include_raw=True
).with_config(run_name="GenerateAnswer")
# Tavily is typically a better search engine, but your free queries are limited
tavily_search = TavilySearchResults(max_results=4)
@tool
async def search_engine(query: str):
"""Search engine to the internet."""
results = tavily_search.invoke(query)
return [{"content": r["content"], "url": r["url"]} for r in results]
async def gen_answer(
state: InterviewState,
config: Optional[RunnableConfig] = None,
name: str = "Subject_Matter_Expert",
max_str_len: int = 15000,
):
swapped_state = swap_roles(state, name) # Convert all other AI messages
queries = await gen_queries_chain.ainvoke(swapped_state)
query_results = await search_engine.abatch(
queries["parsed"].queries, config, return_exceptions=True
)
successful_results = [
res for res in query_results if not isinstance(res, Exception)
]
all_query_results = {
res["url"]: res["content"] for results in successful_results for res in results
}
# We could be more precise about handling max token length if we wanted to here
dumped = json.dumps(all_query_results)[:max_str_len]
ai_message: AIMessage = queries["raw"]
tool_call = queries["raw"].tool_calls[0]
tool_id = tool_call["id"]
tool_message = ToolMessage(tool_call_id=tool_id, content=dumped)
swapped_state["messages"].extend([ai_message, tool_message])
# Only update the shared state with the final answer to avoid
# polluting the dialogue history with intermediate messages
generated = await gen_answer_chain.ainvoke(swapped_state)
cited_urls = set(generated["parsed"].cited_urls)
# Save the retrieved information to a the shared state for future reference
cited_references = {k: v for k, v in all_query_results.items() if k in cited_urls}
formatted_message = AIMessage(name=name, content=generated["parsed"].as_str)
return {"messages": [formatted_message], "references": cited_references}
max_num_turns = 5
def route_messages(state: InterviewState, name: str = "Subject_Matter_Expert"):
messages = state["messages"]
num_responses = len(
[m for m in messages if isinstance(m, AIMessage) and m.name == name]
)
if num_responses >= max_num_turns:
return END
last_question = messages[-2]
if last_question.content.endswith("Thank you so much for your help!"):
return END
return "ask_question"
builder = StateGraph(InterviewState)
builder.add_node("ask_question", generate_question)
builder.add_node("answer_question", gen_answer)
builder.add_conditional_edges("answer_question", route_messages)
builder.add_edge("ask_question", "answer_question")
builder.set_entry_point("ask_question")
interview_graph = builder.compile().with_config(run_name="Conduct Interviews")
refine_outline_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a Wikipedia writer. You have gathered information from experts and search engines. Now, you are refining the outline of the Wikipedia page. \
You need to make sure that the outline is comprehensive and specific. \
Topic you are writing about: {topic}
Old outline:
{old_outline}""",
),
(
"user",
"Refine the outline based on your conversations with subject-matter experts:\n\nConversations:\n\n{conversations}\n\nWrite the refined Wikipedia outline:",
),
]
)
# Using turbo preview since the context can get quite long
refine_outline_chain = refine_outline_prompt | long_context_llm.with_structured_output(
Outline
)
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
# reference_docs = [
# Document(page_content=v, metadata={"source": k})
# for k, v in final_state["references"].items()
# ]
# # This really doesn't need to be a vectorstore for this size of data.
# # It could just be a numpy matrix. Or you could store documents
# # across requests if you want.
# vectorstore = SKLearnVectorStore.from_documents(
# reference_docs,
# embedding=embeddings,
# )
# retriever = vectorstore.as_retriever(k=10)
vectorstore = SKLearnVectorStore(embedding=embeddings)
retriever = vectorstore.as_retriever(k=10)
class SubSection(BaseModel):
subsection_title: str = Field(..., title="Title of the subsection")
content: str = Field(
...,
title="Full content of the subsection. Include [#] citations to the cited sources where relevant.",
)
@property
def as_str(self) -> str:
return f"### {self.subsection_title}\n\n{self.content}".strip()
class WikiSection(BaseModel):
section_title: str = Field(..., title="Title of the section")
content: str = Field(..., title="Full content of the section")
subsections: Optional[List[Subsection]] = Field(
default=None,
title="Titles and descriptions for each subsection of the Wikipedia page.",
)
citations: List[str] = Field(default_factory=list)
@property
def as_str(self) -> str:
subsections = "\n\n".join(
subsection.as_str for subsection in self.subsections or []
)
citations = "\n".join([f" [{i}] {cit}" for i, cit in enumerate(self.citations)])
return (
f"## {self.section_title}\n\n{self.content}\n\n{subsections}".strip()
+ f"\n\n{citations}".strip()
)
section_writer_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an expert Wikipedia writer. Complete your assigned WikiSection from the following outline:\n\n"
"{outline}\n\nCite your sources, using the following references:\n\n<Documents>\n{docs}\n<Documents>",
),
("user", "Write the full WikiSection for the {section} section."),
]
)
async def retrieve(inputs: dict):
docs = await retriever.ainvoke(inputs["topic"] + ": " + inputs["section"])
formatted = "\n".join(
[
f'<Document href="{doc.metadata["source"]}"/>\n{doc.page_content}\n</Document>'
for doc in docs
]
)
return {"docs": formatted, **inputs}
section_writer = (
retrieve
| section_writer_prompt
| long_context_llm.with_structured_output(WikiSection)
)
writer_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an expert Wikipedia author. Write the complete wiki article on {topic} using the following section drafts:\n\n"
"{draft}\n\nStrictly follow Wikipedia format guidelines.",
),
(
"user",
'Write the complete Wiki article using markdown format. Organize citations using footnotes like "[1]",'
" avoiding duplicates in the footer. Include URLs in the footer.",
),
]
)
writer = writer_prompt | long_context_llm | StrOutputParser()
class ResearchState(TypedDict):
topic: str
outline: Outline
editors: List[Editor]
interview_results: List[InterviewState]
# The final sections output
sections: List[WikiSection]
article: str
async def initialize_research(state: ResearchState):
topic = state["topic"]
coros = (
generate_outline_direct.ainvoke({"topic": topic}),
survey_subjects.ainvoke(topic),
)
results = await asyncio.gather(*coros)
return {
**state,
"outline": results[0],
"editors": results[1].editors,
}
async def conduct_interviews(state: ResearchState):
topic = state["topic"]
initial_states = [
{
"editor": editor,
"messages": [
AIMessage(
content=f"So you said you were writing an article on {topic}?",
name="Subject_Matter_Expert",
)
],
}
for editor in state["editors"]
]
# We call in to the sub-graph here to parallelize the interviews
interview_results = await interview_graph.abatch(initial_states)
return {
**state,
"interview_results": interview_results,
}
def format_conversation(interview_state):
messages = interview_state["messages"]
convo = "\n".join(f"{m.name}: {m.content}" for m in messages)
return f'Conversation with {interview_state["editor"].name}\n\n' + convo
async def refine_outline(state: ResearchState):
convos = "\n\n".join(
[
format_conversation(interview_state)
for interview_state in state["interview_results"]
]
)
updated_outline = await refine_outline_chain.ainvoke(
{
"topic": state["topic"],
"old_outline": state["outline"].as_str,
"conversations": convos,
}
)
return {**state, "outline": updated_outline}
async def index_references(state: ResearchState):
all_docs = []
for interview_state in state["interview_results"]:
reference_docs = [
Document(page_content=v, metadata={"source": k})
for k, v in interview_state["references"].items()
]
all_docs.extend(reference_docs)
await vectorstore.aadd_documents(all_docs)
return state
async def write_sections(state: ResearchState):
outline = state["outline"]
sections = await section_writer.abatch(
[
{
"outline": outline.as_str,
"section": section.section_title,
"topic": state["topic"],
}
for section in outline.sections
]
)
return {
**state,
"sections": sections,
}
async def write_article(state: ResearchState):
topic = state["topic"]
sections = state["sections"]
draft = "\n\n".join([section.as_str for section in sections])
article = await writer.ainvoke({"topic": topic, "draft": draft})
return {
**state,
"article": article,
}
builder_of_storm = StateGraph(ResearchState)
nodes = [
("init_research", initialize_research),
("conduct_interviews", conduct_interviews),
("refine_outline", refine_outline),
("index_references", index_references),
("write_sections", write_sections),
("write_article", write_article),
]
for i in range(len(nodes)):
name, node = nodes[i]
builder_of_storm.add_node(name, node)
if i > 0:
builder_of_storm.add_edge(nodes[i - 1][0], name)
builder_of_storm.set_entry_point(nodes[0][0])
builder_of_storm.set_finish_point(nodes[-1][0])
graph = builder_of_storm.compile()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs/agent.py`:
```py
from typing import Annotated, Literal, Sequence, TypedDict
from langchain_anthropic import ChatAnthropic
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import BaseMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, add_messages
from langgraph.prebuilt import ToolNode
tools = [TavilySearchResults(max_results=1)]
model_anth = ChatAnthropic(temperature=0, model_name="claude-3-sonnet-20240229")
model_oai = ChatOpenAI(temperature=0)
model_anth = model_anth.bind_tools(tools)
model_oai = model_oai.bind_tools(tools)
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], add_messages]
# Define the function that determines whether to continue or not
def should_continue(state):
messages = state["messages"]
last_message = messages[-1]
# If there are no tool calls, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
# Define the function that calls the model
def call_model(state, config):
if config["configurable"].get("model", "anthropic") == "anthropic":
model = model_anth
else:
model = model_oai
messages = state["messages"]
response = model.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
# Define the function to execute tools
tool_node = ToolNode(tools)
class ConfigSchema(TypedDict):
model: Literal["anthropic", "openai"]
# Define a new graph
workflow = StateGraph(AgentState, config_schema=ConfigSchema)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "action",
# Otherwise we finish.
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
graph = workflow.compile()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs_reqs_a/hello.py`:
```py
from graphs_reqs_a.graphs_submod.agent import graph # noqa
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/cli/examples/graphs_reqs_a/graphs_submod/agent.py`:
```py
from pathlib import Path
from typing import Annotated, Sequence, TypedDict
from langchain_anthropic import ChatAnthropic
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import BaseMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, add_messages
from langgraph.prebuilt import ToolNode
tools = [TavilySearchResults(max_results=1)]
model_anth = ChatAnthropic(temperature=0, model_name="claude-3-sonnet-20240229")
model_oai = ChatOpenAI(temperature=0)
model_anth = model_anth.bind_tools(tools)
model_oai = model_oai.bind_tools(tools)
prompt = open("prompt.txt").read()
subprompt = open(Path(__file__).parent / "subprompt.txt").read()
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], add_messages]
# Define the function that determines whether to continue or not
def should_continue(state):
messages = state["messages"]
last_message = messages[-1]
# If there are no tool calls, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
# Define the function that calls the model
def call_model(state, config):
if config["configurable"].get("model", "anthropic") == "anthropic":
model = model_anth
else:
model = model_oai
messages = state["messages"]
response = model.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
# Define the function to execute tools
tool_node = ToolNode(tools)
# Define a new graph
workflow = StateGraph(AgentState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "action",
# Otherwise we finish.
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
graph = workflow.compile()
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py`:
```py
import random
import sqlite3
import threading
from contextlib import closing, contextmanager
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
SerializerProtocol,
get_checkpoint_id,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import ChannelProtocol
from langgraph.checkpoint.sqlite.utils import search_where
_AIO_ERROR_MSG = (
"The SqliteSaver does not support async methods. "
"Consider using AsyncSqliteSaver instead.\n"
"from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver\n"
"Note: AsyncSqliteSaver requires the aiosqlite package to use.\n"
"Install with:\n`pip install aiosqlite`\n"
"See https://langchain-ai.github.io/langgraph/reference/checkpoints/asyncsqlitesaver"
"for more information."
)
class SqliteSaver(BaseCheckpointSaver[str]):
"""A checkpoint saver that stores checkpoints in a SQLite database.
Note:
This class is meant for lightweight, synchronous use cases
(demos and small projects) and does not
scale to multiple threads.
For a similar sqlite saver with `async` support,
consider using [AsyncSqliteSaver][langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver].
Args:
conn (sqlite3.Connection): The SQLite database connection.
serde (Optional[SerializerProtocol]): The serializer to use for serializing and deserializing checkpoints. Defaults to JsonPlusSerializerCompat.
Examples:
>>> import sqlite3
>>> from langgraph.checkpoint.sqlite import SqliteSaver
>>> from langgraph.graph import StateGraph
>>>
>>> builder = StateGraph(int)
>>> builder.add_node("add_one", lambda x: x + 1)
>>> builder.set_entry_point("add_one")
>>> builder.set_finish_point("add_one")
>>> conn = sqlite3.connect("checkpoints.sqlite")
>>> memory = SqliteSaver(conn)
>>> graph = builder.compile(checkpointer=memory)
>>> config = {"configurable": {"thread_id": "1"}}
>>> graph.get_state(config)
>>> result = graph.invoke(3, config)
>>> graph.get_state(config)
StateSnapshot(values=4, next=(), config={'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '0c62ca34-ac19-445d-bbb0-5b4984975b2a'}}, parent_config=None)
""" # noqa
conn: sqlite3.Connection
is_setup: bool
def __init__(
self,
conn: sqlite3.Connection,
*,
serde: Optional[SerializerProtocol] = None,
) -> None:
super().__init__(serde=serde)
self.jsonplus_serde = JsonPlusSerializer()
self.conn = conn
self.is_setup = False
self.lock = threading.Lock()
@classmethod
@contextmanager
def from_conn_string(cls, conn_string: str) -> Iterator["SqliteSaver"]:
"""Create a new SqliteSaver instance from a connection string.
Args:
conn_string (str): The SQLite connection string.
Yields:
SqliteSaver: A new SqliteSaver instance.
Examples:
In memory:
with SqliteSaver.from_conn_string(":memory:") as memory:
...
To disk:
with SqliteSaver.from_conn_string("checkpoints.sqlite") as memory:
...
"""
with closing(
sqlite3.connect(
conn_string,
# https://ricardoanderegg.com/posts/python-sqlite-thread-safety/
check_same_thread=False,
)
) as conn:
yield cls(conn)
def setup(self) -> None:
"""Set up the checkpoint database.
This method creates the necessary tables in the SQLite database if they don't
already exist. It is called automatically when needed and should not be called
directly by the user.
"""
if self.is_setup:
return
self.conn.executescript(
"""
PRAGMA journal_mode=WAL;
CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint BLOB,
metadata BLOB,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);
CREATE TABLE IF NOT EXISTS writes (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx INTEGER NOT NULL,
channel TEXT NOT NULL,
type TEXT,
value BLOB,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);
"""
)
self.is_setup = True
@contextmanager
def cursor(self, transaction: bool = True) -> Iterator[sqlite3.Cursor]:
"""Get a cursor for the SQLite database.
This method returns a cursor for the SQLite database. It is used internally
by the SqliteSaver and should not be called directly by the user.
Args:
transaction (bool): Whether to commit the transaction when the cursor is closed. Defaults to True.
Yields:
sqlite3.Cursor: A cursor for the SQLite database.
"""
with self.lock:
self.setup()
cur = self.conn.cursor()
try:
yield cur
finally:
if transaction:
self.conn.commit()
cur.close()
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database.
This method retrieves a checkpoint tuple from the SQLite database based on the
provided config. If the config contains a "checkpoint_id" key, the checkpoint with
the matching thread ID and checkpoint ID is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
Examples:
Basic:
>>> config = {"configurable": {"thread_id": "1"}}
>>> checkpoint_tuple = memory.get_tuple(config)
>>> print(checkpoint_tuple)
CheckpointTuple(...)
With checkpoint ID:
>>> config = {
... "configurable": {
... "thread_id": "1",
... "checkpoint_ns": "",
... "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875",
... }
... }
>>> checkpoint_tuple = memory.get_tuple(config)
>>> print(checkpoint_tuple)
CheckpointTuple(...)
""" # noqa
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
with self.cursor(transaction=False) as cur:
# find the latest checkpoint for the thread_id
if checkpoint_id := get_checkpoint_id(config):
cur.execute(
"SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?",
(
str(config["configurable"]["thread_id"]),
checkpoint_ns,
checkpoint_id,
),
)
else:
cur.execute(
"SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1",
(str(config["configurable"]["thread_id"]), checkpoint_ns),
)
# if a checkpoint is found, return it
if value := cur.fetchone():
(
thread_id,
checkpoint_id,
parent_checkpoint_id,
type,
checkpoint,
metadata,
) = value
if not get_checkpoint_id(config):
config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
}
# find any pending writes
cur.execute(
"SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? ORDER BY task_id, idx",
(
str(config["configurable"]["thread_id"]),
checkpoint_ns,
str(config["configurable"]["checkpoint_id"]),
),
)
# deserialize the checkpoint and metadata
return CheckpointTuple(
config,
self.serde.loads_typed((type, checkpoint)),
self.jsonplus_serde.loads(metadata) if metadata is not None else {},
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
[
(task_id, channel, self.serde.loads_typed((type, value)))
for task_id, channel, type, value in cur
],
)
def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints from the database.
This method retrieves a list of checkpoint tuples from the SQLite database based
on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first).
Args:
config (RunnableConfig): The config to use for listing the checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. Defaults to None.
before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None.
limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None.
Yields:
Iterator[CheckpointTuple]: An iterator of checkpoint tuples.
Examples:
>>> from langgraph.checkpoint.sqlite import SqliteSaver
>>> with SqliteSaver.from_conn_string(":memory:") as memory:
... # Run a graph, then list the checkpoints
>>> config = {"configurable": {"thread_id": "1"}}
>>> checkpoints = list(memory.list(config, limit=2))
>>> print(checkpoints)
[CheckpointTuple(...), CheckpointTuple(...)]
>>> config = {"configurable": {"thread_id": "1"}}
>>> before = {"configurable": {"checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875"}}
>>> with SqliteSaver.from_conn_string(":memory:") as memory:
... # Run a graph, then list the checkpoints
>>> checkpoints = list(memory.list(config, before=before))
>>> print(checkpoints)
[CheckpointTuple(...), ...]
"""
where, param_values = search_where(config, filter, before)
query = f"""SELECT thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata
FROM checkpoints
{where}
ORDER BY checkpoint_id DESC"""
if limit:
query += f" LIMIT {limit}"
with self.cursor(transaction=False) as cur, closing(self.conn.cursor()) as wcur:
cur.execute(query, param_values)
for (
thread_id,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
type,
checkpoint,
metadata,
) in cur:
wcur.execute(
"SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? ORDER BY task_id, idx",
(thread_id, checkpoint_ns, checkpoint_id),
)
yield CheckpointTuple(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
self.serde.loads_typed((type, checkpoint)),
self.jsonplus_serde.loads(metadata) if metadata is not None else {},
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
[
(task_id, channel, self.serde.loads_typed((type, value)))
for task_id, channel, type, value in wcur
],
)
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database.
This method saves a checkpoint to the SQLite database. The checkpoint is associated
with the provided config and its parent config (if any).
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Examples:
>>> from langgraph.checkpoint.sqlite import SqliteSaver
>>> with SqliteSaver.from_conn_string(":memory:") as memory:
>>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
>>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "channel_values": {"key": "value"}}
>>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {})
>>> print(saved_config)
{'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef4f797-8335-6428-8001-8a1503f9b875'}}
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"]["checkpoint_ns"]
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
serialized_metadata = self.jsonplus_serde.dumps(metadata)
with self.cursor() as cur:
cur.execute(
"INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)",
(
str(config["configurable"]["thread_id"]),
checkpoint_ns,
checkpoint["id"],
config["configurable"].get("checkpoint_id"),
type_,
serialized_checkpoint,
serialized_metadata,
),
)
return {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint.
This method saves intermediate writes associated with a checkpoint to the SQLite database.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair.
task_id (str): Identifier for the task creating the writes.
"""
query = (
"INSERT OR REPLACE INTO writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
if all(w[0] in WRITES_IDX_MAP for w in writes)
else "INSERT OR IGNORE INTO writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
)
with self.cursor() as cur:
cur.executemany(
query,
[
(
str(config["configurable"]["thread_id"]),
str(config["configurable"]["checkpoint_ns"]),
str(config["configurable"]["checkpoint_id"]),
task_id,
WRITES_IDX_MAP.get(channel, idx),
channel,
*self.serde.dumps_typed(value),
)
for idx, (channel, value) in enumerate(writes)
],
)
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database asynchronously.
Note:
This async method is not supported by the SqliteSaver class.
Use get_tuple() instead, or consider using [AsyncSqliteSaver][langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver].
"""
raise NotImplementedError(_AIO_ERROR_MSG)
async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""List checkpoints from the database asynchronously.
Note:
This async method is not supported by the SqliteSaver class.
Use list() instead, or consider using [AsyncSqliteSaver][langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver].
"""
raise NotImplementedError(_AIO_ERROR_MSG)
yield
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database asynchronously.
Note:
This async method is not supported by the SqliteSaver class.
Use put() instead, or consider using [AsyncSqliteSaver][langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver].
"""
raise NotImplementedError(_AIO_ERROR_MSG)
def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
"""Generate the next version ID for a channel.
This method creates a new version identifier for a channel based on its current version.
Args:
current (Optional[str]): The current version identifier of the channel.
channel (BaseChannel): The channel being versioned.
Returns:
str: The next version identifier, which is guaranteed to be monotonically increasing.
"""
if current is None:
current_v = 0
elif isinstance(current, int):
current_v = current
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
next_h = random.random()
return f"{next_v:032}.{next_h:016}"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/utils.py`:
```py
import json
from typing import Any, Dict, Optional, Sequence, Tuple
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import get_checkpoint_id
def _metadata_predicate(
metadata_filter: Dict[str, Any],
) -> Tuple[Sequence[str], Sequence[Any]]:
"""Return WHERE clause predicates for (a)search() given metadata filter.
This method returns a tuple of a string and a tuple of values. The string
is the parametered WHERE clause predicate (excluding the WHERE keyword):
"column1 = ? AND column2 IS ?". The tuple of values contains the values
for each of the corresponding parameters.
"""
def _where_value(query_value: Any) -> Tuple[str, Any]:
"""Return tuple of operator and value for WHERE clause predicate."""
if query_value is None:
return ("IS ?", None)
elif (
isinstance(query_value, str)
or isinstance(query_value, int)
or isinstance(query_value, float)
):
return ("= ?", query_value)
elif isinstance(query_value, bool):
return ("= ?", 1 if query_value else 0)
elif isinstance(query_value, dict) or isinstance(query_value, list):
# query value for JSON object cannot have trailing space after separators (, :)
# SQLite json_extract() returns JSON string without whitespace
return ("= ?", json.dumps(query_value, separators=(",", ":")))
else:
return ("= ?", str(query_value))
predicates = []
param_values = []
# process metadata query
for query_key, query_value in metadata_filter.items():
operator, param_value = _where_value(query_value)
predicates.append(
f"json_extract(CAST(metadata AS TEXT), '$.{query_key}') {operator}"
)
param_values.append(param_value)
return (predicates, param_values)
def search_where(
config: Optional[RunnableConfig],
filter: Optional[Dict[str, Any]],
before: Optional[RunnableConfig] = None,
) -> Tuple[str, Sequence[Any]]:
"""Return WHERE clause predicates for (a)search() given metadata filter
and `before` config.
This method returns a tuple of a string and a tuple of values. The string
is the parametered WHERE clause predicate (including the WHERE keyword):
"WHERE column1 = ? AND column2 IS ?". The tuple of values contains the
values for each of the corresponding parameters.
"""
wheres = []
param_values = []
# construct predicate for config filter
if config is not None:
wheres.append("thread_id = ?")
param_values.append(config["configurable"]["thread_id"])
checkpoint_ns = config["configurable"].get("checkpoint_ns")
if checkpoint_ns is not None:
wheres.append("checkpoint_ns = ?")
param_values.append(checkpoint_ns)
if checkpoint_id := get_checkpoint_id(config):
wheres.append("checkpoint_id = ?")
param_values.append(checkpoint_id)
# construct predicate for metadata filter
if filter:
metadata_predicates, metadata_values = _metadata_predicate(filter)
wheres.extend(metadata_predicates)
param_values.extend(metadata_values)
# construct predicate for `before`
if before is not None:
wheres.append("checkpoint_id < ?")
param_values.append(get_checkpoint_id(before))
return ("WHERE " + " AND ".join(wheres) if wheres else "", param_values)
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py`:
```py
import asyncio
import random
from contextlib import asynccontextmanager
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
Optional,
Sequence,
Tuple,
TypeVar,
)
import aiosqlite
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
SerializerProtocol,
get_checkpoint_id,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import ChannelProtocol
from langgraph.checkpoint.sqlite.utils import search_where
T = TypeVar("T", bound=Callable)
class AsyncSqliteSaver(BaseCheckpointSaver[str]):
"""An asynchronous checkpoint saver that stores checkpoints in a SQLite database.
This class provides an asynchronous interface for saving and retrieving checkpoints
using a SQLite database. It's designed for use in asynchronous environments and
offers better performance for I/O-bound operations compared to synchronous alternatives.
Attributes:
conn (aiosqlite.Connection): The asynchronous SQLite database connection.
serde (SerializerProtocol): The serializer used for encoding/decoding checkpoints.
Tip:
Requires the [aiosqlite](https://pypi.org/project/aiosqlite/) package.
Install it with `pip install aiosqlite`.
Warning:
While this class supports asynchronous checkpointing, it is not recommended
for production workloads due to limitations in SQLite's write performance.
For production use, consider a more robust database like PostgreSQL.
Tip:
Remember to **close the database connection** after executing your code,
otherwise, you may see the graph "hang" after execution (since the program
will not exit until the connection is closed).
The easiest way is to use the `async with` statement as shown in the examples.
```python
async with AsyncSqliteSaver.from_conn_string("checkpoints.sqlite") as saver:
# Your code here
graph = builder.compile(checkpointer=saver)
config = {"configurable": {"thread_id": "thread-1"}}
async for event in graph.astream_events(..., config, version="v1"):
print(event)
```
Examples:
Usage within StateGraph:
```pycon
>>> import asyncio
>>>
>>> from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
>>> from langgraph.graph import StateGraph
>>>
>>> builder = StateGraph(int)
>>> builder.add_node("add_one", lambda x: x + 1)
>>> builder.set_entry_point("add_one")
>>> builder.set_finish_point("add_one")
>>> async with AsyncSqliteSaver.from_conn_string("checkpoints.db") as memory:
>>> graph = builder.compile(checkpointer=memory)
>>> coro = graph.ainvoke(1, {"configurable": {"thread_id": "thread-1"}})
>>> print(asyncio.run(coro))
Output: 2
```
Raw usage:
```pycon
>>> import asyncio
>>> import aiosqlite
>>> from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
>>>
>>> async def main():
>>> async with aiosqlite.connect("checkpoints.db") as conn:
... saver = AsyncSqliteSaver(conn)
... config = {"configurable": {"thread_id": "1"}}
... checkpoint = {"ts": "2023-05-03T10:00:00Z", "data": {"key": "value"}}
... saved_config = await saver.aput(config, checkpoint, {}, {})
... print(saved_config)
>>> asyncio.run(main())
{"configurable": {"thread_id": "1", "checkpoint_id": "0c62ca34-ac19-445d-bbb0-5b4984975b2a"}}
```
"""
lock: asyncio.Lock
is_setup: bool
def __init__(
self,
conn: aiosqlite.Connection,
*,
serde: Optional[SerializerProtocol] = None,
):
super().__init__(serde=serde)
self.jsonplus_serde = JsonPlusSerializer()
self.conn = conn
self.lock = asyncio.Lock()
self.loop = asyncio.get_running_loop()
self.is_setup = False
@classmethod
@asynccontextmanager
async def from_conn_string(
cls, conn_string: str
) -> AsyncIterator["AsyncSqliteSaver"]:
"""Create a new AsyncSqliteSaver instance from a connection string.
Args:
conn_string (str): The SQLite connection string.
Yields:
AsyncSqliteSaver: A new AsyncSqliteSaver instance.
"""
async with aiosqlite.connect(conn_string) as conn:
yield cls(conn)
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database.
This method retrieves a checkpoint tuple from the SQLite database based on the
provided config. If the config contains a "checkpoint_id" key, the checkpoint with
the matching thread ID and checkpoint ID is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
try:
# check if we are in the main thread, only bg threads can block
# we don't check in other methods to avoid the overhead
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AsyncSqliteSaver are only allowed from a "
"different thread. From the main thread, use the async interface."
"For example, use `await checkpointer.aget_tuple(...)` or `await "
"graph.ainvoke(...)`."
)
except RuntimeError:
pass
return asyncio.run_coroutine_threadsafe(
self.aget_tuple(config), self.loop
).result()
def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints from the database asynchronously.
This method retrieves a list of checkpoint tuples from the SQLite database based
on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first).
Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata.
before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None.
limit (Optional[int]): Maximum number of checkpoints to return.
Yields:
Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples.
"""
aiter_ = self.alist(config, filter=filter, before=before, limit=limit)
while True:
try:
yield asyncio.run_coroutine_threadsafe(
anext(aiter_),
self.loop,
).result()
except StopAsyncIteration:
break
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database.
This method saves a checkpoint to the SQLite database. The checkpoint is associated
with the provided config and its parent config (if any).
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
"""
return asyncio.run_coroutine_threadsafe(
self.aput(config, checkpoint, metadata, new_versions), self.loop
).result()
def put_writes(
self, config: RunnableConfig, writes: Sequence[Tuple[str, Any]], task_id: str
) -> None:
return asyncio.run_coroutine_threadsafe(
self.aput_writes(config, writes, task_id), self.loop
).result()
async def setup(self) -> None:
"""Set up the checkpoint database asynchronously.
This method creates the necessary tables in the SQLite database if they don't
already exist. It is called automatically when needed and should not be called
directly by the user.
"""
async with self.lock:
if self.is_setup:
return
if not self.conn.is_alive():
await self.conn
async with self.conn.executescript(
"""
PRAGMA journal_mode=WAL;
CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint BLOB,
metadata BLOB,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);
CREATE TABLE IF NOT EXISTS writes (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx INTEGER NOT NULL,
channel TEXT NOT NULL,
type TEXT,
value BLOB,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);
"""
):
await self.conn.commit()
self.is_setup = True
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database asynchronously.
This method retrieves a checkpoint tuple from the SQLite database based on the
provided config. If the config contains a "checkpoint_id" key, the checkpoint with
the matching thread ID and checkpoint ID is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
await self.setup()
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
async with self.lock, self.conn.cursor() as cur:
# find the latest checkpoint for the thread_id
if checkpoint_id := get_checkpoint_id(config):
await cur.execute(
"SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?",
(
str(config["configurable"]["thread_id"]),
checkpoint_ns,
checkpoint_id,
),
)
else:
await cur.execute(
"SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1",
(str(config["configurable"]["thread_id"]), checkpoint_ns),
)
# if a checkpoint is found, return it
if value := await cur.fetchone():
(
thread_id,
checkpoint_id,
parent_checkpoint_id,
type,
checkpoint,
metadata,
) = value
if not get_checkpoint_id(config):
config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
}
# find any pending writes
await cur.execute(
"SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? ORDER BY task_id, idx",
(
str(config["configurable"]["thread_id"]),
checkpoint_ns,
str(config["configurable"]["checkpoint_id"]),
),
)
# deserialize the checkpoint and metadata
return CheckpointTuple(
config,
self.serde.loads_typed((type, checkpoint)),
self.jsonplus_serde.loads(metadata) if metadata is not None else {},
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
[
(task_id, channel, self.serde.loads_typed((type, value)))
async for task_id, channel, type, value in cur
],
)
async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""List checkpoints from the database asynchronously.
This method retrieves a list of checkpoint tuples from the SQLite database based
on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first).
Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata.
before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None.
limit (Optional[int]): Maximum number of checkpoints to return.
Yields:
AsyncIterator[CheckpointTuple]: An asynchronous iterator of matching checkpoint tuples.
"""
await self.setup()
where, params = search_where(config, filter, before)
query = f"""SELECT thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata
FROM checkpoints
{where}
ORDER BY checkpoint_id DESC"""
if limit:
query += f" LIMIT {limit}"
async with self.lock, self.conn.execute(
query, params
) as cur, self.conn.cursor() as wcur:
async for (
thread_id,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
type,
checkpoint,
metadata,
) in cur:
await wcur.execute(
"SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? ORDER BY task_id, idx",
(thread_id, checkpoint_ns, checkpoint_id),
)
yield CheckpointTuple(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
self.serde.loads_typed((type, checkpoint)),
self.jsonplus_serde.loads(metadata) if metadata is not None else {},
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
[
(task_id, channel, self.serde.loads_typed((type, value)))
async for task_id, channel, type, value in wcur
],
)
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database asynchronously.
This method saves a checkpoint to the SQLite database. The checkpoint is associated
with the provided config and its parent config (if any).
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
"""
await self.setup()
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"]["checkpoint_ns"]
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
serialized_metadata = self.jsonplus_serde.dumps(metadata)
async with self.lock, self.conn.execute(
"INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)",
(
str(config["configurable"]["thread_id"]),
checkpoint_ns,
checkpoint["id"],
config["configurable"].get("checkpoint_id"),
type_,
serialized_checkpoint,
serialized_metadata,
),
):
await self.conn.commit()
return {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint asynchronously.
This method saves intermediate writes associated with a checkpoint to the database.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair.
task_id (str): Identifier for the task creating the writes.
"""
query = (
"INSERT OR REPLACE INTO writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
if all(w[0] in WRITES_IDX_MAP for w in writes)
else "INSERT OR IGNORE INTO writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
)
await self.setup()
async with self.lock, self.conn.cursor() as cur:
await cur.executemany(
query,
[
(
str(config["configurable"]["thread_id"]),
str(config["configurable"]["checkpoint_ns"]),
str(config["configurable"]["checkpoint_id"]),
task_id,
WRITES_IDX_MAP.get(channel, idx),
channel,
*self.serde.dumps_typed(value),
)
for idx, (channel, value) in enumerate(writes)
],
)
def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
"""Generate the next version ID for a channel.
This method creates a new version identifier for a channel based on its current version.
Args:
current (Optional[str]): The current version identifier of the channel.
channel (BaseChannel): The channel being versioned.
Returns:
str: The next version identifier, which is guaranteed to be monotonically increasing.
"""
if current is None:
current_v = 0
elif isinstance(current, int):
current_v = current
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
next_h = random.random()
return f"{next_v:032}.{next_h:016}"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/pyproject.toml`:
```toml
[tool.poetry]
name = "langgraph-checkpoint-sqlite"
version = "2.0.1"
description = "Library with a SQLite implementation of LangGraph checkpoint saver."
authors = []
license = "MIT"
readme = "README.md"
repository = "https://www.github.com/langchain-ai/langgraph"
packages = [{ include = "langgraph" }]
[tool.poetry.dependencies]
python = "^3.9.0"
langgraph-checkpoint = "^2.0.2"
aiosqlite = "^0.20.0"
[tool.poetry.group.dev.dependencies]
ruff = "^0.6.2"
codespell = "^2.2.0"
pytest = "^7.2.1"
pytest-asyncio = "^0.21.1"
pytest-mock = "^3.11.1"
pytest-watcher = "^0.4.1"
mypy = "^1.10.0"
langgraph-checkpoint = {path = "../checkpoint", develop = true}
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5 -vv"
asyncio_mode = "auto"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
lint.select = [
"E", # pycodestyle
"F", # Pyflakes
"UP", # pyupgrade
"B", # flake8-bugbear
"I", # isort
]
lint.ignore = ["E501", "B008", "UP007", "UP006"]
[tool.pytest-watcher]
now = true
delay = 0.1
runner_args = ["--ff", "-v", "--tb", "short"]
patterns = ["*.py"]
[tool.mypy]
# https://mypy.readthedocs.io/en/stable/config_file.html
disallow_untyped_defs = "True"
explicit_package_bases = "True"
warn_no_return = "False"
warn_unused_ignores = "True"
warn_redundant_casts = "True"
allow_redefinition = "True"
disable_error_code = "typeddict-item, return-value"
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/tests/test_aiosqlite.py`:
```py
from typing import Any
import pytest
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Checkpoint,
CheckpointMetadata,
create_checkpoint,
empty_checkpoint,
)
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
class TestAsyncSqliteSaver:
@pytest.fixture(autouse=True)
def setup(self) -> None:
# objects for test setup
self.config_1: RunnableConfig = {
"configurable": {
"thread_id": "thread-1",
# for backwards compatibility testing
"thread_ts": "1",
"checkpoint_ns": "",
}
}
self.config_2: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2",
"checkpoint_ns": "",
}
}
self.config_3: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2-inner",
"checkpoint_ns": "inner",
}
}
self.chkpnt_1: Checkpoint = empty_checkpoint()
self.chkpnt_2: Checkpoint = create_checkpoint(self.chkpnt_1, {}, 1)
self.chkpnt_3: Checkpoint = empty_checkpoint()
self.metadata_1: CheckpointMetadata = {
"source": "input",
"step": 2,
"writes": {},
"score": 1,
}
self.metadata_2: CheckpointMetadata = {
"source": "loop",
"step": 1,
"writes": {"foo": "bar"},
"score": None,
}
self.metadata_3: CheckpointMetadata = {}
async def test_asearch(self) -> None:
async with AsyncSqliteSaver.from_conn_string(":memory:") as saver:
await saver.aput(self.config_1, self.chkpnt_1, self.metadata_1, {})
await saver.aput(self.config_2, self.chkpnt_2, self.metadata_2, {})
await saver.aput(self.config_3, self.chkpnt_3, self.metadata_3, {})
# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match
search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
assert len(search_results_1) == 1
assert search_results_1[0].metadata == self.metadata_1
search_results_2 = [c async for c in saver.alist(None, filter=query_2)]
assert len(search_results_2) == 1
assert search_results_2[0].metadata == self.metadata_2
search_results_3 = [c async for c in saver.alist(None, filter=query_3)]
assert len(search_results_3) == 3
search_results_4 = [c async for c in saver.alist(None, filter=query_4)]
assert len(search_results_4) == 0
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = [
c
async for c in saver.alist({"configurable": {"thread_id": "thread-2"}})
]
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}
# TODO: test before and limit params
```
`/Users/malcolm/dev/langchain-ai/langgraph/libs/checkpoint-sqlite/tests/test_sqlite.py`:
```py
from typing import Any, cast
import pytest
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Checkpoint,
CheckpointMetadata,
create_checkpoint,
empty_checkpoint,
)
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint.sqlite.utils import _metadata_predicate, search_where
class TestSqliteSaver:
@pytest.fixture(autouse=True)
def setup(self) -> None:
# objects for test setup
self.config_1: RunnableConfig = {
"configurable": {
"thread_id": "thread-1",
# for backwards compatibility testing
"thread_ts": "1",
"checkpoint_ns": "",
}
}
self.config_2: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2",
"checkpoint_ns": "",
}
}
self.config_3: RunnableConfig = {
"configurable": {
"thread_id": "thread-2",
"checkpoint_id": "2-inner",
"checkpoint_ns": "inner",
}
}
self.chkpnt_1: Checkpoint = empty_checkpoint()
self.chkpnt_2: Checkpoint = create_checkpoint(self.chkpnt_1, {}, 1)
self.chkpnt_3: Checkpoint = empty_checkpoint()
self.metadata_1: CheckpointMetadata = {
"source": "input",
"step": 2,
"writes": {},
"score": 1,
}
self.metadata_2: CheckpointMetadata = {
"source": "loop",
"step": 1,
"writes": {"foo": "bar"},
"score": None,
}
self.metadata_3: CheckpointMetadata = {}
def test_search(self) -> None:
with SqliteSaver.from_conn_string(":memory:") as saver:
# set up test
# save checkpoints
saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {})
saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {})
saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {})
# call method / assertions
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match
search_results_1 = list(saver.list(None, filter=query_1))
assert len(search_results_1) == 1
assert search_results_1[0].metadata == self.metadata_1
search_results_2 = list(saver.list(None, filter=query_2))
assert len(search_results_2) == 1
assert search_results_2[0].metadata == self.metadata_2
search_results_3 = list(saver.list(None, filter=query_3))
assert len(search_results_3) == 3
search_results_4 = list(saver.list(None, filter=query_4))
assert len(search_results_4) == 0
# search by config (defaults to checkpoints across all namespaces)
search_results_5 = list(
saver.list({"configurable": {"thread_id": "thread-2"}})
)
assert len(search_results_5) == 2
assert {
search_results_5[0].config["configurable"]["checkpoint_ns"],
search_results_5[1].config["configurable"]["checkpoint_ns"],
} == {"", "inner"}
# TODO: test before and limit params
def test_search_where(self) -> None:
# call method / assertions
expected_predicate_1 = "WHERE json_extract(CAST(metadata AS TEXT), '$.source') = ? AND json_extract(CAST(metadata AS TEXT), '$.step') = ? AND json_extract(CAST(metadata AS TEXT), '$.writes') = ? AND json_extract(CAST(metadata AS TEXT), '$.score') = ? AND checkpoint_id < ?"
expected_param_values_1 = ["input", 2, "{}", 1, "1"]
assert search_where(
None, cast(dict[str, Any], self.metadata_1), self.config_1
) == (
expected_predicate_1,
expected_param_values_1,
)
def test_metadata_predicate(self) -> None:
# call method / assertions
expected_predicate_1 = [
"json_extract(CAST(metadata AS TEXT), '$.source') = ?",
"json_extract(CAST(metadata AS TEXT), '$.step') = ?",
"json_extract(CAST(metadata AS TEXT), '$.writes') = ?",
"json_extract(CAST(metadata AS TEXT), '$.score') = ?",
]
expected_predicate_2 = [
"json_extract(CAST(metadata AS TEXT), '$.source') = ?",
"json_extract(CAST(metadata AS TEXT), '$.step') = ?",
"json_extract(CAST(metadata AS TEXT), '$.writes') = ?",
"json_extract(CAST(metadata AS TEXT), '$.score') IS ?",
]
expected_predicate_3: list[str] = []
expected_param_values_1 = ["input", 2, "{}", 1]
expected_param_values_2 = ["loop", 1, '{"foo":"bar"}', None]
expected_param_values_3: list[Any] = []
assert _metadata_predicate(cast(dict[str, Any], self.metadata_1)) == (
expected_predicate_1,
expected_param_values_1,
)
assert _metadata_predicate(cast(dict[str, Any], self.metadata_2)) == (
expected_predicate_2,
expected_param_values_2,
)
assert _metadata_predicate(cast(dict[str, Any], self.metadata_3)) == (
expected_predicate_3,
expected_param_values_3,
)
async def test_informative_async_errors(self) -> None:
with SqliteSaver.from_conn_string(":memory:") as saver:
# call method / assertions
with pytest.raises(NotImplementedError, match="AsyncSqliteSaver"):
await saver.aget(self.config_1)
with pytest.raises(NotImplementedError, match="AsyncSqliteSaver"):
await saver.aget_tuple(self.config_1)
with pytest.raises(NotImplementedError, match="AsyncSqliteSaver"):
async for _ in saver.alist(self.config_1):
pass
```
`/Users/malcolm/dev/langchain-ai/langgraph/examples/chatbot-simulation-evaluation/simulation_utils.py`:
```py
import functools
from typing import Annotated, Any, Callable, Dict, List, Optional, Union
from langchain_community.adapters.openai import convert_message_to_dict
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.runnables import chain as as_runnable
from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict
from langgraph.graph import END, StateGraph, START
def langchain_to_openai_messages(messages: List[BaseMessage]):
"""
Convert a list of langchain base messages to a list of openai messages.
Parameters:
messages (List[BaseMessage]): A list of langchain base messages.
Returns:
List[dict]: A list of openai messages.
"""
return [
convert_message_to_dict(m) if isinstance(m, BaseMessage) else m
for m in messages
]
def create_simulated_user(
system_prompt: str, llm: Runnable | None = None
) -> Runnable[Dict, AIMessage]:
"""
Creates a simulated user for chatbot simulation.
Args:
system_prompt (str): The system prompt to be used by the simulated user.
llm (Runnable | None, optional): The language model to be used for the simulation.
Defaults to gpt-3.5-turbo.
Returns:
Runnable[Dict, AIMessage]: The simulated user for chatbot simulation.
"""
return ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
]
) | (llm or ChatOpenAI(model="gpt-3.5-turbo")).with_config(
run_name="simulated_user"
)
Messages = Union[list[AnyMessage], AnyMessage]
def add_messages(left: Messages, right: Messages) -> Messages:
if not isinstance(left, list):
left = [left]
if not isinstance(right, list):
right = [right]
return left + right
class SimulationState(TypedDict):
"""
Represents the state of a simulation.
Attributes:
messages (List[AnyMessage]): A list of messages in the simulation.
inputs (Optional[dict[str, Any]]): Optional inputs for the simulation.
"""
messages: Annotated[List[AnyMessage], add_messages]
inputs: Optional[dict[str, Any]]
def create_chat_simulator(
assistant: (
Callable[[List[AnyMessage]], str | AIMessage]
| Runnable[List[AnyMessage], str | AIMessage]
),
simulated_user: Runnable[Dict, AIMessage],
*,
input_key: str,
max_turns: int = 6,
should_continue: Optional[Callable[[SimulationState], str]] = None,
):
"""Creates a chat simulator for evaluating a chatbot.
Args:
assistant: The chatbot assistant function or runnable object.
simulated_user: The simulated user object.
input_key: The key for the input to the chat simulation.
max_turns: The maximum number of turns in the chat simulation. Default is 6.
should_continue: Optional function to determine if the simulation should continue.
If not provided, a default function will be used.
Returns:
The compiled chat simulation graph.
"""
graph_builder = StateGraph(SimulationState)
graph_builder.add_node(
"user",
_create_simulated_user_node(simulated_user),
)
graph_builder.add_node(
"assistant", _fetch_messages | assistant | _coerce_to_message
)
graph_builder.add_edge("assistant", "user")
graph_builder.add_conditional_edges(
"user",
should_continue or functools.partial(_should_continue, max_turns=max_turns),
)
# If your dataset has a 'leading question/input', then we route first to the assistant, otherwise, we let the user take the lead.
graph_builder.add_edge(START, "assistant" if input_key is not None else "user")
return (
RunnableLambda(_prepare_example).bind(input_key=input_key)
| graph_builder.compile()
)
## Private methods
def _prepare_example(inputs: dict[str, Any], input_key: Optional[str] = None):
if input_key is not None:
if input_key not in inputs:
raise ValueError(
f"Dataset's example input must contain the provided input key: '{input_key}'.\nFound: {list(inputs.keys())}"
)
messages = [HumanMessage(content=inputs[input_key])]
return {
"inputs": {k: v for k, v in inputs.items() if k != input_key},
"messages": messages,
}
return {"inputs": inputs, "messages": []}
def _invoke_simulated_user(state: SimulationState, simulated_user: Runnable):
"""Invoke the simulated user node."""
runnable = (
simulated_user
if isinstance(simulated_user, Runnable)
else RunnableLambda(simulated_user)
)
inputs = state.get("inputs", {})
inputs["messages"] = state["messages"]
return runnable.invoke(inputs)
def _swap_roles(state: SimulationState):
new_messages = []
for m in state["messages"]:
if isinstance(m, AIMessage):
new_messages.append(HumanMessage(content=m.content))
else:
new_messages.append(AIMessage(content=m.content))
return {
"inputs": state.get("inputs", {}),
"messages": new_messages,
}
@as_runnable
def _fetch_messages(state: SimulationState):
"""Invoke the simulated user node."""
return state["messages"]
def _convert_to_human_message(message: BaseMessage):
return {"messages": [HumanMessage(content=message.content)]}
def _create_simulated_user_node(simulated_user: Runnable):
"""Simulated user accepts a {"messages": [...]} argument and returns a single message."""
return (
_swap_roles
| RunnableLambda(_invoke_simulated_user).bind(simulated_user=simulated_user)
| _convert_to_human_message
)
def _coerce_to_message(assistant_output: str | BaseMessage):
if isinstance(assistant_output, str):
return {"messages": [AIMessage(content=assistant_output)]}
else:
return {"messages": [assistant_output]}
def _should_continue(state: SimulationState, max_turns: int = 6):
messages = state["messages"]
# TODO support other stop criteria
if len(messages) > max_turns:
return END
elif messages[-1].content.strip() == "FINISHED":
return END
else:
return "assistant"
```