Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/datachain/lib/dc/datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Literal,
TypeVar,
cast,
get_args,
get_origin,
overload,
)

Expand Down Expand Up @@ -567,6 +569,61 @@ def json_to_model(json_value: str | dict):

return self.map(json_to_model, params=col, output={column: model})

def unnest(self, col: str) -> "DataChain":
"""Expands a list column into rows, one row per element. Sibling columns are
repeated for each emitted row. The element type is inferred from the column's
declared type. Empty lists and ``None`` values drop the source row.

Equivalent to SQL ``UNNEST``, pandas ``.explode()``, Spark ``explode``.

Example:
```py
# rows: {id: 1, items: [a, b]}, {id: 2, items: [c]}
chain.unnest("items")
# rows: {id: 1, items: a}, {id: 1, items: b}, {id: 2, items: c}
```

Args:
col: name of a top-level column whose type is ``list[T]`` or ``tuple[T]``.

Returns:
DataChain: A new DataChain with one row per element of ``col``.
"""
schema = self.signals_schema.clone_without_sys_signals()
if col not in schema.values:
raise SignalResolvingError([col], "is not a top-level column")

col_type, _ = unwrap_optional(schema.values[col])
origin = get_origin(col_type)
if origin not in (list, tuple):
raise TypeError(
f"unnest() expects a list/tuple column; column {col!r} has type "
f"{schema.values[col]!r}"
)
type_args = get_args(col_type)
if not type_args:
raise TypeError(
f"unnest() needs an element type for {col!r}; got bare {col_type!r}"
)
elem_type = type_args[0]
elem_is_model = isinstance(elem_type, type) and issubclass(elem_type, BaseModel)

params = [name for name in schema.values if name != col]
output = {name: schema.values[name] for name in params}
params.append(col)
output[col] = elem_type

def _unnest(*args):
items = args[-1]
if not items:
return
for item in items:
if elem_is_model and isinstance(item, dict):
item = elem_type(**item)
yield (*args[:-1], item)

return self.gen(_unnest, params=params, output=output)

@property
def namespace_name(self) -> str:
"""Current namespace name in which the chain is running."""
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,77 @@ def test_explode_raises_on_wrong_column_type(test_session):
chain.explode("f1.count")


def test_unnest_scalar_list(test_session):
chain = dc.read_values(
id=[1, 2, 3],
items=[["a", "b"], ["c"], []],
session=test_session,
)
rows = sorted(chain.unnest("items").to_list("id", "items"))
assert rows == [(1, "a"), (1, "b"), (2, "c")]


def test_unnest_drops_none(test_session):
chain = dc.read_values(
id=[1, 2],
items=[["a"], None],
session=test_session,
)
rows = sorted(chain.unnest("items").to_list("id", "items"))
assert rows == [(1, "a")]


def test_unnest_pydantic_elements(test_session):
class Item(BaseModel):
kind: str
weight: int

chain = dc.read_values(
id=[1, 2],
items=[
[Item(kind="x", weight=1), Item(kind="y", weight=2)],
[Item(kind="z", weight=3)],
],
session=test_session,
)
rows = sorted(chain.unnest("items").to_list("id", "items.kind", "items.weight"))
assert rows == [(1, "x", 1), (1, "y", 2), (2, "z", 3)]


def test_unnest_preserves_multiple_siblings(test_session):
chain = dc.read_values(
id=[1, 2],
name=["a", "b"],
items=[[10, 20], [30]],
session=test_session,
)
rows = sorted(chain.unnest("items").to_list("id", "name", "items"))
assert rows == [(1, "a", 10), (1, "a", 20), (2, "b", 30)]


def test_unnest_no_siblings(test_session):
chain = dc.read_values(
items=[[1, 2], [3]],
session=test_session,
)
rows = sorted(chain.unnest("items").to_values("items"))
assert rows == [1, 2, 3]


def test_unnest_raises_on_unknown_column(test_session):
chain = dc.read_values(id=[1, 2], session=test_session)

with pytest.raises(SignalResolvingError):
chain.unnest("missing")


def test_unnest_raises_on_scalar_column(test_session):
chain = dc.read_values(id=[1, 2], session=test_session)

with pytest.raises(TypeError):
chain.unnest("id")


def test_to_json_features(tmp_dir, test_session):
dc_to = dc.read_values(f1=features, num=range(len(features)), session=test_session)
path = tmp_dir / "test.json"
Expand Down
Loading