Skip to content
47 changes: 46 additions & 1 deletion src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,45 @@ class BucketStatus(NamedTuple):
error: str | None = None


def _anon_fallback(method):
"""Retry a Client method once with anonymous access on PermissionError.

Only marks the bucket as anon-needed if the retry actually succeeds, so
genuinely inaccessible buckets keep raising clean errors instead of
being silently cached as anon.
"""

@functools.wraps(method)
def wrapper(self, *args, **kwargs):
try:
return method(self, *args, **kwargs)
except PermissionError:
if self.fs_kwargs.get("anon") or self._bucket_needs_anon(self.name):
raise
saved_fs = self._fs
self._fs = type(self).create_fs(**{**self.fs_kwargs, "anon": True})
try:
result = method(self, *args, **kwargs)
except PermissionError:
self._fs = saved_fs
raise
Comment thread
shcheklein marked this conversation as resolved.
Outdated
self._mark_bucket_anon(self.name)
Comment thread
shcheklein marked this conversation as resolved.
Outdated
return result

return wrapper
Comment thread
shcheklein marked this conversation as resolved.
Outdated


class Client(ABC):
MAX_THREADS = multiprocessing.cpu_count()
FS_CLASS: ClassVar[type["AbstractFileSystem"]]
PREFIX: ClassVar[str]
protocol: ClassVar[str]
# client_config keys this backend treats as credentials.
CREDENTIAL_KEYS: ClassVar[frozenset[str]] = frozenset()
# Process-local cache of (protocol, bucket) pairs that have been
# resolved as needing anonymous access. Populated only after an anon
# retry actually succeeds.
_ANON_BUCKETS: ClassVar[set[tuple[str, str]]] = set()
Comment thread
shcheklein marked this conversation as resolved.
Outdated

@classmethod
def has_explicit_credentials(cls, client_config: dict | None) -> bool:
Expand All @@ -78,6 +110,14 @@ def has_explicit_credentials(cls, client_config: dict | None) -> bool:
return False
return any(k in client_config for k in cls.CREDENTIAL_KEYS)

