diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index d0803d3f..94fbdfe1 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -231,6 +231,7 @@ def __init__( self._train_datasets = [] self._shapes = [] self._sparse_dataset_elem_cache = {} + self._dense_split_buffer = None def __len__(self) -> int: return self._batch_sampler.n_batches(self.n_obs) @@ -979,7 +980,9 @@ def __iter__( if n > positions.size: positions = np.arange(n, dtype=np.intp) inv = inv_buffer[:n] - inv = positions[order] + inv[order] = positions[:n] + concat_splits = np.concatenate(splits) + sel = inv[concat_splits] # same as inv = positions[order] but reuses preallocated buffer raw_out: CSRContainer | np.ndarray = zsync.sync(self._index_datasets(dataset_index_to_rows)) @@ -990,19 +993,43 @@ def __iter__( dtype=_cupy_dtype(raw_out.dtype) if self._preload_to_gpu else raw_out.dtype, ) else: - in_memory_data = self._np_module.asarray(raw_out) + raw_out_arr = self._np_module.asarray(raw_out) + needed_len = len(concat_splits) + feature_shape = raw_out_arr.shape[1:] + if self._dense_split_buffer is None or self._dense_split_buffer.dtype != raw_out_arr.dtype: + self._dense_split_buffer = self._np_module.empty( + (needed_len, *feature_shape), + raw_out_arr.dtype, + ) + in_memory_data = self._dense_split_buffer[:needed_len] + self._np_module.take( + raw_out_arr, + sel, + axis=0, + out=in_memory_data, + **({"mode": "clip"} if not self._preload_to_gpu else {}), # cp.take doesn't have mode kwarg + ) concatenated_obs: None | pd.DataFrame = self._maybe_accumulate_obs(dataset_index_to_rows) + if concatenated_obs is not None: + concatenated_obs = concatenated_obs.iloc[sel] + in_memory_indices: None | np.ndarray = self._maybe_accumulate_indices(dataset_index_to_rows) + if in_memory_indices is not None: + in_memory_indices = in_memory_indices[sel] + + start = 0 for split in splits: - sel = inv[split] - data = in_memory_data[sel] + end = start + len(split) + # if data is sparse splits isn't applied to it yet + data = in_memory_data[inv[split]] if is_sparse else in_memory_data[slice(start, end)].copy() yield { "X": data if not self._to_torch else to_torch(data, self._preload_to_gpu), - "obs": concatenated_obs.iloc[sel] if concatenated_obs is not None else None, + "obs": concatenated_obs.iloc[start:end] if concatenated_obs is not None else None, "var": self._var, - "index": in_memory_indices[sel] if in_memory_indices is not None else None, + "index": in_memory_indices[start:end] if in_memory_indices is not None else None, } + start = end # https://github.com/cupy/cupy/issues/9625 if self._preload_to_gpu and is_sparse: