Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 227 additions & 2 deletions tests/unit/executor/test_minimal_runner_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,38 @@
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from typing import Any
from typing import Annotated, Any
from uuid import UUID

import httpx
import orjson
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field

from tracecat.executor import minimal_runner

# --- Action Gateway compatibility regressions ---

_tag_factory_calls = 0


def _tag_factory() -> list[str]:
global _tag_factory_calls
_tag_factory_calls += 1
return [f"call-{_tag_factory_calls}"]


class _NonCopyableDefault:
def __deepcopy__(self, memo: dict[int, object]) -> object:
raise AssertionError("factory results should not be deep-copied")


_NON_COPYABLE_DEFAULT = _NonCopyableDefault()


def _non_copyable_factory() -> _NonCopyableDefault:
return _NON_COPYABLE_DEFAULT


def test_action_gateway_sdk_transport_patches_legacy_tracecat_client(
monkeypatch,
Expand Down Expand Up @@ -393,6 +413,211 @@ def boom_action() -> None:
assert result["error"]["message"] == "boom"


def test_materialize_annotated_field_defaults_applies_pydantic_field_default() -> None:
def create_ticket(
summary: Annotated[str, Field(..., description="Required summary")],
client_id: Annotated[
int | None, Field(default=None, description="Optional client")
],
) -> dict[str, str | int | None]:
return {"summary": summary, "client_id": client_id}

args = minimal_runner.materialize_annotated_field_defaults(
create_ticket,
{"summary": "Ticket without client_id"},
)

assert args == {"summary": "Ticket without client_id", "client_id": None}


def test_materialize_annotated_field_defaults_preserves_provided_args() -> None:
def create_ticket(
client_id: Annotated[
int | None, Field(default=None, description="Optional client")
],
priority: Annotated[int, Field(default=3, description="Priority")],
) -> dict[str, int | None]:
return {"client_id": client_id, "priority": priority}

args = minimal_runner.materialize_annotated_field_defaults(
create_ticket,
{"client_id": 42, "priority": 5},
)

assert args == {"client_id": 42, "priority": 5}


def test_materialize_annotated_field_defaults_keeps_nullable_without_default_required() -> None:
def create_ticket(client_id: int | None) -> int | None:
return client_id

args = minimal_runner.materialize_annotated_field_defaults(create_ticket, {})

assert args == {}


def test_materialize_annotated_field_defaults_copies_mutable_defaults() -> None:
def collect_tags(
tags: Annotated[list[str], Field(default=["base"])],
) -> list[str]:
return tags

first = minimal_runner.materialize_annotated_field_defaults(collect_tags, {})
second = minimal_runner.materialize_annotated_field_defaults(collect_tags, {})

first["tags"].append("first")

assert first == {"tags": ["base", "first"]}
assert second == {"tags": ["base"]}


def test_materialize_annotated_field_defaults_calls_default_factory_each_time() -> None:
global _tag_factory_calls
_tag_factory_calls = 0

def collect_tags(
tags: Annotated[list[str], Field(default_factory=_tag_factory)],
) -> list[str]:
return tags

first = minimal_runner.materialize_annotated_field_defaults(collect_tags, {})
second = minimal_runner.materialize_annotated_field_defaults(collect_tags, {})

first["tags"].append("first")

assert _tag_factory_calls == 2
assert first == {"tags": ["call-1", "first"]}
assert second == {"tags": ["call-2"]}


def test_materialize_annotated_field_defaults_does_not_copy_factory_result() -> None:
def collect_value(
value: Annotated[
_NonCopyableDefault, Field(default_factory=_non_copyable_factory)
],
) -> _NonCopyableDefault:
return value

args = minimal_runner.materialize_annotated_field_defaults(collect_value, {})

assert args == {"value": _NON_COPYABLE_DEFAULT}


def test_materialize_annotated_field_defaults_supports_validated_data_factory() -> None:
def assign_priority(
severity: int,
priority: Annotated[
int, Field(default_factory=lambda data: data["severity"] + 1)
],
) -> dict[str, int]:
return {"severity": severity, "priority": priority}

args = minimal_runner.materialize_annotated_field_defaults(
assign_priority,
{"severity": 2},
)

assert args == {"severity": 2, "priority": 3}


def test_main_minimal_applies_annotated_field_defaults_before_udf_call(
monkeypatch,
) -> None:
test_module: Any = types.ModuleType("test_module")

def create_ticket(
summary: Annotated[str, Field(..., description="Required summary")],
client_id: Annotated[
int | None, Field(default=None, description="Optional client")
],
priority: Annotated[int, Field(default=3, description="Priority")],
) -> dict[str, str | int | None]:
return {"summary": summary, "client_id": client_id, "priority": priority}

test_module.create_ticket = create_ticket

def import_test_module(path: str, *args: object, **kwargs: object) -> Any:
assert path == "test_module"
return test_module

monkeypatch.setattr(
minimal_runner.importlib,
"import_module",
import_test_module,
)

result = minimal_runner.main_minimal(
{
"resolved_context": {
"action_impl": {
"type": "udf",
"module": "test_module",
"name": "create_ticket",
},
"evaluated_args": {"summary": "Ticket without client_id"},
},
"secret_env": {},
}
)

assert result == {
"success": True,
"result": {
"summary": "Ticket without client_id",
"client_id": None,
"priority": 3,
},
}


@pytest.mark.anyio
async def test_run_action_minimal_async_applies_annotated_field_defaults(
monkeypatch,
) -> None:
test_module: Any = types.ModuleType("test_module")

async def create_ticket(
summary: Annotated[str, Field(..., description="Required summary")],
client_id: Annotated[
int | None, Field(default=None, description="Optional client")
],
priority: Annotated[int, Field(default=3, description="Priority")],
) -> dict[str, str | int | None]:
return {"summary": summary, "client_id": client_id, "priority": priority}

test_module.create_ticket = create_ticket

def import_test_module(path: str, *args: object, **kwargs: object) -> Any:
assert path == "test_module"
return test_module

monkeypatch.setattr(
minimal_runner.importlib,
"import_module",
import_test_module,
)

result = await minimal_runner.run_action_minimal_async(
action_impl={
"type": "udf",
"module": "test_module",
"name": "create_ticket",
},
args={"summary": "Ticket without client_id"},
secret_env={},
workspace_id="workspace",
workflow_id="workflow",
run_id="run",
executor_token="token",
)

assert result == {
"summary": "Ticket without client_id",
"client_id": None,
"priority": 3,
}


def test_capped_text_buffer_limits_memory_growth() -> None:
buf = minimal_runner._CappedTextBuffer(limit=5)

Expand Down
Loading