@classmethod
def _bucket_needs_anon(cls, name: str) -> bool:
return (cls.protocol, name) in cls._ANON_BUCKETS

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can there be a distinction by prefix - a prefix inside a bucket allows anon access, other prefix doesn't?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, this can be edge case.
The problem is that implementing prefix based cache is much more complex than current one and has some open questions. Issues:

  • need to extract the path from every fsspec method call (different methods take path in different positions
  • need to pick a prefix granularity (per-segment? whole dir? configurable?) - any choice is a heuristic but maybe we can take latest directory
  • need longest-prefix matching on lookup
  • need two fs instances alive at once (one auth, one anon) or pay the cost of rebuilding on each call ~100-150 lines of code + new tests for a scenario most users don't hit (mixed-access bucket with scoped creds)

My suggestion is to just don't use cache when creds are explicitlty set and that's it.


@classmethod
def _mark_bucket_anon(cls, name: str) -> None:
cls._ANON_BUCKETS.add((cls.protocol, name))

def __init__(self, name: str, fs_kwargs: dict[str, Any], cache: Cache) -> None:
self.name = name
self.fs_kwargs = fs_kwargs
Expand Down Expand Up @@ -232,7 +272,10 @@ def split_url(cls, url: str) -> tuple[str, str]:
@property
def fs(self) -> "AbstractFileSystem":
if not self._fs:
self._fs = self.create_fs(**self.fs_kwargs)
kwargs = dict(self.fs_kwargs)
if self._bucket_needs_anon(self.name):
kwargs["anon"] = True
self._fs = self.create_fs(**kwargs)
return self._fs

def url(
Expand All @@ -251,6 +294,7 @@ async def get_current_etag(self, file: "File") -> str:
info = await self.fs._info(full_path, **self._file_info_kwargs(file.version))
return self.info_to_file(info, file.path).etag

@_anon_fallback
Comment thread
shcheklein marked this conversation as resolved.
Outdated
def get_file_info(self, path: str, version_id: str | None = None) -> "File":
self.validate_file_path(path)
full_path = self.get_uri(path)
Expand Down Expand Up @@ -435,6 +479,7 @@ def do_instantiate_object(self, file: "File", dst: str) -> None:
# Default to copy if reflinks are not supported
shutil.copy2(src, dst)

@_anon_fallback
def open_object(
self, file: "File", use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
) -> BinaryIO:
Expand Down
4 changes: 3 additions & 1 deletion src/datachain/client/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from datachain.lib.file import File
from datachain.progress import tqdm

from .fsspec import DELIMITER, BucketStatus, Client, ResultQueue
from .fsspec import DELIMITER, BucketStatus, Client, ResultQueue, _anon_fallback

# Patch gcsfs for consistency with s3fs
GCSFileSystem.set_session = GCSFileSystem._set_session
Expand Down Expand Up @@ -154,6 +154,7 @@ async def get_current_etag(self, file: File) -> str:
info = await self.fs._info(path)
return self.info_to_file(info, file.path).etag

@_anon_fallback
def get_file_info(self, path: str, version_id: str | None = None) -> File:
self.validate_file_path(path)
fs_path = self._path_with_generation(self.get_uri(path), version_id)
Expand All @@ -168,6 +169,7 @@ async def get_size(self, file: File) -> int:
raise FileNotFoundError(file.get_fs_path())
return int(size)

@_anon_fallback
def open_object(
self,
file: File,
Expand Down
119 changes: 119 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
import sys
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock

import pytest

from datachain.client import Client
from datachain.client.gcs import GCSClient
from datachain.client.local import FileClient
from datachain.lib.file import File


def test_bad_protocol():
Expand All @@ -31,3 +34,119 @@ def test_parse_file_path_ends_with_slash(cloud_type):
uri, rel_part = Client.parse_url("./animals/".replace("/", os.sep))
assert uri == (Path().absolute() / Path("animals")).as_uri()
assert rel_part == ""


# Anonymous-access fallback (auto-retry on PermissionError) ---------------
Comment thread
shcheklein marked this conversation as resolved.
Outdated


@pytest.fixture
def _clear_anon_cache():
Client._ANON_BUCKETS.clear()
yield
Client._ANON_BUCKETS.clear()


def _gcs_client(bucket: str = "foo", **fs_kwargs) -> GCSClient:
return GCSClient(bucket, fs_kwargs, MagicMock())


def test_anon_fallback_no_error_no_retry(monkeypatch, _clear_anon_cache):
Comment thread
shcheklein marked this conversation as resolved.
Outdated
client = _gcs_client()
client._fs = MagicMock()
client._fs._info = AsyncMock(
return_value={
"name": "gs://foo/x.txt",
"size": 1,
"etag": "e",
"updated": "2024-01-01T00:00:00Z",
}
)
create_fs = MagicMock()
monkeypatch.setattr(GCSClient, "create_fs", create_fs)

client.get_file_info("x.txt")

create_fs.assert_not_called()
assert not GCSClient._bucket_needs_anon("foo")


def test_anon_fallback_retry_succeeds_marks_bucket(monkeypatch, _clear_anon_cache):
client = _gcs_client()
auth_fs = MagicMock()
auth_fs._info = AsyncMock(side_effect=PermissionError)
client._fs = auth_fs

anon_fs = MagicMock()
anon_fs._info = AsyncMock(
return_value={
"name": "gs://foo/x.txt",
"size": 1,
"etag": "e",
"updated": "2024-01-01T00:00:00Z",
}
)
monkeypatch.setattr(GCSClient, "create_fs", MagicMock(return_value=anon_fs))

client.get_file_info("x.txt")

assert GCSClient._bucket_needs_anon("foo")
assert client._fs is anon_fs
assert GCSClient.create_fs.call_args.kwargs.get("anon") is True


def test_anon_fallback_retry_also_fails_does_not_mark(monkeypatch, _clear_anon_cache):
client = _gcs_client()
auth_fs = MagicMock()
auth_fs._info = AsyncMock(side_effect=PermissionError)
client._fs = auth_fs

anon_fs = MagicMock()
anon_fs._info = AsyncMock(side_effect=PermissionError)
monkeypatch.setattr(GCSClient, "create_fs", MagicMock(return_value=anon_fs))

with pytest.raises(PermissionError):
client.get_file_info("x.txt")

assert not GCSClient._bucket_needs_anon("foo")
assert client._fs is auth_fs


def test_anon_fallback_cached_bucket_uses_anon_directly(monkeypatch, _clear_anon_cache):
GCSClient._mark_bucket_anon("foo")
create_fs = MagicMock()
monkeypatch.setattr(GCSClient, "create_fs", create_fs)

_ = _gcs_client().fs

create_fs.assert_called_once()
assert create_fs.call_args.kwargs.get("anon") is True


def test_anon_fallback_explicit_anon_no_retry(monkeypatch, _clear_anon_cache):
client = _gcs_client(anon=True)
client._fs = MagicMock()
client._fs._info = AsyncMock(side_effect=PermissionError)
create_fs = MagicMock()
monkeypatch.setattr(GCSClient, "create_fs", create_fs)

with pytest.raises(PermissionError):
client.get_file_info("x.txt")

create_fs.assert_not_called()
assert not GCSClient._bucket_needs_anon("foo")


def test_anon_fallback_open_object_retry_succeeds(monkeypatch, _clear_anon_cache):
client = _gcs_client()
auth_fs = MagicMock()
auth_fs.open = MagicMock(side_effect=PermissionError)
client._fs = auth_fs
client.cache.get_path = MagicMock(return_value=None)

anon_fs = MagicMock()
anon_fs.open = MagicMock(return_value=MagicMock())
monkeypatch.setattr(GCSClient, "create_fs", MagicMock(return_value=anon_fs))

client.open_object(File(source="gs://foo", path="x.txt"))

assert GCSClient._bucket_needs_anon("foo")
Loading