Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions src/annbatch/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def __init__(
self._train_datasets = []
self._shapes = []
self._sparse_dataset_elem_cache = {}
self._dense_split_buffer = None

def __len__(self) -> int:
return self._batch_sampler.n_batches(self.n_obs)
Expand Down Expand Up @@ -979,7 +980,9 @@ def __iter__(
if n > positions.size:
positions = np.arange(n, dtype=np.intp)
inv = inv_buffer[:n]
inv = positions[order]
inv[order] = positions[:n]
concat_splits = np.concatenate(splits)
sel = inv[concat_splits] # same as inv = positions[order] but reuses preallocated buffer

raw_out: CSRContainer | np.ndarray = zsync.sync(self._index_datasets(dataset_index_to_rows))

Expand All @@ -990,19 +993,43 @@ def __iter__(
dtype=_cupy_dtype(raw_out.dtype) if self._preload_to_gpu else raw_out.dtype,
)
else:
in_memory_data = self._np_module.asarray(raw_out)
raw_out_arr = self._np_module.asarray(raw_out)
needed_len = len(concat_splits)
feature_shape = raw_out_arr.shape[1:]
if self._dense_split_buffer is None or self._dense_split_buffer.dtype != raw_out_arr.dtype:
self._dense_split_buffer = self._np_module.empty(
(needed_len, *feature_shape),
raw_out_arr.dtype,
)
in_memory_data = self._dense_split_buffer[:needed_len]
self._np_module.take(

@ilan-gold ilan-gold Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I am not 100% sure this a safe operation on the GPU because AFAIK, operations happen asynchronously. Thus you may hit this line while your model is fitting on a batch derived from in_memory_data but you are then overriding in_memory_data. #105 It may make sense to have a pool

@selmanozleyen selmanozleyen Jun 16, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, you are right. But how does a pool solve this? How can we know if the model is done with that data? Isn't copying here our only option? in_memory_data[slice(start, end)].copy()

@ilan-gold ilan-gold Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I guess that would be the only way. Althought that is a good point, the normal indexing on the GPU may copy without .copy(). I am not sure. I hadn't considered that - it might be worth checking.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pool was just spitballing.

raw_out_arr,
sel,
axis=0,
out=in_memory_data,
**({"mode": "clip"} if not self._preload_to_gpu else {}), # cp.take doesn't have mode kwarg
)

concatenated_obs: None | pd.DataFrame = self._maybe_accumulate_obs(dataset_index_to_rows)
if concatenated_obs is not None:
concatenated_obs = concatenated_obs.iloc[sel]

in_memory_indices: None | np.ndarray = self._maybe_accumulate_indices(dataset_index_to_rows)
if in_memory_indices is not None:
in_memory_indices = in_memory_indices[sel]

start = 0
for split in splits:
sel = inv[split]
data = in_memory_data[sel]
end = start + len(split)
# if data is sparse splits isn't applied to it yet
data = in_memory_data[inv[split]] if is_sparse else in_memory_data[slice(start, end)].copy()
yield {
"X": data if not self._to_torch else to_torch(data, self._preload_to_gpu),
"obs": concatenated_obs.iloc[sel] if concatenated_obs is not None else None,
"obs": concatenated_obs.iloc[start:end] if concatenated_obs is not None else None,
"var": self._var,
"index": in_memory_indices[sel] if in_memory_indices is not None else None,
"index": in_memory_indices[start:end] if in_memory_indices is not None else None,
}
start = end

# https://github.com/cupy/cupy/issues/9625
if self._preload_to_gpu and is_sparse:
Expand Down
Loading