Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.test
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ PYTEST_ADMIN_PASSWORD=start123
PYTEST_DEFAULT_MASTER_IMAGE=python/base
PYTEST_ASYNC_MAX_RETRIES=5
PYTEST_ASYNC_RETRY_DELAY_MILLIS=500
PYTEST_HUB_VERSION=0.10.1
PYTEST_HUB_VERSION=0.10.2
23 changes: 11 additions & 12 deletions flame_hub/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,22 +275,22 @@ class GetKwargs(te.TypedDict, total=False):
meta: bool


def build_page_params(page_params: PageParams = None, default_page_params: PageParams = None) -> dict:
def build_page_params(page_params: PageParams | None = None, default_page_params: PageParams | None = None) -> dict:
"""Build a dictionary of query parameters based on provided pagination parameters."""
# use empty dict if None is provided
if default_page_params is None:
default_page_params = DEFAULT_PAGE_PARAMS

if page_params is None:
page_params = {}
page_params: PageParams = {}

# overwrite default values with user-defined ones
page_params = default_page_params | page_params

return {f"page[{k}]": v for k, v in page_params.items()}


def build_filter_params(filter_params: FilterParams = None) -> dict:
def build_filter_params(filter_params: FilterParams | None = None) -> dict:
"""Build a dictionary of query parameters based on provided filter parameters."""
if filter_params is None:
filter_params = {}
Expand All @@ -317,7 +317,7 @@ def build_filter_params(filter_params: FilterParams = None) -> dict:
return query_params


def build_sort_params(sort_params: SortParams = None) -> dict:
def build_sort_params(sort_params: SortParams | None = None) -> dict:
if sort_params is None:
sort_params = {}

Expand All @@ -337,7 +337,7 @@ def build_sort_params(sort_params: SortParams = None) -> dict:
return query_params


def build_include_params(include_params: IncludeParams = None) -> dict:
def build_include_params(include_params: IncludeParams | None = None) -> dict:
if include_params is None:
include_params = () # empty tuple

Expand All @@ -353,7 +353,7 @@ def build_include_params(include_params: IncludeParams = None) -> dict:
return {"include": ",".join(include_params)}


def build_field_params(field_params: FieldParams = None) -> dict:
def build_field_params(field_params: FieldParams | None = None) -> dict:
if field_params is None:
field_params = () # empty tuple

