From 8ee692687f7e80ff688cf8baa5b0920c8be2146b Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Fri, 29 May 2026 16:24:26 -0400 Subject: [PATCH 1/7] fix(cases): async case duration sync --- frontend/src/client/schemas.gen.ts | 1 + frontend/src/client/services.gen.ts | 2 +- frontend/src/client/types.gen.ts | 2 + .../test_case_duration_benchmarks.py | 241 +++++++++++ tests/unit/test_case_duration_router.py | 37 ++ tests/unit/test_case_duration_service.py | 97 +++++ .../unit/test_case_duration_sync_consumer.py | 162 +++++++ tests/unit/test_case_events_service.py | 126 +++++- tests/unit/test_cases_service.py | 16 +- tracecat/api/app.py | 18 + tracecat/cases/durations/consumer.py | 407 ++++++++++++++++++ tracecat/cases/durations/router.py | 11 +- tracecat/cases/durations/schemas.py | 3 + tracecat/cases/durations/service.py | 13 + tracecat/cases/durations/sync_queue.py | 84 ++++ tracecat/cases/service.py | 34 +- tracecat/config.py | 41 ++ tracecat/identifiers/__init__.py | 1 + 18 files changed, 1272 insertions(+), 24 deletions(-) create mode 100644 tests/unit/test_case_duration_router.py create mode 100644 tests/unit/test_case_duration_sync_consumer.py create mode 100644 tracecat/cases/durations/consumer.py create mode 100644 tracecat/cases/durations/sync_queue.py diff --git a/frontend/src/client/schemas.gen.ts b/frontend/src/client/schemas.gen.ts index f8550ed32a..a63ef91f2b 100644 --- a/frontend/src/client/schemas.gen.ts +++ b/frontend/src/client/schemas.gen.ts @@ -18918,6 +18918,7 @@ export const $Role = { "tracecat-cli", "tracecat-executor", "tracecat-agent-executor", + "tracecat-case-duration-sync", "tracecat-case-triggers", "tracecat-llm-gateway", "tracecat-mcp", diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index 3e57837079..6bf2cf188d 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -10735,7 +10735,7 @@ export const caseDurationsDeleteCaseDurationDefinition = ( /** * List Case Durations - * Sync and list case durations for the provided case. + * List materialized case durations for the provided case. * @param data The data for the request. * @param data.caseId * @param data.workspaceId diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index 9a3d5cfb01..216c65f9c4 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -5780,6 +5780,7 @@ export type Role = { | "tracecat-cli" | "tracecat-executor" | "tracecat-agent-executor" + | "tracecat-case-duration-sync" | "tracecat-case-triggers" | "tracecat-llm-gateway" | "tracecat-mcp" @@ -5800,6 +5801,7 @@ export type service_id = | "tracecat-cli" | "tracecat-executor" | "tracecat-agent-executor" + | "tracecat-case-duration-sync" | "tracecat-case-triggers" | "tracecat-llm-gateway" | "tracecat-mcp" diff --git a/tests/integration/test_case_duration_benchmarks.py b/tests/integration/test_case_duration_benchmarks.py index b1d78684d6..cb245b78ab 100644 --- a/tests/integration/test_case_duration_benchmarks.py +++ b/tests/integration/test_case_duration_benchmarks.py @@ -13,12 +13,18 @@ TRACECAT_CASE_DURATION_BENCHMARK_UPDATES_PER_CASE TRACECAT_CASE_DURATION_BENCHMARK_HEALTH_INTERVAL_MS TRACECAT_CASE_DURATION_BENCHMARK_HEALTH_TIMEOUT_MS + TRACECAT_CASE_DURATION_BENCHMARK_HOT_CASE_MUTATORS + TRACECAT_CASE_DURATION_BENCHMARK_HOT_CASE_MUTATIONS + TRACECAT_CASE_DURATION_BENCHMARK_HOT_CASE_LOADS + TRACECAT_CASE_DURATION_BENCHMARK_HOT_CASE_BASELINE_LOADS + TRACECAT_CASE_DURATION_BENCHMARK_HOT_CASE_LOAD_INTERVAL_MS TRACECAT_CASE_DURATION_BENCHMARK_OUTPUT """ from __future__ import annotations import asyncio +import contextlib import json import math import os @@ -39,6 +45,8 @@ from tracecat.api.app import app from tracecat.auth.types import Role from tracecat.authz.scopes import ADMIN_SCOPES +from tracecat.cases.durations import consumer as duration_sync_consumer +from tracecat.cases.durations.consumer import CaseDurationSyncConsumer from tracecat.cases.durations.schemas import ( CaseDurationAnchorSelection, CaseDurationDefinitionCreate, @@ -53,6 +61,7 @@ from tracecat.cases.schemas import CaseCreate, CaseUpdate from tracecat.cases.service import CasesService from tracecat.db.models import CaseEvent, Organization, Workspace +from tracecat.redis.client import get_redis_client RUN_BENCHMARKS = os.environ.get("TRACECAT_RUN_CASE_DURATION_BENCHMARKS") == "1" BENCHMARK_OUTPUT_PATH = os.environ.get("TRACECAT_CASE_DURATION_BENCHMARK_OUTPUT") @@ -94,6 +103,25 @@ class CaseDurationBurstBenchmarkConfig: ) / 1000 ) + hot_case_mutators: int = int( + os.environ.get("TRACECAT_CASE_DURATION_BENCHMARK_HOT_CASE_MUTATORS") or 4 + ) + hot_case_mutations: int = int( + os.environ.get("TRACECAT_CASE_DURATION_BENCHMARK_HOT_CASE_MUTATIONS") or 8 + ) + hot_case_loads: int = int( + os.environ.get("TRACECAT_CASE_DURATION_BENCHMARK_HOT_CASE_LOADS") or 12 + ) + hot_case_baseline_loads: int = int( + os.environ.get("TRACECAT_CASE_DURATION_BENCHMARK_HOT_CASE_BASELINE_LOADS") or 3 + ) + hot_case_load_interval_s: float = ( + int( + os.environ.get("TRACECAT_CASE_DURATION_BENCHMARK_HOT_CASE_LOAD_INTERVAL_MS") + or 10 + ) + / 1000 + ) def _percentile(values: list[float], percentile: float) -> float | None: @@ -297,6 +325,77 @@ async def update_one_case( ) +async def _sync_initial_case_durations(*, async_engine, role: Role, case_id: uuid.UUID): + async with AsyncSession(async_engine, expire_on_commit=False) as session: + await CaseDurationService(session=session, role=role).sync_case_durations( + case_id + ) + await session.commit() + + +async def _load_case_page_once( + *, + async_engine, + role: Role, + case_id: uuid.UUID, +) -> float: + async def load_case_detail() -> None: + async with AsyncSession(async_engine, expire_on_commit=False) as session: + case = await CasesService(session=session, role=role).get_case( + case_id, + track_view=True, + ) + if case is None: + raise AssertionError(f"Case {case_id} not found during benchmark") + + async def load_case_durations() -> None: + async with AsyncSession(async_engine, expire_on_commit=False) as session: + await CaseDurationService(session=session, role=role).list_durations( + case_id + ) + + started = time.perf_counter() + await asyncio.gather(load_case_detail(), load_case_durations()) + return time.perf_counter() - started + + +async def _load_case_page_repeatedly( + *, + async_engine, + role: Role, + case_id: uuid.UUID, + load_count: int, + interval_s: float, +) -> list[float]: + latencies: list[float] = [] + for _ in range(load_count): + latencies.append( + await _load_case_page_once( + async_engine=async_engine, + role=role, + case_id=case_id, + ) + ) + await asyncio.sleep(interval_s) + return latencies + + +async def _run_hot_case_update_burst( + *, + async_engine, + role: Role, + case_id: uuid.UUID, + mutators: int, + mutations_per_mutator: int, +) -> tuple[list[float], int]: + return await _run_case_update_burst( + async_engine=async_engine, + role=role, + case_ids=[case_id for _ in range(mutators)], + updates_per_case=mutations_per_mutator, + ) + + @pytest.mark.anyio async def test_case_duration_update_burst_health_latency( monkeypatch: pytest.MonkeyPatch, @@ -429,3 +528,145 @@ async def probe_health() -> None: assert health_latencies["burst"] finally: await async_engine.dispose() + + +@pytest.mark.anyio +async def test_hot_case_load_latency_during_async_duration_mutation_burst( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Measure case-load latency while same-case mutations enqueue async sync.""" + + cfg = CaseDurationBurstBenchmarkConfig(case_count=1) + stream_suffix = uuid.uuid4().hex[:8] + monkeypatch.setattr(config, "TRACECAT__CASE_TRIGGERS_ENABLED", False) + monkeypatch.setattr(config, "TRACECAT__CASE_DURATION_SYNC_ENABLED", False) + monkeypatch.setattr( + config, + "TRACECAT__CASE_DURATION_SYNC_STREAM_KEY", + f"case-duration-sync-benchmark-{stream_suffix}", + ) + monkeypatch.setattr( + config, + "TRACECAT__CASE_DURATION_SYNC_GROUP", + f"case-duration-sync-benchmark-{stream_suffix}", + ) + + async_engine = create_async_engine( + TEST_DB_CONFIG.test_url, + poolclass=NullPool, + ) + consumer_task: asyncio.Task[None] | None = None + + @contextlib.asynccontextmanager + async def benchmark_bypass_session(): + async with AsyncSession(async_engine, expire_on_commit=False) as session: + yield session + + try: + with ( + patch.object( + CaseDurationDefinitionService, + "has_entitlement", + new=AsyncMock(return_value=True), + ), + patch.object( + CaseDurationService, + "has_entitlement", + new=AsyncMock(return_value=True), + ), + patch.object( + CasesService, + "has_entitlement", + new=AsyncMock(return_value=False), + ), + patch.object( + duration_sync_consumer, + "get_async_session_bypass_rls_context_manager", + benchmark_bypass_session, + ), + ): + role = await _seed_benchmark_role(async_engine) + case_ids = await _seed_cases_definitions_and_history( + async_engine=async_engine, + role=role, + cfg=cfg, + ) + case_id = case_ids[0] + await _sync_initial_case_durations( + async_engine=async_engine, + role=role, + case_id=case_id, + ) + monkeypatch.setattr(config, "TRACECAT__CASE_DURATION_SYNC_ENABLED", True) + consumer = CaseDurationSyncConsumer( + await get_redis_client(), + consumer_name=f"duration-benchmark-{uuid.uuid4().hex[:8]}", + ) + consumer_task = asyncio.create_task(consumer.run()) + await asyncio.sleep(0.1) + + baseline_loads = await _load_case_page_repeatedly( + async_engine=async_engine, + role=role, + case_id=case_id, + load_count=cfg.hot_case_baseline_loads, + interval_s=cfg.hot_case_load_interval_s, + ) + + load_task = asyncio.create_task( + _load_case_page_repeatedly( + async_engine=async_engine, + role=role, + case_id=case_id, + load_count=cfg.hot_case_loads, + interval_s=cfg.hot_case_load_interval_s, + ) + ) + mutation_task = asyncio.create_task( + _run_hot_case_update_burst( + async_engine=async_engine, + role=role, + case_id=case_id, + mutators=cfg.hot_case_mutators, + mutations_per_mutator=cfg.hot_case_mutations, + ) + ) + burst_loads, (mutation_latencies, mutation_errors) = await asyncio.gather( + load_task, + mutation_task, + ) + await asyncio.sleep(0.5) + + summary: dict[str, object] = { + "config": { + "cases": cfg.case_count, + "definitions": cfg.definition_count, + "history_events_per_case": cfg.history_events_per_case, + "hot_case_mutators": cfg.hot_case_mutators, + "hot_case_mutations": cfg.hot_case_mutations, + "hot_case_loads": cfg.hot_case_loads, + "hot_case_baseline_loads": cfg.hot_case_baseline_loads, + "hot_case_load_interval_ms": round(cfg.hot_case_load_interval_s * 1000), + }, + "case_load_baseline": _latency_stats(baseline_loads), + "case_load_burst": _latency_stats(burst_loads), + "mutation_latencies": _latency_stats(mutation_latencies), + "mutation_errors": mutation_errors, + } + _write_summary_to_file(summary) + + print("\nHot case async duration sync benchmark:") + print(summary) + + assert baseline_loads + assert burst_loads + assert mutation_latencies + assert mutation_errors == 0 + finally: + if consumer_task is not None: + consumer_task.cancel() + try: + await consumer_task + except asyncio.CancelledError: + pass + await async_engine.dispose() diff --git a/tests/unit/test_case_duration_router.py b/tests/unit/test_case_duration_router.py new file mode 100644 index 0000000000..005fa08cd1 --- /dev/null +++ b/tests/unit/test_case_duration_router.py @@ -0,0 +1,37 @@ +import uuid +from unittest.mock import AsyncMock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from tracecat.cases.durations.router import list_case_durations +from tracecat.cases.durations.service import CaseDurationService + +pytestmark = pytest.mark.usefixtures("db") + + +@pytest.mark.anyio +async def test_list_case_durations_is_read_only( + session: AsyncSession, + svc_role, + monkeypatch: pytest.MonkeyPatch, +) -> None: + sync_mock = AsyncMock() + list_mock = AsyncMock(return_value=[]) + commit_mock = AsyncMock() + + monkeypatch.setattr(CaseDurationService, "sync_case_durations", sync_mock) + monkeypatch.setattr(CaseDurationService, "list_durations", list_mock) + monkeypatch.setattr(session, "commit", commit_mock) + + case_id = uuid.uuid4() + result = await list_case_durations( + role=svc_role, + session=session, + case_id=case_id, + ) + + assert result == [] + sync_mock.assert_not_awaited() + commit_mock.assert_not_awaited() + list_mock.assert_awaited_once_with(case_id) diff --git a/tests/unit/test_case_duration_service.py b/tests/unit/test_case_duration_service.py index 4e3c9ab693..3c649b8eda 100644 --- a/tests/unit/test_case_duration_service.py +++ b/tests/unit/test_case_duration_service.py @@ -49,6 +49,18 @@ def stub_case_duration_entitlements() -> Iterator[None]: "has_entitlement", new=AsyncMock(return_value=False), ), + patch( + "tracecat.cases.durations.service.enqueue_case_duration_sync_after_commit", + return_value=None, + ), + patch( + "tracecat.cases.service.enqueue_case_duration_sync_after_commit", + return_value=None, + ), + patch( + "tracecat.cases.service.publish_case_event_payload", + new=AsyncMock(return_value=None), + ), ): yield @@ -73,6 +85,85 @@ def make_case_create( ) +@pytest.mark.anyio +async def test_create_definition_enqueues_backfill_after_commit( + session: AsyncSession, + svc_role, + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[dict[str, object]] = [] + + def fake_enqueue(*args, **kwargs) -> None: + calls.append(kwargs) + + monkeypatch.setattr( + "tracecat.cases.durations.service.enqueue_case_duration_sync_after_commit", + fake_enqueue, + ) + + definition_service = CaseDurationDefinitionService(session=session, role=svc_role) + await definition_service.create_definition( + CaseDurationDefinitionCreate( + name="Backfilled Duration", + start_anchor=CaseDurationEventAnchor( + event_type=CaseEventType.CASE_CREATED, + ), + end_anchor=CaseDurationEventAnchor( + event_type=CaseEventType.STATUS_CHANGED, + filters=CaseDurationEventFilters( + new_values=[CaseStatus.RESOLVED.value] + ), + ), + ) + ) + + assert len(calls) == 1 + assert calls[0]["workspace_id"] == svc_role.workspace_id + assert calls[0]["reason"] == "duration_definition_created" + + +@pytest.mark.anyio +async def test_update_definition_enqueues_backfill_after_commit( + session: AsyncSession, + svc_role, + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[dict[str, object]] = [] + + def fake_enqueue(*args, **kwargs) -> None: + calls.append(kwargs) + + definition_service = CaseDurationDefinitionService(session=session, role=svc_role) + definition = await definition_service.create_definition( + CaseDurationDefinitionCreate( + name="Backfilled Duration", + start_anchor=CaseDurationEventAnchor( + event_type=CaseEventType.CASE_CREATED, + ), + end_anchor=CaseDurationEventAnchor( + event_type=CaseEventType.STATUS_CHANGED, + filters=CaseDurationEventFilters( + new_values=[CaseStatus.RESOLVED.value] + ), + ), + ) + ) + calls.clear() + monkeypatch.setattr( + "tracecat.cases.durations.service.enqueue_case_duration_sync_after_commit", + fake_enqueue, + ) + + await definition_service.update_definition( + definition.id, + CaseDurationDefinitionUpdate(description="Updated description"), + ) + + assert len(calls) == 1 + assert calls[0]["workspace_id"] == svc_role.workspace_id + assert calls[0]["reason"] == "duration_definition_updated" + + @pytest.mark.anyio async def test_compute_case_durations_from_events( session: AsyncSession, svc_role @@ -130,6 +221,7 @@ async def test_compute_case_durations_from_events( assert value.duration is not None assert value.duration.total_seconds() >= 0 + await duration_service.sync_case_durations(updated_case) duration_stmt = select(CaseDuration).where(CaseDuration.case_id == case.id) stored_duration = await session.execute(duration_stmt) record = stored_duration.scalar_one() @@ -1050,6 +1142,11 @@ def test_duration_anchor_allows_empty_filters_for_unfiltered_events( assert anchor.filters == CaseDurationEventFilters() +def test_duration_anchor_rejects_case_viewed() -> None: + with pytest.raises(ValidationError): + CaseDurationEventAnchor(event_type=CaseEventType.CASE_VIEWED) + + @pytest.mark.anyio async def test_duration_storage_maps_known_legacy_ui_filters( session: AsyncSession, svc_role diff --git a/tests/unit/test_case_duration_sync_consumer.py b/tests/unit/test_case_duration_sync_consumer.py new file mode 100644 index 0000000000..b88ab9b38f --- /dev/null +++ b/tests/unit/test_case_duration_sync_consumer.py @@ -0,0 +1,162 @@ +import uuid +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from tracecat.cases.durations.consumer import CaseDurationSyncConsumer +from tracecat.redis.client import RedisClient + + +class FakeRedisClient: + def __init__(self) -> None: + self.acked: list[list[str]] = [] + + async def xack( + self, + stream_key: str, + group_name: str, + message_ids: list[str], + ) -> None: + del stream_key, group_name + self.acked.append(message_ids) + + +@pytest.mark.anyio +async def test_consumer_coalesces_case_jobs_by_case( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = FakeRedisClient() + consumer = CaseDurationSyncConsumer(cast(RedisClient, client)) + workspace_id = uuid.uuid4() + case_id = uuid.uuid4() + sync_mock = AsyncMock(return_value=True) + monkeypatch.setattr(consumer, "_sync_case_duration", sync_mock) + + await consumer._handle_entries( + [ + ( + "1-0", + { + "workspace_id": str(workspace_id), + "case_id": str(case_id), + "reason": "case_event", + "event_type": "case_updated", + }, + ), + ( + "2-0", + { + "workspace_id": str(workspace_id), + "case_id": str(case_id), + "reason": "case_event", + "event_type": "case_updated", + }, + ), + ] + ) + + sync_mock.assert_awaited_once_with( + workspace_id, + case_id, + event_types={"case_updated"}, + ) + assert client.acked == [["1-0", "2-0"]] + + +@pytest.mark.anyio +async def test_consumer_leaves_locked_case_jobs_pending( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = FakeRedisClient() + consumer = CaseDurationSyncConsumer(cast(RedisClient, client)) + sync_mock = AsyncMock(return_value=False) + monkeypatch.setattr(consumer, "_sync_case_duration", sync_mock) + + await consumer._handle_entries( + [ + ( + "1-0", + { + "workspace_id": str(uuid.uuid4()), + "case_id": str(uuid.uuid4()), + "reason": "case_event", + }, + ) + ] + ) + + sync_mock.assert_awaited_once() + assert client.acked == [] + + +@pytest.mark.anyio +async def test_consumer_leaves_failed_case_jobs_pending( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = FakeRedisClient() + consumer = CaseDurationSyncConsumer(cast(RedisClient, client)) + sync_mock = AsyncMock(side_effect=RuntimeError("transient db failure")) + logger_mock = MagicMock() + monkeypatch.setattr(consumer, "_sync_case_duration", sync_mock) + monkeypatch.setattr( + "tracecat.cases.durations.consumer.logger.exception", + logger_mock, + ) + + await consumer._handle_entries( + [ + ( + "1-0", + { + "workspace_id": str(uuid.uuid4()), + "case_id": str(uuid.uuid4()), + "reason": "case_event", + }, + ) + ] + ) + + sync_mock.assert_awaited_once() + logger_mock.assert_called_once() + assert client.acked == [] + + +@pytest.mark.anyio +async def test_consumer_acks_malformed_jobs() -> None: + client = FakeRedisClient() + consumer = CaseDurationSyncConsumer(cast(RedisClient, client)) + + await consumer._handle_entries([("1-0", {"reason": "case_event"})]) + + assert client.acked == [["1-0"]] + + +@pytest.mark.anyio +async def test_consumer_acks_successful_backfill_jobs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = FakeRedisClient() + consumer = CaseDurationSyncConsumer(cast(RedisClient, client)) + workspace_id = uuid.uuid4() + backfill_mock = AsyncMock(return_value=True) + monkeypatch.setattr(consumer, "_process_backfill_job", backfill_mock) + + await consumer._handle_entries( + [ + ( + "1-0", + { + "workspace_id": str(workspace_id), + "reason": "duration_definition_created", + }, + ) + ] + ) + + backfill_mock.assert_awaited_once() + await_args = backfill_mock.await_args + assert await_args is not None + job = await_args.args[0] + assert cast(Any, job).workspace_id == workspace_id + assert client.acked == [["1-0"]] diff --git a/tests/unit/test_case_events_service.py b/tests/unit/test_case_events_service.py index 55e176e89d..40b3f14170 100644 --- a/tests/unit/test_case_events_service.py +++ b/tests/unit/test_case_events_service.py @@ -1,12 +1,13 @@ import uuid from collections.abc import Iterator from datetime import UTC, datetime, timedelta -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession +from tracecat import config from tracecat.auth.types import Role from tracecat.authz.scopes import ADMIN_SCOPES, SERVICE_PRINCIPAL_SCOPES from tracecat.cases.enums import CaseEventType, CasePriority, CaseSeverity, CaseStatus @@ -34,9 +35,19 @@ @pytest.fixture(autouse=True) def stub_case_duration_sync() -> Iterator[None]: - with patch( - "tracecat.cases.service.CaseDurationService.sync_case_durations", - new=AsyncMock(return_value=None), + with ( + patch( + "tracecat.cases.service.CaseDurationService.sync_case_durations", + new=AsyncMock(return_value=None), + ), + patch( + "tracecat.cases.service.enqueue_case_duration_sync_after_commit", + return_value=None, + ), + patch( + "tracecat.cases.service.publish_case_event_payload", + new=AsyncMock(return_value=None), + ), ): yield @@ -90,6 +101,101 @@ async def test_case(cases_service: CasesService): @pytest.mark.anyio class TestCaseEventsService: + async def test_create_event_queues_duration_sync_by_default( + self, + case_events_service: CaseEventsService, + test_case, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Case events should enqueue duration sync outside the request transaction.""" + enqueue_calls: list[dict[str, object]] = [] + sync_mock = AsyncMock(return_value=None) + + def fake_enqueue(*args, **kwargs) -> None: + enqueue_calls.append(kwargs) + + monkeypatch.setattr( + "tracecat.cases.service.enqueue_case_duration_sync_after_commit", + fake_enqueue, + ) + monkeypatch.setattr( + "tracecat.cases.service.CaseDurationService.sync_case_durations", + sync_mock, + ) + + event_data = StatusChangedEvent( + type=CaseEventType.STATUS_CHANGED, + old=CaseStatus.NEW, + new=CaseStatus.IN_PROGRESS, + ) + await case_events_service.create_event(test_case, event_data) + + sync_mock.assert_not_awaited() + assert len(enqueue_calls) == 1 + assert enqueue_calls[0]["workspace_id"] == test_case.workspace_id + assert enqueue_calls[0]["case_id"] == test_case.id + assert enqueue_calls[0]["event_type"] == CaseEventType.STATUS_CHANGED.value + assert enqueue_calls[0]["reason"] == "case_event" + + async def test_create_event_can_sync_durations_inline( + self, + case_events_service: CaseEventsService, + test_case, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Create-case and explicit callers can still request inline materialization.""" + enqueue_mock = MagicMock() + sync_mock = AsyncMock(return_value=None) + + monkeypatch.setattr( + "tracecat.cases.service.enqueue_case_duration_sync_after_commit", + enqueue_mock, + ) + monkeypatch.setattr( + "tracecat.cases.service.CaseDurationService.sync_case_durations", + sync_mock, + ) + + event_data = CreatedEvent(type=CaseEventType.CASE_CREATED) + await case_events_service.create_event( + test_case, + event_data, + duration_sync="inline", + ) + + sync_mock.assert_awaited_once() + enqueue_mock.assert_not_called() + + async def test_create_event_syncs_inline_when_async_duration_sync_disabled( + self, + case_events_service: CaseEventsService, + test_case, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Disabling the async worker should fall back to prior inline sync behavior.""" + enqueue_mock = MagicMock() + sync_mock = AsyncMock(return_value=None) + + monkeypatch.setattr(config, "TRACECAT__CASE_DURATION_SYNC_ENABLED", False) + monkeypatch.setattr( + "tracecat.cases.service.enqueue_case_duration_sync_after_commit", + enqueue_mock, + ) + monkeypatch.setattr( + "tracecat.cases.service.CaseDurationService.sync_case_durations", + sync_mock, + ) + + event_data = StatusChangedEvent( + type=CaseEventType.STATUS_CHANGED, + old=CaseStatus.NEW, + new=CaseStatus.IN_PROGRESS, + ) + await case_events_service.create_event(test_case, event_data) + + sync_mock.assert_awaited_once() + enqueue_mock.assert_not_called() + async def test_create_case_created_event( self, case_events_service: CaseEventsService, test_case ) -> None: @@ -522,12 +628,24 @@ async def test_case_viewed_event_created_once_per_window( self, case_events_service: CaseEventsService, test_case, + monkeypatch: pytest.MonkeyPatch, ) -> None: """Case viewed events should be created at most once within the dedupe window.""" + enqueue_calls: list[dict[str, object]] = [] + + def fake_enqueue(*args, **kwargs) -> None: + enqueue_calls.append(kwargs) + + monkeypatch.setattr( + "tracecat.cases.service.enqueue_case_duration_sync_after_commit", + fake_enqueue, + ) + first_event = await case_events_service.create_case_viewed_event(test_case) assert first_event is not None assert first_event.type == CaseEventType.CASE_VIEWED assert first_event.user_id == case_events_service.role.user_id + assert enqueue_calls == [] await case_events_service.session.commit() duplicate_event = await case_events_service.create_case_viewed_event( diff --git a/tests/unit/test_cases_service.py b/tests/unit/test_cases_service.py index cf8f2d8b94..c46dbcaac8 100644 --- a/tests/unit/test_cases_service.py +++ b/tests/unit/test_cases_service.py @@ -47,9 +47,19 @@ @pytest.fixture(autouse=True) def stub_case_duration_sync() -> Iterator[None]: - with patch( - "tracecat.cases.service.CaseDurationService.sync_case_durations", - new=AsyncMock(return_value=None), + with ( + patch( + "tracecat.cases.service.CaseDurationService.sync_case_durations", + new=AsyncMock(return_value=None), + ), + patch( + "tracecat.cases.service.enqueue_case_duration_sync_after_commit", + return_value=None, + ), + patch( + "tracecat.cases.service.publish_case_event_payload", + new=AsyncMock(return_value=None), + ), ): yield diff --git a/tracecat/api/app.py b/tracecat/api/app.py index c08c0807bb..b41b7852f1 100644 --- a/tracecat/api/app.py +++ b/tracecat/api/app.py @@ -88,6 +88,7 @@ from tracecat.cases.attachments.router import router as case_attachments_router from tracecat.cases.dropdowns.router import definitions_router as case_dropdowns_router from tracecat.cases.dropdowns.router import values_router as case_dropdown_values_router +from tracecat.cases.durations.consumer import start_case_duration_sync_consumer from tracecat.cases.durations.router import router as case_durations_router from tracecat.cases.internal_router import ( comments_router as internal_comments_router, @@ -245,6 +246,14 @@ async def lifespan(app: FastAPI): ) logger.debug("Spawned background task for case trigger consumer") + case_duration_sync_task = None + if config.TRACECAT__CASE_DURATION_SYNC_ENABLED: + case_duration_sync_task = asyncio.create_task( + start_case_duration_sync_consumer(), + name="case_duration_sync_consumer", + ) + logger.debug("Spawned background task for case duration sync consumer") + logger.info( "Feature flags", feature_flags=[f.value for f in config.TRACECAT__FEATURE_FLAGS] ) @@ -315,6 +324,15 @@ async def lifespan(app: FastAPI): except Exception as e: logger.warning("Case trigger consumer stopped with error", error=e) + if case_duration_sync_task is not None: + case_duration_sync_task.cancel() + try: + await case_duration_sync_task + except asyncio.CancelledError: + logger.debug("Case duration sync consumer task cancelled") + except Exception as e: + logger.warning("Case duration sync consumer stopped with error", error=e) + async def setup_org_settings(session: AsyncSession, admin_role: Role): settings_service = SettingsService(session, role=admin_role) diff --git a/tracecat/cases/durations/consumer.py b/tracecat/cases/durations/consumer.py new file mode 100644 index 0000000000..e8c418d39c --- /dev/null +++ b/tracecat/cases/durations/consumer.py @@ -0,0 +1,407 @@ +"""Consumer for async case duration materialization jobs.""" + +from __future__ import annotations + +import asyncio +import os +import socket +import uuid +from dataclasses import dataclass +from time import monotonic +from typing import cast, get_args + +from redis.exceptions import ResponseError +from sqlalchemy import or_, select +from tenacity import RetryError + +from tracecat import config +from tracecat.auth.types import Role +from tracecat.authz.scopes import SERVICE_PRINCIPAL_SCOPES +from tracecat.cases.durations.service import CaseDurationService +from tracecat.cases.durations.sync_queue import ( + CaseDurationSyncReason, + publish_case_duration_sync, +) +from tracecat.cases.enums import CaseEventType +from tracecat.db.engine import get_async_session_bypass_rls_context_manager +from tracecat.db.locks import ( + derive_lock_key_from_parts, + pg_advisory_unlock, + try_pg_advisory_lock, +) +from tracecat.db.models import Case, Workspace +from tracecat.db.models import CaseDurationDefinition as CaseDurationDefinitionDB +from tracecat.exceptions import TracecatNotFoundError +from tracecat.logger import logger +from tracecat.redis.client import RedisClient, get_redis_client + +CASE_DURATION_SYNC_REASONS = frozenset( + cast(tuple[str, ...], get_args(CaseDurationSyncReason)) +) + + +@dataclass(frozen=True) +class CaseDurationSyncJob: + workspace_id: uuid.UUID + reason: CaseDurationSyncReason + case_id: uuid.UUID | None = None + event_type: str | None = None + cursor: int | None = None + + +class CaseDurationSyncConsumer: + """Consume and coalesce case duration sync jobs.""" + + def __init__( + self, client: RedisClient, *, consumer_name: str | None = None + ) -> None: + self.client = client + self.stream_key = config.TRACECAT__CASE_DURATION_SYNC_STREAM_KEY + self.group = config.TRACECAT__CASE_DURATION_SYNC_GROUP + self.block_ms = config.TRACECAT__CASE_DURATION_SYNC_BLOCK_MS + self.batch = config.TRACECAT__CASE_DURATION_SYNC_BATCH + self.claim_idle_ms = config.TRACECAT__CASE_DURATION_SYNC_CLAIM_IDLE_MS + self.backfill_batch = config.TRACECAT__CASE_DURATION_SYNC_BACKFILL_BATCH + self.consumer_name = consumer_name or f"{socket.gethostname()}:{os.getpid()}" + self._pending_check_interval = max(self.claim_idle_ms / 1000.0, 30.0) + + async def run(self) -> None: + if not config.TRACECAT__CASE_DURATION_SYNC_ENABLED: + logger.info("Case duration sync disabled; skipping consumer") + return + + await self._ensure_group() + logger.info( + "Case duration sync consumer started", + stream_key=self.stream_key, + group=self.group, + consumer=self.consumer_name, + ) + last_pending_check = monotonic() + try: + while True: + try: + messages = await self.client.xreadgroup( + group_name=self.group, + consumer_name=self.consumer_name, + streams={self.stream_key: ">"}, + count=self.batch, + block=self.block_ms, + ) + except (ResponseError, RetryError) as e: + if self._is_nogroup_error(e): + logger.warning( + "Redis case duration sync stream/group missing; recreating", + stream_key=self.stream_key, + group=self.group, + error=str(e), + ) + await self._ensure_group() + continue + raise + if messages: + for _stream, entries in messages: + await self._handle_entries(entries) + else: + now = monotonic() + if now - last_pending_check >= self._pending_check_interval: + await self._claim_idle_messages() + last_pending_check = now + await asyncio.sleep(0) + except asyncio.CancelledError: + logger.info("Case duration sync consumer cancelled") + raise + except Exception as e: + logger.error( + "Case duration sync consumer stopped due to error", error=str(e) + ) + raise + + def _is_nogroup_error(self, error: Exception) -> bool: + if isinstance(error, ResponseError): + return "NOGROUP" in str(error) + if isinstance(error, RetryError): + last_exc = error.last_attempt.exception() + return isinstance(last_exc, ResponseError) and "NOGROUP" in str(last_exc) + return False + + async def _ensure_group(self) -> None: + try: + await self.client.xgroup_create( + self.stream_key, + self.group, + id="$", + ignore_busygroup=True, + ) + except ResponseError as e: + if "BUSYGROUP" in str(e): + return + raise + + async def _handle_entries(self, entries: list[tuple[str, dict[str, str]]]) -> None: + case_jobs: dict[tuple[uuid.UUID, uuid.UUID], list[str]] = {} + case_event_types: dict[tuple[uuid.UUID, uuid.UUID], set[str]] = {} + for message_id, fields in entries: + job = self._parse_job(fields) + if job is None: + await self.client.xack(self.stream_key, self.group, [message_id]) + continue + + if job.case_id is None: + try: + should_ack = await self._process_backfill_job(job) + except Exception: + logger.exception( + "Failed to process case duration backfill job", + workspace_id=str(job.workspace_id), + reason=job.reason, + ) + should_ack = False + if should_ack: + await self.client.xack(self.stream_key, self.group, [message_id]) + continue + + key = (job.workspace_id, job.case_id) + case_jobs.setdefault(key, []).append(message_id) + if job.event_type: + case_event_types.setdefault(key, set()).add(job.event_type) + + for (workspace_id, case_id), message_ids in case_jobs.items(): + try: + synced = await self._sync_case_duration( + workspace_id, + case_id, + event_types=case_event_types.get((workspace_id, case_id)), + ) + except Exception: + logger.exception( + "Failed to process case duration sync job", + workspace_id=str(workspace_id), + case_id=str(case_id), + ) + continue + + if synced: + await self.client.xack(self.stream_key, self.group, message_ids) + + def _parse_job(self, fields: dict[str, str]) -> CaseDurationSyncJob | None: + workspace_id = fields.get("workspace_id") + reason = fields.get("reason") + if not (workspace_id and reason): + logger.warning("Malformed case duration sync message", fields=fields) + return None + if reason not in CASE_DURATION_SYNC_REASONS: + logger.warning("Unknown case duration sync reason", fields=fields) + return None + try: + return CaseDurationSyncJob( + workspace_id=uuid.UUID(workspace_id), + case_id=uuid.UUID(case_id) + if (case_id := fields.get("case_id")) + else None, + event_type=fields.get("event_type"), + reason=cast(CaseDurationSyncReason, reason), + cursor=int(cursor) if (cursor := fields.get("cursor")) else None, + ) + except (TypeError, ValueError): + logger.warning("Invalid IDs in case duration sync message", fields=fields) + return None + + async def _sync_case_duration( + self, + workspace_id: uuid.UUID, + case_id: uuid.UUID, + *, + event_types: set[str] | None = None, + ) -> bool: + lock_key = derive_lock_key_from_parts( + "case-duration-sync", + str(workspace_id), + str(case_id), + ) + async with get_async_session_bypass_rls_context_manager() as session: + role = await self._get_service_role(session, workspace_id) + if role is None: + return True + + if not await self._event_types_require_sync( + session, + workspace_id=workspace_id, + event_types=event_types or set(), + ): + logger.debug( + "Skipping case duration sync; no definitions use event types", + workspace_id=str(workspace_id), + case_id=str(case_id), + event_types=sorted(event_types or set()), + ) + return True + + locked = await try_pg_advisory_lock(session, lock_key) + if not locked: + await session.rollback() + logger.debug( + "Case duration sync already locked; leaving message pending", + workspace_id=str(workspace_id), + case_id=str(case_id), + ) + return False + + try: + await CaseDurationService( + session=session, role=role + ).sync_case_durations(case_id) + await session.commit() + return True + except TracecatNotFoundError: + await session.rollback() + logger.info( + "Skipping case duration sync for deleted case", + workspace_id=str(workspace_id), + case_id=str(case_id), + ) + return True + except Exception: + await session.rollback() + logger.exception( + "Failed to sync case durations", + workspace_id=str(workspace_id), + case_id=str(case_id), + ) + return False + finally: + try: + await pg_advisory_unlock(session, lock_key) + await session.commit() + except Exception: + await session.rollback() + logger.warning( + "Failed to release case duration sync advisory lock", + workspace_id=str(workspace_id), + case_id=str(case_id), + ) + + async def _event_types_require_sync( + self, + session, + *, + workspace_id: uuid.UUID, + event_types: set[str], + ) -> bool: + if not event_types: + return True + + parsed_event_types: list[CaseEventType] = [] + for event_type in event_types: + try: + parsed_event_types.append(CaseEventType(event_type)) + except ValueError: + logger.warning( + "Unknown case event type in duration sync job", + workspace_id=str(workspace_id), + event_type=event_type, + ) + return True + + stmt = ( + select(CaseDurationDefinitionDB.id) + .where( + CaseDurationDefinitionDB.workspace_id == workspace_id, + or_( + CaseDurationDefinitionDB.start_event_type.in_(parsed_event_types), + CaseDurationDefinitionDB.end_event_type.in_(parsed_event_types), + ), + ) + .limit(1) + ) + result = await session.execute(stmt) + return result.scalar_one_or_none() is not None + + async def _process_backfill_job(self, job: CaseDurationSyncJob) -> bool: + async with get_async_session_bypass_rls_context_manager() as session: + stmt = ( + select(Case.surrogate_id, Case.id) + .where(Case.workspace_id == job.workspace_id) + .order_by(Case.surrogate_id.asc()) + .limit(self.backfill_batch) + ) + if job.cursor is not None: + stmt = stmt.where(Case.surrogate_id > job.cursor) + result = await session.execute(stmt) + case_rows = result.tuples().all() + + for _surrogate_id, case_id in case_rows: + await publish_case_duration_sync( + workspace_id=job.workspace_id, + case_id=case_id, + reason="duration_definition_backfill", + ) + + if len(case_rows) == self.backfill_batch: + next_cursor = case_rows[-1][0] + await publish_case_duration_sync( + workspace_id=job.workspace_id, + reason="duration_definition_backfill", + cursor=next_cursor, + ) + return True + + async def _get_service_role(self, session, workspace_id: uuid.UUID) -> Role | None: + result = await session.execute( + select(Workspace).where(Workspace.id == workspace_id) + ) + workspace = result.scalars().first() + if workspace is None: + logger.info( + "Skipping case duration sync for deleted workspace", + workspace_id=str(workspace_id), + ) + return None + return Role( + type="service", + workspace_id=workspace_id, + organization_id=workspace.organization_id, + user_id=None, + service_id="tracecat-case-duration-sync", + scopes=SERVICE_PRINCIPAL_SCOPES["tracecat-case-duration-sync"], + ) + + async def _claim_idle_messages(self) -> None: + pending = await self.client.xpending_range( + self.stream_key, + self.group, + min_id="-", + max_id="+", + count=self.batch, + idle=self.claim_idle_ms, + ) + if not pending: + return + + message_ids: list[str] = [] + for entry in pending: + msg_id = None + if isinstance(entry, dict): + msg_id = entry.get("message_id") or entry.get("id") + else: + msg_id = getattr(entry, "message_id", None) + if msg_id: + message_ids.append(msg_id) + + if not message_ids: + return + + claimed = await self.client.xclaim( + self.stream_key, + self.group, + self.consumer_name, + self.claim_idle_ms, + message_ids, + ) + await self._handle_entries(claimed) + + +async def start_case_duration_sync_consumer() -> None: + client = await get_redis_client() + consumer = CaseDurationSyncConsumer(client) + await consumer.run() diff --git a/tracecat/cases/durations/router.py b/tracecat/cases/durations/router.py index 55b658e541..bdb3338b79 100644 --- a/tracecat/cases/durations/router.py +++ b/tracecat/cases/durations/router.py @@ -191,17 +191,8 @@ async def list_case_durations( session: AsyncDBSession, case_id: uuid.UUID, ) -> list[CaseDurationRead]: - """Sync and list case durations for the provided case.""" + """List materialized case durations for the provided case.""" service = CaseDurationService(session=session, role=role) - try: - await service.sync_case_durations(case_id) - await session.commit() - except Exception: - await session.rollback() - logger.error( - "Failed to sync case durations before listing", - case_id=str(case_id), - ) try: return await service.list_durations(case_id) except TracecatNotFoundError as err: diff --git a/tracecat/cases/durations/schemas.py b/tracecat/cases/durations/schemas.py index 3c2f20aee6..822be2efb2 100644 --- a/tracecat/cases/durations/schemas.py +++ b/tracecat/cases/durations/schemas.py @@ -89,6 +89,9 @@ def validate_filters_for_event_type(self) -> CaseDurationEventAnchor: ): return self + if self.event_type is CaseEventType.CASE_VIEWED: + raise ValueError("case_viewed cannot be used as a duration anchor") + filters = self.filters allowed_fields = _allowed_filter_fields_for_event_type(self.event_type) active_fields = _active_filter_fields(filters) diff --git a/tracecat/cases/durations/service.py b/tracecat/cases/durations/service.py index d9a51ed101..0d231cf112 100644 --- a/tracecat/cases/durations/service.py +++ b/tracecat/cases/durations/service.py @@ -41,6 +41,9 @@ CaseDurationRead, CaseDurationUpdate, ) +from tracecat.cases.durations.sync_queue import ( + enqueue_case_duration_sync_after_commit, +) from tracecat.cases.enums import CaseEventType from tracecat.concurrency import cooperative_every from tracecat.db.models import Case, CaseDuration, CaseEvent @@ -101,6 +104,11 @@ async def create_definition( **self._anchor_attributes(params.end_anchor, "end"), ) self.session.add(entity) + enqueue_case_duration_sync_after_commit( + self.session, + workspace_id=self.workspace_id, + reason="duration_definition_created", + ) await self.session.commit() await self.session.refresh(entity) return self._to_read_model(entity) @@ -142,6 +150,11 @@ async def update_definition( self._apply_anchor(entity, end_anchor, "end") self.session.add(entity) + enqueue_case_duration_sync_after_commit( + self.session, + workspace_id=self.workspace_id, + reason="duration_definition_updated", + ) await self.session.commit() await self.session.refresh(entity) return self._to_read_model(entity) diff --git a/tracecat/cases/durations/sync_queue.py b/tracecat/cases/durations/sync_queue.py new file mode 100644 index 0000000000..cc5224d1ba --- /dev/null +++ b/tracecat/cases/durations/sync_queue.py @@ -0,0 +1,84 @@ +"""Redis-backed queue helpers for async case duration materialization.""" + +from __future__ import annotations + +import uuid +from typing import Literal + +from sqlalchemy.ext.asyncio import AsyncSession + +from tracecat import config +from tracecat.db.session_events import add_after_commit_callback +from tracecat.logger import logger +from tracecat.redis.client import get_redis_client + +CaseDurationSyncReason = Literal[ + "case_event", + "duration_definition_created", + "duration_definition_updated", + "duration_definition_backfill", +] + + +async def publish_case_duration_sync( + *, + workspace_id: uuid.UUID, + reason: CaseDurationSyncReason, + case_id: uuid.UUID | None = None, + event_type: str | None = None, + cursor: int | None = None, +) -> str | None: + """Publish a case duration sync job to Redis.""" + if not config.TRACECAT__CASE_DURATION_SYNC_ENABLED: + return None + + fields = { + "workspace_id": str(workspace_id), + "reason": reason, + } + if case_id is not None: + fields["case_id"] = str(case_id) + if event_type is not None: + fields["event_type"] = event_type + if cursor is not None: + fields["cursor"] = str(cursor) + + client = await get_redis_client() + message_id = await client.xadd( + stream_key=config.TRACECAT__CASE_DURATION_SYNC_STREAM_KEY, + fields=fields, + maxlen=config.TRACECAT__CASE_DURATION_SYNC_MAXLEN, + approximate=True, + expire_seconds=None, + ) + logger.debug( + "Queued case duration sync", + message_id=message_id, + workspace_id=str(workspace_id), + case_id=str(case_id) if case_id is not None else None, + reason=reason, + ) + return message_id + + +def enqueue_case_duration_sync_after_commit( + session: AsyncSession, + *, + workspace_id: uuid.UUID, + reason: CaseDurationSyncReason, + case_id: uuid.UUID | None = None, + event_type: str | None = None, + cursor: int | None = None, +) -> None: + """Register a duration sync publish after the current transaction commits.""" + + async def _publish() -> None: + await publish_case_duration_sync( + workspace_id=workspace_id, + case_id=case_id, + event_type=event_type, + reason=reason, + cursor=cursor, + ) + + add_after_commit_callback(session, _publish) diff --git a/tracecat/cases/service.py b/tracecat/cases/service.py index 91fc19b231..6c9de03dcd 100644 --- a/tracecat/cases/service.py +++ b/tracecat/cases/service.py @@ -16,6 +16,7 @@ from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.sql.elements import ColumnElement +from tracecat import config from tracecat.audit.enums import AuditEventStatus from tracecat.audit.logger import audit_log from tracecat.audit.service import AuditService @@ -30,6 +31,9 @@ from tracecat.cases.dropdowns.service import CaseDropdownValuesService from tracecat.cases.durations.schemas import CaseDurationRead from tracecat.cases.durations.service import CaseDurationService +from tracecat.cases.durations.sync_queue import ( + enqueue_case_duration_sync_after_commit, +) from tracecat.cases.enums import ( CaseEventType, CasePriority, @@ -869,6 +873,7 @@ async def create_case(self, params: CaseCreate) -> Case: await self.events.create_event( case=case, event=CreatedEvent(wf_exec_id=run_ctx.wf_exec_id if run_ctx else None), + duration_sync="inline", ) if params.dropdown_values is not None: @@ -2364,6 +2369,7 @@ async def create_event( event: CaseEventVariant, *, publish_case_trigger: bool = True, + duration_sync: Literal["async", "inline", "none"] = "async", ) -> CaseEvent: """Create a new activity record for a case with variant-specific data. @@ -2371,8 +2377,8 @@ async def create_event( wrapping operations in a transaction and committing once at the end to preserve atomicity across multi-step updates. - Duration sync is performed automatically after each event is created, - so callers do not need to call sync_case_durations separately. + Duration sync is queued after commit by default. Callers that need + immediate materialization can request inline sync. """ db_event = CaseEvent( @@ -2407,9 +2413,21 @@ async def _publish_case_event() -> None: add_after_commit_callback(self.session, _publish_case_event) - # Auto-sync durations whenever an event is created - durations_service = CaseDurationService(session=self.session, role=self.role) - await durations_service.sync_case_durations(case) + if duration_sync == "inline" or ( + duration_sync == "async" and not config.TRACECAT__CASE_DURATION_SYNC_ENABLED + ): + durations_service = CaseDurationService( + session=self.session, role=self.role + ) + await durations_service.sync_case_durations(case) + elif duration_sync == "async": + enqueue_case_duration_sync_after_commit( + self.session, + workspace_id=case.workspace_id, + case_id=case.id, + event_type=event_type, + reason="case_event", + ) return db_event @@ -2446,7 +2464,11 @@ async def create_case_viewed_event( if now_utc - last_created_at < dedupe_window: return None - return await self.create_event(case=case, event=CaseViewedEvent()) + return await self.create_event( + case=case, + event=CaseViewedEvent(), + duration_sync="none", + ) class CaseTasksService(BaseWorkspaceService): diff --git a/tracecat/config.py b/tracecat/config.py index 9e6314ccf8..2ef2259d8c 100644 --- a/tracecat/config.py +++ b/tracecat/config.py @@ -172,6 +172,7 @@ class RLSMode(StrEnum): "tracecat-runner", "tracecat-schedule-runner", "tracecat-case-triggers", + "tracecat-case-duration-sync", "tracecat-ui", ] TRACECAT__DEFAULT_USER_ID = uuid.UUID(int=0) @@ -967,6 +968,46 @@ def _parse_auth_types() -> set[AuthType]: ) """TTL for case trigger lock keys in seconds.""" +TRACECAT__CASE_DURATION_SYNC_ENABLED = ( + os.environ.get("TRACECAT__CASE_DURATION_SYNC_ENABLED", "true").lower() == "true" +) +"""Enable async case duration materialization from case event writes.""" + +TRACECAT__CASE_DURATION_SYNC_STREAM_KEY = os.environ.get( + "TRACECAT__CASE_DURATION_SYNC_STREAM_KEY", "case-duration-sync" +) +"""Redis stream key for case duration sync jobs.""" + +TRACECAT__CASE_DURATION_SYNC_GROUP = os.environ.get( + "TRACECAT__CASE_DURATION_SYNC_GROUP", "case-duration-sync" +) +"""Redis consumer group for case duration sync processing.""" + +TRACECAT__CASE_DURATION_SYNC_BLOCK_MS = int( + os.environ.get("TRACECAT__CASE_DURATION_SYNC_BLOCK_MS") or 2000 +) +"""XREADGROUP block timeout in milliseconds for duration sync jobs.""" + +TRACECAT__CASE_DURATION_SYNC_BATCH = int( + os.environ.get("TRACECAT__CASE_DURATION_SYNC_BATCH") or 100 +) +"""Maximum number of duration sync jobs to read per batch.""" + +TRACECAT__CASE_DURATION_SYNC_CLAIM_IDLE_MS = int( + os.environ.get("TRACECAT__CASE_DURATION_SYNC_CLAIM_IDLE_MS") or 300000 +) +"""Idle time before claiming pending duration sync jobs.""" + +TRACECAT__CASE_DURATION_SYNC_MAXLEN = int( + os.environ.get("TRACECAT__CASE_DURATION_SYNC_MAXLEN") or 30000 +) +"""Approximate max length for the case duration sync stream.""" + +TRACECAT__CASE_DURATION_SYNC_BACKFILL_BATCH = int( + os.environ.get("TRACECAT__CASE_DURATION_SYNC_BACKFILL_BATCH") or 250 +) +"""Number of cases to enqueue per duration definition backfill job.""" + # === File limits === # TRACECAT__MAX_ATTACHMENT_SIZE_BYTES = int( os.environ.get("TRACECAT__MAX_ATTACHMENT_SIZE_BYTES") or 20 * 1024 * 1024 diff --git a/tracecat/identifiers/__init__.py b/tracecat/identifiers/__init__.py index 9472d1bc7f..4df2dce8fa 100644 --- a/tracecat/identifiers/__init__.py +++ b/tracecat/identifiers/__init__.py @@ -90,6 +90,7 @@ "tracecat-cli", "tracecat-executor", "tracecat-agent-executor", + "tracecat-case-duration-sync", "tracecat-case-triggers", "tracecat-llm-gateway", "tracecat-mcp", From 95f5e752bb73e468452ede8bf7efb12db62b0e38 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Mon, 1 Jun 2026 14:03:50 -0700 Subject: [PATCH 2/7] fix(cases): match status duration aliases --- .../unit/test_case_duration_sync_consumer.py | 46 +++++++++++++++++++ tracecat/cases/durations/consumer.py | 17 ++++++- 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_case_duration_sync_consumer.py b/tests/unit/test_case_duration_sync_consumer.py index b88ab9b38f..f2bf391694 100644 --- a/tests/unit/test_case_duration_sync_consumer.py +++ b/tests/unit/test_case_duration_sync_consumer.py @@ -5,6 +5,7 @@ import pytest from tracecat.cases.durations.consumer import CaseDurationSyncConsumer +from tracecat.cases.enums import CaseEventType from tracecat.redis.client import RedisClient @@ -22,6 +23,28 @@ async def xack( self.acked.append(message_ids) +class FakeScalarResult: + def __init__(self, value: uuid.UUID | None) -> None: + self.value = value + + def scalar_one_or_none(self) -> uuid.UUID | None: + return self.value + + +class FakeDefinitionMatchSession: + async def execute(self, stmt: Any) -> FakeScalarResult: + compiled = stmt.compile() + event_types = { + event_type + for value in compiled.params.values() + if isinstance(value, list) + for event_type in value + } + return FakeScalarResult( + uuid.uuid4() if CaseEventType.STATUS_CHANGED in event_types else None + ) + + @pytest.mark.anyio async def test_consumer_coalesces_case_jobs_by_case( monkeypatch: pytest.MonkeyPatch, @@ -160,3 +183,26 @@ async def test_consumer_acks_successful_backfill_jobs( job = await_args.args[0] assert cast(Any, job).workspace_id == workspace_id assert client.acked == [["1-0"]] + + +@pytest.mark.anyio +async def test_event_types_require_sync_matches_status_changed_aliases() -> None: + session = FakeDefinitionMatchSession() + consumer = CaseDurationSyncConsumer(cast(RedisClient, FakeRedisClient())) + workspace_id = uuid.uuid4() + + assert await consumer._event_types_require_sync( + session, + workspace_id=workspace_id, + event_types={"case_closed"}, + ) + assert await consumer._event_types_require_sync( + session, + workspace_id=workspace_id, + event_types={"case_reopened"}, + ) + assert not await consumer._event_types_require_sync( + session, + workspace_id=workspace_id, + event_types={"case_updated"}, + ) diff --git a/tracecat/cases/durations/consumer.py b/tracecat/cases/durations/consumer.py index e8c418d39c..41d0ef4e06 100644 --- a/tracecat/cases/durations/consumer.py +++ b/tracecat/cases/durations/consumer.py @@ -38,6 +38,9 @@ CASE_DURATION_SYNC_REASONS = frozenset( cast(tuple[str, ...], get_args(CaseDurationSyncReason)) ) +STATUS_CHANGED_ALIASES = frozenset( + (CaseEventType.CASE_CLOSED, CaseEventType.CASE_REOPENED) +) @dataclass(frozen=True) @@ -303,13 +306,23 @@ async def _event_types_require_sync( ) return True + matching_event_types = list(parsed_event_types) + if ( + any( + event_type in STATUS_CHANGED_ALIASES + for event_type in parsed_event_types + ) + and CaseEventType.STATUS_CHANGED not in matching_event_types + ): + matching_event_types.append(CaseEventType.STATUS_CHANGED) + stmt = ( select(CaseDurationDefinitionDB.id) .where( CaseDurationDefinitionDB.workspace_id == workspace_id, or_( - CaseDurationDefinitionDB.start_event_type.in_(parsed_event_types), - CaseDurationDefinitionDB.end_event_type.in_(parsed_event_types), + CaseDurationDefinitionDB.start_event_type.in_(matching_event_types), + CaseDurationDefinitionDB.end_event_type.in_(matching_event_types), ), ) .limit(1) From 19974b485d3138c7437a6bf6957136b8435cd829 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Mon, 1 Jun 2026 14:06:39 -0700 Subject: [PATCH 3/7] fix(cases): reclaim duration sync retries --- .../unit/test_case_duration_sync_consumer.py | 48 +++++++++++++++++++ tracecat/cases/durations/consumer.py | 9 ++-- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_case_duration_sync_consumer.py b/tests/unit/test_case_duration_sync_consumer.py index f2bf391694..110a57d4e6 100644 --- a/tests/unit/test_case_duration_sync_consumer.py +++ b/tests/unit/test_case_duration_sync_consumer.py @@ -1,3 +1,4 @@ +import asyncio import uuid from typing import Any, cast from unittest.mock import AsyncMock, MagicMock @@ -206,3 +207,50 @@ async def test_event_types_require_sync_matches_status_changed_aliases() -> None workspace_id=workspace_id, event_types={"case_updated"}, ) + + +@pytest.mark.anyio +async def test_consumer_claims_idle_messages_while_stream_is_busy( + monkeypatch: pytest.MonkeyPatch, +) -> None: + entries = [ + ( + "1-0", + { + "workspace_id": str(uuid.uuid4()), + "case_id": str(uuid.uuid4()), + "reason": "case_event", + }, + ) + ] + client = AsyncMock() + client.xreadgroup = AsyncMock( + side_effect=[ + [("stream", entries)], + asyncio.CancelledError(), + ] + ) + consumer = CaseDurationSyncConsumer(cast(RedisClient, client)) + consumer._pending_check_interval = 10 + ensure_group_mock = AsyncMock() + handle_entries_mock = AsyncMock() + claim_idle_mock = AsyncMock() + monotonic_mock = MagicMock(side_effect=[0.0, 11.0]) + monkeypatch.setattr( + "tracecat.cases.durations.consumer.config.TRACECAT__CASE_DURATION_SYNC_ENABLED", + True, + ) + monkeypatch.setattr( + "tracecat.cases.durations.consumer.monotonic", + monotonic_mock, + ) + monkeypatch.setattr(consumer, "_ensure_group", ensure_group_mock) + monkeypatch.setattr(consumer, "_handle_entries", handle_entries_mock) + monkeypatch.setattr(consumer, "_claim_idle_messages", claim_idle_mock) + + with pytest.raises(asyncio.CancelledError): + await consumer.run() + + ensure_group_mock.assert_awaited_once() + handle_entries_mock.assert_awaited_once_with(entries) + claim_idle_mock.assert_awaited_once() diff --git a/tracecat/cases/durations/consumer.py b/tracecat/cases/durations/consumer.py index 41d0ef4e06..1423d6bf04 100644 --- a/tracecat/cases/durations/consumer.py +++ b/tracecat/cases/durations/consumer.py @@ -105,11 +105,10 @@ async def run(self) -> None: if messages: for _stream, entries in messages: await self._handle_entries(entries) - else: - now = monotonic() - if now - last_pending_check >= self._pending_check_interval: - await self._claim_idle_messages() - last_pending_check = now + now = monotonic() + if now - last_pending_check >= self._pending_check_interval: + await self._claim_idle_messages() + last_pending_check = now await asyncio.sleep(0) except asyncio.CancelledError: logger.info("Case duration sync consumer cancelled") From 4d7437ba7c9b24194c2ce218f9b30a61eee70126 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Thu, 18 Jun 2026 16:07:57 -0400 Subject: [PATCH 4/7] fix(cases): preserve duration backfill sync --- .../unit/test_case_duration_sync_consumer.py | 43 +++++++++++++++++++ tracecat/cases/durations/consumer.py | 11 ++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_case_duration_sync_consumer.py b/tests/unit/test_case_duration_sync_consumer.py index 110a57d4e6..5dc51c88c6 100644 --- a/tests/unit/test_case_duration_sync_consumer.py +++ b/tests/unit/test_case_duration_sync_consumer.py @@ -88,6 +88,49 @@ async def test_consumer_coalesces_case_jobs_by_case( assert client.acked == [["1-0", "2-0"]] +@pytest.mark.anyio +async def test_consumer_forces_sync_when_backfill_coalesces_with_case_event( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = FakeRedisClient() + consumer = CaseDurationSyncConsumer(cast(RedisClient, client)) + workspace_id = uuid.uuid4() + case_id = uuid.uuid4() + sync_mock = AsyncMock(return_value=True) + monkeypatch.setattr(consumer, "_sync_case_duration", sync_mock) + + # A per-case backfill job (no event_type) coalesced with a non-matching + # case_event must still force an unconditional sync, not be filtered out. + await consumer._handle_entries( + [ + ( + "1-0", + { + "workspace_id": str(workspace_id), + "case_id": str(case_id), + "reason": "duration_definition_backfill", + }, + ), + ( + "2-0", + { + "workspace_id": str(workspace_id), + "case_id": str(case_id), + "reason": "case_event", + "event_type": "case_updated", + }, + ), + ] + ) + + sync_mock.assert_awaited_once_with( + workspace_id, + case_id, + event_types=None, + ) + assert client.acked == [["1-0", "2-0"]] + + @pytest.mark.anyio async def test_consumer_leaves_locked_case_jobs_pending( monkeypatch: pytest.MonkeyPatch, diff --git a/tracecat/cases/durations/consumer.py b/tracecat/cases/durations/consumer.py index 1423d6bf04..5e2e62a240 100644 --- a/tracecat/cases/durations/consumer.py +++ b/tracecat/cases/durations/consumer.py @@ -143,6 +143,7 @@ async def _ensure_group(self) -> None: async def _handle_entries(self, entries: list[tuple[str, dict[str, str]]]) -> None: case_jobs: dict[tuple[uuid.UUID, uuid.UUID], list[str]] = {} case_event_types: dict[tuple[uuid.UUID, uuid.UUID], set[str]] = {} + force_sync_keys: set[tuple[uuid.UUID, uuid.UUID]] = set() for message_id, fields in entries: job = self._parse_job(fields) if job is None: @@ -167,13 +168,21 @@ async def _handle_entries(self, entries: list[tuple[str, dict[str, str]]]) -> No case_jobs.setdefault(key, []).append(message_id) if job.event_type: case_event_types.setdefault(key, set()).add(job.event_type) + else: + # A case-scoped job without an event type (e.g. a backfill job) + # means "sync unconditionally". Record the key so a coalesced, + # non-matching event type cannot make the event-type filter skip + # and ack it. + force_sync_keys.add(key) for (workspace_id, case_id), message_ids in case_jobs.items(): + key = (workspace_id, case_id) + event_types = None if key in force_sync_keys else case_event_types.get(key) try: synced = await self._sync_case_duration( workspace_id, case_id, - event_types=case_event_types.get((workspace_id, case_id)), + event_types=event_types, ) except Exception: logger.exception( From b21e7f525dc0198d674deb143bd5e2717a0e3de0 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Thu, 18 Jun 2026 16:08:29 -0400 Subject: [PATCH 5/7] fix(cases): read duration sync backlog --- tests/unit/test_case_duration_sync_consumer.py | 17 +++++++++++++++++ tracecat/cases/durations/consumer.py | 6 +++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_case_duration_sync_consumer.py b/tests/unit/test_case_duration_sync_consumer.py index 5dc51c88c6..ca108b9d16 100644 --- a/tests/unit/test_case_duration_sync_consumer.py +++ b/tests/unit/test_case_duration_sync_consumer.py @@ -229,6 +229,23 @@ async def test_consumer_acks_successful_backfill_jobs( assert client.acked == [["1-0"]] +@pytest.mark.anyio +async def test_ensure_group_reads_backlog_from_start() -> None: + client = AsyncMock() + consumer = CaseDurationSyncConsumer(cast(RedisClient, client)) + + await consumer._ensure_group() + + # "0" ensures jobs published during the startup gap (before the group + # exists) are still delivered, rather than being skipped by "$". + client.xgroup_create.assert_awaited_once_with( + consumer.stream_key, + consumer.group, + id="0", + ignore_busygroup=True, + ) + + @pytest.mark.anyio async def test_event_types_require_sync_matches_status_changed_aliases() -> None: session = FakeDefinitionMatchSession() diff --git a/tracecat/cases/durations/consumer.py b/tracecat/cases/durations/consumer.py index 5e2e62a240..8082024a8d 100644 --- a/tracecat/cases/durations/consumer.py +++ b/tracecat/cases/durations/consumer.py @@ -129,10 +129,14 @@ def _is_nogroup_error(self, error: Exception) -> bool: async def _ensure_group(self) -> None: try: + # Read from the start of the stream ("0") rather than only new + # messages ("$"). The consumer is started as an unawaited background + # task, so jobs can be published before the group exists; "0" also + # lets the group reclaim retained jobs after a NOGROUP recovery. await self.client.xgroup_create( self.stream_key, self.group, - id="$", + id="0", ignore_busygroup=True, ) except ResponseError as e: From b25d6b6d60ce7bec7b9c0a58d13cd14253393c0a Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Thu, 18 Jun 2026 16:40:55 -0400 Subject: [PATCH 6/7] fix(cases): use transaction duration sync locks --- .../unit/test_case_duration_sync_consumer.py | 54 +++++++++++++++++++ tracecat/cases/durations/consumer.py | 16 +----- tracecat/db/locks.py | 11 ++++ 3 files changed, 67 insertions(+), 14 deletions(-) diff --git a/tests/unit/test_case_duration_sync_consumer.py b/tests/unit/test_case_duration_sync_consumer.py index ca108b9d16..9e7cbaf75c 100644 --- a/tests/unit/test_case_duration_sync_consumer.py +++ b/tests/unit/test_case_duration_sync_consumer.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import uuid from typing import Any, cast from unittest.mock import AsyncMock, MagicMock @@ -157,6 +158,59 @@ async def test_consumer_leaves_locked_case_jobs_pending( assert client.acked == [] +@pytest.mark.anyio +async def test_sync_case_duration_uses_transaction_scoped_lock( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = FakeRedisClient() + consumer = CaseDurationSyncConsumer(cast(RedisClient, client)) + workspace_id = uuid.uuid4() + case_id = uuid.uuid4() + fake_session = MagicMock() + fake_session.commit = AsyncMock() + fake_session.rollback = AsyncMock() + role = MagicMock() + duration_service = MagicMock() + duration_service.sync_case_durations = AsyncMock(return_value=[]) + duration_service_cls = MagicMock(return_value=duration_service) + lock_mock = AsyncMock(return_value=True) + + @contextlib.asynccontextmanager + async def fake_session_context(): + yield fake_session + + monkeypatch.setattr( + "tracecat.cases.durations.consumer.get_async_session_bypass_rls_context_manager", + fake_session_context, + ) + monkeypatch.setattr( + "tracecat.cases.durations.consumer.try_pg_advisory_xact_lock", + lock_mock, + ) + monkeypatch.setattr( + "tracecat.cases.durations.consumer.CaseDurationService", + duration_service_cls, + ) + monkeypatch.setattr( + consumer, + "_get_service_role", + AsyncMock(return_value=role), + ) + monkeypatch.setattr( + consumer, + "_event_types_require_sync", + AsyncMock(return_value=True), + ) + + assert await consumer._sync_case_duration(workspace_id, case_id) + + lock_mock.assert_awaited_once() + duration_service_cls.assert_called_once_with(session=fake_session, role=role) + duration_service.sync_case_durations.assert_awaited_once_with(case_id) + fake_session.commit.assert_awaited_once() + fake_session.rollback.assert_not_awaited() + + @pytest.mark.anyio async def test_consumer_leaves_failed_case_jobs_pending( monkeypatch: pytest.MonkeyPatch, diff --git a/tracecat/cases/durations/consumer.py b/tracecat/cases/durations/consumer.py index 8082024a8d..554fcb9cc5 100644 --- a/tracecat/cases/durations/consumer.py +++ b/tracecat/cases/durations/consumer.py @@ -26,8 +26,7 @@ from tracecat.db.engine import get_async_session_bypass_rls_context_manager from tracecat.db.locks import ( derive_lock_key_from_parts, - pg_advisory_unlock, - try_pg_advisory_lock, + try_pg_advisory_xact_lock, ) from tracecat.db.models import Case, Workspace from tracecat.db.models import CaseDurationDefinition as CaseDurationDefinitionDB @@ -252,7 +251,7 @@ async def _sync_case_duration( ) return True - locked = await try_pg_advisory_lock(session, lock_key) + locked = await try_pg_advisory_xact_lock(session, lock_key) if not locked: await session.rollback() logger.debug( @@ -284,17 +283,6 @@ async def _sync_case_duration( case_id=str(case_id), ) return False - finally: - try: - await pg_advisory_unlock(session, lock_key) - await session.commit() - except Exception: - await session.rollback() - logger.warning( - "Failed to release case duration sync advisory lock", - workspace_id=str(workspace_id), - case_id=str(case_id), - ) async def _event_types_require_sync( self, diff --git a/tracecat/db/locks.py b/tracecat/db/locks.py index fb712ce5c9..21a62f7fe2 100644 --- a/tracecat/db/locks.py +++ b/tracecat/db/locks.py @@ -61,6 +61,17 @@ async def try_pg_advisory_lock(session: AsyncSession, key: int) -> bool: return result.scalar() is True +async def try_pg_advisory_xact_lock(session: AsyncSession, key: int) -> bool: + """Try to acquire a transaction-scoped PostgreSQL advisory lock.""" + if not (-(2**63) <= key < 2**63): + raise ValueError(f"Lock key {key} out of range for PostgreSQL advisory locks") + + result = await session.execute( + text("SELECT pg_try_advisory_xact_lock(:key)"), {"key": key} + ) + return result.scalar() is True + + async def pg_advisory_unlock(session: AsyncSession, key: int) -> bool: """Release a PostgreSQL advisory lock. From e8d25a0a4b705baa06eee31215718f87495da7c8 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Thu, 18 Jun 2026 16:42:06 -0400 Subject: [PATCH 7/7] fix(cases): backfill durations without worker --- tests/unit/test_case_duration_service.py | 99 ++++++++++++++++++++++++ tracecat/cases/durations/service.py | 39 ++++++++-- 2 files changed, 130 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_case_duration_service.py b/tests/unit/test_case_duration_service.py index 3c649b8eda..c611d263cb 100644 --- a/tests/unit/test_case_duration_service.py +++ b/tests/unit/test_case_duration_service.py @@ -8,6 +8,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from tracecat import config from tracecat.cases.durations import ( CaseDurationAnchorSelection, CaseDurationComputation, @@ -122,6 +123,53 @@ def fake_enqueue(*args, **kwargs) -> None: assert calls[0]["reason"] == "duration_definition_created" +@pytest.mark.anyio +async def test_create_definition_syncs_inline_when_async_duration_sync_disabled( + session: AsyncSession, + svc_role, + monkeypatch: pytest.MonkeyPatch, +) -> None: + enqueue_calls: list[dict[str, object]] = [] + inline_sync_calls = 0 + + def fake_enqueue(*args, **kwargs) -> None: + enqueue_calls.append(kwargs) + + async def fake_inline_sync(self) -> None: + nonlocal inline_sync_calls + inline_sync_calls += 1 + + monkeypatch.setattr(config, "TRACECAT__CASE_DURATION_SYNC_ENABLED", False) + monkeypatch.setattr( + "tracecat.cases.durations.service.enqueue_case_duration_sync_after_commit", + fake_enqueue, + ) + monkeypatch.setattr( + CaseDurationDefinitionService, + "_sync_existing_case_durations_inline", + fake_inline_sync, + ) + + definition_service = CaseDurationDefinitionService(session=session, role=svc_role) + await definition_service.create_definition( + CaseDurationDefinitionCreate( + name="Inline Backfilled Duration", + start_anchor=CaseDurationEventAnchor( + event_type=CaseEventType.CASE_CREATED, + ), + end_anchor=CaseDurationEventAnchor( + event_type=CaseEventType.STATUS_CHANGED, + filters=CaseDurationEventFilters( + new_values=[CaseStatus.RESOLVED.value] + ), + ), + ) + ) + + assert enqueue_calls == [] + assert inline_sync_calls == 1 + + @pytest.mark.anyio async def test_update_definition_enqueues_backfill_after_commit( session: AsyncSession, @@ -164,6 +212,57 @@ def fake_enqueue(*args, **kwargs) -> None: assert calls[0]["reason"] == "duration_definition_updated" +@pytest.mark.anyio +async def test_update_definition_syncs_inline_when_async_duration_sync_disabled( + session: AsyncSession, + svc_role, + monkeypatch: pytest.MonkeyPatch, +) -> None: + enqueue_calls: list[dict[str, object]] = [] + inline_sync_calls = 0 + + def fake_enqueue(*args, **kwargs) -> None: + enqueue_calls.append(kwargs) + + async def fake_inline_sync(self) -> None: + nonlocal inline_sync_calls + inline_sync_calls += 1 + + definition_service = CaseDurationDefinitionService(session=session, role=svc_role) + definition = await definition_service.create_definition( + CaseDurationDefinitionCreate( + name="Inline Updated Duration", + start_anchor=CaseDurationEventAnchor( + event_type=CaseEventType.CASE_CREATED, + ), + end_anchor=CaseDurationEventAnchor( + event_type=CaseEventType.STATUS_CHANGED, + filters=CaseDurationEventFilters( + new_values=[CaseStatus.RESOLVED.value] + ), + ), + ) + ) + monkeypatch.setattr(config, "TRACECAT__CASE_DURATION_SYNC_ENABLED", False) + monkeypatch.setattr( + "tracecat.cases.durations.service.enqueue_case_duration_sync_after_commit", + fake_enqueue, + ) + monkeypatch.setattr( + CaseDurationDefinitionService, + "_sync_existing_case_durations_inline", + fake_inline_sync, + ) + + await definition_service.update_definition( + definition.id, + CaseDurationDefinitionUpdate(description="Updated inline description"), + ) + + assert enqueue_calls == [] + assert inline_sync_calls == 1 + + @pytest.mark.anyio async def test_compute_case_durations_from_events( session: AsyncSession, svc_role diff --git a/tracecat/cases/durations/service.py b/tracecat/cases/durations/service.py index 0d231cf112..3d9b8230d1 100644 --- a/tracecat/cases/durations/service.py +++ b/tracecat/cases/durations/service.py @@ -27,6 +27,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql.elements import ColumnElement +from tracecat import config from tracecat.auth.types import Role from tracecat.cases.durations.schemas import ( CaseDurationAnchorSelection, @@ -42,6 +43,7 @@ CaseDurationUpdate, ) from tracecat.cases.durations.sync_queue import ( + CaseDurationSyncReason, enqueue_case_duration_sync_after_commit, ) from tracecat.cases.enums import CaseEventType @@ -104,10 +106,8 @@ async def create_definition( **self._anchor_attributes(params.end_anchor, "end"), ) self.session.add(entity) - enqueue_case_duration_sync_after_commit( - self.session, - workspace_id=self.workspace_id, - reason="duration_definition_created", + await self._backfill_after_definition_change( + reason="duration_definition_created" ) await self.session.commit() await self.session.refresh(entity) @@ -150,10 +150,8 @@ async def update_definition( self._apply_anchor(entity, end_anchor, "end") self.session.add(entity) - enqueue_case_duration_sync_after_commit( - self.session, - workspace_id=self.workspace_id, - reason="duration_definition_updated", + await self._backfill_after_definition_change( + reason="duration_definition_updated" ) await self.session.commit() await self.session.refresh(entity) @@ -167,6 +165,31 @@ async def delete_definition(self, duration_id: uuid.UUID) -> None: await self.session.delete(entity) await self.session.commit() + async def _backfill_after_definition_change( + self, *, reason: CaseDurationSyncReason + ) -> None: + if config.TRACECAT__CASE_DURATION_SYNC_ENABLED: + enqueue_case_duration_sync_after_commit( + self.session, + workspace_id=self.workspace_id, + reason=reason, + ) + return + + await self._sync_existing_case_durations_inline() + + async def _sync_existing_case_durations_inline(self) -> None: + await self.session.flush() + stmt = ( + select(Case.id) + .where(Case.workspace_id == self.workspace_id) + .order_by(Case.surrogate_id.asc()) + ) + result = await self.session.execute(stmt) + duration_service = CaseDurationService(session=self.session, role=self.role) + async for case_id in cooperative_every(result.scalars().all(), every=8): + await duration_service.sync_case_durations(case_id) + async def _get_definition_entity( self, duration_id: uuid.UUID ) -> CaseDurationDefinitionDB: