diff --git a/libs/langgraph/langgraph/channels/delta.py b/libs/langgraph/langgraph/channels/delta.py index 052ced50df5..eecf954770b 100644 --- a/libs/langgraph/langgraph/channels/delta.py +++ b/libs/langgraph/langgraph/channels/delta.py @@ -172,13 +172,11 @@ def update(self, values: Sequence[Any]) -> bool: overwrite_idx = i if overwrite_idx is not None: _, overwrite_value = _get_overwrite(values[overwrite_idx]) - base = ( + self.value = ( _copy.copy(overwrite_value) if overwrite_value is not None else self.typ() ) - remaining = [v for i, v in enumerate(values) if i != overwrite_idx] - self.value = self.reducer(base, remaining) if remaining else base return True base = self.typ() if self.value is MISSING else self.value self.value = self.reducer(base, list(values)) diff --git a/libs/langgraph/langgraph/pregel/_loop.py b/libs/langgraph/langgraph/pregel/_loop.py index 587f5905ae4..af8d1664dfa 100644 --- a/libs/langgraph/langgraph/pregel/_loop.py +++ b/libs/langgraph/langgraph/pregel/_loop.py @@ -70,6 +70,7 @@ GraphResumeEvent, ) from langgraph.channels.base import BaseChannel +from langgraph.channels.binop import _get_overwrite from langgraph.channels.delta import DeltaChannel from langgraph.channels.untracked_value import UntrackedValue from langgraph.constants import TAG_HIDDEN @@ -220,6 +221,11 @@ class PregelLoop: # under the saver's `ORDER BY task_id, idx` sorting. _exit_delta_writes: list[tuple[int, str, str, Any]] | None = None + # Delta channels that saw an Overwrite since the last checkpoint. These + # channels must snapshot after live update applies overwrite semantics so + # sparse replay starts from the same post-overwrite value. + _delta_channels_with_overwrite: set[str] + # The checkpoint_config that points at the parent loaded at `__enter__` # (or the synthetic-empty checkpoint, on first run). We capture it # eagerly because every `_put_checkpoint` advances `self.checkpoint_config` @@ -676,6 +682,11 @@ def tick(self) -> bool: def after_tick(self) -> None: # finish superstep writes = [w for t in self.tasks.values() for w in t.writes] + self._delta_channels_with_overwrite.update( + ch + for ch, v in writes + if isinstance(self.specs.get(ch), DeltaChannel) and _get_overwrite(v)[0] + ) # all tasks have finished self.updated_channels = apply_writes( self.checkpoint, @@ -979,6 +990,11 @@ def _first( manager=None, updated_channels=updated_channels, ) + self._delta_channels_with_overwrite.update( + c + for c, v in input_writes + if isinstance(self.specs.get(c), DeltaChannel) and _get_overwrite(v)[0] + ) # apply input writes updated_channels = apply_writes( self.checkpoint, @@ -1119,6 +1135,7 @@ def _put_checkpoint(self, metadata: CheckpointMetadata) -> None: # create new checkpoint channels_to_snapshot = ( delta_channels_to_snapshot(self.channels, new_counters) + | self._delta_channels_with_overwrite if do_checkpoint else set() ) @@ -1135,6 +1152,8 @@ def _put_checkpoint(self, metadata: CheckpointMetadata) -> None: ) for k in channels_to_snapshot: new_counters[k] = (0, 0) + if do_checkpoint: + self._delta_channels_with_overwrite.difference_update(channels_to_snapshot) non_zero = {k: v for k, v in new_counters.items() if v != (0, 0)} if non_zero: self.checkpoint_metadata["counters_since_delta_snapshot"] = non_zero @@ -1217,7 +1236,10 @@ def _put_exit_delta_writes(self) -> None: counters = dict( self.checkpoint_metadata.get("counters_since_delta_snapshot") or {} ) - channels_to_snapshot = delta_channels_to_snapshot(self.channels, counters) + channels_to_snapshot = ( + delta_channels_to_snapshot(self.channels, counters) + | self._delta_channels_with_overwrite + ) pending = [ (step, tid, ch, v) @@ -1661,6 +1683,7 @@ def __enter__(self) -> Self: ) self._delta_write_futs = [] self._error_handler_write_futs = [] + self._delta_channels_with_overwrite = set() self._exit_delta_writes = ( [] if self.durability == "exit" and self.checkpointer is not None else None ) @@ -1918,6 +1941,7 @@ async def __aenter__(self) -> Self: ) self._delta_write_futs = [] self._error_handler_write_futs = [] + self._delta_channels_with_overwrite = set() self._exit_delta_writes = ( [] if self.durability == "exit" and self.checkpointer is not None else None ) diff --git a/libs/langgraph/tests/test_channels.py b/libs/langgraph/tests/test_channels.py index 9dfa7158a41..d194f55fdd5 100644 --- a/libs/langgraph/tests/test_channels.py +++ b/libs/langgraph/tests/test_channels.py @@ -397,6 +397,93 @@ def respond(state: State) -> dict: assert len(state.values["messages"]) == 4 # 2 human + 2 AI +def test_delta_channel_overwrite_superstep_snapshots() -> None: + def reducer(state: list[str], writes: Sequence[list[str]]) -> list[str]: + result = list(state) + for write in writes: + result.extend(write) + return result + + class State(TypedDict): + items: Annotated[ + list[str], DeltaChannel(reducer, list, snapshot_frequency=1000) + ] + + def node_a(state: State) -> dict: + return {"items": ["a"]} + + def node_b(state: State) -> dict: + return {"items": Overwrite(["b"])} + + def node_c(state: State) -> dict: + return {"items": ["c"]} + + builder = StateGraph(State) + builder.add_node("node_a", node_a) + builder.add_node("node_b", node_b) + builder.add_node("node_c", node_c) + builder.add_edge(START, "node_a") + builder.add_edge("node_a", "node_b") + builder.add_edge("node_a", "node_c") + + saver = InMemorySaver() + graph = builder.compile(checkpointer=saver) + config = {"configurable": {"thread_id": "overwrite-snapshot"}} + + result = graph.invoke({"items": ["START"]}, config) + assert result == {"items": ["b"]} + + saved = saver.get_tuple(config) + assert saved is not None + snapshot = saved.checkpoint["channel_values"].get("items") + assert isinstance(snapshot, _DeltaSnapshot) + assert snapshot.value == ["b"] + assert saved.metadata.get("counters_since_delta_snapshot", {}).get("items") is None + + +def test_delta_channel_replay_after_overwrite_snapshot() -> None: + def reducer(state: list[str], writes: Sequence[list[str]]) -> list[str]: + result = list(state) + for write in writes: + result.extend(write) + return result + + class State(TypedDict): + items: Annotated[ + list[str], DeltaChannel(reducer, list, snapshot_frequency=1000) + ] + + calls = 0 + + def node(state: State) -> dict: + nonlocal calls + calls += 1 + if calls == 1: + return {"items": Overwrite(["reset"])} + return {"items": ["after"]} + + builder = StateGraph(State) + builder.add_node("node", node) + builder.add_edge(START, "node") + + saver = InMemorySaver() + graph = builder.compile(checkpointer=saver) + config = {"configurable": {"thread_id": "overwrite-replay"}} + + assert graph.invoke({"items": ["before"]}, config) == {"items": ["reset"]} + first_saved = saver.get_tuple(config) + assert first_saved is not None + assert isinstance( + first_saved.checkpoint["channel_values"].get("items"), _DeltaSnapshot + ) + + assert graph.invoke({"items": []}, config) == {"items": ["reset", "after"]} + second_saved = saver.get_tuple(config) + assert second_saved is not None + assert "items" not in second_saved.checkpoint["channel_values"] + assert graph.get_state(config).values == {"items": ["reset", "after"]} + + # --------------------------------------------------------------------------- # DeltaChannel — dict reducer # ---------------------------------------------------------------------------