diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 2286c761..d5c85dbb 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -998,6 +998,11 @@ def __iter__( in_memory_indices: None | np.ndarray = self._maybe_accumulate_indices(dataset_index_to_rows) for split in splits: sel = inv[split] + + # Use basic slicing for contiguous selections to avoid costly fancy indexing on the loaded memory + if len(sel) > 0 and (sel[-1] - sel[0] == len(sel) - 1 and np.all(np.diff(sel) == 1)): + sel = slice(sel[0], sel[-1] + 1) + data = in_memory_data[sel] yield { "X": data if not self._to_torch else to_torch(data, self._preload_to_gpu),