diff --git a/libs/langgraph/langgraph/channels/delta.py b/libs/langgraph/langgraph/channels/delta.py index 052ced50df5..0523d55bde9 100644 --- a/libs/langgraph/langgraph/channels/delta.py +++ b/libs/langgraph/langgraph/channels/delta.py @@ -177,8 +177,7 @@ def update(self, values: Sequence[Any]) -> bool: if overwrite_value is not None else self.typ() ) - remaining = [v for i, v in enumerate(values) if i != overwrite_idx] - self.value = self.reducer(base, remaining) if remaining else base + self.value = base return True base = self.typ() if self.value is MISSING else self.value self.value = self.reducer(base, list(values)) diff --git a/libs/langgraph/tests/test_channels.py b/libs/langgraph/tests/test_channels.py index 9dfa7158a41..0df856fae13 100644 --- a/libs/langgraph/tests/test_channels.py +++ b/libs/langgraph/tests/test_channels.py @@ -186,6 +186,18 @@ def test_delta_channel_overwrite() -> None: assert ch.get()[0].content == "new" +def test_delta_channel_overwrite_bypasses_same_step_reducer_writes() -> None: + def list_reducer(state: list, writes: list) -> list: + out = list(state) + for w in writes: + out.extend(w) + return out + + ch = DeltaChannel(list_reducer, list).from_checkpoint(MISSING) + ch.update([[1], Overwrite([50]), [2]]) + assert ch.get() == [50] + + def test_delta_channel_remove_message_and_replay() -> None: """RemoveMessage must round-trip correctly when writes are replayed.""" spec = DeltaChannel(_messages_delta_reducer, list) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 0aae1318e13..f44a4b00b05 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -9281,6 +9281,13 @@ def tool_node(state: State) -> State: assert state.values.get("dictionary") == {"session_resource": "legal_value"} +def _delta_list_reducer(state: list, writes: Sequence[list]) -> list: + out = list(state) + for write in writes: + out.extend(write) + return out + + @pytest.mark.parametrize("as_json", [False, True]) def test_overwrite_sequential( sync_checkpointer: BaseCheckpointSaver, as_json: bool @@ -9388,6 +9395,105 @@ def node_c(state: State): graph.invoke({"messages": ["START"]}, config) +@pytest.mark.parametrize("as_json", [False, True]) +def test_delta_channel_overwrite_sequential( + sync_checkpointer: BaseCheckpointSaver, as_json: bool +) -> None: + class State(TypedDict): + messages: Annotated[list, DeltaChannel(_delta_list_reducer)] + + def node_a(state: State): + return {"messages": ["a"]} + + def node_b(state: State): + overwrite = {"__overwrite__": ["b"]} if as_json else Overwrite(["b"]) + return {"messages": overwrite} + + builder = StateGraph(State) + builder.add_node("node_a", node_a) + builder.add_node("node_b", node_b) + builder.add_edge(START, "node_a") + builder.add_edge("node_a", "node_b") + + graph = builder.compile(checkpointer=sync_checkpointer) + config = {"configurable": {"thread_id": "delta-overwrite-sequential"}} + result = graph.invoke({"messages": ["START"]}, config) + assert result == {"messages": ["b"]} + + +@pytest.mark.parametrize("as_json", [False, True]) +def test_delta_channel_overwrite_parallel( + sync_checkpointer: BaseCheckpointSaver, as_json: bool +) -> None: + class State(TypedDict): + messages: Annotated[list, DeltaChannel(_delta_list_reducer)] + + def node_a(state: State): + return {"messages": ["a"]} + + def node_b(state: State): + overwrite = {"__overwrite__": ["b"]} if as_json else Overwrite(["b"]) + return {"messages": overwrite} + + def node_c(state: State): + return {"messages": ["c"]} + + def node_d(state: State): + return {"messages": ["d"]} + + builder = StateGraph(State) + builder.add_node("node_a", node_a) + builder.add_node("node_b", node_b) + builder.add_node("node_c", node_c) + builder.add_node("node_d", node_d) + builder.add_edge(START, "node_a") + builder.add_edge("node_a", "node_b") + builder.add_edge("node_a", "node_c") + builder.add_edge("node_b", "node_d") + builder.add_edge("node_c", "node_d") + + graph = builder.compile(checkpointer=sync_checkpointer) + config = {"configurable": {"thread_id": "delta-overwrite-parallel"}} + result = graph.invoke({"messages": ["START"]}, config) + assert result == {"messages": ["b", "d"]} + + +@pytest.mark.parametrize("as_json", [False, True]) +def test_delta_channel_overwrite_parallel_error( + sync_checkpointer: BaseCheckpointSaver, as_json: bool +) -> None: + class State(TypedDict): + messages: Annotated[list, DeltaChannel(_delta_list_reducer)] + + def node_a(state: State): + return {"messages": ["a"]} + + def node_b(state: State): + overwrite = {"__overwrite__": ["b"]} if as_json else Overwrite(["b"]) + return {"messages": overwrite} + + def node_c(state: State): + overwrite = {"__overwrite__": ["c"]} if as_json else Overwrite(["c"]) + return {"messages": overwrite} + + builder = StateGraph(State) + builder.add_node("node_a", node_a) + builder.add_node("node_b", node_b) + builder.add_node("node_c", node_c) + builder.add_edge(START, "node_a") + builder.add_edge("node_a", "node_b") + builder.add_edge("node_a", "node_c") + builder.add_edge("node_b", END) + builder.add_edge("node_c", END) + + graph = builder.compile(checkpointer=sync_checkpointer) + config = {"configurable": {"thread_id": "delta-overwrite-parallel-error"}} + with pytest.raises( + InvalidUpdateError, match="Can receive only one Overwrite value per super-step." + ): + graph.invoke({"messages": ["START"]}, config) + + def test_fork_does_not_apply_pending_writes( sync_checkpointer: BaseCheckpointSaver, ) -> None: