From 941c170c58c0c619ab1a837ee48f72276461f35b Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Wed, 17 Jun 2026 13:53:16 -0400 Subject: [PATCH 1/3] fix delta channel overwrite semantics --- .../spec/test_delta_channel_history.py | 44 ++++++++++++++++++ .../langgraph/checkpoint/postgres/base.py | 17 ++++--- .../langgraph/checkpoint/sqlite/_delta.py | 15 ++++-- .../tests/test_get_delta_channel_history.py | 38 +++++++++++++++ .../langgraph/checkpoint/base/__init__.py | 40 ++++++++++++++-- .../langgraph/checkpoint/memory/__init__.py | 12 ++++- libs/checkpoint/tests/test_memory.py | 46 +++++++++++++++++++ libs/langgraph/langgraph/channels/delta.py | 3 +- libs/langgraph/tests/test_channels.py | 12 +++++ 9 files changed, 207 insertions(+), 20 deletions(-) diff --git a/libs/checkpoint-conformance/langgraph/checkpoint/conformance/spec/test_delta_channel_history.py b/libs/checkpoint-conformance/langgraph/checkpoint/conformance/spec/test_delta_channel_history.py index 06fddd51dd7..b28cd0b790f 100644 --- a/libs/checkpoint-conformance/langgraph/checkpoint/conformance/spec/test_delta_channel_history.py +++ b/libs/checkpoint-conformance/langgraph/checkpoint/conformance/spec/test_delta_channel_history.py @@ -74,6 +74,49 @@ async def test_history_excludes_target_pending_writes( assert "extra" not in values, f"Target's writes should be excluded, got {values}" +async def test_history_overwrite_bypasses_same_step_writes( + saver: BaseCheckpointSaver, +) -> None: + tid = str(uuid4()) + channel = "ch" + + from langgraph.checkpoint.base import Checkpoint + from langgraph.checkpoint.base.id import uuid6 + + from langgraph.checkpoint.conformance.test_utils import generate_metadata + + parent_cfg = None + configs: list = [] + for step in range(2): + config = {"configurable": {"thread_id": tid, "checkpoint_ns": ""}} + if parent_cfg: + config["configurable"]["checkpoint_id"] = parent_cfg["configurable"][ + "checkpoint_id" + ] + cp = Checkpoint( + v=1, + id=str(uuid6(clock_seq=-1)), + ts="", + channel_values={}, + channel_versions={}, + versions_seen={}, + updated_channels=None, + ) + parent_cfg = await saver.aput(config, cp, generate_metadata(step=step), {}) + configs.append(parent_cfg) + + await saver.aput_writes( + configs[0], + [(channel, [1]), (channel, {"__overwrite__": [50]}), (channel, [2])], + str(uuid4()), + ) + result = await saver.aget_delta_channel_history( + config=configs[1], channels=[channel] + ) + values = [w[2] for w in result[channel]["writes"]] + assert values == [{"__overwrite__": [50]}], f"Expected overwrite only, got {values}" + + async def test_history_multi_channel( saver: BaseCheckpointSaver, ) -> None: @@ -212,6 +255,7 @@ async def test_history_migration_plain_value_as_seed( test_history_returns_writes_oldest_first, test_history_seed_is_nearest_snapshot, test_history_excludes_target_pending_writes, + test_history_overwrite_bypasses_same_step_writes, test_history_multi_channel, test_history_empty_channels_returns_empty, test_history_walk_to_root_no_seed, diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py index beb1e99724b..ca2f333963d 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py @@ -13,6 +13,7 @@ ChannelVersions, DeltaChannelHistory, PendingWrite, + _apply_delta_history_overwrite_semantics, get_checkpoint_id, ) from langgraph.checkpoint.serde.types import TASKS @@ -450,10 +451,10 @@ def _build_delta_channels_writes_history( "tuple[str, bytes]", (r["type"], r["blob"]) ) - # Sort writes per (channel, cid) newest-first by (task_id, idx) + # Sort writes per (channel, cid) oldest-first by (task_id, idx). for cid_map in writes_by_ch_by_cid.values(): for ws in cid_map.values(): - ws.sort(key=lambda w: (w[2], w[3]), reverse=True) + ws.sort(key=lambda w: (w[2], w[3])) result: dict[str, DeltaChannelHistory] = {} for ch in channels: @@ -462,11 +463,13 @@ def _build_delta_channels_writes_history( collected: list[PendingWrite] = [] cid_writes = writes_by_ch_by_cid.get(ch, {}) - for cid in chain_cids: - for type_tag, write_blob, task_id, _idx in cid_writes.get(cid, []): - val = self.serde.loads_typed((type_tag, write_blob)) - collected.append((task_id, ch, val)) - collected.reverse() + # Chain is newest-first; iterate oldest-first for the public order. + for cid in reversed(chain_cids): + step_writes: list[PendingWrite] = [ + (task_id, ch, self.serde.loads_typed((type_tag, write_blob))) + for type_tag, write_blob, task_id, _idx in cid_writes.get(cid, []) + ] + collected.extend(_apply_delta_history_overwrite_semantics(step_writes)) entry: DeltaChannelHistory = {"writes": collected} if seed_version is not None: diff --git a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/_delta.py b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/_delta.py index 1fe617ff7dd..c9cb450abe4 100644 --- a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/_delta.py +++ b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/_delta.py @@ -24,7 +24,11 @@ from collections.abc import Mapping, Sequence from typing import Any -from langgraph.checkpoint.base import DeltaChannelHistory, PendingWrite +from langgraph.checkpoint.base import ( + DeltaChannelHistory, + PendingWrite, + _apply_delta_history_overwrite_semantics, +) # Stage 1 streams ancestors of `target_cid` newest-first. The `<=` # predicate keeps target itself in the stream so we can read its @@ -161,10 +165,11 @@ def build_delta_channels_writes_history( collected: list[PendingWrite] = [] # Chain is newest-first; iterate oldest-first for the public order. for cid in reversed(chain_cids): - for type_tag, value_blob, task_id, _idx in cid_writes.get(cid, []): - collected.append( - (task_id, ch, serde.loads_typed((type_tag, value_blob))) - ) + step_writes: list[PendingWrite] = [ + (task_id, ch, serde.loads_typed((type_tag, value_blob))) + for type_tag, value_blob, task_id, _idx in cid_writes.get(cid, []) + ] + collected.extend(_apply_delta_history_overwrite_semantics(step_writes)) entry: DeltaChannelHistory = {"writes": collected} if ch in seeded: entry["seed"] = seed_val_by_ch[ch] diff --git a/libs/checkpoint-sqlite/tests/test_get_delta_channel_history.py b/libs/checkpoint-sqlite/tests/test_get_delta_channel_history.py index 8e19c8b857e..9346ea34f6b 100644 --- a/libs/checkpoint-sqlite/tests/test_get_delta_channel_history.py +++ b/libs/checkpoint-sqlite/tests/test_get_delta_channel_history.py @@ -33,6 +33,8 @@ pytest.importorskip("langgraph.graph", reason="langgraph core not installed") from langgraph.channels.delta import DeltaChannel # type: ignore[import-untyped] # noqa: E402,I001 +from langgraph.checkpoint.base import Checkpoint # noqa: E402 +from langgraph.checkpoint.base.id import uuid6 # noqa: E402 from langgraph.checkpoint.serde.types import _DeltaSnapshot # noqa: E402 from langgraph.graph import END, START, StateGraph # type: ignore[import-untyped] # noqa: E402 from typing_extensions import TypedDict # noqa: E402 @@ -213,6 +215,42 @@ def test_seed_omitted_when_walk_reaches_root_sync() -> None: assert entry["writes"] == [] +def test_overwrite_bypasses_same_step_writes_sync() -> None: + with SqliteSaver.from_conn_string(":memory:") as saver: + config: RunnableConfig = { + "configurable": {"thread_id": "overwrite-sync", "checkpoint_ns": ""} + } + cp1 = Checkpoint( + v=1, + id=str(uuid6(clock_seq=-1)), + ts="", + channel_values={}, + channel_versions={}, + versions_seen={}, + updated_channels=None, + ) + cfg1 = saver.put(config, cp1, {"source": "loop", "step": 0}, {}) + cp2 = Checkpoint( + v=1, + id=str(uuid6(clock_seq=-1)), + ts="", + channel_values={}, + channel_versions={}, + versions_seen={}, + updated_channels=None, + ) + cfg2 = saver.put(cfg1, cp2, {"source": "loop", "step": 1}, {}) + saver.put_writes( + cfg1, + [("items", [1]), ("items", {"__overwrite__": [50]}), ("items", [2])], + "task", + ) + + result = saver.get_delta_channel_history(config=cfg2, channels=["items"]) + values = [w[2] for w in result["items"]["writes"]] + assert values == [{"__overwrite__": [50]}] + + # --------------------------------------------------------------------------- # Async: AsyncSqliteSaver # --------------------------------------------------------------------------- diff --git a/libs/checkpoint/langgraph/checkpoint/base/__init__.py b/libs/checkpoint/langgraph/checkpoint/base/__init__.py index 6e42061190d..d0435a6ed96 100644 --- a/libs/checkpoint/langgraph/checkpoint/base/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/base/__init__.py @@ -34,6 +34,24 @@ logger = logging.getLogger(__name__) +_OVERWRITE_KEY = "__overwrite__" + + +def _is_overwrite_value(value: Any) -> bool: + if isinstance(value, dict) and len(value) == 1 and _OVERWRITE_KEY in value: + return True + return value.__class__.__name__ == "Overwrite" and hasattr(value, "value") + + +def _apply_delta_history_overwrite_semantics( + writes: Sequence[PendingWrite], +) -> list[PendingWrite]: + for write in writes: + if _is_overwrite_value(write[2]): + return [write] + return list(writes) + + # Marked as total=False to allow for future expansion. class CheckpointMetadata(TypedDict, total=False): """Metadata associated with a checkpoint.""" @@ -631,10 +649,17 @@ def get_delta_channel_history( if tup is None: break if tup.pending_writes: - for write in reversed(tup.pending_writes): + step_writes_by_ch: dict[str, list[PendingWrite]] = { + ch: [] for ch in remaining + } + for write in tup.pending_writes: ch = write[1] if ch in remaining: - collected_by_ch[ch].append(write) + step_writes_by_ch[ch].append(write) + for ch, step_writes in step_writes_by_ch.items(): + collected_by_ch[ch].extend( + reversed(_apply_delta_history_overwrite_semantics(step_writes)) + ) for ch in list(remaining): if ch in tup.checkpoint["channel_values"]: seed_by_ch[ch] = tup.checkpoint["channel_values"][ch] @@ -672,10 +697,17 @@ async def aget_delta_channel_history( if tup is None: break if tup.pending_writes: - for write in reversed(tup.pending_writes): + step_writes_by_ch: dict[str, list[PendingWrite]] = { + ch: [] for ch in remaining + } + for write in tup.pending_writes: ch = write[1] if ch in remaining: - collected_by_ch[ch].append(write) + step_writes_by_ch[ch].append(write) + for ch, step_writes in step_writes_by_ch.items(): + collected_by_ch[ch].extend( + reversed(_apply_delta_history_overwrite_semantics(step_writes)) + ) for ch in list(remaining): if ch in tup.checkpoint["channel_values"]: seed_by_ch[ch] = tup.checkpoint["channel_values"][ch] diff --git a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py index 80043c71060..8e0a6acd76f 100644 --- a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py @@ -23,6 +23,7 @@ DeltaChannelHistory, PendingWrite, SerializerProtocol, + _apply_delta_history_overwrite_semantics, get_checkpoint_id, get_checkpoint_metadata, ) @@ -200,8 +201,11 @@ def get_delta_channel_history( terminated_here.add(ch) step_writes = self.writes.get((thread_id, checkpoint_ns, cp_id), {}) + step_writes_by_ch: dict[str, list[PendingWrite]] = { + ch: [] for ch in remaining + } for (_task_id, _idx), (tid, ch, serialized, _) in sorted( - step_writes.items(), reverse=True + step_writes.items() ): if ch not in remaining: continue @@ -210,9 +214,13 @@ def get_delta_channel_history( blob_value, _DeltaSnapshot ): continue - collected_by_ch[ch].append( + step_writes_by_ch[ch].append( (tid, ch, self.serde.loads_typed(serialized)) ) + for ch, writes in step_writes_by_ch.items(): + collected_by_ch[ch].extend( + reversed(_apply_delta_history_overwrite_semantics(writes)) + ) for ch in terminated_here: seed_by_ch[ch] = blob_value_by_ch[ch] diff --git a/libs/checkpoint/tests/test_memory.py b/libs/checkpoint/tests/test_memory.py index 70e22e0d826..eec879b0c9d 100644 --- a/libs/checkpoint/tests/test_memory.py +++ b/libs/checkpoint/tests/test_memory.py @@ -387,6 +387,52 @@ def test_get_channel_writes_collects_ancestor_writes_only(self) -> None: values = [v for _, _, v in result["writes"]] assert values == [{"content": "hi"}] + def test_get_channel_writes_overwrite_bypasses_same_step_writes(self) -> None: + saver = InMemorySaver() + serde = JsonPlusSerializer() + + thread_id, ns, channel = "t1", "", "messages" + + cp1 = empty_checkpoint() + cp1["id"] = "cp1" + cp2 = empty_checkpoint() + cp2["id"] = "cp2" + saver.storage[thread_id][ns] = { + "cp1": (serde.dumps_typed(cp1), serde.dumps_typed({}), None), + "cp2": (serde.dumps_typed(cp2), serde.dumps_typed({}), "cp1"), + } + saver.writes[(thread_id, ns, "cp1")][("task1", 0)] = ( + "task1", + channel, + serde.dumps_typed([1]), + "", + ) + saver.writes[(thread_id, ns, "cp1")][("task2", 0)] = ( + "task2", + channel, + serde.dumps_typed({"__overwrite__": [50]}), + "", + ) + saver.writes[(thread_id, ns, "cp1")][("task3", 0)] = ( + "task3", + channel, + serde.dumps_typed([2]), + "", + ) + + config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": ns, + "checkpoint_id": "cp2", + } + } + result = saver.get_delta_channel_history(config=config, channels=[channel])[ + channel + ] + values = [v for _, _, v in result["writes"]] + assert values == [{"__overwrite__": [50]}] + def test_get_channel_writes_at_root_returns_empty(self) -> None: """Reconstructing the root checkpoint's state: no ancestors → [].""" saver = InMemorySaver() diff --git a/libs/langgraph/langgraph/channels/delta.py b/libs/langgraph/langgraph/channels/delta.py index 052ced50df5..0523d55bde9 100644 --- a/libs/langgraph/langgraph/channels/delta.py +++ b/libs/langgraph/langgraph/channels/delta.py @@ -177,8 +177,7 @@ def update(self, values: Sequence[Any]) -> bool: if overwrite_value is not None else self.typ() ) - remaining = [v for i, v in enumerate(values) if i != overwrite_idx] - self.value = self.reducer(base, remaining) if remaining else base + self.value = base return True base = self.typ() if self.value is MISSING else self.value self.value = self.reducer(base, list(values)) diff --git a/libs/langgraph/tests/test_channels.py b/libs/langgraph/tests/test_channels.py index 9dfa7158a41..0df856fae13 100644 --- a/libs/langgraph/tests/test_channels.py +++ b/libs/langgraph/tests/test_channels.py @@ -186,6 +186,18 @@ def test_delta_channel_overwrite() -> None: assert ch.get()[0].content == "new" +def test_delta_channel_overwrite_bypasses_same_step_reducer_writes() -> None: + def list_reducer(state: list, writes: list) -> list: + out = list(state) + for w in writes: + out.extend(w) + return out + + ch = DeltaChannel(list_reducer, list).from_checkpoint(MISSING) + ch.update([[1], Overwrite([50]), [2]]) + assert ch.get() == [50] + + def test_delta_channel_remove_message_and_replay() -> None: """RemoveMessage must round-trip correctly when writes are replayed.""" spec = DeltaChannel(_messages_delta_reducer, list) From 4210feccd98be2426dd5fb24964eb1208ee450a9 Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Wed, 17 Jun 2026 14:07:19 -0400 Subject: [PATCH 2/3] simplify delta overwrite fix --- .../spec/test_delta_channel_history.py | 44 ------------------ .../langgraph/checkpoint/postgres/base.py | 17 +++---- .../langgraph/checkpoint/sqlite/_delta.py | 15 ++---- .../tests/test_get_delta_channel_history.py | 38 --------------- .../langgraph/checkpoint/base/__init__.py | 40 ++-------------- .../langgraph/checkpoint/memory/__init__.py | 12 +---- libs/checkpoint/tests/test_memory.py | 46 ------------------- 7 files changed, 18 insertions(+), 194 deletions(-) diff --git a/libs/checkpoint-conformance/langgraph/checkpoint/conformance/spec/test_delta_channel_history.py b/libs/checkpoint-conformance/langgraph/checkpoint/conformance/spec/test_delta_channel_history.py index b28cd0b790f..06fddd51dd7 100644 --- a/libs/checkpoint-conformance/langgraph/checkpoint/conformance/spec/test_delta_channel_history.py +++ b/libs/checkpoint-conformance/langgraph/checkpoint/conformance/spec/test_delta_channel_history.py @@ -74,49 +74,6 @@ async def test_history_excludes_target_pending_writes( assert "extra" not in values, f"Target's writes should be excluded, got {values}" -async def test_history_overwrite_bypasses_same_step_writes( - saver: BaseCheckpointSaver, -) -> None: - tid = str(uuid4()) - channel = "ch" - - from langgraph.checkpoint.base import Checkpoint - from langgraph.checkpoint.base.id import uuid6 - - from langgraph.checkpoint.conformance.test_utils import generate_metadata - - parent_cfg = None - configs: list = [] - for step in range(2): - config = {"configurable": {"thread_id": tid, "checkpoint_ns": ""}} - if parent_cfg: - config["configurable"]["checkpoint_id"] = parent_cfg["configurable"][ - "checkpoint_id" - ] - cp = Checkpoint( - v=1, - id=str(uuid6(clock_seq=-1)), - ts="", - channel_values={}, - channel_versions={}, - versions_seen={}, - updated_channels=None, - ) - parent_cfg = await saver.aput(config, cp, generate_metadata(step=step), {}) - configs.append(parent_cfg) - - await saver.aput_writes( - configs[0], - [(channel, [1]), (channel, {"__overwrite__": [50]}), (channel, [2])], - str(uuid4()), - ) - result = await saver.aget_delta_channel_history( - config=configs[1], channels=[channel] - ) - values = [w[2] for w in result[channel]["writes"]] - assert values == [{"__overwrite__": [50]}], f"Expected overwrite only, got {values}" - - async def test_history_multi_channel( saver: BaseCheckpointSaver, ) -> None: @@ -255,7 +212,6 @@ async def test_history_migration_plain_value_as_seed( test_history_returns_writes_oldest_first, test_history_seed_is_nearest_snapshot, test_history_excludes_target_pending_writes, - test_history_overwrite_bypasses_same_step_writes, test_history_multi_channel, test_history_empty_channels_returns_empty, test_history_walk_to_root_no_seed, diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py index ca2f333963d..beb1e99724b 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py @@ -13,7 +13,6 @@ ChannelVersions, DeltaChannelHistory, PendingWrite, - _apply_delta_history_overwrite_semantics, get_checkpoint_id, ) from langgraph.checkpoint.serde.types import TASKS @@ -451,10 +450,10 @@ def _build_delta_channels_writes_history( "tuple[str, bytes]", (r["type"], r["blob"]) ) - # Sort writes per (channel, cid) oldest-first by (task_id, idx). + # Sort writes per (channel, cid) newest-first by (task_id, idx) for cid_map in writes_by_ch_by_cid.values(): for ws in cid_map.values(): - ws.sort(key=lambda w: (w[2], w[3])) + ws.sort(key=lambda w: (w[2], w[3]), reverse=True) result: dict[str, DeltaChannelHistory] = {} for ch in channels: @@ -463,13 +462,11 @@ def _build_delta_channels_writes_history( collected: list[PendingWrite] = [] cid_writes = writes_by_ch_by_cid.get(ch, {}) - # Chain is newest-first; iterate oldest-first for the public order. - for cid in reversed(chain_cids): - step_writes: list[PendingWrite] = [ - (task_id, ch, self.serde.loads_typed((type_tag, write_blob))) - for type_tag, write_blob, task_id, _idx in cid_writes.get(cid, []) - ] - collected.extend(_apply_delta_history_overwrite_semantics(step_writes)) + for cid in chain_cids: + for type_tag, write_blob, task_id, _idx in cid_writes.get(cid, []): + val = self.serde.loads_typed((type_tag, write_blob)) + collected.append((task_id, ch, val)) + collected.reverse() entry: DeltaChannelHistory = {"writes": collected} if seed_version is not None: diff --git a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/_delta.py b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/_delta.py index c9cb450abe4..1fe617ff7dd 100644 --- a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/_delta.py +++ b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/_delta.py @@ -24,11 +24,7 @@ from collections.abc import Mapping, Sequence from typing import Any -from langgraph.checkpoint.base import ( - DeltaChannelHistory, - PendingWrite, - _apply_delta_history_overwrite_semantics, -) +from langgraph.checkpoint.base import DeltaChannelHistory, PendingWrite # Stage 1 streams ancestors of `target_cid` newest-first. The `<=` # predicate keeps target itself in the stream so we can read its @@ -165,11 +161,10 @@ def build_delta_channels_writes_history( collected: list[PendingWrite] = [] # Chain is newest-first; iterate oldest-first for the public order. for cid in reversed(chain_cids): - step_writes: list[PendingWrite] = [ - (task_id, ch, serde.loads_typed((type_tag, value_blob))) - for type_tag, value_blob, task_id, _idx in cid_writes.get(cid, []) - ] - collected.extend(_apply_delta_history_overwrite_semantics(step_writes)) + for type_tag, value_blob, task_id, _idx in cid_writes.get(cid, []): + collected.append( + (task_id, ch, serde.loads_typed((type_tag, value_blob))) + ) entry: DeltaChannelHistory = {"writes": collected} if ch in seeded: entry["seed"] = seed_val_by_ch[ch] diff --git a/libs/checkpoint-sqlite/tests/test_get_delta_channel_history.py b/libs/checkpoint-sqlite/tests/test_get_delta_channel_history.py index 9346ea34f6b..8e19c8b857e 100644 --- a/libs/checkpoint-sqlite/tests/test_get_delta_channel_history.py +++ b/libs/checkpoint-sqlite/tests/test_get_delta_channel_history.py @@ -33,8 +33,6 @@ pytest.importorskip("langgraph.graph", reason="langgraph core not installed") from langgraph.channels.delta import DeltaChannel # type: ignore[import-untyped] # noqa: E402,I001 -from langgraph.checkpoint.base import Checkpoint # noqa: E402 -from langgraph.checkpoint.base.id import uuid6 # noqa: E402 from langgraph.checkpoint.serde.types import _DeltaSnapshot # noqa: E402 from langgraph.graph import END, START, StateGraph # type: ignore[import-untyped] # noqa: E402 from typing_extensions import TypedDict # noqa: E402 @@ -215,42 +213,6 @@ def test_seed_omitted_when_walk_reaches_root_sync() -> None: assert entry["writes"] == [] -def test_overwrite_bypasses_same_step_writes_sync() -> None: - with SqliteSaver.from_conn_string(":memory:") as saver: - config: RunnableConfig = { - "configurable": {"thread_id": "overwrite-sync", "checkpoint_ns": ""} - } - cp1 = Checkpoint( - v=1, - id=str(uuid6(clock_seq=-1)), - ts="", - channel_values={}, - channel_versions={}, - versions_seen={}, - updated_channels=None, - ) - cfg1 = saver.put(config, cp1, {"source": "loop", "step": 0}, {}) - cp2 = Checkpoint( - v=1, - id=str(uuid6(clock_seq=-1)), - ts="", - channel_values={}, - channel_versions={}, - versions_seen={}, - updated_channels=None, - ) - cfg2 = saver.put(cfg1, cp2, {"source": "loop", "step": 1}, {}) - saver.put_writes( - cfg1, - [("items", [1]), ("items", {"__overwrite__": [50]}), ("items", [2])], - "task", - ) - - result = saver.get_delta_channel_history(config=cfg2, channels=["items"]) - values = [w[2] for w in result["items"]["writes"]] - assert values == [{"__overwrite__": [50]}] - - # --------------------------------------------------------------------------- # Async: AsyncSqliteSaver # --------------------------------------------------------------------------- diff --git a/libs/checkpoint/langgraph/checkpoint/base/__init__.py b/libs/checkpoint/langgraph/checkpoint/base/__init__.py index d0435a6ed96..6e42061190d 100644 --- a/libs/checkpoint/langgraph/checkpoint/base/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/base/__init__.py @@ -34,24 +34,6 @@ logger = logging.getLogger(__name__) -_OVERWRITE_KEY = "__overwrite__" - - -def _is_overwrite_value(value: Any) -> bool: - if isinstance(value, dict) and len(value) == 1 and _OVERWRITE_KEY in value: - return True - return value.__class__.__name__ == "Overwrite" and hasattr(value, "value") - - -def _apply_delta_history_overwrite_semantics( - writes: Sequence[PendingWrite], -) -> list[PendingWrite]: - for write in writes: - if _is_overwrite_value(write[2]): - return [write] - return list(writes) - - # Marked as total=False to allow for future expansion. class CheckpointMetadata(TypedDict, total=False): """Metadata associated with a checkpoint.""" @@ -649,17 +631,10 @@ def get_delta_channel_history( if tup is None: break if tup.pending_writes: - step_writes_by_ch: dict[str, list[PendingWrite]] = { - ch: [] for ch in remaining - } - for write in tup.pending_writes: + for write in reversed(tup.pending_writes): ch = write[1] if ch in remaining: - step_writes_by_ch[ch].append(write) - for ch, step_writes in step_writes_by_ch.items(): - collected_by_ch[ch].extend( - reversed(_apply_delta_history_overwrite_semantics(step_writes)) - ) + collected_by_ch[ch].append(write) for ch in list(remaining): if ch in tup.checkpoint["channel_values"]: seed_by_ch[ch] = tup.checkpoint["channel_values"][ch] @@ -697,17 +672,10 @@ async def aget_delta_channel_history( if tup is None: break if tup.pending_writes: - step_writes_by_ch: dict[str, list[PendingWrite]] = { - ch: [] for ch in remaining - } - for write in tup.pending_writes: + for write in reversed(tup.pending_writes): ch = write[1] if ch in remaining: - step_writes_by_ch[ch].append(write) - for ch, step_writes in step_writes_by_ch.items(): - collected_by_ch[ch].extend( - reversed(_apply_delta_history_overwrite_semantics(step_writes)) - ) + collected_by_ch[ch].append(write) for ch in list(remaining): if ch in tup.checkpoint["channel_values"]: seed_by_ch[ch] = tup.checkpoint["channel_values"][ch] diff --git a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py index 8e0a6acd76f..80043c71060 100644 --- a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py @@ -23,7 +23,6 @@ DeltaChannelHistory, PendingWrite, SerializerProtocol, - _apply_delta_history_overwrite_semantics, get_checkpoint_id, get_checkpoint_metadata, ) @@ -201,11 +200,8 @@ def get_delta_channel_history( terminated_here.add(ch) step_writes = self.writes.get((thread_id, checkpoint_ns, cp_id), {}) - step_writes_by_ch: dict[str, list[PendingWrite]] = { - ch: [] for ch in remaining - } for (_task_id, _idx), (tid, ch, serialized, _) in sorted( - step_writes.items() + step_writes.items(), reverse=True ): if ch not in remaining: continue @@ -214,13 +210,9 @@ def get_delta_channel_history( blob_value, _DeltaSnapshot ): continue - step_writes_by_ch[ch].append( + collected_by_ch[ch].append( (tid, ch, self.serde.loads_typed(serialized)) ) - for ch, writes in step_writes_by_ch.items(): - collected_by_ch[ch].extend( - reversed(_apply_delta_history_overwrite_semantics(writes)) - ) for ch in terminated_here: seed_by_ch[ch] = blob_value_by_ch[ch] diff --git a/libs/checkpoint/tests/test_memory.py b/libs/checkpoint/tests/test_memory.py index eec879b0c9d..70e22e0d826 100644 --- a/libs/checkpoint/tests/test_memory.py +++ b/libs/checkpoint/tests/test_memory.py @@ -387,52 +387,6 @@ def test_get_channel_writes_collects_ancestor_writes_only(self) -> None: values = [v for _, _, v in result["writes"]] assert values == [{"content": "hi"}] - def test_get_channel_writes_overwrite_bypasses_same_step_writes(self) -> None: - saver = InMemorySaver() - serde = JsonPlusSerializer() - - thread_id, ns, channel = "t1", "", "messages" - - cp1 = empty_checkpoint() - cp1["id"] = "cp1" - cp2 = empty_checkpoint() - cp2["id"] = "cp2" - saver.storage[thread_id][ns] = { - "cp1": (serde.dumps_typed(cp1), serde.dumps_typed({}), None), - "cp2": (serde.dumps_typed(cp2), serde.dumps_typed({}), "cp1"), - } - saver.writes[(thread_id, ns, "cp1")][("task1", 0)] = ( - "task1", - channel, - serde.dumps_typed([1]), - "", - ) - saver.writes[(thread_id, ns, "cp1")][("task2", 0)] = ( - "task2", - channel, - serde.dumps_typed({"__overwrite__": [50]}), - "", - ) - saver.writes[(thread_id, ns, "cp1")][("task3", 0)] = ( - "task3", - channel, - serde.dumps_typed([2]), - "", - ) - - config: RunnableConfig = { - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": ns, - "checkpoint_id": "cp2", - } - } - result = saver.get_delta_channel_history(config=config, channels=[channel])[ - channel - ] - values = [v for _, _, v in result["writes"]] - assert values == [{"__overwrite__": [50]}] - def test_get_channel_writes_at_root_returns_empty(self) -> None: """Reconstructing the root checkpoint's state: no ancestors → [].""" saver = InMemorySaver() From af62aff8ac5eae2afd838e4faa172a04cf3b5b0c Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Wed, 17 Jun 2026 14:11:23 -0400 Subject: [PATCH 3/3] add delta overwrite graph coverage --- libs/langgraph/tests/test_pregel.py | 106 ++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 0aae1318e13..f44a4b00b05 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -9281,6 +9281,13 @@ def tool_node(state: State) -> State: assert state.values.get("dictionary") == {"session_resource": "legal_value"} +def _delta_list_reducer(state: list, writes: Sequence[list]) -> list: + out = list(state) + for write in writes: + out.extend(write) + return out + + @pytest.mark.parametrize("as_json", [False, True]) def test_overwrite_sequential( sync_checkpointer: BaseCheckpointSaver, as_json: bool @@ -9388,6 +9395,105 @@ def node_c(state: State): graph.invoke({"messages": ["START"]}, config) +@pytest.mark.parametrize("as_json", [False, True]) +def test_delta_channel_overwrite_sequential( + sync_checkpointer: BaseCheckpointSaver, as_json: bool +) -> None: + class State(TypedDict): + messages: Annotated[list, DeltaChannel(_delta_list_reducer)] + + def node_a(state: State): + return {"messages": ["a"]} + + def node_b(state: State): + overwrite = {"__overwrite__": ["b"]} if as_json else Overwrite(["b"]) + return {"messages": overwrite} + + builder = StateGraph(State) + builder.add_node("node_a", node_a) + builder.add_node("node_b", node_b) + builder.add_edge(START, "node_a") + builder.add_edge("node_a", "node_b") + + graph = builder.compile(checkpointer=sync_checkpointer) + config = {"configurable": {"thread_id": "delta-overwrite-sequential"}} + result = graph.invoke({"messages": ["START"]}, config) + assert result == {"messages": ["b"]} + + +@pytest.mark.parametrize("as_json", [False, True]) +def test_delta_channel_overwrite_parallel( + sync_checkpointer: BaseCheckpointSaver, as_json: bool +) -> None: + class State(TypedDict): + messages: Annotated[list, DeltaChannel(_delta_list_reducer)] + + def node_a(state: State): + return {"messages": ["a"]} + + def node_b(state: State): + overwrite = {"__overwrite__": ["b"]} if as_json else Overwrite(["b"]) + return {"messages": overwrite} + + def node_c(state: State): + return {"messages": ["c"]} + + def node_d(state: State): + return {"messages": ["d"]} + + builder = StateGraph(State) + builder.add_node("node_a", node_a) + builder.add_node("node_b", node_b) + builder.add_node("node_c", node_c) + builder.add_node("node_d", node_d) + builder.add_edge(START, "node_a") + builder.add_edge("node_a", "node_b") + builder.add_edge("node_a", "node_c") + builder.add_edge("node_b", "node_d") + builder.add_edge("node_c", "node_d") + + graph = builder.compile(checkpointer=sync_checkpointer) + config = {"configurable": {"thread_id": "delta-overwrite-parallel"}} + result = graph.invoke({"messages": ["START"]}, config) + assert result == {"messages": ["b", "d"]} + + +@pytest.mark.parametrize("as_json", [False, True]) +def test_delta_channel_overwrite_parallel_error( + sync_checkpointer: BaseCheckpointSaver, as_json: bool +) -> None: + class State(TypedDict): + messages: Annotated[list, DeltaChannel(_delta_list_reducer)] + + def node_a(state: State): + return {"messages": ["a"]} + + def node_b(state: State): + overwrite = {"__overwrite__": ["b"]} if as_json else Overwrite(["b"]) + return {"messages": overwrite} + + def node_c(state: State): + overwrite = {"__overwrite__": ["c"]} if as_json else Overwrite(["c"]) + return {"messages": overwrite} + + builder = StateGraph(State) + builder.add_node("node_a", node_a) + builder.add_node("node_b", node_b) + builder.add_node("node_c", node_c) + builder.add_edge(START, "node_a") + builder.add_edge("node_a", "node_b") + builder.add_edge("node_a", "node_c") + builder.add_edge("node_b", END) + builder.add_edge("node_c", END) + + graph = builder.compile(checkpointer=sync_checkpointer) + config = {"configurable": {"thread_id": "delta-overwrite-parallel-error"}} + with pytest.raises( + InvalidUpdateError, match="Can receive only one Overwrite value per super-step." + ): + graph.invoke({"messages": ["START"]}, config) + + def test_fork_does_not_apply_pending_writes( sync_checkpointer: BaseCheckpointSaver, ) -> None: