diff --git a/api/entrypoints/routers.py b/api/entrypoints/routers.py index 1db32a85f3..800abd074d 100644 --- a/api/entrypoints/routers.py +++ b/api/entrypoints/routers.py @@ -149,6 +149,9 @@ from oss.src.core.triggers.registry import TriggersGatewayRegistry from oss.src.core.triggers.service import TriggersService from oss.src.apis.fastapi.triggers.router import TriggersRouter +from oss.src.tasks.asyncio.triggers.dispatcher import TriggersDispatcher +from oss.src.tasks.taskiq.triggers.worker import TriggersWorker +from taskiq_redis import RedisStreamBroker from oss.src.apis.fastapi.shared.utils import SupportHeadersMiddleware @@ -214,8 +217,12 @@ async def lifespan(*args, **kwargs): warn_deprecated_env_vars() validate_required_env_vars() + await _triggers_broker.startup() + yield + await _triggers_broker.shutdown() + for adapter in _composio_adapters.values(): await adapter.close() @@ -651,6 +658,26 @@ async def lifespan(*args, **kwargs): connections_service=connections_service, ) +# Producer side of the inbound dispatch pipeline: the ingress route enqueues +# `triggers.dispatch` tasks here; entrypoints/worker_triggers.py consumes them. +_triggers_broker = RedisStreamBroker( + url=env.redis.uri_durable, + queue_name="queues:triggers", + consumer_group_name="api-triggers-producer", + maxlen=100_000, + approximate=True, +) + +_triggers_dispatcher = TriggersDispatcher( + triggers_dao=triggers_dao, + workflows_service=workflows_service, +) + +_triggers_worker = TriggersWorker( + broker=_triggers_broker, + dispatcher=_triggers_dispatcher, +) + _t_services_done = time.perf_counter() - _t_services print(f"[STARTUP] Service initialization completed (+{_t_services_done:.3f}s)") _t_routers = time.perf_counter() @@ -767,6 +794,7 @@ async def lifespan(*args, **kwargs): triggers = TriggersRouter( triggers_service=triggers_service, + dispatch_task=_triggers_worker.dispatch_trigger, ) simple_traces = SimpleTracesRouter( diff --git a/api/entrypoints/worker_triggers.py b/api/entrypoints/worker_triggers.py new file mode 100644 index 0000000000..1b25bef5a7 --- /dev/null +++ b/api/entrypoints/worker_triggers.py @@ -0,0 +1,142 @@ +import sys + +from taskiq.cli.worker.run import run_worker +from taskiq.cli.worker.args import WorkerArgs +from taskiq_redis import RedisStreamBroker + +from oss.src.utils.logging import get_module_logger +from oss.src.utils.helpers import warn_deprecated_env_vars, validate_required_env_vars +from oss.src.utils.env import env + +from oss.src.utils.common import is_ee +from oss.src.dbs.postgres.git.dao import GitDAO +from oss.src.dbs.postgres.triggers.dao import TriggersDAO +from oss.src.dbs.postgres.workflows.dbes import ( + WorkflowArtifactDBE, + WorkflowVariantDBE, + WorkflowRevisionDBE, +) +from oss.src.dbs.postgres.environments.dbes import ( + EnvironmentArtifactDBE, + EnvironmentVariantDBE, + EnvironmentRevisionDBE, +) +from oss.src.core.workflows.service import WorkflowsService +from oss.src.core.environments.service import EnvironmentsService +from oss.src.core.embeds.service import EmbedsService +from oss.src.tasks.asyncio.triggers.dispatcher import TriggersDispatcher +from oss.src.tasks.taskiq.triggers.worker import TriggersWorker + +# Guard EE imports — see worker_tracing.py for the rationale. +if is_ee(): + from ee.src.core.access.entitlements.service import bootstrap_entitlements_services + + +import agenta as ag + +log = get_module_logger(__name__) + +# Initialize Agenta SDK +ag.init( + api_url=env.agenta.api_url, +) + +# Bound the stream so acked entries are trimmed; without this it grows unbounded. +MAXLEN_QUEUES_TRIGGERS = 100_000 + +# BROKER ------------------------------------------------------------------- +broker = RedisStreamBroker( + url=env.redis.uri_durable, + queue_name="queues:triggers", + consumer_group_name="worker-triggers", + maxlen=MAXLEN_QUEUES_TRIGGERS, + approximate=True, +) + + +# WORKERS ------------------------------------------------------------------ +triggers_dao = TriggersDAO() + +workflows_dao = GitDAO( + ArtifactDBE=WorkflowArtifactDBE, + VariantDBE=WorkflowVariantDBE, + RevisionDBE=WorkflowRevisionDBE, +) + +environments_dao = GitDAO( + ArtifactDBE=EnvironmentArtifactDBE, + VariantDBE=EnvironmentVariantDBE, + RevisionDBE=EnvironmentRevisionDBE, +) + +workflows_service = WorkflowsService( + workflows_dao=workflows_dao, +) + +environments_service = EnvironmentsService( + environments_dao=environments_dao, +) + +embeds_service = EmbedsService( + workflows_service=workflows_service, + environments_service=environments_service, +) + +workflows_service.environments_service = environments_service +workflows_service.embeds_service = embeds_service +environments_service.embeds_service = embeds_service + +triggers_dispatcher = TriggersDispatcher( + triggers_dao=triggers_dao, + workflows_service=workflows_service, +) + +triggers_worker = TriggersWorker( + broker=broker, + dispatcher=triggers_dispatcher, +) + + +def main() -> int: + """ + Main entry point for the worker. + + Returns: + Exit code (0 for success, non-zero for failure) + """ + try: + log.info("[TRIGGERS] Initializing Taskiq worker") + + # Validate environment + warn_deprecated_env_vars() + validate_required_env_vars() + + # Wire EE entitlement services so `check_entitlements` works in + # this worker process. Gated on `is_ee()` to match the import above. + if is_ee(): + bootstrap_entitlements_services() + + log.info("[TRIGGERS] Starting Taskiq worker with Redis Streams") + + # Run Taskiq worker + args = WorkerArgs( + broker="entrypoints.worker_triggers:broker", # Reference broker from this module + modules=[], + fs_discover=False, + workers=1, + max_async_tasks=50, + ) + + result = run_worker(args) + return result if result is not None else 0 + + except KeyboardInterrupt: + log.info("[TRIGGERS] Shutdown requested") + return 0 + except Exception as e: + log.error("[TRIGGERS] Fatal error", error=str(e)) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/api/oss/src/apis/fastapi/triggers/models.py b/api/oss/src/apis/fastapi/triggers/models.py index 9d671ac49d..9e13dd38f4 100644 --- a/api/oss/src/apis/fastapi/triggers/models.py +++ b/api/oss/src/apis/fastapi/triggers/models.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -91,3 +91,26 @@ class TriggerDeliveryResponse(BaseModel): class TriggerDeliveriesResponse(BaseModel): count: int = 0 deliveries: List[TriggerDelivery] = Field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Trigger Ingress (inbound provider events) +# --------------------------------------------------------------------------- + + +class TriggerEventAck(BaseModel): + status: str = "accepted" + detail: Optional[str] = None + + +class ComposioEventEnvelope(BaseModel): + """Loose view of a Composio trigger webhook envelope (`{data, type, ...}`). + + Demultiplexing keys live under ``metadata`` (``trigger_id``, ``id``); the rest + is passed through to the resolver as the inbound event. + """ + + type: Optional[str] = None + timestamp: Optional[str] = None + data: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None diff --git a/api/oss/src/apis/fastapi/triggers/router.py b/api/oss/src/apis/fastapi/triggers/router.py index a1e5281e8f..1bb1d66f80 100644 --- a/api/oss/src/apis/fastapi/triggers/router.py +++ b/api/oss/src/apis/fastapi/triggers/router.py @@ -1,5 +1,8 @@ +import hashlib +import hmac from functools import wraps -from typing import Optional +from json import JSONDecodeError, loads +from typing import Any, Optional from uuid import UUID import httpx @@ -10,6 +13,7 @@ from oss.src.utils.logging import get_module_logger from oss.src.utils.caching import get_cache, set_cache from oss.src.utils.common import is_ee +from oss.src.utils.env import env from oss.src.apis.fastapi.triggers.models import ( TriggerCatalogEventResponse, @@ -19,6 +23,7 @@ TriggerDeliveriesResponse, TriggerDeliveryQueryRequest, TriggerDeliveryResponse, + TriggerEventAck, TriggerSubscriptionCreateRequest, TriggerSubscriptionEditRequest, TriggerSubscriptionQueryRequest, @@ -76,16 +81,58 @@ async def wrapper(*args, **kwargs): return decorator +def _verify_composio_signature( + *, + body: bytes, + headers: Any, +) -> bool: + """HMAC-SHA256 verify over ``{id}.{ts}.{body}`` with ``COMPOSIO_WEBHOOK_SECRET``. + + Returns True when the secret is unset (no-op) or the signature matches. + """ + secret = env.composio.webhook_secret + if not secret: + return True + + signature = headers.get("webhook-signature") or headers.get("x-composio-signature") + webhook_id = headers.get("webhook-id") or "" + timestamp = headers.get("webhook-timestamp") or "" + if not signature: + return False + + signed = f"{webhook_id}.{timestamp}.{body.decode('utf-8', errors='replace')}" + expected = hmac.new( + secret.encode("utf-8"), + signed.encode("utf-8"), + hashlib.sha256, + ).hexdigest() + + provided = signature.split(",")[-1].strip() + return hmac.compare_digest(expected, provided) + + class TriggersRouter: def __init__( self, *, triggers_service: TriggersService, + dispatch_task: Optional[Any] = None, ): self.triggers_service = triggers_service + self.dispatch_task = dispatch_task self.router = APIRouter() + # --- Trigger Ingress (inbound provider events) --- + self.router.add_api_route( + "/composio/events", + self.ingest_composio_event, + methods=["POST"], + operation_id="ingest_composio_event", + response_model=TriggerEventAck, + status_code=status.HTTP_202_ACCEPTED, + ) + # --- Trigger Catalog --- self.router.add_api_route( "/catalog/providers/", @@ -711,3 +758,52 @@ async def fetch_delivery( count=1, delivery=delivery, ) + + # ----------------------------------------------------------------------- + # Trigger Ingress (inbound provider events) + # ----------------------------------------------------------------------- + + @intercept_exceptions() + async def ingest_composio_event( + self, + request: Request, + ) -> Any: + """Receive a Composio provider event; verify, demux, ack-fast, enqueue. + + Public (no Agenta auth) — mirrors the Stripe events receiver. Scope and + attribution are recovered downstream from the resolved subscription row. + """ + body = await request.body() + + if not _verify_composio_signature(body=body, headers=request.headers): + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"status": "error", "detail": "Signature verification failed"}, + ) + + try: + envelope = loads(body) if body else {} + except JSONDecodeError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid payload", + ) + + metadata = envelope.get("metadata") or {} + trigger_id = metadata.get("trigger_id") or metadata.get("nano_id") + event_id = metadata.get("id") + + if not trigger_id or not event_id: + # Nothing to route — accept (no-op) so the provider does not retry. + return TriggerEventAck( + status="accepted", detail="No trigger_id/id to route" + ) + + if self.dispatch_task is not None: + await self.dispatch_task.kiq( + trigger_id=str(trigger_id), + event_id=str(event_id), + event=envelope, + ) + + return TriggerEventAck(status="accepted") diff --git a/api/oss/src/core/triggers/dtos.py b/api/oss/src/core/triggers/dtos.py index b2a302d6b9..2d7a1769f3 100644 --- a/api/oss/src/core/triggers/dtos.py +++ b/api/oss/src/core/triggers/dtos.py @@ -15,6 +15,13 @@ ) +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +TRIGGER_MAX_RETRIES = 5 + + # --------------------------------------------------------------------------- # Trigger Enums # --------------------------------------------------------------------------- diff --git a/api/oss/src/core/triggers/interfaces.py b/api/oss/src/core/triggers/interfaces.py index a7280c90ea..5221b52349 100644 --- a/api/oss/src/core/triggers/interfaces.py +++ b/api/oss/src/core/triggers/interfaces.py @@ -148,6 +148,15 @@ async def get_subscription_by_trigger_id( """FROZEN (WP4): resolve an inbound event's ``ti_*`` to its local row.""" ... + @abstractmethod + async def get_project_and_subscription_by_trigger_id( + self, + *, + trigger_id: str, + ) -> Optional[Tuple[UUID, TriggerSubscription]]: + """Resolve a ``ti_*`` to its (project_id, subscription); the DTO omits project scope.""" + ... + # --- deliveries --------------------------------------------------------- # @abstractmethod diff --git a/api/oss/src/dbs/postgres/triggers/dao.py b/api/oss/src/dbs/postgres/triggers/dao.py index b3a1a51e3c..c53bf2b9eb 100644 --- a/api/oss/src/dbs/postgres/triggers/dao.py +++ b/api/oss/src/dbs/postgres/triggers/dao.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import List, Optional +from typing import List, Optional, Tuple from uuid import UUID from sqlalchemy import select @@ -233,6 +233,32 @@ async def get_subscription_by_trigger_id( subscription_dbe=subscription_dbe, ) + async def get_project_and_subscription_by_trigger_id( + self, + *, + trigger_id: str, + ) -> Optional[Tuple[UUID, TriggerSubscription]]: + async with self.engine.session() as session: + stmt = ( + select(TriggerSubscriptionDBE) + .filter( + TriggerSubscriptionDBE.data["ti_id"].astext == trigger_id, + ) + .limit(1) + ) + + result = await session.execute(stmt) + + subscription_dbe = result.scalars().first() + + if not subscription_dbe: + return None + + return ( + subscription_dbe.project_id, + map_subscription_dbe_to_dto(subscription_dbe=subscription_dbe), + ) + # --- DELIVERIES --------------------------------------------------------- # async def write_delivery( diff --git a/api/oss/src/middlewares/auth.py b/api/oss/src/middlewares/auth.py index 1cf4ab698b..bdbc1ee8c9 100644 --- a/api/oss/src/middlewares/auth.py +++ b/api/oss/src/middlewares/auth.py @@ -69,6 +69,11 @@ "/api/tools/connections/callback", "/preview/tools/connections/callback", "/api/preview/tools/connections/callback", + # TRIGGERS — inbound provider events arrive from Composio with no auth token + "/triggers/composio/events", + "/api/triggers/composio/events", + "/preview/triggers/composio/events", + "/api/preview/triggers/composio/events", ) _ADMIN_ENDPOINT_IDENTIFIER = "/admin/" diff --git a/api/oss/src/tasks/asyncio/triggers/__init__.py b/api/oss/src/tasks/asyncio/triggers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/oss/src/tasks/asyncio/triggers/dispatcher.py b/api/oss/src/tasks/asyncio/triggers/dispatcher.py new file mode 100644 index 0000000000..3c2bcbdfe3 --- /dev/null +++ b/api/oss/src/tasks/asyncio/triggers/dispatcher.py @@ -0,0 +1,244 @@ +"""Trigger dispatcher — asyncio side of the inbound pipeline. + +The inbound dual of ``webhooks/dispatcher.py``. Given a verified Composio event +(``ti_*`` trigger id + ``metadata.id`` dedup key + raw payload), it resolves the +local subscription, dedups, maps ``inputs_fields`` into the workflow inputs, runs +the bound workflow, and records a single delivery row with the outcome. + +Self-contained so it can run inside its own TaskIQ worker process. +""" + +from typing import Any, Dict, Optional +from uuid import UUID + +import uuid_utils.compat as uuid_compat + +from oss.src.core.shared.dtos import Status +from oss.src.core.triggers.dtos import ( + TRIGGER_EVENT_FIELDS, + SUBSCRIPTION_CONTEXT_FIELDS, + TriggerDeliveryCreate, + TriggerDeliveryData, + TriggerSubscription, +) +from oss.src.core.triggers.interfaces import TriggersDAOInterface +from oss.src.core.workflows.service import WorkflowsService +from oss.src.utils.logging import get_module_logger + +from agenta.sdk.decorators.running import WorkflowServiceRequest +from agenta.sdk.models.workflows import WorkflowRequestData +from agenta.sdk.utils.resolvers import resolve_target_fields + +log = get_module_logger(__name__) + + +class TriggersDispatcher: + """Resolves and runs one inbound provider event against its bound workflow.""" + + def __init__( + self, + *, + triggers_dao: TriggersDAOInterface, + workflows_service: WorkflowsService, + ): + self.triggers_dao = triggers_dao + self.workflows_service = workflows_service + + def _build_context( + self, + *, + event: Dict[str, Any], + subscription: TriggerSubscription, + project_id: UUID, + ) -> Dict[str, Any]: + sub_dump = subscription.model_dump(mode="json", exclude_none=True) + return { + "event": {k: v for k, v in event.items() if k in TRIGGER_EVENT_FIELDS}, + "subscription": { + k: v for k, v in sub_dump.items() if k in SUBSCRIPTION_CONTEXT_FIELDS + }, + "scope": {"project_id": str(project_id)}, + } + + async def dispatch( + self, + *, + trigger_id: str, + event_id: str, + event: Dict[str, Any], + ) -> None: + """Run the bound workflow for one inbound event (idempotent on event_id).""" + resolved = await self.triggers_dao.get_project_and_subscription_by_trigger_id( + trigger_id=trigger_id, + ) + + if resolved is None: + log.info( + "[TRIGGERS DISPATCHER] Unknown trigger_id %s — skipping", trigger_id + ) + return + + project_id, subscription = resolved + + if not subscription.enabled: + log.info( + "[TRIGGERS DISPATCHER] Subscription %s disabled — skipping", + subscription.id, + ) + return + + already_seen = await self.triggers_dao.dedup_seen( + project_id=project_id, + subscription_id=subscription.id, + event_id=event_id, + ) + if already_seen: + log.info( + "[TRIGGERS DISPATCHER] Duplicate event %s for subscription %s — skipping", + event_id, + subscription.id, + ) + return + + context = self._build_context( + event=event, + subscription=subscription, + project_id=project_id, + ) + + # MAPPING — inputs-only template (default whole-context "$" like webhooks). + template = subscription.data.inputs_fields + inputs = resolve_target_fields( + template if template is not None else "$", context + ) + + references = ( + { + k: ref.model_dump(mode="json", exclude_none=True) + for k, ref in subscription.data.references.items() + } + if subscription.data.references + else None + ) + selector = ( + subscription.data.selector.model_dump(mode="json", exclude_none=True) + if subscription.data.selector + else None + ) + + delivery_id = uuid_compat.uuid7() + user_id = subscription.created_by_id # M6 — attribute to the creator, or None + + delivery_data = TriggerDeliveryData( + event_key=subscription.data.event_key, + references=subscription.data.references, + inputs=inputs if isinstance(inputs, dict) else {"value": inputs}, + ) + + if not references: + await self._write_delivery( + project_id=project_id, + user_id=user_id, + delivery_id=delivery_id, + subscription_id=subscription.id, + event_id=event_id, + status=Status(code="400", message="failed"), + data=delivery_data.model_copy( + update={"error": "Subscription has no bound workflow reference"} + ), + ) + return + + try: + request = WorkflowServiceRequest( + references=references, + selector=selector, + data=WorkflowRequestData( + inputs=inputs if isinstance(inputs, dict) else {"value": inputs}, + ), + ) + + response = await self.workflows_service.invoke_workflow( + project_id=project_id, + user_id=user_id, + request=request, + ) + except Exception as e: + await self._write_delivery( + project_id=project_id, + user_id=user_id, + delivery_id=delivery_id, + subscription_id=subscription.id, + event_id=event_id, + status=Status(code="500", message="failed"), + data=delivery_data.model_copy(update={"error": str(e)}), + ) + raise + + status_obj = getattr(response, "status", None) + status_code = getattr(status_obj, "code", None) + outputs = getattr(response, "outputs", None) or getattr( + getattr(response, "data", None), "outputs", None + ) + + if status_code not in (None, 200): + await self._write_delivery( + project_id=project_id, + user_id=user_id, + delivery_id=delivery_id, + subscription_id=subscription.id, + event_id=event_id, + status=Status(code=str(status_code), message="failed"), + data=delivery_data.model_copy( + update={ + "error": getattr(status_obj, "message", None) + or "Workflow failed", + "result": { + "trace_id": getattr(response, "trace_id", None), + "span_id": getattr(response, "span_id", None), + }, + } + ), + ) + return + + await self._write_delivery( + project_id=project_id, + user_id=user_id, + delivery_id=delivery_id, + subscription_id=subscription.id, + event_id=event_id, + status=Status(code="200", message="success"), + data=delivery_data.model_copy( + update={ + "result": { + "trace_id": getattr(response, "trace_id", None), + "span_id": getattr(response, "span_id", None), + "outputs": outputs, + } + } + ), + ) + + async def _write_delivery( + self, + *, + project_id: UUID, + user_id: Optional[UUID], + delivery_id: UUID, + subscription_id: UUID, + event_id: str, + status: Status, + data: TriggerDeliveryData, + ) -> None: + await self.triggers_dao.write_delivery( + project_id=project_id, + user_id=user_id, + delivery=TriggerDeliveryCreate( + id=delivery_id, + subscription_id=subscription_id, + event_id=event_id, + status=status, + data=data, + ), + ) diff --git a/api/oss/src/tasks/taskiq/triggers/__init__.py b/api/oss/src/tasks/taskiq/triggers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/oss/src/tasks/taskiq/triggers/worker.py b/api/oss/src/tasks/taskiq/triggers/worker.py new file mode 100644 index 0000000000..8bcf26d553 --- /dev/null +++ b/api/oss/src/tasks/taskiq/triggers/worker.py @@ -0,0 +1,63 @@ +from typing import Any, Dict + +from taskiq import AsyncBroker, Context, TaskiqDepends + +from oss.src.core.triggers.dtos import TRIGGER_MAX_RETRIES +from oss.src.tasks.asyncio.triggers.dispatcher import TriggersDispatcher +from oss.src.utils.logging import get_module_logger + +log = get_module_logger(__name__) + + +class TriggersWorker: + """Registers and owns the TaskIQ trigger dispatch task. + + The dispatch task receives the verified Composio event inline and runs the + bound workflow, writing a single delivery row on the outcome. Idempotency + comes from the WP3 ``dedup_seen`` guard, so provider + TaskIQ retries are safe. + """ + + def __init__( + self, + *, + broker: AsyncBroker, + dispatcher: TriggersDispatcher, + ): + self.broker = broker + self.dispatcher = dispatcher + + self._register_tasks() + + def _register_tasks(self): + @self.broker.task( + task_name="triggers.dispatch", + retry_on_error=True, + max_retries=TRIGGER_MAX_RETRIES, + ) + async def dispatch_trigger( + *, + trigger_id: str, + event_id: str, + event: Dict[str, Any], + # + context: Context = TaskiqDepends(), + ) -> None: + retry_count_raw = context.message.labels.get("_taskiq_retry_count", 0) or 0 + try: + retry_count = int(retry_count_raw) + except (TypeError, ValueError): + retry_count = 0 + + log.info( + f"[TASK] triggers.dispatch " + f"trigger={trigger_id} event={event_id} " + f"attempt={retry_count}/{TRIGGER_MAX_RETRIES}" + ) + + await self.dispatcher.dispatch( + trigger_id=trigger_id, + event_id=event_id, + event=event, + ) + + self.dispatch_trigger = dispatch_trigger diff --git a/api/oss/src/utils/env.py b/api/oss/src/utils/env.py index 585386c33e..993ab83725 100644 --- a/api/oss/src/utils/env.py +++ b/api/oss/src/utils/env.py @@ -510,6 +510,7 @@ class ComposioConfig(BaseModel): api_key: str | None = os.getenv("COMPOSIO_API_KEY") api_url: str = os.getenv("COMPOSIO_API_URL", "https://backend.composio.dev/api/v3") + webhook_secret: str | None = os.getenv("COMPOSIO_WEBHOOK_SECRET") @property def enabled(self) -> bool: diff --git a/api/oss/tests/pytest/acceptance/triggers/test_triggers_ingress.py b/api/oss/tests/pytest/acceptance/triggers/test_triggers_ingress.py new file mode 100644 index 0000000000..d76db95ed3 --- /dev/null +++ b/api/oss/tests/pytest/acceptance/triggers/test_triggers_ingress.py @@ -0,0 +1,163 @@ +"""Acceptance tests for POST /triggers/composio/events (inbound ingress). + +The ingress is the inbound dual of webhooks: a public (no Agenta auth) endpoint +that Composio POSTs provider events to. It ACKs fast (202) and enqueues dispatch +asynchronously; the actual workflow run + delivery write happen in a separate +worker, so the unconditional paths here are DB-free: + + - an event for an unknown trigger id is a clean 202 no-op (nothing to route); + - an event with no routable metadata is a clean 202 no-op. + +The signature-rejection path only bites when COMPOSIO_WEBHOOK_SECRET is set +(unset → 200/202 no-op, mirroring the Stripe receiver), so it is gated on that. +The full signed-event -> workflow-invoked -> single-delivery roundtrip needs the +live Composio adapter and a bound workflow, so it is gated on COMPOSIO_API_KEY. + +Requires a running API. +""" + +import os +from uuid import uuid4 + +import pytest + + +_COMPOSIO_ENABLED = bool(os.getenv("COMPOSIO_API_KEY")) +_WEBHOOK_SECRET = os.getenv("COMPOSIO_WEBHOOK_SECRET") + +_requires_composio = pytest.mark.skipif( + not _COMPOSIO_ENABLED, + reason="needs live Composio credentials (COMPOSIO_API_KEY)", +) +_requires_webhook_secret = pytest.mark.skipif( + not _WEBHOOK_SECRET, + reason="needs COMPOSIO_WEBHOOK_SECRET set to verify signature rejection", +) + + +# --------------------------------------------------------------------------- +# DB-only: unknown trigger / no metadata are clean 202 no-ops +# --------------------------------------------------------------------------- + + +class TestTriggerIngressNoOps: + def test_unknown_trigger_id_is_accepted_noop(self, unauthed_api): + response = unauthed_api( + "POST", + "/triggers/composio/events", + json={ + "type": "github_star_added_event", + "metadata": { + "trigger_id": f"ti_{uuid4().hex}", + "id": uuid4().hex, + }, + "data": {"repository": "acme/widgets"}, + }, + ) + assert response.status_code == 202, response.text + assert response.json()["status"] == "accepted" + + def test_no_routable_metadata_is_accepted_noop(self, unauthed_api): + response = unauthed_api( + "POST", + "/triggers/composio/events", + json={"type": "some_event", "data": {}}, + ) + assert response.status_code == 202, response.text + assert response.json()["status"] == "accepted" + + def test_empty_body_is_accepted_noop(self, unauthed_api): + response = unauthed_api("POST", "/triggers/composio/events", data=b"") + assert response.status_code == 202, response.text + + +@_requires_webhook_secret +class TestTriggerIngressSignature: + def test_forged_signature_is_rejected(self, unauthed_api): + response = unauthed_api( + "POST", + "/triggers/composio/events", + headers={ + "webhook-id": "msg_1", + "webhook-timestamp": "1700000000", + "webhook-signature": "v1,deadbeef", + }, + json={ + "metadata": {"trigger_id": f"ti_{uuid4().hex}", "id": uuid4().hex}, + }, + ) + assert response.status_code == 401, response.text + + +# --------------------------------------------------------------------------- +# Dedup (needs Composio) — a duplicate metadata.id does not double-write a +# delivery. Exercised end-to-end via a real subscription bound to a workflow. +# --------------------------------------------------------------------------- + + +@_requires_composio +class TestTriggerIngressDedup: + def test_duplicate_event_id_writes_single_delivery(self, authed_api, unauthed_api): + # Create a connection + subscription so an inbound ti_* resolves locally. + slug = f"acc-{uuid4().hex[:8]}" + conn = authed_api( + "POST", + "/tools/connections/", + json={ + "connection": { + "slug": slug, + "provider_key": "composio", + "integration_key": "github", + "data": {"auth_scheme": "oauth"}, + } + }, + ) + assert conn.status_code == 200, conn.text + connection_id = conn.json()["connection"]["id"] + + create = authed_api( + "POST", + "/triggers/subscriptions/", + json={ + "subscription": { + "name": f"sub-{uuid4().hex[:8]}", + "connection_id": connection_id, + "data": { + "event_key": "GITHUB_STAR_ADDED_EVENT", + "trigger_config": {}, + "inputs_fields": {"repo": "$.event.data.repository"}, + "references": {"workflow": {"slug": "triage"}}, + }, + } + }, + ) + assert create.status_code == 200, create.text + sub = create.json()["subscription"] + subscription_id = sub["id"] + ti_id = sub["data"]["ti_id"] + + event_id = uuid4().hex + envelope = { + "type": "github_star_added_event", + "metadata": {"trigger_id": ti_id, "id": event_id}, + "data": {"repository": "acme/widgets"}, + } + + # Post the same event twice (provider redelivery) — dedup must hold. + for _ in range(2): + ack = unauthed_api("POST", "/triggers/composio/events", json=envelope) + assert ack.status_code == 202, ack.text + + # The dispatch is async; the dedup guard means at most one delivery row + # exists for this (subscription, event_id). + deliveries = authed_api( + "POST", + "/triggers/deliveries/query", + json={ + "delivery": {"subscription_id": subscription_id, "event_id": event_id} + }, + ).json()["deliveries"] + assert len(deliveries) <= 1 + + authed_api("DELETE", f"/triggers/subscriptions/{subscription_id}") + authed_api("DELETE", f"/tools/connections/{connection_id}") diff --git a/api/oss/tests/pytest/unit/triggers/__init__.py b/api/oss/tests/pytest/unit/triggers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/oss/tests/pytest/unit/triggers/test_triggers_dispatcher.py b/api/oss/tests/pytest/unit/triggers/test_triggers_dispatcher.py new file mode 100644 index 0000000000..e50fcf157b --- /dev/null +++ b/api/oss/tests/pytest/unit/triggers/test_triggers_dispatcher.py @@ -0,0 +1,149 @@ +"""Unit tests for the trigger dispatcher. + +The inbound dual of ``test_webhooks_dispatcher.py``. Stubs the DAO and workflows +service (no DB, no Composio) and pins the dispatch branches: unknown trigger, +disabled subscription, dedup, missing workflow reference, and the happy path. +""" + +from types import SimpleNamespace +from uuid import uuid4 + +from unittest.mock import AsyncMock, MagicMock + +from oss.src.core.shared.dtos import Reference +from oss.src.tasks.asyncio.triggers.dispatcher import TriggersDispatcher + + +def _make_subscription(*, enabled=True, references=None, inputs_fields=None): + data = SimpleNamespace( + event_key="github.issue.opened", + inputs_fields=inputs_fields, + references=references, + selector=None, + ) + return SimpleNamespace( + id=uuid4(), + enabled=enabled, + created_by_id=uuid4(), + data=data, + model_dump=lambda **_kwargs: {"id": "sub", "name": "watch"}, + ) + + +def _make_dao(*, resolved, seen=False): + dao = MagicMock() + dao.get_project_and_subscription_by_trigger_id = AsyncMock(return_value=resolved) + dao.dedup_seen = AsyncMock(return_value=seen) + dao.write_delivery = AsyncMock() + return dao + + +_EVENT = {"type": "github.issue.opened", "data": {"issue": {"number": 7}}} + + +async def test_unknown_trigger_id_is_skipped(): + dao = _make_dao(resolved=None) + workflows = MagicMock() + dispatcher = TriggersDispatcher(triggers_dao=dao, workflows_service=workflows) + + await dispatcher.dispatch(trigger_id="ti_unknown", event_id="e1", event=_EVENT) + + dao.dedup_seen.assert_not_awaited() + dao.write_delivery.assert_not_awaited() + + +async def test_disabled_subscription_is_skipped(): + project_id = uuid4() + subscription = _make_subscription(enabled=False) + dao = _make_dao(resolved=(project_id, subscription)) + dispatcher = TriggersDispatcher(triggers_dao=dao, workflows_service=MagicMock()) + + await dispatcher.dispatch(trigger_id="ti_1", event_id="e1", event=_EVENT) + + dao.dedup_seen.assert_not_awaited() + dao.write_delivery.assert_not_awaited() + + +async def test_duplicate_event_is_skipped(): + project_id = uuid4() + subscription = _make_subscription(references={"workflow": MagicMock()}) + dao = _make_dao(resolved=(project_id, subscription), seen=True) + dispatcher = TriggersDispatcher(triggers_dao=dao, workflows_service=MagicMock()) + + await dispatcher.dispatch(trigger_id="ti_1", event_id="e1", event=_EVENT) + + dao.dedup_seen.assert_awaited_once() + dao.write_delivery.assert_not_awaited() + + +async def test_missing_reference_writes_failed_delivery(): + project_id = uuid4() + subscription = _make_subscription(references=None) + dao = _make_dao(resolved=(project_id, subscription)) + workflows = MagicMock() + workflows.invoke_workflow = AsyncMock() + dispatcher = TriggersDispatcher(triggers_dao=dao, workflows_service=workflows) + + await dispatcher.dispatch(trigger_id="ti_1", event_id="e1", event=_EVENT) + + workflows.invoke_workflow.assert_not_awaited() + dao.write_delivery.assert_awaited_once() + delivery = dao.write_delivery.await_args.kwargs["delivery"] + assert delivery.status.code == "400" + assert "no bound workflow" in delivery.data.error.lower() + + +async def test_happy_path_invokes_workflow_and_writes_success(): + project_id = uuid4() + reference = Reference(slug="wf-1") + subscription = _make_subscription( + references={"workflow": reference}, + inputs_fields={"number": "$.event.data.issue.number"}, + ) + dao = _make_dao(resolved=(project_id, subscription)) + + response = SimpleNamespace( + status=SimpleNamespace(code=200, message="success"), + outputs={"ok": True}, + trace_id="tr-1", + span_id="sp-1", + ) + workflows = MagicMock() + workflows.invoke_workflow = AsyncMock(return_value=response) + dispatcher = TriggersDispatcher(triggers_dao=dao, workflows_service=workflows) + + await dispatcher.dispatch(trigger_id="ti_1", event_id="e1", event=_EVENT) + + workflows.invoke_workflow.assert_awaited_once() + invoke_kwargs = workflows.invoke_workflow.await_args.kwargs + assert invoke_kwargs["project_id"] == project_id + assert invoke_kwargs["user_id"] == subscription.created_by_id + + dao.write_delivery.assert_awaited_once() + delivery = dao.write_delivery.await_args.kwargs["delivery"] + assert delivery.status.code == "200" + assert delivery.event_id == "e1" + assert delivery.data.inputs == {"number": 7} + + +async def test_workflow_non_200_writes_failed_delivery(): + project_id = uuid4() + reference = Reference(slug="wf-1") + subscription = _make_subscription(references={"workflow": reference}) + dao = _make_dao(resolved=(project_id, subscription)) + + response = SimpleNamespace( + status=SimpleNamespace(code=500, message="boom"), + outputs=None, + trace_id="tr-1", + span_id="sp-1", + ) + workflows = MagicMock() + workflows.invoke_workflow = AsyncMock(return_value=response) + dispatcher = TriggersDispatcher(triggers_dao=dao, workflows_service=workflows) + + await dispatcher.dispatch(trigger_id="ti_1", event_id="e1", event=_EVENT) + + dao.write_delivery.assert_awaited_once() + delivery = dao.write_delivery.await_args.kwargs["delivery"] + assert delivery.status.code == "500" diff --git a/api/oss/tests/pytest/unit/triggers/test_triggers_signature.py b/api/oss/tests/pytest/unit/triggers/test_triggers_signature.py new file mode 100644 index 0000000000..d0d49ee0b7 --- /dev/null +++ b/api/oss/tests/pytest/unit/triggers/test_triggers_signature.py @@ -0,0 +1,106 @@ +"""Unit tests for Composio webhook signature verification. + +Pure HMAC logic, no network or database. The acceptance suite only exercises +this path when ``COMPOSIO_WEBHOOK_SECRET`` is present in the runner; these tests +pin the security contract (forged/missing signatures rejected) unconditionally. +""" + +import hashlib +import hmac + +from unittest.mock import patch + +from oss.src.apis.fastapi.triggers.router import _verify_composio_signature + +_SECRET = "whsec_test_secret" +_WEBHOOK_ID = "wh-1" +_TIMESTAMP = "1700000000" +_BODY = b'{"type":"github.issue.opened"}' + +_ENV_PATH = "oss.src.apis.fastapi.triggers.router.env" + + +def _sign(secret: str, webhook_id: str, timestamp: str, body: bytes) -> str: + signed = f"{webhook_id}.{timestamp}.{body.decode('utf-8')}" + return hmac.new( + secret.encode("utf-8"), + signed.encode("utf-8"), + hashlib.sha256, + ).hexdigest() + + +class _Env: + """Minimal stand-in for the shared env object's composio config.""" + + class composio: # noqa: N801 - mirrors env.composio attribute access + webhook_secret = None + + +def _env_with_secret(secret): + env = _Env() + env.composio.webhook_secret = secret + return env + + +class TestVerifyComposioSignature: + def test_unset_secret_is_noop_accept(self): + with patch(_ENV_PATH, _env_with_secret(None)): + assert _verify_composio_signature(body=_BODY, headers={}) is True + + def test_valid_signature_accepted(self): + sig = _sign(_SECRET, _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers = { + "webhook-signature": sig, + "webhook-id": _WEBHOOK_ID, + "webhook-timestamp": _TIMESTAMP, + } + with patch(_ENV_PATH, _env_with_secret(_SECRET)): + assert _verify_composio_signature(body=_BODY, headers=headers) is True + + def test_valid_signature_with_versioned_prefix_accepted(self): + # Composio sends "v1,"; only the last comma-part is the digest. + sig = _sign(_SECRET, _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers = { + "webhook-signature": f"v1,{sig}", + "webhook-id": _WEBHOOK_ID, + "webhook-timestamp": _TIMESTAMP, + } + with patch(_ENV_PATH, _env_with_secret(_SECRET)): + assert _verify_composio_signature(body=_BODY, headers=headers) is True + + def test_forged_signature_rejected(self): + headers = { + "webhook-signature": "deadbeef", + "webhook-id": _WEBHOOK_ID, + "webhook-timestamp": _TIMESTAMP, + } + with patch(_ENV_PATH, _env_with_secret(_SECRET)): + assert _verify_composio_signature(body=_BODY, headers=headers) is False + + def test_missing_signature_header_rejected(self): + headers = {"webhook-id": _WEBHOOK_ID, "webhook-timestamp": _TIMESTAMP} + with patch(_ENV_PATH, _env_with_secret(_SECRET)): + assert _verify_composio_signature(body=_BODY, headers=headers) is False + + def test_tampered_body_rejected(self): + sig = _sign(_SECRET, _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers = { + "webhook-signature": sig, + "webhook-id": _WEBHOOK_ID, + "webhook-timestamp": _TIMESTAMP, + } + with patch(_ENV_PATH, _env_with_secret(_SECRET)): + assert ( + _verify_composio_signature(body=b'{"type":"tampered"}', headers=headers) + is False + ) + + def test_x_composio_signature_header_alias(self): + sig = _sign(_SECRET, _WEBHOOK_ID, _TIMESTAMP, _BODY) + headers = { + "x-composio-signature": sig, + "webhook-id": _WEBHOOK_ID, + "webhook-timestamp": _TIMESTAMP, + } + with patch(_ENV_PATH, _env_with_secret(_SECRET)): + assert _verify_composio_signature(body=_BODY, headers=headers) is True