Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions libs/langgraph/langgraph/channels/delta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import annotations

import collections.abc
Expand Down Expand Up @@ -177,8 +177,7 @@
if overwrite_value is not None
else self.typ()
)
remaining = [v for i, v in enumerate(values) if i != overwrite_idx]
self.value = self.reducer(base, remaining) if remaining else base
self.value = base

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Checkpoint replay keeps discarded writes

This makes live execution treat an overwrite as a hard reset for the whole superstep, but the checkpoint history/replay path still records and replays the other writes from that same superstep. For example, a parallel step that writes Overwrite(["b"]) and ["c"] now returns ["b"] live at that step, but after the next checkpoint reload DeltaChannel.replay_writes() sees both writes and reconstructs ["b", "c"] (and subsequent deltas build on that wrong state). I reproduced this with the new parallel DeltaChannel graph shape: invoke() returned {'messages': ['b', 'd']} while get_state() reconstructed {'messages': ['b', 'c', 'd']}. Please update the replay/history side to drop the same-step deltas that live execution now discards, or otherwise make replay use the same hard-reset semantics.

(Refers to line 180)


Your feedback helps Open SWE learn. React with 👍 or 👎 to tell us if this review comment was useful.

return True
base = self.typ() if self.value is MISSING else self.value
self.value = self.reducer(base, list(values))
Expand Down
12 changes: 12 additions & 0 deletions libs/langgraph/tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,18 @@ def test_delta_channel_overwrite() -> None:
assert ch.get()[0].content == "new"


def test_delta_channel_overwrite_bypasses_same_step_reducer_writes() -> None:
def list_reducer(state: list, writes: list) -> list:
out = list(state)
for w in writes:
out.extend(w)
return out

ch = DeltaChannel(list_reducer, list).from_checkpoint(MISSING)
ch.update([[1], Overwrite([50]), [2]])
assert ch.get() == [50]


def test_delta_channel_remove_message_and_replay() -> None:
"""RemoveMessage must round-trip correctly when writes are replayed."""
spec = DeltaChannel(_messages_delta_reducer, list)
Expand Down
106 changes: 106 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9281,6 +9281,13 @@ def tool_node(state: State) -> State:
assert state.values.get("dictionary") == {"session_resource": "legal_value"}


def _delta_list_reducer(state: list, writes: Sequence[list]) -> list:
out = list(state)
for write in writes:
out.extend(write)
return out


@pytest.mark.parametrize("as_json", [False, True])
def test_overwrite_sequential(
sync_checkpointer: BaseCheckpointSaver, as_json: bool
Expand Down Expand Up @@ -9388,6 +9395,105 @@ def node_c(state: State):
graph.invoke({"messages": ["START"]}, config)


@pytest.mark.parametrize("as_json", [False, True])
def test_delta_channel_overwrite_sequential(
sync_checkpointer: BaseCheckpointSaver, as_json: bool
) -> None:
class State(TypedDict):
messages: Annotated[list, DeltaChannel(_delta_list_reducer)]

def node_a(state: State):
return {"messages": ["a"]}

def node_b(state: State):
overwrite = {"__overwrite__": ["b"]} if as_json else Overwrite(["b"])
return {"messages": overwrite}

builder = StateGraph(State)
builder.add_node("node_a", node_a)
builder.add_node("node_b", node_b)
builder.add_edge(START, "node_a")
builder.add_edge("node_a", "node_b")

graph = builder.compile(checkpointer=sync_checkpointer)
config = {"configurable": {"thread_id": "delta-overwrite-sequential"}}
result = graph.invoke({"messages": ["START"]}, config)
assert result == {"messages": ["b"]}


@pytest.mark.parametrize("as_json", [False, True])
def test_delta_channel_overwrite_parallel(
sync_checkpointer: BaseCheckpointSaver, as_json: bool
) -> None:
class State(TypedDict):
messages: Annotated[list, DeltaChannel(_delta_list_reducer)]

def node_a(state: State):
return {"messages": ["a"]}

def node_b(state: State):
overwrite = {"__overwrite__": ["b"]} if as_json else Overwrite(["b"])
return {"messages": overwrite}

def node_c(state: State):
return {"messages": ["c"]}

def node_d(state: State):
return {"messages": ["d"]}

builder = StateGraph(State)
builder.add_node("node_a", node_a)
builder.add_node("node_b", node_b)
builder.add_node("node_c", node_c)
builder.add_node("node_d", node_d)
builder.add_edge(START, "node_a")
builder.add_edge("node_a", "node_b")
builder.add_edge("node_a", "node_c")
builder.add_edge("node_b", "node_d")
builder.add_edge("node_c", "node_d")

graph = builder.compile(checkpointer=sync_checkpointer)
config = {"configurable": {"thread_id": "delta-overwrite-parallel"}}
result = graph.invoke({"messages": ["START"]}, config)
assert result == {"messages": ["b", "d"]}


@pytest.mark.parametrize("as_json", [False, True])
def test_delta_channel_overwrite_parallel_error(
sync_checkpointer: BaseCheckpointSaver, as_json: bool
) -> None:
class State(TypedDict):
messages: Annotated[list, DeltaChannel(_delta_list_reducer)]

def node_a(state: State):
return {"messages": ["a"]}

def node_b(state: State):
overwrite = {"__overwrite__": ["b"]} if as_json else Overwrite(["b"])
return {"messages": overwrite}

def node_c(state: State):
overwrite = {"__overwrite__": ["c"]} if as_json else Overwrite(["c"])
return {"messages": overwrite}

builder = StateGraph(State)
builder.add_node("node_a", node_a)
builder.add_node("node_b", node_b)
builder.add_node("node_c", node_c)
builder.add_edge(START, "node_a")
builder.add_edge("node_a", "node_b")
builder.add_edge("node_a", "node_c")
builder.add_edge("node_b", END)
builder.add_edge("node_c", END)

graph = builder.compile(checkpointer=sync_checkpointer)
config = {"configurable": {"thread_id": "delta-overwrite-parallel-error"}}
with pytest.raises(
InvalidUpdateError, match="Can receive only one Overwrite value per super-step."
):
graph.invoke({"messages": ["START"]}, config)


def test_fork_does_not_apply_pending_writes(
sync_checkpointer: BaseCheckpointSaver,
) -> None:
Expand Down
Loading