diff --git a/strix/core/execution.py b/strix/core/execution.py index 06dc3ddf5..78efeed25 100644 --- a/strix/core/execution.py +++ b/strix/core/execution.py @@ -4,6 +4,7 @@ import asyncio import contextlib +import json import logging import uuid from collections.abc import Callable @@ -36,6 +37,7 @@ StreamEventSink = Callable[[str, Any], None] _INPUT_REJECTION_CODES = frozenset({400, 404, 422}) +_TOOL_ARGUMENT_KEYS = frozenset({"action_input", "arguments", "input", "parameters", "params"}) async def run_agent_loop( @@ -62,7 +64,7 @@ async def run_agent_loop( if not (start_parked and interactive): if interactive: - result = await _run_cycle( + result = await _run_interactive_until_tool_continuation_settled( agent, coordinator, agent_id, @@ -71,7 +73,6 @@ async def run_agent_loop( context=context, max_turns=max_turns, session=session, - interactive=interactive, event_sink=event_sink, hooks=hooks, ) @@ -103,7 +104,7 @@ async def run_agent_loop( raise BudgetExceededError("scan budget reached") await coordinator.consume_pending(agent_id) - result = await _run_cycle( + result = await _run_interactive_until_tool_continuation_settled( agent, coordinator, agent_id, @@ -112,12 +113,73 @@ async def run_agent_loop( context=context, max_turns=max_turns, session=session, - interactive=interactive, event_sink=event_sink, hooks=hooks, ) +async def _run_interactive_until_tool_continuation_settled( + agent: Any, + coordinator: AgentCoordinator, + agent_id: str, + *, + input_data: Any, + run_config: RunConfig, + context: dict[str, Any], + max_turns: int, + session: Session | None, + event_sink: StreamEventSink | None, + hooks: RunHooks[dict[str, Any]] | None, +) -> RunResultBase | None: + """Retry interactive turns when a model prints tool-call JSON as final text.""" + result: RunResultBase | None = None + invalid_final_outputs = 0 + invalid_final_output_limit = max(1, max_turns) + + while True: + if coordinator.budget_stopped: + await coordinator.set_status(agent_id, "stopped") + raise BudgetExceededError("scan budget reached") + + result = await _run_cycle( + agent, + coordinator, + agent_id, + input_data=input_data, + run_config=run_config, + context=context, + max_turns=max_turns, + session=session, + interactive=True, + event_sink=event_sink, + hooks=hooks, + ) + + status = await _agent_status(coordinator, agent_id) + if status != "waiting" or not _looks_like_unexecuted_tool_call(result): + return result + + invalid_final_outputs += 1 + logger.warning( + "agent %s produced tool-call-shaped final output in interactive mode; " + "forcing tool continuation (%d/%d): %s", + agent_id, + invalid_final_outputs, + invalid_final_output_limit, + _final_output_preview(result), + ) + + if invalid_final_outputs >= invalid_final_output_limit: + return result + + input_data = await _append_interactive_tool_required_message( + session=session, + context=context, + attempt=invalid_final_outputs, + limit=invalid_final_output_limit, + ) + + async def spawn_child_agent( *, coordinator: AgentCoordinator, @@ -468,6 +530,72 @@ def _final_output_preview(result: RunResultBase | None) -> str: return text[:300] +def _looks_like_unexecuted_tool_call(result: RunResultBase | None) -> bool: + final_output = getattr(result, "final_output", None) + if final_output is None: + return False + if isinstance(final_output, str): + parsed = _parse_json_final_output(final_output) + return parsed is not None and _is_tool_call_payload(parsed) + return _is_tool_call_payload(final_output) + + +def _parse_json_final_output(text: str) -> Any | None: + stripped = text.strip() + if not stripped: + return None + if stripped.startswith("```"): + lines = stripped.splitlines() + if len(lines) >= 2 and lines[-1].strip() == "```": + stripped = "\n".join(lines[1:-1]).strip() + try: + return json.loads(stripped) + except (TypeError, ValueError): + return None + + +def _is_tool_call_payload(payload: Any) -> bool: + if isinstance(payload, list): + return any(_is_tool_call_payload(item) for item in payload) + if not isinstance(payload, dict): + return False + + tool_calls = payload.get("tool_calls") + if isinstance(tool_calls, list) and any(_is_tool_call_payload(item) for item in tool_calls): + return True + + function = payload.get("function") + if isinstance(function, dict) and _is_tool_call_payload(function): + return True + + tool_name = payload.get("action") or payload.get("tool") or payload.get("name") + return ( + isinstance(tool_name, str) + and bool(tool_name.strip()) + and any(key in payload for key in _TOOL_ARGUMENT_KEYS) + ) + + +async def _append_interactive_tool_required_message( + *, + session: Session | None, + context: dict[str, Any], + attempt: int, + limit: int, +) -> list[dict[str, str]]: + finish_tool = "finish_scan" if context.get("parent_id") is None else "agent_finish" + message = ( + "Your previous response looked like a tool call, but it was returned as plain text " + "instead of being executed. Plain-text tool-call JSON is not executed by Strix. " + "Continue immediately and call exactly one tool. " + f"If your work is complete, call {finish_tool}. " + "If you are blocked waiting for another agent, call wait_for_message. " + "Otherwise use the appropriate execution or planning tool. " + f"This is recovery attempt {attempt}/{limit}." + ) + return await _append_tool_required_message(session=session, message=message) + + async def _append_noninteractive_tool_required_message( *, session: Session | None, @@ -485,6 +613,14 @@ async def _append_noninteractive_tool_required_message( "Otherwise use the appropriate execution or planning tool. " f"This is recovery attempt {attempt}/{limit}." ) + return await _append_tool_required_message(session=session, message=message) + + +async def _append_tool_required_message( + *, + session: Session | None, + message: str, +) -> list[dict[str, str]]: item = {"role": "user", "content": message} if session is None: return [item] diff --git a/tests/test_execution.py b/tests/test_execution.py index 59a37e651..069cad9e3 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -3,12 +3,65 @@ from __future__ import annotations import asyncio +from typing import TYPE_CHECKING, Any import pytest +from agents import RunConfig +from strix.core import execution from strix.core.agents import AgentCoordinator +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Awaitable, Callable + + +class FakeStream: + def __init__( + self, + final_output: Any, + on_stream: Callable[[], Awaitable[None]] | None = None, + ) -> None: + self.final_output = final_output + self.on_stream = on_stream + self.run_loop_exception: Exception | None = None + self.cancel_modes: list[str] = [] + + async def stream_events(self) -> AsyncIterator[Any]: + if self.on_stream is not None: + await self.on_stream() + for event in (): + yield event + + def cancel(self, mode: str) -> None: + self.cancel_modes.append(mode) + + +class FakeSession: + def __init__(self) -> None: + self.items: list[Any] = [] + + async def add_items(self, items: list[Any]) -> None: + self.items.extend(items) + + +def _install_streams(monkeypatch: pytest.MonkeyPatch, streams: list[FakeStream]) -> list[Any]: + inputs: list[Any] = [] + + def fake_run_streamed(*_args: Any, **kwargs: Any) -> FakeStream: + inputs.append(kwargs["input"]) + return streams.pop(0) + + monkeypatch.setattr(execution.Runner, "run_streamed", fake_run_streamed) + return inputs + + +async def _registered_coordinator() -> AgentCoordinator: + coordinator = AgentCoordinator() + await coordinator.register("root", "strix", parent_id=None) + return coordinator + + @pytest.mark.asyncio async def test_budget_stop_sets_flag() -> None: coordinator = AgentCoordinator() @@ -42,3 +95,152 @@ async def test_wait_for_message_returns_immediately_after_budget_stop() -> None: # No pending messages, but the stop flag short-circuits the wait. await asyncio.wait_for(coordinator.wait_for_message("agent"), timeout=1.0) + + +@pytest.mark.asyncio +async def test_interactive_tool_call_json_forces_retry(monkeypatch: pytest.MonkeyPatch) -> None: + session = FakeSession() + streams = [ + FakeStream('{"action": "exec_command", "params": {"cmd": "ls"}}'), + FakeStream("plain response after retry"), + ] + inputs = _install_streams(monkeypatch, streams) + coordinator = await _registered_coordinator() + + await execution._run_interactive_until_tool_continuation_settled( + object(), + coordinator, + "root", + input_data="start", + run_config=RunConfig(), + context={"parent_id": None}, + max_turns=3, + session=session, # type: ignore[arg-type] + event_sink=None, + hooks=None, + ) + + assert inputs == ["start", []] + assert len(session.items) == 1 + assert "call exactly one tool" in session.items[0]["content"] + assert await execution._agent_status(coordinator, "root") == "waiting" + + +@pytest.mark.asyncio +async def test_interactive_tool_call_retry_can_complete(monkeypatch: pytest.MonkeyPatch) -> None: + session = FakeSession() + coordinator = await _registered_coordinator() + + async def complete_scan() -> None: + await coordinator.set_status("root", "completed") + + streams = [ + FakeStream('{"action": "exec_command", "params": {"cmd": "ls"}}'), + FakeStream('{"success": true, "scan_completed": true}', on_stream=complete_scan), + ] + inputs = _install_streams(monkeypatch, streams) + + await execution._run_interactive_until_tool_continuation_settled( + object(), + coordinator, + "root", + input_data="start", + run_config=RunConfig(), + context={"parent_id": None}, + max_turns=3, + session=session, # type: ignore[arg-type] + event_sink=None, + hooks=None, + ) + + assert inputs == ["start", []] + assert await execution._agent_status(coordinator, "root") == "completed" + + +@pytest.mark.asyncio +async def test_interactive_tool_call_json_retry_limit_parks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + session = FakeSession() + streams = [ + FakeStream('{"action": "exec_command", "params": {"cmd": "ls"}}'), + FakeStream('{"action": "exec_command", "params": {"cmd": "pwd"}}'), + ] + inputs = _install_streams(monkeypatch, streams) + coordinator = await _registered_coordinator() + + await execution._run_interactive_until_tool_continuation_settled( + object(), + coordinator, + "root", + input_data="start", + run_config=RunConfig(), + context={"parent_id": None}, + max_turns=2, + session=session, # type: ignore[arg-type] + event_sink=None, + hooks=None, + ) + + assert inputs == ["start", []] + assert len(session.items) == 1 + assert await execution._agent_status(coordinator, "root") == "waiting" + + +@pytest.mark.asyncio +async def test_interactive_plain_text_still_parks(monkeypatch: pytest.MonkeyPatch) -> None: + session = FakeSession() + inputs = _install_streams(monkeypatch, [FakeStream("This is a normal answer.")]) + coordinator = await _registered_coordinator() + + await execution._run_interactive_until_tool_continuation_settled( + object(), + coordinator, + "root", + input_data="start", + run_config=RunConfig(), + context={"parent_id": None}, + max_turns=3, + session=session, # type: ignore[arg-type] + event_sink=None, + hooks=None, + ) + + assert inputs == ["start"] + assert session.items == [] + assert await execution._agent_status(coordinator, "root") == "waiting" + + +@pytest.mark.asyncio +async def test_noninteractive_plain_text_recovery_is_unchanged( + monkeypatch: pytest.MonkeyPatch, +) -> None: + session = FakeSession() + coordinator = await _registered_coordinator() + + async def complete_scan() -> None: + await coordinator.set_status("root", "completed") + + streams = [ + FakeStream("plain response"), + FakeStream('{"success": true, "scan_completed": true}', on_stream=complete_scan), + ] + inputs = _install_streams(monkeypatch, streams) + + await execution._run_noninteractive_until_lifecycle( + object(), + coordinator, + "root", + initial_input="start", + run_config=RunConfig(), + context={"parent_id": None}, + max_turns=3, + session=session, # type: ignore[arg-type] + event_sink=None, + hooks=None, + ) + + assert inputs == ["start", []] + assert len(session.items) == 1 + assert "non-interactive mode" in session.items[0]["content"] + assert await execution._agent_status(coordinator, "root") == "completed"