Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions libs/langgraph/langgraph/channels/delta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import annotations

import collections.abc
Expand Down Expand Up @@ -172,13 +172,11 @@
overwrite_idx = i
if overwrite_idx is not None:
_, overwrite_value = _get_overwrite(values[overwrite_idx])
base = (
self.value = (
_copy.copy(overwrite_value)
if overwrite_value is not None
else self.typ()
)
remaining = [v for i, v in enumerate(values) if i != overwrite_idx]
self.value = self.reducer(base, remaining) if remaining else base
return True
base = self.typ() if self.value is MISSING else self.value
self.value = self.reducer(base, list(values))
Expand Down
26 changes: 25 additions & 1 deletion libs/langgraph/langgraph/pregel/_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
GraphResumeEvent,
)
from langgraph.channels.base import BaseChannel
from langgraph.channels.binop import _get_overwrite
from langgraph.channels.delta import DeltaChannel
from langgraph.channels.untracked_value import UntrackedValue
from langgraph.constants import TAG_HIDDEN
Expand Down Expand Up @@ -220,6 +221,11 @@ class PregelLoop:
# under the saver's `ORDER BY task_id, idx` sorting.
_exit_delta_writes: list[tuple[int, str, str, Any]] | None = None

# Delta channels that saw an Overwrite since the last checkpoint. These
# channels must snapshot after live update applies overwrite semantics so
# sparse replay starts from the same post-overwrite value.
_delta_channels_with_overwrite: set[str]

# The checkpoint_config that points at the parent loaded at `__enter__`
# (or the synthetic-empty checkpoint, on first run). We capture it
# eagerly because every `_put_checkpoint` advances `self.checkpoint_config`
Expand Down Expand Up @@ -676,6 +682,11 @@ def tick(self) -> bool:
def after_tick(self) -> None:
# finish superstep
writes = [w for t in self.tasks.values() for w in t.writes]
self._delta_channels_with_overwrite.update(
ch
for ch, v in writes
if isinstance(self.specs.get(ch), DeltaChannel) and _get_overwrite(v)[0]
)
# all tasks have finished
self.updated_channels = apply_writes(
self.checkpoint,
Expand Down Expand Up @@ -979,6 +990,11 @@ def _first(
manager=None,
updated_channels=updated_channels,
)
self._delta_channels_with_overwrite.update(
c
for c, v in input_writes
if isinstance(self.specs.get(c), DeltaChannel) and _get_overwrite(v)[0]
)
# apply input writes
updated_channels = apply_writes(
self.checkpoint,
Expand Down Expand Up @@ -1119,6 +1135,7 @@ def _put_checkpoint(self, metadata: CheckpointMetadata) -> None:
# create new checkpoint
channels_to_snapshot = (
delta_channels_to_snapshot(self.channels, new_counters)
| self._delta_channels_with_overwrite
if do_checkpoint
else set()
)
Expand All @@ -1135,6 +1152,8 @@ def _put_checkpoint(self, metadata: CheckpointMetadata) -> None:
)
for k in channels_to_snapshot:
new_counters[k] = (0, 0)
if do_checkpoint:
self._delta_channels_with_overwrite.difference_update(channels_to_snapshot)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's the same if we just clear it to empty set?
self._delta_channels_with_overwrite = set()?

non_zero = {k: v for k, v in new_counters.items() if v != (0, 0)}
if non_zero:
self.checkpoint_metadata["counters_since_delta_snapshot"] = non_zero
Expand Down Expand Up @@ -1217,7 +1236,10 @@ def _put_exit_delta_writes(self) -> None:
counters = dict(
self.checkpoint_metadata.get("counters_since_delta_snapshot") or {}
)
channels_to_snapshot = delta_channels_to_snapshot(self.channels, counters)
channels_to_snapshot = (
delta_channels_to_snapshot(self.channels, counters)
| self._delta_channels_with_overwrite
)

pending = [
(step, tid, ch, v)
Expand Down Expand Up @@ -1661,6 +1683,7 @@ def __enter__(self) -> Self:
)
self._delta_write_futs = []
self._error_handler_write_futs = []
self._delta_channels_with_overwrite = set()
self._exit_delta_writes = (
[] if self.durability == "exit" and self.checkpointer is not None else None
)
Expand Down Expand Up @@ -1918,6 +1941,7 @@ async def __aenter__(self) -> Self:
)
self._delta_write_futs = []
self._error_handler_write_futs = []
self._delta_channels_with_overwrite = set()
self._exit_delta_writes = (
[] if self.durability == "exit" and self.checkpointer is not None else None
)
Expand Down
87 changes: 87 additions & 0 deletions libs/langgraph/tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,93 @@ def respond(state: State) -> dict:
assert len(state.values["messages"]) == 4 # 2 human + 2 AI


def test_delta_channel_overwrite_superstep_snapshots() -> None:
def reducer(state: list[str], writes: Sequence[list[str]]) -> list[str]:
result = list(state)
for write in writes:
result.extend(write)
return result

class State(TypedDict):
items: Annotated[
list[str], DeltaChannel(reducer, list, snapshot_frequency=1000)
]

def node_a(state: State) -> dict:
return {"items": ["a"]}

def node_b(state: State) -> dict:
return {"items": Overwrite(["b"])}

def node_c(state: State) -> dict:
return {"items": ["c"]}

builder = StateGraph(State)
builder.add_node("node_a", node_a)
builder.add_node("node_b", node_b)
builder.add_node("node_c", node_c)
builder.add_edge(START, "node_a")
builder.add_edge("node_a", "node_b")
builder.add_edge("node_a", "node_c")

saver = InMemorySaver()
graph = builder.compile(checkpointer=saver)
config = {"configurable": {"thread_id": "overwrite-snapshot"}}

result = graph.invoke({"items": ["START"]}, config)
assert result == {"items": ["b"]}

saved = saver.get_tuple(config)
assert saved is not None
snapshot = saved.checkpoint["channel_values"].get("items")
assert isinstance(snapshot, _DeltaSnapshot)
assert snapshot.value == ["b"]
assert saved.metadata.get("counters_since_delta_snapshot", {}).get("items") is None


def test_delta_channel_replay_after_overwrite_snapshot() -> None:
def reducer(state: list[str], writes: Sequence[list[str]]) -> list[str]:
result = list(state)
for write in writes:
result.extend(write)
return result

class State(TypedDict):
items: Annotated[
list[str], DeltaChannel(reducer, list, snapshot_frequency=1000)
]

calls = 0

def node(state: State) -> dict:
nonlocal calls
calls += 1
if calls == 1:
return {"items": Overwrite(["reset"])}
return {"items": ["after"]}

builder = StateGraph(State)
builder.add_node("node", node)
builder.add_edge(START, "node")

saver = InMemorySaver()
graph = builder.compile(checkpointer=saver)
config = {"configurable": {"thread_id": "overwrite-replay"}}

assert graph.invoke({"items": ["before"]}, config) == {"items": ["reset"]}
first_saved = saver.get_tuple(config)
assert first_saved is not None
assert isinstance(
first_saved.checkpoint["channel_values"].get("items"), _DeltaSnapshot
)

assert graph.invoke({"items": []}, config) == {"items": ["reset", "after"]}
second_saved = saver.get_tuple(config)
assert second_saved is not None
assert "items" not in second_saved.checkpoint["channel_values"]
assert graph.get_state(config).values == {"items": ["reset", "after"]}


# ---------------------------------------------------------------------------
# DeltaChannel — dict reducer
# ---------------------------------------------------------------------------
Expand Down
Loading