Expand Down Expand Up @@ -406,16 +406,15 @@ class BaseClient(object):
:py:class:`.AuthClient`, :py:class:`.CoreClient`, :py:class:`.StorageClient`
"""

def __init__(self, base_url: str, auth: PasswordAuth | ClientAuth = None, **kwargs: te.Unpack[ClientKwargs]):
def __init__(self, base_url: str, auth: PasswordAuth | ClientAuth | None = None, **kwargs: te.Unpack[ClientKwargs]):
client = kwargs.get("client", None)
# Set a read timeout of 20 seconds here because the endpoint for registry projects is slow.
self._client = client or httpx.Client(auth=auth, base_url=base_url, timeout=httpx.Timeout(5, read=20))
self._client = client or httpx.Client(auth=auth, base_url=base_url)

def _get_all_resources(
self,
resource_type: type[ResourceT],
*path: str,
include: IncludeParams = None,
include: IncludeParams | None = None,
expected_code: int = httpx.codes.OK.value,
**params: te.Unpack[GetKwargs],
) -> list[ResourceT] | tuple[list[ResourceT], ResourceListMeta]:
Expand All @@ -439,7 +438,7 @@ def _find_all_resources(
self,
resource_type: type[ResourceT],
*path: str,
include: IncludeParams = None,
include: IncludeParams | None = None,
expected_code: int = httpx.codes.OK.value,
**params: te.Unpack[FindAllKwargs],
) -> list[ResourceT] | tuple[list[ResourceT], ResourceListMeta]:
Expand Down Expand Up @@ -568,7 +567,7 @@ def _get_single_resource(
self,
resource_type: type[ResourceT],
*path: str | UuidIdentifiable,
include: IncludeParams = None,
include: IncludeParams | None = None,
expected_code: int = httpx.codes.OK.value,
**params: te.Unpack[GetKwargs],
) -> ResourceT | None:
Expand Down
28 changes: 24 additions & 4 deletions flame_hub/_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class CreateProject(BaseModel):
description: str | None
master_image_id: t.Annotated[uuid.UUID | None, Field(), WrapValidator(uuid_validator)]
name: str
display_name: str | None


class Project(CreateProject):
Expand All @@ -186,6 +187,7 @@ class UpdateProject(BaseModel):
description: str | None | UNSET_T = UNSET
master_image_id: t.Annotated[uuid.UUID | None | UNSET_T, Field(), WrapValidator(uuid_validator)] = UNSET
name: str | UNSET_T = UNSET
display_name: str | None | UNSET_T = UNSET


ProjectNodeApprovalStatus = t.Literal["rejected", "approved"]
Expand Down Expand Up @@ -229,6 +231,7 @@ class Log(BaseModel):
class CreateAnalysis(BaseModel):
description: str | None
name: str | None
display_name: str | None
project_id: t.Annotated[uuid.UUID, Field(), WrapValidator(uuid_validator)]
master_image_id: t.Annotated[uuid.UUID | None, Field(), WrapValidator(uuid_validator)]
registry_id: t.Annotated[uuid.UUID | None, Field(), WrapValidator(uuid_validator)]
Expand Down Expand Up @@ -271,7 +274,8 @@ class Analysis(CreateAnalysis):

class UpdateAnalysis(BaseModel):
description: str | None | UNSET_T = UNSET
name: str | None | UNSET_T = UNSET
name: str | UNSET_T = UNSET
display_name: str | None | UNSET_T = UNSET
master_image_id: t.Annotated[uuid.UUID | None | UNSET_T, Field(), WrapValidator(uuid_validator)] = UNSET
image_command_arguments: (
t.Annotated[
Expand Down Expand Up @@ -505,11 +509,20 @@ def build_master_image(self, master_image_id: MasterImage | uuid.UUID | str):
raise new_hub_api_error_from_response(r)

def create_project(
self, name: str, master_image_id: MasterImage | uuid.UUID | str = None, description: str = None
self,
name: str,
display_name: str = None,
master_image_id: MasterImage | uuid.UUID | str = None,
description: str = None,
) -> Project:
return self._create_resource(
Project,
CreateProject(name=name, master_image_id=master_image_id, description=description),
CreateProject(
name=name,
master_image_id=master_image_id,
description=description,
display_name=display_name,
),
"projects",
)

Expand All @@ -527,10 +540,13 @@ def update_project(
description: str | None | UNSET_T = UNSET,
master_image_id: MasterImage | str | uuid.UUID | None | UNSET_T = UNSET,
name: str | UNSET_T = UNSET,
display_name: str | None | UNSET_T = UNSET,
) -> Project:
return self._update_resource(
Project,
UpdateProject(description=description, master_image_id=master_image_id, name=name),
UpdateProject(
description=description, master_image_id=master_image_id, name=name, display_name=display_name
),
"projects",
project_id,
)
Expand Down Expand Up @@ -581,6 +597,7 @@ def create_analysis(
self,
project_id: Project | uuid.UUID | str,
name: str = None,
display_name: str = None,
description: str = None,
master_image_id: MasterImage | uuid.UUID | str = None,
registry_id: Registry | uuid.UUID | str = None,
Expand All @@ -591,6 +608,7 @@ def create_analysis(
CreateAnalysis(
project_id=project_id,
name=name,
display_name=display_name,
description=description,
master_image_id=master_image_id,
registry_id=registry_id,
Expand All @@ -617,6 +635,7 @@ def update_analysis(
self,
analysis_id: Analysis | uuid.UUID | str,
name: str | None | UNSET_T = UNSET,
display_name: str | None | UNSET_T = UNSET,
description: str | None | UNSET_T = UNSET,
master_image_id: MasterImage | uuid.UUID | str | None | UNSET_T = UNSET,
image_command_arguments: list[MasterImageCommandArgument] | UNSET_T = UNSET,
Expand All @@ -625,6 +644,7 @@ def update_analysis(
Analysis,
UpdateAnalysis(
name=name,
display_name=display_name,
description=description,
master_image_id=master_image_id,
image_command_arguments=image_command_arguments,
Expand Down
Loading